From b6b119c42f0af5c9bd935efa93da77bf114495ba Mon Sep 17 00:00:00 2001
From: Markus Sintonen <markus.sintonen@gmail.com>
Date: Mon, 10 Jun 2024 21:32:07 +0300
Subject: [PATCH] Optimize connection pool

---
 httpcore/_async/connection_pool.py |  40 +++++------
 httpcore/_async/http11.py          |  31 +++++++--
 httpcore/_sync/connection_pool.py  |  40 +++++------
 httpcore/_sync/http11.py           |  31 +++++++--
 tests/_async/test_http11.py        | 102 +++++++++++++++++++++++++++--
 tests/_sync/test_http11.py         | 102 +++++++++++++++++++++++++++--
 6 files changed, 286 insertions(+), 60 deletions(-)

diff --git a/httpcore/_async/connection_pool.py b/httpcore/_async/connection_pool.py
index 214dfc4be..5b603fd11 100644
--- a/httpcore/_async/connection_pool.py
+++ b/httpcore/_async/connection_pool.py
@@ -238,6 +238,7 @@ def _assign_requests_to_connections(self) -> List[AsyncConnectionInterface]:
         those connections to be handled seperately.
         """
         closing_connections = []
+        idling_count = 0
 
         # First we handle cleaning up any connections that are closed,
         # have expired their keep-alive, or surplus idle connections.
@@ -249,27 +250,25 @@ def _assign_requests_to_connections(self) -> List[AsyncConnectionInterface]:
                 # log: "closing expired connection"
                 self._connections.remove(connection)
                 closing_connections.append(connection)
-            elif (
-                connection.is_idle()
-                and len([connection.is_idle() for connection in self._connections])
-                > self._max_keepalive_connections
-            ):
+            elif connection.is_idle():
+                if idling_count < self._max_keepalive_connections:
+                    idling_count += 1
+                    continue
                 # log: "closing idle connection"
                 self._connections.remove(connection)
                 closing_connections.append(connection)
 
         # Assign queued requests to connections.
-        queued_requests = [request for request in self._requests if request.is_queued()]
-        for pool_request in queued_requests:
+        for pool_request in list(self._requests):
+            if not pool_request.is_queued():
+                continue
+
             origin = pool_request.request.url.origin
             available_connections = [
                 connection
                 for connection in self._connections
                 if connection.can_handle_request(origin) and connection.is_available()
             ]
-            idle_connections = [
-                connection for connection in self._connections if connection.is_idle()
-            ]
 
             # There are three cases for how we may be able to handle the request:
             #
@@ -286,15 +285,18 @@ def _assign_requests_to_connections(self) -> List[AsyncConnectionInterface]:
                 connection = self.create_connection(origin)
                 self._connections.append(connection)
                 pool_request.assign_to_connection(connection)
-            elif idle_connections:
-                # log: "closing idle connection"
-                connection = idle_connections[0]
-                self._connections.remove(connection)
-                closing_connections.append(connection)
-                # log: "creating new connection"
-                connection = self.create_connection(origin)
-                self._connections.append(connection)
-                pool_request.assign_to_connection(connection)
+            else:
+                idling_connection = next(
+                    (c for c in self._connections if c.is_idle()), None
+                )
+                if idling_connection is not None:
+                    # log: "closing idle connection"
+                    self._connections.remove(idling_connection)
+                    closing_connections.append(idling_connection)
+                    # log: "creating new connection"
+                    new_connection = self.create_connection(origin)
+                    self._connections.append(new_connection)
+                    pool_request.assign_to_connection(new_connection)
 
         return closing_connections
 
diff --git a/httpcore/_async/http11.py b/httpcore/_async/http11.py
index 0493a923d..38f4a292d 100644
--- a/httpcore/_async/http11.py
+++ b/httpcore/_async/http11.py
@@ -1,5 +1,6 @@
 import enum
 import logging
+import random
 import ssl
 import time
 from types import TracebackType
@@ -56,10 +57,12 @@ def __init__(
         origin: Origin,
         stream: AsyncNetworkStream,
         keepalive_expiry: Optional[float] = None,
+        socket_poll_interval_between: Tuple[float, float] = (1, 3),
     ) -> None:
         self._origin = origin
         self._network_stream = stream
-        self._keepalive_expiry: Optional[float] = keepalive_expiry
+        self._keepalive_expiry = keepalive_expiry
+        self._socket_poll_interval_between = socket_poll_interval_between
         self._expire_at: Optional[float] = None
         self._state = HTTPConnectionState.NEW
         self._state_lock = AsyncLock()
@@ -68,6 +71,8 @@ def __init__(
             our_role=h11.CLIENT,
             max_incomplete_event_size=self.MAX_INCOMPLETE_EVENT_SIZE,
         )
+        # Assuming we were just connected
+        self._network_stream_used_at = time.monotonic()
 
     async def handle_async_request(self, request: Request) -> Response:
         if not self.can_handle_request(request.url.origin):
@@ -173,6 +178,7 @@ async def _send_event(
         bytes_to_send = self._h11_state.send(event)
         if bytes_to_send is not None:
             await self._network_stream.write(bytes_to_send, timeout=timeout)
+            self._network_stream_used_at = time.monotonic()
 
     # Receiving the response...
 
@@ -224,6 +230,7 @@ async def _receive_event(
                 data = await self._network_stream.read(
                     self.READ_NUM_BYTES, timeout=timeout
                 )
+                self._network_stream_used_at = time.monotonic()
 
                 # If we feed this case through h11 we'll raise an exception like:
                 #
@@ -281,16 +288,28 @@ def is_available(self) -> bool:
     def has_expired(self) -> bool:
         now = time.monotonic()
         keepalive_expired = self._expire_at is not None and now > self._expire_at
+        if keepalive_expired:
+            return True
 
         # If the HTTP connection is idle but the socket is readable, then the
         # only valid state is that the socket is about to return b"", indicating
         # a server-initiated disconnect.
-        server_disconnected = (
-            self._state == HTTPConnectionState.IDLE
-            and self._network_stream.get_extra_info("is_readable")
-        )
+        # Checking the readable status is relatively expensive so check it at a lower frequency.
+        if (now - self._network_stream_used_at) > self._socket_poll_interval():
+            self._network_stream_used_at = now
+            server_disconnected = (
+                self._state == HTTPConnectionState.IDLE
+                and self._network_stream.get_extra_info("is_readable")
+            )
+            if server_disconnected:
+                return True
+
+        return False
 
-        return keepalive_expired or server_disconnected
+    def _socket_poll_interval(self) -> float:
+        # Randomize to avoid polling for all the connections at once
+        low, high = self._socket_poll_interval_between
+        return random.uniform(low, high)
 
     def is_idle(self) -> bool:
         return self._state == HTTPConnectionState.IDLE
diff --git a/httpcore/_sync/connection_pool.py b/httpcore/_sync/connection_pool.py
index 01bec59e8..7d7a9156b 100644
--- a/httpcore/_sync/connection_pool.py
+++ b/httpcore/_sync/connection_pool.py
@@ -238,6 +238,7 @@ def _assign_requests_to_connections(self) -> List[ConnectionInterface]:
         those connections to be handled seperately.
         """
         closing_connections = []
