|
6 | 6 |
|
7 | 7 | from __future__ import annotations as _annotations
|
8 | 8 |
|
| 9 | +import asyncio |
9 | 10 | from abc import ABC, abstractmethod
|
10 | 11 | from collections.abc import AsyncIterator, Iterator
|
11 | 12 | from contextlib import asynccontextmanager, contextmanager
|
12 | 13 | from dataclasses import dataclass, field
|
13 | 14 | from datetime import datetime
|
14 |
| -from functools import cache |
| 15 | +from functools import cache, lru_cache |
15 | 16 |
|
16 | 17 | import httpx
|
17 | 18 | from typing_extensions import Literal, TypeAliasType
|
@@ -495,25 +496,35 @@ def cached_async_http_client(*, provider: str | None = None, timeout: int = 600,
|
495 | 496 | The default timeouts match those of OpenAI,
|
496 | 497 | see <https://github.com/openai/openai-python/blob/v1.54.4/src/openai/_constants.py#L9>.
|
497 | 498 | """
|
498 |
| - client = _cached_async_http_client(provider=provider, timeout=timeout, connect=connect) |
| 499 | + try: |
| 500 | + loop = asyncio.get_running_loop() |
| 501 | + except RuntimeError: |
| 502 | + loop = None |
| 503 | + |
| 504 | + client = _cached_async_http_client(loop=loop, provider=provider, timeout=timeout, connect=connect) |
499 | 505 | if client.is_closed:
|
500 | 506 | # This happens if the context manager is used, so we need to create a new client.
|
501 | 507 | _cached_async_http_client.cache_clear()
|
502 |
| - client = _cached_async_http_client(provider=provider, timeout=timeout, connect=connect) |
| 508 | + client = _cached_async_http_client(loop=loop, provider=provider, timeout=timeout, connect=connect) |
503 | 509 | return client
|
504 | 510 |
|
505 | 511 |
|
506 |
| -@cache |
507 |
| -def _cached_async_http_client(provider: str | None, timeout: int = 600, connect: int = 5) -> httpx.AsyncClient: |
| 512 | +@lru_cache(maxsize=32) |
| 513 | +def _cached_async_http_client( |
| 514 | + loop: asyncio.AbstractEventLoop, provider: str | None, timeout: int = 600, connect: int = 5 |
| 515 | +) -> httpx.AsyncClient: |
508 | 516 | return httpx.AsyncClient(
|
509 |
| - transport=_cached_async_http_transport(), |
| 517 | + transport=_cached_async_http_transport(loop=loop), |
510 | 518 | timeout=httpx.Timeout(timeout=timeout, connect=connect),
|
511 | 519 | headers={'User-Agent': get_user_agent()},
|
512 | 520 | )
|
513 | 521 |
|
514 | 522 |
|
515 |
| -@cache |
516 |
| -def _cached_async_http_transport() -> httpx.AsyncHTTPTransport: |
| 523 | +@lru_cache(maxsize=32) |
| 524 | +def _cached_async_http_transport(loop: asyncio.AbstractEventLoop) -> httpx.AsyncHTTPTransport: |
| 525 | + # The loop argument is unused, but it's here to ensure the cache key is different |
| 526 | + # for each event loop, because a `httpx.AsyncHTTPTransport instanciated in a loop |
| 527 | + # cannot be used in another loop. |
517 | 528 | return httpx.AsyncHTTPTransport()
|
518 | 529 |
|
519 | 530 |
|
|
0 commit comments