@@ -77,13 +77,6 @@ def get_state_path() -> str:
7777 return log_path
7878
7979
80- env_name = "QuadrupedRun"
81- env = registry .load (env_name )
82- env_cfg = registry .get_default_config (env_name )
83- eval_env = registry .load (env_name , config = env_cfg )
84- agent_name = "PPO"
85-
86-
8780def get_ppo_train_fn ():
8881 from brax .training .agents .ppo import networks as ppo_networks
8982 from brax .training .agents .ppo import train as ppo
@@ -207,14 +200,17 @@ def main(cfg):
207200 f"\n { OmegaConf .to_yaml (cfg )} "
208201 )
209202 logger = WeightAndBiasesWriter (cfg )
210- if agent_name == "SAC" :
203+ if cfg . agent_name == "SAC" :
211204 train_fn = get_sac_train_fn ()
212- elif agent_name == "PPO" :
205+ elif cfg . agent_name == "PPO" :
213206 train_fn = get_ppo_train_fn ()
214207 else :
215208 raise NotImplementedError
216209 rng = jax .random .PRNGKey (cfg .training .seed )
217210 steps = Counter ()
211+ env = registry .load (cfg .task_name )
212+ env_cfg = registry .get_default_config (cfg .task_name )
213+ eval_env = registry .load (cfg .task_name , config = env_cfg )
218214 with jax .disable_jit (not cfg .jit ):
219215 make_policy , params , _ = train_fn (
220216 environment = env ,
0 commit comments