diff --git a/docs/network-backends.md b/docs/network-backends.md index fbb6bfdbf..21a930b65 100644 --- a/docs/network-backends.md +++ b/docs/network-backends.md @@ -73,8 +73,29 @@ while True: If we're working with an `async` codebase, then we need to select a different backend. -The `httpcore.AnyIOBackend` is suitable for usage if you're running under `asyncio`. This is a networking backend implemented using [the `anyio` package](https://anyio.readthedocs.io/en/3.x/). +These `async` network backends are available: +- `httpcore.AsyncIOBackend` This networking backend is implemented using Pythons native `asyncio`. +- `httpcore.AnyIOBackend` This is implemented using [the `anyio` package](https://anyio.readthedocs.io/en/3.x/). +- `httpcore.TrioBackend` This is implemented using [`trio`](https://trio.readthedocs.io/en/stable/). +Currently by default `AnyIOBackend` is used when running with `asyncio` (this may change). +`TrioBackend` is used by default when running with `trio`. + +Using `httpcore.AsyncIOBackend`: +```python +import httpcore +import asyncio + +async def main(): + network_backend = httpcore.AsyncIOBackend() + async with httpcore.AsyncConnectionPool(network_backend=network_backend) as http: + response = await http.request('GET', 'https://www.example.com') + print(response) + +asyncio.run(main()) +``` + +Using `httpcore.AnyIOBackend`: ```python import httpcore import asyncio diff --git a/httpcore/__init__.py b/httpcore/__init__.py index 014213bae..a2bb1d8a8 100644 --- a/httpcore/__init__.py +++ b/httpcore/__init__.py @@ -8,6 +8,7 @@ AsyncHTTPProxy, AsyncSOCKSProxy, ) +from ._backends.asyncio import AsyncIOBackend from ._backends.base import ( SOCKET_OPTION, AsyncNetworkBackend, @@ -97,6 +98,7 @@ def __init__(self, *args, **kwargs): # type: ignore "SOCKSProxy", # network backends, implementations "SyncBackend", + "AsyncIOBackend", "AnyIOBackend", "TrioBackend", # network backends, mock implementations diff --git a/httpcore/_backends/asyncio.py b/httpcore/_backends/asyncio.py new file mode 100644 index 000000000..ba6becb06 --- /dev/null +++ b/httpcore/_backends/asyncio.py @@ -0,0 +1,223 @@ +import asyncio +import socket +import ssl +from typing import Any, Dict, Iterable, Optional, Type + +from .._exceptions import ( + ConnectError, + ConnectTimeout, + ReadError, + ReadTimeout, + WriteError, + WriteTimeout, + map_exceptions, +) +from .._utils import is_socket_readable +from .base import SOCKET_OPTION, AsyncNetworkBackend, AsyncNetworkStream + + +class AsyncIOStream(AsyncNetworkStream): + def __init__( + self, stream_reader: asyncio.StreamReader, stream_writer: asyncio.StreamWriter + ): + self._stream_reader = stream_reader + self._stream_writer = stream_writer + self._inner: Optional[AsyncIOStream] = None + + async def start_tls( + self, + ssl_context: ssl.SSLContext, + server_hostname: Optional[str] = None, + timeout: Optional[float] = None, + ) -> AsyncNetworkStream: + loop = asyncio.get_event_loop() + + stream_reader = asyncio.StreamReader() + protocol = asyncio.StreamReaderProtocol(stream_reader) + + exc_map: Dict[Type[Exception], Type[Exception]] = { + asyncio.TimeoutError: ConnectTimeout, + OSError: ConnectError, + } + with map_exceptions(exc_map): + transport_ssl = await asyncio.wait_for( + loop.start_tls( + self._stream_writer.transport, + protocol, + ssl_context, + server_hostname=server_hostname, + ), + timeout, + ) + if transport_ssl is None: + # https://docs.python.org/3/library/asyncio-eventloop.html#asyncio.loop.start_tls + raise ConnectError("Transport closed while starting TLS") # pragma: nocover + + # Initialize the protocol, so it is made aware of being tied to + # a TLS connection. + # See: https://github.com/encode/httpx/issues/859 + protocol.connection_made(transport_ssl) + + stream_writer = asyncio.StreamWriter( + transport=transport_ssl, protocol=protocol, reader=stream_reader, loop=loop + ) + + ssl_stream = AsyncIOStream(stream_reader, stream_writer) + # When we return a new SocketStream with new StreamReader/StreamWriter instances + # we need to keep references to the old StreamReader/StreamWriter so that they + # are not garbage collected and closed while we're still using them. + ssl_stream._inner = self + return ssl_stream + + async def read(self, max_bytes: int, timeout: Optional[float] = None) -> bytes: + exc_map: Dict[Type[Exception], Type[Exception]] = { + asyncio.TimeoutError: ReadTimeout, + OSError: ReadError, + } + with map_exceptions(exc_map): + try: + return await asyncio.wait_for( + self._stream_reader.read(max_bytes), timeout + ) + except AttributeError as exc: # pragma: nocover + if "resume_reading" in str(exc): + # Python's asyncio has a bug that can occur when a + # connection has been closed, while it is paused. + # See: https://github.com/encode/httpx/issues/1213 + # + # Returning an empty byte-string to indicate connection + # close will eventually raise an httpcore.RemoteProtocolError + # to the user when this goes through our HTTP parsing layer. + return b"" + raise + + async def write(self, data: bytes, timeout: Optional[float] = None) -> None: + if not data: + return + + exc_map: Dict[Type[Exception], Type[Exception]] = { + asyncio.TimeoutError: WriteTimeout, + OSError: WriteError, + } + with map_exceptions(exc_map): + self._stream_writer.write(data) + return await asyncio.wait_for(self._stream_writer.drain(), timeout) + + async def aclose(self) -> None: + # SSL connections should issue the close and then abort, rather than + # waiting for the remote end of the connection to signal the EOF. + # + # See: + # + # * https://bugs.python.org/issue39758 + # * https://github.com/python-trio/trio/blob/ + # 31e2ae866ad549f1927d45ce073d4f0ea9f12419/trio/_ssl.py#L779-L829 + # + # And related issues caused if we simply omit the 'wait_closed' call, + # without first using `.abort()` + # + # * https://github.com/encode/httpx/issues/825 + # * https://github.com/encode/httpx/issues/914 + is_ssl = self._sslobj is not None + + try: + self._stream_writer.close() + if is_ssl: + # Give the connection a chance to write any data in the buffer, + # and then forcibly tear down the SSL connection. + await asyncio.sleep(0) + self._stream_writer.transport.abort() + await self._stream_writer.wait_closed() + except OSError: # pragma: nocover + pass + + def get_extra_info(self, info: str) -> Any: + if info == "is_readable": + return is_socket_readable(self._raw_socket) + if info == "ssl_object": + return self._sslobj + if info in ("client_addr", "server_addr"): + sock = self._raw_socket + if sock is None: # pragma: nocover + # TODO replace with an explicit error such as BrokenSocketError + raise ConnectError() + return sock.getsockname() if info == "client_addr" else sock.getpeername() + if info == "socket": + return self._raw_socket + return None + + @property + def _raw_socket(self) -> Optional[socket.socket]: + transport = self._stream_writer.transport + sock: Optional[socket.socket] = transport.get_extra_info("socket") + return sock + + @property + def _sslobj(self) -> Optional[ssl.SSLObject]: + transport = self._stream_writer.transport + sslobj: Optional[ssl.SSLObject] = transport.get_extra_info("ssl_object") + return sslobj + + +class AsyncIOBackend(AsyncNetworkBackend): + async def connect_tcp( + self, + host: str, + port: int, + timeout: Optional[float] = None, + local_address: Optional[str] = None, + socket_options: Optional[Iterable[SOCKET_OPTION]] = None, + ) -> AsyncNetworkStream: + local_addr = None if local_address is None else (local_address, 0) + + exc_map: Dict[Type[Exception], Type[Exception]] = { + asyncio.TimeoutError: ConnectTimeout, + OSError: ConnectError, + } + with map_exceptions(exc_map): + stream_reader, stream_writer = await asyncio.wait_for( + asyncio.open_connection(host, port, local_addr=local_addr), + timeout, + ) + self._set_socket_options(stream_writer, socket_options) + return AsyncIOStream( + stream_reader=stream_reader, stream_writer=stream_writer + ) + + async def connect_unix_socket( + self, + path: str, + timeout: Optional[float] = None, + socket_options: Optional[Iterable[SOCKET_OPTION]] = None, + ) -> AsyncNetworkStream: + exc_map: Dict[Type[Exception], Type[Exception]] = { + asyncio.TimeoutError: ConnectTimeout, + OSError: ConnectError, + } + with map_exceptions(exc_map): + stream_reader, stream_writer = await asyncio.wait_for( + asyncio.open_unix_connection(path), timeout + ) + self._set_socket_options(stream_writer, socket_options) + return AsyncIOStream( + stream_reader=stream_reader, stream_writer=stream_writer + ) + + async def sleep(self, seconds: float) -> None: + await asyncio.sleep(seconds) # pragma: nocover + + def _set_socket_options( + self, + stream: asyncio.StreamWriter, + socket_options: Optional[Iterable[SOCKET_OPTION]] = None, + ) -> None: + if not socket_options: + return + + sock = stream.get_extra_info("socket") + if sock is None: # pragma: nocover + # TODO replace with an explicit error such as BrokenSocketError + raise ConnectError() + + for option in socket_options: + sock.setsockopt(*option) diff --git a/pyproject.toml b/pyproject.toml index d6573dd84..f90f64b8f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -91,7 +91,10 @@ ignore_missing_imports = true [tool.pytest.ini_options] addopts = ["-rxXs", "--strict-config", "--strict-markers"] -markers = ["copied_from(source, changes=None): mark test as copied from somewhere else, along with a description of changes made to accodomate e.g. our test setup"] +markers = [ + "copied_from(source, changes=None): mark test as copied from somewhere else, along with a description of changes made to accodomate e.g. our test setup", + "no_auto_backend_patch", # TODO remove this marker once we have a way to define the asyncio backend in AutoBackend +] filterwarnings = ["error"] [tool.coverage.run] diff --git a/tests/_async/test_integration.py b/tests/_async/test_integration.py index 1970531d5..797933e4a 100644 --- a/tests/_async/test_integration.py +++ b/tests/_async/test_integration.py @@ -1,8 +1,13 @@ +import os +import socket import ssl +from tempfile import gettempdir import pytest +import uvicorn import httpcore +from tests.conftest import Server @pytest.mark.anyio @@ -49,3 +54,82 @@ async def test_extra_info(httpbin_secure): assert invalid is None stream.get_extra_info("is_readable") + + +@pytest.mark.anyio +@pytest.mark.parametrize("keep_alive_enabled", [True, False]) +async def test_socket_options( + server: Server, server_url: str, keep_alive_enabled: bool +) -> None: + socket_options = [(socket.SOL_SOCKET, socket.SO_KEEPALIVE, int(keep_alive_enabled))] + async with httpcore.AsyncConnectionPool(socket_options=socket_options) as pool: + response = await pool.request("GET", server_url) + assert response.status == 200 + + stream = response.extensions["network_stream"] + sock = stream.get_extra_info("socket") + opt = sock.getsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE) + assert bool(opt) is keep_alive_enabled + + +@pytest.mark.anyio +async def test_socket_no_nagle(server: Server, server_url: str) -> None: + async with httpcore.AsyncConnectionPool() as pool: + response = await pool.request("GET", server_url) + assert response.status == 200 + + stream = response.extensions["network_stream"] + sock = stream.get_extra_info("socket") + opt = sock.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY) + assert bool(opt) is True + + +@pytest.mark.anyio +async def test_pool_recovers_from_connection_breakage( + server_config: uvicorn.Config, server_url: str +) -> None: + async with httpcore.AsyncConnectionPool( + max_connections=1, max_keepalive_connections=1, keepalive_expiry=10 + ) as pool: + with Server(server_config).run_in_thread(): + response = await pool.request("GET", server_url) + assert response.status == 200 + + assert len(pool.connections) == 1 + conn = pool.connections[0] + + stream = response.extensions["network_stream"] + assert stream.get_extra_info("is_readable") is False + + assert ( + stream.get_extra_info("is_readable") is True + ), "Should break by coming readable" + + with Server(server_config).run_in_thread(): + assert len(pool.connections) == 1 + assert pool.connections[0] is conn, "Should be the broken connection" + + response = await pool.request("GET", server_url) + assert response.status == 200 + + assert len(pool.connections) == 1 + assert pool.connections[0] is not conn, "Should be a new connection" + + +@pytest.mark.anyio +async def test_unix_domain_socket(server_port, server_config, server_url): + uds = f"{gettempdir()}/test_httpcore_app.sock" + if os.path.exists(uds): + os.remove(uds) # pragma: nocover + + uds_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + try: + uds_sock.bind(uds) + + with Server(server_config).run_in_thread(sockets=[uds_sock]): + async with httpcore.AsyncConnectionPool(uds=uds) as pool: + response = await pool.request("GET", server_url) + assert response.status == 200 + finally: + uds_sock.close() + os.remove(uds) diff --git a/tests/_sync/test_integration.py b/tests/_sync/test_integration.py index e3327e696..d114f878c 100644 --- a/tests/_sync/test_integration.py +++ b/tests/_sync/test_integration.py @@ -1,8 +1,13 @@ +import os +import socket import ssl +from tempfile import gettempdir import pytest +import uvicorn import httpcore +from tests.conftest import Server @@ -49,3 +54,82 @@ def test_extra_info(httpbin_secure): assert invalid is None stream.get_extra_info("is_readable") + + + +@pytest.mark.parametrize("keep_alive_enabled", [True, False]) +def test_socket_options( + server: Server, server_url: str, keep_alive_enabled: bool +) -> None: + socket_options = [(socket.SOL_SOCKET, socket.SO_KEEPALIVE, int(keep_alive_enabled))] + with httpcore.ConnectionPool(socket_options=socket_options) as pool: + response = pool.request("GET", server_url) + assert response.status == 200 + + stream = response.extensions["network_stream"] + sock = stream.get_extra_info("socket") + opt = sock.getsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE) + assert bool(opt) is keep_alive_enabled + + + +def test_socket_no_nagle(server: Server, server_url: str) -> None: + with httpcore.ConnectionPool() as pool: + response = pool.request("GET", server_url) + assert response.status == 200 + + stream = response.extensions["network_stream"] + sock = stream.get_extra_info("socket") + opt = sock.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY) + assert bool(opt) is True + + + +def test_pool_recovers_from_connection_breakage( + server_config: uvicorn.Config, server_url: str +) -> None: + with httpcore.ConnectionPool( + max_connections=1, max_keepalive_connections=1, keepalive_expiry=10 + ) as pool: + with Server(server_config).run_in_thread(): + response = pool.request("GET", server_url) + assert response.status == 200 + + assert len(pool.connections) == 1 + conn = pool.connections[0] + + stream = response.extensions["network_stream"] + assert stream.get_extra_info("is_readable") is False + + assert ( + stream.get_extra_info("is_readable") is True + ), "Should break by coming readable" + + with Server(server_config).run_in_thread(): + assert len(pool.connections) == 1 + assert pool.connections[0] is conn, "Should be the broken connection" + + response = pool.request("GET", server_url) + assert response.status == 200 + + assert len(pool.connections) == 1 + assert pool.connections[0] is not conn, "Should be a new connection" + + + +def test_unix_domain_socket(server_port, server_config, server_url): + uds = f"{gettempdir()}/test_httpcore_app.sock" + if os.path.exists(uds): + os.remove(uds) # pragma: nocover + + uds_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + try: + uds_sock.bind(uds) + + with Server(server_config).run_in_thread(sockets=[uds_sock]): + with httpcore.ConnectionPool(uds=uds) as pool: + response = pool.request("GET", server_url) + assert response.status == 200 + finally: + uds_sock.close() + os.remove(uds) diff --git a/tests/benchmark/client.py b/tests/benchmark/client.py index d07802b01..e62968c97 100644 --- a/tests/benchmark/client.py +++ b/tests/benchmark/client.py @@ -1,5 +1,4 @@ import asyncio -import os import sys import time from concurrent.futures import ThreadPoolExecutor @@ -21,7 +20,7 @@ CONCURRENCY = 20 POOL_LIMIT = 100 PROFILE = False -os.environ["HTTPCORE_PREFER_ANYIO"] = "0" +NET_BACKEND = httpcore.AsyncIOBackend def duration(start: float) -> int: @@ -66,7 +65,9 @@ async def aiohttp_get(session: aiohttp.ClientSession, timings: List[int]) -> Non assert res.status == 200, f"status={res.status}" timings.append(duration(start)) - async with httpcore.AsyncConnectionPool(max_connections=POOL_LIMIT) as pool: + async with httpcore.AsyncConnectionPool( + max_connections=POOL_LIMIT, network_backend=NET_BACKEND() + ) as pool: # warmup await gather_limited_concurrency( (httpcore_get(pool, []) for _ in range(REQUESTS)), CONCURRENCY * 2 diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 000000000..355eeca81 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,100 @@ +import socket +import time +from contextlib import contextmanager +from threading import Thread +from typing import Any, Awaitable, Callable, Generator, Iterator, List, Optional + +import pytest +import uvicorn + +from httpcore import AnyIOBackend, AsyncIOBackend +from httpcore._backends.auto import AutoBackend + + +@pytest.fixture( + params=[ + pytest.param(("asyncio", {"httpcore_use_anyio": False}), id="asyncio"), + pytest.param(("asyncio", {"httpcore_use_anyio": True}), id="asyncio+anyio"), + pytest.param(("trio", {}), id="trio"), + ] +) +def anyio_backend(request, monkeypatch): + backend_name, options = request.param + options = {**options} + use_anyio = options.pop("httpcore_use_anyio", False) + + # TODO remove this marker once we have a way to define the asyncio backend in AutoBackend + no_auto_backend_patch = bool( + request.node.get_closest_marker("no_auto_backend_patch") + ) + + if backend_name != "trio" and not no_auto_backend_patch: + # TODO replace with a proper interface in AutoBackend to setup either the AnyIO or asyncio backend + async def patch_init_backend(auto_backend: AutoBackend) -> None: + if hasattr(auto_backend, "_backend"): + return + auto_backend._backend = AnyIOBackend() if use_anyio else AsyncIOBackend() + + monkeypatch.setattr(AutoBackend, "_init_backend", patch_init_backend) + + return backend_name, options + + +class Server(uvicorn.Server): + @contextmanager + def run_in_thread( + self, sockets: Optional[List[socket.socket]] = None + ) -> Generator[None, None, None]: + thread = Thread(target=lambda: self.run(sockets)) + thread.start() + start_time = time.monotonic() + try: + while not self.started: + time.sleep(0.01) + if (time.monotonic() - start_time) > 5: + raise TimeoutError() # pragma: nocover + yield + finally: + self.should_exit = True + thread.join() + + +@pytest.fixture +def server_port() -> int: + return 1111 + + +@pytest.fixture +def server_url(server_port: int) -> str: + return f"http://127.0.0.1:{server_port}" + + +@pytest.fixture +def server_app() -> Callable[[Any, Any, Any], Awaitable[None]]: + async def app(scope, receive, send): + assert scope["type"] == "http" + assert not (await receive()).get("more_body", False) + + start = { + "type": "http.response.start", + "status": 200, + "headers": [[b"content-type", b"text/plain"]], + } + body = {"type": "http.response.body", "body": b"Hello World"} + await send(start) + await send(body) + + return app + + +@pytest.fixture +def server_config( + server_port: int, server_app: Callable[[Any, Any, Any], Awaitable[None]] +) -> uvicorn.Config: + return uvicorn.Config(server_app, port=server_port, log_level="error") + + +@pytest.fixture +def server(server_config: uvicorn.Config) -> Iterator[None]: + with Server(server_config).run_in_thread(): + yield diff --git a/tests/test_auto_backend.py b/tests/test_auto_backend.py new file mode 100644 index 000000000..6643d9943 --- /dev/null +++ b/tests/test_auto_backend.py @@ -0,0 +1,41 @@ +from typing import Generator, List + +import pytest +from sniffio import current_async_library + +from httpcore import AnyIOBackend, AsyncIOBackend, AsyncNetworkBackend, TrioBackend +from httpcore._backends.auto import AutoBackend + + +@pytest.fixture(scope="session", autouse=True) +def check_tested_backends() -> Generator[List[AsyncNetworkBackend], None, None]: + # Ensure tests cover all supported backend variants + backends: List[AsyncNetworkBackend] = [] + yield backends + assert {b.__class__ for b in backends} == { + AsyncIOBackend, + AnyIOBackend, + TrioBackend, + } + + +@pytest.mark.anyio +async def test_init_backend(check_tested_backends: List[AsyncNetworkBackend]) -> None: + auto = AutoBackend() + await auto._init_backend() + assert auto._backend is not None + check_tested_backends.append(auto._backend) + + +@pytest.mark.anyio +@pytest.mark.no_auto_backend_patch +async def test_auto_backend_uses_expected_backend(monkeypatch): + auto = AutoBackend() + await auto._init_backend() + assert auto._backend is not None + + if current_async_library() == "trio": + assert isinstance(auto._backend, TrioBackend) + else: + # TODO add support for choosing the AsyncIOBackend in AutoBackend + assert isinstance(auto._backend, AnyIOBackend)