Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,38 @@ For a ready-to-use motion capture dataset, you can use the [AMP Dataset on Huggi

---

## ♻️ Symmetry Augmentation

You can mirror demonstrations and policy rollouts by declaring the robot's kinematic symmetry in the training configuration. Provide a `symmetry_cfg` dictionary when building the runner:

```python
symmetry_cfg = {
"joint_pairs": [
["left_hip_yaw", "right_hip_yaw"],
["left_knee", "right_knee"],
# ... add all mirrored joint names here
],
"center_joints": ["torso_yaw"],
"joint_sign_overrides": {
"left_hip_roll": -1.0,
"right_hip_roll": -1.0,
},
"base_linear_sign": [1.0, -1.0, 1.0],
"base_angular_sign": [1.0, -1.0, -1.0],
}
```

Key fields:

- **`joint_pairs`** – mirrored joint name tuples (left, right).
- **`center_joints`** – joints lying on the mirror plane (remain unchanged).
- **`joint_sign_overrides`** – optional per-joint multipliers (±1) for axes that reverse direction (e.g., roll joints).
- **`base_linear_sign` / `base_angular_sign`** – sign flips for base velocities when reflected across the sagittal plane.

When supplied, the dataset loader automatically augments motion clips with mirrored copies. The PPO agent mirrors policy-generated AMP observations as well, so the discriminator receives both original and reflected trajectories without extra code.

---

## 🧑‍💻 Authors

- **Giulio Romualdi** – [@GiulioRomualdi](https://github.com/GiulioRomualdi)
Expand Down
17 changes: 16 additions & 1 deletion amp_rsl_rl/algorithms/amp_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from amp_rsl_rl.storage import ReplayBuffer
from amp_rsl_rl.networks import Discriminator
from amp_rsl_rl.utils import AMPLoader
from amp_rsl_rl.utils import AMPLoader, SymmetryTransform, mirror_amp_transition


class AMP_PPO:
Expand Down Expand Up @@ -69,6 +69,10 @@ class AMP_PPO:
Maximum number of policy transitions stored in the replay buffer for AMP training.
device : str, default="cpu"
The device (CPU or GPU) on which the models will be computed.
symmetry_transform : Optional[SymmetryTransform]
Optional symmetry description used to mirror policy AMP observations for
data augmentation. When provided, every policy transition is mirrored and
stored alongside the original one.
"""

actor_critic: ActorCritic
Expand All @@ -94,6 +98,7 @@ def __init__(
amp_replay_buffer_size: int = 100000,
use_smooth_ratio_clipping: bool = False,
device: str = "cpu",
symmetry_transform: Optional[SymmetryTransform] = None,
) -> None:
# Set device and learning hyperparameters
self.device: str = device
Expand All @@ -112,6 +117,9 @@ def __init__(
)
self.amp_data: AMPLoader = amp_data
self.amp_normalizer: Optional[Any] = amp_normalizer
self.symmetry_transform: Optional[SymmetryTransform] = (
symmetry_transform.to(device) if symmetry_transform is not None else None
)

# Set up the actor-critic (policy) and move it to the device.
self.actor_critic = actor_critic
Expand Down Expand Up @@ -285,6 +293,13 @@ def process_amp_step(self, amp_obs: torch.Tensor) -> None:
The new AMP observation (from expert data or policy update).
"""
self.amp_storage.insert(self.amp_transition.observations, amp_obs)
if self.symmetry_transform is not None:
mirrored_state, mirrored_next = mirror_amp_transition(
self.amp_transition.observations,
amp_obs,
self.symmetry_transform,
)
self.amp_storage.insert(mirrored_state, mirrored_next)
self.amp_transition.clear()

def compute_returns(self, last_critic_obs: torch.Tensor) -> None:
Expand Down
28 changes: 25 additions & 3 deletions amp_rsl_rl/runners/amp_on_policy_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from rsl_rl.utils import store_code_state

from amp_rsl_rl.utils import Normalizer
from amp_rsl_rl.utils import AMPLoader
from amp_rsl_rl.utils import AMPLoader, SymmetrySpec
from amp_rsl_rl.algorithms import AMP_PPO
from amp_rsl_rl.networks import Discriminator, ActorCriticMoE
from amp_rsl_rl.utils import export_policy_as_onnx
Expand Down Expand Up @@ -147,7 +147,26 @@ def __init__(self, env: VecEnv, train_cfg, log_dir=None, device="cpu"):
).to(self.device)
)
# NOTE: to use this we need to configure the observations in the env coherently with amp observation. Tested with Manager Based envs in Isaaclab
amp_joint_names = self.env.cfg.observations.amp.joint_pos.params['asset_cfg'].joint_names
amp_joint_names = self.env.cfg.observations.amp.joint_pos.params[
"asset_cfg"
].joint_names

