Skip to content

Commit 67aa99d

Browse files
authored
Sgym unsupervised (#33)
* Update SAG * No need to really reset tasks for now * Improve config file * Add Bhavi scale * Back to circle, use Bhavi scale * Update hparams * Set default epistemic scale to 1. * Fix tests
1 parent 6ace84a commit 67aa99d

9 files changed

+29
-15
lines changed

poetry.lock

+1-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

safe_opax/configs/agent/la_mbda.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -54,5 +54,6 @@ exploration_strategy: uniform
5454
exploration_steps: 5000
5555
learn_model_steps: null
5656
exploration_reward_scale: 10.0
57+
exploration_epistemic_scale: 1.
5758
unsupervised: false
5859
reward_scale: 1.

safe_opax/configs/experiment/debug_unsupervised.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -31,4 +31,4 @@ agent:
3131
batch_size: 4
3232
sequence_length: 16
3333
exploration_steps: 750
34-
unsupervised: true
34+
unsupervised: true

safe_opax/configs/experiment/unsupervised_safety_gym.yaml

+6-4
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,21 @@ defaults:
44

55
training:
66
trainer: unsupervised
7-
epochs: 200
7+
epochs: 100
88
safe: true
99
action_repeat: 2
10-
episodes_per_epoch: 5
1110
exploration_steps: 1000000
1211
train_task_name: unsupervised
1312
test_task_name: go_to_goal
1413

1514
environment:
1615
safe_adaptation_gym:
17-
robot_name: doggo
16+
robot_name: point
1817

1918
agent:
2019
exploration_strategy: opax
2120
exploration_steps: 1000000
22-
unsupervised: true
21+
unsupervised: true
22+
learn_model_steps: 1000000
23+
exploration_epistemic_scale: 15.0
24+
exploration_reward_scale: 25.0

safe_opax/la_mbda/exploration.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -50,14 +50,15 @@ def __init__(
5050
),
5151
)
5252
self.reward_scale = config.agent.exploration_reward_scale
53+
self.epistemic_scale = config.agent.exploration_epistemic_scale
5354

5455
def update(
5556
self,
5657
model: Model,
5758
initial_states: jax.Array,
5859
key: jax.Array,
5960
) -> dict[str, float]:
60-
model = OpaxBridge(model, self.reward_scale)
61+
model = OpaxBridge(model, self.reward_scale, self.epistemic_scale)
6162
outs = self.actor_critic.update(model, initial_states, key)
6263
outs = {f"{_append_opax(k)}": v for k, v in outs.items()}
6364
return outs
@@ -77,6 +78,5 @@ def __init__(self, action_dim: int):
7778
self.action_dim = action_dim
7879
self.policy = lambda _, key: jax.random.uniform(key, (self.action_dim,))
7980

80-
8181
def get_policy(self) -> Policy:
8282
return self.policy

safe_opax/la_mbda/opax_bridge.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
class OpaxBridge(eqx.Module):
1010
model: WorldModel
1111
reward_scale: float = eqx.field(static=True)
12+
reward_epistemic_scale: float = eqx.field(static=True)
1213

1314
def sample(
1415
self,
@@ -21,4 +22,6 @@ def sample(
2122
horizon, initial_state, key, policy
2223
)
2324
trajectory, distributions = samples
24-
return opax.modify_reward(trajectory, distributions, self.reward_scale)
25+
return opax.modify_reward(
26+
trajectory, distributions, self.reward_scale, self.reward_epistemic_scale
27+
)

safe_opax/opax.py

+11-3
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,12 @@ def modify_reward(
1010
trajectory: Prediction,
1111
distributions: ShiftScale,
1212
scale: float = 1.0,
13+
epistemic_scale: float = 1.0,
1314
stop_grad: bool = True,
1415
) -> tuple[Prediction, ShiftScale]:
15-
new_rewards = normalized_epistemic_uncertainty(distributions) * scale
16+
new_rewards = (
17+
normalized_epistemic_uncertainty(distributions, scale=epistemic_scale) * scale
18+
)
1619
if stop_grad:
1720
new_rewards = jax.lax.stop_gradient(new_rewards)
1821
return Prediction(
@@ -23,10 +26,15 @@ def modify_reward(
2326

2427

2528
def normalized_epistemic_uncertainty(
26-
distributions: ShiftScale, axis: int = 0
29+
distributions: ShiftScale, axis: int = 0, scale: float = 1.0
2730
) -> jnp.ndarray:
2831
epistemic_uncertainty = distributions.shift.var(axis)
2932
aleatoric_uncertainty = (distributions.scale**2).mean(axis)
3033
return 0.5 * jnp.log(
31-
1.0 + (epistemic_uncertainty.mean(-1) / (aleatoric_uncertainty.mean(-1) + _EPS))
34+
1.0
35+
+ (
36+
scale
37+
* epistemic_uncertainty.mean(-1)
38+
/ (aleatoric_uncertainty.mean(-1) + _EPS)
39+
)
3240
)

safe_opax/rl/trainer.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,8 @@ def __init__(
213213
super().__init__(config, make_env, agent, start_epoch, step, seeds)
214214
self.test_task_name = self.config.training.test_task_name
215215
self.train_task_name = self.config.training.train_task_name
216+
# After a few iterations, we realized `test_tasks` are not useful, as we just use multiple rewards.
217+
# just ignore this.
216218
self.test_tasks: list[Task] | None = None
217219

218220
def __enter__(self):
@@ -233,7 +235,6 @@ def __enter__(self):
233235
get_task(self.test_task_name)
234236
for _ in range(self.config.training.parallel_envs)
235237
]
236-
self.env.reset(options={"task": self.test_tasks})
237238
return self
238239

239240
def _run_training_epoch(
@@ -250,7 +251,6 @@ def _run_training_epoch(
250251
for _ in range(self.config.training.parallel_envs)
251252
]
252253
assert self.env is not None
253-
self.env.reset(options={"task": self.test_tasks})
254254
assert self.agent is not None
255255
return outs
256256

tests/test_unsupervised_trainer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -43,4 +43,4 @@ def test_epoch(trainer):
4343
with trainer as trainer:
4444
with patch.object(trainer.env, "reset", wraps=trainer.env.reset) as mock:
4545
trainer.train(1)
46-
assert mock.call_count == 4
46+
assert mock.call_count == 3

0 commit comments

Comments
 (0)