Skip to content

Commit f8ec40b

Browse files
committed
added simple cache
1 parent b9cfe93 commit f8ec40b

File tree

4 files changed

+273
-8
lines changed

4 files changed

+273
-8
lines changed

fastapi_jsonapi/api.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ def __init__(
7878
pagination_default_offset: Optional[int] = None,
7979
pagination_default_limit: Optional[int] = None,
8080
methods: Iterable[str] = (),
81+
max_cache_size: int = 0,
8182
) -> None:
8283
"""
8384
Initialize router items.
@@ -127,7 +128,7 @@ def __init__(
127128
self.pagination_default_number: Optional[int] = pagination_default_number
128129
self.pagination_default_offset: Optional[int] = pagination_default_offset
129130
self.pagination_default_limit: Optional[int] = pagination_default_limit
130-
self.schema_builder = SchemaBuilder(resource_type=resource_type)
131+
self.schema_builder = SchemaBuilder(resource_type=resource_type, max_cache_size=max_cache_size)
131132

132133
dto = self.schema_builder.create_schemas(
133134
schema=schema,

fastapi_jsonapi/schema_builder.py

+45-6
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""JSON API schemas builder class."""
22
from dataclasses import dataclass
3+
from functools import lru_cache
34
from typing import (
45
Any,
56
Callable,
@@ -119,11 +120,15 @@ class SchemaBuilder:
119120
relationship_schema_cache: ClassVar = {}
120121
base_jsonapi_object_schemas_cache: ClassVar = {}
121122

122-
def __init__(
123-
self,
124-
resource_type: str,
125-
):
123+
def __init__(self, resource_type: str, max_cache_size: int = 0):
126124
self._resource_type = resource_type
125+
self._init_cache(max_cache_size)
126+
127+
def _init_cache(self, max_cache_size: int):
128+
# TODO: remove crutch
129+
self._get_info_from_schema_for_building_cached = lru_cache(maxsize=max_cache_size)(
130+
self._get_info_from_schema_for_building_cached,
131+
)
127132

128133
def _create_schemas_objects_list(self, schema: Type[BaseModel]) -> Type[JSONAPIResultListSchema]:
129134
object_jsonapi_list_schema, list_jsonapi_schema = self.build_list_schemas(schema)
@@ -187,7 +192,7 @@ def build_schema_in(
187192
) -> Tuple[Type[BaseJSONAPIDataInSchema], Type[BaseJSONAPIItemInSchema]]:
188193
base_schema_name = schema_in.__name__.removesuffix("Schema") + schema_name_suffix
189194

190-
dto = self._get_info_from_schema_for_building(
195+
dto = self._get_info_from_schema_for_building_wrapper(
191196
base_name=base_schema_name,
192197
schema=schema_in,
193198
non_optional_relationships=non_optional_relationships,
@@ -258,6 +263,40 @@ def build_list_schemas(
258263
includes=includes,
259264
)
260265

266+
def _get_info_from_schema_for_building_cached(
267+
self,
268+
base_name: str,
269+
schema: Type[BaseModel],
270+
includes: Iterable[str],
271+
non_optional_relationships: bool,
272+
):
273+
return self._get_info_from_schema_for_building(
274+
base_name=base_name,
275+
schema=schema,
276+
includes=includes,
277+
non_optional_relationships=non_optional_relationships,
278+
)
279+
280+
def _get_info_from_schema_for_building_wrapper(
281+
self,
282+
base_name: str,
283+
schema: Type[BaseModel],
284+
includes: Iterable[str] = not_passed,
285+
non_optional_relationships: bool = False,
286+
):
287+
"""
288+
Wrapper function for return cached schema result
289+
"""
290+
if includes is not not_passed:
291+
includes = tuple(includes)
292+
293+
return self._get_info_from_schema_for_building_cached(
294+
base_name=base_name,
295+
schema=schema,
296+
includes=includes,
297+
non_optional_relationships=non_optional_relationships,
298+
)
299+
261300
def _get_info_from_schema_for_building(
262301
self,
263302
base_name: str,
@@ -494,7 +533,7 @@ def create_jsonapi_object_schemas(
494533
if includes is not not_passed:
495534
includes = set(includes)
496535

497-
dto = self._get_info_from_schema_for_building(
536+
dto = self._get_info_from_schema_for_building_wrapper(
498537
base_name=base_name,
499538
schema=schema,
500539
includes=includes,

tests/fixtures/app.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -232,10 +232,11 @@ def build_app_custom(
232232
resource_type: str = "misc",
233233
class_list: Type[ListViewBase] = ListViewBaseGeneric,
234234
class_detail: Type[DetailViewBase] = DetailViewBaseGeneric,
235+
max_cache_size: int = 0,
235236
) -> FastAPI:
236237
router: APIRouter = APIRouter()
237238

238-
RoutersJSONAPI(
239+
jsonapi_routers = RoutersJSONAPI(
239240
router=router,
240241
path=path,
241242
tags=["Misc"],
@@ -246,6 +247,7 @@ def build_app_custom(
246247
schema_in_patch=schema_in_patch,
247248
schema_in_post=schema_in_post,
248249
model=model,
250+
max_cache_size=max_cache_size,
249251
)
250252

251253
app = build_app_plain()
@@ -254,6 +256,9 @@ def build_app_custom(
254256
atomic = AtomicOperations()
255257
app.include_router(atomic.router, prefix="")
256258
init(app)
259+
260+
app.jsonapi_routers = jsonapi_routers
261+
257262
return app
258263

259264

tests/test_api/test_api_sqla_with_includes.py

+220
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from itertools import chain, zip_longest
77
from json import dumps, loads
88
from typing import Dict, List, Literal, Set, Tuple
9+
from unittest.mock import call, patch
910
from uuid import UUID, uuid4
1011

1112
import pytest
@@ -20,6 +21,7 @@
2021
from starlette.datastructures import QueryParams
2122

2223
from fastapi_jsonapi.api import RoutersJSONAPI
24+
from fastapi_jsonapi.schema_builder import SchemaBuilder
2325
from fastapi_jsonapi.views.view_base import ViewBase
2426
from tests.common import is_postgres_tests
2527
from tests.fixtures.app import build_alphabet_app, build_app_custom
@@ -52,6 +54,8 @@
5254
CustomUUIDItemAttributesSchema,
5355
PostAttributesBaseSchema,
5456
PostCommentAttributesBaseSchema,
57+
PostCommentSchema,
58+
PostSchema,
5559
SelfRelationshipAttributesSchema,
5660
SelfRelationshipSchema,
5761
UserAttributesBaseSchema,
@@ -360,6 +364,215 @@ async def test_select_custom_fields_for_includes_without_requesting_includes(
360364
"meta": {"count": 1, "totalPages": 1},
361365
}
362366

367+
def _get_clear_mock_calls(self, mock_obj) -> list[call]:
368+
mock_calls = mock_obj.mock_calls
369+
return [call_ for call_ in mock_calls if call_ not in [call.__len__(), call.__str__()]]
370+
371+
def _prepare_info_schema_calls_to_assert(self, mock_calls) -> list[call]:
372+
calls_to_check = []
373+
for wrapper_call in mock_calls:
374+
kwargs = wrapper_call.kwargs
375+
kwargs["includes"] = sorted(kwargs["includes"], key=lambda x: x)
376+
377+
calls_to_check.append(
378+
call(
379+
*wrapper_call.args,
380+
**kwargs,
381+
),
382+
)
383+
384+
return sorted(
385+
calls_to_check,
386+
key=lambda x: (x.kwargs["base_name"], x.kwargs["includes"]),
387+
)
388+
389+
async def test_check_get_info_schema_cache(
390+
self,
391+
user_1: User,
392+
):
393+
resource_type = "user_with_cache"
394+
with suppress(KeyError):
395+
RoutersJSONAPI.all_jsonapi_routers.pop(resource_type)
396+
397+
app_with_cache = build_app_custom(
398+
model=User,
399+
schema=UserSchema,
400+
schema_in_post=UserInSchemaAllowIdOnPost,
401+
schema_in_patch=UserPatchSchema,
402+
resource_type=resource_type,
403+
# set cache size to enable caching
404+
max_cache_size=128,
405+
)
406+
407+
target_func_name = "_get_info_from_schema_for_building"
408+
url = app_with_cache.url_path_for(f"get_{resource_type}_list")
409+
params = {
410+
"include": "posts,posts.comments",
411+
}
412+
413+
expected_len_with_cache = 6
414+
expected_len_without_cache = 10
415+
416+
with patch.object(
417+
SchemaBuilder,
418+
target_func_name,
419+
wraps=app_with_cache.jsonapi_routers.schema_builder._get_info_from_schema_for_building,
420+
) as wrapped_func:
421+
async with AsyncClient(app=app_with_cache, base_url="http://test") as client:
422+
response = await client.get(url, params=params)
423+
assert response.status_code == status.HTTP_200_OK, response.text
424+
425+
calls_to_check = self._prepare_info_schema_calls_to_assert(self._get_clear_mock_calls(wrapped_func))
426+
427+
# there are no duplicated calls
428+
assert calls_to_check == sorted(
429+
[
430+
call(
431+
base_name="UserSchema",
432+
schema=UserSchema,
433+
includes=["posts"],
434+
non_optional_relationships=False,
435+
),
436+
call(
437+
base_name="UserSchema",
438+
schema=UserSchema,
439+
includes=["posts", "posts.comments"],
440+
non_optional_relationships=False,
441+
),
442+
call(
443+
base_name="PostSchema",
444+
schema=PostSchema,
445+
includes=[],
446+
non_optional_relationships=False,
447+
),
448+
call(
449+
base_name="PostSchema",
450+
schema=PostSchema,
451+
includes=["comments"],
452+
non_optional_relationships=False,
453+
),
454+
call(
455+
base_name="PostCommentSchema",
456+
schema=PostCommentSchema,
457+
includes=[],
458+
non_optional_relationships=False,
459+
),
460+
call(
461+
base_name="PostCommentSchema",
462+
schema=PostCommentSchema,
463+
includes=["posts"],
464+
non_optional_relationships=False,
465+
),
466+
],
467+
key=lambda x: (x.kwargs["base_name"], x.kwargs["includes"]),
468+
)
469+
assert wrapped_func.call_count == expected_len_with_cache
470+
471+
response = await client.get(url, params=params)
472+
assert response.status_code == status.HTTP_200_OK, response.text
473+
474+
# there are no new calls
475+
assert wrapped_func.call_count == expected_len_with_cache
476+
477+
resource_type = "user_without_cache"
478+
with suppress(KeyError):
479+
RoutersJSONAPI.all_jsonapi_routers.pop(resource_type)
480+
481+
app_without_cache = build_app_custom(
482+
model=User,
483+
schema=UserSchema,
484+
schema_in_post=UserInSchemaAllowIdOnPost,
485+
schema_in_patch=UserPatchSchema,
486+
resource_type=resource_type,
487+
max_cache_size=0,
488+
)
489+
490+
with patch.object(
491+
SchemaBuilder,
492+
target_func_name,
493+
wraps=app_without_cache.jsonapi_routers.schema_builder._get_info_from_schema_for_building,
494+
) as wrapped_func:
495+
async with AsyncClient(app=app_without_cache, base_url="http://test") as client:
496+
response = await client.get(url, params=params)
497+
assert response.status_code == status.HTTP_200_OK, response.text
498+
499+
calls_to_check = self._prepare_info_schema_calls_to_assert(self._get_clear_mock_calls(wrapped_func))
500+
501+
# there are duplicated calls
502+
assert calls_to_check == sorted(
503+
[
504+
call(
505+
base_name="UserSchema",
506+
schema=UserSchema,
507+
includes=["posts"],
508+
non_optional_relationships=False,
509+
),
510+
call(
511+
base_name="UserSchema",
512+
schema=UserSchema,
513+
includes=["posts"],
514+
non_optional_relationships=False,
515+
), # duplicate
516+
call(
517+
base_name="UserSchema",
518+
schema=UserSchema,
519+
includes=["posts", "posts.comments"],
520+
non_optional_relationships=False,
521+
),
522+
call(
523+
base_name="PostSchema",
524+
schema=PostSchema,
525+
includes=[],
526+
non_optional_relationships=False,
527+
),
528+
call(
529+
base_name="PostSchema",
530+
schema=PostSchema,
531+
includes=[],
532+
non_optional_relationships=False,
533+
), # duplicate
534+
call(
535+
base_name="PostSchema",
536+
schema=PostSchema,
537+
includes=[],
538+
non_optional_relationships=False,
539+
), # duplicate
540+
call(
541+
base_name="PostSchema",
542+
schema=PostSchema,
543+
includes=["comments"],
544+
non_optional_relationships=False,
545+
),
546+
call(
547+
base_name="PostSchema",
548+
schema=PostSchema,
549+
includes=["comments"],
550+
non_optional_relationships=False,
551+
), # duplicate
552+
call(
553+
base_name="PostCommentSchema",
554+
schema=PostCommentSchema,
555+
includes=[],
556+
non_optional_relationships=False,
557+
),
558+
call(
559+
base_name="PostCommentSchema",
560+
schema=PostCommentSchema,
561+
includes=["posts"],
562+
non_optional_relationships=False,
563+
), # duplicate
564+
],
565+
key=lambda x: (x.kwargs["base_name"], x.kwargs["includes"]),
566+
)
567+
568+
assert wrapped_func.call_count == expected_len_without_cache
569+
570+
response = await client.get(url, params=params)
571+
assert response.status_code == status.HTTP_200_OK, response.text
572+
573+
# there are new calls
574+
assert wrapped_func.call_count == expected_len_without_cache * 2
575+
363576

364577
class TestCreatePostAndComments:
365578
async def test_get_posts_with_users(
@@ -371,6 +584,13 @@ async def test_get_posts_with_users(
371584
user_1_posts: List[Post],
372585
user_2_posts: List[Post],
373586
):
587+
call(
588+
base_name="UserSchema",
589+
schema=UserSchema,
590+
includes=["posts"],
591+
non_optional_relationships=False,
592+
on_optional_relationships=False,
593+
)
374594
url = app.url_path_for("get_post_list")
375595
url = f"{url}?include=user"
376596
response = await client.get(url)

0 commit comments

Comments
 (0)