Skip to content

Conversation

@pkooij
Copy link
Member

@pkooij pkooij commented Oct 8, 2025

This PR introduces a way to generate image and text embeddings to a dataset to be more efficient when training a dataset for multiple epochs. For example for learning a general reward we contain a specific dataset with OXE to improve generalization. In order to nor recompute the image and text embeddings each time we finetune for OXE we can use this script to add the embeddings to the dataset. We can additionally remove the videos in the dataset to safe space.

Testing:
Both the generate and validate script were tested on this dataset: lerobot/utokyo_xarm_bimanual. The generated dataset can be found here: pepijn223/utokyo_xarm_bimanual_embeddings.

@pkooij pkooij self-assigned this Oct 8, 2025
@pkooij pkooij added dataset Issues regarding data inputs, processing, or datasets performance Issues aimed at improving speed or resource usage policies Items related to robot policies labels Oct 8, 2025
@pkooij pkooij marked this pull request as ready for review October 8, 2025 09:39
Comment on lines +201 to +228
if isinstance(img, torch.Tensor):
if img.ndim == 4:
# Shape: (T, C, H, W) where T=1 for single timestamp
img = img[0] # Now (C, H, W)
elif img.ndim == 3:
# Shape: (C, H, W)
pass
else:
raise ValueError(
f"Unexpected video frame shape {img.shape} for camera {cam_key}. "
f"Expected (T, C, H, W) or (C, H, W). Episode {ep_idx}, Frame {frame_idx}"
)

# Convert to numpy: (C, H, W) float32 [0, 1] -> (H, W, C) uint8 [0, 255]
img_np = (img.permute(1, 2, 0).numpy() * 255).astype(np.uint8)
else:
img_np = np.array(img)
else:
# Load from image file
img = item[cam_key]
# Convert to numpy if needed
if isinstance(img, torch.Tensor):
if img.ndim == 3:
img_np = (img.permute(1, 2, 0).numpy() * 255).astype(np.uint8)
else:
raise ValueError(f"Unexpected image shape {img.shape} for camera {cam_key}")
else:
img_np = np.array(img)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add testing for this part,
this logic could be buggy if we have image datasets instead of videos.
Better if we had tests to verify

self.batch_size = batch_size
self.model_name = model_name
logger.info(f"Loading DinoV2 model: {model_name}")
self.model = torch.hub.load("facebookresearch/dinov2", model_name) # nosec B614
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason why you didn't use AutoModel from transformers here also? We do it for the SAC encoder.

self.image_enc_layers = AutoModel.from_pretrained(config.vision_encoder_name, trust_remote_code=True)

Something like:

        self.model = AutoModel.from_pretrained("facebook/dinov2_base"):

from lerobot.datasets.lerobot_dataset import LeRobotDataset


def cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice

Comment on lines +275 to +292
for (chunk_idx, file_idx), frame_indices in tqdm(frames_by_file.items(), desc="Updating parquet files"):
parquet_path = output_dataset.root / DEFAULT_DATA_PATH.format(
chunk_index=chunk_idx, file_index=file_idx
)

# Load the parquet file
df = pd.read_parquet(parquet_path)

# Add embedding columns
df["task_embedding"] = [all_task_embeddings[idx].tolist() for idx in frame_indices]

for cam_key in dataset.meta.camera_keys:
df[f"{cam_key}_embedding"] = [
image_embeddings_dict[cam_key][idx].tolist() for idx in frame_indices
]

# Save the updated parquet file
df.to_parquet(parquet_path, index=False)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For this chunk I suggest to do it in a more effiecient way either by using dataset_tools. I'll push a PR.

Reading a writing the parquet file to disk will be slow and result in memory explosion for huge datasets.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

dataset Issues regarding data inputs, processing, or datasets performance Issues aimed at improving speed or resource usage policies Items related to robot policies

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants