Skip to content

Commit 7697f2c

Browse files
committed
updated logic to exclude
1 parent 80df7fa commit 7697f2c

File tree

5 files changed

+324
-84
lines changed

5 files changed

+324
-84
lines changed

fastapi_jsonapi/querystring.py

+47-27
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
List,
99
Optional,
1010
Type,
11-
Union,
1211
)
1312
from urllib.parse import unquote
1413

@@ -26,9 +25,11 @@
2625
from fastapi_jsonapi.api import RoutersJSONAPI
2726
from fastapi_jsonapi.exceptions import (
2827
BadRequest,
28+
InvalidField,
2929
InvalidFilters,
3030
InvalidInclude,
3131
InvalidSort,
32+
InvalidType,
3233
)
3334
from fastapi_jsonapi.schema import (
3435
get_model_field,
@@ -89,30 +90,45 @@ def __init__(self, request: Request) -> None:
8990
self.MAX_INCLUDE_DEPTH: int = self.config.get("MAX_INCLUDE_DEPTH", 3)
9091
self.headers: HeadersQueryStringManager = HeadersQueryStringManager(**dict(self.request.headers))
9192

92-
def _get_key_values(self, name: str) -> Dict[str, Union[List[str], str]]:
93+
def _extract_item_key(self, key: str) -> str:
94+
try:
95+
key_start = key.index("[") + 1
96+
key_end = key.index("]")
97+
return key[key_start:key_end]
98+
except Exception:
99+
msg = "Parse error"
100+
raise BadRequest(msg, parameter=key)
101+
102+
def _get_unique_key_values(self, name: str) -> Dict[str, str]:
93103
"""
94104
Return a dict containing key / values items for a given key, used for items like filters, page, etc.
95105
96106
:param name: name of the querystring parameter
97107
:return: a dict of key / values items
98108
:raises BadRequest: if an error occurred while parsing the querystring.
99109
"""
100-
results = defaultdict(set)
110+
results = {}
101111

102112
for raw_key, value in self.qs.multi_items():
103113
key = unquote(raw_key)
104-
try:
105-
if not key.startswith(name):
106-
continue
114+
if not key.startswith(name):
115+
continue
107116

108-
key_start = key.index("[") + 1
109-
key_end = key.index("]")
110-
item_key = key[key_start:key_end]
117+
item_key = self._extract_item_key(key)
118+
results[item_key] = value
111119

112-
results[item_key].update(value.split(","))
113-
except Exception:
114-
msg = "Parse error"
115-
raise BadRequest(msg, parameter=key)
120+
return results
121+
122+
def _get_multiple_key_values(self, name: str) -> Dict[str, List]:
123+
results = defaultdict(list)
124+
125+
for raw_key, value in self.qs.multi_items():
126+
key = unquote(raw_key)
127+
if not key.startswith(name):
128+
continue
129+
130+
item_key = self._extract_item_key(key)
131+
results[item_key].extend(value.split(","))
116132

117133
return results
118134

@@ -131,7 +147,7 @@ def querystring(self) -> Dict[str, str]:
131147
return {
132148
key: value
133149
for (key, value) in self.qs.multi_items()
134-
if key.startswith(self.managed_keys) or self._get_key_values("filter[")
150+
if key.startswith(self.managed_keys) or self._get_unique_key_values("filter[")
135151
}
136152

137153
@property
@@ -156,8 +172,8 @@ def filters(self) -> List[dict]:
156172
raise InvalidFilters(msg)
157173

158174
results.extend(loaded_filters)
159-
if self._get_key_values("filter["):
160-
results.extend(self._simple_filters(self._get_key_values("filter[")))
175+
if filter_key_values := self._get_unique_key_values("filter["):
176+
results.extend(self._simple_filters(filter_key_values))
161177
return results
162178

163179
@cached_property
@@ -183,7 +199,7 @@ def pagination(self) -> PaginationQueryStringManager:
183199
:raises BadRequest: if the client is not allowed to disable pagination.
184200
"""
185201
# check values type
186-
pagination_data: Dict[str, Union[List[str], str]] = self._get_key_values("page")
202+
pagination_data: Dict[str, str] = self._get_unique_key_values("page")
187203
pagination = PaginationQueryStringManager(**pagination_data)
188204
if pagination_data.get("size") is None:
189205
pagination.size = None
@@ -213,23 +229,27 @@ def fields(self) -> Dict[str, List[str]]:
213229
214230
:raises InvalidField: if result field not in schema.
215231
"""
216-
fields = self._get_key_values("fields")
232+
fields = self._get_multiple_key_values("fields")
217233
for resource_type, field_names in fields.items():
218234
# TODO: we have registry for models (BaseModel)
219235
# TODO: create `type to schemas` registry
220236

221-
# schema: Type[BaseModel] = get_schema_from_type(key, self.app)
237+
if resource_type not in RoutersJSONAPI.all_jsonapi_routers:
238+
msg = f"Application has no resource with type {resource_type!r}"
239+
raise InvalidType(msg)
240+
241+
schema: Type[BaseModel] = RoutersJSONAPI.all_jsonapi_routers[resource_type]._schema
222242
self._get_schema(resource_type)
223243

224-
# for field_name in field_names:
225-
# if field_name not in schema.__fields__:
226-
# msg = "{schema} has no attribute {field}".format(
227-
# schema=schema.__name__,
228-
# field=field_name,
229-
# )
230-
# raise InvalidField(msg)
244+
for field_name in field_names:
245+
if field_name not in schema.__fields__:
246+
msg = "{schema} has no attribute {field}".format(
247+
schema=schema.__name__,
248+
field=field_name,
249+
)
250+
raise InvalidField(msg)
231251

232-
return fields
252+
return {resource_type: set(field_names) for resource_type, field_names in fields.items()}
233253

234254
def _get_schema(self, resource_type: str) -> Type[BaseModel]:
235255
target_router = RoutersJSONAPI.all_jsonapi_routers[resource_type]

fastapi_jsonapi/views/list_view.py

+4-38
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
import logging
2-
from typing import TYPE_CHECKING, Any, Dict
2+
from typing import TYPE_CHECKING, Any, Dict, Union
33

44
from fastapi_jsonapi.schema import (
55
BaseJSONAPIItemInSchema,
66
JSONAPIResultDetailSchema,
77
JSONAPIResultListSchema,
88
)
9-
from fastapi_jsonapi.views.utils import get_includes_indexes_by_type
9+
from fastapi_jsonapi.views.utils import handle_fields
1010
from fastapi_jsonapi.views.view_base import ViewBase
1111

1212
if TYPE_CHECKING:
@@ -15,31 +15,6 @@
1515
logger = logging.getLogger(__name__)
1616

1717

18-
def calculate_include_fields(response, query_params, jsonapi) -> Dict:
19-
included = "included" in response.__fields__ and response.included or []
20-
21-
include_params = {
22-
field_name: {*response.__fields__[field_name].type_.__fields__.keys()}
23-
for field_name in response.__fields__
24-
if field_name
25-
}
26-
include_params["included"] = {}
27-
28-
includes_indexes_by_type = get_includes_indexes_by_type(included)
29-
30-
for resource_type, field_names in query_params.fields.items():
31-
if resource_type == jsonapi.type_:
32-
include_params["data"] = {"__all__": {"attributes": field_names, "id": {"id"}, "type": {"type"}}}
33-
continue
34-
35-
target_type_indexes = includes_indexes_by_type.get(resource_type)
36-
37-
if resource_type in includes_indexes_by_type and target_type_indexes:
38-
include_params["included"].update((idx, field_names) for idx in target_type_indexes)
39-
40-
return include_params
41-
42-
4318
class ListViewBase(ViewBase):
4419
def _calculate_total_pages(self, db_items_count: int) -> int:
4520
total_pages = 1
@@ -60,23 +35,14 @@ async def get_data_layer(
6035
) -> "BaseDataLayer":
6136
return await self.get_data_layer_for_list(extra_view_deps)
6237

63-
async def handle_get_resource_list(self, **extra_view_deps) -> JSONAPIResultListSchema:
38+
async def handle_get_resource_list(self, **extra_view_deps) -> Union[JSONAPIResultListSchema, Dict]:
6439
dl: "BaseDataLayer" = await self.get_data_layer(extra_view_deps)
6540
query_params = self.query_params
6641
count, items_from_db = await dl.get_collection(qs=query_params)
6742
total_pages = self._calculate_total_pages(count)
6843

6944
response = self._build_list_response(items_from_db, count, total_pages)
70-
71-
if not query_params.fields:
72-
return response
73-
74-
include_params = calculate_include_fields(response, query_params, self.jsonapi)
75-
76-
if include_params:
77-
return response.dict(include=include_params)
78-
79-
return response
45+
return handle_fields(response, query_params, self.jsonapi)
8046

8147
async def handle_post_resource_list(
8248
self,

fastapi_jsonapi/views/utils.py

+121-4
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,39 @@
1+
from __future__ import annotations
2+
13
from collections import defaultdict
24
from enum import Enum
35
from functools import cache
4-
from typing import Callable, Coroutine, Dict, List, Optional, Set, Type, Union
6+
from typing import (
7+
TYPE_CHECKING,
8+
Any,
9+
Callable,
10+
Coroutine,
11+
Dict,
12+
Iterable,
13+
List,
14+
Optional,
15+
Set,
16+
Type,
17+
Union,
18+
)
519

620
from pydantic import BaseModel
21+
from pydantic.fields import ModelField
22+
23+
from fastapi_jsonapi.data_typing import TypeSchema
24+
from fastapi_jsonapi.schema import JSONAPIObjectSchema
25+
from fastapi_jsonapi.schema_builder import (
26+
JSONAPIResultDetailSchema,
27+
JSONAPIResultListSchema,
28+
)
29+
30+
if TYPE_CHECKING:
31+
from fastapi_jsonapi.api import RoutersJSONAPI
32+
from fastapi_jsonapi.querystring import QueryStringManager
33+
34+
35+
JSONAPIResponse = Union[JSONAPIResultDetailSchema, JSONAPIResultListSchema]
36+
IGNORE_ALL_FIELDS_LITERAL = ""
737

838

939
class HTTPMethod(Enum):
@@ -30,10 +60,97 @@ def handler(self) -> Optional[Union[Callable, Coroutine]]:
3060
return self.prepare_data_layer_kwargs
3161

3262

33-
def get_includes_indexes_by_type(included: List[Dict]) -> Dict[str, List[int]]:
63+
def _get_includes_indexes_by_type(included: List[JSONAPIObjectSchema]) -> Dict[str, List[int]]:
3464
result = defaultdict(list)
3565

36-
for idx, item in enumerate(included, 1):
37-
result[item["type"]].append(idx)
66+
for idx, item in enumerate(included):
67+
result[item.type].append(idx)
3868

3969
return result
70+
71+
72+
# TODO: move to schema builder?
73+
def _is_relationship_field(field: ModelField) -> bool:
74+
return "relationship" in field.field_info.extra
75+
76+
77+
def _get_schema_field_names(schema: Type[TypeSchema]) -> Set[str]:
78+
"""
79+
Returns all attribute names except relationships
80+
"""
81+
result = set()
82+
83+
for field_name, field in schema.__fields__.items():
84+
if _is_relationship_field(field):
85+
continue
86+
87+
result.add(field_name)
88+
89+
return result
90+
91+
92+
def _get_exclude_fields(
93+
schema: Type[TypeSchema],
94+
include_fields: Iterable[str],
95+
) -> Set[str]:
96+
schema_fields = _get_schema_field_names(schema)
97+
98+
if IGNORE_ALL_FIELDS_LITERAL in include_fields:
99+
return schema_fields
100+
101+
return set(_get_schema_field_names(schema)).difference(include_fields)
102+
103+
104+
def _calculate_exclude_fields(
105+
response: JSONAPIResponse,
106+
query_params: QueryStringManager,
107+
jsonapi: RoutersJSONAPI,
108+
) -> Dict:
109+
included = "included" in response.__fields__ and response.included or []
110+
is_list_response = isinstance(response, JSONAPIResultListSchema)
111+
112+
exclude_params: Dict[str, Any] = {}
113+
114+
includes_indexes_by_type = _get_includes_indexes_by_type(included)
115+
116+
for resource_type, field_names in query_params.fields.items():
117+
schema = jsonapi.all_jsonapi_routers[resource_type]._schema
118+
exclude_fields = _get_exclude_fields(schema, include_fields=field_names)
119+
attributes_exclude = {"attributes": exclude_fields}
120+
121+
if resource_type == jsonapi.type_:
122+
if is_list_response:
123+
exclude_params["data"] = {"__all__": attributes_exclude}
124+
else:
125+
exclude_params["data"] = attributes_exclude
126+
127+
continue
128+
129+
if not included:
130+
continue
131+
132+
target_type_indexes = includes_indexes_by_type.get(resource_type)
133+
134+
if target_type_indexes:
135+
if "included" not in exclude_params:
136+
exclude_params["included"] = {}
137+
138+
exclude_params["included"].update((idx, attributes_exclude) for idx in target_type_indexes)
139+
140+
return exclude_params
141+
142+
143+
def handle_fields(
144+
response: JSONAPIResponse,
145+
query_params: QueryStringManager,
146+
jsonapi: RoutersJSONAPI,
147+
) -> Union[JSONAPIResponse, Dict]:
148+
if not query_params.fields:
149+
return response
150+
151+
exclude_params = _calculate_exclude_fields(response, query_params, jsonapi)
152+
153+
if exclude_params:
154+
return response.dict(exclude=exclude_params)
155+
156+
return response

0 commit comments

Comments
 (0)