diff --git a/README.md b/README.md index 411d385..bee9fee 100644 --- a/README.md +++ b/README.md @@ -105,6 +105,38 @@ For a ready-to-use motion capture dataset, you can use the [AMP Dataset on Huggi --- +## ♻️ Symmetry Augmentation + +You can mirror demonstrations and policy rollouts by declaring the robot's kinematic symmetry in the training configuration. Provide a `symmetry_cfg` dictionary when building the runner: + +```python +symmetry_cfg = { + "joint_pairs": [ + ["left_hip_yaw", "right_hip_yaw"], + ["left_knee", "right_knee"], + # ... add all mirrored joint names here + ], + "center_joints": ["torso_yaw"], + "joint_sign_overrides": { + "left_hip_roll": -1.0, + "right_hip_roll": -1.0, + }, + "base_linear_sign": [1.0, -1.0, 1.0], + "base_angular_sign": [1.0, -1.0, -1.0], +} +``` + +Key fields: + +- **`joint_pairs`** – mirrored joint name tuples (left, right). +- **`center_joints`** – joints lying on the mirror plane (remain unchanged). +- **`joint_sign_overrides`** – optional per-joint multipliers (±1) for axes that reverse direction (e.g., roll joints). +- **`base_linear_sign` / `base_angular_sign`** – sign flips for base velocities when reflected across the sagittal plane. + +When supplied, the dataset loader automatically augments motion clips with mirrored copies. The PPO agent mirrors policy-generated AMP observations as well, so the discriminator receives both original and reflected trajectories without extra code. + +--- + ## 🧑‍💻 Authors - **Giulio Romualdi** – [@GiulioRomualdi](https://github.com/GiulioRomualdi) diff --git a/amp_rsl_rl/algorithms/amp_ppo.py b/amp_rsl_rl/algorithms/amp_ppo.py index c67e04c..2c9a328 100644 --- a/amp_rsl_rl/algorithms/amp_ppo.py +++ b/amp_rsl_rl/algorithms/amp_ppo.py @@ -18,7 +18,7 @@ from amp_rsl_rl.storage import ReplayBuffer from amp_rsl_rl.networks import Discriminator -from amp_rsl_rl.utils import AMPLoader +from amp_rsl_rl.utils import AMPLoader, SymmetryTransform, mirror_amp_transition class AMP_PPO: @@ -69,6 +69,10 @@ class AMP_PPO: Maximum number of policy transitions stored in the replay buffer for AMP training. device : str, default="cpu" The device (CPU or GPU) on which the models will be computed. + symmetry_transform : Optional[SymmetryTransform] + Optional symmetry description used to mirror policy AMP observations for + data augmentation. When provided, every policy transition is mirrored and + stored alongside the original one. """ actor_critic: ActorCritic @@ -94,6 +98,7 @@ def __init__( amp_replay_buffer_size: int = 100000, use_smooth_ratio_clipping: bool = False, device: str = "cpu", + symmetry_transform: Optional[SymmetryTransform] = None, ) -> None: # Set device and learning hyperparameters self.device: str = device @@ -112,6 +117,9 @@ def __init__( ) self.amp_data: AMPLoader = amp_data self.amp_normalizer: Optional[Any] = amp_normalizer + self.symmetry_transform: Optional[SymmetryTransform] = ( + symmetry_transform.to(device) if symmetry_transform is not None else None + ) # Set up the actor-critic (policy) and move it to the device. self.actor_critic = actor_critic @@ -285,6 +293,13 @@ def process_amp_step(self, amp_obs: torch.Tensor) -> None: The new AMP observation (from expert data or policy update). """ self.amp_storage.insert(self.amp_transition.observations, amp_obs) + if self.symmetry_transform is not None: + mirrored_state, mirrored_next = mirror_amp_transition( + self.amp_transition.observations, + amp_obs, + self.symmetry_transform, + ) + self.amp_storage.insert(mirrored_state, mirrored_next) self.amp_transition.clear() def compute_returns(self, last_critic_obs: torch.Tensor) -> None: diff --git a/amp_rsl_rl/runners/amp_on_policy_runner.py b/amp_rsl_rl/runners/amp_on_policy_runner.py index 7810f90..52b4868 100644 --- a/amp_rsl_rl/runners/amp_on_policy_runner.py +++ b/amp_rsl_rl/runners/amp_on_policy_runner.py @@ -21,7 +21,7 @@ from rsl_rl.utils import store_code_state from amp_rsl_rl.utils import Normalizer -from amp_rsl_rl.utils import AMPLoader +from amp_rsl_rl.utils import AMPLoader, SymmetrySpec from amp_rsl_rl.algorithms import AMP_PPO from amp_rsl_rl.networks import Discriminator, ActorCriticMoE from amp_rsl_rl.utils import export_policy_as_onnx @@ -147,7 +147,26 @@ def __init__(self, env: VecEnv, train_cfg, log_dir=None, device="cpu"): ).to(self.device) ) # NOTE: to use this we need to configure the observations in the env coherently with amp observation. Tested with Manager Based envs in Isaaclab - amp_joint_names = self.env.cfg.observations.amp.joint_pos.params['asset_cfg'].joint_names + amp_joint_names = self.env.cfg.observations.amp.joint_pos.params[ + "asset_cfg" + ].joint_names + + symmetry_cfg = self.cfg.get("symmetry_cfg") + symmetry_spec = None + if symmetry_cfg: + joint_pairs = [tuple(pair) for pair in symmetry_cfg.get("joint_pairs", [])] + symmetry_spec = SymmetrySpec( + joint_pairs=joint_pairs, + center_joints=tuple(symmetry_cfg.get("center_joints", ())), + joint_sign_overrides=symmetry_cfg.get("joint_sign_overrides", {}), + base_linear_sign=tuple( + symmetry_cfg.get("base_linear_sign", (1.0, -1.0, 1.0)) + ), + base_angular_sign=tuple( + symmetry_cfg.get("base_angular_sign", (1.0, -1.0, -1.0)) + ), + allow_unmapped=symmetry_cfg.get("allow_unmapped", True), + ) delta_t = self.env.cfg.sim.dt * self.env.cfg.decimation @@ -161,6 +180,7 @@ def __init__(self, env: VecEnv, train_cfg, log_dir=None, device="cpu"): delta_t, self.cfg["slow_down_factor"], amp_joint_names, + symmetry_spec=symmetry_spec, ) # self.env.unwrapped.scene["robot"].joint_names) @@ -168,7 +188,8 @@ def __init__(self, env: VecEnv, train_cfg, log_dir=None, device="cpu"): # amp_data = AMPLoader(num_amp_obs, self.device) self.amp_normalizer = Normalizer(num_amp_obs, device=self.device) self.discriminator = Discriminator( - input_dim=num_amp_obs* 2, # the discriminator takes in the concatenation of the current and next observation + input_dim=num_amp_obs + * 2, # the discriminator takes in the concatenation of the current and next observation hidden_layer_sizes=self.discriminator_cfg["hidden_dims"], reward_scale=self.discriminator_cfg["reward_scale"], device=self.device, @@ -192,6 +213,7 @@ def __init__(self, env: VecEnv, train_cfg, log_dir=None, device="cpu"): amp_data=amp_data, amp_normalizer=self.amp_normalizer, device=self.device, + symmetry_transform=amp_data.symmetry_transform, **self.alg_cfg, ) self.num_steps_per_env = self.cfg["num_steps_per_env"] diff --git a/amp_rsl_rl/utils/__init__.py b/amp_rsl_rl/utils/__init__.py index 0ede374..8215da5 100644 --- a/amp_rsl_rl/utils/__init__.py +++ b/amp_rsl_rl/utils/__init__.py @@ -9,6 +9,14 @@ from .utils import Normalizer, RunningMeanStd from .motion_loader import AMPLoader, download_amp_dataset_from_hf from .exporter import export_policy_as_onnx +from .symmetry import ( + SymmetrySpec, + SymmetryTransform, + apply_base_symmetry, + apply_joint_symmetry, + mirror_amp_observation, + mirror_amp_transition, +) __all__ = [ "Normalizer", @@ -16,4 +24,10 @@ "AMPLoader", "download_amp_dataset_from_hf", "export_policy_as_onnx", + "SymmetrySpec", + "SymmetryTransform", + "apply_joint_symmetry", + "apply_base_symmetry", + "mirror_amp_observation", + "mirror_amp_transition", ] diff --git a/amp_rsl_rl/utils/motion_loader.py b/amp_rsl_rl/utils/motion_loader.py index 739bd83..c3b9c3a 100644 --- a/amp_rsl_rl/utils/motion_loader.py +++ b/amp_rsl_rl/utils/motion_loader.py @@ -4,7 +4,7 @@ # SPDX-License-Identifier: BSD-3-Clause from pathlib import Path -from typing import List, Union, Tuple, Generator +from typing import List, Union, Tuple, Generator, Optional from dataclasses import dataclass import torch @@ -13,6 +13,41 @@ from scipy.interpolate import interp1d +def _mirror_quaternion( + quat_wxyz: torch.Tensor, linear_sign: torch.Tensor +) -> torch.Tensor: + """Mirror orientation quaternions using a diagonal reflection defined by ``linear_sign``.""" + + if quat_wxyz.numel() == 0: + return quat_wxyz + + device = quat_wxyz.device + dtype = quat_wxyz.dtype + diag = linear_sign.detach().to(torch.float64).cpu().numpy() + + quat_xyzw = torch.cat((quat_wxyz[..., 1:], quat_wxyz[..., :1]), dim=-1) + quat_np = quat_xyzw.detach().to(torch.float64).cpu().numpy() + rot = Rotation.from_quat(quat_np) + + mat = rot.as_matrix() + left = diag.reshape(1, 3, 1) + right = diag.reshape(1, 1, 3) + mirrored_mat = mat * left * right + + mirrored_rot = Rotation.from_matrix(mirrored_mat) + mirrored_xyzw = torch.tensor(mirrored_rot.as_quat(), dtype=dtype, device=device) + # Convert back to wxyz convention + return torch.cat((mirrored_xyzw[..., 3:], mirrored_xyzw[..., :3]), dim=-1) + + +from .symmetry import ( + SymmetrySpec, + SymmetryTransform, + apply_base_symmetry, + apply_joint_symmetry, +) + + def download_amp_dataset_from_hf( destination_dir: Path, robot_folder: str, @@ -157,6 +192,29 @@ def get_random_sample_for_reset(self, items: int = 1) -> Tuple[torch.Tensor, ... indices = torch.randint(0, len(self), (items,), device=self.device) return self.get_state_for_reset(indices) + def mirrored(self, transform: SymmetryTransform) -> "MotionData": + """Return a mirrored copy of the motion according to ``transform``.""" + + joint_positions = apply_joint_symmetry(self.joint_positions, transform) + joint_velocities = apply_joint_symmetry(self.joint_velocities, transform) + base_lin_vel_mixed = self.base_lin_velocities_mixed * transform.base_lin_sign + base_ang_vel_mixed = self.base_ang_velocities_mixed * transform.base_ang_sign + + base_lin_vel_local, base_ang_vel_local = apply_base_symmetry( + self.base_lin_velocities_local, self.base_ang_velocities_local, transform + ) + base_quat = _mirror_quaternion(self.base_quat, transform.base_lin_sign) + return MotionData( + joint_positions=joint_positions, + joint_velocities=joint_velocities, + base_lin_velocities_mixed=base_lin_vel_mixed, + base_ang_velocities_mixed=base_ang_vel_mixed, + base_lin_velocities_local=base_lin_vel_local, + base_ang_velocities_local=base_ang_vel_local, + base_quat=base_quat, + device=self.device, + ) + class AMPLoader: """ @@ -185,6 +243,8 @@ class AMPLoader: simulation_dt: Timestep used by the simulator slow_down_factor: Integer factor to slow down original data expected_joint_names: (Optional) override for joint ordering + symmetry_spec: (Optional) symmetry description used to mirror motions and + augment the dataset on the fly. """ def __init__( @@ -196,6 +256,7 @@ def __init__( simulation_dt: float, slow_down_factor: int, expected_joint_names: Union[List[str], None] = None, + symmetry_spec: Optional[SymmetrySpec] = None, ) -> None: self.device = device if isinstance(dataset_path_root, str): @@ -217,6 +278,12 @@ def __init__( # Load and process each dataset into MotionData self.motion_data: List[MotionData] = [] + self.symmetry_spec = symmetry_spec + self.symmetry_transform: Optional[SymmetryTransform] = None + if self.symmetry_spec is not None: + self.symmetry_transform = self.symmetry_spec.build_transform( + expected_joint_names + ).to(self.device) for dataset_name in dataset_names: dataset_path = dataset_path_root / f"{dataset_name}.npy" md = self.load_data( @@ -226,9 +293,20 @@ def __init__( expected_joint_names, ) self.motion_data.append(md) + if self.symmetry_transform is not None: + mirrored_md = md.mirrored(self.symmetry_transform) + self.motion_data.append(mirrored_md) # Normalize dataset-level sampling weights - weights = torch.tensor(dataset_weights, dtype=torch.float32, device=self.device) + augmented_weights: List[float] = [] + for weight in dataset_weights: + augmented_weights.append(weight) + if self.symmetry_transform is not None: + augmented_weights.append(weight) + + weights = torch.tensor( + augmented_weights, dtype=torch.float32, device=self.device + ) self.dataset_weights = weights / weights.sum() # Precompute flat buffers for fast sampling diff --git a/amp_rsl_rl/utils/symmetry.py b/amp_rsl_rl/utils/symmetry.py new file mode 100644 index 0000000..5c1a647 --- /dev/null +++ b/amp_rsl_rl/utils/symmetry.py @@ -0,0 +1,307 @@ +# Copyright (c) 2025, Istituto Italiano di Tecnologia +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Utilities to describe and apply kinematic symmetries. + +The symmetry helpers exposed here allow users to describe how left/right (or any +other mirrored) joints of a robot relate to each other. They produce efficient +permutation and sign tensors that can be reused across the AMP pipeline for +augmentation and observation manipulation. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Iterable, Mapping, MutableMapping, Sequence, Tuple + +import torch + + +@dataclass(frozen=True) +class SymmetryTransform: + """Permutation and sign information to mirror joint-space quantities. + + Attributes + ---------- + joint_permutation: + Long tensor of shape ``(num_joints,)`` describing how indices must be + permuted to obtain the mirrored order. The tensor maps the original index + ``i`` to ``joint_permutation[i]``. + joint_sign: + Float tensor of shape ``(num_joints,)`` encoding sign flips to apply after + the permutation. Use ``+1`` to keep the value and ``-1`` to change sign. + base_lin_sign: + Float tensor of shape ``(3,)`` with sign flips for the (x, y, z) base linear + velocity components. + base_ang_sign: + Float tensor of shape ``(3,)`` with sign flips for the (roll, pitch, yaw) + base angular velocity components. + """ + + joint_permutation: torch.Tensor + joint_sign: torch.Tensor + base_lin_sign: torch.Tensor + base_ang_sign: torch.Tensor + + def to(self, device: torch.device | str) -> "SymmetryTransform": + """Returns a copy of the transform moved to ``device``.""" + + dev = torch.device(device) + return SymmetryTransform( + joint_permutation=self.joint_permutation.to(dev), + joint_sign=self.joint_sign.to(dev), + base_lin_sign=self.base_lin_sign.to(dev), + base_ang_sign=self.base_ang_sign.to(dev), + ) + + +@dataclass(frozen=True) +class SymmetrySpec: + """Declarative description of the robot's mirrored joints. + + Parameters + ---------- + joint_pairs: + Iterable of ``(left, right)`` tuples specifying pairs of mirrored joints. + Each joint name must appear at most once across all tuples. + center_joints: + Optional iterable listing joints that lie on the mirror plane (e.g., torso + joints). These joints keep their position during mirroring. Defaults to an + empty tuple. + joint_sign_overrides: + Optional mapping ``joint_name -> sign`` that allows forcing a ``+1`` or + ``-1`` multiplier for specific joints. This is useful for revolute joints + whose axes reverse direction when swapped between sides. + base_linear_sign: + Tuple of three multipliers for mirrored base linear velocity components. A + sagittal-plane mirror typically keeps ``v_x`` and ``v_z`` while flipping + ``v_y`` (default: ``(1.0, -1.0, 1.0)``). + base_angular_sign: + Tuple of three multipliers for mirrored base angular velocity components. + Defaults to ``(1.0, -1.0, -1.0)``, assuming roll and yaw change sign when + reflected across the sagittal plane. + allow_unmapped: + When ``True`` (default), joints that do not belong to ``joint_pairs`` nor + ``center_joints`` keep their original index. Setting it to ``False`` raises + ``ValueError`` if such joints are encountered. + """ + + joint_pairs: Sequence[Tuple[str, str]] + center_joints: Sequence[str] = field(default_factory=tuple) + joint_sign_overrides: Mapping[str, float] = field(default_factory=dict) + base_linear_sign: Tuple[float, float, float] = (1.0, -1.0, 1.0) + base_angular_sign: Tuple[float, float, float] = (1.0, -1.0, -1.0) + allow_unmapped: bool = True + + def build_transform(self, joint_names: Sequence[str]) -> SymmetryTransform: + """Create a :class:`SymmetryTransform` for ``joint_names``. + + Parameters + ---------- + joint_names: + Canonical ordering of the joints used throughout the observations and + motion datasets. + + Returns + ------- + SymmetryTransform + Ready-to-use permutation and sign tensors that can be applied to joint + position, velocity, or torque vectors that follow ``joint_names``. + """ + + joint_to_index: MutableMapping[str, int] = { + name: i for i, name in enumerate(joint_names) + } + + permutation = torch.arange(len(joint_names), dtype=torch.long) + joint_sign = torch.ones(len(joint_names), dtype=torch.float32) + + # Validate uniqueness of joint usage. + seen: set[str] = set() + for left, right in self.joint_pairs: + if left == right: + raise ValueError( + f"Joint pair ({left}, {right}) cannot contain identical entries." + ) + if left in seen or right in seen: + raise ValueError( + f"Joint '{left if left in seen else right}' appears in multiple symmetry pairs." + ) + seen.add(left) + seen.add(right) + + if left not in joint_to_index: + if self.allow_unmapped: + continue + raise ValueError( + f"Joint '{left}' from symmetry pairs missing in provided joint list." + ) + if right not in joint_to_index: + if self.allow_unmapped: + continue + raise ValueError( + f"Joint '{right}' from symmetry pairs missing in provided joint list." + ) + + left_idx = joint_to_index[left] + right_idx = joint_to_index[right] + permutation[left_idx] = right_idx + permutation[right_idx] = left_idx + + center_set = set(self.center_joints) + if len(center_set) != len(tuple(self.center_joints)): + raise ValueError("`center_joints` must not contain duplicates.") + + for joint in center_set: + if joint not in joint_to_index and not self.allow_unmapped: + raise ValueError( + f"Center joint '{joint}' missing in provided joint list." + ) + + # Apply joint-level sign overrides. + for joint_name, sign in self.joint_sign_overrides.items(): + if abs(sign) != 1.0: + raise ValueError( + f"Sign override for joint '{joint_name}' must be either +1 or -1, got {sign}." + ) + if joint_name not in joint_to_index: + if self.allow_unmapped: + continue + raise ValueError( + f"Joint '{joint_name}' lacks an index in the provided joint list." + ) + idx = joint_to_index[joint_name] + joint_sign[idx] = float(sign) + + # Ensure center joints keep their original indices. + for joint in center_set: + if joint in joint_to_index: + idx = joint_to_index[joint] + permutation[idx] = idx + + missing: set[str] = set(joint_names) - seen - center_set + if not self.allow_unmapped: + if missing: + raise ValueError( + "The following joints are neither paired nor centered: " + + ", ".join(sorted(missing)) + ) + else: + # Unmapped joints keep their position and default sign (already set). + pass + + base_lin_sign = torch.tensor(self.base_linear_sign, dtype=torch.float32) + base_ang_sign = torch.tensor(self.base_angular_sign, dtype=torch.float32) + + return SymmetryTransform( + joint_permutation=permutation, + joint_sign=joint_sign, + base_lin_sign=base_lin_sign, + base_ang_sign=base_ang_sign, + ) + + +def apply_joint_symmetry( + joint_tensor: torch.Tensor, + transform: SymmetryTransform, +) -> torch.Tensor: + """Mirror ``joint_tensor`` using ``transform``. + + Parameters + ---------- + joint_tensor: + Tensor of shape ``(..., num_joints)`` following the ordering used to build + ``transform``. The function mirrors the last dimension. + transform: + :class:`SymmetryTransform` built for the desired joint ordering. + + Returns + ------- + torch.Tensor + Mirrored tensor with the same shape as ``joint_tensor``. + """ + + mirrored = torch.index_select( + joint_tensor, dim=-1, index=transform.joint_permutation + ) + return mirrored * transform.joint_sign + + +def apply_base_symmetry( + base_linear: torch.Tensor, + base_angular: torch.Tensor, + transform: SymmetryTransform, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Apply base velocity symmetry flips. + + Parameters + ---------- + base_linear: + Tensor with shape ``(..., 3)`` representing linear velocity components. + base_angular: + Tensor with shape ``(..., 3)`` representing angular velocity components. + transform: + Symmetry transform returned by :meth:`SymmetrySpec.build_transform`. + + Returns + ------- + Tuple[torch.Tensor, torch.Tensor] + Mirrored base linear and angular velocities. + """ + + mirrored_linear = base_linear * transform.base_lin_sign + mirrored_angular = base_angular * transform.base_ang_sign + return mirrored_linear, mirrored_angular + + +def mirror_amp_observation( + observation: torch.Tensor, transform: SymmetryTransform +) -> torch.Tensor: + """Mirror an AMP observation vector following the canonical layout. + + The observation is expected to have the structure: + + ``[joint_pos, joint_vel, base_lin_vel, base_ang_vel]`` + + Parameters + ---------- + observation: + Tensor whose final dimension follows the structure above. + transform: + Symmetry transform returned by :class:`SymmetrySpec`. + + Returns + ------- + torch.Tensor + Mirrored observation with the same shape as the input. + """ + + joint_dim = transform.joint_permutation.numel() + expected_dim = 2 * joint_dim + 6 + if observation.shape[-1] != expected_dim: + raise ValueError( + f"AMP observation last dimension must be {expected_dim}, got {observation.shape[-1]}" + ) + joint_pos = observation[..., :joint_dim] + joint_vel = observation[..., joint_dim : 2 * joint_dim] + base_lin = observation[..., 2 * joint_dim : 2 * joint_dim + 3] + base_ang = observation[..., 2 * joint_dim + 3 : 2 * joint_dim + 6] + + mirrored_pos = apply_joint_symmetry(joint_pos, transform) + mirrored_vel = apply_joint_symmetry(joint_vel, transform) + mirrored_lin, mirrored_ang = apply_base_symmetry(base_lin, base_ang, transform) + + return torch.cat((mirrored_pos, mirrored_vel, mirrored_lin, mirrored_ang), dim=-1) + + +def mirror_amp_transition( + state: torch.Tensor, next_state: torch.Tensor, transform: SymmetryTransform +) -> Tuple[torch.Tensor, torch.Tensor]: + """Mirror a pair of AMP observations representing consecutive states.""" + + return ( + mirror_amp_observation(state, transform), + mirror_amp_observation(next_state, transform), + )