diff --git a/amp_rsl_rl/runners/amp_on_policy_runner.py b/amp_rsl_rl/runners/amp_on_policy_runner.py index c8d8350..3e7f351 100644 --- a/amp_rsl_rl/runners/amp_on_policy_runner.py +++ b/amp_rsl_rl/runners/amp_on_policy_runner.py @@ -154,13 +154,16 @@ def __init__(self, env: VecEnv, train_cfg, log_dir=None, device="cpu"): # Initilize all the ingredients required for AMP (discriminator, dataset loader) num_amp_obs = extras["observations"]["amp"].shape[1] amp_data = AMPLoader( - self.device, - self.cfg["amp_data_path"], - self.cfg["dataset_names"], - self.cfg["dataset_weights"], - delta_t, - self.cfg["slow_down_factor"], - actuated_joint_names, + device=self.device, + dataset_path_root=self.cfg.get("amp_data_path", None), + dataset_names=self.cfg["dataset_names"], + dataset_weights=self.cfg["dataset_weights"], + simulation_dt=delta_t, + slow_down_factor=self.cfg["slow_down_factor"], + expected_joint_names=actuated_joint_names, + download_from_hf=self.cfg.get("download_from_hf", False), + robot_folder=self.cfg.get("robot_folder", None), + repo_id=self.cfg.get("repo_id", None), ) # self.env.unwrapped.scene["robot"].joint_names) diff --git a/amp_rsl_rl/utils/motion_loader.py b/amp_rsl_rl/utils/motion_loader.py index 00c7ae0..5116aee 100644 --- a/amp_rsl_rl/utils/motion_loader.py +++ b/amp_rsl_rl/utils/motion_loader.py @@ -2,13 +2,13 @@ # All rights reserved. # # SPDX-License-Identifier: BSD-3-Clause - from pathlib import Path -from typing import List, Union, Tuple, Generator +from typing import List, Union, Tuple, Generator, Optional from dataclasses import dataclass import torch import numpy as np +import platformdirs from scipy.spatial.transform import Rotation, Slerp from scipy.interpolate import interp1d @@ -185,22 +185,53 @@ class AMPLoader: simulation_dt: Timestep used by the simulator slow_down_factor: Integer factor to slow down original data expected_joint_names: (Optional) override for joint ordering + download_from_hf: (Optional) download datasets from Hugging Face + robot_folder: (Optional) folder in the Hugging Face dataset repo + repo_id: (Optional) Hugging Face repository ID """ def __init__( self, device: str, - dataset_path_root: Path, dataset_names: List[str], dataset_weights: List[float], simulation_dt: float, slow_down_factor: int, - expected_joint_names: Union[List[str], None] = None, + dataset_path_root: Optional[Union[str, Path]] = None, + expected_joint_names: Optional[List[str]] = None, + download_from_hf: Optional[bool] = False, + robot_folder: Optional[str] = "ergocub", + repo_id: Optional[str] = None, ) -> None: self.device = device - if isinstance(dataset_path_root, str): + + if download_from_hf: + print("Downloading datasets from Hugging Face.") + if dataset_path_root is None: + print( + "Warning: `dataset_path_root` is None. " + "A cache directory will be used." + ) + + # Create a cache directory for downloading + dataset_path_root = platformdirs.user_cache_path( + appname="amp-rsl-rl", + appauthor="ami-iit", + ensure_exists=True, + ) + print(f"Cache directory: {dataset_path_root}") + + # Convert dataset path to Path object for easier handling dataset_path_root = Path(dataset_path_root) + # Download dataset from Hugging Face + dataset_names = download_amp_dataset_from_hf( + dataset_path_root, + robot_folder=robot_folder or "ergocub", + files=[f"{name}.npy" for name in dataset_names], + repo_id=repo_id or "ami-iit/amp-dataset", + ) + # ─── Build union of all joint names if not provided ─── if expected_joint_names is None: joint_union: List[str] = [] diff --git a/pyproject.toml b/pyproject.toml index 0af7188..1977192 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,6 +32,7 @@ classifiers = [ # Added standard PyPI classifiers keywords = ["reinforcement-learning", "robotics", "motion-priors", "ppo"] dependencies = [ "numpy>=1.21.0", + "platformdirs>=4.0.0", "scipy>=1.7.0", "torch>=1.10.0", "rsl-rl-lib>=2.3.0",