Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,25 @@ amp_rsl_rl/

---

## 📁 Dataset Structure
## � Symmetry Augmentation

AMP-RSL-RL now exposes the symmetry-aware data augmentation and mirror-loss hooks from
[RSL-RL](https://github.com/leggedrobotics/rsl_rl). The implementation follows the design
described in:

> Mittal, M., Rudin, N., Klemm, V., Allshire, A., & Hutter, M. (2024).<br>
> *Symmetry Considerations for Learning Task Symmetric Robot Policies*. In IEEE International Conference on Robotics and Automation (ICRA).<br>
> https://doi.org/10.1109/ICRA57147.2024.10611493

Symmetry augmentation can be enabled through the `symmetry_cfg` section of the algorithm
configuration, providing both minibatch augmentation and optional mirror-loss regularisation
for the policy update. AMP-specific components (the discriminator and expert/policy motion
buffers) are augmented using the same configuration so that style rewards and adversarial
training remain consistent with their symmetric counterparts.

---

## �📁 Dataset Structure

The AMP-RSL-RL framework expects motion capture datasets in `.npy` format. Each `.npy` file must contain a Python dictionary with the following keys:

Expand Down
202 changes: 195 additions & 7 deletions amp_rsl_rl/algorithms/amp_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from __future__ import annotations

import inspect
from typing import Optional, Tuple, Dict, Any

import torch
Expand All @@ -14,6 +15,7 @@

# External modules providing the actor-critic model, storage utilities, and AMP components.
from rsl_rl.modules import ActorCritic
from rsl_rl.utils import string_to_callable
from rsl_rl.storage import RolloutStorage

from amp_rsl_rl.storage import ReplayBuffer
Expand Down Expand Up @@ -67,6 +69,10 @@ class AMP_PPO:
Target KL divergence for the adaptive learning rate schedule.
amp_replay_buffer_size : int, default=100000
Maximum number of policy transitions stored in the replay buffer for AMP training.
normalize_advantage_per_mini_batch : bool, default=False
Whether to normalize advantages within each mini-batch (instead of the entire rollout).
symmetry_cfg : dict | None, default=None
Configuration dictionary enabling symmetry-based data augmentation and mirror loss.
device : str, default="cpu"
The device (CPU or GPU) on which the models will be computed.
"""
Expand All @@ -93,13 +99,18 @@ def __init__(
desired_kl: float = 0.01,
amp_replay_buffer_size: int = 100000,
use_smooth_ratio_clipping: bool = False,
normalize_advantage_per_mini_batch: bool = False,
symmetry_cfg: Optional[Dict[str, Any]] = None,
device: str = "cpu",
) -> None:
# Set device and learning hyperparameters
self.device: str = device
self.desired_kl: float = desired_kl
self.schedule: str = schedule
self.learning_rate: float = learning_rate
self.normalize_advantage_per_mini_batch: bool = (
normalize_advantage_per_mini_batch
)

# Set up the discriminator and move it to the appropriate device.
self.discriminator: Discriminator = discriminator.to(self.device)
Expand Down Expand Up @@ -149,6 +160,30 @@ def __init__(
self.use_clipped_value_loss: bool = use_clipped_value_loss
self.use_smooth_ratio_clipping: bool = use_smooth_ratio_clipping

# Symmetry configuration for PPO and AMP augmentation
self.symmetry: Optional[Dict[str, Any]] = None
if symmetry_cfg is not None:
use_symmetry = symmetry_cfg.get("use_data_augmentation", False) or symmetry_cfg.get(
"use_mirror_loss", False
)
if not use_symmetry:
print(
"Symmetry configuration provided but neither data augmentation nor mirror loss are enabled."
" Symmetry utilities will only be available for logging."
)
aug_fn = symmetry_cfg.get("data_augmentation_func")
if isinstance(aug_fn, str):
symmetry_cfg["data_augmentation_func"] = string_to_callable(aug_fn)
aug_fn = symmetry_cfg.get("data_augmentation_func")
if aug_fn is not None and not callable(aug_fn):
raise ValueError(
"Symmetry configuration exists but the function is not callable: "
f"{aug_fn}"
)
if getattr(actor_critic, "is_recurrent", False):
raise ValueError("Symmetry augmentation is not supported for recurrent policies in AMP_PPO.")
self.symmetry = symmetry_cfg

def init_storage(
self,
num_envs: int,
Expand Down Expand Up @@ -184,6 +219,68 @@ def init_storage(
device=self.device,
)

def _augment_batch_size(self, original_size: int, augmented: Optional[torch.Tensor]) -> int:
"""Compute augmentation factor given the original and augmented batch sizes."""

if augmented is None or original_size == 0:
return 1
if augmented.shape[0] % original_size != 0:
raise ValueError(
"Symmetry augmentation function returned a batch size incompatible with the original size."
f" Original={original_size}, augmented={augmented.shape[0]}"
)
return augmented.shape[0] // original_size

def _repeat_along_batch(
self, tensor: Optional[torch.Tensor], num_aug: int
) -> Optional[torch.Tensor]:
"""Repeat a tensor along the first dimension to match augmentation factor."""

if tensor is None or num_aug == 1:
return tensor
repeat_dims = [num_aug] + [1] * (tensor.dim() - 1)
return tensor.repeat(*repeat_dims)

def _apply_symmetry(
self,
*,
obs: Optional[torch.Tensor],
actions: Optional[torch.Tensor],
obs_type: Optional[str] = None,
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
"""Apply configured symmetry augmentation to observations/actions."""

if self.symmetry is None:
return obs, actions

aug_fn = self.symmetry.get("data_augmentation_func")
if aug_fn is None:
return obs, actions

signature = inspect.signature(aug_fn)
kwargs: Dict[str, Any] = {}
if "obs" in signature.parameters:
kwargs["obs"] = obs
if "actions" in signature.parameters:
kwargs["actions"] = actions
env_ref = self.symmetry.get("_env")
if "env" in signature.parameters:
kwargs["env"] = env_ref
if "cfg" in signature.parameters and env_ref is not None:
kwargs["cfg"] = getattr(env_ref, "cfg", env_ref)
if obs_type is not None and "obs_type" in signature.parameters:
kwargs["obs_type"] = obs_type

result = aug_fn(**kwargs)
if isinstance(result, tuple):
aug_obs, aug_actions = result
else:
aug_obs, aug_actions = result, None

aug_obs = aug_obs if aug_obs is not None else obs
aug_actions = aug_actions if aug_actions is not None else actions
return aug_obs, aug_actions

def test_mode(self) -> None:
"""
Sets the actor-critic model to evaluation mode.
Expand Down Expand Up @@ -299,7 +396,7 @@ def compute_returns(self, last_critic_obs: torch.Tensor) -> None:
last_values = self.actor_critic.evaluate(last_critic_obs).detach()
self.storage.compute_returns(last_values, self.gamma, self.lam)

def update(self) -> Tuple[float, float, float, float, float, float, float, float]:
def update(self) -> Tuple[float, float, float, float, float, float, float, float, float, float]:
"""
Performs a single update step for both the actor-critic (PPO) and the AMP discriminator.
It iterates over mini-batches of data, computes surrogate, value, AMP and gradient penalty losses,
Expand All @@ -310,7 +407,8 @@ def update(self) -> Tuple[float, float, float, float, float, float, float, float
tuple
A tuple containing mean losses and statistics:
(mean_value_loss, mean_surrogate_loss, mean_amp_loss, mean_grad_pen_loss,
mean_policy_pred, mean_expert_pred, mean_accuracy_policy, mean_accuracy_expert)
mean_policy_pred, mean_expert_pred, mean_accuracy_policy, mean_accuracy_expert,
mean_kl_divergence, mean_symmetry_loss)
"""
# Initialize mean loss and accuracy statistics.
mean_value_loss: float = 0.0
Expand All @@ -324,6 +422,7 @@ def update(self) -> Tuple[float, float, float, float, float, float, float, float
mean_accuracy_policy_elem: float = 0.0
mean_accuracy_expert_elem: float = 0.0
mean_kl_divergence: float = 0.0
mean_symmetry_loss: float = 0.0

# Create data generators for mini-batch sampling.
if self.actor_critic.is_recurrent:
Expand Down Expand Up @@ -372,6 +471,48 @@ def update(self) -> Tuple[float, float, float, float, float, float, float, float
rnd_state_batch,
) = sample

original_batch_size = obs_batch.shape[0]

if self.normalize_advantage_per_mini_batch:
with torch.no_grad():
advantages_batch = (advantages_batch - advantages_batch.mean()) / (
advantages_batch.std() + 1e-8
)

# Symmetry data augmentation for PPO inputs
num_aug = 1
if self.symmetry and self.symmetry.get("use_data_augmentation", False):
aug_obs, aug_actions = self._apply_symmetry(
obs=obs_batch,
actions=actions_batch,
obs_type="policy",
)
num_aug = self._augment_batch_size(original_batch_size, aug_obs)
obs_batch = aug_obs
actions_batch = (
aug_actions
if aug_actions is not None
else self._repeat_along_batch(actions_batch, num_aug)
)

aug_critic_obs, _ = self._apply_symmetry(
obs=critic_obs_batch,
actions=None,
obs_type="critic",
)
critic_obs_batch = (
aug_critic_obs
if aug_critic_obs is not None
else self._repeat_along_batch(critic_obs_batch, num_aug)
)

old_actions_log_prob_batch = self._repeat_along_batch(
old_actions_log_prob_batch, num_aug
)
target_values_batch = self._repeat_along_batch(target_values_batch, num_aug)
advantages_batch = self._repeat_along_batch(advantages_batch, num_aug)
returns_batch = self._repeat_along_batch(returns_batch, num_aug)

# Forward pass through the actor to get current policy outputs.
self.actor_critic.act(
obs_batch, masks=masks_batch, hidden_states=hid_states_batch[0]
Expand All @@ -382,9 +523,9 @@ def update(self) -> Tuple[float, float, float, float, float, float, float, float
value_batch = self.actor_critic.evaluate(
critic_obs_batch, masks=masks_batch, hidden_states=hid_states_batch[1]
)
mu_batch = self.actor_critic.action_mean
sigma_batch = self.actor_critic.action_std
entropy_batch = self.actor_critic.entropy
mu_batch = self.actor_critic.action_mean[:original_batch_size]
sigma_batch = self.actor_critic.action_std[:original_batch_size]
entropy_batch = self.actor_critic.entropy[:original_batch_size]

# Adaptive learning rate adjustment based on KL divergence if schedule is "adaptive".
if self.desired_kl is not None and self.schedule == "adaptive":
Expand Down Expand Up @@ -450,10 +591,54 @@ def update(self) -> Tuple[float, float, float, float, float, float, float, float
- self.entropy_coef * entropy_batch.mean()
)

# Mirror loss (if enabled)
symmetry_loss_value = torch.zeros(1, device=self.device)
if self.symmetry:
if not self.symmetry.get("use_data_augmentation", False):
sym_obs_batch, _ = self._apply_symmetry(
obs=obs_batch[:original_batch_size],
actions=None,
obs_type="policy",
)
else:
sym_obs_batch = obs_batch

if sym_obs_batch is not None:
with torch.no_grad():
sym_obs_detached = sym_obs_batch.detach().clone()
mean_actions_batch = self.actor_critic.act_inference(sym_obs_detached)
action_mean_orig = mean_actions_batch[:original_batch_size]
_, sym_actions = self._apply_symmetry(
obs=None,
actions=action_mean_orig,
obs_type="policy",
)
if sym_actions is None:
sym_actions = mean_actions_batch
mse_loss = torch.nn.MSELoss()
symmetry_loss_value = mse_loss(
mean_actions_batch[original_batch_size:],
sym_actions.detach()[original_batch_size:],
)
if self.symmetry.get("use_mirror_loss", False):
coeff = self.symmetry.get("mirror_loss_coeff", 0.0)
ppo_loss = ppo_loss + coeff * symmetry_loss_value
else:
symmetry_loss_value = symmetry_loss_value.detach()

# Process AMP loss by unpacking policy and expert AMP samples.
policy_state, policy_next_state = sample_amp_policy
expert_state, expert_next_state = sample_amp_expert

if self.symmetry and self.symmetry.get("use_data_augmentation", False):
policy_state = self.discriminator.apply_symmetry(policy_state, obs_type="amp")
policy_next_state = self.discriminator.apply_symmetry(policy_next_state, obs_type="amp")
expert_state = self.discriminator.apply_symmetry(expert_state, obs_type="amp")
expert_next_state = self.discriminator.apply_symmetry(expert_next_state, obs_type="amp")

raw_policy_state = policy_state.detach()
raw_expert_state = expert_state.detach()

# Normalize AMP observations if a normalizer is provided.
if self.amp_normalizer is not None:
with torch.no_grad():
Expand Down Expand Up @@ -493,8 +678,8 @@ def update(self) -> Tuple[float, float, float, float, float, float, float, float

# Update the normalizer with current policy and expert AMP observations.
if self.amp_normalizer is not None:
self.amp_normalizer.update(policy_state)
self.amp_normalizer.update(expert_state)
self.amp_normalizer.update(raw_policy_state)
self.amp_normalizer.update(raw_expert_state)

# Compute probabilities from the discriminator logits.
policy_d_prob = torch.sigmoid(policy_d)
Expand All @@ -507,6 +692,7 @@ def update(self) -> Tuple[float, float, float, float, float, float, float, float
mean_grad_pen_loss += grad_pen_loss.item()
mean_policy_pred += policy_d_prob.mean().item()
mean_expert_pred += expert_d_prob.mean().item()
mean_symmetry_loss += symmetry_loss_value.item()

# Calculate the accuracy of the discriminator.
mean_accuracy_policy += torch.sum(
Expand All @@ -531,6 +717,7 @@ def update(self) -> Tuple[float, float, float, float, float, float, float, float
mean_accuracy_policy /= mean_accuracy_policy_elem
mean_accuracy_expert /= mean_accuracy_expert_elem
mean_kl_divergence /= num_updates
mean_symmetry_loss /= num_updates

# Clear the storage for the next update cycle.
self.storage.clear()
Expand All @@ -545,4 +732,5 @@ def update(self) -> Tuple[float, float, float, float, float, float, float, float
mean_accuracy_policy,
mean_accuracy_expert,
mean_kl_divergence,
mean_symmetry_loss,
)
Loading