Skip to content

Commit

Permalink
Add the add_tool(), remove_tool() and remove_all_tools() method…
Browse files Browse the repository at this point in the history
…s for `AssistantAgent`
  • Loading branch information
Jean-Marc Le Roux committed Dec 17, 2024
1 parent 7eaffa8 commit 9b9ad51
Show file tree
Hide file tree
Showing 2 changed files with 249 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -231,24 +231,10 @@ def __init__(
else:
self._system_messages = [SystemMessage(content=system_message)]
self._tools: List[Tool] = []
if tools is not None:
if model_client.capabilities["function_calling"] is False:
raise ValueError("The model does not support function calling.")
for tool in tools:
if isinstance(tool, Tool):
self._tools.append(tool)
elif callable(tool):
if hasattr(tool, "__doc__") and tool.__doc__ is not None:
description = tool.__doc__
else:
description = ""
self._tools.append(FunctionTool(tool, description=description))
else:
raise ValueError(f"Unsupported tool type: {type(tool)}")
# Check if tool names are unique.
tool_names = [tool.name for tool in self._tools]
if len(tool_names) != len(set(tool_names)):
raise ValueError(f"Tool names must be unique: {tool_names}")
self._model_context: List[LLMMessage] = []
self._reflect_on_tool_use = reflect_on_tool_use
self._tool_call_summary_format = tool_call_summary_format
self._is_running = False
# Handoff tools.
self._handoff_tools: List[Tool] = []
self._handoffs: Dict[str, HandoffBase] = {}
Expand All @@ -258,24 +244,184 @@ def __init__(
for handoff in handoffs:
if isinstance(handoff, str):
handoff = HandoffBase(target=handoff)
if handoff.name in self._handoffs:
raise ValueError(f"Handoff name {handoff.name} already exists.")
if isinstance(handoff, HandoffBase):
self._handoff_tools.append(handoff.handoff_tool)
self._handoffs[handoff.name] = handoff
else:
raise ValueError(f"Unsupported handoff type: {type(handoff)}")
# Check if handoff tool names are unique.
handoff_tool_names = [tool.name for tool in self._handoff_tools]
if len(handoff_tool_names) != len(set(handoff_tool_names)):
raise ValueError(f"Handoff names must be unique: {handoff_tool_names}")
if tools is not None:
for tool in tools:
self.add_tool(tool)

def add_tool(self, tool: Tool | Callable[..., Any] | Callable[..., Awaitable[Any]]) -> None:
"""
Adds a new tool to the assistant agent.
The tool can be either an instance of the `Tool` class, or a callable function. If the tool is a callable
function, a `FunctionTool` instance will be created with the function and its docstring as the description.
The tool name must be unique among all the tools and handoffs added to the agent. If the model does not support
function calling, an error will be raised.
Args:
tool (Tool | Callable[..., Any] | Callable[..., Awaitable[Any]]): The tool to add.
Raises:
ValueError: If the tool name is not unique.
ValueError: If the tool name is already used by a handoff.
ValueError: If the tool has an unsupported type.
ValueError: If the model does not support function calling.
Examples:
.. code-block:: python
import asyncio
from autogen_ext.models.openai import OpenAIChatCompletionClient
from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.messages import TextMessage
from autogen_agentchat.ui import Console
from autogen_core import CancellationToken
async def get_current_time() -> str:
return "The current time is 12:00 PM."
async def main() -> None:
model_client = OpenAIChatCompletionClient(
model="gpt-4o",
# api_key = "your_openai_api_key"
)
agent = AssistantAgent(name="assistant", model_client=model_client)
agent.add_tool(get_current_time)
await Console(
agent.on_messages_stream(
[TextMessage(content="What is the current time?", source="user")], CancellationToken()
)
)
asyncio.run(main())
"""
new_tool = None
if self._model_client.capabilities["function_calling"] is False:
raise ValueError("The model does not support function calling.")
if isinstance(tool, Tool):
new_tool = tool
elif callable(tool):
if hasattr(tool, "__doc__") and tool.__doc__ is not None:
description = tool.__doc__
else:
description = ""
new_tool = FunctionTool(tool, description=description)
else:
raise ValueError(f"Unsupported tool type: {type(tool)}")
# Check if tool names are unique.
if any(tool.name == new_tool.name for tool in self._tools):
raise ValueError(f"Tool names must be unique: {new_tool.name}")
# Check if handoff tool names not in tool names.
if any(name in tool_names for name in handoff_tool_names):
handoff_tool_names = [handoff.name for handoff in self._handoffs.values()]
if new_tool.name in handoff_tool_names:
raise ValueError(
f"Handoff names must be unique from tool names. Handoff names: {handoff_tool_names}; tool names: {tool_names}"
f"Handoff names must be unique from tool names. Handoff names: {handoff_tool_names}; "
f"tool names: {new_tool.name}"
)
self._model_context: List[LLMMessage] = []
self._reflect_on_tool_use = reflect_on_tool_use
self._tool_call_summary_format = tool_call_summary_format
self._is_running = False
self._tools.append(new_tool)

def remove_all_tools(self) -> None:
"""
Remove all tools.
Examples:
.. code-block:: python
import asyncio
from autogen_ext.models.openai import OpenAIChatCompletionClient
from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.messages import TextMessage
from autogen_agentchat.ui import Console
from autogen_core import CancellationToken
async def get_current_time() -> str:
return "The current time is 12:00 PM."
async def main() -> None:
model_client = OpenAIChatCompletionClient(
model="gpt-4o",
# api_key = "your_openai_api_key"
)
agent = AssistantAgent(name="assistant", model_client=model_client)
agent.add_tool(get_current_time)
agent.remove_all_tools()
await Console(
agent.on_messages_stream(
[TextMessage(content="What is the current time?", source="user")], CancellationToken()
)
)
asyncio.run(main())
"""
self._tools.clear()

def remove_tool(self, tool_name: str) -> None:
"""
Remove a tool by name.
Args:
tool_name (str): The name of the tool to remove.
Raises:
ValueError: If the tool name is not found.
Examples:
.. code-block:: python
import asyncio
from autogen_ext.models.openai import OpenAIChatCompletionClient
from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.messages import TextMessage
from autogen_agentchat.ui import Console
from autogen_core import CancellationToken
async def get_current_time() -> str:
return "The current time is 12:00 PM."
async def main() -> None:
model_client = OpenAIChatCompletionClient(
model="gpt-4o",
# api_key = "your_openai_api_key"
)
agent = AssistantAgent(name="assistant", model_client=model_client)
agent.add_tool(get_current_time)
agent.remove_tool("get_current_time")
await Console(
agent.on_messages_stream(
[TextMessage(content="What is the current time?", source="user")], CancellationToken()
)
)
asyncio.run(main())
"""
for tool in self._tools:
if tool.name == tool_name:
self._tools.remove(tool)
return
raise ValueError(f"Tool {tool_name} not found.")

@property
def produced_message_types(self) -> List[type[ChatMessage]]:
Expand Down
75 changes: 75 additions & 0 deletions python/packages/autogen-agentchat/tests/test_assistant_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,3 +467,78 @@ async def test_list_chat_messages(monkeypatch: pytest.MonkeyPatch) -> None:
else:
assert message == result.messages[index]
index += 1


def test_tool_management():
model_client = OpenAIChatCompletionClient(model="gpt-4", api_key="")
agent = AssistantAgent(name="test_assistant", model_client=model_client)

# Test function to be used as a tool
def sample_tool() -> str:
return "sample result"

# Test adding a tool
tool = FunctionTool(sample_tool, description="Sample tool")
agent.add_tool(tool)
assert len(agent._tools) == 1

# Test adding duplicate tool
with pytest.raises(ValueError, match="Tool names must be unique"):
agent.add_tool(tool)

# Test tool collision with handoff
agent_with_handoff = AssistantAgent(
name="test_assistant", model_client=model_client, handoffs=[Handoff(target="other_agent")]
)

conflicting_tool = FunctionTool(sample_tool, name="transfer_to_other_agent", description="Sample tool")
with pytest.raises(ValueError, match="Handoff names must be unique from tool names"):
agent_with_handoff.add_tool(conflicting_tool)

# Test removing a tool
agent.remove_tool(tool.name)
assert len(agent._tools) == 0

# Test removing non-existent tool
with pytest.raises(ValueError, match="Tool non_existent_tool not found"):
agent.remove_tool("non_existent_tool")

# Test removing all tools
agent.add_tool(tool)
assert len(agent._tools) == 1
agent.remove_all_tools()
assert len(agent._tools) == 0

# Test idempotency of remove_all_tools
agent.remove_all_tools()
assert len(agent._tools) == 0


def test_callable_tool_addition():
model_client = OpenAIChatCompletionClient(model="gpt-4", api_key="")
agent = AssistantAgent(name="test_assistant", model_client=model_client)

# Test adding a callable directly
def documented_tool() -> str:
"""This is a documented tool"""
return "result"

agent.add_tool(documented_tool)
assert len(agent._tools) == 1
assert agent._tools[0].description == "This is a documented tool"

# Test adding async callable
async def async_tool() -> str:
return "async result"

agent.add_tool(async_tool)
assert len(agent._tools) == 2


def test_invalid_tool_addition():
model_client = OpenAIChatCompletionClient(model="gpt-4", api_key="")
agent = AssistantAgent(name="test_assistant", model_client=model_client)

# Test adding invalid tool type
with pytest.raises(ValueError, match="Unsupported tool type"):
agent.add_tool("not a tool")

0 comments on commit 9b9ad51

Please sign in to comment.