Skip to content

Commit

Permalink
Bug fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
yardenas committed Sep 1, 2024
1 parent 2454f59 commit 5a890e8
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 7 deletions.
2 changes: 1 addition & 1 deletion safe_opax/cem_gp/cem.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def make_objective(
) -> ObjectiveFn:
def objective(candidates):
sample = lambda x: rollout_fn(horizon, initial_state, key, x)
preds = jax.vmap(sample)(candidates)
preds, dist = jax.vmap(sample)(candidates)
assert preds.reward.ndim == 2
return preds.reward.mean(axis=1)

Expand Down
2 changes: 1 addition & 1 deletion safe_opax/cem_gp/cem_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,4 +130,4 @@ def cartpole_reward(observation):

def cartpole_cost(observation, slider_position_bound):
cart_position = observation[..., 0]
return jnp.where(jnp.abs(cart_position) >= slider_position_bound)
return jnp.where(jnp.abs(cart_position) >= slider_position_bound, 1.0, 0.0)
8 changes: 3 additions & 5 deletions safe_opax/cem_gp/gp_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import equinox as eqx
from jmp import get_policy

from safe_opax.common.mixed_precision import apply_dtype, apply_mixed_precision
from safe_opax.common.mixed_precision import apply_mixed_precision
from safe_opax.rl.types import Policy, Prediction, ShiftScale


Expand Down Expand Up @@ -97,21 +97,19 @@ def _pytrees_stack(pytrees, axis=0):
return results


@eqx.filter_jit
@apply_mixed_precision(
policy=get_policy("params=float32,compute=float64,output=float32"),
target_input_names=["x", "y"],
)
def compute_posteriors(x, y, posterior):
posterior_f64 = apply_dtype(posterior, jnp.float64)
posteriors = []
for i in range(y.shape[-1]):
p, _ = gpx.fit_scipy(
model=posterior_f64,
model=posterior,
train_data=gpx.Dataset(x, y[:, i : i + 1]),
objective=lambda p, d: -gpx.objectives.conjugate_mll(p, d),
max_iters=1,
verbose=False
verbose=False,
)
posteriors.append(p)
return posteriors

0 comments on commit 5a890e8

Please sign in to comment.