Skip to content

Commit

Permalink
Trio support
Browse files Browse the repository at this point in the history
  • Loading branch information
tkukushkin committed Oct 12, 2024
1 parent 5be7127 commit d257d7c
Show file tree
Hide file tree
Showing 13 changed files with 115 additions and 103 deletions.
1 change: 0 additions & 1 deletion MANIFEST.in

This file was deleted.

5 changes: 0 additions & 5 deletions mypy.ini

This file was deleted.

5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ authors = [{ name = "Timofei Kukushkin", email = "[email protected]" }]
description = "Async ACME client implementation"
readme = "README.md"
dependencies = [
"aiohttp",
"anyio",
"httpx",
"cryptography",
"typing-extensions; python_version<'3.11'",
"orjson",
Expand All @@ -28,6 +29,8 @@ classifiers = [
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Topic :: Software Development :: Libraries :: Python Modules",
"Framework :: AsyncIO",
"Framework :: Trio",
]

[project.urls]
Expand Down
14 changes: 14 additions & 0 deletions pyrightconfig.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
{
"include": ["src", "tests"],
"executionEnvironments": [
{
"root": "src",
"typeCheckingMode": "strict"
},
{
"root": "tests",
"extraPaths": ["."],
"typeCheckingMode": "basic"
}
]
}
7 changes: 2 additions & 5 deletions pytest.ini
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
[pytest]
asyncio_mode = auto
testpaths =
tests
src
addopts = --tb=short
testpaths = tests
addopts = --tb=short -v
3 changes: 1 addition & 2 deletions requirements/lint.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
mypy
-r tests.txt
pyright
ruff
types-python-dateutil
3 changes: 2 additions & 1 deletion requirements/tests.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
pytest
pytest-asyncio==0.21.2
pytest-anyio
pytest-cov
lovely-pytest-docker
trio
5 changes: 5 additions & 0 deletions ruff.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ select = [
"RSE", # flake8-raise
"RET", # flake8-return
"SIM", # flake8-simplify
"TID", # flake8-tidy-imports
"PTH", # flake8-use-pathlib
"PGH", # pygrep-hooks
"PL", # Pylint
Expand Down Expand Up @@ -51,5 +52,9 @@ inline-quotes = "single"
combine-as-imports = true
no-lines-before = ["local-folder"]

[lint.flake8-tidy-imports.banned-api]
"asyncio".msg = "Use anyio instead."
"trio".msg = "Use anyio instead."

[format]
quote-style = "single"
113 changes: 54 additions & 59 deletions src/aioacme/_client.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
import asyncio
import hashlib
import sys
from base64 import b64decode
from collections.abc import AsyncIterator, Mapping, Sequence
from contextlib import asynccontextmanager
from collections.abc import Mapping, Sequence
from dataclasses import dataclass
from datetime import datetime
from ssl import SSLContext
from types import TracebackType
from typing import Any, Final, Literal

import aiohttp
import anyio
import httpx
import orjson
import serpyco_rs
from cryptography import x509
Expand Down Expand Up @@ -81,12 +80,10 @@ def __init__(
self._account_key_jwk_thumbprint = jwk_thumbprint(self._account_key_jwk)

self._directory: _Directory | None = None
self._session: aiohttp.ClientSession = aiohttp.ClientSession(
connector=aiohttp.TCPConnector(ssl=ssl), cookie_jar=aiohttp.DummyCookieJar()
)
self._client = httpx.AsyncClient(verify=ssl)

self._nonces: list[str] = []
self._get_nonce_lock = asyncio.Lock()
self._get_nonce_lock = anyio.Lock()

async def get_terms_of_service(self) -> str | None:
return (await self._get_directory()).terms_of_service
Expand Down Expand Up @@ -117,8 +114,8 @@ async def get_account(self) -> Account:
data = b''
jwk = None

async with self._request(url=url, data=data, jwk=jwk) as response:
response_data = await response.json(loads=orjson.loads)
response = await self._request(url=url, data=data, jwk=jwk)
response_data = orjson.loads(response.content)

account = _account_serializer.load({**response_data, 'uri': self.account_uri or response.headers['Location']})
self.account_uri = account.uri
Expand All @@ -139,9 +136,8 @@ async def update_account(
data['contact'] = contact
if status is not None:
data['status'] = status.value
async with self._request(url=uri, data=data) as response:
response_data = await response.json(loads=orjson.loads)

response = await self._request(url=uri, data=data)
response_data = orjson.loads(response.content)
return _account_serializer.load({**response_data, 'uri': uri})

async def new_order(
Expand All @@ -166,8 +162,8 @@ async def new_order(
if not_after is not None:
data['notAfter'] = not_after.isoformat()

async with self._request(url, data=data) as response:
response_data = await response.json(loads=orjson.loads)
response = await self._request(url, data=data)
response_data = orjson.loads(response.content)

return _order_serializer.load({**response_data, 'uri': response.headers['Location']})

Expand All @@ -177,8 +173,8 @@ async def get_order(self, order_uri: str) -> Order:
:param order_uri: order URI.
"""
async with self._request(order_uri) as response:
response_data = await response.json(loads=orjson.loads)
response = await self._request(order_uri)
response_data = orjson.loads(response.content)
return _order_serializer.load({**response_data, 'uri': order_uri})

async def get_authorization(self, authorization_uri: str) -> Authorization:
Expand All @@ -187,8 +183,8 @@ async def get_authorization(self, authorization_uri: str) -> Authorization:
:param authorization_uri: authorization URI.
"""
async with self._request(authorization_uri) as response:
response_data = await response.json(loads=orjson.loads)
response = await self._request(authorization_uri)
response_data = orjson.loads(response.content)
return _authorization_serializer.load({**response_data, 'uri': authorization_uri})

def get_dns_challenge_domain(self, domain: str) -> str:
Expand Down Expand Up @@ -217,8 +213,8 @@ async def answer_challenge(self, url: str) -> Challenge:
:param url: challenge url.
"""
async with self._request(url, data={}) as response:
response_data = await response.json(loads=orjson.loads)
response = await self._request(url, data={})
response_data = orjson.loads(response.content)
return _challenge_serializer.load(response_data)

async def deactivate_authorization(self, uri: str) -> Authorization:
Expand All @@ -229,8 +225,8 @@ async def deactivate_authorization(self, uri: str) -> Authorization:
:param uri: authorization URI.
"""
async with self._request(uri, data={'status': 'deactivated'}) as response:
response_data = await response.json(loads=orjson.loads)
response = await self._request(uri, data={'status': 'deactivated'})
response_data = orjson.loads(response.content)
return _authorization_serializer.load({**response_data, 'uri': uri})

async def finalize_order(self, finalize: str, csr: x509.CertificateSigningRequest) -> Order:
Expand All @@ -240,11 +236,11 @@ async def finalize_order(self, finalize: str, csr: x509.CertificateSigningReques
:param finalize: finalize uri.
:param csr: CSR.
"""
async with self._request(
response = await self._request(
finalize,
data={'csr': b64_encode(csr.public_bytes(serialization.Encoding.DER)).decode('ascii')},
) as response:
response_data = await response.json(loads=orjson.loads)
)
response_data = orjson.loads(response.content)
return _order_serializer.load({**response_data, 'uri': response.headers['Location']})

async def get_certificate(self, certificate: str) -> bytes:
Expand All @@ -253,8 +249,8 @@ async def get_certificate(self, certificate: str) -> bytes:
:param certificate: certificate URL.
"""
async with self._request(certificate) as response:
return await response.read()
response = await self._request(certificate)
return response.content

async def revoke_certificate(
self,
Expand All @@ -271,16 +267,15 @@ async def revoke_certificate(
:param key: private key, if you want to revoke certificate without account key.
:param reason: reason.
"""
async with self._request(
await self._request(
(await self._get_directory()).revoke_cert,
data={
'certificate': b64_encode(certificate.public_bytes(serialization.Encoding.DER)).decode('ascii'),
'reason': reason.value,
},
jwk=make_jwk(key) if key else None,
key=key,
):
pass
)

async def change_key(self, new_account_key: PrivateKeyTypes) -> None:
"""
Expand All @@ -304,46 +299,45 @@ async def change_key(self, new_account_key: PrivateKeyTypes) -> None:
jwk=new_account_key_jwk,
add_nonce=False,
)
async with self._request(url, data=payload):
pass
await self._request(url, data=payload)
self.account_key = new_account_key
self._account_key_jwk = new_account_key_jwk
self._account_key_jwk_thumbprint = jwk_thumbprint(new_account_key_jwk)

async def _get_account_uri(self) -> str:
return self.account_uri or (await self.get_account()).uri

@asynccontextmanager
async def _request(
self,
url: str,
*,
data: Mapping[str, Any] | bytes = b'',
jwk: JWK | None = None,
key: PrivateKeyTypes | None = None,
) -> AsyncIterator[aiohttp.ClientResponse]:
async with self._session.post(
) -> httpx.Response:
response = await self._client.post(
url,
data=await self._wrap_in_jws(url=url, data=data, jwk=jwk, key=key),
content=await self._wrap_in_jws(url=url, data=data, jwk=jwk, key=key),
headers={'Content-Type': 'application/jose+json'},
) as response:
self._add_nounce(response)
)

self._add_nounce(response)

if response.status < 300:
yield response
return
if response.status_code < 300:
return response

try:
error = _error_serializer.load(await response.json(loads=orjson.loads))
except aiohttp.ContentTypeError:
error = Error(type='unknown', detail='')
try:
response_data = orjson.loads(response.content)
except orjson.JSONDecodeError:
error = Error(type='unknown', detail=response.text)
else:
error = _error_serializer.load(response_data)

if error.type != 'urn:ietf:params:acme:error:badNonce':
raise AcmeError(error)
if error.type != 'urn:ietf:params:acme:error:badNonce':
raise AcmeError(error)

# retry bad nonce
async with self._request(url, data=data, jwk=jwk, key=key) as retried_response:
yield retried_response
# retry bad nonce
return await self._request(url, data=data, jwk=jwk, key=key)

async def _wrap_in_jws(
self,
Expand All @@ -369,18 +363,18 @@ async def _wrap_in_jws(
async def _get_nonce(self) -> str:
async with self._get_nonce_lock:
if not self._nonces:
async with self._session.head((await self._get_directory()).new_nonce) as response:
self._nonces.append(response.headers['Replay-Nonce'])
response = await self._client.head((await self._get_directory()).new_nonce)
self._add_nounce(response)
return self._nonces.pop()

def _add_nounce(self, response: aiohttp.ClientResponse) -> None:
def _add_nounce(self, response: httpx.Response) -> None:
self._nonces.append(response.headers['Replay-Nonce'])

async def _get_directory(self) -> '_Directory':
if self._directory is None:
async with self._session.get(self.directory_url) as response:
response_data = await response.json(loads=orjson.loads)

response = await self._client.get(self.directory_url)
response.raise_for_status()
response_data = orjson.loads(response.content)
self._directory = _Directory(
new_account=response_data['newAccount'],
new_nonce=response_data['newNonce'],
Expand All @@ -393,15 +387,16 @@ async def _get_directory(self) -> '_Directory':
return self._directory

async def close(self) -> None:
await self._session.close()
await self._client.aclose()

async def __aenter__(self) -> Self:
return self

async def __aexit__(
self, exc_type: type[BaseException] | None, exc: BaseException | None, tb: TracebackType | None
) -> None:
await asyncio.shield(self.close())
with anyio.CancelScope(shield=True):
await self.close()


@dataclass(frozen=True, slots=True)
Expand Down
16 changes: 9 additions & 7 deletions src/aioacme/_jws.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,21 @@
else:
from typing_extensions import assert_never

_EC_ALGS: Mapping[type[ec.EllipticCurve], str] = {ec.SECP256R1: 'ES256', ec.SECP384R1: 'ES384', ec.SECP521R1: 'ES512'}
_EC_HASHES: Mapping[type[ec.EllipticCurve], hashes.HashAlgorithm] = {
ec.SECP256R1: hashes.SHA256(),
ec.SECP384R1: hashes.SHA384(),
ec.SECP521R1: hashes.SHA512(),
}


def jws_encode(
payload: bytes,
key: PrivateKeyTypes | hmac.HMAC,
headers: Mapping[str, Any],
) -> bytes:
if isinstance(key, ec.EllipticCurvePrivateKey):
alg = {ec.SECP256R1.name: 'ES256', ec.SECP384R1.name: 'ES384', ec.SECP521R1.name: 'ES512'}[key.curve.name]
alg = _EC_ALGS[type(key.curve)]
elif isinstance(key, rsa.RSAPrivateKey):
alg = 'RS256'
elif isinstance(key, ed25519.Ed25519PrivateKey):
Expand All @@ -43,11 +50,7 @@ def jws_encode(
if isinstance(key, rsa.RSAPrivateKey):
signature = key.sign(signing_input, PKCS1v15(), hashes.SHA256())
elif isinstance(key, ec.EllipticCurvePrivateKey):
hash_alg = {
ec.SECP256R1.name: hashes.SHA256(),
ec.SECP384R1.name: hashes.SHA384(),
ec.SECP521R1.name: hashes.SHA512(),
}[key.curve.name]
hash_alg = _EC_HASHES[type(key.curve)]
signature = key.sign(signing_input, ec.ECDSA(hash_alg))
signature = _der_to_raw_signature(signature, key.curve)
elif isinstance(key, ed25519.Ed25519PrivateKey):
Expand All @@ -56,7 +59,6 @@ def jws_encode(
key = key.copy()
key.update(signing_input)
signature = key.finalize()

else:
assert_never(key)

Expand Down
Loading

0 comments on commit d257d7c

Please sign in to comment.