|
6 | 6 |
|
7 | 7 | from __future__ import annotations as _annotations
|
8 | 8 |
|
| 9 | +import asyncio |
9 | 10 | from abc import ABC, abstractmethod
|
10 | 11 | from collections.abc import AsyncIterator, Iterator
|
11 | 12 | from contextlib import asynccontextmanager, contextmanager
|
12 | 13 | from dataclasses import dataclass, field
|
13 | 14 | from datetime import datetime
|
14 |
| -from functools import cache |
| 15 | +from functools import cache, lru_cache |
| 16 | +from types import TracebackType |
15 | 17 |
|
16 | 18 | import httpx
|
17 | 19 | from typing_extensions import Literal, TypeAliasType
|
@@ -506,17 +508,47 @@ def cached_async_http_client(*, provider: str | None = None, timeout: int = 600,
|
506 | 508 | @cache
|
507 | 509 | def _cached_async_http_client(provider: str | None, timeout: int = 600, connect: int = 5) -> httpx.AsyncClient:
|
508 | 510 | return httpx.AsyncClient(
|
509 |
| - transport=_cached_async_http_transport(), |
| 511 | + transport=_PerLoopTransport(), |
510 | 512 | timeout=httpx.Timeout(timeout=timeout, connect=connect),
|
511 | 513 | headers={'User-Agent': get_user_agent()},
|
512 | 514 | )
|
513 | 515 |
|
514 | 516 |
|
515 |
| -@cache |
516 |
| -def _cached_async_http_transport() -> httpx.AsyncHTTPTransport: |
| 517 | +@lru_cache(maxsize=32) |
| 518 | +def _get_transport_for_loop(loop: asyncio.AbstractEventLoop) -> httpx.AsyncHTTPTransport: |
517 | 519 | return httpx.AsyncHTTPTransport()
|
518 | 520 |
|
519 | 521 |
|
| 522 | +class _PerLoopTransport(httpx.AsyncBaseTransport): |
| 523 | + def get_transport(self) -> httpx.AsyncHTTPTransport: |
| 524 | + return _get_transport_for_loop(asyncio.get_running_loop()) |
| 525 | + |
| 526 | + def __init__(self): |
| 527 | + # We need to ensure that if we call __aenter__ on a transport, it won't be removed |
| 528 | + # from the lru cache until __aexit__ is called, so ins |
| 529 | + self._currently_opened: dict[asyncio.AbstractEventLoop, httpx.AsyncHTTPTransport] = {} |
| 530 | + |
| 531 | + async def __aenter__(self): |
| 532 | + loop = asyncio.get_running_loop() |
| 533 | + transport = self._currently_opened[loop] = self.get_transport() |
| 534 | + return transport |
| 535 | + |
| 536 | + async def __aexit__( |
| 537 | + self, |
| 538 | + exc_type: type[BaseException] | None = None, |
| 539 | + exc_value: BaseException | None = None, |
| 540 | + traceback: TracebackType | None = None, |
| 541 | + ) -> None: |
| 542 | + if transport := self._currently_opened.pop(asyncio.get_running_loop(), None): |
| 543 | + await transport.__aexit__(exc_type, exc_value, traceback) |
| 544 | + |
| 545 | + async def handle_async_request(self, request: httpx.Request) -> httpx.Response: |
| 546 | + return await self.get_transport().handle_async_request(request) |
| 547 | + |
| 548 | + async def aclose(self) -> None: |
| 549 | + await self.get_transport().aclose() |
| 550 | + |
| 551 | + |
520 | 552 | @cache
|
521 | 553 | def get_user_agent() -> str:
|
522 | 554 | """Get the user agent string for the HTTP client."""
|
|
0 commit comments