diff --git a/docs/deferred-tools.md b/docs/deferred-tools.md index e5e5201163..5fa8c7d8b3 100644 --- a/docs/deferred-tools.md +++ b/docs/deferred-tools.md @@ -77,6 +77,7 @@ DeferredToolRequests( tool_call_id='delete_file', ), ], + metadata={}, ) """ @@ -247,6 +248,7 @@ async def main(): ) ], approvals=[], + metadata={}, ) """ @@ -320,6 +322,151 @@ async def main(): _(This example is complete, it can be run "as is" — you'll need to add `asyncio.run(main())` to run `main`)_ +## Attaching Metadata to Deferred Tools + +Both [`CallDeferred`][pydantic_ai.exceptions.CallDeferred] and [`ApprovalRequired`][pydantic_ai.exceptions.ApprovalRequired] exceptions accept an optional `metadata` parameter that allows you to attach arbitrary context information to deferred tool calls. This metadata is available in [`DeferredToolRequests.metadata`][pydantic_ai.tools.DeferredToolRequests.metadata] keyed by tool call ID. + +Common use cases include cost estimates for approval decisions and tracking information for external systems. + +```python {title="deferred_tools_with_metadata.py"} +from dataclasses import dataclass + +from pydantic_ai import ( + Agent, + ApprovalRequired, + CallDeferred, + DeferredToolRequests, + DeferredToolResults, + RunContext, + ToolApproved, + ToolDenied, +) + + +@dataclass +class User: + home_location: str = 'St. Louis, MO' + + +class FlightAPI: + COSTS = { + ('St. Louis, MO', 'Lisbon, Portugal'): 850, + ('St. Louis, MO', 'Santiago, Chile'): 1200, + ('St. Louis, MO', 'Los Angeles, CA'): 300, + } + + def get_flight_cost(self, origin: str, destination: str) -> int: + return self.COSTS.get((origin, destination), 500) + + def get_airline_auth_url(self, airline: str) -> str: + # In real code, this might generate a proper OAuth URL + return f"https://example.com/auth/{airline.lower().replace(' ', '-')}" + + +@dataclass +class TravelDeps: + user: User + flight_api: FlightAPI + + +agent = Agent( + 'openai:gpt-5', + deps_type=TravelDeps, + output_type=[str, DeferredToolRequests], +) + + +@agent.tool +def book_flight(ctx: RunContext[TravelDeps], destination: str) -> str: + """Book a flight to the destination.""" + if not ctx.tool_call_approved: + # Look up cost based on user's location and destination + cost = ctx.deps.flight_api.get_flight_cost( + ctx.deps.user.home_location, + destination + ) + + raise ApprovalRequired( + metadata={ + 'origin': ctx.deps.user.home_location, + 'destination': destination, + 'cost_usd': cost, + } + ) + + return f'Flight booked to {destination}' + + +@agent.tool +def authenticate_with_airline(ctx: RunContext[TravelDeps], airline: str) -> str: + """Authenticate with airline website to link frequent flyer account.""" + # Generate auth URL that would normally open in browser + auth_url = ctx.deps.flight_api.get_airline_auth_url(airline) + + # Cannot complete auth in this process - need user interaction + raise CallDeferred( + metadata={ + 'airline': airline, + 'auth_url': auth_url, + } + ) + + +# Set up dependencies +user = User(home_location='St. Louis, MO') +flight_api = FlightAPI() +deps = TravelDeps(user=user, flight_api=flight_api) + +# Agent calls both tools +result = agent.run_sync( + 'Book a flight to Lisbon, Portugal and link my SkyWay Airlines account', + deps=deps, +) +messages = result.all_messages() + +assert isinstance(result.output, DeferredToolRequests) +requests = result.output + +# Make approval decision using metadata +results = DeferredToolResults() +for call in requests.approvals: + metadata = requests.metadata.get(call.tool_call_id, {}) + cost = metadata.get('cost_usd', 0) + + print(f'Approval needed: {call.tool_name}') + #> Approval needed: book_flight + print(f" {metadata['origin']} → {metadata['destination']}: ${cost}") + #> St. Louis, MO → Lisbon, Portugal: $850 + + if cost < 1000: + results.approvals[call.tool_call_id] = ToolApproved() + else: + results.approvals[call.tool_call_id] = ToolDenied('Cost exceeds budget') + +# Handle deferred calls using metadata +for call in requests.calls: + metadata = requests.metadata.get(call.tool_call_id, {}) + auth_url = metadata.get('auth_url') + + print(f'Browser auth required: {auth_url}') + #> Browser auth required: https://example.com/auth/skyway-airlines + + # In real code: open browser, wait for auth completion + # For demo, just mark as completed + results.calls[call.tool_call_id] = 'Frequent flyer account linked' + +# Continue with results +result = agent.run_sync( + message_history=messages, + deferred_tool_results=results, + deps=deps, +) +print(result.output) +#> Flight to Lisbon booked successfully and your SkyWay Airlines account is now linked. +``` + +_(This example is complete, it can be run "as is")_ + ## See Also - [Function Tools](tools.md) - Basic tool concepts and registration diff --git a/docs/toolsets.md b/docs/toolsets.md index 8d970b8e31..1b041b3baa 100644 --- a/docs/toolsets.md +++ b/docs/toolsets.md @@ -362,6 +362,7 @@ DeferredToolRequests( tool_call_id='pyd_ai_tool_call_id__temperature_fahrenheit', ), ], + metadata={}, ) """ diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index c167521079..7285908123 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -883,6 +883,7 @@ async def process_tool_calls( # noqa: C901 calls_to_run = [call for call in calls_to_run if call.tool_call_id in calls_to_run_results] deferred_calls: dict[Literal['external', 'unapproved'], list[_messages.ToolCallPart]] = defaultdict(list) + deferred_metadata: dict[str, dict[str, Any]] = {} if calls_to_run: async for event in _call_tools( @@ -894,6 +895,7 @@ async def process_tool_calls( # noqa: C901 usage_limits=ctx.deps.usage_limits, output_parts=output_parts, output_deferred_calls=deferred_calls, + output_deferred_metadata=deferred_metadata, ): yield event @@ -927,6 +929,7 @@ async def process_tool_calls( # noqa: C901 deferred_tool_requests = _output.DeferredToolRequests( calls=deferred_calls['external'], approvals=deferred_calls['unapproved'], + metadata=deferred_metadata, ) final_result = result.FinalResult(cast(NodeRunEndT, deferred_tool_requests), None, None) @@ -944,10 +947,12 @@ async def _call_tools( usage_limits: _usage.UsageLimits, output_parts: list[_messages.ModelRequestPart], output_deferred_calls: dict[Literal['external', 'unapproved'], list[_messages.ToolCallPart]], + output_deferred_metadata: dict[str, dict[str, Any]], ) -> AsyncIterator[_messages.HandleResponseEvent]: tool_parts_by_index: dict[int, _messages.ModelRequestPart] = {} user_parts_by_index: dict[int, _messages.UserPromptPart] = {} deferred_calls_by_index: dict[int, Literal['external', 'unapproved']] = {} + deferred_metadata_by_index: dict[int, dict[str, Any] | None] = {} if usage_limits.tool_calls_limit is not None: projected_usage = deepcopy(usage) @@ -982,10 +987,12 @@ async def handle_call_or_result( tool_part, tool_user_content = ( (await coro_or_task) if inspect.isawaitable(coro_or_task) else coro_or_task.result() ) - except exceptions.CallDeferred: + except exceptions.CallDeferred as e: deferred_calls_by_index[index] = 'external' - except exceptions.ApprovalRequired: + deferred_metadata_by_index[index] = e.metadata + except exceptions.ApprovalRequired as e: deferred_calls_by_index[index] = 'unapproved' + deferred_metadata_by_index[index] = e.metadata else: tool_parts_by_index[index] = tool_part if tool_user_content: @@ -1023,8 +1030,25 @@ async def handle_call_or_result( output_parts.extend([tool_parts_by_index[k] for k in sorted(tool_parts_by_index)]) output_parts.extend([user_parts_by_index[k] for k in sorted(user_parts_by_index)]) + _populate_deferred_calls( + tool_calls, deferred_calls_by_index, deferred_metadata_by_index, output_deferred_calls, output_deferred_metadata + ) + + +def _populate_deferred_calls( + tool_calls: list[_messages.ToolCallPart], + deferred_calls_by_index: dict[int, Literal['external', 'unapproved']], + deferred_metadata_by_index: dict[int, dict[str, Any] | None], + output_deferred_calls: dict[Literal['external', 'unapproved'], list[_messages.ToolCallPart]], + output_deferred_metadata: dict[str, dict[str, Any]], +) -> None: + """Populate deferred calls and metadata from indexed mappings.""" for k in sorted(deferred_calls_by_index): - output_deferred_calls[deferred_calls_by_index[k]].append(tool_calls[k]) + call = tool_calls[k] + output_deferred_calls[deferred_calls_by_index[k]].append(call) + metadata = deferred_metadata_by_index[k] + if metadata is not None: + output_deferred_metadata[call.tool_call_id] = metadata async def _call_tool( diff --git a/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_toolset.py b/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_toolset.py index d4adb4b6a7..a85b35ee4a 100644 --- a/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_toolset.py +++ b/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_toolset.py @@ -27,11 +27,13 @@ class CallToolParams: @dataclass class _ApprovalRequired: + metadata: dict[str, Any] | None = None kind: Literal['approval_required'] = 'approval_required' @dataclass class _CallDeferred: + metadata: dict[str, Any] | None = None kind: Literal['call_deferred'] = 'call_deferred' @@ -75,10 +77,10 @@ async def _wrap_call_tool_result(self, coro: Awaitable[Any]) -> CallToolResult: try: result = await coro return _ToolReturn(result=result) - except ApprovalRequired: - return _ApprovalRequired() - except CallDeferred: - return _CallDeferred() + except ApprovalRequired as e: + return _ApprovalRequired(metadata=e.metadata) + except CallDeferred as e: + return _CallDeferred(metadata=e.metadata) except ModelRetry as e: return _ModelRetry(message=e.message) @@ -86,9 +88,9 @@ def _unwrap_call_tool_result(self, result: CallToolResult) -> Any: if isinstance(result, _ToolReturn): return result.result elif isinstance(result, _ApprovalRequired): - raise ApprovalRequired() + raise ApprovalRequired(metadata=result.metadata) elif isinstance(result, _CallDeferred): - raise CallDeferred() + raise CallDeferred(metadata=result.metadata) elif isinstance(result, _ModelRetry): raise ModelRetry(result.message) else: diff --git a/pydantic_ai_slim/pydantic_ai/exceptions.py b/pydantic_ai_slim/pydantic_ai/exceptions.py index ae5cce0908..2e8358c2fc 100644 --- a/pydantic_ai_slim/pydantic_ai/exceptions.py +++ b/pydantic_ai_slim/pydantic_ai/exceptions.py @@ -67,18 +67,30 @@ class CallDeferred(Exception): """Exception to raise when a tool call should be deferred. See [tools docs](../deferred-tools.md#deferred-tools) for more information. + + Args: + metadata: Optional dictionary of metadata to attach to the deferred tool call. + This metadata will be available in `DeferredToolRequests.metadata` keyed by `tool_call_id`. """ - pass + def __init__(self, metadata: dict[str, Any] | None = None): + self.metadata = metadata + super().__init__() class ApprovalRequired(Exception): """Exception to raise when a tool call requires human-in-the-loop approval. See [tools docs](../deferred-tools.md#human-in-the-loop-tool-approval) for more information. + + Args: + metadata: Optional dictionary of metadata to attach to the deferred tool call. + This metadata will be available in `DeferredToolRequests.metadata` keyed by `tool_call_id`. """ - pass + def __init__(self, metadata: dict[str, Any] | None = None): + self.metadata = metadata + super().__init__() class UserError(RuntimeError): diff --git a/pydantic_ai_slim/pydantic_ai/tools.py b/pydantic_ai_slim/pydantic_ai/tools.py index da053a5191..beae26661b 100644 --- a/pydantic_ai_slim/pydantic_ai/tools.py +++ b/pydantic_ai_slim/pydantic_ai/tools.py @@ -147,6 +147,12 @@ class DeferredToolRequests: """Tool calls that require external execution.""" approvals: list[ToolCallPart] = field(default_factory=list) """Tool calls that require human-in-the-loop approval.""" + metadata: dict[str, dict[str, Any]] = field(default_factory=dict) + """Metadata for deferred tool calls, keyed by tool_call_id. + + This contains any metadata that was provided when raising [`CallDeferred`][pydantic_ai.exceptions.CallDeferred] + or [`ApprovalRequired`][pydantic_ai.exceptions.ApprovalRequired] exceptions. + """ @dataclass(kw_only=True) diff --git a/tests/test_agent.py b/tests/test_agent.py index 0a6bf1e325..2b89c3f38b 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -4857,9 +4857,13 @@ def call_second(): else: result = agent.run_sync(user_prompt) - assert result.output == snapshot( - DeferredToolRequests(approvals=[ToolCallPart(tool_name='requires_approval', tool_call_id=IsStr())]) - ) + assert isinstance(result.output, DeferredToolRequests) + assert len(result.output.approvals) == 1 + assert result.output.approvals[0].tool_name == 'requires_approval' + # When no metadata is provided, the tool_call_id should not be in metadata dict + tool_call_id = result.output.approvals[0].tool_call_id + assert tool_call_id not in result.output.metadata + assert result.output.metadata == {} assert integer_holder == 2 diff --git a/tests/test_examples.py b/tests/test_examples.py index c7c32c340d..dce4bca048 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -708,6 +708,21 @@ async def model_logic( # noqa: C901 TextPart(content='The factorial of 15 is **1,307,674,368,000**.'), ] ) + elif m.content == 'Book a flight to Lisbon, Portugal and link my SkyWay Airlines account': + return ModelResponse( + parts=[ + ToolCallPart( + tool_name='book_flight', + args={'destination': 'Lisbon, Portugal'}, + tool_call_id='pyd_ai_tool_call_id_1', + ), + ToolCallPart( + tool_name='authenticate_with_airline', + args={'airline': 'SkyWay Airlines'}, + tool_call_id='pyd_ai_tool_call_id_2', + ), + ] + ) elif isinstance(m, ToolReturnPart) and m.tool_name == 'roulette_wheel': win = m.content == 'winner' @@ -871,10 +886,24 @@ async def model_logic( # noqa: C901 return ModelResponse( parts=[TextPart('The answer to the ultimate question of life, the universe, and everything is 42.')] ) - else: + elif isinstance(m, ToolReturnPart) and m.tool_name in ('book_flight', 'authenticate_with_airline'): + # After deferred tools complete, check if we have all results to provide final response + tool_names = {part.tool_name for msg in messages for part in msg.parts if isinstance(part, ToolReturnPart)} + if 'book_flight' in tool_names and 'authenticate_with_airline' in tool_names: + return ModelResponse( + parts=[TextPart('Flight to Lisbon booked successfully and your SkyWay Airlines account is now linked.')] + ) + # If we don't have both results yet, just acknowledge the tool result + return ModelResponse(parts=[TextPart(f'Received result from {m.tool_name}')]) + + if isinstance(m, ToolReturnPart): sys.stdout.write(str(debug.format(messages, info))) raise RuntimeError(f'Unexpected message: {m}') + # Fallback for any other message type + sys.stdout.write(str(debug.format(messages, info))) + raise RuntimeError(f'Unexpected message type: {type(m).__name__}') + async def stream_model_logic( # noqa C901 messages: list[ModelMessage], info: AgentInfo diff --git a/tests/test_streaming.py b/tests/test_streaming.py index 1a126f26dc..d66d5510e0 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -1149,9 +1149,13 @@ def regular_tool(x: int) -> int: async with agent.run_stream('test early strategy with external tool call') as result: response = await result.get_output() - assert response == snapshot( - DeferredToolRequests(calls=[ToolCallPart(tool_name='deferred_tool', tool_call_id=IsStr())]) - ) + assert isinstance(response, DeferredToolRequests) + assert len(response.calls) == 1 + assert response.calls[0].tool_name == 'deferred_tool' + # When no metadata is provided, the tool_call_id should not be in metadata dict + tool_call_id = response.calls[0].tool_call_id + assert tool_call_id not in response.metadata + assert response.metadata == {} messages = result.all_messages() # Verify no tools were called @@ -1638,9 +1642,7 @@ def my_tool(x: int) -> int: [DeferredToolRequests(calls=[ToolCallPart(tool_name='my_tool', args={'x': 0}, tool_call_id=IsStr())])] ) assert await result.get_output() == snapshot( - DeferredToolRequests( - calls=[ToolCallPart(tool_name='my_tool', args={'x': 0}, tool_call_id=IsStr())], - ) + DeferredToolRequests(calls=[ToolCallPart(tool_name='my_tool', args={'x': 0}, tool_call_id=IsStr())]) ) responses = [c async for c, _is_last in result.stream_responses(debounce_by=None)] assert responses == snapshot( @@ -1682,9 +1684,7 @@ def my_tool(ctx: RunContext[None], x: int) -> int: messages = result.all_messages() output = await result.get_output() assert output == snapshot( - DeferredToolRequests( - approvals=[ToolCallPart(tool_name='my_tool', args='{"x": 1}', tool_call_id=IsStr())], - ) + DeferredToolRequests(approvals=[ToolCallPart(tool_name='my_tool', args='{"x": 1}', tool_call_id=IsStr())]) ) assert result.is_complete diff --git a/tests/test_tools.py b/tests/test_tools.py index ea26d8ac91..a92e831873 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -1318,9 +1318,7 @@ def my_tool(x: int) -> int: result = agent.run_sync('Hello') assert result.output == snapshot( - DeferredToolRequests( - calls=[ToolCallPart(tool_name='my_tool', args={'x': 0}, tool_call_id=IsStr())], - ) + DeferredToolRequests(calls=[ToolCallPart(tool_name='my_tool', args={'x': 0}, tool_call_id=IsStr())]) ) @@ -1394,6 +1392,187 @@ def my_tool(ctx: RunContext[None], x: int) -> int: assert result.output == snapshot('Done!') +def test_call_deferred_with_metadata(): + """Test that CallDeferred exception can carry metadata.""" + agent = Agent(TestModel(), output_type=[str, DeferredToolRequests]) + + @agent.tool_plain + def my_tool(x: int) -> int: + raise CallDeferred(metadata={'task_id': 'task-123', 'estimated_cost': 25.50}) + + result = agent.run_sync('Hello') + assert isinstance(result.output, DeferredToolRequests) + assert len(result.output.calls) == 1 + + tool_call_id = result.output.calls[0].tool_call_id + assert tool_call_id in result.output.metadata + assert result.output.metadata[tool_call_id] == {'task_id': 'task-123', 'estimated_cost': 25.50} + + +def test_approval_required_with_metadata(): + """Test that ApprovalRequired exception can carry metadata.""" + + def llm(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: + if len(messages) == 1: + return ModelResponse( + parts=[ + ToolCallPart('my_tool', {'x': 1}, tool_call_id='my_tool'), + ] + ) + else: + return ModelResponse( + parts=[ + TextPart('Done!'), + ] + ) + + agent = Agent(FunctionModel(llm), output_type=[str, DeferredToolRequests]) + + @agent.tool + def my_tool(ctx: RunContext[None], x: int) -> int: + if not ctx.tool_call_approved: + raise ApprovalRequired( + metadata={ + 'reason': 'High compute cost', + 'estimated_time': '5 minutes', + 'cost_usd': 100.0, + } + ) + return x * 42 + + result = agent.run_sync('Hello') + assert isinstance(result.output, DeferredToolRequests) + assert len(result.output.approvals) == 1 + + assert 'my_tool' in result.output.metadata + assert result.output.metadata['my_tool'] == { + 'reason': 'High compute cost', + 'estimated_time': '5 minutes', + 'cost_usd': 100.0, + } + + # Continue with approval + messages = result.all_messages() + result = agent.run_sync( + message_history=messages, + deferred_tool_results=DeferredToolResults(approvals={'my_tool': ToolApproved()}), + ) + assert result.output == 'Done!' + + +def test_call_deferred_without_metadata(): + """Test backward compatibility: CallDeferred without metadata still works.""" + agent = Agent(TestModel(), output_type=[str, DeferredToolRequests]) + + @agent.tool_plain + def my_tool(x: int) -> int: + raise CallDeferred # No metadata + + result = agent.run_sync('Hello') + assert isinstance(result.output, DeferredToolRequests) + assert len(result.output.calls) == 1 + + tool_call_id = result.output.calls[0].tool_call_id + # Should have an empty metadata dict for this tool + assert result.output.metadata.get(tool_call_id, {}) == {} + + +def test_approval_required_without_metadata(): + """Test backward compatibility: ApprovalRequired without metadata still works.""" + + def llm(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: + if len(messages) == 1: + return ModelResponse( + parts=[ + ToolCallPart('my_tool', {'x': 1}, tool_call_id='my_tool'), + ] + ) + else: + return ModelResponse( + parts=[ + TextPart('Done!'), + ] + ) + + agent = Agent(FunctionModel(llm), output_type=[str, DeferredToolRequests]) + + @agent.tool + def my_tool(ctx: RunContext[None], x: int) -> int: + if not ctx.tool_call_approved: + raise ApprovalRequired # No metadata + return x * 42 + + result = agent.run_sync('Hello') + assert isinstance(result.output, DeferredToolRequests) + assert len(result.output.approvals) == 1 + + # Should have an empty metadata dict for this tool + assert result.output.metadata.get('my_tool', {}) == {} + + # Continue with approval + messages = result.all_messages() + result = agent.run_sync( + message_history=messages, + deferred_tool_results=DeferredToolResults(approvals={'my_tool': ToolApproved()}), + ) + assert result.output == 'Done!' + + +def test_mixed_deferred_tools_with_metadata(): + """Test multiple deferred tools with different metadata.""" + + def llm(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: + if len(messages) == 1: + return ModelResponse( + parts=[ + ToolCallPart('tool_a', {'x': 1}, tool_call_id='call_a'), + ToolCallPart('tool_b', {'y': 2}, tool_call_id='call_b'), + ToolCallPart('tool_c', {'z': 3}, tool_call_id='call_c'), + ] + ) + else: + return ModelResponse(parts=[TextPart('Done!')]) + + agent = Agent(FunctionModel(llm), output_type=[str, DeferredToolRequests]) + + @agent.tool + def tool_a(ctx: RunContext[None], x: int) -> int: + raise CallDeferred(metadata={'type': 'external', 'priority': 'high'}) + + @agent.tool + def tool_b(ctx: RunContext[None], y: int) -> int: + if not ctx.tool_call_approved: + raise ApprovalRequired(metadata={'reason': 'Needs approval', 'level': 'manager'}) + return y * 10 + + @agent.tool + def tool_c(ctx: RunContext[None], z: int) -> int: + raise CallDeferred # No metadata + + result = agent.run_sync('Hello') + assert isinstance(result.output, DeferredToolRequests) + + # Check that we have the right tools deferred + assert len(result.output.calls) == 2 # tool_a and tool_c + assert len(result.output.approvals) == 1 # tool_b + + # Check metadata + assert result.output.metadata['call_a'] == {'type': 'external', 'priority': 'high'} + assert result.output.metadata['call_b'] == {'reason': 'Needs approval', 'level': 'manager'} + assert result.output.metadata.get('call_c', {}) == {} + + # Continue with results for all three tools + messages = result.all_messages() + result = agent.run_sync( + message_history=messages, + deferred_tool_results=DeferredToolResults( + calls={'call_a': 10, 'call_c': 30}, + approvals={'call_b': ToolApproved()}, + ), + ) + assert result.output == 'Done!' + + def test_deferred_tool_with_output_type(): class MyModel(BaseModel): foo: str @@ -1583,7 +1762,7 @@ def buy(fruit: str): ToolCallPart(tool_name='buy', args={'fruit': 'apple'}, tool_call_id='buy_apple'), ToolCallPart(tool_name='buy', args={'fruit': 'banana'}, tool_call_id='buy_banana'), ToolCallPart(tool_name='buy', args={'fruit': 'pear'}, tool_call_id='buy_pear'), - ], + ] ) ) diff --git a/tests/test_ui.py b/tests/test_ui.py index 38f9950ad5..a497d09389 100644 --- a/tests/test_ui.py +++ b/tests/test_ui.py @@ -439,7 +439,7 @@ async def test_run_stream_external_tools(): '', "{}", '', - "DeferredToolRequests(calls=[ToolCallPart(tool_name='external_tool', args={}, tool_call_id='pyd_ai_tool_call_id__external_tool')], approvals=[])", + "DeferredToolRequests(calls=[ToolCallPart(tool_name='external_tool', args={}, tool_call_id='pyd_ai_tool_call_id__external_tool')], approvals=[], metadata={})", '', ] )