Skip to content

Conversation

@michel-aractingi
Copy link
Collaborator

@michel-aractingi michel-aractingi commented Oct 1, 2025

Dataset Editing Tools

  • Introduced dataset tools for LeRobotDataset, including functions for:
    • deleting episodes from an existing dataset
    • splitting datasets into several subsets
    • adding/removing features
    • merging a list of datasets into one
  • Added an example script demonstrating the usage of these utilities in examples/dataset/use_dataset_tools.py.
  • Implemented comprehensive tests for all new functionalities to ensure reliability and correctness.
  • New script: src/lerobot/scripts/lerobot_edit_dataset.py to run a configurable script to edit your dataset with a simple cli.
  • Added lerobot-edit-dataset shortcut in pyproject

Usage examples

Delete episodes 0, 2, and 5 from a dataset:

python -m lerobot.scripts.lerobot_edit_dataset \
        --repo_id lerobot/pusht \
        --operation.type delete_episodes \
        --operation.episode_indices "[0, 2, 5]"

Delete episodes and save to a new dataset:

python -m lerobot.scripts.lerobot_edit_dataset \
        --repo_id lerobot/pusht \
        --new_repo_id lerobot/pusht_filtered \
        --operation.type delete_episodes \
        --operation.episode_indices "[0, 2, 5]"

Split dataset by fractions:

python -m lerobot.scripts.lerobot_edit_dataset \
        --repo_id lerobot/pusht \
        --operation.type split \
        --operation.splits '{"train": 0.8, "val": 0.2}'

Split dataset by episode indices:

python -m lerobot.scripts.lerobot_edit_dataset \
        --repo_id lerobot/pusht \
        --operation.type split \
        --operation.splits '{"train": [0, 1, 2, 3], "val": [4, 5]}'

Split into more than two splits:

python -m lerobot.scripts.lerobot_edit_dataset \
        --repo_id lerobot/pusht \
        --operation.type split \
        --operation.splits '{"train": 0.6, "val": 0.2, "test": 0.2}'

Merge multiple datasets:

python -m lerobot.scripts.lerobot_edit_dataset \
        --repo_id lerobot/pusht_merged \
        --operation.type merge \
        --operation.repo_ids "['lerobot/pusht_train', 'lerobot/pusht_val']"

Remove camera feature:

python -m lerobot.scripts.lerobot_edit_dataset \
        --repo_id lerobot/pusht \
        --operation.type remove_feature \
        --operation.feature_names "['observation.images.top']"

- Introduced dataset tools for LeRobotDataset, including functions for deleting episodes, splitting datasets, adding/removing features, and merging datasets.
- Added an example script demonstrating the usage of these utilities.
- Implemented comprehensive tests for all new functionalities to ensure reliability and correctness.
- copy unchanged video and parquet files to avoid recreating the entire dataset
- remove hardcoded split names
add lerobot-edit-dataset shortcut
@michel-aractingi michel-aractingi marked this pull request as ready for review October 2, 2025 09:11
Copilot AI review requested due to automatic review settings October 2, 2025 09:11
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR introduces dataset editing tools for LeRobotDataset, enabling users to modify datasets through operations like deleting episodes, splitting datasets, adding/removing features, and merging datasets. The implementation includes comprehensive functionality with CLI support.

  • Comprehensive dataset tools including delete episodes, split, merge, and feature manipulation
  • Command-line interface script with configurable operations and examples
  • Complete test coverage for all new functionalities

Reviewed Changes

Copilot reviewed 5 out of 5 changed files in this pull request and generated 3 comments.

Show a summary per file
File Description
tests/datasets/test_dataset_tools.py Comprehensive test suite covering all dataset tool operations
src/lerobot/scripts/lerobot_edit_dataset.py CLI script for dataset editing operations with detailed usage examples
src/lerobot/datasets/dataset_tools.py Core implementation of dataset manipulation functions
pyproject.toml Added CLI shortcut for the dataset editing script
examples/dataset/use_dataset_tools.py Example script demonstrating usage of dataset tools
Comments suppressed due to low confidence (1)

