Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion src/dependencies/pagination.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from typing import Annotated
from typing import Annotated, Sequence, TypeVar, Generic

from fastapi import Query, Depends
from pydantic import BaseModel
from pydantic.generics import GenericModel
from sqlmodel import Field


Expand All @@ -25,3 +26,11 @@ class Pagination(BaseModel):


PaginationParams = Annotated[Pagination, Depends(Pagination)]

T = TypeVar("T")

class PaginatedResponse(GenericModel, Generic[T]):
offset: int
limit: int
total_count: int
data: Sequence[T]
18 changes: 11 additions & 7 deletions src/routers/bookmark_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from fastapi import APIRouter, Depends, HTTPException
from typing import List, cast
from sqlmodel import Session, select, Field, SQLModel
from sqlalchemy import func

from authentication import KeycloakUser, get_user_or_raise
from database.session import get_session
Expand All @@ -12,7 +13,7 @@
from http import HTTPStatus
from datetime import datetime

from dependencies.pagination import PaginationParams
from dependencies.pagination import PaginationParams, PaginatedResponse
from database.model.helper_functions import get_asset_type_by_abbreviation
from versioning import Version

Expand All @@ -38,21 +39,24 @@ def create(url_prefix: str = "", version: Version = Version.LATEST) -> APIRouter
path,
tags=["User"],
description="Return all your bookmarks.",
response_model=List[BookmarkRead],
response_model=PaginatedResponse[BookmarkRead],
)
def list_bookmarks(
pagination: PaginationParams,
user: KeycloakUser = Depends(get_user_or_raise),
session: Session = Depends(get_session),
) -> List[BookmarkRead]:
) -> PaginatedResponse[BookmarkRead]:
base_stmt = select(Bookmark).where(Bookmark.user_identifier == user._subject_identifier)
total_count = session.scalar(select(func.count()).select_from(base_stmt.subquery())) or 0
stmt = (
select(Bookmark)
.where(Bookmark.user_identifier == user._subject_identifier)
.order_by(Bookmark.created_at)
base_stmt.order_by(Bookmark.created_at)
.offset(pagination.offset)
.limit(pagination.limit)
)
return session.exec(stmt).all()
data = session.exec(stmt).all()
return PaginatedResponse(
offset=pagination.offset, limit=pagination.limit, total_count=total_count, data=data
)

@router.post(
path,
Expand Down
44 changes: 39 additions & 5 deletions src/routers/resource_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from database.review import Submission, SubmissionCreateV2, AssetReview
from database.session import DbSession
from dependencies.filtering import ResourceFilters, ResourceFiltersParams
from dependencies.pagination import Pagination, PaginationParams
from dependencies.pagination import Pagination, PaginationParams, PaginatedResponse
from dependencies.sorting import SortingParams, Sorting, SortDirection
from error_handling import as_http_exception
from database.model.ai_asset.distribution import Distribution
Expand Down Expand Up @@ -112,9 +112,7 @@ def create(self, url_prefix: str, version: Version = Version.LATEST) -> APIRoute
}
available_schemas: list[Type] = [c.to_class for c in self.schema_converters.values()]
response_model = Union[self.resource_class_read, *available_schemas] # type:ignore
response_model_plural = Union[ # type:ignore
list[self.resource_class_read], *[list[s] for s in available_schemas] # type:ignore
]
response_model_plural = PaginatedResponse[response_model] # type:ignore

