From 7f402303fe1703767d9236494aacc3f197fbc708 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 12 Jul 2024 09:02:07 +0200 Subject: [PATCH] Switch from typing.Optional to | None. --- src/websockets/client.py | 16 +-- src/websockets/exceptions.py | 19 ++- src/websockets/extensions/base.py | 4 +- .../extensions/permessage_deflate.py | 32 ++--- src/websockets/frames.py | 8 +- src/websockets/headers.py | 6 +- src/websockets/http11.py | 14 +- src/websockets/imports.py | 6 +- src/websockets/legacy/auth.py | 18 +-- src/websockets/legacy/client.py | 77 +++++----- src/websockets/legacy/framing.py | 8 +- src/websockets/legacy/protocol.py | 65 ++++----- src/websockets/legacy/server.py | 135 +++++++++--------- src/websockets/protocol.py | 28 ++-- src/websockets/server.py | 35 ++--- src/websockets/sync/client.py | 44 +++--- src/websockets/sync/connection.py | 28 ++-- src/websockets/sync/messages.py | 12 +- src/websockets/sync/server.py | 92 ++++++------ src/websockets/sync/utils.py | 7 +- src/websockets/typing.py | 2 +- src/websockets/uri.py | 8 +- 22 files changed, 334 insertions(+), 330 deletions(-) diff --git a/src/websockets/client.py b/src/websockets/client.py index 633b1960b..cfb441fd9 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -1,7 +1,7 @@ from __future__ import annotations import warnings -from typing import Any, Generator, List, Optional, Sequence +from typing import Any, Generator, List, Sequence from .datastructures import Headers, MultipleValuesError from .exceptions import ( @@ -73,12 +73,12 @@ def __init__( self, wsuri: WebSocketURI, *, - origin: Optional[Origin] = None, - extensions: Optional[Sequence[ClientExtensionFactory]] = None, - subprotocols: Optional[Sequence[Subprotocol]] = None, + origin: Origin | None = None, + extensions: Sequence[ClientExtensionFactory] | None = None, + subprotocols: Sequence[Subprotocol] | None = None, state: State = CONNECTING, - max_size: Optional[int] = 2**20, - logger: Optional[LoggerLike] = None, + max_size: int | None = 2**20, + logger: LoggerLike | None = None, ): super().__init__( side=CLIENT, @@ -261,7 +261,7 @@ def process_extensions(self, headers: Headers) -> List[Extension]: return accepted_extensions - def process_subprotocol(self, headers: Headers) -> Optional[Subprotocol]: + def process_subprotocol(self, headers: Headers) -> Subprotocol | None: """ Handle the Sec-WebSocket-Protocol HTTP response header. @@ -274,7 +274,7 @@ def process_subprotocol(self, headers: Headers) -> Optional[Subprotocol]: Subprotocol, if one was selected. """ - subprotocol: Optional[Subprotocol] = None + subprotocol: Subprotocol | None = None subprotocols = headers.get_all("Sec-WebSocket-Protocol") diff --git a/src/websockets/exceptions.py b/src/websockets/exceptions.py index f7169e3b1..adb66e262 100644 --- a/src/websockets/exceptions.py +++ b/src/websockets/exceptions.py @@ -31,7 +31,6 @@ from __future__ import annotations import http -from typing import Optional from . import datastructures, frames, http11 from .typing import StatusLike @@ -78,11 +77,11 @@ class ConnectionClosed(WebSocketException): Raised when trying to interact with a closed connection. Attributes: - rcvd (Optional[Close]): if a close frame was received, its code and + rcvd (Close | None): if a close frame was received, its code and reason are available in ``rcvd.code`` and ``rcvd.reason``. - sent (Optional[Close]): if a close frame was sent, its code and reason + sent (Close | None): if a close frame was sent, its code and reason are available in ``sent.code`` and ``sent.reason``. - rcvd_then_sent (Optional[bool]): if close frames were received and + rcvd_then_sent (bool | None): if close frames were received and sent, this attribute tells in which order this happened, from the perspective of this side of the connection. @@ -90,9 +89,9 @@ class ConnectionClosed(WebSocketException): def __init__( self, - rcvd: Optional[frames.Close], - sent: Optional[frames.Close], - rcvd_then_sent: Optional[bool] = None, + rcvd: frames.Close | None, + sent: frames.Close | None, + rcvd_then_sent: bool | None = None, ) -> None: self.rcvd = rcvd self.sent = sent @@ -181,7 +180,7 @@ class InvalidHeader(InvalidHandshake): """ - def __init__(self, name: str, value: Optional[str] = None) -> None: + def __init__(self, name: str, value: str | None = None) -> None: self.name = name self.value = value @@ -221,7 +220,7 @@ class InvalidOrigin(InvalidHeader): """ - def __init__(self, origin: Optional[str]) -> None: + def __init__(self, origin: str | None) -> None: super().__init__("Origin", origin) @@ -301,7 +300,7 @@ class InvalidParameterValue(NegotiationError): """ - def __init__(self, name: str, value: Optional[str]) -> None: + def __init__(self, name: str, value: str | None) -> None: self.name = name self.value = value diff --git a/src/websockets/extensions/base.py b/src/websockets/extensions/base.py index 7446c990c..5b5528a09 100644 --- a/src/websockets/extensions/base.py +++ b/src/websockets/extensions/base.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import List, Optional, Sequence, Tuple +from typing import List, Sequence, Tuple from .. import frames from ..typing import ExtensionName, ExtensionParameter @@ -22,7 +22,7 @@ def decode( self, frame: frames.Frame, *, - max_size: Optional[int] = None, + max_size: int | None = None, ) -> frames.Frame: """ Decode an incoming frame. diff --git a/src/websockets/extensions/permessage_deflate.py b/src/websockets/extensions/permessage_deflate.py index edccac3ca..e95b1064b 100644 --- a/src/websockets/extensions/permessage_deflate.py +++ b/src/websockets/extensions/permessage_deflate.py @@ -2,7 +2,7 @@ import dataclasses import zlib -from typing import Any, Dict, List, Optional, Sequence, Tuple, Union +from typing import Any, Dict, List, Sequence, Tuple, Union from .. import exceptions, frames from ..typing import ExtensionName, ExtensionParameter @@ -36,7 +36,7 @@ def __init__( local_no_context_takeover: bool, remote_max_window_bits: int, local_max_window_bits: int, - compress_settings: Optional[Dict[Any, Any]] = None, + compress_settings: Dict[Any, Any] | None = None, ) -> None: """ Configure the Per-Message Deflate extension. @@ -84,7 +84,7 @@ def decode( self, frame: frames.Frame, *, - max_size: Optional[int] = None, + max_size: int | None = None, ) -> frames.Frame: """ Decode an incoming frame. @@ -174,8 +174,8 @@ def encode(self, frame: frames.Frame) -> frames.Frame: def _build_parameters( server_no_context_takeover: bool, client_no_context_takeover: bool, - server_max_window_bits: Optional[int], - client_max_window_bits: Optional[Union[int, bool]], + server_max_window_bits: int | None, + client_max_window_bits: Union[int, bool] | None, ) -> List[ExtensionParameter]: """ Build a list of ``(name, value)`` pairs for some compression parameters. @@ -197,7 +197,7 @@ def _build_parameters( def _extract_parameters( params: Sequence[ExtensionParameter], *, is_server: bool -) -> Tuple[bool, bool, Optional[int], Optional[Union[int, bool]]]: +) -> Tuple[bool, bool, int | None, Union[int, bool] | None]: """ Extract compression parameters from a list of ``(name, value)`` pairs. @@ -207,8 +207,8 @@ def _extract_parameters( """ server_no_context_takeover: bool = False client_no_context_takeover: bool = False - server_max_window_bits: Optional[int] = None - client_max_window_bits: Optional[Union[int, bool]] = None + server_max_window_bits: int | None = None + client_max_window_bits: Union[int, bool] | None = None for name, value in params: if name == "server_no_context_takeover": @@ -286,9 +286,9 @@ def __init__( self, server_no_context_takeover: bool = False, client_no_context_takeover: bool = False, - server_max_window_bits: Optional[int] = None, - client_max_window_bits: Optional[Union[int, bool]] = True, - compress_settings: Optional[Dict[str, Any]] = None, + server_max_window_bits: int | None = None, + client_max_window_bits: Union[int, bool] | None = True, + compress_settings: Dict[str, Any] | None = None, ) -> None: """ Configure the Per-Message Deflate extension factory. @@ -433,7 +433,7 @@ def process_response_params( def enable_client_permessage_deflate( - extensions: Optional[Sequence[ClientExtensionFactory]], + extensions: Sequence[ClientExtensionFactory] | None, ) -> Sequence[ClientExtensionFactory]: """ Enable Per-Message Deflate with default settings in client extensions. @@ -489,9 +489,9 @@ def __init__( self, server_no_context_takeover: bool = False, client_no_context_takeover: bool = False, - server_max_window_bits: Optional[int] = None, - client_max_window_bits: Optional[int] = None, - compress_settings: Optional[Dict[str, Any]] = None, + server_max_window_bits: int | None = None, + client_max_window_bits: int | None = None, + compress_settings: Dict[str, Any] | None = None, require_client_max_window_bits: bool = False, ) -> None: """ @@ -635,7 +635,7 @@ def process_request_params( def enable_server_permessage_deflate( - extensions: Optional[Sequence[ServerExtensionFactory]], + extensions: Sequence[ServerExtensionFactory] | None, ) -> Sequence[ServerExtensionFactory]: """ Enable Per-Message Deflate with default settings in server extensions. diff --git a/src/websockets/frames.py b/src/websockets/frames.py index 862eef3aa..5a304d6a7 100644 --- a/src/websockets/frames.py +++ b/src/websockets/frames.py @@ -5,7 +5,7 @@ import io import secrets import struct -from typing import Callable, Generator, Optional, Sequence, Tuple +from typing import Callable, Generator, Sequence, Tuple from . import exceptions, extensions from .typing import Data @@ -205,8 +205,8 @@ def parse( read_exact: Callable[[int], Generator[None, None, bytes]], *, mask: bool, - max_size: Optional[int] = None, - extensions: Optional[Sequence[extensions.Extension]] = None, + max_size: int | None = None, + extensions: Sequence[extensions.Extension] | None = None, ) -> Generator[None, None, Frame]: """ Parse a WebSocket frame. @@ -280,7 +280,7 @@ def serialize( self, *, mask: bool, - extensions: Optional[Sequence[extensions.Extension]] = None, + extensions: Sequence[extensions.Extension] | None = None, ) -> bytes: """ Serialize a WebSocket frame. diff --git a/src/websockets/headers.py b/src/websockets/headers.py index 463df3061..3b316e0bf 100644 --- a/src/websockets/headers.py +++ b/src/websockets/headers.py @@ -4,7 +4,7 @@ import binascii import ipaddress import re -from typing import Callable, List, Optional, Sequence, Tuple, TypeVar, cast +from typing import Callable, List, Sequence, Tuple, TypeVar, cast from . import exceptions from .typing import ( @@ -63,7 +63,7 @@ def build_host(host: str, port: int, secure: bool) -> str: # https://www.rfc-editor.org/rfc/rfc7230.html#appendix-B. -def peek_ahead(header: str, pos: int) -> Optional[str]: +def peek_ahead(header: str, pos: int) -> str | None: """ Return the next character from ``header`` at the given position. @@ -314,7 +314,7 @@ def parse_extension_item_param( name, pos = parse_token(header, pos, header_name) pos = parse_OWS(header, pos) # Extract parameter value, if there is one. - value: Optional[str] = None + value: str | None = None if peek_ahead(header, pos) == "=": pos = parse_OWS(header, pos + 1) if peek_ahead(header, pos) == '"': diff --git a/src/websockets/http11.py b/src/websockets/http11.py index 6fe775eec..a7e9ae682 100644 --- a/src/websockets/http11.py +++ b/src/websockets/http11.py @@ -3,7 +3,7 @@ import dataclasses import re import warnings -from typing import Callable, Generator, Optional +from typing import Callable, Generator from . import datastructures, exceptions @@ -62,10 +62,10 @@ class Request: headers: datastructures.Headers # body isn't useful is the context of this library. - _exception: Optional[Exception] = None + _exception: Exception | None = None @property - def exception(self) -> Optional[Exception]: # pragma: no cover + def exception(self) -> Exception | None: # pragma: no cover warnings.warn( "Request.exception is deprecated; " "use ServerProtocol.handshake_exc instead", @@ -164,12 +164,12 @@ class Response: status_code: int reason_phrase: str headers: datastructures.Headers - body: Optional[bytes] = None + body: bytes | None = None - _exception: Optional[Exception] = None + _exception: Exception | None = None @property - def exception(self) -> Optional[Exception]: # pragma: no cover + def exception(self) -> Exception | None: # pragma: no cover warnings.warn( "Response.exception is deprecated; " "use ClientProtocol.handshake_exc instead", @@ -245,7 +245,7 @@ def parse( if 100 <= status_code < 200 or status_code == 204 or status_code == 304: body = None else: - content_length: Optional[int] + content_length: int | None try: # MultipleValuesError is sufficiently unlikely that we don't # attempt to handle it. Instead we document that its parent diff --git a/src/websockets/imports.py b/src/websockets/imports.py index a6a59d4c2..9c05234f5 100644 --- a/src/websockets/imports.py +++ b/src/websockets/imports.py @@ -1,7 +1,7 @@ from __future__ import annotations import warnings -from typing import Any, Dict, Iterable, Optional +from typing import Any, Dict, Iterable __all__ = ["lazy_import"] @@ -30,8 +30,8 @@ def import_name(name: str, source: str, namespace: Dict[str, Any]) -> Any: def lazy_import( namespace: Dict[str, Any], - aliases: Optional[Dict[str, str]] = None, - deprecated_aliases: Optional[Dict[str, str]] = None, + aliases: Dict[str, str] | None = None, + deprecated_aliases: Dict[str, str] | None = None, ) -> None: """ Provide lazy, module-level imports. diff --git a/src/websockets/legacy/auth.py b/src/websockets/legacy/auth.py index 8217afedd..067f9c78c 100644 --- a/src/websockets/legacy/auth.py +++ b/src/websockets/legacy/auth.py @@ -3,7 +3,7 @@ import functools import hmac import http -from typing import Any, Awaitable, Callable, Iterable, Optional, Tuple, Union, cast +from typing import Any, Awaitable, Callable, Iterable, Tuple, Union, cast from ..datastructures import Headers from ..exceptions import InvalidHeader @@ -39,14 +39,14 @@ class BasicAuthWebSocketServerProtocol(WebSocketServerProtocol): encoding of non-ASCII characters is undefined. """ - username: Optional[str] = None + username: str | None = None """Username of the authenticated user.""" def __init__( self, *args: Any, - realm: Optional[str] = None, - check_credentials: Optional[Callable[[str, str], Awaitable[bool]]] = None, + realm: str | None = None, + check_credentials: Callable[[str, str], Awaitable[bool]] | None = None, **kwargs: Any, ) -> None: if realm is not None: @@ -79,7 +79,7 @@ async def process_request( self, path: str, request_headers: Headers, - ) -> Optional[HTTPResponse]: + ) -> HTTPResponse | None: """ Check HTTP Basic Auth and return an HTTP 401 response if needed. @@ -115,10 +115,10 @@ async def process_request( def basic_auth_protocol_factory( - realm: Optional[str] = None, - credentials: Optional[Union[Credentials, Iterable[Credentials]]] = None, - check_credentials: Optional[Callable[[str, str], Awaitable[bool]]] = None, - create_protocol: Optional[Callable[..., BasicAuthWebSocketServerProtocol]] = None, + realm: str | None = None, + credentials: Union[Credentials, Iterable[Credentials]] | None = None, + check_credentials: Callable[[str, str], Awaitable[bool]] | None = None, + create_protocol: Callable[..., BasicAuthWebSocketServerProtocol] | None = None, ) -> Callable[..., BasicAuthWebSocketServerProtocol]: """ Protocol factory that enforces HTTP Basic Auth. diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index e5da8b13a..f7464368f 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -13,7 +13,6 @@ Callable, Generator, List, - Optional, Sequence, Tuple, Type, @@ -86,12 +85,12 @@ class WebSocketClientProtocol(WebSocketCommonProtocol): def __init__( self, *, - logger: Optional[LoggerLike] = None, - origin: Optional[Origin] = None, - extensions: Optional[Sequence[ClientExtensionFactory]] = None, - subprotocols: Optional[Sequence[Subprotocol]] = None, - extra_headers: Optional[HeadersLike] = None, - user_agent_header: Optional[str] = USER_AGENT, + logger: LoggerLike | None = None, + origin: Origin | None = None, + extensions: Sequence[ClientExtensionFactory] | None = None, + subprotocols: Sequence[Subprotocol] | None = None, + extra_headers: HeadersLike | None = None, + user_agent_header: str | None = USER_AGENT, **kwargs: Any, ) -> None: if logger is None: @@ -152,7 +151,7 @@ async def read_http_response(self) -> Tuple[int, Headers]: @staticmethod def process_extensions( headers: Headers, - available_extensions: Optional[Sequence[ClientExtensionFactory]], + available_extensions: Sequence[ClientExtensionFactory] | None, ) -> List[Extension]: """ Handle the Sec-WebSocket-Extensions HTTP response header. @@ -224,8 +223,8 @@ def process_extensions( @staticmethod def process_subprotocol( - headers: Headers, available_subprotocols: Optional[Sequence[Subprotocol]] - ) -> Optional[Subprotocol]: + headers: Headers, available_subprotocols: Sequence[Subprotocol] | None + ) -> Subprotocol | None: """ Handle the Sec-WebSocket-Protocol HTTP response header. @@ -234,7 +233,7 @@ def process_subprotocol( Return the selected subprotocol. """ - subprotocol: Optional[Subprotocol] = None + subprotocol: Subprotocol | None = None header_values = headers.get_all("Sec-WebSocket-Protocol") @@ -260,10 +259,10 @@ def process_subprotocol( async def handshake( self, wsuri: WebSocketURI, - origin: Optional[Origin] = None, - available_extensions: Optional[Sequence[ClientExtensionFactory]] = None, - available_subprotocols: Optional[Sequence[Subprotocol]] = None, - extra_headers: Optional[HeadersLike] = None, + origin: Origin | None = None, + available_extensions: Sequence[ClientExtensionFactory] | None = None, + available_subprotocols: Sequence[Subprotocol] | None = None, + extra_headers: HeadersLike | None = None, ) -> None: """ Perform the client side of the opening handshake. @@ -427,26 +426,26 @@ def __init__( self, uri: str, *, - create_protocol: Optional[Callable[..., WebSocketClientProtocol]] = None, - logger: Optional[LoggerLike] = None, - compression: Optional[str] = "deflate", - origin: Optional[Origin] = None, - extensions: Optional[Sequence[ClientExtensionFactory]] = None, - subprotocols: Optional[Sequence[Subprotocol]] = None, - extra_headers: Optional[HeadersLike] = None, - user_agent_header: Optional[str] = USER_AGENT, - open_timeout: Optional[float] = 10, - ping_interval: Optional[float] = 20, - ping_timeout: Optional[float] = 20, - close_timeout: Optional[float] = None, - max_size: Optional[int] = 2**20, - max_queue: Optional[int] = 2**5, + create_protocol: Callable[..., WebSocketClientProtocol] | None = None, + logger: LoggerLike | None = None, + compression: str | None = "deflate", + origin: Origin | None = None, + extensions: Sequence[ClientExtensionFactory] | None = None, + subprotocols: Sequence[Subprotocol] | None = None, + extra_headers: HeadersLike | None = None, + user_agent_header: str | None = USER_AGENT, + open_timeout: float | None = 10, + ping_interval: float | None = 20, + ping_timeout: float | None = 20, + close_timeout: float | None = None, + max_size: int | None = 2**20, + max_queue: int | None = 2**5, read_limit: int = 2**16, write_limit: int = 2**16, **kwargs: Any, ) -> None: # Backwards compatibility: close_timeout used to be called timeout. - timeout: Optional[float] = kwargs.pop("timeout", None) + timeout: float | None = kwargs.pop("timeout", None) if timeout is None: timeout = 10 else: @@ -456,7 +455,7 @@ def __init__( close_timeout = timeout # Backwards compatibility: create_protocol used to be called klass. - klass: Optional[Type[WebSocketClientProtocol]] = kwargs.pop("klass", None) + klass: Type[WebSocketClientProtocol] | None = kwargs.pop("klass", None) if klass is None: klass = WebSocketClientProtocol else: @@ -469,7 +468,7 @@ def __init__( legacy_recv: bool = kwargs.pop("legacy_recv", False) # Backwards compatibility: the loop parameter used to be supported. - _loop: Optional[asyncio.AbstractEventLoop] = kwargs.pop("loop", None) + _loop: asyncio.AbstractEventLoop | None = kwargs.pop("loop", None) if _loop is None: loop = asyncio.get_event_loop() else: @@ -516,13 +515,13 @@ def __init__( ) if kwargs.pop("unix", False): - path: Optional[str] = kwargs.pop("path", None) + path: str | None = kwargs.pop("path", None) create_connection = functools.partial( loop.create_unix_connection, factory, path, **kwargs ) else: - host: Optional[str] - port: Optional[int] + host: str | None + port: int | None if kwargs.get("sock") is None: host, port = wsuri.host, wsuri.port else: @@ -630,9 +629,9 @@ async def __aenter__(self) -> WebSocketClientProtocol: async def __aexit__( self, - exc_type: Optional[Type[BaseException]], - exc_value: Optional[BaseException], - traceback: Optional[TracebackType], + exc_type: Type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, ) -> None: await self.protocol.close() @@ -679,7 +678,7 @@ async def __await_impl__(self) -> WebSocketClientProtocol: def unix_connect( - path: Optional[str] = None, + path: str | None = None, uri: str = "ws://localhost/", **kwargs: Any, ) -> Connect: diff --git a/src/websockets/legacy/framing.py b/src/websockets/legacy/framing.py index b77b869e3..8a13fa446 100644 --- a/src/websockets/legacy/framing.py +++ b/src/websockets/legacy/framing.py @@ -1,7 +1,7 @@ from __future__ import annotations import struct -from typing import Any, Awaitable, Callable, NamedTuple, Optional, Sequence, Tuple +from typing import Any, Awaitable, Callable, NamedTuple, Sequence, Tuple from .. import extensions, frames from ..exceptions import PayloadTooBig, ProtocolError @@ -44,8 +44,8 @@ async def read( reader: Callable[[int], Awaitable[bytes]], *, mask: bool, - max_size: Optional[int] = None, - extensions: Optional[Sequence[extensions.Extension]] = None, + max_size: int | None = None, + extensions: Sequence[extensions.Extension] | None = None, ) -> Frame: """ Read a WebSocket frame. @@ -122,7 +122,7 @@ def write( write: Callable[[bytes], Any], *, mask: bool, - extensions: Optional[Sequence[extensions.Extension]] = None, + extensions: Sequence[extensions.Extension] | None = None, ) -> None: """ Write a WebSocket frame. diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index 26d50a2cc..94d42cfdb 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -22,7 +22,6 @@ Iterable, List, Mapping, - Optional, Tuple, Union, cast, @@ -173,21 +172,21 @@ class WebSocketCommonProtocol(asyncio.Protocol): def __init__( self, *, - logger: Optional[LoggerLike] = None, - ping_interval: Optional[float] = 20, - ping_timeout: Optional[float] = 20, - close_timeout: Optional[float] = None, - max_size: Optional[int] = 2**20, - max_queue: Optional[int] = 2**5, + logger: LoggerLike | None = None, + ping_interval: float | None = 20, + ping_timeout: float | None = 20, + close_timeout: float | None = None, + max_size: int | None = 2**20, + max_queue: int | None = 2**5, read_limit: int = 2**16, write_limit: int = 2**16, # The following arguments are kept only for backwards compatibility. - host: Optional[str] = None, - port: Optional[int] = None, - secure: Optional[bool] = None, + host: str | None = None, + port: int | None = None, + secure: bool | None = None, legacy_recv: bool = False, - loop: Optional[asyncio.AbstractEventLoop] = None, - timeout: Optional[float] = None, + loop: asyncio.AbstractEventLoop | None = None, + timeout: float | None = None, ) -> None: if legacy_recv: # pragma: no cover warnings.warn("legacy_recv is deprecated", DeprecationWarning) @@ -243,7 +242,7 @@ def __init__( # Copied from asyncio.FlowControlMixin self._paused = False - self._drain_waiter: Optional[asyncio.Future[None]] = None + self._drain_waiter: asyncio.Future[None] | None = None self._drain_lock = asyncio.Lock() @@ -265,13 +264,13 @@ def __init__( # WebSocket protocol parameters. self.extensions: List[Extension] = [] - self.subprotocol: Optional[Subprotocol] = None + self.subprotocol: Subprotocol | None = None """Subprotocol, if one was negotiated.""" # Close code and reason, set when a close frame is sent or received. - self.close_rcvd: Optional[Close] = None - self.close_sent: Optional[Close] = None - self.close_rcvd_then_sent: Optional[bool] = None + self.close_rcvd: Close | None = None + self.close_sent: Close | None = None + self.close_rcvd_then_sent: bool | None = None # Completed when the connection state becomes CLOSED. Translates the # :meth:`connection_lost` callback to a :class:`~asyncio.Future` @@ -281,11 +280,11 @@ def __init__( # Queue of received messages. self.messages: Deque[Data] = collections.deque() - self._pop_message_waiter: Optional[asyncio.Future[None]] = None - self._put_message_waiter: Optional[asyncio.Future[None]] = None + self._pop_message_waiter: asyncio.Future[None] | None = None + self._put_message_waiter: asyncio.Future[None] | None = None # Protect sending fragmented messages. - self._fragmented_message_waiter: Optional[asyncio.Future[None]] = None + self._fragmented_message_waiter: asyncio.Future[None] | None = None # Mapping of ping IDs to pong waiters, in chronological order. self.pings: Dict[bytes, Tuple[asyncio.Future[float], float]] = {} @@ -306,7 +305,7 @@ def __init__( self.transfer_data_task: asyncio.Task[None] # Exception that occurred during data transfer, if any. - self.transfer_data_exc: Optional[BaseException] = None + self.transfer_data_exc: BaseException | None = None # Task sending keepalive pings. self.keepalive_ping_task: asyncio.Task[None] @@ -363,19 +362,19 @@ def connection_open(self) -> None: self.close_connection_task = self.loop.create_task(self.close_connection()) @property - def host(self) -> Optional[str]: + def host(self) -> str | None: alternative = "remote_address" if self.is_client else "local_address" warnings.warn(f"use {alternative}[0] instead of host", DeprecationWarning) return self._host @property - def port(self) -> Optional[int]: + def port(self) -> int | None: alternative = "remote_address" if self.is_client else "local_address" warnings.warn(f"use {alternative}[1] instead of port", DeprecationWarning) return self._port @property - def secure(self) -> Optional[bool]: + def secure(self) -> bool | None: warnings.warn("don't use secure", DeprecationWarning) return self._secure @@ -447,7 +446,7 @@ def closed(self) -> bool: return self.state is State.CLOSED @property - def close_code(self) -> Optional[int]: + def close_code(self) -> int | None: """ WebSocket close code, defined in `section 7.1.5 of RFC 6455`_. @@ -465,7 +464,7 @@ def close_code(self) -> Optional[int]: return self.close_rcvd.code @property - def close_reason(self) -> Optional[str]: + def close_reason(self) -> str | None: """ WebSocket close reason, defined in `section 7.1.6 of RFC 6455`_. @@ -804,7 +803,7 @@ async def wait_closed(self) -> None: """ await asyncio.shield(self.connection_lost_waiter) - async def ping(self, data: Optional[Data] = None) -> Awaitable[float]: + async def ping(self, data: Data | None = None) -> Awaitable[float]: """ Send a Ping_. @@ -1017,7 +1016,7 @@ async def transfer_data(self) -> None: self.transfer_data_exc = exc self.fail_connection(CloseCode.INTERNAL_ERROR) - async def read_message(self) -> Optional[Data]: + async def read_message(self) -> Data | None: """ Read a single message from the connection. @@ -1090,7 +1089,7 @@ def append(frame: Frame) -> None: return ("" if text else b"").join(fragments) - async def read_data_frame(self, max_size: Optional[int]) -> Optional[Frame]: + async def read_data_frame(self, max_size: int | None) -> Frame | None: """ Read a single data frame from the connection. @@ -1153,7 +1152,7 @@ async def read_data_frame(self, max_size: Optional[int]) -> Optional[Frame]: else: return frame - async def read_frame(self, max_size: Optional[int]) -> Frame: + async def read_frame(self, max_size: int | None) -> Frame: """ Read a single frame from the connection. @@ -1204,9 +1203,7 @@ async def write_frame( self.write_frame_sync(fin, opcode, data) await self.drain() - async def write_close_frame( - self, close: Close, data: Optional[bytes] = None - ) -> None: + async def write_close_frame(self, close: Close, data: bytes | None = None) -> None: """ Write a close frame if and only if the connection state is OPEN. @@ -1484,7 +1481,7 @@ def connection_made(self, transport: asyncio.BaseTransport) -> None: # Copied from asyncio.StreamReaderProtocol self.reader.set_transport(transport) - def connection_lost(self, exc: Optional[Exception]) -> None: + def connection_lost(self, exc: Exception | None) -> None: """ 7.1.4. The WebSocket Connection is Closed. diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 0f3c1c150..551115174 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -16,7 +16,6 @@ Generator, Iterable, List, - Optional, Sequence, Set, Tuple, @@ -103,19 +102,19 @@ def __init__( ], ws_server: WebSocketServer, *, - logger: Optional[LoggerLike] = None, - origins: Optional[Sequence[Optional[Origin]]] = None, - extensions: Optional[Sequence[ServerExtensionFactory]] = None, - subprotocols: Optional[Sequence[Subprotocol]] = None, - extra_headers: Optional[HeadersLikeOrCallable] = None, - server_header: Optional[str] = USER_AGENT, - process_request: Optional[ - Callable[[str, Headers], Awaitable[Optional[HTTPResponse]]] - ] = None, - select_subprotocol: Optional[ - Callable[[Sequence[Subprotocol], Sequence[Subprotocol]], Subprotocol] - ] = None, - open_timeout: Optional[float] = 10, + logger: LoggerLike | None = None, + origins: Sequence[Origin | None] | None = None, + extensions: Sequence[ServerExtensionFactory] | None = None, + subprotocols: Sequence[Subprotocol] | None = None, + extra_headers: HeadersLikeOrCallable | None = None, + server_header: str | None = USER_AGENT, + process_request: ( + Callable[[str, Headers], Awaitable[HTTPResponse | None]] | None + ) = None, + select_subprotocol: ( + Callable[[Sequence[Subprotocol], Sequence[Subprotocol]], Subprotocol] | None + ) = None, + open_timeout: float | None = 10, **kwargs: Any, ) -> None: if logger is None: @@ -293,7 +292,7 @@ async def read_http_request(self) -> Tuple[str, Headers]: return path, headers def write_http_response( - self, status: http.HTTPStatus, headers: Headers, body: Optional[bytes] = None + self, status: http.HTTPStatus, headers: Headers, body: bytes | None = None ) -> None: """ Write status line and headers to the HTTP response. @@ -322,7 +321,7 @@ def write_http_response( async def process_request( self, path: str, request_headers: Headers - ) -> Optional[HTTPResponse]: + ) -> HTTPResponse | None: """ Intercept the HTTP request and return an HTTP response if appropriate. @@ -371,8 +370,8 @@ async def process_request( @staticmethod def process_origin( - headers: Headers, origins: Optional[Sequence[Optional[Origin]]] = None - ) -> Optional[Origin]: + headers: Headers, origins: Sequence[Origin | None] | None = None + ) -> Origin | None: """ Handle the Origin HTTP request header. @@ -387,9 +386,11 @@ def process_origin( # "The user agent MUST NOT include more than one Origin header field" # per https://www.rfc-editor.org/rfc/rfc6454.html#section-7.3. try: - origin = cast(Optional[Origin], headers.get("Origin")) + origin = headers.get("Origin") except MultipleValuesError as exc: raise InvalidHeader("Origin", "more than one Origin header found") from exc + if origin is not None: + origin = cast(Origin, origin) if origins is not None: if origin not in origins: raise InvalidOrigin(origin) @@ -398,8 +399,8 @@ def process_origin( @staticmethod def process_extensions( headers: Headers, - available_extensions: Optional[Sequence[ServerExtensionFactory]], - ) -> Tuple[Optional[str], List[Extension]]: + available_extensions: Sequence[ServerExtensionFactory] | None, + ) -> Tuple[str | None, List[Extension]]: """ Handle the Sec-WebSocket-Extensions HTTP request header. @@ -435,7 +436,7 @@ def process_extensions( InvalidHandshake: To abort the handshake with an HTTP 400 error. """ - response_header_value: Optional[str] = None + response_header_value: str | None = None extension_headers: List[ExtensionHeader] = [] accepted_extensions: List[Extension] = [] @@ -479,8 +480,8 @@ def process_extensions( # Not @staticmethod because it calls self.select_subprotocol() def process_subprotocol( - self, headers: Headers, available_subprotocols: Optional[Sequence[Subprotocol]] - ) -> Optional[Subprotocol]: + self, headers: Headers, available_subprotocols: Sequence[Subprotocol] | None + ) -> Subprotocol | None: """ Handle the Sec-WebSocket-Protocol HTTP request header. @@ -495,7 +496,7 @@ def process_subprotocol( InvalidHandshake: To abort the handshake with an HTTP 400 error. """ - subprotocol: Optional[Subprotocol] = None + subprotocol: Subprotocol | None = None header_values = headers.get_all("Sec-WebSocket-Protocol") @@ -514,7 +515,7 @@ def select_subprotocol( self, client_subprotocols: Sequence[Subprotocol], server_subprotocols: Sequence[Subprotocol], - ) -> Optional[Subprotocol]: + ) -> Subprotocol | None: """ Pick a subprotocol among those supported by the client and the server. @@ -552,10 +553,10 @@ def select_subprotocol( async def handshake( self, - origins: Optional[Sequence[Optional[Origin]]] = None, - available_extensions: Optional[Sequence[ServerExtensionFactory]] = None, - available_subprotocols: Optional[Sequence[Subprotocol]] = None, - extra_headers: Optional[HeadersLikeOrCallable] = None, + origins: Sequence[Origin | None] | None = None, + available_extensions: Sequence[ServerExtensionFactory] | None = None, + available_subprotocols: Sequence[Subprotocol] | None = None, + extra_headers: HeadersLikeOrCallable | None = None, ) -> str: """ Perform the server side of the opening handshake. @@ -661,7 +662,7 @@ class WebSocketServer: """ - def __init__(self, logger: Optional[LoggerLike] = None): + def __init__(self, logger: LoggerLike | None = None): if logger is None: logger = logging.getLogger("websockets.server") self.logger = logger @@ -670,7 +671,7 @@ def __init__(self, logger: Optional[LoggerLike] = None): self.websockets: Set[WebSocketServerProtocol] = set() # Task responsible for closing the server and terminating connections. - self.close_task: Optional[asyncio.Task[None]] = None + self.close_task: asyncio.Task[None] | None = None # Completed when the server is closed and connections are terminated. self.closed_waiter: asyncio.Future[None] @@ -869,9 +870,9 @@ async def __aenter__(self) -> WebSocketServer: # pragma: no cover async def __aexit__( self, - exc_type: Optional[Type[BaseException]], - exc_value: Optional[BaseException], - traceback: Optional[TracebackType], + exc_type: Type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, ) -> None: # pragma: no cover self.close() await self.wait_closed() @@ -941,8 +942,8 @@ class Serve: server_header: Value of the ``Server`` response header. It defaults to ``"Python/x.y.z websockets/X.Y"``. Setting it to :obj:`None` removes the header. - process_request (Optional[Callable[[str, Headers], \ - Awaitable[Optional[Tuple[StatusLike, HeadersLike, bytes]]]]]): + process_request (Callable[[str, Headers], \ + Awaitable[Tuple[StatusLike, HeadersLike, bytes] | None]] | None): Intercept HTTP request before the opening handshake. See :meth:`~WebSocketServerProtocol.process_request` for details. select_subprotocol: Select a subprotocol supported by the client. @@ -975,35 +976,35 @@ def __init__( Callable[[WebSocketServerProtocol], Awaitable[Any]], Callable[[WebSocketServerProtocol, str], Awaitable[Any]], # deprecated ], - host: Optional[Union[str, Sequence[str]]] = None, - port: Optional[int] = None, + host: Union[str, Sequence[str]] | None = None, + port: int | None = None, *, - create_protocol: Optional[Callable[..., WebSocketServerProtocol]] = None, - logger: Optional[LoggerLike] = None, - compression: Optional[str] = "deflate", - origins: Optional[Sequence[Optional[Origin]]] = None, - extensions: Optional[Sequence[ServerExtensionFactory]] = None, - subprotocols: Optional[Sequence[Subprotocol]] = None, - extra_headers: Optional[HeadersLikeOrCallable] = None, - server_header: Optional[str] = USER_AGENT, - process_request: Optional[ - Callable[[str, Headers], Awaitable[Optional[HTTPResponse]]] - ] = None, - select_subprotocol: Optional[ - Callable[[Sequence[Subprotocol], Sequence[Subprotocol]], Subprotocol] - ] = None, - open_timeout: Optional[float] = 10, - ping_interval: Optional[float] = 20, - ping_timeout: Optional[float] = 20, - close_timeout: Optional[float] = None, - max_size: Optional[int] = 2**20, - max_queue: Optional[int] = 2**5, + create_protocol: Callable[..., WebSocketServerProtocol] | None = None, + logger: LoggerLike | None = None, + compression: str | None = "deflate", + origins: Sequence[Origin | None] | None = None, + extensions: Sequence[ServerExtensionFactory] | None = None, + subprotocols: Sequence[Subprotocol] | None = None, + extra_headers: HeadersLikeOrCallable | None = None, + server_header: str | None = USER_AGENT, + process_request: ( + Callable[[str, Headers], Awaitable[HTTPResponse | None]] | None + ) = None, + select_subprotocol: ( + Callable[[Sequence[Subprotocol], Sequence[Subprotocol]], Subprotocol] | None + ) = None, + open_timeout: float | None = 10, + ping_interval: float | None = 20, + ping_timeout: float | None = 20, + close_timeout: float | None = None, + max_size: int | None = 2**20, + max_queue: int | None = 2**5, read_limit: int = 2**16, write_limit: int = 2**16, **kwargs: Any, ) -> None: # Backwards compatibility: close_timeout used to be called timeout. - timeout: Optional[float] = kwargs.pop("timeout", None) + timeout: float | None = kwargs.pop("timeout", None) if timeout is None: timeout = 10 else: @@ -1013,7 +1014,7 @@ def __init__( close_timeout = timeout # Backwards compatibility: create_protocol used to be called klass. - klass: Optional[Type[WebSocketServerProtocol]] = kwargs.pop("klass", None) + klass: Type[WebSocketServerProtocol] | None = kwargs.pop("klass", None) if klass is None: klass = WebSocketServerProtocol else: @@ -1026,7 +1027,7 @@ def __init__( legacy_recv: bool = kwargs.pop("legacy_recv", False) # Backwards compatibility: the loop parameter used to be supported. - _loop: Optional[asyncio.AbstractEventLoop] = kwargs.pop("loop", None) + _loop: asyncio.AbstractEventLoop | None = kwargs.pop("loop", None) if _loop is None: loop = asyncio.get_event_loop() else: @@ -1076,7 +1077,7 @@ def __init__( ) if kwargs.pop("unix", False): - path: Optional[str] = kwargs.pop("path", None) + path: str | None = kwargs.pop("path", None) # unix_serve(path) must not specify host and port parameters. assert host is None and port is None create_server = functools.partial( @@ -1098,9 +1099,9 @@ async def __aenter__(self) -> WebSocketServer: async def __aexit__( self, - exc_type: Optional[Type[BaseException]], - exc_value: Optional[BaseException], - traceback: Optional[TracebackType], + exc_type: Type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, ) -> None: self.ws_server.close() await self.ws_server.wait_closed() @@ -1129,7 +1130,7 @@ def unix_serve( Callable[[WebSocketServerProtocol], Awaitable[Any]], Callable[[WebSocketServerProtocol, str], Awaitable[Any]], # deprecated ], - path: Optional[str] = None, + path: str | None = None, **kwargs: Any, ) -> Serve: """ diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 0b36202e5..8aa222eeb 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -3,7 +3,7 @@ import enum import logging import uuid -from typing import Generator, List, Optional, Type, Union +from typing import Generator, List, Type, Union from .exceptions import ( ConnectionClosed, @@ -89,8 +89,8 @@ def __init__( side: Side, *, state: State = OPEN, - max_size: Optional[int] = 2**20, - logger: Optional[LoggerLike] = None, + max_size: int | None = 2**20, + logger: LoggerLike | None = None, ) -> None: # Unique identifier. For logs. self.id: uuid.UUID = uuid.uuid4() @@ -116,24 +116,24 @@ def __init__( # Current size of incoming message in bytes. Only set while reading a # fragmented message i.e. a data frames with the FIN bit not set. - self.cur_size: Optional[int] = None + self.cur_size: int | None = None # True while sending a fragmented message i.e. a data frames with the # FIN bit not set. self.expect_continuation_frame = False # WebSocket protocol parameters. - self.origin: Optional[Origin] = None + self.origin: Origin | None = None self.extensions: List[Extension] = [] - self.subprotocol: Optional[Subprotocol] = None + self.subprotocol: Subprotocol | None = None # Close code and reason, set when a close frame is sent or received. - self.close_rcvd: Optional[Close] = None - self.close_sent: Optional[Close] = None - self.close_rcvd_then_sent: Optional[bool] = None + self.close_rcvd: Close | None = None + self.close_sent: Close | None = None + self.close_rcvd_then_sent: bool | None = None # Track if an exception happened during the handshake. - self.handshake_exc: Optional[Exception] = None + self.handshake_exc: Exception | None = None """ Exception to raise if the opening handshake failed. @@ -150,7 +150,7 @@ def __init__( self.writes: List[bytes] = [] self.parser = self.parse() next(self.parser) # start coroutine - self.parser_exc: Optional[Exception] = None + self.parser_exc: Exception | None = None @property def state(self) -> State: @@ -169,7 +169,7 @@ def state(self, state: State) -> None: self._state = state @property - def close_code(self) -> Optional[int]: + def close_code(self) -> int | None: """ `WebSocket close code`_. @@ -187,7 +187,7 @@ def close_code(self) -> Optional[int]: return self.close_rcvd.code @property - def close_reason(self) -> Optional[str]: + def close_reason(self) -> str | None: """ `WebSocket close reason`_. @@ -348,7 +348,7 @@ def send_binary(self, data: bytes, fin: bool = True) -> None: self.expect_continuation_frame = not fin self.send_frame(Frame(OP_BINARY, data, fin)) - def send_close(self, code: Optional[int] = None, reason: str = "") -> None: + def send_close(self, code: int | None = None, reason: str = "") -> None: """ Send a `Close frame`_. diff --git a/src/websockets/server.py b/src/websockets/server.py index 330e54f37..a92541085 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -5,7 +5,7 @@ import email.utils import http import warnings -from typing import Any, Callable, Generator, List, Optional, Sequence, Tuple, cast +from typing import Any, Callable, Generator, List, Sequence, Tuple, cast from .datastructures import Headers, MultipleValuesError from .exceptions import ( @@ -77,18 +77,19 @@ class ServerProtocol(Protocol): def __init__( self, *, - origins: Optional[Sequence[Optional[Origin]]] = None, - extensions: Optional[Sequence[ServerExtensionFactory]] = None, - subprotocols: Optional[Sequence[Subprotocol]] = None, - select_subprotocol: Optional[ + origins: Sequence[Origin | None] | None = None, + extensions: Sequence[ServerExtensionFactory] | None = None, + subprotocols: Sequence[Subprotocol] | None = None, + select_subprotocol: ( Callable[ [ServerProtocol, Sequence[Subprotocol]], - Optional[Subprotocol], + Subprotocol | None, ] - ] = None, + | None + ) = None, state: State = CONNECTING, - max_size: Optional[int] = 2**20, - logger: Optional[LoggerLike] = None, + max_size: int | None = 2**20, + logger: LoggerLike | None = None, ): super().__init__( side=SERVER, @@ -200,7 +201,7 @@ def accept(self, request: Request) -> Response: def process_request( self, request: Request, - ) -> Tuple[str, Optional[str], Optional[str]]: + ) -> Tuple[str, str | None, str | None]: """ Check a handshake request and negotiate extensions and subprotocol. @@ -285,7 +286,7 @@ def process_request( protocol_header, ) - def process_origin(self, headers: Headers) -> Optional[Origin]: + def process_origin(self, headers: Headers) -> Origin | None: """ Handle the Origin HTTP request header. @@ -303,9 +304,11 @@ def process_origin(self, headers: Headers) -> Optional[Origin]: # "The user agent MUST NOT include more than one Origin header field" # per https://www.rfc-editor.org/rfc/rfc6454.html#section-7.3. try: - origin = cast(Optional[Origin], headers.get("Origin")) + origin = headers.get("Origin") except MultipleValuesError as exc: raise InvalidHeader("Origin", "more than one Origin header found") from exc + if origin is not None: + origin = cast(Origin, origin) if self.origins is not None: if origin not in self.origins: raise InvalidOrigin(origin) @@ -314,7 +317,7 @@ def process_origin(self, headers: Headers) -> Optional[Origin]: def process_extensions( self, headers: Headers, - ) -> Tuple[Optional[str], List[Extension]]: + ) -> Tuple[str | None, List[Extension]]: """ Handle the Sec-WebSocket-Extensions HTTP request header. @@ -350,7 +353,7 @@ def process_extensions( InvalidHandshake: If the Sec-WebSocket-Extensions header is invalid. """ - response_header_value: Optional[str] = None + response_header_value: str | None = None extension_headers: List[ExtensionHeader] = [] accepted_extensions: List[Extension] = [] @@ -392,7 +395,7 @@ def process_extensions( return response_header_value, accepted_extensions - def process_subprotocol(self, headers: Headers) -> Optional[Subprotocol]: + def process_subprotocol(self, headers: Headers) -> Subprotocol | None: """ Handle the Sec-WebSocket-Protocol HTTP request header. @@ -420,7 +423,7 @@ def process_subprotocol(self, headers: Headers) -> Optional[Subprotocol]: def select_subprotocol( self, subprotocols: Sequence[Subprotocol], - ) -> Optional[Subprotocol]: + ) -> Subprotocol | None: """ Pick a subprotocol among those offered by the client. diff --git a/src/websockets/sync/client.py b/src/websockets/sync/client.py index 0bb7a76fd..60b49ebc3 100644 --- a/src/websockets/sync/client.py +++ b/src/websockets/sync/client.py @@ -4,7 +4,7 @@ import ssl as ssl_module import threading import warnings -from typing import Any, Optional, Sequence, Type +from typing import Any, Sequence, Type from ..client import ClientProtocol from ..datastructures import HeadersLike @@ -52,7 +52,7 @@ def __init__( socket: socket.socket, protocol: ClientProtocol, *, - close_timeout: Optional[float] = 10, + close_timeout: float | None = 10, ) -> None: self.protocol: ClientProtocol self.response_rcvd = threading.Event() @@ -64,9 +64,9 @@ def __init__( def handshake( self, - additional_headers: Optional[HeadersLike] = None, - user_agent_header: Optional[str] = USER_AGENT, - timeout: Optional[float] = None, + additional_headers: HeadersLike | None = None, + user_agent_header: str | None = USER_AGENT, + timeout: float | None = None, ) -> None: """ Perform the opening handshake. @@ -128,25 +128,25 @@ def connect( uri: str, *, # TCP/TLS - sock: Optional[socket.socket] = None, - ssl: Optional[ssl_module.SSLContext] = None, - server_hostname: Optional[str] = None, + sock: socket.socket | None = None, + ssl: ssl_module.SSLContext | None = None, + server_hostname: str | None = None, # WebSocket - origin: Optional[Origin] = None, - extensions: Optional[Sequence[ClientExtensionFactory]] = None, - subprotocols: Optional[Sequence[Subprotocol]] = None, - additional_headers: Optional[HeadersLike] = None, - user_agent_header: Optional[str] = USER_AGENT, - compression: Optional[str] = "deflate", + origin: Origin | None = None, + extensions: Sequence[ClientExtensionFactory] | None = None, + subprotocols: Sequence[Subprotocol] | None = None, + additional_headers: HeadersLike | None = None, + user_agent_header: str | None = USER_AGENT, + compression: str | None = "deflate", # Timeouts - open_timeout: Optional[float] = 10, - close_timeout: Optional[float] = 10, + open_timeout: float | None = 10, + close_timeout: float | None = 10, # Limits - max_size: Optional[int] = 2**20, + max_size: int | None = 2**20, # Logging - logger: Optional[LoggerLike] = None, + logger: LoggerLike | None = None, # Escape hatch for advanced customization - create_connection: Optional[Type[ClientConnection]] = None, + create_connection: Type[ClientConnection] | None = None, **kwargs: Any, ) -> ClientConnection: """ @@ -219,7 +219,7 @@ def connect( # Private APIs for unix_connect() unix: bool = kwargs.pop("unix", False) - path: Optional[str] = kwargs.pop("path", None) + path: str | None = kwargs.pop("path", None) if unix: if path is None and sock is None: @@ -307,8 +307,8 @@ def connect( def unix_connect( - path: Optional[str] = None, - uri: Optional[str] = None, + path: str | None = None, + uri: str | None = None, **kwargs: Any, ) -> ClientConnection: """ diff --git a/src/websockets/sync/connection.py b/src/websockets/sync/connection.py index b41202dc9..bb9743181 100644 --- a/src/websockets/sync/connection.py +++ b/src/websockets/sync/connection.py @@ -8,7 +8,7 @@ import threading import uuid from types import TracebackType -from typing import Any, Dict, Iterable, Iterator, Mapping, Optional, Type, Union +from typing import Any, Dict, Iterable, Iterator, Mapping, Type, Union from ..exceptions import ConnectionClosed, ConnectionClosedOK, ProtocolError from ..frames import DATA_OPCODES, BytesLike, CloseCode, Frame, Opcode, prepare_ctrl @@ -42,7 +42,7 @@ def __init__( socket: socket.socket, protocol: Protocol, *, - close_timeout: Optional[float] = 10, + close_timeout: float | None = 10, ) -> None: self.socket = socket self.protocol = protocol @@ -62,9 +62,9 @@ def __init__( self.debug = self.protocol.debug # HTTP handshake request and response. - self.request: Optional[Request] = None + self.request: Request | None = None """Opening handshake request.""" - self.response: Optional[Response] = None + self.response: Response | None = None """Opening handshake response.""" # Mutex serializing interactions with the protocol. @@ -77,7 +77,7 @@ def __init__( self.send_in_progress = False # Deadline for the closing handshake. - self.close_deadline: Optional[Deadline] = None + self.close_deadline: Deadline | None = None # Mapping of ping IDs to pong waiters, in chronological order. self.ping_waiters: Dict[bytes, threading.Event] = {} @@ -93,7 +93,7 @@ def __init__( # Exception raised in recv_events, to be chained to ConnectionClosed # in the user thread in order to show why the TCP connection dropped. - self.recv_exc: Optional[BaseException] = None + self.recv_exc: BaseException | None = None # Public attributes @@ -124,7 +124,7 @@ def remote_address(self) -> Any: return self.socket.getpeername() @property - def subprotocol(self) -> Optional[Subprotocol]: + def subprotocol(self) -> Subprotocol | None: """ Subprotocol negotiated during the opening handshake. @@ -140,9 +140,9 @@ def __enter__(self) -> Connection: def __exit__( self, - exc_type: Optional[Type[BaseException]], - exc_value: Optional[BaseException], - traceback: Optional[TracebackType], + exc_type: Type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, ) -> None: if exc_type is None: self.close() @@ -166,7 +166,7 @@ def __iter__(self) -> Iterator[Data]: except ConnectionClosedOK: return - def recv(self, timeout: Optional[float] = None) -> Data: + def recv(self, timeout: float | None = None) -> Data: """ Receive the next message. @@ -420,7 +420,7 @@ def close(self, code: int = CloseCode.NORMAL_CLOSURE, reason: str = "") -> None: # They mean that the connection is closed, which was the goal. pass - def ping(self, data: Optional[Data] = None) -> threading.Event: + def ping(self, data: Data | None = None) -> threading.Event: """ Send a Ping_. @@ -647,7 +647,7 @@ def send_context( # Should we close the socket and raise ConnectionClosed? raise_close_exc = False # What exception should we chain ConnectionClosed to? - original_exc: Optional[BaseException] = None + original_exc: BaseException | None = None # Acquire the protocol lock. with self.protocol_mutex: @@ -748,7 +748,7 @@ def send_data(self) -> None: except OSError: # socket already closed pass - def set_recv_exc(self, exc: Optional[BaseException]) -> None: + def set_recv_exc(self, exc: BaseException | None) -> None: """ Set recv_exc, if not set yet. diff --git a/src/websockets/sync/messages.py b/src/websockets/sync/messages.py index dcba183d9..2c604ba09 100644 --- a/src/websockets/sync/messages.py +++ b/src/websockets/sync/messages.py @@ -3,7 +3,7 @@ import codecs import queue import threading -from typing import Iterator, List, Optional, cast +from typing import Iterator, List, cast from ..frames import OP_BINARY, OP_CONT, OP_TEXT, Frame from ..typing import Data @@ -41,7 +41,7 @@ def __init__(self) -> None: self.put_in_progress = False # Decoder for text frames, None for binary frames. - self.decoder: Optional[codecs.IncrementalDecoder] = None + self.decoder: codecs.IncrementalDecoder | None = None # Buffer of frames belonging to the same message. self.chunks: List[Data] = [] @@ -54,12 +54,12 @@ def __init__(self) -> None: # Stream data from frames belonging to the same message. # Remove quotes around type when dropping Python < 3.9. - self.chunks_queue: Optional["queue.SimpleQueue[Optional[Data]]"] = None + self.chunks_queue: "queue.SimpleQueue[Data | None] | None" = None # This flag marks the end of the connection. self.closed = False - def get(self, timeout: Optional[float] = None) -> Data: + def get(self, timeout: float | None = None) -> Data: """ Read the next message. @@ -151,7 +151,7 @@ def get_iter(self) -> Iterator[Data]: self.chunks = [] self.chunks_queue = cast( # Remove quotes around type when dropping Python < 3.9. - "queue.SimpleQueue[Optional[Data]]", + "queue.SimpleQueue[Data | None]", queue.SimpleQueue(), ) @@ -164,7 +164,7 @@ def get_iter(self) -> Iterator[Data]: self.get_in_progress = True # Locking with get_in_progress ensures only one thread can get here. - chunk: Optional[Data] + chunk: Data | None for chunk in chunks: yield chunk while (chunk := self.chunks_queue.get()) is not None: diff --git a/src/websockets/sync/server.py b/src/websockets/sync/server.py index fd4f5d3bd..b801510b4 100644 --- a/src/websockets/sync/server.py +++ b/src/websockets/sync/server.py @@ -10,7 +10,7 @@ import threading import warnings from types import TracebackType -from typing import Any, Callable, Optional, Sequence, Type +from typing import Any, Callable, Sequence, Type from ..extensions.base import ServerExtensionFactory from ..extensions.permessage_deflate import enable_server_permessage_deflate @@ -57,7 +57,7 @@ def __init__( socket: socket.socket, protocol: ServerProtocol, *, - close_timeout: Optional[float] = 10, + close_timeout: float | None = 10, ) -> None: self.protocol: ServerProtocol self.request_rcvd = threading.Event() @@ -69,20 +69,22 @@ def __init__( def handshake( self, - process_request: Optional[ + process_request: ( Callable[ [ServerConnection, Request], - Optional[Response], + Response | None, ] - ] = None, - process_response: Optional[ + | None + ) = None, + process_response: ( Callable[ [ServerConnection, Request, Response], - Optional[Response], + Response | None, ] - ] = None, - server_header: Optional[str] = USER_AGENT, - timeout: Optional[float] = None, + | None + ) = None, + server_header: str | None = USER_AGENT, + timeout: float | None = None, ) -> None: """ Perform the opening handshake. @@ -197,7 +199,7 @@ def __init__( self, socket: socket.socket, handler: Callable[[socket.socket, Any], None], - logger: Optional[LoggerLike] = None, + logger: LoggerLike | None = None, ): self.socket = socket self.handler = handler @@ -260,54 +262,57 @@ def __enter__(self) -> WebSocketServer: def __exit__( self, - exc_type: Optional[Type[BaseException]], - exc_value: Optional[BaseException], - traceback: Optional[TracebackType], + exc_type: Type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, ) -> None: self.shutdown() def serve( handler: Callable[[ServerConnection], None], - host: Optional[str] = None, - port: Optional[int] = None, + host: str | None = None, + port: int | None = None, *, # TCP/TLS - sock: Optional[socket.socket] = None, - ssl: Optional[ssl_module.SSLContext] = None, + sock: socket.socket | None = None, + ssl: ssl_module.SSLContext | None = None, # WebSocket - origins: Optional[Sequence[Optional[Origin]]] = None, - extensions: Optional[Sequence[ServerExtensionFactory]] = None, - subprotocols: Optional[Sequence[Subprotocol]] = None, - select_subprotocol: Optional[ + origins: Sequence[Origin | None] | None = None, + extensions: Sequence[ServerExtensionFactory] | None = None, + subprotocols: Sequence[Subprotocol] | None = None, + select_subprotocol: ( Callable[ [ServerConnection, Sequence[Subprotocol]], - Optional[Subprotocol], + Subprotocol | None, ] - ] = None, - process_request: Optional[ + | None + ) = None, + process_request: ( Callable[ [ServerConnection, Request], - Optional[Response], + Response | None, ] - ] = None, - process_response: Optional[ + | None + ) = None, + process_response: ( Callable[ [ServerConnection, Request, Response], - Optional[Response], + Response | None, ] - ] = None, - server_header: Optional[str] = USER_AGENT, - compression: Optional[str] = "deflate", + | None + ) = None, + server_header: str | None = USER_AGENT, + compression: str | None = "deflate", # Timeouts - open_timeout: Optional[float] = 10, - close_timeout: Optional[float] = 10, + open_timeout: float | None = 10, + close_timeout: float | None = 10, # Limits - max_size: Optional[int] = 2**20, + max_size: int | None = 2**20, # Logging - logger: Optional[LoggerLike] = None, + logger: LoggerLike | None = None, # Escape hatch for advanced customization - create_connection: Optional[Type[ServerConnection]] = None, + create_connection: Type[ServerConnection] | None = None, **kwargs: Any, ) -> WebSocketServer: """ @@ -412,7 +417,7 @@ def handler(websocket): # Private APIs for unix_connect() unix: bool = kwargs.pop("unix", False) - path: Optional[str] = kwargs.pop("path", None) + path: str | None = kwargs.pop("path", None) if sock is None: if unix: @@ -460,18 +465,19 @@ def conn_handler(sock: socket.socket, addr: Any) -> None: sock.settimeout(None) # Create a closure to give select_subprotocol access to connection. - protocol_select_subprotocol: Optional[ + protocol_select_subprotocol: ( Callable[ [ServerProtocol, Sequence[Subprotocol]], - Optional[Subprotocol], + Subprotocol | None, ] - ] = None + | None + ) = None if select_subprotocol is not None: def protocol_select_subprotocol( protocol: ServerProtocol, subprotocols: Sequence[Subprotocol], - ) -> Optional[Subprotocol]: + ) -> Subprotocol | None: # mypy doesn't know that select_subprotocol is immutable. assert select_subprotocol is not None # Ensure this function is only used in the intended context. @@ -525,7 +531,7 @@ def protocol_select_subprotocol( def unix_serve( handler: Callable[[ServerConnection], None], - path: Optional[str] = None, + path: str | None = None, **kwargs: Any, ) -> WebSocketServer: """ diff --git a/src/websockets/sync/utils.py b/src/websockets/sync/utils.py index 3364bdc2d..00bce2cc6 100644 --- a/src/websockets/sync/utils.py +++ b/src/websockets/sync/utils.py @@ -1,7 +1,6 @@ from __future__ import annotations import time -from typing import Optional __all__ = ["Deadline"] @@ -16,14 +15,14 @@ class Deadline: """ - def __init__(self, timeout: Optional[float]) -> None: - self.deadline: Optional[float] + def __init__(self, timeout: float | None) -> None: + self.deadline: float | None if timeout is None: self.deadline = None else: self.deadline = time.monotonic() + timeout - def timeout(self, *, raise_if_elapsed: bool = True) -> Optional[float]: + def timeout(self, *, raise_if_elapsed: bool = True) -> float | None: """ Calculate a timeout from a deadline. diff --git a/src/websockets/typing.py b/src/websockets/typing.py index 5dfecf66f..7c5b3664d 100644 --- a/src/websockets/typing.py +++ b/src/websockets/typing.py @@ -53,7 +53,7 @@ ExtensionName = NewType("ExtensionName", str) """Name of a WebSocket extension.""" - +# Change to str | None when dropping Python < 3.10. ExtensionParameter = Tuple[str, Optional[str]] """Parameter of a WebSocket extension.""" diff --git a/src/websockets/uri.py b/src/websockets/uri.py index 8cf581743..902716066 100644 --- a/src/websockets/uri.py +++ b/src/websockets/uri.py @@ -2,7 +2,7 @@ import dataclasses import urllib.parse -from typing import Optional, Tuple +from typing import Tuple from . import exceptions @@ -33,8 +33,8 @@ class WebSocketURI: port: int path: str query: str - username: Optional[str] = None - password: Optional[str] = None + username: str | None = None + password: str | None = None @property def resource_name(self) -> str: @@ -47,7 +47,7 @@ def resource_name(self) -> str: return resource_name @property - def user_info(self) -> Optional[Tuple[str, str]]: + def user_info(self) -> Tuple[str, str] | None: if self.username is None: return None assert self.password is not None