Skip to content
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
AnnotationViewsWithCount,
AnnotationWithPayloadAndCountView,
)
from lightly_studio.models.dataset import DatasetTable, SampleType
from lightly_studio.models.dataset import DatasetTable
from lightly_studio.resolvers import annotation_resolver, tag_resolver
from lightly_studio.resolvers.annotation_resolver.get_all import (
GetAllAnnotationsResult,
Expand All @@ -45,20 +45,17 @@ class AnnotationQueryParamsModel(BaseModel):
"""Model for all annotation query parameters."""

pagination: PaginatedWithCursor
sample_type: SampleType
annotation_label_ids: list[UUID] | None = None
tag_ids: list[UUID] | None = None


def _get_annotation_query_params(
pagination: Annotated[PaginatedWithCursor, Depends()],
sample_type: SampleType,
annotation_label_ids: Annotated[list[UUID] | None, Query()] = None,
tag_ids: Annotated[list[UUID] | None, Query()] = None,
) -> AnnotationQueryParamsModel:
return AnnotationQueryParamsModel(
pagination=pagination,
sample_type=sample_type,
annotation_label_ids=annotation_label_ids,
tag_ids=tag_ids,
)
Expand Down Expand Up @@ -145,12 +142,12 @@ def read_annotations_with_payload(
offset=params.pagination.offset,
limit=params.pagination.limit,
),
sample_type=params.sample_type,
filters=AnnotationsFilter(
dataset_ids=[dataset_id],
annotation_label_ids=params.annotation_label_ids,
annotation_tag_ids=params.tag_ids,
),
dataset_id=dataset_id,
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
SemanticSegmentationAnnotationTable,
SemanticSegmentationAnnotationView,
)
from lightly_studio.models.dataset import SampleType
from lightly_studio.models.sample import SampleTable

if TYPE_CHECKING:
Expand Down Expand Up @@ -202,6 +203,7 @@ class AnnotationWithPayloadView(BaseModel):

model_config = ConfigDict(populate_by_name=True)

parent_sample_type: SampleType
annotation: AnnotationView
parent_sample_data: Union[ImageAnnotationView, VideoFrameAnnotationView]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

from typing import Any
from uuid import UUID

from sqlalchemy.orm import aliased, joinedload, load_only
from sqlmodel import Session, col, func, select
Expand All @@ -20,14 +21,15 @@
from lightly_studio.models.image import ImageTable
from lightly_studio.models.sample import SampleTable
from lightly_studio.models.video import VideoFrameTable, VideoTable
from lightly_studio.resolvers import dataset_resolver
from lightly_studio.resolvers.annotations.annotations_filter import (
AnnotationsFilter,
)


