1
1
"""Helper to deal with querystring parameters according to jsonapi specification."""
2
+ from collections import defaultdict
2
3
from functools import cached_property
3
4
from typing import (
4
5
TYPE_CHECKING ,
7
8
List ,
8
9
Optional ,
9
10
Type ,
10
- Union ,
11
11
)
12
12
from urllib .parse import unquote
13
13
22
22
)
23
23
from starlette .datastructures import QueryParams
24
24
25
+ from fastapi_jsonapi .api import RoutersJSONAPI
25
26
from fastapi_jsonapi .exceptions import (
26
27
BadRequest ,
27
28
InvalidField ,
28
29
InvalidFilters ,
29
30
InvalidInclude ,
30
31
InvalidSort ,
32
+ InvalidType ,
31
33
)
32
34
from fastapi_jsonapi .schema import (
33
35
get_model_field ,
34
36
get_relationships ,
35
- get_schema_from_type ,
36
37
)
37
38
from fastapi_jsonapi .splitter import SPLIT_REL
38
39
@@ -89,33 +90,45 @@ def __init__(self, request: Request) -> None:
89
90
self .MAX_INCLUDE_DEPTH : int = self .config .get ("MAX_INCLUDE_DEPTH" , 3 )
90
91
self .headers : HeadersQueryStringManager = HeadersQueryStringManager (** dict (self .request .headers ))
91
92
92
- def _get_key_values (self , name : str ) -> Dict [str , Union [List [str ], str ]]:
93
+ def _extract_item_key (self , key : str ) -> str :
94
+ try :
95
+ key_start = key .index ("[" ) + 1
96
+ key_end = key .index ("]" )
97
+ return key [key_start :key_end ]
98
+ except Exception :
99
+ msg = "Parse error"
100
+ raise BadRequest (msg , parameter = key )
101
+
102
+ def _get_unique_key_values (self , name : str ) -> Dict [str , str ]:
93
103
"""
94
104
Return a dict containing key / values items for a given key, used for items like filters, page, etc.
95
105
96
106
:param name: name of the querystring parameter
97
107
:return: a dict of key / values items
98
108
:raises BadRequest: if an error occurred while parsing the querystring.
99
109
"""
100
- results : Dict [ str , Union [ List [ str ], str ]] = {}
110
+ results = {}
101
111
102
112
for raw_key , value in self .qs .multi_items ():
103
113
key = unquote (raw_key )
104
- try :
105
- if not key .startswith (name ):
106
- continue
114
+ if not key .startswith (name ):
115
+ continue
107
116
108
- key_start = key .index ("[" ) + 1
109
- key_end = key .index ("]" )
110
- item_key = key [key_start :key_end ]
117
+ item_key = self ._extract_item_key (key )
118
+ results [item_key ] = value
111
119
112
- if "," in value :
113
- results .update ({item_key : value .split ("," )})
114
- else :
115
- results .update ({item_key : value })
116
- except Exception :
117
- msg = "Parse error"
118
- raise BadRequest (msg , parameter = key )
120
+ return results
121
+
122
+ def _get_multiple_key_values (self , name : str ) -> Dict [str , List ]:
123
+ results = defaultdict (list )
124
+
125
+ for raw_key , value in self .qs .multi_items ():
126
+ key = unquote (raw_key )
127
+ if not key .startswith (name ):
128
+ continue
129
+
130
+ item_key = self ._extract_item_key (key )
131
+ results [item_key ].extend (value .split ("," ))
119
132
120
133
return results
121
134
@@ -134,7 +147,7 @@ def querystring(self) -> Dict[str, str]:
134
147
return {
135
148
key : value
136
149
for (key , value ) in self .qs .multi_items ()
137
- if key .startswith (self .managed_keys ) or self ._get_key_values ("filter[" )
150
+ if key .startswith (self .managed_keys ) or self ._get_unique_key_values ("filter[" )
138
151
}
139
152
140
153
@property
@@ -159,8 +172,8 @@ def filters(self) -> List[dict]:
159
172
raise InvalidFilters (msg )
160
173
161
174
results .extend (loaded_filters )
162
- if self ._get_key_values ("filter[" ):
163
- results .extend (self ._simple_filters (self . _get_key_values ( "filter[" ) ))
175
+ if filter_key_values := self ._get_unique_key_values ("filter[" ):
176
+ results .extend (self ._simple_filters (filter_key_values ))
164
177
return results
165
178
166
179
@cached_property
@@ -186,7 +199,7 @@ def pagination(self) -> PaginationQueryStringManager:
186
199
:raises BadRequest: if the client is not allowed to disable pagination.
187
200
"""
188
201
# check values type
189
- pagination_data : Dict [str , Union [ List [ str ], str ]] = self ._get_key_values ("page" )
202
+ pagination_data : Dict [str , str ] = self ._get_unique_key_values ("page" )
190
203
pagination = PaginationQueryStringManager (** pagination_data )
191
204
if pagination_data .get ("size" ) is None :
192
205
pagination .size = None
@@ -199,8 +212,6 @@ def pagination(self) -> PaginationQueryStringManager:
199
212
200
213
return pagination
201
214
202
- # TODO: finally use this! upgrade Sqlachemy Data Layer
203
- # and add to all views (get list/detail, create, patch)
204
215
@property
205
216
def fields (self ) -> Dict [str , List [str ]]:
206
217
"""
@@ -216,26 +227,32 @@ def fields(self) -> Dict[str, List[str]]:
216
227
217
228
:raises InvalidField: if result field not in schema.
218
229
"""
219
- if self .request .method != "GET" :
220
- msg = "attribute 'fields' allowed only for GET-method"
221
- raise InvalidField (msg )
222
- fields = self ._get_key_values ("fields" )
223
- for key , value in fields .items ():
224
- if not isinstance (value , list ):
225
- value = [value ] # noqa: PLW2901
226
- fields [key ] = value
230
+ fields = self ._get_multiple_key_values ("fields" )
231
+ for resource_type , field_names in fields .items ():
227
232
# TODO: we have registry for models (BaseModel)
228
233
# TODO: create `type to schemas` registry
229
- schema : Type [BaseModel ] = get_schema_from_type (key , self .app )
230
- for field in value :
231
- if field not in schema .__fields__ :
234
+
235
+ if resource_type not in RoutersJSONAPI .all_jsonapi_routers :
236
+ msg = f"Application has no resource with type { resource_type !r} "
237
+ raise InvalidType (msg )
238
+
239
+ schema : Type [BaseModel ] = self ._get_schema (resource_type )
240
+
241
+ for field_name in field_names :
242
+ if field_name == "" :
243
+ continue
244
+
245
+ if field_name not in schema .__fields__ :
232
246
msg = "{schema} has no attribute {field}" .format (
233
247
schema = schema .__name__ ,
234
- field = field ,
248
+ field = field_name ,
235
249
)
236
250
raise InvalidField (msg )
237
251
238
- return fields
252
+ return {resource_type : set (field_names ) for resource_type , field_names in fields .items ()}
253
+
254
+ def _get_schema (self , resource_type : str ) -> Type [BaseModel ]:
255
+ return RoutersJSONAPI .all_jsonapi_routers [resource_type ]._schema
239
256
240
257
def get_sorts (self , schema : Type ["TypeSchema" ]) -> List [Dict [str , str ]]:
241
258
"""
0 commit comments