Skip to content

Commit 8e940bf

Browse files
Feat/expand add features (#2202)
* make add_feature take multiple features at a time and rename to add_features * - New function: modify_features that was a combination of remove features and add features. - This function is important for when we want to add a feature and remove another so we can do it in one time to avoid copying and creating the dataset multiple times
1 parent 6e8be57 commit 8e940bf

File tree

3 files changed

+393
-158
lines changed

3 files changed

+393
-158
lines changed

examples/dataset/use_dataset_tools.py

Lines changed: 33 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,10 @@
3030
import numpy as np
3131

3232
from lerobot.datasets.dataset_tools import (
33-
add_feature,
33+
add_features,
3434
delete_episodes,
3535
merge_datasets,
36+
modify_features,
3637
remove_feature,
3738
split_dataset,
3839
)
@@ -57,50 +58,56 @@ def main():
5758
print(f"Train split: {splits['train'].meta.total_episodes} episodes")
5859
print(f"Val split: {splits['val'].meta.total_episodes} episodes")
5960

60-
print("\n3. Adding a reward feature...")
61+
print("\n3. Adding features...")
6162

6263
reward_values = np.random.randn(dataset.meta.total_frames).astype(np.float32)
63-
dataset_with_reward = add_feature(
64-
dataset,
65-
feature_name="reward",
66-
feature_values=reward_values,
67-
feature_info={
68-
"dtype": "float32",
69-
"shape": (1,),
70-
"names": None,
71-
},
72-
repo_id="lerobot/pusht_with_reward",
73-
)
7464

7565
def compute_success(row_dict, episode_index, frame_index):
7666
episode_length = 10
7767
return float(frame_index >= episode_length - 10)
7868

79-
dataset_with_success = add_feature(
80-
dataset_with_reward,
81-
feature_name="success",
82-
feature_values=compute_success,
83-
feature_info={
84-
"dtype": "float32",
85-
"shape": (1,),
86-
"names": None,
69+
dataset_with_features = add_features(
70+
dataset,
71+
features={
72+
"reward": (
73+
reward_values,
74+
{"dtype": "float32", "shape": (1,), "names": None},
75+
),
76+
"success": (
77+
compute_success,
78+
{"dtype": "float32", "shape": (1,), "names": None},
79+
),
8780
},
88-
repo_id="lerobot/pusht_with_reward_and_success",
81+
repo_id="lerobot/pusht_with_features",
8982
)
9083

91-
print(f"New features: {list(dataset_with_success.meta.features.keys())}")
84+
print(f"New features: {list(dataset_with_features.meta.features.keys())}")
9285

9386
print("\n4. Removing the success feature...")
9487
dataset_cleaned = remove_feature(
95-
dataset_with_success, feature_names="success", repo_id="lerobot/pusht_cleaned"
88+
dataset_with_features, feature_names="success", repo_id="lerobot/pusht_cleaned"
9689
)
9790
print(f"Features after removal: {list(dataset_cleaned.meta.features.keys())}")
9891

99-
print("\n5. Merging train and val splits back together...")
92+
print("\n5. Using modify_features to add and remove features simultaneously...")
93+
dataset_modified = modify_features(
94+
dataset_with_features,
95+
add_features={
96+
"discount": (
97+
np.ones(dataset.meta.total_frames, dtype=np.float32) * 0.99,
98+
{"dtype": "float32", "shape": (1,), "names": None},
99+
),
100+
},
101+
remove_features="reward",
102+
repo_id="lerobot/pusht_modified",
103+
)
104+
print(f"Modified features: {list(dataset_modified.meta.features.keys())}")
105+
106+
print("\n6. Merging train and val splits back together...")
100107
merged = merge_datasets([splits["train"], splits["val"]], output_repo_id="lerobot/pusht_merged")
101108
print(f"Merged dataset: {merged.meta.total_episodes} episodes")
102109

103-
print("\n6. Complex workflow example...")
110+
print("\n7. Complex workflow example...")
104111

105112
if len(dataset.meta.camera_keys) > 1:
106113
camera_to_remove = dataset.meta.camera_keys[0]

0 commit comments

Comments
 (0)