From b75487fe0a85a80b545757fc69a5598b262a54e2 Mon Sep 17 00:00:00 2001 From: Spencer Brown Date: Tue, 8 Aug 2023 18:40:46 +1000 Subject: [PATCH 01/18] Add types to trio.testing._trio_test --- trio/testing/_trio_test.py | 41 ++++++++++++++++++++++++++++---------- 1 file changed, 30 insertions(+), 11 deletions(-) diff --git a/trio/testing/_trio_test.py b/trio/testing/_trio_test.py index b4ef69ef09..177648d552 100644 --- a/trio/testing/_trio_test.py +++ b/trio/testing/_trio_test.py @@ -1,20 +1,37 @@ +from __future__ import annotations +from typing import TYPE_CHECKING, TypeVar +from collections.abc import Callable, Awaitable from functools import partial, wraps + from .. import _core from ..abc import Clock, Instrument -# Use: -# -# @trio_test -# async def test_whatever(): -# await ... -# -# Also: if a pytest fixture is passed in that subclasses the Clock abc, then -# that clock is passed to trio.run(). -def trio_test(fn): +if TYPE_CHECKING: + from typing_extensions import ParamSpec + + ArgsT = ParamSpec("ArgsT") + + +RetT = TypeVar("RetT") + + +def trio_test(fn: Callable[ArgsT, Awaitable[RetT]]) -> Callable[ArgsT, RetT]: + """Converts an async test function to be synchronous, running via Trio. + + Usage:: + + @trio_test + async def test_whatever(): + await ... + + If a pytest fixture is passed in that subclasses the :class:`~trio.abc.Clock` or + :class:`~trio.abc.Instrument` ABCs, then those are passed to :meth:`trio.run()`. + """ + @wraps(fn) - def wrapper(**kwargs): + def wrapper(*args: ArgsT.args, **kwargs: ArgsT.kwargs) -> RetT: __tracebackhide__ = True clocks = [c for c in kwargs.values() if isinstance(c, Clock)] if not clocks: @@ -24,6 +41,8 @@ def wrapper(**kwargs): else: raise ValueError("too many clocks spoil the broth!") instruments = [i for i in kwargs.values() if isinstance(i, Instrument)] - return _core.run(partial(fn, **kwargs), clock=clock, instruments=instruments) + return _core.run( + partial(fn, *args, **kwargs), clock=clock, instruments=instruments + ) return wrapper From 391f42bcb1a3c763282af302e316aa79c856a29d Mon Sep 17 00:00:00 2001 From: Spencer Brown Date: Tue, 8 Aug 2023 18:55:08 +1000 Subject: [PATCH 02/18] Add types to trio.testing._check_streams --- trio/testing/_check_streams.py | 98 +++++++++++++++++++++------------- 1 file changed, 62 insertions(+), 36 deletions(-) diff --git a/trio/testing/_check_streams.py b/trio/testing/_check_streams.py index 401b8ef0c2..7ef4bfaa1a 100644 --- a/trio/testing/_check_streams.py +++ b/trio/testing/_check_streams.py @@ -2,24 +2,31 @@ from __future__ import annotations import random +from collections.abc import Awaitable, Callable from contextlib import contextmanager -from typing import TYPE_CHECKING +from typing import Generic, TYPE_CHECKING, TypeVar -from .. import _core -from .._abc import HalfCloseableStream, ReceiveStream, SendStream, Stream +from .. import _core, CancelScope +from .._abc import AsyncResource, HalfCloseableStream, ReceiveStream, SendStream, Stream from .._highlevel_generic import aclose_forcefully from ._checkpoints import assert_checkpoints if TYPE_CHECKING: from types import TracebackType + from typing_extensions import ParamSpec + ArgsT = ParamSpec("ArgsT") -class _ForceCloseBoth: - def __init__(self, both): - self._both = list(both) +Res1 = TypeVar("Res1", bound=AsyncResource) +Res2 = TypeVar("Res2", bound=AsyncResource) - async def __aenter__(self): - return self._both + +class _ForceCloseBoth(Generic[Res1, Res2]): + def __init__(self, both: tuple[Res1, Res2]) -> None: + self._first, self._second = both + + async def __aenter__(self) -> tuple[Res1, Res2]: + return self._first, self._second async def __aexit__( self, @@ -28,9 +35,9 @@ async def __aexit__( traceback: TracebackType | None, ) -> None: try: - await aclose_forcefully(self._both[0]) + await aclose_forcefully(self._first) finally: - await aclose_forcefully(self._both[1]) + await aclose_forcefully(self._second) @contextmanager @@ -44,7 +51,11 @@ def _assert_raises(exc): raise AssertionError(f"expected exception: {exc}") -async def check_one_way_stream(stream_maker, clogged_stream_maker): +async def check_one_way_stream( + stream_maker: Callable[[], Awaitable[tuple[SendStream, ReceiveStream]]], + clogged_stream_maker: Callable[[], Awaitable[tuple[SendStream, ReceiveStream]]] + | None, +): """Perform a number of generic tests on a custom one-way stream implementation. @@ -67,18 +78,18 @@ async def check_one_way_stream(stream_maker, clogged_stream_maker): assert isinstance(s, SendStream) assert isinstance(r, ReceiveStream) - async def do_send_all(data): + async def do_send_all(data: bytes) -> None: with assert_checkpoints(): assert await s.send_all(data) is None - async def do_receive_some(*args): + async def do_receive_some(max_bytes: int | None = None) -> bytes | bytearray: with assert_checkpoints(): - return await r.receive_some(*args) + return await r.receive_some(max_bytes) - async def checked_receive_1(expected): + async def checked_receive_1(expected: bytes) -> None: assert await do_receive_some(1) == expected - async def do_aclose(resource): + async def do_aclose(resource: AsyncResource) -> None: with assert_checkpoints(): await resource.aclose() @@ -87,7 +98,7 @@ async def do_aclose(resource): nursery.start_soon(do_send_all, b"x") nursery.start_soon(checked_receive_1, b"x") - async def send_empty_then_y(): + async def send_empty_then_y() -> None: # Streams should tolerate sending b"" without giving it any # special meaning. await do_send_all(b"") @@ -258,9 +269,13 @@ async def receive_send_then_close(): # https://github.com/python-trio/trio/issues/77 async with _ForceCloseBoth(await stream_maker()) as (s, r): - async def expect_cancelled(afn, *args): + async def expect_cancelled( + afn: Callable[ArgsT, Awaitable[object]], + *args: ArgsT.args, + **kwargs: ArgsT.kwargs, + ) -> None: with _assert_raises(_core.Cancelled): - await afn(*args) + await afn(*args, **kwargs) with _core.CancelScope() as scope: scope.cancel() @@ -288,16 +303,16 @@ async def receive_expecting_closed(): # check wait_send_all_might_not_block, if we can if clogged_stream_maker is not None: async with _ForceCloseBoth(await clogged_stream_maker()) as (s, r): - record = [] + record: list[str] = [] - async def waiter(cancel_scope): + async def waiter(cancel_scope: CancelScope) -> None: record.append("waiter sleeping") with assert_checkpoints(): await s.wait_send_all_might_not_block() record.append("waiter wokeup") cancel_scope.cancel() - async def receiver(): + async def receiver() -> None: # give wait_send_all_might_not_block a chance to block await _core.wait_all_tasks_blocked() record.append("receiver starting") @@ -343,14 +358,14 @@ async def receiver(): # with or without an exception async with _ForceCloseBoth(await clogged_stream_maker()) as (s, r): - async def sender(): + async def sender() -> None: try: with assert_checkpoints(): await s.wait_send_all_might_not_block() except _core.BrokenResourceError: # pragma: no cover pass - async def receiver(): + async def receiver() -> None: await _core.wait_all_tasks_blocked() await aclose_forcefully(r) @@ -369,7 +384,7 @@ async def receiver(): # Check that if a task is blocked in a send-side method, then closing # the send stream causes it to wake up. - async def close_soon(s): + async def close_soon(s: SendStream) -> None: await _core.wait_all_tasks_blocked() await aclose_forcefully(s) @@ -386,7 +401,10 @@ async def close_soon(s): await s.wait_send_all_might_not_block() -async def check_two_way_stream(stream_maker, clogged_stream_maker): +async def check_two_way_stream( + stream_maker: Callable[[], Awaitable[tuple[Stream, Stream]]], + clogged_stream_maker: Callable[[], Awaitable[tuple[Stream, Stream]]] | None, +) -> None: """Perform a number of generic tests on a custom two-way stream implementation. @@ -401,13 +419,13 @@ async def check_two_way_stream(stream_maker, clogged_stream_maker): """ await check_one_way_stream(stream_maker, clogged_stream_maker) - async def flipped_stream_maker(): - return reversed(await stream_maker()) + async def flipped_stream_maker() -> tuple[Stream, Stream]: + return (await stream_maker())[::-1] if clogged_stream_maker is not None: - async def flipped_clogged_stream_maker(): - return reversed(await clogged_stream_maker()) + async def flipped_clogged_stream_maker() -> tuple[Stream, Stream]: + return (await clogged_stream_maker())[::-1] else: flipped_clogged_stream_maker = None @@ -425,7 +443,7 @@ async def flipped_clogged_stream_maker(): i = r.getrandbits(8 * DUPLEX_TEST_SIZE) test_data = i.to_bytes(DUPLEX_TEST_SIZE, "little") - async def sender(s, data, seed): + async def sender(s: Stream, data: bytes, seed: int) -> None: r = random.Random(seed) m = memoryview(data) while m: @@ -433,7 +451,7 @@ async def sender(s, data, seed): await s.send_all(m[:chunk_size]) m = m[chunk_size:] - async def receiver(s, data, seed): + async def receiver(s: Stream, data: bytes, seed: int) -> None: r = random.Random(seed) got = bytearray() while len(got) < len(data): @@ -448,7 +466,7 @@ async def receiver(s, data, seed): nursery.start_soon(receiver, s1, test_data[::-1], 2) nursery.start_soon(receiver, s2, test_data, 3) - async def expect_receive_some_empty(): + async def expect_receive_some_empty() -> None: assert await s2.receive_some(10) == b"" await s2.aclose() @@ -457,7 +475,15 @@ async def expect_receive_some_empty(): nursery.start_soon(s1.aclose) -async def check_half_closeable_stream(stream_maker, clogged_stream_maker): +async def check_half_closeable_stream( + stream_maker: Callable[ + [], Awaitable[tuple[HalfCloseableStream, HalfCloseableStream]] + ], + clogged_stream_maker: Callable[ + [], Awaitable[tuple[HalfCloseableStream, HalfCloseableStream]] + ] + | None, +) -> None: """Perform a number of generic tests on a custom half-closeable stream implementation. @@ -476,12 +502,12 @@ async def check_half_closeable_stream(stream_maker, clogged_stream_maker): assert isinstance(s1, HalfCloseableStream) assert isinstance(s2, HalfCloseableStream) - async def send_x_then_eof(s): + async def send_x_then_eof(s: HalfCloseableStream) -> None: await s.send_all(b"x") with assert_checkpoints(): await s.send_eof() - async def expect_x_then_eof(r): + async def expect_x_then_eof(r: HalfCloseableStream) -> None: await _core.wait_all_tasks_blocked() assert await r.receive_some(10) == b"x" assert await r.receive_some(10) == b"" From e3cb71e643f8fa0eef21cfaa424994c0b11c2136 Mon Sep 17 00:00:00 2001 From: Spencer Brown Date: Tue, 8 Aug 2023 18:57:37 +1000 Subject: [PATCH 03/18] Add types to trio.testing._checkpoints --- trio/testing/_checkpoints.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/trio/testing/_checkpoints.py b/trio/testing/_checkpoints.py index 5804295300..4a4047813b 100644 --- a/trio/testing/_checkpoints.py +++ b/trio/testing/_checkpoints.py @@ -1,10 +1,14 @@ -from contextlib import contextmanager +from __future__ import annotations + +from collections.abc import Generator +from contextlib import AbstractContextManager, contextmanager from .. import _core @contextmanager -def _assert_yields_or_not(expected): +def _assert_yields_or_not(expected: bool) -> Generator[None, None, None]: + """Check if checkpoints are executed in a block of code.""" __tracebackhide__ = True task = _core.current_task() orig_cancel = task._cancel_points @@ -22,7 +26,7 @@ def _assert_yields_or_not(expected): raise AssertionError("assert_no_checkpoints block yielded!") -def assert_checkpoints(): +def assert_checkpoints() -> AbstractContextManager[None]: """Use as a context manager to check that the code inside the ``with`` block either exits with an exception or executes at least one :ref:`checkpoint `. @@ -42,7 +46,7 @@ def assert_checkpoints(): return _assert_yields_or_not(True) -def assert_no_checkpoints(): +def assert_no_checkpoints() -> AbstractContextManager[None]: """Use as a context manager to check that the code inside the ``with`` block does not execute any :ref:`checkpoints `. From 737b8e2b659311cd7761f97b6503053c31f6949f Mon Sep 17 00:00:00 2001 From: Spencer Brown Date: Tue, 8 Aug 2023 18:58:34 +1000 Subject: [PATCH 04/18] Add types to trio.testing._network --- trio/testing/_network.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/trio/testing/_network.py b/trio/testing/_network.py index 615ce2effb..fddbbf0fdc 100644 --- a/trio/testing/_network.py +++ b/trio/testing/_network.py @@ -1,8 +1,10 @@ from .. import socket as tsocket -from .._highlevel_socket import SocketStream +from .._highlevel_socket import SocketListener, SocketStream -async def open_stream_to_socket_listener(socket_listener): +async def open_stream_to_socket_listener( + socket_listener: SocketListener, +) -> SocketStream: """Connect to the given :class:`~trio.SocketListener`. This is particularly useful in tests when you want to let a server pick From e1e5aa39884fb851a6132daa7a8df80f59c6a870 Mon Sep 17 00:00:00 2001 From: Spencer Brown Date: Tue, 8 Aug 2023 19:17:06 +1000 Subject: [PATCH 05/18] Add types to trio.testing._memory_streams --- trio/testing/_memory_streams.py | 117 ++++++++++++++++++-------------- 1 file changed, 66 insertions(+), 51 deletions(-) diff --git a/trio/testing/_memory_streams.py b/trio/testing/_memory_streams.py index 38e8e54de8..599f77ed40 100644 --- a/trio/testing/_memory_streams.py +++ b/trio/testing/_memory_streams.py @@ -1,4 +1,7 @@ +from __future__ import annotations + import operator +from collections.abc import Awaitable, Callable from .. import _core, _util from .._highlevel_generic import StapledStream @@ -10,7 +13,7 @@ class _UnboundedByteQueue: - def __init__(self): + def __init__(self) -> None: self._data = bytearray() self._closed = False self._lot = _core.ParkingLot() @@ -22,28 +25,28 @@ def __init__(self): # channel: so after close(), calling put() raises ClosedResourceError, and # calling the get() variants drains the buffer and then returns an empty # bytearray. - def close(self): + def close(self) -> None: self._closed = True self._lot.unpark_all() - def close_and_wipe(self): + def close_and_wipe(self) -> None: self._data = bytearray() self.close() - def put(self, data): + def put(self, data: bytes | bytearray) -> None: if self._closed: raise _core.ClosedResourceError("virtual connection closed") self._data += data self._lot.unpark_all() - def _check_max_bytes(self, max_bytes): + def _check_max_bytes(self, max_bytes: int | None) -> None: if max_bytes is None: return max_bytes = operator.index(max_bytes) if max_bytes < 1: raise ValueError("max_bytes must be >= 1") - def _get_impl(self, max_bytes): + def _get_impl(self, max_bytes: int | None) -> bytearray: assert self._closed or self._data if max_bytes is None: max_bytes = len(self._data) @@ -55,14 +58,14 @@ def _get_impl(self, max_bytes): else: return bytearray() - def get_nowait(self, max_bytes=None): + def get_nowait(self, max_bytes: int | None = None) -> bytearray: with self._fetch_lock: self._check_max_bytes(max_bytes) if not self._closed and not self._data: raise _core.WouldBlock return self._get_impl(max_bytes) - async def get(self, max_bytes=None): + async def get(self, max_bytes: int | None = None) -> bytearray: with self._fetch_lock: self._check_max_bytes(max_bytes) if not self._closed and not self._data: @@ -95,9 +98,10 @@ class MemorySendStream(SendStream, metaclass=_util.Final): def __init__( self, - send_all_hook=None, - wait_send_all_might_not_block_hook=None, - close_hook=None, + send_all_hook: Callable[[], Awaitable[object]] | None = None, + wait_send_all_might_not_block_hook: Callable[[], Awaitable[object]] + | None = None, + close_hook: Callable[[], object] | None = None, ): self._conflict_detector = _util.ConflictDetector( "another task is using this stream" @@ -107,7 +111,7 @@ def __init__( self.wait_send_all_might_not_block_hook = wait_send_all_might_not_block_hook self.close_hook = close_hook - async def send_all(self, data): + async def send_all(self, data: bytes | bytearray | memoryview) -> None: """Places the given data into the object's internal buffer, and then calls the :attr:`send_all_hook` (if any). @@ -121,12 +125,12 @@ async def send_all(self, data): if self.send_all_hook is not None: await self.send_all_hook() - async def wait_send_all_might_not_block(self): + async def wait_send_all_might_not_block(self) -> None: """Calls the :attr:`wait_send_all_might_not_block_hook` (if any), and then returns immediately. """ - # Execute two checkpoints so we have more of a chance to detect + # Execute two checkpoints so that we have more of a chance to detect # buggy user code that calls this twice at the same time. with self._conflict_detector: await _core.checkpoint() @@ -136,7 +140,7 @@ async def wait_send_all_might_not_block(self): if self.wait_send_all_might_not_block_hook is not None: await self.wait_send_all_might_not_block_hook() - def close(self): + def close(self) -> None: """Marks this stream as closed, and then calls the :attr:`close_hook` (if any). @@ -153,12 +157,12 @@ def close(self): if self.close_hook is not None: self.close_hook() - async def aclose(self): + async def aclose(self) -> None: """Same as :meth:`close`, but async.""" self.close() await _core.checkpoint() - async def get_data(self, max_bytes=None): + async def get_data(self, max_bytes: int | None = None) -> bytearray: """Retrieves data from the internal buffer, blocking if necessary. Args: @@ -174,7 +178,7 @@ async def get_data(self, max_bytes=None): """ return await self._outgoing.get(max_bytes) - def get_data_nowait(self, max_bytes=None): + def get_data_nowait(self, max_bytes: int | None = None) -> bytearray: """Retrieves data from the internal buffer, but doesn't block. See :meth:`get_data` for details. @@ -203,7 +207,11 @@ class MemoryReceiveStream(ReceiveStream, metaclass=_util.Final): """ - def __init__(self, receive_some_hook=None, close_hook=None): + def __init__( + self, + receive_some_hook: Callable[[], Awaitable[object]] | None = None, + close_hook: Callable[[], object] | None = None, + ): self._conflict_detector = _util.ConflictDetector( "another task is using this stream" ) @@ -212,7 +220,7 @@ def __init__(self, receive_some_hook=None, close_hook=None): self.receive_some_hook = receive_some_hook self.close_hook = close_hook - async def receive_some(self, max_bytes=None): + async def receive_some(self, max_bytes: int | None = None) -> bytearray: """Calls the :attr:`receive_some_hook` (if any), and then retrieves data from the internal buffer, blocking if necessary. @@ -235,7 +243,7 @@ async def receive_some(self, max_bytes=None): raise _core.ClosedResourceError return data - def close(self): + def close(self) -> None: """Discards any pending data from the internal buffer, and marks this stream as closed. @@ -245,21 +253,26 @@ def close(self): if self.close_hook is not None: self.close_hook() - async def aclose(self): + async def aclose(self) -> None: """Same as :meth:`close`, but async.""" self.close() await _core.checkpoint() - def put_data(self, data): + def put_data(self, data: bytes | bytearray) -> None: """Appends the given data to the internal buffer.""" self._incoming.put(data) - def put_eof(self): + def put_eof(self) -> None: """Adds an end-of-file marker to the internal buffer.""" self._incoming.close() -def memory_stream_pump(memory_send_stream, memory_receive_stream, *, max_bytes=None): +def memory_stream_pump( + memory_send_stream: MemorySendStream, + memory_receive_stream: MemoryReceiveStream, + *, + max_bytes: int | None = None, +) -> bool: """Take data out of the given :class:`MemorySendStream`'s internal buffer, and put it into the given :class:`MemoryReceiveStream`'s internal buffer. @@ -292,7 +305,7 @@ def memory_stream_pump(memory_send_stream, memory_receive_stream, *, max_bytes=N return True -def memory_stream_one_way_pair(): +def memory_stream_one_way_pair() -> tuple[MemorySendStream, MemoryReceiveStream]: """Create a connected, pure-Python, unidirectional stream with infinite buffering and flexible configuration options. @@ -319,10 +332,10 @@ def memory_stream_one_way_pair(): send_stream = MemorySendStream() recv_stream = MemoryReceiveStream() - def pump_from_send_stream_to_recv_stream(): + def pump_from_send_stream_to_recv_stream() -> None: memory_stream_pump(send_stream, recv_stream) - async def async_pump_from_send_stream_to_recv_stream(): + async def async_pump_from_send_stream_to_recv_stream() -> None: pump_from_send_stream_to_recv_stream() send_stream.send_all_hook = async_pump_from_send_stream_to_recv_stream @@ -330,7 +343,9 @@ async def async_pump_from_send_stream_to_recv_stream(): return send_stream, recv_stream -def _make_stapled_pair(one_way_pair): +def _make_stapled_pair( + one_way_pair: Callable[[], tuple[SendStream, ReceiveStream]] +) -> tuple[StapledStream, StapledStream]: pipe1_send, pipe1_recv = one_way_pair() pipe2_send, pipe2_recv = one_way_pair() stream1 = StapledStream(pipe1_send, pipe2_recv) @@ -338,7 +353,7 @@ def _make_stapled_pair(one_way_pair): return stream1, stream2 -def memory_stream_pair(): +def memory_stream_pair() -> tuple[SendStream, SendStream]: """Create a connected, pure-Python, bidirectional stream with infinite buffering and flexible configuration options. @@ -421,7 +436,7 @@ async def receiver(): class _LockstepByteQueue: - def __init__(self): + def __init__(self) -> None: self._data = bytearray() self._sender_closed = False self._receiver_closed = False @@ -434,12 +449,12 @@ def __init__(self): "another task is already receiving" ) - def _something_happened(self): + def _something_happened(self) -> None: self._waiters.unpark_all() # Always wakes up when one side is closed, because everyone always reacts # to that. - async def _wait_for(self, fn): + async def _wait_for(self, fn: Callable[[], bool]) -> None: while True: if fn(): break @@ -448,15 +463,15 @@ async def _wait_for(self, fn): await self._waiters.park() await _core.checkpoint() - def close_sender(self): + def close_sender(self) -> None: self._sender_closed = True self._something_happened() - def close_receiver(self): + def close_receiver(self) -> None: self._receiver_closed = True self._something_happened() - async def send_all(self, data): + async def send_all(self, data: bytes | bytearray) -> None: with self._send_conflict_detector: if self._sender_closed: raise _core.ClosedResourceError @@ -465,13 +480,13 @@ async def send_all(self, data): assert not self._data self._data += data self._something_happened() - await self._wait_for(lambda: not self._data) + await self._wait_for(lambda: self._data == b"") if self._sender_closed: raise _core.ClosedResourceError if self._data and self._receiver_closed: raise _core.BrokenResourceError - async def wait_send_all_might_not_block(self): + async def wait_send_all_might_not_block(self) -> None: with self._send_conflict_detector: if self._sender_closed: raise _core.ClosedResourceError @@ -482,7 +497,7 @@ async def wait_send_all_might_not_block(self): if self._sender_closed: raise _core.ClosedResourceError - async def receive_some(self, max_bytes=None): + async def receive_some(self, max_bytes: int | None = None) -> bytes | bytearray: with self._receive_conflict_detector: # Argument validation if max_bytes is not None: @@ -496,7 +511,7 @@ async def receive_some(self, max_bytes=None): self._receiver_waiting = True self._something_happened() try: - await self._wait_for(lambda: self._data) + await self._wait_for(lambda: self._data != b"") finally: self._receiver_waiting = False if self._receiver_closed: @@ -515,39 +530,39 @@ async def receive_some(self, max_bytes=None): class _LockstepSendStream(SendStream): - def __init__(self, lbq): + def __init__(self, lbq: _LockstepByteQueue): self._lbq = lbq - def close(self): + def close(self) -> None: self._lbq.close_sender() - async def aclose(self): + async def aclose(self) -> None: self.close() await _core.checkpoint() - async def send_all(self, data): + async def send_all(self, data: bytes | bytearray) -> None: await self._lbq.send_all(data) - async def wait_send_all_might_not_block(self): + async def wait_send_all_might_not_block(self) -> None: await self._lbq.wait_send_all_might_not_block() class _LockstepReceiveStream(ReceiveStream): - def __init__(self, lbq): + def __init__(self, lbq: _LockstepByteQueue): self._lbq = lbq - def close(self): + def close(self) -> None: self._lbq.close_receiver() - async def aclose(self): + async def aclose(self) -> None: self.close() await _core.checkpoint() - async def receive_some(self, max_bytes=None): + async def receive_some(self, max_bytes: int | None = None) -> bytes | bytearray: return await self._lbq.receive_some(max_bytes) -def lockstep_stream_one_way_pair(): +def lockstep_stream_one_way_pair() -> tuple[SendStream, ReceiveStream]: """Create a connected, pure Python, unidirectional stream where data flows in lockstep. @@ -574,7 +589,7 @@ def lockstep_stream_one_way_pair(): return _LockstepSendStream(lbq), _LockstepReceiveStream(lbq) -def lockstep_stream_pair(): +def lockstep_stream_pair() -> tuple[StapledStream, StapledStream]: """Create a connected, pure-Python, bidirectional stream where data flows in lockstep. From fcd68315093b3238c9f141467cf27a0545bf3461 Mon Sep 17 00:00:00 2001 From: Spencer Brown Date: Tue, 8 Aug 2023 19:17:35 +1000 Subject: [PATCH 06/18] Enable strict tests for most of trio.testing, fill in some missing ones --- pyproject.toml | 7 +++++++ trio/testing/_check_streams.py | 24 ++++++++++++++---------- trio/testing/_trio_test.py | 3 ++- 3 files changed, 23 insertions(+), 11 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index a23f7f5db9..2366bc61f9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -63,6 +63,13 @@ module = [ "trio._sync", "trio._tools.gen_exports", "trio._util", + "trio.testing", + "trio.testing._check_streams", + "trio.testing._checkpoints", + "trio.testing._memory_streams", + "trio.testing._network", + "trio.testing._sequencer", + "trio.testing._trio_test", ] disallow_incomplete_defs = true disallow_untyped_defs = true diff --git a/trio/testing/_check_streams.py b/trio/testing/_check_streams.py index 7ef4bfaa1a..6e4a946521 100644 --- a/trio/testing/_check_streams.py +++ b/trio/testing/_check_streams.py @@ -2,7 +2,7 @@ from __future__ import annotations import random -from collections.abc import Awaitable, Callable +from collections.abc import Awaitable, Callable, Generator from contextlib import contextmanager from typing import Generic, TYPE_CHECKING, TypeVar @@ -41,7 +41,7 @@ async def __aexit__( @contextmanager -def _assert_raises(exc): +def _assert_raises(exc: type[BaseException]) -> Generator[None, None, None]: __tracebackhide__ = True try: yield @@ -55,7 +55,7 @@ async def check_one_way_stream( stream_maker: Callable[[], Awaitable[tuple[SendStream, ReceiveStream]]], clogged_stream_maker: Callable[[], Awaitable[tuple[SendStream, ReceiveStream]]] | None, -): +) -> None: """Perform a number of generic tests on a custom one-way stream implementation. @@ -79,8 +79,8 @@ async def check_one_way_stream( assert isinstance(r, ReceiveStream) async def do_send_all(data: bytes) -> None: - with assert_checkpoints(): - assert await s.send_all(data) is None + with assert_checkpoints(): # We're testing that it doesn't return anything. + assert await s.send_all(data) is None # type: ignore[func-returns-value] async def do_receive_some(max_bytes: int | None = None) -> bytes | bytearray: with assert_checkpoints(): @@ -125,7 +125,7 @@ async def send_empty_then_y() -> None: with _assert_raises(ValueError): await r.receive_some(0) with _assert_raises(TypeError): - await r.receive_some(1.5) + await r.receive_some(1.5) # type: ignore[arg-type] # it can also be missing or None async with _core.open_nursery() as nursery: nursery.start_soon(do_send_all, b"x") @@ -144,7 +144,9 @@ async def send_empty_then_y() -> None: # for send_all to wait until receive_some is called to run, though; a # stream doesn't *have* to have any internal buffering. That's why we # start a concurrent receive_some call, then cancel it.) - async def simple_check_wait_send_all_might_not_block(scope): + async def simple_check_wait_send_all_might_not_block( + scope: CancelScope, + ) -> None: with assert_checkpoints(): await s.wait_send_all_might_not_block() scope.cancel() @@ -157,7 +159,7 @@ async def simple_check_wait_send_all_might_not_block(scope): # closing the r side leads to BrokenResourceError on the s side # (eventually) - async def expect_broken_stream_on_send(): + async def expect_broken_stream_on_send() -> None: with _assert_raises(_core.BrokenResourceError): while True: await do_send_all(b"x" * 100) @@ -200,11 +202,11 @@ async def expect_broken_stream_on_send(): async with _ForceCloseBoth(await stream_maker()) as (s, r): # if send-then-graceful-close, receiver gets data then b"" - async def send_then_close(): + async def send_then_close() -> None: await do_send_all(b"y") await do_aclose(s) - async def receive_send_then_close(): + async def receive_send_then_close() -> None: # We want to make sure that if the sender closes the stream before # we read anything, then we still get all the data. But some # streams might block on the do_send_all call. So we let the @@ -422,6 +424,8 @@ async def check_two_way_stream( async def flipped_stream_maker() -> tuple[Stream, Stream]: return (await stream_maker())[::-1] + flipped_clogged_stream_maker: Callable[[], Awaitable[tuple[Stream, Stream]]] | None + if clogged_stream_maker is not None: async def flipped_clogged_stream_maker() -> tuple[Stream, Stream]: diff --git a/trio/testing/_trio_test.py b/trio/testing/_trio_test.py index 177648d552..3b75fef12c 100644 --- a/trio/testing/_trio_test.py +++ b/trio/testing/_trio_test.py @@ -41,7 +41,8 @@ def wrapper(*args: ArgsT.args, **kwargs: ArgsT.kwargs) -> RetT: else: raise ValueError("too many clocks spoil the broth!") instruments = [i for i in kwargs.values() if isinstance(i, Instrument)] - return _core.run( + # trio.run isn't typed yet. + return _core.run( # type: ignore[no-any-return] partial(fn, *args, **kwargs), clock=clock, instruments=instruments ) From 11ce4fff1843e98f4bc7b4db7384e797ad5a6853 Mon Sep 17 00:00:00 2001 From: Spencer Brown Date: Tue, 8 Aug 2023 19:19:20 +1000 Subject: [PATCH 07/18] Run linters --- trio/_tests/verify_types.json | 38 ++++------------------------------ trio/testing/_check_streams.py | 5 +++-- trio/testing/_trio_test.py | 7 +++---- 3 files changed, 10 insertions(+), 40 deletions(-) diff --git a/trio/_tests/verify_types.json b/trio/_tests/verify_types.json index 5a0c59e33e..fe84b1ebb3 100644 --- a/trio/_tests/verify_types.json +++ b/trio/_tests/verify_types.json @@ -7,16 +7,16 @@ "warningCount": 0 }, "typeCompleteness": { - "completenessScore": 0.9202551834130781, + "completenessScore": 0.9266347687400319, "exportedSymbolCounts": { "withAmbiguousType": 0, - "withKnownType": 577, - "withUnknownType": 50 + "withKnownType": 581, + "withUnknownType": 46 }, "ignoreUnknownTypesFromImports": true, "missingClassDocStringCount": 1, "missingDefaultParamCount": 0, - "missingFunctionDocStringCount": 4, + "missingFunctionDocStringCount": 3, "moduleName": "trio", "modules": [ { @@ -112,36 +112,6 @@ "trio.serve_listeners", "trio.serve_ssl_over_tcp", "trio.serve_tcp", - "trio.testing._memory_streams.MemoryReceiveStream.__init__", - "trio.testing._memory_streams.MemoryReceiveStream.aclose", - "trio.testing._memory_streams.MemoryReceiveStream.close", - "trio.testing._memory_streams.MemoryReceiveStream.close_hook", - "trio.testing._memory_streams.MemoryReceiveStream.put_data", - "trio.testing._memory_streams.MemoryReceiveStream.put_eof", - "trio.testing._memory_streams.MemoryReceiveStream.receive_some", - "trio.testing._memory_streams.MemoryReceiveStream.receive_some_hook", - "trio.testing._memory_streams.MemorySendStream.__init__", - "trio.testing._memory_streams.MemorySendStream.aclose", - "trio.testing._memory_streams.MemorySendStream.close", - "trio.testing._memory_streams.MemorySendStream.close_hook", - "trio.testing._memory_streams.MemorySendStream.get_data", - "trio.testing._memory_streams.MemorySendStream.get_data_nowait", - "trio.testing._memory_streams.MemorySendStream.send_all", - "trio.testing._memory_streams.MemorySendStream.send_all_hook", - "trio.testing._memory_streams.MemorySendStream.wait_send_all_might_not_block", - "trio.testing._memory_streams.MemorySendStream.wait_send_all_might_not_block_hook", - "trio.testing.assert_checkpoints", - "trio.testing.assert_no_checkpoints", - "trio.testing.check_half_closeable_stream", - "trio.testing.check_one_way_stream", - "trio.testing.check_two_way_stream", - "trio.testing.lockstep_stream_one_way_pair", - "trio.testing.lockstep_stream_pair", - "trio.testing.memory_stream_one_way_pair", - "trio.testing.memory_stream_pair", - "trio.testing.memory_stream_pump", - "trio.testing.open_stream_to_socket_listener", - "trio.testing.trio_test", "trio.testing.wait_all_tasks_blocked", "trio.tests.TestsDeprecationWrapper", "trio.to_thread.current_default_thread_limiter" diff --git a/trio/testing/_check_streams.py b/trio/testing/_check_streams.py index 6e4a946521..aeed087eae 100644 --- a/trio/testing/_check_streams.py +++ b/trio/testing/_check_streams.py @@ -4,15 +4,16 @@ import random from collections.abc import Awaitable, Callable, Generator from contextlib import contextmanager -from typing import Generic, TYPE_CHECKING, TypeVar +from typing import TYPE_CHECKING, Generic, TypeVar -from .. import _core, CancelScope +from .. import CancelScope, _core from .._abc import AsyncResource, HalfCloseableStream, ReceiveStream, SendStream, Stream from .._highlevel_generic import aclose_forcefully from ._checkpoints import assert_checkpoints if TYPE_CHECKING: from types import TracebackType + from typing_extensions import ParamSpec ArgsT = ParamSpec("ArgsT") diff --git a/trio/testing/_trio_test.py b/trio/testing/_trio_test.py index 3b75fef12c..e72d1a6df0 100644 --- a/trio/testing/_trio_test.py +++ b/trio/testing/_trio_test.py @@ -1,13 +1,12 @@ from __future__ import annotations -from typing import TYPE_CHECKING, TypeVar -from collections.abc import Callable, Awaitable -from functools import partial, wraps +from collections.abc import Awaitable, Callable +from functools import partial, wraps +from typing import TYPE_CHECKING, TypeVar from .. import _core from ..abc import Clock, Instrument - if TYPE_CHECKING: from typing_extensions import ParamSpec From 073cae1343459f87cdee0e5f2757815c496f40cb Mon Sep 17 00:00:00 2001 From: Spencer Brown Date: Tue, 8 Aug 2023 19:31:29 +1000 Subject: [PATCH 08/18] Update verify_types.json --- trio/_tests/verify_types.json | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/trio/_tests/verify_types.json b/trio/_tests/verify_types.json index fe84b1ebb3..d9a5cad892 100644 --- a/trio/_tests/verify_types.json +++ b/trio/_tests/verify_types.json @@ -7,11 +7,11 @@ "warningCount": 0 }, "typeCompleteness": { - "completenessScore": 0.9266347687400319, + "completenessScore": 0.9425837320574163, "exportedSymbolCounts": { "withAmbiguousType": 0, - "withKnownType": 581, - "withUnknownType": 46 + "withKnownType": 591, + "withUnknownType": 36 }, "ignoreUnknownTypesFromImports": true, "missingClassDocStringCount": 1, @@ -46,8 +46,8 @@ ], "otherSymbolCounts": { "withAmbiguousType": 3, - "withKnownType": 602, - "withUnknownType": 61 + "withKnownType": 622, + "withUnknownType": 41 }, "packageName": "trio", "symbols": [ From 68eaad79d1d5b22270693bf1221a7eed7639728b Mon Sep 17 00:00:00 2001 From: Spencer Brown Date: Wed, 9 Aug 2023 11:20:36 +1000 Subject: [PATCH 09/18] Introduce some type aliases to make code more readable --- trio/testing/_check_streams.py | 21 ++++++++------------- trio/testing/_memory_streams.py | 22 ++++++++++++++++------ 2 files changed, 24 insertions(+), 19 deletions(-) diff --git a/trio/testing/_check_streams.py b/trio/testing/_check_streams.py index aeed087eae..074a0a59b2 100644 --- a/trio/testing/_check_streams.py +++ b/trio/testing/_check_streams.py @@ -14,12 +14,13 @@ if TYPE_CHECKING: from types import TracebackType - from typing_extensions import ParamSpec + from typing_extensions import ParamSpec, TypeAlias ArgsT = ParamSpec("ArgsT") Res1 = TypeVar("Res1", bound=AsyncResource) Res2 = TypeVar("Res2", bound=AsyncResource) +StreamMaker: TypeAlias = "Callable[[], Awaitable[tuple[Res1, Res2]]]" class _ForceCloseBoth(Generic[Res1, Res2]): @@ -53,9 +54,8 @@ def _assert_raises(exc: type[BaseException]) -> Generator[None, None, None]: async def check_one_way_stream( - stream_maker: Callable[[], Awaitable[tuple[SendStream, ReceiveStream]]], - clogged_stream_maker: Callable[[], Awaitable[tuple[SendStream, ReceiveStream]]] - | None, + stream_maker: StreamMaker[SendStream, ReceiveStream], + clogged_stream_maker: StreamMaker[SendStream, ReceiveStream] | None, ) -> None: """Perform a number of generic tests on a custom one-way stream implementation. @@ -405,8 +405,8 @@ async def close_soon(s: SendStream) -> None: async def check_two_way_stream( - stream_maker: Callable[[], Awaitable[tuple[Stream, Stream]]], - clogged_stream_maker: Callable[[], Awaitable[tuple[Stream, Stream]]] | None, + stream_maker: StreamMaker[Stream, Stream], + clogged_stream_maker: StreamMaker[Stream, Stream] | None, ) -> None: """Perform a number of generic tests on a custom two-way stream implementation. @@ -481,13 +481,8 @@ async def expect_receive_some_empty() -> None: async def check_half_closeable_stream( - stream_maker: Callable[ - [], Awaitable[tuple[HalfCloseableStream, HalfCloseableStream]] - ], - clogged_stream_maker: Callable[ - [], Awaitable[tuple[HalfCloseableStream, HalfCloseableStream]] - ] - | None, + stream_maker: StreamMaker[HalfCloseableStream, HalfCloseableStream], + clogged_stream_maker: StreamMaker[HalfCloseableStream, HalfCloseableStream] | None, ) -> None: """Perform a number of generic tests on a custom half-closeable stream implementation. diff --git a/trio/testing/_memory_streams.py b/trio/testing/_memory_streams.py index 599f77ed40..8ba91ad330 100644 --- a/trio/testing/_memory_streams.py +++ b/trio/testing/_memory_streams.py @@ -2,11 +2,22 @@ import operator from collections.abc import Awaitable, Callable +from typing import TYPE_CHECKING from .. import _core, _util from .._highlevel_generic import StapledStream from ..abc import ReceiveStream, SendStream + +if TYPE_CHECKING: + from typing_extensions import TypeAlias + + +AsyncHook: TypeAlias = "Callable[[], Awaitable[object]]" +# Would be nice to exclude awaitable here, but currently not possible. +SyncHook: TypeAlias = "Callable[[], object]" + + ################################################################ # In-memory streams - Unbounded buffer version ################################################################ @@ -98,10 +109,9 @@ class MemorySendStream(SendStream, metaclass=_util.Final): def __init__( self, - send_all_hook: Callable[[], Awaitable[object]] | None = None, - wait_send_all_might_not_block_hook: Callable[[], Awaitable[object]] - | None = None, - close_hook: Callable[[], object] | None = None, + send_all_hook: AsyncHook | None = None, + wait_send_all_might_not_block_hook: AsyncHook | None = None, + close_hook: SyncHook | None = None, ): self._conflict_detector = _util.ConflictDetector( "another task is using this stream" @@ -209,8 +219,8 @@ class MemoryReceiveStream(ReceiveStream, metaclass=_util.Final): def __init__( self, - receive_some_hook: Callable[[], Awaitable[object]] | None = None, - close_hook: Callable[[], object] | None = None, + receive_some_hook: AsyncHook | None = None, + close_hook: SyncHook | None = None, ): self._conflict_detector = _util.ConflictDetector( "another task is using this stream" From 3c6b14c05e7706ff9c967a83496d64fd2bc79534 Mon Sep 17 00:00:00 2001 From: Spencer Brown Date: Wed, 9 Aug 2023 14:52:26 +1000 Subject: [PATCH 10/18] Accept bytearray/memoryview in functions where applicable --- trio/testing/_check_streams.py | 8 +++++--- trio/testing/_memory_streams.py | 9 ++++----- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/trio/testing/_check_streams.py b/trio/testing/_check_streams.py index 074a0a59b2..d9abf19f14 100644 --- a/trio/testing/_check_streams.py +++ b/trio/testing/_check_streams.py @@ -79,7 +79,7 @@ async def check_one_way_stream( assert isinstance(s, SendStream) assert isinstance(r, ReceiveStream) - async def do_send_all(data: bytes) -> None: + async def do_send_all(data: bytes | bytearray | memoryview) -> None: with assert_checkpoints(): # We're testing that it doesn't return anything. assert await s.send_all(data) is None # type: ignore[func-returns-value] @@ -448,7 +448,9 @@ async def flipped_clogged_stream_maker() -> tuple[Stream, Stream]: i = r.getrandbits(8 * DUPLEX_TEST_SIZE) test_data = i.to_bytes(DUPLEX_TEST_SIZE, "little") - async def sender(s: Stream, data: bytes, seed: int) -> None: + async def sender( + s: Stream, data: bytes | bytearray | memoryview, seed: int + ) -> None: r = random.Random(seed) m = memoryview(data) while m: @@ -456,7 +458,7 @@ async def sender(s: Stream, data: bytes, seed: int) -> None: await s.send_all(m[:chunk_size]) m = m[chunk_size:] - async def receiver(s: Stream, data: bytes, seed: int) -> None: + async def receiver(s: Stream, data: bytes | bytearray, seed: int) -> None: r = random.Random(seed) got = bytearray() while len(got) < len(data): diff --git a/trio/testing/_memory_streams.py b/trio/testing/_memory_streams.py index 8ba91ad330..956f2e8fab 100644 --- a/trio/testing/_memory_streams.py +++ b/trio/testing/_memory_streams.py @@ -8,7 +8,6 @@ from .._highlevel_generic import StapledStream from ..abc import ReceiveStream, SendStream - if TYPE_CHECKING: from typing_extensions import TypeAlias @@ -44,7 +43,7 @@ def close_and_wipe(self) -> None: self._data = bytearray() self.close() - def put(self, data: bytes | bytearray) -> None: + def put(self, data: bytes | bytearray | memoryview) -> None: if self._closed: raise _core.ClosedResourceError("virtual connection closed") self._data += data @@ -268,7 +267,7 @@ async def aclose(self) -> None: self.close() await _core.checkpoint() - def put_data(self, data: bytes | bytearray) -> None: + def put_data(self, data: bytes | bytearray | memoryview) -> None: """Appends the given data to the internal buffer.""" self._incoming.put(data) @@ -481,7 +480,7 @@ def close_receiver(self) -> None: self._receiver_closed = True self._something_happened() - async def send_all(self, data: bytes | bytearray) -> None: + async def send_all(self, data: bytes | bytearray | memoryview) -> None: with self._send_conflict_detector: if self._sender_closed: raise _core.ClosedResourceError @@ -550,7 +549,7 @@ async def aclose(self) -> None: self.close() await _core.checkpoint() - async def send_all(self, data: bytes | bytearray) -> None: + async def send_all(self, data: bytes | bytearray | memoryview) -> None: await self._lbq.send_all(data) async def wait_send_all_might_not_block(self) -> None: From c1fcd4a27ad402fde7e85f8bba2cbe79b3bcab77 Mon Sep 17 00:00:00 2001 From: Spencer Brown Date: Wed, 9 Aug 2023 17:00:24 +1000 Subject: [PATCH 11/18] Expand the type aliases in trio.testing for docs --- docs/source/conf.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/docs/source/conf.py b/docs/source/conf.py index 39eda12ad5..f3e0e79627 100755 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -18,6 +18,7 @@ # import os import sys +import re # For our local_customization module sys.path.insert(0, os.path.abspath(".")) @@ -90,6 +91,26 @@ } +def autodoc_process_signature( + app, what, name, obj, options, signature, return_annotation +): + """Modify found signatures to fix various issues.""" + if signature is not None: + if name.startswith("trio.testing"): + # Expand type aliases + signature = re.sub( + r"StreamMaker\[([a-zA-Z ,]+)]", + lambda match: f"typing.Callable[[], typing.Awaitable[tuple[{match.group(1)}]]]", + signature, + ) + signature = signature.replace( + "AsyncHook", "typing.Callable[[], typing.Awaitable[object]]" + ) + signature = signature.replace("SyncHook", "typing.Callable[[], object]") + + return signature, return_annotation + + # XX hack the RTD theme until # https://github.com/rtfd/sphinx_rtd_theme/pull/382 # is shipped (should be in the release after 0.2.4) @@ -97,6 +118,7 @@ # though. def setup(app): app.add_css_file("hackrtd.css") + app.connect("autodoc-process-signature", autodoc_process_signature) # -- General configuration ------------------------------------------------ From d1a2c99668744e95b27ba0d433d06620c4c52c1b Mon Sep 17 00:00:00 2001 From: jakkdl Date: Mon, 14 Aug 2023 14:30:38 +0200 Subject: [PATCH 12/18] fix return type of trio/_core/_run to not be an awaitable, remove type: ignore --- trio/_core/_run.py | 2 +- trio/testing/_trio_test.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/trio/_core/_run.py b/trio/_core/_run.py index 1ba88da85e..3a4751254f 100644 --- a/trio/_core/_run.py +++ b/trio/_core/_run.py @@ -2152,7 +2152,7 @@ def setup_runner( def run( - async_fn: Callable[..., RetT], + async_fn: Callable[..., Awaitable[RetT]], *args: object, clock: Clock | None = None, instruments: Sequence[Instrument] = (), diff --git a/trio/testing/_trio_test.py b/trio/testing/_trio_test.py index e72d1a6df0..5619352846 100644 --- a/trio/testing/_trio_test.py +++ b/trio/testing/_trio_test.py @@ -40,8 +40,7 @@ def wrapper(*args: ArgsT.args, **kwargs: ArgsT.kwargs) -> RetT: else: raise ValueError("too many clocks spoil the broth!") instruments = [i for i in kwargs.values() if isinstance(i, Instrument)] - # trio.run isn't typed yet. - return _core.run( # type: ignore[no-any-return] + return _core.run( partial(fn, *args, **kwargs), clock=clock, instruments=instruments ) From e8fbc12d22af1604b8c72aad2c93b00f9af5982d Mon Sep 17 00:00:00 2001 From: Spencer Brown Date: Tue, 15 Aug 2023 13:40:12 +1000 Subject: [PATCH 13/18] Un-stringify these aliases --- docs/source/conf.py | 11 ----------- trio/testing/_check_streams.py | 6 +++--- trio/testing/_memory_streams.py | 7 +++---- 3 files changed, 6 insertions(+), 18 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index e83d2be96a..bade5a8ff1 100755 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -109,17 +109,6 @@ def autodoc_process_signature( # Strip the type from the union, make it look like = ... signature = signature.replace(" | type[trio._core._local._NoValue]", "") signature = signature.replace("", "...") - if name.startswith("trio.testing"): - # Expand type aliases - signature = re.sub( - r"StreamMaker\[([a-zA-Z ,]+)]", - lambda match: f"typing.Callable[[], typing.Awaitable[tuple[{match.group(1)}]]]", - signature, - ) - signature = signature.replace( - "AsyncHook", "typing.Callable[[], typing.Awaitable[object]]" - ) - signature = signature.replace("SyncHook", "typing.Callable[[], object]") return signature, return_annotation diff --git a/trio/testing/_check_streams.py b/trio/testing/_check_streams.py index d9abf19f14..33947ccc55 100644 --- a/trio/testing/_check_streams.py +++ b/trio/testing/_check_streams.py @@ -2,9 +2,9 @@ from __future__ import annotations import random -from collections.abc import Awaitable, Callable, Generator +from collections.abc import Generator from contextlib import contextmanager -from typing import TYPE_CHECKING, Generic, TypeVar +from typing import TYPE_CHECKING, Awaitable, Callable, Generic, Tuple, TypeVar from .. import CancelScope, _core from .._abc import AsyncResource, HalfCloseableStream, ReceiveStream, SendStream, Stream @@ -20,7 +20,7 @@ Res1 = TypeVar("Res1", bound=AsyncResource) Res2 = TypeVar("Res2", bound=AsyncResource) -StreamMaker: TypeAlias = "Callable[[], Awaitable[tuple[Res1, Res2]]]" +StreamMaker: TypeAlias = Callable[[], Awaitable[Tuple[Res1, Res2]]] class _ForceCloseBoth(Generic[Res1, Res2]): diff --git a/trio/testing/_memory_streams.py b/trio/testing/_memory_streams.py index 956f2e8fab..719b82e0b5 100644 --- a/trio/testing/_memory_streams.py +++ b/trio/testing/_memory_streams.py @@ -1,8 +1,7 @@ from __future__ import annotations import operator -from collections.abc import Awaitable, Callable -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Awaitable, Callable from .. import _core, _util from .._highlevel_generic import StapledStream @@ -12,9 +11,9 @@ from typing_extensions import TypeAlias -AsyncHook: TypeAlias = "Callable[[], Awaitable[object]]" +AsyncHook: TypeAlias = Callable[[], Awaitable[object]] # Would be nice to exclude awaitable here, but currently not possible. -SyncHook: TypeAlias = "Callable[[], object]" +SyncHook: TypeAlias = Callable[[], object] ################################################################ From 787759100c4767928028e84f651098d885fc3dfb Mon Sep 17 00:00:00 2001 From: Spencer Brown Date: Thu, 17 Aug 2023 09:57:44 +1000 Subject: [PATCH 14/18] Make StapledStream generic, to preserve the type of the component streams --- docs/source/conf.py | 5 +++++ trio/_highlevel_generic.py | 19 ++++++++++++------- trio/testing/_memory_streams.py | 25 ++++++++++++++++++++----- 3 files changed, 37 insertions(+), 12 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index bade5a8ff1..25f428169a 100755 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -63,10 +63,15 @@ ("py:obj", "trio._abc.SendType"), ("py:obj", "trio._abc.T"), ("py:obj", "trio._abc.T_resource"), + ("py:obj", "trio._highlevel_generic.SendStreamT"), + ("py:obj", "trio._highlevel_generic.ReceiveStreamT"), + ("py:class", "trio._core._thread_cache.RetT"), ("py:class", "trio._core._run.StatusT"), ("py:class", "trio._core._run.StatusT_co"), ("py:class", "trio._core._run.StatusT_contra"), ("py:class", "trio._core._run.RetT"), + ("py:class", "trio._highlevel_generic.SendStreamT"), + ("py:class", "trio._highlevel_generic.ReceiveStreamT"), ("py:class", "trio._threads.T"), ("py:class", "P.args"), ("py:class", "P.kwargs"), diff --git a/trio/_highlevel_generic.py b/trio/_highlevel_generic.py index e1ac378c6a..f47df9a381 100644 --- a/trio/_highlevel_generic.py +++ b/trio/_highlevel_generic.py @@ -1,16 +1,17 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import Generic, TypeVar import attr import trio from trio._util import Final -if TYPE_CHECKING: - from .abc import SendStream, ReceiveStream, AsyncResource +from .abc import AsyncResource, HalfCloseableStream, ReceiveStream, SendStream -from .abc import HalfCloseableStream + +SendStreamT = TypeVar("SendStreamT", bound=SendStream) +ReceiveStreamT = TypeVar("ReceiveStreamT", bound=ReceiveStream) async def aclose_forcefully(resource: AsyncResource) -> None: @@ -44,7 +45,11 @@ async def aclose_forcefully(resource: AsyncResource) -> None: @attr.s(eq=False, hash=False) -class StapledStream(HalfCloseableStream, metaclass=Final): +class StapledStream( + HalfCloseableStream, + Generic[SendStreamT, ReceiveStreamT], + metaclass=Final, +): """This class `staples `__ together two unidirectional streams to make single bidirectional stream. @@ -79,8 +84,8 @@ class StapledStream(HalfCloseableStream, metaclass=Final): """ - send_stream: SendStream = attr.ib() - receive_stream: ReceiveStream = attr.ib() + send_stream: SendStreamT = attr.ib() + receive_stream: ReceiveStreamT = attr.ib() async def send_all(self, data: bytes | bytearray | memoryview) -> None: """Calls ``self.send_stream.send_all``.""" diff --git a/trio/testing/_memory_streams.py b/trio/testing/_memory_streams.py index 719b82e0b5..fc23fae842 100644 --- a/trio/testing/_memory_streams.py +++ b/trio/testing/_memory_streams.py @@ -1,7 +1,7 @@ from __future__ import annotations import operator -from typing import TYPE_CHECKING, Awaitable, Callable +from typing import TYPE_CHECKING, Awaitable, Callable, TypeVar from .. import _core, _util from .._highlevel_generic import StapledStream @@ -14,6 +14,8 @@ AsyncHook: TypeAlias = Callable[[], Awaitable[object]] # Would be nice to exclude awaitable here, but currently not possible. SyncHook: TypeAlias = Callable[[], object] +SendStreamT = TypeVar("SendStreamT", bound=SendStream) +ReceiveStreamT = TypeVar("ReceiveStreamT", bound=ReceiveStream) ################################################################ @@ -352,8 +354,11 @@ async def async_pump_from_send_stream_to_recv_stream() -> None: def _make_stapled_pair( - one_way_pair: Callable[[], tuple[SendStream, ReceiveStream]] -) -> tuple[StapledStream, StapledStream]: + one_way_pair: Callable[[], tuple[SendStreamT, ReceiveStreamT]] +) -> tuple[ + StapledStream[SendStreamT, ReceiveStreamT], + StapledStream[SendStreamT, ReceiveStreamT], +]: pipe1_send, pipe1_recv = one_way_pair() pipe2_send, pipe2_recv = one_way_pair() stream1 = StapledStream(pipe1_send, pipe2_recv) @@ -361,7 +366,12 @@ def _make_stapled_pair( return stream1, stream2 -def memory_stream_pair() -> tuple[SendStream, SendStream]: +def memory_stream_pair() -> ( + tuple[ + StapledStream[MemorySendStream, MemoryReceiveStream], + StapledStream[MemorySendStream, MemoryReceiveStream], + ] +): """Create a connected, pure-Python, bidirectional stream with infinite buffering and flexible configuration options. @@ -597,7 +607,12 @@ def lockstep_stream_one_way_pair() -> tuple[SendStream, ReceiveStream]: return _LockstepSendStream(lbq), _LockstepReceiveStream(lbq) -def lockstep_stream_pair() -> tuple[StapledStream, StapledStream]: +def lockstep_stream_pair() -> ( + tuple[ + StapledStream[SendStream, ReceiveStream], + StapledStream[SendStream, ReceiveStream], + ] +): """Create a connected, pure-Python, bidirectional stream where data flows in lockstep. From 2b669c7ef636ef4c9be0151c5bf40c71925887fd Mon Sep 17 00:00:00 2001 From: jakkdl Date: Thu, 17 Aug 2023 12:30:13 +0200 Subject: [PATCH 15/18] add _highlevel_generic to mypy strictlist --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 03ae322c56..7bc0cb1ad5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -64,6 +64,7 @@ module = [ "trio._deprecate", "trio._dtls", "trio._file_io", + "trio._highlevel_generic", "trio._highlevel_open_tcp_stream", "trio._ki", "trio._socket", From d191c39a4f712363f716af1e94e696ed5e888af3 Mon Sep 17 00:00:00 2001 From: jakkdl Date: Thu, 17 Aug 2023 12:39:09 +0200 Subject: [PATCH 16/18] fix missing generic parameter --- trio/_subprocess.py | 2 +- trio/_tests/verify_types.json | 11 +++++------ 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/trio/_subprocess.py b/trio/_subprocess.py index 7cf990fa53..978f7e6188 100644 --- a/trio/_subprocess.py +++ b/trio/_subprocess.py @@ -149,7 +149,7 @@ def __init__( self.stdout = stdout self.stderr = stderr - self.stdio: StapledStream | None = None + self.stdio: StapledStream[SendStream, ReceiveStream] | None = None if self.stdin is not None and self.stdout is not None: self.stdio = StapledStream(self.stdin, self.stdout) diff --git a/trio/_tests/verify_types.json b/trio/_tests/verify_types.json index 7fe3faaac6..b61b28a428 100644 --- a/trio/_tests/verify_types.json +++ b/trio/_tests/verify_types.json @@ -7,11 +7,11 @@ "warningCount": 0 }, "typeCompleteness": { - "completenessScore": 0.9856687898089171, + "completenessScore": 0.9872611464968153, "exportedSymbolCounts": { "withAmbiguousType": 0, - "withKnownType": 619, - "withUnknownType": 9 + "withKnownType": 620, + "withUnknownType": 8 }, "ignoreUnknownTypesFromImports": true, "missingClassDocStringCount": 1, @@ -46,8 +46,8 @@ ], "otherSymbolCounts": { "withAmbiguousType": 1, - "withKnownType": 660, - "withUnknownType": 21 + "withKnownType": 662, + "withUnknownType": 19 }, "packageName": "trio", "symbols": [ @@ -68,7 +68,6 @@ "trio._ssl.SSLStream.transport_stream", "trio._ssl.SSLStream.unwrap", "trio._ssl.SSLStream.wait_send_all_might_not_block", - "trio._subprocess.Process.stdio", "trio.lowlevel.notify_closing", "trio.lowlevel.wait_readable", "trio.lowlevel.wait_writable", From ad4157151fc53aff11eb398438e62934dcedea59 Mon Sep 17 00:00:00 2001 From: jakkdl Date: Thu, 17 Aug 2023 12:47:34 +0200 Subject: [PATCH 17/18] fix CI --- trio/_highlevel_generic.py | 1 - 1 file changed, 1 deletion(-) diff --git a/trio/_highlevel_generic.py b/trio/_highlevel_generic.py index f47df9a381..4269f90bae 100644 --- a/trio/_highlevel_generic.py +++ b/trio/_highlevel_generic.py @@ -9,7 +9,6 @@ from .abc import AsyncResource, HalfCloseableStream, ReceiveStream, SendStream - SendStreamT = TypeVar("SendStreamT", bound=SendStream) ReceiveStreamT = TypeVar("ReceiveStreamT", bound=ReceiveStream) From 054fe701dda5d96e78c16615022325b639c06683 Mon Sep 17 00:00:00 2001 From: John Litborn <11260241+jakkdl@users.noreply.github.com> Date: Fri, 18 Aug 2023 17:39:46 +0200 Subject: [PATCH 18/18] Update pyproject.toml remove testing/* files from the ignorelist in pyproject.toml for mypy --- pyproject.toml | 6 ------ 1 file changed, 6 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index a175593dbd..6893927337 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,12 +50,6 @@ disallow_untyped_calls = false # files not yet fully typed [[tool.mypy.overrides]] module = [ -# 2747 -"trio/testing/_network", -"trio/testing/_trio_test", -"trio/testing/_checkpoints", -"trio/testing/_check_streams", -"trio/testing/_memory_streams", # 2745 "trio/_ssl", # 2756