+        idling_count = 0
 
         # First we handle cleaning up any connections that are closed,
         # have expired their keep-alive, or surplus idle connections.
@@ -249,27 +250,25 @@ def _assign_requests_to_connections(self) -> List[ConnectionInterface]:
                 # log: "closing expired connection"
                 self._connections.remove(connection)
                 closing_connections.append(connection)
-            elif (
-                connection.is_idle()
-                and len([connection.is_idle() for connection in self._connections])
-                > self._max_keepalive_connections
-            ):
+            elif connection.is_idle():
+                if idling_count < self._max_keepalive_connections:
+                    idling_count += 1
+                    continue
                 # log: "closing idle connection"
                 self._connections.remove(connection)
                 closing_connections.append(connection)
 
         # Assign queued requests to connections.
-        queued_requests = [request for request in self._requests if request.is_queued()]
-        for pool_request in queued_requests:
+        for pool_request in list(self._requests):
+            if not pool_request.is_queued():
+                continue
+
             origin = pool_request.request.url.origin
             available_connections = [
                 connection
                 for connection in self._connections
                 if connection.can_handle_request(origin) and connection.is_available()
             ]
-            idle_connections = [
-                connection for connection in self._connections if connection.is_idle()
-            ]
 
             # There are three cases for how we may be able to handle the request:
             #
