diff --git a/README.md b/README.md index 29a8d4a27..3821acd82 100644 --- a/README.md +++ b/README.md @@ -97,6 +97,50 @@ func (r *helloWorldResolver) Hello(ctx context.Context) (string, error) { } ``` +### Separate resolvers for different operations + +The GraphQL specification allows for fields with the same name defined in different query types. For example, the schema below is a valid schema definition: +```graphql +schema { + query: Query + mutation: Mutation +} + +type Query { + hello: String! +} + +type Mutation { + hello: String! +} +``` +The above schema would result in name collision if we use a single resolver struct because fields from both operations correspond to methods in the root resolver (the same Go struct). In order to resolve this issue, the library allows resolvers for query, mutation and subscription operations to be separated using the `Query`, `Mutation` and `Subscription` methods of the root resolver. These special methods are optional and if defined return the resolver for each opeartion. For example, the following is a resolver corresponding to the schema definition above. Note that there is a field named `hello` in both the query and the mutation definitions: + +```go +type RootResolver struct{} +type QueryResolver struct{} +type MutationResolver struct{} + +func(r *RootResolver) Query() *QueryResolver { + return &QueryResolver{} +} + +func(r *RootResolver) Mutation() *MutationResolver { + return &MutationResolver{} +} + +func (*QueryResolver) Hello() string { + return "Hello query!" +} + +func (*MutationResolver) Hello() string { + return "Hello mutation!" +} + +schema := graphql.MustParseSchema(sdl, &RootResolver{}, nil) +... +``` + ### Schema Options - `UseStringDescriptions()` enables the usage of double quoted and triple quoted. When this is not enabled, comments are parsed as descriptions instead. diff --git a/example/starwars/starwars.go b/example/starwars/starwars.go index 07cbb9f40..6dd570cef 100644 --- a/example/starwars/starwars.go +++ b/example/starwars/starwars.go @@ -286,14 +286,20 @@ var reviews = make(map[string][]*review) type Resolver struct{} -func (r *Resolver) Hero(args struct{ Episode string }) *characterResolver { +func (*Resolver) Query() *QueryResolver { + return &QueryResolver{} +} + +type QueryResolver struct{} + +func (r *QueryResolver) Hero(args struct{ Episode string }) *characterResolver { if args.Episode == "EMPIRE" { return &characterResolver{&humanResolver{humanData["1000"]}} } return &characterResolver{&droidResolver{droidData["2001"]}} } -func (r *Resolver) Reviews(args struct{ Episode string }) []*reviewResolver { +func (r *QueryResolver) Reviews(args struct{ Episode string }) []*reviewResolver { var l []*reviewResolver for _, review := range reviews[args.Episode] { l = append(l, &reviewResolver{review}) @@ -301,7 +307,7 @@ func (r *Resolver) Reviews(args struct{ Episode string }) []*reviewResolver { return l } -func (r *Resolver) Search(args struct{ Text string }) []*searchResultResolver { +func (r *QueryResolver) Search(args struct{ Text string }) []*searchResultResolver { var l []*searchResultResolver for _, h := range humans { if strings.Contains(h.Name, args.Text) { @@ -321,7 +327,7 @@ func (r *Resolver) Search(args struct{ Text string }) []*searchResultResolver { return l } -func (r *Resolver) Character(args struct{ ID graphql.ID }) *characterResolver { +func (r *QueryResolver) Character(args struct{ ID graphql.ID }) *characterResolver { if h := humanData[args.ID]; h != nil { return &characterResolver{&humanResolver{h}} } @@ -331,28 +337,34 @@ func (r *Resolver) Character(args struct{ ID graphql.ID }) *characterResolver { return nil } -func (r *Resolver) Human(args struct{ ID graphql.ID }) *humanResolver { +func (r *QueryResolver) Human(args struct{ ID graphql.ID }) *humanResolver { if h := humanData[args.ID]; h != nil { return &humanResolver{h} } return nil } -func (r *Resolver) Droid(args struct{ ID graphql.ID }) *droidResolver { +func (r *QueryResolver) Droid(args struct{ ID graphql.ID }) *droidResolver { if d := droidData[args.ID]; d != nil { return &droidResolver{d} } return nil } -func (r *Resolver) Starship(args struct{ ID graphql.ID }) *starshipResolver { +func (r *QueryResolver) Starship(args struct{ ID graphql.ID }) *starshipResolver { if s := starshipData[args.ID]; s != nil { return &starshipResolver{s} } return nil } -func (r *Resolver) CreateReview(args *struct { +func (*Resolver) Mutation() *MutationResolver { + return &MutationResolver{} +} + +type MutationResolver struct{} + +func (r *MutationResolver) CreateReview(args *struct { Episode string Review *reviewInput }) *reviewResolver { diff --git a/graphql.go b/graphql.go index bbe92e087..f134bc59f 100644 --- a/graphql.go +++ b/graphql.go @@ -247,7 +247,7 @@ func (s *Schema) ValidateWithVariables(queryString string, variables map[string] // without a resolver. If the context get cancelled, no further resolvers will be called and a // the context error will be returned as soon as possible (not immediately). func (s *Schema) Exec(ctx context.Context, queryString string, operationName string, variables map[string]interface{}) *Response { - if !s.res.Resolver.IsValid() { + if !s.res.QueryResolver.IsValid() { panic("schema created without resolver, can not exec") } execF := s.exec diff --git a/graphql_test.go b/graphql_test.go index c655d12be..90787d948 100644 --- a/graphql_test.go +++ b/graphql_test.go @@ -2,6 +2,7 @@ package graphql_test import ( "context" + "encoding/json" "errors" "fmt" "github.com/graph-gophers/graphql-go/internal/exec/resolvable" @@ -4858,3 +4859,306 @@ func TestQueryService(t *testing.T) { }, }) } + +type RootResolver struct{} +type QueryResolver struct{} +type MutationResolver struct{} +type SubscriptionResolver struct { + err error + upstream <-chan *helloEventResolver +} + +func (r *RootResolver) Query() *QueryResolver { + return &QueryResolver{} +} + +func (r *RootResolver) Mutation() *MutationResolver { + return &MutationResolver{} +} + +type helloEventResolver struct { + msg string + err error +} + +func (r *helloEventResolver) Msg() (string, error) { + return r.msg, r.err +} + +func closedHelloEventUpstream(rr ...*helloEventResolver) <-chan *helloEventResolver { + c := make(chan *helloEventResolver, len(rr)) + for _, r := range rr { + c <- r + } + close(c) + return c +} + +func (r *RootResolver) Subscription() *SubscriptionResolver { + return &SubscriptionResolver{ + upstream: closedHelloEventUpstream( + &helloEventResolver{msg: "Hello subscription!"}, + &helloEventResolver{err: errors.New("resolver error")}, + &helloEventResolver{msg: "Hello again!"}, + ), + } +} + +func (qr *QueryResolver) Hello() string { + return "Hello query!" +} + +func (mr *MutationResolver) Hello() string { + return "Hello mutation!" +} + +func (sr *SubscriptionResolver) Hello(ctx context.Context) (chan *helloEventResolver, error) { + if sr.err != nil { + return nil, sr.err + } + + c := make(chan *helloEventResolver) + go func() { + for r := range sr.upstream { + select { + case <-ctx.Done(): + close(c) + return + case c <- r: + } + } + close(c) + }() + + return c, nil +} + +type errRootResolver1 struct { + RootResolver +} + +// Query is invalid because it doesn't have a return value. +func (*errRootResolver1) Query() {} + +type errRootResolver2 struct { + RootResolver +} + +// Query is invalid because it has more than 1 return value +func (*errRootResolver2) Query() (*QueryResolver, error) { + return nil, nil +} + +type errRootResolver3 struct { + RootResolver +} + +// Mutation is invalid because it returns nil +func (*errRootResolver3) Mutation() *MutationResolver { + return nil +} + +type errRootResolver4 struct { + RootResolver +} + +// Query is invalid because it doesn't return a pointer. +func (*errRootResolver4) Query() MutationResolver { + return MutationResolver{} +} + +type errRootResolver5 struct { + RootResolver +} + +// Query is invalid because it returns *[]int instead of a resolver. +func (*errRootResolver5) Query() *[]int { + return &[]int{1, 2} +} + +type errRootResolver6 struct { + RootResolver +} + +// Mutation is invalid because it returns a map[string]int instead of a resolver. +func (*errRootResolver6) Mutation() map[string]int { + return map[string]int{"key": 3} +} + +type errRootResolver7 struct { + RootResolver +} + +// Subscription is invalid because it returns an invalid resolver. +func (*errRootResolver7) Subscription() interface{} { + a := struct { + Name string + }{Name: "invalid"} + return &a +} + +type errRootResolver8 struct { + RootResolver +} + +// Query is invalid because it accepts arguments. +func (*errRootResolver8) Query(ctx context.Context) *QueryResolver { + return &QueryResolver{} +} + +// TestSeparateResolvers ensures that a field with the same name is allowed in different operations +func TestSeparateResolvers(t *testing.T) { + helloEverywhere := ` + schema { + query: Query + mutation: Mutation + subscription: Subscription + } + + type Query { + hello: String! + } + + type Mutation { + hello: String! + } + + type Subscription { + hello: HelloEvent! + } + + type HelloEvent { + msg: String! + } + ` + + separateSchema := graphql.MustParseSchema(helloEverywhere, &RootResolver{}) + + gqltesting.RunTests(t, []*gqltesting.Test{ + { + Schema: separateSchema, + Query: ` + query { + hello + } + `, + ExpectedResult: ` + { + "hello": "Hello query!" + } + `, + }, + { + Schema: separateSchema, + Query: ` + mutation { + hello + } + `, + ExpectedResult: ` + { + "hello": "Hello mutation!" + } + `, + }, + }) + + gqltesting.RunSubscribes(t, []*gqltesting.TestSubscription{ + { + Name: "ok", + Schema: separateSchema, + Query: ` + subscription { + hello { + msg + } + } + `, + ExpectedResults: []gqltesting.TestResponse{ + { + Data: json.RawMessage(` + { + "hello": { + "msg": "Hello subscription!" + } + } + `), + }, + { + // null propagates all the way up because msg is non-null + Data: json.RawMessage(`null`), + Errors: []*gqlerrors.QueryError{gqlerrors.Errorf("%s", resolverErr)}, + }, + { + Data: json.RawMessage(` + { + "hello": { + "msg": "Hello again!" + } + } + `), + }, + }, + }, + }) + + // test errors with invalid resolvers + tests := []struct { + name string + resolver interface{} + opts []graphql.SchemaOpt + wantErr string + }{ + { + name: "query_method_has_no_return_val", + resolver: &errRootResolver1{}, + wantErr: "method \"Query\" of *graphql_test.errRootResolver1 must have 1 return value, got 0", + }, + { + name: "query_method_returns_too_many_vals", + resolver: &errRootResolver2{}, + wantErr: "method \"Query\" of *graphql_test.errRootResolver2 must have 1 return value, got 2", + }, + { + name: "mutation_method_returns_nil", + resolver: &errRootResolver3{}, + wantErr: "method \"Mutation\" of *graphql_test.errRootResolver3 must return a non-nil result, got ", + }, + { + name: "query_method_does_not_return_a_pointer", + resolver: &errRootResolver4{}, + wantErr: "method \"Query\" of *graphql_test.errRootResolver4 must return an interface or a pointer, got graphql_test.MutationResolver", + }, + { + name: "query_method_returns_invalid_resolver_type", + resolver: &errRootResolver5{}, + wantErr: "*[]int does not resolve \"Query\": missing method for field \"hello\"", + }, + { + name: "mutation_method_returns_invalid_resolver_type", + resolver: &errRootResolver6{}, + wantErr: "method \"Mutation\" of *graphql_test.errRootResolver6 must return an interface or a pointer, got map[string]int", + }, + { + name: "query_subscription_returns_invalid_resolver_type", + resolver: &errRootResolver7{}, + wantErr: "*struct { Name string } does not resolve \"Subscription\": missing method for field \"hello\"", + }, + { + name: "mutation_method_returns_invalid_resolver_type", + resolver: &errRootResolver8{}, + wantErr: "method \"Query\" of *graphql_test.errRootResolver8 must not accept any arguments, got 1", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := graphql.ParseSchema(helloEverywhere, tt.resolver, tt.opts...) + if err == nil { + t.Fatalf("want err: %q, got: ", tt.wantErr) + } + if err.Error() != tt.wantErr { + t.Fatalf("want err: %q, got: %q", tt.wantErr, err.Error()) + } + }) + } +} diff --git a/internal/exec/exec.go b/internal/exec/exec.go index 55b0a7fea..589129e40 100644 --- a/internal/exec/exec.go +++ b/internal/exec/exec.go @@ -45,7 +45,18 @@ func (r *Request) Execute(ctx context.Context, s *resolvable.Schema, op *types.O func() { defer r.handlePanic(ctx) sels := selected.ApplyOperation(&r.Request, s, op) - r.execSelections(ctx, sels, nil, s, s.Resolver, &out, op.Type == query.Mutation) + var resolver reflect.Value + switch op.Type { + case query.Query: + resolver = s.QueryResolver + case query.Mutation: + resolver = s.MutationResolver + case query.Subscription: + resolver = s.SubscriptionResolver + default: + panic("unknown query operation") + } + r.execSelections(ctx, sels, nil, s, resolver, &out, op.Type == query.Mutation) }() if err := ctx.Err(); err != nil { diff --git a/internal/exec/resolvable/resolvable.go b/internal/exec/resolvable/resolvable.go index 11ce3bab5..644acdea4 100644 --- a/internal/exec/resolvable/resolvable.go +++ b/internal/exec/resolvable/resolvable.go @@ -11,13 +11,21 @@ import ( "github.com/graph-gophers/graphql-go/types" ) +const ( + Query = "Query" + Mutation = "Mutation" + Subscription = "Subscription" +) + type Schema struct { *Meta types.Schema - Query Resolvable - Mutation Resolvable - Subscription Resolvable - Resolver reflect.Value + Query Resolvable + Mutation Resolvable + Subscription Resolvable + QueryResolver reflect.Value + MutationResolver reflect.Value + SubscriptionResolver reflect.Value } type Resolvable interface { @@ -70,20 +78,59 @@ func ApplyResolver(s *types.Schema, resolver interface{}) (*Schema, error) { var query, mutation, subscription Resolvable + resolvers := map[string]interface{}{} + + rv := reflect.ValueOf(resolver) + // use separate resolvers in case Query, Mutation and/or Subscription methods are defined + for _, op := range [...]string{Query, Mutation, Subscription} { + m := rv.MethodByName(op) // operation method + if m.IsValid() { + mt := m.Type() + if mt.NumIn() != 0 { + return nil, fmt.Errorf("method %q of %v must not accept any arguments, got %d", op, rv.Type(), mt.NumIn()) + } + if mt.NumOut() != 1 { + return nil, fmt.Errorf("method %q of %v must have 1 return value, got %d", op, rv.Type(), mt.NumOut()) + } + ot := mt.Out(0) + if ot.Kind() != reflect.Pointer && ot.Kind() != reflect.Interface { + return nil, fmt.Errorf("method %q of %v must return an interface or a pointer, got %+v", op, rv.Type(), ot) + } + out := m.Call(nil) + res := out[0] + if res.IsNil() { + return nil, fmt.Errorf("method %q of %v must return a non-nil result, got %v", op, rv.Type(), res) + } + switch res.Kind() { + case reflect.Pointer: + resolvers[op] = res.Elem().Addr().Interface() + case reflect.Interface: + resolvers[op] = res.Elem().Interface() + default: + panic("ureachable") + } + } + // If a method/field for the given operation is not defined in the root resolver, then share the + // root resolver for all the operations in order to ensure backwards compatibility. + if resolvers[op] == nil { + resolvers[op] = resolver + } + } + if t, ok := s.RootOperationTypes["query"]; ok { - if err := b.assignExec(&query, t, reflect.TypeOf(resolver)); err != nil { + if err := b.assignExec(&query, t, reflect.TypeOf(resolvers[Query])); err != nil { return nil, err } } if t, ok := s.RootOperationTypes["mutation"]; ok { - if err := b.assignExec(&mutation, t, reflect.TypeOf(resolver)); err != nil { + if err := b.assignExec(&mutation, t, reflect.TypeOf(resolvers[Mutation])); err != nil { return nil, err } } if t, ok := s.RootOperationTypes["subscription"]; ok { - if err := b.assignExec(&subscription, t, reflect.TypeOf(resolver)); err != nil { + if err := b.assignExec(&subscription, t, reflect.TypeOf(resolvers[Subscription])); err != nil { return nil, err } } @@ -93,12 +140,14 @@ func ApplyResolver(s *types.Schema, resolver interface{}) (*Schema, error) { } return &Schema{ - Meta: newMeta(s), - Schema: *s, - Resolver: reflect.ValueOf(resolver), - Query: query, - Mutation: mutation, - Subscription: subscription, + Meta: newMeta(s), + Schema: *s, + QueryResolver: reflect.ValueOf(resolvers[Query]), + MutationResolver: reflect.ValueOf(resolvers[Mutation]), + SubscriptionResolver: reflect.ValueOf(resolvers[Subscription]), + Query: query, + Mutation: mutation, + Subscription: subscription, }, nil } diff --git a/internal/exec/subscribe.go b/internal/exec/subscribe.go index 37ebacbc9..110c9541a 100644 --- a/internal/exec/subscribe.go +++ b/internal/exec/subscribe.go @@ -28,7 +28,7 @@ func (r *Request) Subscribe(ctx context.Context, s *resolvable.Schema, op *types sels := selected.ApplyOperation(&r.Request, s, op) var fields []*fieldToExec - collectFieldsToResolve(sels, s, s.Resolver, &fields, make(map[string]*fieldToExec)) + collectFieldsToResolve(sels, s, s.SubscriptionResolver, &fields, make(map[string]*fieldToExec)) // TODO: move this check into validation.Validate if len(fields) != 1 { diff --git a/subscriptions.go b/subscriptions.go index cf060cd24..91e84c0d8 100644 --- a/subscriptions.go +++ b/subscriptions.go @@ -20,7 +20,7 @@ import ( // further resolvers will be called. The context error will be returned as soon // as possible (not immediately). func (s *Schema) Subscribe(ctx context.Context, queryString string, operationName string, variables map[string]interface{}) (<-chan interface{}, error) { - if !s.res.Resolver.IsValid() { + if !s.res.SubscriptionResolver.IsValid() { return nil, errors.New("schema created without resolver, can not subscribe") } if _, ok := s.schema.RootOperationTypes["subscription"]; !ok {