diff --git a/core/tests/test_tool_registry.py b/core/tests/test_tool_registry.py index cebf2ee05a..43992ec28e 100644 --- a/core/tests/test_tool_registry.py +++ b/core/tests/test_tool_registry.py @@ -393,3 +393,407 @@ def failing_executor(inputs: dict) -> None: log_messages = [record.message for record in caplog.records] full_log = " ".join(log_messages) assert "...(truncated)" in full_log + + +# --------------------------------------------------------------------------- +# register_function — type inference and required/optional parameters +# --------------------------------------------------------------------------- + + +def test_register_function_infers_type_hints(): + """register_function should map Python type annotations to JSON schema types.""" + registry = ToolRegistry() + + def my_func(a: int, b: float, c: bool, d: dict, e: list, f: str = "x") -> None: + pass + + registry.register_function(my_func) + + tool = registry.get_tools()["my_func"] + props = tool.parameters["properties"] + assert props["a"]["type"] == "integer" + assert props["b"]["type"] == "number" + assert props["c"]["type"] == "boolean" + assert props["d"]["type"] == "object" + assert props["e"]["type"] == "array" + assert props["f"]["type"] == "string" + + +def test_register_function_required_vs_optional(): + """Parameters without defaults should appear in 'required'.""" + registry = ToolRegistry() + + def my_func(required_param: str, optional_param: int = 5) -> None: + pass + + registry.register_function(my_func) + + tool = registry.get_tools()["my_func"] + required = tool.parameters["required"] + assert "required_param" in required + assert "optional_param" not in required + + +def test_register_function_custom_name_and_description(): + """register_function should accept explicit name and description overrides.""" + registry = ToolRegistry() + + def original_name() -> None: + """Original docstring.""" + pass + + registry.register_function(original_name, name="custom_name", description="Custom desc") + tools = registry.get_tools() + assert "custom_name" in tools + assert "original_name" not in tools + assert tools["custom_name"].description == "Custom desc" + + +def test_register_function_falls_back_to_docstring(): + """register_function should use the docstring if no description is given.""" + registry = ToolRegistry() + + def my_tool() -> None: + """My docstring.""" + pass + + registry.register_function(my_tool) + tool = registry.get_tools()["my_tool"] + assert tool.description == "My docstring." + + +def test_register_function_executor_calls_function(): + """The executor created by register_function should call the underlying function.""" + registry = ToolRegistry() + calls = [] + + def multiply(x: int, y: int) -> int: + calls.append((x, y)) + return x * y + + registry.register_function(multiply) + tool_use = ToolUse(id="call_1", name="multiply", input={"x": 3, "y": 4}) + executor = registry.get_executor() + result = executor(tool_use) + + assert calls == [(3, 4)] + assert "12" in result.content + + +# --------------------------------------------------------------------------- +# @tool decorator discovery via discover_from_module +# --------------------------------------------------------------------------- + + +def test_discover_from_module_finds_tool_decorated_functions(tmp_path): + """discover_from_module should pick up functions decorated with @tool.""" + module_src = """ + from framework.runner.tool_registry import tool + + @tool(description="Say hello") + def greet(name: str) -> str: + return f"Hello {name}" + """ + module_path = tmp_path / "agent_tools.py" + module_path.write_text(textwrap.dedent(module_src)) + + registry = ToolRegistry() + count = registry.discover_from_module(module_path) + assert count == 1 + assert "greet" in registry.get_tools() + + +def test_discover_from_module_returns_zero_for_missing_file(tmp_path): + """discover_from_module should return 0 when the file does not exist.""" + registry = ToolRegistry() + count = registry.discover_from_module(tmp_path / "nonexistent.py") + assert count == 0 + + +def test_discover_from_module_registers_mock_executor_without_tool_executor(tmp_path): + """When TOOLS dict exists but no tool_executor, a mock executor is used.""" + module_src = """ + from framework.llm.provider import Tool + + TOOLS = { + "mock_tool": Tool( + name="mock_tool", + description="Has no executor", + parameters={"type": "object", "properties": {}}, + ), + } + """ + module_path = tmp_path / "agent_tools.py" + module_path.write_text(textwrap.dedent(module_src)) + + registry = ToolRegistry() + count = registry.discover_from_module(module_path) + assert count == 1 + + registered = registry._tools["mock_tool"] # noqa: SLF001 + result = registered.executor({"foo": "bar"}) + assert result == {"mock": True, "inputs": {"foo": "bar"}} + + +# --------------------------------------------------------------------------- +# has_tool / get_registered_names +# --------------------------------------------------------------------------- + + +def test_has_tool_returns_true_for_registered_tool(): + registry = ToolRegistry() + tool = Tool(name="t", description="d", parameters={"type": "object", "properties": {}}) + registry.register("t", tool, lambda inputs: inputs) + assert registry.has_tool("t") is True + + +def test_has_tool_returns_false_for_missing_tool(): + registry = ToolRegistry() + assert registry.has_tool("not_there") is False + + +def test_get_registered_names_lists_all_tools(): + registry = ToolRegistry() + for name in ("alpha", "beta", "gamma"): + t = Tool(name=name, description="d", parameters={"type": "object", "properties": {}}) + registry.register(name, t, lambda inputs: inputs) + assert set(registry.get_registered_names()) == {"alpha", "beta", "gamma"} + + +# --------------------------------------------------------------------------- +# Session context injection +# --------------------------------------------------------------------------- + + +def test_session_context_is_injected_into_mcp_tool_call(monkeypatch): + """Context params in session context should be forwarded to MCP tool calls.""" + registry = ToolRegistry() + registry.set_session_context(workspace_id="ws-123", agent_id="agent-99") + + received: list[dict] = [] + + class FakeClient: + def __init__(self, config): + self.config = config + + def connect(self): + pass + + def disconnect(self): + pass + + def list_tools(self): + return [ + SimpleNamespace( + name="ctx_tool", + description="context tool", + input_schema={ + "type": "object", + "properties": { + "workspace_id": {"type": "string"}, + "agent_id": {"type": "string"}, + }, + "required": [], + }, + ) + ] + + def call_tool(self, tool_name, arguments): + received.append(dict(arguments)) + return {"result": "ok"} + + monkeypatch.setattr("framework.runner.mcp_client.MCPClient", FakeClient) + + registry.register_mcp_server( + {"name": "ctx-server", "transport": "stdio", "command": "echo"}, + use_connection_manager=False, + ) + + tool_use = ToolUse(id="c1", name="ctx_tool", input={}) + executor = registry.get_executor() + executor(tool_use) + + assert received, "call_tool was never called" + assert received[0].get("workspace_id") == "ws-123" + assert received[0].get("agent_id") == "agent-99" + + +# --------------------------------------------------------------------------- +# Execution context (contextvars isolation) +# --------------------------------------------------------------------------- + + +def test_execution_context_overrides_session_context(monkeypatch): + """Execution context values should win over session context for the same key.""" + registry = ToolRegistry() + registry.set_session_context(workspace_id="session-ws") + received: list[dict] = [] + + class FakeClient: + def __init__(self, config): + self.config = config + + def connect(self): + pass + + def disconnect(self): + pass + + def list_tools(self): + return [ + SimpleNamespace( + name="exec_tool", + description="execution context tool", + input_schema={ + "type": "object", + "properties": {"workspace_id": {"type": "string"}}, + "required": [], + }, + ) + ] + + def call_tool(self, tool_name, arguments): + received.append(dict(arguments)) + return {"result": "ok"} + + monkeypatch.setattr("framework.runner.mcp_client.MCPClient", FakeClient) + registry.register_mcp_server( + {"name": "exec-server", "transport": "stdio", "command": "echo"}, + use_connection_manager=False, + ) + + token = ToolRegistry.set_execution_context(workspace_id="exec-ws") + try: + tool_use = ToolUse(id="c2", name="exec_tool", input={}) + executor = registry.get_executor() + executor(tool_use) + finally: + ToolRegistry.reset_execution_context(token) + + assert received, "call_tool was never called" + assert received[0]["workspace_id"] == "exec-ws" + + +# --------------------------------------------------------------------------- +# _convert_mcp_tool_to_framework_tool — CONTEXT_PARAMS stripped +# --------------------------------------------------------------------------- + + +def test_convert_mcp_tool_strips_context_params(): + """CONTEXT_PARAMS should be removed from the LLM-facing tool schema.""" + registry = ToolRegistry() + mcp_tool = SimpleNamespace( + name="some_tool", + description="a tool", + input_schema={ + "type": "object", + "properties": { + "workspace_id": {"type": "string"}, # context param → stripped + "agent_id": {"type": "string"}, # context param → stripped + "query": {"type": "string"}, # regular param → kept + }, + "required": ["workspace_id", "query"], + }, + ) + tool = registry._convert_mcp_tool_to_framework_tool(mcp_tool) # noqa: SLF001 + props = tool.parameters["properties"] + assert "workspace_id" not in props + assert "agent_id" not in props + assert "query" in props + # workspace_id should also be stripped from required + assert "workspace_id" not in tool.parameters["required"] + assert "query" in tool.parameters["required"] + + +# --------------------------------------------------------------------------- +# load_mcp_config — both JSON config formats +# --------------------------------------------------------------------------- + + +def test_load_mcp_config_list_format(tmp_path, monkeypatch): + """load_mcp_config should accept the {\"servers\": [...]} list format.""" + config_file = tmp_path / "mcp_servers.json" + config_file.write_text( + '{"servers": [{"name": "s1", "transport": "http", "url": "http://localhost:9000"}]}' + ) + + called_with = [] + + def fake_load_registry(server_list, **kwargs): + called_with.extend(server_list) + return [] + + registry = ToolRegistry() + monkeypatch.setattr(registry, "load_registry_servers", fake_load_registry) + registry.load_mcp_config(config_file) + + assert len(called_with) == 1 + assert called_with[0]["name"] == "s1" + + +def test_load_mcp_config_dict_format(tmp_path, monkeypatch): + """load_mcp_config should accept the {\"server-name\": {...}} dict format.""" + config_file = tmp_path / "mcp_servers.json" + config_file.write_text('{"my-server": {"transport": "http", "url": "http://localhost:9001"}}') + + called_with = [] + + def fake_load_registry(server_list, **kwargs): + called_with.extend(server_list) + return [] + + registry = ToolRegistry() + monkeypatch.setattr(registry, "load_registry_servers", fake_load_registry) + registry.load_mcp_config(config_file) + + assert len(called_with) == 1 + assert called_with[0]["name"] == "my-server" + + +def test_load_mcp_config_handles_invalid_json(tmp_path, caplog): + """load_mcp_config should log a warning and return gracefully on bad JSON.""" + bad_file = tmp_path / "bad.json" + bad_file.write_text("{not valid json") + + registry = ToolRegistry() + with caplog.at_level(logging.WARNING): + registry.load_mcp_config(bad_file) + + assert any("Failed to load MCP config" in r.message for r in caplog.records) + + +# --------------------------------------------------------------------------- +# resync_mcp_servers_if_needed — no-op when nothing changed +# --------------------------------------------------------------------------- + + +def test_resync_returns_false_when_no_clients(): + """resync_mcp_servers_if_needed should return False immediately with no clients.""" + registry = ToolRegistry() + assert registry.resync_mcp_servers_if_needed() is False + + +def test_resync_returns_false_when_credentials_unchanged(tmp_path, monkeypatch): + """Resync should return False when neither credentials nor ADEN_API_KEY changed.""" + config_file = tmp_path / "mcp_servers.json" + config_file.write_text('{"servers": []}') + + registry = ToolRegistry() + # Simulate that MCP was loaded (need at least one client and a config path) + registry._mcp_config_path = config_file # noqa: SLF001 + + class _FakeClient: + config = SimpleNamespace(name="stub") + + def disconnect(self): + pass + + registry._mcp_clients.append(_FakeClient()) # noqa: SLF001 + registry._mcp_cred_snapshot = set() # noqa: SLF001 + registry._mcp_aden_key_snapshot = None # noqa: SLF001 + + # No credentials on disk and env var not set → nothing changed + monkeypatch.delenv("ADEN_API_KEY", raising=False) + monkeypatch.setattr(registry, "_snapshot_credentials", lambda: set()) + + assert registry.resync_mcp_servers_if_needed() is False