Skip to content

Commit 08d1f7f

Browse files
authored
Merge pull request #80 from mts-ai/feature/customizable-request-attributes
Feature/customizable request attributes
2 parents 7015aa1 + ebc59b5 commit 08d1f7f

File tree

10 files changed

+679
-70
lines changed

10 files changed

+679
-70
lines changed

Diff for: fastapi_jsonapi/querystring.py

+53-36
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""Helper to deal with querystring parameters according to jsonapi specification."""
2+
from collections import defaultdict
23
from functools import cached_property
34
from typing import (
45
TYPE_CHECKING,
@@ -7,7 +8,6 @@
78
List,
89
Optional,
910
Type,
10-
Union,
1111
)
1212
from urllib.parse import unquote
1313

@@ -22,17 +22,18 @@
2222
)
2323
from starlette.datastructures import QueryParams
2424

25+
from fastapi_jsonapi.api import RoutersJSONAPI
2526
from fastapi_jsonapi.exceptions import (
2627
BadRequest,
2728
InvalidField,
2829
InvalidFilters,
2930
InvalidInclude,
3031
InvalidSort,
32+
InvalidType,
3133
)
3234
from fastapi_jsonapi.schema import (
3335
get_model_field,
3436
get_relationships,
35-
get_schema_from_type,
3637
)
3738
from fastapi_jsonapi.splitter import SPLIT_REL
3839

@@ -89,33 +90,45 @@ def __init__(self, request: Request) -> None:
8990
self.MAX_INCLUDE_DEPTH: int = self.config.get("MAX_INCLUDE_DEPTH", 3)
9091
self.headers: HeadersQueryStringManager = HeadersQueryStringManager(**dict(self.request.headers))
9192

92-
def _get_key_values(self, name: str) -> Dict[str, Union[List[str], str]]:
93+
def _extract_item_key(self, key: str) -> str:
94+
try:
95+
key_start = key.index("[") + 1
96+
key_end = key.index("]")
97+
return key[key_start:key_end]
98+
except Exception:
99+
msg = "Parse error"
100+
raise BadRequest(msg, parameter=key)
101+
102+
def _get_unique_key_values(self, name: str) -> Dict[str, str]:
93103
"""
94104
Return a dict containing key / values items for a given key, used for items like filters, page, etc.
95105
96106
:param name: name of the querystring parameter
97107
:return: a dict of key / values items
98108
:raises BadRequest: if an error occurred while parsing the querystring.
99109
"""
100-
results: Dict[str, Union[List[str], str]] = {}
110+
results = {}
101111

102112
for raw_key, value in self.qs.multi_items():
103113
key = unquote(raw_key)
104-
try:
105-
if not key.startswith(name):
106-
continue
114+
if not key.startswith(name):
115+
continue
107116

108-
key_start = key.index("[") + 1
109-
key_end = key.index("]")
110-
item_key = key[key_start:key_end]
117+
item_key = self._extract_item_key(key)
118+
results[item_key] = value
111119

112-
if "," in value:
113-
results.update({item_key: value.split(",")})
114-
else:
115-
results.update({item_key: value})
116-
except Exception:
117-
msg = "Parse error"
118-
raise BadRequest(msg, parameter=key)
120+
return results
121+
122+
def _get_multiple_key_values(self, name: str) -> Dict[str, List]:
123+
results = defaultdict(list)
124+
125+
for raw_key, value in self.qs.multi_items():
126+
key = unquote(raw_key)
127+
if not key.startswith(name):
128+
continue
129+
130+
item_key = self._extract_item_key(key)
131+
results[item_key].extend(value.split(","))
119132

120133
return results
121134

@@ -134,7 +147,7 @@ def querystring(self) -> Dict[str, str]:
134147
return {
135148
key: value
136149
for (key, value) in self.qs.multi_items()
137-
if key.startswith(self.managed_keys) or self._get_key_values("filter[")
150+
if key.startswith(self.managed_keys) or self._get_unique_key_values("filter[")
138151
}
139152

140153
@property
@@ -159,8 +172,8 @@ def filters(self) -> List[dict]:
159172
raise InvalidFilters(msg)
160173

161174
results.extend(loaded_filters)
162-
if self._get_key_values("filter["):
163-
results.extend(self._simple_filters(self._get_key_values("filter[")))
175+
if filter_key_values := self._get_unique_key_values("filter["):
176+
results.extend(self._simple_filters(filter_key_values))
164177
return results
165178

166179
@cached_property
@@ -186,7 +199,7 @@ def pagination(self) -> PaginationQueryStringManager:
186199
:raises BadRequest: if the client is not allowed to disable pagination.
187200
"""
188201
# check values type
189-
pagination_data: Dict[str, Union[List[str], str]] = self._get_key_values("page")
202+
pagination_data: Dict[str, str] = self._get_unique_key_values("page")
190203
pagination = PaginationQueryStringManager(**pagination_data)
191204
if pagination_data.get("size") is None:
192205
pagination.size = None
@@ -199,8 +212,6 @@ def pagination(self) -> PaginationQueryStringManager:
199212

