diff --git a/lightly_studio/src/lightly_studio/api/routes/api/annotation.py b/lightly_studio/src/lightly_studio/api/routes/api/annotation.py index afa323c75..a926a0191 100644 --- a/lightly_studio/src/lightly_studio/api/routes/api/annotation.py +++ b/lightly_studio/src/lightly_studio/api/routes/api/annotation.py @@ -23,7 +23,7 @@ AnnotationViewsWithCount, AnnotationWithPayloadAndCountView, ) -from lightly_studio.models.dataset import DatasetTable, SampleType +from lightly_studio.models.dataset import DatasetTable from lightly_studio.resolvers import annotation_resolver, tag_resolver from lightly_studio.resolvers.annotation_resolver.get_all import ( GetAllAnnotationsResult, @@ -45,20 +45,17 @@ class AnnotationQueryParamsModel(BaseModel): """Model for all annotation query parameters.""" pagination: PaginatedWithCursor - sample_type: SampleType annotation_label_ids: list[UUID] | None = None tag_ids: list[UUID] | None = None def _get_annotation_query_params( pagination: Annotated[PaginatedWithCursor, Depends()], - sample_type: SampleType, annotation_label_ids: Annotated[list[UUID] | None, Query()] = None, tag_ids: Annotated[list[UUID] | None, Query()] = None, ) -> AnnotationQueryParamsModel: return AnnotationQueryParamsModel( pagination=pagination, - sample_type=sample_type, annotation_label_ids=annotation_label_ids, tag_ids=tag_ids, ) @@ -145,12 +142,12 @@ def read_annotations_with_payload( offset=params.pagination.offset, limit=params.pagination.limit, ), - sample_type=params.sample_type, filters=AnnotationsFilter( dataset_ids=[dataset_id], annotation_label_ids=params.annotation_label_ids, annotation_tag_ids=params.tag_ids, ), + dataset_id=dataset_id, ) diff --git a/lightly_studio/src/lightly_studio/models/annotation/annotation_base.py b/lightly_studio/src/lightly_studio/models/annotation/annotation_base.py index fb6a75ba3..179d88f54 100644 --- a/lightly_studio/src/lightly_studio/models/annotation/annotation_base.py +++ b/lightly_studio/src/lightly_studio/models/annotation/annotation_base.py @@ -23,6 +23,7 @@ SemanticSegmentationAnnotationTable, SemanticSegmentationAnnotationView, ) +from lightly_studio.models.dataset import SampleType from lightly_studio.models.sample import SampleTable if TYPE_CHECKING: @@ -202,6 +203,7 @@ class AnnotationWithPayloadView(BaseModel): model_config = ConfigDict(populate_by_name=True) + parent_sample_type: SampleType annotation: AnnotationView parent_sample_data: Union[ImageAnnotationView, VideoFrameAnnotationView] diff --git a/lightly_studio/src/lightly_studio/resolvers/annotation_resolver/get_all_with_payload.py b/lightly_studio/src/lightly_studio/resolvers/annotation_resolver/get_all_with_payload.py index 6a16a486b..b1b8022af 100644 --- a/lightly_studio/src/lightly_studio/resolvers/annotation_resolver/get_all_with_payload.py +++ b/lightly_studio/src/lightly_studio/resolvers/annotation_resolver/get_all_with_payload.py @@ -3,6 +3,7 @@ from __future__ import annotations from typing import Any +from uuid import UUID from sqlalchemy.orm import aliased, joinedload, load_only from sqlmodel import Session, col, func, select @@ -12,6 +13,7 @@ from lightly_studio.models.annotation.annotation_base import ( AnnotationBaseTable, AnnotationWithPayloadAndCountView, + AnnotationWithPayloadView, ImageAnnotationView, SampleAnnotationView, VideoFrameAnnotationView, @@ -20,6 +22,7 @@ from lightly_studio.models.image import ImageTable from lightly_studio.models.sample import SampleTable from lightly_studio.models.video import VideoFrameTable, VideoTable +from lightly_studio.resolvers import dataset_resolver from lightly_studio.resolvers.annotations.annotations_filter import ( AnnotationsFilter, ) @@ -27,7 +30,7 @@ def get_all_with_payload( session: Session, - sample_type: SampleType, + dataset_id: UUID, pagination: Paginated | None = None, filters: AnnotationsFilter | None = None, ) -> AnnotationWithPayloadAndCountView: @@ -35,13 +38,20 @@ def get_all_with_payload( Args: session: Database session - sample_type: Sample type to filter by pagination: Optional pagination parameters filters: Optional filters to apply to the query + dataset_id: ID of the dataset to get annotations for Returns: List of annotations matching the filters with payload """ + parent_dataset = dataset_resolver.get_parent_dataset_id(session=session, dataset_id=dataset_id) + + if parent_dataset is None: + raise ValueError(f"Dataset with id {dataset_id} does not have a parent dataset.") + + sample_type = parent_dataset.sample_type + base_query = _build_base_query(sample_type=sample_type) if filters: @@ -69,7 +79,11 @@ def get_all_with_payload( total_count=total_count, next_cursor=next_cursor, annotations=[ - {"annotation": annotation, "parent_sample_data": _serialize_annotation_payload(payload)} + AnnotationWithPayloadView( + parent_sample_type=sample_type, + annotation=annotation, + parent_sample_data=_serialize_annotation_payload(payload), + ) for annotation, payload in rows ], ) @@ -102,7 +116,7 @@ def _build_base_query( ) ) - if sample_type == SampleType.VIDEO_FRAME: + if sample_type in (SampleType.VIDEO_FRAME, SampleType.VIDEO): return ( select(AnnotationBaseTable, VideoFrameTable) .join( @@ -130,7 +144,7 @@ def _extra_order_by(sample_type: SampleType) -> list[Any]: col(ImageTable.file_path_abs).asc(), ] - if sample_type == SampleType.VIDEO_FRAME: + if sample_type in (SampleType.VIDEO_FRAME, SampleType.VIDEO): return [ col(VideoTable.file_path_abs).asc(), ] diff --git a/lightly_studio/src/lightly_studio/resolvers/dataset_resolver/__init__.py b/lightly_studio/src/lightly_studio/resolvers/dataset_resolver/__init__.py index 73c0879c8..571e6c722 100644 --- a/lightly_studio/src/lightly_studio/resolvers/dataset_resolver/__init__.py +++ b/lightly_studio/src/lightly_studio/resolvers/dataset_resolver/__init__.py @@ -21,6 +21,9 @@ from lightly_studio.resolvers.dataset_resolver.get_or_create_child_dataset import ( get_or_create_child_dataset, ) +from lightly_studio.resolvers.dataset_resolver.get_parent_dataset_id import ( + get_parent_dataset_id, +) from lightly_studio.resolvers.dataset_resolver.get_root_dataset import ( get_root_dataset, ) @@ -41,6 +44,7 @@ "get_filtered_samples_count", "get_hierarchy", "get_or_create_child_dataset", + "get_parent_dataset_id", "get_root_dataset", "get_root_datasets_overview", "update", diff --git a/lightly_studio/src/lightly_studio/resolvers/dataset_resolver/get_parent_dataset_id.py b/lightly_studio/src/lightly_studio/resolvers/dataset_resolver/get_parent_dataset_id.py new file mode 100644 index 000000000..6dd0b5a1d --- /dev/null +++ b/lightly_studio/src/lightly_studio/resolvers/dataset_resolver/get_parent_dataset_id.py @@ -0,0 +1,22 @@ +"""Retrieve the parent dataset ID for a given dataset ID.""" + +from __future__ import annotations + +from uuid import UUID + +from sqlalchemy.orm import aliased +from sqlmodel import Session, col, select + +from lightly_studio.models.dataset import DatasetTable + +ParentDataset = aliased(DatasetTable) +ChildDataset = aliased(DatasetTable) + + +def get_parent_dataset_id(session: Session, dataset_id: UUID) -> DatasetTable | None: + """Retrieve the parent dataset for a given dataset ID.""" + return session.exec( + select(ParentDataset) + .join(ChildDataset, col(ChildDataset.parent_dataset_id) == col(ParentDataset.dataset_id)) + .where(ChildDataset.dataset_id == dataset_id) + ).one_or_none() diff --git a/lightly_studio/tests/api/routes/api/test_annotation.py b/lightly_studio/tests/api/routes/api/test_annotation.py index cf641fe5b..ae5be1da8 100644 --- a/lightly_studio/tests/api/routes/api/test_annotation.py +++ b/lightly_studio/tests/api/routes/api/test_annotation.py @@ -199,7 +199,6 @@ def test_read_annotations_with_payload( params={ "offset": 0, "limit": 1, - "sample_type": "image", }, ) diff --git a/lightly_studio/tests/resolvers/annotations/test_get_all_with_payload.py b/lightly_studio/tests/resolvers/annotations/test_get_all_with_payload.py index fdbb8156e..fbf76ded0 100644 --- a/lightly_studio/tests/resolvers/annotations/test_get_all_with_payload.py +++ b/lightly_studio/tests/resolvers/annotations/test_get_all_with_payload.py @@ -49,7 +49,7 @@ def test_get_all_with_payload__with_pagination( ) # Create annotations - create_annotation( + annotation = create_annotation( session=test_db, sample_id=image_1.sample_id, annotation_label_id=car_label.annotation_label_id, @@ -64,8 +64,8 @@ def test_get_all_with_payload__with_pagination( annotations_page = annotation_resolver.get_all_with_payload( session=test_db, - sample_type=SampleType.IMAGE, pagination=Paginated(limit=1, offset=0), + dataset_id=annotation.sample.dataset_id, ) assert annotations_page.total_count == 2 @@ -105,7 +105,7 @@ def test_get_all_with_payload__with_image( ) # Create annotations - create_annotation( + annotation = create_annotation( session=test_db, sample_id=image_1.sample_id, annotation_label_id=car_label.annotation_label_id, @@ -120,13 +120,14 @@ def test_get_all_with_payload__with_image( annotations_page = annotation_resolver.get_all_with_payload( session=test_db, - sample_type=SampleType.IMAGE, + dataset_id=annotation.sample.dataset_id, ) assert annotations_page.total_count == 2 assert len(annotations_page.annotations) == 2 assert isinstance(annotations_page.annotations[0].parent_sample_data, ImageAnnotationView) + assert annotations_page.annotations[0].parent_sample_type == SampleType.IMAGE assert ( annotations_page.annotations[0].annotation.annotation_label.annotation_label_name == airplane_label.annotation_label_name @@ -135,6 +136,7 @@ def test_get_all_with_payload__with_image( assert annotations_page.annotations[0].parent_sample_data.sample.dataset_id == dataset_id assert isinstance(annotations_page.annotations[1].parent_sample_data, ImageAnnotationView) + assert annotations_page.annotations[0].parent_sample_type == SampleType.IMAGE assert ( annotations_page.annotations[1].annotation.annotation_label.annotation_label_name == car_label.annotation_label_name @@ -164,7 +166,7 @@ def test_get_all_with_payload__with_video_frame(test_db: Session) -> None: ) # Create annotations - create_annotation( + annotation = create_annotation( session=test_db, sample_id=video_frame_data.frame_sample_ids[0], annotation_label_id=car_label.annotation_label_id, @@ -179,13 +181,14 @@ def test_get_all_with_payload__with_video_frame(test_db: Session) -> None: annotations_page = annotation_resolver.get_all_with_payload( session=test_db, - sample_type=SampleType.VIDEO_FRAME, + dataset_id=annotation.sample.dataset_id, ) assert annotations_page.total_count == 2 assert len(annotations_page.annotations) == 2 assert isinstance(annotations_page.annotations[0].parent_sample_data, VideoFrameAnnotationView) + assert annotations_page.annotations[0].parent_sample_type == SampleType.VIDEO assert ( annotations_page.annotations[0].parent_sample_data.video.file_path_abs == "/path/to/sample1.mp4" @@ -196,6 +199,7 @@ def test_get_all_with_payload__with_video_frame(test_db: Session) -> None: ) assert isinstance(annotations_page.annotations[1].parent_sample_data, VideoFrameAnnotationView) + assert annotations_page.annotations[0].parent_sample_type == SampleType.VIDEO assert ( annotations_page.annotations[1].parent_sample_data.video.file_path_abs == "/path/to/sample1.mp4" @@ -206,14 +210,16 @@ def test_get_all_with_payload__with_video_frame(test_db: Session) -> None: ) -def test_get_all_with_payload__with_unsupported_sample_type( +def test_get_all_with_payload__with_unsupported_dataset( test_db: Session, ) -> None: - with pytest.raises(NotImplementedError) as exc: + dataset = create_dataset(session=test_db, sample_type=SampleType.VIDEO) + + with pytest.raises( + ValueError, match=f"Dataset with id {dataset.dataset_id} does not have a parent dataset." + ): annotation_resolver.get_all_with_payload( session=test_db, - sample_type=SampleType.CAPTION, pagination=Paginated(limit=1, offset=0), + dataset_id=dataset.dataset_id, ) - - assert "Unsupported sample type" in str(exc.value) diff --git a/lightly_studio/tests/resolvers/datasets_resolver/test_get_parent_dataset_id.py b/lightly_studio/tests/resolvers/datasets_resolver/test_get_parent_dataset_id.py new file mode 100644 index 000000000..08941fa98 --- /dev/null +++ b/lightly_studio/tests/resolvers/datasets_resolver/test_get_parent_dataset_id.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +from sqlmodel import Session + +from lightly_studio.resolvers import dataset_resolver +from tests.helpers_resolvers import ( + create_annotation, + create_annotation_label, + create_dataset, + create_image, +) + + +def test_get_parent_dataset_id__from_parent_dataset(test_db: Session) -> None: + dataset = create_dataset(session=test_db) + + image_1 = create_image( + session=test_db, + dataset_id=dataset.dataset_id, + file_path_abs="/path/to/sample2.png", + ) + car_label = create_annotation_label( + session=test_db, + annotation_label_name="car", + ) + annotation = create_annotation( + session=test_db, + sample_id=image_1.sample_id, + annotation_label_id=car_label.annotation_label_id, + dataset_id=dataset.dataset_id, + ) + + parent_dataset = dataset_resolver.get_parent_dataset_id( + session=test_db, dataset_id=annotation.sample.dataset_id + ) + assert parent_dataset is not None + assert parent_dataset.dataset_id == dataset.dataset_id + + +def test_get_parent_dataset_id__from_root_dataset(test_db: Session) -> None: + dataset = create_dataset(session=test_db) + + parent_dataset = dataset_resolver.get_parent_dataset_id( + session=test_db, dataset_id=dataset.dataset_id + ) + assert parent_dataset is None diff --git a/lightly_studio_view/src/lib/components/AnnotationsGrid/AnnotationsGrid.svelte b/lightly_studio_view/src/lib/components/AnnotationsGrid/AnnotationsGrid.svelte index cb009e3cc..80745bd18 100644 --- a/lightly_studio_view/src/lib/components/AnnotationsGrid/AnnotationsGrid.svelte +++ b/lightly_studio_view/src/lib/components/AnnotationsGrid/AnnotationsGrid.svelte @@ -15,11 +15,7 @@ import { useScrollRestoration } from '$lib/hooks/useScrollRestoration/useScrollRestoration'; import { addAnnotationLabelChangeToUndoStack } from '$lib/services/addAnnotationLabelChangeToUndoStack'; import { useUpdateAnnotationsMutation } from '$lib/hooks/useUpdateAnnotationsMutation/useUpdateAnnotationsMutation'; - import { - AnnotationType, - SampleType, - type AnnotationWithPayloadView - } from '$lib/api/lightly_studio_local'; + import { AnnotationType, type AnnotationWithPayloadView } from '$lib/api/lightly_studio_local'; type AnnotationsProps = { dataset_id: string; @@ -76,8 +72,7 @@ query: { annotation_label_ids: $selectedAnnotationFilterIds.length > 0 ? $selectedAnnotationFilterIds : undefined, - tag_ids: $tagsSelected.size > 0 ? Array.from($tagsSelected) : undefined, - sample_type: SampleType.IMAGE + tag_ids: $tagsSelected.size > 0 ? Array.from($tagsSelected) : undefined } }); diff --git a/lightly_studio_view/src/lib/schema.d.ts b/lightly_studio_view/src/lib/schema.d.ts index 033a3daca..3de7d4b8c 100644 --- a/lightly_studio_view/src/lib/schema.d.ts +++ b/lightly_studio_view/src/lib/schema.d.ts @@ -1812,6 +1812,7 @@ export interface components { * @description Response model for annotation with payload. */ AnnotationWithPayloadView: { + parent_sample_type: components["schemas"]["SampleType"]; annotation: components["schemas"]["AnnotationView"]; /** Parent Sample Data */ parent_sample_data: components["schemas"]["ImageAnnotationView"] | components["schemas"]["VideoFrameAnnotationView"]; @@ -4189,8 +4190,7 @@ export interface operations { }; read_annotations_with_payload: { parameters: { - query: { - sample_type: components["schemas"]["SampleType"]; + query?: { annotation_label_ids?: string[] | null; tag_ids?: string[] | null; cursor?: number;