Skip to content

Commit 28c21be

Browse files
authored
Rssm init (#21)
* Extend initialization tricks * Add parameters for doggo * Fix bug with rendering
1 parent 0e95470 commit 28c21be

File tree

9 files changed

+71
-24
lines changed

9 files changed

+71
-24
lines changed

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
name = "safe-opax"
33
version = "0.1.0"
44
description = ""
5-
authors = ["Yarden <[email protected]>"]
5+
authors = ["Yarden As"]
66
readme = "README.md"
77

88
[tool.poetry.dependencies]

safe_opax/benchmark_suites/safe_adaptation_gym/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def make_env():
4444
task_name=task,
4545
seed=cfg.training.seed,
4646
rgb_observation=task_cfg.image_observation.enabled,
47-
render_lidar_and_collision=not task_cfg.image_observation.enabled,
47+
render_lidar_and_collision=False,
4848
)
4949
env = SafeAdaptationEnvCompatibility(env)
5050
if (

safe_opax/common/pytree_utils.py

+7
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
import jax
2+
import jax.numpy as jnp
3+
24

35
def pytrees_unstack(pytree):
46
leaves, treedef = jax.tree_flatten(pytree)
@@ -9,3 +11,8 @@ def pytrees_unstack(pytree):
911
new_leaves[i].append(leaf[i])
1012
new_trees = [treedef.unflatten(leaf) for leaf in new_leaves]
1113
return new_trees
14+
15+
16+
def pytrees_stack(pytrees, axis=0):
17+
results = jax.tree_map(lambda *values: jnp.stack(values, axis=axis), *pytrees)
18+
return results

safe_opax/configs/agent/la_mbda.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ replay_buffer:
88
capacity: 1000
99
sentiment:
1010
ensemble_size: 5
11-
model_initialization_scale: 0.5
11+
model_initialization_scale: null
1212
constraint_pessimism: null
1313
objective_optimism: null
1414
model:

safe_opax/configs/config.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ training:
4242
episodes_per_epoch: 5
4343
epochs: 200
4444
action_repeat: 1
45-
render_episodes: 1
45+
render_episodes: 0
4646
parallel_envs: 10
4747
scale_reward: 1.
4848
exploration_steps: 5000

safe_opax/configs/experiment/safe_sparse_cartpole.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ training:
1414

1515
agent:
1616
exploration_strategy: opax
17-
exploration_steps: 1000000
17+
exploration_steps: 0
1818
actor:
1919
init_stddev: 0.025
2020
sentiment:
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# @package _global_
2+
defaults:
3+
- override /environment: safe_adaptation_gym
4+
5+
training:
6+
epochs: 100
7+
safe: true
8+
action_repeat: 2
9+
10+
environment:
11+
safe_adaptation_gym:
12+
robot_name: doggo
13+
14+
agent:
15+
exploration_steps: 0
16+
actor:
17+
initialization_scale: 1.

safe_opax/la_mbda/rssm.py

+27-17
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
1-
from functools import partial
21
from typing import NamedTuple
32

43
import distrax as dtx
54
import equinox as eqx
65
import jax
6+
import jax.flatten_util
77
import jax.nn as jnn
88
import jax.numpy as jnp
99

10-
from safe_opax.rl.utils import glorot_uniform, init_linear_weights
10+
from safe_opax.rl.utils import init_linear_weights_and_biases
1111

1212

1313
class State(NamedTuple):
@@ -127,27 +127,26 @@ def __init__(
127127
embedding_size: int,
128128
action_dim: int,
129129
ensemble_size: int,
130-
initialization_scale: float,
130+
initialization_scale: float | None = None,
131131
*,
132132
key: jax.Array,
133133
):
134134
self.ensemble_size = ensemble_size
135135
prior_key, posterior_key = jax.random.split(key)
136-
make_priors = eqx.filter_vmap(
137-
lambda key: init_linear_weights(
138-
Prior(
139-
deterministic_size,
140-
stochastic_size,
141-
hidden_size,
142-
action_dim,
143-
key,
144-
),
145-
partial(glorot_uniform, scale=initialization_scale),
146-
key,
147-
)
136+
dummy_prior = Prior(
137+
deterministic_size,
138+
stochastic_size,
139+
hidden_size,
140+
action_dim,
141+
key,
142+
)
143+
initialization_scale = (
144+
initialization_scale
145+
if initialization_scale is not None
146+
else jax.flatten_util.ravel_pytree(dummy_prior)[0].std()
148147
)
149-
self.priors = make_priors(
150-
jnp.asarray(jax.random.split(prior_key, ensemble_size))
148+
self.priors = jitter_priors(
149+
dummy_prior, prior_key, initialization_scale, ensemble_size
151150
)
152151
self.posterior = Posterior(
153152
deterministic_size,
@@ -214,3 +213,14 @@ def _priors_predict(
214213
in_axes=(eqx.if_array(0), prev_state_in_axis, action_in_axis),
215214
)
216215
return priors_fn(priors, prev_state, action)
216+
217+
218+
def jitter_priors(
219+
prior: Prior, key: jax.Array, scale: float, ensemble_size: int
220+
) -> Prior:
221+
make_priors = eqx.filter_vmap(
222+
lambda key: init_linear_weights_and_biases(
223+
prior, lambda x, subkey: x + scale * jax.random.normal(subkey, x.shape), key
224+
)
225+
)
226+
return make_priors(jnp.asarray(jax.random.split(key, ensemble_size)))

safe_opax/rl/utils.py

+15-2
Original file line numberDiff line numberDiff line change
@@ -84,22 +84,35 @@ def rl_initialize_weights_trick(model, bias_shift=0.0, weight_scale=0.01):
8484
model.layers[-1].weight * weight_scale,
8585
)
8686
model = eqx.tree_at(
87-
lambda model: model.layers[-1].bias, model, model.layers[-1].bias * 0. + bias_shift
87+
lambda model: model.layers[-1].bias,
88+
model,
89+
model.layers[-1].bias * 0.0 + bias_shift,
8890
)
8991
return model
9092

9193

92-
def init_linear_weights(model, init_fn, key):
94+
def init_linear_weights_and_biases(model, init_fn, key):
9395
is_linear = lambda x: isinstance(x, eqx.nn.Linear)
9496
get_weights = lambda m: [
9597
x.weight
9698
for x in jax.tree_util.tree_leaves(m, is_leaf=is_linear)
9799
if is_linear(x)
98100
]
101+
get_biases = lambda m: [
102+
x.bias
103+
for x in jax.tree_util.tree_leaves(m, is_leaf=is_linear)
104+
if is_linear(x) and x.bias is not None
105+
]
99106
weights = get_weights(model)
107+
biases = get_biases(model)
100108
new_weights = [
101109
init_fn(weight, subkey)
102110
for weight, subkey in zip(weights, jax.random.split(key, len(weights)))
103111
]
112+
new_biases = [
113+
init_fn(bias, subkey)
114+
for bias, subkey in zip(biases, jax.random.split(key, len(biases)))
115+
]
104116
new_model = eqx.tree_at(get_weights, model, new_weights)
117+
new_model = eqx.tree_at(get_biases, new_model, new_biases)
105118
return new_model

0 commit comments

Comments
 (0)