Skip to content

Commit f87c636

Browse files
authored
Add a constraint cartpole environment (#14)
* Add a constraint cartpole environment * Setup experiment configuration * Cost rate -> cost return * Bug fixes with lbsgd * Scale optimism rewards * Fix bug in rewards scale * Fix baking state of exploration into policy bug * Zero out bias * Initially safe policy * Reset metrics completely * Log lhs rhs * Fix bugs in rhs computation, tune lrs * Log safety budget * Fix learning rate scaling in lbsgd
1 parent 0a8fe5a commit f87c636

18 files changed

+151
-76
lines changed

poetry.lock

+3-3
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

safe_opax/benchmark_suites/dm_control/__init__.py

+25-3
Original file line numberDiff line numberDiff line change
@@ -203,23 +203,44 @@ def __init__(self, env: Env, cost_multiplier: float = 0):
203203
self.cost_multiplier = cost_multiplier
204204

205205
def step(self, action):
206-
action_cost = self.cost_multiplier * (1 - tolerance(action, (-0.1, 0.1), 0.1))[0]
206+
action_cost = (
207+
self.cost_multiplier * (1 - tolerance(action, (-0.1, 0.1), 0.1))[0]
208+
)
207209
observation, reward, terminal, truncated, info = self.env.step(action)
208210
return observation, reward - action_cost, terminal, truncated, info
209211

210212
def __getattr__(self, name):
211213
return getattr(self.env, name)
212214

213215

216+
class ConstraintWrapper:
217+
def __init__(self, env: Env, slider_position_bound: float):
218+
self.env = env
219+
self.physics = env.env.physics
220+
self.slider_position_bound = slider_position_bound
221+
222+
def step(self, action):
223+
observation, reward, terminal, truncated, info = self.env.step(action)
224+
slider_pos = self.physics.cart_position().copy()
225+
cost = float(np.abs(slider_pos) >= self.slider_position_bound)
226+
info["cost"] = cost
227+
return observation, reward, terminal, truncated, info
228+
229+
def __getattr__(self, name):
230+
return getattr(self.env, name)
231+
232+
214233
def make(cfg: DictConfig) -> EnvironmentFactory:
215234
def make_env():
216235
domain_name, task_cfg = get_domain_and_task(cfg)
217-
if task_cfg.task == "swingup_sparse_hard":
236+
if task_cfg.task in ["swingup_sparse_hard", "safe_swingup_sparse_hard"]:
218237
task = "swingup_sparse"
219238
else:
220239
task = task_cfg.task
221240
env = DMCWrapper(domain_name, task)
222-
if task_cfg.task == "swingup_sparse_hard":
241+
if "safe" in task_cfg.task:
242+
env = ConstraintWrapper(env, task_cfg.slider_position_bound)
243+
if task_cfg.task in ["swingup_sparse_hard", "safe_swingup_sparse_hard"]:
223244
env = ActionCostWrapper(env, cost_multiplier=task_cfg.cost_multiplier)
224245
if task_cfg.image_observation.enabled:
225246
env = ImageObservation(
@@ -245,6 +266,7 @@ def make_env():
245266
("dm_cartpole", "swingup"),
246267
("dm_cartpole", "swingup_sparse"),
247268
("dm_cartpole", "swingup_sparse_hard"),
269+
("dm_cartpole", "safe_swingup_sparse_hard"),
248270
("dm_humanoid", "stand"),
249271
("dm_humanoid", "walk"),
250272
("dm_manipulator", "bring_ball"),

safe_opax/configs/agent/la_mbda.yaml

+4-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
defaults:
2-
- penalizer/lagrangian
2+
- penalizer: lagrangian
33

44
name: lambda
55
replay_buffer:
@@ -18,6 +18,7 @@ actor:
1818
n_layers: 4
1919
hidden_size: 400
2020
init_stddev: 5.
21+
initialization_scale: 0.01
2122
critic:
2223
n_layers: 3
2324
hidden_size: 400
@@ -49,4 +50,5 @@ kl_mix: 0.8
4950
safety_slack: 0.
5051
evaluate_model: false
5152
exploration_strategy: uniform
52-
exploration_steps: 5000
53+
exploration_steps: 5000
54+
exploration_reward_scale: 10.0

safe_opax/configs/agent/penalizer/lbsgd.yaml

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name: lbsgd
2-
m_0: 5e6
3-
m_1: 5e4
2+
m_0: 1.2e4
3+
m_1: 1.2e4
44
eta: 0.1
55
eta_rate: 8e-6
66

safe_opax/configs/environment/dm_cartpole.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,4 @@ dm_cartpole:
66
image_format: "channels_first"
77
visualize_reward: true
88
cost_multiplier: 0.2
9+
slider_position_bound: 0.5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# @package _global_
2+
defaults:
3+
- override /environment: dm_cartpole
4+
- override /agent/penalizer: lbsgd
5+
6+
environment:
7+
dm_cartpole:
8+
task: safe_swingup_sparse_hard
9+
10+
training:
11+
epochs: 100
12+
safe: true
13+
action_repeat: 2
14+
safety_budget: 100
15+
16+
agent:
17+
exploration_strategy: opax
18+
exploration_steps: 1000000
19+
actor:
20+
init_stddev: 0.001

safe_opax/la_mbda/actor_critic.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ def __init__(
2020
action_dim: int,
2121
hidden_size: int,
2222
init_stddev: float,
23+
initialization_scale: float,
2324
*,
2425
key: jax.Array,
2526
):
@@ -31,7 +32,8 @@ def __init__(
3132
n_layers + 1,
3233
key=key,
3334
activation=jnn.elu,
34-
)
35+
),
36+
weight_scale=initialization_scale,
3537
)
3638
self.init_stddev = init_stddev
3739

