1313)
1414
1515from ..error import GraphQLError
16- from ..language import ast , OperationType
16+ from ..language import OperationType , ast
1717from ..pyutils import inspect , is_collection , is_description
1818from .definition import (
1919 GraphQLAbstractType ,
20- GraphQLInterfaceType ,
2120 GraphQLInputObjectType ,
21+ GraphQLInputType ,
22+ GraphQLInterfaceType ,
2223 GraphQLNamedType ,
2324 GraphQLObjectType ,
24- GraphQLUnionType ,
2525 GraphQLType ,
26+ GraphQLUnionType ,
2627 GraphQLWrappingType ,
2728 get_named_type ,
2829 is_input_object_type ,
3132 is_union_type ,
3233 is_wrapping_type ,
3334)
34- from .directives import GraphQLDirective , specified_directives , is_directive
35+ from .directives import GraphQLDirective , is_directive , specified_directives
3536from .introspection import introspection_types
3637
3738try :
@@ -310,8 +311,8 @@ def __copy__(self) -> "GraphQLSchema": # pragma: no cover
310311 def __deepcopy__ (self , memo_ : Dict ) -> "GraphQLSchema" :
311312 from ..type import (
312313 is_introspection_type ,
313- is_specified_scalar_type ,
314314 is_specified_directive ,
315+ is_specified_scalar_type ,
315316 )
316317
317318 type_map : TypeMap = {
@@ -326,6 +327,8 @@ def __deepcopy__(self, memo_: Dict) -> "GraphQLSchema":
326327 directive if is_specified_directive (directive ) else copy (directive )
327328 for directive in self .directives
328329 ]
330+ for directive in directives :
331+ remap_directive (directive , type_map )
329332 return self .__class__ (
330333 self .query_type and cast (GraphQLObjectType , type_map [self .query_type .name ]),
331334 self .mutation_type
@@ -461,12 +464,7 @@ def remapped_type(type_: GraphQLType, type_map: TypeMap) -> GraphQLType:
461464
462465def remap_named_type (type_ : GraphQLNamedType , type_map : TypeMap ) -> None :
463466 """Change all references in the given named type to use this type map."""
464- if is_union_type (type_ ):
465- type_ = cast (GraphQLUnionType , type_ )
466- type_ .types = [
467- type_map .get (member_type .name , member_type ) for member_type in type_ .types
468- ]
469- elif is_object_type (type_ ) or is_interface_type (type_ ):
467+ if is_object_type (type_ ) or is_interface_type (type_ ):
470468 type_ = cast (Union [GraphQLObjectType , GraphQLInterfaceType ], type_ )
471469 type_ .interfaces = [
472470 type_map .get (interface_type .name , interface_type )
@@ -482,10 +480,23 @@ def remap_named_type(type_: GraphQLNamedType, type_map: TypeMap) -> None:
482480 arg .type = remapped_type (arg .type , type_map )
483481 args [arg_name ] = arg
484482 fields [field_name ] = field
483+ elif is_union_type (type_ ):
484+ type_ = cast (GraphQLUnionType , type_ )
485+ type_ .types = [
486+ type_map .get (member_type .name , member_type ) for member_type in type_ .types
487+ ]
485488 elif is_input_object_type (type_ ):
486489 type_ = cast (GraphQLInputObjectType , type_ )
487490 fields = type_ .fields
488491 for field_name , field in fields .items ():
489492 field = copy (field )
490493 field .type = remapped_type (field .type , type_map )
491494 fields [field_name ] = field
495+
496+ def remap_directive (directive : GraphQLDirective , type_map : TypeMap ) -> None :
497+ """Change all references in the given directive to use this type map."""
498+ args = directive .args
499+ for arg_name , arg in args .items ():
500+ arg = copy (arg ) # noqa: PLW2901
501+ arg .type = cast (GraphQLInputType , remapped_type (arg .type , type_map ))
502+ args [arg_name ] = arg
0 commit comments