Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pydantic_ai_slim/pydantic_ai/_agent_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ class GraphAgentState:
usage: _usage.RunUsage = dataclasses.field(default_factory=_usage.RunUsage)
retries: int = 0
run_step: int = 0
run_id: str | None = None

def increment_retries(
self,
Expand Down Expand Up @@ -469,6 +470,7 @@ async def _make_request(
async def _prepare_request(
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
) -> tuple[ModelSettings | None, models.ModelRequestParameters, list[_messages.ModelMessage], RunContext[DepsT]]:
self.request.run_id = self.request.run_id or ctx.state.run_id
ctx.state.message_history.append(self.request)

ctx.state.run_step += 1
Expand Down Expand Up @@ -510,6 +512,7 @@ def _finish_handling(
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
response: _messages.ModelResponse,
) -> CallToolsNode[DepsT, NodeRunEndT]:
response.run_id = response.run_id or ctx.state.run_id
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may also need to do this in StreamedRunResult._marked_completed to get this to work right when using agent.run_stream:

if message is not None:
self._all_messages.append(message)

You'll want to test if it currently works, and if not add it.

We also definitely need to set the run_id here:

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I made it a part of RunContext and passed it, added for the others two as well

# Update usage
ctx.state.usage.incr(response.usage)
if ctx.deps.usage_limits: # pragma: no branch
Expand Down
2 changes: 2 additions & 0 deletions pydantic_ai_slim/pydantic_ai/agent/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import dataclasses
import inspect
import json
import uuid
import warnings
from asyncio import Lock
from collections.abc import AsyncIterator, Awaitable, Callable, Iterator, Sequence
Expand Down Expand Up @@ -572,6 +573,7 @@ async def main():
usage=usage,
retries=0,
run_step=0,
run_id=str(uuid.uuid4()),
)

# Merge model settings in order of precedence: run > agent > model
Expand Down
6 changes: 6 additions & 0 deletions pydantic_ai_slim/pydantic_ai/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -947,6 +947,9 @@ class ModelRequest:
kind: Literal['request'] = 'request'
"""Message type identifier, this is available on all parts as a discriminator."""

run_id: str | None = None
"""A unique identifier to identify the run."""

@classmethod
def user_text_prompt(cls, user_prompt: str, *, instructions: str | None = None) -> ModelRequest:
"""Create a `ModelRequest` with a single user prompt as text."""
Expand Down Expand Up @@ -1188,6 +1191,9 @@ class ModelResponse:
finish_reason: FinishReason | None = None
"""Reason the model finished generating the response, normalized to OpenTelemetry values."""

run_id: str | None = None
"""A unique identifier to identify the run."""

@property
def text(self) -> str | None:
"""Get the text in the response."""
Expand Down
11 changes: 11 additions & 0 deletions tests/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2506,6 +2506,17 @@ def llm(messages: list[ModelMessage], _info: AgentInfo) -> ModelResponse:
)


def test_agent_message_history_includes_run_id() -> None:
agent = Agent(TestModel(custom_output_text='testing run_id'))

result = agent.run_sync('Hello')
history = result.all_messages()

run_ids = [message.run_id for message in history]
assert run_ids == snapshot([IsStr(), IsStr()])
assert len({*run_ids}) == snapshot(1)


def test_unknown_tool():
def empty(_: list[ModelMessage], _info: AgentInfo) -> ModelResponse:
return ModelResponse(parts=[ToolCallPart('foobar', '{}')])
Expand Down
18 changes: 18 additions & 0 deletions tests/test_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,24 @@ def test_file_part_serialization_roundtrip():
assert deserialized == messages


def test_model_messages_type_adapter_preserves_run_id():
messages: list[ModelMessage] = [
ModelRequest(
parts=[UserPromptPart(content='Hi there', timestamp=datetime.now(tz=timezone.utc))],
run_id='run-123',
),
ModelResponse(
parts=[TextPart(content='Hello!')],
run_id='run-123',
),
]

serialized = ModelMessagesTypeAdapter.dump_python(messages, mode='python')
deserialized = ModelMessagesTypeAdapter.validate_python(serialized)

assert [message.run_id for message in deserialized] == snapshot(['run-123', 'run-123'])


def test_model_response_convenience_methods():
response = ModelResponse(parts=[])
assert response.text == snapshot(None)
Expand Down
Loading