diff --git a/src/lerobot/datasets/generating_embeddings/README.md b/src/lerobot/datasets/generating_embeddings/README.md new file mode 100644 index 0000000000..a31290b044 --- /dev/null +++ b/src/lerobot/datasets/generating_embeddings/README.md @@ -0,0 +1,146 @@ +# LeRobot Embedding Generation Script + +Generate embeddings for LeRobot datasets to make them more lightweight and efficient for training. + +## Overview + +This script processes v3.0 LeRobot datasets and adds pre-computed embeddings for: + +- **Task embeddings**: Language command embeddings using MiniLM +- **Image embeddings**: Frame embeddings using DinoV2 + +The resulting dataset can be used more efficiently during training by loading pre-computed embeddings instead of running encoders on-the-fly. + +## Supported Encoders + +### Image Encoders (DinoV2) + +DinoV2 is a self-supervised vision transformer that produces high-quality image embeddings: + +- **`dinov2_vits14`**: ViT-S/14 (384-dim) - Fastest, smaller model +- **`dinov2_vitb14`**: ViT-B/14 (768-dim) - **Recommended** - Good balance +- **`dinov2_vitl14`**: ViT-L/14 (1024-dim) - Best quality, slower + +### Language Encoders (MiniLM) + +MiniLM is a lightweight sentence transformer model: + +- **`minilm-l6`**: MiniLM-L6-v2 (384-dim) - Faster +- **`minilm-l12`**: MiniLM-L12-v2 (384-dim) - **Recommended** - Better quality + +## Usage + +### Basic Command + +```bash +python src/lerobot/datasets/generating_embeddings/generate_embeddings.py \ + --repo-id lerobot/utokyo_xarm_bimanual \ + --output-repo-id your-username/utokyo_xarm_bimanual_embeddings \ + --image-encoder dinov2_vitb14 \ + --language-encoder minilm-l12 \ + --push-to-hub +``` + +### Lightweight Version (No Videos) + +Removes video files to significantly reduce storage: + +```bash +python src/lerobot/datasets/generating_embeddings/generate_embeddings.py \ + --repo-id lerobot/utokyo_xarm_bimanual \ + --output-repo-id your-username/utokyo_xarm_bimanual_lightweight \ + --image-encoder dinov2_vitb14 \ + --language-encoder minilm-l12 \ + --remove-videos \ + --push-to-hub +``` + +## Output + +The script adds new features to your dataset: + +### New Features + +1. **`task_embedding`**: Language embedding for each frame + - Shape: `[384]` (MiniLM) + - One embedding per frame based on its task + +2. **`{camera_key}_embedding`**: Image embedding for each camera view + - Shape: `[384]`, `[768]`, or `[1024]` depending on DinoV2 model + - Examples: `observation.images.top_embedding`, `observation.images.wrist_embedding` + +### Using Embeddings in Training + +```python +from lerobot.datasets.lerobot_dataset import LeRobotDataset + +# Load dataset with embeddings +dataset = LeRobotDataset("your-username/utokyo_xarm_bimanual_embeddings") + +# Access embeddings +item = dataset[0] +task_emb = item["task_embedding"] # Shape: [384] +img_emb = item["observation.images.top_embedding"] # Shape: [768] + +# Use in your policy +# Instead of running encoders during training, use pre-computed embeddings +``` + +## Extending with New Encoders + +The script is designed to be easily extensible. To add a new encoder: + +### 1. Create Encoder Class + +```python +class MyCustomImageEncoder(ImageEncoder): + """Your custom image encoder.""" + + def __init__(self, device: str = "cuda"): + super().__init__(device) + # Load your model + self.model = load_my_model() + self.model = self.model.to(self.device) + self.model.eval() + + def encode(self, images: list[np.ndarray]) -> np.ndarray: + """Encode a batch of images.""" + # Your encoding logic here + embeddings = [] + for img in images: + emb = self.model(img) + embeddings.append(emb) + return np.array(embeddings) + + @property + def embedding_dim(self) -> int: + """Return embedding dimension.""" + return 512 # Your embedding dimension +``` + +### 2. Add to Factory Function + +```python +def get_image_encoder(encoder_name: str, device: str = "cuda") -> ImageEncoder: + encoders = { + "dinov2_vits14": lambda: DinoV2Encoder(model_name="dinov2_vits14", device=device), + "dinov2_vitb14": lambda: DinoV2Encoder(model_name="dinov2_vitb14", device=device), + "dinov2_vitl14": lambda: DinoV2Encoder(model_name="dinov2_vitl14", device=device), + # Add your encoder + "my_custom": lambda: MyCustomImageEncoder(device=device), + } + # ... rest of function +``` + +## Validating Embeddings + +After generating embeddings, you can validate them using `validate_embeddings.py`: + +```bash +python src/lerobot/datasets/generating_embeddings/validate_embeddings.py \ + --original-repo-id lerobot/utokyo_xarm_bimanual \ + --embeddings-repo-id pepijn223/utokyo_xarm_bimanual_embeddings \ + --image-encoder dinov2_vitb14 \ + --language-encoder minilm-l12 \ + --num-samples 20 +``` diff --git a/src/lerobot/datasets/generating_embeddings/encoders.py b/src/lerobot/datasets/generating_embeddings/encoders.py new file mode 100644 index 0000000000..79e798bdda --- /dev/null +++ b/src/lerobot/datasets/generating_embeddings/encoders.py @@ -0,0 +1,147 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +import numpy as np +import torch +from PIL import Image + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +class ImageEncoder: + """Base class for image encoders.""" + + def __init__(self, device: str = "cuda"): + self.device = torch.device(device if torch.cuda.is_available() else "cpu") + + def encode(self, images: list[np.ndarray]) -> np.ndarray: + """Encode a batch of images.""" + raise NotImplementedError + + +class DinoV2Encoder(ImageEncoder): + """DinoV2 image encoder. + + DinoV2 is a self-supervised vision transformer that produces high-quality image embeddings. + Supports multiple model sizes (ViT-S/14, ViT-B/14, ViT-L/14). + """ + + def __init__(self, model_name: str = "dinov2_vitb14", device: str = "cuda", batch_size: int = 32): + super().__init__(device) + self.batch_size = batch_size + self.model_name = model_name + logger.info(f"Loading DinoV2 model: {model_name}") + self.model = torch.hub.load("facebookresearch/dinov2", model_name) # nosec B614 + self.model = self.model.to(self.device) + self.model.eval() + + # DinoV2 preprocessing + from torchvision import transforms + + self.transform = transforms.Compose( + [ + transforms.Resize(256, interpolation=transforms.InterpolationMode.BICUBIC), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ] + ) + + def encode(self, images: list[np.ndarray]) -> np.ndarray: + """Encode a batch of images.""" + embeddings = [] + + with torch.inference_mode(): + for i in range(0, len(images), self.batch_size): + batch_images = images[i : i + self.batch_size] + # Convert numpy arrays to PIL Images and apply transforms + pil_images = [Image.fromarray(img.astype(np.uint8)) for img in batch_images] + tensors = torch.stack([self.transform(img) for img in pil_images]).to(self.device) + + # Get embeddings + batch_embeddings = self.model(tensors).cpu().numpy() + embeddings.append(batch_embeddings) + + return np.concatenate(embeddings, axis=0) + + @property + def embedding_dim(self) -> int: + """Return the embedding dimension based on model size.""" + if "vits14" in self.model_name: + return 384 # DinoV2 ViT-S/14 + elif "vitb14" in self.model_name: + return 768 # DinoV2 ViT-B/14 + elif "vitl14" in self.model_name: + return 1024 # DinoV2 ViT-L/14 + else: + return 768 # Default to ViT-B/14 + + +class LanguageEncoder: + """Base class for language encoders.""" + + def __init__(self, device: str = "cuda"): + self.device = torch.device(device if torch.cuda.is_available() else "cpu") + + def encode(self, texts: list[str]) -> np.ndarray: + """Encode a batch of texts.""" + raise NotImplementedError + + +class MiniLMEncoder(LanguageEncoder): + """MiniLM language encoder. + + MiniLM is a lightweight sentence transformer model that produces high-quality text embeddings. + Supports L6 and L12 model sizes. + """ + + def __init__(self, model_name: str = "sentence-transformers/all-MiniLM-L12-v2", device: str = "cuda"): + super().__init__(device) + self.model_name = model_name + logger.info(f"Loading MiniLM model: {model_name}") + + from transformers import AutoModel, AutoTokenizer + + self.tokenizer = AutoTokenizer.from_pretrained(model_name) + self.model = AutoModel.from_pretrained(model_name).to(self.device) + self.model.eval() + + def _mean_pooling(self, model_output, attention_mask): + """Mean pooling to get sentence embeddings.""" + token_embeddings = model_output[0] + input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() + return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp( + input_mask_expanded.sum(1), min=1e-9 + ) + + def encode(self, texts: list[str]) -> np.ndarray: + """Encode a batch of texts.""" + with torch.inference_mode(): + encoded_input = self.tokenizer(texts, padding=True, truncation=True, return_tensors="pt") + encoded_input = {k: v.to(self.device) for k, v in encoded_input.items()} + + model_output = self.model(**encoded_input) + embeddings = self._mean_pooling(model_output, encoded_input["attention_mask"]) + + return embeddings.cpu().numpy() + + @property + def embedding_dim(self) -> int: + """Return the embedding dimension.""" + return 384 # Both MiniLM-L6 and L12 output 384-dim embeddings diff --git a/src/lerobot/datasets/generating_embeddings/generate_embeddings.py b/src/lerobot/datasets/generating_embeddings/generate_embeddings.py new file mode 100644 index 0000000000..82d47299de --- /dev/null +++ b/src/lerobot/datasets/generating_embeddings/generate_embeddings.py @@ -0,0 +1,409 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Generate embeddings for LeRobot datasets to make them more lightweight and efficient. + +This script: +1. Loads a v3.0 LeRobot dataset from the hub +2. Computes embeddings for tasks (language commands) and frames (images) +3. Stores embeddings as new features in the dataset +4. Optionally removes video files to reduce size +5. Pushes the converted dataset to the hub + +Current supported encoders: +- Image: DinoV2 (dinov2_vits14, dinov2_vitb14, dinov2_vitl14) +- Language: MiniLM (minilm-l6, minilm-l12) + +The architecture is extensible - you can add more encoders by: +1. Creating a new encoder class inheriting from ImageEncoder or LanguageEncoder +2. Implementing the encode() method and embedding_dim property +3. Adding it to the get_image_encoder() or get_language_encoder() factory function + +Usage example: + python src/lerobot/datasets/generating_embeddings/generate_embeddings.py \ + --repo-id lerobot/utokyo_xarm_bimanual \ + --output-repo-id lerobot/utokyo_xarm_bimanual_embeddings \ + --image-encoder dinov2_vitb14 \ + --language-encoder minilm-l12 \ + --remove-videos \ + --push-to-hub +""" + +import argparse +import shutil +from pathlib import Path + +import numpy as np +import torch +from tqdm import tqdm + +from lerobot.datasets.generating_embeddings.encoders import ( + DinoV2Encoder, + ImageEncoder, + LanguageEncoder, + MiniLMEncoder, +) +from lerobot.datasets.lerobot_dataset import LeRobotDataset + + +def get_image_encoder(encoder_name: str, device: str = "cuda") -> ImageEncoder: + """Factory function to get image encoder. + + To add a new encoder: + 1. Create a new class inheriting from ImageEncoder + 2. Implement encode() and embedding_dim property + 3. Add it to the encoders dictionary below + """ + encoders = { + "dinov2_vits14": lambda: DinoV2Encoder(model_name="dinov2_vits14", device=device), + "dinov2_vitb14": lambda: DinoV2Encoder(model_name="dinov2_vitb14", device=device), + "dinov2_vitl14": lambda: DinoV2Encoder(model_name="dinov2_vitl14", device=device), + } + + if encoder_name not in encoders: + raise ValueError(f"Unknown image encoder: {encoder_name}. Available options: {list(encoders.keys())}") + + return encoders[encoder_name]() + + +def get_language_encoder(encoder_name: str, device: str = "cuda") -> LanguageEncoder: + """Factory function to get language encoder. + + To add a new encoder: + 1. Create a new class inheriting from LanguageEncoder + 2. Implement encode() and embedding_dim property + 3. Add it to the encoders dictionary below + """ + encoders = { + "minilm-l6": lambda: MiniLMEncoder( + model_name="sentence-transformers/all-MiniLM-L6-v2", device=device + ), + "minilm-l12": lambda: MiniLMEncoder( + model_name="sentence-transformers/all-MiniLM-L12-v2", device=device + ), + } + + if encoder_name not in encoders: + raise ValueError( + f"Unknown language encoder: {encoder_name}. Available options: {list(encoders.keys())}" + ) + + return encoders[encoder_name]() + + +def generate_embeddings_for_dataset( + repo_id: str, + output_repo_id: str, + image_encoder: ImageEncoder, + language_encoder: LanguageEncoder, + remove_videos: bool = False, + local_dir: Path | None = None, + output_local_dir: Path | None = None, + push_to_hub: bool = False, +): + """Generate embeddings for a LeRobot dataset. + + Args: + repo_id: Source dataset repository ID + output_repo_id: Output dataset repository ID + image_encoder: Image encoder instance + language_encoder: Language encoder instance + remove_videos: Whether to remove video files + local_dir: Local directory for source dataset + output_local_dir: Local directory for output dataset + push_to_hub: Whether to push to hub after conversion + """ + print(f"Loading dataset: {repo_id}") + + dataset = LeRobotDataset(repo_id, root=local_dir, download_videos=not remove_videos) + + print(f"Dataset: {dataset.num_episodes} episodes, {dataset.num_frames} frames") + + # Create output directory + if output_local_dir is None: + from lerobot.utils.constants import HF_LEROBOT_HOME + + output_local_dir = HF_LEROBOT_HOME / output_repo_id + else: + output_local_dir = Path(output_local_dir) + + # Copy the dataset to the output location + print(f"Copying to {output_local_dir}") + if output_local_dir.exists(): + shutil.rmtree(output_local_dir) + + shutil.copytree(dataset.root, output_local_dir) + + output_dataset = LeRobotDataset(output_repo_id, root=output_local_dir) + + # Define new features for embeddings + img_emb_dim = image_encoder.embedding_dim + lang_emb_dim = language_encoder.embedding_dim + + # Get unique tasks and compute their embeddings + print("Computing task embeddings...") + unique_tasks = dataset.meta.tasks.index.tolist() + task_embeddings = {} + + for task in tqdm(unique_tasks, desc="Encoding tasks"): + # Clean up task text + task_clean = task.strip().capitalize().strip(" .,!?-_") + embedding = language_encoder.encode([task_clean])[0] + task_embeddings[task] = embedding + + print(f"Computed {len(task_embeddings)} task embeddings") + + # Process each episode + print("Processing episodes...") + + # Track task embeddings per frame + all_task_embeddings = [] + all_image_embeddings_dict = {cam_key: [] for cam_key in dataset.meta.camera_keys} + + for ep_idx in tqdm(range(dataset.num_episodes), desc="Processing episodes"): + ep = dataset.meta.episodes[ep_idx] + ep_start = ep["dataset_from_index"] + ep_end = ep["dataset_to_index"] + + # Get all frames for this episode + for frame_idx in range(ep_start, ep_end): + item = dataset.hf_dataset[frame_idx] + + # Get task embedding for this frame + task = dataset.meta.tasks.iloc[item["task_index"].item()].name + task_emb = task_embeddings[task] + all_task_embeddings.append(task_emb) + + for cam_key in dataset.meta.camera_keys: + if cam_key in dataset.meta.video_keys: + # Decode from video + current_ts = item["timestamp"].item() + video_frames = dataset._query_videos({cam_key: [current_ts]}, ep_idx) + img = video_frames[ + cam_key + ] # This returns tensor of shape (T, C, H, W) or might be squeezed + + # Handle the tensor shape + if isinstance(img, torch.Tensor): + if img.ndim == 4: + # Shape: (T, C, H, W) where T=1 for single timestamp + img = img[0] # Now (C, H, W) + elif img.ndim == 3: + # Shape: (C, H, W) + pass + else: + raise ValueError( + f"Unexpected video frame shape {img.shape} for camera {cam_key}. " + f"Expected (T, C, H, W) or (C, H, W). Episode {ep_idx}, Frame {frame_idx}" + ) + + # Convert to numpy: (C, H, W) float32 [0, 1] -> (H, W, C) uint8 [0, 255] + img_np = (img.permute(1, 2, 0).numpy() * 255).astype(np.uint8) + else: + img_np = np.array(img) + else: + # Load from image file + img = item[cam_key] + # Convert to numpy if needed + if isinstance(img, torch.Tensor): + if img.ndim == 3: + img_np = (img.permute(1, 2, 0).numpy() * 255).astype(np.uint8) + else: + raise ValueError(f"Unexpected image shape {img.shape} for camera {cam_key}") + else: + img_np = np.array(img) + + all_image_embeddings_dict[cam_key].append(img_np) + + print("Computing image embeddings...") + image_embeddings_dict = {} + for cam_key, images in all_image_embeddings_dict.items(): + print(f" {cam_key}: {len(images)} images") + embeddings = image_encoder.encode(images) + image_embeddings_dict[cam_key] = embeddings + + print("Adding embeddings to dataset...") + + # Update features in info.json + output_dataset.meta.info["features"]["task_embedding"] = { + "dtype": "float32", + "shape": [lang_emb_dim], + "names": None, + } + + for cam_key in dataset.meta.camera_keys: + output_dataset.meta.info["features"][f"{cam_key}_embedding"] = { + "dtype": "float32", + "shape": [img_emb_dim], + "names": None, + } + + import pandas as pd + + from lerobot.datasets.utils import DEFAULT_DATA_PATH, write_info + + write_info(output_dataset.meta.info, output_dataset.root) + + # Group frames by their parquet file + frames_by_file = {} + for frame_idx in range(output_dataset.num_frames): + item = output_dataset.hf_dataset[frame_idx] + ep_idx = item["episode_index"].item() + ep = output_dataset.meta.episodes[ep_idx] + chunk_idx = ep["data/chunk_index"] + file_idx = ep["data/file_index"] + key = (chunk_idx, file_idx) + if key not in frames_by_file: + frames_by_file[key] = [] + frames_by_file[key].append(frame_idx) + + # Update each parquet file + for (chunk_idx, file_idx), frame_indices in tqdm(frames_by_file.items(), desc="Updating parquet files"): + parquet_path = output_dataset.root / DEFAULT_DATA_PATH.format( + chunk_index=chunk_idx, file_index=file_idx + ) + + # Load the parquet file + df = pd.read_parquet(parquet_path) + + # Add embedding columns + df["task_embedding"] = [all_task_embeddings[idx].tolist() for idx in frame_indices] + + for cam_key in dataset.meta.camera_keys: + df[f"{cam_key}_embedding"] = [ + image_embeddings_dict[cam_key][idx].tolist() for idx in frame_indices + ] + + # Save the updated parquet file + df.to_parquet(parquet_path, index=False) + + # Remove videos if requested + if remove_videos: + print("Removing video files...") + videos_dir = output_dataset.root / "videos" + if videos_dir.exists(): + shutil.rmtree(videos_dir) + + # Update info to reflect no videos + for cam_key in dataset.meta.camera_keys: + if cam_key in dataset.meta.video_keys: + output_dataset.meta.info["features"][cam_key]["dtype"] = "image" + # Remove video-specific info + if "info" in output_dataset.meta.info["features"][cam_key]: + del output_dataset.meta.info["features"][cam_key]["info"] + + output_dataset.meta.info["video_path"] = None + write_info(output_dataset.meta.info, output_dataset.root) + + print(f"Saved to: {output_local_dir}") + + # Push to hub + if push_to_hub: + print(f"Pushing to hub: {output_repo_id}") + output_dataset.push_to_hub(push_videos=not remove_videos) + print("Done!") + + +def main(): + parser = argparse.ArgumentParser( + description="Generate embeddings for LeRobot datasets", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # Basic usage with default encoders (DinoV2 ViT-B/14 + MiniLM-L12) + python src/lerobot/datasets/generating_embeddings/generate_embeddings.py \\ + --repo-id lerobot/utokyo_xarm_bimanual \\ + --output-repo-id your-username/utokyo_xarm_bimanual_embeddings \\ + --image-encoder dinov2_vitb14 \\ + --language-encoder minilm-l12 \\ + --push-to-hub + + # Generate embeddings and remove videos + python src/lerobot/datasets/generating_embeddings/generate_embeddings.py \\ + --repo-id lerobot/utokyo_xarm_bimanual \\ + --output-repo-id your-username/utokyo_xarm_bimanual_lightweight \\ + --image-encoder dinov2_vitb14 \\ + --language-encoder minilm-l12 \\ + --remove-videos \\ + --push-to-hub + +Available image encoders: + - dinov2_vits14: DinoV2 ViT-S/14 (384-dim, faster) + - dinov2_vitb14: DinoV2 ViT-B/14 (768-dim, recommended) + - dinov2_vitl14: DinoV2 ViT-L/14 (1024-dim, best quality) + +Available language encoders: + - minilm-l6: MiniLM-L6-v2 (384-dim, faster) + - minilm-l12: MiniLM-L12-v2 (384-dim, recommended) + """, + ) + parser.add_argument("--repo-id", type=str, required=True, help="Source dataset repository ID") + parser.add_argument("--output-repo-id", type=str, required=True, help="Output dataset repository ID") + parser.add_argument( + "--image-encoder", + type=str, + default="dinov2_vitb14", + help="Image encoder to use (default: dinov2_vitb14)", + ) + parser.add_argument( + "--language-encoder", + type=str, + default="minilm-l12", + help="Language encoder to use (default: minilm-l12)", + ) + parser.add_argument( + "--remove-videos", + action="store_true", + help="Remove video files after generating embeddings", + ) + parser.add_argument("--local-dir", type=str, default=None, help="Local directory for source dataset") + parser.add_argument( + "--output-local-dir", type=str, default=None, help="Local directory for output dataset" + ) + parser.add_argument( + "--push-to-hub", + action="store_true", + help="Push the converted dataset to the hub", + ) + parser.add_argument( + "--device", + type=str, + default="cuda", + help="Device to use for encoding (default: cuda)", + ) + + args = parser.parse_args() + + # Load encoders + image_encoder = get_image_encoder(args.image_encoder, device=args.device) + language_encoder = get_language_encoder(args.language_encoder, device=args.device) + + # Generate embeddings + generate_embeddings_for_dataset( + repo_id=args.repo_id, + output_repo_id=args.output_repo_id, + image_encoder=image_encoder, + language_encoder=language_encoder, + remove_videos=args.remove_videos, + local_dir=Path(args.local_dir) if args.local_dir else None, + output_local_dir=Path(args.output_local_dir) if args.output_local_dir else None, + push_to_hub=args.push_to_hub, + ) + + +if __name__ == "__main__": + main() diff --git a/src/lerobot/datasets/generating_embeddings/validate_embeddings.py b/src/lerobot/datasets/generating_embeddings/validate_embeddings.py new file mode 100644 index 0000000000..88d603dc99 --- /dev/null +++ b/src/lerobot/datasets/generating_embeddings/validate_embeddings.py @@ -0,0 +1,222 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Validate pre-computed embeddings against on-the-fly computed embeddings. + +Usage: + python src/lerobot/datasets/generating_embeddings/validate_embeddings.py \ + --original-repo-id lerobot/utokyo_xarm_bimanual \ + --embeddings-repo-id /utokyo_xarm_bimanual_embeddings \ + --image-encoder dinov2_vitb14 \ + --language-encoder minilm-l12 \ + --num-samples 10 +""" + +import argparse + +import numpy as np +import torch +from tqdm import tqdm + +from lerobot.datasets.generating_embeddings.encoders import ImageEncoder, LanguageEncoder +from lerobot.datasets.generating_embeddings.generate_embeddings import ( + get_image_encoder, + get_language_encoder, +) +from lerobot.datasets.lerobot_dataset import LeRobotDataset + + +def cosine_similarity(a: np.ndarray, b: np.ndarray) -> float: + """Compute cosine similarity between two vectors.""" + return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)) + + +def validate_embeddings( + original_repo_id: str, + embeddings_repo_id: str, + image_encoder: ImageEncoder, + language_encoder: LanguageEncoder, + num_samples: int = 10, + device: str = "cuda", +): + """Validate pre-computed embeddings against on-the-fly embeddings. + + Args: + original_repo_id: Original dataset repository ID + embeddings_repo_id: Dataset with pre-computed embeddings repository ID + image_encoder: Image encoder instance + language_encoder: Language encoder instance + num_samples: Number of samples to validate + device: Device to use for encoding + """ + # Load both datasets + print("Loading datasets...") + original_dataset = LeRobotDataset(original_repo_id, download_videos=True) + embeddings_dataset = LeRobotDataset(embeddings_repo_id, download_videos=False) + + # Verify both datasets have the same number of frames + assert original_dataset.num_frames == embeddings_dataset.num_frames, ( + f"Frame count mismatch: original={original_dataset.num_frames}, " + f"embeddings={embeddings_dataset.num_frames}" + ) + + camera_keys = original_dataset.meta.camera_keys + + # Check embedding features exist + expected_features = ["task_embedding"] + [f"{cam}_embedding" for cam in camera_keys] + for feat in expected_features: + if feat not in embeddings_dataset.features: + raise ValueError(f"Embedding feature not found: {feat}") + + # Select random sample indices + sample_indices = np.random.choice( + original_dataset.num_frames, size=min(num_samples, original_dataset.num_frames), replace=False + ) + print(f"Validating {len(sample_indices)} samples...") + + # Track statistics + task_similarities = [] + image_similarities = {cam: [] for cam in camera_keys} + + for idx in tqdm(sample_indices, desc="Validating"): + idx = int(idx) + + embeddings_item = embeddings_dataset[idx] + precomputed_task_emb = embeddings_item["task_embedding"].numpy() + precomputed_image_embs = {cam: embeddings_item[f"{cam}_embedding"].numpy() for cam in camera_keys} + + original_item = original_dataset[idx] + + # Get task and compute embedding + task = original_item["task"] + # Clean up task text (same as in generate_embeddings.py) + task_clean = task.strip().capitalize().strip(" .,!?-_") + onthefly_task_emb = language_encoder.encode([task_clean])[0] + + # Get images and compute embeddings + onthefly_image_embs = {} + for cam in camera_keys: + img = original_item[cam] + # Convert to numpy if needed + if isinstance(img, torch.Tensor): + if img.ndim == 3: # (C, H, W) + img_np = (img.permute(1, 2, 0).numpy() * 255).astype(np.uint8) + else: + raise ValueError(f"Unexpected image shape: {img.shape}") + else: + img_np = np.array(img) + + onthefly_image_embs[cam] = image_encoder.encode([img_np])[0] + + # Task embedding comparison + task_sim = cosine_similarity(precomputed_task_emb, onthefly_task_emb) + task_similarities.append(task_sim) + + # Image embedding comparison + for cam in camera_keys: + img_sim = cosine_similarity(precomputed_image_embs[cam], onthefly_image_embs[cam]) + image_similarities[cam].append(img_sim) + + # Results + print("\nResults:") + task_sim_threshold = 0.99 + img_sim_threshold = 0.99 + + task_mean_sim = np.mean(task_similarities) + task_pass = task_mean_sim >= task_sim_threshold + + print(f" Task: {task_mean_sim:.4f} {'✓' if task_pass else '✗'}") + + for cam in camera_keys: + cam_mean_sim = np.mean(image_similarities[cam]) + cam_pass = cam_mean_sim >= img_sim_threshold + print(f" {cam}: {cam_mean_sim:.4f} {'✓' if cam_pass else '✗'}") + + image_pass = all(np.mean(image_similarities[cam]) >= img_sim_threshold for cam in camera_keys) + + print() + if task_pass and image_pass: + print("✓ PASSED") + else: + print("✗ FAILED") + + +def main(): + parser = argparse.ArgumentParser( + description="Validate and compare pre-computed embeddings with on-the-fly embeddings", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Example: + python src/lerobot/datasets/generating_embeddings/validate_embeddings.py \\ + --original-repo-id lerobot/utokyo_xarm_bimanual \\ + --embeddings-repo-id lerobot/utokyo_xarm_bimanual_embeddings \\ + --image-encoder dinov2_vitb14 \\ + --language-encoder minilm-l12 \\ + --num-samples 20 + """, + ) + parser.add_argument("--original-repo-id", type=str, required=True, help="Original dataset repository ID") + parser.add_argument( + "--embeddings-repo-id", + type=str, + required=True, + help="Dataset with pre-computed embeddings repository ID", + ) + parser.add_argument( + "--image-encoder", + type=str, + default="dinov2_vitb14", + help="Image encoder to use (default: dinov2_vitb14)", + ) + parser.add_argument( + "--language-encoder", + type=str, + default="minilm-l12", + help="Language encoder to use (default: minilm-l12)", + ) + parser.add_argument( + "--num-samples", + type=int, + default=10, + help="Number of samples to validate (default: 10)", + ) + parser.add_argument( + "--device", + type=str, + default="cuda", + help="Device to use for encoding (default: cuda)", + ) + + args = parser.parse_args() + + # Load encoders + image_encoder = get_image_encoder(args.image_encoder, device=args.device) + language_encoder = get_language_encoder(args.language_encoder, device=args.device) + + # Validate embeddings + validate_embeddings( + original_repo_id=args.original_repo_id, + embeddings_repo_id=args.embeddings_repo_id, + image_encoder=image_encoder, + language_encoder=language_encoder, + num_samples=args.num_samples, + device=args.device, + ) + + +if __name__ == "__main__": + main()