symmetry_cfg = self.cfg.get("symmetry_cfg")
symmetry_spec = None
if symmetry_cfg:
joint_pairs = [tuple(pair) for pair in symmetry_cfg.get("joint_pairs", [])]
symmetry_spec = SymmetrySpec(
joint_pairs=joint_pairs,
center_joints=tuple(symmetry_cfg.get("center_joints", ())),
joint_sign_overrides=symmetry_cfg.get("joint_sign_overrides", {}),
base_linear_sign=tuple(
symmetry_cfg.get("base_linear_sign", (1.0, -1.0, 1.0))
),
base_angular_sign=tuple(
symmetry_cfg.get("base_angular_sign", (1.0, -1.0, -1.0))
),
allow_unmapped=symmetry_cfg.get("allow_unmapped", True),
)

delta_t = self.env.cfg.sim.dt * self.env.cfg.decimation

Expand All @@ -161,14 +180,16 @@ def __init__(self, env: VecEnv, train_cfg, log_dir=None, device="cpu"):
delta_t,
self.cfg["slow_down_factor"],
amp_joint_names,
symmetry_spec=symmetry_spec,
)

# self.env.unwrapped.scene["robot"].joint_names)

# amp_data = AMPLoader(num_amp_obs, self.device)
self.amp_normalizer = Normalizer(num_amp_obs, device=self.device)
self.discriminator = Discriminator(
input_dim=num_amp_obs* 2, # the discriminator takes in the concatenation of the current and next observation
input_dim=num_amp_obs
* 2, # the discriminator takes in the concatenation of the current and next observation
hidden_layer_sizes=self.discriminator_cfg["hidden_dims"],
reward_scale=self.discriminator_cfg["reward_scale"],
device=self.device,
Expand All @@ -192,6 +213,7 @@ def __init__(self, env: VecEnv, train_cfg, log_dir=None, device="cpu"):
amp_data=amp_data,
amp_normalizer=self.amp_normalizer,
device=self.device,
symmetry_transform=amp_data.symmetry_transform,
**self.alg_cfg,
)
self.num_steps_per_env = self.cfg["num_steps_per_env"]
Expand Down
14 changes: 14 additions & 0 deletions amp_rsl_rl/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,25 @@
from .utils import Normalizer, RunningMeanStd
from .motion_loader import AMPLoader, download_amp_dataset_from_hf
from .exporter import export_policy_as_onnx
from .symmetry import (
SymmetrySpec,
SymmetryTransform,
apply_base_symmetry,
apply_joint_symmetry,
mirror_amp_observation,
mirror_amp_transition,
)

__all__ = [
"Normalizer",
"RunningMeanStd",
"AMPLoader",
"download_amp_dataset_from_hf",
"export_policy_as_onnx",
"SymmetrySpec",
"SymmetryTransform",
"apply_joint_symmetry",
"apply_base_symmetry",
"mirror_amp_observation",
"mirror_amp_transition",
]
82 changes: 80 additions & 2 deletions amp_rsl_rl/utils/motion_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# SPDX-License-Identifier: BSD-3-Clause

from pathlib import Path
from typing import List, Union, Tuple, Generator
from typing import List, Union, Tuple, Generator, Optional
from dataclasses import dataclass

import torch
Expand All @@ -13,6 +13,41 @@
from scipy.interpolate import interp1d


def _mirror_quaternion(
quat_wxyz: torch.Tensor, linear_sign: torch.Tensor
) -> torch.Tensor:
"""Mirror orientation quaternions using a diagonal reflection defined by ``linear_sign``."""

if quat_wxyz.numel() == 0:
return quat_wxyz

device = quat_wxyz.device
dtype = quat_wxyz.dtype
diag = linear_sign.detach().to(torch.float64).cpu().numpy()

