Skip to content

Commit 4f2fc4c

Browse files
jrvb-rlclaude
andauthored
feat: share HTTP connection pool across SDK instances; refactor polling (#797)
Co-authored-by: Claude Opus 4.6 <[email protected]>
1 parent 7656890 commit 4f2fc4c

9 files changed

Lines changed: 736 additions & 169 deletions

File tree

src/runloop_api_client/_base_client.py

Lines changed: 220 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,10 @@
88
import asyncio
99
import inspect
1010
import logging
11+
import weakref
1112
import platform
1213
import warnings
14+
import threading
1315
import email.utils
1416
from types import TracebackType
1517
from random import random
@@ -90,6 +92,88 @@
9092

9193
log: logging.Logger = logging.getLogger(__name__)
9294

95+
# Shared HTTP transport state. We share transports (connection pools) rather
96+
# than full httpx clients so each SDK instance keeps its own cookie jar and
97+
# mutable client state. Refcounted wrappers close the real transport only
98+
# when the last user releases it.
99+
# The async transport is keyed by event loop because connections bind to the
100+
# loop that created them and cannot be reused across asyncio.run() calls.
101+
_pool_lock = threading.Lock()
102+
103+
104+
class _SharedTransport(httpx.BaseTransport):
105+
"""Refcounted wrapper: delegates to a real transport, closes it when refcount hits 0."""
106+
107+
def __init__(self, transport: httpx.BaseTransport) -> None:
108+
self._transport = transport
109+
self._refcount = 1
110+
self._lock = threading.Lock()
111+
112+
@property
113+
def refcount(self) -> int:
114+
return self._refcount
115+
116+
def acquire(self) -> bool:
117+
with self._lock:
118+
if self._refcount <= 0:
119+
return False
120+
self._refcount += 1
121+
return True
122+
123+
@override
124+
def handle_request(self, request: httpx.Request) -> httpx.Response:
125+
return self._transport.handle_request(request)
126+
127+
@override
128+
def close(self) -> None:
129+
should_close = False
130+
with self._lock:
131+
self._refcount -= 1
132+
if self._refcount <= 0:
133+
should_close = True
134+
if should_close:
135+
self._transport.close()
136+
137+
138+
class _SharedAsyncTransport(httpx.AsyncBaseTransport):
139+
"""Async refcounted wrapper: delegates to a real async transport."""
140+
141+
def __init__(self, transport: httpx.AsyncBaseTransport) -> None:
142+
self._transport = transport
143+
self._refcount = 1
144+
self._lock = threading.Lock()
145+
146+
@property
147+
def refcount(self) -> int:
148+
return self._refcount
149+
150+
def acquire(self) -> bool:
151+
with self._lock:
152+
if self._refcount <= 0:
153+
return False
154+
self._refcount += 1
155+
return True
156+
157+
@override
158+
async def handle_async_request(self, request: httpx.Request) -> httpx.Response:
159+
return await self._transport.handle_async_request(request)
160+
161+
@override
162+
async def aclose(self) -> None:
163+
should_close = False
164+
with self._lock:
165+
self._refcount -= 1
166+
if self._refcount <= 0:
167+
should_close = True
168+
if should_close:
169+
await self._transport.aclose()
170+
171+
172+
_shared_sync_transport: _SharedTransport | None = None
173+
_shared_async_transports: weakref.WeakKeyDictionary[asyncio.AbstractEventLoop, _SharedAsyncTransport] = (
174+
weakref.WeakKeyDictionary()
175+
)
176+
93177
# TODO: make base page type vars covariant
94178
SyncPageT = TypeVar("SyncPageT", bound="BaseSyncPage[Any]")
95179
AsyncPageT = TypeVar("AsyncPageT", bound="BaseAsyncPage[Any]")
@@ -816,6 +900,7 @@ def __init__(self, **kwargs: Any) -> None:
816900
kwargs.setdefault("timeout", DEFAULT_TIMEOUT)
817901
kwargs.setdefault("limits", DEFAULT_CONNECTION_LIMITS)
818902
kwargs.setdefault("follow_redirects", True)
903+
kwargs.setdefault("http2", True)
819904
super().__init__(**kwargs)
820905

821906

@@ -845,6 +930,8 @@ def __del__(self) -> None:
845930
class SyncAPIClient(BaseClient[httpx.Client, Stream[Any]]):
846931
_client: httpx.Client
847932
_default_stream_cls: type[Stream[Any]] | None = None
933+
_uses_shared_pool: bool
934+
_closed: bool
848935

849936
def __init__(
850937
self,
@@ -857,6 +944,7 @@ def __init__(
857944
custom_headers: Mapping[str, str] | None = None,
858945
custom_query: Mapping[str, object] | None = None,
859946
_strict_response_validation: bool,
947+
shared_http_pool: bool = True,
860948
) -> None:
861949
if not is_given(timeout):
862950
# if the user passed in a custom http client with a non-default
@@ -886,24 +974,46 @@ def __init__(
886974
custom_headers=custom_headers,
887975
_strict_response_validation=_strict_response_validation,
888976
)
889-
self._client = http_client or SyncHttpxClientWrapper(
890-
base_url=base_url,
891-
# cast to a valid type because mypy doesn't understand our type narrowing
892-
timeout=cast(Timeout, timeout),
893-
)
977+
978+
self._closed = False
979+
980+
if http_client is not None:
981+
self._client = http_client
982+
self._uses_shared_pool = False
983+
elif shared_http_pool:
984+
global _shared_sync_transport
985+
with _pool_lock:
986+
if _shared_sync_transport is None or not _shared_sync_transport.acquire():
987+
_shared_sync_transport = _SharedTransport(
988+
httpx.HTTPTransport(limits=DEFAULT_CONNECTION_LIMITS, http2=True),
989+
)
990+
self._client = SyncHttpxClientWrapper(
991+
base_url=base_url,
992+
timeout=cast(Timeout, timeout),
993+
transport=_shared_sync_transport,
994+
)
995+
self._uses_shared_pool = True
996+
else:
997+
self._client = SyncHttpxClientWrapper(
998+
base_url=base_url,
999+
timeout=cast(Timeout, timeout),
1000+
)
1001+
self._uses_shared_pool = False
8941002

8951003
def is_closed(self) -> bool:
896-
return self._client.is_closed
1004+
return self._closed or self._client.is_closed
8971005

8981006
def close(self) -> None:
8991007
"""Close the underlying HTTPX client.
9001008
9011009
The client will *not* be usable after this.
9021010
"""
903-
# If an error is thrown while constructing a client, self._client
904-
# may not be present
905-
if hasattr(self, "_client"):
906-
self._client.close()
1011+
if not hasattr(self, "_client"):
1012+
return
1013+
if self._closed:
1014+
return
1015+
self._closed = True
1016+
self._client.close()
9071017

9081018
def __enter__(self: _T) -> _T:
9091019
return self
@@ -1018,6 +1128,7 @@ def request(
10181128
max_retries=max_retries,
10191129
options=input_options,
10201130
response=None,
1131+
error=err,
10211132
)
10221133
continue
10231134

@@ -1032,6 +1143,7 @@ def request(
10321143
max_retries=max_retries,
10331144
options=input_options,
10341145
response=None,
1146+
error=err,
10351147
)
10361148
continue
10371149

@@ -1083,7 +1195,13 @@ def request(
10831195
)
10841196

10851197
def _sleep_for_retry(
1086-
self, *, retries_taken: int, max_retries: int, options: FinalRequestOptions, response: httpx.Response | None
1198+
self,
1199+
*,
1200+
retries_taken: int,
1201+
max_retries: int,
1202+
options: FinalRequestOptions,
1203+
response: httpx.Response | None,
1204+
error: BaseException | None = None,
10871205
) -> None:
10881206
remaining_retries = max_retries - retries_taken
10891207
if remaining_retries == 1:
@@ -1092,7 +1210,23 @@ def _sleep_for_retry(
10921210
log.debug("%i retries left", remaining_retries)
10931211

10941212
timeout = self._calculate_retry_timeout(remaining_retries, options, response.headers if response else None)
1095-
log.info("Retrying request to %s in %f seconds", options.url, timeout)
1213+
if response is not None:
1214+
log.info(
1215+
"Retrying request to %s in %f seconds (status %d)",
1216+
options.url,
1217+
timeout,
1218+
response.status_code,
1219+
)
1220+
elif error is not None:
1221+
log.info(
1222+
"Retrying request to %s in %f seconds (%s: %s)",
1223+
options.url,
1224+
timeout,
1225+
type(error).__name__,
1226+
error,
1227+
)
1228+
else:
1229+
log.info("Retrying request to %s in %f seconds", options.url, timeout)
10961230

10971231
time.sleep(timeout)
10981232

@@ -1428,6 +1562,8 @@ def __del__(self) -> None:
14281562
class AsyncAPIClient(BaseClient[httpx.AsyncClient, AsyncStream[Any]]):
14291563
_client: httpx.AsyncClient
14301564
_default_stream_cls: type[AsyncStream[Any]] | None = None
1565+
_uses_shared_pool: bool
1566+
_closed: bool
14311567

14321568
def __init__(
14331569
self,
@@ -1440,6 +1576,7 @@ def __init__(
14401576
http_client: httpx.AsyncClient | None = None,
14411577
custom_headers: Mapping[str, str] | None = None,
14421578
custom_query: Mapping[str, object] | None = None,
1579+
shared_http_pool: bool = True,
14431580
) -> None:
14441581
if not is_given(timeout):
14451582
# if the user passed in a custom http client with a non-default
@@ -1469,20 +1606,59 @@ def __init__(
14691606
custom_headers=custom_headers,
14701607
_strict_response_validation=_strict_response_validation,
14711608
)
1472-
self._client = http_client or AsyncHttpxClientWrapper(
1473-
base_url=base_url,
1474-
# cast to a valid type because mypy doesn't understand our type narrowing
1475-
timeout=cast(Timeout, timeout),
1476-
)
1609+
1610+
self._closed = False
1611+
1612+
if http_client is not None:
1613+
self._client = http_client
1614+
self._uses_shared_pool = False
1615+
elif shared_http_pool:
1616+
try:
1617+
loop: asyncio.AbstractEventLoop | None = asyncio.get_running_loop()
1618+
except RuntimeError:
1619+
loop = None
1620+
if loop is not None:
1621+
with _pool_lock:
1622+
existing = _shared_async_transports.get(loop)
1623+
if existing is not None and existing.acquire():
1624+
transport: _SharedAsyncTransport = existing
1625+
else:
1626+
transport = _SharedAsyncTransport(
1627+
httpx.AsyncHTTPTransport(limits=DEFAULT_CONNECTION_LIMITS, http2=True),
1628+
)
1629+
_shared_async_transports[loop] = transport
1630+
self._client = AsyncHttpxClientWrapper(
1631+
base_url=base_url,
1632+
timeout=cast(Timeout, timeout),
1633+
transport=transport,
1634+
)
1635+
self._uses_shared_pool = True
1636+
else:
1637+
self._client = AsyncHttpxClientWrapper(
1638+
base_url=base_url,
1639+
timeout=cast(Timeout, timeout),
1640+
)
1641+
self._uses_shared_pool = False
1642+
else:
1643+
self._client = AsyncHttpxClientWrapper(
1644+
base_url=base_url,
1645+
timeout=cast(Timeout, timeout),
1646+
)
1647+
self._uses_shared_pool = False
14771648

14781649
def is_closed(self) -> bool:
1479-
return self._client.is_closed
1650+
return self._closed or self._client.is_closed
14801651

14811652
async def close(self) -> None:
14821653
"""Close the underlying HTTPX client.
14831654
14841655
The client will *not* be usable after this.
14851656
"""
1657+
if not hasattr(self, "_client"):
1658+
return
1659+
if self._closed:
1660+
return
1661+
self._closed = True
14861662
await self._client.aclose()
14871663

14881664
async def __aenter__(self: _T) -> _T:
@@ -1603,6 +1779,7 @@ async def request(
16031779
max_retries=max_retries,
16041780
options=input_options,
16051781
response=None,
1782+
error=err,
16061783
)
16071784
continue
16081785

@@ -1617,6 +1794,7 @@ async def request(
16171794
max_retries=max_retries,
16181795
options=input_options,
16191796
response=None,
1797+
error=err,
16201798
)
16211799
continue
16221800

@@ -1668,7 +1846,13 @@ async def request(
16681846
)
16691847

16701848
async def _sleep_for_retry(
1671-
self, *, retries_taken: int, max_retries: int, options: FinalRequestOptions, response: httpx.Response | None
1849+
self,
1850+
*,
1851+
retries_taken: int,
1852+
max_retries: int,
1853+
options: FinalRequestOptions,
1854+
response: httpx.Response | None,
1855+
error: BaseException | None = None,
16721856
) -> None:
16731857
remaining_retries = max_retries - retries_taken
16741858
if remaining_retries == 1:
@@ -1677,7 +1861,23 @@ async def _sleep_for_retry(
16771861
log.debug("%i retries left", remaining_retries)
16781862

16791863
timeout = self._calculate_retry_timeout(remaining_retries, options, response.headers if response else None)
1680-
log.info("Retrying request to %s in %f seconds", options.url, timeout)
1864+
if response is not None:
1865+
log.info(
1866+
"Retrying request to %s in %f seconds (status %d)",
1867+
options.url,
1868+
timeout,
1869+
response.status_code,
1870+
)
1871+
elif error is not None:
1872+
log.info(
1873+
"Retrying request to %s in %f seconds (%s: %s)",
1874+
options.url,
1875+
timeout,
1876+
type(error).__name__,
1877+
error,
1878+
)
1879+
else:
1880+
log.info("Retrying request to %s in %f seconds", options.url, timeout)
16811881

16821882
await anyio.sleep(timeout)
16831883

0 commit comments

Comments
 (0)