diff --git a/src/hypercorn/asyncio/tcp_server.py b/src/hypercorn/asyncio/tcp_server.py index ed9d710f..91f0a117 100644 --- a/src/hypercorn/asyncio/tcp_server.py +++ b/src/hypercorn/asyncio/tcp_server.py @@ -2,10 +2,10 @@ import asyncio from ssl import SSLError -from typing import Any, Generator, Optional +from typing import Any, Generator from .task_group import TaskGroup -from .worker_context import WorkerContext +from .worker_context import AsyncioSingleTask, WorkerContext from ..config import Config from ..events import Closed, Event, RawData, Updated from ..protocol import ProtocolWrapper @@ -33,9 +33,7 @@ def __init__( self.reader = reader self.writer = writer self.send_lock = asyncio.Lock() - self.idle_lock = asyncio.Lock() - - self._idle_handle: Optional[asyncio.Task] = None + self.idle_task = AsyncioSingleTask() def __await__(self) -> Generator[Any, None, None]: return self.run().__await__() @@ -54,6 +52,7 @@ async def run(self) -> None: alpn_protocol = "http/1.1" async with TaskGroup(self.loop) as task_group: + self._task_group = task_group self.protocol = ProtocolWrapper( self.app, self.config, @@ -66,7 +65,7 @@ async def run(self) -> None: alpn_protocol, ) await self.protocol.initiate() - await self._start_idle() + await self.idle_task.restart(task_group, self._idle_timeout) await self._read_data() except OSError: pass @@ -85,9 +84,9 @@ async def protocol_send(self, event: Event) -> None: await self._close() elif isinstance(event, Updated): if event.idle: - await self._start_idle() + await self.idle_task.restart(self._task_group, self._idle_timeout) else: - await self._stop_idle() + await self.idle_task.stop() async def _read_data(self) -> None: while not self.reader.at_eof(): @@ -124,28 +123,13 @@ async def _close(self) -> None: ): pass # Already closed finally: - await self._stop_idle() + await self.idle_task.stop() async def _initiate_server_close(self) -> None: await self.protocol.handle(Closed()) self.writer.close() - async def _start_idle(self) -> None: - async with self.idle_lock: - if self._idle_handle is None: - self._idle_handle = self.loop.create_task(self._run_idle()) - - async def _stop_idle(self) -> None: - async with self.idle_lock: - if self._idle_handle is not None: - self._idle_handle.cancel() - try: - await self._idle_handle - except asyncio.CancelledError: - pass - self._idle_handle = None - - async def _run_idle(self) -> None: + async def _idle_timeout(self) -> None: try: await asyncio.wait_for(self.context.terminated.wait(), self.config.keep_alive_timeout) except asyncio.TimeoutError: diff --git a/src/hypercorn/asyncio/worker_context.py b/src/hypercorn/asyncio/worker_context.py index d16f76ba..31e98779 100644 --- a/src/hypercorn/asyncio/worker_context.py +++ b/src/hypercorn/asyncio/worker_context.py @@ -1,9 +1,37 @@ from __future__ import annotations import asyncio -from typing import Optional, Type, Union +from typing import Callable, Optional, Type, Union -from ..typing import Event +from ..typing import Event, SingleTask, TaskGroup + + +class AsyncioSingleTask: + def __init__(self) -> None: + self._handle: Optional[asyncio.Task] = None + self._lock = asyncio.Lock() + + async def restart(self, task_group: TaskGroup, action: Callable) -> None: + async with self._lock: + if self._handle is not None: + self._handle.cancel() + try: + await self._handle + except asyncio.CancelledError: + pass + + self._handle = task_group._task_group.create_task(action()) # type: ignore + + async def stop(self) -> None: + async with self._lock: + if self._handle is not None: + self._handle.cancel() + try: + await self._handle + except asyncio.CancelledError: + pass + + self._handle = None class EventWrapper: @@ -25,6 +53,7 @@ def is_set(self) -> bool: class WorkerContext: event_class: Type[Event] = EventWrapper + single_task_class: Type[SingleTask] = AsyncioSingleTask def __init__(self, max_requests: Optional[int]) -> None: self.max_requests = max_requests diff --git a/src/hypercorn/protocol/quic.py b/src/hypercorn/protocol/quic.py index 3d16e54d..98bd8b48 100644 --- a/src/hypercorn/protocol/quic.py +++ b/src/hypercorn/protocol/quic.py @@ -1,7 +1,8 @@ from __future__ import annotations +from dataclasses import dataclass from functools import partial -from typing import Awaitable, Callable, Dict, Optional, Tuple +from typing import Awaitable, Callable, Dict, Optional, Set, Tuple from aioquic.buffer import Buffer from aioquic.h3.connection import H3_ALPN @@ -22,7 +23,15 @@ from .h3 import H3Protocol from ..config import Config from ..events import Closed, Event, RawData -from ..typing import AppWrapper, TaskGroup, WorkerContext +from ..typing import AppWrapper, SingleTask, TaskGroup, WorkerContext + + +@dataclass +class _Connection: + cids: Set[bytes] + quic: QuicConnection + task: SingleTask + h3: Optional[H3Protocol] = None class QuicProtocol: @@ -38,8 +47,7 @@ def __init__( self.app = app self.config = config self.context = context - self.connections: Dict[bytes, QuicConnection] = {} - self.http_connections: Dict[QuicConnection, H3Protocol] = {} + self.connections: Dict[bytes, _Connection] = {} self.send = send self.server = server self.task_group = task_group @@ -49,7 +57,7 @@ def __init__( @property def idle(self) -> bool: - return len(self.connections) == 0 and len(self.http_connections) == 0 + return len(self.connections) == 0 async def handle(self, event: Event) -> None: if isinstance(event, RawData): @@ -76,32 +84,46 @@ async def handle(self, event: Event) -> None: and header.packet_type == PACKET_TYPE_INITIAL and not self.context.terminated.is_set() ): - connection = QuicConnection( + quic_connection = QuicConnection( configuration=self.quic_config, original_destination_connection_id=header.destination_cid, ) + connection = _Connection( + cids={header.destination_cid, quic_connection.host_cid}, + quic=quic_connection, + task=self.context.single_task_class(), + ) self.connections[header.destination_cid] = connection - self.connections[connection.host_cid] = connection + self.connections[quic_connection.host_cid] = connection if connection is not None: - connection.receive_datagram(event.data, event.address, now=self.context.time()) + connection.quic.receive_datagram(event.data, event.address, now=self.context.time()) await self._handle_events(connection, event.address) elif isinstance(event, Closed): pass - async def send_all(self, connection: QuicConnection) -> None: - for data, address in connection.datagrams_to_send(now=self.context.time()): + async def send_all(self, connection: _Connection) -> None: + for data, address in connection.quic.datagrams_to_send(now=self.context.time()): await self.send(RawData(data=data, address=address)) + timer = connection.quic.get_timer() + if timer is not None: + await connection.task.restart( + self.task_group, partial(self._handle_timer, timer, connection) + ) + async def _handle_events( - self, connection: QuicConnection, client: Optional[Tuple[str, int]] = None + self, connection: _Connection, client: Optional[Tuple[str, int]] = None ) -> None: - event = connection.next_event() + event = connection.quic.next_event() while event is not None: if isinstance(event, ConnectionTerminated): - pass + await connection.task.stop() + for cid in connection.cids: + del self.connections[cid] + connection.cids = set() elif isinstance(event, ProtocolNegotiated): - self.http_connections[connection] = H3Protocol( + connection.h3 = H3Protocol( self.app, self.config, self.context, @@ -112,24 +134,22 @@ async def _handle_events( partial(self.send_all, connection), ) elif isinstance(event, ConnectionIdIssued): + connection.cids.add(event.connection_id) self.connections[event.connection_id] = connection elif isinstance(event, ConnectionIdRetired): + connection.cids.remove(event.connection_id) del self.connections[event.connection_id] - if connection in self.http_connections: - await self.http_connections[connection].handle(event) + if connection.h3 is not None: + await connection.h3.handle(event) - event = connection.next_event() + event = connection.quic.next_event() await self.send_all(connection) - timer = connection.get_timer() - if timer is not None: - self.task_group.spawn(self._handle_timer, timer, connection) - - async def _handle_timer(self, timer: float, connection: QuicConnection) -> None: + async def _handle_timer(self, timer: float, connection: _Connection) -> None: wait = max(0, timer - self.context.time()) await self.context.sleep(wait) - if connection._close_at is not None: - connection.handle_timer(now=self.context.time()) + if connection.quic._close_at is not None: + connection.quic.handle_timer(now=self.context.time()) await self._handle_events(connection, None) diff --git a/src/hypercorn/trio/tcp_server.py b/src/hypercorn/trio/tcp_server.py index dbcc7a12..5c6c68e9 100644 --- a/src/hypercorn/trio/tcp_server.py +++ b/src/hypercorn/trio/tcp_server.py @@ -1,12 +1,12 @@ from __future__ import annotations from math import inf -from typing import Any, Generator, Optional +from typing import Any, Generator import trio from .task_group import TaskGroup -from .worker_context import WorkerContext +from .worker_context import TrioSingleTask, WorkerContext from ..config import Config from ..events import Closed, Event, RawData, Updated from ..protocol import ProtocolWrapper @@ -25,11 +25,9 @@ def __init__( self.context = context self.protocol: ProtocolWrapper self.send_lock = trio.Lock() - self.idle_lock = trio.Lock() + self.idle_task = TrioSingleTask() self.stream = stream - self._idle_handle: Optional[trio.CancelScope] = None - def __await__(self) -> Generator[Any, None, None]: return self.run().__await__() @@ -66,7 +64,7 @@ async def run(self) -> None: alpn_protocol, ) await self.protocol.initiate() - await self._start_idle() + await self.idle_task.restart(self._task_group, self._idle_timeout) await self._read_data() except OSError: pass @@ -87,9 +85,9 @@ async def protocol_send(self, event: Event) -> None: await self.protocol.handle(Closed()) elif isinstance(event, Updated): if event.idle: - await self._start_idle() + await self.idle_task.restart(self._task_group, self._idle_timeout) else: - await self._stop_idle() + await self.idle_task.stop() async def _read_data(self) -> None: while True: @@ -122,30 +120,13 @@ async def _close(self) -> None: pass await self.stream.aclose() + async def _idle_timeout(self) -> None: + with trio.move_on_after(self.config.keep_alive_timeout): + await self.context.terminated.wait() + + with trio.CancelScope(shield=True): + await self._initiate_server_close() + async def _initiate_server_close(self) -> None: await self.protocol.handle(Closed()) await self.stream.aclose() - - async def _start_idle(self) -> None: - async with self.idle_lock: - if self._idle_handle is None: - self._idle_handle = await self._task_group._nursery.start(self._run_idle) - - async def _stop_idle(self) -> None: - async with self.idle_lock: - if self._idle_handle is not None: - self._idle_handle.cancel() - self._idle_handle = None - - async def _run_idle( - self, - task_status: trio._core._run._TaskStatus = trio.TASK_STATUS_IGNORED, - ) -> None: - cancel_scope = trio.CancelScope() - task_status.started(cancel_scope) - with cancel_scope: - with trio.move_on_after(self.config.keep_alive_timeout): - await self.context.terminated.wait() - - cancel_scope.shield = True - await self._initiate_server_close() diff --git a/src/hypercorn/trio/worker_context.py b/src/hypercorn/trio/worker_context.py index c09c4fb6..dddcf427 100644 --- a/src/hypercorn/trio/worker_context.py +++ b/src/hypercorn/trio/worker_context.py @@ -1,10 +1,42 @@ from __future__ import annotations -from typing import Optional, Type, Union +from functools import wraps +from typing import Awaitable, Callable, Optional, Type, Union import trio -from ..typing import Event +from ..typing import Event, SingleTask, TaskGroup + + +def _cancel_wrapper(func: Callable[[], Awaitable[None]]) -> Callable[[], Awaitable[None]]: + @wraps(func) + async def wrapper( + task_status: trio._core._run._TaskStatus = trio.TASK_STATUS_IGNORED, + ) -> None: + cancel_scope = trio.CancelScope() + task_status.started(cancel_scope) + with cancel_scope: + await func() + + return wrapper + + +class TrioSingleTask: + def __init__(self) -> None: + self._handle: Optional[trio.CancelScope] = None + self._lock = trio.Lock() + + async def restart(self, task_group: TaskGroup, action: Callable) -> None: + async with self._lock: + if self._handle is not None: + self._handle.cancel() + self._handle = await task_group._nursery.start(_cancel_wrapper(action)) # type: ignore + + async def stop(self) -> None: + async with self._lock: + if self._handle is not None: + self._handle.cancel() + self._handle = None class EventWrapper: @@ -26,6 +58,7 @@ def is_set(self) -> bool: class WorkerContext: event_class: Type[Event] = EventWrapper + single_task_class: Type[SingleTask] = TrioSingleTask def __init__(self, max_requests: Optional[int]) -> None: self.max_requests = max_requests diff --git a/src/hypercorn/typing.py b/src/hypercorn/typing.py index ccbb50dc..e733d7ad 100644 --- a/src/hypercorn/typing.py +++ b/src/hypercorn/typing.py @@ -290,6 +290,7 @@ def is_set(self) -> bool: class WorkerContext(Protocol): event_class: Type[Event] + single_task_class: Type[SingleTask] terminate: Event terminated: Event @@ -340,3 +341,14 @@ async def __call__( call_soon: Callable, ) -> None: pass + + +class SingleTask(Protocol): + def __init__(self) -> None: + pass + + async def restart(self, task_group: TaskGroup, action: Callable) -> None: + pass + + async def stop(self) -> None: + pass