safe_opax/la_mbda/exploration.py

+16-15
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,10 @@
44
from safe_opax.la_mbda.opax_bridge import OpaxBridge
55
from safe_opax.la_mbda.make_actor_critic import make_actor_critic
66
from safe_opax.la_mbda.sentiment import identity
7-
from safe_opax.rl.types import Model
7+
from safe_opax.rl.types import Model, Policy
88

99

1010
class Exploration:
11-
def __call__(self, state: jax.Array, key: jax.Array) -> jax.Array:
12-
raise NotImplementedError("Must be implemented by subclass")
13-
1411
def update(
1512
self,
1613
model: Model,
@@ -19,6 +16,9 @@ def update(
1916
) -> dict[str, float]:
2017
return {}
2118

19+
def get_policy(self) -> Policy:
20+
raise NotImplementedError("Must be implemented by subclass")
21+
2222

2323
def make_exploration(
2424
config: DictConfig, action_dim: int, key: jax.Array
@@ -46,31 +46,32 @@ def __init__(
4646
key,
4747
sentiment=identity,
4848
)
49+
self.reward_scale = config.agent.exploration_reward_scale
4950

5051
def update(
5152
self,
5253
model: Model,
5354
initial_states: jax.Array,
5455
key: jax.Array,
5556
) -> dict[str, float]:
56-
model = OpaxBridge(model)
57+
model = OpaxBridge(model, self.reward_scale)
5758
outs = self.actor_critic.update(model, initial_states, key)
59+
outs = {f"{_append_opax(k)}": v for k, v in outs.items()}
60+
return outs
5861

59-
def append_opax(string):
60-
parts = string.split("/")
61-
parts.insert(2, "opax")
62-
return "/".join(parts)
62+
def get_policy(self) -> Policy:
63+
return self.actor_critic.actor.act
6364

64-
outs = {f"{append_opax(k)}": v for k, v in outs.items()}
65-
return outs
6665

67-
def __call__(self, state: jax.Array, key: jax.Array) -> jax.Array:
68-
return self.actor_critic.actor.act(state, key)
66+
def _append_opax(string):
67+
parts = string.split("/")
68+
parts.insert(2, "opax")
69+
return "/".join(parts)
6970

7071

7172
class UniformExploration(Exploration):
7273
def __init__(self, action_dim: int):
7374
self.action_dim = action_dim
7475

75-
def __call__(self, state: jax.Array, key: jax.Array) -> jax.Array:
76-
return jax.random.uniform(key, (self.action_dim,))
76+
def get_policy(self) -> Policy:
77+
return lambda state, key: jax.random.uniform(key, (self.action_dim,))

safe_opax/la_mbda/la_mbda.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,9 @@ def __call__(
115115
if train and self.should_train() and not self.replay_buffer.empty:
116116
self.update()
117117
policy_fn = (
118-
self.exploration if self.should_explore() else self.actor_critic.actor.act
118+
self.exploration.get_policy()
119+
if self.should_explore()
120+
else self.actor_critic.actor.act
119121
)
120122
self.should_explore.tick()
121123
actions, self.state = policy(

safe_opax/la_mbda/lbsgd.py

+50-34
Original file line numberDiff line numberDiff line change
@@ -7,57 +7,56 @@
77
import jax.numpy as jnp
88
from jaxtyping import PyTree
99

10+
from safe_opax.common.mixed_precision import apply_dtype
1011
from safe_opax.common.pytree_utils import pytrees_unstack
1112
from safe_opax.la_mbda.actor_critic import ContinuousActor
1213
from safe_opax.la_mbda.safe_actor_critic import ActorEvaluation
1314

15+
_EPS = 1e-8
16+
1417

1518
class LBSGDState(NamedTuple):
1619
eta: jax.Array
1720

1821

19-
def compute_lr(constraint, loss_grads, constraint_grads, m_0, m_1, eta):
20-
constraint_grads, _ = jax.flatten_util.ravel_pytree(constraint_grads) # type: ignore
21-
loss_grads, _ = jax.flatten_util.ravel_pytree(loss_grads) # type: ignore
22-
projection = constraint_grads.dot(loss_grads)
23-
lhs = (
24-
constraint
25-
/ (
26-
2.0 * jnp.abs(projection) / jnp.linalg.norm(loss_grads)
27-
+ jnp.sqrt(constraint * m_1 + 1e-8)
28-
)
29-
/ (jnp.linalg.norm(loss_grads) + 1e-8)
30-
)
22+
def compute_lr(alpha_1, g, grad_f_1, m_0, m_1, eta):
23+
grad_f_1, _ = jax.flatten_util.ravel_pytree(grad_f_1)
24+
g, _ = jax.flatten_util.ravel_pytree(g)
25+
theta_1 = grad_f_1.dot(g / (jnp.linalg.norm(g) + _EPS))
26+
lhs = alpha_1 / (2.0 * jnp.abs(theta_1) + jnp.sqrt(alpha_1 * m_1 + _EPS))
3127
m_2 = (
3228
m_0
33-
+ 10.0 * eta * (m_1 / (constraint + 1e-8))
34-
+ 8.0
35-
* eta
36-
* jnp.linalg.norm(projection) ** 2
37-
/ ((jnp.linalg.norm(loss_grads) * constraint) ** 2)
29+
+ 10.0 * eta * (m_1 / (alpha_1 + _EPS))
30+
+ 8.0 * eta * (theta_1 / alpha_1 + _EPS) ** 2
3831
)
3932
rhs = 1.0 / m_2
40-
return jnp.minimum(lhs, rhs)
33+
return jnp.minimum(lhs, rhs), (lhs, rhs)
4134

4235

4336
def lbsgd_update(
44-
state: LBSGDState, updates: PyTree, eta_rate: float, m_0: float, m_1: float
45-
) -> tuple[PyTree, LBSGDState]:
37+
state: LBSGDState,
38+
updates: PyTree,
39+
eta_rate: float,
40+
m_0: float,
41+
m_1: float,
42+
base_lr: float,
43+
backup_lr: float,
44+
) -> tuple[PyTree, LBSGDState, tuple[float, ...]]:
4645
def happy_case():
47-
lr = compute_lr(constraint, loss_grads, constraints_grads, m_0, m_1, eta_t)
46+
lr, (lhs, rhs) = compute_lr(alpha_1, g, grad_f_1, m_0, m_1, eta_t)
4847
new_eta = eta_t / eta_rate
49-
updates = jax.tree_map(lambda x: x * lr, loss_grads)
50-
return updates, LBSGDState(new_eta)
48+
updates = jax.tree_map(lambda x: x * lr / base_lr, g)
49+
return updates, LBSGDState(new_eta), (lr, lhs, rhs)
5150

5251
def fallback():
5352
# Taking the negative gradient of the constraints to minimize the costs
54-
updates = jax.tree_map(lambda x: x * -1.0, constraints_grads)
55-
return updates, LBSGDState(eta_t)
53+
updates = jax.tree_map(lambda x: x * backup_lr, grad_f_1)
54+
return updates, LBSGDState(eta_t), (0.0, 0.0, 0.0)
5655

57-
loss_grads, constraints_grads, constraint = updates
56+
g, grad_f_1, alpha_1 = updates
5857
eta_t = state.eta
5958
return jax.lax.cond(
60-
jnp.greater(constraint, 0.0),
59+
jnp.greater(alpha_1, _EPS),
6160
happy_case,
6261
fallback,
6362
)
@@ -66,17 +65,27 @@ def fallback():
6665
def jacrev(f, has_aux=False):
6766
def jacfn(x):
6867
y, vjp_fn, aux = eqx.filter_vjp(f, x, has_aux=has_aux) # type: ignore
69-
(J,) = eqx.filter_vmap(vjp_fn, in_axes=0)(jnp.eye(len(y)))
68+
(J,) = eqx.filter_vmap(vjp_fn, in_axes=eqx.if_array(0))(jnp.eye(len(y)))
7069
return J, aux
7170

7271
return jacfn
7372

7473

7574
class LBSGDPenalizer:
76-
def __init__(self, m_0, m_1, eta, eta_rate) -> None:
75+
def __init__(
76+
self,
77+
m_0: float,
78+
m_1: float,
79+
eta: float,
80+
eta_rate: float,
81+
base_lr: float,
82+
backup_lr: float = 1e-2,
83+
) -> None:
7784
self.m_0 = m_0
7885
self.m_1 = m_1
7986
self.eta_rate = eta_rate + 1.0
87+
self.base_lr = base_lr
88+
self.backup_lr = backup_lr
8089
self.state = LBSGDState(eta)
8190

8291
def __call__(
@@ -87,19 +96,26 @@ def __call__(
8796
) -> tuple[PyTree, Any, ActorEvaluation, dict[str, jax.Array]]:
8897
def evaluate_helper(actor):
8998
evaluation = evaluate(actor)
90-
outs = jnp.stack([evaluation.loss, evaluation.constraint])
99+
loss = evaluation.loss - state.eta * jnp.log(evaluation.constraint)
100+
outs = jnp.stack([loss, -evaluation.constraint])
91101
return outs, evaluation
92102

93103
jacobian, rest = jacrev(evaluate_helper, has_aux=True)(actor)
94-
loss_grads, constraint_grads = pytrees_unstack(jacobian)
95-
updates, state = lbsgd_update(
104+
g, grad_f_1 = pytrees_unstack(jacobian)
105+
alpha = rest.constraint
106+
updates, state, (lr, lhs, rhs) = lbsgd_update(
96107
state,
97-
(loss_grads, constraint_grads, rest.constraint),
108+
apply_dtype((g, grad_f_1, alpha), jnp.float32),
98109
self.eta_rate,
99110
self.m_0,
100111
self.m_1,
112+
self.base_lr,
113+
self.backup_lr,
101114
)
102115
metrics = {
103-
"agent/lbsgd/eta": state.eta,
116+
"agent/lbsgd/eta": jnp.asarray(state.eta),
117+
"agent/lbsgd/lr": jnp.asarray(lr),
118+
"agent/lbsgd/lhs": jnp.asarray(lhs),
119+
"agent/lbsgd/rhs": jnp.asarray(rhs),
104120
}
105121
return updates, state, rest, metrics

0 commit comments

Comments
 (0)