Skip to content
This repository has been archived by the owner on Apr 19, 2024. It is now read-only.

Commit

Permalink
PIP-1490: Add context to Store methods.
Browse files Browse the repository at this point in the history
  • Loading branch information
Baliedge committed Jan 19, 2022
1 parent 7d5c909 commit 9a0f98b
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 37 deletions.
18 changes: 9 additions & 9 deletions algorithms.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func tokenBucket(ctx context.Context, s Store, c Cache, r *RateLimitReq) (resp *
if s != nil && !ok {
// Cache miss.
// Check our store for the item.
if item, ok = s.Get(r); ok {
if item, ok = s.Get(ctx, r); ok {
tracing.LogInfo(span, "Check store for rate limit")
c.Add(item)
tracing.LogInfo(span, "c.Add()")
Expand Down Expand Up @@ -79,7 +79,7 @@ func tokenBucket(ctx context.Context, s Store, c Cache, r *RateLimitReq) (resp *
tracing.LogInfo(span, "c.Remove()")

if s != nil {
s.Remove(hashKey)
s.Remove(ctx, hashKey)
tracing.LogInfo(span, "s.Remove()")
}
return &RateLimitResp{
Expand All @@ -104,7 +104,7 @@ func tokenBucket(ctx context.Context, s Store, c Cache, r *RateLimitReq) (resp *
tracing.LogInfo(span, "c.Remove()")

if s != nil {
s.Remove(hashKey)
s.Remove(ctx, hashKey)
tracing.LogInfo(span, "s.Remove()")
}

Expand Down Expand Up @@ -157,7 +157,7 @@ func tokenBucket(ctx context.Context, s Store, c Cache, r *RateLimitReq) (resp *

if s != nil {
defer func() {
s.OnChange(r, item)
s.OnChange(ctx, r, item)
tracing.LogInfo(span, "defer s.OnChange()")
}()
}
Expand Down Expand Up @@ -255,7 +255,7 @@ func tokenBucketNewItem(ctx context.Context, s Store, c Cache, r *RateLimitReq)
tracing.LogInfo(span, "c.Add()")

if s != nil {
s.OnChange(r, item)
s.OnChange(ctx, r, item)
tracing.LogInfo(span, "s.OnChange()")
}

Expand Down Expand Up @@ -283,7 +283,7 @@ func leakyBucket(ctx context.Context, s Store, c Cache, r *RateLimitReq) (resp *
if s != nil && !ok {
// Cache miss.
// Check our store for the item.
if item, ok = s.Get(r); ok {
if item, ok = s.Get(ctx, r); ok {
tracing.LogInfo(span, "Check store for rate limit")
c.Add(item)
tracing.LogInfo(span, "c.Add()")
Expand Down Expand Up @@ -324,7 +324,7 @@ func leakyBucket(ctx context.Context, s Store, c Cache, r *RateLimitReq) (resp *
tracing.LogInfo(span, "c.Remove()")

if s != nil {
s.Remove(hashKey)
s.Remove(ctx, hashKey)
tracing.LogInfo(span, "s.Remove()")
}

Expand Down Expand Up @@ -395,7 +395,7 @@ func leakyBucket(ctx context.Context, s Store, c Cache, r *RateLimitReq) (resp *

if s != nil {
defer func() {
s.OnChange(r, item)
s.OnChange(ctx, r, item)
tracing.LogInfo(span, "s.OnChange()")
}()
}
Expand Down Expand Up @@ -492,7 +492,7 @@ func leakyBucketNewItem(ctx context.Context, s Store, c Cache, r *RateLimitReq)
tracing.LogInfo(span, "c.Add()")

if s != nil {
s.OnChange(r, item)
s.OnChange(ctx, r, item)
tracing.LogInfo(span, "s.OnChange()")
}

Expand Down
14 changes: 8 additions & 6 deletions mock_store_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ package gubernator_test
// Mock implementation of Store.

import (
"context"

guber "github.com/mailgun/gubernator/v2"
"github.com/stretchr/testify/mock"
)
Expand All @@ -29,19 +31,19 @@ type MockStore2 struct {

var _ guber.Store = &MockStore2{}

func (m *MockStore2) OnChange(r *guber.RateLimitReq, item *guber.CacheItem) {
m.Called(r, item)
func (m *MockStore2) OnChange(ctx context.Context, r *guber.RateLimitReq, item *guber.CacheItem) {
m.Called(ctx, r, item)
}

func (m *MockStore2) Get(r *guber.RateLimitReq) (*guber.CacheItem, bool) {
args := m.Called(r)
func (m *MockStore2) Get(ctx context.Context, r *guber.RateLimitReq) (*guber.CacheItem, bool) {
args := m.Called(ctx, r)
var retval *guber.CacheItem
if retval2, ok := args.Get(0).(*guber.CacheItem); ok {
retval = retval2
}
return retval, args.Bool(1)
}

func (m *MockStore2) Remove(key string) {
m.Called(key)
func (m *MockStore2) Remove(ctx context.Context, key string) {
m.Called(ctx, key)
}
14 changes: 8 additions & 6 deletions store.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ limitations under the License.

package gubernator

import "context"

// PERSISTENT STORE DETAILS

// The storage interfaces defined here allows the implementor flexibility in storage options. Depending on the
Expand Down Expand Up @@ -49,17 +51,17 @@ type Store interface {
// decide if this rate limit item should be persisted in the store. It's up to the
// store to expire old rate limit items. The CacheItem represents the current state of
// the rate limit item *after* the RateLimitReq has been applied.
OnChange(r *RateLimitReq, item *CacheItem)
OnChange(ctx context.Context, r *RateLimitReq, item *CacheItem)

// Called by gubernator when a rate limit is missing from the cache. It's up to the store
// to decide if this request is fulfilled. Should return true if the request is fulfilled
// and false if the request is not fulfilled or doesn't exist in the store.
Get(r *RateLimitReq) (*CacheItem, bool)
Get(ctx context.Context, r *RateLimitReq) (*CacheItem, bool)

// Called by gubernator when an existing rate limit should be removed from the store.
// NOTE: This is NOT called when an rate limit expires from the cache, store implementors
// must expire rate limits in the store.
Remove(key string)
Remove(ctx context.Context, key string)
}

// Loader interface allows implementors to store all or a subset of ratelimits into a persistent
Expand Down Expand Up @@ -93,18 +95,18 @@ type MockStore struct {

var _ Store = &MockStore{}

func (ms *MockStore) OnChange(r *RateLimitReq, item *CacheItem) {
func (ms *MockStore) OnChange(ctx context.Context, r *RateLimitReq, item *CacheItem) {
ms.Called["OnChange()"] += 1
ms.CacheItems[item.Key] = item
}

func (ms *MockStore) Get(r *RateLimitReq) (*CacheItem, bool) {
func (ms *MockStore) Get(ctx context.Context, r *RateLimitReq) (*CacheItem, bool) {
ms.Called["Get()"] += 1
item, ok := ms.CacheItems[r.HashKey()]
return item, ok
}

func (ms *MockStore) Remove(key string) {
func (ms *MockStore) Remove(ctx context.Context, key string) {
ms.Called["Remove()"] += 1
delete(ms.CacheItems, key)
}
Expand Down
35 changes: 19 additions & 16 deletions store_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ func TestLoader(t *testing.T) {
}

func TestStore(t *testing.T) {
ctx := context.Background()
setup := func() (*MockStore2, *v1Server, gubernator.V1Client) {
store := &MockStore2{}

Expand Down Expand Up @@ -240,11 +241,11 @@ func TestStore(t *testing.T) {
}

// Setup mocks.
store.On("Get", matchReq(req)).Once().Return(nil, false)
store.On("OnChange", matchReq(req), matchItem(req)).Once()
store.On("Get", mock.Anything, matchReq(req)).Once().Return(nil, false)
store.On("OnChange", mock.Anything, matchReq(req), matchItem(req)).Once()

// Call code.
resp, err := client.GetRateLimits(context.Background(), &gubernator.GetRateLimitsReq{
resp, err := client.GetRateLimits(ctx, &gubernator.GetRateLimitsReq{
Requests: []*gubernator.RateLimitReq{req},
})
require.NoError(t, err)
Expand All @@ -256,10 +257,10 @@ func TestStore(t *testing.T) {

t.Run("Second rate check pulls from cache", func(t *testing.T) {
// Setup mocks.
store.On("OnChange", matchReq(req), matchItem(req)).Once()
store.On("OnChange", mock.Anything, matchReq(req), matchItem(req)).Once()

// Call code.
resp, err := client.GetRateLimits(context.Background(), &gubernator.GetRateLimitsReq{
resp, err := client.GetRateLimits(ctx, &gubernator.GetRateLimitsReq{
Requests: []*gubernator.RateLimitReq{req},
})
require.NoError(t, err)
Expand Down Expand Up @@ -294,11 +295,11 @@ func TestStore(t *testing.T) {
Value: createBucketItem(req),
}

store.On("Get", matchReq(req)).Once().Return(storedItem, true)
store.On("OnChange", matchReq(req), matchItem(req)).Once()
store.On("Get", mock.Anything, matchReq(req)).Once().Return(storedItem, true)
store.On("OnChange", mock.Anything, matchReq(req), matchItem(req)).Once()

// Call code.
resp, err := client.GetRateLimits(context.Background(), &gubernator.GetRateLimitsReq{
resp, err := client.GetRateLimits(ctx, &gubernator.GetRateLimitsReq{
Requests: []*gubernator.RateLimitReq{req},
})
require.NoError(t, err)
Expand Down Expand Up @@ -333,12 +334,12 @@ func TestStore(t *testing.T) {
Value: &struct{}{},
}

store.On("Get", matchReq(req)).Once().Return(storedItem, true)
store.On("Remove", req.HashKey()).Once()
store.On("OnChange", matchReq(req), matchItem(req)).Once()
store.On("Get", mock.Anything, matchReq(req)).Once().Return(storedItem, true)
store.On("Remove", mock.Anything, req.HashKey()).Once()
store.On("OnChange", mock.Anything, matchReq(req), matchItem(req)).Once()

// Call code.
resp, err := client.GetRateLimits(context.Background(), &gubernator.GetRateLimitsReq{
resp, err := client.GetRateLimits(ctx, &gubernator.GetRateLimitsReq{
Requests: []*gubernator.RateLimitReq{req},
})
require.NoError(t, err)
Expand Down Expand Up @@ -388,9 +389,10 @@ func TestStore(t *testing.T) {
Value: bucketItem,
}

store.On("Get", matchReq(req)).Once().Return(storedItem, true)
store.On("Get", mock.Anything, matchReq(req)).Once().Return(storedItem, true)

store.On("OnChange",
mock.Anything,
matchReq(req),
mock.MatchedBy(func(item *gubernator.CacheItem) bool {
switch req.Algorithm {
Expand Down Expand Up @@ -427,7 +429,7 @@ func TestStore(t *testing.T) {
Once()

// Call code.
resp, err := client.GetRateLimits(context.Background(), &gubernator.GetRateLimitsReq{
resp, err := client.GetRateLimits(ctx, &gubernator.GetRateLimitsReq{
Requests: []*gubernator.RateLimitReq{req},
})
require.NoError(t, err)
Expand Down Expand Up @@ -477,9 +479,10 @@ func TestStore(t *testing.T) {
Value: bucketItem,
}

store.On("Get", matchReq(req)).Once().Return(storedItem, true)
store.On("Get", mock.Anything, matchReq(req)).Once().Return(storedItem, true)

store.On("OnChange",
mock.Anything,
matchReq(req),
mock.MatchedBy(func(item *gubernator.CacheItem) bool {
switch req.Algorithm {
Expand Down Expand Up @@ -516,7 +519,7 @@ func TestStore(t *testing.T) {
Once()

// Call code.
resp, err := client.GetRateLimits(context.Background(), &gubernator.GetRateLimitsReq{
resp, err := client.GetRateLimits(ctx, &gubernator.GetRateLimitsReq{
Requests: []*gubernator.RateLimitReq{req},
})
require.NoError(t, err)
Expand Down

0 comments on commit 9a0f98b

Please sign in to comment.