diff --git a/asgiref/sync.py b/asgiref/sync.py index 5406b7d3..92d25112 100644 --- a/asgiref/sync.py +++ b/asgiref/sync.py @@ -309,6 +309,11 @@ async def main_wrap( if context is not None: _restore_context(context[0]) + current_task = asyncio.current_task() + if current_task is not None: + task_context = SyncToAsync.task_context.get() + task_context.append(current_task) + try: # If we have an exception, run the function inside the except block # after raising it so exc_info is correctly populated. @@ -324,6 +329,8 @@ async def main_wrap( else: call_result.set_result(result) finally: + if current_task is not None: + task_context.remove(current_task) context[0] = contextvars.copy_context() @@ -355,6 +362,10 @@ class SyncToAsync(Generic[_P, _R]): # Single-thread executor for thread-sensitive code single_thread_executor = ThreadPoolExecutor(max_workers=1) + task_context: "contextvars.ContextVar[List[asyncio.Task[Any]]]" = ( + contextvars.ContextVar("task_context", default=[]) + ) + # Maintain a contextvar for the current execution context. Optionally used # for thread sensitive mode. thread_sensitive_context: "contextvars.ContextVar[ThreadSensitiveContext]" = ( @@ -438,19 +449,35 @@ async def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R: child = functools.partial(self.func, *args, **kwargs) func = context.run + # Run the code in the right thread + exec_coro = loop.run_in_executor( + executor, + functools.partial( + self.thread_handler, + loop, + sys.exc_info(), + func, + child, + ), + ) + ret: _R try: - # Run the code in the right thread - ret: _R = await loop.run_in_executor( - executor, - functools.partial( - self.thread_handler, - loop, - sys.exc_info(), - func, - child, - ), - ) - + ret = await asyncio.shield(exec_coro) + except asyncio.CancelledError: + tasks = self.task_context.get() + cancel_parent = True + cancel_parent_force = False + for task in tasks: + task.cancel() + for task in tasks: + try: + await task + cancel_parent = False + except asyncio.CancelledError: + cancel_parent_force = True + if cancel_parent or cancel_parent_force: + exec_coro.cancel() + ret = await exec_coro finally: _restore_context(context) self.deadlock_context.set(False) diff --git a/tests/test_sync.py b/tests/test_sync.py index daed8c4d..a6587f74 100644 --- a/tests/test_sync.py +++ b/tests/test_sync.py @@ -852,7 +852,6 @@ def sync_task(): @pytest.mark.asyncio -@pytest.mark.skip(reason="deadlocks") async def test_inner_shield_sync_middleware(): """ Tests that asyncio.shield is capable of preventing http.disconnect from