From 9b2b8eca9a91b6c24a4d01a269831d677e76aeb0 Mon Sep 17 00:00:00 2001 From: Shawn Poulson Date: Fri, 14 Jan 2022 16:19:15 -0500 Subject: [PATCH] PIP-1490: Refactor unit tests on `tokenBucket()` in `store_test.go`. TODO: Need to build equivelent tests for `leakyBucket()`. --- algorithms.go | 81 +++---- gubernator_pool_test.go | 2 +- mock_store_test.go | 47 ++++ store_test.go | 479 ++++++++++++++++++++++++++++------------ 4 files changed, 432 insertions(+), 177 deletions(-) create mode 100644 mock_store_test.go diff --git a/algorithms.go b/algorithms.go index b5353bc2..3bc0c3ef 100644 --- a/algorithms.go +++ b/algorithms.go @@ -71,6 +71,7 @@ func tokenBucket(ctx context.Context, s Store, c Cache, r *RateLimitReq) (resp * } if ok { + // Item found in cache or store. tracing.LogInfo(span, "Update existing rate limit") if HasBehavior(r.Behavior, Behavior_RESET_REMAINING) { @@ -107,8 +108,7 @@ func tokenBucket(ctx context.Context, s Store, c Cache, r *RateLimitReq) (resp * tracing.LogInfo(span, "s.Remove()") } - // FIXME: Eliminate recursion. - return tokenBucket(ctx, s, c, r) + return tokenBucketNewItem(ctx, s, c, r) } if s != nil { @@ -118,10 +118,10 @@ func tokenBucket(ctx context.Context, s Store, c Cache, r *RateLimitReq) (resp * }() } - // Update the limit if it changed + // Update the limit if it changed. tracing.LogInfo(span, "Update the limit if changed") if t.Limit != r.Limit { - // Add difference to remaining + // Add difference to remaining. t.Remaining += r.Limit - t.Limit if t.Remaining < 0 { t.Remaining = 0 @@ -136,9 +136,9 @@ func tokenBucket(ctx context.Context, s Store, c Cache, r *RateLimitReq) (resp * ResetTime: item.ExpireAt, } - // If the duration config changed, update the new ExpireAt + // If the duration config changed, update the new ExpireAt. if t.Duration != r.Duration { - tracing.LogInfo(span, "Duration config changed") + tracing.LogInfo(span, "Duration changed") expire := t.CreatedAt + r.Duration if HasBehavior(r.Behavior, Behavior_DURATION_IS_GREGORIAN) { expire, err = GregorianExpiration(clock.Now(), r.Duration) @@ -147,31 +147,29 @@ func tokenBucket(ctx context.Context, s Store, c Cache, r *RateLimitReq) (resp * } } - // If our new duration means we are currently expired - if expire <= MillisecondNow() { - // Update this so s.OnChange() will get the new expire change + // If our new duration means we are currently expired. + now := MillisecondNow() + if expire <= now { + // Renew item. tracing.LogInfo(span, "Limit has expired") - item.ExpireAt = expire - - c.Remove(item.Key) - tracing.LogInfo(span, "c.Remove()") - - // FIXME: tokenBucketNewItem creates a new item. But, we want to - // preserve this item for its CreatedAt timestamp. - return tokenBucketNewItem(ctx, s, c, r) + expire = now + r.Duration + t.CreatedAt = now + t.Remaining = t.Limit } item.ExpireAt = expire + t.Duration = r.Duration rl.ResetTime = expire } - // Client is only interested in retrieving the current status or updating the rate limit config + // Client is only interested in retrieving the current status or + // updating the rate limit config. if r.Hits == 0 { tracing.LogInfo(span, "Return current status, apply no change") return rl, nil } - // If we are already at the limit + // If we are already at the limit. if rl.Remaining == 0 { tracing.LogInfo(span, "Already over the limit") overLimitCounter.Add(1) @@ -180,7 +178,7 @@ func tokenBucket(ctx context.Context, s Store, c Cache, r *RateLimitReq) (resp * return rl, nil } - // If requested hits takes the remainder + // If requested hits takes the remainder. if t.Remaining == r.Hits { tracing.LogInfo(span, "At the limit") t.Remaining = 0 @@ -188,7 +186,8 @@ func tokenBucket(ctx context.Context, s Store, c Cache, r *RateLimitReq) (resp * return rl, nil } - // If requested is more than available, then return over the limit without updating the cache. + // If requested is more than available, then return over the limit + // without updating the cache. if r.Hits > t.Remaining { tracing.LogInfo(span, "Over the limit") overLimitCounter.Add(1) @@ -202,18 +201,32 @@ func tokenBucket(ctx context.Context, s Store, c Cache, r *RateLimitReq) (resp * return rl, nil } + // Item is not found in cache or store, create new. return tokenBucketNewItem(ctx, s, c, r) } -// Called by tokenBucket() when the requested item is not found in cache or store. +// Called by tokenBucket() when adding a new item in the store. func tokenBucketNewItem(ctx context.Context, s Store, c Cache, r *RateLimitReq) (resp *RateLimitResp, err error) { span, ctx := tracing.StartSpan(ctx) defer span.Finish() - // Add a new rate limit to the cache - tracing.LogInfo(span, "Add a new rate limit to the cache") now := MillisecondNow() expire := now + r.Duration + + item := CacheItem{ + Algorithm: r.Algorithm, + Key: r.HashKey(), + Value: &TokenBucketItem{ + Limit: r.Limit, + Duration: r.Duration, + Remaining: r.Limit - r.Hits, + CreatedAt: now, + }, + ExpireAt: expire, + } + + // Add a new rate limit to the cache. + tracing.LogInfo(span, "Add a new rate limit to the cache") if HasBehavior(r.Behavior, Behavior_DURATION_IS_GREGORIAN) { expire, err = GregorianExpiration(clock.Now(), r.Duration) if err != nil { @@ -221,13 +234,7 @@ func tokenBucketNewItem(ctx context.Context, s Store, c Cache, r *RateLimitReq) } } - t := &TokenBucketItem{ - Limit: r.Limit, - Duration: r.Duration, - Remaining: r.Limit - r.Hits, - CreatedAt: now, - } - + t := item.Value.(*TokenBucketItem) rl := &RateLimitResp{ Status: Status_UNDER_LIMIT, Limit: r.Limit, @@ -235,7 +242,7 @@ func tokenBucketNewItem(ctx context.Context, s Store, c Cache, r *RateLimitReq) ResetTime: expire, } - // Client could be requesting that we always return OVER_LIMIT + // Client could be requesting that we always return OVER_LIMIT. if r.Hits > r.Limit { tracing.LogInfo(span, "Over the limit") overLimitCounter.Add(1) @@ -244,13 +251,6 @@ func tokenBucketNewItem(ctx context.Context, s Store, c Cache, r *RateLimitReq) t.Remaining = r.Limit } - item := CacheItem{ - Algorithm: r.Algorithm, - Key: r.HashKey(), - Value: t, - ExpireAt: expire, - } - c.Add(item) tracing.LogInfo(span, "c.Add()") @@ -258,6 +258,7 @@ func tokenBucketNewItem(ctx context.Context, s Store, c Cache, r *RateLimitReq) s.OnChange(r, item) tracing.LogInfo(span, "s.OnChange()") } + return rl, nil } @@ -358,6 +359,8 @@ func leakyBucket(ctx context.Context, s Store, c Cache, r *RateLimitReq) (resp * ResetTime: now + (b.Limit-int64(b.Remaining))*int64(rate), } + // TODO: Feature missing: check for Duration change between item/request. + if s != nil { defer func() { s.OnChange(r, item) diff --git a/gubernator_pool_test.go b/gubernator_pool_test.go index ace4328b..4169829f 100644 --- a/gubernator_pool_test.go +++ b/gubernator_pool_test.go @@ -91,7 +91,7 @@ func TestGubernatorPool(t *testing.T) { CacheFactory: func() guber.Cache { return mockCache }, - Loader: mockLoader, + Loader: mockLoader, } conf.SetDefaults() chp := guber.NewGubernatorPool(conf, testCase.concurrency) diff --git a/mock_store_test.go b/mock_store_test.go new file mode 100644 index 00000000..ba7fa82c --- /dev/null +++ b/mock_store_test.go @@ -0,0 +1,47 @@ +/* +Copyright 2018-2022 Mailgun Technologies Inc + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package gubernator_test + +// Mock implementation of Store. + +import ( + guber "github.com/mailgun/gubernator/v2" + "github.com/stretchr/testify/mock" +) + +type MockStore2 struct { + mock.Mock +} + +var _ guber.Store = &MockStore2{} + +func (m *MockStore2) OnChange(r *guber.RateLimitReq, item guber.CacheItem) { + m.Called(r, item) +} + +func (m *MockStore2) Get(r *guber.RateLimitReq) (guber.CacheItem, bool) { + args := m.Called(r) + var retval guber.CacheItem + if retval2, ok := args.Get(0).(guber.CacheItem); ok { + retval = retval2 + } + return retval, args.Bool(1) +} + +func (m *MockStore2) Remove(key string) { + m.Called(key) +} diff --git a/store_test.go b/store_test.go index 4a618eef..c9599566 100644 --- a/store_test.go +++ b/store_test.go @@ -25,6 +25,7 @@ import ( "github.com/mailgun/gubernator/v2" "github.com/mailgun/holster/v4/clock" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "google.golang.org/grpc" ) @@ -125,168 +126,372 @@ func TestLoader(t *testing.T) { } func TestStore(t *testing.T) { - tests := []struct { - name string - firstRemaining int64 - firstStatus gubernator.Status - secondRemaining int64 - secondStatus gubernator.Status - algorithm gubernator.Algorithm - switchAlgorithm gubernator.Algorithm - testCase func(*gubernator.RateLimitReq, *gubernator.MockStore) - }{ - { - name: "Given there are no token bucket limits in the store", - firstRemaining: int64(9), - firstStatus: gubernator.Status_UNDER_LIMIT, - secondRemaining: int64(8), - secondStatus: gubernator.Status_UNDER_LIMIT, - algorithm: gubernator.Algorithm_TOKEN_BUCKET, - switchAlgorithm: gubernator.Algorithm_LEAKY_BUCKET, - testCase: func(req *gubernator.RateLimitReq, store *gubernator.MockStore) {}, - }, - { - name: "Given the store contains a token bucket rate limit not in the guber cache", - firstRemaining: int64(0), - firstStatus: gubernator.Status_UNDER_LIMIT, - secondRemaining: int64(0), - secondStatus: gubernator.Status_OVER_LIMIT, - algorithm: gubernator.Algorithm_TOKEN_BUCKET, - switchAlgorithm: gubernator.Algorithm_LEAKY_BUCKET, - testCase: func(req *gubernator.RateLimitReq, store *gubernator.MockStore) { - now := gubernator.MillisecondNow() - // Expire 1 second from now - expire := now + gubernator.Second - store.CacheItems[req.HashKey()] = gubernator.CacheItem{ - Algorithm: gubernator.Algorithm_TOKEN_BUCKET, - ExpireAt: expire, - Key: req.HashKey(), - Value: &gubernator.TokenBucketItem{ - Limit: req.Limit, - Duration: req.Duration, - CreatedAt: now, - Remaining: 1, - }, - } - }, - }, - { - name: "Given there are no leaky bucket limits in the store", - firstRemaining: int64(9), - firstStatus: gubernator.Status_UNDER_LIMIT, - secondRemaining: int64(8), - secondStatus: gubernator.Status_UNDER_LIMIT, - algorithm: gubernator.Algorithm_LEAKY_BUCKET, - switchAlgorithm: gubernator.Algorithm_TOKEN_BUCKET, - testCase: func(req *gubernator.RateLimitReq, store *gubernator.MockStore) {}, - }, - { - name: "Given the store contains a leaky bucket rate limit not in the guber cache", - firstRemaining: int64(0), - firstStatus: gubernator.Status_UNDER_LIMIT, - secondRemaining: int64(0), - secondStatus: gubernator.Status_OVER_LIMIT, - algorithm: gubernator.Algorithm_LEAKY_BUCKET, - switchAlgorithm: gubernator.Algorithm_TOKEN_BUCKET, - testCase: func(req *gubernator.RateLimitReq, store *gubernator.MockStore) { - // Expire 1 second from now - expire := gubernator.MillisecondNow() + gubernator.Second - store.CacheItems[req.HashKey()] = gubernator.CacheItem{ - Algorithm: gubernator.Algorithm_LEAKY_BUCKET, - ExpireAt: expire, - Key: req.HashKey(), - Value: &gubernator.LeakyBucketItem{ - UpdatedAt: gubernator.MillisecondNow(), - Duration: req.Duration, - Limit: req.Limit, - Remaining: 1, - Burst: req.Limit, - }, - } + setup := func() (*MockStore2, *v1Server, gubernator.V1Client) { + store := &MockStore2{} + + srv := newV1Server(t, "", gubernator.Config{ + Behaviors: gubernator.BehaviorConfig{ + GlobalSyncWait: clock.Millisecond * 50, // Suitable for testing but not production + GlobalTimeout: clock.Second, }, - }, + Store: store, + }) + + client, err := gubernator.DialV1Server(srv.listener.Addr().String(), nil) + require.NoError(t, err) + + return store, srv, client } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - store := gubernator.NewMockStore() - - srv := newV1Server(t, "", gubernator.Config{ - Behaviors: gubernator.BehaviorConfig{ - GlobalSyncWait: clock.Millisecond * 50, // Suitable for testing but not production - GlobalTimeout: clock.Second, - }, - Store: store, - }) - // No calls to store - assert.Equal(t, 0, store.Called["OnChange()"]) - assert.Equal(t, 0, store.Called["Get()"]) + tearDown := func(srv *v1Server) { + err := srv.Close() + require.NoError(t, err) + } + + // Create a mock argument matcher for a request by name/key. + matchReq := func(req *gubernator.RateLimitReq) interface{} { + return mock.MatchedBy(func(req2 *gubernator.RateLimitReq) bool { + return req2.Name == req.Name && + req2.UniqueKey == req.UniqueKey + }) + } + + // TODO: Build equivalent tests for Leaky Bucket. + t.Run("Token bucket", func(t *testing.T) { + t.Run("First rate check pulls from store", func(t *testing.T) { + store, srv, client := setup() + defer tearDown(srv) + + req := &gubernator.RateLimitReq{ + Name: "test_over_limit", + UniqueKey: "account:1234", + Algorithm: gubernator.Algorithm_TOKEN_BUCKET, + Duration: gubernator.Second, + Limit: 10, + Hits: 1, + } + + // Setup mocks. + expectedRemaining := req.Limit - req.Hits + store.On("Get", matchReq(req)).Once().Return(gubernator.CacheItem{}, false) + + store.On("OnChange", + matchReq(req), + mock.MatchedBy(func(item gubernator.CacheItem) bool { + titem, ok := item.Value.(*gubernator.TokenBucketItem) + if !ok { + return false + } + + return item.Algorithm == req.Algorithm && + item.Key == req.HashKey() && + titem.Limit == req.Limit && + titem.Duration == req.Duration && + titem.Remaining == expectedRemaining + }), + ). + Once() + + // Call code. + resp, err := client.GetRateLimits(context.Background(), &gubernator.GetRateLimitsReq{ + Requests: []*gubernator.RateLimitReq{req}, + }) + require.NoError(t, err) + require.NotNil(t, resp) + assert.Len(t, resp.Responses, 1) + assert.Equal(t, expectedRemaining, resp.Responses[0].Remaining) + assert.Equal(t, req.Limit, resp.Responses[0].Limit) + assert.Equal(t, gubernator.Status_UNDER_LIMIT, resp.Responses[0].Status) + store.AssertExpectations(t) + + t.Run("Second rate check pulls from cache", func(t *testing.T) { + // Setup mocks. + expectedRemaining2 := req.Limit - (req.Hits * 2) + + store.On("OnChange", + matchReq(req), + mock.MatchedBy(func(item gubernator.CacheItem) bool { + titem, ok := item.Value.(*gubernator.TokenBucketItem) + if !ok { + return false + } + + return item.Algorithm == req.Algorithm && + item.Key == req.HashKey() && + titem.Limit == req.Limit && + titem.Duration == req.Duration && + titem.Remaining == expectedRemaining2 + }), + ). + Once() + + // Call code. + resp, err := client.GetRateLimits(context.Background(), &gubernator.GetRateLimitsReq{ + Requests: []*gubernator.RateLimitReq{req}, + }) + require.NoError(t, err) + require.NotNil(t, resp) + assert.Len(t, resp.Responses, 1) + assert.Equal(t, expectedRemaining2, resp.Responses[0].Remaining) + assert.Equal(t, req.Limit, resp.Responses[0].Limit) + assert.Equal(t, gubernator.Status_UNDER_LIMIT, resp.Responses[0].Status) + store.AssertExpectations(t) + }) + }) - client, err := gubernator.DialV1Server(srv.listener.Addr().String(), nil) - assert.Nil(t, err) + t.Run("Found in store after cache miss", func(t *testing.T) { + store, srv, client := setup() + defer tearDown(srv) - req := gubernator.RateLimitReq{ + req := &gubernator.RateLimitReq{ Name: "test_over_limit", UniqueKey: "account:1234", - Algorithm: tt.algorithm, + Algorithm: gubernator.Algorithm_TOKEN_BUCKET, Duration: gubernator.Second, Limit: 10, Hits: 1, } - tt.testCase(&req, store) + // Setup mocks. + expectedRemaining := req.Limit - req.Hits + now := gubernator.MillisecondNow() + expire := now + req.Duration + storedItem := gubernator.CacheItem{ + Algorithm: req.Algorithm, + ExpireAt: expire, + Key: req.HashKey(), + Value: &gubernator.TokenBucketItem{ + Limit: req.Limit, + Duration: req.Duration, + CreatedAt: now, + Remaining: req.Limit, + }, + } - // This request for the rate limit should ask the store via Get() and then - // tell the store about the change to the rate limit by calling OnChange() + store.On("Get", matchReq(req)).Once().Return(storedItem, true) + + store.On("OnChange", + matchReq(req), + mock.MatchedBy(func(item gubernator.CacheItem) bool { + titem, ok := item.Value.(*gubernator.TokenBucketItem) + if !ok { + return false + } + + return item.Algorithm == req.Algorithm && + item.Key == req.HashKey() && + titem.Limit == req.Limit && + titem.Duration == req.Duration && + titem.Remaining == expectedRemaining + }), + ). + Once() + + // Call code. resp, err := client.GetRateLimits(context.Background(), &gubernator.GetRateLimitsReq{ - Requests: []*gubernator.RateLimitReq{&req}, + Requests: []*gubernator.RateLimitReq{req}, }) - require.Nil(t, err) + require.NoError(t, err) require.NotNil(t, resp) - require.Equal(t, 1, len(resp.Responses)) - require.Equal(t, "", resp.Responses[0].Error) - assert.Equal(t, tt.firstRemaining, resp.Responses[0].Remaining) - assert.Equal(t, int64(10), resp.Responses[0].Limit) - assert.Equal(t, tt.firstStatus, resp.Responses[0].Status) - - // Should have called OnChange() and Get() - assert.Equal(t, 1, store.Called["OnChange()"]) - assert.Equal(t, 1, store.Called["Get()"]) - - // Should have updated the store - assert.Equal(t, tt.firstRemaining, getRemaining(store.CacheItems[req.HashKey()])) - - // Next call should not call `Get()` but only `OnChange()` - resp, err = client.GetRateLimits(context.Background(), &gubernator.GetRateLimitsReq{ - Requests: []*gubernator.RateLimitReq{&req}, + assert.Len(t, resp.Responses, 1) + assert.Equal(t, expectedRemaining, resp.Responses[0].Remaining) + assert.Equal(t, req.Limit, resp.Responses[0].Limit) + assert.Equal(t, gubernator.Status_UNDER_LIMIT, resp.Responses[0].Status) + store.AssertExpectations(t) + }) + + t.Run("Algorithm changed", func(t *testing.T) { + // Removes stored item, then creates new. + store, srv, client := setup() + defer tearDown(srv) + + req := &gubernator.RateLimitReq{ + Name: "test_over_limit", + UniqueKey: "account:1234", + Algorithm: gubernator.Algorithm_TOKEN_BUCKET, + Duration: gubernator.Second, + Limit: 10, + Hits: 1, + } + + // Setup mocks. + expectedRemaining := req.Limit - req.Hits + now := gubernator.MillisecondNow() + expire := now + req.Duration + storedItem := gubernator.CacheItem{ + Algorithm: req.Algorithm, + ExpireAt: expire, + Key: req.HashKey(), + Value: &gubernator.LeakyBucketItem{}, + } + + store.On("Get", matchReq(req)).Once().Return(storedItem, true) + store.On("Remove", req.HashKey()).Once() + + store.On("OnChange", + matchReq(req), + mock.MatchedBy(func(item gubernator.CacheItem) bool { + titem, ok := item.Value.(*gubernator.TokenBucketItem) + if !ok { + return false + } + + return item.Algorithm == req.Algorithm && + item.Key == req.HashKey() && + titem.Limit == req.Limit && + titem.Duration == req.Duration && + titem.Remaining == expectedRemaining + }), + ). + Once() + + // Call code. + resp, err := client.GetRateLimits(context.Background(), &gubernator.GetRateLimitsReq{ + Requests: []*gubernator.RateLimitReq{req}, }) - require.Nil(t, err) + require.NoError(t, err) require.NotNil(t, resp) - assert.Equal(t, tt.secondRemaining, resp.Responses[0].Remaining) - assert.Equal(t, int64(10), resp.Responses[0].Limit) - assert.Equal(t, tt.secondStatus, resp.Responses[0].Status) + assert.Len(t, resp.Responses, 1) + assert.Equal(t, expectedRemaining, resp.Responses[0].Remaining) + assert.Equal(t, req.Limit, resp.Responses[0].Limit) + assert.Equal(t, gubernator.Status_UNDER_LIMIT, resp.Responses[0].Status) + store.AssertExpectations(t) + }) - // Should have called OnChange() not Get() since rate limit is in the cache - assert.Equal(t, 2, store.Called["OnChange()"]) - assert.Equal(t, 1, store.Called["Get()"]) + t.Run("Duration changed", func(t *testing.T) { + // Updates expiration timestamp in store. + store, srv, client := setup() + defer tearDown(srv) - // Should have updated the store - assert.Equal(t, tt.secondRemaining, getRemaining(store.CacheItems[req.HashKey()])) + oldDuration := int64(5000) + newDuration := int64(8000) + req := &gubernator.RateLimitReq{ + Name: "test_over_limit", + UniqueKey: "account:1234", + Algorithm: gubernator.Algorithm_TOKEN_BUCKET, + Duration: newDuration, + Limit: 10, + Hits: 1, + } + + // Setup mocks. + expectedRemaining := req.Limit - req.Hits + now := gubernator.MillisecondNow() + oldExpire := now + oldDuration + storedItem := gubernator.CacheItem{ + Algorithm: req.Algorithm, + ExpireAt: oldExpire, + Key: req.HashKey(), + Value: &gubernator.TokenBucketItem{ + Limit: req.Limit, + Duration: oldDuration, + CreatedAt: now, + Remaining: req.Limit, + }, + } - // Should have called `Remove()` when algorithm changed - req.Algorithm = tt.switchAlgorithm - resp, err = client.GetRateLimits(context.Background(), &gubernator.GetRateLimitsReq{ - Requests: []*gubernator.RateLimitReq{&req}, + store.On("Get", matchReq(req)).Once().Return(storedItem, true) + + store.On("OnChange", + matchReq(req), + mock.MatchedBy(func(item gubernator.CacheItem) bool { + titem, ok := item.Value.(*gubernator.TokenBucketItem) + if !ok { + return false + } + + return item.Algorithm == req.Algorithm && + item.Key == req.HashKey() && + item.ExpireAt == titem.CreatedAt+newDuration && + titem.Limit == req.Limit && + titem.Duration == req.Duration && + titem.Remaining == expectedRemaining + }), + ). + Once() + + // Call code. + resp, err := client.GetRateLimits(context.Background(), &gubernator.GetRateLimitsReq{ + Requests: []*gubernator.RateLimitReq{req}, }) - require.Nil(t, err) + require.NoError(t, err) require.NotNil(t, resp) - assert.Equal(t, 1, store.Called["Remove()"]) - assert.Equal(t, 3, store.Called["OnChange()"]) - assert.Equal(t, 2, store.Called["Get()"]) + assert.Len(t, resp.Responses, 1) + assert.Equal(t, expectedRemaining, resp.Responses[0].Remaining) + assert.Equal(t, req.Limit, resp.Responses[0].Limit) + assert.Equal(t, gubernator.Status_UNDER_LIMIT, resp.Responses[0].Status) + store.AssertExpectations(t) + }) + + t.Run("Duration changed and immediately expired", func(t *testing.T) { + // Occurs when new duration is shorter and is immediately expired + // because CreatedAt + NewDuration < Now. + // Stores new item with renewed expiration and resets remaining. + store, srv, client := setup() + defer tearDown(srv) + + oldDuration := int64(500000) + newDuration := int64(8000) + req := &gubernator.RateLimitReq{ + Name: "test_over_limit", + UniqueKey: "account:1234", + Algorithm: gubernator.Algorithm_TOKEN_BUCKET, + Duration: newDuration, + Limit: 10, + Hits: 1, + } - assert.Equal(t, tt.switchAlgorithm, store.CacheItems[req.HashKey()].Algorithm) + // Setup mocks. + expectedRemaining := req.Limit - req.Hits + now := gubernator.MillisecondNow() + longTimeAgo := now - 100000 + oldExpire := longTimeAgo + oldDuration + oldRemaining := int64(1) + storedItem := gubernator.CacheItem{ + Algorithm: req.Algorithm, + ExpireAt: oldExpire, + Key: req.HashKey(), + Value: &gubernator.TokenBucketItem{ + Limit: req.Limit, + Duration: oldDuration, + CreatedAt: longTimeAgo, + Remaining: oldRemaining, + }, + } + + store.On("Get", matchReq(req)).Once().Return(storedItem, true) + + store.On("OnChange", + matchReq(req), + mock.MatchedBy(func(item gubernator.CacheItem) bool { + titem, ok := item.Value.(*gubernator.TokenBucketItem) + if !ok { + return false + } + + return item.Algorithm == req.Algorithm && + item.Key == req.HashKey() && + item.ExpireAt == titem.CreatedAt+newDuration && + titem.Limit == req.Limit && + titem.Duration == req.Duration && + titem.Remaining == expectedRemaining + }), + ). + Once() + + // Call code. + resp, err := client.GetRateLimits(context.Background(), &gubernator.GetRateLimitsReq{ + Requests: []*gubernator.RateLimitReq{req}, + }) + require.NoError(t, err) + require.NotNil(t, resp) + assert.Len(t, resp.Responses, 1) + assert.Equal(t, expectedRemaining, resp.Responses[0].Remaining) + assert.Equal(t, req.Limit, resp.Responses[0].Limit) + assert.Equal(t, gubernator.Status_UNDER_LIMIT, resp.Responses[0].Status) + store.AssertExpectations(t) }) - } + }) } func getRemaining(item gubernator.CacheItem) int64 {