|
1 |
| -from functools import partial |
2 | 1 | from typing import NamedTuple
|
3 | 2 |
|
4 | 3 | import distrax as dtx
|
5 | 4 | import equinox as eqx
|
6 | 5 | import jax
|
| 6 | +import jax.flatten_util |
7 | 7 | import jax.nn as jnn
|
8 | 8 | import jax.numpy as jnp
|
9 | 9 |
|
10 |
| -from safe_opax.rl.utils import glorot_uniform, init_linear_weights |
| 10 | +from safe_opax.rl.utils import init_linear_weights_and_biases |
11 | 11 |
|
12 | 12 |
|
13 | 13 | class State(NamedTuple):
|
@@ -127,27 +127,26 @@ def __init__(
|
127 | 127 | embedding_size: int,
|
128 | 128 | action_dim: int,
|
129 | 129 | ensemble_size: int,
|
130 |
| - initialization_scale: float, |
| 130 | + initialization_scale: float | None = None, |
131 | 131 | *,
|
132 | 132 | key: jax.Array,
|
133 | 133 | ):
|
134 | 134 | self.ensemble_size = ensemble_size
|
135 | 135 | prior_key, posterior_key = jax.random.split(key)
|
136 |
| - make_priors = eqx.filter_vmap( |
137 |
| - lambda key: init_linear_weights( |
138 |
| - Prior( |
139 |
| - deterministic_size, |
140 |
| - stochastic_size, |
141 |
| - hidden_size, |
142 |
| - action_dim, |
143 |
| - key, |
144 |
| - ), |
145 |
| - partial(glorot_uniform, scale=initialization_scale), |
146 |
| - key, |
147 |
| - ) |
| 136 | + dummy_prior = Prior( |
| 137 | + deterministic_size, |
| 138 | + stochastic_size, |
| 139 | + hidden_size, |
| 140 | + action_dim, |
| 141 | + key, |
| 142 | + ) |
| 143 | + initialization_scale = ( |
| 144 | + initialization_scale |
| 145 | + if initialization_scale is not None |
| 146 | + else jax.flatten_util.ravel_pytree(dummy_prior)[0].std() |
148 | 147 | )
|
149 |
| - self.priors = make_priors( |
150 |
| - jnp.asarray(jax.random.split(prior_key, ensemble_size)) |
| 148 | + self.priors = jitter_priors( |
| 149 | + dummy_prior, prior_key, initialization_scale, ensemble_size |
151 | 150 | )
|
152 | 151 | self.posterior = Posterior(
|
153 | 152 | deterministic_size,
|
@@ -214,3 +213,14 @@ def _priors_predict(
|
214 | 213 | in_axes=(eqx.if_array(0), prev_state_in_axis, action_in_axis),
|
215 | 214 | )
|
216 | 215 | return priors_fn(priors, prev_state, action)
|
| 216 | + |
| 217 | + |
| 218 | +def jitter_priors( |
| 219 | + prior: Prior, key: jax.Array, scale: float, ensemble_size: int |
| 220 | +) -> Prior: |
| 221 | + make_priors = eqx.filter_vmap( |
| 222 | + lambda key: init_linear_weights_and_biases( |
| 223 | + prior, lambda x, subkey: x + scale * jax.random.normal(subkey, x.shape), key |
| 224 | + ) |
| 225 | + ) |
| 226 | + return make_priors(jnp.asarray(jax.random.split(key, ensemble_size))) |
0 commit comments