55import traceback
66import yaml
77
8- from typing import Callable , Union , Any
8+ from typing import Callable , Union , Tuple
99
1010from datasets import Dataset
1111from omegaconf import OmegaConf
1414from aworld .agents .llm_agent import Agent
1515from aworld .config import BaseConfig , ConfigDict , load_config
1616from aworld .core .common import Config
17+ from aworld .logs .util import logger
1718from train .adapter .verl .agent_template import VERL_TEMPLATE
1819from train .trainer .trainer_processor import TrainerProcessor
1920
@@ -29,7 +30,9 @@ def train(self):
2930
3031 main (self .config )
3132
32- def check_dataset (self , dataset : Union [str , Dataset ], test_dataset : Union [str , Dataset ] = None ):
33+ def check_dataset (self , dataset : Union [str , Dataset ], test_dataset : Union [str , Dataset ] = None ) -> Tuple [str , str ]:
34+ logger .info ("Check dataset..." )
35+
3336 if isinstance (dataset , str ):
3437 # means dataset path
3538 dataset_path = dataset
@@ -41,6 +44,9 @@ def check_dataset(self, dataset: Union[str, Dataset], test_dataset: Union[str, D
4144 raise ValueError ("Train dataset must be a string or a Dataset" )
4245 self .train_dataset_path = dataset_path
4346
47+ if not test_dataset :
48+ test_dataset = dataset_path
49+
4450 if isinstance (test_dataset , str ):
4551 # means dataset path
4652 test_dataset_path = test_dataset
@@ -51,9 +57,19 @@ def check_dataset(self, dataset: Union[str, Dataset], test_dataset: Union[str, D
5157 test_dataset_path = None
5258 self .test_dataset_path = test_dataset_path
5359
54- def check_reward (self , reward_func : Union [str , Callable [..., float ]]):
60+ logger .info (f"View datasets in file: { self .train_dataset_path } and { self .test_dataset_path } " )
61+ return self .train_dataset_path , self .test_dataset_path
62+
63+ def check_reward (self , reward_func : Union [str , Callable [..., float ]]) -> Tuple [str , str ]:
64+ logger .info ("Check reward..." )
65+
5566 if isinstance (reward_func , str ):
56- return reward_func , os .path .basename (reward_func ).replace (".py" , "" )
67+ # means reward func file path
68+ name = os .path .basename (reward_func ).replace (".py" , "" )
69+ self .reward_file_path = reward_func
70+ self .reward_func_name = name
71+ logger .info (f"View reward function in file: { reward_func } , name is: { name } " )
72+ return reward_func , name
5773
5874 # data_source, solution_str, ground_truth, extra_info=None
5975 sig = inspect .signature (reward_func )
@@ -78,6 +94,7 @@ def check_reward(self, reward_func: Union[str, Callable[..., float]]):
7894 content = inspect .getsource (reward_func )
7995 if 'if __name__' in content and '__main__' in content :
8096 # have __name__ == '__main__', save to function to the new file
97+ # and the func must is a dependency-free function
8198 reward_file_path = f'{ self .run_path } /reward_func.py'
8299 with open (reward_file_path , 'w' ) as writer :
83100 writer .write (content )
@@ -86,110 +103,168 @@ def check_reward(self, reward_func: Union[str, Callable[..., float]]):
86103
87104 self .reward_file_path = reward_file_path
88105 self .reward_func_name = reward_func .__name__
106+ logger .info (f"View reward function in file: { reward_file_path } , name is: { self .reward_file_path } " )
89107 return reward_file_path , reward_func .__name__
90108
91- def check_agent (self , agent : Union [str , Agent ]):
109+ def check_agent (self , agent : Union [str , Agent ]) -> str :
110+ """Check single agent instance, and create agent loop dynamically.
111+
112+ NOTE: Single-agent only now, Swarm to be added in the future.
113+
114+ Returns:
115+ Return agent yaml file used to VeRL agent loop.
116+ """
117+ logger .info ("Check agent..." )
118+
92119 if isinstance (agent , str ):
93120 # means an agent yaml config file path
94121 config_dict = load_config (agent )
95122 agent = Agent (** config_dict )
96123
124+ # model params
97125 model_config : BaseConfig = agent .conf .llm_config
98126 if isinstance (model_config , dict ):
99127 model_dict = dict (model_config )
100128 else :
101129 model_dict = dict (model_config .to_dict ())
102- model_dict .pop ("llm_provider" , None )
103- model_dict .pop ("llm_model_name" , None )
104- model_dict .pop ("llm_base_url" , None )
105- model_dict .pop ("llm_api_key" , None )
106- model_dict .pop ("llm_client_type" , None )
107- model_dict .pop ("params" , None )
108- model_dict .pop ("model_type" , None )
109-
110- kv_parameters = ",\n " .join ([f"{ k } ={ v } " for k , v in model_dict .items ()])
130+
131+ for key in ["llm_provider" , "llm_model_name" , "llm_base_url" ,
132+ "llm_api_key" , "llm_client_type" , "params" , "model_type" ]:
133+ model_dict .pop (key , None )
134+
135+ model_kv_parameters = ",\n " .join ([f"{ k } ={ v } " for k , v in model_dict .items ()])
136+
137+ # agent params
138+ func_name = None
139+ func_str = ''
140+ if agent .tools_aggregate_func != agent ._tools_aggregate_func :
141+ # special process tools_aggregate_func
142+ if agent .tools_aggregate_func .__module__ == '__main__' :
143+ raise ValueError ("tools_aggregate_func must be in a independent file" )
144+ else :
145+ func_str = f"from { agent .tools_aggregate_func .__module__ } import { agent .tools_aggregate_func .__name__ } "
146+ func_name = agent .tools_aggregate_func .__name__
147+
148+ if agent .__class__ == Agent :
149+ import_str = ''
150+ extend_params = ''
151+ else :
152+ # custom agent, the custom parameters must be explicitly specified
153+ import_str = f"from { agent .__module__ } import { agent .__class__ .__name__ } "
154+ base_sig = inspect .signature (Agent .__init__ )
155+ base_params = base_sig .parameters
156+
157+ sig = inspect .signature (agent .__init__ )
158+ kv = []
159+ for k , v in sig .parameters .items ():
160+ if k not in base_params :
161+ kv .append (f"{ k } ={ getattr (agent , k )} " )
162+ extend_params = ',\n ' .join (kv )
163+
164+ # NOTE: If the basic interface of the `Agent` changes, an upgrade is required
111165 con = VERL_TEMPLATE .format (agent_name = agent .name (),
112166 agent_desc = agent .desc (),
113167 system_prompt = agent .system_prompt ,
114168 mcp_config = agent .mcp_config ,
169+ tool_names = agent .tool_names ,
170+ agent_names = agent .handoffs ,
171+ wait_tool_result = agent .wait_tool_result ,
172+ feedback_tool_result = agent .feedback_tool_result ,
173+ black_tool_actions = agent .black_tool_actions ,
174+ skill_configs = agent .skill_configs ,
175+ event_handler_name = agent .event_handler_name ,
176+ tool_aggregate_func_import_str = func_str ,
177+ tools_aggregate_func = func_name ,
115178 parser_module = type (agent .model_output_parser ).__module__ ,
116179 parser_name = type (agent .model_output_parser ).__name__ ,
117- kv_parameters = kv_parameters )
180+ model_kv_parameters = model_kv_parameters ,
181+ agent_import_str = import_str ,
182+ real_agent = agent .__class__ .__name__ ,
183+ extend_params = extend_params )
118184 module = f"{ self .run_path } /{ agent .name ()} "
119185 with open (f"{ module } .py" , 'w+' ) as write :
120186 write .writelines (con )
121187
122188 # VeRL agent config file
123189 module = module .replace (os .getcwd (), '' ).replace ('/' , '.' )
124- if module [0 ] == '.' :
125- module = module [1 :]
190+ module = module [1 :] if module [0 ] == '.' else module
126191 con = f"""- name: { agent .name ()}
127192 _target_: { module } .VerlAgentLoop
128193 """
129- with open (f"{ self .run_path } /agent.yaml" , "w+" ) as write :
194+
195+ agent_yaml = f"{ self .run_path } /agent.yaml"
196+ with open (agent_yaml , "w+" ) as write :
130197 write .writelines (con )
131- self .agent_yaml = f"{ self .run_path } /agent.yaml"
198+ self .agent_yaml = agent_yaml
199+ logger .info (f"View agent config in file: { agent_yaml } " )
132200 return self .agent_yaml
133201
134- def check_config (self , config : Union [str , Any ]) :
202+ def check_config (self , config : Union [str , Config ]) -> DictConfig :
135203 import verl .trainer .config
136204
137- file_path = os .path .join (os .path .dirname (verl .trainer .config .__file__ ), "_generated_ppo_trainer.yaml" )
138- try :
139- with open (file_path , "r" ) as file :
140- yaml_data = yaml .safe_load (file )
141- except FileNotFoundError :
142- raise ValueError (f"Can not find the file: { config } " )
143- except Exception :
144- raise RuntimeError (f"{ config } read fail.\n " , traceback .format_exc ())
205+ logger .info ("Check config..." )
145206
146- configs = DictConfig (OmegaConf .to_container (DictConfig (yaml_data ), resolve = True ))
207+ # custom config or config file
208+ custom_configs = dict ()
147209 if isinstance (config , str ):
148210 try :
149211 with open (config , "r" ) as file :
150- yaml_data = yaml .safe_load (file )
151- configs .merge_with (yaml_data )
212+ custom_configs = yaml .safe_load (file )
152213 except FileNotFoundError :
153214 raise ValueError (f"Can not find the file: { config } " )
154215 except Exception :
155216 raise RuntimeError (f"{ config } read fail.\n " , traceback .format_exc ())
156-
157217 elif isinstance (config , Config ):
158218 if isinstance (config , BaseConfig ):
159- config_dict = ConfigDict (config .model_dump ())
160- configs . merge_with ( config_dict )
161-
219+ custom_configs = ConfigDict (config .model_dump ())
220+ else :
221+ custom_configs = config
162222 else :
163223 raise ValueError ("Config must be a string or a Config" )
164224
225+ # full config
226+ file_path = os .path .join (os .path .dirname (verl .trainer .config .__file__ ), "_generated_ppo_trainer.yaml" )
227+ try :
228+ with open (file_path , "r" ) as file :
229+ root_configs = yaml .safe_load (file )
230+ except FileNotFoundError :
231+ raise ValueError (f"Can not find the file: { config } " )
232+ except Exception :
233+ raise RuntimeError (f"{ config } read fail.\n " , traceback .format_exc ())
234+
235+ configs = OmegaConf .merge (root_configs , custom_configs )
236+ configs = DictConfig (OmegaConf .to_container (configs , resolve = True ))
237+ logger .debug (f"train full configs: { configs } " )
238+
165239 self .config = configs
166240 # replace to real value, because the values are dynamically generated
167- if not self .config [ ' actor_rollout_ref' ][ ' rollout' ][ ' agent' ][ ' agent_loop_config_path' ] :
241+ if not self .config . actor_rollout_ref . rollout . agent . agent_loop_config_path :
168242 if not hasattr (self , 'agent_yaml' ):
169243 raise RuntimeError ("Please check agent first before check config" )
170- self .config [ ' actor_rollout_ref' ][ ' rollout' ][ ' agent' ][ ' agent_loop_config_path' ] = self .agent_yaml
244+ self .config . actor_rollout_ref . rollout . agent . agent_loop_config_path = self .agent_yaml
171245
172- if not self .config [ ' custom_reward_function' ][ ' name' ] :
246+ if not self .config . custom_reward_function . name :
173247 if not hasattr (self , 'reward_func_name' ):
174248 raise RuntimeError ("Please check reward function first before check config" )
175- self .config [ ' custom_reward_function' ][ ' name' ] = self .reward_func_name
176- if not self .config [ ' custom_reward_function' ][ ' path' ] :
177- self .config [ ' custom_reward_function' ][ ' path' ] = self .reward_file_path
249+ self .config . custom_reward_function . name = self .reward_func_name
250+ if not self .config . custom_reward_function . path :
251+ self .config . custom_reward_function . path = self .reward_file_path
178252
179- if not self .config [ ' data' ][ ' train_files' ] :
253+ if not self .config . data . train_files :
180254 if not hasattr (self , 'train_dataset_path' ):
181255 raise RuntimeError ("Please check train dataset first before check config" )
182- self .config [ ' data' ][ ' train_files' ] = [self .train_dataset_path ]
183- if not self .config [ ' data' ][ ' val_files' ] :
256+ self .config . data . train_files = [self .train_dataset_path ]
257+ if not self .config . data . val_files :
184258 if not hasattr (self , 'test_dataset_path' ):
185259 raise RuntimeError ("Please check test dataset first before check config" )
186- self .config [ ' data' ][ ' val_files' ] = [self .test_dataset_path ]
260+ self .config . data . val_files = [self .test_dataset_path ]
187261
188- if not self .config [ ' trainer' ][ ' default_local_dir' ] :
262+ if not self .config . trainer . default_local_dir :
189263 local_dir = os .path .join (self .run_path , 'checkpoints' )
190264 os .makedirs (local_dir , exist_ok = True )
191- self .config [ ' trainer' ][ ' default_local_dir' ] = local_dir
265+ self .config . trainer . default_local_dir = local_dir
192266
193267 # for check
194268 yaml .safe_dump (OmegaConf .to_container (self .config ), open (f"{ self .run_path } /final_trainer.yaml" , "w" ))
269+ logger .info (f"View final config in file: { self .run_path } /final_trainer.yaml" )
195270 return self .config
0 commit comments