Skip to content

Commit

Permalink
Merge pull request #976 from Fenny/master
Browse files Browse the repository at this point in the history
💼 implement Storage in cache
  • Loading branch information
Fenny authored Oct 28, 2020
2 parents 0d53f94 + 67eed4c commit 58cd7a9
Show file tree
Hide file tree
Showing 11 changed files with 526 additions and 123 deletions.
158 changes: 113 additions & 45 deletions middleware/cache/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package cache
import (
"strconv"
"sync"
"sync/atomic"
"time"

"github.com/gofiber/fiber/v2"
Expand All @@ -26,13 +27,23 @@ type Config struct {
//
// Optional. Default: false
CacheControl bool

// Store is used to store the state of the middleware
//
// Default: an in memory store for this process only
Store fiber.Storage

// Internally used - if true, the simpler method of two maps is used in order to keep
// execution time down.
defaultStore bool
}

// ConfigDefault is the default config
var ConfigDefault = Config{
Next: nil,
Expiration: 1 * time.Minute,
CacheControl: false,
defaultStore: true,
}

// cache is the manager to store the cached responses
Expand All @@ -42,14 +53,6 @@ type cache struct {
expiration int64
}

// entry defines the cached response
type entry struct {
body []byte
contentType []byte
statusCode int
expiration int64
}

// New creates a new middleware handler
func New(config ...Config) fiber.Handler {
// Set default config
Expand All @@ -66,32 +69,48 @@ func New(config ...Config) fiber.Handler {
if int(cfg.Expiration.Seconds()) == 0 {
cfg.Expiration = ConfigDefault.Expiration
}
if cfg.Store == nil {
cfg.defaultStore = true
}
}

var (
// Cache settings
timestamp = uint64(time.Now().Unix())
expiration = uint64(cfg.Expiration.Seconds())
mux = &sync.RWMutex{}

// Default store logic (if no Store is provided)
entries = make(map[string]entry)
)

// Update timestamp every second
go func() {
for {
atomic.StoreUint64(&timestamp, uint64(time.Now().Unix()))
time.Sleep(1 * time.Second)
}
}()

// Nothing to cache
if int(cfg.Expiration.Seconds()) < 0 {
return func(c *fiber.Ctx) error {
return c.Next()
}
}

// Initialize db
db := &cache{
entries: make(map[string]entry),
expiration: int64(cfg.Expiration.Seconds()),
}
// Remove expired entries
go func() {
for {
// GC the entries every 10 seconds to avoid
// GC the entries every 10 seconds
time.Sleep(10 * time.Second)
db.Lock()
for k := range db.entries {
if time.Now().Unix() >= db.entries[k].expiration {
delete(db.entries, k)
mux.Lock()
for k := range entries {
if atomic.LoadUint64(&timestamp) >= entries[k].exp {
delete(entries, k)
}
}
db.Unlock()
mux.Unlock()
}
}()

Expand All @@ -110,28 +129,65 @@ func New(config ...Config) fiber.Handler {
// Get key from request
key := c.Path()

// Find cached entry
db.RLock()
resp, ok := db.entries[key]
db.RUnlock()
if ok {
// Create new entry
entry := entry{}

// Lock entry
mux.Lock()
defer mux.Unlock()

// Check if we need to use the default in-memory storage
if cfg.defaultStore {
entry = entries[key]

} else {
// Load data from store
storeEntry, err := cfg.Store.Get(key)
if err != nil {
return err
}

// Only decode if we found an entry
if len(storeEntry) > 0 {
// Decode bytes using msgp
if _, err := entry.UnmarshalMsg(storeEntry); err != nil {
return err
}
}
}

// Get timestamp
ts := atomic.LoadUint64(&timestamp)

// Set expiration if entry does not exist
if entry.exp == 0 {
entry.exp = ts + expiration

} else if ts >= entry.exp {
// Check if entry is expired
if time.Now().Unix() >= resp.expiration {
db.Lock()
delete(db.entries, key)
db.Unlock()
} else {
// Set response headers from cache
c.Response().SetBodyRaw(resp.body)
c.Response().SetStatusCode(resp.statusCode)
c.Response().Header.SetContentTypeBytes(resp.contentType)
// Set Cache-Control header if enabled
if cfg.CacheControl {
maxAge := strconv.FormatInt(resp.expiration-time.Now().Unix(), 10)
c.Set(fiber.HeaderCacheControl, "public, max-age="+maxAge)
// Use default memory storage
if cfg.defaultStore {
delete(entries, key)
} else { // Use custom storage
if err := cfg.Store.Delete(key); err != nil {
return err
}
return nil
}

} else {
// Set response headers from cache
c.Response().SetBodyRaw(entry.body)
c.Response().SetStatusCode(entry.status)
c.Response().Header.SetContentTypeBytes(entry.cType)

// Set Cache-Control header if enabled
if cfg.CacheControl {
maxAge := strconv.FormatUint(entry.exp-ts, 10)
c.Set(fiber.HeaderCacheControl, "public, max-age="+maxAge)
}

// Return response
return nil
}

// Continue stack, return err to Fiber if exist
Expand All @@ -140,14 +196,26 @@ func New(config ...Config) fiber.Handler {
}

// Cache response
db.Lock()
db.entries[key] = entry{
body: c.Response().Body(),
statusCode: c.Response().StatusCode(),
contentType: c.Response().Header.ContentType(),
expiration: time.Now().Unix() + db.expiration,
entry.body = c.Response().Body()
entry.status = c.Response().StatusCode()
entry.cType = c.Response().Header.ContentType()

// Use default memory storage
if cfg.defaultStore {
entries[key] = entry

} else {
// Use custom storage
data, err := entry.MarshalMsg(nil)
if err != nil {
return err
}

// Pass bytes to Storage
if err = cfg.Store.Set(key, data, cfg.Expiration); err != nil {
return err
}
}
db.Unlock()

// Finish response
return nil
Expand Down
84 changes: 84 additions & 0 deletions middleware/cache/cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@ import (
"bytes"
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
"sync"
"testing"
"time"

Expand Down Expand Up @@ -91,6 +93,55 @@ func Test_Cache(t *testing.T) {
utils.AssertEqual(t, cachedBody, body)
}

// go test -run Test_Cache_Concurrency_Store -race -v
func Test_Cache_Concurrency_Store(t *testing.T) {
// Test concurrency using a custom store

app := fiber.New()

app.Use(New(Config{
Store: testStore{stmap: map[string][]byte{}, mutex: new(sync.Mutex)},
}))

app.Get("/", func(c *fiber.Ctx) error {
return c.SendString("Hello tester!")
})

var wg sync.WaitGroup
singleRequest := func(wg *sync.WaitGroup) {
defer wg.Done()
resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)

body, err := ioutil.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "Hello tester!", string(body))
}

for i := 0; i <= 49; i++ {
wg.Add(1)
go singleRequest(&wg)
}

wg.Wait()

req := httptest.NewRequest("GET", "/", nil)
resp, err := app.Test(req)
utils.AssertEqual(t, nil, err)

cachedReq := httptest.NewRequest("GET", "/", nil)
cachedResp, err := app.Test(cachedReq)
utils.AssertEqual(t, nil, err)

body, err := ioutil.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)
cachedBody, err := ioutil.ReadAll(cachedResp.Body)
utils.AssertEqual(t, nil, err)

utils.AssertEqual(t, cachedBody, body)
}

func Test_Cache_Invalid_Expiration(t *testing.T) {
app := fiber.New()
cache := New(Config{Expiration: 0 * time.Second})
Expand Down Expand Up @@ -208,3 +259,36 @@ func Benchmark_Cache(b *testing.B) {

utils.AssertEqual(b, fiber.StatusOK, fctx.Response.Header.StatusCode())
}

// testStore is used for testing custom stores
type testStore struct {
stmap map[string][]byte
mutex *sync.Mutex
}

func (s testStore) Get(id string) ([]byte, error) {
s.mutex.Lock()
val, ok := s.stmap[id]
s.mutex.Unlock()
if !ok {
return []byte{}, nil
} else {
return val, nil
}
}

func (s testStore) Set(id string, val []byte, _ time.Duration) error {
s.mutex.Lock()
s.stmap[id] = val
s.mutex.Unlock()

return nil
}

func (s testStore) Clear() error {
return nil
}

func (s testStore) Delete(id string) error {
return nil
}
12 changes: 12 additions & 0 deletions middleware/cache/store.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package cache

// go:generate msgp
// msgp -file="store.go" -o="store_msgp.go" -tests=false -unexported
// don't forget to replace the msgp import path to:
// "github.com/gofiber/fiber/v2/internal/msgp"
type entry struct {
body []byte `msg:"body"`
cType []byte `msg:"cType"`
status int `msg:"status"`
exp uint64 `msg:"exp"`
}
Loading

0 comments on commit 58cd7a9

Please sign in to comment.