200213
return pagination
201214

202-
# TODO: finally use this! upgrade Sqlachemy Data Layer
203-
# and add to all views (get list/detail, create, patch)
204215
@property
205216
def fields(self) -> Dict[str, List[str]]:
206217
"""
@@ -216,26 +227,32 @@ def fields(self) -> Dict[str, List[str]]:
216227
217228
:raises InvalidField: if result field not in schema.
218229
"""
219-
if self.request.method != "GET":
220-
msg = "attribute 'fields' allowed only for GET-method"
221-
raise InvalidField(msg)
222-
fields = self._get_key_values("fields")
223-
for key, value in fields.items():
224-
if not isinstance(value, list):
225-
value = [value] # noqa: PLW2901
226-
fields[key] = value
230+
fields = self._get_multiple_key_values("fields")
231+
for resource_type, field_names in fields.items():
227232
# TODO: we have registry for models (BaseModel)
228233
# TODO: create `type to schemas` registry
229-
schema: Type[BaseModel] = get_schema_from_type(key, self.app)
230-
for field in value:
231-
if field not in schema.__fields__:
234+
235+
if resource_type not in RoutersJSONAPI.all_jsonapi_routers:
236+
msg = f"Application has no resource with type {resource_type!r}"
237+
raise InvalidType(msg)
238+
239+
schema: Type[BaseModel] = self._get_schema(resource_type)
240+
241+
for field_name in field_names:
242+
if field_name == "":
243+
continue
244+
245+
if field_name not in schema.__fields__:
232246
msg = "{schema} has no attribute {field}".format(
233247
schema=schema.__name__,
234-
field=field,
248+
field=field_name,
235249
)
236250
raise InvalidField(msg)
237251

238-
return fields
252+
return {resource_type: set(field_names) for resource_type, field_names in fields.items()}
253+
254+
def _get_schema(self, resource_type: str) -> Type[BaseModel]:
255+
return RoutersJSONAPI.all_jsonapi_routers[resource_type]._schema
239256

