Skip to content

Commit

Permalink
Implementation of gp data loading
Browse files Browse the repository at this point in the history
  • Loading branch information
yardenas committed Aug 22, 2024
1 parent 6ace84a commit 36c8edb
Show file tree
Hide file tree
Showing 9 changed files with 727 additions and 18 deletions.
419 changes: 412 additions & 7 deletions poetry.lock

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ equinox = "^0.11.1"
jaxlib = { url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-0.4.23+cuda11.cudnn86-cp310-cp310-manylinux2014_x86_64.whl", extras = [
"cuda11_pip",
] }
optax = "^0.1.7"
optax = "0.2.1"
hydra-core = "^1.3.2"
jax = "^0.4.23"
hydra-submitit-launcher = "^1.2.0"
Expand All @@ -29,6 +29,7 @@ jmp = { git = "https://github.com/deepmind/jmp" }
tensorboard = "^2.16.2"
torch = { version = "^2.4.0+cpu", source = "pytorch_cpu" }
torchvision = { version = "^0.19.0+cpu", source = "pytorch_cpu" }
gpjax = "^0.9.0"


[[tool.poetry.source]]
Expand Down
Empty file added safe_opax/cem_gp/__init__.py
Empty file.
88 changes: 88 additions & 0 deletions safe_opax/cem_gp/cem.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
from typing import Callable, TypedDict

import jax
import jax.numpy as jnp

from smbrl.types import RolloutFn

ObjectiveFn = Callable[[jax.Array], jax.Array]


def make_objective(
rollout_fn: RolloutFn,
horizon: int,
initial_state: jax.Array,
key: jax.random.KeyArray,
) -> ObjectiveFn:
def objective(candidates):
sample = lambda x: rollout_fn(horizon, initial_state, key, x)
preds = jax.vmap(sample)(candidates)
assert preds.reward.ndim == 2
return preds.reward.mean(axis=1)

return objective


def solve(
objective_fn: ObjectiveFn,
initial_guess: jax.Array,
key: jax.random.PRNGKeyArray,
num_particles: int,
num_iters: int,
num_elite: int,
stop_cond: float = 0.1,
initial_stddev: float = 1.0,
) -> jax.Array:
mu = initial_guess
stddev = jnp.ones_like(initial_guess) * initial_stddev

def cond(val):
_, iters, _, stddev, *_ = val
return (stddev.mean() > stop_cond) & (iters < num_iters)

def body(val):
key, iter, mu, stddev, best_score, best = val
key, subkey = jax.random.split(key)
eps = jax.random.normal(subkey, (num_particles,) + mu.shape)
sample = eps * stddev[None] + mu[None]
scores = objective_fn(sample)
elite_ids = jnp.argsort(scores)[-num_elite:]
best = jnp.where(
scores[elite_ids[-1]] > best_score, sample[elite_ids[-1]], best
)
best_score = jnp.maximum(best_score, scores[elite_ids[-1]])
elite = sample[elite_ids]
# Moment matching on the `particles` axis
mu, stddev = elite.mean(0), elite.std(0)
return key, iter + 1, mu, stddev, best_score, best

*_, best = jax.lax.while_loop(
cond, body, (key, 0, mu, stddev, -jnp.inf, initial_guess)
)
return best


class CEMConfig(TypedDict):
num_particles: int
num_iters: int
num_elite: int
stop_cond: float
initial_stddev: float


def policy(
observation: jax.Array,
rollout_fn: RolloutFn,
horizon: int,
init_guess: jax.Array,
key: jax.random.KeyArray,
cem_config: CEMConfig,
) -> jax.Array:
objective = make_objective(rollout_fn, horizon, observation, key)
action = solve(
objective,
init_guess,
key,
**cem_config,
)
return action
98 changes: 98 additions & 0 deletions safe_opax/cem_gp/cem_gp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import equinox as eqx
import jax
import jax.numpy as jnp
import numpy as np
from gymnasium.spaces import Box
from omegaconf import DictConfig
import gpjax as gpx

