Skip to content

Commit

Permalink
fix: apply_semaphore (#446)
Browse files Browse the repository at this point in the history
* fix: apply_semaphore breaks with our cython semaphore

* feat(test): test_apply_semaphore

* feat(test): test for failure cases too

* chore: `black .`

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
  • Loading branch information
BobTheBuidler and github-actions[bot] authored Nov 26, 2024
1 parent 792fd99 commit 1db391b
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 5 deletions.
7 changes: 2 additions & 5 deletions a_sync/a_sync/modifiers/semaphores.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,6 @@
from a_sync import exceptions, primitives
from a_sync._typing import *

# We keep this here for now so we don't break downstream deps. Eventually will be removed.
from a_sync.primitives import ThreadsafeSemaphore, DummySemaphore


@overload
def apply_semaphore( # type: ignore [misc]
Expand Down Expand Up @@ -134,7 +131,7 @@ def apply_semaphore(
`primitives.Semaphore` is a subclass of `asyncio.Semaphore`. Therefore, when the documentation refers to `asyncio.Semaphore`, it also includes `primitives.Semaphore` and any other subclasses.
"""
# Parse Inputs
if isinstance(coro_fn, (int, asyncio.Semaphore)):
if isinstance(coro_fn, (int, asyncio.Semaphore, primitives.Semaphore)):
if semaphore is not None:
raise ValueError("You can only pass in one arg.")
semaphore = coro_fn
Expand All @@ -146,7 +143,7 @@ def apply_semaphore(
# Create the semaphore if necessary
if isinstance(semaphore, int):
semaphore = primitives.ThreadsafeSemaphore(semaphore)
elif not isinstance(semaphore, asyncio.Semaphore):
elif not isinstance(semaphore, (asyncio.Semaphore, primitives.Semaphore)):
raise TypeError(
f"'semaphore' must either be an integer or a Semaphore object. You passed {semaphore}"
)
Expand Down
42 changes: 42 additions & 0 deletions tests/a_sync/modifiers/test_apply_semaphore.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import asyncio
import pytest

from a_sync import Semaphore, apply_semaphore
from a_sync.exceptions import FunctionNotAsync


@pytest.mark.asyncio_cooperative
async def test_apply_semaphore_int():
apply_semaphore(asyncio.sleep, 1)


@pytest.mark.asyncio_cooperative
async def test_apply_semaphore_asyncio_semaphore():
apply_semaphore(asyncio.sleep, asyncio.Semaphore(1))


@pytest.mark.asyncio_cooperative
async def test_apply_semaphore_a_sync_semaphore():
apply_semaphore(asyncio.sleep, Semaphore(1))


def fail():
pass


@pytest.mark.asyncio_cooperative
async def test_apply_semaphore_failure_int():
with pytest.raises(FunctionNotAsync):
apply_semaphore(fail, 1)


@pytest.mark.asyncio_cooperative
async def test_apply_semaphore_failure_asyncio_semaphore():
with pytest.raises(FunctionNotAsync):
apply_semaphore(fail, asyncio.Semaphore(1))


@pytest.mark.asyncio_cooperative
async def test_apply_semaphore_failure_a_sync_semaphore():
with pytest.raises(FunctionNotAsync):
apply_semaphore(fail, Semaphore(1))

0 comments on commit 1db391b

Please sign in to comment.