Skip to content

Commit 5e28ae5

Browse files
committed
fix bugs
1 parent e2d016f commit 5e28ae5

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

train_brax.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -200,17 +200,17 @@ def main(cfg):
200200
f"\n{OmegaConf.to_yaml(cfg)}"
201201
)
202202
logger = WeightAndBiasesWriter(cfg)
203-
if cfg.agent_name == "SAC":
203+
if cfg.training.agent_name == "SAC":
204204
train_fn = get_sac_train_fn()
205-
elif cfg.agent_name == "PPO":
205+
elif cfg.training.agent_name == "PPO":
206206
train_fn = get_ppo_train_fn()
207207
else:
208208
raise NotImplementedError
209209
rng = jax.random.PRNGKey(cfg.training.seed)
210210
steps = Counter()
211-
env = registry.load(cfg.task_name)
212-
env_cfg = registry.get_default_config(cfg.task_name)
213-
eval_env = registry.load(cfg.task_name, config=env_cfg)
211+
env = registry.load(cfg.training.task_name)
212+
env_cfg = registry.get_default_config(cfg.training.task_name)
213+
eval_env = registry.load(cfg.training.task_name, config=env_cfg)
214214
with jax.disable_jit(not cfg.jit):
215215
make_policy, params, _ = train_fn(
216216
environment=env,

0 commit comments

Comments
 (0)