Skip to content

Commit 45b685d

Browse files
committed
don't share httpx.AsyncHTTPTransport between event loops
1 parent 63af922 commit 45b685d

File tree

1 file changed

+36
-4
lines changed

1 file changed

+36
-4
lines changed

pydantic_ai_slim/pydantic_ai/models/__init__.py

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,14 @@
66

77
from __future__ import annotations as _annotations
88

9+
import asyncio
910
from abc import ABC, abstractmethod
1011
from collections.abc import AsyncIterator, Iterator
1112
from contextlib import asynccontextmanager, contextmanager
1213
from dataclasses import dataclass, field
1314
from datetime import datetime
14-
from functools import cache
15+
from functools import cache, lru_cache
16+
from types import TracebackType
1517

1618
import httpx
1719
from typing_extensions import Literal, TypeAliasType
@@ -506,17 +508,47 @@ def cached_async_http_client(*, provider: str | None = None, timeout: int = 600,
506508
@cache
507509
def _cached_async_http_client(provider: str | None, timeout: int = 600, connect: int = 5) -> httpx.AsyncClient:
508510
return httpx.AsyncClient(
509-
transport=_cached_async_http_transport(),
511+
transport=_PerLoopTransport(),
510512
timeout=httpx.Timeout(timeout=timeout, connect=connect),
511513
headers={'User-Agent': get_user_agent()},
512514
)
513515

514516

515-
@cache
516-
def _cached_async_http_transport() -> httpx.AsyncHTTPTransport:
517+
@lru_cache(maxsize=32)
518+
def _get_transport_for_loop(loop: asyncio.AbstractEventLoop) -> httpx.AsyncHTTPTransport:
517519
return httpx.AsyncHTTPTransport()
518520

519521

522+
class _PerLoopTransport(httpx.AsyncBaseTransport):
523+
def get_transport(self) -> httpx.AsyncHTTPTransport:
524+
return _get_transport_for_loop(asyncio.get_running_loop())
525+
526+
def __init__(self):
527+
# We need to ensure that if we call __aenter__ on a transport, it won't be removed
528+
# from the lru cache until __aexit__ is called, so ins
529+
self._currently_opened: dict[asyncio.AbstractEventLoop, httpx.AsyncHTTPTransport] = {}
530+
531+
async def __aenter__(self):
532+
loop = asyncio.get_running_loop()
533+
transport = self._currently_opened[loop] = self.get_transport()
534+
return transport
535+
536+
async def __aexit__(
537+
self,
538+
exc_type: type[BaseException] | None = None,
539+
exc_value: BaseException | None = None,
540+
traceback: TracebackType | None = None,
541+
) -> None:
542+
if transport := self._currently_opened.pop(asyncio.get_running_loop(), None):
543+
await transport.__aexit__(exc_type, exc_value, traceback)
544+
545+
async def handle_async_request(self, request: httpx.Request) -> httpx.Response:
546+
return await self.get_transport().handle_async_request(request)
547+
548+
async def aclose(self) -> None:
549+
await self.get_transport().aclose()
550+
551+
520552
@cache
521553
def get_user_agent() -> str:
522554
"""Get the user agent string for the HTTP client."""

0 commit comments

Comments
 (0)