diff --git a/algorithms.go b/algorithms.go index 26d3897c..7ea11d95 100644 --- a/algorithms.go +++ b/algorithms.go @@ -292,6 +292,10 @@ func leakyBucket(s Store, c Cache, r *RateLimitReq) (resp *RateLimitResp, err er duration = expire - (n.UnixNano() / 1000000) } + if r.Hits != 0 { + c.UpdateExpiration(r.HashKey(), now+duration) + } + // Calculate how much leaked out of the bucket since the last time we leaked a hit elapsed := now - b.UpdatedAt leak := float64(elapsed) / rate @@ -349,7 +353,6 @@ func leakyBucket(s Store, c Cache, r *RateLimitReq) (resp *RateLimitResp, err er b.Remaining -= float64(r.Hits) rl.Remaining = int64(b.Remaining) rl.ResetTime = now + (rl.Limit-rl.Remaining)*int64(rate) - c.UpdateExpiration(hashKey, now+duration) return rl, nil } diff --git a/functional_test.go b/functional_test.go index 14fce4b6..b7911a3e 100644 --- a/functional_test.go +++ b/functional_test.go @@ -365,6 +365,27 @@ func TestLeakyBucket(t *testing.T) { Hits: 0, Remaining: 10, Status: guber.Status_UNDER_LIMIT, + Sleep: clock.Second * 60, + }, + { + Name: "should use up the limit and wait until 1 second before duration period", + Hits: 10, + Remaining: 0, + Status: guber.Status_UNDER_LIMIT, + Sleep: clock.Second * 29, + }, + { + Name: "should use up all hits one second before duration period", + Hits: 9, + Remaining: 0, + Status: guber.Status_UNDER_LIMIT, + Sleep: clock.Second * 3, + }, + { + Name: "only have 1 hit remaining", + Hits: 1, + Remaining: 0, + Status: guber.Status_UNDER_LIMIT, Sleep: clock.Second, }, } @@ -391,7 +412,7 @@ func TestLeakyBucket(t *testing.T) { assert.Equal(t, test.Status, rl.Status) assert.Equal(t, test.Remaining, rl.Remaining) assert.Equal(t, int64(10), rl.Limit) - assert.Equal(t, clock.Now().Unix() + (rl.Limit - rl.Remaining) * 3, rl.ResetTime/1000) + assert.Equal(t, clock.Now().Unix()+(rl.Limit-rl.Remaining)*3, rl.ResetTime/1000) clock.Advance(test.Sleep) }) } @@ -498,7 +519,7 @@ func TestLeakyBucketWithBurst(t *testing.T) { assert.Equal(t, test.Status, rl.Status) assert.Equal(t, test.Remaining, rl.Remaining) assert.Equal(t, int64(10), rl.Limit) - assert.Equal(t, clock.Now().Unix() + (rl.Limit - rl.Remaining) * 3, rl.ResetTime/1000) + assert.Equal(t, clock.Now().Unix()+(rl.Limit-rl.Remaining)*3, rl.ResetTime/1000) clock.Advance(test.Sleep) }) }