diff --git a/src/dependencies/pagination.py b/src/dependencies/pagination.py index 9d88b1c1..68cb6487 100644 --- a/src/dependencies/pagination.py +++ b/src/dependencies/pagination.py @@ -1,7 +1,8 @@ -from typing import Annotated +from typing import Annotated, Sequence, TypeVar, Generic from fastapi import Query, Depends from pydantic import BaseModel +from pydantic.generics import GenericModel from sqlmodel import Field @@ -25,3 +26,11 @@ class Pagination(BaseModel): PaginationParams = Annotated[Pagination, Depends(Pagination)] + +T = TypeVar("T") + +class PaginatedResponse(GenericModel, Generic[T]): + offset: int + limit: int + total_count: int + data: Sequence[T] diff --git a/src/routers/bookmark_router.py b/src/routers/bookmark_router.py index 98094d3d..2e1cce48 100644 --- a/src/routers/bookmark_router.py +++ b/src/routers/bookmark_router.py @@ -4,6 +4,7 @@ from fastapi import APIRouter, Depends, HTTPException from typing import List, cast from sqlmodel import Session, select, Field, SQLModel +from sqlalchemy import func from authentication import KeycloakUser, get_user_or_raise from database.session import get_session @@ -12,7 +13,7 @@ from http import HTTPStatus from datetime import datetime -from dependencies.pagination import PaginationParams +from dependencies.pagination import PaginationParams, PaginatedResponse from database.model.helper_functions import get_asset_type_by_abbreviation from versioning import Version @@ -38,21 +39,24 @@ def create(url_prefix: str = "", version: Version = Version.LATEST) -> APIRouter path, tags=["User"], description="Return all your bookmarks.", - response_model=List[BookmarkRead], + response_model=PaginatedResponse[BookmarkRead], ) def list_bookmarks( pagination: PaginationParams, user: KeycloakUser = Depends(get_user_or_raise), session: Session = Depends(get_session), - ) -> List[BookmarkRead]: + ) -> PaginatedResponse[BookmarkRead]: + base_stmt = select(Bookmark).where(Bookmark.user_identifier == user._subject_identifier) + total_count = session.scalar(select(func.count()).select_from(base_stmt.subquery())) or 0 stmt = ( - select(Bookmark) - .where(Bookmark.user_identifier == user._subject_identifier) - .order_by(Bookmark.created_at) + base_stmt.order_by(Bookmark.created_at) .offset(pagination.offset) .limit(pagination.limit) ) - return session.exec(stmt).all() + data = session.exec(stmt).all() + return PaginatedResponse( + offset=pagination.offset, limit=pagination.limit, total_count=total_count, data=data + ) @router.post( path, diff --git a/src/routers/resource_router.py b/src/routers/resource_router.py index b6ec00a2..34e90a7a 100644 --- a/src/routers/resource_router.py +++ b/src/routers/resource_router.py @@ -27,7 +27,7 @@ from database.review import Submission, SubmissionCreateV2, AssetReview from database.session import DbSession from dependencies.filtering import ResourceFilters, ResourceFiltersParams -from dependencies.pagination import Pagination, PaginationParams +from dependencies.pagination import Pagination, PaginationParams, PaginatedResponse from dependencies.sorting import SortingParams, Sorting, SortDirection from error_handling import as_http_exception from database.model.ai_asset.distribution import Distribution @@ -112,9 +112,7 @@ def create(self, url_prefix: str, version: Version = Version.LATEST) -> APIRoute } available_schemas: list[Type] = [c.to_class for c in self.schema_converters.values()] response_model = Union[self.resource_class_read, *available_schemas] # type:ignore - response_model_plural = Union[ # type:ignore - list[self.resource_class_read], *[list[s] for s in available_schemas] # type:ignore - ] + response_model_plural = PaginatedResponse[response_model] # type:ignore router.add_api_route( path=f"/{self.resource_name_plural}", @@ -231,12 +229,20 @@ def get_resources( resources: Any = self._retrieve_resources_and_post_process( session, pagination, sorting, resource_filters, user, platform ) + total_count = self._retrieve_resources_count(session, resource_filters, platform) + for resource in resources: if not get_image and hasattr(resource, "media"): for media_obj in resource.media: media_obj.binary_blob = None - return [convert_schema(resource) for resource in resources] + data = [convert_schema(resource) for resource in resources] + return PaginatedResponse( + offset=pagination.offset, + limit=pagination.limit, + total_count=total_count, + data=data, + ) except Exception as e: raise as_http_exception(e) @@ -767,6 +773,34 @@ def _retrieve_resources( resources: Sequence = session.scalars(query).all() return resources + def _retrieve_resources_count( + self, + session: Session, + resource_filters: ResourceFilters, + platform: str | None = None, + ) -> int: + """ + Retrieve the total count of published resources from the database based on the + provided platform and resource filters (if applicable). + """ + where_clause = and_( + is_(self.resource_class.date_deleted, None), + (self.resource_class.platform == platform) if platform is not None else True, + AIoDEntryORM.date_modified >= resource_filters.date_modified_after + if resource_filters.date_modified_after is not None + else True, + AIoDEntryORM.date_modified < resource_filters.date_modified_before + if resource_filters.date_modified_before is not None + else True, + AIoDEntryORM.status == EntryStatus.PUBLISHED, + ) + query = ( + select(func.count(self.resource_class.identifier)) + .join(self.resource_class.aiod_entry, isouter=True) + .where(where_clause) + ) + return session.scalar(query) or 0 + def _retrieve_resource_and_post_process( self, session: Session, diff --git a/src/routers/resource_routers/platform_router.py b/src/routers/resource_routers/platform_router.py index bd3f1d93..cde1deaf 100644 --- a/src/routers/resource_routers/platform_router.py +++ b/src/routers/resource_routers/platform_router.py @@ -3,13 +3,14 @@ from fastapi import Depends, HTTPException, status, APIRouter from sqlmodel import SQLModel, Session, select +from sqlalchemy import func from authentication import KeycloakUser, get_user_or_raise from database.model.platform.platform import Platform from database.model.resource_read_and_create import resource_create, resource_read from database.model.serializers import deserialize_resource_relationships from database.session import DbSession -from dependencies.pagination import Pagination, PaginationParams +from dependencies.pagination import Pagination, PaginationParams, PaginatedResponse from error_handling import as_http_exception from versioning import Version @@ -55,7 +56,7 @@ def create(self, url_prefix: str, version: Version) -> APIRouter: "tags": [self.resource_name_plural], } response_model = self.resource_class_read # type:ignore - response_model_plural = list[self.resource_class_read] # type:ignore + response_model_plural = PaginatedResponse[self.resource_class_read] # type:ignore router.add_api_route( path=f"/{self.resource_name_plural}", @@ -116,8 +117,15 @@ def get_resources(self, pagination: Pagination): """Fetch all resources.""" with DbSession(autoflush=False) as session: try: + total_count = session.query(self.resource_class).count() resources: Any = self._retrieve_resources(session, pagination) - return [self.resource_class_read.model_validate(resource) for resource in resources] + data = [self.resource_class_read.model_validate(resource) for resource in resources] + return PaginatedResponse( + offset=pagination.offset, + limit=pagination.limit, + total_count=total_count, + data=data, + ) except Exception as e: raise as_http_exception(e) diff --git a/src/tests/routers/generic/test_authentication.py b/src/tests/routers/generic/test_authentication.py index 089c137d..3380b916 100644 --- a/src/tests/routers/generic/test_authentication.py +++ b/src/tests/routers/generic/test_authentication.py @@ -12,7 +12,7 @@ def test_get_all_unauthenticated( """You don't need authentication for GET""" response = client_test_resource.get("/test_resources") assert response.status_code == 200, response.json() - assert len(response.json()) == 1 + assert len(response.json()["data"]) == 1 def test_get_unauthenticated(client_test_resource: TestClient, auto_publish: None, engine_test_resource_filled: Engine): @@ -27,7 +27,7 @@ def test_platform_get_all_unauthenticated( """You don't need authentication for GET""" response = client_test_resource.get("/platforms/example/test_resources") assert response.status_code == 200, response.json() - assert len(response.json()) == 1 + assert len(response.json()["data"]) == 1 def test_platform_get_unauthenticated( @@ -36,7 +36,7 @@ def test_platform_get_unauthenticated( """You don't need authentication for GET""" response = client_test_resource.get("/platforms/example/test_resources") assert response.status_code == 200, response.json() - assert len(response.json()) == 1 + assert len(response.json()["data"]) == 1 @pytest.mark.parametrize( diff --git a/src/tests/routers/generic/test_router_delete.py b/src/tests/routers/generic/test_router_delete.py index f6f24b7d..a7b47bc0 100644 --- a/src/tests/routers/generic/test_router_delete.py +++ b/src/tests/routers/generic/test_router_delete.py @@ -37,7 +37,7 @@ def test_happy_path( assert response.status_code == 200, response.json() response = client_test_resource.get("/test_resources/") assert response.status_code == 200, response.json() - response_json = response.json() + response_json = response.json()["data"] assert len(response_json) == 1 assert {r["identifier"] for r in response_json} == set(identifiers) - {identifier} diff --git a/src/tests/routers/generic/test_router_get_all.py b/src/tests/routers/generic/test_router_get_all.py index 547bdb93..844b7c59 100644 --- a/src/tests/routers/generic/test_router_get_all.py +++ b/src/tests/routers/generic/test_router_get_all.py @@ -21,7 +21,7 @@ def test_get_all_happy_path(client_test_resource: TestClient): session.commit() response = client_test_resource.get("/test_resources?direction=asc") assert response.status_code == 200, response.json() - response_json = response.json() + response_json = response.json()["data"] assert len(response_json) == 2, "Expecting only two published assets" response_1, response_2 = response_json diff --git a/src/tests/routers/generic/test_router_platform_get_all.py b/src/tests/routers/generic/test_router_platform_get_all.py index 8479fc64..3bfed590 100644 --- a/src/tests/routers/generic/test_router_platform_get_all.py +++ b/src/tests/routers/generic/test_router_platform_get_all.py @@ -24,7 +24,7 @@ def test_get_all_happy_path(client_test_resource: TestClient, auto_publish): session.commit() response = client_test_resource.get("/platforms/example/test_resources?direction=asc") assert response.status_code == 200, response.json() - response_json = response.json() + response_json = response.json()["data"] assert len(response_json) == 2 response_1, response_2 = response_json diff --git a/src/tests/routers/generic/test_router_relations.py b/src/tests/routers/generic/test_router_relations.py index d3ca3c64..6c15dc3c 100644 --- a/src/tests/routers/generic/test_router_relations.py +++ b/src/tests/routers/generic/test_router_relations.py @@ -183,7 +183,7 @@ def test_get_happy_path(test_objects: list[TestObject], client_with_testobject: def test_get_all_happy_path(client_with_testobject: TestClient): response = client_with_testobject.get("/test_resources?direction=asc") assert response.status_code == 200, response.json() - response_json = response.json() + response_json = response.json()["data"] assert "deprecated" not in response.headers assert len(response_json) == 4 @@ -213,7 +213,7 @@ def test_post_happy_path(client_with_testobject: TestClient, auto_publish: None) headers={"Authorization": "Fake token"}, ) assert response.status_code == 200, response.json() - objects = client_with_testobject.get("/test_resources?direction=asc").json() + objects = client_with_testobject.get("/test_resources?direction=asc").json()["data"] obj = objects[-1] assert obj["title"] == "title" assert obj["named_string"] == "named_string1" diff --git a/src/tests/routers/generic/test_router_scheme.py b/src/tests/routers/generic/test_router_scheme.py index c743c67b..1f5c6386 100644 --- a/src/tests/routers/generic/test_router_scheme.py +++ b/src/tests/routers/generic/test_router_scheme.py @@ -56,7 +56,7 @@ def test_resources_aiod( for client in [client_test_resource_other_schema, client_test_resource]: response = client.get("/test_resources" + schema_string) assert response.status_code == 200, response.json() - json_ = response.json() + json_ = response.json()["data"] assert len(json_) == 1 assert json_[0]["title"] == "A title" assert "title_with_alternative_name" not in json_[0] @@ -89,7 +89,7 @@ def test_resources_other_schema( ): response = client_test_resource_other_schema.get("/test_resources?schema=other-schema") assert response.status_code == 200, response.json() - json_ = response.json() + json_ = response.json()["data"] assert len(json_) == 1 assert json_[0]["title_with_alternative_name"] == "A title" assert "title" not in json_[0] diff --git a/src/tests/routers/resource_routers/test_resource_router.py b/src/tests/routers/resource_routers/test_resource_router.py index f42706ff..a30fb0a7 100644 --- a/src/tests/routers/resource_routers/test_resource_router.py +++ b/src/tests/routers/resource_routers/test_resource_router.py @@ -70,8 +70,8 @@ def test_happy_path_with_filters( assert response.status_code == 200, response.json() response_json = response.json() - assert isinstance(response_json, list) - assert len(response_json) == expected_count + assert "data" in response_json + assert len(response_json["data"]) == expected_count @pytest.mark.parametrize( @@ -104,13 +104,14 @@ def test_happy_path_with_sorting( ) for sort, direction in itertools.product(list(Sort), list(SortDirection)): - resources = client.get( + resources_json = client.get( f"/{resource_type}", params={ "direction": str(direction), "sort": str(sort), }, ).json() + resources = resources_json["data"] match sort, direction: case Sort.DATE_MODIFIED, SortDirection.ASC: diff --git a/src/tests/routers/resource_routers/test_router_contact.py b/src/tests/routers/resource_routers/test_router_contact.py index 5990e1b8..46870125 100644 --- a/src/tests/routers/resource_routers/test_router_contact.py +++ b/src/tests/routers/resource_routers/test_router_contact.py @@ -135,6 +135,8 @@ def test_email_mask_for_not_authenticated_user( guest_response = client.get(endpoint) assert guest_response.status_code == 200, guest_response.json() guest_response_json = guest_response.json() + if isinstance(guest_response_json, dict) and "data" in guest_response_json: + guest_response_json = guest_response_json["data"] if not isinstance(guest_response_json, list): guest_response_json = [guest_response_json] @@ -164,9 +166,10 @@ def test_email_mask_for_authenticated_user( response = client.get("/contacts?direction=asc", headers=headers) response_json = response.json() assert response.status_code == 200, response_json - assert len(response_json) == 2, response_json - assert response_json[0]["email"] == ["a@b.com"] - assert set(response_json[1]["email"]) == {"fake2@email.com", "fake@email.com"} + data = response_json["data"] if isinstance(response_json, dict) else response_json + assert len(data) == 2, response_json + assert data[0]["email"] == ["a@b.com"] + assert set(data[1]["email"]) == {"fake2@email.com", "fake@email.com"} response = client.get(f"/contacts/{contact2.identifier}", headers=headers) assert response.status_code == 200, response.json() @@ -177,9 +180,10 @@ def test_email_mask_for_authenticated_user( response_json = response.json() assert response.status_code == 200, response_json - assert len(response_json) == 2, response_json - assert response_json[0]["email"] == ["a@b.com"] - assert set(response_json[1]["email"]) == {"fake2@email.com", "fake@email.com"} + data = response_json["data"] if isinstance(response_json, dict) else response_json + assert len(data) == 2, response_json + assert data[0]["email"] == ["a@b.com"] + assert set(data[1]["email"]) == {"fake2@email.com", "fake@email.com"} response = client.get("/platforms/aiod/contacts/fake:100", headers=headers) response_json = response.json() @@ -220,6 +224,8 @@ def test_email_privacy_for_ai4europe_cms( endpoint = endpoint.replace("/1", f"/{contact.identifier}") response = client.get(endpoint, headers=headers) response_json = response.json() + if isinstance(response_json, dict) and "data" in response_json: + response_json = response_json["data"] if isinstance(response_json, list): response_json = response_json[0] @@ -229,8 +235,11 @@ def test_email_privacy_for_ai4europe_cms( keycloak_openid.introspect = AI4EUROPE_CMS_TOKEN + endpoint = endpoint.replace("/1", f"/{contact.identifier}") response = client.get(endpoint, headers=headers) response_json = response.json() + if isinstance(response_json, dict) and "data" in response_json: + response_json = response_json["data"] if isinstance(response_json, list): response_json = response_json[0] diff --git a/src/tests/routers/resource_routers/test_router_person.py b/src/tests/routers/resource_routers/test_router_person.py index 58292987..3778f49f 100644 --- a/src/tests/routers/resource_routers/test_router_person.py +++ b/src/tests/routers/resource_routers/test_router_person.py @@ -101,6 +101,8 @@ def test_privacy_for_ai4europe_cms( endpoint = endpoint.replace("/1", f"/{person.identifier}") response = client.get(endpoint, headers=headers) response_json = response.json() + if isinstance(response_json, dict) and "data" in response_json: + response_json = response_json["data"] response_json = [response_json] if isinstance(response_json, dict) else response_json assert response.status_code == 200, response_json for person_dict in response_json: @@ -111,6 +113,8 @@ def test_privacy_for_ai4europe_cms( keycloak_openid.introspect = AI4EUROPE_CMS_TOKEN response = client.get(endpoint, headers=headers) response_json = response.json() + if isinstance(response_json, dict) and "data" in response_json: + response_json = response_json["data"] response_json = [response_json] if isinstance(response_json, dict) else response_json assert response.status_code == 200, response_json for person_dict in response_json: diff --git a/src/tests/routers/resource_routers/test_router_platform.py b/src/tests/routers/resource_routers/test_router_platform.py index 148ea1a1..53f5993a 100644 --- a/src/tests/routers/resource_routers/test_router_platform.py +++ b/src/tests/routers/resource_routers/test_router_platform.py @@ -15,7 +15,7 @@ def test_happy_path(mocked_token: Mock, client: TestClient, auto_publish: None): assert response.status_code == 200, response.json() response = client.get("/platforms") assert response.status_code == 200, response.json() - platforms = {p["name"] for p in response.json()} + platforms = {p["name"] for p in response.json()["data"]} assert platforms == {p.name for p in PlatformName}.union(["my_favourite_platform"]) diff --git a/src/tests/test_bookmark_endpoints.py b/src/tests/test_bookmark_endpoints.py index fd6b563d..d09f6f9b 100644 --- a/src/tests/test_bookmark_endpoints.py +++ b/src/tests/test_bookmark_endpoints.py @@ -102,7 +102,7 @@ def test_get_bookmarks(client: TestClient, person: Person, contact: Contact) -> ) assert response.status_code == HTTPStatus.OK - bookmarks = response.json() + bookmarks = response.json()["data"] assert len(bookmarks) == 2 @@ -140,7 +140,7 @@ def test_delete_bookmark( headers={"Authorization": "fake token"}, ) assert response.status_code == HTTPStatus.OK - assert all(b["resource_identifier"] != identifier for b in response.json()) + assert all(b["resource_identifier"] != identifier for b in response.json().get("data", [])) @pytest.mark.versions(Version.V2, Version.LATEST) @@ -161,25 +161,25 @@ def test_get_bookmark_pagination(client: TestClient, publication_factory) -> Non headers={"Authorization": "fake token"}, ) assert response.status_code == HTTPStatus.OK - assert len(response.json()) == PAGINATION_DEFAULT_LIMIT + assert len(response.json()["data"]) == PAGINATION_DEFAULT_LIMIT response = client.get( "/bookmarks?limit=100", headers={"Authorization": "fake token"}, ) assert response.status_code == HTTPStatus.OK - assert len(response.json()) == PAGINATION_DEFAULT_LIMIT + 1 + assert len(response.json()["data"]) == PAGINATION_DEFAULT_LIMIT + 1 response = client.get( "/bookmarks?offset=10", headers={"Authorization": "fake token"}, ) assert response.status_code == HTTPStatus.OK - assert len(response.json()) == 1 + assert len(response.json()["data"]) == 1 response = client.get( "/bookmarks?offset=8&limit=2", headers={"Authorization": "fake token"}, ) assert response.status_code == HTTPStatus.OK - assert len(response.json()) == 2 + assert len(response.json()["data"]) == 2