-
Notifications
You must be signed in to change notification settings - Fork 2.8k
Add script to generate embedding for dataset #2138
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
| if isinstance(img, torch.Tensor): | ||
| if img.ndim == 4: | ||
| # Shape: (T, C, H, W) where T=1 for single timestamp | ||
| img = img[0] # Now (C, H, W) | ||
| elif img.ndim == 3: | ||
| # Shape: (C, H, W) | ||
| pass | ||
| else: | ||
| raise ValueError( | ||
| f"Unexpected video frame shape {img.shape} for camera {cam_key}. " | ||
| f"Expected (T, C, H, W) or (C, H, W). Episode {ep_idx}, Frame {frame_idx}" | ||
| ) | ||
|
|
||
| # Convert to numpy: (C, H, W) float32 [0, 1] -> (H, W, C) uint8 [0, 255] | ||
| img_np = (img.permute(1, 2, 0).numpy() * 255).astype(np.uint8) | ||
| else: | ||
| img_np = np.array(img) | ||
| else: | ||
| # Load from image file | ||
| img = item[cam_key] | ||
| # Convert to numpy if needed | ||
| if isinstance(img, torch.Tensor): | ||
| if img.ndim == 3: | ||
| img_np = (img.permute(1, 2, 0).numpy() * 255).astype(np.uint8) | ||
| else: | ||
| raise ValueError(f"Unexpected image shape {img.shape} for camera {cam_key}") | ||
| else: | ||
| img_np = np.array(img) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we add testing for this part,
this logic could be buggy if we have image datasets instead of videos.
Better if we had tests to verify
| self.batch_size = batch_size | ||
| self.model_name = model_name | ||
| logger.info(f"Loading DinoV2 model: {model_name}") | ||
| self.model = torch.hub.load("facebookresearch/dinov2", model_name) # nosec B614 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a reason why you didn't use AutoModel from transformers here also? We do it for the SAC encoder.
| self.image_enc_layers = AutoModel.from_pretrained(config.vision_encoder_name, trust_remote_code=True) |
Something like:
self.model = AutoModel.from_pretrained("facebook/dinov2_base"):
| from lerobot.datasets.lerobot_dataset import LeRobotDataset | ||
|
|
||
|
|
||
| def cosine_similarity(a: np.ndarray, b: np.ndarray) -> float: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice
| for (chunk_idx, file_idx), frame_indices in tqdm(frames_by_file.items(), desc="Updating parquet files"): | ||
| parquet_path = output_dataset.root / DEFAULT_DATA_PATH.format( | ||
| chunk_index=chunk_idx, file_index=file_idx | ||
| ) | ||
|
|
||
| # Load the parquet file | ||
| df = pd.read_parquet(parquet_path) | ||
|
|
||
| # Add embedding columns | ||
| df["task_embedding"] = [all_task_embeddings[idx].tolist() for idx in frame_indices] | ||
|
|
||
| for cam_key in dataset.meta.camera_keys: | ||
| df[f"{cam_key}_embedding"] = [ | ||
| image_embeddings_dict[cam_key][idx].tolist() for idx in frame_indices | ||
| ] | ||
|
|
||
| # Save the updated parquet file | ||
| df.to_parquet(parquet_path, index=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For this chunk I suggest to do it in a more effiecient way either by using dataset_tools. I'll push a PR.
Reading a writing the parquet file to disk will be slow and result in memory explosion for huge datasets.
This PR introduces a way to generate image and text embeddings to a dataset to be more efficient when training a dataset for multiple epochs. For example for learning a general reward we contain a specific dataset with OXE to improve generalization. In order to nor recompute the image and text embeddings each time we finetune for OXE we can use this script to add the embeddings to the dataset. We can additionally remove the videos in the dataset to safe space.
Testing:
Both the generate and validate script were tested on this dataset: lerobot/utokyo_xarm_bimanual. The generated dataset can be found here: pepijn223/utokyo_xarm_bimanual_embeddings.