Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion ms_agent/agent/llm_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def __init__(self,
self.mcp_server_file = kwargs.get('mcp_server_file', None)
self.mcp_config: Dict[str, Any] = self._parse_mcp_servers(
kwargs.get('mcp_config', {}))
self.mcp_client = kwargs.get('mcp_client', None)
self._task_begin()

def register_callback(self, callback: Callback):
Expand Down Expand Up @@ -187,7 +188,8 @@ async def _parallel_tool_call(self,

async def _prepare_tools(self):
"""Initialize and connect the tool manager."""
self.tool_manager = ToolManager(self.config, self.mcp_config)
self.tool_manager = ToolManager(self.config, self.mcp_config,
self.mcp_client)
await self.tool_manager.connect()

async def _cleanup_tools(self):
Expand Down
64 changes: 56 additions & 8 deletions ms_agent/tools/mcp_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import os
from contextlib import AsyncExitStack
from datetime import timedelta
from os import environb
from types import TracebackType
from typing import Any, Dict, Literal, Optional

from mcp import ClientSession, ListToolsResult, StdioServerParameters
Expand All @@ -13,7 +15,7 @@
from ms_agent.llm.utils import Tool
from ms_agent.tools.base import ToolBase
from ms_agent.utils import enhance_error, get_logger
from omegaconf import DictConfig
from omegaconf import DictConfig, ListConfig

logger = get_logger()

Expand All @@ -40,17 +42,23 @@ class MCPClient(ToolBase):
mcp_config(`Optional[Dict[str, Any]]`): Extra mcp servers in json format.
"""

def __init__(self,
config: DictConfig,
mcp_config: Optional[Dict[str, Any]] = None):
def __init__(
self,
mcp_config: Optional[Dict[str, Any]] = None,
config: Optional[DictConfig] = None,
):
super().__init__(config)
self.sessions: Dict[str, ClientSession] = {}
self.exit_stack = AsyncExitStack()
self.mcp_config: Dict[str, Dict[
str, Any]] = Config.convert_mcp_servers_to_json(config)
self.mcp_config: Dict[str, Dict[str, Any]] = {'mcpServers': {}}
if config is not None:
config_from_file = Config.convert_mcp_servers_to_json(config)
self.mcp_config['mcpServers'].update(
config_from_file.get('mcpServers', {}))
self._exclude_functions = {}
if mcp_config is not None:
self.mcp_config.update(mcp_config)
self.mcp_config['mcpServers'].update(
mcp_config.get('mcpServers', {}))

async def call_tool(self, server_name: str, tool_name: str,
tool_args: dict):
Expand Down Expand Up @@ -114,7 +122,9 @@ def print_tools(server_name: str, tools: ListToolsResult):
logger.info(f'\nConnected to server "{server_name}" '
f'with tools: \n{sep.join(tools)}.')

async def connect_to_server(self, server_name: str, timeout: int,
async def connect_to_server(self,
server_name: str,
timeout: int = CONNECTION_TIMEOUT,
**kwargs):
logger.info(f'connect to {server_name}')
# transport: stdio, sse, streamable_http, websocket
Expand Down Expand Up @@ -237,6 +247,44 @@ async def connect(self, timeout: int = CONNECTION_TIMEOUT):
new_eg = enhance_error(e, f'Connect `{name}` failed, details:')
raise new_eg from e

async def add_mcp_config(self, mcp_config: Dict[str, Dict[str, Any]]):
if mcp_config is None:
return
new_mcp_config = mcp_config.get('mcpServers', {})
servers = self.mcp_config.setdefault('mcpServers', {})
envs = Env.load_env()
for name, server in new_mcp_config.items():
if name in servers and servers[name] == server:
continue
else:
servers[name] = server
env_dict = server.pop('env', {})
env_dict = {
key: value if value else envs.get(key, '')
for key, value in env_dict.items()
}
if 'exclude' in server:
self._exclude_functions[name] = server.pop('exclude')
await self.connect_to_server(
server_name=name, env=env_dict, **server)
self.mcp_config['mcpServers'].update(new_mcp_config)

async def cleanup(self):
"""Clean up resources"""
await self.exit_stack.aclose()

async def __aenter__(self) -> 'MCPClient':
try:
await self.connect()
return self
except Exception:
await self.exit_stack.aclose()
raise

async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
await self.exit_stack.aclose()
37 changes: 33 additions & 4 deletions ms_agent/tools/tool_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import asyncio
import os
from copy import copy
from types import TracebackType
from typing import Any, Dict, List, Optional

import json
Expand All @@ -21,9 +22,12 @@ class ToolManager:

TOOL_SPLITER = '---'

def __init__(self, config, mcp_config: Optional[Dict[str, Any]] = None):
def __init__(self,
config,
mcp_config: Optional[Dict[str, Any]] = None,
mcp_client: Optional[MCPClient] = None):
self.config = config
self.servers = MCPClient(config, mcp_config)

self.extra_tools: List[ToolBase] = []
self.has_split_task_tool = False
if hasattr(config, 'tools') and hasattr(config.tools, 'split_task'):
Expand All @@ -34,17 +38,30 @@ def __init__(self, config, mcp_config: Optional[Dict[str, Any]] = None):
TOOL_CALL_TIMEOUT)
self._tool_index = {}

# Used temporarily during async initialization; the actual client is managed in self.servers
self.mcp_client = mcp_client
self.mcp_config = mcp_config
self._managed_client = mcp_client is None

def register_tool(self, tool: ToolBase):
self.extra_tools.append(tool)

async def connect(self):
await self.servers.connect()
if self.mcp_client and isinstance(self.mcp_client, MCPClient):
self.servers = self.mcp_client
await self.servers.add_mcp_config(self.mcp_config)
self.mcp_config = self.servers.mcp_config
else:
self.servers = MCPClient(self.mcp_config, self.config)
await self.servers.connect()
for tool in self.extra_tools:
await tool.connect()
await self.reindex_tool()

async def cleanup(self):
await self.servers.cleanup()
if self._managed_client and self.servers:
await self.servers.cleanup()
self.servers = None
for tool in self.extra_tools:
await tool.cleanup()

Expand Down Expand Up @@ -101,3 +118,15 @@ async def parallel_call_tool(self, tool_list: List[ToolCall]):
tasks = [self.single_call_tool(tool) for tool in tool_list]
result = await asyncio.gather(*tasks)
return result

async def __aenter__(self) -> 'ToolManager':

return self

async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
pass
82 changes: 82 additions & 0 deletions tests/tools/test_mcp_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import asyncio
import os
import unittest

from ms_agent.tools.mcp_client import MCPClient

from modelscope.utils.test_utils import test_level


class TestMCPClient(unittest.TestCase):
mcp_config = {
'mcpServers': {
'fetch': {
'type': 'sse',
'url': os.getenv('MCP_SERVER_FETCH_URL'),
}
}
}
mcp_config2 = {
'mcpServers': {
'time': {
'type': 'sse',
'url': os.getenv('MCP_SERVER_TIME_URL'),
}
}
}

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_outside_init(self):

async def main():
async with MCPClient(self.mcp_config) as mcp_client:
mcps = await mcp_client.get_tools()
assert ('fetch' in mcps)

res = await mcp_client.call_tool(
server_name='fetch',
tool_name='fetch',
tool_args={'url': 'http://www.baidu.com'})
assert ('baidu' in res)

asyncio.run(main())

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_aenter(self):

async def main():
mcp_client = MCPClient(self.mcp_config)
await mcp_client.__aenter__()
mcps = await mcp_client.get_tools()
assert ('fetch' in mcps)
await mcp_client.__aexit__(None, None, None)

asyncio.run(main())

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_normal_connect(self):

async def main():
mcp_client = MCPClient(self.mcp_config)
await mcp_client.connect()
mcps = await mcp_client.get_tools()
assert ('fetch' in mcps)
await mcp_client.cleanup()

asyncio.run(main())

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_add_config(self):

async def main():
async with MCPClient(self.mcp_config) as mcp_client:
await mcp_client.add_mcp_config(self.mcp_config2)
mcps = await mcp_client.get_tools()
assert ('fetch' in mcps and 'time' in mcps)

asyncio.run(main())


if __name__ == '__main__':
unittest.main()
Loading