Skip to content

Commit e97117e

Browse files
suluyanasuluyan
andauthored
feat: anthropic llm (#722)
Co-authored-by: suluyan <[email protected]>
1 parent b8a582f commit e97117e

File tree

6 files changed

+446
-5
lines changed

6 files changed

+446
-5
lines changed

ms_agent/agent/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def __init__(self,
4343
trust_remote_code: bool = False):
4444
if config_dir_or_id is not None:
4545
self.config: DictConfig = Config.from_task(config_dir_or_id, env)
46-
elif config is not None:
46+
elif config is not None and isinstance(config, DictConfig):
4747
self.config: DictConfig = config
4848
else:
4949
self.config: DictConfig = Config.from_task(DEFAULT_YAML)
@@ -52,7 +52,7 @@ def __init__(self,
5252
self.tag = getattr(config, 'tag', None) or self.DEFAULT_TAG
5353
else:
5454
self.tag = tag
55-
self.config.tag = self.tag
55+
setattr(self.config, 'tag', self.tag)
5656
self.trust_remote_code = trust_remote_code
5757
self.config.trust_remote_code = trust_remote_code
5858
self.handler: Optional[ConfigLifecycleHandler] = None

ms_agent/llm/anthropic_llm.py

Lines changed: 288 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,288 @@
1+
import inspect
2+
from typing import Any, Dict, Generator, Iterator, List, Optional, Union
3+
4+
import json5
5+
from ms_agent.llm import LLM
6+
from ms_agent.llm.utils import Message, Tool, ToolCall
7+
from ms_agent.utils import assert_package_exist, get_logger, retry
8+
from omegaconf import DictConfig, OmegaConf
9+
10+
11+
class Anthropic(LLM):
12+
13+
def __init__(
14+
self,
15+
config: DictConfig,
16+
base_url: Optional[str] = None,
17+
api_key: Optional[str] = None,
18+
):
19+
super().__init__(config)
20+
assert_package_exist('anthropic', 'anthropic')
21+
import anthropic
22+
23+
self.model: str = config.llm.model
24+
25+
base_url = base_url or config.llm.get('anthropic_base_url')
26+
api_key = api_key or config.llm.get('anthropic_api_key')
27+
28+
if not api_key:
29+
raise ValueError('Anthropic API key is required.')
30+
31+
self.client = anthropic.Anthropic(
32+
api_key=api_key,
33+
base_url=base_url,
34+
)
35+
36+
self.args: Dict = OmegaConf.to_container(
37+
getattr(config, 'generation_config', DictConfig({})))
38+
39+
def format_tools(self,
40+
tools: Optional[List[Tool]]) -> Optional[List[Dict]]:
41+
if not tools:
42+
return None
43+
44+
formatted_tools = []
45+
for tool in tools:
46+
formatted_tools.append({
47+
'name': tool['tool_name'],
48+
'description': tool.get('description', ''),
49+
'input_schema': {
50+
'type': 'object',
51+
'properties': tool.get('parameters',
52+
{}).get('properties', {}),
53+
'required': tool.get('parameters', {}).get('required', []),
54+
}
55+
})
56+
return formatted_tools
57+
58+
def _format_input_message(self,
59+
messages: List[Message]) -> List[Dict[str, Any]]:
60+
"""Converts a list of Message objects into the format expected by the Anthropic API.
61+
62+
Args:
63+
messages (`List[Message]`): List of Message objects.
64+
65+
Returns:
66+
List[Dict[str, Any]]: List of dictionaries compatible with Anthropic's input format.
67+
"""
68+
formatted_messages = []
69+
for msg in messages:
70+
content = []
71+
72+
if msg.content:
73+
content.append({'type': 'text', 'text': msg.content})
74+
75+
if msg.tool_calls:
76+
for tool_call in msg.tool_calls:
77+
content.append({
78+
'type': 'tool_use',
79+
'id': tool_call['id'],
80+
'name': tool_call['tool_name'],
81+
'input': tool_call.get('arguments', {})
82+
})
83+
84+
if msg.role == 'tool':
85+
formatted_messages.append({
86+
'role':
87+
'user',
88+
'content': [{
89+
'type': 'tool_result',
90+
'tool_use_id': msg.tool_call_id,
91+
'content': msg.content
92+
}]
93+
})
94+
continue
95+
96+
formatted_messages.append({'role': msg.role, 'content': content})
97+
return formatted_messages
98+
99+
def _call_llm(self,
100+
messages: List[Message],
101+
tools: Optional[List[Dict]] = None,
102+
stream: bool = False,
103+
**kwargs) -> Any:
104+
105+
formatted_messages = self._format_input_message(messages)
106+
formatted_messages = [m for m in formatted_messages if m['content']]
107+
108+
system = None
109+
if formatted_messages[0]['role'] == 'system':
110+
system = formatted_messages[0]['content']
111+
formatted_messages = formatted_messages[1:]
112+
params = {
113+
'model': self.model,
114+
'messages': formatted_messages,
115+
'max_tokens': kwargs.pop('max_tokens', 1024),
116+
}
117+
118+
if system:
119+
params['system'] = system
120+
if tools:
121+
params['tools'] = tools
122+
params.update(kwargs)
123+
124+
if stream:
125+
return self.client.messages.stream(**params)
126+
else:
127+
return self.client.messages.create(**params)
128+
129+
@retry(max_attempts=3, delay=1.0)
130+
def generate(self,
131+
messages: List[Message],
132+
tools: Optional[List[Tool]] = None,
133+
max_continue_runs: Optional[int] = None,
134+
**kwargs) -> Union[Message, Generator[Message, None, None]]:
135+
136+
formatted_tools = self.format_tools(tools)
137+
args = self.args.copy()
138+
args.update(kwargs)
139+
stream = args.pop('stream', False)
140+
141+
sig_params = inspect.signature(self.client.messages.create).parameters
142+
filtered_args = {k: v for k, v in args.items() if k in sig_params}
143+
144+
completion = self._call_llm(messages, formatted_tools, stream,
145+
**filtered_args)
146+
147+
if stream:
148+
return self._stream_format_output_message(completion)
149+
else:
150+
return self._format_output_message(completion)
151+
152+
def _stream_format_output_message(self,
153+
stream_manager) -> Iterator[Message]:
154+
current_message = Message(
155+
role='assistant',
156+
content='',
157+
tool_calls=[],
158+
id='',
159+
completion_tokens=0,
160+
prompt_tokens=0,
161+
api_calls=1,
162+
partial=True,
163+
)
164+
tool_call_id_map = {} # index -> tool_call_id (用于去重 yield)
165+
with stream_manager as stream:
166+
for event in stream:
167+
event_type = getattr(event, 'type')
168+
if event_type == 'message_start':
169+
msg = event.message
170+
current_message.id = msg.id
171+
tool_call_id_map = {}
172+
yield current_message
173+
elif event_type == 'text':
174+
current_message.content = event.snapshot
175+
yield current_message
176+
elif event_type == 'message_stop':
177+
final_msg = getattr(event, 'message')
178+
full_content = ''
179+
used_tool_call_ids = set()
180+
for idx, block in enumerate(event.message.content):
181+
if block is None:
182+
continue
183+
if block.type == 'text':
184+
full_content += block.text
185+
elif block.type == 'tool_use':
186+
tool_call_id = tool_call_id_map.get(idx)
187+
tool_call = ToolCall(
188+
id=tool_call_id,
189+
index=len(current_message.tool_calls),
190+
type='function',
191+
tool_name=block.name,
192+
arguments=json5.dumps(block.input),
193+
)
194+
current_message.tool_calls.append(tool_call)
195+
used_tool_call_ids.add(tool_call_id)
196+
current_message.content = full_content
197+
current_message.partial = False
198+
current_message.completion_tokens = getattr(
199+
final_msg.usage, 'output_tokens',
200+
current_message.completion_tokens)
201+
current_message.prompt_tokens = getattr(
202+
final_msg.usage, 'input_tokens',
203+
current_message.prompt_tokens)
204+
205+
yield current_message
206+
207+
@staticmethod
208+
def _format_output_message(completion) -> Message:
209+
"""
210+
Formats the full non-streaming response from Anthropic into a Message object.
211+
212+
Args:
213+
completion: The raw response from the Anthropic API (e.g., a Message object from anthropic SDK).
214+
215+
Returns:
216+
Message: A Message object containing the final response.
217+
"""
218+
# Extract text content
219+
content = ''
220+
tool_calls = []
221+
222+
# Anthropic responses have a list of content blocks
223+
for block in completion.content:
224+
if block.type == 'text':
225+
content += block.text
226+
elif block.type == 'tool_use':
227+
tool_calls.append(
228+
ToolCall(
229+
id=block.id,
230+
index=len(tool_calls), # index based on appearance
231+
type=
232+
'function', # or "tool_use" depending on your schema
233+
arguments=block.input,
234+
tool_name=block.name,
235+
))
236+
237+
# Anthropic does not have a native "reasoning_content" field
238+
reasoning_content = ''
239+
240+
return Message(
241+
role='assistant',
242+
content=content,
243+
reasoning_content=reasoning_content,
244+
tool_calls=tool_calls if tool_calls else None,
245+
id=completion.id,
246+
prompt_tokens=completion.usage.input_tokens,
247+
completion_tokens=completion.usage.output_tokens,
248+
)
249+
250+
251+
if __name__ == '__main__':
252+
import os
253+
config = {
254+
'llm': {
255+
'model': 'Qwen/Qwen2.5-VL-72B-Instruct',
256+
'anthropic_api_key': os.getenv('MODELSCOPE_API_KEY'),
257+
'anthropic_base_url': 'https://api-inference.modelscope.cn'
258+
},
259+
'generation_config': {
260+
'stream': True,
261+
}
262+
}
263+
tools = [{
264+
'tool_name': 'get_weather',
265+
'description': 'Get the current weather in a given location',
266+
'parameters': {
267+
'type': 'object',
268+
'properties': {
269+
'location': {
270+
'type': 'string',
271+
'description': 'City and state'
272+
},
273+
'unit': {
274+
'type': 'string',
275+
'enum': ['celsius', 'fahrenheit']
276+
}
277+
},
278+
'required': ['location']
279+
}
280+
}]
281+
282+
messages = [Message(role='user', content='描述杭州,300字')]
283+
# messages = [Message(role='user', content='去伦敦现在该带什么样的衣服?')]
284+
285+
llm = Anthropic(config=OmegaConf.create(config))
286+
result = llm.generate(messages, tools=tools)
287+
for chunk in result:
288+
print(chunk)

ms_agent/llm/model_mapping.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
# Copyright (c) Alibaba, Inc. and its affiliates.
2+
from ms_agent.llm.anthropic_llm import Anthropic
23
from ms_agent.llm.modelscope_llm import ModelScope
34
from ms_agent.llm.openai_llm import OpenAI
45

56
all_services_mapping = {
67
'modelscope': ModelScope,
78
'openai': OpenAI,
9+
'anthropic': Anthropic,
810
}

ms_agent/llm/openai_llm.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,14 @@
55

66
from ms_agent.llm import LLM
77
from ms_agent.llm.utils import Message, Tool, ToolCall
8-
from ms_agent.utils import assert_package_exist, get_logger, retry
8+
from ms_agent.utils import (MAX_CONTINUE_RUNS, assert_package_exist,
9+
get_logger, retry)
910
from omegaconf import DictConfig, OmegaConf
1011
from openai.types.chat.chat_completion_message_tool_call import (
1112
ChatCompletionMessageToolCall, Function)
1213

1314
logger = get_logger()
1415

15-
MAX_CONTINUE_RUNS = 3
16-
1716

1817
class OpenAI(LLM):
1918
"""Base Class for OpenAI SDK LLMs.

ms_agent/utils/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,5 @@
33
from .logger import get_logger
44
from .prompt import get_fact_retrieval_prompt
55
from .utils import assert_package_exist, enhance_error, strtobool
6+
7+
MAX_CONTINUE_RUNS = 3

0 commit comments

Comments
 (0)