Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add the add_tool(), remove_tool() and remove_all_tool() me… #4545

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
ekzhu marked this conversation as resolved.
Show resolved Hide resolved
ekzhu marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
)

from autogen_core import CancellationToken, FunctionCall
from autogen_core.components.models import LLMMessage
from autogen_core.model_context import (
ChatCompletionContext,
UnboundedChatCompletionContext,
Expand Down Expand Up @@ -246,24 +247,11 @@ def __init__(
else:
self._system_messages = [SystemMessage(content=system_message)]
self._tools: List[Tool] = []
if tools is not None:
if model_client.capabilities["function_calling"] is False:
raise ValueError("The model does not support function calling.")
for tool in tools:
if isinstance(tool, Tool):
self._tools.append(tool)
elif callable(tool):
if hasattr(tool, "__doc__") and tool.__doc__ is not None:
description = tool.__doc__
else:
description = ""
self._tools.append(FunctionTool(tool, description=description))
else:
raise ValueError(f"Unsupported tool type: {type(tool)}")
# Check if tool names are unique.
tool_names = [tool.name for tool in self._tools]
if len(tool_names) != len(set(tool_names)):
raise ValueError(f"Tool names must be unique: {tool_names}")
self._model_context: List[LLMMessage] = []
self._reflect_on_tool_use = reflect_on_tool_use
self._tool_call_summary_format = tool_call_summary_format
self._is_running = False

# Handoff tools.
self._handoff_tools: List[Tool] = []
self._handoffs: Dict[str, HandoffBase] = {}
Expand All @@ -273,26 +261,191 @@ def __init__(
for handoff in handoffs:
if isinstance(handoff, str):
handoff = HandoffBase(target=handoff)
if handoff.name in self._handoffs:
raise ValueError(f"Handoff name {handoff.name} already exists.")
if isinstance(handoff, HandoffBase):
self._handoff_tools.append(handoff.handoff_tool)
self._handoffs[handoff.name] = handoff
else:
raise ValueError(f"Unsupported handoff type: {type(handoff)}")
# Check if handoff tool names are unique.
handoff_tool_names = [tool.name for tool in self._handoff_tools]
if len(handoff_tool_names) != len(set(handoff_tool_names)):
raise ValueError(f"Handoff names must be unique: {handoff_tool_names}")
# Check if handoff tool names not in tool names.
if any(name in tool_names for name in handoff_tool_names):
raise ValueError(
f"Handoff names must be unique from tool names. Handoff names: {handoff_tool_names}; tool names: {tool_names}"
)
if tools is not None:
for tool in tools:
self.add_tool(tool)

if not model_context:
self._model_context = UnboundedChatCompletionContext()
self._reflect_on_tool_use = reflect_on_tool_use
self._tool_call_summary_format = tool_call_summary_format
self._is_running = False

def add_tool(self, tool: Tool | Callable[..., Any] | Callable[..., Awaitable[Any]]) -> None:
"""
Adds a new tool to the assistant agent.

The tool can be either an instance of the `Tool` class, or a callable function. If the tool is a callable
function, a `FunctionTool` instance will be created with the function and its docstring as the description.
ekzhu marked this conversation as resolved.
Show resolved Hide resolved

The tool name must be unique among all the tools and handoffs added to the agent. If the model does not support
function calling, an error will be raised.

Args:
tool (Tool | Callable[..., Any] | Callable[..., Awaitable[Any]]): The tool to add.

Raises:
ValueError: If the tool name is not unique.
ValueError: If the tool name is already used by a handoff.
ValueError: If the tool has an unsupported type.
ValueError: If the model does not support function calling.

Examples:
.. code-block:: python

import asyncio
from autogen_ext.models.openai import OpenAIChatCompletionClient
from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.messages import TextMessage
from autogen_agentchat.ui import Console
from autogen_core import CancellationToken


async def get_current_time() -> str:
return "The current time is 12:00 PM."


async def main() -> None:
model_client = OpenAIChatCompletionClient(
model="gpt-4o",
# api_key = "your_openai_api_key"
)
agent = AssistantAgent(name="assistant", model_client=model_client)

agent.add_tool(get_current_time)

await Console(
agent.on_messages_stream(
[TextMessage(content="What is the current time?", source="user")], CancellationToken()
)
)


asyncio.run(main())
"""
new_tool = None
if self._model_client.capabilities["function_calling"] is False:
ekzhu marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError("The model does not support function calling.")
if isinstance(tool, Tool):
new_tool = tool
elif callable(tool):
if hasattr(tool, "__doc__") and tool.__doc__ is not None:
description = tool.__doc__
else:
description = ""
new_tool = FunctionTool(tool, description=description)
else:
raise ValueError(f"Unsupported tool type: {type(tool)}")
# Check if tool names are unique.
if any(tool.name == new_tool.name for tool in self._tools):
raise ValueError(f"Tool names must be unique: {new_tool.name}")
# Check if handoff tool names not in tool names.
handoff_tool_names = [handoff.name for handoff in self._handoffs.values()]
if new_tool.name in handoff_tool_names:
raise ValueError(
f"Handoff names must be unique from tool names. Handoff names: {handoff_tool_names}; "
f"tool names: {new_tool.name}"
)
self._tools.append(new_tool)

def remove_all_tools(self) -> None:
"""
Remove all tools.

Examples:
.. code-block:: python

import asyncio
from autogen_ext.models.openai import OpenAIChatCompletionClient
from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.messages import TextMessage
from autogen_agentchat.ui import Console
from autogen_core import CancellationToken


async def get_current_time() -> str:
return "The current time is 12:00 PM."


async def main() -> None:
model_client = OpenAIChatCompletionClient(
model="gpt-4o",
# api_key = "your_openai_api_key"
)
agent = AssistantAgent(name="assistant", model_client=model_client)

agent.add_tool(get_current_time)
agent.remove_all_tools()

await Console(
agent.on_messages_stream(
[TextMessage(content="What is the current time?", source="user")], CancellationToken()
)
)


asyncio.run(main())

"""
self._tools.clear()

def remove_tool(self, tool_name: str) -> None:
"""
Remove a tool by name.

Args:
tool_name (str): The name of the tool to remove.

Raises:
ValueError: If the tool name is not found.

Examples:
.. code-block:: python

import asyncio
from autogen_ext.models.openai import OpenAIChatCompletionClient
from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.messages import TextMessage
from autogen_agentchat.ui import Console
from autogen_core import CancellationToken


async def get_current_time() -> str:
return "The current time is 12:00 PM."


async def main() -> None:
model_client = OpenAIChatCompletionClient(
model="gpt-4o",
# api_key = "your_openai_api_key"
)
agent = AssistantAgent(name="assistant", model_client=model_client)

agent.add_tool(get_current_time)
agent.remove_tool("get_current_time")

await Console(
agent.on_messages_stream(
[TextMessage(content="What is the current time?", source="user")], CancellationToken()
)
)


asyncio.run(main())
"""
for tool in self._tools:
if tool.name == tool_name:
self._tools.remove(tool)
return
raise ValueError(f"Tool {tool_name} not found.")

@property
def produced_message_types(self) -> List[type[ChatMessage]]:
"""The types of messages that the assistant agent produces."""
Expand Down
75 changes: 75 additions & 0 deletions python/packages/autogen-agentchat/tests/test_assistant_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,3 +468,78 @@ async def test_list_chat_messages(monkeypatch: pytest.MonkeyPatch) -> None:
else:
assert message == result.messages[index]
index += 1


def test_tool_management():
model_client = OpenAIChatCompletionClient(model="gpt-4", api_key="")
agent = AssistantAgent(name="test_assistant", model_client=model_client)

# Test function to be used as a tool
def sample_tool() -> str:
return "sample result"

# Test adding a tool
tool = FunctionTool(sample_tool, description="Sample tool")
agent.add_tool(tool)
assert len(agent._tools) == 1

# Test adding duplicate tool
with pytest.raises(ValueError, match="Tool names must be unique"):
agent.add_tool(tool)

# Test tool collision with handoff
agent_with_handoff = AssistantAgent(
name="test_assistant", model_client=model_client, handoffs=[Handoff(target="other_agent")]
)

conflicting_tool = FunctionTool(sample_tool, name="transfer_to_other_agent", description="Sample tool")
with pytest.raises(ValueError, match="Handoff names must be unique from tool names"):
agent_with_handoff.add_tool(conflicting_tool)

# Test removing a tool
agent.remove_tool(tool.name)
assert len(agent._tools) == 0

# Test removing non-existent tool
with pytest.raises(ValueError, match="Tool non_existent_tool not found"):
agent.remove_tool("non_existent_tool")

# Test removing all tools
agent.add_tool(tool)
assert len(agent._tools) == 1
agent.remove_all_tools()
assert len(agent._tools) == 0

# Test idempotency of remove_all_tools
agent.remove_all_tools()
assert len(agent._tools) == 0


def test_callable_tool_addition():
model_client = OpenAIChatCompletionClient(model="gpt-4", api_key="")
agent = AssistantAgent(name="test_assistant", model_client=model_client)

# Test adding a callable directly
def documented_tool() -> str:
"""This is a documented tool"""
return "result"

agent.add_tool(documented_tool)
assert len(agent._tools) == 1
assert agent._tools[0].description == "This is a documented tool"

# Test adding async callable
async def async_tool() -> str:
return "async result"

agent.add_tool(async_tool)
assert len(agent._tools) == 2


def test_invalid_tool_addition():
model_client = OpenAIChatCompletionClient(model="gpt-4", api_key="")
agent = AssistantAgent(name="test_assistant", model_client=model_client)

# Test adding invalid tool type
with pytest.raises(ValueError, match="Unsupported tool type"):
agent.add_tool("not a tool")
Loading