@@ -286,15 +285,18 @@ def _assign_requests_to_connections(self) -> List[ConnectionInterface]:
                 connection = self.create_connection(origin)
                 self._connections.append(connection)
                 pool_request.assign_to_connection(connection)
-            elif idle_connections:
-                # log: "closing idle connection"
-                connection = idle_connections[0]
-                self._connections.remove(connection)
-                closing_connections.append(connection)
-                # log: "creating new connection"
-                connection = self.create_connection(origin)
-                self._connections.append(connection)
-                pool_request.assign_to_connection(connection)
+            else:
+                idling_connection = next(
+                    (c for c in self._connections if c.is_idle()), None
+                )
+                if idling_connection is not None:
+                    # log: "closing idle connection"
+                    self._connections.remove(idling_connection)
+                    closing_connections.append(idling_connection)
+                    # log: "creating new connection"
+                    new_connection = self.create_connection(origin)
+                    self._connections.append(new_connection)
+                    pool_request.assign_to_connection(new_connection)
 
         return closing_connections
 
diff --git a/httpcore/_sync/http11.py b/httpcore/_sync/http11.py
index a74ff8e80..eecfd33cc 100644
--- a/httpcore/_sync/http11.py
+++ b/httpcore/_sync/http11.py
@@ -1,5 +1,6 @@
 import enum
 import logging
+import random
 import ssl
 import time
 from types import TracebackType
@@ -56,10 +57,12 @@ def __init__(
         origin: Origin,
         stream: NetworkStream,
         keepalive_expiry: Optional[float] = None,
+        socket_poll_interval_between: Tuple[float, float] = (1, 3),
     ) -> None:
         self._origin = origin
         self._network_stream = stream
-        self._keepalive_expiry: Optional[float] = keepalive_expiry
+        self._keepalive_expiry = keepalive_expiry
+        self._socket_poll_interval_between = socket_poll_interval_between
         self._expire_at: Optional[float] = None
         self._state = HTTPConnectionState.NEW
         self._state_lock = Lock()
@@ -68,6 +71,8 @@ def __init__(
             our_role=h11.CLIENT,
             max_incomplete_event_size=self.MAX_INCOMPLETE_EVENT_SIZE,
         )
+        # Assuming we were just connected
+        self._network_stream_used_at = time.monotonic()
 
     def handle_request(self, request: Request) -> Response:
         if not self.can_handle_request(request.url.origin):
