diff --git a/README.md b/README.md
index 411d385..545d0d2 100644
--- a/README.md
+++ b/README.md
@@ -62,7 +62,25 @@ amp_rsl_rl/
---
-## 📁 Dataset Structure
+## � Symmetry Augmentation
+
+AMP-RSL-RL now exposes the symmetry-aware data augmentation and mirror-loss hooks from
+[RSL-RL](https://github.com/leggedrobotics/rsl_rl). The implementation follows the design
+described in:
+
+> Mittal, M., Rudin, N., Klemm, V., Allshire, A., & Hutter, M. (2024).
+> *Symmetry Considerations for Learning Task Symmetric Robot Policies*. In IEEE International Conference on Robotics and Automation (ICRA).
+> https://doi.org/10.1109/ICRA57147.2024.10611493
+
+Symmetry augmentation can be enabled through the `symmetry_cfg` section of the algorithm
+configuration, providing both minibatch augmentation and optional mirror-loss regularisation
+for the policy update. AMP-specific components (the discriminator and expert/policy motion
+buffers) are augmented using the same configuration so that style rewards and adversarial
+training remain consistent with their symmetric counterparts.
+
+---
+
+## �📁 Dataset Structure
The AMP-RSL-RL framework expects motion capture datasets in `.npy` format. Each `.npy` file must contain a Python dictionary with the following keys:
diff --git a/amp_rsl_rl/algorithms/amp_ppo.py b/amp_rsl_rl/algorithms/amp_ppo.py
index c67e04c..692789c 100644
--- a/amp_rsl_rl/algorithms/amp_ppo.py
+++ b/amp_rsl_rl/algorithms/amp_ppo.py
@@ -6,6 +6,7 @@
from __future__ import annotations
+import inspect
from typing import Optional, Tuple, Dict, Any
import torch
@@ -14,6 +15,7 @@
# External modules providing the actor-critic model, storage utilities, and AMP components.
from rsl_rl.modules import ActorCritic
+from rsl_rl.utils import string_to_callable
from rsl_rl.storage import RolloutStorage
from amp_rsl_rl.storage import ReplayBuffer
@@ -67,6 +69,10 @@ class AMP_PPO:
Target KL divergence for the adaptive learning rate schedule.
amp_replay_buffer_size : int, default=100000
Maximum number of policy transitions stored in the replay buffer for AMP training.
+ normalize_advantage_per_mini_batch : bool, default=False
+ Whether to normalize advantages within each mini-batch (instead of the entire rollout).
+ symmetry_cfg : dict | None, default=None
+ Configuration dictionary enabling symmetry-based data augmentation and mirror loss.
device : str, default="cpu"
The device (CPU or GPU) on which the models will be computed.
"""
@@ -93,6 +99,8 @@ def __init__(
desired_kl: float = 0.01,
amp_replay_buffer_size: int = 100000,
use_smooth_ratio_clipping: bool = False,
+ normalize_advantage_per_mini_batch: bool = False,
+ symmetry_cfg: Optional[Dict[str, Any]] = None,
device: str = "cpu",
) -> None:
# Set device and learning hyperparameters
@@ -100,6 +108,9 @@ def __init__(
self.desired_kl: float = desired_kl
self.schedule: str = schedule
self.learning_rate: float = learning_rate
+ self.normalize_advantage_per_mini_batch: bool = (
+ normalize_advantage_per_mini_batch
+ )
# Set up the discriminator and move it to the appropriate device.
self.discriminator: Discriminator = discriminator.to(self.device)
@@ -149,6 +160,30 @@ def __init__(
self.use_clipped_value_loss: bool = use_clipped_value_loss
self.use_smooth_ratio_clipping: bool = use_smooth_ratio_clipping
+ # Symmetry configuration for PPO and AMP augmentation
+ self.symmetry: Optional[Dict[str, Any]] = None
+ if symmetry_cfg is not None:
+ use_symmetry = symmetry_cfg.get("use_data_augmentation", False) or symmetry_cfg.get(
+ "use_mirror_loss", False
+ )
+ if not use_symmetry:
+ print(
+ "Symmetry configuration provided but neither data augmentation nor mirror loss are enabled."
+ " Symmetry utilities will only be available for logging."
+ )
+ aug_fn = symmetry_cfg.get("data_augmentation_func")
+ if isinstance(aug_fn, str):
+ symmetry_cfg["data_augmentation_func"] = string_to_callable(aug_fn)
+ aug_fn = symmetry_cfg.get("data_augmentation_func")
+ if aug_fn is not None and not callable(aug_fn):
+ raise ValueError(
+ "Symmetry configuration exists but the function is not callable: "
+ f"{aug_fn}"
+ )
+ if getattr(actor_critic, "is_recurrent", False):
+ raise ValueError("Symmetry augmentation is not supported for recurrent policies in AMP_PPO.")
+ self.symmetry = symmetry_cfg
+
def init_storage(
self,
num_envs: int,
@@ -184,6 +219,68 @@ def init_storage(
device=self.device,
)
+ def _augment_batch_size(self, original_size: int, augmented: Optional[torch.Tensor]) -> int:
+ """Compute augmentation factor given the original and augmented batch sizes."""
+
+ if augmented is None or original_size == 0:
+ return 1
+ if augmented.shape[0] % original_size != 0:
+ raise ValueError(
+ "Symmetry augmentation function returned a batch size incompatible with the original size."
+ f" Original={original_size}, augmented={augmented.shape[0]}"
+ )
+ return augmented.shape[0] // original_size
+
+ def _repeat_along_batch(
+ self, tensor: Optional[torch.Tensor], num_aug: int
+ ) -> Optional[torch.Tensor]:
+ """Repeat a tensor along the first dimension to match augmentation factor."""
+
+ if tensor is None or num_aug == 1:
+ return tensor
+ repeat_dims = [num_aug] + [1] * (tensor.dim() - 1)
+ return tensor.repeat(*repeat_dims)
+
+ def _apply_symmetry(
+ self,
+ *,
+ obs: Optional[torch.Tensor],
+ actions: Optional[torch.Tensor],
+ obs_type: Optional[str] = None,
+ ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
+ """Apply configured symmetry augmentation to observations/actions."""
+
+ if self.symmetry is None:
+ return obs, actions
+
+ aug_fn = self.symmetry.get("data_augmentation_func")
+ if aug_fn is None:
+ return obs, actions
+
+ signature = inspect.signature(aug_fn)
+ kwargs: Dict[str, Any] = {}
+ if "obs" in signature.parameters:
+ kwargs["obs"] = obs
+ if "actions" in signature.parameters:
+ kwargs["actions"] = actions
+ env_ref = self.symmetry.get("_env")
+ if "env" in signature.parameters:
+ kwargs["env"] = env_ref
+ if "cfg" in signature.parameters and env_ref is not None:
+ kwargs["cfg"] = getattr(env_ref, "cfg", env_ref)
+ if obs_type is not None and "obs_type" in signature.parameters:
+ kwargs["obs_type"] = obs_type
+
+ result = aug_fn(**kwargs)
+ if isinstance(result, tuple):
+ aug_obs, aug_actions = result
+ else:
+ aug_obs, aug_actions = result, None
+
+ aug_obs = aug_obs if aug_obs is not None else obs
+ aug_actions = aug_actions if aug_actions is not None else actions
+ return aug_obs, aug_actions
+
def test_mode(self) -> None:
"""
Sets the actor-critic model to evaluation mode.
@@ -299,7 +396,7 @@ def compute_returns(self, last_critic_obs: torch.Tensor) -> None:
last_values = self.actor_critic.evaluate(last_critic_obs).detach()
self.storage.compute_returns(last_values, self.gamma, self.lam)
- def update(self) -> Tuple[float, float, float, float, float, float, float, float]:
+ def update(self) -> Tuple[float, float, float, float, float, float, float, float, float, float]:
"""
Performs a single update step for both the actor-critic (PPO) and the AMP discriminator.
It iterates over mini-batches of data, computes surrogate, value, AMP and gradient penalty losses,
@@ -310,7 +407,8 @@ def update(self) -> Tuple[float, float, float, float, float, float, float, float
tuple
A tuple containing mean losses and statistics:
(mean_value_loss, mean_surrogate_loss, mean_amp_loss, mean_grad_pen_loss,
- mean_policy_pred, mean_expert_pred, mean_accuracy_policy, mean_accuracy_expert)
+ mean_policy_pred, mean_expert_pred, mean_accuracy_policy, mean_accuracy_expert,
+ mean_kl_divergence, mean_symmetry_loss)
"""
# Initialize mean loss and accuracy statistics.
mean_value_loss: float = 0.0
@@ -324,6 +422,7 @@ def update(self) -> Tuple[float, float, float, float, float, float, float, float
mean_accuracy_policy_elem: float = 0.0
mean_accuracy_expert_elem: float = 0.0
mean_kl_divergence: float = 0.0
+ mean_symmetry_loss: float = 0.0
# Create data generators for mini-batch sampling.
if self.actor_critic.is_recurrent:
@@ -372,6 +471,48 @@ def update(self) -> Tuple[float, float, float, float, float, float, float, float
rnd_state_batch,
) = sample
+ original_batch_size = obs_batch.shape[0]
+
+ if self.normalize_advantage_per_mini_batch:
+ with torch.no_grad():
+ advantages_batch = (advantages_batch - advantages_batch.mean()) / (
+ advantages_batch.std() + 1e-8
+ )
+
+ # Symmetry data augmentation for PPO inputs
+ num_aug = 1
+ if self.symmetry and self.symmetry.get("use_data_augmentation", False):
+ aug_obs, aug_actions = self._apply_symmetry(
+ obs=obs_batch,
+ actions=actions_batch,
+ obs_type="policy",
+ )
+ num_aug = self._augment_batch_size(original_batch_size, aug_obs)
+ obs_batch = aug_obs
+ actions_batch = (
+ aug_actions
+ if aug_actions is not None
+ else self._repeat_along_batch(actions_batch, num_aug)
+ )
+
+ aug_critic_obs, _ = self._apply_symmetry(
+ obs=critic_obs_batch,
+ actions=None,
+ obs_type="critic",
+ )
+ critic_obs_batch = (
+ aug_critic_obs
+ if aug_critic_obs is not None
+ else self._repeat_along_batch(critic_obs_batch, num_aug)
+ )
+
+ old_actions_log_prob_batch = self._repeat_along_batch(
+ old_actions_log_prob_batch, num_aug
+ )
+ target_values_batch = self._repeat_along_batch(target_values_batch, num_aug)
+ advantages_batch = self._repeat_along_batch(advantages_batch, num_aug)
+ returns_batch = self._repeat_along_batch(returns_batch, num_aug)
+
# Forward pass through the actor to get current policy outputs.
self.actor_critic.act(
obs_batch, masks=masks_batch, hidden_states=hid_states_batch[0]
@@ -382,9 +523,9 @@ def update(self) -> Tuple[float, float, float, float, float, float, float, float
value_batch = self.actor_critic.evaluate(
critic_obs_batch, masks=masks_batch, hidden_states=hid_states_batch[1]
)
- mu_batch = self.actor_critic.action_mean
- sigma_batch = self.actor_critic.action_std
- entropy_batch = self.actor_critic.entropy
+ mu_batch = self.actor_critic.action_mean[:original_batch_size]
+ sigma_batch = self.actor_critic.action_std[:original_batch_size]
+ entropy_batch = self.actor_critic.entropy[:original_batch_size]
# Adaptive learning rate adjustment based on KL divergence if schedule is "adaptive".
if self.desired_kl is not None and self.schedule == "adaptive":
@@ -450,10 +591,54 @@ def update(self) -> Tuple[float, float, float, float, float, float, float, float
- self.entropy_coef * entropy_batch.mean()
)
+ # Mirror loss (if enabled)
+ symmetry_loss_value = torch.zeros(1, device=self.device)
+ if self.symmetry:
+ if not self.symmetry.get("use_data_augmentation", False):
+ sym_obs_batch, _ = self._apply_symmetry(
+ obs=obs_batch[:original_batch_size],
+ actions=None,
+ obs_type="policy",
+ )
+ else:
+ sym_obs_batch = obs_batch
+
+ if sym_obs_batch is not None:
+ with torch.no_grad():
+ sym_obs_detached = sym_obs_batch.detach().clone()
+ mean_actions_batch = self.actor_critic.act_inference(sym_obs_detached)
+ action_mean_orig = mean_actions_batch[:original_batch_size]
+ _, sym_actions = self._apply_symmetry(
+ obs=None,
+ actions=action_mean_orig,
+ obs_type="policy",
+ )
+ if sym_actions is None:
+ sym_actions = mean_actions_batch
+ mse_loss = torch.nn.MSELoss()
+ symmetry_loss_value = mse_loss(
+ mean_actions_batch[original_batch_size:],
+ sym_actions.detach()[original_batch_size:],
+ )
+ if self.symmetry.get("use_mirror_loss", False):
+ coeff = self.symmetry.get("mirror_loss_coeff", 0.0)
+ ppo_loss = ppo_loss + coeff * symmetry_loss_value
+ else:
+ symmetry_loss_value = symmetry_loss_value.detach()
+
# Process AMP loss by unpacking policy and expert AMP samples.
policy_state, policy_next_state = sample_amp_policy
expert_state, expert_next_state = sample_amp_expert
+ if self.symmetry and self.symmetry.get("use_data_augmentation", False):
+ policy_state = self.discriminator.apply_symmetry(policy_state, obs_type="amp")
+ policy_next_state = self.discriminator.apply_symmetry(policy_next_state, obs_type="amp")
+ expert_state = self.discriminator.apply_symmetry(expert_state, obs_type="amp")
+ expert_next_state = self.discriminator.apply_symmetry(expert_next_state, obs_type="amp")
+
+ raw_policy_state = policy_state.detach()
+ raw_expert_state = expert_state.detach()
+
# Normalize AMP observations if a normalizer is provided.
if self.amp_normalizer is not None:
with torch.no_grad():
@@ -493,8 +678,8 @@ def update(self) -> Tuple[float, float, float, float, float, float, float, float
# Update the normalizer with current policy and expert AMP observations.
if self.amp_normalizer is not None:
- self.amp_normalizer.update(policy_state)
- self.amp_normalizer.update(expert_state)
+ self.amp_normalizer.update(raw_policy_state)
+ self.amp_normalizer.update(raw_expert_state)
# Compute probabilities from the discriminator logits.
policy_d_prob = torch.sigmoid(policy_d)
@@ -507,6 +692,7 @@ def update(self) -> Tuple[float, float, float, float, float, float, float, float
mean_grad_pen_loss += grad_pen_loss.item()
mean_policy_pred += policy_d_prob.mean().item()
mean_expert_pred += expert_d_prob.mean().item()
+ mean_symmetry_loss += symmetry_loss_value.item()
# Calculate the accuracy of the discriminator.
mean_accuracy_policy += torch.sum(
@@ -531,6 +717,7 @@ def update(self) -> Tuple[float, float, float, float, float, float, float, float
mean_accuracy_policy /= mean_accuracy_policy_elem
mean_accuracy_expert /= mean_accuracy_expert_elem
mean_kl_divergence /= num_updates
+ mean_symmetry_loss /= num_updates
# Clear the storage for the next update cycle.
self.storage.clear()
@@ -545,4 +732,5 @@ def update(self) -> Tuple[float, float, float, float, float, float, float, float
mean_accuracy_policy,
mean_accuracy_expert,
mean_kl_divergence,
+ mean_symmetry_loss,
)
diff --git a/amp_rsl_rl/networks/discriminator.py b/amp_rsl_rl/networks/discriminator.py
index 6f558d2..e55185e 100644
--- a/amp_rsl_rl/networks/discriminator.py
+++ b/amp_rsl_rl/networks/discriminator.py
@@ -3,6 +3,9 @@
#
# SPDX-License-Identifier: BSD-3-Clause
+import inspect
+from typing import Any, Dict, Optional
+
import torch
import torch.nn as nn
from torch import autograd
@@ -31,6 +34,7 @@ def __init__(
device: str = "cpu",
loss_type: str = "BCEWithLogits",
eta_wgan: float = 0.3,
+ symmetry_cfg: Optional[Dict[str, Any]] = None,
):
super(Discriminator, self).__init__()
@@ -38,6 +42,7 @@ def __init__(
self.input_dim = input_dim
self.reward_scale = reward_scale
self.reward_clamp_epsilon = reward_clamp_epsilon
+ self.symmetry_cfg = symmetry_cfg
layers = []
curr_in_dim = input_dim
@@ -63,6 +68,43 @@ def __init__(
f"Unsupported loss type: {self.loss_type}. Supported types are 'BCEWithLogits' and 'Wasserstein'."
)
+ if self.symmetry_cfg is not None:
+ fn = self.symmetry_cfg.get("data_augmentation_func")
+ if isinstance(fn, str):
+ self.symmetry_cfg["data_augmentation_func"] = utils.string_to_callable(fn)
+
+ def apply_symmetry(self, tensor: torch.Tensor, obs_type: str = "amp") -> torch.Tensor:
+ """Applies the configured symmetry augmentation to the provided tensor."""
+
+ if self.symmetry_cfg is None or not self.symmetry_cfg.get("use_data_augmentation", False):
+ return tensor
+
+ fn = self.symmetry_cfg.get("data_augmentation_func")
+ if fn is None:
+ return tensor
+
+ signature = inspect.signature(fn)
+ kwargs: Dict[str, Any] = {}
+ if "obs" in signature.parameters:
+ kwargs["obs"] = tensor
+ if "actions" in signature.parameters:
+ kwargs["actions"] = None
+ env_ref = self.symmetry_cfg.get("_env")
+ if "env" in signature.parameters:
+ kwargs["env"] = env_ref
+ if "cfg" in signature.parameters and env_ref is not None:
+ kwargs["cfg"] = getattr(env_ref, "cfg", env_ref)
+ if "obs_type" in signature.parameters:
+ kwargs["obs_type"] = obs_type
+
+ result = fn(**kwargs)
+ if isinstance(result, tuple):
+ augmented = result[0]
+ else:
+ augmented = result
+
+ return augmented if augmented is not None else tensor
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass through the discriminator.
diff --git a/amp_rsl_rl/runners/amp_on_policy_runner.py b/amp_rsl_rl/runners/amp_on_policy_runner.py
index 7810f90..17a6f72 100644
--- a/amp_rsl_rl/runners/amp_on_policy_runner.py
+++ b/amp_rsl_rl/runners/amp_on_policy_runner.py
@@ -17,7 +17,12 @@
import rsl_rl
import rsl_rl.utils
from rsl_rl.env import VecEnv
-from rsl_rl.modules import ActorCritic, ActorCriticRecurrent, EmpiricalNormalization
+from rsl_rl.modules import (
+ ActorCritic,
+ ActorCriticRecurrent,
+ EmpiricalNormalization,
+)
+
from rsl_rl.utils import store_code_state
from amp_rsl_rl.utils import Normalizer
@@ -133,6 +138,11 @@ def __init__(self, env: VecEnv, train_cfg, log_dir=None, device="cpu"):
self.device = device
self.env = env
+ # if using symmetry then pass the environment config object
+ if "symmetry_cfg" in self.alg_cfg and self.alg_cfg["symmetry_cfg"] is not None:
+ # this is used by the symmetry function for handling different observation terms
+ self.alg_cfg["symmetry_cfg"]["_env"] = env
+
# Get the size of the observation space
obs, extras = self.env.get_observations()
num_obs = obs.shape[1]
@@ -161,6 +171,7 @@ def __init__(self, env: VecEnv, train_cfg, log_dir=None, device="cpu"):
delta_t,
self.cfg["slow_down_factor"],
amp_joint_names,
+ symmetry_cfg=self.alg_cfg.get("symmetry_cfg"),
)
# self.env.unwrapped.scene["robot"].joint_names)
@@ -173,6 +184,7 @@ def __init__(self, env: VecEnv, train_cfg, log_dir=None, device="cpu"):
reward_scale=self.discriminator_cfg["reward_scale"],
device=self.device,
loss_type=self.discriminator_cfg["loss_type"],
+ symmetry_cfg=self.alg_cfg.get("symmetry_cfg"),
).to(self.device)
# Initialize the PPO algorithm
@@ -404,6 +416,7 @@ def update_run_name_with_sequence(prefix: str) -> None:
mean_accuracy_policy,
mean_accuracy_expert,
mean_kl_divergence,
+ mean_symmetry_loss,
) = self.alg.update()
stop = time.time()
learn_time = stop - start
@@ -479,6 +492,7 @@ def log(self, locs: dict, width: int = 80, pad: int = 35):
self.writer.add_scalar(
"Loss/accuracy_expert", locs["mean_accuracy_expert"], locs["it"]
)
+ self.writer.add_scalar("Loss/symmetry", locs["mean_symmetry_loss"], locs["it"])
self.writer.add_scalar("Loss/learning_rate", self.alg.learning_rate, locs["it"])
self.writer.add_scalar(
@@ -529,6 +543,7 @@ def log(self, locs: dict, width: int = 80, pad: int = 35):
'collection_time']:.3f}s, learning {locs['learn_time']:.3f}s)\n"""
f"""{'Value function loss:':>{pad}} {locs['mean_value_loss']:.4f}\n"""
f"""{'Surrogate loss:':>{pad}} {locs['mean_surrogate_loss']:.4f}\n"""
+ f"""{'Symmetry loss:':>{pad}} {locs['mean_symmetry_loss']:.4f}\n"""
f"""{'Mean action noise std:':>{pad}} {mean_std.item():.2f}\n"""
f"""{'Mean reward:':>{pad}} {statistics.mean(locs['rewbuffer']):.2f}\n"""
f"""{'Mean episode length:':>{pad}} {statistics.mean(locs['lenbuffer']):.2f}\n"""
@@ -543,6 +558,7 @@ def log(self, locs: dict, width: int = 80, pad: int = 35):
'collection_time']:.3f}s, learning {locs['learn_time']:.3f}s)\n"""
f"""{'Value function loss:':>{pad}} {locs['mean_value_loss']:.4f}\n"""
f"""{'Surrogate loss:':>{pad}} {locs['mean_surrogate_loss']:.4f}\n"""
+ f"""{'Symmetry loss:':>{pad}} {locs['mean_symmetry_loss']:.4f}\n"""
f"""{'Mean action noise std:':>{pad}} {mean_std.item():.2f}\n"""
)
# f"""{'Mean reward/step:':>{pad}} {locs['mean_reward']:.2f}\n"""
diff --git a/amp_rsl_rl/utils/motion_loader.py b/amp_rsl_rl/utils/motion_loader.py
index 739bd83..8b658c2 100644
--- a/amp_rsl_rl/utils/motion_loader.py
+++ b/amp_rsl_rl/utils/motion_loader.py
@@ -3,9 +3,11 @@
#
# SPDX-License-Identifier: BSD-3-Clause
+import inspect
from pathlib import Path
-from typing import List, Union, Tuple, Generator
+from typing import List, Union, Tuple, Generator, Optional, Dict, Any
from dataclasses import dataclass
+from rsl_rl.utils import utils
import torch
import numpy as np
@@ -196,10 +198,16 @@ def __init__(
simulation_dt: float,
slow_down_factor: int,
expected_joint_names: Union[List[str], None] = None,
+ symmetry_cfg: Optional[Dict[str, Any]] = None,
) -> None:
self.device = device
if isinstance(dataset_path_root, str):
dataset_path_root = Path(dataset_path_root)
+ self.symmetry_cfg = symmetry_cfg
+ if self.symmetry_cfg is not None:
+ fn = self.symmetry_cfg.get("data_augmentation_func")
+ if isinstance(fn, str):
+ self.symmetry_cfg["data_augmentation_func"] = utils.string_to_callable(fn)
# ─── Build union of all joint names if not provided ───
if expected_joint_names is None:
@@ -233,6 +241,7 @@ def __init__(
# Precompute flat buffers for fast sampling
obs_list, next_obs_list, reset_states = [], [], []
+ augmented_lengths: List[int] = []
for data, w in zip(self.motion_data, self.dataset_weights):
T = len(data)
idx = torch.arange(T, device=self.device)
@@ -240,8 +249,13 @@ def __init__(
next_idx = torch.clamp(idx + 1, max=T - 1)
next_obs = data.get_amp_dataset_obs(next_idx)
+ if self.symmetry_cfg and self.symmetry_cfg.get("use_data_augmentation", False):
+ obs = self._apply_symmetry(obs, obs_type="amp")
+ next_obs = self._apply_symmetry(next_obs, obs_type="amp")
+
obs_list.append(obs)
next_obs_list.append(next_obs)
+ augmented_lengths.append(obs.shape[0])
quat, jp, jv, blv, bav = data.get_state_for_reset(idx)
reset_states.append(torch.cat([quat, jp, jv, blv, bav], dim=1))
@@ -251,7 +265,7 @@ def __init__(
self.all_states = torch.cat(reset_states, dim=0)
# Build per-frame sampling weights: weight_i / length_i
- lengths = [len(d) for d in self.motion_data]
+ lengths = [length for length in augmented_lengths]
per_frame = torch.cat(
[
torch.full((L,), w / L, device=self.device)
@@ -260,6 +274,36 @@ def __init__(
)
self.per_frame_weights = per_frame / per_frame.sum()
+ def _apply_symmetry(self, tensor: torch.Tensor, obs_type: str) -> torch.Tensor:
+ if self.symmetry_cfg is None or not self.symmetry_cfg.get("use_data_augmentation", False):
+ return tensor
+
+ fn = self.symmetry_cfg.get("data_augmentation_func")
+ if fn is None:
+ return tensor
+
+ signature = inspect.signature(fn)
+ kwargs: Dict[str, Any] = {}
+ if "obs" in signature.parameters:
+ kwargs["obs"] = tensor
+ if "actions" in signature.parameters:
+ kwargs["actions"] = None
+ env_ref = self.symmetry_cfg.get("_env")
+ if "env" in signature.parameters:
+ kwargs["env"] = env_ref
+ if "cfg" in signature.parameters and env_ref is not None:
+ kwargs["cfg"] = getattr(env_ref, "cfg", env_ref)
+ if "obs_type" in signature.parameters:
+ kwargs["obs_type"] = obs_type
+
+ result = fn(**kwargs)
+ if isinstance(result, tuple):
+ augmented = result[0]
+ else:
+ augmented = result
+
+ return augmented if augmented is not None else tensor
+
def _resample_data_Rn(
self,
data: List[np.ndarray],