8
8
List ,
9
9
Optional ,
10
10
Type ,
11
- Union ,
12
11
)
13
12
from urllib .parse import unquote
14
13
26
25
from fastapi_jsonapi .api import RoutersJSONAPI
27
26
from fastapi_jsonapi .exceptions import (
28
27
BadRequest ,
28
+ InvalidField ,
29
29
InvalidFilters ,
30
30
InvalidInclude ,
31
31
InvalidSort ,
32
+ InvalidType ,
32
33
)
33
34
from fastapi_jsonapi .schema import (
34
35
get_model_field ,
@@ -89,30 +90,45 @@ def __init__(self, request: Request) -> None:
89
90
self .MAX_INCLUDE_DEPTH : int = self .config .get ("MAX_INCLUDE_DEPTH" , 3 )
90
91
self .headers : HeadersQueryStringManager = HeadersQueryStringManager (** dict (self .request .headers ))
91
92
92
- def _get_key_values (self , name : str ) -> Dict [str , Union [List [str ], str ]]:
93
+ def _extract_item_key (self , key : str ) -> str :
94
+ try :
95
+ key_start = key .index ("[" ) + 1
96
+ key_end = key .index ("]" )
97
+ return key [key_start :key_end ]
98
+ except Exception :
99
+ msg = "Parse error"
100
+ raise BadRequest (msg , parameter = key )
101
+
102
+ def _get_unique_key_values (self , name : str ) -> Dict [str , str ]:
93
103
"""
94
104
Return a dict containing key / values items for a given key, used for items like filters, page, etc.
95
105
96
106
:param name: name of the querystring parameter
97
107
:return: a dict of key / values items
98
108
:raises BadRequest: if an error occurred while parsing the querystring.
99
109
"""
100
- results = defaultdict ( set )
110
+ results = {}
101
111
102
112
for raw_key , value in self .qs .multi_items ():
103
113
key = unquote (raw_key )
104
- try :
105
- if not key .startswith (name ):
106
- continue
114
+ if not key .startswith (name ):
115
+ continue
107
116
108
- key_start = key .index ("[" ) + 1
109
- key_end = key .index ("]" )
110
- item_key = key [key_start :key_end ]
117
+ item_key = self ._extract_item_key (key )
118
+ results [item_key ] = value
111
119
112
- results [item_key ].update (value .split ("," ))
113
- except Exception :
114
- msg = "Parse error"
115
- raise BadRequest (msg , parameter = key )
120
+ return results
121
+
122
+ def _get_multiple_key_values (self , name : str ) -> Dict [str , List ]:
123
+ results = defaultdict (list )
124
+
125
+ for raw_key , value in self .qs .multi_items ():
126
+ key = unquote (raw_key )
127
+ if not key .startswith (name ):
128
+ continue
129
+
130
+ item_key = self ._extract_item_key (key )
131
+ results [item_key ].extend (value .split ("," ))
116
132
117
133
return results
118
134
@@ -131,7 +147,7 @@ def querystring(self) -> Dict[str, str]:
131
147
return {
132
148
key : value
133
149
for (key , value ) in self .qs .multi_items ()
134
- if key .startswith (self .managed_keys ) or self ._get_key_values ("filter[" )
150
+ if key .startswith (self .managed_keys ) or self ._get_unique_key_values ("filter[" )
135
151
}
136
152
137
153
@property
@@ -156,8 +172,8 @@ def filters(self) -> List[dict]:
156
172
raise InvalidFilters (msg )
157
173
158
174
results .extend (loaded_filters )
159
- if self ._get_key_values ("filter[" ):
160
- results .extend (self ._simple_filters (self . _get_key_values ( "filter[" ) ))
175
+ if filter_key_values := self ._get_unique_key_values ("filter[" ):
176
+ results .extend (self ._simple_filters (filter_key_values ))
161
177
return results
162
178
163
179
@cached_property
@@ -183,7 +199,7 @@ def pagination(self) -> PaginationQueryStringManager:
183
199
:raises BadRequest: if the client is not allowed to disable pagination.
184
200
"""
185
201
# check values type
186
- pagination_data : Dict [str , Union [ List [ str ], str ]] = self ._get_key_values ("page" )
202
+ pagination_data : Dict [str , str ] = self ._get_unique_key_values ("page" )
187
203
pagination = PaginationQueryStringManager (** pagination_data )
188
204
if pagination_data .get ("size" ) is None :
189
205
pagination .size = None
@@ -213,23 +229,27 @@ def fields(self) -> Dict[str, List[str]]:
213
229
214
230
:raises InvalidField: if result field not in schema.
215
231
"""
216
- fields = self ._get_key_values ("fields" )
232
+ fields = self ._get_multiple_key_values ("fields" )
217
233
for resource_type , field_names in fields .items ():
218
234
# TODO: we have registry for models (BaseModel)
219
235
# TODO: create `type to schemas` registry
220
236
221
- # schema: Type[BaseModel] = get_schema_from_type(key, self.app)
237
+ if resource_type not in RoutersJSONAPI .all_jsonapi_routers :
238
+ msg = f"Application has no resource with type { resource_type !r} "
239
+ raise InvalidType (msg )
240
+
241
+ schema : Type [BaseModel ] = RoutersJSONAPI .all_jsonapi_routers [resource_type ]._schema
222
242
self ._get_schema (resource_type )
223
243
224
- # for field_name in field_names:
225
- # if field_name not in schema.__fields__:
226
- # msg = "{schema} has no attribute {field}".format(
227
- # schema=schema.__name__,
228
- # field=field_name,
229
- # )
230
- # raise InvalidField(msg)
244
+ for field_name in field_names :
245
+ if field_name not in schema .__fields__ :
246
+ msg = "{schema} has no attribute {field}" .format (
247
+ schema = schema .__name__ ,
248
+ field = field_name ,
249
+ )
250
+ raise InvalidField (msg )
231
251
232
- return fields
252
+ return { resource_type : set ( field_names ) for resource_type , field_names in fields . items ()}
233
253
234
254
def _get_schema (self , resource_type : str ) -> Type [BaseModel ]:
235
255
target_router = RoutersJSONAPI .all_jsonapi_routers [resource_type ]
0 commit comments