@@ -173,6 +178,7 @@ def _send_event(
         bytes_to_send = self._h11_state.send(event)
         if bytes_to_send is not None:
             self._network_stream.write(bytes_to_send, timeout=timeout)
+            self._network_stream_used_at = time.monotonic()
 
     # Receiving the response...
 
@@ -224,6 +230,7 @@ def _receive_event(
                 data = self._network_stream.read(
                     self.READ_NUM_BYTES, timeout=timeout
                 )
+                self._network_stream_used_at = time.monotonic()
 
                 # If we feed this case through h11 we'll raise an exception like:
                 #
@@ -281,16 +288,28 @@ def is_available(self) -> bool:
     def has_expired(self) -> bool:
         now = time.monotonic()
         keepalive_expired = self._expire_at is not None and now > self._expire_at
+        if keepalive_expired:
+            return True
 
         # If the HTTP connection is idle but the socket is readable, then the
         # only valid state is that the socket is about to return b"", indicating
         # a server-initiated disconnect.
-        server_disconnected = (
-            self._state == HTTPConnectionState.IDLE
-            and self._network_stream.get_extra_info("is_readable")
-        )
+        # Checking the readable status is relatively expensive so check it at a lower frequency.
+        if (now - self._network_stream_used_at) > self._socket_poll_interval():
+            self._network_stream_used_at = now
+            server_disconnected = (
+                self._state == HTTPConnectionState.IDLE
+                and self._network_stream.get_extra_info("is_readable")
+            )
+            if server_disconnected:
+                return True
+
+        return False
 
-        return keepalive_expired or server_disconnected
+    def _socket_poll_interval(self) -> float:
+        # Randomize to avoid polling for all the connections at once
+        low, high = self._socket_poll_interval_between
+        return random.uniform(low, high)
 
     def is_idle(self) -> bool:
         return self._state == HTTPConnectionState.IDLE
diff --git a/tests/_async/test_http11.py b/tests/_async/test_http11.py
index 94f2febf0..cb275e1af 100644
--- a/tests/_async/test_http11.py
+++ b/tests/_async/test_http11.py
@@ -1,3 +1,5 @@
+from typing import Any, List
+
 import pytest
 
 import httpcore
@@ -16,7 +18,10 @@ async def test_http11_connection():
         ]
     )
     async with httpcore.AsyncHTTP11Connection(
-        origin=origin, stream=stream, keepalive_expiry=5.0
+        origin=origin,
+        stream=stream,
+        keepalive_expiry=5.0,
+        socket_poll_interval_between=(0, 0),
     ) as conn:
         response = await conn.request("GET", "https://example.com/")
         assert response.status == 200
@@ -48,7 +53,9 @@ async def test_http11_connection_unread_response():
             b"Hello, world!",
         ]
     )
-    async with httpcore.AsyncHTTP11Connection(origin=origin, stream=stream) as conn:
+    async with httpcore.AsyncHTTP11Connection(
+        origin=origin, stream=stream, socket_poll_interval_between=(0, 0)
+    ) as conn:
         async with conn.stream("GET", "https://example.com/") as response:
             assert response.status == 200
 
