Skip to content

Commit

Permalink
Fix task cancellation propagation to subtasks when using sync middleware
Browse files Browse the repository at this point in the history
  • Loading branch information
ttys0dev committed Jan 16, 2024
1 parent 19e14e7 commit f2563b1
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 13 deletions.
51 changes: 39 additions & 12 deletions asgiref/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,11 @@ async def main_wrap(
if context is not None:
_restore_context(context[0])

current_task = asyncio.current_task()
if current_task is not None:
task_context = SyncToAsync.task_context.get()
task_context.append(current_task)

try:
# If we have an exception, run the function inside the except block
# after raising it so exc_info is correctly populated.
Expand All @@ -324,6 +329,8 @@ async def main_wrap(
else:
call_result.set_result(result)
finally:
if current_task is not None:
task_context.remove(current_task)
context[0] = contextvars.copy_context()


Expand Down Expand Up @@ -355,6 +362,10 @@ class SyncToAsync(Generic[_P, _R]):
# Single-thread executor for thread-sensitive code
single_thread_executor = ThreadPoolExecutor(max_workers=1)

task_context: "contextvars.ContextVar[List[asyncio.Task[Any]]]" = (
contextvars.ContextVar("task_context", default=[])
)

# Maintain a contextvar for the current execution context. Optionally used
# for thread sensitive mode.
thread_sensitive_context: "contextvars.ContextVar[ThreadSensitiveContext]" = (
Expand Down Expand Up @@ -438,19 +449,35 @@ async def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R:
child = functools.partial(self.func, *args, **kwargs)
func = context.run

# Run the code in the right thread
exec_coro = loop.run_in_executor(
executor,
functools.partial(
self.thread_handler,
loop,
sys.exc_info(),
func,
child,
),
)
ret: _R
try:
# Run the code in the right thread
ret: _R = await loop.run_in_executor(
executor,
functools.partial(
self.thread_handler,
loop,
sys.exc_info(),
func,
child,
),
)

ret = await asyncio.shield(exec_coro)
except asyncio.CancelledError:
tasks = self.task_context.get()
cancel_parent = True
cancel_parent_force = False
for task in tasks:
task.cancel()
for task in tasks:
try:
await task
cancel_parent = False
except asyncio.CancelledError:
cancel_parent_force = True
if cancel_parent or cancel_parent_force:
exec_coro.cancel()
ret = await exec_coro
finally:
_restore_context(context)
self.deadlock_context.set(False)
Expand Down
1 change: 0 additions & 1 deletion tests/test_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -852,7 +852,6 @@ def sync_task():


@pytest.mark.asyncio
@pytest.mark.skip(reason="deadlocks")
async def test_inner_shield_sync_middleware():
"""
Tests that asyncio.shield is capable of preventing http.disconnect from
Expand Down

0 comments on commit f2563b1

Please sign in to comment.