Skip to content

Commit 02b9f7f

Browse files
authored
Add embedding to video dataset (#274)
1 parent f7c6055 commit 02b9f7f

File tree

8 files changed

+305
-75
lines changed

8 files changed

+305
-75
lines changed

lightly_studio/src/lightly_studio/api/routes/api/text_embedding.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
EmbeddingManagerProvider,
1717
TextEmbedQuery,
1818
)
19+
from lightly_studio.db_manager import SessionDep
20+
from lightly_studio.resolvers import dataset_resolver
1921

2022
text_embedding_router = APIRouter()
2123
# Define a type alias for the EmbeddingManager dependency
@@ -27,6 +29,7 @@
2729

2830
@text_embedding_router.get("/text_embedding/embed_text", response_model=List[float])
2931
def embed_text(
32+
session: SessionDep,
3033
embedding_manager: EmbeddingManagerDep,
3134
query_text: str = Query(..., description="The text to embed."),
3235
embedding_model_id: Annotated[
@@ -35,9 +38,16 @@ def embed_text(
3538
] = None,
3639
) -> list[float]:
3740
"""Retrieve embeddings for the input text."""
41+
# TODO(Jonas, 12/2025): Remove this hack after dataset_id is provided from frontend
42+
# This is a hack, since at the moment, no valid embedding_model_id is passed from the frontend.
43+
# so we fetch the root_dataset_id, which will be used inside embed_text to get the default model
44+
# for this dataset.
45+
root_dataset = dataset_resolver.get_root_dataset(session=session)
46+
dataset_id = root_dataset.dataset_id
3847
try:
3948
text_embeddings = embedding_manager.embed_text(
40-
TextEmbedQuery(query_text, embedding_model_id)
49+
dataset_id=dataset_id,
50+
text_query=TextEmbedQuery(text=query_text, embedding_model_id=embedding_model_id),
4151
)
4252
except ValueError as exc:
4353
raise HTTPException(

lightly_studio/src/lightly_studio/core/dataset.py

Lines changed: 55 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,7 @@ def add_videos_from_path(
271271
path: PathLike,
272272
allowed_extensions: Iterable[str] | None = None,
273273
num_decode_threads: int | None = None,
274+
embed: bool = True,
274275
) -> None:
275276
"""Adding video frames from the specified path to the dataset.
276277
@@ -281,6 +282,7 @@ def add_videos_from_path(
281282
uses default VIDEO_EXTENSIONS.
282283
num_decode_threads: Optional override for the number of FFmpeg decode threads.
283284
If omitted, the available CPU cores - 1 (max 16) are used.
285+
embed: If True, generate embeddings for the newly added videos.
284286
"""
285287
# Collect video file paths.
286288
if allowed_extensions:
@@ -295,13 +297,20 @@ def add_videos_from_path(
295297
logger.info(f"Found {len(video_paths)} videos in {path}.")
296298

297299
# Process videos.
298-
add_videos.load_into_dataset_from_paths(
300+
created_sample_ids, _ = add_videos.load_into_dataset_from_paths(
299301
session=self.session,
300302
dataset_id=self.dataset_id,
301303
video_paths=video_paths,
302304
num_decode_threads=num_decode_threads,
303305
)
304306

307+
if embed:
308+
_generate_embeddings_video(
309+
session=self.session,
310+
dataset_id=self.dataset_id,
311+
sample_ids=created_sample_ids,
312+
)
313+
305314
def add_images_from_path(
306315
self,
307316
path: PathLike,
@@ -354,7 +363,7 @@ def add_images_from_path(
354363
)
355364

356365
if embed:
357-
_generate_embeddings(
366+
_generate_embeddings_image(
358367
session=self.session, dataset_id=self.dataset_id, sample_ids=created_sample_ids
359368
)
360369

@@ -383,7 +392,7 @@ def add_samples_from_labelformat(
383392
)
384393

385394
if embed:
386-
_generate_embeddings(
395+
_generate_embeddings_image(
387396
session=self.session, dataset_id=self.dataset_id, sample_ids=created_sample_ids
388397
)
389398

@@ -446,7 +455,7 @@ def add_samples_from_yolo(
446455

447456
# Generate embeddings for all samples at once
448457
if embed:
449-
_generate_embeddings(
458+
_generate_embeddings_image(
450459
session=self.session, dataset_id=self.dataset_id, sample_ids=all_created_sample_ids
451460
)
452461

@@ -512,7 +521,7 @@ def add_samples_from_coco(
512521
)
513522

514523
if embed:
515-
_generate_embeddings(
524+
_generate_embeddings_image(
516525
session=self.session, dataset_id=self.dataset_id, sample_ids=created_sample_ids
517526
)
518527

@@ -564,7 +573,7 @@ def add_samples_from_coco_caption(
564573
)
565574

566575
if embed:
567-
_generate_embeddings(
576+
_generate_embeddings_image(
568577
session=self.session, dataset_id=self.dataset_id, sample_ids=created_sample_ids
569578
)
570579

@@ -635,7 +644,11 @@ def compute_similarity_metadata(
635644
)
636645

637646

638-
def _generate_embeddings(session: Session, dataset_id: UUID, sample_ids: list[UUID]) -> None:
647+
def _generate_embeddings_video(
648+
session: Session,
649+
dataset_id: UUID,
650+
sample_ids: list[UUID],
651+
) -> None:
639652
"""Generate and store embeddings for samples.
640653
641654
Args:
@@ -647,20 +660,54 @@ def _generate_embeddings(session: Session, dataset_id: UUID, sample_ids: list[UU
647660
return
648661

649662
embedding_manager = EmbeddingManagerProvider.get_embedding_manager()
650-
model_id = embedding_manager.load_or_get_default_model(
663+
model_id = embedding_manager.load_or_get_default_model(session=session, dataset_id=dataset_id)
664+
if model_id is None:
665+
logger.warning("No embedding model loaded. Skipping embedding generation.")
666+
return
667+
668+
embedding_manager.embed_videos(
651669
session=session,
652670
dataset_id=dataset_id,
671+
sample_ids=sample_ids,
672+
embedding_model_id=model_id,
653673
)
674+
675+
_mark_embedding_features_enabled()
676+
677+
678+
def _generate_embeddings_image(
679+
session: Session,
680+
dataset_id: UUID,
681+
sample_ids: list[UUID],
682+
) -> None:
683+
"""Generate and store embeddings for samples.
684+
685+
Args:
686+
session: Database session for resolver operations.
687+
dataset_id: The ID of the dataset to associate with the embedding model.
688+
sample_ids: List of sample IDs to generate embeddings for.
689+
sample_type: The sample_type to generate embeddings for.
690+
"""
691+
if not sample_ids:
692+
return
693+
694+
embedding_manager = EmbeddingManagerProvider.get_embedding_manager()
695+
model_id = embedding_manager.load_or_get_default_model(session=session, dataset_id=dataset_id)
654696
if model_id is None:
655697
logger.warning("No embedding model loaded. Skipping embedding generation.")
656698
return
657699

658700
embedding_manager.embed_images(
659701
session=session,
702+
dataset_id=dataset_id,
660703
sample_ids=sample_ids,
661704
embedding_model_id=model_id,
662705
)
663706

707+
_mark_embedding_features_enabled()
708+
709+
710+
def _mark_embedding_features_enabled() -> None:
664711
# Mark the embedding search feature as enabled.
665712
if "embeddingSearchEnabled" not in features.lightly_studio_active_features:
666713
features.lightly_studio_active_features.append("embeddingSearchEnabled")

0 commit comments

Comments
 (0)