File tree Expand file tree Collapse file tree 1 file changed +5
-5
lines changed
Expand file tree Collapse file tree 1 file changed +5
-5
lines changed Original file line number Diff line number Diff line change @@ -200,17 +200,17 @@ def main(cfg):
200200 f"\n { OmegaConf .to_yaml (cfg )} "
201201 )
202202 logger = WeightAndBiasesWriter (cfg )
203- if cfg .agent_name == "SAC" :
203+ if cfg .training . agent_name == "SAC" :
204204 train_fn = get_sac_train_fn ()
205- elif cfg .agent_name == "PPO" :
205+ elif cfg .training . agent_name == "PPO" :
206206 train_fn = get_ppo_train_fn ()
207207 else :
208208 raise NotImplementedError
209209 rng = jax .random .PRNGKey (cfg .training .seed )
210210 steps = Counter ()
211- env = registry .load (cfg .task_name )
212- env_cfg = registry .get_default_config (cfg .task_name )
213- eval_env = registry .load (cfg .task_name , config = env_cfg )
211+ env = registry .load (cfg .training . task_name )
212+ env_cfg = registry .get_default_config (cfg .training . task_name )
213+ eval_env = registry .load (cfg .training . task_name , config = env_cfg )
214214 with jax .disable_jit (not cfg .jit ):
215215 make_policy , params , _ = train_fn (
216216 environment = env ,
You can’t perform that action at this time.
0 commit comments