240257
def get_sorts(self, schema: Type["TypeSchema"]) -> List[Dict[str, str]]:
241258
"""

Diff for: fastapi_jsonapi/views/detail_view.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
BaseJSONAPIItemInSchema,
1313
JSONAPIResultDetailSchema,
1414
)
15+
from fastapi_jsonapi.views.utils import handle_jsonapi_fields
1516
from fastapi_jsonapi.views.view_base import ViewBase
1617

1718
if TYPE_CHECKING:
@@ -34,22 +35,24 @@ async def handle_get_resource_detail(
3435
self,
3536
object_id: Union[int, str],
3637
**extra_view_deps,
37-
):
38+
) -> Union[JSONAPIResultDetailSchema, Dict]:
3839
dl: "BaseDataLayer" = await self.get_data_layer(extra_view_deps)
3940

4041
view_kwargs = {dl.url_id_field: object_id}
4142
db_object = await dl.get_object(view_kwargs=view_kwargs, qs=self.query_params)
4243

43-
return self._build_detail_response(db_object)
44+
response = self._build_detail_response(db_object)
45+
return handle_jsonapi_fields(response, self.query_params, self.jsonapi)
4446

4547
async def handle_update_resource(
4648
self,
4749
obj_id: str,
4850
data_update: BaseJSONAPIItemInSchema,
4951
**extra_view_deps,
50-
) -> JSONAPIResultDetailSchema:
52+
) -> Union[JSONAPIResultDetailSchema, Dict]:
5153
dl: "BaseDataLayer" = await self.get_data_layer(extra_view_deps)
52-
return await self.process_update_object(dl=dl, obj_id=obj_id, data_update=data_update)
54+
response = await self.process_update_object(dl=dl, obj_id=obj_id, data_update=data_update)
55+
return handle_jsonapi_fields(response, self.query_params, self.jsonapi)
5356

5457
async def process_update_object(
5558
self,

Diff for: fastapi_jsonapi/views/list_view.py

+10-6
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import logging
2-
from typing import TYPE_CHECKING, Any, Dict
2+
from typing import TYPE_CHECKING, Any, Dict, Union
33

44
from fastapi_jsonapi.schema import (
55
BaseJSONAPIItemInSchema,
66
JSONAPIResultDetailSchema,
77
JSONAPIResultListSchema,
88
)
9+
from fastapi_jsonapi.views.utils import handle_jsonapi_fields
910
from fastapi_jsonapi.views.view_base import ViewBase
1011

1112
if TYPE_CHECKING:
@@ -34,21 +35,23 @@ async def get_data_layer(
3435
) -> "BaseDataLayer":
3536
return await self.get_data_layer_for_list(extra_view_deps)
3637

37-
async def handle_get_resource_list(self, **extra_view_deps) -> JSONAPIResultListSchema:
38+
async def handle_get_resource_list(self, **extra_view_deps) -> Union[JSONAPIResultListSchema, Dict]:
3839
dl: "BaseDataLayer" = await self.get_data_layer(extra_view_deps)
3940
query_params = self.query_params
4041
count, items_from_db = await dl.get_collection(qs=query_params)
4142
total_pages = self._calculate_total_pages(count)
4243

43-
return self._build_list_response(items_from_db, count, total_pages)
44+
response = self._build_list_response(items_from_db, count, total_pages)
45+
return handle_jsonapi_fields(response, query_params, self.jsonapi)
4446

4547
async def handle_post_resource_list(
4648
self,
4749
data_create: BaseJSONAPIItemInSchema,
4850
**extra_view_deps,
49-
) -> JSONAPIResultDetailSchema:
51+
) -> Union[JSONAPIResultDetailSchema, Dict]:
5052
dl: "BaseDataLayer" = await self.get_data_layer(extra_view_deps)
51-
return await self.process_create_object(dl=dl, data_create=data_create)
53+
response = await self.process_create_object(dl=dl, data_create=data_create)
54+
return handle_jsonapi_fields(response, self.query_params, self.jsonapi)
5255

5356
async def process_create_object(self, dl: "BaseDataLayer", data_create: BaseJSONAPIItemInSchema):
5457
created_object = await dl.create_object(data_create=data_create, view_kwargs={})
@@ -68,4 +71,5 @@ async def handle_delete_resource_list(self, **extra_view_deps) -> JSONAPIResultL
6871

6972
await dl.delete_objects(items_from_db, {})
7073

71-
return self._build_list_response(items_from_db, count, total_pages)
74+
response = self._build_list_response(items_from_db, count, total_pages)
75+
return handle_jsonapi_fields(response, self.query_params, self.jsonapi)

0 commit comments

Comments
 (0)