Skip to content

Commit 976b978

Browse files
authored
Updates from deployable-rl (#19)
1 parent 5f2e34e commit 976b978

17 files changed

+216
-74
lines changed

safe_opax/configs/agent/la_mbda.yaml

+3-1
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,10 @@ replay_buffer:
88
capacity: 1000
99
sentiment:
1010
ensemble_size: 5
11-
model_initialization_scale: 1.
11+
model_initialization_scale: 0.5
1212
critics_initialization_scale: 0.167
13+
constraint_pessimism: null
14+
objective_optimism: null
1315
model:
1416
hidden_size: 200
1517
stochastic_size: 60

safe_opax/configs/agent/penalizer/lbsgd.yaml

+2-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@ m_0: 1.2e4
33
m_1: 1.2e4
44
eta: 0.1
55
eta_rate: 8e-6
6-
6+
backup_lr: 1e-2
7+
78

89

910

safe_opax/configs/experiment/safe_sparse_cartpole.yaml

+4-1
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,7 @@ agent:
1616
exploration_strategy: opax
1717
exploration_steps: 1000000
1818
actor:
19-
init_stddev: 0.001
19+
init_stddev: 0.025
20+
sentiment:
21+
objective_optimism: 1.0
22+
constraint_pessimism: 1.0
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# @package _global_
2+
defaults:
3+
- override /environment: safe_adaptation_gym
4+
5+
training:
6+
epochs: 100
7+
safe: true
8+
action_repeat: 2
9+
10+
agent:
11+
exploration_steps: 0

safe_opax/la_mbda/exploration.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def __init__(
4444
config.agent.model.stochastic_size + config.agent.model.deterministic_size,
4545
action_dim,
4646
key,
47-
sentiment=identity,
47+
objective_sentiment=identity,
4848
)
4949
self.reward_scale = config.agent.exploration_reward_scale
5050

safe_opax/la_mbda/la_mbda.py

+16-1
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,11 @@
1212
from safe_opax.la_mbda.exploration import make_exploration
1313
from safe_opax.la_mbda.make_actor_critic import make_actor_critic
1414
from safe_opax.la_mbda.replay_buffer import ReplayBuffer
15+
from safe_opax.la_mbda.sentiment import Sentiment, UpperConfidenceBound, bayes
1516
from safe_opax.la_mbda.world_model import WorldModel, evaluate_model, variational_step
1617
from safe_opax.rl.epoch_summary import EpochSummary
1718
from safe_opax.rl.metrics import MetricsMonitor
18-
from safe_opax.rl.trajectory import TrajectoryData
19+
from safe_opax.rl.trajectory import TrajectoryData, Transition
1920
from safe_opax.rl.types import FloatArray, Report
2021
from safe_opax.rl.utils import Count, PRNGSequence, Until, add_to_buffer
2122

@@ -51,6 +52,15 @@ def init(cls, batch_size: int, cell: rssm.RSSM, action_dim: int) -> "AgentState"
5152
return self
5253

5354

55+
def make_sentiment(alpha) -> Sentiment:
56+
if alpha is None or alpha == 0.0:
57+
return bayes
58+
elif alpha > 0.0:
59+
return UpperConfidenceBound(alpha)
60+
else:
61+
raise ValueError(f"Invalid alpha: {alpha}")
62+
63+
5464
class LaMBDA:
5565
def __init__(
5666
self,
@@ -86,6 +96,8 @@ def __init__(
8696
config.agent.model.stochastic_size + config.agent.model.deterministic_size,
8797
action_dim,
8898
next(self.prng),
99+
make_sentiment(self.config.agent.sentiment.objective_optimism),
100+
make_sentiment(self.config.agent.sentiment.constraint_pessimism),
89101
)
90102
self.exploration = make_exploration(
91103
config,
@@ -137,6 +149,9 @@ def observe(self, trajectory: TrajectoryData) -> None:
137149
)
138150
self.state = jax.tree_map(lambda x: jnp.zeros_like(x), self.state)
139151

152+
def observe_transition(self, transition: Transition) -> None:
153+
pass
154+
140155
def update(self):
141156
total_steps = self.config.agent.update_steps
142157
for batch in self.replay_buffer.sample(total_steps):

safe_opax/la_mbda/lbsgd.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def compute_lr(alpha_1, g, grad_f_1, m_0, m_1, eta):
2727
m_2 = (
2828
m_0
2929
+ 10.0 * eta * (m_1 / (alpha_1 + _EPS))
30-
+ 8.0 * eta * (theta_1 / alpha_1 + _EPS) ** 2
30+
+ 8.0 * eta * (theta_1 / (alpha_1 + _EPS)) ** 2
3131
)
3232
rhs = 1.0 / m_2
3333
return jnp.minimum(lhs, rhs), (lhs, rhs)

safe_opax/la_mbda/make_actor_critic.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,15 @@
1010
_LOG = logging.getLogger(__name__)
1111

1212

13-
def make_actor_critic(cfg, safe, state_dim, action_dim, key, sentiment=bayes):
13+
def make_actor_critic(
14+
cfg,
15+
safe,
16+
state_dim,
17+
action_dim,
18+
key,
19+
objective_sentiment=bayes,
20+
constraint_sentiment=bayes,
21+
):
1422
# Account for the the discount factor in the budget.
1523
episode_safety_budget = (
1624
(
@@ -29,6 +37,7 @@ def make_actor_critic(cfg, safe, state_dim, action_dim, key, sentiment=bayes):
2937
cfg.agent.penalizer.eta,
3038
cfg.agent.penalizer.eta_rate,
3139
cfg.agent.actor_optimizer.lr,
40+
cfg.agent.penalizer.backup_lr,
3241
)
3342
elif cfg.agent.penalizer.name == "lagrangian":
3443
penalizer = AugmentedLagrangianPenalizer(
@@ -56,5 +65,6 @@ def make_actor_critic(cfg, safe, state_dim, action_dim, key, sentiment=bayes):
5665
safety_budget=episode_safety_budget,
5766
penalizer=penalizer,
5867
key=key,
59-
objective_sentiment=sentiment,
68+
objective_sentiment=objective_sentiment,
69+
constraint_sentiment=constraint_sentiment,
6070
)

safe_opax/la_mbda/safe_actor_critic.py

+23-60
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ class ActorEvaluation(NamedTuple):
2626
constraint: jax.Array
2727
safe: jax.Array
2828
priors: ShiftScale
29+
reward_stddev: jax.Array
30+
cost_stddev: jax.Array
2931

3032

3133
class Penalizer(Protocol):
@@ -59,6 +61,7 @@ def __init__(
5961
key: jax.Array,
6062
penalizer: Penalizer,
6163
objective_sentiment: Sentiment,
64+
constraint_sentiment: Sentiment,
6265
):
6366
actor_key, critic_key, safety_critic_key = jax.random.split(key, 3)
6467
self.actor = ContinuousActor(
@@ -84,18 +87,18 @@ def __init__(
8487
self.lambda_ = lambda_
8588
self.safety_discount = safety_discount
8689
self.safety_budget = safety_budget
87-
self.update_fn = batched_update_safe_actor_critic
8890
self.penalizer = penalizer
8991
self.objective_sentiment = objective_sentiment
92+
self.constraint_sentiment = constraint_sentiment
9093

9194
def update(
9295
self,
9396
model: Model,
9497
initial_states: jax.Array,
9598
key: jax.Array,
9699
) -> dict[str, float]:
97-
actor_critic_fn = partial(self.update_fn, model.sample)
98-
results: SafeActorCriticStepResults = actor_critic_fn(
100+
results: SafeActorCriticStepResults = update_safe_actor_critic(
101+
model.sample,
99102
self.horizon,
100103
initial_states,
101104
self.actor,
@@ -115,6 +118,7 @@ def update(
115118
self.penalizer,
116119
self.penalizer.state,
117120
self.objective_sentiment,
121+
self.constraint_sentiment,
118122
)
119123
self.actor = results.new_actor
120124
self.critic = results.new_critic
@@ -196,6 +200,7 @@ def evaluate_actor(
196200
lambda_: float,
197201
safety_budget: float,
198202
objective_sentiment: Sentiment,
203+
constraint_sentiment: Sentiment,
199204
) -> ActorEvaluation:
200205
trajectories, priors = rollout_fn(horizon, initial_states, key, actor.act)
201206
next_step = lambda x: x[:, 1:]
@@ -207,9 +212,7 @@ def evaluate_actor(
207212
bootstrap_values, rewards, discount, lambda_
208213
)
209214
bootstrap_safety_values = nest_vmap(safety_critic, 2, eqx.filter_vmap)(next_states)
210-
# TODO (yarden): make costs use their own sentiments when working
211-
# on safety.
212-
costs = current_step(trajectories.cost.mean(1))
215+
costs = current_step(constraint_sentiment(trajectories.cost))
213216
safety_lambda_values = eqx.filter_vmap(compute_lambda_values)(
214217
bootstrap_safety_values,
215218
costs,
@@ -228,9 +231,16 @@ def evaluate_actor(
228231
constraint,
229232
jnp.greater(constraint, 0.0),
230233
priors,
234+
rewards.std(1).mean(),
235+
costs.std(1).mean(),
231236
)
232237

233238

239+
@eqx.filter_jit
240+
@apply_mixed_precision(
241+
target_module_names=["critic", "safety_critic", "actor", "rollout_fn"],
242+
target_input_names=["initial_states"],
243+
)
234244
def update_safe_actor_critic(
235245
rollout_fn: RolloutFn,
236246
horizon: int,
@@ -252,13 +262,15 @@ def update_safe_actor_critic(
252262
penalty_fn: Penalizer,
253263
penalty_state: Any,
254264
objective_sentiment: Sentiment,
265+
constraint_sentiment: Sentiment,
255266
) -> SafeActorCriticStepResults:
267+
vmapped_rollout_fn = jax.vmap(rollout_fn, (None, 0, None, None))
256268
actor_grads, new_penalty_state, evaluation, metrics = penalty_fn(
257269
lambda actor: evaluate_actor(
258270
actor,
259271
critic,
260272
safety_critic,
261-
rollout_fn,
273+
vmapped_rollout_fn,
262274
horizon,
263275
initial_states,
264276
key,
@@ -267,6 +279,7 @@ def update_safe_actor_critic(
267279
lambda_,
268280
safety_budget,
269281
objective_sentiment,
282+
constraint_sentiment,
270283
),
271284
penalty_state,
272285
actor,
@@ -292,9 +305,11 @@ def update_safe_actor_critic(
292305
new_safety_critic, new_safety_critic_state = safety_critic_learner.grad_step(
293306
safety_critic, grads, safety_critic_learning_state
294307
)
295-
metrics["agent/epistemic_uncertainty"] = normalized_epistemic_uncertainty(
308+
metrics["agent/sentiment/epistemic_uncertainty"] = normalized_epistemic_uncertainty(
296309
evaluation.priors, 1
297310
).mean()
311+
metrics["agent/sentiment/reward_stddev"] = evaluation.reward_stddev
312+
metrics["agent/sentiment/cost_stddev"] = evaluation.cost_stddev
298313
return SafeActorCriticStepResults(
299314
new_actor,
300315
new_critic,
@@ -313,58 +328,6 @@ def update_safe_actor_critic(
313328
)
314329

315330

316-
@eqx.filter_jit
317-
@apply_mixed_precision(
318-
target_module_names=["critic", "safety_critic", "actor", "rollout_fn"],
319-
target_input_names=["initial_states"],
320-
)
321-
def batched_update_safe_actor_critic(
322-
rollout_fn: RolloutFn,
323-
horizon: int,
324-
initial_states: jax.Array,
325-
actor: ContinuousActor,
326-
critic: Critic,
327-
safety_critic: Critic,
328-
actor_learning_state: OptState,
329-
critic_learning_state: OptState,
330-
safety_critic_learning_state: OptState,
331-
actor_learner: Learner,
332-
critic_learner: Learner,
333-
safety_critic_learner: Learner,
334-
key: jax.Array,
335-
discount: float,
336-
safety_discount: float,
337-
lambda_: float,
338-
safety_budget: float,
339-
penalty_fn: Penalizer,
340-
penalty_state: Any,
341-
objective_sentiment: Sentiment,
342-
) -> SafeActorCriticStepResults:
343-
vmapped_rollout_fn = jax.vmap(rollout_fn, (None, 0, None, None))
344-
return update_safe_actor_critic(
345-
vmapped_rollout_fn,
346-
horizon,
347-
initial_states,
348-
actor,
349-
critic,
350-
safety_critic,
351-
actor_learning_state,
352-
critic_learning_state,
353-
safety_critic_learning_state,
354-
actor_learner,
355-
critic_learner,
356-
safety_critic_learner,
357-
key,
358-
discount,
359-
safety_discount,
360-
lambda_,
361-
safety_budget,
362-
penalty_fn,
363-
penalty_state,
364-
objective_sentiment,
365-
)
366-
367-
368331
def compute_discount(factor, length):
369332
d = jnp.cumprod(factor * jnp.ones((length - 1,)))
370333
d = jnp.concatenate([jnp.ones((1,)), d])

safe_opax/la_mbda/sentiment.py

+17
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,23 @@ def bayes(values: jax.Array) -> jax.Array:
1616
return values.mean(1)
1717

1818

19+
class UpperConfidenceBound(Sentiment):
20+
def __init__(self, alpha: float = 1.0):
21+
self.alpha = alpha
22+
23+
def __call__(self, values: jax.Array) -> jax.Array:
24+
return upper_confidence_bound(values, self.alpha)
25+
26+
27+
def upper_confidence_bound(
28+
values: jax.Array, alpha: float, stop_gradient: bool = True
29+
) -> jax.Array:
30+
stddev = jnp.std(values, axis=1)
31+
if stop_gradient:
32+
stddev = jax.lax.stop_gradient(stddev)
33+
return jnp.mean(values, axis=1) + alpha * stddev
34+
35+
1936
def _emprirical_estimate(
2037
values: jax.Array, reduce_fn: Callable[[jax.Array], jax.Array]
2138
) -> jax.Array:

safe_opax/lambda_dalal/__init__.py

Whitespace-only changes.

safe_opax/lambda_dalal/cost_model.py

Whitespace-only changes.

0 commit comments

Comments
 (0)