From b70bc75d6a6efcd6f2ea098aa32306fe31815e57 Mon Sep 17 00:00:00 2001 From: BobTheBuidler <70677534+BobTheBuidler@users.noreply.github.com> Date: Tue, 26 Nov 2024 02:59:44 -0400 Subject: [PATCH] fix: PrioritySemaphore name-mangling issue (#447) --- a_sync/primitives/locks/counter.pyx | 5 - a_sync/primitives/locks/prio_semaphore.pxd | 1 - a_sync/primitives/locks/prio_semaphore.pyx | 63 ++++++----- a_sync/primitives/locks/semaphore.pxd | 5 +- a_sync/primitives/locks/semaphore.pyx | 116 ++++++++++++++------- tests/primitives/test_prio_semaphore.py | 7 ++ 6 files changed, 127 insertions(+), 70 deletions(-) diff --git a/a_sync/primitives/locks/counter.pyx b/a_sync/primitives/locks/counter.pyx index b1316812..9032c46f 100644 --- a/a_sync/primitives/locks/counter.pyx +++ b/a_sync/primitives/locks/counter.pyx @@ -209,11 +209,6 @@ cdef class CounterLock(_DebugDaemonMixin): ) await asyncio.sleep(300) - def __dealloc__(self): - # Free the memory allocated for __name - if self.__name is not NULL: - free(self.__name) - class CounterLockCluster: """ diff --git a/a_sync/primitives/locks/prio_semaphore.pxd b/a_sync/primitives/locks/prio_semaphore.pxd index 5b1cd10c..7a41f741 100644 --- a/a_sync/primitives/locks/prio_semaphore.pxd +++ b/a_sync/primitives/locks/prio_semaphore.pxd @@ -6,7 +6,6 @@ cdef class _AbstractPrioritySemaphore(Semaphore): cdef list _potential_lost_waiters cdef object _top_priority cdef object _context_manager_class - cdef list[_AbstractPrioritySemaphoreContextManager] __waiters cdef object c_getitem(self, object priority) cdef dict[object, int] _count_waiters(self) diff --git a/a_sync/primitives/locks/prio_semaphore.pyx b/a_sync/primitives/locks/prio_semaphore.pyx index eb8a5e51..6a706812 100644 --- a/a_sync/primitives/locks/prio_semaphore.pyx +++ b/a_sync/primitives/locks/prio_semaphore.pyx @@ -44,7 +44,7 @@ cdef class _AbstractPrioritySemaphore(Semaphore): self._context_managers = {} """A dictionary mapping priorities to their context managers.""" - self.__waiters = [] + self._Semaphore__waiters = [] """A heap queue of context managers, sorted by priority.""" # NOTE: This should (hopefully) be temporary @@ -144,7 +144,7 @@ cdef class _AbstractPrioritySemaphore(Semaphore): context_manager = self._context_manager_class( self, priority, name=self.name ) - heappush(self.__waiters, context_manager) # type: ignore [misc] + heappush(self._Semaphore__waiters, context_manager) # type: ignore [misc] self._context_managers[priority] = context_manager return self._context_managers[priority] @@ -163,14 +163,17 @@ cdef class _AbstractPrioritySemaphore(Semaphore): cdef bint c_locked(self): cdef _AbstractPrioritySemaphoreContextManager cm cdef object w - return self._Semaphore__value == 0 or ( - any( - cm._waiters and any(not w.cancelled() for w in cm._waiters) - for cm in (self._context_managers.values() or ()) - ) + if self._Semaphore__value == 0: + return True + cdef dict[object, _AbstractPrioritySemaphoreContextManager] cms = self._context_managers + if not cms: + return False + return any( + cm._Semaphore__waiters and any(not w.cancelled() for w in cm._Semaphore__waiters) + for cm in cms.values() ) - cdef dict[object, int] _count_waiters(self): + cdef dict[object, Py_ssize_t] _count_waiters(self): """Counts the number of waiters for each priority. Returns: @@ -181,10 +184,9 @@ cdef class _AbstractPrioritySemaphore(Semaphore): >>> semaphore._count_waiters() """ cdef _AbstractPrioritySemaphoreContextManager manager - return { - manager._priority: len(manager._waiters) - for manager in sorted(self.__waiters, key=lambda m: m._priority) - } + cdef list[_AbstractPrioritySemaphoreContextManager] waiters = self._Semaphore__waiters + return {manager._priority: len(manager._Semaphore__waiters) for manager in sorted(waiters)} + def _wake_up_next(self) -> None: """Wakes up the next waiter in line. @@ -199,10 +201,16 @@ cdef class _AbstractPrioritySemaphore(Semaphore): >>> semaphore = _AbstractPrioritySemaphore(5) >>> semaphore._wake_up_next() """ + self._c_wake_up_next() + + cdef void _c_wake_up_next(self): cdef _AbstractPrioritySemaphoreContextManager manager + cdef Py_ssize_t start_len, end_len + cdef bint woke_up + cdef bint debug_logs = c_logger.isEnabledFor(DEBUG) - while self.__waiters: - manager = heappop(self.__waiters) + while self._Semaphore__waiters: + manager = heappop(self._Semaphore__waiters) if len(manager) == 0: # There are no more waiters, get rid of the empty manager if debug_logs: @@ -219,11 +227,11 @@ cdef class _AbstractPrioritySemaphore(Semaphore): if debug_logs: c_logger._log(DEBUG, "waking up next for %s", (manager._c_repr_no_parent_(), )) - if not manager._waiters: - c_logger._log(DEBUG, "not manager._waiters") + if not manager._Semaphore__waiters: + c_logger._log(DEBUG, "not manager._Semaphore__waiters") - while manager._waiters: - waiter = manager._waiters.popleft() + while manager._Semaphore__waiters: + waiter = manager._Semaphore__waiters.popleft() self._potential_lost_waiters.remove(waiter) if not waiter.done(): waiter.set_result(None) @@ -242,7 +250,7 @@ cdef class _AbstractPrioritySemaphore(Semaphore): if end_len: # There are still waiters, put the manager back - heappush(self.__waiters, manager) # type: ignore [misc] + heappush(self._Semaphore__waiters, manager) # type: ignore [misc] else: # There are no more waiters, get rid of the empty manager self._context_managers.pop(manager._priority) @@ -255,6 +263,7 @@ cdef class _AbstractPrioritySemaphore(Semaphore): if not waiter.done(): waiter.set_result(None) return + return while self._potential_lost_waiters: waiter = self._potential_lost_waiters.pop(0) @@ -293,14 +302,14 @@ cdef class _AbstractPrioritySemaphoreContextManager(Semaphore): >>> context_manager = _AbstractPrioritySemaphoreContextManager(parent_semaphore, priority=1) """ + super().__init__(0, name=name) + self._parent = parent """The parent semaphore.""" self._priority = priority """The priority associated with this context manager.""" - super().__init__(0, name=name) - def __repr__(self) -> str: """Returns a string representation of the context manager.""" return f"<{self.__class__.__name__} parent={self._parent} {self._priority_name}={self._priority} waiters={len(self)}>" @@ -361,10 +370,10 @@ cdef class _AbstractPrioritySemaphoreContextManager(Semaphore): async def __acquire(self) -> Literal[True]: cdef object loop, fut while self._parent._Semaphore__value <= 0: - if self._AbstractPrioritySemaphoreContextManager__waiters is None: - self._AbstractPrioritySemaphoreContextManager__waiters = deque() + if self._Semaphore__waiters is None: + self._Semaphore__waiters = deque() fut = self._c_get_loop().create_future() - self._AbstractPrioritySemaphoreContextManager__waiters.append(fut) + self._Semaphore__waiters.append(fut) self._parent._potential_lost_waiters.append(fut) try: await fut @@ -406,7 +415,7 @@ cdef class _PrioritySemaphoreContextManager(_AbstractPrioritySemaphoreContextMan def __cinit__(self): self._priority_name = "priority" # Semaphore.__cinit__(self) - self._AbstractPrioritySemaphoreContextManager__waiters = deque() + self._Semaphore__waiters = deque() self._decorated: Set[str] = set() def __lt__(self, _PrioritySemaphoreContextManager other) -> bool: @@ -460,7 +469,7 @@ cdef class PrioritySemaphore(_AbstractPrioritySemaphore): # type: ignore [type- self._context_managers = {} """A dictionary mapping priorities to their context managers.""" - self.__waiters = [] + self._Semaphore__waiters = [] """A heap queue of context managers, sorted by priority.""" # NOTE: This should (hopefully) be temporary @@ -503,7 +512,7 @@ cdef class PrioritySemaphore(_AbstractPrioritySemaphore): # type: ignore [type- if priority not in context_managers: context_manager = _PrioritySemaphoreContextManager(self, priority, name=self.name) heappush( - self.__waiters, + self._Semaphore__waiters, context_manager, ) # type: ignore [misc] context_managers[priority] = context_manager diff --git a/a_sync/primitives/locks/semaphore.pxd b/a_sync/primitives/locks/semaphore.pxd index e15d23a9..3f1486b9 100644 --- a/a_sync/primitives/locks/semaphore.pxd +++ b/a_sync/primitives/locks/semaphore.pxd @@ -1,9 +1,10 @@ from a_sync.primitives._debug cimport _DebugDaemonMixin cdef class Semaphore(_DebugDaemonMixin): - cdef str _name cdef unsigned long long __value - cdef object _waiters + cdef object __waiters + cdef char* _name + cdef str decode_name(self) cdef set _decorated cdef dict __dict__ cpdef bint locked(self) diff --git a/a_sync/primitives/locks/semaphore.pyx b/a_sync/primitives/locks/semaphore.pyx index a40d6656..11e65b2e 100644 --- a/a_sync/primitives/locks/semaphore.pyx +++ b/a_sync/primitives/locks/semaphore.pyx @@ -6,6 +6,8 @@ a dummy semaphore that does nothing, and a threadsafe semaphore for use in multi import asyncio import functools from collections import defaultdict, deque +from libc.string cimport strcpy +from libc.stdlib cimport malloc, free from threading import Thread, current_thread from typing import Container @@ -50,10 +52,15 @@ cdef class Semaphore(_DebugDaemonMixin): """ def __cinit__(self) -> None: - self.__waiters = deque() + self._Semaphore__waiters = deque() self._decorated: Set[str] = set() - def __init__(self, value: int=1, name=None, loop=None, **kwargs) -> None: + def __init__( + self, + value: int = 1, + str name = "", + object loop = None, + ) -> None: """ Initialize the semaphore with a given value and optional name for debugging. @@ -66,14 +73,32 @@ cdef class Semaphore(_DebugDaemonMixin): raise ValueError("Semaphore initial value must be >= 0") try: - self.__value = value + self._Semaphore__value = value except OverflowError as e: + if value < 0: + raise ValueError("'value' must be a positive integer") from e.__cause__ raise OverflowError( {"error": str(e), "value": value, "max value": 18446744073709551615}, "If you need a Semaphore with a larger value, you should just use asyncio.Semaphore", ) from e.__cause__ - self._name = name or getattr(self, "__origin__", "") + # we need a constant to coerce to char* + cdef bytes encoded_name = (name or getattr(self, "__origin__", "")).encode("utf-8") + cdef Py_ssize_t length = len(encoded_name) + + # Allocate memory for the char* and add 1 for the null character + self._name = malloc(length + 1) + """An optional name for the counter, used in debug logs.""" + + if self._name == NULL: + raise MemoryError("Failed to allocate memory for __name.") + # Copy the bytes data into the char* + strcpy(self._name, encoded_name) + + def __dealloc__(self): + # Free the memory allocated for _name + if self._name is not NULL: + free(self._name) def __call__(self, fn: CoroFn[P, T]) -> CoroFn[P, T]: """ @@ -91,7 +116,7 @@ cdef class Semaphore(_DebugDaemonMixin): return self.decorate(fn) # type: ignore [arg-type, return-value] def __repr__(self) -> str: - representation = f"<{self.__class__.__name__} name={self._name} value={self.__value} waiters={len(self)}>" + representation = f"<{self.__class__.__name__} name={self.__name.decode('utf-8')} value={self._Semaphore__value} waiters={len(self)}>" if self._decorated: representation = f"{representation[:-1]} decorates={self._decorated}" return representation @@ -112,14 +137,14 @@ cdef class Semaphore(_DebugDaemonMixin): cdef bint c_locked(self): """Returns True if semaphore cannot be acquired immediately.""" - if self.__value == 0: + if self._Semaphore__value == 0: return True - waiters = self.__waiters + cdef object waiters = self._Semaphore__waiters if any(not w.cancelled() for w in (waiters or ())): return True def __len__(self) -> int: - return len(self.__waiters) + return len(self._Semaphore__waiters) def decorate(self, fn: CoroFn[P, T]) -> CoroFn[P, T]: """ @@ -171,11 +196,11 @@ cdef class Semaphore(_DebugDaemonMixin): Returns: True when the semaphore is successfully acquired. """ - if self.__value <= 0: + if self._Semaphore__value <= 0: self._c_ensure_debug_daemon((),{}) if not self.c_locked(): - self.__value -= 1 + self._Semaphore__value -= 1 return __acquire() return self.__acquire() @@ -186,19 +211,19 @@ cdef class Semaphore(_DebugDaemonMixin): # _wake_up_first() and attempt to wake up itself. cdef object fut = self._c_get_loop().create_future() - self.__waiters.append(fut) + self._Semaphore__waiters.append(fut) try: try: await fut finally: - self.__waiters.remove(fut) + self._Semaphore__waiters.remove(fut) except asyncio.exceptions.CancelledError: if not fut.cancelled(): - self.__value += 1 + self._Semaphore__value += 1 self._c_wake_up_next() raise - if self.__value > 0: + if self._Semaphore__value > 0: self._c_wake_up_next() return True @@ -209,56 +234,59 @@ cdef class Semaphore(_DebugDaemonMixin): When it was zero on entry and another coroutine is waiting for it to become larger than zero again, wake up that coroutine. """ - self.__value += 1 + self._Semaphore__value += 1 self._c_wake_up_next() cdef void c_release(self): - self.__value += 1 + self._Semaphore__value += 1 self._c_wake_up_next() @property def name(self) -> str: - return self._name + return self.decode_name() + + cdef str decode_name(self): + return (self._name or b"").decode("utf-8") @property def _value(self) -> int: # required for subclass compatability - return self.__value + return self._Semaphore__value @_value.setter def _value(self, unsigned long long value): # required for subclass compatability - self.__value = value + self._Semaphore__value = value @property - def _waiters(self) -> int: + def _waiters(self) -> List[asyncio.Future]: # required for subclass compatability - return self.__waiters + return self._Semaphore__waiters @_waiters.setter def _waiters(self, value: Container): # required for subclass compatability - self.__waiters = value + self._Semaphore__waiters = value cpdef void _wake_up_next(self): """Wake up the first waiter that isn't done.""" - if not self.__waiters: + if not self._Semaphore__waiters: return - for fut in self.__waiters: + for fut in self._Semaphore__waiters: if not fut.done(): - self.__value -= 1 + self._Semaphore__value -= 1 fut.set_result(True) return cdef void _c_wake_up_next(self): """Wake up the first waiter that isn't done.""" - if not self.__waiters: + if not self._Semaphore__waiters: return - for fut in self.__waiters: + for fut in self._Semaphore__waiters: if not fut.done(): - self.__value -= 1 + self._Semaphore__value -= 1 fut.set_result(True) return @@ -275,7 +303,7 @@ cdef class Semaphore(_DebugDaemonMixin): async def monitor(): await semaphore._debug_daemon() """ - while self.__waiters: + while self._Semaphore__waiters: await asyncio.sleep(60) self.get_logger().debug( "%s has %s waiters for any of: %s", @@ -300,8 +328,8 @@ cdef class DummySemaphore(Semaphore): """ def __cinit__(self): - self.__value = 0 - self.__waiters = deque() + self._Semaphore__value = 0 + self._Semaphore__waiters = deque() self._decorated = None def __init__(self, name: Optional[str] = None): @@ -311,10 +339,27 @@ cdef class DummySemaphore(Semaphore): Args: name (optional): An optional name for the dummy semaphore. """ - self._name = name + + # we need a constant to coerce to char* + cdef bytes encoded_name = (name or getattr(self, "__origin__", "")).encode("utf-8") + cdef Py_ssize_t length = len(encoded_name) + + # Allocate memory for the char* and add 1 for the null character + self._name = malloc(length + 1) + """An optional name for the counter, used in debug logs.""" + + if self._name == NULL: + raise MemoryError("Failed to allocate memory for __name.") + # Copy the bytes data into the char* + strcpy(self._name, encoded_name) + + def __dealloc__(self): + # Free the memory allocated for _name + if self._name is not NULL: + free(self._name) def __repr__(self) -> str: - return "<{} name={}>".format(self.__class__.__name__, self._name) + return "<{} name={}>".format(self.__class__.__name__, self.decode_name()) async def acquire(self) -> Literal[True]: """Acquire the dummy semaphore, which is a no-op.""" @@ -372,11 +417,12 @@ cdef class ThreadsafeSemaphore(Semaphore): self.semaphores = {} self.dummy = DummySemaphore(name=name) else: - self.semaphores: DefaultDict[Thread, Semaphore] = defaultdict(lambda: Semaphore(value, name=self._name)) # type: ignore [arg-type] + self.semaphores: DefaultDict[Thread, Semaphore] = defaultdict(lambda: Semaphore(value, name=name)) # type: ignore [arg-type] self.dummy = None def __len__(self) -> int: - return sum(len(sem._waiters) for sem in self.semaphores.values()) + cdef dict[object, Semaphore] semaphores = self.semaphores + return sum(len(sem) for sem in semaphores.values()) @property def semaphore(self) -> Semaphore: diff --git a/tests/primitives/test_prio_semaphore.py b/tests/primitives/test_prio_semaphore.py index dd5fb233..01a51d40 100644 --- a/tests/primitives/test_prio_semaphore.py +++ b/tests/primitives/test_prio_semaphore.py @@ -1,3 +1,4 @@ +import asyncio import pytest from a_sync import PrioritySemaphore @@ -8,6 +9,12 @@ def test_prio_semaphore_init(): with_name = PrioritySemaphore(10, name="test") assert with_name._value == 10 and with_name.name == "test" +@pytest.mark.asyncio_cooperative +async def test_prio_semaphore_count_waiters(): + semaphore = PrioritySemaphore(1) + tasks = [asyncio.create_task(semaphore[i].acquire()) for i in range(5)] + await asyncio.sleep(1) + print(semaphore) @pytest.mark.asyncio_cooperative async def test_prio_semaphore_use():