diff --git a/MANIFEST.in b/MANIFEST.in deleted file mode 100644 index 17f4eed..0000000 --- a/MANIFEST.in +++ /dev/null @@ -1 +0,0 @@ -include src/aioacme/py.typed diff --git a/mypy.ini b/mypy.ini deleted file mode 100644 index aabb912..0000000 --- a/mypy.ini +++ /dev/null @@ -1,5 +0,0 @@ -[mypy] -mypy_path = $MYPY_CONFIG_FILE_DIR/src -packages = aioacme -strict = true -pretty = true diff --git a/pyproject.toml b/pyproject.toml index 95e460d..b0c1054 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,7 +10,8 @@ authors = [{ name = "Timofei Kukushkin", email = "tima@kukushkin.me" }] description = "Async ACME client implementation" readme = "README.md" dependencies = [ - "aiohttp", + "anyio", + "httpx", "cryptography", "typing-extensions; python_version<'3.11'", "orjson", @@ -28,6 +29,8 @@ classifiers = [ "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", "Topic :: Software Development :: Libraries :: Python Modules", + "Framework :: AsyncIO", + "Framework :: Trio", ] [project.urls] diff --git a/pyrightconfig.json b/pyrightconfig.json new file mode 100644 index 0000000..46a1db9 --- /dev/null +++ b/pyrightconfig.json @@ -0,0 +1,14 @@ +{ + "include": ["src", "tests"], + "executionEnvironments": [ + { + "root": "src", + "typeCheckingMode": "strict" + }, + { + "root": "tests", + "extraPaths": ["."], + "typeCheckingMode": "basic" + } + ] +} diff --git a/pytest.ini b/pytest.ini index 91d8d88..a02f0ea 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,6 +1,3 @@ [pytest] -asyncio_mode = auto -testpaths = - tests - src -addopts = --tb=short +testpaths = tests +addopts = --tb=short -v diff --git a/requirements/lint.txt b/requirements/lint.txt index 01112e2..992dbf8 100644 --- a/requirements/lint.txt +++ b/requirements/lint.txt @@ -1,4 +1,3 @@ -mypy +-r tests.txt pyright ruff -types-python-dateutil diff --git a/requirements/tests.txt b/requirements/tests.txt index 7de5204..8aa087e 100644 --- a/requirements/tests.txt +++ b/requirements/tests.txt @@ -1,4 +1,5 @@ pytest -pytest-asyncio==0.21.2 +pytest-anyio pytest-cov lovely-pytest-docker +trio diff --git a/ruff.toml b/ruff.toml index c3f0435..4426f83 100644 --- a/ruff.toml +++ b/ruff.toml @@ -24,6 +24,7 @@ select = [ "RSE", # flake8-raise "RET", # flake8-return "SIM", # flake8-simplify + "TID", # flake8-tidy-imports "PTH", # flake8-use-pathlib "PGH", # pygrep-hooks "PL", # Pylint @@ -51,5 +52,9 @@ inline-quotes = "single" combine-as-imports = true no-lines-before = ["local-folder"] +[lint.flake8-tidy-imports.banned-api] +"asyncio".msg = "Use anyio instead." +"trio".msg = "Use anyio instead." + [format] quote-style = "single" diff --git a/src/aioacme/_client.py b/src/aioacme/_client.py index e5185be..91676e8 100644 --- a/src/aioacme/_client.py +++ b/src/aioacme/_client.py @@ -1,16 +1,15 @@ -import asyncio import hashlib import sys from base64 import b64decode -from collections.abc import AsyncIterator, Mapping, Sequence -from contextlib import asynccontextmanager +from collections.abc import Mapping, Sequence from dataclasses import dataclass from datetime import datetime from ssl import SSLContext from types import TracebackType from typing import Any, Final, Literal -import aiohttp +import anyio +import httpx import orjson import serpyco_rs from cryptography import x509 @@ -81,12 +80,10 @@ def __init__( self._account_key_jwk_thumbprint = jwk_thumbprint(self._account_key_jwk) self._directory: _Directory | None = None - self._session: aiohttp.ClientSession = aiohttp.ClientSession( - connector=aiohttp.TCPConnector(ssl=ssl), cookie_jar=aiohttp.DummyCookieJar() - ) + self._client = httpx.AsyncClient(verify=ssl) self._nonces: list[str] = [] - self._get_nonce_lock = asyncio.Lock() + self._get_nonce_lock = anyio.Lock() async def get_terms_of_service(self) -> str | None: return (await self._get_directory()).terms_of_service @@ -117,8 +114,8 @@ async def get_account(self) -> Account: data = b'' jwk = None - async with self._request(url=url, data=data, jwk=jwk) as response: - response_data = await response.json(loads=orjson.loads) + response = await self._request(url=url, data=data, jwk=jwk) + response_data = orjson.loads(response.content) account = _account_serializer.load({**response_data, 'uri': self.account_uri or response.headers['Location']}) self.account_uri = account.uri @@ -139,9 +136,8 @@ async def update_account( data['contact'] = contact if status is not None: data['status'] = status.value - async with self._request(url=uri, data=data) as response: - response_data = await response.json(loads=orjson.loads) - + response = await self._request(url=uri, data=data) + response_data = orjson.loads(response.content) return _account_serializer.load({**response_data, 'uri': uri}) async def new_order( @@ -166,8 +162,8 @@ async def new_order( if not_after is not None: data['notAfter'] = not_after.isoformat() - async with self._request(url, data=data) as response: - response_data = await response.json(loads=orjson.loads) + response = await self._request(url, data=data) + response_data = orjson.loads(response.content) return _order_serializer.load({**response_data, 'uri': response.headers['Location']}) @@ -177,8 +173,8 @@ async def get_order(self, order_uri: str) -> Order: :param order_uri: order URI. """ - async with self._request(order_uri) as response: - response_data = await response.json(loads=orjson.loads) + response = await self._request(order_uri) + response_data = orjson.loads(response.content) return _order_serializer.load({**response_data, 'uri': order_uri}) async def get_authorization(self, authorization_uri: str) -> Authorization: @@ -187,8 +183,8 @@ async def get_authorization(self, authorization_uri: str) -> Authorization: :param authorization_uri: authorization URI. """ - async with self._request(authorization_uri) as response: - response_data = await response.json(loads=orjson.loads) + response = await self._request(authorization_uri) + response_data = orjson.loads(response.content) return _authorization_serializer.load({**response_data, 'uri': authorization_uri}) def get_dns_challenge_domain(self, domain: str) -> str: @@ -217,8 +213,8 @@ async def answer_challenge(self, url: str) -> Challenge: :param url: challenge url. """ - async with self._request(url, data={}) as response: - response_data = await response.json(loads=orjson.loads) + response = await self._request(url, data={}) + response_data = orjson.loads(response.content) return _challenge_serializer.load(response_data) async def deactivate_authorization(self, uri: str) -> Authorization: @@ -229,8 +225,8 @@ async def deactivate_authorization(self, uri: str) -> Authorization: :param uri: authorization URI. """ - async with self._request(uri, data={'status': 'deactivated'}) as response: - response_data = await response.json(loads=orjson.loads) + response = await self._request(uri, data={'status': 'deactivated'}) + response_data = orjson.loads(response.content) return _authorization_serializer.load({**response_data, 'uri': uri}) async def finalize_order(self, finalize: str, csr: x509.CertificateSigningRequest) -> Order: @@ -240,11 +236,11 @@ async def finalize_order(self, finalize: str, csr: x509.CertificateSigningReques :param finalize: finalize uri. :param csr: CSR. """ - async with self._request( + response = await self._request( finalize, data={'csr': b64_encode(csr.public_bytes(serialization.Encoding.DER)).decode('ascii')}, - ) as response: - response_data = await response.json(loads=orjson.loads) + ) + response_data = orjson.loads(response.content) return _order_serializer.load({**response_data, 'uri': response.headers['Location']}) async def get_certificate(self, certificate: str) -> bytes: @@ -253,8 +249,8 @@ async def get_certificate(self, certificate: str) -> bytes: :param certificate: certificate URL. """ - async with self._request(certificate) as response: - return await response.read() + response = await self._request(certificate) + return response.content async def revoke_certificate( self, @@ -271,7 +267,7 @@ async def revoke_certificate( :param key: private key, if you want to revoke certificate without account key. :param reason: reason. """ - async with self._request( + await self._request( (await self._get_directory()).revoke_cert, data={ 'certificate': b64_encode(certificate.public_bytes(serialization.Encoding.DER)).decode('ascii'), @@ -279,8 +275,7 @@ async def revoke_certificate( }, jwk=make_jwk(key) if key else None, key=key, - ): - pass + ) async def change_key(self, new_account_key: PrivateKeyTypes) -> None: """ @@ -304,8 +299,7 @@ async def change_key(self, new_account_key: PrivateKeyTypes) -> None: jwk=new_account_key_jwk, add_nonce=False, ) - async with self._request(url, data=payload): - pass + await self._request(url, data=payload) self.account_key = new_account_key self._account_key_jwk = new_account_key_jwk self._account_key_jwk_thumbprint = jwk_thumbprint(new_account_key_jwk) @@ -313,7 +307,6 @@ async def change_key(self, new_account_key: PrivateKeyTypes) -> None: async def _get_account_uri(self) -> str: return self.account_uri or (await self.get_account()).uri - @asynccontextmanager async def _request( self, url: str, @@ -321,29 +314,30 @@ async def _request( data: Mapping[str, Any] | bytes = b'', jwk: JWK | None = None, key: PrivateKeyTypes | None = None, - ) -> AsyncIterator[aiohttp.ClientResponse]: - async with self._session.post( + ) -> httpx.Response: + response = await self._client.post( url, - data=await self._wrap_in_jws(url=url, data=data, jwk=jwk, key=key), + content=await self._wrap_in_jws(url=url, data=data, jwk=jwk, key=key), headers={'Content-Type': 'application/jose+json'}, - ) as response: - self._add_nounce(response) + ) + + self._add_nounce(response) - if response.status < 300: - yield response - return + if response.status_code < 300: + return response - try: - error = _error_serializer.load(await response.json(loads=orjson.loads)) - except aiohttp.ContentTypeError: - error = Error(type='unknown', detail='') + try: + response_data = orjson.loads(response.content) + except orjson.JSONDecodeError: + error = Error(type='unknown', detail=response.text) + else: + error = _error_serializer.load(response_data) - if error.type != 'urn:ietf:params:acme:error:badNonce': - raise AcmeError(error) + if error.type != 'urn:ietf:params:acme:error:badNonce': + raise AcmeError(error) - # retry bad nonce - async with self._request(url, data=data, jwk=jwk, key=key) as retried_response: - yield retried_response + # retry bad nonce + return await self._request(url, data=data, jwk=jwk, key=key) async def _wrap_in_jws( self, @@ -369,18 +363,18 @@ async def _wrap_in_jws( async def _get_nonce(self) -> str: async with self._get_nonce_lock: if not self._nonces: - async with self._session.head((await self._get_directory()).new_nonce) as response: - self._nonces.append(response.headers['Replay-Nonce']) + response = await self._client.head((await self._get_directory()).new_nonce) + self._add_nounce(response) return self._nonces.pop() - def _add_nounce(self, response: aiohttp.ClientResponse) -> None: + def _add_nounce(self, response: httpx.Response) -> None: self._nonces.append(response.headers['Replay-Nonce']) async def _get_directory(self) -> '_Directory': if self._directory is None: - async with self._session.get(self.directory_url) as response: - response_data = await response.json(loads=orjson.loads) - + response = await self._client.get(self.directory_url) + response.raise_for_status() + response_data = orjson.loads(response.content) self._directory = _Directory( new_account=response_data['newAccount'], new_nonce=response_data['newNonce'], @@ -393,7 +387,7 @@ async def _get_directory(self) -> '_Directory': return self._directory async def close(self) -> None: - await self._session.close() + await self._client.aclose() async def __aenter__(self) -> Self: return self @@ -401,7 +395,8 @@ async def __aenter__(self) -> Self: async def __aexit__( self, exc_type: type[BaseException] | None, exc: BaseException | None, tb: TracebackType | None ) -> None: - await asyncio.shield(self.close()) + with anyio.CancelScope(shield=True): + await self.close() @dataclass(frozen=True, slots=True) diff --git a/src/aioacme/_jws.py b/src/aioacme/_jws.py index 4734ec5..152399f 100644 --- a/src/aioacme/_jws.py +++ b/src/aioacme/_jws.py @@ -17,6 +17,13 @@ else: from typing_extensions import assert_never +_EC_ALGS: Mapping[type[ec.EllipticCurve], str] = {ec.SECP256R1: 'ES256', ec.SECP384R1: 'ES384', ec.SECP521R1: 'ES512'} +_EC_HASHES: Mapping[type[ec.EllipticCurve], hashes.HashAlgorithm] = { + ec.SECP256R1: hashes.SHA256(), + ec.SECP384R1: hashes.SHA384(), + ec.SECP521R1: hashes.SHA512(), +} + def jws_encode( payload: bytes, @@ -24,7 +31,7 @@ def jws_encode( headers: Mapping[str, Any], ) -> bytes: if isinstance(key, ec.EllipticCurvePrivateKey): - alg = {ec.SECP256R1.name: 'ES256', ec.SECP384R1.name: 'ES384', ec.SECP521R1.name: 'ES512'}[key.curve.name] + alg = _EC_ALGS[type(key.curve)] elif isinstance(key, rsa.RSAPrivateKey): alg = 'RS256' elif isinstance(key, ed25519.Ed25519PrivateKey): @@ -43,11 +50,7 @@ def jws_encode( if isinstance(key, rsa.RSAPrivateKey): signature = key.sign(signing_input, PKCS1v15(), hashes.SHA256()) elif isinstance(key, ec.EllipticCurvePrivateKey): - hash_alg = { - ec.SECP256R1.name: hashes.SHA256(), - ec.SECP384R1.name: hashes.SHA384(), - ec.SECP521R1.name: hashes.SHA512(), - }[key.curve.name] + hash_alg = _EC_HASHES[type(key.curve)] signature = key.sign(signing_input, ec.ECDSA(hash_alg)) signature = _der_to_raw_signature(signature, key.curve) elif isinstance(key, ed25519.Ed25519PrivateKey): @@ -56,7 +59,6 @@ def jws_encode( key = key.copy() key.update(signing_input) signature = key.finalize() - else: assert_never(key) diff --git a/tests/conftest.py b/tests/conftest.py index 233acc6..a591dec 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,8 @@ -import asyncio +import ssl import uuid -import aiohttp +import anyio +import httpx import pytest from cryptography import x509 from cryptography.hazmat.primitives import hashes @@ -11,11 +12,9 @@ import aioacme -@pytest.fixture(scope='session') -def event_loop(): - loop = asyncio.new_event_loop() - yield loop - loop.close() +@pytest.fixture(autouse=True, scope='session', params=['asyncio', 'trio']) +def anyio_backend(request) -> str: + return request.param @pytest.fixture(scope='session') @@ -34,21 +33,23 @@ async def _get_pebble_url(docker_services: Services, docker_ip: str, name: str) url = f'https://{docker_ip}:{port}' # wait for pebble to be ready - async with aiohttp.ClientSession() as session: + last_exc = None + async with httpx.AsyncClient(verify=False) as client: for _ in range(100): try: - async with session.get(url, ssl=False): - break - except aiohttp.ClientError as exc: + await client.get(url) + except (httpx.HTTPError, ssl.SSLZeroReturnError) as exc: last_exc = exc - await asyncio.sleep(0.1) + await anyio.sleep(0.1) + else: + break else: raise TimeoutError from last_exc return url @pytest.fixture() -async def client(pebble_url) -> aioacme.Client: +async def client(pebble_url): account_key = ec.generate_private_key(ec.SECP256R1()) async with aioacme.Client(directory_url=f'{pebble_url}/dir', ssl=False, account_key=account_key) as client: yield client @@ -80,9 +81,10 @@ def add_txt_fixture(docker_services, docker_ip): port = docker_services.wait_for_service('challtestsrv', 8055) async def add_txt(domain: str, value: str) -> None: - async with aiohttp.request( - 'POST', f'http://{docker_ip}:{port}/set-txt', json={'host': domain + '.', 'value': value} - ) as response: + async with httpx.AsyncClient() as client: + response = await client.post( + f'http://{docker_ip}:{port}/set-txt', json={'host': domain + '.', 'value': value} + ) response.raise_for_status() return add_txt diff --git a/tests/test_client.py b/tests/test_client.py index d35f752..4791e16 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1,7 +1,7 @@ -import asyncio from datetime import datetime, timedelta, timezone from unittest import mock +import anyio import pytest from cryptography import x509 from cryptography.hazmat.primitives.asymmetric import ec, ed25519, rsa @@ -215,7 +215,7 @@ async def test_answer_challenge__invalid_record_value__challenge_error(client, d challenge = await client.answer_challenge(challenge.url) assert challenge.status is aioacme.ChallengeStatus.processing - await asyncio.sleep(0.5) + await anyio.sleep(0.5) authorization = await client.get_authorization(authorization.uri) assert authorization.status is aioacme.AuthorizationStatus.invalid @@ -241,7 +241,7 @@ async def test_deactivate_authorization(client, domain, add_txt): challenge = await client.answer_challenge(challenge.url) assert challenge.status is aioacme.ChallengeStatus.processing - await asyncio.sleep(0.5) + await anyio.sleep(0.5) authorization = await client.get_authorization(authorization.uri) assert authorization.status is aioacme.AuthorizationStatus.valid @@ -264,7 +264,7 @@ async def test_integrational_ok(client, domain, csr, add_txt, private_key): challenge = await client.answer_challenge(challenge.url) assert challenge.status is aioacme.ChallengeStatus.processing - await asyncio.sleep(0.5) + await anyio.sleep(0.5) authorization = await client.get_authorization(authorization.uri) assert authorization.status is aioacme.AuthorizationStatus.valid @@ -285,7 +285,7 @@ async def test_integrational_ok(client, domain, csr, add_txt, private_key): order = await client.finalize_order(order.finalize, csr) assert order.status is aioacme.OrderStatus.processing - await asyncio.sleep(0.5) + await anyio.sleep(0.5) order = await client.get_order(order.uri) assert order.status is aioacme.OrderStatus.valid diff --git a/tox.ini b/tox.ini index 367044c..284e233 100644 --- a/tox.ini +++ b/tox.ini @@ -39,5 +39,5 @@ deps = commands = ruff check --no-fix . ruff format --check . - mypy + pyright pyright --verifytypes aioacme --ignoreexternal