Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions src/lerobot/datasets/dataset_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -1003,10 +1003,18 @@ def _copy_data_with_feature_changes(
df[feature_name] = feature_values
else:
feature_slice = values[frame_idx:end_idx]
if len(feature_slice.shape) > 1 and feature_slice.shape[1] == 1:
if len(feature_slice.shape) == 1:
# 1D array - can assign directly
df[feature_name] = feature_slice
elif len(feature_slice.shape) == 2 and feature_slice.shape[1] == 1:
# 2D array with single column - flatten it
df[feature_name] = feature_slice.flatten()
elif len(feature_slice.shape) == 2:
# 2D array with multiple columns (e.g., embeddings) - convert to list of lists
df[feature_name] = feature_slice.tolist()
else:
df[feature_name] = feature_slice
# Higher dimensional - convert to list
df[feature_name] = [row.tolist() for row in feature_slice]
frame_idx = end_idx

# Write using the preserved chunk_idx and file_idx from source
Expand Down
200 changes: 60 additions & 140 deletions src/lerobot/datasets/generating_embeddings/generate_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,34 +127,13 @@ def generate_embeddings_for_dataset(
output_local_dir: Local directory for output dataset
push_to_hub: Whether to push to hub after conversion
"""
print(f"Loading dataset: {repo_id}")
from lerobot.datasets.dataset_tools import modify_features

dataset = LeRobotDataset(repo_id, root=local_dir, download_videos=not remove_videos)
print(f"Loading dataset: {repo_id}")

dataset = LeRobotDataset(repo_id, root=local_dir, download_videos=True)
print(f"Dataset: {dataset.num_episodes} episodes, {dataset.num_frames} frames")

# Create output directory
if output_local_dir is None:
from lerobot.utils.constants import HF_LEROBOT_HOME

output_local_dir = HF_LEROBOT_HOME / output_repo_id
else:
output_local_dir = Path(output_local_dir)

# Copy the dataset to the output location
print(f"Copying to {output_local_dir}")
if output_local_dir.exists():
shutil.rmtree(output_local_dir)

shutil.copytree(dataset.root, output_local_dir)

output_dataset = LeRobotDataset(output_repo_id, root=output_local_dir)

# Define new features for embeddings
img_emb_dim = image_encoder.embedding_dim
lang_emb_dim = language_encoder.embedding_dim

# Get unique tasks and compute their embeddings
print("Computing task embeddings...")
unique_tasks = dataset.meta.tasks.index.tolist()
task_embeddings = {}
Expand All @@ -167,67 +146,43 @@ def generate_embeddings_for_dataset(

print(f"Computed {len(task_embeddings)} task embeddings")

# Process each episode
print("Processing episodes...")

# Track task embeddings per frame
print("Processing frames and computing embeddings...")
all_task_embeddings = []
all_image_embeddings_dict = {cam_key: [] for cam_key in dataset.meta.camera_keys}

for ep_idx in tqdm(range(dataset.num_episodes), desc="Processing episodes"):
ep = dataset.meta.episodes[ep_idx]
ep_start = ep["dataset_from_index"]
ep_end = ep["dataset_to_index"]

# Get all frames for this episode
for frame_idx in range(ep_start, ep_end):
item = dataset.hf_dataset[frame_idx]

# Get task embedding for this frame
task = dataset.meta.tasks.iloc[item["task_index"].item()].name
task_emb = task_embeddings[task]
all_task_embeddings.append(task_emb)

for cam_key in dataset.meta.camera_keys:
if cam_key in dataset.meta.video_keys:
# Decode from video
current_ts = item["timestamp"].item()
video_frames = dataset._query_videos({cam_key: [current_ts]}, ep_idx)
img = video_frames[
cam_key
] # This returns tensor of shape (T, C, H, W) or might be squeezed

# Handle the tensor shape
if isinstance(img, torch.Tensor):
if img.ndim == 4:
# Shape: (T, C, H, W) where T=1 for single timestamp
img = img[0] # Now (C, H, W)
elif img.ndim == 3:
# Shape: (C, H, W)
pass
else:
raise ValueError(
f"Unexpected video frame shape {img.shape} for camera {cam_key}. "
f"Expected (T, C, H, W) or (C, H, W). Episode {ep_idx}, Frame {frame_idx}"
)

# Convert to numpy: (C, H, W) float32 [0, 1] -> (H, W, C) uint8 [0, 255]
for frame_idx in tqdm(range(dataset.num_frames), desc="Processing frames"):
item = dataset.hf_dataset[frame_idx]
ep_idx = item["episode_index"].item()

task = dataset.meta.tasks.iloc[item["task_index"].item()].name
task_emb = task_embeddings[task]
all_task_embeddings.append(task_emb)

for cam_key in dataset.meta.camera_keys:
if cam_key in dataset.meta.video_keys:
current_ts = item["timestamp"].item()
video_frames = dataset._query_videos({cam_key: [current_ts]}, ep_idx)
img = video_frames[cam_key]

if isinstance(img, torch.Tensor):
if img.ndim == 4:
img = img[0] # (T, C, H, W) -> (C, H, W)
elif img.ndim != 3:
raise ValueError(f"Unexpected video frame shape {img.shape} for camera {cam_key}")
img_np = (img.permute(1, 2, 0).numpy() * 255).astype(np.uint8)
else:
img_np = np.array(img)
else:
img = item[cam_key]
if isinstance(img, torch.Tensor):
if img.ndim == 3:
img_np = (img.permute(1, 2, 0).numpy() * 255).astype(np.uint8)
else:
img_np = np.array(img)
raise ValueError(f"Unexpected image shape {img.shape} for camera {cam_key}")
else:
# Load from image file
img = item[cam_key]
# Convert to numpy if needed
if isinstance(img, torch.Tensor):
if img.ndim == 3:
img_np = (img.permute(1, 2, 0).numpy() * 255).astype(np.uint8)
else:
raise ValueError(f"Unexpected image shape {img.shape} for camera {cam_key}")
else:
img_np = np.array(img)
img_np = np.array(img)

all_image_embeddings_dict[cam_key].append(img_np)
all_image_embeddings_dict[cam_key].append(img_np)

print("Computing image embeddings...")
image_embeddings_dict = {}
Expand All @@ -236,82 +191,47 @@ def generate_embeddings_for_dataset(
embeddings = image_encoder.encode(images)
image_embeddings_dict[cam_key] = embeddings

print("Adding embeddings to dataset...")

# Update features in info.json
output_dataset.meta.info["features"]["task_embedding"] = {
"dtype": "float32",
"shape": [lang_emb_dim],
"names": None,
}

all_task_embeddings = np.array(all_task_embeddings)
for cam_key in dataset.meta.camera_keys:
output_dataset.meta.info["features"][f"{cam_key}_embedding"] = {
"dtype": "float32",
"shape": [img_emb_dim],
"names": None,
}

import pandas as pd
image_embeddings_dict[cam_key] = np.array(image_embeddings_dict[cam_key])

from lerobot.datasets.utils import DEFAULT_DATA_PATH, write_info
img_emb_dim = image_encoder.embedding_dim
lang_emb_dim = language_encoder.embedding_dim

write_info(output_dataset.meta.info, output_dataset.root)
add_features_dict = {
"task_embedding": (
all_task_embeddings,
{"dtype": "float32", "shape": [lang_emb_dim], "names": None},
),
}

# Group frames by their parquet file
frames_by_file = {}
for frame_idx in range(output_dataset.num_frames):
item = output_dataset.hf_dataset[frame_idx]
ep_idx = item["episode_index"].item()
ep = output_dataset.meta.episodes[ep_idx]
chunk_idx = ep["data/chunk_index"]
file_idx = ep["data/file_index"]
key = (chunk_idx, file_idx)
if key not in frames_by_file:
frames_by_file[key] = []
frames_by_file[key].append(frame_idx)

# Update each parquet file
for (chunk_idx, file_idx), frame_indices in tqdm(frames_by_file.items(), desc="Updating parquet files"):
parquet_path = output_dataset.root / DEFAULT_DATA_PATH.format(
chunk_index=chunk_idx, file_index=file_idx
for cam_key in dataset.meta.camera_keys:
add_features_dict[f"{cam_key}_embedding"] = (
image_embeddings_dict[cam_key],
{"dtype": "float32", "shape": [img_emb_dim], "names": None},
)

# Load the parquet file
df = pd.read_parquet(parquet_path)

# Add embedding columns
df["task_embedding"] = [all_task_embeddings[idx].tolist() for idx in frame_indices]

for cam_key in dataset.meta.camera_keys:
df[f"{cam_key}_embedding"] = [
image_embeddings_dict[cam_key][idx].tolist() for idx in frame_indices
]

# Save the updated parquet file
df.to_parquet(parquet_path, index=False)
print("Adding embeddings to dataset...")
remove_features_list = None
if remove_videos:
remove_features_list = dataset.meta.video_keys

output_dataset = modify_features(
dataset=dataset,
add_features=add_features_dict,
remove_features=remove_features_list,
output_dir=output_local_dir,
repo_id=output_repo_id,
)

# Remove videos if requested
if remove_videos:
print("Removing video files...")
videos_dir = output_dataset.root / "videos"
if videos_dir.exists():
shutil.rmtree(videos_dir)

# Update info to reflect no videos
for cam_key in dataset.meta.camera_keys:
if cam_key in dataset.meta.video_keys:
output_dataset.meta.info["features"][cam_key]["dtype"] = "image"
# Remove video-specific info
if "info" in output_dataset.meta.info["features"][cam_key]:
del output_dataset.meta.info["features"][cam_key]["info"]

output_dataset.meta.info["video_path"] = None
write_info(output_dataset.meta.info, output_dataset.root)

print(f"Saved to: {output_local_dir}")
print(f"Saved to: {output_dataset.root}")

# Push to hub
if push_to_hub:
print(f"Pushing to hub: {output_repo_id}")
output_dataset.push_to_hub(push_videos=not remove_videos)
Expand Down