Skip to content

Commit

Permalink
Add asyncio backend. Add integration testing
Browse files Browse the repository at this point in the history
  • Loading branch information
MarkusSintonen committed Jun 17, 2024
1 parent da86ca4 commit 81b6b1f
Show file tree
Hide file tree
Showing 13 changed files with 627 additions and 51 deletions.
1 change: 1 addition & 0 deletions docs/exceptions.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ The following exceptions may be raised when sending a request:
* `httpcore.WriteTimeout`
* `httpcore.NetworkError`
* `httpcore.ConnectError`
* `httpcore.BrokenSocketError`
* `httpcore.ReadError`
* `httpcore.WriteError`
* `httpcore.ProtocolError`
Expand Down
6 changes: 6 additions & 0 deletions httpcore/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
AsyncHTTPProxy,
AsyncSOCKSProxy,
)
from ._backends.asyncio import AsyncioBackend
from ._backends.auto import AutoBackend
from ._backends.base import (
SOCKET_OPTION,
AsyncNetworkBackend,
Expand All @@ -18,6 +20,7 @@
from ._backends.mock import AsyncMockBackend, AsyncMockStream, MockBackend, MockStream
from ._backends.sync import SyncBackend
from ._exceptions import (
BrokenSocketError,
ConnectError,
ConnectionNotAvailable,
ConnectTimeout,
Expand Down Expand Up @@ -97,6 +100,8 @@ def __init__(self, *args, **kwargs): # type: ignore
"SOCKSProxy",
# network backends, implementations
"SyncBackend",
"AutoBackend",
"AsyncioBackend",
"AnyIOBackend",
"TrioBackend",
# network backends, mock implementations
Expand Down Expand Up @@ -126,6 +131,7 @@ def __init__(self, *args, **kwargs): # type: ignore
"WriteTimeout",
"NetworkError",
"ConnectError",
"BrokenSocketError",
"ReadError",
"WriteError",
]
Expand Down
34 changes: 22 additions & 12 deletions httpcore/_backends/anyio.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import anyio

from .._exceptions import (
BrokenSocketError,
ConnectError,
ConnectTimeout,
ReadError,
Expand Down Expand Up @@ -82,6 +83,9 @@ async def start_tls(
return AnyIOStream(ssl_stream)

def get_extra_info(self, info: str) -> typing.Any:
if info == "is_readable":
sock = self._stream.extra(anyio.abc.SocketAttribute.raw_socket, None)
return is_socket_readable(sock)
if info == "ssl_object":
return self._stream.extra(anyio.streams.tls.TLSAttribute.ssl_object, None)
if info == "client_addr":
Expand All @@ -90,9 +94,6 @@ def get_extra_info(self, info: str) -> typing.Any:
return self._stream.extra(anyio.abc.SocketAttribute.remote_address, None)
if info == "socket":
return self._stream.extra(anyio.abc.SocketAttribute.raw_socket, None)
if info == "is_readable":
sock = self._stream.extra(anyio.abc.SocketAttribute.raw_socket, None)
return is_socket_readable(sock)
return None


Expand All @@ -105,8 +106,6 @@ async def connect_tcp(
local_address: typing.Optional[str] = None,
socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None,
) -> AsyncNetworkStream:
if socket_options is None:
socket_options = [] # pragma: no cover
exc_map = {
TimeoutError: ConnectTimeout,
OSError: ConnectError,
Expand All @@ -120,18 +119,15 @@ async def connect_tcp(
local_host=local_address,
)
# By default TCP sockets opened in `asyncio` include TCP_NODELAY.
for option in socket_options:
stream._raw_socket.setsockopt(*option) # type: ignore[attr-defined] # pragma: no cover
self._set_socket_options(stream, socket_options)
return AnyIOStream(stream)

async def connect_unix_socket(
self,
path: str,
timeout: typing.Optional[float] = None,
socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None,
) -> AsyncNetworkStream: # pragma: nocover
if socket_options is None:
socket_options = []
) -> AsyncNetworkStream:
exc_map = {
TimeoutError: ConnectTimeout,
OSError: ConnectError,
Expand All @@ -140,9 +136,23 @@ async def connect_unix_socket(
with map_exceptions(exc_map):
with anyio.fail_after(timeout):
stream: anyio.abc.ByteStream = await anyio.connect_unix(path)
for option in socket_options:
stream._raw_socket.setsockopt(*option) # type: ignore[attr-defined] # pragma: no cover
self._set_socket_options(stream, socket_options)
return AnyIOStream(stream)

async def sleep(self, seconds: float) -> None:
await anyio.sleep(seconds) # pragma: nocover

def _set_socket_options(
self,
stream: anyio.abc.ByteStream,
socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None,
) -> None:
if not socket_options:
return

sock = stream.extra(anyio.abc.SocketAttribute.raw_socket, None)
if sock is None:
raise BrokenSocketError() # pragma: nocover

for option in socket_options:
sock.setsockopt(*option)
227 changes: 227 additions & 0 deletions httpcore/_backends/asyncio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
import asyncio
import socket
import ssl
from typing import Any, Dict, Iterable, Optional, Type

from .._exceptions import (
BrokenSocketError,
ConnectError,
ConnectTimeout,
ReadError,
ReadTimeout,
WriteError,
WriteTimeout,
map_exceptions,
)
from .._utils import is_socket_readable
from .base import SOCKET_OPTION, AsyncNetworkBackend, AsyncNetworkStream


class AsyncIOStream(AsyncNetworkStream):
def __init__(
self, stream_reader: asyncio.StreamReader, stream_writer: asyncio.StreamWriter
):
self._stream_reader = stream_reader
self._stream_writer = stream_writer
self._read_lock = asyncio.Lock()
self._write_lock = asyncio.Lock()
self._inner: Optional[AsyncIOStream] = None

async def start_tls(
self,
ssl_context: ssl.SSLContext,
server_hostname: Optional[str] = None,
timeout: Optional[float] = None,
) -> AsyncNetworkStream:
loop = asyncio.get_event_loop()

stream_reader = asyncio.StreamReader()
protocol = asyncio.StreamReaderProtocol(stream_reader)

exc_map: Dict[Type[Exception], Type[Exception]] = {
asyncio.TimeoutError: ConnectTimeout,
OSError: ConnectError,
}
with map_exceptions(exc_map):
transport_ssl = await asyncio.wait_for(
loop.start_tls(
self._stream_writer.transport,
protocol,
ssl_context,
server_hostname=server_hostname,
),
timeout,
)
if transport_ssl is None:
# https://docs.python.org/3/library/asyncio-eventloop.html#asyncio.loop.start_tls
raise ConnectError("Transport closed while starting TLS") # pragma: nocover

# Initialize the protocol, so it is made aware of being tied to
# a TLS connection.
# See: https://github.com/encode/httpx/issues/859
protocol.connection_made(transport_ssl)

stream_writer = asyncio.StreamWriter(
transport=transport_ssl, protocol=protocol, reader=stream_reader, loop=loop
)

ssl_stream = AsyncIOStream(stream_reader, stream_writer)
# When we return a new SocketStream with new StreamReader/StreamWriter instances
# we need to keep references to the old StreamReader/StreamWriter so that they
# are not garbage collected and closed while we're still using them.
ssl_stream._inner = self
return ssl_stream

async def read(self, max_bytes: int, timeout: Optional[float] = None) -> bytes:
exc_map: Dict[Type[Exception], Type[Exception]] = {
asyncio.TimeoutError: ReadTimeout,
OSError: ReadError,
}
async with self._read_lock:
with map_exceptions(exc_map):
try:
return await asyncio.wait_for(
self._stream_reader.read(max_bytes), timeout
)
except AttributeError as exc: # pragma: nocover
if "resume_reading" in str(exc):
# Python's asyncio has a bug that can occur when a
# connection has been closed, while it is paused.
# See: https://github.com/encode/httpx/issues/1213
#
# Returning an empty byte-string to indicate connection
# close will eventually raise an httpcore.RemoteProtocolError
# to the user when this goes through our HTTP parsing layer.
return b""
raise

async def write(self, data: bytes, timeout: Optional[float] = None) -> None:
if not data:
return

exc_map: Dict[Type[Exception], Type[Exception]] = {
asyncio.TimeoutError: WriteTimeout,
OSError: WriteError,
}
async with self._write_lock:
with map_exceptions(exc_map):
self._stream_writer.write(data)
return await asyncio.wait_for(self._stream_writer.drain(), timeout)

async def aclose(self) -> None:
# SSL connections should issue the close and then abort, rather than
# waiting for the remote end of the connection to signal the EOF.
#
# See:
#
# * https://bugs.python.org/issue39758
# * https://github.com/python-trio/trio/blob/
# 31e2ae866ad549f1927d45ce073d4f0ea9f12419/trio/_ssl.py#L779-L829
#
# And related issues caused if we simply omit the 'wait_closed' call,
# without first using `.abort()`
#
# * https://github.com/encode/httpx/issues/825
# * https://github.com/encode/httpx/issues/914
is_ssl = self._sslobj is not None

async with self._write_lock:
try:
self._stream_writer.close()
if is_ssl:
# Give the connection a chance to write any data in the buffer,
# and then forcibly tear down the SSL connection.
await asyncio.sleep(0)
self._stream_writer.transport.abort()
await self._stream_writer.wait_closed()
except OSError: # pragma: nocover
pass

def get_extra_info(self, info: str) -> Any:
if info == "is_readable":
return is_socket_readable(self._raw_socket)
if info == "ssl_object":
return self._sslobj
if info in ("client_addr", "server_addr"):
sock = self._raw_socket
if sock is None:
raise BrokenSocketError() # pragma: nocover
return sock.getsockname() if info == "client_addr" else sock.getpeername()
if info == "socket":
return self._raw_socket
return None

@property
def _raw_socket(self) -> Optional[socket.socket]:
transport = self._stream_writer.transport
sock: Optional[socket.socket] = transport.get_extra_info("socket")
return sock

@property
def _sslobj(self) -> Optional[ssl.SSLObject]:
transport = self._stream_writer.transport
sslobj: Optional[ssl.SSLObject] = transport.get_extra_info("ssl_object")
return sslobj


class AsyncioBackend(AsyncNetworkBackend):
async def connect_tcp(
self,
host: str,
port: int,
timeout: Optional[float] = None,
local_address: Optional[str] = None,
socket_options: Optional[Iterable[SOCKET_OPTION]] = None,
) -> AsyncNetworkStream:
local_addr = None if local_address is None else (local_address, 0)

exc_map: Dict[Type[Exception], Type[Exception]] = {
asyncio.TimeoutError: ConnectTimeout,
OSError: ConnectError,
}
with map_exceptions(exc_map):
stream_reader, stream_writer = await asyncio.wait_for(
asyncio.open_connection(host, port, local_addr=local_addr),
timeout,
)
self._set_socket_options(stream_writer, socket_options)
return AsyncIOStream(
stream_reader=stream_reader, stream_writer=stream_writer
)

async def connect_unix_socket(
self,
path: str,
timeout: Optional[float] = None,
socket_options: Optional[Iterable[SOCKET_OPTION]] = None,
) -> AsyncNetworkStream:
exc_map: Dict[Type[Exception], Type[Exception]] = {
asyncio.TimeoutError: ConnectTimeout,
OSError: ConnectError,
}
with map_exceptions(exc_map):
stream_reader, stream_writer = await asyncio.wait_for(
asyncio.open_unix_connection(path), timeout
)
self._set_socket_options(stream_writer, socket_options)
return AsyncIOStream(
stream_reader=stream_reader, stream_writer=stream_writer
)

async def sleep(self, seconds: float) -> None:
await asyncio.sleep(seconds) # pragma: nocover

def _set_socket_options(
self,
stream: asyncio.StreamWriter,
socket_options: Optional[Iterable[SOCKET_OPTION]] = None,
) -> None:
if not socket_options:
return

sock = stream.get_extra_info("socket")
if sock is None:
raise BrokenSocketError() # pragma: nocover

for option in socket_options:
sock.setsockopt(*option)
Loading

0 comments on commit 81b6b1f

Please sign in to comment.