diff --git a/ms_agent/agent/base.py b/ms_agent/agent/base.py index cbb9db911..421d30be4 100644 --- a/ms_agent/agent/base.py +++ b/ms_agent/agent/base.py @@ -43,7 +43,7 @@ def __init__(self, trust_remote_code: bool = False): if config_dir_or_id is not None: self.config: DictConfig = Config.from_task(config_dir_or_id, env) - elif config is not None: + elif config is not None and isinstance(config, DictConfig): self.config: DictConfig = config else: self.config: DictConfig = Config.from_task(DEFAULT_YAML) @@ -52,7 +52,7 @@ def __init__(self, self.tag = getattr(config, 'tag', None) or self.DEFAULT_TAG else: self.tag = tag - self.config.tag = self.tag + setattr(self.config, 'tag', self.tag) self.trust_remote_code = trust_remote_code self.config.trust_remote_code = trust_remote_code self.handler: Optional[ConfigLifecycleHandler] = None diff --git a/ms_agent/llm/anthropic_llm.py b/ms_agent/llm/anthropic_llm.py new file mode 100644 index 000000000..b684208d0 --- /dev/null +++ b/ms_agent/llm/anthropic_llm.py @@ -0,0 +1,288 @@ +import inspect +from typing import Any, Dict, Generator, Iterator, List, Optional, Union + +import json5 +from ms_agent.llm import LLM +from ms_agent.llm.utils import Message, Tool, ToolCall +from ms_agent.utils import assert_package_exist, get_logger, retry +from omegaconf import DictConfig, OmegaConf + + +class Anthropic(LLM): + + def __init__( + self, + config: DictConfig, + base_url: Optional[str] = None, + api_key: Optional[str] = None, + ): + super().__init__(config) + assert_package_exist('anthropic', 'anthropic') + import anthropic + + self.model: str = config.llm.model + + base_url = base_url or config.llm.get('anthropic_base_url') + api_key = api_key or config.llm.get('anthropic_api_key') + + if not api_key: + raise ValueError('Anthropic API key is required.') + + self.client = anthropic.Anthropic( + api_key=api_key, + base_url=base_url, + ) + + self.args: Dict = OmegaConf.to_container( + getattr(config, 'generation_config', DictConfig({}))) + + def format_tools(self, + tools: Optional[List[Tool]]) -> Optional[List[Dict]]: + if not tools: + return None + + formatted_tools = [] + for tool in tools: + formatted_tools.append({ + 'name': tool['tool_name'], + 'description': tool.get('description', ''), + 'input_schema': { + 'type': 'object', + 'properties': tool.get('parameters', + {}).get('properties', {}), + 'required': tool.get('parameters', {}).get('required', []), + } + }) + return formatted_tools + + def _format_input_message(self, + messages: List[Message]) -> List[Dict[str, Any]]: + """Converts a list of Message objects into the format expected by the Anthropic API. + + Args: + messages (`List[Message]`): List of Message objects. + + Returns: + List[Dict[str, Any]]: List of dictionaries compatible with Anthropic's input format. + """ + formatted_messages = [] + for msg in messages: + content = [] + + if msg.content: + content.append({'type': 'text', 'text': msg.content}) + + if msg.tool_calls: + for tool_call in msg.tool_calls: + content.append({ + 'type': 'tool_use', + 'id': tool_call['id'], + 'name': tool_call['tool_name'], + 'input': tool_call.get('arguments', {}) + }) + + if msg.role == 'tool': + formatted_messages.append({ + 'role': + 'user', + 'content': [{ + 'type': 'tool_result', + 'tool_use_id': msg.tool_call_id, + 'content': msg.content + }] + }) + continue + + formatted_messages.append({'role': msg.role, 'content': content}) + return formatted_messages + + def _call_llm(self, + messages: List[Message], + tools: Optional[List[Dict]] = None, + stream: bool = False, + **kwargs) -> Any: + + formatted_messages = self._format_input_message(messages) + formatted_messages = [m for m in formatted_messages if m['content']] + + system = None + if formatted_messages[0]['role'] == 'system': + system = formatted_messages[0]['content'] + formatted_messages = formatted_messages[1:] + params = { + 'model': self.model, + 'messages': formatted_messages, + 'max_tokens': kwargs.pop('max_tokens', 1024), + } + + if system: + params['system'] = system + if tools: + params['tools'] = tools + params.update(kwargs) + + if stream: + return self.client.messages.stream(**params) + else: + return self.client.messages.create(**params) + + @retry(max_attempts=3, delay=1.0) + def generate(self, + messages: List[Message], + tools: Optional[List[Tool]] = None, + max_continue_runs: Optional[int] = None, + **kwargs) -> Union[Message, Generator[Message, None, None]]: + + formatted_tools = self.format_tools(tools) + args = self.args.copy() + args.update(kwargs) + stream = args.pop('stream', False) + + sig_params = inspect.signature(self.client.messages.create).parameters + filtered_args = {k: v for k, v in args.items() if k in sig_params} + + completion = self._call_llm(messages, formatted_tools, stream, + **filtered_args) + + if stream: + return self._stream_format_output_message(completion) + else: + return self._format_output_message(completion) + + def _stream_format_output_message(self, + stream_manager) -> Iterator[Message]: + current_message = Message( + role='assistant', + content='', + tool_calls=[], + id='', + completion_tokens=0, + prompt_tokens=0, + api_calls=1, + partial=True, + ) + tool_call_id_map = {} # index -> tool_call_id (用于去重 yield) + with stream_manager as stream: + for event in stream: + event_type = getattr(event, 'type') + if event_type == 'message_start': + msg = event.message + current_message.id = msg.id + tool_call_id_map = {} + yield current_message + elif event_type == 'text': + current_message.content = event.snapshot + yield current_message + elif event_type == 'message_stop': + final_msg = getattr(event, 'message') + full_content = '' + used_tool_call_ids = set() + for idx, block in enumerate(event.message.content): + if block is None: + continue + if block.type == 'text': + full_content += block.text + elif block.type == 'tool_use': + tool_call_id = tool_call_id_map.get(idx) + tool_call = ToolCall( + id=tool_call_id, + index=len(current_message.tool_calls), + type='function', + tool_name=block.name, + arguments=json5.dumps(block.input), + ) + current_message.tool_calls.append(tool_call) + used_tool_call_ids.add(tool_call_id) + current_message.content = full_content + current_message.partial = False + current_message.completion_tokens = getattr( + final_msg.usage, 'output_tokens', + current_message.completion_tokens) + current_message.prompt_tokens = getattr( + final_msg.usage, 'input_tokens', + current_message.prompt_tokens) + + yield current_message + + @staticmethod + def _format_output_message(completion) -> Message: + """ + Formats the full non-streaming response from Anthropic into a Message object. + + Args: + completion: The raw response from the Anthropic API (e.g., a Message object from anthropic SDK). + + Returns: + Message: A Message object containing the final response. + """ + # Extract text content + content = '' + tool_calls = [] + + # Anthropic responses have a list of content blocks + for block in completion.content: + if block.type == 'text': + content += block.text + elif block.type == 'tool_use': + tool_calls.append( + ToolCall( + id=block.id, + index=len(tool_calls), # index based on appearance + type= + 'function', # or "tool_use" depending on your schema + arguments=block.input, + tool_name=block.name, + )) + + # Anthropic does not have a native "reasoning_content" field + reasoning_content = '' + + return Message( + role='assistant', + content=content, + reasoning_content=reasoning_content, + tool_calls=tool_calls if tool_calls else None, + id=completion.id, + prompt_tokens=completion.usage.input_tokens, + completion_tokens=completion.usage.output_tokens, + ) + + +if __name__ == '__main__': + import os + config = { + 'llm': { + 'model': 'Qwen/Qwen2.5-VL-72B-Instruct', + 'anthropic_api_key': os.getenv('MODELSCOPE_API_KEY'), + 'anthropic_base_url': 'https://api-inference.modelscope.cn' + }, + 'generation_config': { + 'stream': True, + } + } + tools = [{ + 'tool_name': 'get_weather', + 'description': 'Get the current weather in a given location', + 'parameters': { + 'type': 'object', + 'properties': { + 'location': { + 'type': 'string', + 'description': 'City and state' + }, + 'unit': { + 'type': 'string', + 'enum': ['celsius', 'fahrenheit'] + } + }, + 'required': ['location'] + } + }] + + messages = [Message(role='user', content='描述杭州,300字')] + # messages = [Message(role='user', content='去伦敦现在该带什么样的衣服?')] + + llm = Anthropic(config=OmegaConf.create(config)) + result = llm.generate(messages, tools=tools) + for chunk in result: + print(chunk) diff --git a/ms_agent/llm/model_mapping.py b/ms_agent/llm/model_mapping.py index e152591c5..97bba8a3b 100644 --- a/ms_agent/llm/model_mapping.py +++ b/ms_agent/llm/model_mapping.py @@ -1,8 +1,10 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +from ms_agent.llm.anthropic_llm import Anthropic from ms_agent.llm.modelscope_llm import ModelScope from ms_agent.llm.openai_llm import OpenAI all_services_mapping = { 'modelscope': ModelScope, 'openai': OpenAI, + 'anthropic': Anthropic, } diff --git a/ms_agent/llm/openai_llm.py b/ms_agent/llm/openai_llm.py index 4a483af69..20a33cf1e 100644 --- a/ms_agent/llm/openai_llm.py +++ b/ms_agent/llm/openai_llm.py @@ -5,15 +5,14 @@ from ms_agent.llm import LLM from ms_agent.llm.utils import Message, Tool, ToolCall -from ms_agent.utils import assert_package_exist, get_logger, retry +from ms_agent.utils import (MAX_CONTINUE_RUNS, assert_package_exist, + get_logger, retry) from omegaconf import DictConfig, OmegaConf from openai.types.chat.chat_completion_message_tool_call import ( ChatCompletionMessageToolCall, Function) logger = get_logger() -MAX_CONTINUE_RUNS = 3 - class OpenAI(LLM): """Base Class for OpenAI SDK LLMs. diff --git a/ms_agent/utils/__init__.py b/ms_agent/utils/__init__.py index 927fac93c..648137072 100644 --- a/ms_agent/utils/__init__.py +++ b/ms_agent/utils/__init__.py @@ -2,3 +2,5 @@ from .llm_utils import async_retry, retry from .logger import get_logger from .utils import assert_package_exist, enhance_error, strtobool + +MAX_CONTINUE_RUNS = 3 diff --git a/tests/llm/test_anthropic.py b/tests/llm/test_anthropic.py new file mode 100644 index 000000000..61f0372fe --- /dev/null +++ b/tests/llm/test_anthropic.py @@ -0,0 +1,150 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import unittest + +from ms_agent.agent.llm_agent import LLMAgent +from ms_agent.llm.anthropic_llm import Anthropic +from ms_agent.llm.utils import Message, Tool +from omegaconf import DictConfig, OmegaConf + +from modelscope.utils.test_utils import test_level + +API_CALL_MAX_TOKEN = 50 + + +class OpenaiLLM(unittest.TestCase): + conf: DictConfig = OmegaConf.create({ + 'llm': { + 'model': 'Qwen/Qwen2.5-VL-72B-Instruct', + 'anthropic_api_key': os.getenv('MODELSCOPE_API_KEY'), + 'anthropic_base_url': 'https://api-inference.modelscope.cn', + 'service': 'anthropic' + }, + 'generation_config': { + 'stream': False, + 'extra_body': { + 'enable_thinking': False + }, + 'max_tokens': API_CALL_MAX_TOKEN + } + }) + messages = [ + Message(role='assistant', content='You are a helpful assistant.'), + Message(role='user', content='浙江的省会是哪里?'), + ] + tool_messages = [ + Message(role='assistant', content='You are a helpful assistant.'), + Message(role='user', content='经度:116.4074,纬度:39.9042是什么地方'), + ] + continue_messages = [ + Message(role='assistant', content='You are a helpful assistant.'), + Message(role='user', content='写一篇介绍杭州的短文,200字左右。'), + ] + + tools = [ + Tool( + server_name='amap-maps', + tool_name='maps_regeocode', + description='将一个高德经纬度坐标转换为行政区划地址信息', + parameters={ + 'type': 'object', + 'properties': { + 'location': { + 'type': 'string', + 'description': '经纬度' + } + }, + 'required': ['location'] + }), + Tool( + tool_name='mkdir', + description='在文件系统创建目录', + parameters={ + 'type': 'object', + 'properties': { + 'dir_name': { + 'type': 'string', + 'description': '目录名' + } + }, + 'required': ['dir_name'] + }) + ] + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_call_no_stream(self): + llm = Anthropic(self.conf) + res = llm.generate(messages=self.messages, tools=None) + print(res) + assert (res.content) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_call_stream(self): + llm = Anthropic(self.conf) + res = llm.generate(messages=self.messages, tools=None, stream=True) + for chunk in res: + print(chunk) + assert (len(chunk.content)) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_tool_stream(self): + llm = Anthropic(self.conf) + res = llm.generate( + messages=self.tool_messages, tools=self.tools, stream=True) + for chunk in res: + print(chunk) + assert (len(chunk.tool_calls)) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_tool_no_stream(self): + llm = Anthropic(self.conf) + res = llm.generate(messages=self.tool_messages, tools=self.tools) + print(res) + assert (len(res.tool_calls)) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_agent_multi_round(self): + import asyncio + + async def main(): + mcp_config = { + 'mcpServers': { + 'fetch': { + 'type': 'sse', + 'url': os.getenv('MCP_SERVER_FETCH_URL'), + } + } + } + agent = LLMAgent(config=self.conf, mcp_config=mcp_config) + res = await agent.run('访问www.baidu.com') + print(res) + assert ('robots.txt' in res[-1].content) + + asyncio.run(main()) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_stream_agent_multi_round(self): + import asyncio + from copy import deepcopy + + async def main(): + mcp_config = { + 'mcpServers': { + 'fetch': { + 'type': 'sse', + 'url': os.getenv('MCP_SERVER_FETCH_URL'), + } + } + } + conf2 = deepcopy(self.conf) + conf2.generation_config.stream = True + agent = LLMAgent(config=self.conf, mcp_config=mcp_config) + res = await agent.run('访问www.baidu.com') + print('res:', res) + assert ('robots.txt' in res[-1].content) + + asyncio.run(main()) + + +if __name__ == '__main__': + unittest.main()