diff --git a/src/agents/models/openai_responses.py b/src/agents/models/openai_responses.py index de8cd93ff..d19628b60 100644 --- a/src/agents/models/openai_responses.py +++ b/src/agents/models/openai_responses.py @@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any, Literal, cast, overload from openai import NOT_GIVEN, APIStatusError, AsyncOpenAI, AsyncStream, NotGiven +from openai._models import add_request_id from openai.types import ChatModel from openai.types.responses import ( Response, @@ -180,6 +181,13 @@ async def stream_response( prompt=prompt, ) + request_id = None + stream_response = getattr(stream, "response", None) + if stream_response is not None: + headers = getattr(stream_response, "headers", None) + if headers is not None: + request_id = headers.get("x-request-id") + final_response: Response | None = None async for chunk in stream: @@ -187,9 +195,12 @@ async def stream_response( final_response = chunk.response yield chunk - if final_response and tracing.include_data(): - span_response.span_data.response = final_response - span_response.span_data.input = input + if final_response: + if request_id: + add_request_id(final_response, request_id) + if tracing.include_data(): + span_response.span_data.response = final_response + span_response.span_data.input = input except Exception as e: span_response.set_error( diff --git a/tests/test_agent_tracing.py b/tests/test_agent_tracing.py index bb16cab26..4ad6194f8 100644 --- a/tests/test_agent_tracing.py +++ b/tests/test_agent_tracing.py @@ -1,15 +1,22 @@ from __future__ import annotations import asyncio +from types import SimpleNamespace import pytest from inline_snapshot import snapshot +from openai.types.responses import ResponseCompletedEvent -from agents import Agent, RunConfig, Runner, trace +from agents import Agent, OpenAIResponsesModel, RunConfig, Runner, trace +from agents.tracing import ResponseSpanData -from .fake_model import FakeModel +from .fake_model import FakeModel, get_response_obj from .test_responses import get_text_message -from .testing_processor import assert_no_traces, fetch_normalized_spans +from .testing_processor import ( + assert_no_traces, + fetch_normalized_spans, + fetch_ordered_spans, +) @pytest.mark.asyncio @@ -292,6 +299,58 @@ async def test_streaming_single_run_is_single_trace(): ) +@pytest.mark.asyncio +@pytest.mark.allow_call_model_methods +async def test_streamed_response_request_id_recorded(): + request_id = "req_test_123" + + class DummyStream: + def __init__(self) -> None: + self.response = SimpleNamespace(headers={"x-request-id": request_id}) + + def __aiter__(self): + async def gen(): + yield ResponseCompletedEvent( + type="response.completed", + response=get_response_obj([get_text_message("first_test")]), + sequence_number=0, + ) + + return gen() + + class DummyResponses: + async def create(self, **kwargs): + assert kwargs.get("stream") is True + return DummyStream() + + class DummyResponsesClient: + def __init__(self) -> None: + self.responses = DummyResponses() + + model = OpenAIResponsesModel(model="gpt-4", openai_client=DummyResponsesClient()) # type: ignore[arg-type] + + agent = Agent( + name="test_agent", + model=model, + ) + + result = Runner.run_streamed(agent, input="first_test") + async for _ in result.stream_events(): + pass + + response_spans = [ + span + for span in fetch_ordered_spans() + if isinstance(span.span_data, ResponseSpanData) and span.span_data.response is not None + ] + + assert response_spans + assert any( + getattr(span.span_data.response, "_request_id", None) == request_id + for span in response_spans + ) + + @pytest.mark.asyncio async def test_multiple_streamed_runs_are_multiple_traces(): model = FakeModel()