Skip to content

Commit c9185d8

Browse files
committed
Add temporary testing files
1 parent 50f0640 commit c9185d8

File tree

6 files changed

+271
-0
lines changed

6 files changed

+271
-0
lines changed

config/hardware/2080_rtx.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# @package _global_
2+
3+
hydra:
4+
launcher:
5+
additional_parameters: { "gpus": "rtx_2080_ti:1", "account": "ls_krausea" }
6+

config/hardware/3090_rtx.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# @package _global_
2+
3+
hydra:
4+
launcher:
5+
additional_parameters: { "gpus": "rtx_3090:1", "account": "ls_krausea" }
6+

config/hardware/4090_rtx.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# @package _global_
2+
3+
hydra:
4+
launcher:
5+
additional_parameters: { "gpus": "rtx_4090:1", "account": "ls_krausea" }
6+

config/hydra/launcher/slurm.yaml

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
submitit_folder: ${hydra.sweep.dir}/.submitit/%j
2+
timeout_min: 30
3+
cpus_per_task: 10
4+
tasks_per_node: 1
5+
mem_gb: null
6+
nodes: 1
7+
name: ${hydra.job.name}
8+
_target_: hydra_plugins.hydra_submitit_launcher.submitit_launcher.SlurmLauncher
9+
mem_per_gpu: null
10+
mem_per_cpu: 10240
11+
account: ls_krausea
12+
additional_parameters: {"gpus": "rtx_4090:1", "account": "ls_krausea"}
13+
array_parallelism: 256
14+
max_num_timeout: 100
15+
setup:
16+
- '#SBATCH --requeue'

config/train_brax.yaml

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
defaults:
2+
- _self_
3+
4+
hydra:
5+
run:
6+
dir: ${log_dir}/${now:%Y-%m-%d}/${now:%H-%M-%S}
7+
sweep:
8+
dir: ${log_dir}/${hydra.job.name}
9+
subdir: ${hydra.job.override_dirname}/seed=${training.seed}
10+
job:
11+
config:
12+
override_dirname:
13+
exclude_keys:
14+
- log_dir
15+
- training.seed
16+
- wandb
17+
chdir: true
18+
19+
20+
wandb:
21+
group: null
22+
notes: null
23+
name: ${hydra:job.override_dirname}
24+
25+
jit: true
26+
27+
training:
28+
seed: 0
29+
render: true

