Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
508d262
New protocol VideoEmbeddingGenerator + Embedding manager video method
JonasWurst Dec 4, 2025
2e2a9b6
update test
JonasWurst Dec 4, 2025
b744070
add ...
JonasWurst Dec 4, 2025
a457d57
remove video embed form manger
JonasWurst Dec 4, 2025
1f3723a
Add video embed method to manager
JonasWurst Dec 4, 2025
8f9db53
Add tests
JonasWurst Dec 4, 2025
fe5bb81
Adding perception encoder image
JonasWurst Dec 4, 2025
0ad0067
add TODOs
JonasWurst Dec 4, 2025
4191dd3
Merge branch 'main' into jonas-lig-8204-video-embedding-model-impleme…
JonasWurst Dec 4, 2025
951d8e3
Merge commit '966805d2064029240b1a4cff78133b9819f69881' into jonas-li…
JonasWurst Dec 4, 2025
a5652f5
Merge commit 'bd32906b782b733dda02b1f703503435fd434700' into jonas-li…
JonasWurst Dec 4, 2025
357b435
Merge commit 'a5652f594af312cb8c054d7348d0c2867826dfbe' into jonas-li…
JonasWurst Dec 4, 2025
2698f64
Merge commit '357b435afcdd2a08541c09fed2a3f6e20d0ed10e' into jonas-li…
JonasWurst Dec 4, 2025
5a4aeb5
add video embedding and video dataset
JonasWurst Dec 4, 2025
0e6ad03
Tests
JonasWurst Dec 4, 2025
bfcf910
Merge commit '31abae36ef1680ffdb0463a7c699d7af7fab5665' into jonas-li…
JonasWurst Dec 5, 2025
df804ce
switch to pyav
JonasWurst Dec 5, 2025
fe85b9e
add video embedding to dataset
JonasWurst Dec 5, 2025
06bd377
Merge commit 'ad7bb9e6e02fa4ecd410f5292ff7394da0b1afe6' into jonas-li…
JonasWurst Dec 8, 2025
00f8bef
rename
JonasWurst Dec 8, 2025
166576c
Fix docstring
JonasWurst Dec 8, 2025
4b8a239
Fix test attempt
JonasWurst Dec 8, 2025
7b319f4
Fix attempt 2
JonasWurst Dec 8, 2025
19ca995
Fix test
JonasWurst Dec 8, 2025
aed135b
fix 3
JonasWurst Dec 8, 2025
16fca47
review comments
JonasWurst Dec 9, 2025
29b41b0
format
JonasWurst Dec 9, 2025
f6d0ad4
Merge branch 'main' into jonas-lig-8204-video-embedding-model-impleme…
JonasWurst Dec 9, 2025
22790b9
Merge commit 'f6d0ad400b28a8cb10bc6ced9199a36b49244dc1' into jonas-li…
JonasWurst Dec 9, 2025
078f6f0
format
JonasWurst Dec 9, 2025
ea2927c
Merge commit '26b777d326027f236ef0466f616b6cd532bfcff6' into jonas-li…
JonasWurst Dec 9, 2025
ce81e0e
Remove TODOs
JonasWurst Dec 9, 2025
15ea81e
Review fixes
JonasWurst Dec 10, 2025
145b285
Fix format
JonasWurst Dec 10, 2025
067b6b9
format
JonasWurst Dec 10, 2025
1c8d26e
Review changes 1
JonasWurst Dec 10, 2025
8a3c279
Review
JonasWurst Dec 10, 2025
75d3294
Comment
JonasWurst Dec 11, 2025
03239cf
Format
JonasWurst Dec 11, 2025
767d1c4
Merge branch 'main' into jonas-lig-8206-video-embedding-model-add-emb…
JonasWurst Dec 11, 2025
e5df4cd
Review changes
JonasWurst Dec 11, 2025
fcce43c
Merge branch 'main' into jonas-lig-8206-video-embedding-model-add-emb…
JonasWurst Dec 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
EmbeddingManagerProvider,
TextEmbedQuery,
)
from lightly_studio.db_manager import SessionDep
from lightly_studio.resolvers import dataset_resolver

text_embedding_router = APIRouter()
# Define a type alias for the EmbeddingManager dependency
Expand All @@ -27,6 +29,7 @@

