diff --git a/README.md b/README.md index 411d385..545d0d2 100644 --- a/README.md +++ b/README.md @@ -62,7 +62,25 @@ amp_rsl_rl/ --- -## 📁 Dataset Structure +## � Symmetry Augmentation + +AMP-RSL-RL now exposes the symmetry-aware data augmentation and mirror-loss hooks from +[RSL-RL](https://github.com/leggedrobotics/rsl_rl). The implementation follows the design +described in: + +> Mittal, M., Rudin, N., Klemm, V., Allshire, A., & Hutter, M. (2024).
+> *Symmetry Considerations for Learning Task Symmetric Robot Policies*. In IEEE International Conference on Robotics and Automation (ICRA).
+> https://doi.org/10.1109/ICRA57147.2024.10611493 + +Symmetry augmentation can be enabled through the `symmetry_cfg` section of the algorithm +configuration, providing both minibatch augmentation and optional mirror-loss regularisation +for the policy update. AMP-specific components (the discriminator and expert/policy motion +buffers) are augmented using the same configuration so that style rewards and adversarial +training remain consistent with their symmetric counterparts. + +--- + +## �📁 Dataset Structure The AMP-RSL-RL framework expects motion capture datasets in `.npy` format. Each `.npy` file must contain a Python dictionary with the following keys: diff --git a/amp_rsl_rl/algorithms/amp_ppo.py b/amp_rsl_rl/algorithms/amp_ppo.py index c67e04c..692789c 100644 --- a/amp_rsl_rl/algorithms/amp_ppo.py +++ b/amp_rsl_rl/algorithms/amp_ppo.py @@ -6,6 +6,7 @@ from __future__ import annotations +import inspect from typing import Optional, Tuple, Dict, Any import torch @@ -14,6 +15,7 @@ # External modules providing the actor-critic model, storage utilities, and AMP components. from rsl_rl.modules import ActorCritic +from rsl_rl.utils import string_to_callable from rsl_rl.storage import RolloutStorage from amp_rsl_rl.storage import ReplayBuffer @@ -67,6 +69,10 @@ class AMP_PPO: Target KL divergence for the adaptive learning rate schedule. amp_replay_buffer_size : int, default=100000 Maximum number of policy transitions stored in the replay buffer for AMP training. + normalize_advantage_per_mini_batch : bool, default=False + Whether to normalize advantages within each mini-batch (instead of the entire rollout). + symmetry_cfg : dict | None, default=None + Configuration dictionary enabling symmetry-based data augmentation and mirror loss. device : str, default="cpu" The device (CPU or GPU) on which the models will be computed. """ @@ -93,6 +99,8 @@ def __init__( desired_kl: float = 0.01, amp_replay_buffer_size: int = 100000, use_smooth_ratio_clipping: bool = False, + normalize_advantage_per_mini_batch: bool = False, + symmetry_cfg: Optional[Dict[str, Any]] = None, device: str = "cpu", ) -> None: # Set device and learning hyperparameters @@ -100,6 +108,9 @@ def __init__( self.desired_kl: float = desired_kl self.schedule: str = schedule self.learning_rate: float = learning_rate + self.normalize_advantage_per_mini_batch: bool = ( + normalize_advantage_per_mini_batch + ) # Set up the discriminator and move it to the appropriate device. self.discriminator: Discriminator = discriminator.to(self.device) @@ -149,6 +160,30 @@ def __init__( self.use_clipped_value_loss: bool = use_clipped_value_loss self.use_smooth_ratio_clipping: bool = use_smooth_ratio_clipping + # Symmetry configuration for PPO and AMP augmentation + self.symmetry: Optional[Dict[str, Any]] = None + if symmetry_cfg is not None: + use_symmetry = symmetry_cfg.get("use_data_augmentation", False) or symmetry_cfg.get( + "use_mirror_loss", False + ) + if not use_symmetry: + print( + "Symmetry configuration provided but neither data augmentation nor mirror loss are enabled." + " Symmetry utilities will only be available for logging." + ) + aug_fn = symmetry_cfg.get("data_augmentation_func") + if isinstance(aug_fn, str): + symmetry_cfg["data_augmentation_func"] = string_to_callable(aug_fn) + aug_fn = symmetry_cfg.get("data_augmentation_func") + if aug_fn is not None and not callable(aug_fn): + raise ValueError( + "Symmetry configuration exists but the function is not callable: " + f"{aug_fn}" + ) + if getattr(actor_critic, "is_recurrent", False): + raise ValueError("Symmetry augmentation is not supported for recurrent policies in AMP_PPO.") + self.symmetry = symmetry_cfg + def init_storage( self, num_envs: int, @@ -184,6 +219,68 @@ def init_storage( device=self.device, ) + def _augment_batch_size(self, original_size: int, augmented: Optional[torch.Tensor]) -> int: + """Compute augmentation factor given the original and augmented batch sizes.""" + + if augmented is None or original_size == 0: + return 1 + if augmented.shape[0] % original_size != 0: + raise ValueError( + "Symmetry augmentation function returned a batch size incompatible with the original size." + f" Original={original_size}, augmented={augmented.shape[0]}" + ) + return augmented.shape[0] // original_size + + def _repeat_along_batch( + self, tensor: Optional[torch.Tensor], num_aug: int + ) -> Optional[torch.Tensor]: + """Repeat a tensor along the first dimension to match augmentation factor.""" + + if tensor is None or num_aug == 1: + return tensor + repeat_dims = [num_aug] + [1] * (tensor.dim() - 1) + return tensor.repeat(*repeat_dims) + + def _apply_symmetry( + self, + *, + obs: Optional[torch.Tensor], + actions: Optional[torch.Tensor], + obs_type: Optional[str] = None, + ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: + """Apply configured symmetry augmentation to observations/actions.""" + + if self.symmetry is None: + return obs, actions + + aug_fn = self.symmetry.get("data_augmentation_func") + if aug_fn is None: + return obs, actions + + signature = inspect.signature(aug_fn) + kwargs: Dict[str, Any] = {} + if "obs" in signature.parameters: + kwargs["obs"] = obs + if "actions" in signature.parameters: + kwargs["actions"] = actions + env_ref = self.symmetry.get("_env") + if "env" in signature.parameters: + kwargs["env"] = env_ref + if "cfg" in signature.parameters and env_ref is not None: + kwargs["cfg"] = getattr(env_ref, "cfg", env_ref) + if obs_type is not None and "obs_type" in signature.parameters: + kwargs["obs_type"] = obs_type + + result = aug_fn(**kwargs) + if isinstance(result, tuple): + aug_obs, aug_actions = result + else: + aug_obs, aug_actions = result, None + + aug_obs = aug_obs if aug_obs is not None else obs + aug_actions = aug_actions if aug_actions is not None else actions + return aug_obs, aug_actions + def test_mode(self) -> None: """ Sets the actor-critic model to evaluation mode. @@ -299,7 +396,7 @@ def compute_returns(self, last_critic_obs: torch.Tensor) -> None: last_values = self.actor_critic.evaluate(last_critic_obs).detach() self.storage.compute_returns(last_values, self.gamma, self.lam) - def update(self) -> Tuple[float, float, float, float, float, float, float, float]: + def update(self) -> Tuple[float, float, float, float, float, float, float, float, float, float]: """ Performs a single update step for both the actor-critic (PPO) and the AMP discriminator. It iterates over mini-batches of data, computes surrogate, value, AMP and gradient penalty losses, @@ -310,7 +407,8 @@ def update(self) -> Tuple[float, float, float, float, float, float, float, float tuple A tuple containing mean losses and statistics: (mean_value_loss, mean_surrogate_loss, mean_amp_loss, mean_grad_pen_loss, - mean_policy_pred, mean_expert_pred, mean_accuracy_policy, mean_accuracy_expert) + mean_policy_pred, mean_expert_pred, mean_accuracy_policy, mean_accuracy_expert, + mean_kl_divergence, mean_symmetry_loss) """ # Initialize mean loss and accuracy statistics. mean_value_loss: float = 0.0 @@ -324,6 +422,7 @@ def update(self) -> Tuple[float, float, float, float, float, float, float, float mean_accuracy_policy_elem: float = 0.0 mean_accuracy_expert_elem: float = 0.0 mean_kl_divergence: float = 0.0 + mean_symmetry_loss: float = 0.0 # Create data generators for mini-batch sampling. if self.actor_critic.is_recurrent: @@ -372,6 +471,48 @@ def update(self) -> Tuple[float, float, float, float, float, float, float, float rnd_state_batch, ) = sample + original_batch_size = obs_batch.shape[0] + + if self.normalize_advantage_per_mini_batch: + with torch.no_grad(): + advantages_batch = (advantages_batch - advantages_batch.mean()) / ( + advantages_batch.std() + 1e-8 + ) + + # Symmetry data augmentation for PPO inputs + num_aug = 1 + if self.symmetry and self.symmetry.get("use_data_augmentation", False): + aug_obs, aug_actions = self._apply_symmetry( + obs=obs_batch, + actions=actions_batch, + obs_type="policy", + ) + num_aug = self._augment_batch_size(original_batch_size, aug_obs) + obs_batch = aug_obs + actions_batch = ( + aug_actions + if aug_actions is not None + else self._repeat_along_batch(actions_batch, num_aug) + ) + + aug_critic_obs, _ = self._apply_symmetry( + obs=critic_obs_batch, + actions=None, + obs_type="critic", + ) + critic_obs_batch = ( + aug_critic_obs + if aug_critic_obs is not None + else self._repeat_along_batch(critic_obs_batch, num_aug) + ) + + old_actions_log_prob_batch = self._repeat_along_batch( + old_actions_log_prob_batch, num_aug + ) + target_values_batch = self._repeat_along_batch(target_values_batch, num_aug) + advantages_batch = self._repeat_along_batch(advantages_batch, num_aug) + returns_batch = self._repeat_along_batch(returns_batch, num_aug) + # Forward pass through the actor to get current policy outputs. self.actor_critic.act( obs_batch, masks=masks_batch, hidden_states=hid_states_batch[0] @@ -382,9 +523,9 @@ def update(self) -> Tuple[float, float, float, float, float, float, float, float value_batch = self.actor_critic.evaluate( critic_obs_batch, masks=masks_batch, hidden_states=hid_states_batch[1] ) - mu_batch = self.actor_critic.action_mean - sigma_batch = self.actor_critic.action_std - entropy_batch = self.actor_critic.entropy + mu_batch = self.actor_critic.action_mean[:original_batch_size] + sigma_batch = self.actor_critic.action_std[:original_batch_size] + entropy_batch = self.actor_critic.entropy[:original_batch_size] # Adaptive learning rate adjustment based on KL divergence if schedule is "adaptive". if self.desired_kl is not None and self.schedule == "adaptive": @@ -450,10 +591,54 @@ def update(self) -> Tuple[float, float, float, float, float, float, float, float - self.entropy_coef * entropy_batch.mean() ) + # Mirror loss (if enabled) + symmetry_loss_value = torch.zeros(1, device=self.device) + if self.symmetry: + if not self.symmetry.get("use_data_augmentation", False): + sym_obs_batch, _ = self._apply_symmetry( + obs=obs_batch[:original_batch_size], + actions=None, + obs_type="policy", + ) + else: + sym_obs_batch = obs_batch + + if sym_obs_batch is not None: + with torch.no_grad(): + sym_obs_detached = sym_obs_batch.detach().clone() + mean_actions_batch = self.actor_critic.act_inference(sym_obs_detached) + action_mean_orig = mean_actions_batch[:original_batch_size] + _, sym_actions = self._apply_symmetry( + obs=None, + actions=action_mean_orig, + obs_type="policy", + ) + if sym_actions is None: + sym_actions = mean_actions_batch + mse_loss = torch.nn.MSELoss() + symmetry_loss_value = mse_loss( + mean_actions_batch[original_batch_size:], + sym_actions.detach()[original_batch_size:], + ) + if self.symmetry.get("use_mirror_loss", False): + coeff = self.symmetry.get("mirror_loss_coeff", 0.0) + ppo_loss = ppo_loss + coeff * symmetry_loss_value + else: + symmetry_loss_value = symmetry_loss_value.detach() + # Process AMP loss by unpacking policy and expert AMP samples. policy_state, policy_next_state = sample_amp_policy expert_state, expert_next_state = sample_amp_expert + if self.symmetry and self.symmetry.get("use_data_augmentation", False): + policy_state = self.discriminator.apply_symmetry(policy_state, obs_type="amp") + policy_next_state = self.discriminator.apply_symmetry(policy_next_state, obs_type="amp") + expert_state = self.discriminator.apply_symmetry(expert_state, obs_type="amp") + expert_next_state = self.discriminator.apply_symmetry(expert_next_state, obs_type="amp") + + raw_policy_state = policy_state.detach() + raw_expert_state = expert_state.detach() + # Normalize AMP observations if a normalizer is provided. if self.amp_normalizer is not None: with torch.no_grad(): @@ -493,8 +678,8 @@ def update(self) -> Tuple[float, float, float, float, float, float, float, float # Update the normalizer with current policy and expert AMP observations. if self.amp_normalizer is not None: - self.amp_normalizer.update(policy_state) - self.amp_normalizer.update(expert_state) + self.amp_normalizer.update(raw_policy_state) + self.amp_normalizer.update(raw_expert_state) # Compute probabilities from the discriminator logits. policy_d_prob = torch.sigmoid(policy_d) @@ -507,6 +692,7 @@ def update(self) -> Tuple[float, float, float, float, float, float, float, float mean_grad_pen_loss += grad_pen_loss.item() mean_policy_pred += policy_d_prob.mean().item() mean_expert_pred += expert_d_prob.mean().item() + mean_symmetry_loss += symmetry_loss_value.item() # Calculate the accuracy of the discriminator. mean_accuracy_policy += torch.sum( @@ -531,6 +717,7 @@ def update(self) -> Tuple[float, float, float, float, float, float, float, float mean_accuracy_policy /= mean_accuracy_policy_elem mean_accuracy_expert /= mean_accuracy_expert_elem mean_kl_divergence /= num_updates + mean_symmetry_loss /= num_updates # Clear the storage for the next update cycle. self.storage.clear() @@ -545,4 +732,5 @@ def update(self) -> Tuple[float, float, float, float, float, float, float, float mean_accuracy_policy, mean_accuracy_expert, mean_kl_divergence, + mean_symmetry_loss, ) diff --git a/amp_rsl_rl/networks/discriminator.py b/amp_rsl_rl/networks/discriminator.py index 6f558d2..e55185e 100644 --- a/amp_rsl_rl/networks/discriminator.py +++ b/amp_rsl_rl/networks/discriminator.py @@ -3,6 +3,9 @@ # # SPDX-License-Identifier: BSD-3-Clause +import inspect +from typing import Any, Dict, Optional + import torch import torch.nn as nn from torch import autograd @@ -31,6 +34,7 @@ def __init__( device: str = "cpu", loss_type: str = "BCEWithLogits", eta_wgan: float = 0.3, + symmetry_cfg: Optional[Dict[str, Any]] = None, ): super(Discriminator, self).__init__() @@ -38,6 +42,7 @@ def __init__( self.input_dim = input_dim self.reward_scale = reward_scale self.reward_clamp_epsilon = reward_clamp_epsilon + self.symmetry_cfg = symmetry_cfg layers = [] curr_in_dim = input_dim @@ -63,6 +68,43 @@ def __init__( f"Unsupported loss type: {self.loss_type}. Supported types are 'BCEWithLogits' and 'Wasserstein'." ) + if self.symmetry_cfg is not None: + fn = self.symmetry_cfg.get("data_augmentation_func") + if isinstance(fn, str): + self.symmetry_cfg["data_augmentation_func"] = utils.string_to_callable(fn) + + def apply_symmetry(self, tensor: torch.Tensor, obs_type: str = "amp") -> torch.Tensor: + """Applies the configured symmetry augmentation to the provided tensor.""" + + if self.symmetry_cfg is None or not self.symmetry_cfg.get("use_data_augmentation", False): + return tensor + + fn = self.symmetry_cfg.get("data_augmentation_func") + if fn is None: + return tensor + + signature = inspect.signature(fn) + kwargs: Dict[str, Any] = {} + if "obs" in signature.parameters: + kwargs["obs"] = tensor + if "actions" in signature.parameters: + kwargs["actions"] = None + env_ref = self.symmetry_cfg.get("_env") + if "env" in signature.parameters: + kwargs["env"] = env_ref + if "cfg" in signature.parameters and env_ref is not None: + kwargs["cfg"] = getattr(env_ref, "cfg", env_ref) + if "obs_type" in signature.parameters: + kwargs["obs_type"] = obs_type + + result = fn(**kwargs) + if isinstance(result, tuple): + augmented = result[0] + else: + augmented = result + + return augmented if augmented is not None else tensor + def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward pass through the discriminator. diff --git a/amp_rsl_rl/runners/amp_on_policy_runner.py b/amp_rsl_rl/runners/amp_on_policy_runner.py index 7810f90..17a6f72 100644 --- a/amp_rsl_rl/runners/amp_on_policy_runner.py +++ b/amp_rsl_rl/runners/amp_on_policy_runner.py @@ -17,7 +17,12 @@ import rsl_rl import rsl_rl.utils from rsl_rl.env import VecEnv -from rsl_rl.modules import ActorCritic, ActorCriticRecurrent, EmpiricalNormalization +from rsl_rl.modules import ( + ActorCritic, + ActorCriticRecurrent, + EmpiricalNormalization, +) + from rsl_rl.utils import store_code_state from amp_rsl_rl.utils import Normalizer @@ -133,6 +138,11 @@ def __init__(self, env: VecEnv, train_cfg, log_dir=None, device="cpu"): self.device = device self.env = env + # if using symmetry then pass the environment config object + if "symmetry_cfg" in self.alg_cfg and self.alg_cfg["symmetry_cfg"] is not None: + # this is used by the symmetry function for handling different observation terms + self.alg_cfg["symmetry_cfg"]["_env"] = env + # Get the size of the observation space obs, extras = self.env.get_observations() num_obs = obs.shape[1] @@ -161,6 +171,7 @@ def __init__(self, env: VecEnv, train_cfg, log_dir=None, device="cpu"): delta_t, self.cfg["slow_down_factor"], amp_joint_names, + symmetry_cfg=self.alg_cfg.get("symmetry_cfg"), ) # self.env.unwrapped.scene["robot"].joint_names) @@ -173,6 +184,7 @@ def __init__(self, env: VecEnv, train_cfg, log_dir=None, device="cpu"): reward_scale=self.discriminator_cfg["reward_scale"], device=self.device, loss_type=self.discriminator_cfg["loss_type"], + symmetry_cfg=self.alg_cfg.get("symmetry_cfg"), ).to(self.device) # Initialize the PPO algorithm @@ -404,6 +416,7 @@ def update_run_name_with_sequence(prefix: str) -> None: mean_accuracy_policy, mean_accuracy_expert, mean_kl_divergence, + mean_symmetry_loss, ) = self.alg.update() stop = time.time() learn_time = stop - start @@ -479,6 +492,7 @@ def log(self, locs: dict, width: int = 80, pad: int = 35): self.writer.add_scalar( "Loss/accuracy_expert", locs["mean_accuracy_expert"], locs["it"] ) + self.writer.add_scalar("Loss/symmetry", locs["mean_symmetry_loss"], locs["it"]) self.writer.add_scalar("Loss/learning_rate", self.alg.learning_rate, locs["it"]) self.writer.add_scalar( @@ -529,6 +543,7 @@ def log(self, locs: dict, width: int = 80, pad: int = 35): 'collection_time']:.3f}s, learning {locs['learn_time']:.3f}s)\n""" f"""{'Value function loss:':>{pad}} {locs['mean_value_loss']:.4f}\n""" f"""{'Surrogate loss:':>{pad}} {locs['mean_surrogate_loss']:.4f}\n""" + f"""{'Symmetry loss:':>{pad}} {locs['mean_symmetry_loss']:.4f}\n""" f"""{'Mean action noise std:':>{pad}} {mean_std.item():.2f}\n""" f"""{'Mean reward:':>{pad}} {statistics.mean(locs['rewbuffer']):.2f}\n""" f"""{'Mean episode length:':>{pad}} {statistics.mean(locs['lenbuffer']):.2f}\n""" @@ -543,6 +558,7 @@ def log(self, locs: dict, width: int = 80, pad: int = 35): 'collection_time']:.3f}s, learning {locs['learn_time']:.3f}s)\n""" f"""{'Value function loss:':>{pad}} {locs['mean_value_loss']:.4f}\n""" f"""{'Surrogate loss:':>{pad}} {locs['mean_surrogate_loss']:.4f}\n""" + f"""{'Symmetry loss:':>{pad}} {locs['mean_symmetry_loss']:.4f}\n""" f"""{'Mean action noise std:':>{pad}} {mean_std.item():.2f}\n""" ) # f"""{'Mean reward/step:':>{pad}} {locs['mean_reward']:.2f}\n""" diff --git a/amp_rsl_rl/utils/motion_loader.py b/amp_rsl_rl/utils/motion_loader.py index 739bd83..8b658c2 100644 --- a/amp_rsl_rl/utils/motion_loader.py +++ b/amp_rsl_rl/utils/motion_loader.py @@ -3,9 +3,11 @@ # # SPDX-License-Identifier: BSD-3-Clause +import inspect from pathlib import Path -from typing import List, Union, Tuple, Generator +from typing import List, Union, Tuple, Generator, Optional, Dict, Any from dataclasses import dataclass +from rsl_rl.utils import utils import torch import numpy as np @@ -196,10 +198,16 @@ def __init__( simulation_dt: float, slow_down_factor: int, expected_joint_names: Union[List[str], None] = None, + symmetry_cfg: Optional[Dict[str, Any]] = None, ) -> None: self.device = device if isinstance(dataset_path_root, str): dataset_path_root = Path(dataset_path_root) + self.symmetry_cfg = symmetry_cfg + if self.symmetry_cfg is not None: + fn = self.symmetry_cfg.get("data_augmentation_func") + if isinstance(fn, str): + self.symmetry_cfg["data_augmentation_func"] = utils.string_to_callable(fn) # ─── Build union of all joint names if not provided ─── if expected_joint_names is None: @@ -233,6 +241,7 @@ def __init__( # Precompute flat buffers for fast sampling obs_list, next_obs_list, reset_states = [], [], [] + augmented_lengths: List[int] = [] for data, w in zip(self.motion_data, self.dataset_weights): T = len(data) idx = torch.arange(T, device=self.device) @@ -240,8 +249,13 @@ def __init__( next_idx = torch.clamp(idx + 1, max=T - 1) next_obs = data.get_amp_dataset_obs(next_idx) + if self.symmetry_cfg and self.symmetry_cfg.get("use_data_augmentation", False): + obs = self._apply_symmetry(obs, obs_type="amp") + next_obs = self._apply_symmetry(next_obs, obs_type="amp") + obs_list.append(obs) next_obs_list.append(next_obs) + augmented_lengths.append(obs.shape[0]) quat, jp, jv, blv, bav = data.get_state_for_reset(idx) reset_states.append(torch.cat([quat, jp, jv, blv, bav], dim=1)) @@ -251,7 +265,7 @@ def __init__( self.all_states = torch.cat(reset_states, dim=0) # Build per-frame sampling weights: weight_i / length_i - lengths = [len(d) for d in self.motion_data] + lengths = [length for length in augmented_lengths] per_frame = torch.cat( [ torch.full((L,), w / L, device=self.device) @@ -260,6 +274,36 @@ def __init__( ) self.per_frame_weights = per_frame / per_frame.sum() + def _apply_symmetry(self, tensor: torch.Tensor, obs_type: str) -> torch.Tensor: + if self.symmetry_cfg is None or not self.symmetry_cfg.get("use_data_augmentation", False): + return tensor + + fn = self.symmetry_cfg.get("data_augmentation_func") + if fn is None: + return tensor + + signature = inspect.signature(fn) + kwargs: Dict[str, Any] = {} + if "obs" in signature.parameters: + kwargs["obs"] = tensor + if "actions" in signature.parameters: + kwargs["actions"] = None + env_ref = self.symmetry_cfg.get("_env") + if "env" in signature.parameters: + kwargs["env"] = env_ref + if "cfg" in signature.parameters and env_ref is not None: + kwargs["cfg"] = getattr(env_ref, "cfg", env_ref) + if "obs_type" in signature.parameters: + kwargs["obs_type"] = obs_type + + result = fn(**kwargs) + if isinstance(result, tuple): + augmented = result[0] + else: + augmented = result + + return augmented if augmented is not None else tensor + def _resample_data_Rn( self, data: List[np.ndarray],