diff --git a/lightly_studio/src/lightly_studio/api/routes/api/text_embedding.py b/lightly_studio/src/lightly_studio/api/routes/api/text_embedding.py index fb5325119..a65fed18d 100644 --- a/lightly_studio/src/lightly_studio/api/routes/api/text_embedding.py +++ b/lightly_studio/src/lightly_studio/api/routes/api/text_embedding.py @@ -16,6 +16,8 @@ EmbeddingManagerProvider, TextEmbedQuery, ) +from lightly_studio.db_manager import SessionDep +from lightly_studio.resolvers import dataset_resolver text_embedding_router = APIRouter() # Define a type alias for the EmbeddingManager dependency @@ -27,6 +29,7 @@ @text_embedding_router.get("/text_embedding/embed_text", response_model=List[float]) def embed_text( + session: SessionDep, embedding_manager: EmbeddingManagerDep, query_text: str = Query(..., description="The text to embed."), embedding_model_id: Annotated[ @@ -35,9 +38,16 @@ def embed_text( ] = None, ) -> list[float]: """Retrieve embeddings for the input text.""" + # TODO(Jonas, 12/2025): Remove this hack after dataset_id is provided from frontend + # This is a hack, since at the moment, no valid embedding_model_id is passed from the frontend. + # so we fetch the root_dataset_id, which will be used inside embed_text to get the default model + # for this dataset. + root_dataset = dataset_resolver.get_root_dataset(session=session) + dataset_id = root_dataset.dataset_id try: text_embeddings = embedding_manager.embed_text( - TextEmbedQuery(query_text, embedding_model_id) + dataset_id=dataset_id, + text_query=TextEmbedQuery(text=query_text, embedding_model_id=embedding_model_id), ) except ValueError as exc: raise HTTPException( diff --git a/lightly_studio/src/lightly_studio/core/dataset.py b/lightly_studio/src/lightly_studio/core/dataset.py index e7a1f8469..7dff4d470 100644 --- a/lightly_studio/src/lightly_studio/core/dataset.py +++ b/lightly_studio/src/lightly_studio/core/dataset.py @@ -271,6 +271,7 @@ def add_videos_from_path( path: PathLike, allowed_extensions: Iterable[str] | None = None, num_decode_threads: int | None = None, + embed: bool = True, ) -> None: """Adding video frames from the specified path to the dataset. @@ -281,6 +282,7 @@ def add_videos_from_path( uses default VIDEO_EXTENSIONS. num_decode_threads: Optional override for the number of FFmpeg decode threads. If omitted, the available CPU cores - 1 (max 16) are used. + embed: If True, generate embeddings for the newly added videos. """ # Collect video file paths. if allowed_extensions: @@ -295,13 +297,20 @@ def add_videos_from_path( logger.info(f"Found {len(video_paths)} videos in {path}.") # Process videos. - add_videos.load_into_dataset_from_paths( + created_sample_ids, _ = add_videos.load_into_dataset_from_paths( session=self.session, dataset_id=self.dataset_id, video_paths=video_paths, num_decode_threads=num_decode_threads, ) + if embed: + _generate_embeddings_video( + session=self.session, + dataset_id=self.dataset_id, + sample_ids=created_sample_ids, + ) + def add_images_from_path( self, path: PathLike, @@ -354,7 +363,7 @@ def add_images_from_path( ) if embed: - _generate_embeddings( + _generate_embeddings_image( session=self.session, dataset_id=self.dataset_id, sample_ids=created_sample_ids ) @@ -383,7 +392,7 @@ def add_samples_from_labelformat( ) if embed: - _generate_embeddings( + _generate_embeddings_image( session=self.session, dataset_id=self.dataset_id, sample_ids=created_sample_ids ) @@ -446,7 +455,7 @@ def add_samples_from_yolo( # Generate embeddings for all samples at once if embed: - _generate_embeddings( + _generate_embeddings_image( session=self.session, dataset_id=self.dataset_id, sample_ids=all_created_sample_ids ) @@ -512,7 +521,7 @@ def add_samples_from_coco( ) if embed: - _generate_embeddings( + _generate_embeddings_image( session=self.session, dataset_id=self.dataset_id, sample_ids=created_sample_ids ) @@ -564,7 +573,7 @@ def add_samples_from_coco_caption( ) if embed: - _generate_embeddings( + _generate_embeddings_image( session=self.session, dataset_id=self.dataset_id, sample_ids=created_sample_ids ) @@ -635,7 +644,11 @@ def compute_similarity_metadata( ) -def _generate_embeddings(session: Session, dataset_id: UUID, sample_ids: list[UUID]) -> None: +def _generate_embeddings_video( + session: Session, + dataset_id: UUID, + sample_ids: list[UUID], +) -> None: """Generate and store embeddings for samples. Args: @@ -647,20 +660,54 @@ def _generate_embeddings(session: Session, dataset_id: UUID, sample_ids: list[UU return embedding_manager = EmbeddingManagerProvider.get_embedding_manager() - model_id = embedding_manager.load_or_get_default_model( + model_id = embedding_manager.load_or_get_default_model(session=session, dataset_id=dataset_id) + if model_id is None: + logger.warning("No embedding model loaded. Skipping embedding generation.") + return + + embedding_manager.embed_videos( session=session, dataset_id=dataset_id, + sample_ids=sample_ids, + embedding_model_id=model_id, ) + + _mark_embedding_features_enabled() + + +def _generate_embeddings_image( + session: Session, + dataset_id: UUID, + sample_ids: list[UUID], +) -> None: + """Generate and store embeddings for samples. + + Args: + session: Database session for resolver operations. + dataset_id: The ID of the dataset to associate with the embedding model. + sample_ids: List of sample IDs to generate embeddings for. + sample_type: The sample_type to generate embeddings for. + """ + if not sample_ids: + return + + embedding_manager = EmbeddingManagerProvider.get_embedding_manager() + model_id = embedding_manager.load_or_get_default_model(session=session, dataset_id=dataset_id) if model_id is None: logger.warning("No embedding model loaded. Skipping embedding generation.") return embedding_manager.embed_images( session=session, + dataset_id=dataset_id, sample_ids=sample_ids, embedding_model_id=model_id, ) + _mark_embedding_features_enabled() + + +def _mark_embedding_features_enabled() -> None: # Mark the embedding search feature as enabled. if "embeddingSearchEnabled" not in features.lightly_studio_active_features: features.lightly_studio_active_features.append("embeddingSearchEnabled") diff --git a/lightly_studio/src/lightly_studio/dataset/embedding_manager.py b/lightly_studio/src/lightly_studio/dataset/embedding_manager.py index b9f5a993e..ff8d1b25d 100644 --- a/lightly_studio/src/lightly_studio/dataset/embedding_manager.py +++ b/lightly_studio/src/lightly_studio/dataset/embedding_manager.py @@ -14,9 +14,11 @@ ImageEmbeddingGenerator, VideoEmbeddingGenerator, ) +from lightly_studio.models.dataset import SampleType from lightly_studio.models.embedding_model import EmbeddingModelTable from lightly_studio.models.sample_embedding import SampleEmbeddingCreate from lightly_studio.resolvers import ( + dataset_resolver, embedding_model_resolver, image_resolver, sample_embedding_resolver, @@ -60,7 +62,7 @@ class EmbeddingManager: def __init__(self) -> None: """Initialize the embedding manager.""" self._models: dict[UUID, EmbeddingGenerator] = {} - self._default_model_id: UUID | None = None + self._dataset_id_to_default_model_id: dict[UUID, UUID] = {} def register_embedding_model( self, @@ -77,6 +79,7 @@ def register_embedding_model( Args: session: Database session for resolver operations. dataset_id: The ID of the dataset to associate with the model. + And to register as default, if requested. embedding_generator: The model implementation used for embeddings. set_as_default: Whether to set this model as the default. @@ -94,33 +97,34 @@ def register_embedding_model( self._models[model_id] = embedding_generator # Set as default if requested or if it's the first model - if set_as_default or self._default_model_id is None: - self._default_model_id = model_id + if set_as_default or dataset_id not in self._dataset_id_to_default_model_id: + self._dataset_id_to_default_model_id[dataset_id] = model_id return db_model - def embed_text(self, text_query: TextEmbedQuery) -> list[float]: + def embed_text(self, dataset_id: UUID, text_query: TextEmbedQuery) -> list[float]: """Generate an embedding for a text sample. Args: + dataset_id: The ID of the dataset to determine the registered default model. + It is used if embedding_model_id is not valid. text_query: Text embedding query containing text and model ID. Returns: A list of floats representing the generated embedding. """ - model_id = text_query.embedding_model_id or self._default_model_id - if model_id is None: - raise ValueError("No embedding model specified and no default model set.") + model_id = self._get_default_or_validate( + dataset_id=dataset_id, embedding_model_id=text_query.embedding_model_id + ) - model = self._models.get(model_id) - if model is None: - raise ValueError(f"Embedding model with ID {model_id} not found.") + model = self._models[model_id] return model.embed_text(text_query.text) def embed_images( self, session: Session, + dataset_id: UUID, sample_ids: list[UUID], embedding_model_id: UUID | None = None, ) -> None: @@ -128,6 +132,8 @@ def embed_images( Args: session: Database session for resolver operations. + dataset_id: The ID of the dataset to determine the registered default model. + It is used if embedding_model_id is not valid. sample_ids: List of sample IDs to generate embeddings for. embedding_model_id: ID of the model to use. Uses default if None. @@ -135,7 +141,9 @@ def embed_images( ValueError: If no embedding model is registered, provided model ID doesn't exist or if the embedding model does not support images. """ - model_id = self._get_default_or_validate(embedding_model_id) + model_id = self._get_default_or_validate( + dataset_id=dataset_id, embedding_model_id=embedding_model_id + ) model = self._models[model_id] if not isinstance(model, ImageEmbeddingGenerator): @@ -172,6 +180,7 @@ def embed_images( def embed_videos( self, session: Session, + dataset_id: UUID, sample_ids: list[UUID], embedding_model_id: UUID | None = None, ) -> None: @@ -179,6 +188,8 @@ def embed_videos( Args: session: Database session for resolver operations. + dataset_id: The ID of the dataset to determine the registered default model. + It is used if embedding_model_id is not valid. sample_ids: List of sample IDs to generate embeddings for. embedding_model_id: ID of the model to use. Uses default if None. @@ -186,7 +197,9 @@ def embed_videos( ValueError: If no embedding model is registered, provided model ID doesn't exist or if the embedding model does not support videos. """ - model_id = self._get_default_or_validate(embedding_model_id) + model_id = self._get_default_or_validate( + dataset_id=dataset_id, embedding_model_id=embedding_model_id + ) model = self._models[model_id] if not isinstance(model, VideoEmbeddingGenerator): @@ -218,7 +231,6 @@ def embed_videos( # Store the embeddings in the database. sample_embedding_resolver.create_many(session=session, sample_embeddings=sample_embeddings) - # TODO (Jonas 12/2025): We need to introduce default models per type def load_or_get_default_model( self, session: Session, @@ -234,13 +246,16 @@ def load_or_get_default_model( UUID of the default embedding model or None if the model cannot be loaded. """ # Return the existing default model ID if available. - # TODO(Michal, 09/2025): We do not check if the model belongs to the dataset. - # The design of EmbeddingManager needs to change to support multiple datasets. - if self._default_model_id is not None: - return self._default_model_id - # Load the embedding generator based on configuration. - embedding_generator = _load_embedding_generator_from_env() + if dataset_id in self._dataset_id_to_default_model_id: + return self._dataset_id_to_default_model_id[dataset_id] + + # Load the embedding generator based on sample_type from the env var. + dataset = dataset_resolver.get_by_id(session=session, dataset_id=dataset_id) + if dataset is None: + raise ValueError("Provided dataset_id could not be found.") + + embedding_generator = _load_embedding_generator_from_env(sample_type=dataset.sample_type) if embedding_generator is None: return None @@ -254,47 +269,81 @@ def load_or_get_default_model( return embedding_model.embedding_model_id - # TODO (Jonas 12/2025): We need to introduce default models per type - def _get_default_or_validate(self, embedding_model_id: UUID | None) -> UUID: - if embedding_model_id is None and self._default_model_id is None: + def _get_default_or_validate(self, dataset_id: UUID, embedding_model_id: UUID | None) -> UUID: + """Get a valid model_id or raise error of non available. + + If embedding_model_id is not provided, returns the default model for dataset_id. + If embedding_model_id is provided, validates that the model has been loaded and returns it. + """ + default_model_id = self._dataset_id_to_default_model_id.get(dataset_id, None) + if embedding_model_id is None and default_model_id is None: raise ValueError( "No embedding_model_id provided and no default embedding model registered." ) - if embedding_model_id is None and self._default_model_id is not None: - return self._default_model_id + if embedding_model_id is None and default_model_id is not None: + return default_model_id if embedding_model_id not in self._models: raise ValueError(f"No embedding model found with ID {embedding_model_id}") return embedding_model_id -# TODO(Michal, 09/2025): Write tests for this function. -def _load_embedding_generator_from_env() -> EmbeddingGenerator | None: +def _load_embedding_generator_from_env(sample_type: SampleType) -> EmbeddingGenerator | None: """Load the embedding generator based on environment variable configuration.""" + if sample_type == SampleType.IMAGE: + return _load_image_embedding_generator_from_env() + if sample_type == SampleType.VIDEO: + return _load_video_embedding_generator() + return None + + +# TODO(Michal, 09/2025): Write tests for this function. +def _load_image_embedding_generator_from_env() -> ImageEmbeddingGenerator | None: if env.LIGHTLY_STUDIO_EMBEDDINGS_MODEL_TYPE == "EDGE": try: from lightly_studio.dataset.edge_embedding_generator import ( EdgeSDKEmbeddingGenerator, ) - logger.info("Using LightlyEdge embedding generator.") + logger.info("Using LightlyEdge embedding generator for images.") return EdgeSDKEmbeddingGenerator(model_path=env.LIGHTLY_STUDIO_EDGE_MODEL_FILE_PATH) except ImportError: logger.warning("Embedding functionality is disabled.") - return None elif env.LIGHTLY_STUDIO_EMBEDDINGS_MODEL_TYPE == "MOBILE_CLIP": try: from lightly_studio.dataset.mobileclip_embedding_generator import ( MobileCLIPEmbeddingGenerator, ) - logger.info("Using MobileCLIP embedding generator.") + logger.info("Using MobileCLIP embedding generator for images.") return MobileCLIPEmbeddingGenerator() except ImportError: logger.warning("Embedding functionality is disabled.") - return None + elif env.LIGHTLY_STUDIO_EMBEDDINGS_MODEL_TYPE == "PE": + try: + from lightly_studio.dataset.perception_encoder_embedding_generator import ( + PerceptionEncoderEmbeddingGenerator, + ) - logger.warning(f"Unsupported model type: '{env.LIGHTLY_STUDIO_EMBEDDINGS_MODEL_TYPE}'") - logger.warning("Embedding functionality is disabled.") + logger.info("Using PerceptionEncoder embedding generator for images.") + return PerceptionEncoderEmbeddingGenerator() + except ImportError: + logger.warning("Embedding functionality is disabled.") + else: + logger.warning(f"Unsupported model type: '{env.LIGHTLY_STUDIO_EMBEDDINGS_MODEL_TYPE}'") + logger.warning("Embedding functionality is disabled.") return None + + +def _load_video_embedding_generator() -> VideoEmbeddingGenerator | None: + try: + from lightly_studio.dataset.perception_encoder_embedding_generator import ( + PerceptionEncoderEmbeddingGenerator, + ) + + logger.info("Using PerceptionEncoder embedding generator for videos.") + return PerceptionEncoderEmbeddingGenerator() + except ImportError: + logger.warning("Embedding functionality is disabled.") + return None diff --git a/lightly_studio/src/lightly_studio/resolvers/embedding_model_resolver.py b/lightly_studio/src/lightly_studio/resolvers/embedding_model_resolver.py index f60b7b198..1e823d639 100644 --- a/lightly_studio/src/lightly_studio/resolvers/embedding_model_resolver.py +++ b/lightly_studio/src/lightly_studio/resolvers/embedding_model_resolver.py @@ -34,8 +34,6 @@ def get_or_create(session: Session, embedding_model: EmbeddingModelCreate) -> Em db_model.name != embedding_model.name or db_model.parameter_count_in_mb != embedding_model.parameter_count_in_mb or db_model.embedding_dimension != embedding_model.embedding_dimension - # TODO(Michal, 09/2025): Allow same model for different datasets. - or db_model.dataset_id != embedding_model.dataset_id ): raise ValueError( "An embedding model with the same hash but different parameters already exists." diff --git a/lightly_studio/tests/api/routes/api/test_text_embedding.py b/lightly_studio/tests/api/routes/api/test_text_embedding.py index e239c9aaa..3e5250d80 100644 --- a/lightly_studio/tests/api/routes/api/test_text_embedding.py +++ b/lightly_studio/tests/api/routes/api/test_text_embedding.py @@ -3,6 +3,7 @@ from fastapi.testclient import TestClient from pytest_mock import MockerFixture +from sqlmodel import Session from lightly_studio.api.routes.api.status import ( HTTP_STATUS_OK, @@ -11,9 +12,13 @@ EmbeddingManager, EmbeddingManagerProvider, ) +from tests import helpers_resolvers -def test_embed_text(mocker: MockerFixture, test_client: TestClient) -> None: +def test_embed_text(db_session: Session, mocker: MockerFixture, test_client: TestClient) -> None: + # Create a db as the text_embeddings defaults to root_dataset + helpers_resolvers.create_dataset(session=db_session) + # Initialize the embedding_manager with a mock variant so it does not update # the singleton. mocker.patch.object( @@ -40,10 +45,14 @@ def test_embed_text(mocker: MockerFixture, test_client: TestClient) -> None: def test_embed_text_embedding_invalid_model_id( + db_session: Session, mocker: MockerFixture, test_client: TestClient, ) -> None: # Make the request to the `/samples` endpoint + # Create a db as the text_embeddings defaults to root_dataset + helpers_resolvers.create_dataset(session=db_session) + mocker.patch.object( EmbeddingManagerProvider, "get_embedding_manager", @@ -59,4 +68,4 @@ def test_embed_text_embedding_invalid_model_id( }, ) assert response.status_code == 500 - assert response.json() == {"detail": f"Embedding model with ID {test_uuid} not found."} + assert response.json() == {"detail": f"No embedding model found with ID {test_uuid}"} diff --git a/lightly_studio/tests/core/test_dataset.py b/lightly_studio/tests/core/test_dataset.py index 7b8bf7c2a..bf65d8818 100644 --- a/lightly_studio/tests/core/test_dataset.py +++ b/lightly_studio/tests/core/test_dataset.py @@ -409,7 +409,7 @@ def test_generate_embeddings( assert "embeddingSearchEnabled" not in features.lightly_studio_active_features assert "fewShotClassifierEnabled" not in features.lightly_studio_active_features - dataset_module._generate_embeddings( + dataset_module._generate_embeddings_image( session=session, dataset_id=dataset.dataset_id, sample_ids=[image1.sample_id], @@ -436,7 +436,7 @@ def test_generate_embeddings__no_generator( assert "embeddingSearchEnabled" not in features.lightly_studio_active_features assert "fewShotClassifierEnabled" not in features.lightly_studio_active_features - dataset_module._generate_embeddings( + dataset_module._generate_embeddings_image( session=session, dataset_id=dataset.dataset_id, sample_ids=[image1.sample_id], @@ -455,7 +455,7 @@ def test_generate_embeddings__empty_sample_ids( session = db_manager.persistent_session() dataset = create_dataset(session=session) - dataset_module._generate_embeddings( + dataset_module._generate_embeddings_image( session=session, dataset_id=dataset.dataset_id, sample_ids=[], @@ -482,7 +482,7 @@ def test_are_embeddings_available( is False ) - dataset_module._generate_embeddings( + dataset_module._generate_embeddings_image( session=session, dataset_id=dataset.dataset_id, sample_ids=[image1.sample_id], @@ -507,7 +507,7 @@ def test_enable_few_shot_classifier_on_load( assert "embeddingSearchEnabled" not in features.lightly_studio_active_features assert "fewShotClassifierEnabled" not in features.lightly_studio_active_features - dataset_module._generate_embeddings( + dataset_module._generate_embeddings_image( session=session, dataset_id=dataset.dataset_id, sample_ids=[image1.sample_id], @@ -542,7 +542,7 @@ def test_enable_few_shot_classifier_on_load_or_create( session = db_manager.persistent_session() dataset = create_dataset(session=session, dataset_name="test_dataset") image1 = create_image(session=session, dataset_id=dataset.dataset_id) - dataset_module._generate_embeddings( + dataset_module._generate_embeddings_image( session=session, dataset_id=dataset.dataset_id, sample_ids=[image1.sample_id], diff --git a/lightly_studio/tests/core/test_dataset__video.py b/lightly_studio/tests/core/test_dataset__video.py new file mode 100644 index 000000000..aba7d1093 --- /dev/null +++ b/lightly_studio/tests/core/test_dataset__video.py @@ -0,0 +1,101 @@ +from __future__ import annotations + +from pathlib import Path + +from lightly_studio import Dataset +from lightly_studio.dataset.embedding_manager import EmbeddingManagerProvider +from lightly_studio.models.dataset import SampleType +from lightly_studio.resolvers import sample_embedding_resolver, video_resolver +from tests.core.test_add_videos import _create_temp_video + + +class TestDataset: + def test_dataset_add_videos_from_path__valid( + self, + patch_dataset: None, # noqa: ARG002 + tmp_path: Path, + ) -> None: + _create_temp_video( + output_path=tmp_path / "test_video_1.mp4", + width=640, + height=480, + num_frames=30, + fps=2, + ) + _create_temp_video( + output_path=tmp_path / "test_video_0.mp4", + width=640, + height=480, + num_frames=30, + fps=2, + ) + + dataset = Dataset.create(name="test_dataset", sample_type=SampleType.VIDEO) + dataset.add_videos_from_path(path=tmp_path) + + # Verify frames are in the database + videos = video_resolver.get_all_by_dataset_id( + session=dataset.session, + dataset_id=dataset.dataset_id, + ).samples + assert len(videos) == 2 + assert {s.file_name for s in videos} == { + "test_video_1.mp4", + "test_video_0.mp4", + } + # Check that embeddings were created + embedding_manager = EmbeddingManagerProvider.get_embedding_manager() + model_id = embedding_manager.load_or_get_default_model( + session=dataset.session, + dataset_id=dataset.dataset_id, + ) + assert model_id is not None + embeddings = sample_embedding_resolver.get_all_by_dataset_id( + session=dataset.session, dataset_id=dataset.dataset_id, embedding_model_id=model_id + ) + assert len(embeddings) == 2 + + def test_dataset_add_videos_from_path__dont_embed( + self, + patch_dataset: None, # noqa: ARG002 + tmp_path: Path, + ) -> None: + _create_temp_video( + output_path=tmp_path / "test_video_1.mp4", + width=640, + height=480, + num_frames=30, + fps=2, + ) + _create_temp_video( + output_path=tmp_path / "test_video_0.mp4", + width=640, + height=480, + num_frames=30, + fps=2, + ) + + dataset = Dataset.create(name="test_dataset", sample_type=SampleType.VIDEO) + dataset.add_videos_from_path(path=tmp_path, embed=False) + + # Verify frames are in the database + videos = video_resolver.get_all_by_dataset_id( + session=dataset.session, + dataset_id=dataset.dataset_id, + ).samples + assert len(videos) == 2 + assert {s.file_name for s in videos} == { + "test_video_1.mp4", + "test_video_0.mp4", + } + # Check that embeddings were created + embedding_manager = EmbeddingManagerProvider.get_embedding_manager() + model_id = embedding_manager.load_or_get_default_model( + session=dataset.session, + dataset_id=dataset.dataset_id, + ) + assert model_id is not None + embeddings = sample_embedding_resolver.get_all_by_dataset_id( + session=dataset.session, dataset_id=dataset.dataset_id, embedding_model_id=model_id + ) + assert len(embeddings) == 0 diff --git a/lightly_studio/tests/dataset/test_embedding_manager.py b/lightly_studio/tests/dataset/test_embedding_manager.py index 4eca43c01..4fe8317bb 100644 --- a/lightly_studio/tests/dataset/test_embedding_manager.py +++ b/lightly_studio/tests/dataset/test_embedding_manager.py @@ -46,7 +46,7 @@ def test_register_embedding_model( # Check that the model was registered in memory. assert model_id in embedding_manager._models assert embedding_manager._models[model_id] == random_model - assert embedding_manager._default_model_id == model_id + assert embedding_manager._dataset_id_to_default_model_id[dataset.dataset_id] == model_id # Check that the model was stored in the database. stored_model = db_session.exec( @@ -98,7 +98,7 @@ def embed_images(self, filepaths: list[str]) -> NDArray[np.float32]: # Check that both models were registered in memory assert model_id1 in embedding_manager._models assert model_id2 in embedding_manager._models - assert embedding_manager._default_model_id == model_id1 + assert embedding_manager._dataset_id_to_default_model_id[dataset.dataset_id] == model_id1 # Check that both models were stored in the database stored_models = db_session.exec(select(EmbeddingModelTable)).all() @@ -125,7 +125,7 @@ def test_embed_text_with_default_model( # Generate embedding. query = TextEmbedQuery(text="test text") - embedding = embedding_manager.embed_text(query) + embedding = embedding_manager.embed_text(dataset_id=dataset.dataset_id, text_query=query) # Check embedding. assert len(embedding) == 3 @@ -147,7 +147,7 @@ def test_embed_text_with_specific_model( # Generate embedding with specific model. query = TextEmbedQuery(text="test text", embedding_model_id=model_id) - embedding = embedding_manager.embed_text(query) + embedding = embedding_manager.embed_text(dataset_id=dataset.dataset_id, text_query=query) # Check embedding. assert len(embedding) == 3 @@ -157,8 +157,8 @@ def test_embed_text_without_model() -> None: """Test generating text embeddings without registered model.""" embedding_manager = EmbeddingManager() query = TextEmbedQuery(text="test text") - with pytest.raises(ValueError, match="No embedding model specified"): - embedding_manager.embed_text(query) + with pytest.raises(ValueError, match="No embedding_model_id provided and no default embedding"): + embedding_manager.embed_text(dataset_id=uuid4(), text_query=query) def test_embed_text_with_invalid_model( @@ -178,9 +178,9 @@ def test_embed_text_with_invalid_model( query = TextEmbedQuery(text="test text", embedding_model_id=invalid_model_id) with pytest.raises( ValueError, - match=f"Embedding model with ID {invalid_model_id} not found.", + match=f"No embedding model found with ID {invalid_model_id}", ): - embedding_manager.embed_text(query) + embedding_manager.embed_text(dataset_id=dataset.dataset_id, text_query=query) def test_embed_images( @@ -200,7 +200,9 @@ def test_embed_images( # Generate embeddings for samples sample_ids = [sample.sample_id for sample in samples] - embedding_manager.embed_images(session=db_session, sample_ids=sample_ids) + embedding_manager.embed_images( + session=db_session, dataset_id=dataset.dataset_id, sample_ids=sample_ids + ) # Verify embeddings were stored in the database stored_embeddings = db_session.exec( @@ -226,7 +228,9 @@ def test_embed_images_with_incompatible_generator( ) with pytest.raises(ValueError, match=r"Embedding model not compatible with images."): - manager.embed_images(session=db_session, sample_ids=[uuid4()]) + manager.embed_images( + session=db_session, dataset_id=dataset.dataset_id, sample_ids=[uuid4()] + ) def test_get_valid_model_id_without_default_model() -> None: @@ -236,7 +240,7 @@ def test_get_valid_model_id_without_default_model() -> None: ValueError, match=r"No embedding_model_id provided and no default embedding model registered.", ): - manager._get_default_or_validate(embedding_model_id=None) + manager._get_default_or_validate(dataset_id=uuid4(), embedding_model_id=None) def test_get_valid_model_id_with_invalid_requested_model( @@ -256,7 +260,9 @@ def test_get_valid_model_id_with_invalid_requested_model( ValueError, match=f"No embedding model found with ID {missing_model_id}", ): - manager._get_default_or_validate(embedding_model_id=missing_model_id) + manager._get_default_or_validate( + dataset_id=dataset.dataset_id, embedding_model_id=missing_model_id + ) def test_get_valid_model_id_with_default_and_explicit_id( @@ -271,7 +277,10 @@ def test_get_valid_model_id_with_default_and_explicit_id( dataset_id=dataset.dataset_id, set_as_default=True, ).embedding_model_id - assert manager._get_default_or_validate(embedding_model_id=None) == default_model_id + assert ( + manager._get_default_or_validate(dataset_id=dataset.dataset_id, embedding_model_id=None) + == default_model_id + ) other_model_id = manager.register_embedding_model( session=db_session, @@ -279,7 +288,12 @@ def test_get_valid_model_id_with_default_and_explicit_id( dataset_id=dataset.dataset_id, set_as_default=False, ).embedding_model_id - assert manager._get_default_or_validate(embedding_model_id=other_model_id) == other_model_id + assert ( + manager._get_default_or_validate( + dataset_id=dataset.dataset_id, embedding_model_id=other_model_id + ) + == other_model_id + ) def test_load_or_get_default_model( @@ -305,7 +319,7 @@ def test_load_or_get_default_model( assert model_id is not None # Verify we got back the random model. - mock_load.assert_called_once_with() + mock_load.assert_called_once_with(sample_type=SampleType.IMAGE) model = embedding_model_resolver.get_by_id(session=db_session, embedding_model_id=model_id) assert model is not None assert model.name == "Random" @@ -316,7 +330,7 @@ def test_load_or_get_default_model( dataset_id=dataset.dataset_id, ) assert model_id == second_id - mock_load.assert_called_once_with() # still only one call + mock_load.assert_called_once_with(sample_type=SampleType.IMAGE) # still only one call def test_load_or_get_default_model__cant_load( @@ -338,7 +352,7 @@ def test_load_or_get_default_model__cant_load( dataset_id=dataset.dataset_id, ) - mock_load.assert_called_once_with() + mock_load.assert_called_once_with(sample_type=SampleType.IMAGE) assert model_id is None @@ -355,7 +369,7 @@ def test_default_model( set_as_default=False, ).embedding_model_id # The first model is always set as default. - assert embedding_manager._default_model_id == first_model_id + assert embedding_manager._dataset_id_to_default_model_id[dataset.dataset_id] == first_model_id # Override default model with set_as_default=True. second_model_id = embedding_manager.register_embedding_model( @@ -365,7 +379,7 @@ def test_default_model( set_as_default=True, ).embedding_model_id - assert embedding_manager._default_model_id == second_model_id + assert embedding_manager._dataset_id_to_default_model_id[dataset.dataset_id] == second_model_id def test_embed_videos( @@ -373,9 +387,10 @@ def test_embed_videos( ) -> None: """Test generating embeddings for video samples.""" video_dataset = create_dataset(session=db_session, sample_type=SampleType.VIDEO) + dataset_id = video_dataset.dataset_id video_ids = create_videos( session=db_session, - dataset_id=video_dataset.dataset_id, + dataset_id=dataset_id, videos=[ VideoStub(path=f"/videos/video_{idx}.mp4", duration_s=1.0 + idx, fps=24.0) for idx in range(3) @@ -385,11 +400,11 @@ def test_embed_videos( model_id = manager.register_embedding_model( session=db_session, embedding_generator=RandomEmbeddingGenerator(), - dataset_id=video_dataset.dataset_id, + dataset_id=dataset_id, set_as_default=True, ).embedding_model_id - manager.embed_videos(session=db_session, sample_ids=video_ids) + manager.embed_videos(session=db_session, dataset_id=dataset_id, sample_ids=video_ids) stored_embeddings = db_session.exec( select(SampleEmbeddingTable).where(SampleEmbeddingTable.embedding_model_id == model_id) @@ -403,16 +418,17 @@ def test_embed_videos( def test_embed_videos_with_incompatible_generator(db_session: Session) -> None: """Ensure we raise when the default lacks video support.""" video_dataset = create_dataset(session=db_session, sample_type=SampleType.VIDEO) + dataset_id = video_dataset.dataset_id manager = EmbeddingManager() manager.register_embedding_model( session=db_session, embedding_generator=TextOnlyEmbeddingGenerator(), - dataset_id=video_dataset.dataset_id, + dataset_id=dataset_id, set_as_default=True, ) with pytest.raises(ValueError, match=r"Embedding model not compatible with videos."): - manager.embed_videos(session=db_session, sample_ids=[uuid4()]) + manager.embed_videos(session=db_session, dataset_id=dataset_id, sample_ids=[uuid4()]) class TextOnlyEmbeddingGenerator: