diff --git a/examples/dataset/use_dataset_tools.py b/examples/dataset/use_dataset_tools.py index 2442598723..bd7c389bc4 100644 --- a/examples/dataset/use_dataset_tools.py +++ b/examples/dataset/use_dataset_tools.py @@ -30,9 +30,10 @@ import numpy as np from lerobot.datasets.dataset_tools import ( - add_feature, + add_features, delete_episodes, merge_datasets, + modify_features, remove_feature, split_dataset, ) @@ -57,50 +58,56 @@ def main(): print(f"Train split: {splits['train'].meta.total_episodes} episodes") print(f"Val split: {splits['val'].meta.total_episodes} episodes") - print("\n3. Adding a reward feature...") + print("\n3. Adding features...") reward_values = np.random.randn(dataset.meta.total_frames).astype(np.float32) - dataset_with_reward = add_feature( - dataset, - feature_name="reward", - feature_values=reward_values, - feature_info={ - "dtype": "float32", - "shape": (1,), - "names": None, - }, - repo_id="lerobot/pusht_with_reward", - ) def compute_success(row_dict, episode_index, frame_index): episode_length = 10 return float(frame_index >= episode_length - 10) - dataset_with_success = add_feature( - dataset_with_reward, - feature_name="success", - feature_values=compute_success, - feature_info={ - "dtype": "float32", - "shape": (1,), - "names": None, + dataset_with_features = add_features( + dataset, + features={ + "reward": ( + reward_values, + {"dtype": "float32", "shape": (1,), "names": None}, + ), + "success": ( + compute_success, + {"dtype": "float32", "shape": (1,), "names": None}, + ), }, - repo_id="lerobot/pusht_with_reward_and_success", + repo_id="lerobot/pusht_with_features", ) - print(f"New features: {list(dataset_with_success.meta.features.keys())}") + print(f"New features: {list(dataset_with_features.meta.features.keys())}") print("\n4. Removing the success feature...") dataset_cleaned = remove_feature( - dataset_with_success, feature_names="success", repo_id="lerobot/pusht_cleaned" + dataset_with_features, feature_names="success", repo_id="lerobot/pusht_cleaned" ) print(f"Features after removal: {list(dataset_cleaned.meta.features.keys())}") - print("\n5. Merging train and val splits back together...") + print("\n5. Using modify_features to add and remove features simultaneously...") + dataset_modified = modify_features( + dataset_with_features, + add_features={ + "discount": ( + np.ones(dataset.meta.total_frames, dtype=np.float32) * 0.99, + {"dtype": "float32", "shape": (1,), "names": None}, + ), + }, + remove_features="reward", + repo_id="lerobot/pusht_modified", + ) + print(f"Modified features: {list(dataset_modified.meta.features.keys())}") + + print("\n6. Merging train and val splits back together...") merged = merge_datasets([splits["train"], splits["val"]], output_repo_id="lerobot/pusht_merged") print(f"Merged dataset: {merged.meta.total_episodes} episodes") - print("\n6. Complex workflow example...") + print("\n7. Complex workflow example...") if len(dataset.meta.camera_keys) > 1: camera_to_remove = dataset.meta.camera_keys[0] diff --git a/src/lerobot/datasets/dataset_tools.py b/src/lerobot/datasets/dataset_tools.py index 8ebc4a59de..2735ba0a05 100644 --- a/src/lerobot/datasets/dataset_tools.py +++ b/src/lerobot/datasets/dataset_tools.py @@ -28,8 +28,10 @@ from collections.abc import Callable from pathlib import Path +import datasets import numpy as np import pandas as pd +import pyarrow.parquet as pq import torch from tqdm import tqdm @@ -43,7 +45,6 @@ DEFAULT_EPISODES_PATH, get_parquet_file_size_in_mb, load_episodes, - to_parquet_with_hf_images, update_chunk_file_indices, write_info, write_stats, @@ -268,39 +269,79 @@ def merge_datasets( return merged_dataset -def add_feature( +def modify_features( dataset: LeRobotDataset, - feature_name: str, - feature_values: np.ndarray | torch.Tensor | Callable, - feature_info: dict, + add_features: dict[str, tuple[np.ndarray | torch.Tensor | Callable, dict]] | None = None, + remove_features: str | list[str] | None = None, output_dir: str | Path | None = None, repo_id: str | None = None, ) -> LeRobotDataset: - """Add a new feature to a LeRobotDataset. + """Modify a LeRobotDataset by adding and/or removing features in a single pass. + + This is the most efficient way to modify features, as it only copies the dataset once + regardless of how many features are being added or removed. Args: dataset: The source LeRobotDataset. - feature_name: Name of the new feature. - feature_values: Either: - - Array/tensor of shape (num_frames, ...) with values for each frame - - Callable that takes (frame_dict, episode_index, frame_index) and returns feature value - feature_info: Dictionary with feature metadata (dtype, shape, names). + add_features: Optional dict mapping feature names to (feature_values, feature_info) tuples. + remove_features: Optional feature name(s) to remove. Can be a single string or list. output_dir: Directory to save the new dataset. If None, uses default location. repo_id: Repository ID for the new dataset. If None, appends "_modified" to original. + + Returns: + New dataset with features modified. + + Example: + new_dataset = modify_features( + dataset, + add_features={ + "reward": (reward_array, {"dtype": "float32", "shape": [1], "names": None}), + }, + remove_features=["old_feature"], + output_dir="./output", + ) """ - if feature_name in dataset.meta.features: - raise ValueError(f"Feature '{feature_name}' already exists in dataset") + if add_features is None and remove_features is None: + raise ValueError("Must specify at least one of add_features or remove_features") + + remove_features_list: list[str] = [] + if remove_features is not None: + remove_features_list = [remove_features] if isinstance(remove_features, str) else remove_features + + if add_features: + required_keys = {"dtype", "shape"} + for feature_name, (_, feature_info) in add_features.items(): + if feature_name in dataset.meta.features: + raise ValueError(f"Feature '{feature_name}' already exists in dataset") + + if not required_keys.issubset(feature_info.keys()): + raise ValueError(f"feature_info for '{feature_name}' must contain keys: {required_keys}") + + if remove_features_list: + for name in remove_features_list: + if name not in dataset.meta.features: + raise ValueError(f"Feature '{name}' not found in dataset") + + required_features = {"timestamp", "frame_index", "episode_index", "index", "task_index"} + if any(name in required_features for name in remove_features_list): + raise ValueError(f"Cannot remove required features: {required_features}") if repo_id is None: repo_id = f"{dataset.repo_id}_modified" output_dir = Path(output_dir) if output_dir is not None else HF_LEROBOT_HOME / repo_id - required_keys = {"dtype", "shape"} - if not required_keys.issubset(feature_info.keys()): - raise ValueError(f"feature_info must contain keys: {required_keys}") - new_features = dataset.meta.features.copy() - new_features[feature_name] = feature_info + + if remove_features_list: + for name in remove_features_list: + new_features.pop(name, None) + + if add_features: + for feature_name, (_, feature_info) in add_features.items(): + new_features[feature_name] = feature_info + + video_keys_to_remove = [name for name in remove_features_list if name in dataset.meta.video_keys] + remaining_video_keys = [k for k in dataset.meta.video_keys if k not in video_keys_to_remove] new_meta = LeRobotDatasetMetadata.create( repo_id=repo_id, @@ -308,17 +349,18 @@ def add_feature( features=new_features, robot_type=dataset.meta.robot_type, root=output_dir, - use_videos=len(dataset.meta.video_keys) > 0, + use_videos=len(remaining_video_keys) > 0, ) _copy_data_with_feature_changes( dataset=dataset, new_meta=new_meta, - add_features={feature_name: (feature_values, feature_info)}, + add_features=add_features, + remove_features=remove_features_list if remove_features_list else None, ) - if dataset.meta.video_keys: - _copy_videos(dataset, new_meta) + if new_meta.video_keys: + _copy_videos(dataset, new_meta, exclude_keys=video_keys_to_remove if video_keys_to_remove else None) new_dataset = LeRobotDataset( repo_id=repo_id, @@ -331,70 +373,71 @@ def add_feature( return new_dataset -def remove_feature( +def add_features( dataset: LeRobotDataset, - feature_names: str | list[str], + features: dict[str, tuple[np.ndarray | torch.Tensor | Callable, dict]], output_dir: str | Path | None = None, repo_id: str | None = None, ) -> LeRobotDataset: - """Remove features from a LeRobotDataset. + """Add multiple features to a LeRobotDataset in a single pass. + + This is more efficient than calling add_feature() multiple times, as it only + copies the dataset once regardless of how many features are being added. Args: dataset: The source LeRobotDataset. - feature_names: Name(s) of features to remove. Can be a single string or list. + features: Dictionary mapping feature names to (feature_values, feature_info) tuples. output_dir: Directory to save the new dataset. If None, uses default location. repo_id: Repository ID for the new dataset. If None, appends "_modified" to original. - """ - if isinstance(feature_names, str): - feature_names = [feature_names] - - for name in feature_names: - if name not in dataset.meta.features: - raise ValueError(f"Feature '{name}' not found in dataset") - - required_features = {"timestamp", "frame_index", "episode_index", "index", "task_index"} - if any(name in required_features for name in feature_names): - raise ValueError(f"Cannot remove required features: {required_features}") + Returns: + New dataset with all features added. - if repo_id is None: - repo_id = f"{dataset.repo_id}_modified" - output_dir = Path(output_dir) if output_dir is not None else HF_LEROBOT_HOME / repo_id + Example: + features = { + "task_embedding": (task_emb_array, {"dtype": "float32", "shape": [384], "names": None}), + "cam1_embedding": (cam1_emb_array, {"dtype": "float32", "shape": [768], "names": None}), + "cam2_embedding": (cam2_emb_array, {"dtype": "float32", "shape": [768], "names": None}), + } + new_dataset = add_features(dataset, features, output_dir="./output", repo_id="my_dataset") + """ + if not features: + raise ValueError("No features provided") - new_features = {k: v for k, v in dataset.meta.features.items() if k not in feature_names} + return modify_features( + dataset=dataset, + add_features=features, + remove_features=None, + output_dir=output_dir, + repo_id=repo_id, + ) - video_keys_to_remove = [name for name in feature_names if name in dataset.meta.video_keys] - remaining_video_keys = [k for k in dataset.meta.video_keys if k not in video_keys_to_remove] +def remove_feature( + dataset: LeRobotDataset, + feature_names: str | list[str], + output_dir: str | Path | None = None, + repo_id: str | None = None, +) -> LeRobotDataset: + """Remove features from a LeRobotDataset. - new_meta = LeRobotDatasetMetadata.create( - repo_id=repo_id, - fps=dataset.meta.fps, - features=new_features, - robot_type=dataset.meta.robot_type, - root=output_dir, - use_videos=len(remaining_video_keys) > 0, - ) + Args: + dataset: The source LeRobotDataset. + feature_names: Name(s) of features to remove. Can be a single string or list. + output_dir: Directory to save the new dataset. If None, uses default location. + repo_id: Repository ID for the new dataset. If None, appends "_modified" to original. - _copy_data_with_feature_changes( + Returns: + New dataset with features removed. + """ + return modify_features( dataset=dataset, - new_meta=new_meta, + add_features=None, remove_features=feature_names, - ) - - if new_meta.video_keys: - _copy_videos(dataset, new_meta, exclude_keys=video_keys_to_remove) - - new_dataset = LeRobotDataset( + output_dir=output_dir, repo_id=repo_id, - root=output_dir, - image_transforms=dataset.image_transforms, - delta_timestamps=dataset.delta_timestamps, - tolerance_s=dataset.tolerance_s, ) - return new_dataset - def _fractions_to_episode_indices( total_episodes: int, @@ -501,10 +544,7 @@ def _copy_and_reindex_data( dst_path = dst_meta.root / DEFAULT_DATA_PATH.format(chunk_index=chunk_idx, file_index=file_idx) dst_path.parent.mkdir(parents=True, exist_ok=True) - if len(dst_meta.image_keys) > 0: - to_parquet_with_hf_images(df, dst_path) - else: - df.to_parquet(dst_path, index=False) + _write_parquet(df, dst_path, dst_meta) for ep_old_idx in episodes_to_keep: ep_new_idx = episode_mapping[ep_old_idx] @@ -862,6 +902,25 @@ def _copy_and_reindex_episodes_metadata( write_stats(filtered_stats, dst_meta.root) +def _write_parquet(df: pd.DataFrame, path: Path, meta: LeRobotDatasetMetadata) -> None: + """Write DataFrame to parquet + + This ensures images are properly embedded and the file can be loaded correctly by HF datasets. + """ + from lerobot.datasets.utils import embed_images, get_hf_features_from_features + + hf_features = get_hf_features_from_features(meta.features) + ep_dataset = datasets.Dataset.from_dict(df.to_dict(orient="list"), features=hf_features, split="train") + + if len(meta.image_keys) > 0: + ep_dataset = embed_images(ep_dataset) + + table = ep_dataset.with_format("arrow")[:] + writer = pq.ParquetWriter(path, schema=table.schema, compression="snappy", use_dictionary=True) + writer.write_table(table) + writer.close() + + def _save_data_chunk( df: pd.DataFrame, meta: LeRobotDatasetMetadata, @@ -877,10 +936,7 @@ def _save_data_chunk( path = meta.root / DEFAULT_DATA_PATH.format(chunk_index=chunk_idx, file_index=file_idx) path.parent.mkdir(parents=True, exist_ok=True) - if len(meta.image_keys) > 0: - to_parquet_with_hf_images(df, path) - else: - df.to_parquet(path, index=False) + _write_parquet(df, path, meta) episode_metadata = {} for ep_idx in df["episode_index"].unique(): @@ -906,19 +962,34 @@ def _copy_data_with_feature_changes( remove_features: list[str] | None = None, ) -> None: """Copy data while adding or removing features.""" - file_paths = set() + if dataset.meta.episodes is None: + dataset.meta.episodes = load_episodes(dataset.meta.root) + + # Map file paths to episode indices to extract chunk/file indices + file_to_episodes: dict[Path, set[int]] = {} for ep_idx in range(dataset.meta.total_episodes): - file_paths.add(dataset.meta.get_data_file_path(ep_idx)) + file_path = dataset.meta.get_data_file_path(ep_idx) + if file_path not in file_to_episodes: + file_to_episodes[file_path] = set() + file_to_episodes[file_path].add(ep_idx) frame_idx = 0 - for src_path in tqdm(sorted(file_paths), desc="Processing data files"): + for src_path in tqdm(sorted(file_to_episodes.keys()), desc="Processing data files"): df = pd.read_parquet(dataset.root / src_path).reset_index(drop=True) + # Get chunk_idx and file_idx from the source file's first episode + episodes_in_file = file_to_episodes[src_path] + first_ep_idx = min(episodes_in_file) + src_ep = dataset.meta.episodes[first_ep_idx] + chunk_idx = src_ep["data/chunk_index"] + file_idx = src_ep["data/file_index"] + if remove_features: df = df.drop(columns=remove_features, errors="ignore") if add_features: + end_idx = frame_idx + len(df) for feature_name, (values, _) in add_features.items(): if callable(values): feature_values = [] @@ -931,15 +1002,18 @@ def _copy_data_with_feature_changes( feature_values.append(value) df[feature_name] = feature_values else: - end_idx = frame_idx + len(df) feature_slice = values[frame_idx:end_idx] if len(feature_slice.shape) > 1 and feature_slice.shape[1] == 1: df[feature_name] = feature_slice.flatten() else: df[feature_name] = feature_slice - frame_idx = end_idx + frame_idx = end_idx + + # Write using the preserved chunk_idx and file_idx from source + dst_path = new_meta.root / DEFAULT_DATA_PATH.format(chunk_index=chunk_idx, file_index=file_idx) + dst_path.parent.mkdir(parents=True, exist_ok=True) - _save_data_chunk(df, new_meta) + _write_parquet(df, dst_path, new_meta) _copy_episodes_metadata_and_stats(dataset, new_meta) diff --git a/tests/datasets/test_dataset_tools.py b/tests/datasets/test_dataset_tools.py index a9c04d6f24..8bc1dbf6b9 100644 --- a/tests/datasets/test_dataset_tools.py +++ b/tests/datasets/test_dataset_tools.py @@ -22,9 +22,10 @@ import torch from lerobot.datasets.dataset_tools import ( - add_feature, + add_features, delete_episodes, merge_datasets, + modify_features, remove_feature, split_dataset, ) @@ -292,7 +293,7 @@ def test_merge_empty_list(tmp_path): merge_datasets([], output_repo_id="merged", output_dir=tmp_path) -def test_add_feature_with_values(sample_dataset, tmp_path): +def test_add_features_with_values(sample_dataset, tmp_path): """Test adding a feature with pre-computed values.""" num_frames = sample_dataset.meta.total_frames reward_values = np.random.randn(num_frames, 1).astype(np.float32) @@ -302,6 +303,9 @@ def test_add_feature_with_values(sample_dataset, tmp_path): "shape": (1,), "names": None, } + features = { + "reward": (reward_values, feature_info), + } with ( patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, @@ -310,11 +314,9 @@ def test_add_feature_with_values(sample_dataset, tmp_path): mock_get_safe_version.return_value = "v3.0" mock_snapshot_download.return_value = str(tmp_path / "with_reward") - new_dataset = add_feature( - sample_dataset, - feature_name="reward", - feature_values=reward_values, - feature_info=feature_info, + new_dataset = add_features( + dataset=sample_dataset, + features=features, output_dir=tmp_path / "with_reward", ) @@ -327,7 +329,7 @@ def test_add_feature_with_values(sample_dataset, tmp_path): assert isinstance(sample_item["reward"], torch.Tensor) -def test_add_feature_with_callable(sample_dataset, tmp_path): +def test_add_features_with_callable(sample_dataset, tmp_path): """Test adding a feature with a callable.""" def compute_reward(frame_dict, episode_idx, frame_idx): @@ -338,7 +340,9 @@ def compute_reward(frame_dict, episode_idx, frame_idx): "shape": (1,), "names": None, } - + features = { + "reward": (compute_reward, feature_info), + } with ( patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, @@ -346,11 +350,9 @@ def compute_reward(frame_dict, episode_idx, frame_idx): mock_get_safe_version.return_value = "v3.0" mock_snapshot_download.return_value = str(tmp_path / "with_reward") - new_dataset = add_feature( - sample_dataset, - feature_name="reward", - feature_values=compute_reward, - feature_info=feature_info, + new_dataset = add_features( + dataset=sample_dataset, + features=features, output_dir=tmp_path / "with_reward", ) @@ -368,31 +370,88 @@ def compute_reward(frame_dict, episode_idx, frame_idx): def test_add_existing_feature(sample_dataset, tmp_path): """Test error when adding an existing feature.""" feature_info = {"dtype": "float32", "shape": (1,)} + features = { + "action": (np.zeros(50), feature_info), + } with pytest.raises(ValueError, match="Feature 'action' already exists"): - add_feature( - sample_dataset, - feature_name="action", - feature_values=np.zeros(50), - feature_info=feature_info, + add_features( + dataset=sample_dataset, + features=features, output_dir=tmp_path / "modified", ) def test_add_feature_invalid_info(sample_dataset, tmp_path): """Test error with invalid feature info.""" - with pytest.raises(ValueError, match="feature_info must contain keys"): - add_feature( + with pytest.raises(ValueError, match="feature_info for 'reward' must contain keys"): + add_features( + dataset=sample_dataset, + features={ + "reward": (np.zeros(50), {"dtype": "float32"}), + }, + output_dir=tmp_path / "modified", + ) + + +def test_modify_features_add_and_remove(sample_dataset, tmp_path): + """Test modifying features by adding and removing simultaneously.""" + feature_info = {"dtype": "float32", "shape": (1,), "names": None} + + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + mock_snapshot_download.return_value = str(tmp_path / "modified") + + # First add a feature we'll later remove + dataset_with_reward = add_features( sample_dataset, - feature_name="reward", - feature_values=np.zeros(50), - feature_info={"dtype": "float32"}, + features={"reward": (np.random.randn(50, 1).astype(np.float32), feature_info)}, + output_dir=tmp_path / "with_reward", + ) + + # Now use modify_features to add "success" and remove "reward" in one pass + modified_dataset = modify_features( + dataset_with_reward, + add_features={ + "success": (np.random.randn(50, 1).astype(np.float32), feature_info), + }, + remove_features="reward", output_dir=tmp_path / "modified", ) + assert "success" in modified_dataset.meta.features + assert "reward" not in modified_dataset.meta.features + assert len(modified_dataset) == 50 -def test_remove_single_feature(sample_dataset, tmp_path): - """Test removing a single feature.""" + +def test_modify_features_only_add(sample_dataset, tmp_path): + """Test that modify_features works with only add_features.""" + feature_info = {"dtype": "float32", "shape": (1,), "names": None} + + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + mock_snapshot_download.return_value = str(tmp_path / "modified") + + modified_dataset = modify_features( + sample_dataset, + add_features={ + "reward": (np.random.randn(50, 1).astype(np.float32), feature_info), + }, + output_dir=tmp_path / "modified", + ) + + assert "reward" in modified_dataset.meta.features + assert len(modified_dataset) == 50 + + +def test_modify_features_only_remove(sample_dataset, tmp_path): + """Test that modify_features works with only remove_features.""" feature_info = {"dtype": "float32", "shape": (1,), "names": None} with ( @@ -402,11 +461,46 @@ def test_remove_single_feature(sample_dataset, tmp_path): mock_get_safe_version.return_value = "v3.0" mock_snapshot_download.side_effect = lambda repo_id, **kwargs: str(kwargs.get("local_dir", tmp_path)) - dataset_with_reward = add_feature( + dataset_with_reward = add_features( + sample_dataset, + features={"reward": (np.random.randn(50, 1).astype(np.float32), feature_info)}, + output_dir=tmp_path / "with_reward", + ) + + modified_dataset = modify_features( + dataset_with_reward, + remove_features="reward", + output_dir=tmp_path / "modified", + ) + + assert "reward" not in modified_dataset.meta.features + + +def test_modify_features_no_changes(sample_dataset, tmp_path): + """Test error when modify_features is called with no changes.""" + with pytest.raises(ValueError, match="Must specify at least one of add_features or remove_features"): + modify_features( sample_dataset, - feature_name="reward", - feature_values=np.random.randn(50, 1).astype(np.float32), - feature_info=feature_info, + output_dir=tmp_path / "modified", + ) + + +def test_remove_single_feature(sample_dataset, tmp_path): + """Test removing a single feature.""" + feature_info = {"dtype": "float32", "shape": (1,), "names": None} + features = { + "reward": (np.random.randn(50, 1).astype(np.float32), feature_info), + } + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + mock_snapshot_download.side_effect = lambda repo_id, **kwargs: str(kwargs.get("local_dir", tmp_path)) + + dataset_with_reward = add_features( + dataset=sample_dataset, + features=features, output_dir=tmp_path / "with_reward", ) @@ -432,20 +526,19 @@ def test_remove_multiple_features(sample_dataset, tmp_path): mock_snapshot_download.side_effect = lambda repo_id, **kwargs: str(kwargs.get("local_dir", tmp_path)) dataset = sample_dataset + features = {} for feature_name in ["reward", "success"]: feature_info = {"dtype": "float32", "shape": (1,), "names": None} - dataset = add_feature( - dataset, - feature_name=feature_name, - feature_values=np.random.randn(dataset.meta.total_frames, 1).astype(np.float32), - feature_info=feature_info, - output_dir=tmp_path / f"with_{feature_name}", + features[feature_name] = ( + np.random.randn(dataset.meta.total_frames, 1).astype(np.float32), + feature_info, ) + dataset_with_features = add_features( + dataset, features=features, output_dir=tmp_path / "with_features" + ) dataset_clean = remove_feature( - dataset, - feature_names=["reward", "success"], - output_dir=tmp_path / "clean", + dataset_with_features, feature_names=["reward", "success"], output_dir=tmp_path / "clean" ) assert "reward" not in dataset_clean.meta.features @@ -509,11 +602,14 @@ def test_complex_workflow_integration(sample_dataset, tmp_path): mock_get_safe_version.return_value = "v3.0" mock_snapshot_download.side_effect = lambda repo_id, **kwargs: str(kwargs.get("local_dir", tmp_path)) - dataset = add_feature( + dataset = add_features( sample_dataset, - feature_name="reward", - feature_values=np.random.randn(50, 1).astype(np.float32), - feature_info={"dtype": "float32", "shape": (1,), "names": None}, + features={ + "reward": ( + np.random.randn(50, 1).astype(np.float32), + {"dtype": "float32", "shape": (1,), "names": None}, + ) + }, output_dir=tmp_path / "step1", ) @@ -753,7 +849,7 @@ def test_merge_preserves_stats(sample_dataset, tmp_path, empty_lerobot_dataset_f assert "std" in merged.meta.stats[feature] -def test_add_feature_preserves_existing_stats(sample_dataset, tmp_path): +def test_add_features_preserves_existing_stats(sample_dataset, tmp_path): """Test that adding a feature preserves existing stats.""" num_frames = sample_dataset.meta.total_frames reward_values = np.random.randn(num_frames, 1).astype(np.float32) @@ -763,6 +859,9 @@ def test_add_feature_preserves_existing_stats(sample_dataset, tmp_path): "shape": (1,), "names": None, } + features = { + "reward": (reward_values, feature_info), + } with ( patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, @@ -771,11 +870,9 @@ def test_add_feature_preserves_existing_stats(sample_dataset, tmp_path): mock_get_safe_version.return_value = "v3.0" mock_snapshot_download.return_value = str(tmp_path / "with_reward") - new_dataset = add_feature( - sample_dataset, - feature_name="reward", - feature_values=reward_values, - feature_info=feature_info, + new_dataset = add_features( + dataset=sample_dataset, + features=features, output_dir=tmp_path / "with_reward", ) @@ -797,11 +894,11 @@ def test_remove_feature_updates_stats(sample_dataset, tmp_path): mock_get_safe_version.return_value = "v3.0" mock_snapshot_download.side_effect = lambda repo_id, **kwargs: str(kwargs.get("local_dir", tmp_path)) - dataset_with_reward = add_feature( + dataset_with_reward = add_features( sample_dataset, - feature_name="reward", - feature_values=np.random.randn(50, 1).astype(np.float32), - feature_info=feature_info, + features={ + "reward": (np.random.randn(50, 1).astype(np.float32), feature_info), + }, output_dir=tmp_path / "with_reward", ) @@ -893,3 +990,60 @@ def mock_snapshot(repo_id, **kwargs): total_episodes = sum(ds.meta.total_episodes for ds in result.values()) assert total_episodes == sample_dataset.meta.total_episodes + + +def test_modify_features_preserves_file_structure(sample_dataset, tmp_path): + """Test that modifying features preserves chunk_idx and file_idx from source dataset.""" + feature_info = {"dtype": "float32", "shape": (1,), "names": None} + + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + + def mock_snapshot(repo_id, **kwargs): + return str(kwargs.get("local_dir", tmp_path / repo_id.split("/")[-1])) + + mock_snapshot_download.side_effect = mock_snapshot + + # First split the dataset to create a non-zero starting chunk/file structure + splits = split_dataset( + sample_dataset, + splits={"train": [0, 1, 2], "val": [3, 4]}, + output_dir=tmp_path / "splits", + ) + + train_dataset = splits["train"] + + # Get original chunk/file indices from first episode + if train_dataset.meta.episodes is None: + from lerobot.datasets.utils import load_episodes + + train_dataset.meta.episodes = load_episodes(train_dataset.meta.root) + original_chunk_indices = [ep["data/chunk_index"] for ep in train_dataset.meta.episodes] + original_file_indices = [ep["data/file_index"] for ep in train_dataset.meta.episodes] + + # Now add a feature to the split dataset + modified_dataset = add_features( + train_dataset, + features={ + "reward": ( + np.random.randn(train_dataset.meta.total_frames, 1).astype(np.float32), + feature_info, + ), + }, + output_dir=tmp_path / "modified", + ) + + # Check that chunk/file indices are preserved + if modified_dataset.meta.episodes is None: + from lerobot.datasets.utils import load_episodes + + modified_dataset.meta.episodes = load_episodes(modified_dataset.meta.root) + new_chunk_indices = [ep["data/chunk_index"] for ep in modified_dataset.meta.episodes] + new_file_indices = [ep["data/file_index"] for ep in modified_dataset.meta.episodes] + + assert new_chunk_indices == original_chunk_indices, "Chunk indices should be preserved" + assert new_file_indices == original_file_indices, "File indices should be preserved" + assert "reward" in modified_dataset.meta.features