Skip to content
59 changes: 33 additions & 26 deletions examples/dataset/use_dataset_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,10 @@
import numpy as np

from lerobot.datasets.dataset_tools import (
add_feature,
add_features,
delete_episodes,
merge_datasets,
modify_features,
remove_feature,
split_dataset,
)
Expand All @@ -57,50 +58,56 @@ def main():
print(f"Train split: {splits['train'].meta.total_episodes} episodes")
print(f"Val split: {splits['val'].meta.total_episodes} episodes")

print("\n3. Adding a reward feature...")
print("\n3. Adding features...")

reward_values = np.random.randn(dataset.meta.total_frames).astype(np.float32)
dataset_with_reward = add_feature(
dataset,
feature_name="reward",
feature_values=reward_values,
feature_info={
"dtype": "float32",
"shape": (1,),
"names": None,
},
repo_id="lerobot/pusht_with_reward",
)

def compute_success(row_dict, episode_index, frame_index):
episode_length = 10
return float(frame_index >= episode_length - 10)

dataset_with_success = add_feature(
dataset_with_reward,
feature_name="success",
feature_values=compute_success,
feature_info={
"dtype": "float32",
"shape": (1,),
"names": None,
dataset_with_features = add_features(
dataset,
features={
"reward": (
reward_values,
{"dtype": "float32", "shape": (1,), "names": None},
),
"success": (
compute_success,
{"dtype": "float32", "shape": (1,), "names": None},
),
},
repo_id="lerobot/pusht_with_reward_and_success",
repo_id="lerobot/pusht_with_features",
)

print(f"New features: {list(dataset_with_success.meta.features.keys())}")
print(f"New features: {list(dataset_with_features.meta.features.keys())}")

print("\n4. Removing the success feature...")
dataset_cleaned = remove_feature(
dataset_with_success, feature_names="success", repo_id="lerobot/pusht_cleaned"
dataset_with_features, feature_names="success", repo_id="lerobot/pusht_cleaned"
)
print(f"Features after removal: {list(dataset_cleaned.meta.features.keys())}")

print("\n5. Merging train and val splits back together...")
print("\n5. Using modify_features to add and remove features simultaneously...")
dataset_modified = modify_features(
dataset_with_features,
add_features={
"discount": (
np.ones(dataset.meta.total_frames, dtype=np.float32) * 0.99,
{"dtype": "float32", "shape": (1,), "names": None},
),
},
remove_features="reward",
repo_id="lerobot/pusht_modified",
)
print(f"Modified features: {list(dataset_modified.meta.features.keys())}")

print("\n6. Merging train and val splits back together...")
merged = merge_datasets([splits["train"], splits["val"]], output_repo_id="lerobot/pusht_merged")
print(f"Merged dataset: {merged.meta.total_episodes} episodes")

print("\n6. Complex workflow example...")
print("\n7. Complex workflow example...")

if len(dataset.meta.camera_keys) > 1:
camera_to_remove = dataset.meta.camera_keys[0]
Expand Down
178 changes: 111 additions & 67 deletions src/lerobot/datasets/dataset_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,57 +268,98 @@ def merge_datasets(
return merged_dataset


def add_feature(
def modify_features(
dataset: LeRobotDataset,
feature_name: str,
feature_values: np.ndarray | torch.Tensor | Callable,
feature_info: dict,
add_features: dict[str, tuple[np.ndarray | torch.Tensor | Callable, dict]] | None = None,
remove_features: str | list[str] | None = None,
output_dir: str | Path | None = None,
repo_id: str | None = None,
) -> LeRobotDataset:
"""Add a new feature to a LeRobotDataset.
"""Modify a LeRobotDataset by adding and/or removing features in a single pass.

This is the most efficient way to modify features, as it only copies the dataset once
regardless of how many features are being added or removed.

Args:
dataset: The source LeRobotDataset.
feature_name: Name of the new feature.
feature_values: Either:
- Array/tensor of shape (num_frames, ...) with values for each frame
- Callable that takes (frame_dict, episode_index, frame_index) and returns feature value
feature_info: Dictionary with feature metadata (dtype, shape, names).
add_features: Optional dict mapping feature names to (feature_values, feature_info) tuples.
remove_features: Optional feature name(s) to remove. Can be a single string or list.
output_dir: Directory to save the new dataset. If None, uses default location.
repo_id: Repository ID for the new dataset. If None, appends "_modified" to original.

Returns:
New dataset with features modified.

Example:
new_dataset = modify_features(
dataset,
add_features={
"reward": (reward_array, {"dtype": "float32", "shape": [1], "names": None}),
},
remove_features=["old_feature"],
output_dir="./output",
)
"""
if feature_name in dataset.meta.features:
raise ValueError(f"Feature '{feature_name}' already exists in dataset")
if add_features is None and remove_features is None:
raise ValueError("Must specify at least one of add_features or remove_features")

remove_features_list: list[str] = []
if remove_features is not None:
remove_features_list = [remove_features] if isinstance(remove_features, str) else remove_features

if add_features:
required_keys = {"dtype", "shape"}
for feature_name, (_, feature_info) in add_features.items():
if feature_name in dataset.meta.features:
raise ValueError(f"Feature '{feature_name}' already exists in dataset")

if not required_keys.issubset(feature_info.keys()):
raise ValueError(f"feature_info for '{feature_name}' must contain keys: {required_keys}")

if remove_features_list:
for name in remove_features_list:
if name not in dataset.meta.features:
raise ValueError(f"Feature '{name}' not found in dataset")

required_features = {"timestamp", "frame_index", "episode_index", "index", "task_index"}
if any(name in required_features for name in remove_features_list):
raise ValueError(f"Cannot remove required features: {required_features}")

if repo_id is None:
repo_id = f"{dataset.repo_id}_modified"
output_dir = Path(output_dir) if output_dir is not None else HF_LEROBOT_HOME / repo_id

required_keys = {"dtype", "shape"}
if not required_keys.issubset(feature_info.keys()):
raise ValueError(f"feature_info must contain keys: {required_keys}")

new_features = dataset.meta.features.copy()
new_features[feature_name] = feature_info

if remove_features_list:
for name in remove_features_list:
new_features.pop(name, None)

if add_features:
for feature_name, (_, feature_info) in add_features.items():
new_features[feature_name] = feature_info

video_keys_to_remove = [name for name in remove_features_list if name in dataset.meta.video_keys]
remaining_video_keys = [k for k in dataset.meta.video_keys if k not in video_keys_to_remove]

new_meta = LeRobotDatasetMetadata.create(
repo_id=repo_id,
fps=dataset.meta.fps,
features=new_features,
robot_type=dataset.meta.robot_type,
root=output_dir,
use_videos=len(dataset.meta.video_keys) > 0,
use_videos=len(remaining_video_keys) > 0,
)

_copy_data_with_feature_changes(
dataset=dataset,
new_meta=new_meta,
add_features={feature_name: (feature_values, feature_info)},
add_features=add_features,
remove_features=remove_features_list if remove_features_list else None,
)

if dataset.meta.video_keys:
_copy_videos(dataset, new_meta)
if new_meta.video_keys:
_copy_videos(dataset, new_meta, exclude_keys=video_keys_to_remove if video_keys_to_remove else None)

new_dataset = LeRobotDataset(
repo_id=repo_id,
Expand All @@ -331,70 +372,71 @@ def add_feature(
return new_dataset


def remove_feature(
def add_features(
dataset: LeRobotDataset,
feature_names: str | list[str],
features: dict[str, tuple[np.ndarray | torch.Tensor | Callable, dict]],
output_dir: str | Path | None = None,
repo_id: str | None = None,
) -> LeRobotDataset:
"""Remove features from a LeRobotDataset.
"""Add multiple features to a LeRobotDataset in a single pass.

This is more efficient than calling add_feature() multiple times, as it only
copies the dataset once regardless of how many features are being added.

Args:
dataset: The source LeRobotDataset.
feature_names: Name(s) of features to remove. Can be a single string or list.
features: Dictionary mapping feature names to (feature_values, feature_info) tuples.
output_dir: Directory to save the new dataset. If None, uses default location.
repo_id: Repository ID for the new dataset. If None, appends "_modified" to original.

"""
if isinstance(feature_names, str):
feature_names = [feature_names]

for name in feature_names:
if name not in dataset.meta.features:
raise ValueError(f"Feature '{name}' not found in dataset")

required_features = {"timestamp", "frame_index", "episode_index", "index", "task_index"}
if any(name in required_features for name in feature_names):
raise ValueError(f"Cannot remove required features: {required_features}")
Returns:
New dataset with all features added.

if repo_id is None:
repo_id = f"{dataset.repo_id}_modified"
output_dir = Path(output_dir) if output_dir is not None else HF_LEROBOT_HOME / repo_id
Example:
features = {
"task_embedding": (task_emb_array, {"dtype": "float32", "shape": [384], "names": None}),
"cam1_embedding": (cam1_emb_array, {"dtype": "float32", "shape": [768], "names": None}),
"cam2_embedding": (cam2_emb_array, {"dtype": "float32", "shape": [768], "names": None}),
}
new_dataset = add_features(dataset, features, output_dir="./output", repo_id="my_dataset")
"""
if not features:
raise ValueError("No features provided")

new_features = {k: v for k, v in dataset.meta.features.items() if k not in feature_names}
return modify_features(
dataset=dataset,
add_features=features,
remove_features=None,
output_dir=output_dir,
repo_id=repo_id,
)

video_keys_to_remove = [name for name in feature_names if name in dataset.meta.video_keys]

remaining_video_keys = [k for k in dataset.meta.video_keys if k not in video_keys_to_remove]
def remove_feature(
dataset: LeRobotDataset,
feature_names: str | list[str],
output_dir: str | Path | None = None,
repo_id: str | None = None,
) -> LeRobotDataset:
"""Remove features from a LeRobotDataset.

new_meta = LeRobotDatasetMetadata.create(
repo_id=repo_id,
fps=dataset.meta.fps,
features=new_features,
robot_type=dataset.meta.robot_type,
root=output_dir,
use_videos=len(remaining_video_keys) > 0,
)
Args:
dataset: The source LeRobotDataset.
feature_names: Name(s) of features to remove. Can be a single string or list.
output_dir: Directory to save the new dataset. If None, uses default location.
repo_id: Repository ID for the new dataset. If None, appends "_modified" to original.

_copy_data_with_feature_changes(
Returns:
New dataset with features removed.
"""
return modify_features(
dataset=dataset,
new_meta=new_meta,
add_features=None,
remove_features=feature_names,
)

if new_meta.video_keys:
_copy_videos(dataset, new_meta, exclude_keys=video_keys_to_remove)

new_dataset = LeRobotDataset(
output_dir=output_dir,
repo_id=repo_id,
root=output_dir,
image_transforms=dataset.image_transforms,
delta_timestamps=dataset.delta_timestamps,
tolerance_s=dataset.tolerance_s,
)

return new_dataset


def _fractions_to_episode_indices(
total_episodes: int,
Expand Down Expand Up @@ -911,6 +953,8 @@ def _copy_data_with_feature_changes(
file_paths.add(dataset.meta.get_data_file_path(ep_idx))

frame_idx = 0
chunk_idx = 0
file_idx = 0

for src_path in tqdm(sorted(file_paths), desc="Processing data files"):
df = pd.read_parquet(dataset.root / src_path).reset_index(drop=True)
Expand All @@ -919,6 +963,7 @@ def _copy_data_with_feature_changes(
df = df.drop(columns=remove_features, errors="ignore")

if add_features:
end_idx = frame_idx + len(df)
for feature_name, (values, _) in add_features.items():
if callable(values):
feature_values = []
Expand All @@ -931,15 +976,14 @@ def _copy_data_with_feature_changes(
feature_values.append(value)
df[feature_name] = feature_values
else:
end_idx = frame_idx + len(df)
feature_slice = values[frame_idx:end_idx]
if len(feature_slice.shape) > 1 and feature_slice.shape[1] == 1:
df[feature_name] = feature_slice.flatten()
else:
df[feature_name] = feature_slice
frame_idx = end_idx
frame_idx = end_idx

_save_data_chunk(df, new_meta)
chunk_idx, file_idx, _ = _save_data_chunk(df, new_meta, chunk_idx, file_idx)

_copy_episodes_metadata_and_stats(dataset, new_meta)

Expand Down
Loading