Skip to content

Commit

Permalink
Add rate limiting.
Browse files Browse the repository at this point in the history
Signed-off-by: Jakub Martin <[email protected]>
  • Loading branch information
cube2222 committed Jun 6, 2022
1 parent 5a1c172 commit 9c27f9c
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 3 deletions.
31 changes: 28 additions & 3 deletions graphql.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"encoding/json"
"fmt"
"net/http"
"time"

"github.com/graph-gophers/graphql-go/errors"
Expand All @@ -16,6 +17,8 @@ import (
"github.com/graph-gophers/graphql-go/internal/validation"
"github.com/graph-gophers/graphql-go/introspection"
"github.com/graph-gophers/graphql-go/log"
"github.com/graph-gophers/graphql-go/ratelimit"
noopratelimit "github.com/graph-gophers/graphql-go/ratelimit/noop"
"github.com/graph-gophers/graphql-go/trace/noop"
"github.com/graph-gophers/graphql-go/trace/tracer"
"github.com/graph-gophers/graphql-go/types"
Expand All @@ -28,6 +31,7 @@ func ParseSchema(schemaString string, resolver interface{}, opts ...SchemaOpt) (
s := &Schema{
schema: schema.New(),
maxParallelism: 10,
rateLimiter: &noopratelimit.RateLimiter{},
tracer: noop.Tracer{},
logger: &log.DefaultLogger{},
panicHandler: &errors.DefaultPanicHandler{},
Expand Down Expand Up @@ -76,6 +80,7 @@ type Schema struct {

maxDepth int
maxParallelism int
rateLimiter ratelimit.RateLimiter
tracer tracer.Tracer
validationTracer tracer.ValidationTracer
logger log.Logger
Expand Down Expand Up @@ -123,6 +128,13 @@ func MaxParallelism(n int) SchemaOpt {
}
}

// RateLimiter is used to rate limit queries.
func RateLimiter(r ratelimit.RateLimiter) SchemaOpt {
return func(s *Schema) {
s.rateLimiter = r
}
}

// Tracer is used to trace queries and fields. It defaults to tracer.Noop.
func Tracer(t tracer.Tracer) SchemaOpt {
return func(s *Schema) {
Expand Down Expand Up @@ -176,6 +188,8 @@ type Response struct {
Errors []*errors.QueryError `json:"errors,omitempty"`
Data json.RawMessage `json:"data,omitempty"`
Extensions map[string]interface{} `json:"extensions,omitempty"`
// Optional forced StatusCode.
StatusCode *int `json:"-,omitempty"`
}

// Validate validates the given query with the schema.
Expand Down Expand Up @@ -268,12 +282,23 @@ func (s *Schema) exec(ctx context.Context, queryString string, operationName str
varTypes[v.Name.Name] = introspection.WrapType(t)
}
traceCtx, finish := s.tracer.TraceQuery(ctx, queryString, operationName, variables, varTypes)
data, errs := r.Execute(traceCtx, res, op)

var data []byte
var statusCode *int
if !s.rateLimiter.LimitQuery(ctx, queryString, operationName, variables, varTypes) {
data, errs = r.Execute(traceCtx, res, op)
} else {
errs = []*errors.QueryError{{Message: "rate limit exceeded"}}
code := http.StatusTooManyRequests
statusCode = &code
}

finish(errs)

return &Response{
Data: data,
Errors: errs,
Data: data,
Errors: errs,
StatusCode: statusCode,
}
}

Expand Down
14 changes: 14 additions & 0 deletions ratelimit/noop/ratelimit.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package noop

import (
"context"

"github.com/graph-gophers/graphql-go/introspection"
)

// RateLimiter is a no-op rate limiter that does nothing.
type RateLimiter struct{}

func (r *RateLimiter) LimitQuery(ctx context.Context, queryString string, operationName string, variables map[string]interface{}, varTypes map[string]*introspection.Type) bool {
return false
}
11 changes: 11 additions & 0 deletions ratelimit/ratelimit.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package ratelimit

import (
"context"

"github.com/graph-gophers/graphql-go/introspection"
)

type RateLimiter interface {
LimitQuery(ctx context.Context, queryString string, operationName string, variables map[string]interface{}, varTypes map[string]*introspection.Type) bool
}
3 changes: 3 additions & 0 deletions relay/relay.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
if response.StatusCode != nil {
w.WriteHeader(*response.StatusCode)
}

w.Header().Set("Content-Type", "application/json")
w.Write(responseJSON)
Expand Down

0 comments on commit 9c27f9c

Please sign in to comment.