Skip to content

Add tool call parameters for on_tool_start hook #253

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions examples/basic/agent_lifecycle_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from pydantic import BaseModel

from agents import Agent, AgentHooks, RunContextWrapper, Runner, Tool, function_tool
from agents import Agent, AgentHooks, RunContextWrapper, Runner, Tool, Action, function_tool


class CustomAgentHooks(AgentHooks):
Expand All @@ -28,10 +28,10 @@ async def on_handoff(self, context: RunContextWrapper, agent: Agent, source: Age
f"### ({self.display_name}) {self.event_counter}: Agent {source.name} handed off to {agent.name}"
)

async def on_tool_start(self, context: RunContextWrapper, agent: Agent, tool: Tool) -> None:
async def on_tool_start(self, context: RunContextWrapper, agent: Agent, action: Action) -> None:
self.event_counter += 1
print(
f"### ({self.display_name}) {self.event_counter}: Agent {agent.name} started tool {tool.name}"
f"### ({self.display_name}) {self.event_counter}: Agent {agent.name} started tool {action.function_tool.name} with arguments {action.tool_call.arguments}"
)

async def on_tool_end(
Expand Down
6 changes: 3 additions & 3 deletions examples/basic/lifecycle_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from pydantic import BaseModel

from agents import Agent, RunContextWrapper, RunHooks, Runner, Tool, Usage, function_tool
from agents import Agent, RunContextWrapper, RunHooks, Runner, Tool, Usage, function_tool, Action


class ExampleHooks(RunHooks):
Expand All @@ -26,10 +26,10 @@ async def on_agent_end(self, context: RunContextWrapper, agent: Agent, output: A
f"### {self.event_counter}: Agent {agent.name} ended with output {output}. Usage: {self._usage_to_str(context.usage)}"
)

async def on_tool_start(self, context: RunContextWrapper, agent: Agent, tool: Tool) -> None:
async def on_tool_start(self, context: RunContextWrapper, agent: Agent, action: Action) -> None:
self.event_counter += 1
print(
f"### {self.event_counter}: Tool {tool.name} started. Usage: {self._usage_to_str(context.usage)}"
f"### {self.event_counter}: Tool {action.function_tool.tool.name} started. Usage: {self._usage_to_str(context.usage)}"
)

async def on_tool_end(
Expand Down
2 changes: 2 additions & 0 deletions src/agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
WebSearchTool,
default_tool_error_function,
function_tool,
Action,
)
from .tracing import (
AgentSpanData,
Expand Down Expand Up @@ -209,6 +210,7 @@ def enable_verbose_stdout_logging():
"Tool",
"WebSearchTool",
"function_tool",
"Action",
"Usage",
"add_trace_processor",
"agent_span",
Expand Down
32 changes: 9 additions & 23 deletions src/agents/_run_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
from .models.interface import ModelTracing
from .run_context import RunContextWrapper, TContext
from .stream_events import RunItemStreamEvent, StreamEvent
from .tool import ComputerTool, FunctionTool, FunctionToolResult, Tool
from .tool import ComputerTool, FunctionTool, FunctionToolResult, Tool, ToolRunFunction, ToolRunComputerAction
from .tracing import (
SpanError,
Trace,
Expand Down Expand Up @@ -99,19 +99,6 @@ class ToolRunHandoff:
handoff: Handoff
tool_call: ResponseFunctionToolCall


@dataclass
class ToolRunFunction:
tool_call: ResponseFunctionToolCall
function_tool: FunctionTool


@dataclass
class ToolRunComputerAction:
tool_call: ResponseComputerToolCall
computer_tool: ComputerTool


@dataclass
class ProcessedResponse:
new_items: list[RunItem]
Expand Down Expand Up @@ -429,17 +416,17 @@ async def execute_function_tool_calls(
context_wrapper: RunContextWrapper[TContext],
config: RunConfig,
) -> list[FunctionToolResult]:
async def run_single_tool(
func_tool: FunctionTool, tool_call: ResponseFunctionToolCall
) -> Any:
async def run_single_tool(action: ToolRunFunction) -> Any:
func_tool = action.function_tool
tool_call = action.tool_call
with function_span(func_tool.name) as span_fn:
if config.trace_include_sensitive_data:
span_fn.span_data.input = tool_call.arguments
try:
_, _, result = await asyncio.gather(
hooks.on_tool_start(context_wrapper, agent, func_tool),
hooks.on_tool_start(context_wrapper, agent, action),
(
agent.hooks.on_tool_start(context_wrapper, agent, func_tool)
agent.hooks.on_tool_start(context_wrapper, agent, action)
if agent.hooks

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since using only arguments as a parameter is too granular, and adding tool_call.id in the future would lead to parameter inflation at the root level, plus tool_call.id is actually needed, would it be better to pass the entire tool_call object as a parameter instead?"

Copy link
Author

@yanmxa yanmxa Apr 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mini-peanut Thanks for your review! I've updated the hook to use Action, which not only includes the Tool but also contains the instance information for a tool call, including the parameters and tool_call.id.

else _coro.noop_coroutine()
),
Expand Down Expand Up @@ -471,8 +458,7 @@ async def run_single_tool(

tasks = []
for tool_run in tool_runs:
function_tool = tool_run.function_tool
tasks.append(run_single_tool(function_tool, tool_run.tool_call))
tasks.append(run_single_tool(tool_run))

results = await asyncio.gather(*tasks)

Expand Down Expand Up @@ -831,9 +817,9 @@ async def execute(
)

_, _, output = await asyncio.gather(
hooks.on_tool_start(context_wrapper, agent, action.computer_tool),
hooks.on_tool_start(context_wrapper, agent, action),
(
agent.hooks.on_tool_start(context_wrapper, agent, action.computer_tool)
agent.hooks.on_tool_start(context_wrapper, agent, action)
if agent.hooks
else _coro.noop_coroutine()
),
Expand Down
6 changes: 3 additions & 3 deletions src/agents/lifecycle.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from .agent import Agent
from .run_context import RunContextWrapper, TContext
from .tool import Tool
from .tool import Tool, Action


class RunHooks(Generic[TContext]):
Expand Down Expand Up @@ -38,7 +38,7 @@ async def on_tool_start(
self,
context: RunContextWrapper[TContext],
agent: Agent[TContext],
tool: Tool,
action: Action,
) -> None:
"""Called before a tool is invoked."""
pass
Expand Down Expand Up @@ -89,7 +89,7 @@ async def on_tool_start(
self,
context: RunContextWrapper[TContext],
agent: Agent[TContext],
tool: Tool,
action: Action,
) -> None:
"""Called before a tool is invoked."""
pass
Expand Down
14 changes: 14 additions & 0 deletions src/agents/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from openai.types.responses.web_search_tool_param import UserLocation
from pydantic import ValidationError
from typing_extensions import Concatenate, ParamSpec
from openai.types.responses import ResponseComputerToolCall, ResponseFunctionToolCall

from . import _debug
from .computer import AsyncComputer, Computer
Expand Down Expand Up @@ -133,6 +134,19 @@ def name(self):
Tool = Union[FunctionTool, FileSearchTool, WebSearchTool, ComputerTool]
"""A tool that can be used in an agent."""

@dataclass
class ToolRunFunction:
tool_call: ResponseFunctionToolCall
function_tool: FunctionTool


@dataclass
class ToolRunComputerAction:
tool_call: ResponseComputerToolCall
computer_tool: ComputerTool

Action = Union[ToolRunFunction, ToolRunComputerAction]
"""An action that can be performed by an agent. It contains the tool call and the tool"""

def default_tool_error_function(ctx: RunContextWrapper[Any], error: Exception) -> str:
"""The default tool error function, which just returns a generic error message."""
Expand Down
4 changes: 2 additions & 2 deletions tests/test_agent_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pytest
from typing_extensions import TypedDict

from agents.agent import Agent
from agents.agent import Agent, Action
from agents.lifecycle import AgentHooks
from agents.run import Runner
from agents.run_context import RunContextWrapper, TContext
Expand Down Expand Up @@ -53,7 +53,7 @@ async def on_tool_start(
self,
context: RunContextWrapper[TContext],
agent: Agent[TContext],
tool: Tool,
action: Action,
) -> None:
self.events["on_tool_start"] += 1

Expand Down
12 changes: 7 additions & 5 deletions tests/test_computer_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from agents import (
Agent,
Action,
AgentHooks,
AsyncComputer,
Computer,
Expand All @@ -31,7 +32,8 @@
RunContextWrapper,
RunHooks,
)
from agents._run_impl import ComputerAction, ToolRunComputerAction
from agents._run_impl import ComputerAction
from agents.tool import ToolRunComputerAction
from agents.items import ToolCallOutputItem


Expand Down Expand Up @@ -222,9 +224,9 @@ def __init__(self) -> None:
self.ended: list[tuple[Agent[Any], Any, str]] = []

async def on_tool_start(
self, context: RunContextWrapper[Any], agent: Agent[Any], tool: Any
self, context: RunContextWrapper[Any], agent: Agent[Any], action: Action,
) -> None:
self.started.append((agent, tool))
self.started.append((agent, action.computer_tool))

async def on_tool_end(
self, context: RunContextWrapper[Any], agent: Agent[Any], tool: Any, result: str
Expand All @@ -241,9 +243,9 @@ def __init__(self) -> None:
self.ended: list[tuple[Agent[Any], Any, str]] = []

async def on_tool_start(
self, context: RunContextWrapper[Any], agent: Agent[Any], tool: Any
self, context: RunContextWrapper[Any], agent: Agent[Any], action: Action,
) -> None:
self.started.append((agent, tool))
self.started.append((agent, action.computer_tool))

async def on_tool_end(
self, context: RunContextWrapper[Any], agent: Agent[Any], tool: Any, result: str
Expand Down