Skip to content
Merged
Show file tree
Hide file tree
Changes from 32 commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
508d262
New protocol VideoEmbeddingGenerator + Embedding manager video method
JonasWurst Dec 4, 2025
2e2a9b6
update test
JonasWurst Dec 4, 2025
b744070
add ...
JonasWurst Dec 4, 2025
a457d57
remove video embed form manger
JonasWurst Dec 4, 2025
1f3723a
Add video embed method to manager
JonasWurst Dec 4, 2025
8f9db53
Add tests
JonasWurst Dec 4, 2025
fe5bb81
Adding perception encoder image
JonasWurst Dec 4, 2025
0ad0067
add TODOs
JonasWurst Dec 4, 2025
4191dd3
Merge branch 'main' into jonas-lig-8204-video-embedding-model-impleme…
JonasWurst Dec 4, 2025
951d8e3
Merge commit '966805d2064029240b1a4cff78133b9819f69881' into jonas-li…
JonasWurst Dec 4, 2025
a5652f5
Merge commit 'bd32906b782b733dda02b1f703503435fd434700' into jonas-li…
JonasWurst Dec 4, 2025
357b435
Merge commit 'a5652f594af312cb8c054d7348d0c2867826dfbe' into jonas-li…
JonasWurst Dec 4, 2025
2698f64
Merge commit '357b435afcdd2a08541c09fed2a3f6e20d0ed10e' into jonas-li…
JonasWurst Dec 4, 2025
5a4aeb5
add video embedding and video dataset
JonasWurst Dec 4, 2025
0e6ad03
Tests
JonasWurst Dec 4, 2025
bfcf910
Merge commit '31abae36ef1680ffdb0463a7c699d7af7fab5665' into jonas-li…
JonasWurst Dec 5, 2025
df804ce
switch to pyav
JonasWurst Dec 5, 2025
fe85b9e
add video embedding to dataset
JonasWurst Dec 5, 2025
06bd377
Merge commit 'ad7bb9e6e02fa4ecd410f5292ff7394da0b1afe6' into jonas-li…
JonasWurst Dec 8, 2025
00f8bef
rename
JonasWurst Dec 8, 2025
166576c
Fix docstring
JonasWurst Dec 8, 2025
4b8a239
Fix test attempt
JonasWurst Dec 8, 2025
7b319f4
Fix attempt 2
JonasWurst Dec 8, 2025
19ca995
Fix test
JonasWurst Dec 8, 2025
aed135b
fix 3
JonasWurst Dec 8, 2025
16fca47
review comments
JonasWurst Dec 9, 2025
29b41b0
format
JonasWurst Dec 9, 2025
f6d0ad4
Merge branch 'main' into jonas-lig-8204-video-embedding-model-impleme…
JonasWurst Dec 9, 2025
22790b9
Merge commit 'f6d0ad400b28a8cb10bc6ced9199a36b49244dc1' into jonas-li…
JonasWurst Dec 9, 2025
078f6f0
format
JonasWurst Dec 9, 2025
ea2927c
Merge commit '26b777d326027f236ef0466f616b6cd532bfcff6' into jonas-li…
JonasWurst Dec 9, 2025
ce81e0e
Remove TODOs
JonasWurst Dec 9, 2025
15ea81e
Review fixes
JonasWurst Dec 10, 2025
145b285
Fix format
JonasWurst Dec 10, 2025
067b6b9
format
JonasWurst Dec 10, 2025
1c8d26e
Review changes 1
JonasWurst Dec 10, 2025
8a3c279
Review
JonasWurst Dec 10, 2025
75d3294
Comment
JonasWurst Dec 11, 2025
03239cf
Format
JonasWurst Dec 11, 2025
767d1c4
Merge branch 'main' into jonas-lig-8206-video-embedding-model-add-emb…
JonasWurst Dec 11, 2025
e5df4cd
Review changes
JonasWurst Dec 11, 2025
fcce43c
Merge branch 'main' into jonas-lig-8206-video-embedding-model-add-emb…
JonasWurst Dec 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 47 additions & 16 deletions lightly_studio/src/lightly_studio/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,9 @@ def load(name: str | None = None) -> Dataset:
raise ValueError(f"Dataset with name '{name}' not found.")
# If we have embeddings in the database enable the FSC and embedding search features.
_enable_embedding_features_if_available(
session=db_manager.persistent_session(), dataset_id=dataset.dataset_id
session=db_manager.persistent_session(),
dataset_id=dataset.dataset_id,
sample_type=dataset.sample_type,
)
return Dataset(dataset=dataset)

