@@ -26,6 +26,8 @@ class ActorEvaluation(NamedTuple):
26
26
constraint : jax .Array
27
27
safe : jax .Array
28
28
priors : ShiftScale
29
+ reward_stddev : jax .Array
30
+ cost_stddev : jax .Array
29
31
30
32
31
33
class Penalizer (Protocol ):
@@ -59,6 +61,7 @@ def __init__(
59
61
key : jax .Array ,
60
62
penalizer : Penalizer ,
61
63
objective_sentiment : Sentiment ,
64
+ constraint_sentiment : Sentiment ,
62
65
):
63
66
actor_key , critic_key , safety_critic_key = jax .random .split (key , 3 )
64
67
self .actor = ContinuousActor (
@@ -84,18 +87,18 @@ def __init__(
84
87
self .lambda_ = lambda_
85
88
self .safety_discount = safety_discount
86
89
self .safety_budget = safety_budget
87
- self .update_fn = batched_update_safe_actor_critic
88
90
self .penalizer = penalizer
89
91
self .objective_sentiment = objective_sentiment
92
+ self .constraint_sentiment = constraint_sentiment
90
93
91
94
def update (
92
95
self ,
93
96
model : Model ,
94
97
initial_states : jax .Array ,
95
98
key : jax .Array ,
96
99
) -> dict [str , float ]:
97
- actor_critic_fn = partial ( self . update_fn , model . sample )
98
- results : SafeActorCriticStepResults = actor_critic_fn (
100
+ results : SafeActorCriticStepResults = update_safe_actor_critic (
101
+ model . sample ,
99
102
self .horizon ,
100
103
initial_states ,
101
104
self .actor ,
@@ -115,6 +118,7 @@ def update(
115
118
self .penalizer ,
116
119
self .penalizer .state ,
117
120
self .objective_sentiment ,
121
+ self .constraint_sentiment ,
118
122
)
119
123
self .actor = results .new_actor
120
124
self .critic = results .new_critic
@@ -196,6 +200,7 @@ def evaluate_actor(
196
200
lambda_ : float ,
197
201
safety_budget : float ,
198
202
objective_sentiment : Sentiment ,
203
+ constraint_sentiment : Sentiment ,
199
204
) -> ActorEvaluation :
200
205
trajectories , priors = rollout_fn (horizon , initial_states , key , actor .act )
201
206
next_step = lambda x : x [:, 1 :]
@@ -207,9 +212,7 @@ def evaluate_actor(
207
212
bootstrap_values , rewards , discount , lambda_
208
213
)
209
214
bootstrap_safety_values = nest_vmap (safety_critic , 2 , eqx .filter_vmap )(next_states )
210
- # TODO (yarden): make costs use their own sentiments when working
211
- # on safety.
212
- costs = current_step (trajectories .cost .mean (1 ))
215
+ costs = current_step (constraint_sentiment (trajectories .cost ))
213
216
safety_lambda_values = eqx .filter_vmap (compute_lambda_values )(
214
217
bootstrap_safety_values ,
215
218
costs ,
@@ -228,9 +231,16 @@ def evaluate_actor(
228
231
constraint ,
229
232
jnp .greater (constraint , 0.0 ),
230
233
priors ,
234
+ rewards .std (1 ).mean (),
235
+ costs .std (1 ).mean (),
231
236
)
232
237
233
238
239
+ @eqx .filter_jit
240
+ @apply_mixed_precision (
241
+ target_module_names = ["critic" , "safety_critic" , "actor" , "rollout_fn" ],
242
+ target_input_names = ["initial_states" ],
243
+ )
234
244
def update_safe_actor_critic (
235
245
rollout_fn : RolloutFn ,
236
246
horizon : int ,
@@ -252,13 +262,15 @@ def update_safe_actor_critic(
252
262
penalty_fn : Penalizer ,
253
263
penalty_state : Any ,
254
264
objective_sentiment : Sentiment ,
265
+ constraint_sentiment : Sentiment ,
255
266
) -> SafeActorCriticStepResults :
267
+ vmapped_rollout_fn = jax .vmap (rollout_fn , (None , 0 , None , None ))
256
268
actor_grads , new_penalty_state , evaluation , metrics = penalty_fn (
257
269
lambda actor : evaluate_actor (
258
270
actor ,
259
271
critic ,
260
272
safety_critic ,
261
- rollout_fn ,
273
+ vmapped_rollout_fn ,
262
274
horizon ,
263
275
initial_states ,
264
276
key ,
@@ -267,6 +279,7 @@ def update_safe_actor_critic(
267
279
lambda_ ,
268
280
safety_budget ,
269
281
objective_sentiment ,
282
+ constraint_sentiment ,
270
283
),
271
284
penalty_state ,
272
285
actor ,
@@ -292,9 +305,11 @@ def update_safe_actor_critic(
292
305
new_safety_critic , new_safety_critic_state = safety_critic_learner .grad_step (
293
306
safety_critic , grads , safety_critic_learning_state
294
307
)
295
- metrics ["agent/epistemic_uncertainty" ] = normalized_epistemic_uncertainty (
308
+ metrics ["agent/sentiment/ epistemic_uncertainty" ] = normalized_epistemic_uncertainty (
296
309
evaluation .priors , 1
297
310
).mean ()
311
+ metrics ["agent/sentiment/reward_stddev" ] = evaluation .reward_stddev
312
+ metrics ["agent/sentiment/cost_stddev" ] = evaluation .cost_stddev
298
313
return SafeActorCriticStepResults (
299
314
new_actor ,
300
315
new_critic ,
@@ -313,58 +328,6 @@ def update_safe_actor_critic(
313
328
)
314
329
315
330
316
- @eqx .filter_jit
317
- @apply_mixed_precision (
318
- target_module_names = ["critic" , "safety_critic" , "actor" , "rollout_fn" ],
319
- target_input_names = ["initial_states" ],
320
- )
321
- def batched_update_safe_actor_critic (
322
- rollout_fn : RolloutFn ,
323
- horizon : int ,
324
- initial_states : jax .Array ,
325
- actor : ContinuousActor ,
326
- critic : Critic ,
327
- safety_critic : Critic ,
328
- actor_learning_state : OptState ,
329
- critic_learning_state : OptState ,
330
- safety_critic_learning_state : OptState ,
331
- actor_learner : Learner ,
332
- critic_learner : Learner ,
333
- safety_critic_learner : Learner ,
334
- key : jax .Array ,
335
- discount : float ,
336
- safety_discount : float ,
337
- lambda_ : float ,
338
- safety_budget : float ,
339
- penalty_fn : Penalizer ,
340
- penalty_state : Any ,
341
- objective_sentiment : Sentiment ,
342
- ) -> SafeActorCriticStepResults :
343
- vmapped_rollout_fn = jax .vmap (rollout_fn , (None , 0 , None , None ))
344
- return update_safe_actor_critic (
345
- vmapped_rollout_fn ,
346
- horizon ,
347
- initial_states ,
348
- actor ,
349
- critic ,
350
- safety_critic ,
351
- actor_learning_state ,
352
- critic_learning_state ,
353
- safety_critic_learning_state ,
354
- actor_learner ,
355
- critic_learner ,
356
- safety_critic_learner ,
357
- key ,
358
- discount ,
359
- safety_discount ,
360
- lambda_ ,
361
- safety_budget ,
362
- penalty_fn ,
363
- penalty_state ,
364
- objective_sentiment ,
365
- )
366
-
367
-
368
331
def compute_discount (factor , length ):
369
332
d = jnp .cumprod (factor * jnp .ones ((length - 1 ,)))
370
333
d = jnp .concatenate ([jnp .ones ((1 ,)), d ])
0 commit comments