@@ -70,7 +77,9 @@ async def test_http11_connection_with_remote_protocol_error():
     """
     origin = httpcore.Origin(b"https", b"example.com", 443)
     stream = httpcore.AsyncMockStream([b"Wait, this isn't valid HTTP!", b""])
-    async with httpcore.AsyncHTTP11Connection(origin=origin, stream=stream) as conn:
+    async with httpcore.AsyncHTTP11Connection(
+        origin=origin, stream=stream, socket_poll_interval_between=(0, 0)
+    ) as conn:
         with pytest.raises(httpcore.RemoteProtocolError):
             await conn.request("GET", "https://example.com/")
 
@@ -99,7 +108,9 @@ async def test_http11_connection_with_incomplete_response():
             b"Hello, wor",
         ]
     )
-    async with httpcore.AsyncHTTP11Connection(origin=origin, stream=stream) as conn:
+    async with httpcore.AsyncHTTP11Connection(
+        origin=origin, stream=stream, socket_poll_interval_between=(0, 0)
+    ) as conn:
         with pytest.raises(httpcore.RemoteProtocolError):
             await conn.request("GET", "https://example.com/")
 
@@ -129,7 +140,9 @@ async def test_http11_connection_with_local_protocol_error():
             b"Hello, world!",
         ]
     )
-    async with httpcore.AsyncHTTP11Connection(origin=origin, stream=stream) as conn:
+    async with httpcore.AsyncHTTP11Connection(
+        origin=origin, stream=stream, socket_poll_interval_between=(0, 0)
+    ) as conn:
         with pytest.raises(httpcore.LocalProtocolError) as exc_info:
             await conn.request("GET", "https://example.com/", headers={"Host": "\0"})
 
@@ -145,6 +158,85 @@ async def test_http11_connection_with_local_protocol_error():
         )
 
 
+@pytest.mark.anyio
+async def test_http11_has_expired_checks_readable_status():
+    class AsyncMockStreamReadable(httpcore.AsyncMockStream):
+        def __init__(self, buffer: List[bytes]) -> None:
+            super().__init__(buffer)
+            self.is_readable = False
+            self.checks = 0
+
+        def get_extra_info(self, info: str) -> Any:
+            if info == "is_readable":
+                self.checks += 1
+                return self.is_readable
+            return super().get_extra_info(info)  # pragma: nocover
+
+    origin = httpcore.Origin(b"https", b"example.com", 443)
+    stream = AsyncMockStreamReadable(
+        [
+            b"HTTP/1.1 200 OK\r\n",
+            b"Content-Type: plain/text\r\n",
+            b"Content-Length: 13\r\n",
+            b"\r\n",
+            b"Hello, world!",
+        ]
+    )
+    async with httpcore.AsyncHTTP11Connection(
+        origin=origin, stream=stream, socket_poll_interval_between=(0, 0)
+    ) as conn:
+        response = await conn.request("GET", "https://example.com/")
+        assert response.status == 200
+
+        assert stream.checks == 0
+        assert not conn.has_expired()
+        stream.is_readable = True
+        assert conn.has_expired()
+        assert stream.checks == 2
+
+
+@pytest.mark.anyio
+@pytest.mark.parametrize("should_check", [True, False])
+async def test_http11_has_expired_checks_readable_status_by_interval(
+    monkeypatch, should_check
+):
+    origin = httpcore.Origin(b"https", b"example.com", 443)
+    stream = httpcore.AsyncMockStream(
+        [
+            b"HTTP/1.1 200 OK\r\n",
+            b"Content-Type: plain/text\r\n",
+            b"Content-Length: 13\r\n",
+            b"\r\n",
+            b"Hello, world!",
+        ]
+    )
+    async with httpcore.AsyncHTTP11Connection(
+        origin=origin,
+        stream=stream,
+        keepalive_expiry=5.0,
+        socket_poll_interval_between=(0, 0) if should_check else (999, 999),
+    ) as conn:
+        orig = conn._network_stream.get_extra_info
+        calls = []
+
+        def patch_get_extra_info(attr_name: str) -> Any:
+            calls.append(attr_name)
+            return orig(attr_name)
+
+        monkeypatch.setattr(
+            conn._network_stream, "get_extra_info", patch_get_extra_info
+        )
+
+        response = await conn.request("GET", "https://example.com/")
+        assert response.status == 200
+
+        assert "is_readable" not in calls
+        assert not conn.has_expired()
+        assert (
+            ("is_readable" in calls) if should_check else ("is_readable" not in calls)
+        )
+
+
 @pytest.mark.anyio
 async def test_http11_connection_handles_one_active_request():
     """
diff --git a/tests/_sync/test_http11.py b/tests/_sync/test_http11.py
index f2fa28f4c..a870865b6 100644
--- a/tests/_sync/test_http11.py
+++ b/tests/_sync/test_http11.py
@@ -1,3 +1,5 @@
+from typing import Any, List
+
 import pytest
 
 import httpcore
@@ -16,7 +18,10 @@ def test_http11_connection():
         ]
     )
     with httpcore.HTTP11Connection(
-        origin=origin, stream=stream, keepalive_expiry=5.0
+        origin=origin,
+        stream=stream,
+        keepalive_expiry=5.0,
+        socket_poll_interval_between=(0, 0),
     ) as conn:
         response = conn.request("GET", "https://example.com/")
         assert response.status == 200
@@ -48,7 +53,9 @@ def test_http11_connection_unread_response():
             b"Hello, world!",
         ]
     )
-    with httpcore.HTTP11Connection(origin=origin, stream=stream) as conn:
+    with httpcore.HTTP11Connection(
+        origin=origin, stream=stream, socket_poll_interval_between=(0, 0)
+    ) as conn:
         with conn.stream("GET", "https://example.com/") as response:
             assert response.status == 200
 
