Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 146 additions & 0 deletions src/lerobot/datasets/generating_embeddings/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
# LeRobot Embedding Generation Script

Generate embeddings for LeRobot datasets to make them more lightweight and efficient for training.

## Overview

This script processes v3.0 LeRobot datasets and adds pre-computed embeddings for:

- **Task embeddings**: Language command embeddings using MiniLM
- **Image embeddings**: Frame embeddings using DinoV2

The resulting dataset can be used more efficiently during training by loading pre-computed embeddings instead of running encoders on-the-fly.

## Supported Encoders

### Image Encoders (DinoV2)

DinoV2 is a self-supervised vision transformer that produces high-quality image embeddings:

- **`dinov2_vits14`**: ViT-S/14 (384-dim) - Fastest, smaller model
- **`dinov2_vitb14`**: ViT-B/14 (768-dim) - **Recommended** - Good balance
- **`dinov2_vitl14`**: ViT-L/14 (1024-dim) - Best quality, slower

### Language Encoders (MiniLM)

MiniLM is a lightweight sentence transformer model:

- **`minilm-l6`**: MiniLM-L6-v2 (384-dim) - Faster
- **`minilm-l12`**: MiniLM-L12-v2 (384-dim) - **Recommended** - Better quality

## Usage

### Basic Command

```bash
python src/lerobot/datasets/generating_embeddings/generate_embeddings.py \
--repo-id lerobot/utokyo_xarm_bimanual \
--output-repo-id your-username/utokyo_xarm_bimanual_embeddings \
--image-encoder dinov2_vitb14 \
--language-encoder minilm-l12 \
--push-to-hub
```

### Lightweight Version (No Videos)

Removes video files to significantly reduce storage:

```bash
python src/lerobot/datasets/generating_embeddings/generate_embeddings.py \
--repo-id lerobot/utokyo_xarm_bimanual \
--output-repo-id your-username/utokyo_xarm_bimanual_lightweight \
--image-encoder dinov2_vitb14 \
--language-encoder minilm-l12 \
--remove-videos \
--push-to-hub
```

## Output

The script adds new features to your dataset:

### New Features

1. **`task_embedding`**: Language embedding for each frame
- Shape: `[384]` (MiniLM)
- One embedding per frame based on its task

2. **`{camera_key}_embedding`**: Image embedding for each camera view
- Shape: `[384]`, `[768]`, or `[1024]` depending on DinoV2 model
- Examples: `observation.images.top_embedding`, `observation.images.wrist_embedding`

### Using Embeddings in Training

```python
from lerobot.datasets.lerobot_dataset import LeRobotDataset

# Load dataset with embeddings
dataset = LeRobotDataset("your-username/utokyo_xarm_bimanual_embeddings")

# Access embeddings
item = dataset[0]
task_emb = item["task_embedding"] # Shape: [384]
img_emb = item["observation.images.top_embedding"] # Shape: [768]

# Use in your policy
# Instead of running encoders during training, use pre-computed embeddings
```

## Extending with New Encoders

The script is designed to be easily extensible. To add a new encoder:

### 1. Create Encoder Class

```python
class MyCustomImageEncoder(ImageEncoder):
"""Your custom image encoder."""

def __init__(self, device: str = "cuda"):
super().__init__(device)
# Load your model
self.model = load_my_model()
self.model = self.model.to(self.device)
self.model.eval()

def encode(self, images: list[np.ndarray]) -> np.ndarray:
"""Encode a batch of images."""
# Your encoding logic here
embeddings = []
for img in images:
emb = self.model(img)
embeddings.append(emb)
return np.array(embeddings)

@property
def embedding_dim(self) -> int:
"""Return embedding dimension."""
return 512 # Your embedding dimension
```

### 2. Add to Factory Function

```python
def get_image_encoder(encoder_name: str, device: str = "cuda") -> ImageEncoder:
encoders = {
"dinov2_vits14": lambda: DinoV2Encoder(model_name="dinov2_vits14", device=device),
"dinov2_vitb14": lambda: DinoV2Encoder(model_name="dinov2_vitb14", device=device),
"dinov2_vitl14": lambda: DinoV2Encoder(model_name="dinov2_vitl14", device=device),
# Add your encoder
"my_custom": lambda: MyCustomImageEncoder(device=device),
}
# ... rest of function
```

## Validating Embeddings

After generating embeddings, you can validate them using `validate_embeddings.py`:

