Skip to content
59 changes: 33 additions & 26 deletions examples/dataset/use_dataset_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,10 @@
import numpy as np

from lerobot.datasets.dataset_tools import (
add_feature,
add_features,
delete_episodes,
merge_datasets,
modify_features,
remove_feature,
split_dataset,
)
Expand All @@ -57,50 +58,56 @@ def main():
print(f"Train split: {splits['train'].meta.total_episodes} episodes")
print(f"Val split: {splits['val'].meta.total_episodes} episodes")

print("\n3. Adding a reward feature...")
print("\n3. Adding features...")

reward_values = np.random.randn(dataset.meta.total_frames).astype(np.float32)
dataset_with_reward = add_feature(
dataset,
feature_name="reward",
feature_values=reward_values,
feature_info={
"dtype": "float32",
"shape": (1,),
"names": None,
},
repo_id="lerobot/pusht_with_reward",
)

def compute_success(row_dict, episode_index, frame_index):
episode_length = 10
return float(frame_index >= episode_length - 10)

dataset_with_success = add_feature(
dataset_with_reward,
feature_name="success",
feature_values=compute_success,
feature_info={
"dtype": "float32",
"shape": (1,),
"names": None,
dataset_with_features = add_features(
dataset,
features={
"reward": (
reward_values,
{"dtype": "float32", "shape": (1,), "names": None},
),
"success": (
compute_success,
{"dtype": "float32", "shape": (1,), "names": None},
),
},
repo_id="lerobot/pusht_with_reward_and_success",
repo_id="lerobot/pusht_with_features",
)

print(f"New features: {list(dataset_with_success.meta.features.keys())}")
print(f"New features: {list(dataset_with_features.meta.features.keys())}")

print("\n4. Removing the success feature...")
dataset_cleaned = remove_feature(
dataset_with_success, feature_names="success", repo_id="lerobot/pusht_cleaned"
dataset_with_features, feature_names="success", repo_id="lerobot/pusht_cleaned"
)
print(f"Features after removal: {list(dataset_cleaned.meta.features.keys())}")

print("\n5. Merging train and val splits back together...")
print("\n5. Using modify_features to add and remove features simultaneously...")
dataset_modified = modify_features(
dataset_with_features,
add_features={
"discount": (
np.ones(dataset.meta.total_frames, dtype=np.float32) * 0.99,
{"dtype": "float32", "shape": (1,), "names": None},
),
},
remove_features="reward",
repo_id="lerobot/pusht_modified",
)
print(f"Modified features: {list(dataset_modified.meta.features.keys())}")

print("\n6. Merging train and val splits back together...")
merged = merge_datasets([splits["train"], splits["val"]], output_repo_id="lerobot/pusht_merged")
print(f"Merged dataset: {merged.meta.total_episodes} episodes")

print("\n6. Complex workflow example...")
print("\n7. Complex workflow example...")

if len(dataset.meta.camera_keys) > 1:
camera_to_remove = dataset.meta.camera_keys[0]
Expand Down
Loading