Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
AnnotationViewsWithCount,
AnnotationWithPayloadAndCountView,
)
from lightly_studio.models.dataset import DatasetTable, SampleType
from lightly_studio.models.dataset import DatasetTable
from lightly_studio.resolvers import annotation_resolver, tag_resolver
from lightly_studio.resolvers.annotation_resolver.get_all import (
GetAllAnnotationsResult,
Expand All @@ -45,20 +45,17 @@ class AnnotationQueryParamsModel(BaseModel):
"""Model for all annotation query parameters."""

pagination: PaginatedWithCursor
sample_type: SampleType
annotation_label_ids: list[UUID] | None = None
tag_ids: list[UUID] | None = None


def _get_annotation_query_params(
pagination: Annotated[PaginatedWithCursor, Depends()],
sample_type: SampleType,
annotation_label_ids: Annotated[list[UUID] | None, Query()] = None,
tag_ids: Annotated[list[UUID] | None, Query()] = None,
) -> AnnotationQueryParamsModel:
return AnnotationQueryParamsModel(
pagination=pagination,
sample_type=sample_type,
annotation_label_ids=annotation_label_ids,
tag_ids=tag_ids,
)
Expand Down Expand Up @@ -145,12 +142,12 @@ def read_annotations_with_payload(
offset=params.pagination.offset,
limit=params.pagination.limit,
),
sample_type=params.sample_type,
filters=AnnotationsFilter(
dataset_ids=[dataset_id],
annotation_label_ids=params.annotation_label_ids,
annotation_tag_ids=params.tag_ids,
),
dataset_id=dataset_id,
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
SemanticSegmentationAnnotationTable,
SemanticSegmentationAnnotationView,
)
from lightly_studio.models.dataset import SampleType
from lightly_studio.models.sample import SampleTable

if TYPE_CHECKING:
Expand Down Expand Up @@ -202,6 +203,7 @@ class AnnotationWithPayloadView(BaseModel):

model_config = ConfigDict(populate_by_name=True)

sample_type: SampleType
annotation: AnnotationView
parent_sample_data: Union[ImageAnnotationView, VideoFrameAnnotationView]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

from typing import Any
from uuid import UUID

from sqlalchemy.orm import aliased, joinedload, load_only
from sqlmodel import Session, col, func, select
Expand All @@ -20,14 +21,15 @@
from lightly_studio.models.image import ImageTable
from lightly_studio.models.sample import SampleTable
from lightly_studio.models.video import VideoFrameTable, VideoTable
from lightly_studio.resolvers import dataset_resolver
from lightly_studio.resolvers.annotations.annotations_filter import (
AnnotationsFilter,
)