```bash
python src/lerobot/datasets/generating_embeddings/validate_embeddings.py \
--original-repo-id lerobot/utokyo_xarm_bimanual \
--embeddings-repo-id pepijn223/utokyo_xarm_bimanual_embeddings \
--image-encoder dinov2_vitb14 \
--language-encoder minilm-l12 \
--num-samples 20
```
147 changes: 147 additions & 0 deletions src/lerobot/datasets/generating_embeddings/encoders.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
#!/usr/bin/env python

# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging

import numpy as np
import torch
from PIL import Image

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


class ImageEncoder:
"""Base class for image encoders."""

def __init__(self, device: str = "cuda"):
self.device = torch.device(device if torch.cuda.is_available() else "cpu")

def encode(self, images: list[np.ndarray]) -> np.ndarray:
"""Encode a batch of images."""
raise NotImplementedError


class DinoV2Encoder(ImageEncoder):
"""DinoV2 image encoder.

DinoV2 is a self-supervised vision transformer that produces high-quality image embeddings.
Supports multiple model sizes (ViT-S/14, ViT-B/14, ViT-L/14).
"""

def __init__(self, model_name: str = "dinov2_vitb14", device: str = "cuda", batch_size: int = 32):
super().__init__(device)
self.batch_size = batch_size
self.model_name = model_name
logger.info(f"Loading DinoV2 model: {model_name}")
self.model = torch.hub.load("facebookresearch/dinov2", model_name) # nosec B614
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason why you didn't use AutoModel from transformers here also? We do it for the SAC encoder.

self.image_enc_layers = AutoModel.from_pretrained(config.vision_encoder_name, trust_remote_code=True)

Something like:

        self.model = AutoModel.from_pretrained("facebook/dinov2_base"):

self.model = self.model.to(self.device)
self.model.eval()

# DinoV2 preprocessing
from torchvision import transforms

self.transform = transforms.Compose(
[
transforms.Resize(256, interpolation=transforms.InterpolationMode.BICUBIC),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)

def encode(self, images: list[np.ndarray]) -> np.ndarray:
"""Encode a batch of images."""
embeddings = []

with torch.inference_mode():
for i in range(0, len(images), self.batch_size):
batch_images = images[i : i + self.batch_size]
# Convert numpy arrays to PIL Images and apply transforms
pil_images = [Image.fromarray(img.astype(np.uint8)) for img in batch_images]
tensors = torch.stack([self.transform(img) for img in pil_images]).to(self.device)

# Get embeddings
batch_embeddings = self.model(tensors).cpu().numpy()
embeddings.append(batch_embeddings)

return np.concatenate(embeddings, axis=0)

@property
def embedding_dim(self) -> int:
"""Return the embedding dimension based on model size."""
if "vits14" in self.model_name:
return 384 # DinoV2 ViT-S/14
elif "vitb14" in self.model_name:
return 768 # DinoV2 ViT-B/14
elif "vitl14" in self.model_name:
return 1024 # DinoV2 ViT-L/14
else:
return 768 # Default to ViT-B/14


class LanguageEncoder:
"""Base class for language encoders."""

def __init__(self, device: str = "cuda"):
self.device = torch.device(device if torch.cuda.is_available() else "cpu")

def encode(self, texts: list[str]) -> np.ndarray:
"""Encode a batch of texts."""
raise NotImplementedError


class MiniLMEncoder(LanguageEncoder):
"""MiniLM language encoder.

MiniLM is a lightweight sentence transformer model that produces high-quality text embeddings.
Supports L6 and L12 model sizes.
"""

def __init__(self, model_name: str = "sentence-transformers/all-MiniLM-L12-v2", device: str = "cuda"):
super().__init__(device)
self.model_name = model_name
logger.info(f"Loading MiniLM model: {model_name}")

from transformers import AutoModel, AutoTokenizer

self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModel.from_pretrained(model_name).to(self.device)
self.model.eval()

def _mean_pooling(self, model_output, attention_mask):
"""Mean pooling to get sentence embeddings."""
token_embeddings = model_output[0]
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
input_mask_expanded.sum(1), min=1e-9
)

def encode(self, texts: list[str]) -> np.ndarray:
"""Encode a batch of texts."""
with torch.inference_mode():
encoded_input = self.tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
encoded_input = {k: v.to(self.device) for k, v in encoded_input.items()}

model_output = self.model(**encoded_input)
embeddings = self._mean_pooling(model_output, encoded_input["attention_mask"])

return embeddings.cpu().numpy()

@property
def embedding_dim(self) -> int:
"""Return the embedding dimension."""
return 384 # Both MiniLM-L6 and L12 output 384-dim embeddings
Loading