diff --git a/ms_agent/agent/llm_agent.py b/ms_agent/agent/llm_agent.py index 1c325b5bb..c921c93b4 100644 --- a/ms_agent/agent/llm_agent.py +++ b/ms_agent/agent/llm_agent.py @@ -72,6 +72,7 @@ def __init__(self, self.mcp_server_file = kwargs.get('mcp_server_file', None) self.mcp_config: Dict[str, Any] = self._parse_mcp_servers( kwargs.get('mcp_config', {})) + self.mcp_client = kwargs.get('mcp_client', None) self._task_begin() def register_callback(self, callback: Callback): @@ -187,7 +188,8 @@ async def _parallel_tool_call(self, async def _prepare_tools(self): """Initialize and connect the tool manager.""" - self.tool_manager = ToolManager(self.config, self.mcp_config) + self.tool_manager = ToolManager(self.config, self.mcp_config, + self.mcp_client) await self.tool_manager.connect() async def _cleanup_tools(self): diff --git a/ms_agent/tools/mcp_client.py b/ms_agent/tools/mcp_client.py index a28a52432..332f6da4e 100644 --- a/ms_agent/tools/mcp_client.py +++ b/ms_agent/tools/mcp_client.py @@ -3,6 +3,8 @@ import os from contextlib import AsyncExitStack from datetime import timedelta +from os import environb +from types import TracebackType from typing import Any, Dict, Literal, Optional from mcp import ClientSession, ListToolsResult, StdioServerParameters @@ -13,7 +15,7 @@ from ms_agent.llm.utils import Tool from ms_agent.tools.base import ToolBase from ms_agent.utils import enhance_error, get_logger -from omegaconf import DictConfig +from omegaconf import DictConfig, ListConfig logger = get_logger() @@ -40,17 +42,23 @@ class MCPClient(ToolBase): mcp_config(`Optional[Dict[str, Any]]`): Extra mcp servers in json format. """ - def __init__(self, - config: DictConfig, - mcp_config: Optional[Dict[str, Any]] = None): + def __init__( + self, + mcp_config: Optional[Dict[str, Any]] = None, + config: Optional[DictConfig] = None, + ): super().__init__(config) self.sessions: Dict[str, ClientSession] = {} self.exit_stack = AsyncExitStack() - self.mcp_config: Dict[str, Dict[ - str, Any]] = Config.convert_mcp_servers_to_json(config) + self.mcp_config: Dict[str, Dict[str, Any]] = {'mcpServers': {}} + if config is not None: + config_from_file = Config.convert_mcp_servers_to_json(config) + self.mcp_config['mcpServers'].update( + config_from_file.get('mcpServers', {})) self._exclude_functions = {} if mcp_config is not None: - self.mcp_config.update(mcp_config) + self.mcp_config['mcpServers'].update( + mcp_config.get('mcpServers', {})) async def call_tool(self, server_name: str, tool_name: str, tool_args: dict): @@ -114,7 +122,9 @@ def print_tools(server_name: str, tools: ListToolsResult): logger.info(f'\nConnected to server "{server_name}" ' f'with tools: \n{sep.join(tools)}.') - async def connect_to_server(self, server_name: str, timeout: int, + async def connect_to_server(self, + server_name: str, + timeout: int = CONNECTION_TIMEOUT, **kwargs): logger.info(f'connect to {server_name}') # transport: stdio, sse, streamable_http, websocket @@ -237,6 +247,44 @@ async def connect(self, timeout: int = CONNECTION_TIMEOUT): new_eg = enhance_error(e, f'Connect `{name}` failed, details:') raise new_eg from e + async def add_mcp_config(self, mcp_config: Dict[str, Dict[str, Any]]): + if mcp_config is None: + return + new_mcp_config = mcp_config.get('mcpServers', {}) + servers = self.mcp_config.setdefault('mcpServers', {}) + envs = Env.load_env() + for name, server in new_mcp_config.items(): + if name in servers and servers[name] == server: + continue + else: + servers[name] = server + env_dict = server.pop('env', {}) + env_dict = { + key: value if value else envs.get(key, '') + for key, value in env_dict.items() + } + if 'exclude' in server: + self._exclude_functions[name] = server.pop('exclude') + await self.connect_to_server( + server_name=name, env=env_dict, **server) + self.mcp_config['mcpServers'].update(new_mcp_config) + async def cleanup(self): """Clean up resources""" await self.exit_stack.aclose() + + async def __aenter__(self) -> 'MCPClient': + try: + await self.connect() + return self + except Exception: + await self.exit_stack.aclose() + raise + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + await self.exit_stack.aclose() diff --git a/ms_agent/tools/tool_manager.py b/ms_agent/tools/tool_manager.py index 1c8973f37..9314042a9 100644 --- a/ms_agent/tools/tool_manager.py +++ b/ms_agent/tools/tool_manager.py @@ -2,6 +2,7 @@ import asyncio import os from copy import copy +from types import TracebackType from typing import Any, Dict, List, Optional import json @@ -21,9 +22,12 @@ class ToolManager: TOOL_SPLITER = '---' - def __init__(self, config, mcp_config: Optional[Dict[str, Any]] = None): + def __init__(self, + config, + mcp_config: Optional[Dict[str, Any]] = None, + mcp_client: Optional[MCPClient] = None): self.config = config - self.servers = MCPClient(config, mcp_config) + self.extra_tools: List[ToolBase] = [] self.has_split_task_tool = False if hasattr(config, 'tools') and hasattr(config.tools, 'split_task'): @@ -34,17 +38,30 @@ def __init__(self, config, mcp_config: Optional[Dict[str, Any]] = None): TOOL_CALL_TIMEOUT) self._tool_index = {} + # Used temporarily during async initialization; the actual client is managed in self.servers + self.mcp_client = mcp_client + self.mcp_config = mcp_config + self._managed_client = mcp_client is None + def register_tool(self, tool: ToolBase): self.extra_tools.append(tool) async def connect(self): - await self.servers.connect() + if self.mcp_client and isinstance(self.mcp_client, MCPClient): + self.servers = self.mcp_client + await self.servers.add_mcp_config(self.mcp_config) + self.mcp_config = self.servers.mcp_config + else: + self.servers = MCPClient(self.mcp_config, self.config) + await self.servers.connect() for tool in self.extra_tools: await tool.connect() await self.reindex_tool() async def cleanup(self): - await self.servers.cleanup() + if self._managed_client and self.servers: + await self.servers.cleanup() + self.servers = None for tool in self.extra_tools: await tool.cleanup() @@ -101,3 +118,15 @@ async def parallel_call_tool(self, tool_list: List[ToolCall]): tasks = [self.single_call_tool(tool) for tool in tool_list] result = await asyncio.gather(*tasks) return result + + async def __aenter__(self) -> 'ToolManager': + + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + pass diff --git a/tests/tools/test_mcp_client.py b/tests/tools/test_mcp_client.py new file mode 100644 index 000000000..4a987f5a9 --- /dev/null +++ b/tests/tools/test_mcp_client.py @@ -0,0 +1,82 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import asyncio +import os +import unittest + +from ms_agent.tools.mcp_client import MCPClient + +from modelscope.utils.test_utils import test_level + + +class TestMCPClient(unittest.TestCase): + mcp_config = { + 'mcpServers': { + 'fetch': { + 'type': 'sse', + 'url': os.getenv('MCP_SERVER_FETCH_URL'), + } + } + } + mcp_config2 = { + 'mcpServers': { + 'time': { + 'type': 'sse', + 'url': os.getenv('MCP_SERVER_TIME_URL'), + } + } + } + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_outside_init(self): + + async def main(): + async with MCPClient(self.mcp_config) as mcp_client: + mcps = await mcp_client.get_tools() + assert ('fetch' in mcps) + + res = await mcp_client.call_tool( + server_name='fetch', + tool_name='fetch', + tool_args={'url': 'http://www.baidu.com'}) + assert ('baidu' in res) + + asyncio.run(main()) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_aenter(self): + + async def main(): + mcp_client = MCPClient(self.mcp_config) + await mcp_client.__aenter__() + mcps = await mcp_client.get_tools() + assert ('fetch' in mcps) + await mcp_client.__aexit__(None, None, None) + + asyncio.run(main()) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_normal_connect(self): + + async def main(): + mcp_client = MCPClient(self.mcp_config) + await mcp_client.connect() + mcps = await mcp_client.get_tools() + assert ('fetch' in mcps) + await mcp_client.cleanup() + + asyncio.run(main()) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_add_config(self): + + async def main(): + async with MCPClient(self.mcp_config) as mcp_client: + await mcp_client.add_mcp_config(self.mcp_config2) + mcps = await mcp_client.get_tools() + assert ('fetch' in mcps and 'time' in mcps) + + asyncio.run(main()) + + +if __name__ == '__main__': + unittest.main()