quat_xyzw = torch.cat((quat_wxyz[..., 1:], quat_wxyz[..., :1]), dim=-1)
quat_np = quat_xyzw.detach().to(torch.float64).cpu().numpy()
rot = Rotation.from_quat(quat_np)

mat = rot.as_matrix()
left = diag.reshape(1, 3, 1)
right = diag.reshape(1, 1, 3)
mirrored_mat = mat * left * right

mirrored_rot = Rotation.from_matrix(mirrored_mat)
mirrored_xyzw = torch.tensor(mirrored_rot.as_quat(), dtype=dtype, device=device)
# Convert back to wxyz convention
return torch.cat((mirrored_xyzw[..., 3:], mirrored_xyzw[..., :3]), dim=-1)


from .symmetry import (
SymmetrySpec,
SymmetryTransform,
apply_base_symmetry,
apply_joint_symmetry,
)


def download_amp_dataset_from_hf(
destination_dir: Path,
robot_folder: str,
Expand Down Expand Up @@ -157,6 +192,29 @@ def get_random_sample_for_reset(self, items: int = 1) -> Tuple[torch.Tensor, ...
indices = torch.randint(0, len(self), (items,), device=self.device)
return self.get_state_for_reset(indices)

def mirrored(self, transform: SymmetryTransform) -> "MotionData":
"""Return a mirrored copy of the motion according to ``transform``."""

joint_positions = apply_joint_symmetry(self.joint_positions, transform)
joint_velocities = apply_joint_symmetry(self.joint_velocities, transform)
base_lin_vel_mixed = self.base_lin_velocities_mixed * transform.base_lin_sign
base_ang_vel_mixed = self.base_ang_velocities_mixed * transform.base_ang_sign

base_lin_vel_local, base_ang_vel_local = apply_base_symmetry(
self.base_lin_velocities_local, self.base_ang_velocities_local, transform
)
base_quat = _mirror_quaternion(self.base_quat, transform.base_lin_sign)
return MotionData(
joint_positions=joint_positions,
joint_velocities=joint_velocities,
base_lin_velocities_mixed=base_lin_vel_mixed,
base_ang_velocities_mixed=base_ang_vel_mixed,
base_lin_velocities_local=base_lin_vel_local,
base_ang_velocities_local=base_ang_vel_local,
base_quat=base_quat,
device=self.device,
)


class AMPLoader:
"""
Expand Down Expand Up @@ -185,6 +243,8 @@ class AMPLoader:
simulation_dt: Timestep used by the simulator
slow_down_factor: Integer factor to slow down original data
expected_joint_names: (Optional) override for joint ordering
symmetry_spec: (Optional) symmetry description used to mirror motions and
augment the dataset on the fly.
"""

def __init__(
Expand All @@ -196,6 +256,7 @@ def __init__(
simulation_dt: float,
slow_down_factor: int,
expected_joint_names: Union[List[str], None] = None,
symmetry_spec: Optional[SymmetrySpec] = None,
) -> None:
self.device = device
if isinstance(dataset_path_root, str):
Expand All @@ -217,6 +278,12 @@ def __init__(

# Load and process each dataset into MotionData
self.motion_data: List[MotionData] = []
self.symmetry_spec = symmetry_spec
self.symmetry_transform: Optional[SymmetryTransform] = None
if self.symmetry_spec is not None:
self.symmetry_transform = self.symmetry_spec.build_transform(
expected_joint_names
).to(self.device)
for dataset_name in dataset_names:
dataset_path = dataset_path_root / f"{dataset_name}.npy"
md = self.load_data(
Expand All @@ -226,9 +293,20 @@ def __init__(
expected_joint_names,
)
self.motion_data.append(md)
if self.symmetry_transform is not None:
mirrored_md = md.mirrored(self.symmetry_transform)
self.motion_data.append(mirrored_md)

# Normalize dataset-level sampling weights
weights = torch.tensor(dataset_weights, dtype=torch.float32, device=self.device)
augmented_weights: List[float] = []
for weight in dataset_weights:
augmented_weights.append(weight)
if self.symmetry_transform is not None:
augmented_weights.append(weight)

weights = torch.tensor(
augmented_weights, dtype=torch.float32, device=self.device
)
self.dataset_weights = weights / weights.sum()

# Precompute flat buffers for fast sampling
Expand Down
Loading