diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 963353d0..e288831b 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -25,11 +25,23 @@ fixing regressions shortly after a release. Only documented APIs are public. Undocumented, private APIs may change without notice. -12.1 +13.0 ---- *In development* +Backwards-incompatible changes +.............................. + +.. admonition:: The ``ssl_context`` argument of :func:`~sync.client.connect` + and :func:`~sync.server.serve` is renamed to ``ssl``. + :class: note + + This aligns the API of the :mod:`threading` implementation with the + :mod:`asyncio` implementation. + + For backwards compatibility, ``ssl_context`` is still supported. + New features ............ diff --git a/src/websockets/sync/client.py b/src/websockets/sync/client.py index 79af0132..6faca778 100644 --- a/src/websockets/sync/client.py +++ b/src/websockets/sync/client.py @@ -1,8 +1,9 @@ from __future__ import annotations import socket -import ssl +import ssl as ssl_module import threading +import warnings from typing import Any, Optional, Sequence, Type from ..client import ClientProtocol @@ -128,7 +129,7 @@ def connect( *, # TCP/TLS sock: Optional[socket.socket] = None, - ssl_context: Optional[ssl.SSLContext] = None, + ssl: Optional[ssl_module.SSLContext] = None, server_hostname: Optional[str] = None, # WebSocket origin: Optional[Origin] = None, @@ -166,7 +167,7 @@ def connect( sock: Preexisting TCP socket. ``sock`` overrides the host and port from ``uri``. You may call :func:`socket.create_connection` to create a suitable TCP socket. - ssl_context: Configuration for enabling TLS on the connection. + ssl: Configuration for enabling TLS on the connection. server_hostname: Host name for the TLS handshake. ``server_hostname`` overrides the host name from ``uri``. origin: Value of the ``Origin`` header, for servers that require it. @@ -207,9 +208,14 @@ def connect( # Process parameters + # Backwards compatibility: ssl used to be called ssl_context. + if ssl is None and "ssl_context" in kwargs: + ssl = kwargs.pop("ssl_context") + warnings.warn("ssl_context was renamed to ssl", DeprecationWarning) + wsuri = parse_uri(uri) - if not wsuri.secure and ssl_context is not None: - raise TypeError("ssl_context argument is incompatible with a ws:// URI") + if not wsuri.secure and ssl is not None: + raise TypeError("ssl argument is incompatible with a ws:// URI") # Private APIs for unix_connect() unix: bool = kwargs.pop("unix", False) @@ -259,12 +265,12 @@ def connect( # Initialize TLS wrapper and perform TLS handshake if wsuri.secure: - if ssl_context is None: - ssl_context = ssl.create_default_context() + if ssl is None: + ssl = ssl_module.create_default_context() if server_hostname is None: server_hostname = wsuri.host sock.settimeout(deadline.timeout()) - sock = ssl_context.wrap_socket(sock, server_hostname=server_hostname) + sock = ssl.wrap_socket(sock, server_hostname=server_hostname) sock.settimeout(None) # Initialize WebSocket connection @@ -318,12 +324,13 @@ def unix_connect( Args: path: File system path to the Unix socket. uri: URI of the WebSocket server. ``uri`` defaults to - ``ws://localhost/`` or, when a ``ssl_context`` is provided, to + ``ws://localhost/`` or, when a ``ssl`` is provided, to ``wss://localhost/``. """ if uri is None: - if kwargs.get("ssl_context") is None: + # Backwards compatibility: ssl used to be called ssl_context. + if kwargs.get("ssl") is None and kwargs.get("ssl_context") is None: uri = "ws://localhost/" else: uri = "wss://localhost/" diff --git a/src/websockets/sync/server.py b/src/websockets/sync/server.py index c1999284..fa6087d5 100644 --- a/src/websockets/sync/server.py +++ b/src/websockets/sync/server.py @@ -5,9 +5,10 @@ import os import selectors import socket -import ssl +import ssl as ssl_module import sys import threading +import warnings from types import TracebackType from typing import Any, Callable, Optional, Sequence, Type @@ -268,7 +269,7 @@ def serve( *, # TCP/TLS sock: Optional[socket.socket] = None, - ssl_context: Optional[ssl.SSLContext] = None, + ssl: Optional[ssl_module.SSLContext] = None, # WebSocket origins: Optional[Sequence[Optional[Origin]]] = None, extensions: Optional[Sequence[ServerExtensionFactory]] = None, @@ -337,7 +338,7 @@ def handler(websocket): sock: Preexisting TCP socket. ``sock`` replaces ``host`` and ``port``. You may call :func:`socket.create_server` to create a suitable TCP socket. - ssl_context: Configuration for enabling TLS on the connection. + ssl: Configuration for enabling TLS on the connection. origins: Acceptable values of the ``Origin`` header, for defending against Cross-Site WebSocket Hijacking attacks. Include :obj:`None` in the list if the lack of an origin is acceptable. @@ -386,6 +387,11 @@ def handler(websocket): # Process parameters + # Backwards compatibility: ssl used to be called ssl_context. + if ssl is None and "ssl_context" in kwargs: + ssl = kwargs.pop("ssl_context") + warnings.warn("ssl_context was renamed to ssl", DeprecationWarning) + if subprotocols is not None: validate_subprotocols(subprotocols) @@ -417,8 +423,8 @@ def handler(websocket): # Initialize TLS wrapper - if ssl_context is not None: - sock = ssl_context.wrap_socket( + if ssl is not None: + sock = ssl.wrap_socket( sock, server_side=True, # Delay TLS handshake until after we set a timeout on the socket. @@ -441,9 +447,10 @@ def conn_handler(sock: socket.socket, addr: Any) -> None: # Perform TLS handshake - if ssl_context is not None: + if ssl is not None: sock.settimeout(deadline.timeout()) - assert isinstance(sock, ssl.SSLSocket) # mypy cannot figure this out + # mypy cannot figure this out + assert isinstance(sock, ssl_module.SSLSocket) sock.do_handshake() sock.settimeout(None) diff --git a/tests/sync/client.py b/tests/sync/client.py index 683893e8..bb4855c7 100644 --- a/tests/sync/client.py +++ b/tests/sync/client.py @@ -25,7 +25,8 @@ def run_client(wsuri_or_server, secure=None, resource_name="/", **kwargs): else: assert isinstance(wsuri_or_server, WebSocketServer) if secure is None: - secure = "ssl_context" in kwargs + # Backwards compatibility: ssl used to be called ssl_context. + secure = "ssl" in kwargs or "ssl_context" in kwargs protocol = "wss" if secure else "ws" host, port = wsuri_or_server.socket.getsockname() wsuri = f"{protocol}://{host}:{port}{resource_name}" diff --git a/tests/sync/test_client.py b/tests/sync/test_client.py index c900f3b0..fa363deb 100644 --- a/tests/sync/test_client.py +++ b/tests/sync/test_client.py @@ -7,7 +7,7 @@ from websockets.extensions.permessage_deflate import PerMessageDeflate from websockets.sync.client import * -from ..utils import MS, temp_unix_socket_path +from ..utils import MS, DeprecationTestCase, temp_unix_socket_path from .client import CLIENT_CONTEXT, run_client, run_unix_client from .server import SERVER_CONTEXT, do_nothing, run_server, run_unix_server @@ -137,18 +137,18 @@ def close_connection(self, request): class SecureClientTests(unittest.TestCase): def test_connection(self): """Client connects to server securely.""" - with run_server(ssl_context=SERVER_CONTEXT) as server: - with run_client(server, ssl_context=CLIENT_CONTEXT) as client: + with run_server(ssl=SERVER_CONTEXT) as server: + with run_client(server, ssl=CLIENT_CONTEXT) as client: self.assertEqual(client.protocol.state.name, "OPEN") self.assertEqual(client.socket.version()[:3], "TLS") def test_set_server_hostname_implicitly(self): """Client sets server_hostname to the host in the WebSocket URI.""" with temp_unix_socket_path() as path: - with run_unix_server(path, ssl_context=SERVER_CONTEXT): + with run_unix_server(path, ssl=SERVER_CONTEXT): with run_unix_client( path, - ssl_context=CLIENT_CONTEXT, + ssl=CLIENT_CONTEXT, uri="wss://overridden/", ) as client: self.assertEqual(client.socket.server_hostname, "overridden") @@ -156,17 +156,17 @@ def test_set_server_hostname_implicitly(self): def test_set_server_hostname_explicitly(self): """Client sets server_hostname to the value provided in argument.""" with temp_unix_socket_path() as path: - with run_unix_server(path, ssl_context=SERVER_CONTEXT): + with run_unix_server(path, ssl=SERVER_CONTEXT): with run_unix_client( path, - ssl_context=CLIENT_CONTEXT, + ssl=CLIENT_CONTEXT, server_hostname="overridden", ) as client: self.assertEqual(client.socket.server_hostname, "overridden") def test_reject_invalid_server_certificate(self): """Client rejects certificate where server certificate isn't trusted.""" - with run_server(ssl_context=SERVER_CONTEXT) as server: + with run_server(ssl=SERVER_CONTEXT) as server: with self.assertRaisesRegex( ssl.SSLCertVerificationError, r"certificate verify failed: self[ -]signed certificate", @@ -177,15 +177,13 @@ def test_reject_invalid_server_certificate(self): def test_reject_invalid_server_hostname(self): """Client rejects certificate where server hostname doesn't match.""" - with run_server(ssl_context=SERVER_CONTEXT) as server: + with run_server(ssl=SERVER_CONTEXT) as server: with self.assertRaisesRegex( ssl.SSLCertVerificationError, r"certificate verify failed: Hostname mismatch", ): # This hostname isn't included in the test certificate. - with run_client( - server, ssl_context=CLIENT_CONTEXT, server_hostname="invalid" - ): + with run_client(server, ssl=CLIENT_CONTEXT, server_hostname="invalid"): self.fail("did not raise") @@ -212,8 +210,8 @@ class SecureUnixClientTests(unittest.TestCase): def test_connection(self): """Client connects to server securely over a Unix socket.""" with temp_unix_socket_path() as path: - with run_unix_server(path, ssl_context=SERVER_CONTEXT): - with run_unix_client(path, ssl_context=CLIENT_CONTEXT) as client: + with run_unix_server(path, ssl=SERVER_CONTEXT): + with run_unix_client(path, ssl=CLIENT_CONTEXT) as client: self.assertEqual(client.protocol.state.name, "OPEN") self.assertEqual(client.socket.version()[:3], "TLS") @@ -221,23 +219,23 @@ def test_set_server_hostname(self): """Client sets server_hostname to the host in the WebSocket URI.""" # This is part of the documented behavior of unix_connect(). with temp_unix_socket_path() as path: - with run_unix_server(path, ssl_context=SERVER_CONTEXT): + with run_unix_server(path, ssl=SERVER_CONTEXT): with run_unix_client( path, - ssl_context=CLIENT_CONTEXT, + ssl=CLIENT_CONTEXT, uri="wss://overridden/", ) as client: self.assertEqual(client.socket.server_hostname, "overridden") class ClientUsageErrorsTests(unittest.TestCase): - def test_ssl_context_without_secure_uri(self): - """Client rejects ssl_context when URI isn't secure.""" + def test_ssl_without_secure_uri(self): + """Client rejects ssl when URI isn't secure.""" with self.assertRaisesRegex( TypeError, - "ssl_context argument is incompatible with a ws:// URI", + "ssl argument is incompatible with a ws:// URI", ): - connect("ws://localhost/", ssl_context=CLIENT_CONTEXT) + connect("ws://localhost/", ssl=CLIENT_CONTEXT) def test_unix_without_path_or_sock(self): """Unix client requires path when sock isn't provided.""" @@ -272,3 +270,12 @@ def test_unsupported_compression(self): "unsupported compression: False", ): connect("ws://localhost/", compression=False) + + +class BackwardsCompatibilityTests(DeprecationTestCase): + def test_ssl_context_argument(self): + """Client supports the deprecated ssl_context argument.""" + with run_server(ssl=SERVER_CONTEXT) as server: + with self.assertDeprecationWarning("ssl_context was renamed to ssl"): + with run_client(server, ssl_context=CLIENT_CONTEXT): + pass diff --git a/tests/sync/test_server.py b/tests/sync/test_server.py index f9db8424..5e7e79c5 100644 --- a/tests/sync/test_server.py +++ b/tests/sync/test_server.py @@ -14,7 +14,7 @@ from websockets.http11 import Request, Response from websockets.sync.server import * -from ..utils import MS, temp_unix_socket_path +from ..utils import MS, DeprecationTestCase, temp_unix_socket_path from .client import CLIENT_CONTEXT, run_client, run_unix_client from .server import ( SERVER_CONTEXT, @@ -274,20 +274,20 @@ def handler(sock, addr): class SecureServerTests(EvalShellMixin, unittest.TestCase): def test_connection(self): """Server receives secure connection from client.""" - with run_server(ssl_context=SERVER_CONTEXT) as server: - with run_client(server, ssl_context=CLIENT_CONTEXT) as client: + with run_server(ssl=SERVER_CONTEXT) as server: + with run_client(server, ssl=CLIENT_CONTEXT) as client: self.assertEval(client, "ws.protocol.state.name", "OPEN") self.assertEval(client, "ws.socket.version()[:3]", "TLS") def test_timeout_during_tls_handshake(self): """Server times out before receiving TLS handshake request from client.""" - with run_server(ssl_context=SERVER_CONTEXT, open_timeout=MS) as server: + with run_server(ssl=SERVER_CONTEXT, open_timeout=MS) as server: with socket.create_connection(server.socket.getsockname()) as sock: self.assertEqual(sock.recv(4096), b"") def test_connection_closed_during_tls_handshake(self): """Server reads EOF before receiving TLS handshake request from client.""" - with run_server(ssl_context=SERVER_CONTEXT) as server: + with run_server(ssl=SERVER_CONTEXT) as server: # Patch handler to record a reference to the thread running it. server_thread = None conn_received = threading.Event() @@ -325,8 +325,8 @@ class SecureUnixServerTests(EvalShellMixin, unittest.TestCase): def test_connection(self): """Server receives secure connection from client over a Unix socket.""" with temp_unix_socket_path() as path: - with run_unix_server(path, ssl_context=SERVER_CONTEXT): - with run_unix_client(path, ssl_context=CLIENT_CONTEXT) as client: + with run_unix_server(path, ssl=SERVER_CONTEXT): + with run_unix_client(path, ssl=CLIENT_CONTEXT) as client: self.assertEval(client, "ws.protocol.state.name", "OPEN") self.assertEval(client, "ws.socket.version()[:3]", "TLS") @@ -386,3 +386,12 @@ def test_shutdown(self): # Check that the server socket is closed. with self.assertRaises(OSError): server.socket.accept() + + +class BackwardsCompatibilityTests(DeprecationTestCase): + def test_ssl_context_argument(self): + """Client supports the deprecated ssl_context argument.""" + with self.assertDeprecationWarning("ssl_context was renamed to ssl"): + with run_server(ssl_context=SERVER_CONTEXT) as server: + with run_client(server, ssl=CLIENT_CONTEXT): + pass