diff --git a/httpcore/_async/http11.py b/httpcore/_async/http11.py index e6d6d709..0ffb2170 100644 --- a/httpcore/_async/http11.py +++ b/httpcore/_async/http11.py @@ -17,10 +17,10 @@ WriteError, map_exceptions, ) -from .._models import Origin, Request, Response -from .._synchronization import AsyncLock, AsyncShieldCancellation +from .._models import Origin, Request +from .._synchronization import AsyncSemaphore from .._trace import Trace -from .interfaces import AsyncConnectionInterface +from .interfaces import AsyncConnectionInterface, StartResponse logger = logging.getLogger("httpcore.http11") @@ -55,7 +55,7 @@ def __init__( self._keepalive_expiry: float | None = keepalive_expiry self._expire_at: float | None = None self._state = HTTPConnectionState.NEW - self._state_lock = AsyncLock() + self._request_lock = AsyncSemaphore(bound=1) self._request_count = 0 self._h11_state = h11.Connection( our_role=h11.CLIENT, @@ -63,13 +63,25 @@ def __init__( ) async def handle_async_request(self, request: Request) -> Response: + iterator = self.iterate_response(request) + resp = await anext(iterator) + return Response( + status=resp.status, + headers=resp.headers, + content=iterator, + extensions=resp.extensions, + ) + + async def iterate_response( + self, request: Request + ) -> typing.AsyncIterator[StartResponse | bytes]: if not self.can_handle_request(request.url.origin): raise RuntimeError( f"Attempted to send request to {request.url.origin} on connection " f"to {self._origin}" ) - async with self._state_lock: + async with self._request_lock: if self._state in (HTTPConnectionState.NEW, HTTPConnectionState.IDLE): self._request_count += 1 self._state = HTTPConnectionState.ACTIVE @@ -77,63 +89,69 @@ async def handle_async_request(self, request: Request) -> Response: else: raise ConnectionNotAvailable() - try: - kwargs = {"request": request} try: + kwargs = {"request": request} + try: + async with Trace( + "send_request_headers", logger, request, kwargs + ) as trace: + await self._send_request_headers(**kwargs) + async with Trace( + "send_request_body", logger, request, kwargs + ) as trace: + await self._send_request_body(**kwargs) + except WriteError: + # If we get a write error while we're writing the request, + # then we supress this error and move on to attempting to + # read the response. Servers can sometimes close the request + # pre-emptively and then respond with a well formed HTTP + # error response. + pass + async with Trace( - "send_request_headers", logger, request, kwargs + "receive_response_headers", logger, request, kwargs ) as trace: - await self._send_request_headers(**kwargs) - async with Trace("send_request_body", logger, request, kwargs) as trace: - await self._send_request_body(**kwargs) - except WriteError: - # If we get a write error while we're writing the request, - # then we supress this error and move on to attempting to - # read the response. Servers can sometimes close the request - # pre-emptively and then respond with a well formed HTTP - # error response. - pass - - async with Trace( - "receive_response_headers", logger, request, kwargs - ) as trace: - ( - http_version, - status, - reason_phrase, - headers, - trailing_data, - ) = await self._receive_response_headers(**kwargs) - trace.return_value = ( - http_version, - status, - reason_phrase, - headers, + ( + http_version, + status, + reason_phrase, + headers, + trailing_data, + ) = await self._receive_response_headers(**kwargs) + trace.return_value = ( + http_version, + status, + reason_phrase, + headers, + ) + + network_stream = self._network_stream + + # CONNECT or Upgrade request + if (status == 101) or ( + (request.method == b"CONNECT") and (200 <= status < 300) + ): + network_stream = AsyncHTTP11UpgradeStream( + network_stream, trailing_data + ) + + yield Response( + status=status, + headers=headers, + extensions={ + "http_version": http_version, + "reason_phrase": reason_phrase, + "network_stream": network_stream, + }, ) - - network_stream = self._network_stream - - # CONNECT or Upgrade request - if (status == 101) or ( - (request.method == b"CONNECT") and (200 <= status < 300) - ): - network_stream = AsyncHTTP11UpgradeStream(network_stream, trailing_data) - - return Response( - status=status, - headers=headers, - content=HTTP11ConnectionByteStream(self, request), - extensions={ - "http_version": http_version, - "reason_phrase": reason_phrase, - "network_stream": network_stream, - }, - ) - except BaseException as exc: - with AsyncShieldCancellation(): + async with Trace("receive_response_body", logger, request, kwargs): + async for chunk in self._receive_response_body(**kwargs): + yield chunk + finally: + await self._response_closed() async with Trace("response_closed", logger, request) as trace: - await self._response_closed() - raise exc + if self.is_closed(): + await self.aclose() # Sending the request... @@ -236,18 +254,17 @@ async def _receive_event( return event # type: ignore[return-value] async def _response_closed(self) -> None: - async with self._state_lock: - if ( - self._h11_state.our_state is h11.DONE - and self._h11_state.their_state is h11.DONE - ): - self._state = HTTPConnectionState.IDLE - self._h11_state.start_next_cycle() - if self._keepalive_expiry is not None: - now = time.monotonic() - self._expire_at = now + self._keepalive_expiry - else: - await self.aclose() + if ( + self._h11_state.our_state is h11.DONE + and self._h11_state.their_state is h11.DONE + ): + self._state = HTTPConnectionState.IDLE + self._h11_state.start_next_cycle() + if self._keepalive_expiry is not None: + now = time.monotonic() + self._expire_at = now + self._keepalive_expiry + else: + self._state = HTTPConnectionState.CLOSED # Once the connection is no longer required... @@ -321,33 +338,6 @@ async def __aexit__( await self.aclose() -class HTTP11ConnectionByteStream: - def __init__(self, connection: AsyncHTTP11Connection, request: Request) -> None: - self._connection = connection - self._request = request - self._closed = False - - async def __aiter__(self) -> typing.AsyncIterator[bytes]: - kwargs = {"request": self._request} - try: - async with Trace("receive_response_body", logger, self._request, kwargs): - async for chunk in self._connection._receive_response_body(**kwargs): - yield chunk - except BaseException as exc: - # If we get an exception while streaming the response, - # we want to close the response (and possibly the connection) - # before raising that exception. - with AsyncShieldCancellation(): - await self.aclose() - raise exc - - async def aclose(self) -> None: - if not self._closed: - self._closed = True - async with Trace("response_closed", logger, self._request): - await self._connection._response_closed() - - class AsyncHTTP11UpgradeStream(AsyncNetworkStream): def __init__(self, stream: AsyncNetworkStream, leading_data: bytes) -> None: self._stream = stream