Skip to content

chore: dedupe function call try block in _output.py #2217

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 28 additions & 25 deletions pydantic_ai_slim/pydantic_ai/_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ async def execute_function_with_span(
args: dict[str, Any] | Any,
call: _messages.ToolCallPart,
include_tool_call_id: bool = True,
wrap_validation_errors: bool = True,
) -> Any:
"""Execute a function call within a traced span, automatically recording the response."""
# Set up span attributes
Expand All @@ -113,7 +114,20 @@ async def execute_function_with_span(

# Execute function within span
with self.tracer.start_as_current_span('running output function', attributes=attributes) as span:
output = await function_schema.call(args, run_context)
try:
output = await function_schema.call(args, run_context)
except ModelRetry as r:
if wrap_validation_errors:
m = _messages.RetryPromptPart(
content=r.message,
)
if call:
m.tool_name = call.tool_name
if include_tool_call_id and call:
m.tool_call_id = call.tool_call_id
raise ToolRetryError(m) from r
else:
raise

# Record response if content inclusion is enabled
if self.include_content and span.is_recording():
Expand Down Expand Up @@ -760,18 +774,10 @@ async def process(
function_name = getattr(self._function_schema.function, '__name__', 'output_function')
call = _messages.ToolCallPart(tool_name=function_name, args=data)
include_tool_call_id = False
try:
output = await trace_context.execute_function_with_span(
self._function_schema, run_context, output, call, include_tool_call_id
)
except ModelRetry as r:
if wrap_validation_errors:
m = _messages.RetryPromptPart(
content=r.message,
)
raise ToolRetryError(m) from r
else:
raise

output = await trace_context.execute_function_with_span(
self._function_schema, run_context, output, call, include_tool_call_id, wrap_validation_errors
)

return output

Expand Down Expand Up @@ -938,18 +944,15 @@ async def process(
# so we don't have tool call attributes like gen_ai.tool.name or gen_ai.tool.call.id
function_name = getattr(self._function_schema.function, '__name__', 'text_output_function')
call = _messages.ToolCallPart(tool_name=function_name, args=args)
try:
output = await trace_context.execute_function_with_span(
self._function_schema, run_context, args, call, include_tool_call_id=False
)
except ModelRetry as r:
if wrap_validation_errors:
m = _messages.RetryPromptPart(
content=r.message,
)
raise ToolRetryError(m) from r
else:
raise # pragma: no cover

output = await trace_context.execute_function_with_span(
self._function_schema,
run_context,
args,
call,
include_tool_call_id=False,
wrap_validation_errors=wrap_validation_errors,
)

return cast(OutputDataT, output)

Expand Down
6 changes: 4 additions & 2 deletions tests/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -923,14 +923,15 @@ def call_tool(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse:
parts=[
RetryPromptPart(
content='City not found, I only know Mexico City',
tool_name='get_weather',
tool_call_id=IsStr(),
timestamp=IsDatetime(),
)
]
),
ModelResponse(
parts=[TextPart(content='Mexico City')],
usage=Usage(requests=1, request_tokens=70, response_tokens=5, total_tokens=75),
usage=Usage(requests=1, request_tokens=68, response_tokens=5, total_tokens=73),
model_name='function:call_tool:',
timestamp=IsDatetime(),
),
Expand Down Expand Up @@ -1648,14 +1649,15 @@ def call_tool(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse:
parts=[
RetryPromptPart(
content='City not found, I only know Mexico City',
tool_name='get_weather',
tool_call_id=IsStr(),
timestamp=IsDatetime(),
)
]
),
ModelResponse(
parts=[TextPart(content='{"city": "Mexico City"}')],
usage=Usage(requests=1, request_tokens=70, response_tokens=11, total_tokens=81),
usage=Usage(requests=1, request_tokens=68, response_tokens=11, total_tokens=79),
model_name='function:call_tool:',
timestamp=IsDatetime(),
),
Expand Down
Loading