Skip to content

Commit

Permalink
Improve typing against Trio
Browse files Browse the repository at this point in the history
  • Loading branch information
jakkdl authored and pgjones committed Jun 4, 2024
1 parent 4cf3528 commit 84d06b8
Show file tree
Hide file tree
Showing 13 changed files with 80 additions and 46 deletions.
9 changes: 8 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,16 @@ warn_unused_configs = true
warn_unused_ignores = true

[[tool.mypy.overrides]]
module =["aioquic.*", "cryptography.*", "h11.*", "h2.*", "priority.*", "pytest_asyncio.*", "trio.*", "uvloop.*"]
module =["aioquic.*", "cryptography.*", "h11.*", "h2.*", "priority.*", "pytest_asyncio.*", "uvloop.*"]
ignore_missing_imports = true

[[tool.mypy.overrides]]
module = ["trio.*", "tests.trio.*"]
disallow_any_generics = true
disallow_untyped_calls = true
strict_optional = true
warn_return_any = true

[tool.pytest.ini_options]
addopts = "--no-cov-on-fail --showlocals --strict-markers"
asyncio_mode = "strict"
Expand Down
6 changes: 4 additions & 2 deletions src/hypercorn/middleware/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import Callable, Dict

from ..asyncio.task_group import TaskGroup
from ..typing import ASGIFramework, Scope
from ..typing import ASGIFramework, ASGIReceiveEvent, Scope

MAX_QUEUE_SIZE = 10

Expand Down Expand Up @@ -74,7 +74,9 @@ class TrioDispatcherMiddleware(_DispatcherMiddleware):
async def _handle_lifespan(self, scope: Scope, receive: Callable, send: Callable) -> None:
import trio

self.app_queues = {path: trio.open_memory_channel(MAX_QUEUE_SIZE) for path in self.mounts}
self.app_queues = {
path: trio.open_memory_channel[ASGIReceiveEvent](MAX_QUEUE_SIZE) for path in self.mounts
}
self.startup_complete = {path: False for path in self.mounts}
self.shutdown_complete = {path: False for path in self.mounts}

