diff --git a/docs/deferred-tools.md b/docs/deferred-tools.md
index e5e5201163..5fa8c7d8b3 100644
--- a/docs/deferred-tools.md
+++ b/docs/deferred-tools.md
@@ -77,6 +77,7 @@ DeferredToolRequests(
tool_call_id='delete_file',
),
],
+ metadata={},
)
"""
@@ -247,6 +248,7 @@ async def main():
)
],
approvals=[],
+ metadata={},
)
"""
@@ -320,6 +322,151 @@ async def main():
_(This example is complete, it can be run "as is" — you'll need to add `asyncio.run(main())` to run `main`)_
+## Attaching Metadata to Deferred Tools
+
+Both [`CallDeferred`][pydantic_ai.exceptions.CallDeferred] and [`ApprovalRequired`][pydantic_ai.exceptions.ApprovalRequired] exceptions accept an optional `metadata` parameter that allows you to attach arbitrary context information to deferred tool calls. This metadata is available in [`DeferredToolRequests.metadata`][pydantic_ai.tools.DeferredToolRequests.metadata] keyed by tool call ID.
+
+Common use cases include cost estimates for approval decisions and tracking information for external systems.
+
+```python {title="deferred_tools_with_metadata.py"}
+from dataclasses import dataclass
+
+from pydantic_ai import (
+ Agent,
+ ApprovalRequired,
+ CallDeferred,
+ DeferredToolRequests,
+ DeferredToolResults,
+ RunContext,
+ ToolApproved,
+ ToolDenied,
+)
+
+
+@dataclass
+class User:
+ home_location: str = 'St. Louis, MO'
+
+
+class FlightAPI:
+ COSTS = {
+ ('St. Louis, MO', 'Lisbon, Portugal'): 850,
+ ('St. Louis, MO', 'Santiago, Chile'): 1200,
+ ('St. Louis, MO', 'Los Angeles, CA'): 300,
+ }
+
+ def get_flight_cost(self, origin: str, destination: str) -> int:
+ return self.COSTS.get((origin, destination), 500)
+
+ def get_airline_auth_url(self, airline: str) -> str:
+ # In real code, this might generate a proper OAuth URL
+ return f"https://example.com/auth/{airline.lower().replace(' ', '-')}"
+
+
+@dataclass
+class TravelDeps:
+ user: User
+ flight_api: FlightAPI
+
+
+agent = Agent(
+ 'openai:gpt-5',
+ deps_type=TravelDeps,
+ output_type=[str, DeferredToolRequests],
+)
+
+
+@agent.tool
+def book_flight(ctx: RunContext[TravelDeps], destination: str) -> str:
+ """Book a flight to the destination."""
+ if not ctx.tool_call_approved:
+ # Look up cost based on user's location and destination
+ cost = ctx.deps.flight_api.get_flight_cost(
+ ctx.deps.user.home_location,
+ destination
+ )
+
+ raise ApprovalRequired(
+ metadata={
+ 'origin': ctx.deps.user.home_location,
+ 'destination': destination,
+ 'cost_usd': cost,
+ }
+ )
+
+ return f'Flight booked to {destination}'
+
+
+@agent.tool
+def authenticate_with_airline(ctx: RunContext[TravelDeps], airline: str) -> str:
+ """Authenticate with airline website to link frequent flyer account."""
+ # Generate auth URL that would normally open in browser
+ auth_url = ctx.deps.flight_api.get_airline_auth_url(airline)
+
+ # Cannot complete auth in this process - need user interaction
+ raise CallDeferred(
+ metadata={
+ 'airline': airline,
+ 'auth_url': auth_url,
+ }
+ )
+
+
+# Set up dependencies
+user = User(home_location='St. Louis, MO')
+flight_api = FlightAPI()
+deps = TravelDeps(user=user, flight_api=flight_api)
+
+# Agent calls both tools
+result = agent.run_sync(
+ 'Book a flight to Lisbon, Portugal and link my SkyWay Airlines account',
+ deps=deps,
+)
+messages = result.all_messages()
+
+assert isinstance(result.output, DeferredToolRequests)
+requests = result.output
+
+# Make approval decision using metadata
+results = DeferredToolResults()
+for call in requests.approvals:
+ metadata = requests.metadata.get(call.tool_call_id, {})
+ cost = metadata.get('cost_usd', 0)
+
+ print(f'Approval needed: {call.tool_name}')
+ #> Approval needed: book_flight
+ print(f" {metadata['origin']} → {metadata['destination']}: ${cost}")
+ #> St. Louis, MO → Lisbon, Portugal: $850
+
+ if cost < 1000:
+ results.approvals[call.tool_call_id] = ToolApproved()
+ else:
+ results.approvals[call.tool_call_id] = ToolDenied('Cost exceeds budget')
+
+# Handle deferred calls using metadata
+for call in requests.calls:
+ metadata = requests.metadata.get(call.tool_call_id, {})
+ auth_url = metadata.get('auth_url')
+
+ print(f'Browser auth required: {auth_url}')
+ #> Browser auth required: https://example.com/auth/skyway-airlines
+
+ # In real code: open browser, wait for auth completion
+ # For demo, just mark as completed
+ results.calls[call.tool_call_id] = 'Frequent flyer account linked'
+
+# Continue with results
+result = agent.run_sync(
+ message_history=messages,
+ deferred_tool_results=results,
+ deps=deps,
+)
+print(result.output)
+#> Flight to Lisbon booked successfully and your SkyWay Airlines account is now linked.
+```
+
+_(This example is complete, it can be run "as is")_
+
## See Also
- [Function Tools](tools.md) - Basic tool concepts and registration
diff --git a/docs/toolsets.md b/docs/toolsets.md
index 8d970b8e31..1b041b3baa 100644
--- a/docs/toolsets.md
+++ b/docs/toolsets.md
@@ -362,6 +362,7 @@ DeferredToolRequests(
tool_call_id='pyd_ai_tool_call_id__temperature_fahrenheit',
),
],
+ metadata={},
)
"""
diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py
index c167521079..7285908123 100644
--- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py
+++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py
@@ -883,6 +883,7 @@ async def process_tool_calls( # noqa: C901
calls_to_run = [call for call in calls_to_run if call.tool_call_id in calls_to_run_results]
deferred_calls: dict[Literal['external', 'unapproved'], list[_messages.ToolCallPart]] = defaultdict(list)
+ deferred_metadata: dict[str, dict[str, Any]] = {}
if calls_to_run:
async for event in _call_tools(
@@ -894,6 +895,7 @@ async def process_tool_calls( # noqa: C901
usage_limits=ctx.deps.usage_limits,
output_parts=output_parts,
output_deferred_calls=deferred_calls,
+ output_deferred_metadata=deferred_metadata,
):
yield event
@@ -927,6 +929,7 @@ async def process_tool_calls( # noqa: C901
deferred_tool_requests = _output.DeferredToolRequests(
calls=deferred_calls['external'],
approvals=deferred_calls['unapproved'],
+ metadata=deferred_metadata,
)
final_result = result.FinalResult(cast(NodeRunEndT, deferred_tool_requests), None, None)
@@ -944,10 +947,12 @@ async def _call_tools(
usage_limits: _usage.UsageLimits,
output_parts: list[_messages.ModelRequestPart],
output_deferred_calls: dict[Literal['external', 'unapproved'], list[_messages.ToolCallPart]],
+ output_deferred_metadata: dict[str, dict[str, Any]],
) -> AsyncIterator[_messages.HandleResponseEvent]:
tool_parts_by_index: dict[int, _messages.ModelRequestPart] = {}
user_parts_by_index: dict[int, _messages.UserPromptPart] = {}
deferred_calls_by_index: dict[int, Literal['external', 'unapproved']] = {}
+ deferred_metadata_by_index: dict[int, dict[str, Any] | None] = {}
if usage_limits.tool_calls_limit is not None:
projected_usage = deepcopy(usage)
@@ -982,10 +987,12 @@ async def handle_call_or_result(
tool_part, tool_user_content = (
(await coro_or_task) if inspect.isawaitable(coro_or_task) else coro_or_task.result()
)
- except exceptions.CallDeferred:
+ except exceptions.CallDeferred as e:
deferred_calls_by_index[index] = 'external'
- except exceptions.ApprovalRequired:
+ deferred_metadata_by_index[index] = e.metadata
+ except exceptions.ApprovalRequired as e:
deferred_calls_by_index[index] = 'unapproved'
+ deferred_metadata_by_index[index] = e.metadata
else:
tool_parts_by_index[index] = tool_part
if tool_user_content:
@@ -1023,8 +1030,25 @@ async def handle_call_or_result(
output_parts.extend([tool_parts_by_index[k] for k in sorted(tool_parts_by_index)])
output_parts.extend([user_parts_by_index[k] for k in sorted(user_parts_by_index)])
+ _populate_deferred_calls(
+ tool_calls, deferred_calls_by_index, deferred_metadata_by_index, output_deferred_calls, output_deferred_metadata
+ )
+
+
+def _populate_deferred_calls(
+ tool_calls: list[_messages.ToolCallPart],
+ deferred_calls_by_index: dict[int, Literal['external', 'unapproved']],
+ deferred_metadata_by_index: dict[int, dict[str, Any] | None],
+ output_deferred_calls: dict[Literal['external', 'unapproved'], list[_messages.ToolCallPart]],
+ output_deferred_metadata: dict[str, dict[str, Any]],
+) -> None:
+ """Populate deferred calls and metadata from indexed mappings."""
for k in sorted(deferred_calls_by_index):
- output_deferred_calls[deferred_calls_by_index[k]].append(tool_calls[k])
+ call = tool_calls[k]
+ output_deferred_calls[deferred_calls_by_index[k]].append(call)
+ metadata = deferred_metadata_by_index[k]
+ if metadata is not None:
+ output_deferred_metadata[call.tool_call_id] = metadata
async def _call_tool(
diff --git a/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_toolset.py b/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_toolset.py
index d4adb4b6a7..a85b35ee4a 100644
--- a/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_toolset.py
+++ b/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_toolset.py
@@ -27,11 +27,13 @@ class CallToolParams:
@dataclass
class _ApprovalRequired:
+ metadata: dict[str, Any] | None = None
kind: Literal['approval_required'] = 'approval_required'
@dataclass
class _CallDeferred:
+ metadata: dict[str, Any] | None = None
kind: Literal['call_deferred'] = 'call_deferred'
@@ -75,10 +77,10 @@ async def _wrap_call_tool_result(self, coro: Awaitable[Any]) -> CallToolResult:
try:
result = await coro
return _ToolReturn(result=result)
- except ApprovalRequired:
- return _ApprovalRequired()
- except CallDeferred:
- return _CallDeferred()
+ except ApprovalRequired as e:
+ return _ApprovalRequired(metadata=e.metadata)
+ except CallDeferred as e:
+ return _CallDeferred(metadata=e.metadata)
except ModelRetry as e:
return _ModelRetry(message=e.message)
@@ -86,9 +88,9 @@ def _unwrap_call_tool_result(self, result: CallToolResult) -> Any:
if isinstance(result, _ToolReturn):
return result.result
elif isinstance(result, _ApprovalRequired):
- raise ApprovalRequired()
+ raise ApprovalRequired(metadata=result.metadata)
elif isinstance(result, _CallDeferred):
- raise CallDeferred()
+ raise CallDeferred(metadata=result.metadata)
elif isinstance(result, _ModelRetry):
raise ModelRetry(result.message)
else:
diff --git a/pydantic_ai_slim/pydantic_ai/exceptions.py b/pydantic_ai_slim/pydantic_ai/exceptions.py
index ae5cce0908..2e8358c2fc 100644
--- a/pydantic_ai_slim/pydantic_ai/exceptions.py
+++ b/pydantic_ai_slim/pydantic_ai/exceptions.py
@@ -67,18 +67,30 @@ class CallDeferred(Exception):
"""Exception to raise when a tool call should be deferred.
See [tools docs](../deferred-tools.md#deferred-tools) for more information.
+
+ Args:
+ metadata: Optional dictionary of metadata to attach to the deferred tool call.
+ This metadata will be available in `DeferredToolRequests.metadata` keyed by `tool_call_id`.
"""
- pass
+ def __init__(self, metadata: dict[str, Any] | None = None):
+ self.metadata = metadata
+ super().__init__()
class ApprovalRequired(Exception):
"""Exception to raise when a tool call requires human-in-the-loop approval.
See [tools docs](../deferred-tools.md#human-in-the-loop-tool-approval) for more information.
+
+ Args:
+ metadata: Optional dictionary of metadata to attach to the deferred tool call.
+ This metadata will be available in `DeferredToolRequests.metadata` keyed by `tool_call_id`.
"""
- pass
+ def __init__(self, metadata: dict[str, Any] | None = None):
+ self.metadata = metadata
+ super().__init__()
class UserError(RuntimeError):
diff --git a/pydantic_ai_slim/pydantic_ai/tools.py b/pydantic_ai_slim/pydantic_ai/tools.py
index da053a5191..beae26661b 100644
--- a/pydantic_ai_slim/pydantic_ai/tools.py
+++ b/pydantic_ai_slim/pydantic_ai/tools.py
@@ -147,6 +147,12 @@ class DeferredToolRequests:
"""Tool calls that require external execution."""
approvals: list[ToolCallPart] = field(default_factory=list)
"""Tool calls that require human-in-the-loop approval."""
+ metadata: dict[str, dict[str, Any]] = field(default_factory=dict)
+ """Metadata for deferred tool calls, keyed by tool_call_id.
+
+ This contains any metadata that was provided when raising [`CallDeferred`][pydantic_ai.exceptions.CallDeferred]
+ or [`ApprovalRequired`][pydantic_ai.exceptions.ApprovalRequired] exceptions.
+ """
@dataclass(kw_only=True)
diff --git a/tests/test_agent.py b/tests/test_agent.py
index 0a6bf1e325..2b89c3f38b 100644
--- a/tests/test_agent.py
+++ b/tests/test_agent.py
@@ -4857,9 +4857,13 @@ def call_second():
else:
result = agent.run_sync(user_prompt)
- assert result.output == snapshot(
- DeferredToolRequests(approvals=[ToolCallPart(tool_name='requires_approval', tool_call_id=IsStr())])
- )
+ assert isinstance(result.output, DeferredToolRequests)
+ assert len(result.output.approvals) == 1
+ assert result.output.approvals[0].tool_name == 'requires_approval'
+ # When no metadata is provided, the tool_call_id should not be in metadata dict
+ tool_call_id = result.output.approvals[0].tool_call_id
+ assert tool_call_id not in result.output.metadata
+ assert result.output.metadata == {}
assert integer_holder == 2
diff --git a/tests/test_examples.py b/tests/test_examples.py
index c7c32c340d..dce4bca048 100644
--- a/tests/test_examples.py
+++ b/tests/test_examples.py
@@ -708,6 +708,21 @@ async def model_logic( # noqa: C901
TextPart(content='The factorial of 15 is **1,307,674,368,000**.'),
]
)
+ elif m.content == 'Book a flight to Lisbon, Portugal and link my SkyWay Airlines account':
+ return ModelResponse(
+ parts=[
+ ToolCallPart(
+ tool_name='book_flight',
+ args={'destination': 'Lisbon, Portugal'},
+ tool_call_id='pyd_ai_tool_call_id_1',
+ ),
+ ToolCallPart(
+ tool_name='authenticate_with_airline',
+ args={'airline': 'SkyWay Airlines'},
+ tool_call_id='pyd_ai_tool_call_id_2',
+ ),
+ ]
+ )
elif isinstance(m, ToolReturnPart) and m.tool_name == 'roulette_wheel':
win = m.content == 'winner'
@@ -871,10 +886,24 @@ async def model_logic( # noqa: C901
return ModelResponse(
parts=[TextPart('The answer to the ultimate question of life, the universe, and everything is 42.')]
)
- else:
+ elif isinstance(m, ToolReturnPart) and m.tool_name in ('book_flight', 'authenticate_with_airline'):
+ # After deferred tools complete, check if we have all results to provide final response
+ tool_names = {part.tool_name for msg in messages for part in msg.parts if isinstance(part, ToolReturnPart)}
+ if 'book_flight' in tool_names and 'authenticate_with_airline' in tool_names:
+ return ModelResponse(
+ parts=[TextPart('Flight to Lisbon booked successfully and your SkyWay Airlines account is now linked.')]
+ )
+ # If we don't have both results yet, just acknowledge the tool result
+ return ModelResponse(parts=[TextPart(f'Received result from {m.tool_name}')])
+
+ if isinstance(m, ToolReturnPart):
sys.stdout.write(str(debug.format(messages, info)))
raise RuntimeError(f'Unexpected message: {m}')
+ # Fallback for any other message type
+ sys.stdout.write(str(debug.format(messages, info)))
+ raise RuntimeError(f'Unexpected message type: {type(m).__name__}')
+
async def stream_model_logic( # noqa C901
messages: list[ModelMessage], info: AgentInfo
diff --git a/tests/test_streaming.py b/tests/test_streaming.py
index 1a126f26dc..d66d5510e0 100644
--- a/tests/test_streaming.py
+++ b/tests/test_streaming.py
@@ -1149,9 +1149,13 @@ def regular_tool(x: int) -> int:
async with agent.run_stream('test early strategy with external tool call') as result:
response = await result.get_output()
- assert response == snapshot(
- DeferredToolRequests(calls=[ToolCallPart(tool_name='deferred_tool', tool_call_id=IsStr())])
- )
+ assert isinstance(response, DeferredToolRequests)
+ assert len(response.calls) == 1
+ assert response.calls[0].tool_name == 'deferred_tool'
+ # When no metadata is provided, the tool_call_id should not be in metadata dict
+ tool_call_id = response.calls[0].tool_call_id
+ assert tool_call_id not in response.metadata
+ assert response.metadata == {}
messages = result.all_messages()
# Verify no tools were called
@@ -1638,9 +1642,7 @@ def my_tool(x: int) -> int:
[DeferredToolRequests(calls=[ToolCallPart(tool_name='my_tool', args={'x': 0}, tool_call_id=IsStr())])]
)
assert await result.get_output() == snapshot(
- DeferredToolRequests(
- calls=[ToolCallPart(tool_name='my_tool', args={'x': 0}, tool_call_id=IsStr())],
- )
+ DeferredToolRequests(calls=[ToolCallPart(tool_name='my_tool', args={'x': 0}, tool_call_id=IsStr())])
)
responses = [c async for c, _is_last in result.stream_responses(debounce_by=None)]
assert responses == snapshot(
@@ -1682,9 +1684,7 @@ def my_tool(ctx: RunContext[None], x: int) -> int:
messages = result.all_messages()
output = await result.get_output()
assert output == snapshot(
- DeferredToolRequests(
- approvals=[ToolCallPart(tool_name='my_tool', args='{"x": 1}', tool_call_id=IsStr())],
- )
+ DeferredToolRequests(approvals=[ToolCallPart(tool_name='my_tool', args='{"x": 1}', tool_call_id=IsStr())])
)
assert result.is_complete
diff --git a/tests/test_tools.py b/tests/test_tools.py
index ea26d8ac91..a92e831873 100644
--- a/tests/test_tools.py
+++ b/tests/test_tools.py
@@ -1318,9 +1318,7 @@ def my_tool(x: int) -> int:
result = agent.run_sync('Hello')
assert result.output == snapshot(
- DeferredToolRequests(
- calls=[ToolCallPart(tool_name='my_tool', args={'x': 0}, tool_call_id=IsStr())],
- )
+ DeferredToolRequests(calls=[ToolCallPart(tool_name='my_tool', args={'x': 0}, tool_call_id=IsStr())])
)
@@ -1394,6 +1392,187 @@ def my_tool(ctx: RunContext[None], x: int) -> int:
assert result.output == snapshot('Done!')
+def test_call_deferred_with_metadata():
+ """Test that CallDeferred exception can carry metadata."""
+ agent = Agent(TestModel(), output_type=[str, DeferredToolRequests])
+
+ @agent.tool_plain
+ def my_tool(x: int) -> int:
+ raise CallDeferred(metadata={'task_id': 'task-123', 'estimated_cost': 25.50})
+
+ result = agent.run_sync('Hello')
+ assert isinstance(result.output, DeferredToolRequests)
+ assert len(result.output.calls) == 1
+
+ tool_call_id = result.output.calls[0].tool_call_id
+ assert tool_call_id in result.output.metadata
+ assert result.output.metadata[tool_call_id] == {'task_id': 'task-123', 'estimated_cost': 25.50}
+
+
+def test_approval_required_with_metadata():
+ """Test that ApprovalRequired exception can carry metadata."""
+
+ def llm(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse:
+ if len(messages) == 1:
+ return ModelResponse(
+ parts=[
+ ToolCallPart('my_tool', {'x': 1}, tool_call_id='my_tool'),
+ ]
+ )
+ else:
+ return ModelResponse(
+ parts=[
+ TextPart('Done!'),
+ ]
+ )
+
+ agent = Agent(FunctionModel(llm), output_type=[str, DeferredToolRequests])
+
+ @agent.tool
+ def my_tool(ctx: RunContext[None], x: int) -> int:
+ if not ctx.tool_call_approved:
+ raise ApprovalRequired(
+ metadata={
+ 'reason': 'High compute cost',
+ 'estimated_time': '5 minutes',
+ 'cost_usd': 100.0,
+ }
+ )
+ return x * 42
+
+ result = agent.run_sync('Hello')
+ assert isinstance(result.output, DeferredToolRequests)
+ assert len(result.output.approvals) == 1
+
+ assert 'my_tool' in result.output.metadata
+ assert result.output.metadata['my_tool'] == {
+ 'reason': 'High compute cost',
+ 'estimated_time': '5 minutes',
+ 'cost_usd': 100.0,
+ }
+
+ # Continue with approval
+ messages = result.all_messages()
+ result = agent.run_sync(
+ message_history=messages,
+ deferred_tool_results=DeferredToolResults(approvals={'my_tool': ToolApproved()}),
+ )
+ assert result.output == 'Done!'
+
+
+def test_call_deferred_without_metadata():
+ """Test backward compatibility: CallDeferred without metadata still works."""
+ agent = Agent(TestModel(), output_type=[str, DeferredToolRequests])
+
+ @agent.tool_plain
+ def my_tool(x: int) -> int:
+ raise CallDeferred # No metadata
+
+ result = agent.run_sync('Hello')
+ assert isinstance(result.output, DeferredToolRequests)
+ assert len(result.output.calls) == 1
+
+ tool_call_id = result.output.calls[0].tool_call_id
+ # Should have an empty metadata dict for this tool
+ assert result.output.metadata.get(tool_call_id, {}) == {}
+
+
+def test_approval_required_without_metadata():
+ """Test backward compatibility: ApprovalRequired without metadata still works."""
+
+ def llm(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse:
+ if len(messages) == 1:
+ return ModelResponse(
+ parts=[
+ ToolCallPart('my_tool', {'x': 1}, tool_call_id='my_tool'),
+ ]
+ )
+ else:
+ return ModelResponse(
+ parts=[
+ TextPart('Done!'),
+ ]
+ )
+
+ agent = Agent(FunctionModel(llm), output_type=[str, DeferredToolRequests])
+
+ @agent.tool
+ def my_tool(ctx: RunContext[None], x: int) -> int:
+ if not ctx.tool_call_approved:
+ raise ApprovalRequired # No metadata
+ return x * 42
+
+ result = agent.run_sync('Hello')
+ assert isinstance(result.output, DeferredToolRequests)
+ assert len(result.output.approvals) == 1
+
+ # Should have an empty metadata dict for this tool
+ assert result.output.metadata.get('my_tool', {}) == {}
+
+ # Continue with approval
+ messages = result.all_messages()
+ result = agent.run_sync(
+ message_history=messages,
+ deferred_tool_results=DeferredToolResults(approvals={'my_tool': ToolApproved()}),
+ )
+ assert result.output == 'Done!'
+
+
+def test_mixed_deferred_tools_with_metadata():
+ """Test multiple deferred tools with different metadata."""
+
+ def llm(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse:
+ if len(messages) == 1:
+ return ModelResponse(
+ parts=[
+ ToolCallPart('tool_a', {'x': 1}, tool_call_id='call_a'),
+ ToolCallPart('tool_b', {'y': 2}, tool_call_id='call_b'),
+ ToolCallPart('tool_c', {'z': 3}, tool_call_id='call_c'),
+ ]
+ )
+ else:
+ return ModelResponse(parts=[TextPart('Done!')])
+
+ agent = Agent(FunctionModel(llm), output_type=[str, DeferredToolRequests])
+
+ @agent.tool
+ def tool_a(ctx: RunContext[None], x: int) -> int:
+ raise CallDeferred(metadata={'type': 'external', 'priority': 'high'})
+
+ @agent.tool
+ def tool_b(ctx: RunContext[None], y: int) -> int:
+ if not ctx.tool_call_approved:
+ raise ApprovalRequired(metadata={'reason': 'Needs approval', 'level': 'manager'})
+ return y * 10
+
+ @agent.tool
+ def tool_c(ctx: RunContext[None], z: int) -> int:
+ raise CallDeferred # No metadata
+
+ result = agent.run_sync('Hello')
+ assert isinstance(result.output, DeferredToolRequests)
+
+ # Check that we have the right tools deferred
+ assert len(result.output.calls) == 2 # tool_a and tool_c
+ assert len(result.output.approvals) == 1 # tool_b
+
+ # Check metadata
+ assert result.output.metadata['call_a'] == {'type': 'external', 'priority': 'high'}
+ assert result.output.metadata['call_b'] == {'reason': 'Needs approval', 'level': 'manager'}
+ assert result.output.metadata.get('call_c', {}) == {}
+
+ # Continue with results for all three tools
+ messages = result.all_messages()
+ result = agent.run_sync(
+ message_history=messages,
+ deferred_tool_results=DeferredToolResults(
+ calls={'call_a': 10, 'call_c': 30},
+ approvals={'call_b': ToolApproved()},
+ ),
+ )
+ assert result.output == 'Done!'
+
+
def test_deferred_tool_with_output_type():
class MyModel(BaseModel):
foo: str
@@ -1583,7 +1762,7 @@ def buy(fruit: str):
ToolCallPart(tool_name='buy', args={'fruit': 'apple'}, tool_call_id='buy_apple'),
ToolCallPart(tool_name='buy', args={'fruit': 'banana'}, tool_call_id='buy_banana'),
ToolCallPart(tool_name='buy', args={'fruit': 'pear'}, tool_call_id='buy_pear'),
- ],
+ ]
)
)
diff --git a/tests/test_ui.py b/tests/test_ui.py
index 38f9950ad5..a497d09389 100644
--- a/tests/test_ui.py
+++ b/tests/test_ui.py
@@ -439,7 +439,7 @@ async def test_run_stream_external_tools():
'',
"{}",
'',
- "DeferredToolRequests(calls=[ToolCallPart(tool_name='external_tool', args={}, tool_call_id='pyd_ai_tool_call_id__external_tool')], approvals=[])",
+ "DeferredToolRequests(calls=[ToolCallPart(tool_name='external_tool', args={}, tool_call_id='pyd_ai_tool_call_id__external_tool')], approvals=[], metadata={})",
'',
]
)