Skip to content

Commit 10ceb8b

Browse files
authored
Export user_agent_override contextmanager (#1768)
This allows us to override the user agent header when using agents with other OA python sdks.
1 parent 827af41 commit 10ceb8b

File tree

7 files changed

+248
-6
lines changed

7 files changed

+248
-6
lines changed

src/agents/extensions/models/litellm_model.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
from ...logger import logger
4040
from ...model_settings import ModelSettings
4141
from ...models.chatcmpl_converter import Converter
42-
from ...models.chatcmpl_helpers import HEADERS
42+
from ...models.chatcmpl_helpers import HEADERS, USER_AGENT_OVERRIDE
4343
from ...models.chatcmpl_stream_handler import ChatCmplStreamHandler
4444
from ...models.fake_id import FAKE_RESPONSES_ID
4545
from ...models.interface import Model, ModelTracing
@@ -353,7 +353,7 @@ async def _fetch_response(
353353
stream_options=stream_options,
354354
reasoning_effort=reasoning_effort,
355355
top_logprobs=model_settings.top_logprobs,
356-
extra_headers={**HEADERS, **(model_settings.extra_headers or {})},
356+
extra_headers=self._merge_headers(model_settings),
357357
api_key=self.api_key,
358358
base_url=self.base_url,
359359
**extra_kwargs,
@@ -384,6 +384,13 @@ def _remove_not_given(self, value: Any) -> Any:
384384
return None
385385
return value
386386

387+
def _merge_headers(self, model_settings: ModelSettings):
388+
merged = {**HEADERS, **(model_settings.extra_headers or {})}
389+
ua_ctx = USER_AGENT_OVERRIDE.get()
390+
if ua_ctx is not None:
391+
merged["User-Agent"] = ua_ctx
392+
return merged
393+
387394

388395
class LitellmConverter:
389396
@classmethod

src/agents/models/chatcmpl_helpers.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from __future__ import annotations
22

3+
from contextvars import ContextVar
4+
35
from openai import AsyncOpenAI
46

57
from ..model_settings import ModelSettings
@@ -8,6 +10,10 @@
810
_USER_AGENT = f"Agents/Python {__version__}"
911
HEADERS = {"User-Agent": _USER_AGENT}
1012

13+
USER_AGENT_OVERRIDE: ContextVar[str | None] = ContextVar(
14+
"openai_chatcompletions_user_agent_override", default=None
15+
)
16+
1117

1218
class ChatCmplHelpers:
1319
@classmethod

src/agents/models/openai_chatcompletions.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from ..usage import Usage
2626
from ..util._json import _to_dump_compatible
2727
from .chatcmpl_converter import Converter
28-
from .chatcmpl_helpers import HEADERS, ChatCmplHelpers
28+
from .chatcmpl_helpers import HEADERS, USER_AGENT_OVERRIDE, ChatCmplHelpers
2929
from .chatcmpl_stream_handler import ChatCmplStreamHandler
3030
from .fake_id import FAKE_RESPONSES_ID
3131
from .interface import Model, ModelTracing
@@ -306,7 +306,7 @@ async def _fetch_response(
306306
reasoning_effort=self._non_null_or_not_given(reasoning_effort),
307307
verbosity=self._non_null_or_not_given(model_settings.verbosity),
308308
top_logprobs=self._non_null_or_not_given(model_settings.top_logprobs),
309-
extra_headers={**HEADERS, **(model_settings.extra_headers or {})},
309+
extra_headers=self._merge_headers(model_settings),
310310
extra_query=model_settings.extra_query,
311311
extra_body=model_settings.extra_body,
312312
metadata=self._non_null_or_not_given(model_settings.metadata),
@@ -349,3 +349,10 @@ def _get_client(self) -> AsyncOpenAI:
349349
if self._client is None:
350350
self._client = AsyncOpenAI()
351351
return self._client
352+
353+
def _merge_headers(self, model_settings: ModelSettings):
354+
merged = {**HEADERS, **(model_settings.extra_headers or {})}
355+
ua_ctx = USER_AGENT_OVERRIDE.get()
356+
if ua_ctx is not None:
357+
merged["User-Agent"] = ua_ctx
358+
return merged

src/agents/models/openai_responses.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import json
44
from collections.abc import AsyncIterator
5+
from contextvars import ContextVar
56
from dataclasses import dataclass
67
from typing import TYPE_CHECKING, Any, Literal, cast, overload
78

@@ -49,6 +50,11 @@
4950
_USER_AGENT = f"Agents/Python {__version__}"
5051
_HEADERS = {"User-Agent": _USER_AGENT}
5152

53+
# Override for the User-Agent header used by the Responses API.
54+
_USER_AGENT_OVERRIDE: ContextVar[str | None] = ContextVar(
55+
"openai_responses_user_agent_override", default=None
56+
)
57+
5258

5359
class OpenAIResponsesModel(Model):
5460
"""
@@ -312,7 +318,7 @@ async def _fetch_response(
312318
tool_choice=tool_choice,
313319
parallel_tool_calls=parallel_tool_calls,
314320
stream=stream,
315-
extra_headers={**_HEADERS, **(model_settings.extra_headers or {})},
321+
extra_headers=self._merge_headers(model_settings),
316322
extra_query=model_settings.extra_query,
317323
extra_body=model_settings.extra_body,
318324
text=response_format,
@@ -327,6 +333,13 @@ def _get_client(self) -> AsyncOpenAI:
327333
self._client = AsyncOpenAI()
328334
return self._client
329335

336+
def _merge_headers(self, model_settings: ModelSettings):
337+
merged = {**_HEADERS, **(model_settings.extra_headers or {})}
338+
ua_ctx = _USER_AGENT_OVERRIDE.get()
339+
if ua_ctx is not None:
340+
merged["User-Agent"] = ua_ctx
341+
return merged
342+
330343

331344
@dataclass
332345
class ConvertedTools:
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
from __future__ import annotations
2+
3+
from typing import Any
4+
5+
import pytest
6+
7+
from agents import ModelSettings, ModelTracing, __version__
8+
from agents.models.chatcmpl_helpers import USER_AGENT_OVERRIDE
9+
10+
11+
@pytest.mark.allow_call_model_methods
12+
@pytest.mark.asyncio
13+
@pytest.mark.parametrize("override_ua", [None, "test_user_agent"])
14+
async def test_user_agent_header_litellm(override_ua: str | None, monkeypatch):
15+
called_kwargs: dict[str, Any] = {}
16+
expected_ua = override_ua or f"Agents/Python {__version__}"
17+
18+
import importlib
19+
import sys
20+
import types as pytypes
21+
22+
litellm_fake: Any = pytypes.ModuleType("litellm")
23+
24+
class DummyMessage:
25+
role = "assistant"
26+
content = "Hello"
27+
tool_calls: list[Any] | None = None
28+
29+
def get(self, _key, _default=None):
30+
return None
31+
32+
def model_dump(self):
33+
return {"role": self.role, "content": self.content}
34+
35+
class Choices: # noqa: N801 - mimic litellm naming
36+
def __init__(self):
37+
self.message = DummyMessage()
38+
39+
class DummyModelResponse:
40+
def __init__(self):
41+
self.choices = [Choices()]
42+
43+
async def acompletion(**kwargs):
44+
nonlocal called_kwargs
45+
called_kwargs = kwargs
46+
return DummyModelResponse()
47+
48+
utils_ns = pytypes.SimpleNamespace()
49+
utils_ns.Choices = Choices
50+
utils_ns.ModelResponse = DummyModelResponse
51+
52+
litellm_types = pytypes.SimpleNamespace(
53+
utils=utils_ns,
54+
llms=pytypes.SimpleNamespace(openai=pytypes.SimpleNamespace(ChatCompletionAnnotation=dict)),
55+
)
56+
litellm_fake.acompletion = acompletion
57+
litellm_fake.types = litellm_types
58+
59+
monkeypatch.setitem(sys.modules, "litellm", litellm_fake)
60+
61+
litellm_mod = importlib.import_module("agents.extensions.models.litellm_model")
62+
monkeypatch.setattr(litellm_mod, "litellm", litellm_fake, raising=True)
63+
LitellmModel = litellm_mod.LitellmModel
64+
65+
model = LitellmModel(model="gpt-4")
66+
67+
if override_ua is not None:
68+
token = USER_AGENT_OVERRIDE.set(override_ua)
69+
else:
70+
token = None
71+
try:
72+
await model.get_response(
73+
system_instructions=None,
74+
input="hi",
75+
model_settings=ModelSettings(),
76+
tools=[],
77+
output_schema=None,
78+
handoffs=[],
79+
tracing=ModelTracing.DISABLED,
80+
previous_response_id=None,
81+
conversation_id=None,
82+
prompt=None,
83+
)
84+
finally:
85+
if token is not None:
86+
USER_AGENT_OVERRIDE.reset(token)
87+
88+
assert "extra_headers" in called_kwargs
89+
assert called_kwargs["extra_headers"]["User-Agent"] == expected_ua

tests/test_openai_chatcompletions.py

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,10 @@
3131
ModelTracing,
3232
OpenAIChatCompletionsModel,
3333
OpenAIProvider,
34+
__version__,
3435
generation_span,
3536
)
36-
from agents.models.chatcmpl_helpers import ChatCmplHelpers
37+
from agents.models.chatcmpl_helpers import USER_AGENT_OVERRIDE, ChatCmplHelpers
3738
from agents.models.fake_id import FAKE_RESPONSES_ID
3839

3940

@@ -370,6 +371,60 @@ def test_store_param():
370371
"Should respect explicitly set store=True"
371372
)
372373

374+
375+
@pytest.mark.allow_call_model_methods
376+
@pytest.mark.asyncio
377+
@pytest.mark.parametrize("override_ua", [None, "test_user_agent"])
378+
async def test_user_agent_header_chat_completions(override_ua):
379+
called_kwargs: dict[str, Any] = {}
380+
expected_ua = override_ua or f"Agents/Python {__version__}"
381+
382+
class DummyCompletions:
383+
async def create(self, **kwargs):
384+
nonlocal called_kwargs
385+
called_kwargs = kwargs
386+
msg = ChatCompletionMessage(role="assistant", content="Hello")
387+
choice = Choice(index=0, finish_reason="stop", message=msg)
388+
return ChatCompletion(
389+
id="resp-id",
390+
created=0,
391+
model="fake",
392+
object="chat.completion",
393+
choices=[choice],
394+
usage=None,
395+
)
396+
397+
class DummyChatClient:
398+
def __init__(self):
399+
self.chat = type("_Chat", (), {"completions": DummyCompletions()})()
400+
self.base_url = "https://api.openai.com"
401+
402+
model = OpenAIChatCompletionsModel(model="gpt-4", openai_client=DummyChatClient()) # type: ignore
403+
404+
if override_ua is not None:
405+
token = USER_AGENT_OVERRIDE.set(override_ua)
406+
else:
407+
token = None
408+
409+
try:
410+
await model.get_response(
411+
system_instructions=None,
412+
input="hi",
413+
model_settings=ModelSettings(),
414+
tools=[],
415+
output_schema=None,
416+
handoffs=[],
417+
tracing=ModelTracing.DISABLED,
418+
previous_response_id=None,
419+
conversation_id=None,
420+
)
421+
finally:
422+
if token is not None:
423+
USER_AGENT_OVERRIDE.reset(token)
424+
425+
assert "extra_headers" in called_kwargs
426+
assert called_kwargs["extra_headers"]["User-Agent"] == expected_ua
427+
373428
client = AsyncOpenAI(base_url="http://www.notopenai.com")
374429
model_settings = ModelSettings()
375430
assert ChatCmplHelpers.get_store_param(client, model_settings) is None, (

tests/test_openai_responses.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
from __future__ import annotations
2+
3+
from typing import Any
4+
5+
import pytest
6+
from openai.types.responses import ResponseCompletedEvent
7+
8+
from agents import ModelSettings, ModelTracing, __version__
9+
from agents.models.openai_responses import _USER_AGENT_OVERRIDE as RESP_UA, OpenAIResponsesModel
10+
from tests.fake_model import get_response_obj
11+
12+
13+
@pytest.mark.allow_call_model_methods
14+
@pytest.mark.asyncio
15+
@pytest.mark.parametrize("override_ua", [None, "test_user_agent"])
16+
async def test_user_agent_header_responses(override_ua: str | None):
17+
called_kwargs: dict[str, Any] = {}
18+
expected_ua = override_ua or f"Agents/Python {__version__}"
19+
20+
class DummyStream:
21+
def __aiter__(self):
22+
async def gen():
23+
yield ResponseCompletedEvent(
24+
type="response.completed",
25+
response=get_response_obj([]),
26+
sequence_number=0,
27+
)
28+
29+
return gen()
30+
31+
class DummyResponses:
32+
async def create(self, **kwargs):
33+
nonlocal called_kwargs
34+
called_kwargs = kwargs
35+
return DummyStream()
36+
37+
class DummyResponsesClient:
38+
def __init__(self):
39+
self.responses = DummyResponses()
40+
41+
model = OpenAIResponsesModel(model="gpt-4", openai_client=DummyResponsesClient()) # type: ignore
42+
43+
if override_ua is not None:
44+
token = RESP_UA.set(override_ua)
45+
else:
46+
token = None
47+
48+
try:
49+
stream = model.stream_response(
50+
system_instructions=None,
51+
input="hi",
52+
model_settings=ModelSettings(),
53+
tools=[],
54+
output_schema=None,
55+
handoffs=[],
56+
tracing=ModelTracing.DISABLED,
57+
)
58+
async for _ in stream:
59+
pass
60+
finally:
61+
if token is not None:
62+
RESP_UA.reset(token)
63+
64+
assert "extra_headers" in called_kwargs
65+
assert called_kwargs["extra_headers"]["User-Agent"] == expected_ua

0 commit comments

Comments
 (0)