def get_all_with_payload(
session: Session,
sample_type: SampleType,
dataset_id: UUID,
pagination: Paginated | None = None,
filters: AnnotationsFilter | None = None,
) -> AnnotationWithPayloadAndCountView:
Expand All @@ -38,10 +40,18 @@ def get_all_with_payload(
sample_type: Sample type to filter by
pagination: Optional pagination parameters
filters: Optional filters to apply to the query
dataset_id: ID of the dataset to get annotations for

Returns:
List of annotations matching the filters with payload
"""
parent_dataset = dataset_resolver.get_parent_dataset_id(session=session, dataset_id=dataset_id)

if parent_dataset is None:
raise ValueError(f"Dataset with id {dataset_id} does not have a parent dataset.")

sample_type = parent_dataset.sample_type

base_query = _build_base_query(sample_type=sample_type)

if filters:
Expand Down Expand Up @@ -69,7 +79,11 @@ def get_all_with_payload(
total_count=total_count,
next_cursor=next_cursor,
annotations=[
{"annotation": annotation, "parent_sample_data": _serialize_annotation_payload(payload)}
{
"parent_sample_type": sample_type,
"annotation": annotation,
"parent_sample_data": _serialize_annotation_payload(payload),
}
for annotation, payload in rows
],
)
Expand Down Expand Up @@ -102,7 +116,7 @@ def _build_base_query(
)
)

if sample_type == SampleType.VIDEO_FRAME:
if sample_type in (SampleType.VIDEO_FRAME, SampleType.VIDEO):
return (
select(AnnotationBaseTable, VideoFrameTable)
.join(
Expand Down Expand Up @@ -130,7 +144,7 @@ def _extra_order_by(sample_type: SampleType) -> list[Any]:
col(ImageTable.file_path_abs).asc(),
]

if sample_type == SampleType.VIDEO_FRAME:
if sample_type in (SampleType.VIDEO_FRAME, SampleType.VIDEO):
return [
col(VideoTable.file_path_abs).asc(),
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
from lightly_studio.resolvers.dataset_resolver.get_or_create_child_dataset import (
get_or_create_child_dataset,
)
from lightly_studio.resolvers.dataset_resolver.get_parent_dataset_id import (
get_parent_dataset_id,
)
from lightly_studio.resolvers.dataset_resolver.get_root_dataset import (
get_root_dataset,
)
Expand All @@ -41,6 +44,7 @@
"get_filtered_samples_count",
"get_hierarchy",
"get_or_create_child_dataset",
"get_parent_dataset_id",
"get_root_dataset",
"get_root_datasets_overview",
"update",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
"""Retrieve the parent dataset ID for a given dataset ID."""

from __future__ import annotations

from uuid import UUID

from sqlalchemy.orm import aliased
from sqlmodel import Session, col, select

from lightly_studio.models.dataset import DatasetTable

ParentDataset = aliased(DatasetTable)
ChildDataset = aliased(DatasetTable)


def get_parent_dataset_id(session: Session, dataset_id: UUID) -> DatasetTable | None:
"""Retrieve the parent dataset for a given dataset ID."""
return session.exec(
select(ParentDataset)
.join(ChildDataset, col(ChildDataset.parent_dataset_id) == col(ParentDataset.dataset_id))
.where(ChildDataset.dataset_id == dataset_id)
).one_or_none()
1 change: 0 additions & 1 deletion lightly_studio/tests/api/routes/api/test_annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,6 @@ def test_read_annotations_with_payload(
params={
"offset": 0,
"limit": 1,
"sample_type": "image",
},
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def test_get_all_with_payload__with_pagination(
)

# Create annotations
create_annotation(
annotation = create_annotation(
session=test_db,
sample_id=image_1.sample_id,
annotation_label_id=car_label.annotation_label_id,
Expand All @@ -64,8 +64,8 @@ def test_get_all_with_payload__with_pagination(

annotations_page = annotation_resolver.get_all_with_payload(
session=test_db,
sample_type=SampleType.IMAGE,
pagination=Paginated(limit=1, offset=0),
dataset_id=annotation.sample.dataset_id,
)

assert annotations_page.total_count == 2
Expand Down Expand Up @@ -105,7 +105,7 @@ def test_get_all_with_payload__with_image(
)

# Create annotations
create_annotation(
annotation = create_annotation(
session=test_db,
sample_id=image_1.sample_id,
annotation_label_id=car_label.annotation_label_id,
Expand All @@ -120,13 +120,14 @@ def test_get_all_with_payload__with_image(

annotations_page = annotation_resolver.get_all_with_payload(
session=test_db,
sample_type=SampleType.IMAGE,
dataset_id=annotation.sample.dataset_id,
)

assert annotations_page.total_count == 2
assert len(annotations_page.annotations) == 2

assert isinstance(annotations_page.annotations[0].parent_sample_data, ImageAnnotationView)
assert annotations_page.annotations[0].parent_sample_type == SampleType.IMAGE
assert (
annotations_page.annotations[0].annotation.annotation_label.annotation_label_name
== airplane_label.annotation_label_name
Expand All @@ -135,6 +136,7 @@ def test_get_all_with_payload__with_image(
assert annotations_page.annotations[0].parent_sample_data.sample.dataset_id == dataset_id

assert isinstance(annotations_page.annotations[1].parent_sample_data, ImageAnnotationView)
assert annotations_page.annotations[0].parent_sample_type == SampleType.IMAGE
assert (
annotations_page.annotations[1].annotation.annotation_label.annotation_label_name
== car_label.annotation_label_name
Expand Down Expand Up @@ -164,7 +166,7 @@ def test_get_all_with_payload__with_video_frame(test_db: Session) -> None:
)

# Create annotations
create_annotation(
annotation = create_annotation(
session=test_db,
sample_id=video_frame_data.frame_sample_ids[0],
annotation_label_id=car_label.annotation_label_id,
Expand All @@ -179,13 +181,14 @@ def test_get_all_with_payload__with_video_frame(test_db: Session) -> None:

annotations_page = annotation_resolver.get_all_with_payload(
session=test_db,
sample_type=SampleType.VIDEO_FRAME,
dataset_id=annotation.sample.dataset_id,
)

assert annotations_page.total_count == 2
assert len(annotations_page.annotations) == 2

assert isinstance(annotations_page.annotations[0].parent_sample_data, VideoFrameAnnotationView)
assert annotations_page.annotations[0].parent_sample_type == SampleType.VIDEO
assert (
annotations_page.annotations[0].parent_sample_data.video.file_path_abs
== "/path/to/sample1.mp4"
Expand All @@ -196,6 +199,7 @@ def test_get_all_with_payload__with_video_frame(test_db: Session) -> None:
)

assert isinstance(annotations_page.annotations[1].parent_sample_data, VideoFrameAnnotationView)
assert annotations_page.annotations[0].parent_sample_type == SampleType.VIDEO
assert (
annotations_page.annotations[1].parent_sample_data.video.file_path_abs
== "/path/to/sample1.mp4"
Expand All @@ -206,14 +210,16 @@ def test_get_all_with_payload__with_video_frame(test_db: Session) -> None:
)


def test_get_all_with_payload__with_unsupported_sample_type(
def test_get_all_with_payload__with_unsupported_dataset(
test_db: Session,
) -> None:
with pytest.raises(NotImplementedError) as exc:
dataset = create_dataset(session=test_db, sample_type=SampleType.VIDEO)

with pytest.raises(
ValueError, match=f"Dataset with id {dataset.dataset_id} does not have a parent dataset."
):
annotation_resolver.get_all_with_payload(
session=test_db,
sample_type=SampleType.CAPTION,
pagination=Paginated(limit=1, offset=0),
dataset_id=dataset.dataset_id,
)

assert "Unsupported sample type" in str(exc.value)
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from __future__ import annotations

from sqlmodel import Session

from lightly_studio.resolvers import dataset_resolver
from tests.helpers_resolvers import (
create_annotation,
create_annotation_label,
create_dataset,
create_image,
)


def test_get_parent_dataset_id__from_parent_dataset(test_db: Session) -> None:
dataset = create_dataset(session=test_db)

image_1 = create_image(
session=test_db,
dataset_id=dataset.dataset_id,
file_path_abs="/path/to/sample2.png",
)
car_label = create_annotation_label(
session=test_db,
annotation_label_name="car",
)
annotation = create_annotation(
session=test_db,
sample_id=image_1.sample_id,
annotation_label_id=car_label.annotation_label_id,
dataset_id=dataset.dataset_id,
)

parent_dataset = dataset_resolver.get_parent_dataset_id(
session=test_db, dataset_id=annotation.sample.dataset_id
)
assert parent_dataset is not None
assert parent_dataset.dataset_id == dataset.dataset_id


def test_get_parent_dataset_id__from_root_dataset(test_db: Session) -> None:
dataset = create_dataset(session=test_db)

parent_dataset = dataset_resolver.get_parent_dataset_id(
session=test_db, dataset_id=dataset.dataset_id
)
assert parent_dataset is None
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,7 @@
import { useScrollRestoration } from '$lib/hooks/useScrollRestoration/useScrollRestoration';
import { addAnnotationLabelChangeToUndoStack } from '$lib/services/addAnnotationLabelChangeToUndoStack';
import { useUpdateAnnotationsMutation } from '$lib/hooks/useUpdateAnnotationsMutation/useUpdateAnnotationsMutation';
import {
AnnotationType,
SampleType,
type AnnotationWithPayloadView
} from '$lib/api/lightly_studio_local';
import { AnnotationType, type AnnotationWithPayloadView } from '$lib/api/lightly_studio_local';

type AnnotationsProps = {
dataset_id: string;
Expand Down Expand Up @@ -76,8 +72,7 @@
query: {
annotation_label_ids:
$selectedAnnotationFilterIds.length > 0 ? $selectedAnnotationFilterIds : undefined,
tag_ids: $tagsSelected.size > 0 ? Array.from($tagsSelected) : undefined,
sample_type: SampleType.IMAGE
tag_ids: $tagsSelected.size > 0 ? Array.from($tagsSelected) : undefined
}
});

Expand Down
8 changes: 3 additions & 5 deletions lightly_studio_view/src/lib/schema.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1812,6 +1812,7 @@ export interface components {
* @description Response model for annotation with payload.
*/
AnnotationWithPayloadView: {
parent_sample_type: components["schemas"]["SampleType"];
annotation: components["schemas"]["AnnotationView"];
/** Parent Sample Data */
parent_sample_data: components["schemas"]["ImageAnnotationView"] | components["schemas"]["VideoFrameAnnotationView"];
Expand Down Expand Up @@ -2119,9 +2120,7 @@ export interface components {
*/
ExecuteOperatorRequest: {
/** Parameters */
parameters: {
[key: string]: unknown;
};
parameters: Record<string, never>;
};
/**
* ExportBody
Expand Down Expand Up @@ -4191,8 +4190,7 @@ export interface operations {
};
read_annotations_with_payload: {
parameters: {
query: {
sample_type: components["schemas"]["SampleType"];
query?: {
annotation_label_ids?: string[] | null;
tag_ids?: string[] | null;
cursor?: number;
Expand Down