3030import numpy as np
3131
3232from lerobot .datasets .dataset_tools import (
33- add_feature ,
33+ add_features ,
3434 delete_episodes ,
3535 merge_datasets ,
36+ modify_features ,
3637 remove_feature ,
3738 split_dataset ,
3839)
@@ -57,50 +58,56 @@ def main():
5758 print (f"Train split: { splits ['train' ].meta .total_episodes } episodes" )
5859 print (f"Val split: { splits ['val' ].meta .total_episodes } episodes" )
5960
60- print ("\n 3. Adding a reward feature ..." )
61+ print ("\n 3. Adding features ..." )
6162
6263 reward_values = np .random .randn (dataset .meta .total_frames ).astype (np .float32 )
63- dataset_with_reward = add_feature (
64- dataset ,
65- feature_name = "reward" ,
66- feature_values = reward_values ,
67- feature_info = {
68- "dtype" : "float32" ,
69- "shape" : (1 ,),
70- "names" : None ,
71- },
72- repo_id = "lerobot/pusht_with_reward" ,
73- )
7464
7565 def compute_success (row_dict , episode_index , frame_index ):
7666 episode_length = 10
7767 return float (frame_index >= episode_length - 10 )
7868
79- dataset_with_success = add_feature (
80- dataset_with_reward ,
81- feature_name = "success" ,
82- feature_values = compute_success ,
83- feature_info = {
84- "dtype" : "float32" ,
85- "shape" : (1 ,),
86- "names" : None ,
69+ dataset_with_features = add_features (
70+ dataset ,
71+ features = {
72+ "reward" : (
73+ reward_values ,
74+ {"dtype" : "float32" , "shape" : (1 ,), "names" : None },
75+ ),
76+ "success" : (
77+ compute_success ,
78+ {"dtype" : "float32" , "shape" : (1 ,), "names" : None },
79+ ),
8780 },
88- repo_id = "lerobot/pusht_with_reward_and_success " ,
81+ repo_id = "lerobot/pusht_with_features " ,
8982 )
9083
91- print (f"New features: { list (dataset_with_success .meta .features .keys ())} " )
84+ print (f"New features: { list (dataset_with_features .meta .features .keys ())} " )
9285
9386 print ("\n 4. Removing the success feature..." )
9487 dataset_cleaned = remove_feature (
95- dataset_with_success , feature_names = "success" , repo_id = "lerobot/pusht_cleaned"
88+ dataset_with_features , feature_names = "success" , repo_id = "lerobot/pusht_cleaned"
9689 )
9790 print (f"Features after removal: { list (dataset_cleaned .meta .features .keys ())} " )
9891
99- print ("\n 5. Merging train and val splits back together..." )
92+ print ("\n 5. Using modify_features to add and remove features simultaneously..." )
93+ dataset_modified = modify_features (
94+ dataset_with_features ,
95+ add_features = {
96+ "discount" : (
97+ np .ones (dataset .meta .total_frames , dtype = np .float32 ) * 0.99 ,
98+ {"dtype" : "float32" , "shape" : (1 ,), "names" : None },
99+ ),
100+ },
101+ remove_features = "reward" ,
102+ repo_id = "lerobot/pusht_modified" ,
103+ )
104+ print (f"Modified features: { list (dataset_modified .meta .features .keys ())} " )
105+
106+ print ("\n 6. Merging train and val splits back together..." )
100107 merged = merge_datasets ([splits ["train" ], splits ["val" ]], output_repo_id = "lerobot/pusht_merged" )
101108 print (f"Merged dataset: { merged .meta .total_episodes } episodes" )
102109
103- print ("\n 6 . Complex workflow example..." )
110+ print ("\n 7 . Complex workflow example..." )
104111
105112 if len (dataset .meta .camera_keys ) > 1 :
106113 camera_to_remove = dataset .meta .camera_keys [0 ]
0 commit comments