diff --git a/amp_rsl_rl/runners/amp_on_policy_runner.py b/amp_rsl_rl/runners/amp_on_policy_runner.py index 7810f90..b3ae519 100644 --- a/amp_rsl_rl/runners/amp_on_policy_runner.py +++ b/amp_rsl_rl/runners/amp_on_policy_runner.py @@ -161,6 +161,7 @@ def __init__(self, env: VecEnv, train_cfg, log_dir=None, device="cpu"): delta_t, self.cfg["slow_down_factor"], amp_joint_names, + 6.0, ) # self.env.unwrapped.scene["robot"].joint_names) diff --git a/amp_rsl_rl/utils/motion_loader.py b/amp_rsl_rl/utils/motion_loader.py index 739bd83..8b6765b 100644 --- a/amp_rsl_rl/utils/motion_loader.py +++ b/amp_rsl_rl/utils/motion_loader.py @@ -4,13 +4,14 @@ # SPDX-License-Identifier: BSD-3-Clause from pathlib import Path -from typing import List, Union, Tuple, Generator +from typing import List, Union, Tuple, Generator, Optional from dataclasses import dataclass import torch import numpy as np from scipy.spatial.transform import Rotation, Slerp from scipy.interpolate import interp1d +from scipy.signal import butter, filtfilt # ← added for acausal low-pass filtering def download_amp_dataset_from_hf( @@ -88,8 +89,15 @@ class MotionData: def __post_init__(self) -> None: # Convert numpy arrays (or SciPy Rotations) to torch tensors + def to_tensor(x): - return torch.tensor(x, device=self.device, dtype=torch.float32) + # Ensure positive strides / contiguous memory before converting to torch + if not isinstance(x, np.ndarray): + x = np.array(x, dtype=np.float32) + # Make contiguous and float32 without copying if already OK + x = np.ascontiguousarray(x, dtype=np.float32) + return torch.from_numpy(x).to(self.device) + if isinstance(self.joint_positions, np.ndarray): self.joint_positions = to_tensor(self.joint_positions) @@ -185,6 +193,8 @@ class AMPLoader: simulation_dt: Timestep used by the simulator slow_down_factor: Integer factor to slow down original data expected_joint_names: (Optional) override for joint ordering + vel_filter_cutoff_hz: (Optional) if set, apply a 2nd-order acausal (zero-phase) + low-pass Butterworth filter with this cutoff (Hz) to joint and base velocities. """ def __init__( @@ -196,8 +206,10 @@ def __init__( simulation_dt: float, slow_down_factor: int, expected_joint_names: Union[List[str], None] = None, + vel_filter_cutoff_hz: Optional[float] = None, ) -> None: self.device = device + self.vel_filter_cutoff_hz = vel_filter_cutoff_hz if isinstance(dataset_path_root, str): dataset_path_root = Path(dataset_path_root) @@ -302,10 +314,64 @@ def _compute_ang_vel( return np.vstack((rotvec, rotvec[-1])) - def _compute_raw_derivative(self, data: np.ndarray, dt: float) -> np.ndarray: - d = (data[1:] - data[:-1]) / dt + def _compute_raw_derivative( + self, data: np.ndarray, dt: float, angular: bool = False + ) -> np.ndarray: + """ + Finite-difference derivative with optional angle-wrap handling. + + If `angular` is True, the difference is computed as the minimal wrapped + angle between consecutive samples using atan2(sin Δ, cos Δ). This + properly handles jumps across ±π. + """ + if data.shape[0] < 2: + return np.zeros_like(data) + + if angular: + # Minimal angular difference in [-π, π] per element + delta = np.arctan2( + np.sin(data[1:] - data[:-1]), + np.cos(data[1:] - data[:-1]), + ) + else: + delta = data[1:] - data[:-1] + + d = delta / dt return np.vstack([d, d[-1:]]) + def _lowpass_acausal( + self, data: np.ndarray, dt: float, cutoff_hz: Optional[float], order: int = 2 + ) -> np.ndarray: + """ + Zero-phase (acausal) low-pass filter using Butterworth + filtfilt. + + Args: + data: array shaped (T, D) or (T,) + dt: sampling interval + cutoff_hz: cutoff frequency in Hz. If None or invalid, returns data unchanged. + order: filter order (default 2) + + Returns: + Filtered data with same shape. + """ + if cutoff_hz is None or cutoff_hz <= 0: + return data + fs = 1.0 / dt + nyq = 0.5 * fs + Wn = cutoff_hz / nyq + # If cutoff is above Nyquist, skip filtering to avoid warnings/instability + if Wn >= 1.0: + return data + + b, a = butter(N=order, Wn=Wn, btype="low", analog=False) + # Ensure 2D for consistent axis handling + data_2d = data if data.ndim == 2 else data[:, None] + # filtfilt is along axis=0 (time) + filtered = filtfilt(b, a, data_2d, axis=0, method="pad") + filtered = np.ascontiguousarray(filtered) # ensure positive strides + return filtered if data.ndim == 2 else filtered[:, 0] + + def load_data( self, dataset_path: Path, @@ -346,8 +412,9 @@ def load_data( t_new = np.linspace(0, T * dt, T_new) resampled_joint_positions = self._resample_data_Rn(jp_list, t_orig, t_new) + # ── Joint velocities with wrapped-angle difference to handle ±π jumps ── resampled_joint_velocities = self._compute_raw_derivative( - resampled_joint_positions, simulation_dt + resampled_joint_positions, simulation_dt, angular=True ) resampled_base_positions = self._resample_data_Rn( @@ -358,7 +425,7 @@ def load_data( ) resampled_base_lin_vel_mixed = self._compute_raw_derivative( - resampled_base_positions, simulation_dt + resampled_base_positions, simulation_dt, angular=False ) resampled_base_ang_vel_mixed = self._compute_ang_vel( @@ -377,6 +444,25 @@ def load_data( resampled_base_orientations, simulation_dt, local=True ) + # ── Optional 2nd-order acausal low-pass filtering of velocities ── + if self.vel_filter_cutoff_hz is not None: + c = self.vel_filter_cutoff_hz + resampled_joint_velocities = self._lowpass_acausal( + resampled_joint_velocities, simulation_dt, c, order=2 + ) + resampled_base_lin_vel_mixed = self._lowpass_acausal( + resampled_base_lin_vel_mixed, simulation_dt, c, order=2 + ) + resampled_base_ang_vel_mixed = self._lowpass_acausal( + resampled_base_ang_vel_mixed, simulation_dt, c, order=2 + ) + resampled_base_lin_vel_local = self._lowpass_acausal( + resampled_base_lin_vel_local, simulation_dt, c, order=2 + ) + resampled_base_ang_vel_local = self._lowpass_acausal( + resampled_base_ang_vel_local, simulation_dt, c, order=2 + ) + return MotionData( joint_positions=resampled_joint_positions, joint_velocities=resampled_joint_velocities,