diff --git a/learning/train_jax_ppo.py b/learning/train_jax_ppo.py index 8ac0260eb..53d5c7cf2 100644 --- a/learning/train_jax_ppo.py +++ b/learning/train_jax_ppo.py @@ -73,10 +73,16 @@ _LOAD_CHECKPOINT_PATH = flags.DEFINE_string( "load_checkpoint_path", None, "Path to load checkpoint from" ) +_SAVE_PARAMS_PATH = flags.DEFINE_string( + "save_params_path", None, "Path to save parameters to" +) _SUFFIX = flags.DEFINE_string("suffix", None, "Suffix for the experiment name") _PLAY_ONLY = flags.DEFINE_boolean( "play_only", False, "If true, only play with the model and do not train" ) +_RENDER_FINAL_POLICY = flags.DEFINE_boolean( + "render_final_policy", True, "If true, render the final policy" +) _USE_WANDB = flags.DEFINE_boolean( "use_wandb", False, @@ -132,6 +138,33 @@ "policy_obs_key", "state", "Policy obs key" ) _VALUE_OBS_KEY = flags.DEFINE_string("value_obs_key", "state", "Value obs key") +_RSCOPE_ENVS = flags.DEFINE_integer( + "rscope_envs", + None, + "Number of parallel environment rollouts to save for the rscope viewer", +) +_DETERMINISTIC_RSCOPE = flags.DEFINE_boolean( + "deterministic_rscope", + True, + "Run deterministic rollouts for the rscope viewer", +) +_RUN_EVALS = flags.DEFINE_boolean( + "run_evals", + True, + "Run evaluation rollouts between policy updates.", +) +_LOG_TRAINING_METRICS = flags.DEFINE_boolean( + "log_training_metrics", + False, + "Whether to log training metrics and callback to progress_fn. Significantly" + " slows down training if too frequent.", +) +_TRAINING_METRICS_STEPS = flags.DEFINE_integer( + "training_metrics_steps", + 1_000_000, + "Number of steps between logging training metrics. Increase if training" + " experiences slowdown.", +) def get_rl_config(env_name: str) -> config_dict.ConfigDict: @@ -151,6 +184,24 @@ def get_rl_config(env_name: str) -> config_dict.ConfigDict: raise ValueError(f"Env {env_name} not found in {registry.ALL_ENVS}.") +def rscope_fn(full_states, obs, rew, done): + """ + All arrays are of shape (unroll_length, rscope_envs, ...) + full_states: dict with keys 'qpos', 'qvel', 'time', 'metrics' + obs: nd.array or dict obs based on env configuration + rew: nd.array rewards + done: nd.array done flags + """ + # Calculate cumulative rewards per episode, stopping at first done flag + done_mask = jp.cumsum(done, axis=0) + valid_rewards = rew * (done_mask == 0) + episode_rewards = jp.sum(valid_rewards, axis=0) + print( + "Collected rscope rollouts with reward" + f" {episode_rewards.mean():.3f} +- {episode_rewards.std():.3f}" + ) + + def main(argv): """Run training and evaluation for the specified environment.""" @@ -209,11 +260,16 @@ def main(argv): ppo_params.network_factory.policy_obs_key = _POLICY_OBS_KEY.value if _VALUE_OBS_KEY.present: ppo_params.network_factory.value_obs_key = _VALUE_OBS_KEY.value - if _VISION.value: env_cfg.vision = True env_cfg.vision_config.render_batch_size = ppo_params.num_envs env = registry.load(_ENV_NAME.value, config=env_cfg) + if _RUN_EVALS.present: + ppo_params.run_evals = _RUN_EVALS.value + if _LOG_TRAINING_METRICS.present: + ppo_params.log_training_metrics = _LOG_TRAINING_METRICS.value + if _TRAINING_METRICS_STEPS.present: + ppo_params.training_metrics_steps = _TRAINING_METRICS_STEPS.value print(f"Environment Config:\n{env_cfg}") print(f"PPO Training Parameters:\n{ppo_params}") @@ -260,7 +316,10 @@ def main(argv): restore_checkpoint_path = None # Set up checkpoint directory - ckpt_path = logdir / "checkpoints" + if _SAVE_PARAMS_PATH.value is not None: + ckpt_path = epath.Path(_SAVE_PARAMS_PATH.value).resolve() / "checkpoints" + else: + ckpt_path = logdir / "checkpoints" ckpt_path.mkdir(parents=True, exist_ok=True) print(f"Checkpoint path: {ckpt_path}") @@ -268,13 +327,6 @@ def main(argv): with open(ckpt_path / "config.json", "w", encoding="utf-8") as fp: json.dump(env_cfg.to_dict(), fp, indent=4) - # Define policy parameters function for saving checkpoints - def policy_params_fn(current_step, make_policy, params): # pylint: disable=unused-argument - orbax_checkpointer = ocp.PyTreeCheckpointer() - save_args = orbax_utils.save_args_from_target(params) - path = ckpt_path / f"{current_step}" - orbax_checkpointer.save(path, params, force=True, save_args=save_args) - training_params = dict(ppo_params) if "network_factory" in training_params: del training_params["network_factory"] @@ -319,9 +371,9 @@ def policy_params_fn(current_step, make_policy, params): # pylint: disable=unus ppo.train, **training_params, network_factory=network_factory, - policy_params_fn=policy_params_fn, seed=_SEED.value, restore_checkpoint_path=restore_checkpoint_path, + save_checkpoint_path=ckpt_path, wrap_env_fn=None if _VISION.value else wrapper.wrap_for_brax_training, num_eval_envs=num_eval_envs, ) @@ -341,18 +393,55 @@ def progress(num_steps, metrics): for key, value in metrics.items(): writer.add_scalar(key, value, num_steps) writer.flush() - - print(f"{num_steps}: reward={metrics['eval/episode_reward']:.3f}") + if _RUN_EVALS.value: + print(f"{num_steps}: reward={metrics['eval/episode_reward']:.3f}") + if _LOG_TRAINING_METRICS.value: + if "episode/sum_reward" in metrics: + print( + f"{num_steps}: mean episode" + f" reward={metrics['episode/sum_reward']:.3f}" + ) # Load evaluation environment eval_env = ( None if _VISION.value else registry.load(_ENV_NAME.value, config=env_cfg) ) + policy_params_fn = lambda *args: None + if _RSCOPE_ENVS.value: + # Interactive visualisation of policy checkpoints + from rscope import brax as rscope_utils + + if not _VISION.value: + rscope_env = registry.load(_ENV_NAME.value, config=env_cfg) + rscope_env = wrapper.wrap_for_brax_training( + rscope_env, + episode_length=ppo_params.episode_length, + action_repeat=ppo_params.action_repeat, + randomization_fn=training_params.get("randomization_fn"), + ) + else: + rscope_env = env + + rscope_handle = rscope_utils.BraxRolloutSaver( + rscope_env, + ppo_params, + _VISION.value, + _RSCOPE_ENVS.value, + _DETERMINISTIC_RSCOPE.value, + jax.random.PRNGKey(_SEED.value), + rscope_fn, + ) + + def policy_params_fn(current_step, make_policy, params): # pylint: disable=unused-argument + rscope_handle.set_make_policy(make_policy) + rscope_handle.dump_rollout(params) + # Train or load the model make_inference_fn, params, _ = train_fn( # pylint: disable=no-value-for-parameter environment=env, progress_fn=progress, + policy_params_fn=policy_params_fn, eval_env=None if _VISION.value else eval_env, ) @@ -361,6 +450,9 @@ def progress(num_steps, metrics): print(f"Time to JIT compile: {times[1] - times[0]}") print(f"Time to train: {times[-1] - times[1]}") + if not _RENDER_FINAL_POLICY.value: + return + print("Starting inference...") # Create inference function @@ -368,6 +460,9 @@ def progress(num_steps, metrics): jit_inference_fn = jax.jit(inference_fn) # Prepare for evaluation + eval_env = ( + None if _VISION.value else registry.load(_ENV_NAME.value, config=env_cfg) + ) num_envs = 1 if _VISION.value: eval_env = env diff --git a/mujoco_playground/__init__.py b/mujoco_playground/__init__.py index 04865c457..c76b9daa4 100644 --- a/mujoco_playground/__init__.py +++ b/mujoco_playground/__init__.py @@ -25,6 +25,7 @@ from mujoco_playground._src.mjx_env import render_array from mujoco_playground._src.mjx_env import State from mujoco_playground._src.mjx_env import step + # pylint: enable=g-importing-member __all__ = [ diff --git a/mujoco_playground/_src/dm_control_suite/__init__.py b/mujoco_playground/_src/dm_control_suite/__init__.py index 572670ada..a45f2948f 100644 --- a/mujoco_playground/_src/dm_control_suite/__init__.py +++ b/mujoco_playground/_src/dm_control_suite/__init__.py @@ -150,6 +150,8 @@ def load( An instance of the environment. """ if env_name not in _envs: - raise ValueError(f"Env '{env_name}' not found. Available envs: {_cfgs.keys()}") + raise ValueError( + f"Env '{env_name}' not found. Available envs: {_cfgs.keys()}" + ) config = config or get_default_config(env_name) return _envs[env_name](config=config, config_overrides=config_overrides) diff --git a/mujoco_playground/_src/locomotion/__init__.py b/mujoco_playground/_src/locomotion/__init__.py index ce5d98287..227a38cb5 100644 --- a/mujoco_playground/_src/locomotion/__init__.py +++ b/mujoco_playground/_src/locomotion/__init__.py @@ -174,7 +174,9 @@ def load( An instance of the environment. """ if env_name not in _envs: - raise ValueError(f"Env '{env_name}' not found. Available envs: {_cfgs.keys()}") + raise ValueError( + f"Env '{env_name}' not found. Available envs: {_cfgs.keys()}" + ) config = config or get_default_config(env_name) return _envs[env_name](config=config, config_overrides=config_overrides) diff --git a/mujoco_playground/_src/locomotion/t1/randomize.py b/mujoco_playground/_src/locomotion/t1/randomize.py index 85e0dc112..ea132d793 100644 --- a/mujoco_playground/_src/locomotion/t1/randomize.py +++ b/mujoco_playground/_src/locomotion/t1/randomize.py @@ -18,7 +18,6 @@ from mujoco import mjx import numpy as np - FLOOR_GEOM_ID = 0 TORSO_BODY_ID = 1 ANKLE_JOINT_IDS = np.array([[21, 22, 27, 28]]) @@ -30,7 +29,7 @@ def rand_dynamics(rng): # Floor friction: =U(0.4, 1.0). rng, key = jax.random.split(rng) geom_friction = model.geom_friction.at[FLOOR_GEOM_ID, 0].set( - jax.random.uniform(key, minval=0.2, maxval=.6) + jax.random.uniform(key, minval=0.2, maxval=0.6) ) rng, key = jax.random.split(rng) diff --git a/mujoco_playground/_src/manipulation/__init__.py b/mujoco_playground/_src/manipulation/__init__.py index 1d481dc02..69dbd2f83 100644 --- a/mujoco_playground/_src/manipulation/__init__.py +++ b/mujoco_playground/_src/manipulation/__init__.py @@ -20,8 +20,10 @@ from mujoco import mjx from mujoco_playground._src import mjx_env +from mujoco_playground._src.manipulation.aloha import distillation as aloha_distillation from mujoco_playground._src.manipulation.aloha import handover as aloha_handover -from mujoco_playground._src.manipulation.aloha import single_peg_insertion as aloha_peg +from mujoco_playground._src.manipulation.aloha import peg_insertion as aloha_peg_insertion +from mujoco_playground._src.manipulation.aloha import pick as aloha_pick from mujoco_playground._src.manipulation.franka_emika_panda import open_cabinet as panda_open_cabinet from mujoco_playground._src.manipulation.franka_emika_panda import pick as panda_pick from mujoco_playground._src.manipulation.franka_emika_panda import pick_cartesian as panda_pick_cartesian @@ -31,7 +33,9 @@ _envs = { "AlohaHandOver": aloha_handover.HandOver, - "AlohaSinglePegInsertion": aloha_peg.SinglePegInsertion, + "AlohaPick": aloha_pick.Pick, + "AlohaPegInsertion": aloha_peg_insertion.SinglePegInsertion, + "AlohaPegInsertionDistill": aloha_distillation.DistillPegInsertion, "PandaPickCube": panda_pick.PandaPickCube, "PandaPickCubeOrientation": panda_pick.PandaPickCubeOrientation, "PandaPickCubeCartesian": panda_pick_cartesian.PandaPickCubeCartesian, @@ -43,7 +47,9 @@ _cfgs = { "AlohaHandOver": aloha_handover.default_config, - "AlohaSinglePegInsertion": aloha_peg.default_config, + "AlohaPick": aloha_pick.default_config, + "AlohaPegInsertion": aloha_peg_insertion.default_config, + "AlohaPegInsertionDistill": aloha_distillation.default_config, "PandaPickCube": panda_pick.default_config, "PandaPickCubeOrientation": panda_pick.default_config, "PandaPickCubeCartesian": panda_pick_cartesian.default_config, @@ -56,6 +62,8 @@ _randomizer = { "LeapCubeRotateZAxis": leap_rotate_z.domain_randomize, "LeapCubeReorient": leap_cube_reorient.domain_randomize, + "AlohaPick": aloha_pick.domain_randomize, + "AlohaPegInsertionDistill": aloha_distillation.domain_randomize, } @@ -108,7 +116,9 @@ def load( An instance of the environment. """ if env_name not in _envs: - raise ValueError(f"Env '{env_name}' not found. Available envs: {_cfgs.keys()}") + raise ValueError( + f"Env '{env_name}' not found. Available envs: {_cfgs.keys()}" + ) config = config or get_default_config(env_name) return _envs[env_name](config=config, config_overrides=config_overrides) diff --git a/mujoco_playground/_src/manipulation/aloha/README.md b/mujoco_playground/_src/manipulation/aloha/README.md new file mode 100644 index 000000000..7caa59f8f --- /dev/null +++ b/mujoco_playground/_src/manipulation/aloha/README.md @@ -0,0 +1,84 @@ +### Quickstart + + +**Pre-requisites** + +- *Handover, Pick, Peg Insertion:* The standard Playground setup +- *Behaviour Cloning for Peg Insertion:* Madrona MJX +- *Jax-to-ONNX Conversion:* Onnx, Tensorflow, tf2onnx + +```bash +# Train Aloha Handover. Documentation at https://github.com/google-deepmind/mujoco_playground/pull/29 +python learning/train_jax_ppo.py --env_name AlohaHandOver +``` + +```bash +# Plots for pick and peg-insertion at https://github.com/google-deepmind/mujoco_playground/pull/76 +cd +export PARAMS_PATH=mujoco_playground/_src/manipulation/aloha/params + +# Train a single arm to pick up a cube. +python learning/train_jax_ppo.py --env_name AlohaPick --domain_randomization --norender_final_policy --save_params_path $PARAMS_PATH/AlohaPick +sleep 0.5 + +# Train a biarm to insert a peg into a socket. Requires above policy. +python learning/train_jax_ppo.py --env_name AlohaPegInsertion --save_params_path $PARAMS_PATH/AlohaPegInsertion +sleep 0.5 + +# Train a student policy to insert a peg into a socket using *pixel inputs*. Requires above policy. +python mujoco_playground/experimental/bc_peg_insertion.py --domain-randomization --num-evals 0 --print-loss + +# Convert checkpoints from the above run to ONNX for easy robot deployment. +# ONNX policies are written to `experimental/jax2onnx/onnx_policies`. +python mujoco_playground/experimental/jax2onnx/aloha_nets_to_onnx.py --checkpoint_path +``` + +### Sim-to-Real Transfer of a Bi-Arm RL Policy via Pixel-Based Behaviour Cloning + +https://github.com/user-attachments/assets/205fe8b9-1773-4715-8025-5de13490d0da + +--- + +**Distillation** + +In this module, we demonstrate policy distillation: a straightforward method for deploying a simulation-trained reinforcement learning policy that initially uses privileged state observations (such as object positions). The process involves two steps: + +1. **Teacher Policy Training:** A state-based teacher policy is trained using RL. +2. **Student Policy Distillation:** The teacher is then distilled into a student policy via behaviour cloning (BC), where the student learns to map its observations $o_s(x)$ (e.g., exteroceptive RGBD images) to the teacher’s deterministic actions $\pi_t(o_t(x))$. For example, while both policies observe joint angles, the student uses RGBD images, whereas the teacher directly accesses (noisy) object positions. + +The distillation process—where the student uses left and right wrist-mounted RGBD cameras for exteroception—takes about **3 minutes** on an RTX4090. This rapid turnaround is due to three factors: + +1. [Very fast rendering](https://github.com/google-deepmind/mujoco_playground/blob/main/mujoco_playground/experimental/madrona_benchmarking/figures/cartpole_benchmark_full.png) provided by Madrona MJX. +2. The sample efficiency of behaviour cloning. +3. The use of low-resolution (32×32) rendering, which is sufficient for precise alignment given the wrist camera placement. + +For further details on the teacher policy and RGBD sim-to-real techniques, please refer to the [technical report](https://docs.google.com/presentation/d/1v50Vg-SJdy5HV5JmPHALSwph9mcVI2RSPRdrxYR3Bkg/edit?usp=sharing). + +--- + +**A Note on Sample Efficiency** + +Behaviour cloning (BC) can be orders of magnitude more sample-efficient than reinforcement learning. In our approach, we use an L2 loss defined as: + +$|| \pi_s(o_s(x)) - \pi_t(o_t(x)) ||$ + +In contrast, the policy gradient in RL generally takes the form: + +![Equation](https://latex.codecogs.com/svg.latex?\nabla_\theta%20J(\theta)%20=%20\mathbb{E}_{\tau%20\sim%20\theta}%20\left[\sum_t%20\nabla_\theta%20\log%20\pi_\theta(a_t%20|%20s_t)%20R(\tau)\right]) + +Two key observations highlight why BC’s direct supervision is more efficient: + +- **Explicit Loss Signal:** The BC loss compares against the teacher action, giving explicit feedback on how the action should be adjusted. In contrast, the policy gradient only provides directional guidance, instructing the optimizer to increase or decrease an action’s likelihood based solely on its downstream rewards. +- **Per-Dimension Supervision:** While the policy gradient applies a uniform weighting across all action dimensions, BC supplies per-dimension information, making it easier to scale to high-dimensional action spaces. + +--- + +**Frozen Encoders** + +*VisionMLP2ChanCIFAR10_OCP* is an Orbax checkpoint of [NatureCNN](https://github.com/google/brax/blob/241f9bc5bbd003f9cfc9ded7613388e2fe125af6/brax/training/networks.py#L153) (AtariCNN) pre-trained on CIFAR10 to achieve over 70% classification accuracy. We omit the supervised training code, see [this tutorial](https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/JAX/tutorial5/Inception_ResNet_DenseNet.html) for reference. + +--- + +**Aloha Deployment Setup** + +For deployment, the ONNX policy is executed on the Aloha robot using a custom fork of [OpenPI](https://github.com/Physical-Intelligence/openpi) along with the Interbotix Aloha ROS packages. Acknowledgements to Kevin Zakka, Laura Smith and the Levine Lab for robot deployment setup! diff --git a/mujoco_playground/_src/manipulation/aloha/aloha_constants.py b/mujoco_playground/_src/manipulation/aloha/aloha_constants.py index 1697fcbb0..3fd4386d1 100644 --- a/mujoco_playground/_src/manipulation/aloha/aloha_constants.py +++ b/mujoco_playground/_src/manipulation/aloha/aloha_constants.py @@ -50,3 +50,25 @@ "right/left_finger", "right/right_finger", ] + +LEFT_JOINTS = [ + "left/waist", + "left/shoulder", + "left/elbow", + "left/forearm_roll", + "left/wrist_angle", + "left/wrist_rotate", + "left/left_finger", + "left/right_finger", +] + +RIGHT_JOINTS = [ + "right/waist", + "right/shoulder", + "right/elbow", + "right/forearm_roll", + "right/wrist_angle", + "right/wrist_rotate", + "right/left_finger", + "right/right_finger", +] diff --git a/mujoco_playground/_src/manipulation/aloha/depth_noise.py b/mujoco_playground/_src/manipulation/aloha/depth_noise.py new file mode 100644 index 000000000..5c4dcfd5b --- /dev/null +++ b/mujoco_playground/_src/manipulation/aloha/depth_noise.py @@ -0,0 +1,267 @@ +# Copyright 2025 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities for depth noise.""" + +import jax +import jax.numpy as jp +import numpy as np + + +def _bilinear_interpolate(image, y, x): + """ + Bilinearly interpolate a 2D image at floating-point (y,x) locations, + using 'nearest' mode behavior for out-of-bound coordinates. + + Parameters: + image : jp.ndarray of shape (H, W) + y : array of y coordinates (any shape) + x : array of x coordinates (same shape as y) + + Returns: + Interpolated values at the provided coordinates. + """ + height, width = image.shape + + # Clamp coordinates to the valid range. + y_clamped = jp.clip(y, 0.0, height - 1.0) + x_clamped = jp.clip(x, 0.0, width - 1.0) + + # Get the integer parts. + y0 = jp.floor(y_clamped).astype(jp.int32) + x0 = jp.floor(x_clamped).astype(jp.int32) + + # For the "upper" indices, if we're at the boundary, stay at the same index. + y1 = jp.where(y0 < height - 1, y0 + 1, y0) + x1 = jp.where(x0 < width - 1, x0 + 1, x0) + + # Compute the fractional parts. + dy = y_clamped - y0.astype(y_clamped.dtype) + dx = x_clamped - x0.astype(x_clamped.dtype) + + # Gather pixel values at the four corners. + val_tl = image[y0, x0] # top-left + val_tr = image[y0, x1] # top-right + val_bl = image[y1, x0] # bottom-left + val_br = image[y1, x1] # bottom-right + + # Compute the bilinear interpolated result. + # Need to be careful to avoid dead pixels in image edges. + return ( + val_tl * (1.0 - dx) * (1.0 - dy) + + val_tr * dx * (1.0 - dy) + + val_bl * (1.0 - dx) * dy + + val_br * dx * dy + ) + + +def kinect_noise(key, depth, *, sigma_s=0.5, sigma_d=1 / 6, baseline=35130): + """ + Apply noise based on the Kinect. Increased noise with distance. + + Parameters: + depth : 2D numpy array of ground truth depth values. + sigma_s : Std. dev. of spatial shift (in pixels). + sigma_d : Std. dev. of the Gaussian noise added in disparity. + baseline : Constant used for converting depth to disparity. + + Returns: + noisy_depth: 2D numpy array with noisy depth. + """ + if depth.ndim == 3: + depth = depth[..., 0] + + height, width = depth.shape + + # Create a meshgrid for pixel coordinates. + grid_y, grid_x = jp.mgrid[0:height, 0:width].astype(jp.float32) + + # Random shifts in x and y (sampled from Gaussian). + key, key_shift = jax.random.split(key) + shift_x, shift_y = jax.random.normal(key_shift, (2, height, width)) * sigma_s + + # Shifted coordinates. + shifted_x = grid_x + shift_x + shifted_y = grid_y + shift_y + + # Bilinearly interpolate depth at the shifted locations. + shifted_depth = _bilinear_interpolate( + depth, shifted_y.ravel(), shifted_x.ravel() + ).reshape(height, width) + + # Convert depth to disparity. + eps = 1e-6 # small epsilon to avoid division by zero + disparity = baseline / (shifted_depth + eps) + + # Add IID Gaussian noise to disparity. + key, key_noise = jax.random.split(key) + disparity_noisy = ( + disparity + jax.random.normal(key_noise, (height, width)) * sigma_d + ) + + # Quantise disparity (round to nearest integer). + disparity_quantized = jp.round(disparity_noisy) + + # Convert quantised disparity back to depth. + noisy_depth = baseline / (disparity_quantized + eps) + + if depth.ndim == 3: + noisy_depth = jp.expand_dims(noisy_depth, axis=-1) + + return noisy_depth + + +def edge_noise(key, depth, *, grad_threshold=0.05, noise_multiplier=10): + """ + Depth cameras are expected to occasionally lose pixels at edges. + When the spatial gradient of the depth is greater than threshold, theres a + chance the pixels are dropped. + Then, randomly jitter those dropped pixels. + Note that the proper way to do this requires the surface normals of everything + in the scene. + + Args: + grad_threshold: below this, no dropout. + noise_multiplier: higher values mean more dropout. + """ + # Compute gradients along the x and y directions. + # gradient returns [gradient_along_axis0, gradient_along_axis1]. + grad_y, grad_x = jp.gradient(depth) # each is (H, W) + + # Compute the magnitude of the depth gradient. + grad_mag = jp.sqrt(grad_x**2 + grad_y**2) + + # Probability that you lose that pixel. + p_lost = jp.arctan(noise_multiplier * grad_mag) + + p_lost = p_lost * (p_lost > grad_threshold).astype(jp.float32) + + # Sample a mask. + key_dropout, key = jax.random.split(key) + mask = ( + jax.random.uniform(key_dropout, depth.shape) < p_lost + ) # if true, then drop. + + # Scatter the mask. + height, width = depth.shape + grid_y, grid_x = jp.mgrid[0:height, 0:width].astype(jp.int32) + + # Random coordinate shifts in x and y, uniformly 0, 1. + key_shift, key = jax.random.split(key) + shift_x, shift_y = jax.random.randint( + key_shift, (2, height, width), minval=0, maxval=2 + ) + + # Shifted coordinates. + shifted_x = grid_x + shift_x + shifted_y = grid_y + shift_y + + # Ensure the shifted coordinates are within bounds. + shifted_x = jp.clip(shifted_x, 0, width - 1) + shifted_y = jp.clip(shifted_y, 0, height - 1) + + # Fancy indexing. + mask_shifted = mask[shifted_y, shifted_x] + + # Set those values to 0. + depth_noisy = depth * (1 - mask_shifted).astype(jp.float32) + return depth_noisy + + +def random_dropout(key, depth_image, *, p=0.006): + key_dropout, key = jax.random.split(key) + mask = jax.random.bernoulli(key_dropout, p, depth_image.shape) + depth_noisy = depth_image * (1 - mask).astype(jp.float32) + return depth_noisy + + +def _np_bresenham_line(x0, y0, x1, y1): + """ + Compute the list of pixels along a line from (x0,y0) to (x1,y1) + using Bresenham's algorithm. + Returns a list of (x, y) tuples. + """ + points = [] + dx = abs(x1 - x0) + dy = abs(y1 - y0) + x, y = x0, y0 + sx = 1 if x0 < x1 else -1 + sy = 1 if y0 < y1 else -1 + if dx > dy: + err = dx / 2.0 + while x != x1: + points.append((x, y)) + err -= dy + if err < 0: + y += sy + err += dx + x += sx + else: + err = dy / 2.0 + while y != y1: + points.append((x, y)) + err -= dx + if err < 0: + x += sx + err += dy + y += sy + points.append((x1, y1)) + return points + + +def _np_draw_line(img, start, end, color): + """ + Draw a line of thickness 1. + Start, end are (x, y) tuples. + """ + height, width = img.shape[:2] + for x, y in _np_bresenham_line(*start, *end): + if 0 <= x < width and 0 <= y < height: + img[y, x] = color + return img + + +def np_get_line_bank(height, width, bank_size=100, color_range=None): + """ + Get a bank of random lines. Not jax-compatible. + Returns a bank of size: + (bank_size, H, W) + where each element is a white image with up to max_lines lines randomly + drawn on it. + """ + if color_range is None: + color_range = [0, 0.4] + + max_lines = 16 + bank = [] + for _ in range(bank_size): + img = np.zeros((height, width), dtype=np.float32) + num_lines = np.random.randint(1, max_lines + 1) + for _ in range(num_lines): + start = np.random.randint(width), np.random.randint(height) + theta = np.random.uniform(0, 2 * np.pi) + length = np.random.randint(2, 6) + end = ( + start[0] + length * np.cos(theta), + start[1] + length * np.sin(theta), + ) + end = int(end[0]), int(end[1]) + color = np.random.uniform(color_range[0], color_range[1]) + img = _np_draw_line(img, start, end, color) + bank.append(img) + return np.stack(bank) + + +def apply_line_noise(img, line_noise): + return jp.where(line_noise != 0, line_noise, img) diff --git a/mujoco_playground/_src/manipulation/aloha/distillation.py b/mujoco_playground/_src/manipulation/aloha/distillation.py new file mode 100644 index 000000000..569a8390c --- /dev/null +++ b/mujoco_playground/_src/manipulation/aloha/distillation.py @@ -0,0 +1,594 @@ +# Copyright 2025 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Distillation module for sim-to-real +transfer of ALOHA peg insertion.""" + +import pathlib +from typing import Any, Dict, Optional, Union + +from brax.training import networks +from brax.training.agents.ppo import train as ppo_train +import jax +import jax.numpy as jp +from ml_collections import config_dict +from mujoco import mjx +import numpy as np +from orbax import checkpoint as ocp + +from mujoco_playground._src.manipulation.aloha import depth_noise +from mujoco_playground._src.manipulation.aloha import peg_insertion +from mujoco_playground._src.manipulation.aloha import pick_base +from mujoco_playground._src.manipulation.franka_emika_panda.randomize_vision import perturb_orientation + + +def default_vision_config() -> config_dict.ConfigDict: + return config_dict.create( + gpu_id=0, + render_batch_size=1024, + randomization_fn=None, + render_width=64, + render_height=64, + enabled_geom_groups=[1, 2, 5], + use_rasterizer=False, + enabled_cameras=[4, 5], + ) + + +def default_config() -> ( + config_dict.ConfigDict +): # TODO :Clean up. Or just import? + """Returns the default config for bring_to_target tasks.""" + config = config_dict.create( + ctrl_dt=0.05, + sim_dt=0.005, + episode_length=160, + action_repeat=1, + action_scale=0.02, + action_history_length=4, + max_obs_delay=4, + reset_buffer_size=10, + vision=True, + vision_config=default_vision_config(), + obs_noise=config_dict.create( + depth=True, + brightness=[0.5, 2.5], + grad_threshold=0.05, + noise_multiplier=10, + obj_pos=0.015, # meters + obj_vel=0.015, # meters/s + obj_angvel=0.2, + gripper_box=0.015, # meters + obj_angle=7.5, # degrees + robot_qpos=0.1, # radians + robot_qvel=0.1, # radians/s + eef_pos=0.02, # meters + eef_angle=5.0, # degrees + ), + reward_config=config_dict.create( + scales=config_dict.create(peg_insertion=8, obj_rot=0.5), + sparse=config_dict.create(success=0, drop=-10, final_grasp=10), + reg=config_dict.create( + robot_target_qpos=1, joint_vel=1, grip_pos=0.5 # no sliding! + ), + ), + ) + return config + + +def adjust_brightness(img, scale): + """Adjusts the brightness of an image by scaling the pixel values.""" + return jp.clip(img * scale, 0, 1) + + +def load_frozen_encoder_params(): + vision_mlp = networks.VisionMLP(layer_sizes=(0,), policy_head=False) + fpath = pathlib.Path(__file__).parent / 'params' / 'VisionMLP2ChanCIFAR10_OCP' + orbax_checkpointer = ocp.PyTreeCheckpointer() + sample_obs = { + 'pixels/view_0': jp.ones((1, 32, 32, 3)), + 'pixels/view_1': jp.ones((1, 32, 32, 3)), + } + target = vision_mlp.init(jax.random.PRNGKey(0), sample_obs) + return orbax_checkpointer.restore(fpath, item=target) + + +def get_frozen_encoder_fn(): + """Returns a function that encodes observations using a frozen vision MLP.""" + vision_mlp = networks.VisionMLP(layer_sizes=(0,), policy_head=False) + params = load_frozen_encoder_params() + + def encoder_fn(obs: Dict): + stacked = {} + for i in range(2): + stacked[f'pixels/view_{i}'] = obs[f'pixels/view_{i}'][None, ...] + return vision_mlp.apply(params, stacked)[0] # unbatch + + return encoder_fn + + +class DistillPegInsertion(peg_insertion.SinglePegInsertion): + """Distillation environment for peg insertion task with vision capabilities. + + This class extends the PegInsertion environment to support policy distillation + with vision-based observations, including depth and RGB camera inputs. + """ + + def __init__( + self, + config: config_dict.ConfigDict = default_config(), + config_overrides: Optional[Dict[str, Union[str, int, list[Any]]]] = None, + ): + super().__init__(config, config_overrides, distill=True) + self._vision = config.vision + self.encoder_fn = get_frozen_encoder_fn() + if self._vision: + # Import here to avoid dependency issues when vision is disabled + # pylint: disable=import-outside-toplevel + from madrona_mjx.renderer import BatchRenderer + + render_height = self._config.vision_config.render_height + render_width = self._config.vision_config.render_width + + self.renderer = BatchRenderer( + m=self._mjx_model, + gpu_id=self._config.vision_config.gpu_id, + num_worlds=self._config.vision_config.render_batch_size, + batch_render_view_height=render_height, + batch_render_view_width=render_width, + enabled_geom_groups=np.asarray( + self._config.vision_config.enabled_geom_groups + ), + enabled_cameras=np.asarray( + self._config.vision_config.enabled_cameras + ), + add_cam_debug_geo=False, + use_rasterizer=self._config.vision_config.use_rasterizer, + viz_gpu_hdls=None, + ) + self.max_depth = {'pixels/view_0': 0.4, 'pixels/view_1': 0.4} + + if self._config.obs_noise.depth: + # color range based on max_depth values. + # Pre-sample random lines for simplicity. + max_depth = self.max_depth['pixels/view_0'] + self.line_bank = jp.array( + depth_noise.np_get_line_bank( + render_height, + render_width, + bank_size=16384, + color_range=[max_depth * 0.2, max_depth * 0.85], + ) + ) + + def reset_color_noise(self, info): + info['rng'], rng_brightness = jax.random.split(info['rng']) + + info['brightness'] = jax.random.uniform( + rng_brightness, + (), + minval=self._config.obs_noise.brightness[0], + maxval=self._config.obs_noise.brightness[1], + ) + + info['color_noise'] = {} + info['shade_noise'] = {} # Darkness of the colored object. + + color_noise_scales = {0: 0.3, 2: 0.05} + shade_noise_mins = {0: 0.5, 2: 0.9} + shade_noise_maxes = {0: 1.0, 2: 2.0} + + def generate_noise(chan): + info['rng'], rng_noise, rng_shade = jax.random.split(info['rng'], 3) + noise = jax.random.uniform( + rng_noise, (1, 3), minval=0, maxval=color_noise_scales[chan] + ) + noise = noise.at[0, chan].set(0) + info['color_noise'][chan] = noise + info['shade_noise'][chan] = jax.random.uniform( + rng_shade, + (), + minval=shade_noise_mins[chan], + maxval=shade_noise_maxes[chan], + ) + + for chan in [0, 2]: + generate_noise(chan) + + def _get_obs_distill(self, data, info, init=False): + obs_pick = self._get_obs_pick(data, info) + obs_insertion = jp.concatenate([obs_pick, self._get_obs_dist(data, info)]) + if not self._vision: + state_wt = jp.concatenate([ + obs_insertion, + (info['_steps'] / self._config.episode_length).reshape(1), + ]) + return {'state_with_time': state_wt} + if init: + info['render_token'], rgb, depth = self.renderer.init( + data, self._mjx_model + ) + else: + _, rgb, depth = self.renderer.render(info['render_token'], data) + # Process depth. + info['rng'], rng_l, rng_r = jax.random.split(info['rng'], 3) + dmap_l = self.process_depth(depth, 0, 'pixels/view_0', rng_l) + r_dmap_l = jax.image.resize(dmap_l, (8, 8, 1), method='nearest') + dmap_r = self.process_depth(depth, 1, 'pixels/view_1', rng_r) + r_dmap_r = jax.image.resize(dmap_r, (8, 8, 1), method='nearest') + + rgb_l = jp.asarray(rgb[0][..., :3], dtype=jp.float32) / 255.0 + rgb_r = jp.asarray(rgb[1][..., :3], dtype=jp.float32) / 255.0 + + info['rng'], rng_noise1, rng_noise2 = jax.random.split(info['rng'], 3) + rgb_l = adjust_brightness( + self.rgb_noise(rng_noise1, rgb_l, info), info['brightness'] + ) + rgb_r = adjust_brightness( + self.rgb_noise(rng_noise2, rgb_r, info), info['brightness'] + ) + latent_rgb_l, latent_rgb_r = jp.split( + self.encoder_fn({'pixels/view_0': rgb_l, 'pixels/view_1': rgb_r}), + 2, + axis=-1, + ) + + # Required for supervision to stay still. + socket_pos = data.xpos[self._socket_body] + dist_from_hidden = jp.linalg.norm(socket_pos[:2] - jp.array([-0.4, 0.33])) + socket_hidden = jp.where(dist_from_hidden < 3e-2, 1.0, 0.0).reshape(1) + + peg_pos = data.xpos[self._peg_body] + dist_from_hidden = jp.linalg.norm(peg_pos[:2] - jp.array([0.4, 0.33])) + peg_hidden = jp.where(dist_from_hidden < 3e-2, 1.0, 0.0).reshape(1) + + obs = { + 'proprio': self._get_proprio(data, info), + 'pixels/view_0': dmap_l, # view_i for debugging only + 'pixels/view_1': dmap_r, + 'pixels/view_2': rgb_l, + 'pixels/view_3': rgb_r, + 'latent_0': latent_rgb_l, # actual policy inputs + 'latent_1': latent_rgb_r, + 'latent_2': r_dmap_l.ravel(), + 'latent_3': r_dmap_r.ravel(), + 'socket_hidden': socket_hidden, + 'peg_hidden': peg_hidden, + } + return obs + + def _get_proprio(self, data: mjx.Data, info: Dict) -> jax.Array: + """Get the proprio observations for the real sim2real.""" + info['rng'], rng = jax.random.split(info['rng']) + # qpos_noise = jax.random.uniform(rng, data.qpos.shape) - 0.5 + qpos_noise = jax.random.uniform( + rng, (16,), minval=0, maxval=self._config.obs_noise.robot_qpos + ) + qpos_noise = qpos_noise * jp.array(pick_base.QPOS_NOISE_MASK_SINGLE * 2) + qpos = data.qpos[:16] + qpos_noise + l_posobs = qpos[self._left_qposadr] + r_posobs = qpos[self._right_qposadr] + + def dupll(arr): + # increases size of array by 1 by dupLicating its Last element. + return jp.concatenate([arr, arr[-1:]]) + + assert info['motor_targets'].shape == (14,), print( + info['motor_targets'].shape + ) + + l_velobs = l_posobs - dupll(info['motor_targets'][:7]) + r_velobs = r_posobs - dupll(info['motor_targets'][7:]) + proprio_list = [l_posobs, r_posobs, l_velobs, r_velobs] + + switcher = [info['has_switched'].astype(float).reshape(1)] + + proprio = jp.concat(proprio_list + switcher) + return proprio + + def add_depth_noise(self, key, img: jp.ndarray): + """Add realistic depth sensor noise to the depth image.""" + render_width = self._config.vision_config.render_width + render_height = self._config.vision_config.render_height + assert img.shape == (render_height, render_width, 1) + # squeeze + img = img.squeeze(-1) + grad_threshold = self._config.obs_noise.grad_threshold + noise_multiplier = self._config.obs_noise.noise_multiplier + + key_edge_noise, key = jax.random.split(key) + img = depth_noise.edge_noise( + key_edge_noise, + img, + grad_threshold=grad_threshold, + noise_multiplier=noise_multiplier, + ) + key_kinect, key = jax.random.split(key) + img = depth_noise.kinect_noise(key_kinect, img) + key_dropout, key = jax.random.split(key) + img = depth_noise.random_dropout(key_dropout, img) + key_line, key = jax.random.split(key) + noise_idx = jax.random.randint(key_line, (), 0, len(self.line_bank)) + img = depth_noise.apply_line_noise(img, self.line_bank[noise_idx]) + + # With a low probability, return an all-black image. + p_blackout = 0.02 # once per 2.5 sec. + key_blackout, key = jax.random.split(key) + blackout = jax.random.bernoulli(key_blackout, p=p_blackout) + img = jp.where(blackout, 0.0, img) + + return img[..., None] + + def process_depth( + self, + depth, + chan: int, + view_name: str, + key: Optional[jp.ndarray] = None, + ): + """Process depth image with normalization and optional noise.""" + img_size = self._config.vision_config.render_width + num_cams = len(self._config.vision_config.enabled_cameras) + assert depth.shape == (num_cams, img_size, img_size, 1) + depth = depth[chan] + max_depth = self.max_depth[view_name] + # max_depth = info['max_depth'] + too_big = jp.where(depth > max_depth, 0, 1) + depth = depth * too_big + if self._config.obs_noise.depth and key is not None: + depth = self.add_depth_noise(key, depth) + return depth / max_depth # Normalize + + def rgb_noise(self, key, img, info): + """Apply domain randomization noise to RGB images.""" + # Assumes images are already normalized. + pixel_noise = 0.03 + + # Add noise to all channels and clip + key_noise, key = jax.random.split(key) + noise = jax.random.uniform( + key_noise, img.shape, minval=0, maxval=pixel_noise + ) + img += noise + img = jp.clip(img, 0, 1) + + return img + + @property + def observation_size(self): + """Return the observation space dimensions for each observation type.""" + # Manually set observation size; default method breaks madrona MJX. + ret = { + 'has_switched': (1,), + 'proprio': (33,), + 'state': (109,), + 'state_pickup': (106,), + 'peg_hidden': (1,), + 'socket_hidden': (1,), + 'privileged': (110,), + } + if self._vision: + ret.update({ + 'pixels/view_0': (8, 8, 1), + 'pixels/view_1': (8, 8, 1), + 'pixels/view_2': (32, 32, 3), + 'pixels/view_3': (32, 32, 3), + 'latent_0': (64,), + 'latent_1': (64,), + 'latent_2': (64,), + 'latent_3': (64,), + }) + else: + ret['state_with_time'] = (110,) + return ret + + +def make_teacher_policy(): + """Create a teacher policy for distillation from pre-trained models.""" + env = peg_insertion.SinglePegInsertion() + f_pick_teacher = ( + pathlib.Path(__file__).parent / 'params' / 'AlohaPick' / 'checkpoints' + ) + f_pick_teacher = peg_insertion.get_latest_checkpoint(f_pick_teacher) + f_insert_teacher = ( + pathlib.Path(__file__).parent + / 'params' + / 'AlohaPegInsertion' + / 'checkpoints' + ) + f_insert_teacher = peg_insertion.get_latest_checkpoint(f_insert_teacher) + + teacher_pick_policy = peg_insertion.load_brax_policy( + f_pick_teacher.as_posix(), + 'AlohaPick', + distill=True, + config_fname='ppo_network_config.json', + ) + + teacher_insert_policy = peg_insertion.load_brax_policy( + f_insert_teacher.as_posix(), + 'AlohaPegInsertion', + distill=True, + config_fname='ppo_network_config.json', + ) + obs_keys = ppo_train._remove_pixels(env.observation_size.keys()) + + @jax.jit + def teacher_inference_fn(obs, rng): + l_obs, r_obs = jp.split(obs['state_pickup'], 2, axis=-1) + l_act, l_extras = teacher_pick_policy({'state': l_obs}, None) + r_act, r_extras = teacher_pick_policy({'state': r_obs}, None) + + if 'socket_hidden' in obs: + l_act = jp.where(obs['socket_hidden'], jp.zeros_like(l_act), l_act) + l_extras = jax.tree_util.tree_map( + lambda x: jp.where(obs['socket_hidden'], jp.zeros_like(x), x), + l_extras, + ) + r_act = jp.where(obs['peg_hidden'], jp.zeros_like(r_act), r_act) + r_extras = jax.tree_util.tree_map( + lambda x: jp.where(obs['peg_hidden'], jp.zeros_like(x), x), r_extras + ) + act_1 = jp.concatenate([l_act, r_act], axis=-1) + act_extras_1 = jax.tree_util.tree_map( + lambda x, y: jp.concatenate([x, y], axis=-1), l_extras, r_extras + ) + obs_2 = {k: obs[k] for k in obs_keys} + act_2, act_extras_2 = teacher_insert_policy(obs_2, None) + + # Select a pair based on condition. + c = obs['has_switched'].reshape(-1, 1) # 0 for policy 1; 1 for policy 2 + act, extras = jax.tree_util.tree_map( + lambda x, y: (1 - c) * x + c * y, + (act_1, act_extras_1), + (act_2, act_extras_2), + ) + return act, extras + + return teacher_inference_fn + + +def domain_randomize(model: mjx.Model, rng: jax.Array): + """Apply domain randomization to camera positions, lights, and materials.""" + cam_ids = default_vision_config().enabled_cameras + mj_model = DistillPegInsertion(config_overrides={'vision': False}).mj_model + table_geom_id = mj_model.geom('table').id + b_ids = [ + mj_model.geom(f'socket-{wall}').id for wall in ['B', 'T', 'L', 'R'] + ] # blue geoms + r_ids = [ + mj_model.geom('red_peg').id, + mj_model.geom('socket-W').id, + ] # red geoms + + @jax.vmap + def rand(rng): + # Geom RGBA + geom_rgba = model.geom_rgba + + # MatID needs to change to enable RGBA randomization. + geom_matid = model.geom_matid.at[:].set(-1) + for id in b_ids: + rng_obj, rng = jax.random.split(rng) + obj_hue = jax.random.uniform(rng_obj, (), minval=0.5, maxval=1.0) + geom_rgba = geom_rgba.at[id, 2].set(obj_hue) # randomize blue dim. + # Add some noise to the other two dims. + rng_color, rng = jax.random.split(rng) # Doesn't work. + color_noise = jax.random.uniform(rng_color, (2,), minval=0, maxval=0.12) + geom_rgba = geom_rgba.at[id, :2].set(geom_rgba[id, :2] + color_noise) + geom_matid = geom_matid.at[id].set(-2) + for id in r_ids: + rng_obj, rng = jax.random.split(rng) + obj_hue = jax.random.uniform(rng_obj, (), minval=0.0, maxval=1.0) + geom_rgba = geom_rgba.at[id, 0].set(obj_hue) + rng_color, rng = jax.random.split(rng) + color_noise = jax.random.uniform(rng_color, (2,), minval=0, maxval=0.07) + geom_rgba = geom_rgba.at[id, 1:3].set(geom_rgba[id, 1:3] + color_noise) + geom_matid = geom_matid.at[id].set(-2) + + # Set the floor to a random gray-ish color. + gray_value = jax.random.uniform(rng, (), minval=0.0, maxval=0.1) + floor_rgba = ( + geom_rgba[table_geom_id] + .at[:3] + .set(jp.array([gray_value, gray_value, gray_value])) + ) + geom_rgba = geom_rgba.at[table_geom_id].set(floor_rgba) + + # geom_matid = geom_matid.at[peg_geom_id].set(-2) + # geom_matid = geom_matid.at[table_geom_id].set(-2) + + # Cameras + cam_pos = model.cam_pos + cam_quat = model.cam_quat + for cur_idx in cam_ids: + rng, rng_pos, rng_ori = jax.random.split(rng, 3) + offset_scales = jp.array([0.0125, 0.005, 0.005]) + cam_offset = ( + jax.random.uniform(rng_pos, (3,), minval=-1, maxval=1) * offset_scales + ) + cam_pos = cam_pos.at[cur_idx].set(cam_pos[cur_idx] + cam_offset) + cam_quat = cam_quat.at[cur_idx].set( + perturb_orientation(rng_ori, cam_quat[cur_idx], 5) + ) + + n_lights = model.light_pos.shape[0] # full: (n_lights, 3) + + # Light position + rng, rng_pos = jax.random.split(rng) + offset_scales = 10 * jp.array([0.1, 0.1, 0.1]).reshape(1, 3) + light_offset = ( + jax.random.uniform(rng_pos, model.light_pos.shape, minval=-1, maxval=1) + * offset_scales + ) + light_pos = model.light_pos + light_offset + + assert model.light_dir.shape == (n_lights, 3) + # Perturb the light direction + light_dir = model.light_dir + for i_light in range(n_lights): + rng, rng_ldir = jax.random.split(rng) + nom_dir = model.light_dir[i_light] + light_dir = light_dir.at[i_light].set( + perturb_orientation(rng_ldir, nom_dir, 10) + ) + + # Cast shadows + rng, rng_lsha = jax.random.split(rng) + light_castshadow = jax.random.bernoulli( + rng_lsha, 0.75, shape=(n_lights,) + ).astype(jp.float32) + + return ( + cam_pos, + cam_quat, + geom_rgba, + geom_matid, + light_pos, + light_dir, + light_castshadow, + ) + + ( + cam_pos, + cam_quat, + geom_rgba, + geom_matid, + light_pos, + light_dir, + light_castshadow, + ) = rand(rng) + in_axes = jax.tree_util.tree_map(lambda x: None, model) + in_axes = in_axes.tree_replace({ + 'cam_pos': 0, + 'cam_quat': 0, + 'geom_rgba': 0, + 'geom_matid': 0, + 'light_pos': 0, + 'light_dir': 0, + 'light_castshadow': 0, + }) + + model = model.tree_replace({ + 'cam_pos': cam_pos, + 'cam_quat': cam_quat, + 'geom_rgba': geom_rgba, + 'geom_matid': geom_matid, + 'light_pos': light_pos, + 'light_dir': light_dir, + 'light_castshadow': light_castshadow, + }) + + return model, in_axes diff --git a/mujoco_playground/_src/manipulation/aloha/params/VisionMLP2ChanCIFAR10_OCP/_CHECKPOINT_METADATA b/mujoco_playground/_src/manipulation/aloha/params/VisionMLP2ChanCIFAR10_OCP/_CHECKPOINT_METADATA new file mode 100644 index 000000000..fd68a4b65 --- /dev/null +++ b/mujoco_playground/_src/manipulation/aloha/params/VisionMLP2ChanCIFAR10_OCP/_CHECKPOINT_METADATA @@ -0,0 +1 @@ +{"init_timestamp_nsecs": 1744640712443995440, "commit_timestamp_nsecs": 1744640712503807685} \ No newline at end of file diff --git a/mujoco_playground/_src/manipulation/aloha/params/VisionMLP2ChanCIFAR10_OCP/_METADATA b/mujoco_playground/_src/manipulation/aloha/params/VisionMLP2ChanCIFAR10_OCP/_METADATA new file mode 100644 index 000000000..5179c7670 --- /dev/null +++ b/mujoco_playground/_src/manipulation/aloha/params/VisionMLP2ChanCIFAR10_OCP/_METADATA @@ -0,0 +1 @@ +{"tree_metadata": {"('params', 'CNN_0', 'Conv_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "CNN_0", "key_type": 2}, {"key": "Conv_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'CNN_0', 'Conv_1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "CNN_0", "key_type": 2}, {"key": "Conv_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'CNN_0', 'Conv_2', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "CNN_0", "key_type": 2}, {"key": "Conv_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'CNN_1', 'Conv_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "CNN_1", "key_type": 2}, {"key": "Conv_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'CNN_1', 'Conv_1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "CNN_1", "key_type": 2}, {"key": "Conv_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}, "('params', 'CNN_1', 'Conv_2', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "CNN_1", "key_type": 2}, {"key": "Conv_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false}}}, "use_zarr3": false, "store_array_data_equal_to_fill_value": true} \ No newline at end of file diff --git a/mujoco_playground/_src/manipulation/aloha/params/VisionMLP2ChanCIFAR10_OCP/_sharding b/mujoco_playground/_src/manipulation/aloha/params/VisionMLP2ChanCIFAR10_OCP/_sharding new file mode 100644 index 000000000..6e2dab9f2 --- /dev/null +++ b/mujoco_playground/_src/manipulation/aloha/params/VisionMLP2ChanCIFAR10_OCP/_sharding @@ -0,0 +1 @@ +{"cGFyYW1zLkNOTl8wLkNvbnZfMC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLkNOTl8wLkNvbnZfMS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLkNOTl8wLkNvbnZfMi5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLkNOTl8xLkNvbnZfMC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLkNOTl8xLkNvbnZfMS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLkNOTl8xLkNvbnZfMi5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}"} \ No newline at end of file diff --git a/mujoco_playground/_src/manipulation/aloha/params/VisionMLP2ChanCIFAR10_OCP/d/d1f00de212d2b0f1dcc7e09a84bd7409 b/mujoco_playground/_src/manipulation/aloha/params/VisionMLP2ChanCIFAR10_OCP/d/d1f00de212d2b0f1dcc7e09a84bd7409 new file mode 100644 index 000000000..865df911a Binary files /dev/null and b/mujoco_playground/_src/manipulation/aloha/params/VisionMLP2ChanCIFAR10_OCP/d/d1f00de212d2b0f1dcc7e09a84bd7409 differ diff --git a/mujoco_playground/_src/manipulation/aloha/params/VisionMLP2ChanCIFAR10_OCP/manifest.ocdbt b/mujoco_playground/_src/manipulation/aloha/params/VisionMLP2ChanCIFAR10_OCP/manifest.ocdbt new file mode 100644 index 000000000..5f58fbaeb Binary files /dev/null and b/mujoco_playground/_src/manipulation/aloha/params/VisionMLP2ChanCIFAR10_OCP/manifest.ocdbt differ diff --git a/mujoco_playground/_src/manipulation/aloha/params/VisionMLP2ChanCIFAR10_OCP/ocdbt.process_0/d/31e6971d2d3cce5dc906d2d18f697f8f b/mujoco_playground/_src/manipulation/aloha/params/VisionMLP2ChanCIFAR10_OCP/ocdbt.process_0/d/31e6971d2d3cce5dc906d2d18f697f8f new file mode 100644 index 000000000..30eb51a52 Binary files /dev/null and b/mujoco_playground/_src/manipulation/aloha/params/VisionMLP2ChanCIFAR10_OCP/ocdbt.process_0/d/31e6971d2d3cce5dc906d2d18f697f8f differ diff --git a/mujoco_playground/_src/manipulation/aloha/params/VisionMLP2ChanCIFAR10_OCP/ocdbt.process_0/d/ca46a10ced5fb385297e83bf9ed74f12 b/mujoco_playground/_src/manipulation/aloha/params/VisionMLP2ChanCIFAR10_OCP/ocdbt.process_0/d/ca46a10ced5fb385297e83bf9ed74f12 new file mode 100644 index 000000000..29edda73a Binary files /dev/null and b/mujoco_playground/_src/manipulation/aloha/params/VisionMLP2ChanCIFAR10_OCP/ocdbt.process_0/d/ca46a10ced5fb385297e83bf9ed74f12 differ diff --git a/mujoco_playground/_src/manipulation/aloha/params/VisionMLP2ChanCIFAR10_OCP/ocdbt.process_0/d/dc2124f9b1811e062867aab6a6ebbc1f b/mujoco_playground/_src/manipulation/aloha/params/VisionMLP2ChanCIFAR10_OCP/ocdbt.process_0/d/dc2124f9b1811e062867aab6a6ebbc1f new file mode 100644 index 000000000..2b8082917 Binary files /dev/null and b/mujoco_playground/_src/manipulation/aloha/params/VisionMLP2ChanCIFAR10_OCP/ocdbt.process_0/d/dc2124f9b1811e062867aab6a6ebbc1f differ diff --git a/mujoco_playground/_src/manipulation/aloha/params/VisionMLP2ChanCIFAR10_OCP/ocdbt.process_0/d/fb854afc1c01e908601ed5f2525e8eb7 b/mujoco_playground/_src/manipulation/aloha/params/VisionMLP2ChanCIFAR10_OCP/ocdbt.process_0/d/fb854afc1c01e908601ed5f2525e8eb7 new file mode 100644 index 000000000..4a3644893 Binary files /dev/null and b/mujoco_playground/_src/manipulation/aloha/params/VisionMLP2ChanCIFAR10_OCP/ocdbt.process_0/d/fb854afc1c01e908601ed5f2525e8eb7 differ diff --git a/mujoco_playground/_src/manipulation/aloha/params/VisionMLP2ChanCIFAR10_OCP/ocdbt.process_0/manifest.ocdbt b/mujoco_playground/_src/manipulation/aloha/params/VisionMLP2ChanCIFAR10_OCP/ocdbt.process_0/manifest.ocdbt new file mode 100644 index 000000000..bde4ebe93 Binary files /dev/null and b/mujoco_playground/_src/manipulation/aloha/params/VisionMLP2ChanCIFAR10_OCP/ocdbt.process_0/manifest.ocdbt differ diff --git a/mujoco_playground/_src/manipulation/aloha/peg_insertion.py b/mujoco_playground/_src/manipulation/aloha/peg_insertion.py new file mode 100644 index 000000000..a691329df --- /dev/null +++ b/mujoco_playground/_src/manipulation/aloha/peg_insertion.py @@ -0,0 +1,637 @@ +# Copyright 2025 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Module for the PegInsertion task in the Aloha manipulation environment.""" + +import functools +import pathlib +from typing import Any, Dict, Optional, Tuple, Union + +from brax.training.acme import running_statistics +# from brax.training.agents.bc import networks as bc_networks +from brax.training.agents.bc import checkpoint as bc_checkpoint +from brax.training.agents.ppo import checkpoint as ppo_checkpoint +from brax.training.agents.ppo import networks as ppo_networks +import flax +import jax +import jax.numpy as jp +from ml_collections import config_dict +from mujoco import mjx +from mujoco.mjx._src import math +import numpy as np + +from mujoco_playground._src import mjx_env +from mujoco_playground._src import reward as reward_util +from mujoco_playground._src.manipulation.aloha import pick_base +from mujoco_playground._src.mjx_env import State # pylint: disable=g-importing-member +from mujoco_playground.config import manipulation_params + + +def default_config() -> config_dict.ConfigDict: # TODO :Clean up. + """Returns the default config for bring_to_target tasks.""" + config = config_dict.create( + ctrl_dt=0.05, + sim_dt=0.005, + episode_length=160, + action_repeat=1, + action_scale=0.02, + action_history_length=4, + max_obs_delay=4, + reset_buffer_size=10, + obs_noise=config_dict.create( + depth=True, + brightness=[1.0, 3.0], + grad_threshold=0.05, + noise_multiplier=10, + obj_pos=0.015, # meters + obj_vel=0.015, # meters/s + obj_angvel=0.2, + gripper_box=0.015, # meters + obj_angle=7.5, # degrees + robot_qpos=0.1, # radians + robot_qvel=0.1, # radians/s + eef_pos=0.02, # meters + eef_angle=5.0, # degrees + ), + reward_config=config_dict.create( + scales=config_dict.create(peg_insertion=8, obj_rot=0.5), + sparse=config_dict.create(success=0, drop=-10, final_grasp=10), + reg=config_dict.create( + robot_target_qpos=1, joint_vel=1, grip_pos=0.5 # no sliding! + ), + ), + ) + return config + + +def get_latest_checkpoint(path: Union[str, pathlib.Path]): + """ + Get the latest checkpoint from a directory. Assumes checkpoints names are + left-padded ascending. For example, 000005079040, 000010158080, ... + """ + path = pathlib.Path(path) + # ignore anything ending in .json + checkpoints = [p for p in path.glob("*") if not p.name.endswith(".json")] + # sort by name + checkpoints.sort() + return checkpoints[-1] + + +def load_brax_policy( + path: Union[str, pathlib.Path], + env_name, + distill: bool = False, + config_fname: Optional[Union[str, pathlib.Path]] = None, +): + """ + Load a policy from a Brax checkpoint file. Assumes network parameters + match manipulation_params.py. + """ + ppo_params = manipulation_params.brax_ppo_config(env_name) + network_factory = functools.partial( + ppo_networks.make_ppo_networks, **ppo_params.network_factory + ) + if distill: + # override config_fname to allow loading a PPO policy into a BC inference function. + print(f"Loading PPO policy from {path} with config_fname {config_fname}") + return bc_checkpoint.load_policy( + path, network_factory, deterministic=True, config_fname=config_fname + ) + return ppo_checkpoint.load_policy(path, network_factory, deterministic=True) + + +def load_pick_policy(path, env_name): + raw_policy = load_brax_policy(path, env_name) + + def single2biarm_inference_fn(obs: jp.ndarray): + l_obs, r_obs = jp.split(obs, 2, axis=-1) + l_act, _ = raw_policy({"state": l_obs}, None) + r_act, _ = raw_policy({"state": r_obs}, None) + return jp.concatenate([l_act, r_act], axis=-1) + + return jax.jit(single2biarm_inference_fn) + + +class SinglePegInsertion(pick_base.PickBase): + """ + Phase 2 of the peg insertion task. From a pre-insertion position, + brings the peg into the socket. + """ + + def __init__( + self, + config: Optional[config_dict.ConfigDict] = default_config(), + config_overrides: Optional[Dict[str, Union[str, int, list[Any]]]] = None, + *, + # If true, this class just provides methods + # for the downstream distill class to use. + distill: bool = False, + ): + self._distill = distill + xml_path = ( + mjx_env.ROOT_PATH + / "manipulation" + / "aloha" + / "xmls" + / "mjx_single_peg_insertion.xml" + ) + super().__init__( + xml_path=xml_path, + config=config, + config_overrides=config_overrides, + ) + + if distill: + self.pick_policy = lambda x: jp.zeros(self.action_size) + else: + pick_path = ( + pathlib.Path(__file__).parent / "params" / "AlohaPick" / "checkpoints" + ) + if not pick_path.exists(): + raise FileNotFoundError( + f"Pick policy file not found: {pick_path}, please train one." + ) + pick_path = get_latest_checkpoint(pick_path) + self.pick_policy = load_pick_policy( + pick_path, + "AlohaPick", + ) + + self.obj_names = ["socket", "peg"] + self.hands = ["left", "right"] + self.target_positions = jp.array( + [[-0.10, 0.0, 0.25], [0.10, 0.0, 0.25]], dtype=float + ) + self.target_quats = [[1.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 1.0]] + self._switch_height = 0.25 - 0.08 # Used for height-based switching. + self.default_pre_insertion_qpos = self._mj_model.keyframe("preinsert").qpos + self.default_pre_insertion_ctrl = self._mj_model.keyframe("preinsert").ctrl + self._socket_end_site = self._mj_model.site("socket_rear").id + self._peg_end_site = self._mj_model.site("peg_end1").id + self._socket_entrance_site = self._mj_model.site("socket_entrance").id + # hardcode because there's no mj_model.jnt_qveladr option. + self._socket_qveladr = 16 + self._peg_qveladr = 22 + + self.noise_config = config_dict.create( + _peg_target_pos=np.array([0.00, 0.00, 0.00]), + _socket_target_pos=np.array([0.00, 0.00, 0.00]), + _left_waist_init_pos=np.array(0.1), + _right_waist_init_pos=np.array(0.1), + ) + + fov_cam = 58 + self.noise_config["_peg_init_pos"] = config_dict.create( + radius_min=0.27, radius_max=0.42, angle=jp.deg2rad(45 * fov_cam / 90) + ) + self.noise_config["_socket_init_pos"] = config_dict.create( + radius_min=0.27, radius_max=0.42, angle=jp.deg2rad(45 * fov_cam / 90) + ) + + self._finger_ctrladr = np.array([6, 13], dtype=int) + self._skip_prob = 0.8 if not distill else 0.0 + self._post_init(keyframe="home") + + def reset(self, rng: jax.Array) -> State: + data, info = self.init_objects(rng) + metrics = { + **{k: 0.0 for k in self._config.reward_config.scales.keys()}, + **{k: 0.0 for k in self._config.reward_config.sparse.keys()}, + **{k: 0.0 for k in self._config.reward_config.reg.keys()}, + } + + info["has_switched"] = jp.array(0, dtype=int) + info["preinsertion_buffer_qpos"] = jp.tile( + self.default_pre_insertion_qpos, (self._config.reset_buffer_size, 1) + ) + info["preinsertion_buffer_ctrl"] = jp.tile( + self.default_pre_insertion_ctrl, (self._config.reset_buffer_size, 1) + ) + info["time_of_switch"] = jp.array(0, dtype=int) + metrics["peg_end2_dist_to_line"] = jp.array(0.0, dtype=float) + + obs = self._get_obs_insertion(data, info) + if self._distill: + self.reset_color_noise(info) + obs = {**obs, **self._get_obs_distill(data, info, init=True)} + + # Random obs delay. + actor_obs, _ = flax.core.pop(obs, "has_switched") + actor_obs, _ = flax.core.pop(actor_obs, "privileged") + info["obs_history"] = pick_base.init_obs_history( + actor_obs, self._config.max_obs_delay + ) + + # Assert that obs_history has same keys as obs minus popped keys + obs_keys = set(obs.keys()) + history_keys = set(info["obs_history"].keys()) + expected_keys = obs_keys - {"privileged", "has_switched"} + assert history_keys == expected_keys, ( + f"Mismatch between obs keys {expected_keys} and history keys" + f" {history_keys}" + ) + + reward, done = jp.zeros(2) + state = State(data, obs, reward, done, metrics, info) + return state + + def step(self, state: State, action: jax.Array) -> State: + newly_reset = state.info["_steps"] == 0 + self._reset_info_if_needed(state, newly_reset) + + action = self._action_mux(action, state) + state = self._state_mux(state) + data = self._step(state, action) + + # Calculate rewards + sparse_rewards, success, dropped, final_grasp = ( + self._calculate_sparse_rewards(state, data) + ) + dense_rewards = self._calculate_dense_rewards(data, state.metrics) + reg_rewards = self._calculate_reg_rewards(data, state.metrics) + + # Calculate total reward + total_reward = self._calculate_total_reward( + dense_rewards, sparse_rewards, reg_rewards, state + ) + + # Check done conditions + done = self._check_done_conditions(data, state, dropped) + + # Update state + state = self._update_state(state, data, total_reward, done) + state.metrics.update({ + "success": jp.array(success, dtype=float), + "drop": jp.array(dropped, dtype=float), + "final_grasp": final_grasp, + }) + return state + + def _reset_info_if_needed(self, state: State, newly_reset: bool): + state.info["has_switched"] = jp.where( + newly_reset, 0, state.info["has_switched"] + ) + state.info["time_of_switch"] = jp.where( + newly_reset, 0, state.info["time_of_switch"] + ) + + def _calculate_sparse_rewards( + self, state: State, data: mjx.Data + ) -> Tuple[Dict[str, float], bool, bool, float]: + # Calculate success + peg_insertion_dist = jp.linalg.norm( + data.site_xpos[self._mj_model.site("socket_end1").id] + - data.site_xpos[self._peg_end_site] + ) + success = peg_insertion_dist < 0.01 + + # Calculate dropped + peg_height = data.xpos[self._peg_body][2] + socket_height = data.xpos[self._socket_body][2] + thresh = self._switch_height - 0.1 + + dropped = jp.array(False) + if not self._distill: + dropped = (peg_height < thresh) | (socket_height < thresh) + dropped = dropped & state.info["has_switched"].astype(bool) + + # Calculate final grasp + l_grasped = ( + jp.linalg.norm(self.gripping_error(data, "left", "socket")) < 0.03 + ) + r_grasped = jp.linalg.norm(self.gripping_error(data, "right", "peg")) < 0.03 + + final_grasp = 0.0 + grasped = l_grasped.astype(float) * r_grasped.astype(float) + last_step = state.info["_steps"] >= ( + self._config.episode_length - self._config.action_repeat + ) + final_grasp = grasped * last_step.astype(float) + + raw_sparse_rewards = { + "success": success.astype(float), + "drop": dropped.astype(float), + "final_grasp": final_grasp, + } + + sparse_rewards = { + k: v * self._config.reward_config.sparse[k] + for k, v in raw_sparse_rewards.items() + } + + return sparse_rewards, success, dropped, final_grasp + + def _calculate_dense_rewards( + self, data: mjx.Data, metrics: Dict[str, float] + ) -> Dict[str, float]: + socket_entrance_pos = data.site_xpos[self._socket_entrance_site] + socket_rear_pos = data.site_xpos[self._socket_end_site] + peg_end2_pos = data.site_xpos[self._peg_end_site] + + # Insertion reward: if peg end2 is aligned with hole entrance, then reward + # distance from peg end to socket interior. + socket_ab = socket_entrance_pos - socket_rear_pos + socket_t = jp.dot(peg_end2_pos - socket_rear_pos, socket_ab) + socket_t /= jp.dot(socket_ab, socket_ab) + 1e-6 + nearest_pt = data.site_xpos[self._socket_end_site] + socket_t * socket_ab + peg_end2_dist_to_line = jp.linalg.norm(peg_end2_pos - nearest_pt) + + objects_aligned = peg_end2_dist_to_line < 0.01 + metrics["peg_end2_dist_to_line"] = peg_end2_dist_to_line + + peg_insertion_dist = jp.linalg.norm( + data.site_xpos[self._mj_model.site("socket_end1").id] + - data.site_xpos[self._peg_end_site] + ) + + peg_insertion_reward = reward_util.tolerance( + peg_insertion_dist, (0, 0.001), margin=0.2, sigmoid="linear" + ) * objects_aligned.astype(float) + + # Dense rotation reward + rot_rewards = {} + for obj, target in zip(self.obj_names, self.target_quats): + obj_mat = data.xmat[getattr(self, f"_{obj}_body")] + obj_target = math.quat_to_mat(jp.array(target)) + rot_err = jp.linalg.norm(obj_target.ravel()[:6] - obj_mat.ravel()[:6]) + rot_rewards[f"{obj}_rot"] = 1 - jp.tanh(5 * rot_err) + + raw_dense_rewards = {"peg_insertion": peg_insertion_reward, **rot_rewards} + + metrics.update({"peg_insertion": peg_insertion_reward}) + + return { + k: v * self._config.reward_config.scales.get( + k, self._config.reward_config.scales.obj_rot + ) + for k, v in raw_dense_rewards.items() + } + + def _calculate_reg_rewards( + self, + data: mjx.Data, + metrics: Dict[str, float], + ) -> Dict[str, float]: + robot_target_qpos = self._robot_target_qpos(data) + + # Joint velocity regularization + joint_vel_rewards = {} + for side in self.hands: + joint_vel_mse = jp.linalg.norm( + data.qvel[getattr(self, f"_{side}_qposadr")] + ) + joint_vel_rewards[f"{side}_joint_vel"] = reward_util.tolerance( + joint_vel_mse, (0, 0.5), margin=2.0, sigmoid="reciprocal" + ) + + # Grip regularization + e_l_grip = self.gripping_error(data, "left", "socket") + e_r_grip = self.gripping_error(data, "right", "peg") + r_l_grip = 1 - jp.tanh(5 * jp.linalg.norm(e_l_grip)) + r_r_grip = 1 - jp.tanh(5 * jp.linalg.norm(e_r_grip)) + + raw_reg_rewards = { + "robot_target_qpos": robot_target_qpos, + "left_grip_pos": r_l_grip, + "right_grip_pos": r_r_grip, + **joint_vel_rewards, + } + + metrics.update({"robot_target_qpos": jp.array(robot_target_qpos)}) + + reg_rewards = {} + for k, v in raw_reg_rewards.items(): + if k == "robot_target_qpos": + reg_rewards[k] = v * self._config.reward_config.reg.robot_target_qpos + elif k.endswith("_joint_vel"): + reg_rewards[k] = ( + v * self._config.reward_config.reg.joint_vel / len(self.hands) + ) + elif k.endswith("_grip_pos"): + reg_rewards[k] = v * self._config.reward_config.reg.grip_pos / 2 + + return reg_rewards + + def _calculate_total_reward( + self, + dense_rewards: Dict[str, float], + sparse_rewards: Dict[str, float], + reg_rewards: Dict[str, float], + state: State, + ) -> float: + total_reward = ( + sum(dense_rewards.values()) + + sum(sparse_rewards.values()) + + sum(reg_rewards.values()) + ) + + # Zero reward for when the other policy's taking action + total_reward = jp.where(state.info["has_switched"], total_reward, 0.0) + + return total_reward + + def _check_done_conditions( + self, data: mjx.Data, state: State, dropped: bool + ) -> bool: + # Check if out of bounds + out_of_bounds = jp.any(jp.abs(data.xpos[self._socket_body]) > 1.0) + out_of_bounds |= jp.any(jp.abs(data.xpos[self._peg_body]) > 1.0) + + # Check if end of insertion + end_of_insertion = jp.array(False) + if not self._distill: + end_of_insertion = ( + state.info["_steps"] - state.info["time_of_switch"] >= 60 + ) + end_of_insertion = end_of_insertion & state.info["has_switched"].astype( + bool + ) + + # Check if rotated + peg_mat = data.xmat[self._peg_body] + z_axis = jp.array([0, 0, 1]) + peg_z = peg_mat[:3, 2] # Z axis of peg is just last column. + peg_z = peg_z / jp.linalg.norm(peg_z) + angle = jp.arccos(jp.dot(z_axis, peg_z)) + rotated = angle > jp.deg2rad(80) + + socket_mat = data.xmat[self._socket_body] + socket_z = socket_mat[:3, 2] + socket_z = socket_z / jp.linalg.norm(socket_z) + angle = jp.arccos(jp.dot(z_axis, socket_z)) + rotated |= angle > jp.deg2rad(80) + + # Combine all done conditions + done = jp.isnan(data.qpos).any() | jp.isnan(data.qvel).any() | dropped + done = done | out_of_bounds | end_of_insertion + done = done | rotated + + return done + + def _update_state( + self, state: State, data: mjx.Data, total_reward: float, done: bool + ) -> State: + # Get observations + obs = self._get_obs_insertion(data, state.info) + if self._distill: + obs = { + **obs, + **self._get_obs_distill(data, state.info, init=False), + } + + # Update observation history + state.info["rng"], key_obs = jax.random.split(state.info["rng"]) + pick_base.use_obs_history(key_obs, state.info["obs_history"], obs) + + # Update step counter + state.info["_steps"] += self._config.action_repeat + state.info["_steps"] = jp.where( + done | (state.info["_steps"] >= self._config.episode_length), + 0, + state.info["_steps"], + ) + + # Return updated state + return State( + data, + obs, + jp.array(total_reward), + jp.array(done, dtype=float), + state.metrics, + state.info, + ) + + def _action_mux(self, action: jp.array, state: mjx_env.State): + """ + Chooses which policy to apply. If you've already toggled + switched this round, always use the external policy. + """ + + data = state.data + + left_gripper_tip = self.gripper_tip_pos(data, "left") + right_gripper_tip = self.gripper_tip_pos(data, "right") + switch = (left_gripper_tip[2] > self._switch_height) & ( + right_gripper_tip[2] > self._switch_height + ) + + first_switch = jp.logical_and(state.info["has_switched"] == 0, switch) + + state.info["time_of_switch"] = jp.where( + first_switch, state.info["_steps"], state.info["time_of_switch"] + ) + + #### Exploration Manager #### + # If it's the first switch of the run, save the data to + # the buffer of states you can skip to at autoreset. + def update_first_value(buf, val): + buf = jp.roll(buf, 1, axis=0) + buf = buf.at[0].set(val) + return buf + + new_qpos_buf = update_first_value( + state.info["preinsertion_buffer_qpos"], data.qpos + ) + new_ctrl_buf = update_first_value( + state.info["preinsertion_buffer_ctrl"], data.ctrl + ) + + state.info["preinsertion_buffer_qpos"] = jp.where( + first_switch, new_qpos_buf, state.info["preinsertion_buffer_qpos"] + ) + state.info["preinsertion_buffer_ctrl"] = jp.where( + first_switch, new_ctrl_buf, state.info["preinsertion_buffer_ctrl"] + ) + #### End Exploration Manager #### + + state.info["has_switched"] = jp.where(switch, 1, state.info["has_switched"]) + use_input = state.info["has_switched"].astype(bool) | self._distill + return jp.where( + use_input, action, self.pick_policy(state.obs["state_pickup"]) + ) + + def _state_mux(self, state: mjx_env.State) -> mjx_env.State: + + state.info["rng"], key_skip, key_skip_index = jax.random.split( + state.info["rng"], 3 + ) + i_buf = jax.random.randint( + key_skip_index, (), minval=0, maxval=self._config.reset_buffer_size + ) + preinsert_qpos = state.info["preinsertion_buffer_qpos"][i_buf] + preinsert_ctrl = state.info["preinsertion_buffer_ctrl"][i_buf] + + preinsert_data = state.data.replace( + qpos=preinsert_qpos, ctrl=preinsert_ctrl + ) + + newly_reset = state.info["_steps"] == 0 + to_skip = newly_reset * jax.random.bernoulli(key_skip, self._skip_prob) + + # The pre insert buffer is initialized with the home position, + # in which case you can't skip. + to_skip = jp.logical_and(to_skip, jp.any(preinsert_qpos != self._init_q)) + state.info["has_switched"] = jp.where( + to_skip, 1, state.info["has_switched"] + ) + data = jax.tree_util.tree_map( + lambda x, y: (1 - to_skip) * x + to_skip * y, + state.data, + preinsert_data, + ) + + #### Randomly hide #### + qpos = data.qpos + if self._distill: + for obj in ["socket", "peg"]: + state.info["rng"], key_hide = jax.random.split(state.info["rng"]) + hide = newly_reset * jax.random.bernoulli(key_hide, 0.07) + obj_idx = getattr(self, f"_{obj}_qposadr") + hidden_pos = jp.array([0.4, 0.33]) + if obj == "socket": + hidden_pos = hidden_pos.at[0].set(hidden_pos[0] * -1) + obj_hidden = qpos.at[obj_idx : obj_idx + 2].set(hidden_pos) + qpos = jp.where(hide, obj_hidden, qpos) + data = data.replace(qpos=qpos) + #### + + return state.replace(data=data) + + def _get_obs_insertion(self, data: mjx.Data, info: dict) -> jax.Array: + obs_pick = self._get_obs_pick(data, info) + obs_insertion = jp.concatenate([obs_pick, self._get_obs_dist(data, info)]) + obs = { + "state_pickup": obs_pick, + "state": obs_insertion, + "privileged": jp.concat([ + obs_insertion, + (info["_steps"] / self._config.episode_length).reshape(1), + ]), + "has_switched": info["has_switched"].astype(float).reshape(1), + } + return obs + + def _get_obs_dist(self, data: mjx.Data, info: dict) -> jax.Array: + delta = ( + data.site_xpos[self._socket_end_site] + - data.site_xpos[self._peg_end_site] + ) + info["rng"], key = jax.random.split(info["rng"]) + noise = jax.random.uniform( + key, + (2, 3), + minval=-self._config.obs_noise.obj_pos, + maxval=self._config.obs_noise.obj_pos, + ) + return delta + (noise[1] - noise[0]) diff --git a/mujoco_playground/_src/manipulation/aloha/pick.py b/mujoco_playground/_src/manipulation/aloha/pick.py new file mode 100644 index 000000000..c492d5a4b --- /dev/null +++ b/mujoco_playground/_src/manipulation/aloha/pick.py @@ -0,0 +1,432 @@ +# Copyright 2025 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Implementation of the ALOHA pick task for sim-to-real transfer.""" + +from typing import Any, Dict, Optional, Tuple, Union + +import flax +import jax +import jax.numpy as jp +from ml_collections import config_dict +from mujoco import mjx +from mujoco.mjx._src import math +import numpy as np + +from mujoco_playground._src import collision +from mujoco_playground._src import mjx_env +from mujoco_playground._src import reward as reward_util +from mujoco_playground._src.manipulation.aloha import pick_base +from mujoco_playground._src.mjx_env import State # pylint: disable=g-importing-member + + +def default_config() -> config_dict.ConfigDict: + """Returns the default config for bring_to_target tasks.""" + config = config_dict.create( + ctrl_dt=0.05, + sim_dt=0.005, + episode_length=160, + action_repeat=1, + action_scale=0.02, + action_history_length=4, + max_obs_delay=4, + vision=False, + dense_rot_weight=0.23, + obs_noise=config_dict.create( + depth=False, + grad_threshold=0.05, + noise_multiplier=10, + obj_pos=0.015, # meters + obj_vel=0.015, # meters/s + obj_angvel=0.2, + gripper_box=0.015, # meters + obj_angle=5.0, # degrees + robot_qpos=0.1, # radians + robot_qvel=0.1, # radians/s + eef_pos=0.02, # meters + eef_angle=5.0, # degrees + ), + reward_config=config_dict.create( + scales=config_dict.create( + # Gripper goes to the box. + gripper_box=4.0, + # Box goes to the target mocap. + box_target=16.0, + ), + sparse=config_dict.create( + lift=0.5, grasped=0.5, success=0.5, success_time=0.1 + ), + reg=config_dict.create( + finger_force=0.007, + # Do not collide the gripper with the floor. + no_floor_collision=0.005, + joint_vel=0.005, + # Arm stays close to target pose. + robot_target_qpos=0.001, + ), + ), + ) + return config + + +class Pick(pick_base.PickBase): + """Phase 1 of the peg insertion task. + Picks a block and brings it to the target.""" + + def __init__( + self, + config: Optional[config_dict.ConfigDict] = default_config(), + config_overrides: Optional[Dict[str, Union[str, int, list[Any]]]] = None, + ): + xml_path = ( + mjx_env.ROOT_PATH / "manipulation" / "aloha" / "xmls" / "mjx_pick.xml" + ) + super().__init__(xml_path, config, config_overrides) + self.obj_names = ["box"] + self.hands = ["left"] + + self.noise_config = config_dict.create( + _box_target_pos=np.array([0.1, 0.1, 0.1]), + _left_waist_init_pos=np.array(0.2), + ) + self.noise_config["_box_init_pos"] = config_dict.create( + radius_min=0.2, + radius_max=0.45, # Issue: overlapping paths. + angle=jp.deg2rad(45), + ) + + self.target_positions = [jp.array([-0.1, 0.0, 0.25], dtype=float)] + self.target_quats = [[1.0, 0.0, 0.0, 0.0]] + self._finger_ctrladr = np.array([6], dtype=int) + self._box_qveladr = 8 + + self._post_init(keyframe="home") + + def reset(self, rng: jax.Array) -> State: + data, info = self.init_objects(rng) + metrics = { + **{k: 0.0 for k in self._config.reward_config.scales.keys()}, + **{k: 0.0 for k in self._config.reward_config.sparse.keys()}, + **{k: 0.0 for k in self._config.reward_config.reg.keys()}, + } + obs = self._get_obs_pick(data, info) + metrics["score"] = jp.array(0.0, dtype=float) + info["score"] = jp.array(0, dtype=int) + info["reached_box"] = jp.array(0.0, dtype=float) + info["prev_reward"] = jp.array(0.0, dtype=float) + info["success_time"] = jp.array(0, dtype=int) + obs = {"state": obs, "privileged": obs} + actor_obs, _ = flax.core.pop( + obs, "privileged" + ) # Privileged obs is not randomly shifted. + info["obs_history"] = pick_base.init_obs_history( + actor_obs, self._config.max_obs_delay + ) + reward, done = jp.zeros(2) + state = State(data, obs, reward, done, metrics, info) + return state + + def step(self, state: State, action: jax.Array) -> State: + newly_reset = state.info["_steps"] == 0 + self._reset_info_if_needed(state, newly_reset) + + data = self._step(state, action) + + # Calculate rewards + sparse_rewards, success = self._calculate_sparse_rewards(state, data) + dense_rewards = self._calculate_dense_rewards( + data, state.info, state.metrics + ) + reg_rewards = self._calculate_reg_rewards(data, state.metrics) + + # Calculate total reward + total_reward = self._calculate_total_reward( + dense_rewards, sparse_rewards, reg_rewards, state + ) + + # Check done conditions + done, crossed_line = self._check_done_conditions(data) + total_reward += jp.where(crossed_line, -1.0, 0.0) + + # Update score + self._update_score(state, success, done) + + # Update state + state = self._update_state(state, data, total_reward, done) + + return state + + def _reset_info_if_needed(self, state: State, newly_reset: bool): + state.info["reached_box"] = jp.where( + newly_reset, 0.0, state.info["reached_box"] + ) + state.info["prev_reward"] = jp.where( + newly_reset, 0.0, state.info["prev_reward"] + ) + state.info["success_time"] = jp.where( + newly_reset, 0, state.info["success_time"] + ) + + def _calculate_sparse_rewards( + self, state: State, data: mjx.Data + ) -> Tuple[Dict[str, float], bool]: + grasped = self.is_grasped(data, "left") + gripping_error = jp.linalg.norm(self.gripping_error(data, "left", "box")) + grasped_correct = gripping_error < pick_base.GRASP_THRESH + grasped = grasped * grasped_correct.astype(float) + + box_pos = data.xpos[self._box_body] + init_box_height = self._init_q[self._box_qposadr + 2] + lifted = (box_pos[2] > (init_box_height + 0.05)).astype(float) + + success, success_time = self._calculate_success(state, box_pos, data) + + raw_sparse_rewards = { + "grasped": grasped, + "lift": lifted, + "success": jp.array(success, dtype=float), + "success_time": jp.array(success, dtype=float) * success_time, + } + state.metrics.update(**raw_sparse_rewards) + sparse_rewards = { + k: v * self._config.reward_config.sparse[k] + for k, v in raw_sparse_rewards.items() + } + return sparse_rewards, success + + def _calculate_dense_rewards( + self, data: mjx.Data, info: Dict[str, Any], metrics: Dict[str, float] + ) -> Dict[str, float]: + raw_rewards = self._get_dense_pick(data, info) + metrics.update(**raw_rewards) + return { + k: v * self._config.reward_config.scales[k] + for k, v in raw_rewards.items() + } + + def _calculate_reg_rewards( + self, + data: mjx.Data, + metrics: Dict[str, float], + ) -> Dict[str, float]: + raw_reg_rewards = self._get_reg_pick(data) + f_lfing = self.get_finger_force(data, "left", "left") + f_rfing = self.get_finger_force(data, "left", "right") + f_fing = jp.mean(jp.linalg.norm(f_lfing) + jp.linalg.norm(f_rfing)) + max_f_fing = 7.0 + n_f_fing = jp.clip(f_fing, min=None, max=max_f_fing) / max_f_fing + raw_reg_rewards.update( + {"finger_force": n_f_fing * self.is_grasped(data, "left")} + ) + metrics.update(**raw_reg_rewards) + return { + k: v * self._config.reward_config.reg[k] + for k, v in raw_reg_rewards.items() + } + + def _calculate_total_reward( + self, + dense_rewards: Dict[str, float], + sparse_rewards: Dict[str, float], + reg_rewards: Dict[str, float], + state: State, + ) -> float: + total_reward = jp.clip(sum(dense_rewards.values()), -1e4, 1e4) + total_reward += jp.clip(sum(sparse_rewards.values()), -1e4, 1e4) + reward = jp.maximum( + total_reward - state.info["prev_reward"], jp.zeros_like(total_reward) + ) + state.info["prev_reward"] = jp.maximum( + total_reward, state.info["prev_reward"] + ) + reward = jp.where(state.info["_steps"] == 0, 0.0, reward) + reward += jp.clip(sum(reg_rewards.values()), -1e4, 1e4) + return reward + + def _check_done_conditions( + self, + data: mjx.Data, + ) -> Tuple[bool, bool]: + id_far_end = self.mj_model.site("box_end_2").id + box_far_end = data.site_xpos[id_far_end] + crossed_line = box_far_end[0] > (0.0 + 0.048 + 0.025) + + box_pos = data.xpos[self._box_body] + out_of_bounds = jp.any(jp.abs(box_pos) > 1.0) + done = out_of_bounds | jp.isnan(data.qpos).any() | jp.isnan(data.qvel).any() + done = done | crossed_line + + return done, crossed_line + + def _update_score(self, state: State, success: bool, done: bool): + last_step = ( + state.info["_steps"] + self._config.action_repeat + ) >= self._config.episode_length + state.info["score"] += jp.array(success, dtype=int) * last_step + state.info["score"] = jp.clip(state.info["score"], min=0, max=4) + state.metrics["score"] = state.info["score"] * 1.0 + state.info["_steps"] += self._config.action_repeat + state.info["_steps"] = jp.where( + done | (state.info["_steps"] >= self._config.episode_length), + 0, + state.info["_steps"], + ) + + def _update_state( + self, state: State, data: mjx.Data, total_reward: float, done: bool + ) -> State: + obs = self._get_obs_pick(data, state.info) + obs = {"state": obs, "privileged": obs} + state.info["rng"], key_obs = jax.random.split(state.info["rng"]) + pick_base.use_obs_history(key_obs, state.info["obs_history"], obs) + return State( + data, + obs, + total_reward, + jp.array(done, dtype=float), + state.metrics, + state.info, + ) + + def _calculate_thresholds(self, score: int) -> Tuple[float, float]: + def map_to_range(val: int, a: float, b: float, num_vals=5): + step = (b - a) / (num_vals - 1) # Step size for the range + index = jp.minimum(val // 1, num_vals - 1) + return a + step * index + + pos_thresh = map_to_range(score, 0.04, 0.005) + rot_thresh = map_to_range(score, 15, 2.5) + return pos_thresh, rot_thresh + + def _calculate_success( + self, state: State, box_pos: jax.Array, data: mjx.Data + ) -> Tuple[bool, int]: + pos_thresh, rot_thresh = self._calculate_thresholds(state.info["score"]) + success = ( + jp.linalg.norm(box_pos - state.info["_box_target_pos"]) < pos_thresh + ) + box_mat = data.xmat[self._box_body] + target_mat = math.quat_to_mat(data.mocap_quat[self._box_mocap_target]) + rot_err = jp.linalg.norm(target_mat.ravel()[:6] - box_mat.ravel()[:6]) + rot_success = rot_err < jp.deg2rad(rot_thresh) + success = success & rot_success + + state.info["success_time"] = jp.where( + success & (state.info["success_time"] == 0), + state.info["_steps"], + state.info["success_time"], + ) + + success_time = state.info["_steps"] - state.info["success_time"] + return success, success_time + + def _get_dense_pick( + self, data: mjx.Data, info: Dict[str, Any] + ) -> Dict[str, Any]: + target_pos = info["_box_target_pos"] + box_pos = data.xpos[self._box_body] + pos_err = jp.linalg.norm(target_pos - box_pos) + box_mat = data.xmat[self._box_body] + target_mat = math.quat_to_mat(data.mocap_quat[self._box_mocap_target]) + rot_err = jp.linalg.norm(target_mat.ravel()[:6] - box_mat.ravel()[:6]) + w_r = self._config.dense_rot_weight + box_target = 1 - jp.tanh(5 * ((1 - w_r) * pos_err + w_r * rot_err)) + gripping_error = jp.linalg.norm(self.gripping_error(data, "left", "box")) + gripper_box = 1 - jp.tanh(5 * gripping_error) + + info["reached_box"] = 1.0 * jp.maximum( + info["reached_box"], + ( + jp.linalg.norm(self.gripping_error(data, "left", "box")) + < pick_base.GRASP_THRESH + ), + ) + + rewards = { + "gripper_box": gripper_box, + "box_target": box_target * info["reached_box"], + } + return rewards + + def _get_reg_pick(self, data: mjx.Data) -> Dict[str, Any]: + rewards = { + "robot_target_qpos": self._robot_target_qpos(data), + } + + joint_vel_mse = jp.linalg.norm(data.qvel[self._left_qposadr]) + joint_vel = reward_util.tolerance( + joint_vel_mse, (0, 0.5), margin=2.0, sigmoid="reciprocal" + ) + rewards["joint_vel"] = joint_vel + + left_id = self._left_left_finger_geom_bottom + right_id = self._left_right_finger_geom_bottom + + hand_floor_collision = [ + collision.geoms_colliding(data, getattr(self, "_floor_geom"), g) + for g in [ + left_id, + right_id, + self._left_hand_geom, + ] + ] + floor_collision = sum(hand_floor_collision) > 0 + no_floor_collision = (1 - floor_collision).astype(float) + rewards["no_floor_collision"] = no_floor_collision + + return rewards + + +def domain_randomize(model: mjx.Model, rng: jax.Array): + """Randomize domain parameters for sim-to-real transfer. + + Args: + model: The MuJoCo model to randomize + rng: JAX random number generator key + + Returns: + Tuple of (randomized model, in_axes for vmap) + """ + mj_model = Pick().mj_model + obj_id = mj_model.geom("box").id + obj_body_id = mj_model.body("box").id + + @jax.vmap + def rand(rng): + key_size, key_mass = jax.random.split(rng) + # geom size + geom_size_sides = jax.random.uniform(key_size, (), minval=0.01, maxval=0.03) + geom_size = model.geom_size.at[obj_id, 1:3].set(geom_size_sides) + + # mass + mass = jax.random.uniform(key_mass, (), minval=0.03, maxval=0.1) + mass = model.body_mass.at[obj_body_id].set(mass) + + return geom_size, mass + + geom_size, mass = rand(rng) + + in_axes = jax.tree_util.tree_map(lambda x: None, model) + in_axes = in_axes.tree_replace({ + "geom_size": 0, + "body_mass": 0, + }) + + model = model.tree_replace({ + "geom_size": geom_size, + "body_mass": mass, + }) + + return model, in_axes diff --git a/mujoco_playground/_src/manipulation/aloha/pick_base.py b/mujoco_playground/_src/manipulation/aloha/pick_base.py new file mode 100644 index 000000000..26f4398da --- /dev/null +++ b/mujoco_playground/_src/manipulation/aloha/pick_base.py @@ -0,0 +1,668 @@ +# Copyright 2025 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +This module contains the base class for the two phases of peg insertion. +It includes functionalities for initializing objects, calculating observations, +and managing the state of the environment during robotic manipulation tasks. +""" + +import functools +from typing import Any, Dict, Optional, Tuple, Union + +import jax +import jax.numpy as jp +from ml_collections import config_dict +import mujoco +from mujoco import mjx +from mujoco.mjx._src import math +from mujoco.mjx._src.support import contact_force +import numpy as np + +from mujoco_playground._src import mjx_env +from mujoco_playground._src.manipulation.aloha import aloha_constants +from mujoco_playground._src.manipulation.aloha import base +from mujoco_playground._src.mjx_env import State # pylint: disable=g-importing-member + +QPOS_NOISE_MASK_SINGLE = [1] * 6 + [0] * 2 # 6 joints, 2 fingers. +ZIPF_S3 = [ + 0.83, + 0.104, + 0.02, + 0.014, + 0.008, + 0.002, +] # heavy-tailed zipf pmf evaluated at x=1, ..., 6 with s=3. +GRASP_THRESH = 0.015 + + +def get_rand_dir(rng: jax.Array) -> jax.Array: + key1, key2 = jax.random.split(rng) + theta = jax.random.normal(key1) * 2 * jp.pi + phi = jax.random.normal(key2) * jp.pi + x = jp.sin(phi) * jp.cos(theta) + y = jp.sin(phi) * jp.sin(theta) + z = jp.cos(phi) + return jp.array([x, y, z]) + + +def init_obs_history(init_obs: Dict, history_len: int) -> Dict: + """Initialize observation history dictionary. + + For each entry in init_obs, creates a history initialized to the same value. + + Args: + init_obs: Initial observation dictionary + history_len: Length of history to maintain + + Returns: + Dictionary containing observation histories + """ + obs_history = {} + for k, v in init_obs.items(): + # for state and pixel obs + obs_axes = (history_len,) + (1,) * len(v.shape) + obs_history[k] = jp.tile(v, obs_axes) + return obs_history + + +def use_obs_history(key, obs_history: Dict, obs: Dict) -> Tuple[Dict, Dict]: + """Purely in-place. + 1. update obs history. + 2. update obs with value sampled from buffer. + """ + key, key_sample = jax.random.split(key) # all sub-obs share the same jitter. + # Update obs history + for k, v in obs_history.items(): + shifted = jp.roll(v, 1, axis=0) + obs_history[k] = shifted.at[0].set(obs[k]) + # Sample + logits = jp.log(jp.array(ZIPF_S3[: len(v)])) + obs_idx = jax.random.categorical(key_sample, logits) + obs[k] = obs_history[k][obs_idx] + return obs_history, obs + + +class PickBase(base.AlohaEnv): + """Base class for Pick and downstream tasks.""" + + def __init__( + self, + xml_path, + config: config_dict.ConfigDict, + config_overrides: Optional[Dict[str, Union[str, int, list[Any]]]] = None, + ): + super().__init__(xml_path, config, config_overrides) + self.base_info = { + "left": {"pos": jp.array([-0.469, -0.019, 0.02])}, + "right": {"pos": jp.array([0.469, -0.019, 0.02])}, + } + self._action_scale = config.action_scale + + def _post_init(self, keyframe: str): + #### GLOBALS + self._init_ctrl = self._mj_model.keyframe(keyframe).ctrl + + self._lowers, self._uppers = self._mj_model.actuator_ctrlrange.T + self._init_q = self._mj_model.keyframe(keyframe).qpos + + #### PER OBJECT + for obj in self.obj_names: + setattr( + self, + f"_{obj}_qposadr", + self._mj_model.jnt_qposadr[self._mj_model.body(obj).jntadr[0]], + ) + setattr(self, f"_{obj}_body", self._mj_model.body(obj).id) + setattr( + self, + f"_{obj}_init_pos", + jp.array( + self._init_q[ + getattr(self, f"_{obj}_qposadr") : getattr( + self, f"_{obj}_qposadr" + ) + + 3 + ], + dtype=jp.float32, + ), + ) + setattr( + self, f"_{obj}_grip_site", self._mj_model.site(f"{obj}_grip_here").id + ) + setattr( + self, + f"_{obj}_mocap_target", + self._mj_model.body(f"{obj}_mocap_target").mocapid, + ) + + #### PER HAND + for hand in self.hands: + setattr( + self, + f"_{hand}_left_finger_geom_bottom", + self._mj_model.geom(f"{hand}/left_finger_bottom").id, + ) + setattr( + self, + f"_{hand}_right_finger_geom_bottom", + self._mj_model.geom(f"{hand}/right_finger_bottom").id, + ) + # Fingertips + setattr( + self, + f"_{hand}_left_fingertip", + self._mj_model.site(f"{hand}/left_fingertip").id, + ) + setattr( + self, + f"_{hand}_right_fingertip", + self._mj_model.site(f"{hand}/right_fingertip").id, + ) + setattr( + self, + f"_{hand}_hand_geom", + self._mj_model.geom(f"{hand}/gripper_base").id, + ) + setattr( + self, + f"_{hand}_gripper_site", + self._mj_model.site(f"{hand}/gripper").id, + ) + setattr( + self, + f"_{hand}_base_link", + self._mj_model.body(f"{hand}/base_link").id, + ) + setattr( + self, + f"_{hand}_qposadr", + np.array([ + self._mj_model.jnt_qposadr[self._mj_model.joint(j).id] + for j in getattr(aloha_constants, f"{hand.upper()}_JOINTS") + ]), + ) + + match hand: + case "left": + self._left_ctrladr = np.linspace(0, 6, 7, dtype=int) + case "right": + self._right_ctrladr = np.linspace(7, 13, 7, dtype=int) + case _: + raise ValueError(f"Invalid hand: {hand}") + self._floor_geom = self._mj_model.geom("table").id + + def _robot_target_qpos(self, data: mjx.Data) -> float: + robot_target_qpos = 0.0 + for hand in self.hands: + hand_ids = getattr(self, f"_{hand}_qposadr") + robot_target_qpos += 1 - jp.tanh( + jp.linalg.norm(data.qpos[hand_ids] - self._init_q[hand_ids]) + ) + return robot_target_qpos / len(self.hands) + + def sample_fan(self, rng: jax.Array, obj: str) -> Tuple[jax.Array, jax.Array]: + """Sample a perturbation position and quaternion in a fan pattern. + + Args: + rng: Random number generator key + obj: Object name + + Returns: + Tuple of position perturbation and quaternion + """ + rng, rng_r, rng_angle = jax.random.split(rng, 3) + r = jax.random.uniform( + rng_r, + shape=(), + minval=self.noise_config[f"_{obj}_init_pos"].radius_min, + maxval=self.noise_config[f"_{obj}_init_pos"].radius_max, + ) + par = self.noise_config[f"_{obj}_init_pos"].angle + # Can't be a symmetric fan or depth cameras can't distinguish objects + angle = jax.random.uniform( + rng_angle, + shape=(), + minval=-par / 2, + maxval=par / 2, + ) + dx = r * jp.cos(angle) + dy = r * jp.sin(angle) + # Jitter the angle so the object isn't perfectly aligned. + angle_noise = jp.deg2rad(5) + rng, rng_noise = jax.random.split(rng) + angle += jax.random.uniform( + rng_noise, shape=(), minval=-angle_noise, maxval=angle_noise + ) + quat = jp.array([jp.cos(angle / 2), 0.0, 0.0, jp.sin(angle / 2)]) + return jp.array([dx, dy, 0.0]), quat + + def init_objects(self, rng: jax.Array) -> Tuple[mjx.Data, dict[str, Any]]: + """Initialize object positions and targets. + + Args: + rng: Random number generator key + + Returns: + Tuple of MJX data and info dictionary + """ + info = {} + init_q = jp.array(self._init_q) + + for obj, targ, side in zip( + self.obj_names, self.target_positions, self.hands + ): # Defined by child class. + + obj_idx = getattr(self, f"_{obj}_qposadr") + + # Object Position. + rng, rng_offset = jax.random.split(rng) + offset, quat_offset = self.sample_fan(rng_offset, obj) + t = self.base_info[side]["pos"] + idx_offset = 8 if side == "right" else 0 + base_angle = self._init_q[idx_offset] + if side == "right": + base_angle += np.deg2rad(180) + base_quat = jp.array( + [jp.cos(base_angle / 2), 0.0, 0.0, jp.sin(base_angle / 2)] + ) + # rotation_matrix = self.base_info[side]["xmat"] + rotation_matrix = math.quat_to_mat(base_quat) + obj_pos = self.point2global(offset, rotation_matrix.T, t) + init_q = init_q.at[obj_idx : obj_idx + 3].set(obj_pos) + + # Convert quat to mat + obj_quat = math.quat_mul(base_quat, quat_offset) + init_q = init_q.at[obj_idx + 3 : obj_idx + 7].set(obj_quat) + + # Target Position. + rng, rng_target = jax.random.split(rng) + range_val = self.noise_config[f"_{obj}_target_pos"] + info[f"_{obj}_target_pos"] = targ + jax.random.uniform( + rng_target, (3,), minval=-range_val, maxval=range_val + ) + + # Waist init. + for hand in self.hands: + rng, rng_waist = jax.random.split(rng) + range_val = self.noise_config[f"_{hand}_waist_init_pos"] + first_idx = getattr(self, f"_{hand}_qposadr")[0] + # fan is assymmetrical. TODO: False? + rand_setpoint = self._init_q[first_idx] + jax.random.uniform( + rng_waist, (), minval=-range_val, maxval=range_val + ) + init_q = init_q.at[first_idx].set(rand_setpoint) + # Change for ctrl as well. + first_idx_ctrl = getattr(self, f"_{hand}_ctrladr")[0] + init_ctrl = ( + jp.array(self._init_ctrl).at[first_idx_ctrl].set(rand_setpoint) + ) + + data = mjx_env.init( + self._mjx_model, + init_q, + jp.zeros(self._mjx_model.nv, dtype=float), + ctrl=init_ctrl, + ) + + for i, obj in enumerate(self.obj_names): + target_quat = jp.array(self.target_quats[i]) + mocap_target = getattr(self, f"_{obj}_mocap_target") + data = data.replace( + mocap_pos=data.mocap_pos.at[mocap_target, :].set( + info[f"_{obj}_target_pos"] + ), + mocap_quat=data.mocap_quat.at[mocap_target, :].set(target_quat), + ) + + info.update({ + "_steps": jp.array(0, dtype=int), + "rng": rng, + "action_history": jp.zeros( + (self._config.action_history_length, self.action_size), + dtype=jp.float32, + ), + "motor_targets": init_ctrl, + "init_ctrl": init_ctrl, + }) + + return data, info + + def _step(self, state: State, action: jax.Array) -> mjx.Data: + """ + Implements action scaling and random gripper delay. + """ + # Reset if needed. + newly_reset = state.info["_steps"] == 0 + state.info["action_history"] = jp.where( + newly_reset, + jp.zeros( + (self._config.action_history_length, self.action_size), + dtype=jp.float32, + ), + state.info["action_history"], + ) + + action_history = ( + jp.roll(state.info["action_history"], 1, axis=0).at[0].set(action) + ) + state.info["action_history"] = action_history + + # Add action delay for all joints + state.info["rng"], key_joints = jax.random.split(state.info["rng"]) + logits = jp.log(jp.array(ZIPF_S3[: self._config.action_history_length])) + action_idx = jax.random.categorical(key_joints, logits) + action = state.info["action_history"][action_idx] + + # Stronger noise to the grippers + state.info["rng"], key_fingers = jax.random.split(state.info["rng"]) + action_idx = jax.random.randint( + key_fingers, (), minval=0, maxval=self._config.action_history_length + ) + action = action.at[self._finger_ctrladr].set( + state.info["action_history"][action_idx][self._finger_ctrladr] + ) + + delta = action * self._action_scale + ctrl = state.data.ctrl + delta + ctrl = jp.clip(ctrl, self._lowers, self._uppers) + data = mjx_env.step(self._mjx_model, state.data, ctrl, self.n_substeps) + state.info["motor_targets"] = ctrl + return data + + def gripping_error(self, data, hand, obj) -> float: + """Calculate the error between gripper and object grip site. + + Args: + data: MJX data + hand: Hand name ('left' or 'right') + obj: Object name + + Returns: + Vector from gripper to grip site in local coordinates + """ + rotation_matrix = data.xmat[getattr(self, f"_{hand}_base_link")] + t = data.xpos[getattr(self, f"_{hand}_base_link")] + point2local = functools.partial( + self.point2local, rotation_matrix=rotation_matrix, t=t + ) + p_lfing = data.site_xpos[getattr(self, f"_{hand}_left_fingertip")] + p_rfing = data.site_xpos[getattr(self, f"_{hand}_right_fingertip")] + p_mid = (p_lfing + p_rfing) / 2 + grip_here = data.site_xpos[getattr(self, f"_{obj}_grip_site")] + gripper_obj = point2local(p_mid) - point2local(grip_here) + return gripper_obj + + def gripper_tip_pos(self, data, hand) -> jax.Array: + p_lfing = data.site_xpos[getattr(self, f"_{hand}_left_fingertip")] + p_rfing = data.site_xpos[getattr(self, f"_{hand}_right_fingertip")] + return (p_lfing + p_rfing) / 2 + + def _get_obs_pick_helper( + self, data: mjx.Data, info: dict[str, Any], side: str, obj: str + ) -> jax.Array: + """Calculate observations for pickup task between robot and object. + + Coordinates from Forward Kinematics are with respect to the side's base. + + Args: + data: MJX data + info: Info dictionary + side: Robot side ('left' or 'right') + obj: Object name + + Returns: + Observation array + """ + # Robot minimal coords + i_rob_qpos = getattr(self, f"_{side}_qposadr") + rob_qpos = data.qpos[i_rob_qpos] + rob_qvel = data.qvel[i_rob_qpos] + + # Object minimal coords + i_obj_qvel = getattr(self, f"_{obj}_qveladr") + i_obj_qvel = np.linspace(i_obj_qvel, i_obj_qvel + 5, 6, dtype=int) + g_obj_v, g_obj_angv = jp.split(data.qvel[i_obj_qvel], 2, axis=-1) + + # Derived quantities + # g_gripper_pos = data.site_xpos[getattr(self, f"_{side}_gripper_site")] + g_obj_pos = data.xpos[getattr(self, f"_{obj}_body")] + g_target_pos = info[f"_{obj}_target_pos"] + g_gripper_mat = data.site_xmat[getattr(self, f"_{side}_gripper_site")] + g_obj_mat = data.xmat[getattr(self, f"_{obj}_body")] + g_target_mat = math.quat_to_mat( + data.mocap_quat[getattr(self, f"_{obj}_mocap_target")] + ) + rotation_matrix = data.xmat[ + getattr(self, f"_{side}_base_link") + ] # world to local. Orientation. + t = data.xpos[ + getattr(self, f"_{side}_base_link") + ] # world to local. Translation. + + frame2local = functools.partial( + self.frame2local, rotation_matrix=rotation_matrix + ) + point2local = functools.partial( + self.point2local, rotation_matrix=rotation_matrix, t=t + ) + + obj_v, obj_angv = frame2local(g_obj_v), frame2local(g_obj_angv) + # gripper_pos = point2local(g_gripper_pos) + gripper_mat = frame2local(g_gripper_mat) + obj_mat = frame2local(g_obj_mat) + gripper_box = self.gripping_error(data, side, obj) # local + target_pos = point2local(g_target_pos) + obj_pos = point2local(g_obj_pos) + target_mat = frame2local(g_target_mat) + + #### ADD NOISE #### + # QPOS, QVEL + info["rng"], key_qpos, key_qvel = jax.random.split(info["rng"], 3) + noise = jax.random.uniform( + key_qpos, + rob_qpos.shape, + minval=-self._config.obs_noise.robot_qpos, + maxval=self._config.obs_noise.robot_qpos, + ) * jp.array(QPOS_NOISE_MASK_SINGLE) + n_rob_qpos = rob_qpos + noise + + noise = jax.random.uniform( + key_qvel, + rob_qvel.shape, + minval=-self._config.obs_noise.robot_qvel, + maxval=self._config.obs_noise.robot_qvel, + ) * jp.array(QPOS_NOISE_MASK_SINGLE) + n_rob_qvel = rob_qvel + noise + + # OBJ V, ANGV + info["rng"], key_obj_v, key_obj_angv = jax.random.split(info["rng"], 3) + n_obj_v = obj_v + jax.random.uniform( + key_obj_v, + obj_v.shape, + minval=-self._config.obs_noise.obj_vel, + maxval=self._config.obs_noise.obj_vel, + ) + n_obj_angv = obj_angv + jax.random.uniform( + key_obj_angv, + obj_angv.shape, + minval=-self._config.obs_noise.obj_angvel, + maxval=self._config.obs_noise.obj_angvel, + ) + # GRIPPER, OBJ MAT + info["rng"], key1, key2 = jax.random.split(info["rng"], 3) + angle = jax.random.uniform( + key1, + minval=0, + maxval=self._config.obs_noise.eef_angle * jp.pi / 180, + ) + rand_quat = math.axis_angle_to_quat(get_rand_dir(key2), angle) + rand_mat = math.quat_to_mat(rand_quat) + n_gripper_mat = rand_mat @ gripper_mat + + info["rng"], key1, key2 = jax.random.split(info["rng"], 3) + angle = jax.random.uniform( + key1, + minval=0, + maxval=self._config.obs_noise.obj_angle * jp.pi / 180, + ) + rand_quat = math.axis_angle_to_quat(get_rand_dir(key2), angle) + rand_mat = math.quat_to_mat(rand_quat) + n_obj_mat = rand_mat @ obj_mat + + # GRIPPER BOX + info["rng"], key_gripper_box = jax.random.split(info["rng"]) + noise_val = jax.random.uniform( + key_gripper_box, + (2, 3), + minval=-self._config.obs_noise.gripper_box, + maxval=self._config.obs_noise.gripper_box, + ) + # Triangle distribution + n_gripper_box = gripper_box + (noise_val[1] - noise_val[0]) + + # OBJ POS + info["rng"], key_obj = jax.random.split(info["rng"]) + n_obj_pos = obj_pos + jax.random.uniform( + key_obj, + obj_pos.shape, + minval=-self._config.obs_noise.obj_pos, + maxval=self._config.obs_noise.obj_pos, + ) + + #### DONE ADDING NOISE #### + + return jp.concatenate([ + n_rob_qpos, # 0:8 + n_rob_qvel, # 8:16 + n_obj_v, # 16:19 + n_obj_angv, # 19:22 + n_gripper_mat.ravel()[3:], # 25:31 OLD + n_obj_mat.ravel()[3:], # 31:37 + n_gripper_box, # 37:40 + target_pos - n_obj_pos, # 40:43 + target_mat.ravel()[:6] - n_obj_mat.ravel()[:6], # 43:49 + data.ctrl[getattr(self, f"_{side}_ctrladr")] - n_rob_qpos[:-1], # 49:56 + ]) + + def _get_obs_pick(self, data: mjx.Data, info: dict[str, Any]) -> jax.Array: + """ + Calculate the observations in local coordinates + allowing the gripper to pick up an object. + Returns left and right-hand observations. + """ + all_obs = [] + for side, obj in zip(self.hands, self.obj_names): + all_obs.append(self._get_obs_pick_helper(data, info, side, obj)) + return jp.concatenate(all_obs) + + def is_grasped(self, data, hand) -> float: + # Grasped if both fingers have applied forces > 5. + t_f = 2.5 # min force. Don't need to squeeze so hard! + + # 3D vec; top and bottom collision bodies + f_lfing = self.get_finger_force(data, hand, "left") + f_rfing = self.get_finger_force(data, hand, "right") + + d_lfing = self.get_finger_dir(data, hand, "left") + d_rfing = -1 * d_lfing + + l_d_flag = self.check_dir(f_lfing, d_lfing) + l_f_flag = (jp.linalg.norm(f_lfing) > t_f).astype(float) + r_d_flag = self.check_dir(f_rfing, d_rfing) + r_f_flag = (jp.linalg.norm(f_rfing) > t_f).astype(float) + + grasped = jp.all(jp.array([l_d_flag, l_f_flag, r_d_flag, r_f_flag])).astype( + float + ) + + return grasped + + def get_finger_force(self, data, hand, finger): + """ + Sum up the 3D force vectors across bottom and top collision primitives + """ + ids = jp.array([ + self._mj_model.geom(f"{hand}/{finger}_finger_{pos}").id + for pos in ["top", "bottom"] + ]) # 2 + contact_forces = [ + contact_force(self._mjx_model, data, i, True)[None, :3] # 1, 3 + for i in np.arange(data.ncon) + ] + contact_forces = jp.concat(contact_forces, axis=0) # ncon, 3 + matches = jp.isin(data.contact.geom, ids).any(axis=1) # ncon + dist_mask = data.contact.dist < 0 # ncon + + # Sum + return jp.sum(contact_forces * (matches * dist_mask)[:, None], axis=0) + + def get_finger_dir(self, data, hand, finger): + """ + A vector pointing from `finger` to the other finger. + """ + other = "left" if finger == "right" else "right" + + site_fing = mujoco.mj_name2id( + self.mj_model, mujoco.mjtObj.mjOBJ_SITE.value, f"{hand}/{finger}_finger" + ) + site_ofing = mujoco.mj_name2id( + self.mj_model, mujoco.mjtObj.mjOBJ_SITE.value, f"{hand}/{other}_finger" + ) + + v = data.site_xpos[site_ofing] - data.site_xpos[site_fing] + + return v / (jp.linalg.norm(v) + 1e-7) + + def check_dir(self, v1, v2, t_align=jp.deg2rad(75)) -> float: + m = jp.linalg.norm(v1) * jp.linalg.norm(v2) + return (jp.arccos(jp.dot(v1, v2) / (m + 1e-7)) < t_align).astype(float) + + def frame2local(self, frame, rotation_matrix): + """Convert frame from global to local coordinates. + + Args: + frame: Frame in global coordinates + rotation_matrix: Rotation matrix from global to local + + Returns: + Frame in local coordinates + """ + return rotation_matrix @ frame + + def point2local(self, point, rotation_matrix, t): + """Convert point from global to local coordinates. + + Args: + point: Point in global coordinates + rotation_matrix: Rotation matrix from global to local + t: Translation vector + + Returns: + Point in local coordinates + """ + return self.frame2local(point - t, rotation_matrix) + + def point2global(self, point, rotation_matrix, t): + """Convert point from local to global coordinates. + + Args: + point: Point in local coordinates + rotation_matrix: Rotation matrix from global to local + t: Translation vector + + Returns: + Point in global coordinates + """ + return rotation_matrix.T @ point + t diff --git a/mujoco_playground/_src/manipulation/aloha/single_peg_insertion.py b/mujoco_playground/_src/manipulation/aloha/single_peg_insertion.py deleted file mode 100644 index 5230d7c0b..000000000 --- a/mujoco_playground/_src/manipulation/aloha/single_peg_insertion.py +++ /dev/null @@ -1,254 +0,0 @@ -# Copyright 2025 DeepMind Technologies Limited -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Peg insertion task for ALOHA.""" - -from typing import Any, Dict, Optional, Union - -import jax -from jax import numpy as jp -from ml_collections import config_dict -from mujoco import mjx - -from mujoco_playground._src import mjx_env -from mujoco_playground._src import reward as reward_util -from mujoco_playground._src.manipulation.aloha import aloha_constants as consts -from mujoco_playground._src.manipulation.aloha import base as aloha_base - - -def default_config() -> config_dict.ConfigDict: - return config_dict.create( - ctrl_dt=0.0025, - sim_dt=0.0025, - episode_length=1000, - action_repeat=2, - action_scale=0.005, - reward_config=config_dict.create( - scales=config_dict.create( - left_reward=1, - right_reward=1, - left_target_qpos=0.3, - right_target_qpos=0.3, - no_table_collision=0.3, - socket_z_up=0.5, - peg_z_up=0.5, - socket_entrance_reward=4, - peg_end2_reward=4, - peg_insertion_reward=8, - ) - ), - ) - - -class SinglePegInsertion(aloha_base.AlohaEnv): - """Single peg insertion task for ALOHA.""" - - def __init__( - self, - config: config_dict.ConfigDict = default_config(), - config_overrides: Optional[Dict[str, Union[str, int, list[Any]]]] = None, - ): - super().__init__( - xml_path=(consts.XML_PATH / "mjx_single_peg_insertion.xml").as_posix(), - config=config, - config_overrides=config_overrides, - ) - self._post_init() - - def _post_init(self): - self._post_init_aloha(keyframe="home") - self._socket_entrance_site = self._mj_model.site("socket_entrance").id - self._socket_rear_site = self._mj_model.site("socket_rear").id - self._peg_end2_site = self._mj_model.site("peg_end2").id - self._socket_body = self._mj_model.body("socket").id - self._peg_body = self._mj_model.body("peg").id - - self._socket_qadr = self._mj_model.jnt_qposadr[ - self._mj_model.body_jntadr[self._socket_body] - ] - self._peg_qadr = self._mj_model.jnt_qposadr[ - self._mj_model.body_jntadr[self._peg_body] - ] - - # Lift goal: both in the air. - self._socket_entrance_goal_pos = jp.array([-0.05, 0, 0.15]) - self._peg_end2_goal_pos = jp.array([0.05, 0, 0.15]) - - def reset(self, rng: jax.Array) -> mjx_env.State: - rng, rng_peg, rng_socket = jax.random.split(rng, 3) - - peg_xy = jax.random.uniform(rng_peg, (2,), minval=-0.1, maxval=0.1) - socket_xy = jax.random.uniform(rng_socket, (2,), minval=-0.1, maxval=0.1) - init_q = self._init_q.at[self._peg_qadr : self._peg_qadr + 2].add(peg_xy) - init_q = init_q.at[self._socket_qadr : self._socket_qadr + 2].add(socket_xy) - - data = mjx_env.init( - self._mjx_model, - init_q, - jp.zeros(self._mjx_model.nv, dtype=float), - ctrl=self._init_ctrl, - ) - - info = {"rng": rng} - obs = self._get_obs(data) - reward, done = jp.zeros(2) - metrics = { - "out_of_bounds": jp.array(0.0, dtype=float), - "peg_end2_dist_to_line": jp.array(0.0, dtype=float), - **{k: 0.0 for k in self._config.reward_config.scales.keys()}, - } - - return mjx_env.State(data, obs, reward, done, metrics, info) - - def step(self, state: mjx_env.State, action: jax.Array) -> mjx_env.State: - delta = action * self._config.action_scale - ctrl = state.data.ctrl + delta - ctrl = jp.clip(ctrl, self._lowers, self._uppers) - - data = mjx_env.step(self._mjx_model, state.data, ctrl, self.n_substeps) - - socket_entrance_pos = data.site_xpos[self._socket_entrance_site] - socket_rear_pos = data.site_xpos[self._socket_rear_site] - peg_end2_pos = data.site_xpos[self._peg_end2_site] - # Insertion reward: if peg end2 is aligned with hole entrance, then reward - # distance from peg end to socket interior. - socket_ab = socket_entrance_pos - socket_rear_pos - socket_t = jp.dot(peg_end2_pos - socket_rear_pos, socket_ab) - socket_t /= jp.dot(socket_ab, socket_ab) + 1e-6 - nearest_pt = data.site_xpos[self._socket_rear_site] + socket_t * socket_ab - peg_end2_dist_to_line = jp.linalg.norm(peg_end2_pos - nearest_pt) - - out_of_bounds = jp.any(jp.abs(data.xpos[self._socket_body]) > 1.0) - out_of_bounds |= jp.any(jp.abs(data.xpos[self._peg_body]) > 1.0) - - raw_rewards = self._get_reward( - data, use_peg_insertion_reward=(peg_end2_dist_to_line < 0.005) - ) - rewards = { - k: v * self._config.reward_config.scales[k] - for k, v in raw_rewards.items() - } - reward = sum(rewards.values()) / sum( - self._config.reward_config.scales.values() - ) - - done = out_of_bounds | jp.isnan(data.qpos).any() | jp.isnan(data.qvel).any() - done = done.astype(float) - state.metrics.update( - **rewards, - peg_end2_dist_to_line=peg_end2_dist_to_line, - out_of_bounds=out_of_bounds.astype(float), - ) - obs = self._get_obs(data) - return mjx_env.State(data, obs, reward, done, state.metrics, state.info) - - def _get_obs(self, data: mjx.Data) -> jax.Array: - left_gripper_pos = data.site_xpos[self._left_gripper_site] - socket_pos = data.xpos[self._socket_body] - right_gripper_pos = data.site_xpos[self._right_gripper_site] - peg_pos = data.xpos[self._peg_body] - socket_entrance_pos = data.site_xpos[self._socket_entrance_site] - peg_end2_pos = data.site_xpos[self._peg_end2_site] - socket_z = data.xmat[self._socket_body].ravel()[6:] - peg_z = data.xmat[self._peg_body].ravel()[6:] - - obs = jp.concatenate([ - data.qpos, - data.qvel, - left_gripper_pos, - socket_pos, - right_gripper_pos, - peg_pos, - socket_entrance_pos, - peg_end2_pos, - socket_z, - peg_z, - ]) - - return obs - - def _get_reward( - self, data: mjx.Data, use_peg_insertion_reward: bool - ) -> Dict[str, jax.Array]: - left_socket_dist = jp.linalg.norm( - data.xpos[self._socket_body] - data.site_xpos[self._left_gripper_site] - ) - left_reward = reward_util.tolerance( - left_socket_dist, (0, 0.001), margin=0.3, sigmoid="linear" - ) - right_peg_dist = jp.linalg.norm( - data.xpos[self._peg_body] - data.site_xpos[self._right_gripper_site] - ) - right_reward = reward_util.tolerance( - right_peg_dist, (0, 0.001), margin=0.3, sigmoid="linear" - ) - - robot_qpos_diff = data.qpos[self._arm_qadr] - self._init_q[self._arm_qadr] - left_pose = jp.linalg.norm(robot_qpos_diff[:6]) - left_pose = reward_util.tolerance(left_pose, (0, 0.01), margin=2.0) - right_pose = jp.linalg.norm(robot_qpos_diff[6:]) - right_pose = reward_util.tolerance(right_pose, (0, 0.01), margin=2.0) - - socket_dist = jp.linalg.norm( - self._socket_entrance_goal_pos - data.xpos[self._socket_body] - ) - socket_lift = reward_util.tolerance( - socket_dist, (0, 0.01), margin=0.15, sigmoid="linear" - ) - - peg_dist = jp.linalg.norm( - self._peg_end2_goal_pos - data.xpos[self._peg_body] - ) - peg_lift = reward_util.tolerance( - peg_dist, (0, 0.01), margin=0.15, sigmoid="linear" - ) - - table_collision = self.hand_table_collision(data) - - socket_orientation = jp.dot( - data.xmat[self._socket_body][2], jp.array([0.0, 0.0, 1.0]) - ) - socket_orientation = reward_util.tolerance( - socket_orientation, (0.99, 1.0), margin=0.03, sigmoid="linear" - ) - peg_orientation = jp.dot( - data.xmat[self._peg_body][2], jp.array([0.0, 0.0, 1.0]) - ) - peg_orientation = reward_util.tolerance( - peg_orientation, (0.99, 1.0), margin=0.03, sigmoid="linear" - ) - - peg_insertion_dist = jp.linalg.norm( - data.site_xpos[self._peg_end2_site] - - data.site_xpos[self._socket_rear_site] - ) - peg_insertion_reward = ( - reward_util.tolerance( - peg_insertion_dist, (0, 0.001), margin=0.1, sigmoid="linear" - ) - * use_peg_insertion_reward - ) - - return { - "left_reward": left_reward, - "right_reward": right_reward, - "left_target_qpos": left_pose * left_reward * right_reward, - "right_target_qpos": right_pose * left_reward * right_reward, - "no_table_collision": 1 - table_collision, - "socket_entrance_reward": socket_lift, - "peg_end2_reward": peg_lift, - "socket_z_up": socket_orientation * socket_lift, - "peg_z_up": peg_orientation * peg_lift, - "peg_insertion_reward": peg_insertion_reward, - } diff --git a/mujoco_playground/_src/manipulation/aloha/xmls/mjx_aloha.xml b/mujoco_playground/_src/manipulation/aloha/xmls/mjx_aloha.xml index fb2b674bd..be68bb6d5 100644 --- a/mujoco_playground/_src/manipulation/aloha/xmls/mjx_aloha.xml +++ b/mujoco_playground/_src/manipulation/aloha/xmls/mjx_aloha.xml @@ -88,6 +88,9 @@ + + + @@ -105,8 +108,8 @@ - - + + @@ -157,13 +160,12 @@ - + - + - + @@ -175,6 +177,7 @@ + + @@ -244,13 +248,12 @@ - + - + - + @@ -262,6 +265,7 @@ + + diff --git a/mujoco_playground/_src/manipulation/aloha/xmls/mjx_half_aloha.xml b/mujoco_playground/_src/manipulation/aloha/xmls/mjx_half_aloha.xml new file mode 100644 index 000000000..c5597ad83 --- /dev/null +++ b/mujoco_playground/_src/manipulation/aloha/xmls/mjx_half_aloha.xml @@ -0,0 +1,218 @@ + + + + \ No newline at end of file diff --git a/mujoco_playground/_src/manipulation/aloha/xmls/mjx_half_scene.xml b/mujoco_playground/_src/manipulation/aloha/xmls/mjx_half_scene.xml new file mode 100644 index 000000000..3dae4488e --- /dev/null +++ b/mujoco_playground/_src/manipulation/aloha/xmls/mjx_half_scene.xml @@ -0,0 +1,98 @@ + + + + + + + + + + diff --git a/mujoco_playground/_src/manipulation/aloha/xmls/mjx_pick.xml b/mujoco_playground/_src/manipulation/aloha/xmls/mjx_pick.xml new file mode 100644 index 000000000..2c3d790f2 --- /dev/null +++ b/mujoco_playground/_src/manipulation/aloha/xmls/mjx_pick.xml @@ -0,0 +1,34 @@ + + + + diff --git a/mujoco_playground/_src/manipulation/aloha/xmls/mjx_scene.xml b/mujoco_playground/_src/manipulation/aloha/xmls/mjx_scene.xml index d12ff032e..c454385cf 100644 --- a/mujoco_playground/_src/manipulation/aloha/xmls/mjx_scene.xml +++ b/mujoco_playground/_src/manipulation/aloha/xmls/mjx_scene.xml @@ -56,12 +56,12 @@ - + - - + diff --git a/mujoco_playground/_src/manipulation/aloha/xmls/mjx_single_peg_insertion.xml b/mujoco_playground/_src/manipulation/aloha/xmls/mjx_single_peg_insertion.xml index f21ac52f0..ea01ca7a7 100644 --- a/mujoco_playground/_src/manipulation/aloha/xmls/mjx_single_peg_insertion.xml +++ b/mujoco_playground/_src/manipulation/aloha/xmls/mjx_single_peg_insertion.xml @@ -1,8 +1,4 @@ - - - - - + + - - - - - - + + + + + + + + + + + - + + - - + + + + + + + + + + - - + + + + + + + + + + + + + + + diff --git a/mujoco_playground/_src/registry.py b/mujoco_playground/_src/registry.py index e7f0294d3..a2d988279 100644 --- a/mujoco_playground/_src/registry.py +++ b/mujoco_playground/_src/registry.py @@ -31,9 +31,7 @@ # A tuple containing all available environment names across all suites. ALL_ENVS = ( - dm_control_suite.ALL_ENVS - + locomotion.ALL_ENVS - + manipulation.ALL_ENVS + dm_control_suite.ALL_ENVS + locomotion.ALL_ENVS + manipulation.ALL_ENVS ) diff --git a/mujoco_playground/_src/wrapper.py b/mujoco_playground/_src/wrapper.py index 62e2823b8..5a4ccbd95 100644 --- a/mujoco_playground/_src/wrapper.py +++ b/mujoco_playground/_src/wrapper.py @@ -237,7 +237,11 @@ def _supplement_vision_randomization_fn( for field in required_fields: if getattr(in_axes, field) is None: in_axes = in_axes.tree_replace({field: 0}) - val = -1 if field == 'geom_matid' else getattr(mjx_model, field) + val = ( + jp.repeat(-1, mjx_model.geom_matid.shape[0], 0) + if field == 'geom_matid' + else getattr(mjx_model, field) + ) mjx_model = mjx_model.tree_replace({ field: jp.repeat(jp.expand_dims(val, 0), num_worlds, axis=0), }) diff --git a/mujoco_playground/config/manipulation_params.py b/mujoco_playground/config/manipulation_params.py index 1e5c4e442..e7ddff016 100644 --- a/mujoco_playground/config/manipulation_params.py +++ b/mujoco_playground/config/manipulation_params.py @@ -49,18 +49,39 @@ def brax_ppo_config(env_name: str) -> config_dict.ConfigDict: rl_config.batch_size = 512 rl_config.max_grad_norm = 1.0 rl_config.network_factory.policy_hidden_layer_sizes = (256, 256, 256) - elif env_name == "AlohaSinglePegInsertion": - rl_config.num_timesteps = 150_000_000 - rl_config.num_evals = 10 - rl_config.unroll_length = 40 + elif env_name == "AlohaPick": + rl_config.num_timesteps = 20_000_000 + rl_config.num_evals = int(rl_config.num_timesteps / 4_000_000) + rl_config.unroll_length = 10 rl_config.num_minibatches = 32 rl_config.num_updates_per_batch = 8 rl_config.discounting = 0.97 - rl_config.learning_rate = 3e-4 - rl_config.entropy_cost = 1e-2 - rl_config.num_envs = 1024 + rl_config.learning_rate = 1e-3 + rl_config.entropy_cost = 2e-2 + rl_config.num_envs = 2048 + rl_config.num_eval_envs = 128 rl_config.batch_size = 512 - rl_config.network_factory.policy_hidden_layer_sizes = (256, 256, 256, 256) + rl_config.network_factory.policy_hidden_layer_sizes = (32, 32, 32, 32) + rl_config.network_factory.policy_obs_key = "state" + rl_config.network_factory.value_obs_key = "privileged" + elif env_name == "AlohaPegInsertion": + rl_config.num_timesteps = 10_000_000 + rl_config.num_evals = max(int(rl_config.num_timesteps / 2_000_000), 5) + rl_config.unroll_length = 10 + rl_config.num_minibatches = 32 + rl_config.num_updates_per_batch = 8 + rl_config.num_eval_envs = 128 + rl_config.discounting = 0.97 + rl_config.learning_rate = 1e-3 + rl_config.entropy_cost = 2e-2 + rl_config.num_envs = 2048 + rl_config.batch_size = 512 + rl_config.network_factory = config_dict.create( + policy_hidden_layer_sizes=(32, 32, 32, 32), + value_hidden_layer_sizes=(32, 32, 32, 32), + policy_obs_key="state", + value_obs_key="privileged", + ) elif env_name == "PandaOpenCabinet": rl_config.num_timesteps = 40_000_000 rl_config.num_evals = 4 diff --git a/mujoco_playground/experimental/bc_peg_insertion.py b/mujoco_playground/experimental/bc_peg_insertion.py new file mode 100644 index 000000000..1ebf38ff6 --- /dev/null +++ b/mujoco_playground/experimental/bc_peg_insertion.py @@ -0,0 +1,212 @@ +# Copyright 2025 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# pylint: skip-file + +"""Train a policy for the Aloha peg insertion task using Behavior Cloning.""" + +import os + +os.environ[ + "XLA_PYTHON_CLIENT_PREALLOCATE" +] = ( # Ensure that Madrona gets the chance to pre-allocate memory before Jax + "false" +) + + +from datetime import datetime +import functools + +from brax.training.agents.bc import networks as bc_networks +from brax.training.agents.bc import train as bc_fast +from etils import epath +import jax +from jax import numpy as jp +import typer + +from mujoco_playground import manipulation +from mujoco_playground import registry +from mujoco_playground import wrapper +from mujoco_playground._src.manipulation.aloha import distillation + +app = typer.Typer(pretty_exceptions_enable=False) + + +@app.command() +def main( + seed: int = typer.Option(0, help="Random seed"), + print_reward: bool = typer.Option( + False, help="Prints the per-step reward in the data collection phase" + ), + print_loss: bool = typer.Option( + False, help="Prints the actor loss in the student-fitting phase" + ), + domain_randomization: bool = typer.Option( + False, help="Use domain randomization" + ), + vision: bool = typer.Option(True, help="Use vision"), + num_envs: int = typer.Option(1024, help="Number of parallel environments"), + episode_length: int = typer.Option(160, help="Length of each episode"), + dagger_steps: int = typer.Option(400, help="Number of DAgger steps"), + num_evals: int = typer.Option(5, help="Number of evaluation episodes"), + demo_length: int = typer.Option(6, help="Length of demonstrations"), +): + env_name = "AlohaPegInsertionDistill" + env_cfg = manipulation.get_default_config(env_name) + + config_overrides = { + "episode_length": episode_length, + "vision": vision, + "vision_config.enabled_geom_groups": [ + 1, + 2, + 5, + ], # Disable mocaps on group 0. + "vision_config.use_rasterizer": False, + "vision_config.render_batch_size": num_envs, + "vision_config.render_width": 32, + "vision_config.render_height": 32, + } + + env = manipulation.load( + env_name, config=env_cfg, config_overrides=config_overrides + ) + + randomizer = None + if domain_randomization: + randomizer = registry.get_domain_randomizer(env_name) + # Randomizer expected to only require mjx model input. + key_rng = jax.random.PRNGKey(seed) + randomizer = functools.partial( + randomizer, rng=jax.random.split(key_rng, num_envs) + ) + + env = wrapper.wrap_for_brax_training( + env, + vision=vision, + num_vision_envs=num_envs, + episode_length=episode_length, + action_repeat=1, + randomization_fn=randomizer, + ) + + network_factory = functools.partial( # Student network factory. + bc_networks.make_bc_networks, + policy_hidden_layer_sizes=(256,) * 3, + policy_obs_key=("proprio" if vision else "state_with_time"), + latent_vision=True, + ) + + teacher_inference_fn = distillation.make_teacher_policy() + + # Generate unique experiment name. + now = datetime.now() + timestamp = now.strftime("%Y%m%d-%H%M%S") + exp_name = f"{env_name}-{timestamp}" + + ckpt_path = epath.Path("logs").resolve() / exp_name + + epochs = 4 + augment_pixels = True + dagger_beta_fn = lambda step: jp.where(step == 0, 1.0, 0.0) + + def get_num_dagger_steps(num_evals, target): + """Round down to the nearest multiple of num_evals - 1.""" + dagger_steps = target // (num_evals - 1) * (num_evals - 1) + print("Dagger steps:", dagger_steps) + print( + "Episodes per environment:", + (dagger_steps * demo_length / episode_length), + ) + print( + "Total episodes across all environments:", + num_envs * (dagger_steps * demo_length / episode_length), + ) + print( + "Total steps across all environments:", + (dagger_steps * demo_length * num_envs), + ) + return dagger_steps + + dagger_steps = get_num_dagger_steps(num_evals, dagger_steps) + + def progress(epoch, metrics: dict): + timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + if epoch == 0 and num_evals > 0: + print( + f"""[{timestamp}] Dagger Step {epoch}: + Eval Reward: {metrics['eval/episode_reward']:.4f} ± {metrics['eval/episode_reward_std']:.4f} \n\n""" + ) + return + + actor_loss = jp.mean(metrics["actor_loss"], axis=(-1, -2)) + + if print_loss: + actor_loss = jp.ravel(actor_loss) + for loss in actor_loss: + print(f"SGD Actor Loss: {loss:.4f}") + + if print_reward: + r_means = metrics["reward_mean"].ravel() + # Ensure divisibility by 30. + r_means = r_means[: len(r_means) // 30 * 30] + r_means = r_means.reshape(-1, 30).mean(axis=1) # Downsample. + for r_mean in r_means: + print(f"Rewards: {r_mean:.4f}") + + print( + f"""[{timestamp}] Dagger Step {epoch}: + Actor Loss: {jp.mean(actor_loss):.4f}, SPS: {metrics['SPS']:.4f}, Walltime: {metrics['walltime']:.4f} s""" + ) + suffix = ( + "\n\n" + if num_evals == 0 + else ( + f"\t\tEval Reward: {metrics['eval/episode_reward']:.4f} ±" + f" {metrics['eval/episode_reward_std']:.4f}\n\n" + ) + ) + print(suffix) + + train_fn = functools.partial( + bc_fast.train, + dagger_steps=dagger_steps, + demo_length=demo_length, + tanh_squash=True, + teacher_inference_fn=teacher_inference_fn, + normalize_observations=True, + epochs=epochs, + scramble_time=episode_length, + dagger_beta_fn=dagger_beta_fn, + batch_size=256, + env=env, + num_envs=num_envs, + num_eval_envs=num_envs, + num_evals=num_evals, + eval_length=episode_length * 1.15, + network_factory=network_factory, + progress_fn=progress, + madrona_backend=True, + seed=seed, + learning_rate=4e-4, + augment_pixels=augment_pixels, + save_checkpoint_path=ckpt_path, + ) + print(f"Training start: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") + _, params, _ = train_fn() + print(f"Training done: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") + + +if __name__ == "__main__": + app() diff --git a/mujoco_playground/experimental/jax2onnx/aloha_nets_to_onnx.py b/mujoco_playground/experimental/jax2onnx/aloha_nets_to_onnx.py new file mode 100644 index 000000000..0508cf38b --- /dev/null +++ b/mujoco_playground/experimental/jax2onnx/aloha_nets_to_onnx.py @@ -0,0 +1,313 @@ +# Copyright 2025 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# pylint: skip-file +import os + + +def limit_jax_mem(limit): + os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = f'{limit:.2f}' + + +limit_jax_mem(0.1) + +# Tell XLA to use Triton GEMM +xla_flags = os.environ.get('XLA_FLAGS', '') +xla_flags += ' --xla_gpu_triton_gemm_any=True' +os.environ['XLA_FLAGS'] = xla_flags +os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false' + +import jax + +jax.config.update('jax_compilation_cache_dir', '/tmp/jax_cache') +jax.config.update('jax_persistent_cache_min_entry_size_bytes', -1) +jax.config.update('jax_persistent_cache_min_compile_time_secs', 0) + +import argparse +import functools +from pathlib import Path + +from brax.training.acme import running_statistics +from brax.training.acme import specs +from brax.training.agents.bc import networks as bc_networks +from brax.training.agents.ppo import train as ppo_train +from flax import linen +import jax.numpy as jp +import numpy as np +import onnxruntime as rt +from orbax import checkpoint as ocp +import tensorflow as tf +import tf2onnx + +from mujoco_playground._src import mjx_env +import mujoco_playground._src.manipulation.aloha.distillation as distillation +from mujoco_playground.experimental.jax2onnx.aloha_nets_utils import TFVisionMLP +from mujoco_playground.experimental.jax2onnx.aloha_nets_utils import transfer_jax_params_to_tf + +TEST_SCALE = 0.001 +action_size = 14 + + +def parse_args(): + parser = argparse.ArgumentParser( + description='Script to convert from brax to Onnx network.' + ) + parser.add_argument( + '--checkpoint_path', type=str, help='Path to the checkpoint file' + ) + return parser.parse_args() + + +args = parse_args() + +print(f'Checkpoint path: {args.checkpoint_path}') +# Make sure it doesn't end with 'checkpoints'. +if args.checkpoint_path.endswith('checkpoints'): + raise ValueError("Don't end with 'checkpoints'") + +ckpt_path = args.checkpoint_path +if ckpt_path.endswith('/'): + ckpt_path = ckpt_path[:-1] +onnx_foldername = ckpt_path.split('/')[-1] # foldername is envname-date-time + +onnx_base_path = Path(__file__).parent / 'onnx_policies' +# Make sure it exists. +onnx_base_path.mkdir(parents=True, exist_ok=True) +onnx_base_path = onnx_base_path / onnx_foldername + +# Define the observation sizes +pix_shape = (32, 32, 3) + +""" +May be different than the obs you get from the env. +The env obs already pre-applies the frozen encoder +But we need to test the frozen encoder, so we give +the obs as an internal state in the env. (only revelant to pixels/) +""" +obs_shape = { + 'has_switched': (1,), + 'pixels/rgb_l': pix_shape, + 'pixels/rgb_r': pix_shape, + 'latent_0': ( + 64, + ), # For simplicity, also include the 4 latents as in the obs for checkpoint-related param init. + 'latent_1': (64,), + 'latent_2': (64,), + 'latent_3': (64,), + 'privileged': (110,), + 'proprio': (33,), + 'state': (109,), + 'state_pickup': (106,), + 'socket_hidden': (1,), + 'peg_hidden': (1,), +} + +# For tf +tf_pix_shape = (1, 3, 32, 32) +state_dim = (1, 33) + + +# Trace / compile +class TFPolicyWrapper(tf.keras.Model): + + def __init__(self, policy_net): + super().__init__() + self.policy_net = policy_net + + def call(self, inputs): + # Unpack the dictionary into pixel streams and state vector. + pixel_streams = [inputs['pixels/view_0'], inputs['pixels/view_1']] + state = inputs['proprio'] + latents = [inputs['latent_0'], inputs['latent_1']] + return self.policy_net( + pixel_streams, states=state, depth_latent_streams=latents + ) + + +policy_hidden_layer_sizes = (256,) * 3 + +network_factory = functools.partial( + bc_networks.make_bc_networks, + policy_hidden_layer_sizes=policy_hidden_layer_sizes, + activation=linen.relu, + policy_obs_key='proprio', + latent_vision=True, +) + +bc_network = network_factory( + obs_shape, + action_size, + preprocess_observations_fn=running_statistics.normalize, +) +make_inference_fn = bc_networks.make_inference_fn(bc_network) +encoder_fn = distillation.get_frozen_encoder_fn() + +# Load encoder params for transfering to TF. +encoder_params = distillation.load_frozen_encoder_params() + +# Initialize param structure for loading with orbax checkpointer. +dummy_obs = {k: TEST_SCALE * jp.ones(v) for k, v in obs_shape.items()} +specs_obs_shape = jax.tree_util.tree_map( + lambda x: specs.Array(x.shape[-1:], jp.dtype('float32')), dummy_obs +) +spec_obs_shape = ppo_train._remove_pixels(specs_obs_shape) + +normalizer_params = running_statistics.init_state(spec_obs_shape) + +init_params = ( + normalizer_params, + bc_network.policy_network.init(jax.random.PRNGKey(0)), +) + + +def make_inference_fn_wrapper(params): + inference_fn = make_inference_fn(params, deterministic=True, tanh_squash=True) + base_inference_fn = inference_fn + + def inference_fn(jax_input, _): + # encoder_fn adds a batch dim. + encoder_inputs = {} + encoder_inputs['pixels/view_0'] = jax_input['pixels/rgb_l'][0] + encoder_inputs['pixels/view_1'] = jax_input['pixels/rgb_r'][0] + + latents = encoder_fn(encoder_inputs) + latents = jp.split(latents[None, ...], 2, axis=-1) + + p_ins = { + 'proprio': jax_input['proprio'], + 'latent_0': latents[0], # RGB latents + 'latent_1': latents[1], + 'latent_2': jax_input['latent_2'], # Depth latents + 'latent_3': jax_input['latent_3'], + } + + return base_inference_fn(p_ins, _) + + return inference_fn + + +def jax_params_to_onnx(params, output_path): + + onnx_input = { + 'pixels/view_0': TEST_SCALE * np.ones(tf_pix_shape, dtype=np.float32), + 'pixels/view_1': TEST_SCALE * np.ones(tf_pix_shape, dtype=np.float32), + 'proprio': TEST_SCALE * np.ones((state_dim), dtype=np.float32), + 'latent_0': TEST_SCALE * np.ones((1, 64), dtype=np.float32), + 'latent_1': TEST_SCALE * np.ones((1, 64), dtype=np.float32), + } + + mean = params[0].mean['proprio'] + std = params[0].std['proprio'] + mean_std = (tf.convert_to_tensor(mean), tf.convert_to_tensor(std)) + tf_policy_network = TFVisionMLP( + layer_sizes=policy_hidden_layer_sizes + (2 * action_size,), + normalise_channels=False, + layer_norm=False, + num_pixel_streams=2, + state_mean_std=mean_std, + action_size=action_size, + latent_dense_size=64, + ) + + tf_policy_network = TFPolicyWrapper(tf_policy_network) + tf_policy_network(onnx_input)[0] # Initialize. + transfer_jax_params_to_tf( + (encoder_params['params'], params[1]['params']), tf_policy_network + ) + + # --- Convert to ONNX --- + # Define the input signature for conversion. + input_signature = [{ + 'pixels/view_0': tf.TensorSpec( + shape=tf_pix_shape, dtype=tf.float32, name='pixels/view_0' + ), + 'pixels/view_1': tf.TensorSpec( + shape=tf_pix_shape, dtype=tf.float32, name='pixels/view_1' + ), + 'proprio': tf.TensorSpec( + shape=state_dim, dtype=tf.float32, name='proprio' + ), + 'latent_0': tf.TensorSpec( + shape=(1, 64), dtype=tf.float32, name='latent_0' + ), + 'latent_1': tf.TensorSpec( + shape=(1, 64), dtype=tf.float32, name='latent_1' + ), + }] + + tf_policy_network.output_names = ['continuous_actions'] + + # Convert the model to ONNX + # Convert using tf2onnx (opset 11 is used in this example). + tf2onnx.convert.from_keras( + tf_policy_network, + input_signature=input_signature, + opset=11, + output_path=output_path, + ) + + # Run inference with ONNX Runtime + output_names = ['continuous_actions'] + providers = ['CPUExecutionProvider'] + m = rt.InferenceSession(output_path, providers=providers) + + # Prepare inputs for ONNX Runtime + onnx_pred = jp.array(m.run(output_names, onnx_input)[0][0]) + + jax_input = { + key: TEST_SCALE * np.ones((1,) + shape) + for key, shape in obs_shape.items() + } + inference_fn = make_inference_fn_wrapper(params) + jax_pred = inference_fn(jax_input, jax.random.PRNGKey(0))[0][ + 0 + ] # Also returns action extras. Unbatch. + + try: + # Assert that they're close + assert np.allclose(jax_pred, onnx_pred, atol=1e-3) + print('\n\n===============================') + print(' Predictions match! ') + print('===============================\n\n') + except AssertionError as e: + print('Predictions do not match:', e) + print('JAX prediction:', jax_pred) + print('ONNX prediction:', onnx_pred) + # exit + raise e + + +experiment = Path(ckpt_path) +ckpts = [ + ckpt + for ckpt in experiment.glob('*') + if ckpt.name != 'bc_network_config.json' +] +ckpts.sort(key=lambda x: int(x.name)) +assert ckpts, 'No checkpoints found' +orbax_checkpointer = ocp.PyTreeCheckpointer() + + +for restore_checkpoint_path in ckpts: + checkpoint_name = restore_checkpoint_path.name + print('######### CONVERTING CHECKPOINT #########') + print( + f"{' ' * ((40 - len(checkpoint_name)) // 2)}{checkpoint_name}{' ' * ((40 - len(checkpoint_name)) // 2)}" + ) # Print centered. + print('#########################################') + + params = orbax_checkpointer.restore(restore_checkpoint_path, item=init_params) + jax_params_to_onnx( + params, onnx_base_path / f'{restore_checkpoint_path.name}.onnx' + ) diff --git a/mujoco_playground/experimental/jax2onnx/aloha_nets_utils.py b/mujoco_playground/experimental/jax2onnx/aloha_nets_utils.py new file mode 100644 index 000000000..c0782b314 --- /dev/null +++ b/mujoco_playground/experimental/jax2onnx/aloha_nets_utils.py @@ -0,0 +1,372 @@ +# Copyright 2025 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# pylint: skip-file +import numpy as np +import tensorflow as tf +from tensorflow.keras import layers + + +class TFVisionMLP(tf.keras.Model): + """ + Applies multiple CNN backbones (one per pixel stream) named NatureCNN_i, + each containing a single sub-block CNN_0 with Conv_0..Conv_2. + After the CNN output is pooled and projected via Dense_i, + we combine everything and pass it to MLP_0. + + When aux_loss is True, an additional auxiliary branch is computed from + the concatenation of the raw (pooled) CNN outputs (before the Dense projection) + via a Dense(6) layer. The aux output is then concatenated with the policy head. + """ + + def __init__( + self, + layer_sizes, + activation=tf.nn.relu, + kernel_initializer="lecun_uniform", + activate_final=False, + layer_norm=False, + normalise_channels=False, + num_pixel_streams=2, + state_mean_std=None, # (mean, std) for normalizing the state vector + action_size=2, # for demonstration + name="tf_vision_mlp", + latent_dense_size=128, + ): + super().__init__(name=name) + self.layer_sizes = layer_sizes + self.activation = activation + self.kernel_initializer = kernel_initializer + self.activate_final = activate_final + self.layer_norm = layer_norm + self.normalise_channels = normalise_channels + self.num_pixel_streams = num_pixel_streams + self.action_size = action_size + + # Store mean and std for state normalization if provided + if state_mean_std is not None: + self.mean_state = tf.Variable( + state_mean_std[0], trainable=False, dtype=tf.float32 + ) + self.std_state = tf.Variable( + state_mean_std[1], trainable=False, dtype=tf.float32 + ) + else: + self.mean_state = None + self.std_state = None + + # --------------------------------------------------------------------- + # 1) Build the CNN blocks, each named "NatureCNN_i". + # Inside each "NatureCNN_i", create a sub-block "CNN_0" + # with 3 Conv layers named "Conv_0", "Conv_1", "Conv_2". + # --------------------------------------------------------------------- + self.cnn_blocks = [] + self.downstream_denses = [] + for i in range(self.num_pixel_streams): + nature_cnn_block = tf.keras.Sequential(name=f"CNN_{i}") + nature_cnn_block.add( + layers.Conv2D( + 32, + (8, 8), + strides=(4, 4), + activation=self.activation, + use_bias=False, + kernel_initializer=self.kernel_initializer, + name="Conv_0", + padding="same", + ) + ) + nature_cnn_block.add( + layers.Conv2D( + 64, + (4, 4), + strides=(2, 2), + activation=self.activation, + use_bias=False, + kernel_initializer=self.kernel_initializer, + name="Conv_1", + padding="same", + ) + ) + nature_cnn_block.add( + layers.Conv2D( + 64, + (3, 3), + strides=(1, 1), + activation=self.activation, + use_bias=False, + kernel_initializer=self.kernel_initializer, + name="Conv_2", + padding="same", + ) + ) + + # Add the sub-block to the "NatureCNN_i" block + self.cnn_blocks.append(nature_cnn_block) + + # 2) Each CNN output is projected to 128 dims via a Dense_i + proj_dense = layers.Dense( + latent_dense_size, + activation=self.activation, + kernel_initializer=self.kernel_initializer, + name=f"Dense_{i}", + ) + + self.downstream_denses.append(proj_dense) + + # Append two more blocks to downstream_denses. + for j in range(1, 3): # TODO: less sketchy way to ensure correct names. + proj_dense = layers.Dense( + latent_dense_size, + activation=self.activation, + kernel_initializer=self.kernel_initializer, + name=f"Dense_{i+j}", + ) + self.downstream_denses.append(proj_dense) + + # --------------------------------------------------------------------- + # 3) Build the MLP block, named "MLP_0", containing hidden_0..hidden_n. + # --------------------------------------------------------------------- + self.mlp_block = tf.keras.Sequential(name="MLP_0") + for i, size in enumerate(self.layer_sizes): + dense_layer = layers.Dense( + size, + activation=self.activation, + kernel_initializer=self.kernel_initializer, + name=f"hidden_{i}", + ) + self.mlp_block.add(dense_layer) + + # Optionally add layer normalization after each hidden + if self.layer_norm: + self.mlp_block.add(layers.LayerNormalization(name=f"layer_norm_{i}")) + + # Remove activation from the final Dense if activate_final is False + if not self.activate_final and len(self.mlp_block.layers) > 0: + last_layer = self.mlp_block.layers[-1] + if hasattr(last_layer, "activation") and ( + last_layer.activation is not None + ): + last_layer.activation = None + + def normalise_image_channels(self, x): + """ + Apply per-channel normalization over the spatial dimensions (H, W). + Matches JAX linen.LayerNorm(reduction_axes=(-1, -2)) for NHWC format: + compute mean/var over H,W for each N,C. + """ + # x shape: [N, H, W, C] + mean, variance = tf.nn.moments(x, axes=[1, 2], keepdims=True) + x = (x - mean) / tf.sqrt(variance + 1e-5) + return x + + def call( + self, + pixel_streams, + states=None, + depth_latent_streams=None, + training=False, + ): + """ + Forward pass through the model. + + Args: + pixel_streams: A list (or single Tensor) of pixel inputs, + each shape [N, C, H, W]. + states: A Tensor of shape [N, k], representing additional state info. + """ + # If a single tensor is provided, wrap it in a list + if isinstance(pixel_streams, tf.Tensor): + pixel_streams = [pixel_streams] + + if len(pixel_streams) != self.num_pixel_streams: + raise ValueError( + f"Expected {self.num_pixel_streams} pixel streams, " + f"but got {len(pixel_streams)}" + ) + + # Per-channel normalization if enabled: convert to NHWC format first. + if self.normalise_channels: + pixel_streams = [ + self.normalise_image_channels(tf.keras.layers.Permute((2, 3, 1))(x)) + for x in pixel_streams + ] + else: + # Just permute to [N, H, W, C] if not normalizing + pixel_streams = [ + tf.keras.layers.Permute((2, 3, 1))(x) for x in pixel_streams + ] + + projected_cnn_outs = [] # For the policy head + for i, x in enumerate(pixel_streams): + # Pass through "NatureCNN_i" + x = self.cnn_blocks[i](x, training=training) + # Global average pool over H,W; shape becomes [N, filters] + x_pooled = tf.reduce_mean(x, axis=[1, 2]) + x_proj = self.downstream_denses[i](x_pooled) + projected_cnn_outs.append(x_proj) + + ##### ORDER #### + # RGB | DEPTH | Proprio + ################ + if depth_latent_streams is not None: + assert ( + len(depth_latent_streams) == 2 + ), "depth_latent_streams must have 2 streams." + for j, x in enumerate(depth_latent_streams): + x_proj = self.downstream_denses[self.num_pixel_streams + j](x) + projected_cnn_outs.append(x_proj) + + # Optionally normalize states and append them. + if states is not None: + if (self.mean_state is not None) and (self.std_state is not None): + proprio_size = (states.shape[0], 33) + assert states.shape == proprio_size + assert self.mean_state.shape == (33,) + assert self.std_state.shape == (33,) + states = (states - self.mean_state) / (self.std_state) + projected_cnn_outs.append(states) + + # Concatenate projected CNN outputs (and state if provided) + hidden = tf.concat(projected_cnn_outs, axis=-1) + hidden = self.mlp_block(hidden) + + # Compute the policy head output. + # (Here we assume the MLP produces an output of size 2*action_size, + # and we use the first half (after tanh) as the policy output.) + # Adjust this postprocessing as needed. + if hidden.shape[-1] != 2 * self.action_size: + raise ValueError( + f"Expected hidden dim to be {2*self.action_size}, got" + f" {hidden.shape[-1]}" + ) + + return tf.math.tanh(hidden[:, : self.action_size]) + + +def transfer_jax_params_to_tf( + jax_params: tuple[dict, dict], _tf_model: tf.keras.Model +): + tf_model = _tf_model.get_layer(name="tf_vision_mlp") + + encoder_params, policy_params = jax_params + # First, transfer the encoder params. + for ( + block_name, + block_content, + ) in encoder_params.items(): # keys = {Nature_CNN_i} + try: + # e.g. block_name = "NatureCNN_0" or "Dense_0" or "MLP_0" + tf_block = tf_model.get_layer(name=block_name) + except ValueError: + print(f"[Warning] No layer named '{block_name}' found in TF model.") + # Available layers: + print("Available layers:", [l.name for l in tf_model.layers]) + continue + + # --------------------------------------------------------------------- + # CASE 1: A top-level NatureCNN_i block, which is a tf.keras.Sequential + # containing a sub-block named "CNN_0". + # --------------------------------------------------------------------- + if block_name.startswith("NatureCNN_") and isinstance( + tf_block, tf.keras.Sequential + ): + for sub_block_name, sub_block_layers in block_content.items(): + # sub_block_name should be "CNN_0" + try: + tf_sub_block = tf_block.get_layer(name=sub_block_name) + except ValueError: + print( + f" [Warning] No sub-layer '{sub_block_name}' in '{block_name}'." + ) + continue + + # Now sub_block_layers might be {"Conv_0": {kernel/bias}, "Conv_1": ..., ...} + for layer_name, layer_params in sub_block_layers.items(): + try: + tf_layer = tf_sub_block.get_layer(name=layer_name) + except ValueError: + print( + f" [Warning] No layer '{layer_name}' in sub-block" + f" '{sub_block_name}'." + ) + continue + + # e.g. layer_name = "Conv_0" + if isinstance(tf_layer, tf.keras.layers.Conv2D): + kernel = np.array(layer_params["kernel"]) + # Some Conv2D might have bias if use_bias=True, + # but your example used use_bias=False, so we skip it here. + tf_layer.set_weights([kernel]) + print( + "Transferred Conv2D weights to" + f" {block_name}/{sub_block_name}/{layer_name}" + ) + else: + print( + " [Warning] Unhandled layer type in" + f" {block_name}/{sub_block_name}/{layer_name}" + ) + # Then, transfer the policy params. + for ( + block_name, + block_content, + ) in policy_params.items(): # keys = {Dense_i, hidden_i} + try: + if block_name.startswith("Dense_"): + tf_block = tf_model.get_layer(name=block_name) + elif block_name == "MLP_0": + tf_block = tf_model.get_layer(name="MLP_0") + else: + # Must be a hidden layer within MLP_0 + continue + except ValueError: + print(f"[Warning] No layer named '{block_name}' found in TF model.") + print("Available layers:", [l.name for l in tf_model.layers]) + continue + + # --------------------------------------------------------------------- + # CASE A: A top-level Dense_i block, which is a single Dense layer. + # --------------------------------------------------------------------- + if isinstance(tf_block, tf.keras.layers.Dense): + # block_content should have "kernel" and "bias" + kernel = np.array(block_content["kernel"]) + bias = np.array(block_content["bias"]) + tf_block.set_weights([kernel, bias]) + print(f"Transferred Dense weights to {block_name}") + + # --------------------------------------------------------------------- + # CASE B: The MLP_0 block, which contains multiple hidden Dense layers. + # --------------------------------------------------------------------- + elif isinstance(tf_block, tf.keras.Sequential): + for layer_name, layer_params in block_content.items(): + try: + tf_layer = tf_block.get_layer(name=layer_name) + except ValueError: + print(f" [Warning] No layer '{layer_name}' in MLP_0 block.") + continue + + if isinstance(tf_layer, tf.keras.layers.Dense): + kernel = np.array(layer_params["kernel"]) + bias = np.array(layer_params["bias"]) + tf_layer.set_weights([kernel, bias]) + print(f"Transferred Dense weights to MLP_0/{layer_name}") + else: + print(f" [Warning] Unhandled layer type in MLP_0/{layer_name}") + + else: + raise ValueError( + f"Unhandled block '{block_name}' of type {type(tf_block)}" + ) + print("Weight transfer complete.") diff --git a/mujoco_playground/experimental/brax_network_to_onnx.ipynb b/mujoco_playground/experimental/jax2onnx/brax_mlp_to_onnx.ipynb similarity index 100% rename from mujoco_playground/experimental/brax_network_to_onnx.ipynb rename to mujoco_playground/experimental/jax2onnx/brax_mlp_to_onnx.ipynb