train_brax.py

Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
import functools
2+
import logging
3+
import os
4+
5+
import hydra
6+
import jax
7+
from mujoco_playground import registry, wrapper
8+
from omegaconf import OmegaConf
9+
import jax.numpy as jp
10+
11+
12+
import logging
13+
import os
14+
from typing import Any
15+
16+
import numpy as np
17+
import omegaconf
18+
from numpy import typing as npt
19+
from omegaconf import DictConfig
20+
from omegaconf.errors import InterpolationKeyError
21+
22+
_LOG = logging.getLogger(__name__)
23+
24+
25+
class WeightAndBiasesWriter:
26+
def __init__(self, config: DictConfig):
27+
import wandb
28+
29+
try:
30+
name = config.wandb.name
31+
except InterpolationKeyError:
32+
name = None
33+
config.wandb.name = name
34+
config_dict = omegaconf.OmegaConf.to_container(config, resolve=True)
35+
assert isinstance(config_dict, dict)
36+
wandb.init(project="ss2r", resume=True, config=config_dict, **config.wandb)
37+
self._handle = wandb
38+
39+
def log(self, summary: dict[str, float], step: int):
40+
self._handle.log(summary, step=step)
41+
42+
def log_video(
43+
self,
44+
images: npt.ArrayLike,
45+
step: int,
46+
name: str = "policy",
47+
fps: int | float = 30,
48+
):
49+
self._handle.log(
50+
{
51+
name: self._handle.Video(
52+
np.array(images, copy=False),
53+
fps=int(fps),
54+
caption=name,
55+
)
56+
},
57+
step=step,
58+
)
59+
60+
def log_artifact(
61+
self,
62+
path: str,
63+
type: str,
64+
name: str | None = None,
65+
description: str | None = None,
66+
metadata: dict[str, Any] | None = None,
67+
):
68+
if name is None:
69+
name = self._handle.run.id
70+
if metadata is None:
71+
metadata = dict(self._handle.config)
72+
artifact = self._handle.Artifact(name, type, description, metadata)
73+
artifact.add_file(path)
74+
self._handle.log_artifact(artifact, aliases=[self._handle.run.id])
75+
76+
77+
def get_state_path() -> str:
78+
log_path = os.getcwd()
79+
return log_path
80+
81+
82+
env_name = "QuadrupedRun"
83+
env = registry.load(env_name)
84+
env_cfg = registry.get_default_config(env_name)
85+
eval_env = registry.load(env_name, config=env_cfg)
86+
agent_name = "PPO"
87+
88+
89+
def get_ppo_train_fn():
90+
from brax.training.agents.ppo import networks as ppo_networks
91+
from brax.training.agents.ppo import train as ppo
92+
from mujoco_playground.config import locomotion_params
93+
94+
ppo_params = locomotion_params.brax_ppo_config(env_name)
95+
ppo_training_params = dict(ppo_params)
96+
network_factory = ppo_networks.make_ppo_networks
97+
if "network_factory" in ppo_params:
98+
del ppo_training_params["network_factory"]
99+
network_factory = functools.partial(
100+
ppo_networks.make_ppo_networks, **ppo_params.network_factory
101+
)
102+
train_fn = functools.partial(
103+
ppo.train,
104+
**dict(ppo_training_params),
105+
network_factory=network_factory,
106+
)
107+
return train_fn
108+
109+
110+
class Counter:
111+
def __init__(self):
112+
self.count = 0
113+
114+
115+
def report(logger, step, num_steps, metrics):
116+
metrics = {k: float(v) for k, v in metrics.items()}
117+
logger.log(metrics, num_steps)
118+
step.count = num_steps
119+
120+
121+
@functools.partial(jax.jit, static_argnames=("env", "policy", "steps"))
122+
def rollout(
123+
env,
124+
policy,
125+
steps,
126+
rng,
127+
state,
128+
):
129+
def f(carry, _):
130+
state, current_key = carry
131+
current_key, next_key = jax.random.split(current_key)
132+
action, _ = policy(state.obs, current_key)
133+
nstate = env.step(
134+
state,
135+
action,
136+
)
137+
return (nstate, next_key), nstate
138+
139+
(final_state, _), data = jax.lax.scan(f, (state, rng), (), length=steps)
140+
return final_state, data
141+
142+
143+
def pytrees_unstack(pytree):
144+
leaves, treedef = jax.tree_flatten(pytree)
145+
n_trees = leaves[0].shape[0]
146+
new_leaves = [[] for _ in range(n_trees)]
147+
for leaf in leaves:
148+
for i in range(n_trees):
149+
new_leaves[i].append(leaf[i])
150+
new_trees = [treedef.unflatten(leaf) for leaf in new_leaves]
151+
return new_trees
152+
153+
154+
def render(env, policy, steps, rng, camera=None):
155+
state = env.reset(rng)
156+
state = jax.tree_map(lambda x: x[:5], state)
157+
orig_model = env._mjx_model
158+
if hasattr(env, "_randomized_models"):
159+
render_env = env
160+
model = jax.tree_map(
161+
lambda x, ax: jp.take(x, jp.arange(5), axis=ax) if ax is not None else x,
162+
env._randomized_models,
163+
env._in_axes,
164+
)
165+
render_env._randomized_models = model
166+
else:
167+
render_env = env
168+
_, trajectory = rollout(render_env, policy, steps, rng[0], state)
169+
env._mjx_model = orig_model
170+
videos = []
171+
for i in range(5):
172+
ep_trajectory = jax.tree_map(lambda x: x[:, i], trajectory)
173+
ep_trajectory = pytrees_unstack(ep_trajectory)
174+
video = env.render(ep_trajectory, camera=camera)
175+
videos.append(video)
176+
return np.asarray(videos).transpose(0, 1, 4, 2, 3)
177+
178+
179+
@hydra.main(version_base=None, config_path="ss2r/configs", config_name="train_brax")
180+
def main(cfg):
181+
_LOG.info(
182+
f"Setting up experiment with the following configuration: "
183+
f"\n{OmegaConf.to_yaml(cfg)}"
184+
)
185+
logger = WeightAndBiasesWriter(cfg)
186+
train_fn = get_ppo_train_fn()
187+
steps = Counter()
188+
with jax.disable_jit(not cfg.jit):
189+
make_policy, params, _ = train_fn(
190+
environment=env,
191+
eval_env=eval_env,
192+
wrap_env_fn=wrapper.wrap_for_brax_training,
193+
progress_fn=functools.partial(report, logger, steps),
194+
)
195+
if cfg.training.render:
196+
rng = jax.random.split(jax.random.PRNGKey(cfg.training.seed), 128)
197+
video = render(
198+
eval_env,
199+
make_policy(params, deterministic=True),
200+
1000,
201+
rng,
202+
)
203+
logger.log_video(video, steps.count, "eval/video")
204+
_LOG.info("Done training.")
205+
206+
207+
if __name__ == "__main__":
208+
main()

0 commit comments

Comments
 (0)