diff --git a/examples/basic/agent_lifecycle_example.py b/examples/basic/agent_lifecycle_example.py index 29bb18c9..62662f8e 100644 --- a/examples/basic/agent_lifecycle_example.py +++ b/examples/basic/agent_lifecycle_example.py @@ -4,7 +4,7 @@ from pydantic import BaseModel -from agents import Agent, AgentHooks, RunContextWrapper, Runner, Tool, function_tool +from agents import Agent, AgentHooks, RunContextWrapper, Runner, Tool, Action, function_tool class CustomAgentHooks(AgentHooks): @@ -28,10 +28,10 @@ async def on_handoff(self, context: RunContextWrapper, agent: Agent, source: Age f"### ({self.display_name}) {self.event_counter}: Agent {source.name} handed off to {agent.name}" ) - async def on_tool_start(self, context: RunContextWrapper, agent: Agent, tool: Tool) -> None: + async def on_tool_start(self, context: RunContextWrapper, agent: Agent, action: Action) -> None: self.event_counter += 1 print( - f"### ({self.display_name}) {self.event_counter}: Agent {agent.name} started tool {tool.name}" + f"### ({self.display_name}) {self.event_counter}: Agent {agent.name} started tool {action.function_tool.name} with arguments {action.tool_call.arguments}" ) async def on_tool_end( diff --git a/examples/basic/lifecycle_example.py b/examples/basic/lifecycle_example.py index 285bfecd..163322ed 100644 --- a/examples/basic/lifecycle_example.py +++ b/examples/basic/lifecycle_example.py @@ -4,7 +4,7 @@ from pydantic import BaseModel -from agents import Agent, RunContextWrapper, RunHooks, Runner, Tool, Usage, function_tool +from agents import Agent, RunContextWrapper, RunHooks, Runner, Tool, Usage, function_tool, Action class ExampleHooks(RunHooks): @@ -26,10 +26,10 @@ async def on_agent_end(self, context: RunContextWrapper, agent: Agent, output: A f"### {self.event_counter}: Agent {agent.name} ended with output {output}. Usage: {self._usage_to_str(context.usage)}" ) - async def on_tool_start(self, context: RunContextWrapper, agent: Agent, tool: Tool) -> None: + async def on_tool_start(self, context: RunContextWrapper, agent: Agent, action: Action) -> None: self.event_counter += 1 print( - f"### {self.event_counter}: Tool {tool.name} started. Usage: {self._usage_to_str(context.usage)}" + f"### {self.event_counter}: Tool {action.function_tool.tool.name} started. Usage: {self._usage_to_str(context.usage)}" ) async def on_tool_end( diff --git a/src/agents/__init__.py b/src/agents/__init__.py index 6d7c90b4..0014c9c5 100644 --- a/src/agents/__init__.py +++ b/src/agents/__init__.py @@ -62,6 +62,7 @@ WebSearchTool, default_tool_error_function, function_tool, + Action, ) from .tracing import ( AgentSpanData, @@ -209,6 +210,7 @@ def enable_verbose_stdout_logging(): "Tool", "WebSearchTool", "function_tool", + "Action", "Usage", "add_trace_processor", "agent_span", diff --git a/src/agents/_run_impl.py b/src/agents/_run_impl.py index b5a83685..b4f0991d 100644 --- a/src/agents/_run_impl.py +++ b/src/agents/_run_impl.py @@ -52,7 +52,7 @@ from .models.interface import ModelTracing from .run_context import RunContextWrapper, TContext from .stream_events import RunItemStreamEvent, StreamEvent -from .tool import ComputerTool, FunctionTool, FunctionToolResult, Tool +from .tool import ComputerTool, FunctionTool, FunctionToolResult, Tool, ToolRunFunction, ToolRunComputerAction from .tracing import ( SpanError, Trace, @@ -99,19 +99,6 @@ class ToolRunHandoff: handoff: Handoff tool_call: ResponseFunctionToolCall - -@dataclass -class ToolRunFunction: - tool_call: ResponseFunctionToolCall - function_tool: FunctionTool - - -@dataclass -class ToolRunComputerAction: - tool_call: ResponseComputerToolCall - computer_tool: ComputerTool - - @dataclass class ProcessedResponse: new_items: list[RunItem] @@ -429,17 +416,17 @@ async def execute_function_tool_calls( context_wrapper: RunContextWrapper[TContext], config: RunConfig, ) -> list[FunctionToolResult]: - async def run_single_tool( - func_tool: FunctionTool, tool_call: ResponseFunctionToolCall - ) -> Any: + async def run_single_tool(action: ToolRunFunction) -> Any: + func_tool = action.function_tool + tool_call = action.tool_call with function_span(func_tool.name) as span_fn: if config.trace_include_sensitive_data: span_fn.span_data.input = tool_call.arguments try: _, _, result = await asyncio.gather( - hooks.on_tool_start(context_wrapper, agent, func_tool), + hooks.on_tool_start(context_wrapper, agent, action), ( - agent.hooks.on_tool_start(context_wrapper, agent, func_tool) + agent.hooks.on_tool_start(context_wrapper, agent, action) if agent.hooks else _coro.noop_coroutine() ), @@ -471,8 +458,7 @@ async def run_single_tool( tasks = [] for tool_run in tool_runs: - function_tool = tool_run.function_tool - tasks.append(run_single_tool(function_tool, tool_run.tool_call)) + tasks.append(run_single_tool(tool_run)) results = await asyncio.gather(*tasks) @@ -831,9 +817,9 @@ async def execute( ) _, _, output = await asyncio.gather( - hooks.on_tool_start(context_wrapper, agent, action.computer_tool), + hooks.on_tool_start(context_wrapper, agent, action), ( - agent.hooks.on_tool_start(context_wrapper, agent, action.computer_tool) + agent.hooks.on_tool_start(context_wrapper, agent, action) if agent.hooks else _coro.noop_coroutine() ), diff --git a/src/agents/lifecycle.py b/src/agents/lifecycle.py index 8643248b..fa102bc6 100644 --- a/src/agents/lifecycle.py +++ b/src/agents/lifecycle.py @@ -2,7 +2,7 @@ from .agent import Agent from .run_context import RunContextWrapper, TContext -from .tool import Tool +from .tool import Tool, Action class RunHooks(Generic[TContext]): @@ -38,7 +38,7 @@ async def on_tool_start( self, context: RunContextWrapper[TContext], agent: Agent[TContext], - tool: Tool, + action: Action, ) -> None: """Called before a tool is invoked.""" pass @@ -89,7 +89,7 @@ async def on_tool_start( self, context: RunContextWrapper[TContext], agent: Agent[TContext], - tool: Tool, + action: Action, ) -> None: """Called before a tool is invoked.""" pass diff --git a/src/agents/tool.py b/src/agents/tool.py index c1c16242..2012198b 100644 --- a/src/agents/tool.py +++ b/src/agents/tool.py @@ -10,6 +10,7 @@ from openai.types.responses.web_search_tool_param import UserLocation from pydantic import ValidationError from typing_extensions import Concatenate, ParamSpec +from openai.types.responses import ResponseComputerToolCall, ResponseFunctionToolCall from . import _debug from .computer import AsyncComputer, Computer @@ -133,6 +134,19 @@ def name(self): Tool = Union[FunctionTool, FileSearchTool, WebSearchTool, ComputerTool] """A tool that can be used in an agent.""" +@dataclass +class ToolRunFunction: + tool_call: ResponseFunctionToolCall + function_tool: FunctionTool + + +@dataclass +class ToolRunComputerAction: + tool_call: ResponseComputerToolCall + computer_tool: ComputerTool + +Action = Union[ToolRunFunction, ToolRunComputerAction] +"""An action that can be performed by an agent. It contains the tool call and the tool""" def default_tool_error_function(ctx: RunContextWrapper[Any], error: Exception) -> str: """The default tool error function, which just returns a generic error message.""" diff --git a/tests/test_agent_hooks.py b/tests/test_agent_hooks.py index a6c302dc..63e0177c 100644 --- a/tests/test_agent_hooks.py +++ b/tests/test_agent_hooks.py @@ -7,7 +7,7 @@ import pytest from typing_extensions import TypedDict -from agents.agent import Agent +from agents.agent import Agent, Action from agents.lifecycle import AgentHooks from agents.run import Runner from agents.run_context import RunContextWrapper, TContext @@ -53,7 +53,7 @@ async def on_tool_start( self, context: RunContextWrapper[TContext], agent: Agent[TContext], - tool: Tool, + action: Action, ) -> None: self.events["on_tool_start"] += 1 diff --git a/tests/test_computer_action.py b/tests/test_computer_action.py index 70dcabd5..f161f0d1 100644 --- a/tests/test_computer_action.py +++ b/tests/test_computer_action.py @@ -23,6 +23,7 @@ from agents import ( Agent, + Action, AgentHooks, AsyncComputer, Computer, @@ -31,7 +32,8 @@ RunContextWrapper, RunHooks, ) -from agents._run_impl import ComputerAction, ToolRunComputerAction +from agents._run_impl import ComputerAction +from agents.tool import ToolRunComputerAction from agents.items import ToolCallOutputItem @@ -222,9 +224,9 @@ def __init__(self) -> None: self.ended: list[tuple[Agent[Any], Any, str]] = [] async def on_tool_start( - self, context: RunContextWrapper[Any], agent: Agent[Any], tool: Any + self, context: RunContextWrapper[Any], agent: Agent[Any], action: Action, ) -> None: - self.started.append((agent, tool)) + self.started.append((agent, action.computer_tool)) async def on_tool_end( self, context: RunContextWrapper[Any], agent: Agent[Any], tool: Any, result: str @@ -241,9 +243,9 @@ def __init__(self) -> None: self.ended: list[tuple[Agent[Any], Any, str]] = [] async def on_tool_start( - self, context: RunContextWrapper[Any], agent: Agent[Any], tool: Any + self, context: RunContextWrapper[Any], agent: Agent[Any], action: Action, ) -> None: - self.started.append((agent, tool)) + self.started.append((agent, action.computer_tool)) async def on_tool_end( self, context: RunContextWrapper[Any], agent: Agent[Any], tool: Any, result: str