From dde4c59ca74f586869158d868611d715d7285d6f Mon Sep 17 00:00:00 2001 From: Yarden As Date: Mon, 22 Jul 2024 07:59:40 +0200 Subject: [PATCH] Unsupervised (#25) * Version update initial implementation unsupervised * Mask out rewards * Reload new agent when transitioning from unsupervised * Take the mean * Default flags * Unsupervised in agent * Change only replay buffer * Learn reward and policy in unsupervised * Update configs * Scale up rewards * Update exploration steps * Initialize model weights better --- poetry.lock | 8 ++-- safe_opax/configs/agent/la_mbda.yaml | 4 +- .../experiment/active_exploration.yaml | 16 ------- .../experiment/cartpole_sparse_hard.yaml | 12 ----- ...ggo_explore.yaml => safe_sparse_goal.yaml} | 17 +++---- .../configs/experiment/safety_gym_doggo.yaml | 18 ------- .../configs/experiment/unsupervised.yaml | 9 +++- safe_opax/la_mbda/la_mbda.py | 24 ++++++++-- safe_opax/la_mbda/world_model.py | 30 +++++++----- safe_opax/rl/trainer.py | 47 ++++++++++--------- 10 files changed, 90 insertions(+), 95 deletions(-) delete mode 100644 safe_opax/configs/experiment/active_exploration.yaml delete mode 100644 safe_opax/configs/experiment/cartpole_sparse_hard.yaml rename safe_opax/configs/experiment/{safety_gym_doggo_explore.yaml => safe_sparse_goal.yaml} (56%) delete mode 100644 safe_opax/configs/experiment/safety_gym_doggo.yaml diff --git a/poetry.lock b/poetry.lock index f9232269..a30984f9 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2870,7 +2870,7 @@ dev = ["Pillow", "matplotlib", "pytest (>=4.4.0)"] type = "git" url = "https://git@github.com/lasgroup/safe-adaptation-gym" reference = "HEAD" -resolved_reference = "c91e05d26e0ca9416696cf460f1cf5a39006097b" +resolved_reference = "2c9a0d586b3b22b9d042ab82694c2907befe4799" [[package]] name = "scipy" @@ -3407,13 +3407,13 @@ test = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "coverage-enable-subprocess [[package]] name = "wandb" -version = "0.16.4" +version = "0.16.6" description = "A CLI and library for interacting with the Weights & Biases API." optional = false python-versions = ">=3.7" files = [ - {file = "wandb-0.16.4-py3-none-any.whl", hash = "sha256:bb9eb5aa2c2c85e11c76040c4271366f54d4975167aa6320ba86c3f2d97fe5fa"}, - {file = "wandb-0.16.4.tar.gz", hash = "sha256:8752c67d1347a4c29777e64dc1e1a742a66c5ecde03aebadf2b0d62183fa307c"}, + {file = "wandb-0.16.6-py3-none-any.whl", hash = "sha256:5810019a3b981c796e98ea58557a7c380f18834e0c6bdaed15df115522e5616e"}, + {file = "wandb-0.16.6.tar.gz", hash = "sha256:86f491e3012d715e0d7d7421a4d6de41abef643b7403046261f962f3e512fe1c"}, ] [package.dependencies] diff --git a/safe_opax/configs/agent/la_mbda.yaml b/safe_opax/configs/agent/la_mbda.yaml index 2320d9d1..0c05c63a 100644 --- a/safe_opax/configs/agent/la_mbda.yaml +++ b/safe_opax/configs/agent/la_mbda.yaml @@ -52,4 +52,6 @@ safety_slack: 0. evaluate_model: false exploration_strategy: uniform exploration_steps: 5000 -exploration_reward_scale: 10.0 \ No newline at end of file +exploration_reward_scale: 10.0 +unsupervised: false +reward_scale: 1. \ No newline at end of file diff --git a/safe_opax/configs/experiment/active_exploration.yaml b/safe_opax/configs/experiment/active_exploration.yaml deleted file mode 100644 index 44548d37..00000000 --- a/safe_opax/configs/experiment/active_exploration.yaml +++ /dev/null @@ -1,16 +0,0 @@ -# @package _global_ -defaults: - - override /environment: dm_cartpole - -environment: - dm_cartpole: - task: swingup_sparse_hard - -training: - epochs: 100 - safe: false - action_repeat: 2 - -agent: - exploration_strategy: opax - exploration_steps: 100000 diff --git a/safe_opax/configs/experiment/cartpole_sparse_hard.yaml b/safe_opax/configs/experiment/cartpole_sparse_hard.yaml deleted file mode 100644 index c1766439..00000000 --- a/safe_opax/configs/experiment/cartpole_sparse_hard.yaml +++ /dev/null @@ -1,12 +0,0 @@ -# @package _global_ -defaults: - - override /environment: dm_cartpole - -environment: - dm_cartpole: - task: swingup_sparse_hard - -training: - epochs: 100 - safe: false - action_repeat: 2 diff --git a/safe_opax/configs/experiment/safety_gym_doggo_explore.yaml b/safe_opax/configs/experiment/safe_sparse_goal.yaml similarity index 56% rename from safe_opax/configs/experiment/safety_gym_doggo_explore.yaml rename to safe_opax/configs/experiment/safe_sparse_goal.yaml index 665a773c..073ac4bb 100644 --- a/safe_opax/configs/experiment/safety_gym_doggo_explore.yaml +++ b/safe_opax/configs/experiment/safe_sparse_goal.yaml @@ -2,18 +2,19 @@ defaults: - override /environment: safe_adaptation_gym +environment: + safe_adaptation_gym: + task: go_to_goal_scarce training: - epochs: 200 + epochs: 100 safe: true action_repeat: 2 - episodes_per_epoch: 5 - -environment: - safe_adaptation_gym: - robot_name: doggo - task: collect agent: exploration_strategy: opax - exploration_steps: 1000000 \ No newline at end of file + exploration_steps: 850000 + actor: + init_stddev: 0.025 + sentiment: + model_initialization_scale: 0.05 \ No newline at end of file diff --git a/safe_opax/configs/experiment/safety_gym_doggo.yaml b/safe_opax/configs/experiment/safety_gym_doggo.yaml deleted file mode 100644 index c8303269..00000000 --- a/safe_opax/configs/experiment/safety_gym_doggo.yaml +++ /dev/null @@ -1,18 +0,0 @@ -# @package _global_ -defaults: - - override /environment: safe_adaptation_gym - -training: - epochs: 200 - safe: true - action_repeat: 2 - episodes_per_epoch: 5 - -environment: - safe_adaptation_gym: - robot_name: doggo - -agent: - exploration_steps: 0 - actor: - initialization_scale: 1. \ No newline at end of file diff --git a/safe_opax/configs/experiment/unsupervised.yaml b/safe_opax/configs/experiment/unsupervised.yaml index 73926667..e8746172 100644 --- a/safe_opax/configs/experiment/unsupervised.yaml +++ b/safe_opax/configs/experiment/unsupervised.yaml @@ -4,11 +4,18 @@ defaults: training: trainer: unsupervised - epochs: 100 + epochs: 200 safe: true action_repeat: 2 + episodes_per_epoch: 5 exploration_steps: 1000000 + test_task_name: go_to_goal + +environment: + safe_adaptation_gym: + robot_name: doggo agent: exploration_strategy: opax exploration_steps: 1000000 + unsupervised: true \ No newline at end of file diff --git a/safe_opax/la_mbda/la_mbda.py b/safe_opax/la_mbda/la_mbda.py index 0b1cfdf8..6ee30f93 100644 --- a/safe_opax/la_mbda/la_mbda.py +++ b/safe_opax/la_mbda/la_mbda.py @@ -52,8 +52,6 @@ def init(cls, batch_size: int, cell: rssm.RSSM, action_dim: int) -> "AgentState" return self - - class LaMBDA: def __init__( self, @@ -148,21 +146,40 @@ def observe_transition(self, transition: Transition) -> None: def update(self): total_steps = self.config.agent.update_steps for batch in self.replay_buffer.sample(total_steps): + batch = TrajectoryData( + batch.observation, + batch.next_observation, + batch.action, + batch.reward * self.config.agent.reward_scale, + batch.cost, + ) inferred_rssm_states = self.update_model(batch) initial_states = inferred_rssm_states.reshape( -1, inferred_rssm_states.shape[-1] ) - outs = self.actor_critic.update(self.model, initial_states, next(self.prng)) if self.should_explore(): + if not self.config.agent.unsupervised: + outs = self.actor_critic.update( + self.model, initial_states, next(self.prng) + ) + else: + outs = {} exploration_outs = self.exploration.update( self.model, initial_states, next(self.prng) ) outs.update(exploration_outs) + else: + outs = self.actor_critic.update( + self.model, initial_states, next(self.prng) + ) for k, v in outs.items(): self.metrics_monitor[k] = v def update_model(self, batch: TrajectoryData) -> jax.Array: features, actions = _prepare_features(batch) + learn_reward = not self.should_explore() or ( + self.should_explore() and not self.config.agent.unsupervised + ) (self.model, self.model_learner.state), (loss, rest) = variational_step( features, actions, @@ -173,6 +190,7 @@ def update_model(self, batch: TrajectoryData) -> jax.Array: self.config.agent.beta, self.config.agent.free_nats, self.config.agent.kl_mix, + learn_reward, ) self.metrics_monitor["agent/model/loss"] = float(loss.mean()) self.metrics_monitor["agent/model/reconstruction"] = float( diff --git a/safe_opax/la_mbda/world_model.py b/safe_opax/la_mbda/world_model.py index a433a0c4..eda80ec0 100644 --- a/safe_opax/la_mbda/world_model.py +++ b/safe_opax/la_mbda/world_model.py @@ -250,24 +250,32 @@ def variational_step( beta: float = 1.0, free_nats: float = 0.0, kl_mix: float = 0.8, + with_reward: bool = True, ) -> tuple[tuple[WorldModel, OptState], tuple[jax.Array, TrainingResults]]: def loss_fn(model): infer_fn = lambda features, actions: model(features, actions, key) inference_result: InferenceResult = eqx.filter_vmap(infer_fn)(features, actions) - y = features.observation, jnp.concatenate([features.reward, features.cost], -1) - y_hat = inference_result.image, inference_result.reward_cost batch_ndim = 2 - reconstruction_loss = -sum( - map( - lambda predictions, targets: dtx.Independent( - dtx.Normal(targets, 1.0), targets.ndim - batch_ndim - ) - .log_prob(predictions) - .mean(), - y_hat, - y, + logprobs = ( + lambda predictions, targets: dtx.Independent( + dtx.Normal(targets, 1.0), targets.ndim - batch_ndim ) + .log_prob(predictions) + .mean() + ) + if not with_reward: + reward = jnp.zeros_like(features.reward) + _, pred_cost = jnp.split(inference_result.reward_cost, 2, -1) + reward_cost = jnp.concatenate([reward, pred_cost], -1) + else: + reward = features.reward + reward_cost = inference_result.reward_cost + reward_cost_logprobs = logprobs( + reward_cost, + jnp.concatenate([reward, features.cost], -1), ) + image_logprobs = logprobs(inference_result.image, features.observation) + reconstruction_loss = -reward_cost_logprobs - image_logprobs kl_loss = kl_divergence( inference_result.posteriors, inference_result.priors, free_nats, kl_mix ) diff --git a/safe_opax/rl/trainer.py b/safe_opax/rl/trainer.py index cec6fd38..eed6b117 100644 --- a/safe_opax/rl/trainer.py +++ b/safe_opax/rl/trainer.py @@ -12,12 +12,11 @@ from safe_opax.rl import acting, episodic_async_env from safe_opax.rl.epoch_summary import EpochSummary from safe_opax.rl.logging import StateWriter, TrainingLogger -from safe_opax.rl.types import Agent, EnvironmentFactory +from safe_opax.rl.types import EnvironmentFactory from safe_opax.rl.utils import PRNGSequence from safe_adaptation_gym.benchmark import TASKS from safe_adaptation_gym.tasks import Task -from safe_opax.benchmark_suites.safe_adaptation_gym import sample_task _LOG = logging.getLogger(__name__) @@ -58,7 +57,7 @@ def __init__( self, config: DictConfig, make_env: EnvironmentFactory, - agent: Agent | None = None, + agent: LaMBDA | LaMBDADalal | None = None, start_epoch: int = 0, step: int = 0, seeds: PRNGSequence | None = None, @@ -86,24 +85,27 @@ def __enter__(self): if self.seeds is None: self.seeds = PRNGSequence(self.config.training.seed) if self.agent is None: - if self.config.agent.name == "lambda": - self.agent = LaMBDA( - self.env.observation_space, - self.env.action_space, - self.config, - ) - elif self.config.agent.name == "lambda_dalal": - self.agent = LaMBDADalal( - self.env.observation_space, - self.env.action_space, - self.config, - ) - else: - raise NotImplementedError( - f"Unknown agent type: {self.config.agent.name}" - ) + self.agent = self.make_agent() return self + def make_agent(self) -> LaMBDA | LaMBDADalal: + assert self.env is not None + if self.config.agent.name == "lambda": + agent = LaMBDA( + self.env.observation_space, + self.env.action_space, + self.config, + ) + elif self.config.agent.name == "lambda_dalal": + agent = LaMBDADalal( + self.env.observation_space, + self.env.action_space, + self.config, + ) + else: + raise NotImplementedError(f"Unknown agent type: {self.config.agent.name}") + return agent + def __exit__(self, exc_type, exc_val, exc_tb): assert self.logger is not None and self.state_writer is not None self.state_writer.close() @@ -197,13 +199,13 @@ def __init__( self, config: DictConfig, make_env: EnvironmentFactory, - agent: Agent | None = None, + agent: LaMBDA | LaMBDADalal | None = None, start_epoch: int = 0, step: int = 0, seeds: PRNGSequence | None = None, ): super().__init__(config, make_env, agent, start_epoch, step, seeds) - self.test_task_name = sample_task(self.config.training.seed) + self.test_task_name = self.config.training.test_task_name self.test_tasks: list[Task] | None = None def __enter__(self): @@ -233,4 +235,7 @@ def _run_training_epoch( ] assert self.env is not None self.env.reset(options={"task": self.test_tasks}) + assert self.agent is not None + new_agent = self.make_agent() + self.agent.replay_buffer = new_agent.replay_buffer return outs