def get_all_with_payload(
session: Session,
sample_type: SampleType,
dataset_id: UUID,
pagination: Paginated | None = None,
filters: AnnotationsFilter | None = None,
) -> AnnotationWithPayloadAndCountView:
Expand All @@ -38,10 +40,18 @@ def get_all_with_payload(
sample_type: Sample type to filter by
pagination: Optional pagination parameters
filters: Optional filters to apply to the query
dataset_id: ID of the dataset to get annotations for

Returns:
List of annotations matching the filters with payload
"""
parent_dataset = dataset_resolver.get_parent_dataset_id(session=session, dataset_id=dataset_id)

if parent_dataset is None:
raise NotImplementedError(f"Dataset with id {dataset_id} does not exist.")

sample_type = parent_dataset.sample_type

base_query = _build_base_query(sample_type=sample_type)

if filters:
Expand Down Expand Up @@ -69,7 +79,11 @@ def get_all_with_payload(
total_count=total_count,
next_cursor=next_cursor,
annotations=[
{"annotation": annotation, "parent_sample_data": _serialize_annotation_payload(payload)}
{
"sample_type": sample_type,
"annotation": annotation,
"parent_sample_data": _serialize_annotation_payload(payload),
}
for annotation, payload in rows
],
)
Expand Down Expand Up @@ -102,7 +116,7 @@ def _build_base_query(
)
)

if sample_type == SampleType.VIDEO_FRAME:
if sample_type in (SampleType.VIDEO_FRAME, SampleType.VIDEO):
return (
select(AnnotationBaseTable, VideoFrameTable)
.join(
Expand Down Expand Up @@ -130,7 +144,7 @@ def _extra_order_by(sample_type: SampleType) -> list[Any]:
col(ImageTable.file_path_abs).asc(),
]

if sample_type == SampleType.VIDEO_FRAME:
if sample_type in (SampleType.VIDEO_FRAME, SampleType.VIDEO):
return [
col(VideoTable.file_path_abs).asc(),
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
from lightly_studio.resolvers.dataset_resolver.get_or_create_child_dataset import (
get_or_create_child_dataset,
)
from lightly_studio.resolvers.dataset_resolver.get_parent_dataset_id import (
get_parent_dataset_id,
)
from lightly_studio.resolvers.dataset_resolver.get_root_dataset import (
get_root_dataset,
)
Expand All @@ -41,6 +44,7 @@
"get_filtered_samples_count",
"get_hierarchy",
"get_or_create_child_dataset",
"get_parent_dataset_id",
"get_root_dataset",
"get_root_datasets_overview",
"update",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""Retrieve the parent dataset ID for a given dataset ID."""

from __future__ import annotations

from uuid import UUID

from sqlalchemy.orm import aliased
from sqlmodel import Session, col, select

from lightly_studio.models.dataset import DatasetTable

ParentDataset = aliased(DatasetTable)
ChildDataset = aliased(DatasetTable)


def get_parent_dataset_id(session: Session, dataset_id: UUID) -> DatasetTable | None:
"""Retrieve the parent dataset for a given dataset ID."""
result = session.exec(
select(ParentDataset, ChildDataset)
.join(ChildDataset, col(ChildDataset.parent_dataset_id) == col(ParentDataset.dataset_id))
.where(ChildDataset.dataset_id == dataset_id)
).one_or_none()

return result[0] if result else None
1 change: 0 additions & 1 deletion lightly_studio/tests/api/routes/api/test_annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,6 @@ def test_read_annotations_with_payload(
params={
"offset": 0,
"limit": 1,
"sample_type": "image",
},
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def test_get_all_with_payload__with_pagination(
)

# Create annotations
create_annotation(
annotation = create_annotation(
session=test_db,
sample_id=image_1.sample_id,
annotation_label_id=car_label.annotation_label_id,
Expand All @@ -64,8 +64,8 @@ def test_get_all_with_payload__with_pagination(

annotations_page = annotation_resolver.get_all_with_payload(
session=test_db,
sample_type=SampleType.IMAGE,
pagination=Paginated(limit=1, offset=0),
dataset_id=annotation.sample.dataset_id,
)

assert annotations_page.total_count == 2
Expand Down Expand Up @@ -105,7 +105,7 @@ def test_get_all_with_payload__with_image(
)

# Create annotations
create_annotation(
annotation = create_annotation(
session=test_db,
sample_id=image_1.sample_id,
annotation_label_id=car_label.annotation_label_id,
Expand All @@ -120,13 +120,14 @@ def test_get_all_with_payload__with_image(

annotations_page = annotation_resolver.get_all_with_payload(
session=test_db,
sample_type=SampleType.IMAGE,
dataset_id=annotation.sample.dataset_id,
)

assert annotations_page.total_count == 2
assert len(annotations_page.annotations) == 2

assert isinstance(annotations_page.annotations[0].parent_sample_data, ImageAnnotationView)
assert annotations_page.annotations[0].sample_type == SampleType.IMAGE
assert (
annotations_page.annotations[0].annotation.annotation_label.annotation_label_name
== airplane_label.annotation_label_name
Expand All @@ -135,6 +136,7 @@ def test_get_all_with_payload__with_image(
assert annotations_page.annotations[0].parent_sample_data.sample.dataset_id == dataset_id

assert isinstance(annotations_page.annotations[1].parent_sample_data, ImageAnnotationView)
assert annotations_page.annotations[0].sample_type == SampleType.IMAGE
assert (
annotations_page.annotations[1].annotation.annotation_label.annotation_label_name
== car_label.annotation_label_name
Expand Down Expand Up @@ -164,7 +166,7 @@ def test_get_all_with_payload__with_video_frame(test_db: Session) -> None:
)

# Create annotations
create_annotation(
annotation = create_annotation(
session=test_db,
sample_id=video_frame_data.frame_sample_ids[0],
annotation_label_id=car_label.annotation_label_id,
Expand All @@ -179,13 +181,14 @@ def test_get_all_with_payload__with_video_frame(test_db: Session) -> None:

annotations_page = annotation_resolver.get_all_with_payload(
session=test_db,
sample_type=SampleType.VIDEO_FRAME,
dataset_id=annotation.sample.dataset_id,
)

assert annotations_page.total_count == 2
assert len(annotations_page.annotations) == 2

assert isinstance(annotations_page.annotations[0].parent_sample_data, VideoFrameAnnotationView)
assert annotations_page.annotations[0].sample_type == SampleType.VIDEO
assert (
annotations_page.annotations[0].parent_sample_data.video.file_path_abs
== "/path/to/sample1.mp4"
Expand All @@ -196,6 +199,7 @@ def test_get_all_with_payload__with_video_frame(test_db: Session) -> None:
)

assert isinstance(annotations_page.annotations[1].parent_sample_data, VideoFrameAnnotationView)
assert annotations_page.annotations[0].sample_type == SampleType.VIDEO
assert (
annotations_page.annotations[1].parent_sample_data.video.file_path_abs
== "/path/to/sample1.mp4"
Expand All @@ -206,14 +210,16 @@ def test_get_all_with_payload__with_video_frame(test_db: Session) -> None:
)


def test_get_all_with_payload__with_unsupported_sample_type(
def test_get_all_with_payload__with_unsupported_dataset(
test_db: Session,
) -> None:
dataset = create_dataset(session=test_db, sample_type=SampleType.VIDEO)

with pytest.raises(NotImplementedError) as exc:
annotation_resolver.get_all_with_payload(
session=test_db,
sample_type=SampleType.CAPTION,
pagination=Paginated(limit=1, offset=0),
dataset_id=dataset.dataset_id,
)

assert "Unsupported sample type" in str(exc.value)
assert f"Dataset with id {dataset.dataset_id} does not exist" in str(exc.value)
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from __future__ import annotations

from sqlmodel import Session

from lightly_studio.resolvers import dataset_resolver
from tests.helpers_resolvers import (
create_annotation,
create_annotation_label,
create_dataset,
create_image,
)


def test_get_parent_dataset_id(test_db: Session) -> None:
# Create two datasets
dataset = create_dataset(session=test_db)

image_1 = create_image(
session=test_db,
dataset_id=dataset.dataset_id,
file_path_abs="/path/to/sample2.png",
)
car_label = create_annotation_label(
session=test_db,
annotation_label_name="car",
)
annotation = create_annotation(
session=test_db,
sample_id=image_1.sample_id,
annotation_label_id=car_label.annotation_label_id,
dataset_id=dataset.dataset_id,
)

parent_dataset = dataset_resolver.get_parent_dataset_id(
session=test_db, dataset_id=annotation.sample.dataset_id
)
assert parent_dataset is not None
assert parent_dataset.dataset_id == dataset.dataset_id
Loading