Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 22 additions & 6 deletions router/core/graph_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,21 @@ type BuildGraphMuxOptions struct {
RoutingUrlGroupings map[string]map[string]bool
}

func rateLimitOverridesFromConfig(source map[string]config.RateLimitSimpleOverride) map[string]RateLimitOverride {
if len(source) == 0 {
return nil
}
converted := make(map[string]RateLimitOverride, len(source))
for suffix, override := range source {
converted[suffix] = RateLimitOverride{
Rate: override.Rate,
Burst: override.Burst,
Period: override.Period,
}
}
return converted
}

func (b BuildGraphMuxOptions) IsBaseGraph() bool {
return b.FeatureFlagName == ""
}
Expand Down Expand Up @@ -519,12 +534,12 @@ type graphMux struct {
validationCache *ristretto.Cache[uint64, bool]
operationHashCache *ristretto.Cache[uint64, string]

accessLogsFileLogger *logging.BufferedLogger
metricStore rmetric.Store
prometheusCacheMetrics *rmetric.CacheMetrics
otelCacheMetrics *rmetric.CacheMetrics
streamMetricStore rmetric.StreamMetricStore
prometheusMetricsExporter *graphqlmetrics.PrometheusMetricsExporter
accessLogsFileLogger *logging.BufferedLogger
metricStore rmetric.Store
prometheusCacheMetrics *rmetric.CacheMetrics
otelCacheMetrics *rmetric.CacheMetrics
streamMetricStore rmetric.StreamMetricStore
prometheusMetricsExporter *graphqlmetrics.PrometheusMetricsExporter
}

// buildOperationCaches creates the caches for the graph mux.
Expand Down Expand Up @@ -1404,6 +1419,7 @@ func (s *graphServer) buildGraphMux(
RejectStatusCode: s.rateLimit.SimpleStrategy.RejectStatusCode,
KeySuffixExpression: s.rateLimit.KeySuffixExpression,
ExprManager: exprManager,
Overrides: rateLimitOverridesFromConfig(s.rateLimit.SimpleStrategy.Overrides),
})
if err != nil {
return nil, fmt.Errorf("failed to create rate limiter: %w", err)
Expand Down
62 changes: 50 additions & 12 deletions router/core/ratelimiter.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ import (
"fmt"
"io"
"reflect"
"strings"
"sync"
"time"

rd "github.com/wundergraph/cosmo/router/internal/rediscloser"

Expand All @@ -30,15 +32,32 @@ type CosmoRateLimiterOptions struct {

KeySuffixExpression string
ExprManager *expr.Manager
Overrides map[string]RateLimitOverride
}

func NewCosmoRateLimiter(opts *CosmoRateLimiterOptions) (rl *CosmoRateLimiter, err error) {
limiter := redis_rate.NewLimiter(opts.RedisClient)
var overrides map[string]redis_rate.Limit
if len(opts.Overrides) > 0 {
overrides = make(map[string]redis_rate.Limit, len(opts.Overrides))
for rawKey, override := range opts.Overrides {
key := strings.TrimSpace(rawKey)
if key == "" {
continue
}
overrides[key] = redis_rate.Limit{
Rate: override.Rate,
Burst: override.Burst,
Period: override.Period,
}
}
}
rl = &CosmoRateLimiter{
client: opts.RedisClient,
limiter: limiter,
debug: opts.Debug,
rejectStatusCode: opts.RejectStatusCode,
keyOverrides: overrides,
}
if rl.rejectStatusCode == 0 {
rl.rejectStatusCode = 200
Expand All @@ -54,28 +73,35 @@ func NewCosmoRateLimiter(opts *CosmoRateLimiterOptions) (rl *CosmoRateLimiter, e

type CosmoRateLimiter struct {
client rd.RDCloser
limiter *redis_rate.Limiter
limiter redisLimiter
debug bool

rejectStatusCode int

keySuffixProgram *vm.Program
keyOverrides map[string]redis_rate.Limit
}

type redisLimiter interface {
AllowN(ctx context.Context, key string, limit redis_rate.Limit, n int) (*redis_rate.Result, error)
}

type RateLimitOverride struct {
Rate int
Burst int
Period time.Duration
}

func (c *CosmoRateLimiter) RateLimitPreFetch(ctx *resolve.Context, info *resolve.FetchInfo, input json.RawMessage) (result *resolve.RateLimitDeny, err error) {
if c.isIntrospectionQuery(info.RootFields) {
return nil, nil
}
requestRate := c.calculateRate()
limit := redis_rate.Limit{
Rate: ctx.RateLimitOptions.Rate,
Burst: ctx.RateLimitOptions.Burst,
Period: ctx.RateLimitOptions.Period,
}
key, err := c.generateKey(ctx)
key, _, err := c.generateKey(ctx)
if err != nil {
return nil, err
}
limit := c.limitFor(ctx, key)
allow, err := c.limiter.AllowN(ctx.Context(), key, limit, requestRate)
if err != nil {
return nil, err
Expand All @@ -90,23 +116,35 @@ func (c *CosmoRateLimiter) RateLimitPreFetch(ctx *resolve.Context, info *resolve
return &resolve.RateLimitDeny{}, nil
}

func (c *CosmoRateLimiter) generateKey(ctx *resolve.Context) (string, error) {
func (c *CosmoRateLimiter) generateKey(ctx *resolve.Context) (string, string, error) {
if c.keySuffixProgram == nil {
return ctx.RateLimitOptions.RateLimitKey, nil
return ctx.RateLimitOptions.RateLimitKey, "", nil
}
rc := getRequestContext(ctx.Context())
if rc == nil {
return "", errors.New("no request context")
return "", "", errors.New("no request context")
}
str, err := expr.ResolveStringExpression(c.keySuffixProgram, rc.expressionContext)
if err != nil {
return "", fmt.Errorf("failed to resolve key suffix expression: %w", err)
return "", "", fmt.Errorf("failed to resolve key suffix expression: %w", err)
}
buf := bytes.NewBuffer(make([]byte, 0, len(ctx.RateLimitOptions.RateLimitKey)+len(str)+1))
_, _ = buf.WriteString(ctx.RateLimitOptions.RateLimitKey)
_ = buf.WriteByte(':')
_, _ = buf.WriteString(str)
return buf.String(), nil
return buf.String(), str, nil
}

func (c *CosmoRateLimiter) limitFor(ctx *resolve.Context, key string) redis_rate.Limit {
limit := redis_rate.Limit{
Rate: ctx.RateLimitOptions.Rate,
Burst: ctx.RateLimitOptions.Burst,
Period: ctx.RateLimitOptions.Period,
}
if override, ok := c.keyOverrides[key]; ok {
return override
}
return limit
}

func (c *CosmoRateLimiter) RejectStatusCode() int {
Expand Down
Loading