26
26
from fastapi_jsonapi .data_typing import TypeModel , TypeSchema
27
27
from fastapi_jsonapi .exceptions import InvalidFilters , InvalidType
28
28
from fastapi_jsonapi .exceptions .json_api import HTTPException
29
- from fastapi_jsonapi .schema import get_model_field , get_relationships
29
+ from fastapi_jsonapi .schema import JSONAPISchemaIntrospectionError , get_model_field , get_relationships
30
30
31
31
log = logging .getLogger (__name__ )
32
32
@@ -44,7 +44,7 @@ class RelationshipFilteringInfo(BaseModel):
44
44
target_schema : Type [TypeSchema ]
45
45
model : Type [TypeModel ]
46
46
aliased_model : AliasedClass
47
- column : InstrumentedAttribute
47
+ join_column : InstrumentedAttribute
48
48
49
49
class Config :
50
50
arbitrary_types_allowed = True
@@ -288,7 +288,10 @@ def get_model_column(
288
288
schema : Type [TypeSchema ],
289
289
field_name : str ,
290
290
) -> InstrumentedAttribute :
291
- model_field = get_model_field (schema , field_name )
291
+ try :
292
+ model_field = get_model_field (schema , field_name )
293
+ except JSONAPISchemaIntrospectionError as e :
294
+ raise InvalidFilters (str (e ))
292
295
293
296
try :
294
297
return getattr (model , model_field )
@@ -327,8 +330,9 @@ def gather_relationships_info(
327
330
model : Type [TypeModel ],
328
331
schema : Type [TypeSchema ],
329
332
relationship_path : List [str ],
330
- collected_info : dict ,
333
+ collected_info : dict [ RelationshipPath , RelationshipFilteringInfo ] ,
331
334
target_relationship_idx : int = 0 ,
335
+ prev_aliased_model : Optional [Any ] = None ,
332
336
) -> dict [RelationshipPath , RelationshipFilteringInfo ]:
333
337
is_last_relationship = target_relationship_idx == len (relationship_path ) - 1
334
338
target_relationship_path = RELATIONSHIP_SPLITTER .join (
@@ -342,25 +346,36 @@ def gather_relationships_info(
342
346
343
347
target_schema = schema .__fields__ [target_relationship_name ].type_
344
348
target_model = getattr (model , target_relationship_name ).property .mapper .class_
345
- target_column = get_model_column (
346
- model ,
347
- schema ,
348
- target_relationship_name ,
349
- )
349
+
350
+ if prev_aliased_model :
351
+ join_column = get_model_column (
352
+ model = prev_aliased_model ,
353
+ schema = schema ,
354
+ field_name = target_relationship_name ,
355
+ )
356
+ else :
357
+ join_column = get_model_column (
358
+ model ,
359
+ schema ,
360
+ target_relationship_name ,
361
+ )
362
+
363
+ aliased_model = aliased (target_model )
350
364
collected_info [target_relationship_path ] = RelationshipFilteringInfo (
351
365
target_schema = target_schema ,
352
366
model = target_model ,
353
- aliased_model = aliased ( target_model ) ,
354
- column = target_column ,
367
+ aliased_model = aliased_model ,
368
+ join_column = join_column ,
355
369
)
356
370
357
371
if not is_last_relationship :
358
372
return gather_relationships_info (
359
- target_model ,
360
- target_schema ,
361
- relationship_path ,
362
- collected_info ,
363
- target_relationship_idx + 1 ,
373
+ model = target_model ,
374
+ schema = target_schema ,
375
+ relationship_path = relationship_path ,
376
+ collected_info = collected_info ,
377
+ target_relationship_idx = target_relationship_idx + 1 ,
378
+ prev_aliased_model = aliased_model ,
364
379
)
365
380
366
381
return collected_info
@@ -553,5 +568,5 @@ def create_filters_and_joins(
553
568
target_schema = schema ,
554
569
relationships_info = relationships_info ,
555
570
)
556
- joins = [(info .aliased_model , info .column ) for info in relationships_info .values ()]
571
+ joins = [(info .aliased_model , info .join_column ) for info in relationships_info .values ()]
557
572
return expressions , joins
0 commit comments