From 81b6b1f6de93d27ec459f1690dbb2bdbd79370c5 Mon Sep 17 00:00:00 2001 From: Markus Sintonen Date: Sun, 16 Jun 2024 21:26:58 +0300 Subject: [PATCH 1/7] Add asyncio backend. Add integration testing --- docs/exceptions.md | 1 + httpcore/__init__.py | 6 + httpcore/_backends/anyio.py | 34 +++-- httpcore/_backends/asyncio.py | 227 +++++++++++++++++++++++++++++++ httpcore/_backends/auto.py | 39 ++++-- httpcore/_backends/trio.py | 26 ++-- httpcore/_exceptions.py | 4 + httpcore/_synchronization.py | 30 ++-- tests/_async/test_integration.py | 84 ++++++++++++ tests/_sync/test_integration.py | 84 ++++++++++++ tests/benchmark/client.py | 4 +- tests/conftest.py | 90 ++++++++++++ tests/test_auto_backend.py | 49 +++++++ 13 files changed, 627 insertions(+), 51 deletions(-) create mode 100644 httpcore/_backends/asyncio.py create mode 100644 tests/conftest.py create mode 100644 tests/test_auto_backend.py diff --git a/docs/exceptions.md b/docs/exceptions.md index 63ef3f28e..cff979945 100644 --- a/docs/exceptions.md +++ b/docs/exceptions.md @@ -9,6 +9,7 @@ The following exceptions may be raised when sending a request: * `httpcore.WriteTimeout` * `httpcore.NetworkError` * `httpcore.ConnectError` + * `httpcore.BrokenSocketError` * `httpcore.ReadError` * `httpcore.WriteError` * `httpcore.ProtocolError` diff --git a/httpcore/__init__.py b/httpcore/__init__.py index 014213bae..94f4dc0ba 100644 --- a/httpcore/__init__.py +++ b/httpcore/__init__.py @@ -8,6 +8,8 @@ AsyncHTTPProxy, AsyncSOCKSProxy, ) +from ._backends.asyncio import AsyncioBackend +from ._backends.auto import AutoBackend from ._backends.base import ( SOCKET_OPTION, AsyncNetworkBackend, @@ -18,6 +20,7 @@ from ._backends.mock import AsyncMockBackend, AsyncMockStream, MockBackend, MockStream from ._backends.sync import SyncBackend from ._exceptions import ( + BrokenSocketError, ConnectError, ConnectionNotAvailable, ConnectTimeout, @@ -97,6 +100,8 @@ def __init__(self, *args, **kwargs): # type: ignore "SOCKSProxy", # network backends, implementations "SyncBackend", + "AutoBackend", + "AsyncioBackend", "AnyIOBackend", "TrioBackend", # network backends, mock implementations @@ -126,6 +131,7 @@ def __init__(self, *args, **kwargs): # type: ignore "WriteTimeout", "NetworkError", "ConnectError", + "BrokenSocketError", "ReadError", "WriteError", ] diff --git a/httpcore/_backends/anyio.py b/httpcore/_backends/anyio.py index 9f4fdf86c..995a3d946 100644 --- a/httpcore/_backends/anyio.py +++ b/httpcore/_backends/anyio.py @@ -4,6 +4,7 @@ import anyio from .._exceptions import ( + BrokenSocketError, ConnectError, ConnectTimeout, ReadError, @@ -82,6 +83,9 @@ async def start_tls( return AnyIOStream(ssl_stream) def get_extra_info(self, info: str) -> typing.Any: + if info == "is_readable": + sock = self._stream.extra(anyio.abc.SocketAttribute.raw_socket, None) + return is_socket_readable(sock) if info == "ssl_object": return self._stream.extra(anyio.streams.tls.TLSAttribute.ssl_object, None) if info == "client_addr": @@ -90,9 +94,6 @@ def get_extra_info(self, info: str) -> typing.Any: return self._stream.extra(anyio.abc.SocketAttribute.remote_address, None) if info == "socket": return self._stream.extra(anyio.abc.SocketAttribute.raw_socket, None) - if info == "is_readable": - sock = self._stream.extra(anyio.abc.SocketAttribute.raw_socket, None) - return is_socket_readable(sock) return None @@ -105,8 +106,6 @@ async def connect_tcp( local_address: typing.Optional[str] = None, socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None, ) -> AsyncNetworkStream: - if socket_options is None: - socket_options = [] # pragma: no cover exc_map = { TimeoutError: ConnectTimeout, OSError: ConnectError, @@ -120,8 +119,7 @@ async def connect_tcp( local_host=local_address, ) # By default TCP sockets opened in `asyncio` include TCP_NODELAY. - for option in socket_options: - stream._raw_socket.setsockopt(*option) # type: ignore[attr-defined] # pragma: no cover + self._set_socket_options(stream, socket_options) return AnyIOStream(stream) async def connect_unix_socket( @@ -129,9 +127,7 @@ async def connect_unix_socket( path: str, timeout: typing.Optional[float] = None, socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None, - ) -> AsyncNetworkStream: # pragma: nocover - if socket_options is None: - socket_options = [] + ) -> AsyncNetworkStream: exc_map = { TimeoutError: ConnectTimeout, OSError: ConnectError, @@ -140,9 +136,23 @@ async def connect_unix_socket( with map_exceptions(exc_map): with anyio.fail_after(timeout): stream: anyio.abc.ByteStream = await anyio.connect_unix(path) - for option in socket_options: - stream._raw_socket.setsockopt(*option) # type: ignore[attr-defined] # pragma: no cover + self._set_socket_options(stream, socket_options) return AnyIOStream(stream) async def sleep(self, seconds: float) -> None: await anyio.sleep(seconds) # pragma: nocover + + def _set_socket_options( + self, + stream: anyio.abc.ByteStream, + socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None, + ) -> None: + if not socket_options: + return + + sock = stream.extra(anyio.abc.SocketAttribute.raw_socket, None) + if sock is None: + raise BrokenSocketError() # pragma: nocover + + for option in socket_options: + sock.setsockopt(*option) diff --git a/httpcore/_backends/asyncio.py b/httpcore/_backends/asyncio.py new file mode 100644 index 000000000..312fb648a --- /dev/null +++ b/httpcore/_backends/asyncio.py @@ -0,0 +1,227 @@ +import asyncio +import socket +import ssl +from typing import Any, Dict, Iterable, Optional, Type + +from .._exceptions import ( + BrokenSocketError, + ConnectError, + ConnectTimeout, + ReadError, + ReadTimeout, + WriteError, + WriteTimeout, + map_exceptions, +) +from .._utils import is_socket_readable +from .base import SOCKET_OPTION, AsyncNetworkBackend, AsyncNetworkStream + + +class AsyncIOStream(AsyncNetworkStream): + def __init__( + self, stream_reader: asyncio.StreamReader, stream_writer: asyncio.StreamWriter + ): + self._stream_reader = stream_reader + self._stream_writer = stream_writer + self._read_lock = asyncio.Lock() + self._write_lock = asyncio.Lock() + self._inner: Optional[AsyncIOStream] = None + + async def start_tls( + self, + ssl_context: ssl.SSLContext, + server_hostname: Optional[str] = None, + timeout: Optional[float] = None, + ) -> AsyncNetworkStream: + loop = asyncio.get_event_loop() + + stream_reader = asyncio.StreamReader() + protocol = asyncio.StreamReaderProtocol(stream_reader) + + exc_map: Dict[Type[Exception], Type[Exception]] = { + asyncio.TimeoutError: ConnectTimeout, + OSError: ConnectError, + } + with map_exceptions(exc_map): + transport_ssl = await asyncio.wait_for( + loop.start_tls( + self._stream_writer.transport, + protocol, + ssl_context, + server_hostname=server_hostname, + ), + timeout, + ) + if transport_ssl is None: + # https://docs.python.org/3/library/asyncio-eventloop.html#asyncio.loop.start_tls + raise ConnectError("Transport closed while starting TLS") # pragma: nocover + + # Initialize the protocol, so it is made aware of being tied to + # a TLS connection. + # See: https://github.com/encode/httpx/issues/859 + protocol.connection_made(transport_ssl) + + stream_writer = asyncio.StreamWriter( + transport=transport_ssl, protocol=protocol, reader=stream_reader, loop=loop + ) + + ssl_stream = AsyncIOStream(stream_reader, stream_writer) + # When we return a new SocketStream with new StreamReader/StreamWriter instances + # we need to keep references to the old StreamReader/StreamWriter so that they + # are not garbage collected and closed while we're still using them. + ssl_stream._inner = self + return ssl_stream + + async def read(self, max_bytes: int, timeout: Optional[float] = None) -> bytes: + exc_map: Dict[Type[Exception], Type[Exception]] = { + asyncio.TimeoutError: ReadTimeout, + OSError: ReadError, + } + async with self._read_lock: + with map_exceptions(exc_map): + try: + return await asyncio.wait_for( + self._stream_reader.read(max_bytes), timeout + ) + except AttributeError as exc: # pragma: nocover + if "resume_reading" in str(exc): + # Python's asyncio has a bug that can occur when a + # connection has been closed, while it is paused. + # See: https://github.com/encode/httpx/issues/1213 + # + # Returning an empty byte-string to indicate connection + # close will eventually raise an httpcore.RemoteProtocolError + # to the user when this goes through our HTTP parsing layer. + return b"" + raise + + async def write(self, data: bytes, timeout: Optional[float] = None) -> None: + if not data: + return + + exc_map: Dict[Type[Exception], Type[Exception]] = { + asyncio.TimeoutError: WriteTimeout, + OSError: WriteError, + } + async with self._write_lock: + with map_exceptions(exc_map): + self._stream_writer.write(data) + return await asyncio.wait_for(self._stream_writer.drain(), timeout) + + async def aclose(self) -> None: + # SSL connections should issue the close and then abort, rather than + # waiting for the remote end of the connection to signal the EOF. + # + # See: + # + # * https://bugs.python.org/issue39758 + # * https://github.com/python-trio/trio/blob/ + # 31e2ae866ad549f1927d45ce073d4f0ea9f12419/trio/_ssl.py#L779-L829 + # + # And related issues caused if we simply omit the 'wait_closed' call, + # without first using `.abort()` + # + # * https://github.com/encode/httpx/issues/825 + # * https://github.com/encode/httpx/issues/914 + is_ssl = self._sslobj is not None + + async with self._write_lock: + try: + self._stream_writer.close() + if is_ssl: + # Give the connection a chance to write any data in the buffer, + # and then forcibly tear down the SSL connection. + await asyncio.sleep(0) + self._stream_writer.transport.abort() + await self._stream_writer.wait_closed() + except OSError: # pragma: nocover + pass + + def get_extra_info(self, info: str) -> Any: + if info == "is_readable": + return is_socket_readable(self._raw_socket) + if info == "ssl_object": + return self._sslobj + if info in ("client_addr", "server_addr"): + sock = self._raw_socket + if sock is None: + raise BrokenSocketError() # pragma: nocover + return sock.getsockname() if info == "client_addr" else sock.getpeername() + if info == "socket": + return self._raw_socket + return None + + @property + def _raw_socket(self) -> Optional[socket.socket]: + transport = self._stream_writer.transport + sock: Optional[socket.socket] = transport.get_extra_info("socket") + return sock + + @property + def _sslobj(self) -> Optional[ssl.SSLObject]: + transport = self._stream_writer.transport + sslobj: Optional[ssl.SSLObject] = transport.get_extra_info("ssl_object") + return sslobj + + +class AsyncioBackend(AsyncNetworkBackend): + async def connect_tcp( + self, + host: str, + port: int, + timeout: Optional[float] = None, + local_address: Optional[str] = None, + socket_options: Optional[Iterable[SOCKET_OPTION]] = None, + ) -> AsyncNetworkStream: + local_addr = None if local_address is None else (local_address, 0) + + exc_map: Dict[Type[Exception], Type[Exception]] = { + asyncio.TimeoutError: ConnectTimeout, + OSError: ConnectError, + } + with map_exceptions(exc_map): + stream_reader, stream_writer = await asyncio.wait_for( + asyncio.open_connection(host, port, local_addr=local_addr), + timeout, + ) + self._set_socket_options(stream_writer, socket_options) + return AsyncIOStream( + stream_reader=stream_reader, stream_writer=stream_writer + ) + + async def connect_unix_socket( + self, + path: str, + timeout: Optional[float] = None, + socket_options: Optional[Iterable[SOCKET_OPTION]] = None, + ) -> AsyncNetworkStream: + exc_map: Dict[Type[Exception], Type[Exception]] = { + asyncio.TimeoutError: ConnectTimeout, + OSError: ConnectError, + } + with map_exceptions(exc_map): + stream_reader, stream_writer = await asyncio.wait_for( + asyncio.open_unix_connection(path), timeout + ) + self._set_socket_options(stream_writer, socket_options) + return AsyncIOStream( + stream_reader=stream_reader, stream_writer=stream_writer + ) + + async def sleep(self, seconds: float) -> None: + await asyncio.sleep(seconds) # pragma: nocover + + def _set_socket_options( + self, + stream: asyncio.StreamWriter, + socket_options: Optional[Iterable[SOCKET_OPTION]] = None, + ) -> None: + if not socket_options: + return + + sock = stream.get_extra_info("socket") + if sock is None: + raise BrokenSocketError() # pragma: nocover + + for option in socket_options: + sock.setsockopt(*option) diff --git a/httpcore/_backends/auto.py b/httpcore/_backends/auto.py index 3ac05f4da..6e266381a 100644 --- a/httpcore/_backends/auto.py +++ b/httpcore/_backends/auto.py @@ -1,22 +1,41 @@ import typing -from typing import Optional +from importlib.util import find_spec +from typing import Optional, Type -from .._synchronization import current_async_library +from .._synchronization import current_async_backend from .base import SOCKET_OPTION, AsyncNetworkBackend, AsyncNetworkStream +HAS_ANYIO = find_spec("anyio") is not None + class AutoBackend(AsyncNetworkBackend): + @staticmethod + def set_default_backend(backend_class: Optional[Type[AsyncNetworkBackend]]) -> None: + setattr(AutoBackend, "_default_backend_class", backend_class) + async def _init_backend(self) -> None: - if not (hasattr(self, "_backend")): - backend = current_async_library() - if backend == "trio": - from .trio import TrioBackend + if hasattr(self, "_backend"): + return + + default_backend_class: Optional[Type[AsyncNetworkBackend]] = getattr( + AutoBackend, "_default_backend_class", None + ) + if default_backend_class is not None: + self._backend = default_backend_class() + return + + if current_async_backend() == "trio": + from .trio import TrioBackend + + self._backend = TrioBackend() + elif HAS_ANYIO: + from .anyio import AnyIOBackend - self._backend: AsyncNetworkBackend = TrioBackend() - else: - from .anyio import AnyIOBackend + self._backend = AnyIOBackend() + else: + from .asyncio import AsyncioBackend - self._backend = AnyIOBackend() + self._backend = AsyncioBackend() async def connect_tcp( self, diff --git a/httpcore/_backends/trio.py b/httpcore/_backends/trio.py index b1626d28e..26320e61f 100644 --- a/httpcore/_backends/trio.py +++ b/httpcore/_backends/trio.py @@ -117,10 +117,6 @@ async def connect_tcp( local_address: typing.Optional[str] = None, socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None, ) -> AsyncNetworkStream: - # By default for TCP sockets, trio enables TCP_NODELAY. - # https://trio.readthedocs.io/en/stable/reference-io.html#trio.SocketStream - if socket_options is None: - socket_options = [] # pragma: no cover timeout_or_inf = float("inf") if timeout is None else timeout exc_map: ExceptionMapping = { trio.TooSlowError: ConnectTimeout, @@ -132,8 +128,7 @@ async def connect_tcp( stream: trio.abc.Stream = await trio.open_tcp_stream( host=host, port=port, local_address=local_address ) - for option in socket_options: - stream.setsockopt(*option) # type: ignore[attr-defined] # pragma: no cover + self._set_socket_options(stream, socket_options) return TrioStream(stream) async def connect_unix_socket( @@ -141,9 +136,7 @@ async def connect_unix_socket( path: str, timeout: typing.Optional[float] = None, socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None, - ) -> AsyncNetworkStream: # pragma: nocover - if socket_options is None: - socket_options = [] + ) -> AsyncNetworkStream: timeout_or_inf = float("inf") if timeout is None else timeout exc_map: ExceptionMapping = { trio.TooSlowError: ConnectTimeout, @@ -153,9 +146,20 @@ async def connect_unix_socket( with map_exceptions(exc_map): with trio.fail_after(timeout_or_inf): stream: trio.abc.Stream = await trio.open_unix_socket(path) - for option in socket_options: - stream.setsockopt(*option) # type: ignore[attr-defined] # pragma: no cover + self._set_socket_options(stream, socket_options) return TrioStream(stream) async def sleep(self, seconds: float) -> None: await trio.sleep(seconds) # pragma: nocover + + def _set_socket_options( + self, + stream: trio.abc.Stream, + socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None, + ) -> None: + # By default for TCP sockets, trio enables TCP_NODELAY. + # https://trio.readthedocs.io/en/stable/reference-io.html#trio.SocketStream + if not socket_options: + return + for option in socket_options: + stream.setsockopt(*option) # type: ignore[attr-defined] diff --git a/httpcore/_exceptions.py b/httpcore/_exceptions.py index 81e7fc61d..dc62459b1 100644 --- a/httpcore/_exceptions.py +++ b/httpcore/_exceptions.py @@ -73,6 +73,10 @@ class ConnectError(NetworkError): pass +class BrokenSocketError(ConnectError): + pass + + class ReadError(NetworkError): pass diff --git a/httpcore/_synchronization.py b/httpcore/_synchronization.py index 9619a3983..fe4cbecad 100644 --- a/httpcore/_synchronization.py +++ b/httpcore/_synchronization.py @@ -1,6 +1,6 @@ import threading from types import TracebackType -from typing import Optional, Type +from typing import Literal, Optional, Type from ._exceptions import ExceptionMapping, PoolTimeout, map_exceptions @@ -18,30 +18,28 @@ anyio = None # type: ignore -def current_async_library() -> str: +AsyncBackend = Literal["asyncio", "trio"] + + +def current_async_backend() -> AsyncBackend: # Determine if we're running under trio or asyncio. # See https://sniffio.readthedocs.io/en/latest/ try: import sniffio except ImportError: # pragma: nocover - environment = "asyncio" + backend: AsyncBackend = "asyncio" else: - environment = sniffio.current_async_library() + backend = sniffio.current_async_library() # type: ignore[assignment] - if environment not in ("asyncio", "trio"): # pragma: nocover + if backend not in ("asyncio", "trio"): # pragma: nocover raise RuntimeError("Running under an unsupported async environment.") - if environment == "asyncio" and anyio is None: # pragma: nocover - raise RuntimeError( - "Running with asyncio requires installation of 'httpcore[asyncio]'." - ) - - if environment == "trio" and trio is None: # pragma: nocover + if backend == "trio" and trio is None: # pragma: nocover raise RuntimeError( "Running with trio requires installation of 'httpcore[trio]'." ) - return environment + return backend class AsyncLock: @@ -60,7 +58,7 @@ def setup(self) -> None: Detect if we're running under 'asyncio' or 'trio' and create a lock with the correct implementation. """ - self._backend = current_async_library() + self._backend = current_async_backend() if self._backend == "trio": self._trio_lock = trio.Lock() elif self._backend == "asyncio": @@ -118,7 +116,7 @@ def setup(self) -> None: Detect if we're running under 'asyncio' or 'trio' and create a lock with the correct implementation. """ - self._backend = current_async_library() + self._backend = current_async_backend() if self._backend == "trio": self._trio_event = trio.Event() elif self._backend == "asyncio": @@ -160,7 +158,7 @@ def setup(self) -> None: Detect if we're running under 'asyncio' or 'trio' and create a semaphore with the correct implementation. """ - self._backend = current_async_library() + self._backend = current_async_backend() if self._backend == "trio": self._trio_semaphore = trio.Semaphore( initial_value=self._bound, max_value=self._bound @@ -199,7 +197,7 @@ def __init__(self) -> None: Detect if we're running under 'asyncio' or 'trio' and create a shielded scope with the correct implementation. """ - self._backend = current_async_library() + self._backend = current_async_backend() if self._backend == "trio": self._trio_shield = trio.CancelScope(shield=True) diff --git a/tests/_async/test_integration.py b/tests/_async/test_integration.py index 1970531d5..797933e4a 100644 --- a/tests/_async/test_integration.py +++ b/tests/_async/test_integration.py @@ -1,8 +1,13 @@ +import os +import socket import ssl +from tempfile import gettempdir import pytest +import uvicorn import httpcore +from tests.conftest import Server @pytest.mark.anyio @@ -49,3 +54,82 @@ async def test_extra_info(httpbin_secure): assert invalid is None stream.get_extra_info("is_readable") + + +@pytest.mark.anyio +@pytest.mark.parametrize("keep_alive_enabled", [True, False]) +async def test_socket_options( + server: Server, server_url: str, keep_alive_enabled: bool +) -> None: + socket_options = [(socket.SOL_SOCKET, socket.SO_KEEPALIVE, int(keep_alive_enabled))] + async with httpcore.AsyncConnectionPool(socket_options=socket_options) as pool: + response = await pool.request("GET", server_url) + assert response.status == 200 + + stream = response.extensions["network_stream"] + sock = stream.get_extra_info("socket") + opt = sock.getsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE) + assert bool(opt) is keep_alive_enabled + + +@pytest.mark.anyio +async def test_socket_no_nagle(server: Server, server_url: str) -> None: + async with httpcore.AsyncConnectionPool() as pool: + response = await pool.request("GET", server_url) + assert response.status == 200 + + stream = response.extensions["network_stream"] + sock = stream.get_extra_info("socket") + opt = sock.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY) + assert bool(opt) is True + + +@pytest.mark.anyio +async def test_pool_recovers_from_connection_breakage( + server_config: uvicorn.Config, server_url: str +) -> None: + async with httpcore.AsyncConnectionPool( + max_connections=1, max_keepalive_connections=1, keepalive_expiry=10 + ) as pool: + with Server(server_config).run_in_thread(): + response = await pool.request("GET", server_url) + assert response.status == 200 + + assert len(pool.connections) == 1 + conn = pool.connections[0] + + stream = response.extensions["network_stream"] + assert stream.get_extra_info("is_readable") is False + + assert ( + stream.get_extra_info("is_readable") is True + ), "Should break by coming readable" + + with Server(server_config).run_in_thread(): + assert len(pool.connections) == 1 + assert pool.connections[0] is conn, "Should be the broken connection" + + response = await pool.request("GET", server_url) + assert response.status == 200 + + assert len(pool.connections) == 1 + assert pool.connections[0] is not conn, "Should be a new connection" + + +@pytest.mark.anyio +async def test_unix_domain_socket(server_port, server_config, server_url): + uds = f"{gettempdir()}/test_httpcore_app.sock" + if os.path.exists(uds): + os.remove(uds) # pragma: nocover + + uds_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + try: + uds_sock.bind(uds) + + with Server(server_config).run_in_thread(sockets=[uds_sock]): + async with httpcore.AsyncConnectionPool(uds=uds) as pool: + response = await pool.request("GET", server_url) + assert response.status == 200 + finally: + uds_sock.close() + os.remove(uds) diff --git a/tests/_sync/test_integration.py b/tests/_sync/test_integration.py index e3327e696..d114f878c 100644 --- a/tests/_sync/test_integration.py +++ b/tests/_sync/test_integration.py @@ -1,8 +1,13 @@ +import os +import socket import ssl +from tempfile import gettempdir import pytest +import uvicorn import httpcore +from tests.conftest import Server @@ -49,3 +54,82 @@ def test_extra_info(httpbin_secure): assert invalid is None stream.get_extra_info("is_readable") + + + +@pytest.mark.parametrize("keep_alive_enabled", [True, False]) +def test_socket_options( + server: Server, server_url: str, keep_alive_enabled: bool +) -> None: + socket_options = [(socket.SOL_SOCKET, socket.SO_KEEPALIVE, int(keep_alive_enabled))] + with httpcore.ConnectionPool(socket_options=socket_options) as pool: + response = pool.request("GET", server_url) + assert response.status == 200 + + stream = response.extensions["network_stream"] + sock = stream.get_extra_info("socket") + opt = sock.getsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE) + assert bool(opt) is keep_alive_enabled + + + +def test_socket_no_nagle(server: Server, server_url: str) -> None: + with httpcore.ConnectionPool() as pool: + response = pool.request("GET", server_url) + assert response.status == 200 + + stream = response.extensions["network_stream"] + sock = stream.get_extra_info("socket") + opt = sock.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY) + assert bool(opt) is True + + + +def test_pool_recovers_from_connection_breakage( + server_config: uvicorn.Config, server_url: str +) -> None: + with httpcore.ConnectionPool( + max_connections=1, max_keepalive_connections=1, keepalive_expiry=10 + ) as pool: + with Server(server_config).run_in_thread(): + response = pool.request("GET", server_url) + assert response.status == 200 + + assert len(pool.connections) == 1 + conn = pool.connections[0] + + stream = response.extensions["network_stream"] + assert stream.get_extra_info("is_readable") is False + + assert ( + stream.get_extra_info("is_readable") is True + ), "Should break by coming readable" + + with Server(server_config).run_in_thread(): + assert len(pool.connections) == 1 + assert pool.connections[0] is conn, "Should be the broken connection" + + response = pool.request("GET", server_url) + assert response.status == 200 + + assert len(pool.connections) == 1 + assert pool.connections[0] is not conn, "Should be a new connection" + + + +def test_unix_domain_socket(server_port, server_config, server_url): + uds = f"{gettempdir()}/test_httpcore_app.sock" + if os.path.exists(uds): + os.remove(uds) # pragma: nocover + + uds_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + try: + uds_sock.bind(uds) + + with Server(server_config).run_in_thread(sockets=[uds_sock]): + with httpcore.ConnectionPool(uds=uds) as pool: + response = pool.request("GET", server_url) + assert response.status == 200 + finally: + uds_sock.close() + os.remove(uds) diff --git a/tests/benchmark/client.py b/tests/benchmark/client.py index d07802b01..1a673892b 100644 --- a/tests/benchmark/client.py +++ b/tests/benchmark/client.py @@ -1,5 +1,4 @@ import asyncio -import os import sys import time from concurrent.futures import ThreadPoolExecutor @@ -21,7 +20,8 @@ CONCURRENCY = 20 POOL_LIMIT = 100 PROFILE = False -os.environ["HTTPCORE_PREFER_ANYIO"] = "0" +httpcore.AutoBackend.set_default_backend(httpcore.AsyncioBackend) +# httpcore.AutoBackend.set_default_backend(httpcore.AnyIOBackend) def duration(start: float) -> int: diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 000000000..5aa2ed50f --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,90 @@ +import socket +import time +from contextlib import contextmanager +from threading import Thread +from typing import Any, Awaitable, Callable, Generator, Iterator, List, Optional + +import pytest +import uvicorn + +from httpcore import AnyIOBackend, AsyncioBackend, AutoBackend + + +@pytest.fixture( + params=[ + pytest.param(("asyncio", {"httpcore_use_anyio": False}), id="asyncio"), + pytest.param(("asyncio", {"httpcore_use_anyio": True}), id="asyncio+anyio"), + pytest.param(("trio", {}), id="trio"), + ] +) +def anyio_backend(request, monkeypatch): + backend_name, options = request.param + options = {**options} + + if backend_name == "trio": + AutoBackend.set_default_backend(None) + else: + use_anyio = options.pop("httpcore_use_anyio", False) + AutoBackend.set_default_backend(AnyIOBackend if use_anyio else AsyncioBackend) + + return backend_name, options + + +class Server(uvicorn.Server): + @contextmanager + def run_in_thread( + self, sockets: Optional[List[socket.socket]] = None + ) -> Generator[None, None, None]: + thread = Thread(target=lambda: self.run(sockets)) + thread.start() + start_time = time.monotonic() + try: + while not self.started: + time.sleep(0.01) + if (time.monotonic() - start_time) > 5: + raise TimeoutError() # pragma: nocover + yield + finally: + self.should_exit = True + thread.join() + + +@pytest.fixture +def server_port() -> int: + return 1111 + + +@pytest.fixture +def server_url(server_port: int) -> str: + return f"http://127.0.0.1:{server_port}" + + +@pytest.fixture +def server_app() -> Callable[[Any, Any, Any], Awaitable[None]]: + async def app(scope, receive, send): + assert scope["type"] == "http" + assert not (await receive()).get("more_body", False) + + start = { + "type": "http.response.start", + "status": 200, + "headers": [[b"content-type", b"text/plain"]], + } + body = {"type": "http.response.body", "body": b"Hello World"} + await send(start) + await send(body) + + return app + + +@pytest.fixture +def server_config( + server_port: int, server_app: Callable[[Any, Any, Any], Awaitable[None]] +) -> uvicorn.Config: + return uvicorn.Config(server_app, port=server_port, log_level="error") + + +@pytest.fixture +def server(server_config: uvicorn.Config) -> Iterator[None]: + with Server(server_config).run_in_thread(): + yield diff --git a/tests/test_auto_backend.py b/tests/test_auto_backend.py new file mode 100644 index 000000000..c532f17df --- /dev/null +++ b/tests/test_auto_backend.py @@ -0,0 +1,49 @@ +from typing import Generator, List + +import pytest + +import httpcore +from httpcore import ( + AnyIOBackend, + AsyncioBackend, + AsyncNetworkBackend, + AutoBackend, + TrioBackend, +) +from httpcore._synchronization import current_async_backend + + +@pytest.fixture(scope="session", autouse=True) +def check_tested_backends() -> Generator[List[AsyncNetworkBackend], None, None]: + # Ensure tests cover all supported backend variants + backends: List[AsyncNetworkBackend] = [] + yield backends + assert {b.__class__ for b in backends} == { + AsyncioBackend, + AnyIOBackend, + TrioBackend, + } + + +@pytest.mark.anyio +async def test_init_backend(check_tested_backends: List[AsyncNetworkBackend]) -> None: + auto = AutoBackend() + await auto._init_backend() + assert auto._backend is not None + check_tested_backends.append(auto._backend) + + +@pytest.mark.anyio +@pytest.mark.parametrize("has_anyio", [False, True]) +async def test_auto_backend_asyncio(monkeypatch, has_anyio): + if current_async_backend() == "trio": + return + + AutoBackend.set_default_backend(None) + + monkeypatch.setattr(httpcore._backends.auto, "HAS_ANYIO", has_anyio) + + auto = AutoBackend() + await auto._init_backend() + assert auto._backend is not None + assert isinstance(auto._backend, AnyIOBackend if has_anyio else AsyncioBackend) From 828c491731a37923b8bb9e72c6436aa5dff7db42 Mon Sep 17 00:00:00 2001 From: Markus Sintonen Date: Mon, 17 Jun 2024 15:58:34 +0300 Subject: [PATCH 2/7] Simplify patch set --- docs/exceptions.md | 1 - httpcore/__init__.py | 4 ---- httpcore/_backends/anyio.py | 34 +++++++++++------------------- httpcore/_backends/asyncio.py | 11 +++++----- httpcore/_backends/auto.py | 39 +++++++++-------------------------- httpcore/_backends/trio.py | 26 ++++++++++------------- httpcore/_exceptions.py | 4 ---- httpcore/_synchronization.py | 30 ++++++++++++++------------- pyproject.toml | 5 ++++- tests/benchmark/client.py | 4 ++-- tests/conftest.py | 22 ++++++++++++++------ tests/test_auto_backend.py | 30 ++++++++++----------------- 12 files changed, 88 insertions(+), 122 deletions(-) diff --git a/docs/exceptions.md b/docs/exceptions.md index cff979945..63ef3f28e 100644 --- a/docs/exceptions.md +++ b/docs/exceptions.md @@ -9,7 +9,6 @@ The following exceptions may be raised when sending a request: * `httpcore.WriteTimeout` * `httpcore.NetworkError` * `httpcore.ConnectError` - * `httpcore.BrokenSocketError` * `httpcore.ReadError` * `httpcore.WriteError` * `httpcore.ProtocolError` diff --git a/httpcore/__init__.py b/httpcore/__init__.py index 94f4dc0ba..ecd6c537a 100644 --- a/httpcore/__init__.py +++ b/httpcore/__init__.py @@ -9,7 +9,6 @@ AsyncSOCKSProxy, ) from ._backends.asyncio import AsyncioBackend -from ._backends.auto import AutoBackend from ._backends.base import ( SOCKET_OPTION, AsyncNetworkBackend, @@ -20,7 +19,6 @@ from ._backends.mock import AsyncMockBackend, AsyncMockStream, MockBackend, MockStream from ._backends.sync import SyncBackend from ._exceptions import ( - BrokenSocketError, ConnectError, ConnectionNotAvailable, ConnectTimeout, @@ -100,7 +98,6 @@ def __init__(self, *args, **kwargs): # type: ignore "SOCKSProxy", # network backends, implementations "SyncBackend", - "AutoBackend", "AsyncioBackend", "AnyIOBackend", "TrioBackend", @@ -131,7 +128,6 @@ def __init__(self, *args, **kwargs): # type: ignore "WriteTimeout", "NetworkError", "ConnectError", - "BrokenSocketError", "ReadError", "WriteError", ] diff --git a/httpcore/_backends/anyio.py b/httpcore/_backends/anyio.py index 995a3d946..9f4fdf86c 100644 --- a/httpcore/_backends/anyio.py +++ b/httpcore/_backends/anyio.py @@ -4,7 +4,6 @@ import anyio from .._exceptions import ( - BrokenSocketError, ConnectError, ConnectTimeout, ReadError, @@ -83,9 +82,6 @@ async def start_tls( return AnyIOStream(ssl_stream) def get_extra_info(self, info: str) -> typing.Any: - if info == "is_readable": - sock = self._stream.extra(anyio.abc.SocketAttribute.raw_socket, None) - return is_socket_readable(sock) if info == "ssl_object": return self._stream.extra(anyio.streams.tls.TLSAttribute.ssl_object, None) if info == "client_addr": @@ -94,6 +90,9 @@ def get_extra_info(self, info: str) -> typing.Any: return self._stream.extra(anyio.abc.SocketAttribute.remote_address, None) if info == "socket": return self._stream.extra(anyio.abc.SocketAttribute.raw_socket, None) + if info == "is_readable": + sock = self._stream.extra(anyio.abc.SocketAttribute.raw_socket, None) + return is_socket_readable(sock) return None @@ -106,6 +105,8 @@ async def connect_tcp( local_address: typing.Optional[str] = None, socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None, ) -> AsyncNetworkStream: + if socket_options is None: + socket_options = [] # pragma: no cover exc_map = { TimeoutError: ConnectTimeout, OSError: ConnectError, @@ -119,7 +120,8 @@ async def connect_tcp( local_host=local_address, ) # By default TCP sockets opened in `asyncio` include TCP_NODELAY. - self._set_socket_options(stream, socket_options) + for option in socket_options: + stream._raw_socket.setsockopt(*option) # type: ignore[attr-defined] # pragma: no cover return AnyIOStream(stream) async def connect_unix_socket( @@ -127,7 +129,9 @@ async def connect_unix_socket( path: str, timeout: typing.Optional[float] = None, socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None, - ) -> AsyncNetworkStream: + ) -> AsyncNetworkStream: # pragma: nocover + if socket_options is None: + socket_options = [] exc_map = { TimeoutError: ConnectTimeout, OSError: ConnectError, @@ -136,23 +140,9 @@ async def connect_unix_socket( with map_exceptions(exc_map): with anyio.fail_after(timeout): stream: anyio.abc.ByteStream = await anyio.connect_unix(path) - self._set_socket_options(stream, socket_options) + for option in socket_options: + stream._raw_socket.setsockopt(*option) # type: ignore[attr-defined] # pragma: no cover return AnyIOStream(stream) async def sleep(self, seconds: float) -> None: await anyio.sleep(seconds) # pragma: nocover - - def _set_socket_options( - self, - stream: anyio.abc.ByteStream, - socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None, - ) -> None: - if not socket_options: - return - - sock = stream.extra(anyio.abc.SocketAttribute.raw_socket, None) - if sock is None: - raise BrokenSocketError() # pragma: nocover - - for option in socket_options: - sock.setsockopt(*option) diff --git a/httpcore/_backends/asyncio.py b/httpcore/_backends/asyncio.py index 312fb648a..dc5b29157 100644 --- a/httpcore/_backends/asyncio.py +++ b/httpcore/_backends/asyncio.py @@ -4,7 +4,6 @@ from typing import Any, Dict, Iterable, Optional, Type from .._exceptions import ( - BrokenSocketError, ConnectError, ConnectTimeout, ReadError, @@ -144,8 +143,9 @@ def get_extra_info(self, info: str) -> Any: return self._sslobj if info in ("client_addr", "server_addr"): sock = self._raw_socket - if sock is None: - raise BrokenSocketError() # pragma: nocover + if sock is None: # pragma: nocover + # TODO replace with an explicit error such as BrokenSocketError + raise ConnectError() return sock.getsockname() if info == "client_addr" else sock.getpeername() if info == "socket": return self._raw_socket @@ -220,8 +220,9 @@ def _set_socket_options( return sock = stream.get_extra_info("socket") - if sock is None: - raise BrokenSocketError() # pragma: nocover + if sock is None: # pragma: nocover + # TODO replace with an explicit error such as BrokenSocketError + raise ConnectError() for option in socket_options: sock.setsockopt(*option) diff --git a/httpcore/_backends/auto.py b/httpcore/_backends/auto.py index 6e266381a..3ac05f4da 100644 --- a/httpcore/_backends/auto.py +++ b/httpcore/_backends/auto.py @@ -1,41 +1,22 @@ import typing -from importlib.util import find_spec -from typing import Optional, Type +from typing import Optional -from .._synchronization import current_async_backend +from .._synchronization import current_async_library from .base import SOCKET_OPTION, AsyncNetworkBackend, AsyncNetworkStream -HAS_ANYIO = find_spec("anyio") is not None - class AutoBackend(AsyncNetworkBackend): - @staticmethod - def set_default_backend(backend_class: Optional[Type[AsyncNetworkBackend]]) -> None: - setattr(AutoBackend, "_default_backend_class", backend_class) - async def _init_backend(self) -> None: - if hasattr(self, "_backend"): - return - - default_backend_class: Optional[Type[AsyncNetworkBackend]] = getattr( - AutoBackend, "_default_backend_class", None - ) - if default_backend_class is not None: - self._backend = default_backend_class() - return - - if current_async_backend() == "trio": - from .trio import TrioBackend - - self._backend = TrioBackend() - elif HAS_ANYIO: - from .anyio import AnyIOBackend + if not (hasattr(self, "_backend")): + backend = current_async_library() + if backend == "trio": + from .trio import TrioBackend - self._backend = AnyIOBackend() - else: - from .asyncio import AsyncioBackend + self._backend: AsyncNetworkBackend = TrioBackend() + else: + from .anyio import AnyIOBackend - self._backend = AsyncioBackend() + self._backend = AnyIOBackend() async def connect_tcp( self, diff --git a/httpcore/_backends/trio.py b/httpcore/_backends/trio.py index 26320e61f..b1626d28e 100644 --- a/httpcore/_backends/trio.py +++ b/httpcore/_backends/trio.py @@ -117,6 +117,10 @@ async def connect_tcp( local_address: typing.Optional[str] = None, socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None, ) -> AsyncNetworkStream: + # By default for TCP sockets, trio enables TCP_NODELAY. + # https://trio.readthedocs.io/en/stable/reference-io.html#trio.SocketStream + if socket_options is None: + socket_options = [] # pragma: no cover timeout_or_inf = float("inf") if timeout is None else timeout exc_map: ExceptionMapping = { trio.TooSlowError: ConnectTimeout, @@ -128,7 +132,8 @@ async def connect_tcp( stream: trio.abc.Stream = await trio.open_tcp_stream( host=host, port=port, local_address=local_address ) - self._set_socket_options(stream, socket_options) + for option in socket_options: + stream.setsockopt(*option) # type: ignore[attr-defined] # pragma: no cover return TrioStream(stream) async def connect_unix_socket( @@ -136,7 +141,9 @@ async def connect_unix_socket( path: str, timeout: typing.Optional[float] = None, socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None, - ) -> AsyncNetworkStream: + ) -> AsyncNetworkStream: # pragma: nocover + if socket_options is None: + socket_options = [] timeout_or_inf = float("inf") if timeout is None else timeout exc_map: ExceptionMapping = { trio.TooSlowError: ConnectTimeout, @@ -146,20 +153,9 @@ async def connect_unix_socket( with map_exceptions(exc_map): with trio.fail_after(timeout_or_inf): stream: trio.abc.Stream = await trio.open_unix_socket(path) - self._set_socket_options(stream, socket_options) + for option in socket_options: + stream.setsockopt(*option) # type: ignore[attr-defined] # pragma: no cover return TrioStream(stream) async def sleep(self, seconds: float) -> None: await trio.sleep(seconds) # pragma: nocover - - def _set_socket_options( - self, - stream: trio.abc.Stream, - socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None, - ) -> None: - # By default for TCP sockets, trio enables TCP_NODELAY. - # https://trio.readthedocs.io/en/stable/reference-io.html#trio.SocketStream - if not socket_options: - return - for option in socket_options: - stream.setsockopt(*option) # type: ignore[attr-defined] diff --git a/httpcore/_exceptions.py b/httpcore/_exceptions.py index dc62459b1..81e7fc61d 100644 --- a/httpcore/_exceptions.py +++ b/httpcore/_exceptions.py @@ -73,10 +73,6 @@ class ConnectError(NetworkError): pass -class BrokenSocketError(ConnectError): - pass - - class ReadError(NetworkError): pass diff --git a/httpcore/_synchronization.py b/httpcore/_synchronization.py index fe4cbecad..9619a3983 100644 --- a/httpcore/_synchronization.py +++ b/httpcore/_synchronization.py @@ -1,6 +1,6 @@ import threading from types import TracebackType -from typing import Literal, Optional, Type +from typing import Optional, Type from ._exceptions import ExceptionMapping, PoolTimeout, map_exceptions @@ -18,28 +18,30 @@ anyio = None # type: ignore -AsyncBackend = Literal["asyncio", "trio"] - - -def current_async_backend() -> AsyncBackend: +def current_async_library() -> str: # Determine if we're running under trio or asyncio. # See https://sniffio.readthedocs.io/en/latest/ try: import sniffio except ImportError: # pragma: nocover - backend: AsyncBackend = "asyncio" + environment = "asyncio" else: - backend = sniffio.current_async_library() # type: ignore[assignment] + environment = sniffio.current_async_library() - if backend not in ("asyncio", "trio"): # pragma: nocover + if environment not in ("asyncio", "trio"): # pragma: nocover raise RuntimeError("Running under an unsupported async environment.") - if backend == "trio" and trio is None: # pragma: nocover + if environment == "asyncio" and anyio is None: # pragma: nocover + raise RuntimeError( + "Running with asyncio requires installation of 'httpcore[asyncio]'." + ) + + if environment == "trio" and trio is None: # pragma: nocover raise RuntimeError( "Running with trio requires installation of 'httpcore[trio]'." ) - return backend + return environment class AsyncLock: @@ -58,7 +60,7 @@ def setup(self) -> None: Detect if we're running under 'asyncio' or 'trio' and create a lock with the correct implementation. """ - self._backend = current_async_backend() + self._backend = current_async_library() if self._backend == "trio": self._trio_lock = trio.Lock() elif self._backend == "asyncio": @@ -116,7 +118,7 @@ def setup(self) -> None: Detect if we're running under 'asyncio' or 'trio' and create a lock with the correct implementation. """ - self._backend = current_async_backend() + self._backend = current_async_library() if self._backend == "trio": self._trio_event = trio.Event() elif self._backend == "asyncio": @@ -158,7 +160,7 @@ def setup(self) -> None: Detect if we're running under 'asyncio' or 'trio' and create a semaphore with the correct implementation. """ - self._backend = current_async_backend() + self._backend = current_async_library() if self._backend == "trio": self._trio_semaphore = trio.Semaphore( initial_value=self._bound, max_value=self._bound @@ -197,7 +199,7 @@ def __init__(self) -> None: Detect if we're running under 'asyncio' or 'trio' and create a shielded scope with the correct implementation. """ - self._backend = current_async_backend() + self._backend = current_async_library() if self._backend == "trio": self._trio_shield = trio.CancelScope(shield=True) diff --git a/pyproject.toml b/pyproject.toml index 85c787402..4b4d1217e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -91,7 +91,10 @@ ignore_missing_imports = true [tool.pytest.ini_options] addopts = ["-rxXs", "--strict-config", "--strict-markers"] -markers = ["copied_from(source, changes=None): mark test as copied from somewhere else, along with a description of changes made to accodomate e.g. our test setup"] +markers = [ + "copied_from(source, changes=None): mark test as copied from somewhere else, along with a description of changes made to accodomate e.g. our test setup", + "no_auto_backend_patch", # TODO remove this marker once we have a way to define the asyncio backend in AutoBackend +] filterwarnings = ["error"] [tool.coverage.run] diff --git a/tests/benchmark/client.py b/tests/benchmark/client.py index 1a673892b..d07802b01 100644 --- a/tests/benchmark/client.py +++ b/tests/benchmark/client.py @@ -1,4 +1,5 @@ import asyncio +import os import sys import time from concurrent.futures import ThreadPoolExecutor @@ -20,8 +21,7 @@ CONCURRENCY = 20 POOL_LIMIT = 100 PROFILE = False -httpcore.AutoBackend.set_default_backend(httpcore.AsyncioBackend) -# httpcore.AutoBackend.set_default_backend(httpcore.AnyIOBackend) +os.environ["HTTPCORE_PREFER_ANYIO"] = "0" def duration(start: float) -> int: diff --git a/tests/conftest.py b/tests/conftest.py index 5aa2ed50f..b2dc6fe4e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -7,7 +7,8 @@ import pytest import uvicorn -from httpcore import AnyIOBackend, AsyncioBackend, AutoBackend +from httpcore import AnyIOBackend, AsyncioBackend +from httpcore._backends.auto import AutoBackend @pytest.fixture( @@ -20,12 +21,21 @@ def anyio_backend(request, monkeypatch): backend_name, options = request.param options = {**options} + use_anyio = options.pop("httpcore_use_anyio", False) - if backend_name == "trio": - AutoBackend.set_default_backend(None) - else: - use_anyio = options.pop("httpcore_use_anyio", False) - AutoBackend.set_default_backend(AnyIOBackend if use_anyio else AsyncioBackend) + # TODO remove this marker once we have a way to define the asyncio backend in AutoBackend + no_auto_backend_patch = bool( + request.node.get_closest_marker("no_auto_backend_patch") + ) + + if backend_name != "trio" and not no_auto_backend_patch: + # TODO replace with a proper interface in AutoBackend to setup either the AnyIO or asyncio backend + async def patch_init_backend(auto_backend: AutoBackend) -> None: + if hasattr(auto_backend, "_backend"): + return + auto_backend._backend = AnyIOBackend() if use_anyio else AsyncioBackend() + + monkeypatch.setattr(AutoBackend, "_init_backend", patch_init_backend) return backend_name, options diff --git a/tests/test_auto_backend.py b/tests/test_auto_backend.py index c532f17df..6a448881e 100644 --- a/tests/test_auto_backend.py +++ b/tests/test_auto_backend.py @@ -1,16 +1,10 @@ from typing import Generator, List import pytest +from sniffio import current_async_library -import httpcore -from httpcore import ( - AnyIOBackend, - AsyncioBackend, - AsyncNetworkBackend, - AutoBackend, - TrioBackend, -) -from httpcore._synchronization import current_async_backend +from httpcore import AnyIOBackend, AsyncioBackend, AsyncNetworkBackend, TrioBackend +from httpcore._backends.auto import AutoBackend @pytest.fixture(scope="session", autouse=True) @@ -34,16 +28,14 @@ async def test_init_backend(check_tested_backends: List[AsyncNetworkBackend]) -> @pytest.mark.anyio -@pytest.mark.parametrize("has_anyio", [False, True]) -async def test_auto_backend_asyncio(monkeypatch, has_anyio): - if current_async_backend() == "trio": - return - - AutoBackend.set_default_backend(None) - - monkeypatch.setattr(httpcore._backends.auto, "HAS_ANYIO", has_anyio) - +@pytest.mark.no_auto_backend_patch +async def test_auto_backend_uses_expected_backend(monkeypatch): auto = AutoBackend() await auto._init_backend() assert auto._backend is not None - assert isinstance(auto._backend, AnyIOBackend if has_anyio else AsyncioBackend) + + if current_async_library() == "trio": + assert isinstance(auto._backend, TrioBackend) + else: + # TODO add support for choosing the AsyncioBackend in AutoBackend + assert isinstance(auto._backend, AnyIOBackend) From b47b4c2240d66b6b7c8cf44c9bd7fa40388d560d Mon Sep 17 00:00:00 2001 From: Markus Sintonen Date: Mon, 17 Jun 2024 16:48:16 +0300 Subject: [PATCH 3/7] Update doc --- docs/network-backends.md | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/docs/network-backends.md b/docs/network-backends.md index fbb6bfdbf..83ab3869c 100644 --- a/docs/network-backends.md +++ b/docs/network-backends.md @@ -73,8 +73,29 @@ while True: If we're working with an `async` codebase, then we need to select a different backend. -The `httpcore.AnyIOBackend` is suitable for usage if you're running under `asyncio`. This is a networking backend implemented using [the `anyio` package](https://anyio.readthedocs.io/en/3.x/). +These `async` network backends are available: +- `httpcore.AsyncioBackend` This networking backend is implemented using Pythons native `asyncio`. +- `httpcore.AnyIOBackend` This is implemented using [the `anyio` package](https://anyio.readthedocs.io/en/3.x/). +- `httpcore.TrioBackend` This is implemented using [`trio`](https://trio.readthedocs.io/en/stable/). +Currently by default `AnyIOBackend` is used when running with `asyncio` (this may change). +`TrioBackend` is used by default when running with `trio`. + +Using `httpcore.AsyncioBackend`: +```python +import httpcore +import asyncio + +async def main(): + network_backend = httpcore.AsyncioBackend() + async with httpcore.AsyncConnectionPool(network_backend=network_backend) as http: + response = await http.request('GET', 'https://www.example.com') + print(response) + +asyncio.run(main()) +``` + +Using `httpcore.AnyIOBackend`: ```python import httpcore import asyncio From 3814cf4adafab3e933b7614ae706e79bc4f406e5 Mon Sep 17 00:00:00 2001 From: Markus Sintonen Date: Tue, 18 Jun 2024 11:30:25 +0300 Subject: [PATCH 4/7] Add network backend benchmark param --- tests/benchmark/client.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/benchmark/client.py b/tests/benchmark/client.py index d07802b01..14dc73dd7 100644 --- a/tests/benchmark/client.py +++ b/tests/benchmark/client.py @@ -1,5 +1,4 @@ import asyncio -import os import sys import time from concurrent.futures import ThreadPoolExecutor @@ -21,7 +20,7 @@ CONCURRENCY = 20 POOL_LIMIT = 100 PROFILE = False -os.environ["HTTPCORE_PREFER_ANYIO"] = "0" +NET_BACKEND = httpcore.AsyncioBackend def duration(start: float) -> int: @@ -66,7 +65,9 @@ async def aiohttp_get(session: aiohttp.ClientSession, timings: List[int]) -> Non assert res.status == 200, f"status={res.status}" timings.append(duration(start)) - async with httpcore.AsyncConnectionPool(max_connections=POOL_LIMIT) as pool: + async with httpcore.AsyncConnectionPool( + max_connections=POOL_LIMIT, network_backend=NET_BACKEND() + ) as pool: # warmup await gather_limited_concurrency( (httpcore_get(pool, []) for _ in range(REQUESTS)), CONCURRENCY * 2 From 30f9c85bef1213a31c153268dc56e4d93fc36b37 Mon Sep 17 00:00:00 2001 From: Markus Sintonen Date: Tue, 1 Oct 2024 10:52:45 +0300 Subject: [PATCH 5/7] Rename to AsyncIOBackend --- httpcore/__init__.py | 4 ++-- httpcore/_backends/asyncio.py | 2 +- tests/benchmark/client.py | 2 +- tests/conftest.py | 4 ++-- tests/test_auto_backend.py | 6 +++--- 5 files changed, 9 insertions(+), 9 deletions(-) diff --git a/httpcore/__init__.py b/httpcore/__init__.py index ecd6c537a..a2bb1d8a8 100644 --- a/httpcore/__init__.py +++ b/httpcore/__init__.py @@ -8,7 +8,7 @@ AsyncHTTPProxy, AsyncSOCKSProxy, ) -from ._backends.asyncio import AsyncioBackend +from ._backends.asyncio import AsyncIOBackend from ._backends.base import ( SOCKET_OPTION, AsyncNetworkBackend, @@ -98,7 +98,7 @@ def __init__(self, *args, **kwargs): # type: ignore "SOCKSProxy", # network backends, implementations "SyncBackend", - "AsyncioBackend", + "AsyncIOBackend", "AnyIOBackend", "TrioBackend", # network backends, mock implementations diff --git a/httpcore/_backends/asyncio.py b/httpcore/_backends/asyncio.py index dc5b29157..8fbe2b944 100644 --- a/httpcore/_backends/asyncio.py +++ b/httpcore/_backends/asyncio.py @@ -164,7 +164,7 @@ def _sslobj(self) -> Optional[ssl.SSLObject]: return sslobj -class AsyncioBackend(AsyncNetworkBackend): +class AsyncIOBackend(AsyncNetworkBackend): async def connect_tcp( self, host: str, diff --git a/tests/benchmark/client.py b/tests/benchmark/client.py index 14dc73dd7..e62968c97 100644 --- a/tests/benchmark/client.py +++ b/tests/benchmark/client.py @@ -20,7 +20,7 @@ CONCURRENCY = 20 POOL_LIMIT = 100 PROFILE = False -NET_BACKEND = httpcore.AsyncioBackend +NET_BACKEND = httpcore.AsyncIOBackend def duration(start: float) -> int: diff --git a/tests/conftest.py b/tests/conftest.py index b2dc6fe4e..355eeca81 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -7,7 +7,7 @@ import pytest import uvicorn -from httpcore import AnyIOBackend, AsyncioBackend +from httpcore import AnyIOBackend, AsyncIOBackend from httpcore._backends.auto import AutoBackend @@ -33,7 +33,7 @@ def anyio_backend(request, monkeypatch): async def patch_init_backend(auto_backend: AutoBackend) -> None: if hasattr(auto_backend, "_backend"): return - auto_backend._backend = AnyIOBackend() if use_anyio else AsyncioBackend() + auto_backend._backend = AnyIOBackend() if use_anyio else AsyncIOBackend() monkeypatch.setattr(AutoBackend, "_init_backend", patch_init_backend) diff --git a/tests/test_auto_backend.py b/tests/test_auto_backend.py index 6a448881e..6643d9943 100644 --- a/tests/test_auto_backend.py +++ b/tests/test_auto_backend.py @@ -3,7 +3,7 @@ import pytest from sniffio import current_async_library -from httpcore import AnyIOBackend, AsyncioBackend, AsyncNetworkBackend, TrioBackend +from httpcore import AnyIOBackend, AsyncIOBackend, AsyncNetworkBackend, TrioBackend from httpcore._backends.auto import AutoBackend @@ -13,7 +13,7 @@ def check_tested_backends() -> Generator[List[AsyncNetworkBackend], None, None]: backends: List[AsyncNetworkBackend] = [] yield backends assert {b.__class__ for b in backends} == { - AsyncioBackend, + AsyncIOBackend, AnyIOBackend, TrioBackend, } @@ -37,5 +37,5 @@ async def test_auto_backend_uses_expected_backend(monkeypatch): if current_async_library() == "trio": assert isinstance(auto._backend, TrioBackend) else: - # TODO add support for choosing the AsyncioBackend in AutoBackend + # TODO add support for choosing the AsyncIOBackend in AutoBackend assert isinstance(auto._backend, AnyIOBackend) From 07772b19715f790ac7fcefc12c27af1038a666be Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Tue, 1 Oct 2024 10:50:03 +0100 Subject: [PATCH 6/7] Apply suggestions from code review --- docs/network-backends.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/network-backends.md b/docs/network-backends.md index 83ab3869c..21a930b65 100644 --- a/docs/network-backends.md +++ b/docs/network-backends.md @@ -74,20 +74,20 @@ while True: If we're working with an `async` codebase, then we need to select a different backend. These `async` network backends are available: -- `httpcore.AsyncioBackend` This networking backend is implemented using Pythons native `asyncio`. +- `httpcore.AsyncIOBackend` This networking backend is implemented using Pythons native `asyncio`. - `httpcore.AnyIOBackend` This is implemented using [the `anyio` package](https://anyio.readthedocs.io/en/3.x/). - `httpcore.TrioBackend` This is implemented using [`trio`](https://trio.readthedocs.io/en/stable/). Currently by default `AnyIOBackend` is used when running with `asyncio` (this may change). `TrioBackend` is used by default when running with `trio`. -Using `httpcore.AsyncioBackend`: +Using `httpcore.AsyncIOBackend`: ```python import httpcore import asyncio async def main(): - network_backend = httpcore.AsyncioBackend() + network_backend = httpcore.AsyncIOBackend() async with httpcore.AsyncConnectionPool(network_backend=network_backend) as http: response = await http.request('GET', 'https://www.example.com') print(response) From ae741c01776f891d677742c4a3e0230cb14c9f8c Mon Sep 17 00:00:00 2001 From: Markus Sintonen Date: Tue, 1 Oct 2024 15:21:32 +0300 Subject: [PATCH 7/7] Remove unneeded locks --- httpcore/_backends/asyncio.py | 63 ++++++++++++++++------------------- 1 file changed, 29 insertions(+), 34 deletions(-) diff --git a/httpcore/_backends/asyncio.py b/httpcore/_backends/asyncio.py index 8fbe2b944..ba6becb06 100644 --- a/httpcore/_backends/asyncio.py +++ b/httpcore/_backends/asyncio.py @@ -22,8 +22,6 @@ def __init__( ): self._stream_reader = stream_reader self._stream_writer = stream_writer - self._read_lock = asyncio.Lock() - self._write_lock = asyncio.Lock() self._inner: Optional[AsyncIOStream] = None async def start_tls( @@ -76,23 +74,22 @@ async def read(self, max_bytes: int, timeout: Optional[float] = None) -> bytes: asyncio.TimeoutError: ReadTimeout, OSError: ReadError, } - async with self._read_lock: - with map_exceptions(exc_map): - try: - return await asyncio.wait_for( - self._stream_reader.read(max_bytes), timeout - ) - except AttributeError as exc: # pragma: nocover - if "resume_reading" in str(exc): - # Python's asyncio has a bug that can occur when a - # connection has been closed, while it is paused. - # See: https://github.com/encode/httpx/issues/1213 - # - # Returning an empty byte-string to indicate connection - # close will eventually raise an httpcore.RemoteProtocolError - # to the user when this goes through our HTTP parsing layer. - return b"" - raise + with map_exceptions(exc_map): + try: + return await asyncio.wait_for( + self._stream_reader.read(max_bytes), timeout + ) + except AttributeError as exc: # pragma: nocover + if "resume_reading" in str(exc): + # Python's asyncio has a bug that can occur when a + # connection has been closed, while it is paused. + # See: https://github.com/encode/httpx/issues/1213 + # + # Returning an empty byte-string to indicate connection + # close will eventually raise an httpcore.RemoteProtocolError + # to the user when this goes through our HTTP parsing layer. + return b"" + raise async def write(self, data: bytes, timeout: Optional[float] = None) -> None: if not data: @@ -102,10 +99,9 @@ async def write(self, data: bytes, timeout: Optional[float] = None) -> None: asyncio.TimeoutError: WriteTimeout, OSError: WriteError, } - async with self._write_lock: - with map_exceptions(exc_map): - self._stream_writer.write(data) - return await asyncio.wait_for(self._stream_writer.drain(), timeout) + with map_exceptions(exc_map): + self._stream_writer.write(data) + return await asyncio.wait_for(self._stream_writer.drain(), timeout) async def aclose(self) -> None: # SSL connections should issue the close and then abort, rather than @@ -124,17 +120,16 @@ async def aclose(self) -> None: # * https://github.com/encode/httpx/issues/914 is_ssl = self._sslobj is not None - async with self._write_lock: - try: - self._stream_writer.close() - if is_ssl: - # Give the connection a chance to write any data in the buffer, - # and then forcibly tear down the SSL connection. - await asyncio.sleep(0) - self._stream_writer.transport.abort() - await self._stream_writer.wait_closed() - except OSError: # pragma: nocover - pass + try: + self._stream_writer.close() + if is_ssl: + # Give the connection a chance to write any data in the buffer, + # and then forcibly tear down the SSL connection. + await asyncio.sleep(0) + self._stream_writer.transport.abort() + await self._stream_writer.wait_closed() + except OSError: # pragma: nocover + pass def get_extra_info(self, info: str) -> Any: if info == "is_readable":