Skip to content
Open
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions learning/train_jax_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from absl import app
from absl import flags
from absl import logging
from brax.io import model
from brax.training.agents.ppo import networks as ppo_networks
from brax.training.agents.ppo import networks_vision as ppo_networks_vision
from brax.training.agents.ppo import train as ppo
Expand Down Expand Up @@ -73,10 +74,16 @@
_LOAD_CHECKPOINT_PATH = flags.DEFINE_string(
"load_checkpoint_path", None, "Path to load checkpoint from"
)
_SAVE_PARAMS_PATH = flags.DEFINE_string(
"save_params_path", None, "Path to save parameters to"
)
_SUFFIX = flags.DEFINE_string("suffix", None, "Suffix for the experiment name")
_PLAY_ONLY = flags.DEFINE_boolean(
"play_only", False, "If true, only play with the model and do not train"
)
_RENDER_FINAL_POLICY = flags.DEFINE_boolean(
"render_final_policy", True, "If true, render the final policy"
)
_USE_WANDB = flags.DEFINE_boolean(
"use_wandb",
False,
Expand Down Expand Up @@ -361,6 +368,12 @@ def progress(num_steps, metrics):
print(f"Time to JIT compile: {times[1] - times[0]}")
print(f"Time to train: {times[-1] - times[1]}")

if _SAVE_PARAMS_PATH.value is not None:
model.save_params(epath.Path(_SAVE_PARAMS_PATH.value).resolve(), params)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Andrew-Luo1 would really like to not use the pkl stuff, is this absolutely necessary?


if not _RENDER_FINAL_POLICY.value:
return

print("Starting inference...")

# Create inference function
Expand Down
1 change: 1 addition & 0 deletions mujoco_playground/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from mujoco_playground._src.mjx_env import render_array
from mujoco_playground._src.mjx_env import State
from mujoco_playground._src.mjx_env import step

# pylint: enable=g-importing-member

__all__ = [
Expand Down
4 changes: 3 additions & 1 deletion mujoco_playground/_src/dm_control_suite/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,8 @@ def load(
An instance of the environment.
"""
if env_name not in _envs:
raise ValueError(f"Env '{env_name}' not found. Available envs: {_cfgs.keys()}")
raise ValueError(
f"Env '{env_name}' not found. Available envs: {_cfgs.keys()}"
)
config = config or get_default_config(env_name)
return _envs[env_name](config=config, config_overrides=config_overrides)
4 changes: 3 additions & 1 deletion mujoco_playground/_src/locomotion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,9 @@ def load(
An instance of the environment.
"""
if env_name not in _envs:
raise ValueError(f"Env '{env_name}' not found. Available envs: {_cfgs.keys()}")
raise ValueError(
f"Env '{env_name}' not found. Available envs: {_cfgs.keys()}"
)
config = config or get_default_config(env_name)
return _envs[env_name](config=config, config_overrides=config_overrides)

Expand Down
3 changes: 1 addition & 2 deletions mujoco_playground/_src/locomotion/t1/randomize.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from mujoco import mjx
import numpy as np


FLOOR_GEOM_ID = 0
TORSO_BODY_ID = 1
ANKLE_JOINT_IDS = np.array([[21, 22, 27, 28]])
Expand All @@ -30,7 +29,7 @@ def rand_dynamics(rng):
# Floor friction: =U(0.4, 1.0).
rng, key = jax.random.split(rng)
geom_friction = model.geom_friction.at[FLOOR_GEOM_ID, 0].set(
jax.random.uniform(key, minval=0.2, maxval=.6)
jax.random.uniform(key, minval=0.2, maxval=0.6)
)

rng, key = jax.random.split(rng)
Expand Down
15 changes: 14 additions & 1 deletion mujoco_playground/_src/manipulation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
from mujoco_playground._src import mjx_env
from mujoco_playground._src.manipulation.aloha import handover as aloha_handover
from mujoco_playground._src.manipulation.aloha import single_peg_insertion as aloha_peg
from mujoco_playground._src.manipulation.aloha.s2r import distillation as aloha_s2r_distillation
from mujoco_playground._src.manipulation.aloha.s2r import peg_insertion as aloha_s2r_peg_insertion
from mujoco_playground._src.manipulation.aloha.s2r import pick as aloha_s2r_pick
from mujoco_playground._src.manipulation.franka_emika_panda import open_cabinet as panda_open_cabinet
from mujoco_playground._src.manipulation.franka_emika_panda import pick as panda_pick
from mujoco_playground._src.manipulation.franka_emika_panda import pick_cartesian as panda_pick_cartesian
Expand All @@ -31,6 +34,9 @@

_envs = {
"AlohaHandOver": aloha_handover.HandOver,
"AlohaS2RPick": aloha_s2r_pick.Pick,
"AlohaS2RPegInsertion": aloha_s2r_peg_insertion.PegInsertion,
"AlohaS2RPegInsertionDistill": aloha_s2r_distillation.DistillPegInsertion,
"AlohaSinglePegInsertion": aloha_peg.SinglePegInsertion,
"PandaPickCube": panda_pick.PandaPickCube,
"PandaPickCubeOrientation": panda_pick.PandaPickCubeOrientation,
Expand All @@ -43,6 +49,9 @@

_cfgs = {
"AlohaHandOver": aloha_handover.default_config,
"AlohaS2RPick": aloha_s2r_pick.default_config,
"AlohaS2RPegInsertion": aloha_s2r_peg_insertion.default_config,
"AlohaS2RPegInsertionDistill": aloha_s2r_distillation.default_config,
"AlohaSinglePegInsertion": aloha_peg.default_config,
"PandaPickCube": panda_pick.default_config,
"PandaPickCubeOrientation": panda_pick.default_config,
Expand All @@ -56,6 +65,8 @@
_randomizer = {
"LeapCubeRotateZAxis": leap_rotate_z.domain_randomize,
"LeapCubeReorient": leap_cube_reorient.domain_randomize,
"AlohaS2RPick": aloha_s2r_pick.domain_randomize,
"AlohaS2RPegInsertionDistill": aloha_s2r_distillation.domain_randomize,
}


Expand Down Expand Up @@ -108,7 +119,9 @@ def load(
An instance of the environment.
"""
if env_name not in _envs:
raise ValueError(f"Env '{env_name}' not found. Available envs: {_cfgs.keys()}")
raise ValueError(
f"Env '{env_name}' not found. Available envs: {_cfgs.keys()}"
)
config = config or get_default_config(env_name)
return _envs[env_name](config=config, config_overrides=config_overrides)

Expand Down
22 changes: 22 additions & 0 deletions mujoco_playground/_src/manipulation/aloha/aloha_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,25 @@
"right/left_finger",
"right/right_finger",
]

LEFT_JOINTS = [
"left/waist",
"left/shoulder",
"left/elbow",
"left/forearm_roll",
"left/wrist_angle",
"left/wrist_rotate",
"left/left_finger",
"left/right_finger",
]

RIGHT_JOINTS = [
"right/waist",
"right/shoulder",
"right/elbow",
"right/forearm_roll",
"right/wrist_angle",
"right/wrist_rotate",
"right/left_finger",
"right/right_finger",
]
3 changes: 3 additions & 0 deletions mujoco_playground/_src/manipulation/aloha/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ def get_assets() -> Dict[str, bytes]:
path = mjx_env.ROOT_PATH / "manipulation" / "aloha" / "xmls"
mjx_env.update_assets(assets, path, "*.xml")
mjx_env.update_assets(assets, path / "assets")
path = mjx_env.ROOT_PATH / "manipulation" / "aloha" / "xmls" / "s2r"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no longer needed FWIU

mjx_env.update_assets(assets, path, "*.xml")
mjx_env.update_assets(assets, path / "assets")
return assets


Expand Down
122 changes: 122 additions & 0 deletions mujoco_playground/_src/manipulation/aloha/s2r/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
### Sim-to-Real Transfer of a Bi-Arm RL Policy via Pixel-Based Behaviour Cloning

https://github.com/user-attachments/assets/205fe8b9-1773-4715-8025-5de13490d0da

---

**Distillation**

In this module, we demonstrate policy distillation—a straightforward method for deploying a simulation-trained reinforcement learning policy that initially uses privileged state observations (such as object positions). The process involves two steps:

1. **Teacher Policy Training:** A state-based teacher policy is trained using RL.
2. **Student Policy Distillation:** The teacher is then distilled into a student policy via behaviour cloning (BC), where the student learns to map its observations $o_s(x)$ (e.g., exteroceptive RGBD images) to the teacher’s deterministic actions $\pi_t(o_t(x))$. For example, while both policies observe joint angles, the student uses RGBD images, whereas the teacher directly accesses (noisy) object positions.

The distillation process—where the student uses left and right wrist-mounted RGBD cameras for exteroception—takes about **3 minutes** on an RTX4090. This rapid turnaround is due to three factors:

1. [Very fast rendering](https://github.com/google-deepmind/mujoco_playground/blob/main/mujoco_playground/experimental/madrona_benchmarking/figures/cartpole_benchmark_full.png) provided by Madrona MJX.
2. The sample efficiency of behaviour cloning.
3. The use of low-resolution (32×32) rendering, which is sufficient for precise alignment given the wrist camera placement.

For further details on the teacher policy and RGBD sim-to-real techniques, please refer to the [technical report](https://docs.google.com/presentation/d/1v50Vg-SJdy5HV5JmPHALSwph9mcVI2RSPRdrxYR3Bkg/edit?usp=sharing).

---

**Teacher Policy**

This module includes an alternative state-based peg insertion implementation compared to the [single-file implementation](https://github.com/google-deepmind/mujoco_playground/blob/main/mujoco_playground/_src/manipulation/aloha/single_peg_insertion.py). The two main differences are:

1. **Robustness Enhancements:** We add observation noise and random action delays so that the teacher policy does not act with an unrealistic degree of certainty.
2. **Phased Training:** To better control reward shaping, the training is split into two phases:
- **Phase 1:** Bringing the objects to a pre-insertion position using a single-arm pick policy trained to pick up randomly-sized blocks, which is then deployed on both arms to ensure symmetry.
- **Phase 2:** The actual peg insertion phase.

---

**A Note on Sample Efficiency**

Behaviour cloning (BC) can be orders of magnitude more sample-efficient than reinforcement learning. In our approach, we use an L2 loss defined as:

$|| \pi_s(o_s(x)) - \pi_t(o_t(x)) ||$

In contrast, the policy gradient in RL generally takes the form:

![Equation](https://latex.codecogs.com/svg.latex?\nabla_\theta%20J(\theta)%20=%20\mathbb{E}_{\tau%20\sim%20\theta}%20\left[\sum_t%20\nabla_\theta%20\log%20\pi_\theta(a_t%20|%20s_t)%20R(\tau)\right])

Two key observations highlight why BC’s direct supervision is more efficient:

- **Explicit Loss Signal:** The BC loss compares against the teacher action, giving explicit feedback on how the action should be adjusted. In contrast, the policy gradient only provides directional guidance, instructing the optimizer to increase or decrease an action’s likelihood based solely on its downstream rewards.
- **Per-Dimension Supervision:** While the policy gradient applies a uniform weighting across all action dimensions, BC supplies per-dimension information, making it easier to scale to high-dimensional action spaces.

---

### Training Instructions

**Pre-requisites**

- **Teacher Phases 1 and 2:** Require only the standard Playground setup.
- **Distillation:** Requires the installation of Madrona MJX.
- **Jax-to-ONNX Conversion:** Requires several additional Python packages. The script `aloha_nets_to_onnx` converts all checkpoints into ONNX policies under `experimental/jax2onnx/onnx_policies`.

**Running the Training**

```bash
cd <PATH_TO_YOUR_CLONE>
export PARAMS_PATH=mujoco_playground/_src/manipulation/aloha/s2r/params

# Teacher phase 1 (Note: Domain randomization disables rendering)
python learning/train_jax_ppo.py --env_name AlohaS2RPick --domain_randomization --norender_final_policy --save_params_path $PARAMS_PATH/AlohaS2RPick.prms
sleep 0.5

# Teacher phase 2
python learning/train_jax_ppo.py --env_name AlohaS2RPegInsertion --save_params_path $PARAMS_PATH/AlohaS2RPegInsertion.prms
sleep 0.5

# Distill to student (skip evaluations to save time)
python mujoco_playground/experimental/train_dagger.py --domain-randomization --num-evals 0 --print-loss

# Bonus: Convert to ONNX for easy deployment
python mujoco_playground/experimental/jax2onnx/aloha_nets_to_onnx.py --checkpoint_path <YOUR_DISTILL_CHECKPOINT_DIR>
```

The expected training times on an RTX4090 are as follows, with a total training time of approximately **23.5 minutes**:

| Phase | Time (jit) | Time (run) | Algorithm | Inputs |
|---------------------------|------------|------------|---------------------|--------|
| **Teacher phase 1** | 1.5 min | 8.5 min | PPO | State |
| **Teacher phase 2** | 1.5 min | 8.5 min | PPO | State |
| **Student Distillation** | 24 s | 2 min 55 s | Behaviour Cloning | Pixels |

---

### Pre-trained Parameters

All pre-trained parameters are stored in `_src/manipulation/aloha/s2r/params`.

| File | Description |
|------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------|
| **AlohaS2RPick.prms** | Used during teacher phase 2 training and distillation. A default policy (from `pick.py`) is provided. |
| **AlohaS2RPegInsertion.prms**| Used during distillation. A default policy (from `peg_insertion.py`) is provided. |
| **VisionMLP2ChanCIFAR10.prms** | Based on [NatureCNN](https://github.com/google/brax/blob/241f9bc5bbd003f9cfc9ded7613388e2fe125af6/brax/training/networks.py#L153), also known as AtariCNN. This model is pre-trained on CIFAR10 to achieve over 70% classification accuracy. |

*Note:* The standard CIFAR10 pre-training code is not included. For reference, see [this tutorial](https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/JAX/tutorial5/Inception_ResNet_DenseNet.html).

---

### XML Additions

Modifications have been made to the original XML files (located in `_src/manipulation/aloha/xmls/`) to support training and visually match the physical robot. These files are now stored under `_src/manipulation/aloha/xmls/s2r`.

| File | Description |
|-----------------------|-------------------------------------------------------------------------------------------------------------------|
| **mjx_aloha.xml** | Adjusts lighting, adds additional sites, and hides wrist cameras from the batch renderer. |
| **mjx_scene.xml** | Removes certain lights and adjusts the table color. |
| **mjx_peg_insertion.xml** | Adds additional sites for improved simulation accuracy. |
| **mjx_aloha_single.xml** | Removes the right arm from the simulation. |
| **mjx_scene_single.xml** | Loads the single-arm configuration defined in `mjx_aloha_single.xml`. |
| **mjx_pick.xml** | Contains a single block for pick tasks. |

---

### Aloha Deployment Setup

For deployment, the ONNX policy is executed on the Aloha robot using a custom fork of [OpenPI](https://github.com/Physical-Intelligence/openpi) along with the Interbotix Aloha ROS packages. Acknowledgements to Kevin Zakka, Laura Smith and the Levine Lab for robot deployment setup!
1 change: 1 addition & 0 deletions mujoco_playground/_src/manipulation/aloha/s2r/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""S2R (Sim-to-Real) module for ALOHA manipulation tasks."""
Loading