from safe_opax.rl import metrics as m
from safe_opax.cem_gp import cem
from safe_opax.rl.trajectory import TrajectoryData
from safe_opax.rl.types import FloatArray, RolloutFn
from safe_opax.rl.utils import normalize


@eqx.filter_jit
def policy(
observation: jax.Array,
sample: RolloutFn,
horizon: int,
init_guess: jax.Array,
key: jax.random.KeyArray,
cem_config: cem.CEMConfig,
):
# vmap over batches of observations (e.g., solve cem separately for
# each individual environment)
cem_per_env = jax.vmap(
lambda o: cem.policy(o, sample, horizon, init_guess, key, cem_config)
)
return cem_per_env(observation)


class CEMGP:
def __init__(
self,
observation_space: Box,
action_space: Box,
config: DictConfig,
):
self.obs_normalizer = m.MetricsAccumulator()
self.data = None
self.model = None
self.metrics_monitor = m.MetricsMonitor()
self.plan = np.zeros(
(config.training.parallel_envs, config.agent.plan_horizon)
+ action_space.shape
)
self.config = config
self.action_space = action_space

def __call__(self, observation: FloatArray, train: bool = False) -> FloatArray:
normalized_obs = normalize(
observation,
self.obs_normalizer.result.mean,
self.obs_normalizer.result.std,
)
horizon = self.config.agent.plan_horizon
if self.model is not None:
init_guess = jnp.zeros((self.plan, self.action_space.shape[-1]))
action = policy(
normalized_obs,
self.model.sample,
horizon,
init_guess,
next(self.prng),
self.config.agent.cem,
)
self.plan = np.asarray(action)
else:
# TODO (yarden): make this nicer (uniform with scale as parameter)
return np.repeat(
self.action_space.sample()[None], self.config.training.parallel_envs
) * self.config.agent.initial_action_scale
return self.plan[:, 0]

def observe(self, trajectory: TrajectoryData) -> None:
self.obs_normalizer.update_state(
np.concatenate(
[trajectory.observation, trajectory.next_observation[:, -1:]],
axis=1,
),
axis=(0, 1),
)
new_data = _prepare_data(trajectory, self.obs_normalizer)
if self.data is None:
self.data = gpx.Dataset(*new_data)
else:
self.data += gpx.Dataset(*new_data)


def _prepare_data(trajectory: TrajectoryData, normalizer):
results = normalizer.result
normalize_fn = lambda x: normalize(x, results.mean, results.std)
normalized_obs = normalize_fn(trajectory.observation)
normalized_next_obs = normalize_fn(trajectory.next_observation)
x = np.concatenate([normalized_obs, trajectory.action], axis=-1)
y = normalized_next_obs
return x, y
102 changes: 102 additions & 0 deletions safe_opax/cem_gp/gp_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
from typing import Callable
import jax
import jax.numpy as jnp
import gpjax as gpx
import equinox as eqx

from safe_opax.rl.types import Policy, Prediction, ShiftScale


class GPModel(eqx.Module):
x: jax.Array
y: jax.Array
posteriors: list[gpx.gps.ConjugatePosterior] = eqx.field(static=True)
reward_fn: Callable[[jax.Array], jax.Array] = eqx.field(static=True)
cost_fn: Callable[[jax.Array], jax.Array] = eqx.field(static=True)

def __init__(
self,
x: jax.Array,
y: jax.Array,
reward_fn: Callable[[jax.Array], jax.Array],
cost_fn: Callable[[jax.Array], jax.Array],
):
# FIXME (yarden): this should be over multiple dimensions
self.x, self.y = x, y
prior = gpx.gps.Prior(
mean_function=gpx.mean_functions.Zero(), kernel=gpx.kernels.RBF()
)
likelihood = gpx.likelihoods.Gaussian(num_datapoints=x.shape[0])
posterior = prior * likelihood
posteriors = []
for i in range(y.shape[-1]):
p, _ = gpx.fit_scipy(
model=posterior,
train_data=gpx.Dataset(x, y[:, i : i + 1]),
objective=lambda p, d: -gpx.objectives.conjugate_mll(p, d),
)
posteriors.append(p)
self.posteriors = posteriors
self.reward_fn = reward_fn
self.cost_fn = cost_fn