Expand Down Expand Up @@ -165,7 +167,9 @@ def load_or_create(

# If we have embeddings in the database enable the FSC and embedding search features.
_enable_embedding_features_if_available(
session=db_manager.persistent_session(), dataset_id=dataset.dataset_id
session=db_manager.persistent_session(),
dataset_id=dataset.dataset_id,
sample_type=dataset.sample_type,
)
return Dataset(dataset=dataset)

Expand Down Expand Up @@ -271,6 +275,7 @@ def add_videos_from_path(
path: PathLike,
allowed_extensions: Iterable[str] | None = None,
num_decode_threads: int | None = None,
embed_videos: bool = True,
) -> None:
"""Adding video frames from the specified path to the dataset.

Expand All @@ -281,6 +286,7 @@ def add_videos_from_path(
uses default VIDEO_EXTENSIONS.
num_decode_threads: Optional override for the number of FFmpeg decode threads.
If omitted, the available CPU cores - 1 (max 16) are used.
embed_videos: If True, generate embeddings for the newly added videos.
"""
# Collect video file paths.
if allowed_extensions:
Expand All @@ -295,13 +301,21 @@ def add_videos_from_path(
logger.info(f"Found {len(video_paths)} videos in {path}.")

# Process videos.
add_videos.load_into_dataset_from_paths(
created_video_sample_ids, _ = add_videos.load_into_dataset_from_paths(
session=self.session,
dataset_id=self.dataset_id,
video_paths=video_paths,
num_decode_threads=num_decode_threads,
)

if embed_videos:
_generate_embeddings(
session=self.session,
dataset_id=self.dataset_id,
sample_ids=created_video_sample_ids,
sample_type=SampleType.VIDEO,
)

def add_images_from_path(
self,
path: PathLike,
Expand Down Expand Up @@ -635,31 +649,43 @@ def compute_similarity_metadata(
)


def _generate_embeddings(session: Session, dataset_id: UUID, sample_ids: list[UUID]) -> None:
def _generate_embeddings(
session: Session,
dataset_id: UUID,
sample_ids: list[UUID],
sample_type: SampleType = SampleType.IMAGE,
) -> None:
"""Generate and store embeddings for samples.

Args:
session: Database session for resolver operations.
dataset_id: The ID of the dataset to associate with the embedding model.
sample_ids: List of sample IDs to generate embeddings for.
sample_type: The sample_type to generate embeddings for.
"""
if not sample_ids:
return

embedding_manager = EmbeddingManagerProvider.get_embedding_manager()
model_id = embedding_manager.load_or_get_default_model(
session=session,
dataset_id=dataset_id,
session=session, dataset_id=dataset_id, sample_type=sample_type
)
if model_id is None:
logger.warning("No embedding model loaded. Skipping embedding generation.")
return

embedding_manager.embed_images(
session=session,
sample_ids=sample_ids,
embedding_model_id=model_id,
)
if sample_type == SampleType.IMAGE:
embedding_manager.embed_images(
session=session,
sample_ids=sample_ids,
embedding_model_id=model_id,
)
elif sample_type == SampleType.VIDEO:
embedding_manager.embed_videos(
session=session,
sample_ids=sample_ids,
embedding_model_id=model_id,
)

# Mark the embedding search feature as enabled.
if "embeddingSearchEnabled" not in features.lightly_studio_active_features:
Expand Down Expand Up @@ -689,20 +715,22 @@ def _resolve_yolo_splits(data_yaml: Path, input_split: str | None) -> list[str]:
return splits


def _are_embeddings_available(session: Session, dataset_id: UUID) -> bool:
def _are_embeddings_available(
session: Session, dataset_id: UUID, sample_type: SampleType = SampleType.IMAGE
) -> bool:
"""Check if there are any embeddings available for the given dataset.

Args:
session: Database session for resolver operations.
dataset_id: The ID of the dataset to check for embeddings.
sample_type: the sample_type of the dataset.

Returns:
True if embeddings exist for the dataset, False otherwise.
"""
embedding_manager = EmbeddingManagerProvider.get_embedding_manager()
model_id = embedding_manager.load_or_get_default_model(
session=session,
dataset_id=dataset_id,
session=session, dataset_id=dataset_id, sample_type=sample_type
)
if model_id is None:
# No default embedding model loaded for this dataset.
Expand All @@ -718,14 +746,17 @@ def _are_embeddings_available(session: Session, dataset_id: UUID) -> bool:
)


def _enable_embedding_features_if_available(session: Session, dataset_id: UUID) -> None:
def _enable_embedding_features_if_available(
session: Session, dataset_id: UUID, sample_type: SampleType
) -> None:
"""Enable embedding-related features if embeddings are available in the DB.

Args:
session: Database session for resolver operations.
dataset_id: The ID of the dataset to check for embeddings.
sample_type: The sample_type of the dataset.
"""
if _are_embeddings_available(session=session, dataset_id=dataset_id):
if _are_embeddings_available(session=session, dataset_id=dataset_id, sample_type=sample_type):
if "embeddingSearchEnabled" not in features.lightly_studio_active_features:
features.lightly_studio_active_features.append("embeddingSearchEnabled")
if "fewShotClassifierEnabled" not in features.lightly_studio_active_features:
Expand Down
86 changes: 65 additions & 21 deletions lightly_studio/src/lightly_studio/dataset/embedding_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
ImageEmbeddingGenerator,
VideoEmbeddingGenerator,
)
from lightly_studio.models.dataset import SampleType
from lightly_studio.models.embedding_model import EmbeddingModelTable
from lightly_studio.models.sample_embedding import SampleEmbeddingCreate
from lightly_studio.resolvers import (
Expand Down Expand Up @@ -60,13 +61,14 @@ class EmbeddingManager:
def __init__(self) -> None:
"""Initialize the embedding manager."""
self._models: dict[UUID, EmbeddingGenerator] = {}
self._default_model_id: UUID | None = None
self._default_model_id: dict[SampleType, UUID | None] = {}

def register_embedding_model(
self,
session: Session,
dataset_id: UUID,
embedding_generator: EmbeddingGenerator,
sample_type: SampleType = SampleType.IMAGE,
set_as_default: bool = False,
) -> EmbeddingModelTable:
"""Register an embedding model in the database.
Expand All @@ -78,6 +80,7 @@ def register_embedding_model(
session: Database session for resolver operations.
dataset_id: The ID of the dataset to associate with the model.
embedding_generator: The model implementation used for embeddings.
sample_type: The SampleType the model is compatible with.
set_as_default: Whether to set this model as the default.

Returns:
Expand All @@ -94,21 +97,24 @@ def register_embedding_model(
self._models[model_id] = embedding_generator

# Set as default if requested or if it's the first model
if set_as_default or self._default_model_id is None:
self._default_model_id = model_id
if set_as_default or self._default_model_id.get(sample_type, None) is None:
self._default_model_id[sample_type] = model_id

return db_model

def embed_text(self, text_query: TextEmbedQuery) -> list[float]:
def embed_text(
self, text_query: TextEmbedQuery, sample_type: SampleType = SampleType.IMAGE
) -> list[float]:
"""Generate an embedding for a text sample.

Args:
text_query: Text embedding query containing text and model ID.
sample_type: The sample_type the default model is registered for.

Returns:
A list of floats representing the generated embedding.
"""
model_id = text_query.embedding_model_id or self._default_model_id
model_id = text_query.embedding_model_id or self._default_model_id.get(sample_type, None)
if model_id is None:
raise ValueError("No embedding model specified and no default model set.")

Expand All @@ -135,7 +141,7 @@ def embed_images(
ValueError: If no embedding model is registered, provided model
ID doesn't exist or if the embedding model does not support images.
"""
model_id = self._get_default_or_validate(embedding_model_id)
model_id = self._get_default_or_validate(embedding_model_id, sample_type=SampleType.IMAGE)

model = self._models[model_id]
if not isinstance(model, ImageEmbeddingGenerator):
Expand Down Expand Up @@ -186,7 +192,7 @@ def embed_videos(
ValueError: If no embedding model is registered, provided model
ID doesn't exist or if the embedding model does not support videos.
"""
model_id = self._get_default_or_validate(embedding_model_id)
model_id = self._get_default_or_validate(embedding_model_id, sample_type=SampleType.VIDEO)

model = self._models[model_id]
if not isinstance(model, VideoEmbeddingGenerator):
Expand Down Expand Up @@ -218,29 +224,30 @@ def embed_videos(
# Store the embeddings in the database.
sample_embedding_resolver.create_many(session=session, sample_embeddings=sample_embeddings)

# TODO (Jonas 12/2025): We need to introduce default models per type
def load_or_get_default_model(
self,
session: Session,
dataset_id: UUID,
sample_type: SampleType = SampleType.IMAGE,
) -> UUID | None:
"""Ensure a default embedding model exists and return its ID.

Args:
session: Database session for resolver operations.
dataset_id: Dataset identifier the model should belong to.
sample_type: SampleType the model should be compatible with.

Returns:
UUID of the default embedding model or None if the model cannot be loaded.
"""
# Return the existing default model ID if available.
# TODO(Michal, 09/2025): We do not check if the model belongs to the dataset.
# The design of EmbeddingManager needs to change to support multiple datasets.
if self._default_model_id is not None:
return self._default_model_id
if self._default_model_id.get(sample_type, None) is not None:
return self._default_model_id[sample_type]

# Load the embedding generator based on configuration.
embedding_generator = _load_embedding_generator_from_env()
embedding_generator = _load_embedding_generator_from_env(sample_type=sample_type)
if embedding_generator is None:
return None

Expand All @@ -250,51 +257,88 @@ def load_or_get_default_model(
dataset_id=dataset_id,
embedding_generator=embedding_generator,
set_as_default=True,
sample_type=sample_type,
)

return embedding_model.embedding_model_id

# TODO (Jonas 12/2025): We need to introduce default models per type
def _get_default_or_validate(self, embedding_model_id: UUID | None) -> UUID:
if embedding_model_id is None and self._default_model_id is None:
def _get_default_or_validate(
self, embedding_model_id: UUID | None, sample_type: SampleType = SampleType.IMAGE
) -> UUID:
default_model_id = self._default_model_id.get(sample_type, None)
if embedding_model_id is None and default_model_id is None:
raise ValueError(
"No embedding_model_id provided and no default embedding model registered."
)

if embedding_model_id is None and self._default_model_id is not None:
return self._default_model_id
if embedding_model_id is None and default_model_id is not None:
return default_model_id

if embedding_model_id not in self._models:
raise ValueError(f"No embedding model found with ID {embedding_model_id}")
return embedding_model_id


# TODO(Michal, 09/2025): Write tests for this function.
def _load_embedding_generator_from_env() -> EmbeddingGenerator | None:
def _load_embedding_generator_from_env(sample_type: SampleType) -> EmbeddingGenerator | None:
"""Load the embedding generator based on environment variable configuration."""
if sample_type == SampleType.IMAGE:
return _load_image_embedding_generator_from_env()
if sample_type == SampleType.VIDEO:
return _load_video_embedding_generator_from_env()
return None


def _load_image_embedding_generator_from_env() -> EmbeddingGenerator | None:
if env.LIGHTLY_STUDIO_EMBEDDINGS_MODEL_TYPE == "EDGE":
try:
from lightly_studio.dataset.edge_embedding_generator import (
EdgeSDKEmbeddingGenerator,
)

logger.info("Using LightlyEdge embedding generator.")
logger.info("Using LightlyEdge embedding generator for images.")
return EdgeSDKEmbeddingGenerator(model_path=env.LIGHTLY_STUDIO_EDGE_MODEL_FILE_PATH)
except ImportError:
logger.warning("Embedding functionality is disabled.")
return None
elif env.LIGHTLY_STUDIO_EMBEDDINGS_MODEL_TYPE == "MOBILE_CLIP":
try:
from lightly_studio.dataset.mobileclip_embedding_generator import (
MobileCLIPEmbeddingGenerator,
)

logger.info("Using MobileCLIP embedding generator.")
logger.info("Using MobileCLIP embedding generator for images.")
return MobileCLIPEmbeddingGenerator()
except ImportError:
logger.warning("Embedding functionality is disabled.")
elif env.LIGHTLY_STUDIO_EMBEDDINGS_MODEL_TYPE == "PE":
try:
from lightly_studio.dataset.perception_encoder_embedding_generator import (
PerceptionEncoderEmbeddingGenerator,
)

logger.info("Using PerceptionEncoder embedding generator for images.")
return PerceptionEncoderEmbeddingGenerator()
except ImportError:
logger.warning("Embedding functionality is disabled.")
else:
logger.warning(f"Unsupported model type: '{env.LIGHTLY_STUDIO_EMBEDDINGS_MODEL_TYPE}'")
logger.warning("Embedding functionality is disabled.")
return None


def _load_video_embedding_generator_from_env() -> EmbeddingGenerator | None:
if env.LIGHTLY_STUDIO_VIDEO_EMBEDDINGS_MODEL_TYPE == "PE":
try:
from lightly_studio.dataset.perception_encoder_embedding_generator import (
PerceptionEncoderEmbeddingGenerator,
)

logger.info("Using PerceptionEncoder embedding generator for videos.")
return PerceptionEncoderEmbeddingGenerator()
except ImportError:
logger.warning("Embedding functionality is disabled.")
return None

logger.warning(f"Unsupported model type: '{env.LIGHTLY_STUDIO_EMBEDDINGS_MODEL_TYPE}'")
logger.warning(f"Unsupported model type: '{env.LIGHTLY_STUDIO_VIDEO_EMBEDDINGS_MODEL_TYPE}'")
logger.warning("Embedding functionality is disabled.")
return None
3 changes: 3 additions & 0 deletions lightly_studio/src/lightly_studio/dataset/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
LIGHTLY_STUDIO_EMBEDDINGS_MODEL_TYPE: str = env.str(
"LIGHTLY_STUDIO_EMBEDDINGS_MODEL_TYPE", "MOBILE_CLIP"
)
LIGHTLY_STUDIO_VIDEO_EMBEDDINGS_MODEL_TYPE: str = env.str(
"LIGHTLY_STUDIO_VIDEO_EMBEDDINGS_MODEL_TYPE", "PE"
)
LIGHTLY_STUDIO_EDGE_MODEL_FILE_PATH: str = env.str("EDGE_MODEL_PATH", "./lightly_model.tar")
LIGHTLY_STUDIO_PROTOCOL: str = env.str("LIGHTLY_STUDIO_PROTOCOL", "http")
LIGHTLY_STUDIO_PORT: int = env.int("LIGHTLY_STUDIO_PORT", 8001)
Expand Down
Loading