Skip to content

Commit

Permalink
Train only rewards after exploration (#29)
Browse files Browse the repository at this point in the history
* Train only rewards after exploration

* Learn model steps

* No action costs

* Update unsupervised task

* safe swingup

* Reset buffer
  • Loading branch information
yardenas authored Aug 8, 2024
1 parent aa49e40 commit 094b78e
Show file tree
Hide file tree
Showing 7 changed files with 42 additions and 6 deletions.
8 changes: 6 additions & 2 deletions safe_opax/benchmark_suites/dm_control/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def __getattr__(self, name):
class CartpoleUnsupervisedWrapper:
def __init__(self, env: Env):
self.env = env
self._task = self.env.env.env.env.env._env.task
self._task = self.env.env.env.env._env.task
self._reward_fn = self._task._get_reward

def reset(self, *, seed=None, options=None):
Expand Down Expand Up @@ -268,7 +268,10 @@ def make_env():
]:
task = "swingup_sparse"
else:
task = task_cfg.task
if "safe" in task_cfg.task:
task = task_cfg.task.replace("safe_", "")
else:
task = task_cfg.task
env = DMCWrapper(domain_name, task)
if "safe" in task_cfg.task:
env = ConstraintWrapper(env, task_cfg.slider_position_bound)
Expand Down Expand Up @@ -300,6 +303,7 @@ def make_env():
("dm_cartpole", "swingup"),
("dm_cartpole", "swingup_sparse"),
("dm_cartpole", "swingup_sparse_hard"),
("dm_cartpole", "safe_swingup"),
("dm_cartpole", "safe_swingup_sparse"),
("dm_cartpole", "safe_swingup_sparse_hard"),
("dm_humanoid", "stand"),
Expand Down
1 change: 1 addition & 0 deletions safe_opax/configs/agent/la_mbda.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ safety_slack: 0.
evaluate_model: false
exploration_strategy: uniform
exploration_steps: 5000
learn_model_steps: null
exploration_reward_scale: 10.0
unsupervised: false
reward_scale: 1.
2 changes: 1 addition & 1 deletion safe_opax/configs/experiment/debug.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,4 @@ agent:
replay_buffer:
batch_size: 4
sequence_length: 16
exploration_steps: 750
exploration_steps: 750
3 changes: 2 additions & 1 deletion safe_opax/configs/experiment/unsupervised_cartpole.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ defaults:

environment:
dm_cartpole:
task: safe_swingup_sparse_hard
task: safe_swingup

training:
trainer: unsupervised
Expand All @@ -20,6 +20,7 @@ agent:
exploration_strategy: opax
exploration_steps: 1000000
unsupervised: true
learn_model_steps: 1000000
actor:
init_stddev: 0.025
sentiment:
Expand Down
9 changes: 9 additions & 0 deletions safe_opax/la_mbda/la_mbda.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,12 @@ def __init__(
self.should_explore = Until(
config.agent.exploration_steps, environment_steps_per_agent_step
)
learn_model_steps = (
config.agent.learn_model_steps
if config.agent.learn_model_steps is not None
else float("inf")
)
self.learn_model = Until(learn_model_steps, environment_steps_per_agent_step)
self.metrics_monitor = MetricsMonitor()

def __call__(
Expand All @@ -123,6 +129,7 @@ def __call__(
else self.actor_critic.actor.act
)
self.should_explore.tick()
self.learn_model.tick()
actions, self.state = policy(
policy_fn,
self.model,
Expand Down Expand Up @@ -180,6 +187,7 @@ def update_model(self, batch: TrajectoryData) -> jax.Array:
learn_reward = not self.should_explore() or (
self.should_explore() and not self.config.agent.unsupervised
)
no_dynamics = self.config.agent.unsupervised and not self.learn_model()
(self.model, self.model_learner.state), (loss, rest) = variational_step(
features,
actions,
Expand All @@ -191,6 +199,7 @@ def update_model(self, batch: TrajectoryData) -> jax.Array:
self.config.agent.free_nats,
self.config.agent.kl_mix,
learn_reward,
no_dynamics,
)
self.metrics_monitor["agent/model/loss"] = float(loss.mean())
self.metrics_monitor["agent/model/reconstruction"] = float(
Expand Down
22 changes: 20 additions & 2 deletions safe_opax/la_mbda/world_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,8 +251,11 @@ def variational_step(
free_nats: float = 0.0,
kl_mix: float = 0.8,
with_reward: bool = True,
no_dymamics: bool = False,
) -> tuple[tuple[WorldModel, OptState], tuple[jax.Array, TrainingResults]]:
def loss_fn(model):
def loss_fn(model, static_part=None):
if static_part is not None:
model = eqx.combine(model, static_part)
infer_fn = lambda features, actions: model(features, actions, key)
inference_result: InferenceResult = eqx.filter_vmap(infer_fn)(features, actions)
batch_ndim = 2
Expand Down Expand Up @@ -287,11 +290,26 @@ def loss_fn(model):
)
return reconstruction_loss + beta * kl_loss, aux

(loss, rest), model_grads = eqx.filter_value_and_grad(loss_fn, has_aux=True)(model)
if no_dymamics:
diff_model, static_model = partition_dynamics_rewards(model)
(loss, rest), model_grads = eqx.filter_value_and_grad(loss_fn, has_aux=True)(
diff_model, static_model
)
else:
(loss, rest), model_grads = eqx.filter_value_and_grad(loss_fn, has_aux=True)(
model
)
new_model, new_opt_state = learner.grad_step(model, model_grads, opt_state)
return (new_model, new_opt_state), (loss, rest)


def partition_dynamics_rewards(model: WorldModel) -> tuple[WorldModel, WorldModel]:
filter_spec = jax.tree_map(lambda _: False, model)
filter_spec = eqx.tree_at(lambda tree: tree.reward_cost_decoder, filter_spec, True)
diff_model, static_model = eqx.partition(model, filter_spec)
return diff_model, static_model


# https://github.com/danijar/dreamerv2/blob/259e3faa0e01099533e29b0efafdf240adeda4b5/common/nets.py#L130
def kl_divergence(
posterior: ShiftScale, prior: ShiftScale, free_nats: float, mix: float
Expand Down
3 changes: 3 additions & 0 deletions safe_opax/rl/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,9 @@ def _run_training_epoch(
]
assert self.env is not None
self.env.reset(options={"task": self.test_tasks})
assert self.agent is not None
new_agent = self.make_agent()
self.agent.replay_buffer = new_agent.replay_buffer
return outs


Expand Down

0 comments on commit 094b78e

Please sign in to comment.