@@ -271,6 +271,7 @@ def add_videos_from_path(
271271 path : PathLike ,
272272 allowed_extensions : Iterable [str ] | None = None ,
273273 num_decode_threads : int | None = None ,
274+ embed : bool = True ,
274275 ) -> None :
275276 """Adding video frames from the specified path to the dataset.
276277
@@ -281,6 +282,7 @@ def add_videos_from_path(
281282 uses default VIDEO_EXTENSIONS.
282283 num_decode_threads: Optional override for the number of FFmpeg decode threads.
283284 If omitted, the available CPU cores - 1 (max 16) are used.
285+ embed: If True, generate embeddings for the newly added videos.
284286 """
285287 # Collect video file paths.
286288 if allowed_extensions :
@@ -295,13 +297,20 @@ def add_videos_from_path(
295297 logger .info (f"Found { len (video_paths )} videos in { path } ." )
296298
297299 # Process videos.
298- add_videos .load_into_dataset_from_paths (
300+ created_sample_ids , _ = add_videos .load_into_dataset_from_paths (
299301 session = self .session ,
300302 dataset_id = self .dataset_id ,
301303 video_paths = video_paths ,
302304 num_decode_threads = num_decode_threads ,
303305 )
304306
307+ if embed :
308+ _generate_embeddings_video (
309+ session = self .session ,
310+ dataset_id = self .dataset_id ,
311+ sample_ids = created_sample_ids ,
312+ )
313+
305314 def add_images_from_path (
306315 self ,
307316 path : PathLike ,
@@ -354,7 +363,7 @@ def add_images_from_path(
354363 )
355364
356365 if embed :
357- _generate_embeddings (
366+ _generate_embeddings_image (
358367 session = self .session , dataset_id = self .dataset_id , sample_ids = created_sample_ids
359368 )
360369
@@ -383,7 +392,7 @@ def add_samples_from_labelformat(
383392 )
384393
385394 if embed :
386- _generate_embeddings (
395+ _generate_embeddings_image (
387396 session = self .session , dataset_id = self .dataset_id , sample_ids = created_sample_ids
388397 )
389398
@@ -446,7 +455,7 @@ def add_samples_from_yolo(
446455
447456 # Generate embeddings for all samples at once
448457 if embed :
449- _generate_embeddings (
458+ _generate_embeddings_image (
450459 session = self .session , dataset_id = self .dataset_id , sample_ids = all_created_sample_ids
451460 )
452461
@@ -512,7 +521,7 @@ def add_samples_from_coco(
512521 )
513522
514523 if embed :
515- _generate_embeddings (
524+ _generate_embeddings_image (
516525 session = self .session , dataset_id = self .dataset_id , sample_ids = created_sample_ids
517526 )
518527
@@ -564,7 +573,7 @@ def add_samples_from_coco_caption(
564573 )
565574
566575 if embed :
567- _generate_embeddings (
576+ _generate_embeddings_image (
568577 session = self .session , dataset_id = self .dataset_id , sample_ids = created_sample_ids
569578 )
570579
@@ -635,7 +644,11 @@ def compute_similarity_metadata(
635644 )
636645
637646
638- def _generate_embeddings (session : Session , dataset_id : UUID , sample_ids : list [UUID ]) -> None :
647+ def _generate_embeddings_video (
648+ session : Session ,
649+ dataset_id : UUID ,
650+ sample_ids : list [UUID ],
651+ ) -> None :
639652 """Generate and store embeddings for samples.
640653
641654 Args:
@@ -647,20 +660,54 @@ def _generate_embeddings(session: Session, dataset_id: UUID, sample_ids: list[UU
647660 return
648661
649662 embedding_manager = EmbeddingManagerProvider .get_embedding_manager ()
650- model_id = embedding_manager .load_or_get_default_model (
663+ model_id = embedding_manager .load_or_get_default_model (session = session , dataset_id = dataset_id )
664+ if model_id is None :
665+ logger .warning ("No embedding model loaded. Skipping embedding generation." )
666+ return
667+
668+ embedding_manager .embed_videos (
651669 session = session ,
652670 dataset_id = dataset_id ,
671+ sample_ids = sample_ids ,
672+ embedding_model_id = model_id ,
653673 )
674+
675+ _mark_embedding_features_enabled ()
676+
677+
678+ def _generate_embeddings_image (
679+ session : Session ,
680+ dataset_id : UUID ,
681+ sample_ids : list [UUID ],
682+ ) -> None :
683+ """Generate and store embeddings for samples.
684+
685+ Args:
686+ session: Database session for resolver operations.
687+ dataset_id: The ID of the dataset to associate with the embedding model.
688+ sample_ids: List of sample IDs to generate embeddings for.
689+ sample_type: The sample_type to generate embeddings for.
690+ """
691+ if not sample_ids :
692+ return
693+
694+ embedding_manager = EmbeddingManagerProvider .get_embedding_manager ()
695+ model_id = embedding_manager .load_or_get_default_model (session = session , dataset_id = dataset_id )
654696 if model_id is None :
655697 logger .warning ("No embedding model loaded. Skipping embedding generation." )
656698 return
657699
658700 embedding_manager .embed_images (
659701 session = session ,
702+ dataset_id = dataset_id ,
660703 sample_ids = sample_ids ,
661704 embedding_model_id = model_id ,
662705 )
663706
707+ _mark_embedding_features_enabled ()
708+
709+
710+ def _mark_embedding_features_enabled () -> None :
664711 # Mark the embedding search feature as enabled.
665712 if "embeddingSearchEnabled" not in features .lightly_studio_active_features :
666713 features .lightly_studio_active_features .append ("embeddingSearchEnabled" )
0 commit comments