def sample(
self,
horizon: int,
initial_state: jax.Array,
key: jax.Array,
policy: Policy,
) -> tuple[Prediction, ShiftScale]:
def f(carry, inputs):
prev_state = carry
if callable(policy):
key = inputs
key, p_key = jax.random.split(key)
action = policy(jax.lax.stop_gradient(prev_state.flatten()), p_key)
else:
action, key = inputs
x = jnp.concatenate([prev_state, action], axis=-1)
key, prior_key = jax.random.split(key)
predictive_distribution = _multioutput_predict(
self.x, self.y, x, self.posteriors
)
state = predictive_distribution.sample(seed=prior_key).squeeze()
return state, (
state,
(
predictive_distribution.mean().squeeze(),
predictive_distribution.stddev().squeeze(),
),
)

if isinstance(policy, jax.Array):
inputs: tuple[jax.Array, jax.Array] | jax.Array = (
policy,
jax.random.split(key, policy.shape[0]),
)
assert policy.shape[0] <= horizon
elif callable(policy):
inputs = jax.random.split(key, horizon)
_, (trajectory, priors) = jax.lax.scan(f, initial_state, inputs)
reward = self.reward_fn(trajectory)
cost = self.cost_fn(trajectory)
out = Prediction(trajectory, reward, cost)
return out, priors


def _multioutput_predict(x_train, y_train, x_test, posteriors):
# TODO (yarden): Can technically stack trees then vmap, but not important now.
distributions = []
for i, posterior in enumerate(posteriors):
distribution = posterior.predict(
test_inputs=x_test[None],
train_data=gpx.Dataset(x_train, y_train[:, i : i + 1]),
)
predictive_distribution = posterior.likelihood(distribution)
distributions.append(predictive_distribution)
return _pytrees_stack(distributions)


def _pytrees_stack(pytrees, axis=0):
results = jax.tree_map(lambda *values: jnp.stack(values, axis=axis), *pytrees)
return results
2 changes: 2 additions & 0 deletions safe_opax/configs/agent/cem_gp.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
name: cem_gp
initial_action_scale: 0.1
2 changes: 1 addition & 1 deletion safe_opax/la_mbda/la_mbda.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def observe(self, trajectory: TrajectoryData) -> None:
add_to_buffer(
self.replay_buffer,
trajectory,
self.config.training.scale_reward,
reward_scale=self.config.training.scale_reward,
)
self.state = jax.tree_map(lambda x: jnp.zeros_like(x), self.state)

Expand Down
31 changes: 22 additions & 9 deletions safe_opax/rl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,29 @@ def take_n(self, n):
return keys[1:]


def add_to_buffer(buffer, trajectory, reward_scale):
buffer.add(
TrajectoryData(
trajectory.observation,
trajectory.next_observation,
trajectory.action,
trajectory.reward * reward_scale,
trajectory.cost,
def add_to_buffer(buffer, trajectory, *, normalizer=None, reward_scale=1):
if normalizer is not None:
results = normalizer.result
normalize_fn = lambda x: normalize(x, results.mean, results.std)
buffer.add(
TrajectoryData(
normalize_fn(trajectory.observation),
normalize_fn(trajectory.next_observation),
trajectory.action,
trajectory.reward * reward_scale,
trajectory.cost,
)
)
else:
buffer.add(
TrajectoryData(
trajectory.observation,
trajectory.next_observation,
trajectory.action,
trajectory.reward * reward_scale,
trajectory.cost,
)
)
)


def normalize(observation, mean, std):
Expand Down

0 comments on commit 36c8edb

Please sign in to comment.