diff --git a/httpcore/_async/connection_pool.py b/httpcore/_async/connection_pool.py index 214dfc4be..5b603fd11 100644 --- a/httpcore/_async/connection_pool.py +++ b/httpcore/_async/connection_pool.py @@ -238,6 +238,7 @@ def _assign_requests_to_connections(self) -> List[AsyncConnectionInterface]: those connections to be handled seperately. """ closing_connections = [] + idling_count = 0 # First we handle cleaning up any connections that are closed, # have expired their keep-alive, or surplus idle connections. @@ -249,27 +250,25 @@ def _assign_requests_to_connections(self) -> List[AsyncConnectionInterface]: # log: "closing expired connection" self._connections.remove(connection) closing_connections.append(connection) - elif ( - connection.is_idle() - and len([connection.is_idle() for connection in self._connections]) - > self._max_keepalive_connections - ): + elif connection.is_idle(): + if idling_count < self._max_keepalive_connections: + idling_count += 1 + continue # log: "closing idle connection" self._connections.remove(connection) closing_connections.append(connection) # Assign queued requests to connections. - queued_requests = [request for request in self._requests if request.is_queued()] - for pool_request in queued_requests: + for pool_request in list(self._requests): + if not pool_request.is_queued(): + continue + origin = pool_request.request.url.origin available_connections = [ connection for connection in self._connections if connection.can_handle_request(origin) and connection.is_available() ] - idle_connections = [ - connection for connection in self._connections if connection.is_idle() - ] # There are three cases for how we may be able to handle the request: # @@ -286,15 +285,18 @@ def _assign_requests_to_connections(self) -> List[AsyncConnectionInterface]: connection = self.create_connection(origin) self._connections.append(connection) pool_request.assign_to_connection(connection) - elif idle_connections: - # log: "closing idle connection" - connection = idle_connections[0] - self._connections.remove(connection) - closing_connections.append(connection) - # log: "creating new connection" - connection = self.create_connection(origin) - self._connections.append(connection) - pool_request.assign_to_connection(connection) + else: + idling_connection = next( + (c for c in self._connections if c.is_idle()), None + ) + if idling_connection is not None: + # log: "closing idle connection" + self._connections.remove(idling_connection) + closing_connections.append(idling_connection) + # log: "creating new connection" + new_connection = self.create_connection(origin) + self._connections.append(new_connection) + pool_request.assign_to_connection(new_connection) return closing_connections diff --git a/httpcore/_async/http11.py b/httpcore/_async/http11.py index 0493a923d..38f4a292d 100644 --- a/httpcore/_async/http11.py +++ b/httpcore/_async/http11.py @@ -1,5 +1,6 @@ import enum import logging +import random import ssl import time from types import TracebackType @@ -56,10 +57,12 @@ def __init__( origin: Origin, stream: AsyncNetworkStream, keepalive_expiry: Optional[float] = None, + socket_poll_interval_between: Tuple[float, float] = (1, 3), ) -> None: self._origin = origin self._network_stream = stream - self._keepalive_expiry: Optional[float] = keepalive_expiry + self._keepalive_expiry = keepalive_expiry + self._socket_poll_interval_between = socket_poll_interval_between self._expire_at: Optional[float] = None self._state = HTTPConnectionState.NEW self._state_lock = AsyncLock() @@ -68,6 +71,8 @@ def __init__( our_role=h11.CLIENT, max_incomplete_event_size=self.MAX_INCOMPLETE_EVENT_SIZE, ) + # Assuming we were just connected + self._network_stream_used_at = time.monotonic() async def handle_async_request(self, request: Request) -> Response: if not self.can_handle_request(request.url.origin): @@ -173,6 +178,7 @@ async def _send_event( bytes_to_send = self._h11_state.send(event) if bytes_to_send is not None: await self._network_stream.write(bytes_to_send, timeout=timeout) + self._network_stream_used_at = time.monotonic() # Receiving the response... @@ -224,6 +230,7 @@ async def _receive_event( data = await self._network_stream.read( self.READ_NUM_BYTES, timeout=timeout ) + self._network_stream_used_at = time.monotonic() # If we feed this case through h11 we'll raise an exception like: # @@ -281,16 +288,28 @@ def is_available(self) -> bool: def has_expired(self) -> bool: now = time.monotonic() keepalive_expired = self._expire_at is not None and now > self._expire_at + if keepalive_expired: + return True # If the HTTP connection is idle but the socket is readable, then the # only valid state is that the socket is about to return b"", indicating # a server-initiated disconnect. - server_disconnected = ( - self._state == HTTPConnectionState.IDLE - and self._network_stream.get_extra_info("is_readable") - ) + # Checking the readable status is relatively expensive so check it at a lower frequency. + if (now - self._network_stream_used_at) > self._socket_poll_interval(): + self._network_stream_used_at = now + server_disconnected = ( + self._state == HTTPConnectionState.IDLE + and self._network_stream.get_extra_info("is_readable") + ) + if server_disconnected: + return True + + return False - return keepalive_expired or server_disconnected + def _socket_poll_interval(self) -> float: + # Randomize to avoid polling for all the connections at once + low, high = self._socket_poll_interval_between + return random.uniform(low, high) def is_idle(self) -> bool: return self._state == HTTPConnectionState.IDLE diff --git a/httpcore/_sync/connection_pool.py b/httpcore/_sync/connection_pool.py index 01bec59e8..7d7a9156b 100644 --- a/httpcore/_sync/connection_pool.py +++ b/httpcore/_sync/connection_pool.py @@ -238,6 +238,7 @@ def _assign_requests_to_connections(self) -> List[ConnectionInterface]: those connections to be handled seperately. """ closing_connections = [] + idling_count = 0 # First we handle cleaning up any connections that are closed, # have expired their keep-alive, or surplus idle connections. @@ -249,27 +250,25 @@ def _assign_requests_to_connections(self) -> List[ConnectionInterface]: # log: "closing expired connection" self._connections.remove(connection) closing_connections.append(connection) - elif ( - connection.is_idle() - and len([connection.is_idle() for connection in self._connections]) - > self._max_keepalive_connections - ): + elif connection.is_idle(): + if idling_count < self._max_keepalive_connections: + idling_count += 1 + continue # log: "closing idle connection" self._connections.remove(connection) closing_connections.append(connection) # Assign queued requests to connections. - queued_requests = [request for request in self._requests if request.is_queued()] - for pool_request in queued_requests: + for pool_request in list(self._requests): + if not pool_request.is_queued(): + continue + origin = pool_request.request.url.origin available_connections = [ connection for connection in self._connections if connection.can_handle_request(origin) and connection.is_available() ] - idle_connections = [ - connection for connection in self._connections if connection.is_idle() - ] # There are three cases for how we may be able to handle the request: # @@ -286,15 +285,18 @@ def _assign_requests_to_connections(self) -> List[ConnectionInterface]: connection = self.create_connection(origin) self._connections.append(connection) pool_request.assign_to_connection(connection) - elif idle_connections: - # log: "closing idle connection" - connection = idle_connections[0] - self._connections.remove(connection) - closing_connections.append(connection) - # log: "creating new connection" - connection = self.create_connection(origin) - self._connections.append(connection) - pool_request.assign_to_connection(connection) + else: + idling_connection = next( + (c for c in self._connections if c.is_idle()), None + ) + if idling_connection is not None: + # log: "closing idle connection" + self._connections.remove(idling_connection) + closing_connections.append(idling_connection) + # log: "creating new connection" + new_connection = self.create_connection(origin) + self._connections.append(new_connection) + pool_request.assign_to_connection(new_connection) return closing_connections diff --git a/httpcore/_sync/http11.py b/httpcore/_sync/http11.py index a74ff8e80..eecfd33cc 100644 --- a/httpcore/_sync/http11.py +++ b/httpcore/_sync/http11.py @@ -1,5 +1,6 @@ import enum import logging +import random import ssl import time from types import TracebackType @@ -56,10 +57,12 @@ def __init__( origin: Origin, stream: NetworkStream, keepalive_expiry: Optional[float] = None, + socket_poll_interval_between: Tuple[float, float] = (1, 3), ) -> None: self._origin = origin self._network_stream = stream - self._keepalive_expiry: Optional[float] = keepalive_expiry + self._keepalive_expiry = keepalive_expiry + self._socket_poll_interval_between = socket_poll_interval_between self._expire_at: Optional[float] = None self._state = HTTPConnectionState.NEW self._state_lock = Lock() @@ -68,6 +71,8 @@ def __init__( our_role=h11.CLIENT, max_incomplete_event_size=self.MAX_INCOMPLETE_EVENT_SIZE, ) + # Assuming we were just connected + self._network_stream_used_at = time.monotonic() def handle_request(self, request: Request) -> Response: if not self.can_handle_request(request.url.origin): @@ -173,6 +178,7 @@ def _send_event( bytes_to_send = self._h11_state.send(event) if bytes_to_send is not None: self._network_stream.write(bytes_to_send, timeout=timeout) + self._network_stream_used_at = time.monotonic() # Receiving the response... @@ -224,6 +230,7 @@ def _receive_event( data = self._network_stream.read( self.READ_NUM_BYTES, timeout=timeout ) + self._network_stream_used_at = time.monotonic() # If we feed this case through h11 we'll raise an exception like: # @@ -281,16 +288,28 @@ def is_available(self) -> bool: def has_expired(self) -> bool: now = time.monotonic() keepalive_expired = self._expire_at is not None and now > self._expire_at + if keepalive_expired: + return True # If the HTTP connection is idle but the socket is readable, then the # only valid state is that the socket is about to return b"", indicating # a server-initiated disconnect. - server_disconnected = ( - self._state == HTTPConnectionState.IDLE - and self._network_stream.get_extra_info("is_readable") - ) + # Checking the readable status is relatively expensive so check it at a lower frequency. + if (now - self._network_stream_used_at) > self._socket_poll_interval(): + self._network_stream_used_at = now + server_disconnected = ( + self._state == HTTPConnectionState.IDLE + and self._network_stream.get_extra_info("is_readable") + ) + if server_disconnected: + return True + + return False - return keepalive_expired or server_disconnected + def _socket_poll_interval(self) -> float: + # Randomize to avoid polling for all the connections at once + low, high = self._socket_poll_interval_between + return random.uniform(low, high) def is_idle(self) -> bool: return self._state == HTTPConnectionState.IDLE diff --git a/tests/_async/test_http11.py b/tests/_async/test_http11.py index 94f2febf0..cb275e1af 100644 --- a/tests/_async/test_http11.py +++ b/tests/_async/test_http11.py @@ -1,3 +1,5 @@ +from typing import Any, List + import pytest import httpcore @@ -16,7 +18,10 @@ async def test_http11_connection(): ] ) async with httpcore.AsyncHTTP11Connection( - origin=origin, stream=stream, keepalive_expiry=5.0 + origin=origin, + stream=stream, + keepalive_expiry=5.0, + socket_poll_interval_between=(0, 0), ) as conn: response = await conn.request("GET", "https://example.com/") assert response.status == 200 @@ -48,7 +53,9 @@ async def test_http11_connection_unread_response(): b"Hello, world!", ] ) - async with httpcore.AsyncHTTP11Connection(origin=origin, stream=stream) as conn: + async with httpcore.AsyncHTTP11Connection( + origin=origin, stream=stream, socket_poll_interval_between=(0, 0) + ) as conn: async with conn.stream("GET", "https://example.com/") as response: assert response.status == 200 @@ -70,7 +77,9 @@ async def test_http11_connection_with_remote_protocol_error(): """ origin = httpcore.Origin(b"https", b"example.com", 443) stream = httpcore.AsyncMockStream([b"Wait, this isn't valid HTTP!", b""]) - async with httpcore.AsyncHTTP11Connection(origin=origin, stream=stream) as conn: + async with httpcore.AsyncHTTP11Connection( + origin=origin, stream=stream, socket_poll_interval_between=(0, 0) + ) as conn: with pytest.raises(httpcore.RemoteProtocolError): await conn.request("GET", "https://example.com/") @@ -99,7 +108,9 @@ async def test_http11_connection_with_incomplete_response(): b"Hello, wor", ] ) - async with httpcore.AsyncHTTP11Connection(origin=origin, stream=stream) as conn: + async with httpcore.AsyncHTTP11Connection( + origin=origin, stream=stream, socket_poll_interval_between=(0, 0) + ) as conn: with pytest.raises(httpcore.RemoteProtocolError): await conn.request("GET", "https://example.com/") @@ -129,7 +140,9 @@ async def test_http11_connection_with_local_protocol_error(): b"Hello, world!", ] ) - async with httpcore.AsyncHTTP11Connection(origin=origin, stream=stream) as conn: + async with httpcore.AsyncHTTP11Connection( + origin=origin, stream=stream, socket_poll_interval_between=(0, 0) + ) as conn: with pytest.raises(httpcore.LocalProtocolError) as exc_info: await conn.request("GET", "https://example.com/", headers={"Host": "\0"}) @@ -145,6 +158,85 @@ async def test_http11_connection_with_local_protocol_error(): ) +@pytest.mark.anyio +async def test_http11_has_expired_checks_readable_status(): + class AsyncMockStreamReadable(httpcore.AsyncMockStream): + def __init__(self, buffer: List[bytes]) -> None: + super().__init__(buffer) + self.is_readable = False + self.checks = 0 + + def get_extra_info(self, info: str) -> Any: + if info == "is_readable": + self.checks += 1 + return self.is_readable + return super().get_extra_info(info) # pragma: nocover + + origin = httpcore.Origin(b"https", b"example.com", 443) + stream = AsyncMockStreamReadable( + [ + b"HTTP/1.1 200 OK\r\n", + b"Content-Type: plain/text\r\n", + b"Content-Length: 13\r\n", + b"\r\n", + b"Hello, world!", + ] + ) + async with httpcore.AsyncHTTP11Connection( + origin=origin, stream=stream, socket_poll_interval_between=(0, 0) + ) as conn: + response = await conn.request("GET", "https://example.com/") + assert response.status == 200 + + assert stream.checks == 0 + assert not conn.has_expired() + stream.is_readable = True + assert conn.has_expired() + assert stream.checks == 2 + + +@pytest.mark.anyio +@pytest.mark.parametrize("should_check", [True, False]) +async def test_http11_has_expired_checks_readable_status_by_interval( + monkeypatch, should_check +): + origin = httpcore.Origin(b"https", b"example.com", 443) + stream = httpcore.AsyncMockStream( + [ + b"HTTP/1.1 200 OK\r\n", + b"Content-Type: plain/text\r\n", + b"Content-Length: 13\r\n", + b"\r\n", + b"Hello, world!", + ] + ) + async with httpcore.AsyncHTTP11Connection( + origin=origin, + stream=stream, + keepalive_expiry=5.0, + socket_poll_interval_between=(0, 0) if should_check else (999, 999), + ) as conn: + orig = conn._network_stream.get_extra_info + calls = [] + + def patch_get_extra_info(attr_name: str) -> Any: + calls.append(attr_name) + return orig(attr_name) + + monkeypatch.setattr( + conn._network_stream, "get_extra_info", patch_get_extra_info + ) + + response = await conn.request("GET", "https://example.com/") + assert response.status == 200 + + assert "is_readable" not in calls + assert not conn.has_expired() + assert ( + ("is_readable" in calls) if should_check else ("is_readable" not in calls) + ) + + @pytest.mark.anyio async def test_http11_connection_handles_one_active_request(): """ diff --git a/tests/_sync/test_http11.py b/tests/_sync/test_http11.py index f2fa28f4c..a870865b6 100644 --- a/tests/_sync/test_http11.py +++ b/tests/_sync/test_http11.py @@ -1,3 +1,5 @@ +from typing import Any, List + import pytest import httpcore @@ -16,7 +18,10 @@ def test_http11_connection(): ] ) with httpcore.HTTP11Connection( - origin=origin, stream=stream, keepalive_expiry=5.0 + origin=origin, + stream=stream, + keepalive_expiry=5.0, + socket_poll_interval_between=(0, 0), ) as conn: response = conn.request("GET", "https://example.com/") assert response.status == 200 @@ -48,7 +53,9 @@ def test_http11_connection_unread_response(): b"Hello, world!", ] ) - with httpcore.HTTP11Connection(origin=origin, stream=stream) as conn: + with httpcore.HTTP11Connection( + origin=origin, stream=stream, socket_poll_interval_between=(0, 0) + ) as conn: with conn.stream("GET", "https://example.com/") as response: assert response.status == 200 @@ -70,7 +77,9 @@ def test_http11_connection_with_remote_protocol_error(): """ origin = httpcore.Origin(b"https", b"example.com", 443) stream = httpcore.MockStream([b"Wait, this isn't valid HTTP!", b""]) - with httpcore.HTTP11Connection(origin=origin, stream=stream) as conn: + with httpcore.HTTP11Connection( + origin=origin, stream=stream, socket_poll_interval_between=(0, 0) + ) as conn: with pytest.raises(httpcore.RemoteProtocolError): conn.request("GET", "https://example.com/") @@ -99,7 +108,9 @@ def test_http11_connection_with_incomplete_response(): b"Hello, wor", ] ) - with httpcore.HTTP11Connection(origin=origin, stream=stream) as conn: + with httpcore.HTTP11Connection( + origin=origin, stream=stream, socket_poll_interval_between=(0, 0) + ) as conn: with pytest.raises(httpcore.RemoteProtocolError): conn.request("GET", "https://example.com/") @@ -129,7 +140,9 @@ def test_http11_connection_with_local_protocol_error(): b"Hello, world!", ] ) - with httpcore.HTTP11Connection(origin=origin, stream=stream) as conn: + with httpcore.HTTP11Connection( + origin=origin, stream=stream, socket_poll_interval_between=(0, 0) + ) as conn: with pytest.raises(httpcore.LocalProtocolError) as exc_info: conn.request("GET", "https://example.com/", headers={"Host": "\0"}) @@ -146,6 +159,85 @@ def test_http11_connection_with_local_protocol_error(): +def test_http11_has_expired_checks_readable_status(): + class MockStreamReadable(httpcore.MockStream): + def __init__(self, buffer: List[bytes]) -> None: + super().__init__(buffer) + self.is_readable = False + self.checks = 0 + + def get_extra_info(self, info: str) -> Any: + if info == "is_readable": + self.checks += 1 + return self.is_readable + return super().get_extra_info(info) # pragma: nocover + + origin = httpcore.Origin(b"https", b"example.com", 443) + stream = MockStreamReadable( + [ + b"HTTP/1.1 200 OK\r\n", + b"Content-Type: plain/text\r\n", + b"Content-Length: 13\r\n", + b"\r\n", + b"Hello, world!", + ] + ) + with httpcore.HTTP11Connection( + origin=origin, stream=stream, socket_poll_interval_between=(0, 0) + ) as conn: + response = conn.request("GET", "https://example.com/") + assert response.status == 200 + + assert stream.checks == 0 + assert not conn.has_expired() + stream.is_readable = True + assert conn.has_expired() + assert stream.checks == 2 + + + +@pytest.mark.parametrize("should_check", [True, False]) +def test_http11_has_expired_checks_readable_status_by_interval( + monkeypatch, should_check +): + origin = httpcore.Origin(b"https", b"example.com", 443) + stream = httpcore.MockStream( + [ + b"HTTP/1.1 200 OK\r\n", + b"Content-Type: plain/text\r\n", + b"Content-Length: 13\r\n", + b"\r\n", + b"Hello, world!", + ] + ) + with httpcore.HTTP11Connection( + origin=origin, + stream=stream, + keepalive_expiry=5.0, + socket_poll_interval_between=(0, 0) if should_check else (999, 999), + ) as conn: + orig = conn._network_stream.get_extra_info + calls = [] + + def patch_get_extra_info(attr_name: str) -> Any: + calls.append(attr_name) + return orig(attr_name) + + monkeypatch.setattr( + conn._network_stream, "get_extra_info", patch_get_extra_info + ) + + response = conn.request("GET", "https://example.com/") + assert response.status == 200 + + assert "is_readable" not in calls + assert not conn.has_expired() + assert ( + ("is_readable" in calls) if should_check else ("is_readable" not in calls) + ) + + + def test_http11_connection_handles_one_active_request(): """ Attempting to send a request while one is already in-flight will raise