Skip to content

Commit

Permalink
Replace use of zmq Poller with better anyio task handling
Browse files Browse the repository at this point in the history
  • Loading branch information
ianthomas23 committed Sep 13, 2024
1 parent 97e3e91 commit e19273c
Showing 1 changed file with 40 additions and 26 deletions.
66 changes: 40 additions & 26 deletions ipykernel/subshell_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import zmq
import zmq.asyncio
from anyio import Event, create_task_group
from anyio import create_memory_object_stream, create_task_group

from .subshell import SubshellThread

Expand Down Expand Up @@ -50,14 +50,17 @@ def __init__(self, context: zmq.asyncio.Context, shell_socket):
self._parent_shell_channel_socket = self._create_inproc_pair_socket(None, True)
self._parent_other_socket = self._create_inproc_pair_socket(None, False)

self._poller = zmq.asyncio.Poller()
self._poller.register(self._parent_shell_channel_socket, zmq.POLLIN)
self._subshell_change = Event()
# anyio memory object stream for async queue-like communication between tasks.
# Used by _create_subshell to tell listen_from_subshells to spawn a new task.
self._send_stream, self._receive_stream = create_memory_object_stream[str]()

def close(self):
"""Stop all subshells and close all resources."""
assert current_thread().name == "Shell channel"

self._send_stream.close()
self._receive_stream.close()

for socket in (
self._control_shell_channel_socket,
self._control_other_socket,
Expand Down Expand Up @@ -126,14 +129,10 @@ async def listen_from_subshells(self):
"""
assert current_thread().name == "Shell channel"

while True:
async with create_task_group() as tg:
tg.start_soon(self._listen_from_subshells)
await self._subshell_change.wait()
tg.cancel_scope.cancel()

# anyio.Event is single use, so recreate to reuse
self._subshell_change = Event()
async with create_task_group() as tg:
tg.start_soon(self._listen_for_subshell_reply, None)
async for subshell_id in self._receive_stream:
tg.start_soon(self._listen_for_subshell_reply, subshell_id)

def subshell_id_from_thread_id(self, thread_id) -> str | None:
"""Return subshell_id of the specified thread_id.
Expand Down Expand Up @@ -173,14 +172,15 @@ async def _create_subshell(self, subshell_task) -> str:
shell_channel_socket = self._create_inproc_pair_socket(subshell_id, True)
self._cache[subshell_id] = Subshell(thread, shell_channel_socket)

# Tell task running listen_from_subshells to create a new task to listen for
# reply messages from the new subshell to resend to the client.
await self._send_stream.send(subshell_id)

address = self._get_inproc_socket_address(subshell_id)
thread.add_task(thread.create_pair_socket, self._context, address)
thread.add_task(subshell_task, subshell_id)
thread.start()

self._poller.register(shell_channel_socket, zmq.POLLIN)
self._subshell_change.set()

return subshell_id

def _delete_subshell(self, subshell_id: str) -> None:
Expand All @@ -199,22 +199,37 @@ def _get_inproc_socket_address(self, name: str | None):
full_name = f"subshell-{name}" if name else "subshell"
return f"inproc://{full_name}"

async def _listen_from_subshells(self):
"""Await next reply message from any subshell (parent or child) and resend
to the client via the shell_socket.
def _get_shell_channel_socket(self, subshell_id: str | None) -> zmq.asyncio.Socket:
if subshell_id is None:
return self._parent_shell_channel_socket
with self._lock:
return self._cache[subshell_id].shell_channel_socket

def _is_subshell(self, subshell_id: str | None) -> bool:
if subshell_id is None:
return True
with self._lock:
return subshell_id in self._cache

If a subshell is created or deleted then the poller is updated and the task
executing this function is cancelled and then rescheduled with the updated
poller.
async def _listen_for_subshell_reply(self, subshell_id: str | None):
"""Listen for reply messages on specified subshell inproc socket and
resend to the client via the shell_socket.
Runs in the shell channel thread.
"""
assert current_thread().name == "Shell channel"

while True:
for socket, _ in await self._poller.poll():
msg = await socket.recv_multipart(copy=False)
shell_channel_socket = self._get_shell_channel_socket(subshell_id)

try:
while True:
msg = await shell_channel_socket.recv_multipart(copy=False)
self._shell_socket.send_multipart(msg)
except BaseException as e:
if not self._is_subshell(subshell_id):
# Subshell no longer exists so exit gracefully
return
raise e

async def _process_control_request(self, request, subshell_task):
"""Process a control request message received on the control inproc
Expand Down Expand Up @@ -252,6 +267,5 @@ def _stop_subshell(self, subshell: Subshell):
thread.stop()
thread.join()

self._poller.unregister(subshell.shell_channel_socket)
self._subshell_change.set()
# Closing the shell_channel_socket terminates the task that is listening on it.
subshell.shell_channel_socket.close()

0 comments on commit e19273c

Please sign in to comment.