router.add_api_route(
path=f"/{self.resource_name_plural}",
Expand Down Expand Up @@ -231,12 +229,20 @@ def get_resources(
resources: Any = self._retrieve_resources_and_post_process(
session, pagination, sorting, resource_filters, user, platform
)
total_count = self._retrieve_resources_count(session, resource_filters, platform)

for resource in resources:
if not get_image and hasattr(resource, "media"):
for media_obj in resource.media:
media_obj.binary_blob = None

return [convert_schema(resource) for resource in resources]
data = [convert_schema(resource) for resource in resources]
return PaginatedResponse(
offset=pagination.offset,
limit=pagination.limit,
total_count=total_count,
data=data,
)
except Exception as e:
raise as_http_exception(e)

Expand Down Expand Up @@ -767,6 +773,34 @@ def _retrieve_resources(
resources: Sequence = session.scalars(query).all()
return resources

def _retrieve_resources_count(
self,
session: Session,
resource_filters: ResourceFilters,
platform: str | None = None,
) -> int:
"""
Retrieve the total count of published resources from the database based on the
provided platform and resource filters (if applicable).
"""
where_clause = and_(
is_(self.resource_class.date_deleted, None),
(self.resource_class.platform == platform) if platform is not None else True,
AIoDEntryORM.date_modified >= resource_filters.date_modified_after
if resource_filters.date_modified_after is not None
else True,
AIoDEntryORM.date_modified < resource_filters.date_modified_before
if resource_filters.date_modified_before is not None
else True,
AIoDEntryORM.status == EntryStatus.PUBLISHED,
)
query = (
select(func.count(self.resource_class.identifier))
.join(self.resource_class.aiod_entry, isouter=True)
.where(where_clause)
)
return session.scalar(query) or 0

def _retrieve_resource_and_post_process(
self,
session: Session,
Expand Down
14 changes: 11 additions & 3 deletions src/routers/resource_routers/platform_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@

from fastapi import Depends, HTTPException, status, APIRouter
from sqlmodel import SQLModel, Session, select
from sqlalchemy import func

from authentication import KeycloakUser, get_user_or_raise
from database.model.platform.platform import Platform
from database.model.resource_read_and_create import resource_create, resource_read
from database.model.serializers import deserialize_resource_relationships
from database.session import DbSession
from dependencies.pagination import Pagination, PaginationParams
from dependencies.pagination import Pagination, PaginationParams, PaginatedResponse
from error_handling import as_http_exception
from versioning import Version

Expand Down Expand Up @@ -55,7 +56,7 @@ def create(self, url_prefix: str, version: Version) -> APIRouter:
"tags": [self.resource_name_plural],
}
response_model = self.resource_class_read # type:ignore
response_model_plural = list[self.resource_class_read] # type:ignore
response_model_plural = PaginatedResponse[self.resource_class_read] # type:ignore

router.add_api_route(
path=f"/{self.resource_name_plural}",
Expand Down Expand Up @@ -116,8 +117,15 @@ def get_resources(self, pagination: Pagination):
"""Fetch all resources."""
with DbSession(autoflush=False) as session:
try:
total_count = session.query(self.resource_class).count()
resources: Any = self._retrieve_resources(session, pagination)
return [self.resource_class_read.model_validate(resource) for resource in resources]
data = [self.resource_class_read.model_validate(resource) for resource in resources]
return PaginatedResponse(
offset=pagination.offset,
limit=pagination.limit,
total_count=total_count,
data=data,
)
except Exception as e:
raise as_http_exception(e)

Expand Down
6 changes: 3 additions & 3 deletions src/tests/routers/generic/test_authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def test_get_all_unauthenticated(
"""You don't need authentication for GET"""
response = client_test_resource.get("/test_resources")
assert response.status_code == 200, response.json()
assert len(response.json()) == 1
assert len(response.json()["data"]) == 1


def test_get_unauthenticated(client_test_resource: TestClient, auto_publish: None, engine_test_resource_filled: Engine):
Expand All @@ -27,7 +27,7 @@ def test_platform_get_all_unauthenticated(
"""You don't need authentication for GET"""
response = client_test_resource.get("/platforms/example/test_resources")
assert response.status_code == 200, response.json()
assert len(response.json()) == 1
assert len(response.json()["data"]) == 1


def test_platform_get_unauthenticated(
Expand All @@ -36,7 +36,7 @@ def test_platform_get_unauthenticated(
"""You don't need authentication for GET"""
response = client_test_resource.get("/platforms/example/test_resources")
assert response.status_code == 200, response.json()
assert len(response.json()) == 1
assert len(response.json()["data"]) == 1


@pytest.mark.parametrize(
Expand Down
2 changes: 1 addition & 1 deletion src/tests/routers/generic/test_router_delete.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def test_happy_path(
assert response.status_code == 200, response.json()
response = client_test_resource.get("/test_resources/")
assert response.status_code == 200, response.json()
response_json = response.json()
response_json = response.json()["data"]
assert len(response_json) == 1
assert {r["identifier"] for r in response_json} == set(identifiers) - {identifier}

Expand Down
2 changes: 1 addition & 1 deletion src/tests/routers/generic/test_router_get_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def test_get_all_happy_path(client_test_resource: TestClient):
session.commit()
response = client_test_resource.get("/test_resources?direction=asc")
assert response.status_code == 200, response.json()
response_json = response.json()
response_json = response.json()["data"]

assert len(response_json) == 2, "Expecting only two published assets"
response_1, response_2 = response_json
Expand Down
2 changes: 1 addition & 1 deletion src/tests/routers/generic/test_router_platform_get_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def test_get_all_happy_path(client_test_resource: TestClient, auto_publish):
session.commit()
response = client_test_resource.get("/platforms/example/test_resources?direction=asc")
assert response.status_code == 200, response.json()
response_json = response.json()
response_json = response.json()["data"]

assert len(response_json) == 2
response_1, response_2 = response_json
Expand Down
4 changes: 2 additions & 2 deletions src/tests/routers/generic/test_router_relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def test_get_happy_path(test_objects: list[TestObject], client_with_testobject:
def test_get_all_happy_path(client_with_testobject: TestClient):
response = client_with_testobject.get("/test_resources?direction=asc")
assert response.status_code == 200, response.json()
response_json = response.json()
response_json = response.json()["data"]
assert "deprecated" not in response.headers

assert len(response_json) == 4
Expand Down Expand Up @@ -213,7 +213,7 @@ def test_post_happy_path(client_with_testobject: TestClient, auto_publish: None)
headers={"Authorization": "Fake token"},
)
assert response.status_code == 200, response.json()
objects = client_with_testobject.get("/test_resources?direction=asc").json()
objects = client_with_testobject.get("/test_resources?direction=asc").json()["data"]
obj = objects[-1]
assert obj["title"] == "title"
assert obj["named_string"] == "named_string1"
Expand Down
4 changes: 2 additions & 2 deletions src/tests/routers/generic/test_router_scheme.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def test_resources_aiod(
for client in [client_test_resource_other_schema, client_test_resource]:
response = client.get("/test_resources" + schema_string)
assert response.status_code == 200, response.json()
json_ = response.json()
json_ = response.json()["data"]
assert len(json_) == 1
assert json_[0]["title"] == "A title"
assert "title_with_alternative_name" not in json_[0]
Expand Down Expand Up @@ -89,7 +89,7 @@ def test_resources_other_schema(
):
response = client_test_resource_other_schema.get("/test_resources?schema=other-schema")
assert response.status_code == 200, response.json()
json_ = response.json()
json_ = response.json()["data"]
assert len(json_) == 1
assert json_[0]["title_with_alternative_name"] == "A title"
assert "title" not in json_[0]
Expand Down
7 changes: 4 additions & 3 deletions src/tests/routers/resource_routers/test_resource_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ def test_happy_path_with_filters(
assert response.status_code == 200, response.json()

response_json = response.json()
assert isinstance(response_json, list)
assert len(response_json) == expected_count
assert "data" in response_json
assert len(response_json["data"]) == expected_count


@pytest.mark.parametrize(
Expand Down Expand Up @@ -104,13 +104,14 @@ def test_happy_path_with_sorting(
)

for sort, direction in itertools.product(list(Sort), list(SortDirection)):
resources = client.get(
resources_json = client.get(
f"/{resource_type}",
params={
"direction": str(direction),
"sort": str(sort),
},
).json()
resources = resources_json["data"]

match sort, direction:
case Sort.DATE_MODIFIED, SortDirection.ASC:
Expand Down
21 changes: 15 additions & 6 deletions src/tests/routers/resource_routers/test_router_contact.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,8 @@ def test_email_mask_for_not_authenticated_user(
guest_response = client.get(endpoint)
assert guest_response.status_code == 200, guest_response.json()
guest_response_json = guest_response.json()
if isinstance(guest_response_json, dict) and "data" in guest_response_json:
guest_response_json = guest_response_json["data"]
if not isinstance(guest_response_json, list):
guest_response_json = [guest_response_json]

Expand Down Expand Up @@ -164,9 +166,10 @@ def test_email_mask_for_authenticated_user(
response = client.get("/contacts?direction=asc", headers=headers)
response_json = response.json()
assert response.status_code == 200, response_json
assert len(response_json) == 2, response_json
assert response_json[0]["email"] == ["[email protected]"]
assert set(response_json[1]["email"]) == {"[email protected]", "[email protected]"}
data = response_json["data"] if isinstance(response_json, dict) else response_json
assert len(data) == 2, response_json
assert data[0]["email"] == ["[email protected]"]
assert set(data[1]["email"]) == {"[email protected]", "[email protected]"}

response = client.get(f"/contacts/{contact2.identifier}", headers=headers)
assert response.status_code == 200, response.json()
Expand All @@ -177,9 +180,10 @@ def test_email_mask_for_authenticated_user(
response_json = response.json()
assert response.status_code == 200, response_json

assert len(response_json) == 2, response_json
assert response_json[0]["email"] == ["[email protected]"]
assert set(response_json[1]["email"]) == {"[email protected]", "[email protected]"}
data = response_json["data"] if isinstance(response_json, dict) else response_json
assert len(data) == 2, response_json
assert data[0]["email"] == ["[email protected]"]
assert set(data[1]["email"]) == {"[email protected]", "[email protected]"}

response = client.get("/platforms/aiod/contacts/fake:100", headers=headers)
response_json = response.json()
Expand Down Expand Up @@ -220,6 +224,8 @@ def test_email_privacy_for_ai4europe_cms(
endpoint = endpoint.replace("/1", f"/{contact.identifier}")
response = client.get(endpoint, headers=headers)
response_json = response.json()
if isinstance(response_json, dict) and "data" in response_json:
response_json = response_json["data"]
if isinstance(response_json, list):
response_json = response_json[0]

Expand All @@ -229,8 +235,11 @@ def test_email_privacy_for_ai4europe_cms(

keycloak_openid.introspect = AI4EUROPE_CMS_TOKEN

endpoint = endpoint.replace("/1", f"/{contact.identifier}")
response = client.get(endpoint, headers=headers)
response_json = response.json()
if isinstance(response_json, dict) and "data" in response_json:
response_json = response_json["data"]
if isinstance(response_json, list):
response_json = response_json[0]

Expand Down
4 changes: 4 additions & 0 deletions src/tests/routers/resource_routers/test_router_person.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ def test_privacy_for_ai4europe_cms(
endpoint = endpoint.replace("/1", f"/{person.identifier}")
response = client.get(endpoint, headers=headers)
response_json = response.json()
if isinstance(response_json, dict) and "data" in response_json:
response_json = response_json["data"]
response_json = [response_json] if isinstance(response_json, dict) else response_json
assert response.status_code == 200, response_json
for person_dict in response_json:
Expand All @@ -111,6 +113,8 @@ def test_privacy_for_ai4europe_cms(
keycloak_openid.introspect = AI4EUROPE_CMS_TOKEN
response = client.get(endpoint, headers=headers)
response_json = response.json()
if isinstance(response_json, dict) and "data" in response_json:
response_json = response_json["data"]
response_json = [response_json] if isinstance(response_json, dict) else response_json
assert response.status_code == 200, response_json
for person_dict in response_json:
Expand Down
2 changes: 1 addition & 1 deletion src/tests/routers/resource_routers/test_router_platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def test_happy_path(mocked_token: Mock, client: TestClient, auto_publish: None):
assert response.status_code == 200, response.json()
response = client.get("/platforms")
assert response.status_code == 200, response.json()
platforms = {p["name"] for p in response.json()}
platforms = {p["name"] for p in response.json()["data"]}
assert platforms == {p.name for p in PlatformName}.union(["my_favourite_platform"])


Expand Down
Loading