@@ -297,7 +297,7 @@ def __getattr__(self, name: str) -> "DSLType":
297297
298298 assert isinstance (type_def , (GraphQLObjectType , GraphQLInterfaceType ))
299299
300- return DSLType (type_def )
300+ return DSLType (type_def , self )
301301
302302
303303class DSLSelector (ABC ):
@@ -454,7 +454,27 @@ def is_valid_field(self, field: "DSLSelectable") -> bool:
454454 return operation_name != "SUBSCRIPTION"
455455
456456 elif isinstance (field , DSLField ):
457- return field .parent_type .name .upper () == operation_name
457+
458+ assert field .dsl_type is not None
459+
460+ schema = field .dsl_type ._dsl_schema ._schema
461+
462+ root_type = None
463+
464+ if operation_name == "QUERY" :
465+ root_type = schema .query_type
466+ elif operation_name == "MUTATION" :
467+ root_type = schema .mutation_type
468+ elif operation_name == "SUBSCRIPTION" :
469+ root_type = schema .subscription_type
470+
471+ if root_type is None :
472+ log .error (
473+ f"Root type of type { operation_name } not found in the schema!"
474+ )
475+ return False
476+
477+ return field .parent_type .name == root_type .name
458478
459479 return False
460480
@@ -585,16 +605,22 @@ class DSLType:
585605 instances of :class:`DSLField`
586606 """
587607
588- def __init__ (self , graphql_type : Union [GraphQLObjectType , GraphQLInterfaceType ]):
608+ def __init__ (
609+ self ,
610+ graphql_type : Union [GraphQLObjectType , GraphQLInterfaceType ],
611+ dsl_schema : DSLSchema ,
612+ ):
589613 """Initialize the DSLType with the GraphQL type.
590614
591615 .. warning::
592616 Don't instantiate this class yourself.
593617 Use attributes of the :class:`DSLSchema` instead.
594618
595619 :param graphql_type: the GraphQL type definition from the schema
620+ :param dsl_schema: reference to the DSLSchema which created this type
596621 """
597622 self ._type : Union [GraphQLObjectType , GraphQLInterfaceType ] = graphql_type
623+ self ._dsl_schema = dsl_schema
598624 log .debug (f"Creating { self !r} )" )
599625
600626 def __getattr__ (self , name : str ) -> "DSLField" :
@@ -611,7 +637,7 @@ def __getattr__(self, name: str) -> "DSLField":
611637 f"Field { name } does not exist in type { self ._type .name } ."
612638 )
613639
614- return DSLField (formatted_name , self ._type , field )
640+ return DSLField (formatted_name , self ._type , field , self )
615641
616642 def __repr__ (self ) -> str :
617643 return f"<{ self .__class__ .__name__ } { self ._type !r} >"
@@ -763,6 +789,7 @@ def __init__(
763789 name : str ,
764790 parent_type : Union [GraphQLObjectType , GraphQLInterfaceType ],
765791 field : GraphQLField ,
792+ dsl_type : Optional [DSLType ] = None ,
766793 ):
767794 """Initialize the DSLField.
768795
@@ -774,10 +801,12 @@ def __init__(
774801 :param parent_type: the GraphQL type definition from the schema of the
775802 parent type of the field
776803 :param field: the GraphQL field definition from the schema
804+ :param dsl_type: reference of the DSLType instance which created this field
777805 """
778806 self .parent_type = parent_type
779807 self .field = field
780808 self .ast_field = FieldNode (name = NameNode (value = name ), arguments = ())
809+ self .dsl_type = dsl_type
781810
782811 log .debug (f"Creating { self !r} " )
783812
0 commit comments