7
7
import jax .numpy as jnp
8
8
from jaxtyping import PyTree
9
9
10
+ from safe_opax .common .mixed_precision import apply_dtype
10
11
from safe_opax .common .pytree_utils import pytrees_unstack
11
12
from safe_opax .la_mbda .actor_critic import ContinuousActor
12
13
from safe_opax .la_mbda .safe_actor_critic import ActorEvaluation
13
14
15
+ _EPS = 1e-8
16
+
14
17
15
18
class LBSGDState (NamedTuple ):
16
19
eta : jax .Array
17
20
18
21
19
- def compute_lr (constraint , loss_grads , constraint_grads , m_0 , m_1 , eta ):
20
- constraint_grads , _ = jax .flatten_util .ravel_pytree (constraint_grads ) # type: ignore
21
- loss_grads , _ = jax .flatten_util .ravel_pytree (loss_grads ) # type: ignore
22
- projection = constraint_grads .dot (loss_grads )
23
- lhs = (
24
- constraint
25
- / (
26
- 2.0 * jnp .abs (projection ) / jnp .linalg .norm (loss_grads )
27
- + jnp .sqrt (constraint * m_1 + 1e-8 )
28
- )
29
- / (jnp .linalg .norm (loss_grads ) + 1e-8 )
30
- )
22
+ def compute_lr (alpha_1 , g , grad_f_1 , m_0 , m_1 , eta ):
23
+ grad_f_1 , _ = jax .flatten_util .ravel_pytree (grad_f_1 )
24
+ g , _ = jax .flatten_util .ravel_pytree (g )
25
+ theta_1 = grad_f_1 .dot (g / (jnp .linalg .norm (g ) + _EPS ))
26
+ lhs = alpha_1 / (2.0 * jnp .abs (theta_1 ) + jnp .sqrt (alpha_1 * m_1 + _EPS ))
31
27
m_2 = (
32
28
m_0
33
- + 10.0 * eta * (m_1 / (constraint + 1e-8 ))
34
- + 8.0
35
- * eta
36
- * jnp .linalg .norm (projection ) ** 2
37
- / ((jnp .linalg .norm (loss_grads ) * constraint ) ** 2 )
29
+ + 10.0 * eta * (m_1 / (alpha_1 + _EPS ))
30
+ + 8.0 * eta * (theta_1 / alpha_1 + _EPS ) ** 2
38
31
)
39
32
rhs = 1.0 / m_2
40
- return jnp .minimum (lhs , rhs )
33
+ return jnp .minimum (lhs , rhs ), ( lhs , rhs )
41
34
42
35
43
36
def lbsgd_update (
44
- state : LBSGDState , updates : PyTree , eta_rate : float , m_0 : float , m_1 : float
45
- ) -> tuple [PyTree , LBSGDState ]:
37
+ state : LBSGDState ,
38
+ updates : PyTree ,
39
+ eta_rate : float ,
40
+ m_0 : float ,
41
+ m_1 : float ,
42
+ base_lr : float ,
43
+ backup_lr : float ,
44
+ ) -> tuple [PyTree , LBSGDState , tuple [float , ...]]:
46
45
def happy_case ():
47
- lr = compute_lr (constraint , loss_grads , constraints_grads , m_0 , m_1 , eta_t )
46
+ lr , ( lhs , rhs ) = compute_lr (alpha_1 , g , grad_f_1 , m_0 , m_1 , eta_t )
48
47
new_eta = eta_t / eta_rate
49
- updates = jax .tree_map (lambda x : x * lr , loss_grads )
50
- return updates , LBSGDState (new_eta )
48
+ updates = jax .tree_map (lambda x : x * lr / base_lr , g )
49
+ return updates , LBSGDState (new_eta ), ( lr , lhs , rhs )
51
50
52
51
def fallback ():
53
52
# Taking the negative gradient of the constraints to minimize the costs
54
- updates = jax .tree_map (lambda x : x * - 1.0 , constraints_grads )
55
- return updates , LBSGDState (eta_t )
53
+ updates = jax .tree_map (lambda x : x * backup_lr , grad_f_1 )
54
+ return updates , LBSGDState (eta_t ), ( 0.0 , 0.0 , 0.0 )
56
55
57
- loss_grads , constraints_grads , constraint = updates
56
+ g , grad_f_1 , alpha_1 = updates
58
57
eta_t = state .eta
59
58
return jax .lax .cond (
60
- jnp .greater (constraint , 0.0 ),
59
+ jnp .greater (alpha_1 , _EPS ),
61
60
happy_case ,
62
61
fallback ,
63
62
)
@@ -66,17 +65,27 @@ def fallback():
66
65
def jacrev (f , has_aux = False ):
67
66
def jacfn (x ):
68
67
y , vjp_fn , aux = eqx .filter_vjp (f , x , has_aux = has_aux ) # type: ignore
69
- (J ,) = eqx .filter_vmap (vjp_fn , in_axes = 0 )(jnp .eye (len (y )))
68
+ (J ,) = eqx .filter_vmap (vjp_fn , in_axes = eqx . if_array ( 0 ) )(jnp .eye (len (y )))
70
69
return J , aux
71
70
72
71
return jacfn
73
72
74
73
75
74
class LBSGDPenalizer :
76
- def __init__ (self , m_0 , m_1 , eta , eta_rate ) -> None :
75
+ def __init__ (
76
+ self ,
77
+ m_0 : float ,
78
+ m_1 : float ,
79
+ eta : float ,
80
+ eta_rate : float ,
81
+ base_lr : float ,
82
+ backup_lr : float = 1e-2 ,
83
+ ) -> None :
77
84
self .m_0 = m_0
78
85
self .m_1 = m_1
79
86
self .eta_rate = eta_rate + 1.0
87
+ self .base_lr = base_lr
88
+ self .backup_lr = backup_lr
80
89
self .state = LBSGDState (eta )
81
90
82
91
def __call__ (
@@ -87,19 +96,26 @@ def __call__(
87
96
) -> tuple [PyTree , Any , ActorEvaluation , dict [str , jax .Array ]]:
88
97
def evaluate_helper (actor ):
89
98
evaluation = evaluate (actor )
90
- outs = jnp .stack ([evaluation .loss , evaluation .constraint ])
99
+ loss = evaluation .loss - state .eta * jnp .log (evaluation .constraint )
100
+ outs = jnp .stack ([loss , - evaluation .constraint ])
91
101
return outs , evaluation
92
102
93
103
jacobian , rest = jacrev (evaluate_helper , has_aux = True )(actor )
94
- loss_grads , constraint_grads = pytrees_unstack (jacobian )
95
- updates , state = lbsgd_update (
104
+ g , grad_f_1 = pytrees_unstack (jacobian )
105
+ alpha = rest .constraint
106
+ updates , state , (lr , lhs , rhs ) = lbsgd_update (
96
107
state ,
97
- ( loss_grads , constraint_grads , rest . constraint ),
108
+ apply_dtype (( g , grad_f_1 , alpha ), jnp . float32 ),
98
109
self .eta_rate ,
99
110
self .m_0 ,
100
111
self .m_1 ,
112
+ self .base_lr ,
113
+ self .backup_lr ,
101
114
)
102
115
metrics = {
103
- "agent/lbsgd/eta" : state .eta ,
116
+ "agent/lbsgd/eta" : jnp .asarray (state .eta ),
117
+ "agent/lbsgd/lr" : jnp .asarray (lr ),
118
+ "agent/lbsgd/lhs" : jnp .asarray (lhs ),
119
+ "agent/lbsgd/rhs" : jnp .asarray (rhs ),
104
120
}
105
121
return updates , state , rest , metrics
0 commit comments