forked from mts-ai/FastAPI-JSONAPI
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathquerystring.py
311 lines (249 loc) · 10.3 KB
/
querystring.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
"""Helper to deal with querystring parameters according to jsonapi specification."""
from collections import defaultdict
from functools import cached_property
from typing import (
TYPE_CHECKING,
Any,
Dict,
List,
Optional,
Type,
)
from urllib.parse import unquote
import simplejson as json
from fastapi import (
FastAPI,
Request,
)
from pydantic import (
BaseModel,
Field,
)
from starlette.datastructures import QueryParams
from fastapi_jsonapi.api import RoutersJSONAPI
from fastapi_jsonapi.exceptions import (
BadRequest,
InvalidField,
InvalidFilters,
InvalidInclude,
InvalidSort,
InvalidType,
)
from fastapi_jsonapi.schema import (
get_model_field,
get_relationships,
)
from fastapi_jsonapi.splitter import SPLIT_REL
if TYPE_CHECKING:
from fastapi_jsonapi.data_typing import TypeSchema
class PaginationQueryStringManager(BaseModel):
"""
Pagination query string manager.
Contains info about offsets, sizes, number and limits of query with pagination.
"""
offset: Optional[int] = None
size: Optional[int] = 25
number: int = 1
limit: Optional[int] = None
class HeadersQueryStringManager(BaseModel):
"""
Header query string manager.
Contains info about request headers.
"""
host: Optional[str] = None
connection: Optional[str] = None
accept: Optional[str] = None
user_agent: Optional[str] = Field(None, alias="user-agent")
referer: Optional[str] = None
accept_encoding: Optional[str] = Field(None, alias="accept-encoding")
accept_language: Optional[str] = Field(None, alias="accept-language")
class QueryStringManager:
"""Querystring parser according to jsonapi reference."""
managed_keys = ("filter", "page", "fields", "sort", "include", "q")
def __init__(self, request: Request) -> None:
"""
Initialize instance.
:param request
"""
self.request: Request = request
self.app: FastAPI = request.app
self.qs: QueryParams = request.query_params
self.config: Dict[str, Any] = getattr(self.app, "config", {})
self.ALLOW_DISABLE_PAGINATION: bool = self.config.get("ALLOW_DISABLE_PAGINATION", True)
self.MAX_PAGE_SIZE: int = self.config.get("MAX_PAGE_SIZE", 10000)
self.MAX_INCLUDE_DEPTH: int = self.config.get("MAX_INCLUDE_DEPTH", 3)
self.headers: HeadersQueryStringManager = HeadersQueryStringManager(**dict(self.request.headers))
def _extract_item_key(self, key: str) -> str:
try:
key_start = key.index("[") + 1
key_end = key.index("]")
return key[key_start:key_end]
except Exception:
msg = "Parse error"
raise BadRequest(msg, parameter=key)
def _get_unique_key_values(self, name: str) -> Dict[str, str]:
"""
Return a dict containing key / values items for a given key, used for items like filters, page, etc.
:param name: name of the querystring parameter
:return: a dict of key / values items
:raises BadRequest: if an error occurred while parsing the querystring.
"""
results = {}
for raw_key, value in self.qs.multi_items():
key = unquote(raw_key)
if not key.startswith(name):
continue
item_key = self._extract_item_key(key)
results[item_key] = value
return results
def _get_multiple_key_values(self, name: str) -> Dict[str, List]:
results = defaultdict(list)
for raw_key, value in self.qs.multi_items():
key = unquote(raw_key)
if not key.startswith(name):
continue
item_key = self._extract_item_key(key)
results[item_key].extend(value.split(","))
return results
@classmethod
def _simple_filters(cls, dict_: Dict[str, Any]) -> List[Dict[str, Any]]:
"""Filter creation."""
return [{"name": key, "op": "eq", "val": value} for (key, value) in dict_.items()]
@property
def querystring(self) -> Dict[str, str]:
"""
Return original querystring but containing only managed keys.
:return: dict of managed querystring parameter
"""
return {
key: value
for (key, value) in self.qs.multi_items()
if key.startswith(self.managed_keys) or self._get_unique_key_values("filter[")
}
@property
def filters(self) -> List[dict]:
"""
Return filters from query string.
:return: filter information
:raises InvalidFilters: if filter loading from json has failed.
"""
results = []
filters = self.qs.get("filter")
if filters is not None:
try:
loaded_filters = json.loads(filters)
except (ValueError, TypeError):
msg = "Parse error"
raise InvalidFilters(msg)
if not isinstance(loaded_filters, list):
msg = f"Incorrect filters format, expected list of conditions but got {type(loaded_filters).__name__}"
raise InvalidFilters(msg)
results.extend(loaded_filters)
if filter_key_values := self._get_unique_key_values("filter["):
results.extend(self._simple_filters(filter_key_values))
return results
@cached_property
def pagination(self) -> PaginationQueryStringManager:
"""
Return all page parameters as a dict.
:return: a dict of pagination information.
To allow multiples strategies, all parameters starting with `page` will be included. e.g::
{
"number": '25',
"size": '150',
}
Example with number strategy:
query_string = {'page[number]': '25', 'page[size]': '10'}
parsed_query.pagination
{'number': '25', 'size': '10'}
:raises BadRequest: if the client is not allowed to disable pagination.
"""
# check values type
pagination_data: Dict[str, str] = self._get_unique_key_values("page")
pagination = PaginationQueryStringManager(**pagination_data)
if pagination_data.get("size") is None:
pagination.size = None
if pagination.size:
if self.ALLOW_DISABLE_PAGINATION is False and pagination.size == 0:
msg = "You are not allowed to disable pagination"
raise BadRequest(msg, parameter="page[size]")
if self.MAX_PAGE_SIZE and pagination.size > self.MAX_PAGE_SIZE:
pagination.size = self.MAX_PAGE_SIZE
return pagination
@property
def fields(self) -> Dict[str, List[str]]:
"""
Return fields wanted by client.
:return: a dict of sparse fieldsets information
Return value will be a dict containing all fields by resource, for example::
{
"user": ['name', 'email'],
}
:raises InvalidField: if result field not in schema.
"""
fields = self._get_multiple_key_values("fields")
for resource_type, field_names in fields.items():
# TODO: we have registry for models (BaseModel)
# TODO: create `type to schemas` registry
if resource_type not in RoutersJSONAPI.all_jsonapi_routers:
msg = f"Application has no resource with type {resource_type!r}"
raise InvalidType(msg)
schema: Type[BaseModel] = self._get_schema(resource_type)
for field_name in field_names:
if field_name == "":
continue
if field_name not in schema.__fields__:
msg = "{schema} has no attribute {field}".format(
schema=schema.__name__,
field=field_name,
)
raise InvalidField(msg)
return {resource_type: set(field_names) for resource_type, field_names in fields.items()}
def _get_schema(self, resource_type: str) -> Type[BaseModel]:
return RoutersJSONAPI.all_jsonapi_routers[resource_type]._schema
def get_sorts(self, schema: Type["TypeSchema"]) -> List[Dict[str, str]]:
"""
Return fields to sort by including sort name for SQLAlchemy and row sort parameter for other ORMs.
:return: a list of sorting information
Example of return value::
[
{'field': 'created_at', 'order': 'desc'},
]
:raises InvalidSort: if sort field wrong.
"""
if sort_q := self.qs.get("sort"):
sorting_results = []
for sort_field in sort_q.split(","):
field = sort_field.replace("-", "")
if SPLIT_REL not in field:
if field not in schema.__fields__:
msg = "{schema} has no attribute {field}".format(
schema=schema.__name__,
field=field,
)
raise InvalidSort(msg)
if field in get_relationships(schema):
msg = "You can't sort on {field} because it is a relationship field".format(field=field)
raise InvalidSort(msg)
field = get_model_field(schema, field)
order = "desc" if sort_field.startswith("-") else "asc"
sorting_results.append({"field": field, "order": order})
return sorting_results
return []
@property
def include(self) -> List[str]:
"""
Return fields to include.
:return: a list of include information.
:raises InvalidInclude: if nesting is more than MAX_INCLUDE_DEPTH.
"""
include_param: str = self.qs.get("include")
includes = include_param.split(",") if include_param and isinstance(include_param, str) else []
if self.MAX_INCLUDE_DEPTH is not None:
for include_path in includes:
if len(include_path.split(SPLIT_REL)) > self.MAX_INCLUDE_DEPTH:
msg = "You can't use include through more than {max_include_depth} relationships".format(
max_include_depth=self.MAX_INCLUDE_DEPTH,
)
raise InvalidInclude(msg)
return includes