@@ -70,7 +77,9 @@ def test_http11_connection_with_remote_protocol_error():
     """
     origin = httpcore.Origin(b"https", b"example.com", 443)
     stream = httpcore.MockStream([b"Wait, this isn't valid HTTP!", b""])
-    with httpcore.HTTP11Connection(origin=origin, stream=stream) as conn:
+    with httpcore.HTTP11Connection(
+        origin=origin, stream=stream, socket_poll_interval_between=(0, 0)
+    ) as conn:
         with pytest.raises(httpcore.RemoteProtocolError):
             conn.request("GET", "https://example.com/")
 
@@ -99,7 +108,9 @@ def test_http11_connection_with_incomplete_response():
             b"Hello, wor",
         ]
     )
-    with httpcore.HTTP11Connection(origin=origin, stream=stream) as conn:
+    with httpcore.HTTP11Connection(
+        origin=origin, stream=stream, socket_poll_interval_between=(0, 0)
+    ) as conn:
         with pytest.raises(httpcore.RemoteProtocolError):
             conn.request("GET", "https://example.com/")
 
@@ -129,7 +140,9 @@ def test_http11_connection_with_local_protocol_error():
             b"Hello, world!",
         ]
     )
-    with httpcore.HTTP11Connection(origin=origin, stream=stream) as conn:
+    with httpcore.HTTP11Connection(
+        origin=origin, stream=stream, socket_poll_interval_between=(0, 0)
+    ) as conn:
         with pytest.raises(httpcore.LocalProtocolError) as exc_info:
             conn.request("GET", "https://example.com/", headers={"Host": "\0"})
 
@@ -146,6 +159,85 @@ def test_http11_connection_with_local_protocol_error():
 
 
 
+def test_http11_has_expired_checks_readable_status():
+    class MockStreamReadable(httpcore.MockStream):
+        def __init__(self, buffer: List[bytes]) -> None:
+            super().__init__(buffer)
+            self.is_readable = False
+            self.checks = 0
+
+        def get_extra_info(self, info: str) -> Any:
+            if info == "is_readable":
+                self.checks += 1
+                return self.is_readable
+            return super().get_extra_info(info)  # pragma: nocover
+
+    origin = httpcore.Origin(b"https", b"example.com", 443)
+    stream = MockStreamReadable(
+        [
+            b"HTTP/1.1 200 OK\r\n",
+            b"Content-Type: plain/text\r\n",
+            b"Content-Length: 13\r\n",
+            b"\r\n",
+            b"Hello, world!",
+        ]
+    )
+    with httpcore.HTTP11Connection(
+        origin=origin, stream=stream, socket_poll_interval_between=(0, 0)
+    ) as conn:
+        response = conn.request("GET", "https://example.com/")
+        assert response.status == 200
+
+        assert stream.checks == 0
+        assert not conn.has_expired()
+        stream.is_readable = True
+        assert conn.has_expired()
+        assert stream.checks == 2
+
+
+
+@pytest.mark.parametrize("should_check", [True, False])
+def test_http11_has_expired_checks_readable_status_by_interval(
+    monkeypatch, should_check
+):
+    origin = httpcore.Origin(b"https", b"example.com", 443)
+    stream = httpcore.MockStream(
+        [
+            b"HTTP/1.1 200 OK\r\n",
+            b"Content-Type: plain/text\r\n",
+            b"Content-Length: 13\r\n",
+            b"\r\n",
+            b"Hello, world!",
+        ]
+    )
+    with httpcore.HTTP11Connection(
+        origin=origin,
+        stream=stream,
+        keepalive_expiry=5.0,
+        socket_poll_interval_between=(0, 0) if should_check else (999, 999),
+    ) as conn:
+        orig = conn._network_stream.get_extra_info
+        calls = []
+
+        def patch_get_extra_info(attr_name: str) -> Any:
+            calls.append(attr_name)
+            return orig(attr_name)
+
+        monkeypatch.setattr(
+            conn._network_stream, "get_extra_info", patch_get_extra_info
+        )
+
+        response = conn.request("GET", "https://example.com/")
+        assert response.status == 200
+
+        assert "is_readable" not in calls
+        assert not conn.has_expired()
+        assert (
+            ("is_readable" in calls) if should_check else ("is_readable" not in calls)
+        )
+
+
+
 def test_http11_connection_handles_one_active_request():
     """
     Attempting to send a request while one is already in-flight will raise