Expand Down
2 changes: 1 addition & 1 deletion src/hypercorn/trio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ async def serve(
config: Config,
*,
shutdown_trigger: Optional[Callable[..., Awaitable[None]]] = None,
task_status: trio._core._run._TaskStatus = trio.TASK_STATUS_IGNORED,
task_status: trio.TaskStatus = trio.TASK_STATUS_IGNORED,
mode: Optional[Literal["asgi", "wsgi"]] = None,
) -> None:
"""Serve an ASGI framework app given the config.
Expand Down
8 changes: 4 additions & 4 deletions src/hypercorn/trio/lifespan.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,14 @@ def __init__(self, app: AppWrapper, config: Config, state: LifespanState) -> Non
self.config = config
self.startup = trio.Event()
self.shutdown = trio.Event()
self.app_send_channel, self.app_receive_channel = trio.open_memory_channel(
config.max_app_queue_size
)
self.app_send_channel, self.app_receive_channel = trio.open_memory_channel[
ASGIReceiveEvent
](config.max_app_queue_size)
self.state = state
self.supported = True

async def handle_lifespan(
self, *, task_status: trio._core._run._TaskStatus = trio.TASK_STATUS_IGNORED
self, *, task_status: trio.TaskStatus = trio.TASK_STATUS_IGNORED
) -> None:
task_status.started()
scope: LifespanScope = {
Expand Down
4 changes: 2 additions & 2 deletions src/hypercorn/trio/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ async def worker_serve(
*,
sockets: Optional[Sockets] = None,
shutdown_trigger: Optional[Callable[..., Awaitable[None]]] = None,
task_status: trio._core._run._TaskStatus = trio.TASK_STATUS_IGNORED,
task_status: trio.TaskStatus = trio.TASK_STATUS_IGNORED,
) -> None:
config.set_statsd_logger_class(StatsdLogger)

Expand All @@ -57,7 +57,7 @@ async def worker_serve(
sock.listen(config.backlog)

ssl_context = config.create_ssl_context()
listeners = []
listeners: list[trio.SSLListener[trio.SocketStream] | trio.SocketListener] = []
binds = []
for sock in sockets.secure_sockets:
listeners.append(
Expand Down
9 changes: 6 additions & 3 deletions src/hypercorn/trio/task_group.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import sys
from contextlib import AbstractAsyncContextManager
from types import TracebackType
from typing import Any, Awaitable, Callable, Optional

Expand Down Expand Up @@ -41,8 +42,8 @@ async def _handle(

class TaskGroup:
def __init__(self) -> None:
self._nursery: Optional[trio._core._run.Nursery] = None
self._nursery_manager: Optional[trio._core._run.NurseryManager] = None
self._nursery: trio.Nursery | None = None
self._nursery_manager: AbstractAsyncContextManager[trio.Nursery] | None = None

async def spawn_app(
self,
Expand All @@ -51,7 +52,9 @@ async def spawn_app(
scope: Scope,
send: Callable[[Optional[ASGISendEvent]], Awaitable[None]],
) -> Callable[[ASGIReceiveEvent], Awaitable[None]]:
app_send_channel, app_receive_channel = trio.open_memory_channel(config.max_app_queue_size)
app_send_channel, app_receive_channel = trio.open_memory_channel[ASGIReceiveEvent](
config.max_app_queue_size
)
self._nursery.start_soon(
_handle,
app,
Expand Down
2 changes: 1 addition & 1 deletion src/hypercorn/trio/tcp_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def __init__(
config: Config,
context: WorkerContext,
state: LifespanState,
stream: trio.abc.Stream,
stream: trio.SSLStream[trio.SocketStream],
) -> None:
self.app = app
self.config = config
Expand Down
8 changes: 4 additions & 4 deletions src/hypercorn/trio/udp_server.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

import socket

import trio

from .task_group import TaskGroup
Expand All @@ -19,17 +21,15 @@ def __init__(
config: Config,
context: WorkerContext,
state: LifespanState,
socket: trio.socket.socket,
socket: socket.socket,
) -> None:
self.app = app
self.config = config
self.context = context
self.socket = trio.socket.from_stdlib_socket(socket)
self.state = state

async def run(
self, task_status: trio._core._run._TaskStatus = trio.TASK_STATUS_IGNORED
) -> None:
async def run(self, task_status: trio.TaskStatus = trio.TASK_STATUS_IGNORED) -> None:
from ..protocol.quic import QuicProtocol # h3/Quic is an optional part of Hypercorn

task_status.started()
Expand Down
2 changes: 1 addition & 1 deletion src/hypercorn/trio/worker_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
def _cancel_wrapper(func: Callable[[], Awaitable[None]]) -> Callable[[], Awaitable[None]]:
@wraps(func)
async def wrapper(
task_status: trio._core._run._TaskStatus = trio.TASK_STATUS_IGNORED,
task_status: trio.TaskStatus = trio.TASK_STATUS_IGNORED,
) -> None:
cancel_scope = trio.CancelScope()
task_status.started(cancel_scope)
Expand Down
6 changes: 3 additions & 3 deletions tests/test_app_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import trio

from hypercorn.app_wrappers import _build_environ, InvalidPathError, WSGIWrapper
from hypercorn.typing import ASGISendEvent, ConnectionState, HTTPScope
from hypercorn.typing import ASGIReceiveEvent, ASGISendEvent, ConnectionState, HTTPScope


def echo_body(environ: dict, start_response: Callable) -> List[bytes]:
Expand Down Expand Up @@ -41,8 +41,8 @@ async def test_wsgi_trio() -> None:
"extensions": {},
"state": ConnectionState({}),
}
send_channel, receive_channel = trio.open_memory_channel(1)
await send_channel.send({"type": "http.request"})
send_channel, receive_channel = trio.open_memory_channel[ASGIReceiveEvent](1)
await send_channel.send({"type": "http.request"}) # type: ignore

messages = []

Expand Down
54 changes: 35 additions & 19 deletions tests/trio/test_keep_alive.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Callable, Generator
from typing import Awaitable, Callable, cast, Generator

import h11
import pytest
Expand All @@ -10,22 +10,36 @@
from hypercorn.config import Config
from hypercorn.trio.tcp_server import TCPServer
from hypercorn.trio.worker_context import WorkerContext
from hypercorn.typing import Scope
from hypercorn.typing import ASGIReceiveEvent, ASGISendEvent, Scope
from ..helpers import MockSocket

try:
from typing import TypeAlias
except ImportError:
from typing_extensions import TypeAlias


KEEP_ALIVE_TIMEOUT = 0.01
REQUEST = h11.Request(method="GET", target="/", headers=[(b"host", b"hypercorn")])

ClientStream: TypeAlias = trio.StapledStream[
trio.testing.MemorySendStream, trio.testing.MemoryReceiveStream
]


async def slow_framework(scope: Scope, receive: Callable, send: Callable) -> None:
async def slow_framework(
scope: Scope,
receive: Callable[[], Awaitable[ASGIReceiveEvent]],
send: Callable[[ASGISendEvent], Awaitable[None]],
) -> None:
while True:
event = await receive()
if event["type"] == "http.disconnect":
break
elif event["type"] == "lifespan.startup":
await send({"type": "lifspan.startup.complete"})
await send({"type": "lifespan.startup.complete"})
elif event["type"] == "lifespan.shutdown":
await send({"type": "lifspan.shutdown.complete"})
await send({"type": "lifespan.shutdown.complete"})
elif event["type"] == "http.request" and not event.get("more_body", False):
await trio.sleep(2 * KEEP_ALIVE_TIMEOUT)
await send(
Expand All @@ -41,21 +55,20 @@ async def slow_framework(scope: Scope, receive: Callable, send: Callable) -> Non

@pytest.fixture(name="client_stream", scope="function")
def _client_stream(
nursery: trio._core._run.Nursery,
) -> Generator[trio.testing._memory_streams.MemorySendStream, None, None]:
nursery: trio.Nursery,
) -> Generator[ClientStream, None, None]:
config = Config()
config.keep_alive_timeout = KEEP_ALIVE_TIMEOUT
client_stream, server_stream = trio.testing.memory_stream_pair()
server_stream = cast("trio.SSLStream[trio.SocketStream]", server_stream)
server_stream.socket = MockSocket()
server = TCPServer(ASGIWrapper(slow_framework), config, WorkerContext(None), {}, server_stream)
nursery.start_soon(server.run)
yield client_stream


@pytest.mark.trio
async def test_http1_keep_alive_pre_request(
client_stream: trio.testing._memory_streams.MemorySendStream,
) -> None:
async def test_http1_keep_alive_pre_request(client_stream: ClientStream) -> None:
await client_stream.send_all(b"GET")
await trio.sleep(2 * KEEP_ALIVE_TIMEOUT)
# Only way to confirm closure is to invoke an error
Expand All @@ -65,23 +78,26 @@ async def test_http1_keep_alive_pre_request(

@pytest.mark.trio
async def test_http1_keep_alive_during(
client_stream: trio.testing._memory_streams.MemorySendStream,
client_stream: ClientStream,
) -> None:
client = h11.Connection(h11.CLIENT)
await client_stream.send_all(client.send(REQUEST))
# client.send(h11.Request) and client.send(h11.EndOfMessage) only returns bytes.
# Fixed on master/ in the h11 repo, once released the ignore's can be removed.
# See https://github.com/python-hyper/h11/issues/175
await client_stream.send_all(client.send(REQUEST)) # type: ignore[arg-type]
await trio.sleep(2 * KEEP_ALIVE_TIMEOUT)
# Key is that this doesn't error
await client_stream.send_all(client.send(h11.EndOfMessage()))
await client_stream.send_all(client.send(h11.EndOfMessage())) # type: ignore[arg-type]


@pytest.mark.trio
async def test_http1_keep_alive(
client_stream: trio.testing._memory_streams.MemorySendStream,
client_stream: ClientStream,
) -> None:
client = h11.Connection(h11.CLIENT)
await client_stream.send_all(client.send(REQUEST))
await client_stream.send_all(client.send(REQUEST)) # type: ignore[arg-type]
await trio.sleep(2 * KEEP_ALIVE_TIMEOUT)
await client_stream.send_all(client.send(h11.EndOfMessage()))
await client_stream.send_all(client.send(h11.EndOfMessage())) # type: ignore[arg-type]
while True:
event = client.next_event()
if event == h11.NEED_DATA:
Expand All @@ -90,15 +106,15 @@ async def test_http1_keep_alive(
elif isinstance(event, h11.EndOfMessage):
break
client.start_next_cycle()
await client_stream.send_all(client.send(REQUEST))
await client_stream.send_all(client.send(REQUEST)) # type: ignore[arg-type]
await trio.sleep(2 * KEEP_ALIVE_TIMEOUT)
# Key is that this doesn't error
await client_stream.send_all(client.send(h11.EndOfMessage()))
await client_stream.send_all(client.send(h11.EndOfMessage())) # type: ignore[arg-type]


@pytest.mark.trio
async def test_http1_keep_alive_pipelining(
client_stream: trio.testing._memory_streams.MemorySendStream,
client_stream: ClientStream,
) -> None:
await client_stream.send_all(
b"GET / HTTP/1.1\r\nHost: hypercorn\r\n\r\nGET / HTTP/1.1\r\nHost: hypercorn\r\n\r\n"
Expand Down
15 changes: 10 additions & 5 deletions tests/trio/test_sanity.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from typing import cast
from unittest.mock import Mock, PropertyMock

import h2
Expand All @@ -24,14 +25,15 @@
@pytest.mark.trio
async def test_http1_request(nursery: trio._core._run.Nursery) -> None:
client_stream, server_stream = trio.testing.memory_stream_pair()
server_stream = cast("trio.SSLStream[trio.SocketStream]", server_stream)
server_stream.socket = MockSocket()
server = TCPServer(
ASGIWrapper(sanity_framework), Config(), WorkerContext(None), {}, server_stream
)
nursery.start_soon(server.run)
client = h11.Connection(h11.CLIENT)
await client_stream.send_all(
client.send(
client.send( # type: ignore[arg-type]
h11.Request(
method="POST",
target="/",
Expand All @@ -43,8 +45,8 @@ async def test_http1_request(nursery: trio._core._run.Nursery) -> None:
)
)
)
await client_stream.send_all(client.send(h11.Data(data=SANITY_BODY)))
await client_stream.send_all(client.send(h11.EndOfMessage()))
await client_stream.send_all(client.send(h11.Data(data=SANITY_BODY))) # type: ignore[arg-type]
await client_stream.send_all(client.send(h11.EndOfMessage())) # type: ignore[arg-type]
events = []
while True:
event = client.next_event()
Expand Down Expand Up @@ -77,6 +79,7 @@ async def test_http1_request(nursery: trio._core._run.Nursery) -> None:
@pytest.mark.trio
async def test_http1_websocket(nursery: trio._core._run.Nursery) -> None:
client_stream, server_stream = trio.testing.memory_stream_pair()
server_stream = cast("trio.SSLStream[trio.SocketStream]", server_stream)
server_stream.socket = MockSocket()
server = TCPServer(
ASGIWrapper(sanity_framework), Config(), WorkerContext(None), {}, server_stream
Expand Down Expand Up @@ -104,8 +107,9 @@ async def test_http1_websocket(nursery: trio._core._run.Nursery) -> None:
@pytest.mark.trio
async def test_http2_request(nursery: trio._core._run.Nursery) -> None:
client_stream, server_stream = trio.testing.memory_stream_pair()
server_stream = cast("trio.SSLStream[trio.SocketStream]", server_stream)
server_stream.transport_stream = Mock(return_value=PropertyMock(return_value=MockSocket()))
server_stream.do_handshake = AsyncMock()
server_stream.do_handshake = AsyncMock() # type: ignore[method-assign]
server_stream.selected_alpn_protocol = Mock(return_value="h2")
server = TCPServer(
ASGIWrapper(sanity_framework), Config(), WorkerContext(None), {}, server_stream
Expand Down Expand Up @@ -161,8 +165,9 @@ async def test_http2_request(nursery: trio._core._run.Nursery) -> None:
@pytest.mark.trio
async def test_http2_websocket(nursery: trio._core._run.Nursery) -> None:
client_stream, server_stream = trio.testing.memory_stream_pair()
server_stream = cast("trio.SSLStream[trio.SocketStream]", server_stream)
server_stream.transport_stream = Mock(return_value=PropertyMock(return_value=MockSocket()))
server_stream.do_handshake = AsyncMock()
server_stream.do_handshake = AsyncMock() # type: ignore[method-assign]
server_stream.selected_alpn_protocol = Mock(return_value="h2")
server = TCPServer(
ASGIWrapper(sanity_framework), Config(), WorkerContext(None), {}, server_stream
Expand Down
1 change: 1 addition & 0 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ basepython = python3.12
deps =
mypy
pytest
trio
commands =
mypy src/hypercorn/ tests/

Expand Down

0 comments on commit 84d06b8

Please sign in to comment.