Skip to content

Commit 33f6003

Browse files
authored
supported custom agent (#590)
* supported custom agent; * config for qwen3-32B * add sample file for examples
1 parent 007343c commit 33f6003

File tree

8 files changed

+188
-88
lines changed

8 files changed

+188
-88
lines changed

train/adapter/verl/agent_template.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
from aworld.logs.util import logger
1212
from {parser_module} import {parser_name}
1313
14+
{agent_import_str}
15+
{tool_aggregate_func_import_str}
1416
from train.adapter.verl.aworld_agent_loop import AworldAgentLoop
1517
1618
@@ -28,20 +30,29 @@ async def build_agents(self) -> Union[Agent, Swarm]:
2830
"request_id": uuid.uuid4().hex,
2931
"tool_parser": "hermes"
3032
}},
31-
{kv_parameters}
33+
{model_kv_parameters}
3234
),
3335
)
3436
3537
logger.info(f"agent config: ", conf)
3638
mcp_config = {mcp_config}
37-
return Agent(
39+
return {real_agent}(
3840
conf=conf,
3941
name="{agent_name}",
4042
desc="{agent_desc}",
41-
system_prompt="{system_prompt}",
43+
system_prompt='''{system_prompt}''',
44+
tool_names={tool_names},
45+
agent_names={agent_names},
46+
wait_tool_result={wait_tool_result},
47+
feedback_tool_result={feedback_tool_result},
48+
black_tool_actions={black_tool_actions},
49+
skill_configs={skill_configs},
50+
event_handler_name={event_handler_name},
51+
tools_aggregate_func={tools_aggregate_func},
4252
mcp_config=mcp_config,
4353
mcp_servers=list(server_name for server_name in mcp_config.get("mcpServers", {{}}).keys()),
44-
model_output_parser={parser_name}()
54+
model_output_parser={parser_name}(),
55+
{extend_params}
4556
)
4657
4758
"""

train/adapter/verl/verl_trainer.py

Lines changed: 122 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import traceback
66
import yaml
77

8-
from typing import Callable, Union, Any
8+
from typing import Callable, Union, Tuple
99

1010
from datasets import Dataset
1111
from omegaconf import OmegaConf
@@ -14,6 +14,7 @@
1414
from aworld.agents.llm_agent import Agent
1515
from aworld.config import BaseConfig, ConfigDict, load_config
1616
from aworld.core.common import Config
17+
from aworld.logs.util import logger
1718
from train.adapter.verl.agent_template import VERL_TEMPLATE
1819
from train.trainer.trainer_processor import TrainerProcessor
1920

@@ -29,7 +30,9 @@ def train(self):
2930

3031
main(self.config)
3132

32-
def check_dataset(self, dataset: Union[str, Dataset], test_dataset: Union[str, Dataset] = None):
33+
def check_dataset(self, dataset: Union[str, Dataset], test_dataset: Union[str, Dataset] = None) -> Tuple[str, str]:
34+
logger.info("Check dataset...")
35+
3336
if isinstance(dataset, str):
3437
# means dataset path
3538
dataset_path = dataset
@@ -41,6 +44,9 @@ def check_dataset(self, dataset: Union[str, Dataset], test_dataset: Union[str, D
4144
raise ValueError("Train dataset must be a string or a Dataset")
4245
self.train_dataset_path = dataset_path
4346

47+
if not test_dataset:
48+
test_dataset = dataset_path
49+
4450
if isinstance(test_dataset, str):
4551
# means dataset path
4652
test_dataset_path = test_dataset
@@ -51,9 +57,19 @@ def check_dataset(self, dataset: Union[str, Dataset], test_dataset: Union[str, D
5157
test_dataset_path = None
5258
self.test_dataset_path = test_dataset_path
5359

54-
def check_reward(self, reward_func: Union[str, Callable[..., float]]):
60+
logger.info(f"View datasets in file: {self.train_dataset_path} and {self.test_dataset_path}")
61+
return self.train_dataset_path, self.test_dataset_path
62+
63+
def check_reward(self, reward_func: Union[str, Callable[..., float]]) -> Tuple[str, str]:
64+
logger.info("Check reward...")
65+
5566
if isinstance(reward_func, str):
56-
return reward_func, os.path.basename(reward_func).replace(".py", "")
67+
# means reward func file path
68+
name = os.path.basename(reward_func).replace(".py", "")
69+
self.reward_file_path = reward_func
70+
self.reward_func_name = name
71+
logger.info(f"View reward function in file: {reward_func}, name is: {name}")
72+
return reward_func, name
5773

5874
# data_source, solution_str, ground_truth, extra_info=None
5975
sig = inspect.signature(reward_func)
@@ -78,6 +94,7 @@ def check_reward(self, reward_func: Union[str, Callable[..., float]]):
7894
content = inspect.getsource(reward_func)
7995
if 'if __name__' in content and '__main__' in content:
8096
# have __name__ == '__main__', save to function to the new file
97+
# and the func must is a dependency-free function
8198
reward_file_path = f'{self.run_path}/reward_func.py'
8299
with open(reward_file_path, 'w') as writer:
83100
writer.write(content)
@@ -86,110 +103,168 @@ def check_reward(self, reward_func: Union[str, Callable[..., float]]):
86103

87104
self.reward_file_path = reward_file_path
88105
self.reward_func_name = reward_func.__name__
106+
logger.info(f"View reward function in file: {reward_file_path}, name is: {self.reward_file_path}")
89107
return reward_file_path, reward_func.__name__
90108

91-
def check_agent(self, agent: Union[str, Agent]):
109+
def check_agent(self, agent: Union[str, Agent]) -> str:
110+
"""Check single agent instance, and create agent loop dynamically.
111+
112+
NOTE: Single-agent only now, Swarm to be added in the future.
113+
114+
Returns:
115+
Return agent yaml file used to VeRL agent loop.
116+
"""
117+
logger.info("Check agent...")
118+
92119
if isinstance(agent, str):
93120
# means an agent yaml config file path
94121
config_dict = load_config(agent)
95122
agent = Agent(**config_dict)
96123

124+
# model params
97125
model_config: BaseConfig = agent.conf.llm_config
98126
if isinstance(model_config, dict):
99127
model_dict = dict(model_config)
100128
else:
101129
model_dict = dict(model_config.to_dict())
102-
model_dict.pop("llm_provider", None)
103-
model_dict.pop("llm_model_name", None)
104-
model_dict.pop("llm_base_url", None)
105-
model_dict.pop("llm_api_key", None)
106-
model_dict.pop("llm_client_type", None)
107-
model_dict.pop("params", None)
108-
model_dict.pop("model_type", None)
109-
110-
kv_parameters = ",\n".join([f"{k}={v}" for k, v in model_dict.items()])
130+
131+
for key in ["llm_provider", "llm_model_name", "llm_base_url",
132+
"llm_api_key", "llm_client_type", "params", "model_type"]:
133+
model_dict.pop(key, None)
134+
135+
model_kv_parameters = ",\n".join([f"{k}={v}" for k, v in model_dict.items()])
136+
137+
# agent params
138+
func_name = None
139+
func_str = ''
140+
if agent.tools_aggregate_func != agent._tools_aggregate_func:
141+
# special process tools_aggregate_func
142+
if agent.tools_aggregate_func.__module__ == '__main__':
143+
raise ValueError("tools_aggregate_func must be in a independent file")
144+
else:
145+
func_str = f"from {agent.tools_aggregate_func.__module__} import {agent.tools_aggregate_func.__name__}"
146+
func_name = agent.tools_aggregate_func.__name__
147+
148+
if agent.__class__ == Agent:
149+
import_str = ''
150+
extend_params = ''
151+
else:
152+
# custom agent, the custom parameters must be explicitly specified
153+
import_str = f"from {agent.__module__} import {agent.__class__.__name__}"
154+
base_sig = inspect.signature(Agent.__init__)
155+
base_params = base_sig.parameters
156+
157+
sig = inspect.signature(agent.__init__)
158+
kv = []
159+
for k, v in sig.parameters.items():
160+
if k not in base_params:
161+
kv.append(f"{k}={getattr(agent, k)}")
162+
extend_params = ',\n'.join(kv)
163+
164+
# NOTE: If the basic interface of the `Agent` changes, an upgrade is required
111165
con = VERL_TEMPLATE.format(agent_name=agent.name(),
112166
agent_desc=agent.desc(),
113167
system_prompt=agent.system_prompt,
114168
mcp_config=agent.mcp_config,
169+
tool_names=agent.tool_names,
170+
agent_names=agent.handoffs,
171+
wait_tool_result=agent.wait_tool_result,
172+
feedback_tool_result=agent.feedback_tool_result,
173+
black_tool_actions=agent.black_tool_actions,
174+
skill_configs=agent.skill_configs,
175+
event_handler_name=agent.event_handler_name,
176+
tool_aggregate_func_import_str=func_str,
177+
tools_aggregate_func=func_name,
115178
parser_module=type(agent.model_output_parser).__module__,
116179
parser_name=type(agent.model_output_parser).__name__,
117-
kv_parameters=kv_parameters)
180+
model_kv_parameters=model_kv_parameters,
181+
agent_import_str=import_str,
182+
real_agent=agent.__class__.__name__,
183+
extend_params=extend_params)
118184
module = f"{self.run_path}/{agent.name()}"
119185
with open(f"{module}.py", 'w+') as write:
120186
write.writelines(con)
121187

122188
# VeRL agent config file
123189
module = module.replace(os.getcwd(), '').replace('/', '.')
124-
if module[0] == '.':
125-
module = module[1:]
190+
module = module[1:] if module[0] == '.' else module
126191
con = f"""- name: {agent.name()}
127192
_target_: {module}.VerlAgentLoop
128193
"""
129-
with open(f"{self.run_path}/agent.yaml", "w+") as write:
194+
195+
agent_yaml = f"{self.run_path}/agent.yaml"
196+
with open(agent_yaml, "w+") as write:
130197
write.writelines(con)
131-
self.agent_yaml = f"{self.run_path}/agent.yaml"
198+
self.agent_yaml = agent_yaml
199+
logger.info(f"View agent config in file: {agent_yaml}")
132200
return self.agent_yaml
133201

134-
def check_config(self, config: Union[str, Any]):
202+
def check_config(self, config: Union[str, Config]) -> DictConfig:
135203
import verl.trainer.config
136204

137-
file_path = os.path.join(os.path.dirname(verl.trainer.config.__file__), "_generated_ppo_trainer.yaml")
138-
try:
139-
with open(file_path, "r") as file:
140-
yaml_data = yaml.safe_load(file)
141-
except FileNotFoundError:
142-
raise ValueError(f"Can not find the file: {config}")
143-
except Exception:
144-
raise RuntimeError(f"{config} read fail.\n", traceback.format_exc())
205+
logger.info("Check config...")
145206

146-
configs = DictConfig(OmegaConf.to_container(DictConfig(yaml_data), resolve=True))
207+
# custom config or config file
208+
custom_configs = dict()
147209
if isinstance(config, str):
148210
try:
149211
with open(config, "r") as file:
150-
yaml_data = yaml.safe_load(file)
151-
configs.merge_with(yaml_data)
212+
custom_configs = yaml.safe_load(file)
152213
except FileNotFoundError:
153214
raise ValueError(f"Can not find the file: {config}")
154215
except Exception:
155216
raise RuntimeError(f"{config} read fail.\n", traceback.format_exc())
156-
157217
elif isinstance(config, Config):
158218
if isinstance(config, BaseConfig):
159-
config_dict = ConfigDict(config.model_dump())
160-
configs.merge_with(config_dict)
161-
219+
custom_configs = ConfigDict(config.model_dump())
220+
else:
221+
custom_configs = config
162222
else:
163223
raise ValueError("Config must be a string or a Config")
164224

225+
# full config
226+
file_path = os.path.join(os.path.dirname(verl.trainer.config.__file__), "_generated_ppo_trainer.yaml")
227+
try:
228+
with open(file_path, "r") as file:
229+
root_configs = yaml.safe_load(file)
230+
except FileNotFoundError:
231+
raise ValueError(f"Can not find the file: {config}")
232+
except Exception:
233+
raise RuntimeError(f"{config} read fail.\n", traceback.format_exc())
234+
235+
configs = OmegaConf.merge(root_configs, custom_configs)
236+
configs = DictConfig(OmegaConf.to_container(configs, resolve=True))
237+
logger.debug(f"train full configs: {configs}")
238+
165239
self.config = configs
166240
# replace to real value, because the values are dynamically generated
167-
if not self.config['actor_rollout_ref']['rollout']['agent']['agent_loop_config_path']:
241+
if not self.config.actor_rollout_ref.rollout.agent.agent_loop_config_path:
168242
if not hasattr(self, 'agent_yaml'):
169243
raise RuntimeError("Please check agent first before check config")
170-
self.config['actor_rollout_ref']['rollout']['agent']['agent_loop_config_path'] = self.agent_yaml
244+
self.config.actor_rollout_ref.rollout.agent.agent_loop_config_path = self.agent_yaml
171245

172-
if not self.config['custom_reward_function']['name']:
246+
if not self.config.custom_reward_function.name:
173247
if not hasattr(self, 'reward_func_name'):
174248
raise RuntimeError("Please check reward function first before check config")
175-
self.config['custom_reward_function']['name'] = self.reward_func_name
176-
if not self.config['custom_reward_function']['path']:
177-
self.config['custom_reward_function']['path'] = self.reward_file_path
249+
self.config.custom_reward_function.name = self.reward_func_name
250+
if not self.config.custom_reward_function.path:
251+
self.config.custom_reward_function.path = self.reward_file_path
178252

179-
if not self.config['data']['train_files']:
253+
if not self.config.data.train_files:
180254
if not hasattr(self, 'train_dataset_path'):
181255
raise RuntimeError("Please check train dataset first before check config")
182-
self.config['data']['train_files'] = [self.train_dataset_path]
183-
if not self.config['data']['val_files']:
256+
self.config.data.train_files = [self.train_dataset_path]
257+
if not self.config.data.val_files:
184258
if not hasattr(self, 'test_dataset_path'):
185259
raise RuntimeError("Please check test dataset first before check config")
186-
self.config['data']['val_files'] = [self.test_dataset_path]
260+
self.config.data.val_files = [self.test_dataset_path]
187261

188-
if not self.config['trainer']['default_local_dir']:
262+
if not self.config.trainer.default_local_dir:
189263
local_dir = os.path.join(self.run_path, 'checkpoints')
190264
os.makedirs(local_dir, exist_ok=True)
191-
self.config['trainer']['default_local_dir'] = local_dir
265+
self.config.trainer.default_local_dir = local_dir
192266

193267
# for check
194268
yaml.safe_dump(OmegaConf.to_container(self.config), open(f"{self.run_path}/final_trainer.yaml", "w"))
269+
logger.info(f"View final config in file: {self.run_path}/final_trainer.yaml")
195270
return self.config
Binary file not shown.
Binary file not shown.

0 commit comments

Comments
 (0)