Skip to content

Commit bb9ac8f

Browse files
Pinank SolankiReneWerner87
Pinank Solanki
andauthored
🐛 Fix expiration time in cache middleware (#1881)
* 🐛 Fix: Expiration time in cache middleware * Custom expiration time using ExpirationGenerator is also functional now instead of default Expiration only * 🚨 Improve Test_CustomExpiration * - stabilization of the tests - speed up the cache tests - fix race conditions in client and client tests Co-authored-by: wernerr <[email protected]>
1 parent 2326297 commit bb9ac8f

File tree

3 files changed

+73
-17
lines changed

3 files changed

+73
-17
lines changed

client.go

+9-6
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package fiber
33
import (
44
"bytes"
55
"crypto/tls"
6+
"encoding/json"
67
"encoding/xml"
78
"fmt"
89
"io"
@@ -16,8 +17,6 @@ import (
1617
"sync"
1718
"time"
1819

19-
"encoding/json"
20-
2120
"github.com/gofiber/fiber/v2/utils"
2221
"github.com/valyala/fasthttp"
2322
)
@@ -60,6 +59,7 @@ var defaultClient Client
6059
//
6160
// It is safe calling Client methods from concurrently running goroutines.
6261
type Client struct {
62+
mutex sync.RWMutex
6363
// UserAgent is used in User-Agent request header.
6464
UserAgent string
6565

@@ -133,10 +133,15 @@ func (c *Client) createAgent(method, url string) *Agent {
133133
a.req.Header.SetMethod(method)
134134
a.req.SetRequestURI(url)
135135

136+
c.mutex.RLock()
136137
a.Name = c.UserAgent
137138
a.NoDefaultUserAgentHeader = c.NoDefaultUserAgentHeader
138139
a.jsonDecoder = c.JSONDecoder
139140
a.jsonEncoder = c.JSONEncoder
141+
if a.jsonDecoder == nil {
142+
a.jsonDecoder = json.Unmarshal
143+
}
144+
c.mutex.RUnlock()
140145

141146
if err := a.Parse(); err != nil {
142147
a.errs = append(a.errs, err)
@@ -810,10 +815,6 @@ func (a *Agent) String() (int, string, []error) {
810815
// Struct returns the status code, bytes body and errors of url.
811816
// And bytes body will be unmarshalled to given v.
812817
func (a *Agent) Struct(v interface{}) (code int, body []byte, errs []error) {
813-
if a.jsonDecoder == nil {
814-
a.jsonDecoder = json.Unmarshal
815-
}
816-
817818
if code, body, errs = a.Bytes(); len(errs) > 0 {
818819
return
819820
}
@@ -886,6 +887,8 @@ func AcquireClient() *Client {
886887
func ReleaseClient(c *Client) {
887888
c.UserAgent = ""
888889
c.NoDefaultUserAgentHeader = false
890+
c.JSONEncoder = nil
891+
c.JSONDecoder = nil
889892

890893
clientPool.Put(c)
891894
}

middleware/cache/cache.go

+6-6
Original file line numberDiff line numberDiff line change
@@ -168,23 +168,23 @@ func New(config ...Config) fiber.Handler {
168168
}
169169

170170
// default cache expiration
171-
expiration := uint64(cfg.Expiration.Seconds())
171+
expiration := cfg.Expiration
172172
// Calculate expiration by response header or other setting
173173
if cfg.ExpirationGenerator != nil {
174-
expiration = uint64(cfg.ExpirationGenerator(c, &cfg).Seconds())
174+
expiration = cfg.ExpirationGenerator(c, &cfg)
175175
}
176-
e.exp = ts + expiration
176+
e.exp = ts + uint64(expiration.Seconds())
177177

178178
// For external Storage we store raw body separated
179179
if cfg.Storage != nil {
180-
manager.setRaw(key+"_body", e.body, cfg.Expiration)
180+
manager.setRaw(key+"_body", e.body, expiration)
181181
// avoid body msgp encoding
182182
e.body = nil
183-
manager.set(key, e, cfg.Expiration)
183+
manager.set(key, e, expiration)
184184
manager.release(e)
185185
} else {
186186
// Store entry in memory
187-
manager.set(key, e, cfg.Expiration)
187+
manager.set(key, e, expiration)
188188
}
189189

190190
c.Set(cfg.CacheHeader, cacheMiss)

middleware/cache/cache_test.go

+58-5
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ import (
1919
)
2020

2121
func Test_Cache_CacheControl(t *testing.T) {
22+
t.Parallel()
23+
2224
app := fiber.New()
2325

2426
app.Use(New(Config{
@@ -77,6 +79,8 @@ func Test_Cache_Expired(t *testing.T) {
7779
}
7880

7981
func Test_Cache(t *testing.T) {
82+
t.Parallel()
83+
8084
app := fiber.New()
8185
app.Use(New())
8286

@@ -102,6 +106,8 @@ func Test_Cache(t *testing.T) {
102106
}
103107

104108
func Test_Cache_WithSeveralRequests(t *testing.T) {
109+
t.Parallel()
110+
105111
app := fiber.New()
106112

107113
app.Use(New(Config{
@@ -135,6 +141,8 @@ func Test_Cache_WithSeveralRequests(t *testing.T) {
135141
}
136142

137143
func Test_Cache_Invalid_Expiration(t *testing.T) {
144+
t.Parallel()
145+
138146
app := fiber.New()
139147
cache := New(Config{Expiration: 0 * time.Second})
140148
app.Use(cache)
@@ -161,6 +169,8 @@ func Test_Cache_Invalid_Expiration(t *testing.T) {
161169
}
162170

163171
func Test_Cache_Invalid_Method(t *testing.T) {
172+
t.Parallel()
173+
164174
app := fiber.New()
165175

166176
app.Use(New())
@@ -199,6 +209,8 @@ func Test_Cache_Invalid_Method(t *testing.T) {
199209
}
200210

201211
func Test_Cache_NothingToCache(t *testing.T) {
212+
t.Parallel()
213+
202214
app := fiber.New()
203215

204216
app.Use(New(Config{Expiration: -(time.Second * 1)}))
@@ -225,6 +237,8 @@ func Test_Cache_NothingToCache(t *testing.T) {
225237
}
226238

227239
func Test_Cache_CustomNext(t *testing.T) {
240+
t.Parallel()
241+
228242
app := fiber.New()
229243

230244
app.Use(New(Config{
@@ -263,6 +277,8 @@ func Test_Cache_CustomNext(t *testing.T) {
263277
}
264278

265279
func Test_CustomKey(t *testing.T) {
280+
t.Parallel()
281+
266282
app := fiber.New()
267283
var called bool
268284
app.Use(New(Config{KeyGenerator: func(c *fiber.Ctx) string {
@@ -281,6 +297,8 @@ func Test_CustomKey(t *testing.T) {
281297
}
282298

283299
func Test_CustomExpiration(t *testing.T) {
300+
t.Parallel()
301+
284302
app := fiber.New()
285303
var called bool
286304
var newCacheTime int
@@ -291,18 +309,45 @@ func Test_CustomExpiration(t *testing.T) {
291309
}}))
292310

293311
app.Get("/", func(c *fiber.Ctx) error {
294-
c.Response().Header.Add("Cache-Time", "6000")
295-
return c.SendString("hi")
312+
c.Response().Header.Add("Cache-Time", "1")
313+
now := fmt.Sprintf("%d", time.Now().UnixNano())
314+
return c.SendString(now)
296315
})
297316

298-
req := httptest.NewRequest("GET", "/", nil)
299-
_, err := app.Test(req)
317+
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
300318
utils.AssertEqual(t, nil, err)
301319
utils.AssertEqual(t, true, called)
302-
utils.AssertEqual(t, 6000, newCacheTime)
320+
utils.AssertEqual(t, 1, newCacheTime)
321+
322+
// Sleep until the cache is expired
323+
time.Sleep(1 * time.Second)
324+
325+
cachedResp, err := app.Test(httptest.NewRequest("GET", "/", nil))
326+
utils.AssertEqual(t, nil, err)
327+
328+
body, err := ioutil.ReadAll(resp.Body)
329+
utils.AssertEqual(t, nil, err)
330+
cachedBody, err := ioutil.ReadAll(cachedResp.Body)
331+
utils.AssertEqual(t, nil, err)
332+
333+
if bytes.Equal(body, cachedBody) {
334+
t.Errorf("Cache should have expired: %s, %s", body, cachedBody)
335+
}
336+
337+
// Next response should be cached
338+
cachedRespNextRound, err := app.Test(httptest.NewRequest("GET", "/", nil))
339+
utils.AssertEqual(t, nil, err)
340+
cachedBodyNextRound, err := ioutil.ReadAll(cachedRespNextRound.Body)
341+
utils.AssertEqual(t, nil, err)
342+
343+
if !bytes.Equal(cachedBodyNextRound, cachedBody) {
344+
t.Errorf("Cache should not have expired: %s, %s", cachedBodyNextRound, cachedBody)
345+
}
303346
}
304347

305348
func Test_AdditionalE2EResponseHeaders(t *testing.T) {
349+
t.Parallel()
350+
306351
app := fiber.New()
307352
app.Use(New(Config{
308353
StoreResponseHeaders: true,
@@ -325,6 +370,8 @@ func Test_AdditionalE2EResponseHeaders(t *testing.T) {
325370
}
326371

327372
func Test_CacheHeader(t *testing.T) {
373+
t.Parallel()
374+
328375
app := fiber.New()
329376

330377
app.Use(New(Config{
@@ -364,6 +411,8 @@ func Test_CacheHeader(t *testing.T) {
364411
}
365412

366413
func Test_Cache_WithHead(t *testing.T) {
414+
t.Parallel()
415+
367416
app := fiber.New()
368417
app.Use(New())
369418

@@ -389,6 +438,8 @@ func Test_Cache_WithHead(t *testing.T) {
389438
}
390439

391440
func Test_Cache_WithHeadThenGet(t *testing.T) {
441+
t.Parallel()
442+
392443
app := fiber.New()
393444
app.Use(New())
394445
app.Get("/", func(c *fiber.Ctx) error {
@@ -425,6 +476,8 @@ func Test_Cache_WithHeadThenGet(t *testing.T) {
425476
}
426477

427478
func Test_CustomCacheHeader(t *testing.T) {
479+
t.Parallel()
480+
428481
app := fiber.New()
429482

430483
app.Use(New(Config{

0 commit comments

Comments
 (0)