From f51861d4d4cc14de3e33d48f0fa8f23396e11360 Mon Sep 17 00:00:00 2001 From: Shawn Poulson Date: Mon, 11 Mar 2024 20:09:06 -0400 Subject: [PATCH] Don't call `OnChange()` event from non-owner. Non-owners shouldn't be persisting rate limit state. --- algorithms.go | 36 ++++++++++++++++++------------------ gubernator.go | 6 +++--- workers.go | 10 +++++----- 3 files changed, 26 insertions(+), 26 deletions(-) diff --git a/algorithms.go b/algorithms.go index 4032fa4f..61fa1544 100644 --- a/algorithms.go +++ b/algorithms.go @@ -34,7 +34,7 @@ import ( // with 100 emails and the request will succeed. You can override this default behavior with `DRAIN_OVER_LIMIT` // Implements token bucket algorithm for rate limiting. https://en.wikipedia.org/wiki/Token_bucket -func tokenBucket(ctx context.Context, s Store, c Cache, r *RateLimitReq, rs RateLimitReqState) (resp *RateLimitResp, err error) { +func tokenBucket(ctx context.Context, s Store, c Cache, r *RateLimitReq, reqState RateLimitReqState) (resp *RateLimitResp, err error) { tokenBucketTimer := prometheus.NewTimer(metricFuncTimeDuration.WithLabelValues("tokenBucket")) defer tokenBucketTimer.ObserveDuration() @@ -99,7 +99,7 @@ func tokenBucket(ctx context.Context, s Store, c Cache, r *RateLimitReq, rs Rate s.Remove(ctx, hashKey) } - return tokenBucketNewItem(ctx, s, c, r, rs) + return tokenBucketNewItem(ctx, s, c, r, reqState) } // Update the limit if it changed. @@ -146,7 +146,7 @@ func tokenBucket(ctx context.Context, s Store, c Cache, r *RateLimitReq, rs Rate rl.ResetTime = expire } - if s != nil { + if s != nil && reqState.IsOwner { defer func() { s.OnChange(ctx, r, item) }() @@ -161,7 +161,7 @@ func tokenBucket(ctx context.Context, s Store, c Cache, r *RateLimitReq, rs Rate // If we are already at the limit. if rl.Remaining == 0 && r.Hits > 0 { trace.SpanFromContext(ctx).AddEvent("Already over the limit") - if rs.IsOwner { + if reqState.IsOwner { metricOverLimitCounter.Add(1) } rl.Status = Status_OVER_LIMIT @@ -181,7 +181,7 @@ func tokenBucket(ctx context.Context, s Store, c Cache, r *RateLimitReq, rs Rate // without updating the cache. if r.Hits > t.Remaining { trace.SpanFromContext(ctx).AddEvent("Over the limit") - if rs.IsOwner { + if reqState.IsOwner { metricOverLimitCounter.Add(1) } rl.Status = Status_OVER_LIMIT @@ -199,11 +199,11 @@ func tokenBucket(ctx context.Context, s Store, c Cache, r *RateLimitReq, rs Rate } // Item is not found in cache or store, create new. - return tokenBucketNewItem(ctx, s, c, r, rs) + return tokenBucketNewItem(ctx, s, c, r, reqState) } // Called by tokenBucket() when adding a new item in the store. -func tokenBucketNewItem(ctx context.Context, s Store, c Cache, r *RateLimitReq, rs RateLimitReqState) (resp *RateLimitResp, err error) { +func tokenBucketNewItem(ctx context.Context, s Store, c Cache, r *RateLimitReq, reqState RateLimitReqState) (resp *RateLimitResp, err error) { requestTime := *r.RequestTime expire := requestTime + r.Duration @@ -239,7 +239,7 @@ func tokenBucketNewItem(ctx context.Context, s Store, c Cache, r *RateLimitReq, // Client could be requesting that we always return OVER_LIMIT. if r.Hits > r.Limit { trace.SpanFromContext(ctx).AddEvent("Over the limit") - if rs.IsOwner { + if reqState.IsOwner { metricOverLimitCounter.Add(1) } rl.Status = Status_OVER_LIMIT @@ -249,7 +249,7 @@ func tokenBucketNewItem(ctx context.Context, s Store, c Cache, r *RateLimitReq, c.Add(item) - if s != nil { + if s != nil && reqState.IsOwner { s.OnChange(ctx, r, item) } @@ -257,7 +257,7 @@ func tokenBucketNewItem(ctx context.Context, s Store, c Cache, r *RateLimitReq, } // Implements leaky bucket algorithm for rate limiting https://en.wikipedia.org/wiki/Leaky_bucket -func leakyBucket(ctx context.Context, s Store, c Cache, r *RateLimitReq, rs RateLimitReqState) (resp *RateLimitResp, err error) { +func leakyBucket(ctx context.Context, s Store, c Cache, r *RateLimitReq, reqState RateLimitReqState) (resp *RateLimitResp, err error) { leakyBucketTimer := prometheus.NewTimer(metricFuncTimeDuration.WithLabelValues("V1Instance.getRateLimit_leakyBucket")) defer leakyBucketTimer.ObserveDuration() @@ -314,7 +314,7 @@ func leakyBucket(ctx context.Context, s Store, c Cache, r *RateLimitReq, rs Rate s.Remove(ctx, hashKey) } - return leakyBucketNewItem(ctx, s, c, r, rs) + return leakyBucketNewItem(ctx, s, c, r, reqState) } if HasBehavior(r.Behavior, Behavior_RESET_REMAINING) { @@ -379,7 +379,7 @@ func leakyBucket(ctx context.Context, s Store, c Cache, r *RateLimitReq, rs Rate // TODO: Feature missing: check for Duration change between item/request. - if s != nil { + if s != nil && reqState.IsOwner { defer func() { s.OnChange(ctx, r, item) }() @@ -387,7 +387,7 @@ func leakyBucket(ctx context.Context, s Store, c Cache, r *RateLimitReq, rs Rate // If we are already at the limit if int64(b.Remaining) == 0 && r.Hits > 0 { - if rs.IsOwner { + if reqState.IsOwner { metricOverLimitCounter.Add(1) } rl.Status = Status_OVER_LIMIT @@ -405,7 +405,7 @@ func leakyBucket(ctx context.Context, s Store, c Cache, r *RateLimitReq, rs Rate // If requested is more than available, then return over the limit // without updating the bucket, unless `DRAIN_OVER_LIMIT` is set. if r.Hits > int64(b.Remaining) { - if rs.IsOwner { + if reqState.IsOwner { metricOverLimitCounter.Add(1) } rl.Status = Status_OVER_LIMIT @@ -430,11 +430,11 @@ func leakyBucket(ctx context.Context, s Store, c Cache, r *RateLimitReq, rs Rate return rl, nil } - return leakyBucketNewItem(ctx, s, c, r, rs) + return leakyBucketNewItem(ctx, s, c, r, reqState) } // Called by leakyBucket() when adding a new item in the store. -func leakyBucketNewItem(ctx context.Context, s Store, c Cache, r *RateLimitReq, rs RateLimitReqState) (resp *RateLimitResp, err error) { +func leakyBucketNewItem(ctx context.Context, s Store, c Cache, r *RateLimitReq, reqState RateLimitReqState) (resp *RateLimitResp, err error) { requestTime := *r.RequestTime duration := r.Duration rate := float64(duration) / float64(r.Limit) @@ -467,7 +467,7 @@ func leakyBucketNewItem(ctx context.Context, s Store, c Cache, r *RateLimitReq, // Client could be requesting that we start with the bucket OVER_LIMIT if r.Hits > r.Burst { - if rs.IsOwner { + if reqState.IsOwner { metricOverLimitCounter.Add(1) } rl.Status = Status_OVER_LIMIT @@ -485,7 +485,7 @@ func leakyBucketNewItem(ctx context.Context, s Store, c Cache, r *RateLimitReq, c.Add(item) - if s != nil { + if s != nil && reqState.IsOwner { s.OnChange(ctx, r, item) } diff --git a/gubernator.go b/gubernator.go index 9fd6aee1..55f3f9f2 100644 --- a/gubernator.go +++ b/gubernator.go @@ -585,7 +585,7 @@ func (s *V1Instance) HealthCheck(ctx context.Context, r *HealthCheckReq) (health return health, nil } -func (s *V1Instance) getLocalRateLimit(ctx context.Context, r *RateLimitReq, rs RateLimitReqState) (_ *RateLimitResp, err error) { +func (s *V1Instance) getLocalRateLimit(ctx context.Context, r *RateLimitReq, reqState RateLimitReqState) (_ *RateLimitResp, err error) { ctx = tracing.StartNamedScope(ctx, "V1Instance.getLocalRateLimit", trace.WithAttributes( attribute.String("ratelimit.key", r.UniqueKey), attribute.String("ratelimit.name", r.Name), @@ -595,7 +595,7 @@ func (s *V1Instance) getLocalRateLimit(ctx context.Context, r *RateLimitReq, rs defer func() { tracing.EndScope(ctx, err) }() defer prometheus.NewTimer(metricFuncTimeDuration.WithLabelValues("V1Instance.getLocalRateLimit")).ObserveDuration() - resp, err := s.workerPool.GetRateLimit(ctx, r, rs) + resp, err := s.workerPool.GetRateLimit(ctx, r, reqState) if err != nil { return nil, errors.Wrap(err, "during workerPool.GetRateLimit") } @@ -605,7 +605,7 @@ func (s *V1Instance) getLocalRateLimit(ctx context.Context, r *RateLimitReq, rs s.global.QueueUpdate(r) } - if rs.IsOwner { + if reqState.IsOwner { metricGetRateLimitCounter.WithLabelValues("local").Inc() } return resp, nil diff --git a/workers.go b/workers.go index d62071be..34d99d1d 100644 --- a/workers.go +++ b/workers.go @@ -258,7 +258,7 @@ func (p *WorkerPool) dispatch(worker *Worker) { } // GetRateLimit sends a GetRateLimit request to worker pool. -func (p *WorkerPool) GetRateLimit(ctx context.Context, rlRequest *RateLimitReq, rs RateLimitReqState) (*RateLimitResp, error) { +func (p *WorkerPool) GetRateLimit(ctx context.Context, rlRequest *RateLimitReq, reqState RateLimitReqState) (*RateLimitResp, error) { // Delegate request to assigned channel based on request key. worker := p.getWorker(rlRequest.HashKey()) queueGauge := metricWorkerQueue.WithLabelValues("GetRateLimit", worker.name) @@ -268,7 +268,7 @@ func (p *WorkerPool) GetRateLimit(ctx context.Context, rlRequest *RateLimitReq, ctx: ctx, resp: make(chan *response, 1), request: rlRequest, - reqState: rs, + reqState: reqState, } // Send request. @@ -290,14 +290,14 @@ func (p *WorkerPool) GetRateLimit(ctx context.Context, rlRequest *RateLimitReq, } // Handle request received by worker. -func (worker *Worker) handleGetRateLimit(ctx context.Context, req *RateLimitReq, rs RateLimitReqState, cache Cache) (*RateLimitResp, error) { +func (worker *Worker) handleGetRateLimit(ctx context.Context, req *RateLimitReq, reqState RateLimitReqState, cache Cache) (*RateLimitResp, error) { defer prometheus.NewTimer(metricFuncTimeDuration.WithLabelValues("Worker.handleGetRateLimit")).ObserveDuration() var rlResponse *RateLimitResp var err error switch req.Algorithm { case Algorithm_TOKEN_BUCKET: - rlResponse, err = tokenBucket(ctx, worker.conf.Store, cache, req, rs) + rlResponse, err = tokenBucket(ctx, worker.conf.Store, cache, req, reqState) if err != nil { msg := "Error in tokenBucket" countError(err, msg) @@ -306,7 +306,7 @@ func (worker *Worker) handleGetRateLimit(ctx context.Context, req *RateLimitReq, } case Algorithm_LEAKY_BUCKET: - rlResponse, err = leakyBucket(ctx, worker.conf.Store, cache, req, rs) + rlResponse, err = leakyBucket(ctx, worker.conf.Store, cache, req, reqState) if err != nil { msg := "Error in leakyBucket" countError(err, msg)