Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions amp_rsl_rl/runners/amp_on_policy_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,13 +154,16 @@ def __init__(self, env: VecEnv, train_cfg, log_dir=None, device="cpu"):
# Initilize all the ingredients required for AMP (discriminator, dataset loader)
num_amp_obs = extras["observations"]["amp"].shape[1]
amp_data = AMPLoader(
self.device,
self.cfg["amp_data_path"],
self.cfg["dataset_names"],
self.cfg["dataset_weights"],
delta_t,
self.cfg["slow_down_factor"],
actuated_joint_names,
device=self.device,
dataset_path_root=self.cfg.get("amp_data_path", None),
dataset_names=self.cfg["dataset_names"],
dataset_weights=self.cfg["dataset_weights"],
simulation_dt=delta_t,
slow_down_factor=self.cfg["slow_down_factor"],
expected_joint_names=actuated_joint_names,
download_from_hf=self.cfg.get("download_from_hf", False),
robot_folder=self.cfg.get("robot_folder", None),
repo_id=self.cfg.get("repo_id", None),
)

# self.env.unwrapped.scene["robot"].joint_names)
Expand Down
41 changes: 36 additions & 5 deletions amp_rsl_rl/utils/motion_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause

from pathlib import Path
from typing import List, Union, Tuple, Generator
from typing import List, Union, Tuple, Generator, Optional
from dataclasses import dataclass

import torch
import numpy as np
import platformdirs
from scipy.spatial.transform import Rotation, Slerp
from scipy.interpolate import interp1d

Expand Down Expand Up @@ -185,22 +185,53 @@ class AMPLoader:
simulation_dt: Timestep used by the simulator
slow_down_factor: Integer factor to slow down original data
expected_joint_names: (Optional) override for joint ordering
download_from_hf: (Optional) download datasets from Hugging Face
robot_folder: (Optional) folder in the Hugging Face dataset repo
repo_id: (Optional) Hugging Face repository ID
"""

def __init__(
self,
device: str,
dataset_path_root: Path,
dataset_names: List[str],
dataset_weights: List[float],
simulation_dt: float,
slow_down_factor: int,
expected_joint_names: Union[List[str], None] = None,
dataset_path_root: Optional[Union[str, Path]] = None,
expected_joint_names: Optional[List[str]] = None,
download_from_hf: Optional[bool] = False,
robot_folder: Optional[str] = "ergocub",
repo_id: Optional[str] = None,
) -> None:
self.device = device
if isinstance(dataset_path_root, str):

if download_from_hf:
print("Downloading datasets from Hugging Face.")
if dataset_path_root is None:
print(
"Warning: `dataset_path_root` is None. "
"A cache directory will be used."
)

# Create a cache directory for downloading
dataset_path_root = platformdirs.user_cache_path(
appname="amp-rsl-rl",
appauthor="ami-iit",
ensure_exists=True,
)
print(f"Cache directory: {dataset_path_root}")

# Convert dataset path to Path object for easier handling
dataset_path_root = Path(dataset_path_root)

# Download dataset from Hugging Face
dataset_names = download_amp_dataset_from_hf(
dataset_path_root,
robot_folder=robot_folder or "ergocub",
files=[f"{name}.npy" for name in dataset_names],
repo_id=repo_id or "ami-iit/amp-dataset",
)

# ─── Build union of all joint names if not provided ───
if expected_joint_names is None:
joint_union: List[str] = []
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ classifiers = [ # Added standard PyPI classifiers
keywords = ["reinforcement-learning", "robotics", "motion-priors", "ppo"]
dependencies = [
"numpy>=1.21.0",
"platformdirs>=4.0.0",
"scipy>=1.7.0",
"torch>=1.10.0",
"rsl-rl-lib>=2.3.0",
Expand Down