src/lerobot/datasets/dataset_tools.py:1

  • These imports should be moved to the top of the file rather than being placed inside the function to follow Python import conventions.
#!/usr/bin/env python

Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.

@michel-aractingi michel-aractingi changed the title [WIP] Dataset tools Dataset tools Oct 2, 2025
@jackvial
Copy link
Contributor

jackvial commented Oct 4, 2025

Hey Michel, this looks great and will be very useful!

I like the public API and I plan to update LeRobot Data Studio to use this API, which will also bring datasets v3 support to that app and make it easier to maintain across dataset version upgrade.

I'm running into an error when running the delete command. I think this might be a datasets conversion problem and all of the files are not getting correctly copied locally. I encountered a pyav codec error when converting this dataset from v2.1 to v3 so maybe that is related, I opened a PR for that here #2115

╰➤ python -m lerobot.scripts.lerobot_edit_dataset         --repo_id jackvial/screwdriver_panel_center_080225_16_e5         --operation.type delete_episodes         --operation.episode_indices "[0, 2]"
Generating train split: 244 examples [00:00, 100069.44 examples/s]
Generating train split: 244 examples [00:00, 113084.00 examples/s]
Generating train split: 244 examples [00:00, 113284.28 examples/s]
Generating train split: 244 examples [00:00, 119907.46 examples/s]
Generating train split: 244 examples [00:00, 122067.05 examples/s]
INFO 2025-10-04 15:11:11 _dataset.py:155 Deleting episodes [0, 2] from jackvial/screwdriver_panel_center_080225_16_e5
INFO 2025-10-04 15:11:11 set_tools.py:98 Deleting 2 episodes from dataset
INFO 2025-10-04 15:11:11 et_tools.py:540 Processing videos for observation.images.screwdriver
Processing observation.images.screwdriver video files:   0%|                    | 0/1 [00:00<?, ?it/s]
Traceback (most recent call last):                                                                    
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/home/jack/code/lerobot/src/lerobot/scripts/lerobot_edit_dataset.py", line 277, in <module>
    main()
  File "/home/jack/code/lerobot/src/lerobot/scripts/lerobot_edit_dataset.py", line 273, in main
    edit_dataset()
  File "/home/jack/code/lerobot/src/lerobot/configs/parser.py", line 225, in wrapper_inner
    response = fn(cfg, *args, **kwargs)
               ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jack/code/lerobot/src/lerobot/scripts/lerobot_edit_dataset.py", line 260, in edit_dataset
    handle_delete_episodes(cfg)
  File "/home/jack/code/lerobot/src/lerobot/scripts/lerobot_edit_dataset.py", line 156, in handle_delete_episodes
    new_dataset = delete_episodes(
                  ^^^^^^^^^^^^^^^^
  File "/home/jack/code/lerobot/src/lerobot/datasets/dataset_tools.py", line 121, in delete_episodes
    video_metadata = _copy_and_reindex_videos(dataset, new_meta, episode_mapping)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jack/code/lerobot/src/lerobot/datasets/dataset_tools.py", line 611, in _copy_and_reindex_videos
    frames = decode_video_frames(
             ^^^^^^^^^^^^^^^^^^^^
  File "/home/jack/code/lerobot/src/lerobot/datasets/video_utils.py", line 69, in decode_video_frames
    return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jack/code/lerobot/src/lerobot/datasets/video_utils.py", line 248, in decode_video_frames_torchcodec
    decoder = decoder_cache.get_decoder(str(video_path))
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jack/code/lerobot/src/lerobot/datasets/video_utils.py", line 192, in get_decoder
    file_handle = fsspec.open(video_path).__enter__()
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jack/code/lerobot/venv/lib/python3.12/site-packages/fsspec/core.py", line 105, in __enter__
    f = self.fs.open(self.path, mode=mode)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jack/code/lerobot/venv/lib/python3.12/site-packages/fsspec/spec.py", line 1310, in open
    f = self._open(
        ^^^^^^^^^^^
  File "/home/jack/code/lerobot/venv/lib/python3.12/site-packages/fsspec/implementations/local.py", line 201, in _open
    return LocalFileOpener(path, mode, fs=self, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jack/code/lerobot/venv/lib/python3.12/site-packages/fsspec/implementations/local.py", line 365, in __init__
    self._open()
  File "/home/jack/code/lerobot/venv/lib/python3.12/site-packages/fsspec/implementations/local.py", line 370, in _open
    self.f = open(self.path, mode=self.mode)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
FileNotFoundError: [Errno 2] No such file or directory: '/home/jack/.cache/huggingface/lerobot/jackvial/screwdriver_panel_center_080225_16_e5/videos/observation.images.screwdriver/chunk-000/file-000.mp4'

I tried with a fresh dataset after fixing the dataset conversion bug and getting the same error, the dataset looks good on the hub but only the metadata is present on the local disk

╰➤ ls -lah /home/jack/.cache/huggingface/lerobot/jackvial/screwdriver_attach_panel_ls_080125_9_e8/
total 32K
drwxrwxr-x   3 jack jack 4.0K Oct  4 15:19 .
drwxrwxr-x 251 jack jack  20K Oct  4 15:19 ..
drwxrwxr-x   2 jack jack 4.0K Oct  4 15:19 meta
╰➤ python -m lerobot.scripts.lerobot_edit_dataset         --repo_id jackvial/screwdriver_attach_panel_ls_080125_9_e8         --operation.type delete_episodes         --operation.episode_indices "[0, 2]"
stats.json: 7.23kB [00:00, 11.0MB/s]                                            | 0/4 [00:00<?, ?it/s]
info.json: 4.04kB [00:00, 11.1MB/s]rquet:   0%|                           | 0.00/57.7k [00:00<?, ?B/s]
meta/tasks.parquet: 100%|████████████████████████████████████████| 2.92k/2.92k [00:00<00:00, 8.40kB/s]
meta/episodes/chunk-000/file-000.parquet: 100%|███████████████████| 57.7k/57.7k [00:00<00:00, 119kB/s]
Fetching 4 files: 100%|█████████████████████████████████████████████████| 4/4 [00:00<00:00,  5.83it/s]
Generating train split: 8 examples [00:00, 365.09 examples/s]
README.md: 4.53kB [00:00, 23.5MB/s]                                            | 0/10 [00:00<?, ?it/s]
.gitattributes: 2.46kB [00:00, 18.9MB/s]                                  | 0.00/60.3k [00:00<?, ?B/s]
data/chunk-000/file-000.parquet: 100%|████████████████████████████| 60.3k/60.3k [00:00<00:00, 424kB/s]
videos/observation.images.side/chunk-000(…): 100%|███████████████| 15.5M/15.5M [00:00<00:00, 35.7MB/s]
videos/observation.images.top/chunk-000/(…): 100%|███████████████| 22.6M/22.6M [00:00<00:00, 33.5MB/s]
videos/observation.images.screwdriver/ch(…): 100%|███████████████| 31.4M/31.4M [00:00<00:00, 44.6MB/s]
Fetching 10 files: 100%|██████████████████████████████████████████████| 10/10 [00:00<00:00, 11.12it/s]
Generating train split: 1822 examples [00:00, 524503.90 examples/s]31.4M/31.4M [00:00<00:00, 47.0MB/s]
INFO 2025-10-04 15:19:08 _dataset.py:155 Deleting episodes [0, 2] from jackvial/screwdriver_attach_panel_ls_080125_9_e8
INFO 2025-10-04 15:19:08 set_tools.py:98 Deleting 2 episodes from dataset
INFO 2025-10-04 15:19:08 et_tools.py:540 Processing videos for observation.images.screwdriver
Processing observation.images.screwdriver video files:   0%|                    | 0/1 [00:00<?, ?it/s]
Traceback (most recent call last):                                                                    
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/home/jack/code/lerobot/src/lerobot/scripts/lerobot_edit_dataset.py", line 277, in <module>
    main()
  File "/home/jack/code/lerobot/src/lerobot/scripts/lerobot_edit_dataset.py", line 273, in main
    edit_dataset()
  File "/home/jack/code/lerobot/src/lerobot/configs/parser.py", line 225, in wrapper_inner
    response = fn(cfg, *args, **kwargs)
               ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jack/code/lerobot/src/lerobot/scripts/lerobot_edit_dataset.py", line 260, in edit_dataset
    handle_delete_episodes(cfg)
  File "/home/jack/code/lerobot/src/lerobot/scripts/lerobot_edit_dataset.py", line 156, in handle_delete_episodes
    new_dataset = delete_episodes(
                  ^^^^^^^^^^^^^^^^
  File "/home/jack/code/lerobot/src/lerobot/datasets/dataset_tools.py", line 121, in delete_episodes
    video_metadata = _copy_and_reindex_videos(dataset, new_meta, episode_mapping)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jack/code/lerobot/src/lerobot/datasets/dataset_tools.py", line 611, in _copy_and_reindex_videos
    frames = decode_video_frames(
             ^^^^^^^^^^^^^^^^^^^^
  File "/home/jack/code/lerobot/src/lerobot/datasets/video_utils.py", line 69, in decode_video_frames
    return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jack/code/lerobot/src/lerobot/datasets/video_utils.py", line 248, in decode_video_frames_torchcodec
    decoder = decoder_cache.get_decoder(str(video_path))
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jack/code/lerobot/src/lerobot/datasets/video_utils.py", line 192, in get_decoder
    file_handle = fsspec.open(video_path).__enter__()
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jack/code/lerobot/venv/lib/python3.12/site-packages/fsspec/core.py", line 105, in __enter__
    f = self.fs.open(self.path, mode=mode)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jack/code/lerobot/venv/lib/python3.12/site-packages/fsspec/spec.py", line 1310, in open
    f = self._open(
        ^^^^^^^^^^^
  File "/home/jack/code/lerobot/venv/lib/python3.12/site-packages/fsspec/implementations/local.py", line 201, in _open
    return LocalFileOpener(path, mode, fs=self, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jack/code/lerobot/venv/lib/python3.12/site-packages/fsspec/implementations/local.py", line 365, in __init__
    self._open()
  File "/home/jack/code/lerobot/venv/lib/python3.12/site-packages/fsspec/implementations/local.py", line 370, in _open
    self.f = open(self.path, mode=self.mode)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
FileNotFoundError: [Errno 2] No such file or directory: '/home/jack/.cache/huggingface/lerobot/jackvial/screwdriver_attach_panel_ls_080125_9_e8/videos/observation.images.screwdriver/chunk-000/file-000.mp4'

@michel-aractingi
Copy link
Collaborator Author

Thanks @jackvial This was a bug in the dataset root. fixed in the altest commit 4fd895d

Let me know if it works now!

@Keith-Luo
Copy link

Keith-Luo commented Oct 5, 2025

lerobot-train --policy.path=lerobot/smolvla_base --dataset.repo_id=Keith-Luo/pick_wine_bottle_and_pour --batch_size=64 --steps=20000 --output_dir=outputs/train/pick_wine_bottle_and_pour_smovla --job_name=pick_wine_bottle_and_pour_smolvla --policy.device=cuda --wandb.enable=true --policy.repo_id=Keith-Luo/pick_wine_bottle_and_pour_smolvla

INFO 2025-10-05 20:11:24 ils/utils.py:48 Cuda backend detected, using cuda.
WARNING 2025-10-05 20:11:24 /policies.py:80 Device 'None' is not available. Switching to 'cuda'.
INFO 2025-10-05 20:11:24 ts/train.py:111 {'batch_size': 64,
 'dataset': {'episodes': None,
             'image_transforms': {'enable': False,
                                  'max_num_transforms': 3,
                                  'random_order': False,
                                  'tfs': {'brightness': {'kwargs': {'brightness': [0.8,
                                                                                   1.2]},
                                                         'type': 'ColorJitter',
                                                         'weight': 1.0},
                                          'contrast': {'kwargs': {'contrast': [0.8,
                                                                               1.2]},
                                                       'type': 'ColorJitter',
                                                       'weight': 1.0},
                                          'hue': {'kwargs': {'hue': [-0.05,
                                                                     0.05]},
                                                  'type': 'ColorJitter',
                                                  'weight': 1.0},
                                          'saturation': {'kwargs': {'saturation': [0.5,
                                                                                   1.5]},
                                                         'type': 'ColorJitter',
                                                         'weight': 1.0},
                                          'sharpness': {'kwargs': {'sharpness': [0.5,
                                                                                 1.5]},
                                                        'type': 'SharpnessJitter',
                                                        'weight': 1.0}}},
             'repo_id': 'Keith-Luo/pick_wine_bottle_and_pour',
             'revision': None,
             'root': None,
             'streaming': False,
             'use_imagenet_stats': True,
             'video_backend': 'torchcodec'},
 'env': None,
 'eval': {'batch_size': 50, 'n_episodes': 50, 'use_async_envs': False},
 'eval_freq': 20000,
 'job_name': 'pick_wine_bottle_and_pour_smolvla',
 'log_freq': 200,
 'num_workers': 4,
 'optimizer': {'betas': [0.9, 0.95],
               'eps': 1e-08,
               'grad_clip_norm': 10.0,
               'lr': 0.0001,
               'type': 'adamw',
               'weight_decay': 1e-10},
 'output_dir': 'outputs/train/pick_wine_bottle_and_pour_smovla',
 'policy': {'adapt_to_pi_aloha': False,
            'add_image_special_tokens': False,
            'attention_mode': 'cross_attn',
            'chunk_size': 50,
            'device': 'cuda',
            'empty_cameras': 0,
            'expert_width_multiplier': 0.75,
            'freeze_vision_encoder': True,
            'input_features': {'observation.image': {'shape': [3, 256, 256],
                                                     'type': <FeatureType.VISUAL: 'VISUAL'>},
                               'observation.image2': {'shape': [3, 256, 256],
                                                      'type': <FeatureType.VISUAL: 'VISUAL'>},
                               'observation.image3': {'shape': [3, 256, 256],
                                                      'type': <FeatureType.VISUAL: 'VISUAL'>},
                               'observation.state': {'shape': [6],
                                                     'type': <FeatureType.STATE: 'STATE'>}},
            'license': None,
            'load_vlm_weights': True,
            'max_action_dim': 32,
            'max_period': 4.0,
            'max_state_dim': 32,
            'min_period': 0.004,
            'n_action_steps': 50,
            'n_obs_steps': 1,
            'normalization_mapping': {'ACTION': <NormalizationMode.MEAN_STD: 'MEAN_STD'>,
                                      'STATE': <NormalizationMode.MEAN_STD: 'MEAN_STD'>,
                                      'VISUAL': <NormalizationMode.IDENTITY: 'IDENTITY'>},
            'num_expert_layers': 0,
            'num_steps': 10,
            'num_vlm_layers': 16,
            'optimizer_betas': [0.9, 0.95],
            'optimizer_eps': 1e-08,
            'optimizer_grad_clip_norm': 10.0,
            'optimizer_lr': 0.0001,
            'optimizer_weight_decay': 1e-10,
            'output_features': {'action': {'shape': [6],
                                           'type': <FeatureType.ACTION: 'ACTION'>}},
            'pad_language_to': 'max_length',
            'prefix_length': 0,
            'private': None,
            'push_to_hub': True,
            'repo_id': 'Keith-Luo/pick_wine_bottle_and_pour_smolvla',
            'resize_imgs_with_padding': [512, 512],
            'scheduler_decay_lr': 2.5e-06,
            'scheduler_decay_steps': 30000,
            'scheduler_warmup_steps': 1000,
            'self_attn_every_n_layers': 2,
            'tags': None,
            'tokenizer_max_length': 48,
            'train_expert_only': True,
            'train_state_proj': True,
            'type': 'smolvla',
            'use_amp': False,
            'use_cache': True,
            'use_delta_joint_actions_aloha': False,
            'vlm_model_name': 'HuggingFaceTB/SmolVLM2-500M-Video-Instruct'},
 'resume': False,
 'save_checkpoint': True,
 'save_freq': 20000,
 'scheduler': {'decay_lr': 2.5e-06,
               'num_decay_steps': 30000,
               'num_warmup_steps': 1000,
               'peak_lr': 0.0001,
               'type': 'cosine_decay_with_warmup'},
 'seed': 1000,
 'steps': 20000,
 'use_policy_training_preset': True,
 'wandb': {'disable_artifact': False,
           'enable': True,
           'entity': None,
           'mode': None,
           'notes': None,
           'project': 'lerobot',
           'run_id': None}}
Logs will be synced with wandb.
INFO 2025-10-05 20:11:26 db_utils.py:103 Track this run --> https://wandb.ai/xlerobot_hackathon/lerobot/runs/6dn56hgy
INFO 2025-10-05 20:11:26 ts/train.py:127 Creating dataset
INFO 2025-10-05 20:11:26 ts/train.py:138 Creating policy
Loading  HuggingFaceTB/SmolVLM2-500M-Video-Instruct weights ...
INFO 2025-10-05 20:11:30 odeling.py:1004 We will use 90% of the memory on device 0 for storing the model, and 10% for the buffer to avoid OOM. You can set `max_memory` in to a higher value to use more memory (at your own risk).
Reducing the number of VLM layers to 16 ...
[standardise_state_dict] 'normalize_inputs.buffer_observation_state.mean'  ←  ['normalize_inputs.so100-red_buffer_observation_state.mean', 'normalize_inputs.so100_buffer_observation_state.mean']
[standardise_state_dict] 'normalize_inputs.buffer_observation_state.std'  ←  ['normalize_inputs.so100-red_buffer_observation_state.std', 'normalize_inputs.so100_buffer_observation_state.std']
[standardise_state_dict] 'normalize_targets.buffer_action.mean'  ←  ['normalize_targets.so100-red_buffer_action.mean', 'normalize_targets.so100_buffer_action.mean']
[standardise_state_dict] 'normalize_targets.buffer_action.std'  ←  ['normalize_targets.so100-red_buffer_action.std', 'normalize_targets.so100_buffer_action.std']
[standardise_state_dict] 'unnormalize_outputs.buffer_action.mean'  ←  ['unnormalize_outputs.so100-red_buffer_action.mean', 'unnormalize_outputs.so100_buffer_action.mean']
[standardise_state_dict] 'unnormalize_outputs.buffer_action.std'  ←  ['unnormalize_outputs.so100-red_buffer_action.std', 'unnormalize_outputs.so100_buffer_action.std']
INFO 2025-10-05 20:11:37 ts/train.py:144 Creating optimizer and scheduler
INFO 2025-10-05 20:11:37 ts/train.py:156 Output dir: outputs/train/pick_wine_bottle_and_pour_smovla
INFO 2025-10-05 20:11:37 ts/train.py:159 cfg.steps=20000 (20K)
INFO 2025-10-05 20:11:37 ts/train.py:160 dataset.num_frames=20156 (20K)
INFO 2025-10-05 20:11:37 ts/train.py:161 dataset.num_episodes=21
INFO 2025-10-05 20:11:37 ts/train.py:162 num_learnable_params=99880992 (100M)
INFO 2025-10-05 20:11:37 ts/train.py:163 num_total_params=450046278 (450M)
INFO 2025-10-05 20:11:37 ts/train.py:204 Start offline training on a fixed dataset
Traceback (most recent call last):
  File "/home/robo001/anaconda3/envs/lerobot/bin/lerobot-train", line 7, in <module>
    sys.exit(main())
  File "/home/robo001/bartender/src/lerobot/scripts/train.py", line 296, in main
    train()
  File "/home/robo001/bartender/src/lerobot/configs/parser.py", line 225, in wrapper_inner
    response = fn(cfg, *args, **kwargs)
  File "/home/robo001/bartender/src/lerobot/scripts/train.py", line 207, in train
    batch = next(dl_iter)
  File "/home/robo001/bartender/src/lerobot/datasets/utils.py", line 598, in cycle
    yield next(iterator)
  File "/home/robo001/anaconda3/envs/lerobot/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 733, in __next__
    data = self._next_data()
  File "/home/robo001/anaconda3/envs/lerobot/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1515, in _next_data
    return self._process_data(data, worker_id)
  File "/home/robo001/anaconda3/envs/lerobot/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1550, in _process_data
    data.reraise()
  File "/home/robo001/anaconda3/envs/lerobot/lib/python3.10/site-packages/torch/_utils.py", line 750, in reraise
    raise exception
RuntimeError: Caught RuntimeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/robo001/anaconda3/envs/lerobot/lib/python3.10/site-packages/torch/utils/data/_utils/worker.py", line 349, in _worker_loop
    data = fetcher.fetch(index)  # type: ignore[possibly-undefined]
  File "/home/robo001/anaconda3/envs/lerobot/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 52, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/robo001/anaconda3/envs/lerobot/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 52, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/robo001/bartender/src/lerobot/datasets/lerobot_dataset.py", line 874, in __getitem__
    video_frames = self._query_videos(query_timestamps, ep_idx)
  File "/home/robo001/bartender/src/lerobot/datasets/lerobot_dataset.py", line 846, in _query_videos
    frames = decode_video_frames(video_path, shifted_query_ts, self.tolerance_s, self.video_backend)
  File "/home/robo001/bartender/src/lerobot/datasets/video_utils.py", line 69, in decode_video_frames
    return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s)
  File "/home/robo001/bartender/src/lerobot/datasets/video_utils.py", line 259, in decode_video_frames_torchcodec
    frames_batch = decoder.get_frames_at(indices=frame_indices)
  File "/home/robo001/anaconda3/envs/lerobot/lib/python3.10/site-packages/torchcodec/decoders/_video_decoder.py", line 227, in get_frames_at
    data, pts_seconds, duration_seconds = core.get_frames_at_indices(
  File "/home/robo001/anaconda3/envs/lerobot/lib/python3.10/site-packages/torch/_ops.py", line 756, in __call__
    return self._op(*args, **kwargs)
RuntimeError: Invalid frame index=27385 for streamIndex=0; must be less than 20156

are there still some bugs in merging script?

deleting script is okay for training

@michel-aractingi
Copy link
Collaborator Author

Hey @Keith-Luo just tried to train with a merged dataset and couldn't reproduce your error. Are you sure this is not a bug in your dataset.

@Keith-Luo
Copy link

Hey @Keith-Luo just tried to train with a merged dataset and couldn't reproduce your error. Are you sure this is not a bug in your dataset.

Maybe I can upload my dataset, and can you please try again? Hold on a second

@jackvial
Copy link
Contributor

jackvial commented Oct 8, 2025

Hello @jackvial can you check again if the latest changes fixes you push_to_hub issue?

Hey @michel-aractingi, yes I can test again this evening

@jackvial
Copy link
Contributor

jackvial commented Oct 8, 2025

@michel-aractingi delete episodes and push to hub are looking good now. Here's what I tested

 python -m lerobot.scripts.lerobot_edit_dataset         --repo_id jackvial/screwdriver_attach_panel_rs_080125_20_e5         --new_repo_id jackvial/screwdriver_attach_panel_rs_080125_20_e5_edited_4        --operation.type delete_episodes         --operation.episode_indic
es "[0, 2]" --push_to_hub=true

Successfully created new dataset ackvial/screwdriver_attach_panel_rs_080125_20_e5_edited_4

python -m lerobot.scripts.lerobot_edit_dataset         --repo_id jackvial/screwdriver_attach_panel_rs_080125_20_e5         --new_repo_id jackvial/screwdriver_attach_panel_rs_080125_20_e5_edited_5        --operation.type delete_episodes         --operation.episode_indic
es "[0, 2, 4]" --push_to_hub=true

Successfully created new dataset ackvial/screwdriver_attach_panel_rs_080125_20_e5_edited_5

@jackvial
Copy link
Contributor

jackvial commented Oct 8, 2025

@michel-aractingi split looks good. Worth emphasizing in the docs/comments that the split names can be anything you like, the examples of train, val, and test might make the user think those are special names that need to be used.

Split By Percentage

python -m lerobot.scripts.lerobot_edit_dataset \
        --repo_id jackvial/screwdriver_attach_panel_rs_080125_20_e5 \
        --operation.type split \
        --operation.splits '{"train": 0.8, "val": 0.2}' --push_to_hub=true

Split By Episode Selection

 python -m lerobot.scripts.lerobot_edit_dataset         --repo_id jackvial/screwdriver_attach_panel_rs_080125_20_e5         --operation.type split         --operation.splits '{"some_split": [0, 1, 2, 3], "some_other_split": [4]}' --push_to_hub=true

@jackvial
Copy link
Contributor

jackvial commented Oct 8, 2025

@michel-aractingi merge looks good but maybe change the name of --repo-id to --new-repo-id and check that this is not an existing dataset, as a user I wouldn't want or expect merge to overwrite an existing dataset, only create new datasets.

python -m lerobot.scripts.lerobot_edit_dataset \
        --repo_id jackvial/screwdriver_attach_panel_rs_080125_20_e5_merged_2 \
        --operation.type merge \
        --operation.repo_ids "['jackvial/screwdriver_attach_panel_rs_080125_20_e5_some_split', 'jackvial/screwdriver_attach_panel_rs_080125_20_e5_some_other_split']" --push_to_hub=true

@jackvial
Copy link
Contributor

jackvial commented Oct 8, 2025

@michel-aractingi Remove feature looks good

python -m lerobot.scripts.lerobot_edit_dataset \
        --repo_id jackvial/screwdriver_attach_panel_rs_080125_20_e5_split_to_remove_feature_from_0 \
        --operation.type remove_feature \
        --operation.feature_names "['observation.images.top']" --push_to_hub=true

Some considerations for a feature version of these dataset tools:

  • rename_feature Being able to rename a feature key e.g. merging dataset were the camera names are different a we want to map them to a common name.
  • Support renaming camera feature keys with the merge command e.g. merge --rename-features {"observation.images.phone": "observation.images.top", "observation.images.android": "observation.images.top"}

1. **Delete Episodes** - Remove specific episodes from a dataset
2. **Split Dataset** - Divide a dataset into multiple smaller datasets
3. **Merge Datasets** - Combine multiple datasets into one
4. **Add Features** - Add new features to a dataset
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be good to add an example for add feature

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, I think this and the renaming features would make nice small PRs for the community to do to get them engaged in the dataset

CarolinePascal and others added 4 commits October 9, 2025 23:18
…ata info when splitting and aggregating datasets
… merging

There were three critical bugs in aggregate.py that prevented correct dataset merging:

1. Video file indices: Changed from += to = assignment to correctly reference
   merged video files

2. Video timestamps: Implemented per-source-file offset tracking to maintain
   continuous timestamps when merging split datasets (was causing non-monotonic
   timestamp warnings)

3. File rotation offsets: Store timestamp offsets after rotation decision to
   prevent out-of-bounds frame access (was causing "Invalid frame index" errors
   with small file size limits)

Changes:
- Updated update_meta_data() to apply per-source-file timestamp offsets
- Updated aggregate_videos() to track offsets correctly during file rotation
- Added get_video_duration_in_s import for duration calculation
…se that the split size results in zero episodes
@imstevenpmwork imstevenpmwork self-requested a review October 10, 2025 09:21
@imstevenpmwork imstevenpmwork self-requested a review October 10, 2025 10:05
@Keith-Luo
Copy link

Keith-Luo commented Oct 15, 2025

Hi @michel-aractingi , does it support merging two same dataset to become a dataset? I met some problems in this use case.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

dataset Issues regarding data inputs, processing, or datasets enhancement Suggestions for new features or improvements

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants