From e19273cff732771f9b00774dd94319ca47b8f70e Mon Sep 17 00:00:00 2001 From: Ian Thomas Date: Fri, 13 Sep 2024 12:47:13 +0100 Subject: [PATCH] Replace use of zmq Poller with better anyio task handling --- ipykernel/subshell_manager.py | 66 +++++++++++++++++++++-------------- 1 file changed, 40 insertions(+), 26 deletions(-) diff --git a/ipykernel/subshell_manager.py b/ipykernel/subshell_manager.py index e91ec636..00809c7f 100644 --- a/ipykernel/subshell_manager.py +++ b/ipykernel/subshell_manager.py @@ -11,7 +11,7 @@ import zmq import zmq.asyncio -from anyio import Event, create_task_group +from anyio import create_memory_object_stream, create_task_group from .subshell import SubshellThread @@ -50,14 +50,17 @@ def __init__(self, context: zmq.asyncio.Context, shell_socket): self._parent_shell_channel_socket = self._create_inproc_pair_socket(None, True) self._parent_other_socket = self._create_inproc_pair_socket(None, False) - self._poller = zmq.asyncio.Poller() - self._poller.register(self._parent_shell_channel_socket, zmq.POLLIN) - self._subshell_change = Event() + # anyio memory object stream for async queue-like communication between tasks. + # Used by _create_subshell to tell listen_from_subshells to spawn a new task. + self._send_stream, self._receive_stream = create_memory_object_stream[str]() def close(self): """Stop all subshells and close all resources.""" assert current_thread().name == "Shell channel" + self._send_stream.close() + self._receive_stream.close() + for socket in ( self._control_shell_channel_socket, self._control_other_socket, @@ -126,14 +129,10 @@ async def listen_from_subshells(self): """ assert current_thread().name == "Shell channel" - while True: - async with create_task_group() as tg: - tg.start_soon(self._listen_from_subshells) - await self._subshell_change.wait() - tg.cancel_scope.cancel() - - # anyio.Event is single use, so recreate to reuse - self._subshell_change = Event() + async with create_task_group() as tg: + tg.start_soon(self._listen_for_subshell_reply, None) + async for subshell_id in self._receive_stream: + tg.start_soon(self._listen_for_subshell_reply, subshell_id) def subshell_id_from_thread_id(self, thread_id) -> str | None: """Return subshell_id of the specified thread_id. @@ -173,14 +172,15 @@ async def _create_subshell(self, subshell_task) -> str: shell_channel_socket = self._create_inproc_pair_socket(subshell_id, True) self._cache[subshell_id] = Subshell(thread, shell_channel_socket) + # Tell task running listen_from_subshells to create a new task to listen for + # reply messages from the new subshell to resend to the client. + await self._send_stream.send(subshell_id) + address = self._get_inproc_socket_address(subshell_id) thread.add_task(thread.create_pair_socket, self._context, address) thread.add_task(subshell_task, subshell_id) thread.start() - self._poller.register(shell_channel_socket, zmq.POLLIN) - self._subshell_change.set() - return subshell_id def _delete_subshell(self, subshell_id: str) -> None: @@ -199,22 +199,37 @@ def _get_inproc_socket_address(self, name: str | None): full_name = f"subshell-{name}" if name else "subshell" return f"inproc://{full_name}" - async def _listen_from_subshells(self): - """Await next reply message from any subshell (parent or child) and resend - to the client via the shell_socket. + def _get_shell_channel_socket(self, subshell_id: str | None) -> zmq.asyncio.Socket: + if subshell_id is None: + return self._parent_shell_channel_socket + with self._lock: + return self._cache[subshell_id].shell_channel_socket + + def _is_subshell(self, subshell_id: str | None) -> bool: + if subshell_id is None: + return True + with self._lock: + return subshell_id in self._cache - If a subshell is created or deleted then the poller is updated and the task - executing this function is cancelled and then rescheduled with the updated - poller. + async def _listen_for_subshell_reply(self, subshell_id: str | None): + """Listen for reply messages on specified subshell inproc socket and + resend to the client via the shell_socket. Runs in the shell channel thread. """ assert current_thread().name == "Shell channel" - while True: - for socket, _ in await self._poller.poll(): - msg = await socket.recv_multipart(copy=False) + shell_channel_socket = self._get_shell_channel_socket(subshell_id) + + try: + while True: + msg = await shell_channel_socket.recv_multipart(copy=False) self._shell_socket.send_multipart(msg) + except BaseException as e: + if not self._is_subshell(subshell_id): + # Subshell no longer exists so exit gracefully + return + raise e async def _process_control_request(self, request, subshell_task): """Process a control request message received on the control inproc @@ -252,6 +267,5 @@ def _stop_subshell(self, subshell: Subshell): thread.stop() thread.join() - self._poller.unregister(subshell.shell_channel_socket) - self._subshell_change.set() + # Closing the shell_channel_socket terminates the task that is listening on it. subshell.shell_channel_socket.close()