Skip to content

Commit

Permalink
Unsupervised (#25)
Browse files Browse the repository at this point in the history
* Version update initial implementation unsupervised

* Mask out rewards

* Reload new agent when transitioning from unsupervised

* Take the mean

* Default flags

* Unsupervised in agent

* Change only replay buffer

* Learn reward and policy in unsupervised

* Update configs

* Scale up rewards

* Update exploration steps

* Initialize model weights better
  • Loading branch information
yardenas authored Jul 22, 2024
1 parent 2a176dd commit dde4c59
Show file tree
Hide file tree
Showing 10 changed files with 90 additions and 95 deletions.
8 changes: 4 additions & 4 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 3 additions & 1 deletion safe_opax/configs/agent/la_mbda.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,6 @@ safety_slack: 0.
evaluate_model: false
exploration_strategy: uniform
exploration_steps: 5000
exploration_reward_scale: 10.0
exploration_reward_scale: 10.0
unsupervised: false
reward_scale: 1.
16 changes: 0 additions & 16 deletions safe_opax/configs/experiment/active_exploration.yaml

This file was deleted.

12 changes: 0 additions & 12 deletions safe_opax/configs/experiment/cartpole_sparse_hard.yaml

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,19 @@
defaults:
- override /environment: safe_adaptation_gym

environment:
safe_adaptation_gym:
task: go_to_goal_scarce

training:
epochs: 200
epochs: 100
safe: true
action_repeat: 2
episodes_per_epoch: 5

environment:
safe_adaptation_gym:
robot_name: doggo
task: collect

agent:
exploration_strategy: opax
exploration_steps: 1000000
exploration_steps: 850000
actor:
init_stddev: 0.025
sentiment:
model_initialization_scale: 0.05
18 changes: 0 additions & 18 deletions safe_opax/configs/experiment/safety_gym_doggo.yaml

This file was deleted.

9 changes: 8 additions & 1 deletion safe_opax/configs/experiment/unsupervised.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,18 @@ defaults:

training:
trainer: unsupervised
epochs: 100
epochs: 200
safe: true
action_repeat: 2
episodes_per_epoch: 5
exploration_steps: 1000000
test_task_name: go_to_goal

environment:
safe_adaptation_gym:
robot_name: doggo

agent:
exploration_strategy: opax
exploration_steps: 1000000
unsupervised: true
24 changes: 21 additions & 3 deletions safe_opax/la_mbda/la_mbda.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,6 @@ def init(cls, batch_size: int, cell: rssm.RSSM, action_dim: int) -> "AgentState"
return self




class LaMBDA:
def __init__(
self,
Expand Down Expand Up @@ -148,21 +146,40 @@ def observe_transition(self, transition: Transition) -> None:
def update(self):
total_steps = self.config.agent.update_steps
for batch in self.replay_buffer.sample(total_steps):
batch = TrajectoryData(
batch.observation,
batch.next_observation,
batch.action,
batch.reward * self.config.agent.reward_scale,
batch.cost,
)
inferred_rssm_states = self.update_model(batch)
initial_states = inferred_rssm_states.reshape(
-1, inferred_rssm_states.shape[-1]
)
outs = self.actor_critic.update(self.model, initial_states, next(self.prng))
if self.should_explore():
if not self.config.agent.unsupervised:
outs = self.actor_critic.update(
self.model, initial_states, next(self.prng)
)
else:
outs = {}
exploration_outs = self.exploration.update(
self.model, initial_states, next(self.prng)
)
outs.update(exploration_outs)
else:
outs = self.actor_critic.update(
self.model, initial_states, next(self.prng)
)
for k, v in outs.items():
self.metrics_monitor[k] = v

def update_model(self, batch: TrajectoryData) -> jax.Array:
features, actions = _prepare_features(batch)
learn_reward = not self.should_explore() or (
self.should_explore() and not self.config.agent.unsupervised
)
(self.model, self.model_learner.state), (loss, rest) = variational_step(
features,
actions,
Expand All @@ -173,6 +190,7 @@ def update_model(self, batch: TrajectoryData) -> jax.Array:
self.config.agent.beta,
self.config.agent.free_nats,
self.config.agent.kl_mix,
learn_reward,
)
self.metrics_monitor["agent/model/loss"] = float(loss.mean())
self.metrics_monitor["agent/model/reconstruction"] = float(
Expand Down
30 changes: 19 additions & 11 deletions safe_opax/la_mbda/world_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,24 +250,32 @@ def variational_step(
beta: float = 1.0,
free_nats: float = 0.0,
kl_mix: float = 0.8,
with_reward: bool = True,
) -> tuple[tuple[WorldModel, OptState], tuple[jax.Array, TrainingResults]]:
def loss_fn(model):
infer_fn = lambda features, actions: model(features, actions, key)
inference_result: InferenceResult = eqx.filter_vmap(infer_fn)(features, actions)
y = features.observation, jnp.concatenate([features.reward, features.cost], -1)
y_hat = inference_result.image, inference_result.reward_cost
batch_ndim = 2
reconstruction_loss = -sum(
map(
lambda predictions, targets: dtx.Independent(
dtx.Normal(targets, 1.0), targets.ndim - batch_ndim
)
.log_prob(predictions)
.mean(),
y_hat,
y,
logprobs = (
lambda predictions, targets: dtx.Independent(
dtx.Normal(targets, 1.0), targets.ndim - batch_ndim
)
.log_prob(predictions)
.mean()
)
if not with_reward:
reward = jnp.zeros_like(features.reward)
_, pred_cost = jnp.split(inference_result.reward_cost, 2, -1)
reward_cost = jnp.concatenate([reward, pred_cost], -1)
else:
reward = features.reward
reward_cost = inference_result.reward_cost
reward_cost_logprobs = logprobs(
reward_cost,
jnp.concatenate([reward, features.cost], -1),
)
image_logprobs = logprobs(inference_result.image, features.observation)
reconstruction_loss = -reward_cost_logprobs - image_logprobs
kl_loss = kl_divergence(
inference_result.posteriors, inference_result.priors, free_nats, kl_mix
)
Expand Down
47 changes: 26 additions & 21 deletions safe_opax/rl/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,11 @@
from safe_opax.rl import acting, episodic_async_env
from safe_opax.rl.epoch_summary import EpochSummary
from safe_opax.rl.logging import StateWriter, TrainingLogger
from safe_opax.rl.types import Agent, EnvironmentFactory
from safe_opax.rl.types import EnvironmentFactory
from safe_opax.rl.utils import PRNGSequence

from safe_adaptation_gym.benchmark import TASKS
from safe_adaptation_gym.tasks import Task
from safe_opax.benchmark_suites.safe_adaptation_gym import sample_task

_LOG = logging.getLogger(__name__)

Expand Down Expand Up @@ -58,7 +57,7 @@ def __init__(
self,
config: DictConfig,
make_env: EnvironmentFactory,
agent: Agent | None = None,
agent: LaMBDA | LaMBDADalal | None = None,
start_epoch: int = 0,
step: int = 0,
seeds: PRNGSequence | None = None,
Expand Down Expand Up @@ -86,24 +85,27 @@ def __enter__(self):
if self.seeds is None:
self.seeds = PRNGSequence(self.config.training.seed)
if self.agent is None:
if self.config.agent.name == "lambda":
self.agent = LaMBDA(
self.env.observation_space,
self.env.action_space,
self.config,
)
elif self.config.agent.name == "lambda_dalal":
self.agent = LaMBDADalal(
self.env.observation_space,
self.env.action_space,
self.config,
)
else:
raise NotImplementedError(
f"Unknown agent type: {self.config.agent.name}"
)
self.agent = self.make_agent()
return self

def make_agent(self) -> LaMBDA | LaMBDADalal:
assert self.env is not None
if self.config.agent.name == "lambda":
agent = LaMBDA(
self.env.observation_space,
self.env.action_space,
self.config,
)
elif self.config.agent.name == "lambda_dalal":
agent = LaMBDADalal(
self.env.observation_space,
self.env.action_space,
self.config,
)
else:
raise NotImplementedError(f"Unknown agent type: {self.config.agent.name}")
return agent

def __exit__(self, exc_type, exc_val, exc_tb):
assert self.logger is not None and self.state_writer is not None
self.state_writer.close()
Expand Down Expand Up @@ -197,13 +199,13 @@ def __init__(
self,
config: DictConfig,
make_env: EnvironmentFactory,
agent: Agent | None = None,
agent: LaMBDA | LaMBDADalal | None = None,
start_epoch: int = 0,
step: int = 0,
seeds: PRNGSequence | None = None,
):
super().__init__(config, make_env, agent, start_epoch, step, seeds)
self.test_task_name = sample_task(self.config.training.seed)
self.test_task_name = self.config.training.test_task_name
self.test_tasks: list[Task] | None = None

def __enter__(self):
Expand Down Expand Up @@ -233,4 +235,7 @@ def _run_training_epoch(
]
assert self.env is not None
self.env.reset(options={"task": self.test_tasks})
assert self.agent is not None
new_agent = self.make_agent()
self.agent.replay_buffer = new_agent.replay_buffer
return outs

0 comments on commit dde4c59

Please sign in to comment.