@text_embedding_router.get("/text_embedding/embed_text", response_model=List[float])
def embed_text(
session: SessionDep,
embedding_manager: EmbeddingManagerDep,
query_text: str = Query(..., description="The text to embed."),
embedding_model_id: Annotated[
Expand All @@ -35,9 +38,16 @@ def embed_text(
] = None,
) -> list[float]:
"""Retrieve embeddings for the input text."""
# TODO(Jonas, 12/2025): Remove this hack after dataset_id is provided from frontend
# This is a hack, since at the moment, no valid embedding_model_id is passed from the frontend.
# so we fetch the root_dataset_id, which will be used inside embed_text to get the default model
# for this dataset.
root_dataset = dataset_resolver.get_root_dataset(session=session)
dataset_id = root_dataset.dataset_id
try:
text_embeddings = embedding_manager.embed_text(
TextEmbedQuery(query_text, embedding_model_id)
dataset_id=dataset_id,
text_query=TextEmbedQuery(text=query_text, embedding_model_id=embedding_model_id),
)
except ValueError as exc:
raise HTTPException(
Expand Down
63 changes: 55 additions & 8 deletions lightly_studio/src/lightly_studio/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,7 @@ def add_videos_from_path(
path: PathLike,
allowed_extensions: Iterable[str] | None = None,
num_decode_threads: int | None = None,
embed: bool = True,
) -> None:
"""Adding video frames from the specified path to the dataset.

Expand All @@ -281,6 +282,7 @@ def add_videos_from_path(
uses default VIDEO_EXTENSIONS.
num_decode_threads: Optional override for the number of FFmpeg decode threads.
If omitted, the available CPU cores - 1 (max 16) are used.
embed: If True, generate embeddings for the newly added videos.
"""
# Collect video file paths.
if allowed_extensions:
Expand All @@ -295,13 +297,20 @@ def add_videos_from_path(
logger.info(f"Found {len(video_paths)} videos in {path}.")

# Process videos.
add_videos.load_into_dataset_from_paths(
created_sample_ids, _ = add_videos.load_into_dataset_from_paths(
session=self.session,
dataset_id=self.dataset_id,
video_paths=video_paths,
num_decode_threads=num_decode_threads,
)

if embed:
_generate_embeddings_video(
session=self.session,
dataset_id=self.dataset_id,
sample_ids=created_sample_ids,
)

def add_images_from_path(
self,
path: PathLike,
Expand Down Expand Up @@ -354,7 +363,7 @@ def add_images_from_path(
)

if embed:
_generate_embeddings(
_generate_embeddings_image(
session=self.session, dataset_id=self.dataset_id, sample_ids=created_sample_ids
)

Expand Down Expand Up @@ -383,7 +392,7 @@ def add_samples_from_labelformat(
)

if embed:
_generate_embeddings(
_generate_embeddings_image(
session=self.session, dataset_id=self.dataset_id, sample_ids=created_sample_ids
)

Expand Down Expand Up @@ -446,7 +455,7 @@ def add_samples_from_yolo(

# Generate embeddings for all samples at once
if embed:
_generate_embeddings(
_generate_embeddings_image(
session=self.session, dataset_id=self.dataset_id, sample_ids=all_created_sample_ids
)

Expand Down Expand Up @@ -512,7 +521,7 @@ def add_samples_from_coco(
)

if embed:
_generate_embeddings(
_generate_embeddings_image(
session=self.session, dataset_id=self.dataset_id, sample_ids=created_sample_ids
)

Expand Down Expand Up @@ -564,7 +573,7 @@ def add_samples_from_coco_caption(
)

if embed:
_generate_embeddings(
_generate_embeddings_image(
session=self.session, dataset_id=self.dataset_id, sample_ids=created_sample_ids
)

Expand Down Expand Up @@ -635,7 +644,11 @@ def compute_similarity_metadata(
)


def _generate_embeddings(session: Session, dataset_id: UUID, sample_ids: list[UUID]) -> None:
def _generate_embeddings_video(
session: Session,
dataset_id: UUID,
sample_ids: list[UUID],
) -> None:
"""Generate and store embeddings for samples.

Args:
Expand All @@ -647,20 +660,54 @@ def _generate_embeddings(session: Session, dataset_id: UUID, sample_ids: list[UU
return

embedding_manager = EmbeddingManagerProvider.get_embedding_manager()
model_id = embedding_manager.load_or_get_default_model(
model_id = embedding_manager.load_or_get_default_model(session=session, dataset_id=dataset_id)
if model_id is None:
logger.warning("No embedding model loaded. Skipping embedding generation.")
return

embedding_manager.embed_videos(
session=session,
dataset_id=dataset_id,
sample_ids=sample_ids,
embedding_model_id=model_id,
)

_mark_embedding_features_enabled()


def _generate_embeddings_image(
session: Session,
dataset_id: UUID,
sample_ids: list[UUID],
) -> None:
"""Generate and store embeddings for samples.

Args:
session: Database session for resolver operations.
dataset_id: The ID of the dataset to associate with the embedding model.
sample_ids: List of sample IDs to generate embeddings for.
sample_type: The sample_type to generate embeddings for.
"""
if not sample_ids:
return

embedding_manager = EmbeddingManagerProvider.get_embedding_manager()
model_id = embedding_manager.load_or_get_default_model(session=session, dataset_id=dataset_id)
if model_id is None:
logger.warning("No embedding model loaded. Skipping embedding generation.")
return

embedding_manager.embed_images(
session=session,
dataset_id=dataset_id,
sample_ids=sample_ids,
embedding_model_id=model_id,
)

_mark_embedding_features_enabled()


def _mark_embedding_features_enabled() -> None:
# Mark the embedding search feature as enabled.
if "embeddingSearchEnabled" not in features.lightly_studio_active_features:
features.lightly_studio_active_features.append("embeddingSearchEnabled")
Expand Down
Loading