diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS deleted file mode 100644 index b83c5cdd26..0000000000 --- a/.github/CODEOWNERS +++ /dev/null @@ -1,4 +0,0 @@ -* @gofiber/maintainers - -middleware/filesystem/* @arsmn -middleware/cache/* @codemicro \ No newline at end of file diff --git a/app.go b/app.go index be4f71833a..74ddaf2b19 100644 --- a/app.go +++ b/app.go @@ -31,7 +31,7 @@ import ( ) // Version of current fiber package -const Version = "2.1.4" +const Version = "2.2.0" // Handler defines a function to serve HTTP requests. type Handler = func(*Ctx) error @@ -258,7 +258,7 @@ type Config struct { // Default: false ReduceMemoryUsage bool `json:"reduce_memory_usage"` - // FEATURE: v2.2.x + // FEATURE: v2.3.x // The router executes the same handler by default if StrictRouting or CaseSensitive is disabled. // Enabling RedirectFixedPath will change this behaviour into a client redirect to the original route path. // Using the status code 301 for GET requests and 308 for all other request methods. @@ -286,6 +286,12 @@ type Static struct { // Optional. Default value "index.html". Index string `json:"index"` + // Expiration duration for inactive file handlers. + // Use a negative time.Duration to disable it. + // + // Optional. Default value 10 * time.Second. + CacheDuration time.Duration `json:"cache_duration"` + // The value for the Cache-Control HTTP-header // that is set on the file response. MaxAge is defined in seconds. // diff --git a/ctx_test.go b/ctx_test.go index c0f1c49741..cd22f51b4c 100644 --- a/ctx_test.go +++ b/ctx_test.go @@ -934,6 +934,33 @@ func Test_Ctx_MultipartForm(t *testing.T) { utils.AssertEqual(t, StatusOK, resp.StatusCode, "Status code") } +// go test -v -run=^$ -bench=Benchmark_Ctx_MultipartForm -benchmem -count=4 +func Benchmark_Ctx_MultipartForm(b *testing.B) { + app := New() + + app.Post("/", func(c *Ctx) error { + _, _ = c.MultipartForm() + return nil + }) + + c := &fasthttp.RequestCtx{} + + body := []byte("--b\r\nContent-Disposition: form-data; name=\"name\"\r\n\r\njohn\r\n--b--") + c.Request.SetBody(body) + c.Request.Header.SetContentType(MIMEMultipartForm + `;boundary="b"`) + c.Request.Header.SetContentLength(len(body)) + + h := app.Handler() + + b.ReportAllocs() + b.ResetTimer() + + for n := 0; n < b.N; n++ { + h(c) + } + +} + // go test -run Test_Ctx_OriginalURL func Test_Ctx_OriginalURL(t *testing.T) { t.Parallel() diff --git a/go.mod b/go.mod index fac2a50da4..6aa61db766 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/gofiber/fiber/v2 go 1.14 require ( + github.com/klauspost/compress v1.11.0 // indirect github.com/valyala/fasthttp v1.17.0 golang.org/x/sys v0.0.0-20201101102859-da207088b7d1 ) diff --git a/go.sum b/go.sum index 80e70e25d0..437f32b44c 100644 --- a/go.sum +++ b/go.sum @@ -2,6 +2,8 @@ github.com/andybalholm/brotli v1.0.0 h1:7UCwP93aiSfvWpapti8g88vVVGp2qqtGyePsSuDa github.com/andybalholm/brotli v1.0.0/go.mod h1:loMXtMfwqflxFJPmdbJO0a3KNoPuLBgiu3qAvBg8x/Y= github.com/klauspost/compress v1.10.7 h1:7rix8v8GpI3ZBb0nSozFRgbtXKv+hOe+qfEpZqybrAg= github.com/klauspost/compress v1.10.7/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs= +github.com/klauspost/compress v1.11.0 h1:wJbzvpYMVGG9iTI9VxpnNZfd4DzMPoCWze3GgSqz8yg= +github.com/klauspost/compress v1.11.0/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs= github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= github.com/valyala/fasthttp v1.17.0 h1:P8/koH4aSnJ4xbd0cUUFEGQs3jQqIxoDDyRQrUiAkqg= diff --git a/helpers_test.go b/helpers_test.go index 5b19869492..9fe4e6f603 100644 --- a/helpers_test.go +++ b/helpers_test.go @@ -8,6 +8,7 @@ import ( "crypto/tls" "fmt" "net" + "strings" "testing" "time" @@ -15,30 +16,68 @@ import ( "github.com/valyala/fasthttp" ) -// go test -v -run=^$ -bench=Benchmark_Utils_RemoveNewLines -benchmem -count=4 -func Benchmark_Utils_RemoveNewLines(b *testing.B) { +// go test -v -run=^$ -bench=Benchmark_RemoveNewLines -benchmem -count=4 +func Benchmark_RemoveNewLines(b *testing.B) { withNL := "foo\r\nSet-Cookie:%20SESSIONID=MaliciousValue\r\n" withoutNL := "foo Set-Cookie:%20SESSIONID=MaliciousValue " expected := utils.SafeString(withoutNL) var res string - b.Run("withNewlines", func(b *testing.B) { + b.Run("withoutNL", func(b *testing.B) { b.ReportAllocs() b.ResetTimer() for n := 0; n < b.N; n++ { - res = removeNewLines(withNL) + res = removeNewLines(withoutNL) } utils.AssertEqual(b, expected, res) }) - b.Run("withoutNewlines", func(b *testing.B) { + b.Run("withNL", func(b *testing.B) { b.ReportAllocs() b.ResetTimer() for n := 0; n < b.N; n++ { - res = removeNewLines(withoutNL) + res = removeNewLines(withNL) } utils.AssertEqual(b, expected, res) }) +} +// go test -v -run=RemoveNewLines_Bytes -count=3 +func Test_RemoveNewLines_Bytes(t *testing.T) { + app := New() + t.Run("Not Status OK", func(t *testing.T) { + c := app.AcquireCtx(&fasthttp.RequestCtx{}) + defer app.ReleaseCtx(c) + c.SendString("Hello, World!") + c.Status(201) + setETag(c, false) + utils.AssertEqual(t, "", string(c.Response().Header.Peek(HeaderETag))) + }) + + t.Run("No Body", func(t *testing.T) { + c := app.AcquireCtx(&fasthttp.RequestCtx{}) + defer app.ReleaseCtx(c) + setETag(c, false) + utils.AssertEqual(t, "", string(c.Response().Header.Peek(HeaderETag))) + }) + + t.Run("Has HeaderIfNoneMatch", func(t *testing.T) { + c := app.AcquireCtx(&fasthttp.RequestCtx{}) + defer app.ReleaseCtx(c) + c.SendString("Hello, World!") + c.Request().Header.Set(HeaderIfNoneMatch, `"13-1831710635"`) + setETag(c, false) + utils.AssertEqual(t, 304, c.Response().StatusCode()) + utils.AssertEqual(t, "", string(c.Response().Header.Peek(HeaderETag))) + utils.AssertEqual(t, "", string(c.Response().Body())) + }) + + t.Run("No HeaderIfNoneMatch", func(t *testing.T) { + c := app.AcquireCtx(&fasthttp.RequestCtx{}) + defer app.ReleaseCtx(c) + c.SendString("Hello, World!") + setETag(c, false) + utils.AssertEqual(t, `"13-1831710635"`, string(c.Response().Header.Peek(HeaderETag))) + }) } // go test -v -run=Test_Utils_ -count=3 @@ -320,3 +359,49 @@ func Test_Utils_lnMetadata(t *testing.T) { utils.AssertEqual(t, true, config != nil) }) } + +// go test -v -run=^$ -bench=Benchmark_SlashRecognition -benchmem -count=4 +func Benchmark_SlashRecognition(b *testing.B) { + search := "wtf/1234" + var result bool + b.Run("indexBytes", func(b *testing.B) { + result = false + for i := 0; i < b.N; i++ { + if strings.IndexByte(search, slashDelimiter) != -1 { + result = true + } + } + utils.AssertEqual(b, true, result) + }) + b.Run("forEach", func(b *testing.B) { + result = false + c := int32(slashDelimiter) + for i := 0; i < b.N; i++ { + for _, b := range search { + if b == c { + result = true + break + } + } + } + utils.AssertEqual(b, true, result) + }) + b.Run("IndexRune", func(b *testing.B) { + result = false + c := int32(slashDelimiter) + for i := 0; i < b.N; i++ { + result = IndexRune(search, c) + + } + utils.AssertEqual(b, true, result) + }) +} + +func IndexRune(str string, needle int32) bool { + for _, b := range str { + if b == needle { + return true + } + } + return false +} diff --git a/internal/storage/memory/config.go b/internal/storage/memory/config.go new file mode 100644 index 0000000000..07d13edb5b --- /dev/null +++ b/internal/storage/memory/config.go @@ -0,0 +1,33 @@ +package memory + +import "time" + +// Config defines the config for storage. +type Config struct { + // Time before deleting expired keys + // + // Default is 10 * time.Second + GCInterval time.Duration +} + +// ConfigDefault is the default config +var ConfigDefault = Config{ + GCInterval: 10 * time.Second, +} + +// configDefault is a helper function to set default values +func configDefault(config ...Config) Config { + // Return default config if nothing provided + if len(config) < 1 { + return ConfigDefault + } + + // Override default config + cfg := config[0] + + // Set default values + if int(cfg.GCInterval.Seconds()) <= 0 { + cfg.GCInterval = ConfigDefault.GCInterval + } + return cfg +} diff --git a/internal/storage/memory/memory.go b/internal/storage/memory/memory.go new file mode 100644 index 0000000000..8490a52855 --- /dev/null +++ b/internal/storage/memory/memory.go @@ -0,0 +1,122 @@ +package memory + +import ( + "errors" + "sync" + "time" +) + +// Storage interface that is implemented by storage providers +type Storage struct { + mux sync.RWMutex + db map[string]entry + gcInterval time.Duration + done chan struct{} +} + +// Common storage errors +var ErrNotExist = errors.New("key does not exist") + +type entry struct { + data []byte + expiry int64 +} + +// New creates a new memory storage +func New(config ...Config) *Storage { + // Set default config + cfg := configDefault(config...) + + // Create storage + store := &Storage{ + db: make(map[string]entry), + gcInterval: cfg.GCInterval, + done: make(chan struct{}), + } + + // Start garbage collector + go store.gc() + + return store +} + +// Get value by key +func (s *Storage) Get(key string) ([]byte, error) { + if len(key) <= 0 { + return nil, ErrNotExist + } + s.mux.RLock() + v, ok := s.db[key] + s.mux.RUnlock() + if !ok || v.expiry != 0 && v.expiry <= time.Now().Unix() { + return nil, ErrNotExist + } + + return v.data, nil +} + +// Set key with value +// Set key with value +func (s *Storage) Set(key string, val []byte, exp time.Duration) error { + // Ain't Nobody Got Time For That + if len(key) <= 0 || len(val) <= 0 { + return nil + } + + var expire int64 + if exp != 0 { + expire = time.Now().Add(exp).Unix() + } + + s.mux.Lock() + s.db[key] = entry{val, expire} + s.mux.Unlock() + return nil +} + +// Delete key by key +func (s *Storage) Delete(key string) error { + // Ain't Nobody Got Time For That + if len(key) <= 0 { + return nil + } + s.mux.Lock() + delete(s.db, key) + s.mux.Unlock() + return nil +} + +// Reset all keys +func (s *Storage) Reset() error { + s.mux.Lock() + s.db = make(map[string]entry) + s.mux.Unlock() + return nil +} + +// Close the memory storage +func (s *Storage) Close() error { + s.done <- struct{}{} + return nil +} + +func (s *Storage) gc() { + ticker := time.NewTicker(s.gcInterval) + defer ticker.Stop() + + for { + select { + case <-s.done: + return + case t := <-ticker.C: + now := t.Unix() + s.mux.Lock() + for id, v := range s.db { + if v.expiry != 0 && v.expiry < now { + delete(s.db, id) + } + } + s.mux.Unlock() + } + } +} diff --git a/middleware/basicauth/basicauth.go b/middleware/basicauth/basicauth.go index c0116544db..3017be09b9 100644 --- a/middleware/basicauth/basicauth.go +++ b/middleware/basicauth/basicauth.go @@ -8,97 +8,10 @@ import ( "github.com/gofiber/fiber/v2/utils" ) -// Config defines the config for middleware. -type Config struct { - // Next defines a function to skip this middleware when returned true. - // - // Optional. Default: nil - Next func(c *fiber.Ctx) bool - - // Users defines the allowed credentials - // - // Required. Default: map[string]string{} - Users map[string]string - - // Realm is a string to define realm attribute of BasicAuth. - // the realm identifies the system to authenticate against - // and can be used by clients to save credentials - // - // Optional. Default: "Restricted". - Realm string - - // Authorizer defines a function you can pass - // to check the credentials however you want. - // It will be called with a username and password - // and is expected to return true or false to indicate - // that the credentials were approved or not. - // - // Optional. Default: nil. - Authorizer func(string, string) bool - - // Unauthorized defines the response body for unauthorized responses. - // By default it will return with a 401 Unauthorized and the correct WWW-Auth header - // - // Optional. Default: nil - Unauthorized fiber.Handler - - // ContextUser is the key to store the username in Locals - // - // Optional. Default: "username" - ContextUsername string - - // ContextPass is the key to store the password in Locals - // - // Optional. Default: "password" - ContextPassword string -} - -// ConfigDefault is the default config -var ConfigDefault = Config{ - Next: nil, - Users: map[string]string{}, - Realm: "Restricted", - Authorizer: nil, - Unauthorized: nil, - ContextUsername: "username", - ContextPassword: "password", -} - // New creates a new middleware handler func New(config Config) fiber.Handler { - cfg := config - - // Set default values - if cfg.Next == nil { - cfg.Next = ConfigDefault.Next - } - if cfg.Users == nil { - cfg.Users = ConfigDefault.Users - } - if cfg.Realm == "" { - cfg.Realm = ConfigDefault.Realm - } - if cfg.Authorizer == nil { - cfg.Authorizer = func(user, pass string) bool { - user, exist := cfg.Users[user] - if !exist { - return false - } - return user == pass - } - } - if cfg.Unauthorized == nil { - cfg.Unauthorized = func(c *fiber.Ctx) error { - c.Set(fiber.HeaderWWWAuthenticate, "basic realm="+cfg.Realm) - return c.SendStatus(fiber.StatusUnauthorized) - } - } - if cfg.ContextUsername == "" { - cfg.ContextUsername = ConfigDefault.ContextUsername - } - if cfg.ContextPassword == "" { - cfg.ContextPassword = ConfigDefault.ContextPassword - } + // Set default config + cfg := configDefault(config) // Return new handler return func(c *fiber.Ctx) error { diff --git a/middleware/basicauth/config.go b/middleware/basicauth/config.go new file mode 100644 index 0000000000..1a4d1fce63 --- /dev/null +++ b/middleware/basicauth/config.go @@ -0,0 +1,105 @@ +package basicauth + +import ( + "github.com/gofiber/fiber/v2" +) + +// Config defines the config for middleware. +type Config struct { + // Next defines a function to skip this middleware when returned true. + // + // Optional. Default: nil + Next func(c *fiber.Ctx) bool + + // Users defines the allowed credentials + // + // Required. Default: map[string]string{} + Users map[string]string + + // Realm is a string to define realm attribute of BasicAuth. + // the realm identifies the system to authenticate against + // and can be used by clients to save credentials + // + // Optional. Default: "Restricted". + Realm string + + // Authorizer defines a function you can pass + // to check the credentials however you want. + // It will be called with a username and password + // and is expected to return true or false to indicate + // that the credentials were approved or not. + // + // Optional. Default: nil. + Authorizer func(string, string) bool + + // Unauthorized defines the response body for unauthorized responses. + // By default it will return with a 401 Unauthorized and the correct WWW-Auth header + // + // Optional. Default: nil + Unauthorized fiber.Handler + + // ContextUser is the key to store the username in Locals + // + // Optional. Default: "username" + ContextUsername string + + // ContextPass is the key to store the password in Locals + // + // Optional. Default: "password" + ContextPassword string +} + +// ConfigDefault is the default config +var ConfigDefault = Config{ + Next: nil, + Users: map[string]string{}, + Realm: "Restricted", + Authorizer: nil, + Unauthorized: nil, + ContextUsername: "username", + ContextPassword: "password", +} + +// Helper function to set default values +func configDefault(config ...Config) Config { + // Return default config if nothing provided + if len(config) < 1 { + return ConfigDefault + } + + // Override default config + cfg := config[0] + + // Set default values + if cfg.Next == nil { + cfg.Next = ConfigDefault.Next + } + if cfg.Users == nil { + cfg.Users = ConfigDefault.Users + } + if cfg.Realm == "" { + cfg.Realm = ConfigDefault.Realm + } + if cfg.Authorizer == nil { + cfg.Authorizer = func(user, pass string) bool { + user, exist := cfg.Users[user] + if !exist { + return false + } + return user == pass + } + } + if cfg.Unauthorized == nil { + cfg.Unauthorized = func(c *fiber.Ctx) error { + c.Set(fiber.HeaderWWWAuthenticate, "basic realm="+cfg.Realm) + return c.SendStatus(fiber.StatusUnauthorized) + } + } + if cfg.ContextUsername == "" { + cfg.ContextUsername = ConfigDefault.ContextUsername + } + if cfg.ContextPassword == "" { + cfg.ContextPassword = ConfigDefault.ContextPassword + } + return cfg +} diff --git a/middleware/cache/cache.go b/middleware/cache/cache.go index f374a67ba0..8300498a50 100644 --- a/middleware/cache/cache.go +++ b/middleware/cache/cache.go @@ -12,74 +12,10 @@ import ( "github.com/gofiber/fiber/v2/utils" ) -// Config defines the config for middleware. -type Config struct { - // Next defines a function to skip this middleware when returned true. - // - // Optional. Default: nil - Next func(c *fiber.Ctx) bool - - // Expiration is the time that an cached response will live - // - // Optional. Default: 1 * time.Minute - Expiration time.Duration - - // CacheControl enables client side caching if set to true - // - // Optional. Default: false - CacheControl bool - - // Key allows you to generate custom keys, by default c.Path() is used - // - // Default: func(c *fiber.Ctx) string { - // return c.Path() - // } - Key func(*fiber.Ctx) string - - // Store is used to store the state of the middleware - // - // Default: an in memory store for this process only - Store fiber.Storage - - // Internally used - if true, the simpler method of two maps is used in order to keep - // execution time down. - defaultStore bool -} - -// ConfigDefault is the default config -var ConfigDefault = Config{ - Next: nil, - Expiration: 1 * time.Minute, - CacheControl: false, - Key: func(c *fiber.Ctx) string { - return c.Path() - }, - defaultStore: true, -} - // New creates a new middleware handler func New(config ...Config) fiber.Handler { // Set default config - cfg := ConfigDefault - - // Override config if provided - if len(config) > 0 { - cfg = config[0] - - // Set default values - if cfg.Next == nil { - cfg.Next = ConfigDefault.Next - } - if int(cfg.Expiration.Seconds()) == 0 { - cfg.Expiration = ConfigDefault.Expiration - } - if cfg.Key == nil { - cfg.Key = ConfigDefault.Key - } - if cfg.Store == nil { - cfg.defaultStore = true - } - } + cfg := configDefault(config...) var ( // Cache settings @@ -152,7 +88,7 @@ func New(config ...Config) fiber.Handler { } else { // Load data from store - storeEntry, err := cfg.Store.Get(key) + storeEntry, err := cfg.Storage.Get(key) if err != nil { return err } @@ -165,7 +101,7 @@ func New(config ...Config) fiber.Handler { } } - if entryBody, err = cfg.Store.Get(key + "_body"); err != nil { + if entryBody, err = cfg.Storage.Get(key + "_body"); err != nil { return err } } @@ -183,10 +119,10 @@ func New(config ...Config) fiber.Handler { if cfg.defaultStore { delete(entries, key) } else { // Use custom storage - if err := cfg.Store.Delete(key); err != nil { + if err := cfg.Storage.Delete(key); err != nil { return err } - if err := cfg.Store.Delete(key + "_body"); err != nil { + if err := cfg.Storage.Delete(key + "_body"); err != nil { return err } } @@ -234,12 +170,12 @@ func New(config ...Config) fiber.Handler { } // Pass bytes to Storage - if err = cfg.Store.Set(key, data, cfg.Expiration); err != nil { + if err = cfg.Storage.Set(key, data, cfg.Expiration); err != nil { return err } // Pass bytes to Storage - if err = cfg.Store.Set(key+"_body", entryBody, cfg.Expiration); err != nil { + if err = cfg.Storage.Set(key+"_body", entryBody, cfg.Expiration); err != nil { return err } } diff --git a/middleware/cache/config.go b/middleware/cache/config.go new file mode 100644 index 0000000000..10a3017ba6 --- /dev/null +++ b/middleware/cache/config.go @@ -0,0 +1,87 @@ +package cache + +import ( + "fmt" + "time" + + "github.com/gofiber/fiber/v2" +) + +// Config defines the config for middleware. +type Config struct { + // Next defines a function to skip this middleware when returned true. + // + // Optional. Default: nil + Next func(c *fiber.Ctx) bool + + // Expiration is the time that an cached response will live + // + // Optional. Default: 1 * time.Minute + Expiration time.Duration + + // CacheControl enables client side caching if set to true + // + // Optional. Default: false + CacheControl bool + + // Key allows you to generate custom keys, by default c.Path() is used + // + // Default: func(c *fiber.Ctx) string { + // return c.Path() + // } + Key func(*fiber.Ctx) string + + // Deprecated, use Storage instead + Store fiber.Storage + + // Store is used to store the state of the middleware + // + // Default: an in memory store for this process only + Storage fiber.Storage + + // Internally used - if true, the simpler method of two maps is used in order to keep + // execution time down. + defaultStore bool +} + +// ConfigDefault is the default config +var ConfigDefault = Config{ + Next: nil, + Expiration: 1 * time.Minute, + CacheControl: false, + Key: func(c *fiber.Ctx) string { + return c.Path() + }, + defaultStore: true, +} + +// Helper function to set default values +func configDefault(config ...Config) Config { + // Return default config if nothing provided + if len(config) < 1 { + return ConfigDefault + } + + // Override default config + cfg := config[0] + + // Set default values + if cfg.Next == nil { + cfg.Next = ConfigDefault.Next + } + if int(cfg.Expiration.Seconds()) == 0 { + cfg.Expiration = ConfigDefault.Expiration + } + if cfg.Key == nil { + cfg.Key = ConfigDefault.Key + } + if cfg.Storage == nil && cfg.Store == nil { + cfg.defaultStore = true + } + if cfg.Store != nil { + fmt.Println("cache: `Store` is deprecated, use `Storage` instead") + cfg.Storage = cfg.Store + cfg.defaultStore = true + } + return cfg +} diff --git a/middleware/compress/compress.go b/middleware/compress/compress.go index 80919ab0a2..e65d78558b 100644 --- a/middleware/compress/compress.go +++ b/middleware/compress/compress.go @@ -5,54 +5,10 @@ import ( "github.com/valyala/fasthttp" ) -// Config defines the config for middleware. -type Config struct { - // Next defines a function to skip this middleware when returned true. - // - // Optional. Default: nil - Next func(c *fiber.Ctx) bool - - // Level determines the compression algorithm - // - // Optional. Default: LevelDefault - // LevelDisabled: -1 - // LevelDefault: 0 - // LevelBestSpeed: 1 - // LevelBestCompression: 2 - Level Level -} - -// Level is numeric representation of compression level -type Level int - -// Represents compression level that will be used in the middleware -const ( - LevelDisabled Level = -1 - LevelDefault Level = 0 - LevelBestSpeed Level = 1 - LevelBestCompression Level = 2 -) - -// ConfigDefault is the default config -var ConfigDefault = Config{ - Next: nil, - Level: LevelDefault, -} - // New creates a new middleware handler func New(config ...Config) fiber.Handler { // Set default config - cfg := ConfigDefault - - // Override config if provided - if len(config) > 0 { - cfg = config[0] - - // Set default values - if cfg.Level < LevelDisabled || cfg.Level > LevelBestCompression { - cfg.Level = ConfigDefault.Level - } - } + cfg := configDefault(config...) // Setup request handlers var ( diff --git a/middleware/compress/config.go b/middleware/compress/config.go new file mode 100644 index 0000000000..5495ad4c42 --- /dev/null +++ b/middleware/compress/config.go @@ -0,0 +1,56 @@ +package compress + +import ( + "github.com/gofiber/fiber/v2" +) + +// Config defines the config for middleware. +type Config struct { + // Next defines a function to skip this middleware when returned true. + // + // Optional. Default: nil + Next func(c *fiber.Ctx) bool + + // Level determines the compression algorithm + // + // Optional. Default: LevelDefault + // LevelDisabled: -1 + // LevelDefault: 0 + // LevelBestSpeed: 1 + // LevelBestCompression: 2 + Level Level +} + +// Level is numeric representation of compression level +type Level int + +// Represents compression level that will be used in the middleware +const ( + LevelDisabled Level = -1 + LevelDefault Level = 0 + LevelBestSpeed Level = 1 + LevelBestCompression Level = 2 +) + +// ConfigDefault is the default config +var ConfigDefault = Config{ + Next: nil, + Level: LevelDefault, +} + +// Helper function to set default values +func configDefault(config ...Config) Config { + // Return default config if nothing provided + if len(config) < 1 { + return ConfigDefault + } + + // Override default config + cfg := config[0] + + // Set default values + if cfg.Level < LevelDisabled || cfg.Level > LevelBestCompression { + cfg.Level = ConfigDefault.Level + } + return cfg +} diff --git a/middleware/csrf/README.md b/middleware/csrf/README.md index 611da44f5d..abb8d203ac 100644 --- a/middleware/csrf/README.md +++ b/middleware/csrf/README.md @@ -1,5 +1,6 @@ # CSRF -CSRF middleware for [Fiber](https://github.com/gofiber/fiber) that provides [Cross-site request forgery](https://en.wikipedia.org/wiki/Cross-site_request_forgery) protection by passing a csrf token via cookies. This cookie value will be used to compare against the client csrf token in POST requests. When the csrf token is invalid, this middleware will return the `fiber.ErrForbidden` error. +CSRF middleware for [Fiber](https://github.com/gofiber/fiber) that provides [Cross-site request forgery](https://en.wikipedia.org/wiki/Cross-site_request_forgery) protection by passing a csrf token via cookies. This cookie value will be used to compare against the client csrf token in POST requests. When the csrf token is invalid, this middleware will delete the `_csrf` cookie and return the `fiber.ErrForbidden` error. +CSRF Tokens are generated on GET requests. ### Table of Contents - [Signatures](#signatures) @@ -29,12 +30,11 @@ app.Use(csrf.New()) // Or extend your config for customization app.Use(csrf.New(csrf.Config{ - TokenLookup: "header:X-CSRF-Token", - ContextKey: "csrf", - Cookie: &fiber.Cookie{ - Name: "_csrf", - }, - Expiration: 24 * time.Hour, + KeyLookup: "header:X-Csrf-Token", + CookieName: "csrf_", + CookieSameSite: "Strict", + Expiration: 1 * time.Hour, + KeyGenerator: utils.UUID, })) ``` @@ -47,44 +47,72 @@ type Config struct { // Optional. Default: nil Next func(c *fiber.Ctx) bool - // TokenLookup is a string in the form of ":" that is used + // KeyLookup is a string in the form of ":" that is used // to extract token from the request. - // - // Optional. Default value "header:X-CSRF-Token". // Possible values: // - "header:" // - "query:" // - "param:" // - "form:" - TokenLookup string - - // Cookie + // - "cookie:" // - // Optional. - Cookie *fiber.Cookie + // Optional. Default: "header:X-CSRF-Token" + KeyLookup string + + // Name of the session cookie. This cookie will store session key. + // Optional. Default value "_csrf". + CookieName string + + // Domain of the CSRF cookie. + // Optional. Default value "". + CookieDomain string + + // Path of the CSRF cookie. + // Optional. Default value "". + CookiePath string + + // Indicates if CSRF cookie is secure. + // Optional. Default value false. + CookieSecure bool + + // Indicates if CSRF cookie is HTTP only. + // Optional. Default value false. + CookieHTTPOnly bool + + // Indicates if CSRF cookie is HTTP only. + // Optional. Default value "Strict". + CookieSameSite string // Expiration is the duration before csrf token will expire // - // Optional. Default: 24 * time.Hour + // Optional. Default: 1 * time.Hour Expiration time.Duration + // Store is used to store the state of the middleware + // + // Optional. Default: memory.New() + Storage fiber.Storage + // Context key to store generated CSRF token into context. + // If left empty, token will not be stored in context. // - // Optional. Default value "csrf". + // Optional. Default: "" ContextKey string + + // KeyGenerator creates a new CSRF token + // + // Optional. Default: utils.UUID + KeyGenerator func() string } ``` ### Default Config ```go var ConfigDefault = Config{ - Next: nil, - TokenLookup: "header:X-CSRF-Token", - ContextKey: "csrf", - Cookie: &fiber.Cookie{ - Name: "_csrf", - SameSite: "Strict", - }, - Expiration: 24 * time.Hour, + KeyLookup: "header:X-Csrf-Token", + CookieName: "csrf_", + CookieSameSite: "Strict", + Expiration: 1 * time.Hour, + KeyGenerator: utils.UUID, } ``` diff --git a/middleware/csrf/config.go b/middleware/csrf/config.go new file mode 100644 index 0000000000..907e896cc9 --- /dev/null +++ b/middleware/csrf/config.go @@ -0,0 +1,147 @@ +package csrf + +import ( + "fmt" + "time" + + "github.com/gofiber/fiber/v2" + "github.com/gofiber/fiber/v2/utils" +) + +// Config defines the config for middleware. +type Config struct { + // Next defines a function to skip this middleware when returned true. + // + // Optional. Default: nil + Next func(c *fiber.Ctx) bool + + // KeyLookup is a string in the form of ":" that is used + // to extract token from the request. + // Possible values: + // - "header:" + // - "query:" + // - "param:" + // - "form:" + // - "cookie:" + // + // Optional. Default: "header:X-CSRF-Token" + KeyLookup string + + // Name of the session cookie. This cookie will store session key. + // Optional. Default value "_csrf". + CookieName string + + // Domain of the CSRF cookie. + // Optional. Default value "". + CookieDomain string + + // Path of the CSRF cookie. + // Optional. Default value "". + CookiePath string + + // Indicates if CSRF cookie is secure. + // Optional. Default value false. + CookieSecure bool + + // Indicates if CSRF cookie is HTTP only. + // Optional. Default value false. + CookieHTTPOnly bool + + // Indicates if CSRF cookie is HTTP only. + // Optional. Default value "Strict". + CookieSameSite string + + // Expiration is the duration before csrf token will expire + // + // Optional. Default: 1 * time.Hour + Expiration time.Duration + + // Store is used to store the state of the middleware + // + // Optional. Default: memory.New() + Storage fiber.Storage + + // Context key to store generated CSRF token into context. + // If left empty, token will not be stored in context. + // + // Optional. Default: "" + ContextKey string + + // KeyGenerator creates a new CSRF token + // + // Optional. Default: utils.UUID + KeyGenerator func() string + + // Deprecated, please use Expiration + CookieExpires time.Duration + + // Deprecated, please use Cookie* related fields + Cookie *fiber.Cookie + + // Deprecated, please use KeyLookup + TokenLookup string +} + +// ConfigDefault is the default config +var ConfigDefault = Config{ + KeyLookup: "header:X-Csrf-Token", + CookieName: "csrf_", + CookieSameSite: "Strict", + Expiration: 1 * time.Hour, + KeyGenerator: utils.UUID, +} + +// Helper function to set default values +func configDefault(config ...Config) Config { + // Return default config if nothing provided + if len(config) < 1 { + return ConfigDefault + } + + // Override default config + cfg := config[0] + + // Set default values + if cfg.TokenLookup != "" { + fmt.Println("[CSRF] TokenLookup is deprecated, please use KeyLookup") + cfg.KeyLookup = cfg.TokenLookup + } + if int(cfg.CookieExpires.Seconds()) > 0 { + fmt.Println("[CSRF] CookieExpires is deprecated, please use Expiration") + cfg.Expiration = cfg.CookieExpires + } + if cfg.Cookie != nil { + fmt.Println("[CSRF] Cookie is deprecated, please use Cookie* related fields") + if cfg.Cookie.Name != "" { + cfg.CookieName = cfg.Cookie.Name + } + if cfg.Cookie.Domain != "" { + cfg.CookieDomain = cfg.Cookie.Domain + } + if cfg.Cookie.Path != "" { + cfg.CookiePath = cfg.Cookie.Path + } + cfg.CookieSecure = cfg.Cookie.Secure + cfg.CookieHTTPOnly = cfg.Cookie.HTTPOnly + if cfg.Cookie.SameSite != "" { + cfg.CookieSameSite = cfg.Cookie.SameSite + } + } + if cfg.KeyLookup == "" { + cfg.KeyLookup = ConfigDefault.KeyLookup + } + if int(cfg.Expiration.Seconds()) <= 0 { + cfg.Expiration = ConfigDefault.Expiration + } + if cfg.CookieName == "" { + cfg.CookieName = ConfigDefault.CookieName + } + if cfg.CookieSameSite == "" { + cfg.CookieSameSite = ConfigDefault.CookieSameSite + } + if cfg.KeyGenerator == nil { + cfg.KeyGenerator = ConfigDefault.KeyGenerator + } + + return cfg +} diff --git a/middleware/csrf/csrf.go b/middleware/csrf/csrf.go index c99297a9d4..681dd20bd8 100644 --- a/middleware/csrf/csrf.go +++ b/middleware/csrf/csrf.go @@ -3,115 +3,33 @@ package csrf import ( "errors" "fmt" + "net/textproto" "strings" - "sync" "time" "github.com/gofiber/fiber/v2" - "github.com/gofiber/fiber/v2/utils" + "github.com/gofiber/fiber/v2/internal/storage/memory" ) -// Config defines the config for middleware. -type Config struct { - // Next defines a function to skip this middleware when returned true. - // - // Optional. Default: nil - Next func(c *fiber.Ctx) bool - - // TokenLookup is a string in the form of ":" that is used - // to extract token from the request. - // - // Optional. Default value "header:X-CSRF-Token". - // Possible values: - // - "header:" - // - "query:" - // - "param:" - // - "form:" - // - "cookie:" - TokenLookup string - - // Cookie - // - // Optional. - Cookie *fiber.Cookie - - // Deprecated, please use Expiration - CookieExpires time.Duration - - // Expiration is the duration before csrf token will expire - // - // Optional. Default: 24 * time.Hour - Expiration time.Duration - - // Context key to store generated CSRF token into context. - // - // Optional. Default value "csrf". - ContextKey string -} - -// ConfigDefault is the default config -var ConfigDefault = Config{ - Next: nil, - TokenLookup: "header:X-CSRF-Token", - ContextKey: "csrf", - Cookie: &fiber.Cookie{ - Name: "_csrf", - SameSite: "Strict", - }, - Expiration: 24 * time.Hour, - CookieExpires: 24 * time.Hour, // deprecated -} - -type storage struct { - sync.RWMutex - tokens map[string]int64 -} - // New creates a new middleware handler func New(config ...Config) fiber.Handler { // Set default config - cfg := ConfigDefault - - // Override config if provided - if len(config) > 0 { - cfg = config[0] + cfg := configDefault(config...) - // Set default values - if cfg.TokenLookup == "" { - cfg.TokenLookup = ConfigDefault.TokenLookup - } - if cfg.ContextKey == "" { - cfg.ContextKey = ConfigDefault.ContextKey - } - if cfg.CookieExpires != 0 { - fmt.Println("[CSRF] CookieExpires is deprecated, please use Expiration") - cfg.CookieExpires = ConfigDefault.Expiration - } - if cfg.Expiration == 0 { - cfg.Expiration = ConfigDefault.Expiration - } - if cfg.Cookie != nil { - if cfg.Cookie.Name == "" { - cfg.Cookie.Name = ConfigDefault.Cookie.Name - } - if cfg.Cookie.SameSite == "" { - cfg.Cookie.SameSite = ConfigDefault.Cookie.SameSite - } - } else { - cfg.Cookie = ConfigDefault.Cookie - } + // Set default values + if cfg.Storage == nil { + cfg.Storage = memory.New() } - expiration := int64(cfg.Expiration.Seconds()) // Generate the correct extractor to get the token from the correct location - selectors := strings.Split(cfg.TokenLookup, ":") + selectors := strings.Split(cfg.KeyLookup, ":") if len(selectors) != 2 { - panic("csrf: Token lookup must in the form of :") + panic("[CSRF] KeyLookup must in the form of :") } // By default we extract from a header - extractor := csrfFromHeader(selectors[1]) + extractor := csrfFromHeader(textproto.CanonicalMIMEHeaderKey(selectors[1])) switch selectors[0] { case "form": @@ -121,108 +39,110 @@ func New(config ...Config) fiber.Handler { case "param": extractor = csrfFromParam(selectors[1]) case "cookie": - if selectors[1] == cfg.Cookie.Name { - panic(fmt.Sprintf("TokenLookup key %s can't be the same as Cookie.Name %s", selectors[1], cfg.Cookie.Name)) + if selectors[1] == cfg.CookieName { + panic(fmt.Sprintf("KeyLookup key %s can't be the same as CookieName %s", selectors[1], cfg.CookieName)) } extractor = csrfFromCookie(selectors[1]) } - // create new db - db := storage{ - tokens: make(map[string]int64), - } - // Remove expired entries - go func() { - for { - // GC the tokens every 10 seconds to avoid - time.Sleep(10 * time.Second) - db.Lock() - for t := range db.tokens { - if time.Now().Unix() >= db.tokens[t] { - delete(db.tokens, t) - } - } - db.Unlock() - } - }() + // We only use Keys in Storage, so we need a dummy value + dummyVal := []byte{'+'} // Return new handler - return func(c *fiber.Ctx) error { + return func(c *fiber.Ctx) (err error) { // Don't execute middleware if Next returns true - if (cfg.Next != nil && cfg.Next(c)) || - // Or non GET/POST method - (c.Method() != fiber.MethodGet && c.Method() != fiber.MethodPost) { + if cfg.Next != nil && cfg.Next(c) { return c.Next() } - // Declare empty token and try to get previous generated CSRF from cookie - token, key := "", c.Cookies(cfg.Cookie.Name) + var token string - // Check if the cookie had a CSRF token - if key == "" { - // Create a new CSRF token - token = utils.UUID() - // Add token with timestamp expiration - db.Lock() - db.tokens[token] = time.Now().Unix() + expiration - db.Unlock() - } else { - // Use the server generated token previously to compare - // To the extracted token later on - token = key - } + // Action depends on the HTTP method + switch c.Method() { + case fiber.MethodGet: + // Declare empty token and try to get existing CSRF from cookie + token = c.Cookies(cfg.CookieName) + + // Generate CSRF token if not exist + if token == "" { + // Generate new CSRF token + token = cfg.KeyGenerator() + + // Add token to Storage + if err = cfg.Storage.Set(token, dummyVal, cfg.Expiration); err != nil { + fmt.Println("[CSRF]", err.Error()) + } + } + + // Create cookie to pass token to client + cookie := &fiber.Cookie{ + Name: cfg.CookieName, + Value: token, + Domain: cfg.CookieDomain, + Path: cfg.CookiePath, + Expires: time.Now().Add(cfg.Expiration), + Secure: cfg.CookieSecure, + HTTPOnly: cfg.CookieHTTPOnly, + SameSite: cfg.CookieSameSite, + } - // Verify CSRF token on POST requests - if c.Method() == fiber.MethodPost { - // Extract token from client request i.e. header, query, param or form - csrf, err := extractor(c) + // Set cookie to response + c.Cookie(cookie) + case fiber.MethodPost, fiber.MethodDelete, fiber.MethodPatch, fiber.MethodPut: + // Verify CSRF token + // Extract token from client request i.e. header, query, param, form or cookie + token, err = extractor(c) if err != nil { - // We have a problem extracting the csrf token return fiber.ErrForbidden } - - // Get token from DB - db.RLock() - t, ok := db.tokens[csrf] - db.RUnlock() - // Check if token exist or expired - if !ok || time.Now().Unix() >= t { + // We have a problem extracting the csrf token from Storage + if _, err = cfg.Storage.Get(token); err != nil { + // The token is invalid, let client generate a new one + if err = cfg.Storage.Delete(token); err != nil { + fmt.Println("[CSRF]", err.Error()) + } + // Expire cookie + c.Cookie(&fiber.Cookie{ + Name: cfg.CookieName, + Domain: cfg.CookieDomain, + Path: cfg.CookiePath, + Expires: time.Now().Add(-1 * time.Minute), + Secure: cfg.CookieSecure, + HTTPOnly: cfg.CookieHTTPOnly, + SameSite: cfg.CookieSameSite, + }) return fiber.ErrForbidden } } - // Create new cookie to send new CSRF token - cookie := &fiber.Cookie{ - Name: cfg.Cookie.Name, - Value: token, - Domain: cfg.Cookie.Domain, - Path: cfg.Cookie.Path, - Expires: time.Now().Add(cfg.Expiration), - Secure: cfg.Cookie.Secure, - HTTPOnly: cfg.Cookie.HTTPOnly, - SameSite: cfg.Cookie.SameSite, - } - - // Set cookie to response - c.Cookie(cookie) - // Store token in context - c.Locals(cfg.ContextKey, token) - // Protect clients from caching the response by telling the browser // a new header value is generated c.Vary(fiber.HeaderCookie) + // Store token in context if set + if cfg.ContextKey != "" { + c.Locals(cfg.ContextKey, token) + } + // Continue stack return c.Next() } } +var ( + errMissingHeader = errors.New("missing csrf token in header") + errMissingQuery = errors.New("missing csrf token in query") + errMissingParam = errors.New("missing csrf token in param") + errMissingForm = errors.New("missing csrf token in form") + errMissingCookie = errors.New("missing csrf token in cookie") +) + // csrfFromHeader returns a function that extracts token from the request header. func csrfFromHeader(param string) func(c *fiber.Ctx) (string, error) { return func(c *fiber.Ctx) (string, error) { token := c.Get(param) if token == "" { - return "", errors.New("missing csrf token in header") + return "", errMissingHeader } return token, nil } @@ -233,7 +153,7 @@ func csrfFromQuery(param string) func(c *fiber.Ctx) (string, error) { return func(c *fiber.Ctx) (string, error) { token := c.Query(param) if token == "" { - return "", errors.New("missing csrf token in query string") + return "", errMissingQuery } return token, nil } @@ -244,18 +164,18 @@ func csrfFromParam(param string) func(c *fiber.Ctx) (string, error) { return func(c *fiber.Ctx) (string, error) { token := c.Params(param) if token == "" { - return "", errors.New("missing csrf token in url parameter") + return "", errMissingParam } return token, nil } } -// csrfFromParam returns a function that extracts a token from a multipart-form. +// csrfFromForm returns a function that extracts a token from a multipart-form. func csrfFromForm(param string) func(c *fiber.Ctx) (string, error) { return func(c *fiber.Ctx) (string, error) { token := c.FormValue(param) if token == "" { - return "", errors.New("missing csrf token in form parameter") + return "", errMissingForm } return token, nil } @@ -266,7 +186,7 @@ func csrfFromCookie(param string) func(c *fiber.Ctx) (string, error) { return func(c *fiber.Ctx) (string, error) { token := c.Cookies(param) if token == "" { - return "", errors.New("missing csrf token in cookie") + return "", errMissingCookie } return token, nil } diff --git a/middleware/csrf/csrf_test.go b/middleware/csrf/csrf_test.go index ce36439802..e227f659f8 100644 --- a/middleware/csrf/csrf_test.go +++ b/middleware/csrf/csrf_test.go @@ -48,6 +48,8 @@ func Test_CSRF(t *testing.T) { ctx.Response.Reset() ctx.Request.Header.SetMethod("GET") h(ctx) + token = string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)) + token = strings.Split(strings.Split(token, ";")[0], "=")[1] ctx.Request.Reset() ctx.Response.Reset() @@ -71,13 +73,13 @@ func Test_CSRF_Next(t *testing.T) { utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode) } -func Test_CSRF_Invalid_TokenLookup(t *testing.T) { +func Test_CSRF_Invalid_KeyLookup(t *testing.T) { defer func() { - utils.AssertEqual(t, "csrf: Token lookup must in the form of :", recover()) + utils.AssertEqual(t, "[CSRF] KeyLookup must in the form of :", recover()) }() app := fiber.New() - app.Use(New(Config{TokenLookup: "I:am:invalid"})) + app.Use(New(Config{KeyLookup: "I:am:invalid"})) app.Post("/", func(c *fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) @@ -92,7 +94,7 @@ func Test_CSRF_Invalid_TokenLookup(t *testing.T) { func Test_CSRF_From_Form(t *testing.T) { app := fiber.New() - app.Use(New(Config{TokenLookup: "form:_csrf"})) + app.Use(New(Config{KeyLookup: "form:_csrf"})) app.Post("/", func(c *fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) @@ -125,7 +127,7 @@ func Test_CSRF_From_Form(t *testing.T) { func Test_CSRF_From_Query(t *testing.T) { app := fiber.New() - app.Use(New(Config{TokenLookup: "query:_csrf"})) + app.Use(New(Config{KeyLookup: "query:_csrf"})) app.Post("/", func(c *fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) @@ -161,7 +163,7 @@ func Test_CSRF_From_Query(t *testing.T) { func Test_CSRF_From_Param(t *testing.T) { app := fiber.New() - csrfGroup := app.Group("/:csrf", New(Config{TokenLookup: "param:csrf"})) + csrfGroup := app.Group("/:csrf", New(Config{KeyLookup: "param:csrf"})) csrfGroup.Post("/", func(c *fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) @@ -197,7 +199,7 @@ func Test_CSRF_From_Param(t *testing.T) { func Test_CSRF_From_Cookie(t *testing.T) { app := fiber.New() - csrfGroup := app.Group("/", New(Config{TokenLookup: "cookie:csrf"})) + csrfGroup := app.Group("/", New(Config{KeyLookup: "cookie:csrf"})) csrfGroup.Post("/", func(c *fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) diff --git a/middleware/limiter/README.md b/middleware/limiter/README.md index 6b9baceaeb..39f9792dc9 100644 --- a/middleware/limiter/README.md +++ b/middleware/limiter/README.md @@ -30,13 +30,12 @@ After you initiate your Fiber app, you can use the following possibilities: app.Use(limiter.New()) // Or extend your config for customization - app.Use(limiter.New(limiter.Config{ Next: func(c *fiber.Ctx) bool { return c.IP() == "127.0.0.1" }, Max: 20, - Expiration: 30 * time.Second, + Duration: 30 * time.Second, Key: func(c *fiber.Ctx) string { return c.Get("x-forwarded-for") }, @@ -61,17 +60,17 @@ type Config struct { // Default: 5 Max int - // Expiration is the time on how long to keep records of requests in memory - // - // Default: time.Minute - Expiration time.Duration - - // Key allows you to generate custom keys, by default c.IP() is used + // KeyGenerator allows you to generate custom keys, by default c.IP() is used // // Default: func(c *fiber.Ctx) string { // return c.IP() // } - Key func(*fiber.Ctx) string + KeyGenerator func(*fiber.Ctx) string + + // Expiration is the time on how long to keep records of requests in memory + // + // Default: 1 * time.Minute + Expiration time.Duration // LimitReached is called when a request hits the limit // @@ -80,12 +79,10 @@ type Config struct { // } LimitReached fiber.Handler - // Store is used to store the state of the middleware. - // If no store is supplied, an in-memory store is used. If a store is supplied, - // it must implement the `Storage` interface. + // Store is used to store the state of the middleware // - // Default: in memory - Store Storage + // Default: an in memory store for this process only + Storage fiber.Storage } ``` @@ -94,14 +91,13 @@ A custom store can be used if it implements the `Storage` interface - more detai ### Default Config ```go var ConfigDefault = Config{ - Next: nil, - Max: 5, - Duration: time.Minute, - Key: func(c *fiber.Ctx) string { + Max: 5, + Expiration: 1 * time.Minute, + KeyGenerator: func(c *fiber.Ctx) string { return c.IP() }, LimitReached: func(c *fiber.Ctx) error { return c.SendStatus(fiber.StatusTooManyRequests) }, } -``` +``` \ No newline at end of file diff --git a/middleware/limiter/config.go b/middleware/limiter/config.go new file mode 100644 index 0000000000..a08759a400 --- /dev/null +++ b/middleware/limiter/config.go @@ -0,0 +1,107 @@ +package limiter + +import ( + "fmt" + "time" + + "github.com/gofiber/fiber/v2" +) + +// Config defines the config for middleware. +type Config struct { + // Next defines a function to skip this middleware when returned true. + // + // Optional. Default: nil + Next func(c *fiber.Ctx) bool + + // Max number of recent connections during `Duration` seconds before sending a 429 response + // + // Default: 5 + Max int + + // KeyGenerator allows you to generate custom keys, by default c.IP() is used + // + // Default: func(c *fiber.Ctx) string { + // return c.IP() + // } + KeyGenerator func(*fiber.Ctx) string + + // Expiration is the time on how long to keep records of requests in memory + // + // Default: 1 * time.Minute + Expiration time.Duration + + // LimitReached is called when a request hits the limit + // + // Default: func(c *fiber.Ctx) error { + // return c.SendStatus(fiber.StatusTooManyRequests) + // } + LimitReached fiber.Handler + + // Store is used to store the state of the middleware + // + // Default: an in memory store for this process only + Storage fiber.Storage + + // DEPRECATED: Use Expiration instead + Duration time.Duration + + // DEPRECATED, use Storage instead + Store fiber.Storage + + // DEPRECATED, use KeyGenerator instead + Key func(*fiber.Ctx) string +} + +// ConfigDefault is the default config +var ConfigDefault = Config{ + Max: 5, + Expiration: 1 * time.Minute, + KeyGenerator: func(c *fiber.Ctx) string { + return c.IP() + }, + LimitReached: func(c *fiber.Ctx) error { + return c.SendStatus(fiber.StatusTooManyRequests) + }, +} + +// Helper function to set default values +func configDefault(config ...Config) Config { + // Return default config if nothing provided + if len(config) < 1 { + return ConfigDefault + } + + // Override default config + cfg := config[0] + + // Set default values + if int(cfg.Duration.Seconds()) > 0 { + fmt.Println("[LIMITER] Duration is deprecated, please use Expiration") + cfg.Expiration = cfg.Duration + } + if cfg.Key != nil { + fmt.Println("[LIMITER] Key is deprecated, please us KeyGenerator") + cfg.KeyGenerator = cfg.Key + } + if cfg.Store != nil { + fmt.Println("[LIMITER] Store is deprecated, please use Storage") + cfg.Storage = cfg.Store + } + if cfg.Next == nil { + cfg.Next = ConfigDefault.Next + } + if cfg.Max <= 0 { + cfg.Max = ConfigDefault.Max + } + if int(cfg.Expiration.Seconds()) <= 0 { + cfg.Expiration = ConfigDefault.Expiration + } + if cfg.KeyGenerator == nil { + cfg.KeyGenerator = ConfigDefault.KeyGenerator + } + if cfg.LimitReached == nil { + cfg.LimitReached = ConfigDefault.LimitReached + } + return cfg +} diff --git a/middleware/limiter/limiter.go b/middleware/limiter/limiter.go index 42ab980155..dc717be862 100644 --- a/middleware/limiter/limiter.go +++ b/middleware/limiter/limiter.go @@ -10,69 +10,11 @@ import ( "github.com/gofiber/fiber/v2" ) -//go:generate msgp -unexported -//msgp:ignore Config - -// Config defines the config for middleware. -type Config struct { - // Next defines a function to skip this middleware when returned true. - // - // Optional. Default: nil - Next func(c *fiber.Ctx) bool - - // Max number of recent connections during `Duration` seconds before sending a 429 response - // - // Default: 5 - Max int - - // DEPRECATED: Use Expiration instead - Duration time.Duration - - // Expiration is the time on how long to keep records of requests in memory - // - // Default: 1 * time.Minute - Expiration time.Duration - - // Key allows you to generate custom keys, by default c.IP() is used - // - // Default: func(c *fiber.Ctx) string { - // return c.IP() - // } - Key func(*fiber.Ctx) string - - // LimitReached is called when a request hits the limit - // - // Default: func(c *fiber.Ctx) error { - // return c.SendStatus(fiber.StatusTooManyRequests) - // } - LimitReached fiber.Handler - - // Store is used to store the state of the middleware - // - // Default: an in memory store for this process only - Store fiber.Storage - - // Internally used - if true, the simpler method of two maps is used in order to keep - // execution time down. - defaultStore bool -} - -// ConfigDefault is the default config -var ConfigDefault = Config{ - Next: nil, - Max: 5, - Expiration: 1 * time.Minute, - Key: func(c *fiber.Ctx) string { - return c.IP() - }, - LimitReached: func(c *fiber.Ctx) error { - return c.SendStatus(fiber.StatusTooManyRequests) - }, - defaultStore: true, -} - -// X-RateLimit-* headers const ( + // Storage ErrNotExist + errNotExist = "key does not exist" + + // X-RateLimit-* headers xRateLimitLimit = "X-RateLimit-Limit" xRateLimitRemaining = "X-RateLimit-Remaining" xRateLimitReset = "X-RateLimit-Reset" @@ -81,38 +23,7 @@ const ( // New creates a new middleware handler func New(config ...Config) fiber.Handler { // Set default config - cfg := ConfigDefault - - // Override config if provided - if len(config) > 0 { - cfg = config[0] - - // Set default values - if cfg.Next == nil { - cfg.Next = ConfigDefault.Next - } - if cfg.Max <= 0 { - cfg.Max = ConfigDefault.Max - } - if int(cfg.Duration.Seconds()) <= 0 && int(cfg.Expiration.Seconds()) <= 0 { - cfg.Expiration = ConfigDefault.Expiration - } - if int(cfg.Duration.Seconds()) > 0 { - fmt.Println("[LIMITER] Duration is deprecated, please use Expiration") - if cfg.Expiration != ConfigDefault.Expiration { - cfg.Expiration = cfg.Duration - } - } - if cfg.Key == nil { - cfg.Key = ConfigDefault.Key - } - if cfg.LimitReached == nil { - cfg.LimitReached = ConfigDefault.LimitReached - } - if cfg.Store == nil { - cfg.defaultStore = true - } - } + cfg := configDefault(config...) var ( // Limiter settings @@ -141,7 +52,7 @@ func New(config ...Config) fiber.Handler { } // Get key from request - key := cfg.Key(c) + key := cfg.KeyGenerator(c) // Create new entry entry := entry{} @@ -150,21 +61,19 @@ func New(config ...Config) fiber.Handler { mux.Lock() defer mux.Unlock() - // Use default memory storage - if cfg.defaultStore { - entry = entries[key] - } else { // Use custom storage - storeEntry, err := cfg.Store.Get(key) - if err != nil { - return err - } - // Only decode if we found an entry - if storeEntry != nil { - // Decode bytes using msgp - if _, err := entry.UnmarshalMsg(storeEntry); err != nil { + // Use Storage if provided + if cfg.Storage != nil { + val, err := cfg.Storage.Get(key) + if val != nil && len(val) > 0 { + if _, err := entry.UnmarshalMsg(val); err != nil { return err } } + if err != nil && err.Error() != errNotExist { + fmt.Println("[LIMITER]", err.Error()) + } + } else { + entry = entries[key] } // Get timestamp @@ -183,19 +92,20 @@ func New(config ...Config) fiber.Handler { // Increment hits entry.hits++ - // Use default memory storage - if cfg.defaultStore { - entries[key] = entry - } else { // Use custom storage - data, err := entry.MarshalMsg(nil) + // Use Storage if provided + if cfg.Storage != nil { + // Marshal entry to bytes + val, err := entry.MarshalMsg(nil) if err != nil { return err } - // Pass bytes to Storage - if err = cfg.Store.Set(key, data, cfg.Expiration); err != nil { + // Pass value to Storage + if err = cfg.Storage.Set(key, val, cfg.Expiration); err != nil { return err } + } else { + entries[key] = entry } // Calculate when it resets in seconds @@ -223,10 +133,3 @@ func New(config ...Config) fiber.Handler { return c.Next() } } - -// replacer for strconv.FormatUint -// func appendInt(buf *bytebufferpool.ByteBuffer, v int) (int, error) { -// old := len(buf.B) -// buf.B = fasthttp.AppendUint(buf.B, v) -// return len(buf.B) - old, nil -// } diff --git a/middleware/limiter/limiter_test.go b/middleware/limiter/limiter_test.go index f2daf9b23a..1ffae44920 100644 --- a/middleware/limiter/limiter_test.go +++ b/middleware/limiter/limiter_test.go @@ -9,9 +9,9 @@ import ( "testing" "time" - "github.com/gofiber/fiber/v2/utils" - "github.com/gofiber/fiber/v2" + "github.com/gofiber/fiber/v2/internal/storage/memory" + "github.com/gofiber/fiber/v2/utils" "github.com/valyala/fasthttp" ) @@ -24,7 +24,7 @@ func Test_Limiter_Concurrency_Store(t *testing.T) { app.Use(New(Config{ Max: 50, Expiration: 2 * time.Second, - Store: testStore{stmap: map[string][]byte{}, mutex: new(sync.Mutex)}, + Storage: memory.New(), })) app.Get("/", func(c *fiber.Ctx) error { @@ -108,32 +108,6 @@ func Test_Limiter_Concurrency(t *testing.T) { } -// go test -v -run=^$ -bench=Benchmark_Limiter -benchmem -count=4 -func Benchmark_Limiter(b *testing.B) { - app := fiber.New() - - app.Use(New(Config{ - Max: 100, - Expiration: 60 * time.Second, - })) - - app.Get("/", func(c *fiber.Ctx) error { - return c.SendString("Hello, World!") - }) - - h := app.Handler() - - fctx := &fasthttp.RequestCtx{} - fctx.Request.Header.SetMethod("GET") - fctx.Request.SetRequestURI("/") - - b.ResetTimer() - - for n := 0; n < b.N; n++ { - h(fctx) - } -} - // go test -v -run=^$ -bench=Benchmark_Limiter_Custom_Store -benchmem -count=4 func Benchmark_Limiter_Custom_Store(b *testing.B) { app := fiber.New() @@ -141,7 +115,7 @@ func Benchmark_Limiter_Custom_Store(b *testing.B) { app.Use(New(Config{ Max: 100, Expiration: 60 * time.Second, - Store: testStore{stmap: map[string][]byte{}, mutex: new(sync.Mutex)}, + Storage: memory.New(), })) app.Get("/", func(c *fiber.Ctx) error { @@ -203,39 +177,28 @@ func Test_Limiter_Headers(t *testing.T) { } } -// testStore is used for testing custom stores -type testStore struct { - stmap map[string][]byte - mutex *sync.Mutex -} +// go test -v -run=^$ -bench=Benchmark_Limiter -benchmem -count=4 +func Benchmark_Limiter(b *testing.B) { + app := fiber.New() -func (s testStore) Get(id string) ([]byte, error) { - s.mutex.Lock() - val, ok := s.stmap[id] - s.mutex.Unlock() - if !ok { - return nil, nil - } else { - return val, nil - } -} + app.Use(New(Config{ + Max: 100, + Expiration: 60 * time.Second, + })) -func (s testStore) Set(id string, val []byte, _ time.Duration) error { - s.mutex.Lock() - s.stmap[id] = val - s.mutex.Unlock() + app.Get("/", func(c *fiber.Ctx) error { + return c.SendString("Hello, World!") + }) - return nil -} + h := app.Handler() -func (s testStore) Reset() error { - return nil -} + fctx := &fasthttp.RequestCtx{} + fctx.Request.Header.SetMethod("GET") + fctx.Request.SetRequestURI("/") -func (s testStore) Delete(id string) error { - return nil -} + b.ResetTimer() -func (s testStore) Close() error { - return nil + for n := 0; n < b.N; n++ { + h(fctx) + } } diff --git a/middleware/logger/config.go b/middleware/logger/config.go new file mode 100644 index 0000000000..c5dd48a751 --- /dev/null +++ b/middleware/logger/config.go @@ -0,0 +1,94 @@ +package logger + +import ( + "io" + "os" + "time" + + "github.com/gofiber/fiber/v2" +) + +// Config defines the config for middleware. +type Config struct { + // Next defines a function to skip this middleware when returned true. + // + // Optional. Default: nil + Next func(c *fiber.Ctx) bool + + // Format defines the logging tags + // + // Optional. Default: [${time}] ${status} - ${latency} ${method} ${path}\n + Format string + + // TimeFormat https://programming.guide/go/format-parse-string-time-date-example.html + // + // Optional. Default: 15:04:05 + TimeFormat string + + // TimeZone can be specified, such as "UTC" and "America/New_York" and "Asia/Chongqing", etc + // + // Optional. Default: "Local" + TimeZone string + + // TimeInterval is the delay before the timestamp is updated + // + // Optional. Default: 500 * time.Millisecond + TimeInterval time.Duration + + // Output is a writter where logs are written + // + // Default: os.Stderr + Output io.Writer + + enableColors bool + enableLatency bool + timeZoneLocation *time.Location +} + +// ConfigDefault is the default config +var ConfigDefault = Config{ + Next: nil, + Format: "[${time}] ${status} - ${latency} ${method} ${path}\n", + TimeFormat: "15:04:05", + TimeZone: "Local", + TimeInterval: 500 * time.Millisecond, + Output: os.Stderr, + enableColors: true, +} + +// Helper function to set default values +func configDefault(config ...Config) Config { + // Return default config if nothing provided + if len(config) < 1 { + return ConfigDefault + } + + // Override default config + cfg := config[0] + + // Enable colors if no custom format or output is given + if cfg.Format == "" && cfg.Output == nil { + cfg.enableColors = true + } + + // Set default values + if cfg.Next == nil { + cfg.Next = ConfigDefault.Next + } + if cfg.Format == "" { + cfg.Format = ConfigDefault.Format + } + if cfg.TimeZone == "" { + cfg.TimeZone = ConfigDefault.TimeZone + } + if cfg.TimeFormat == "" { + cfg.TimeFormat = ConfigDefault.TimeFormat + } + if int(cfg.TimeInterval) <= 0 { + cfg.TimeInterval = ConfigDefault.TimeInterval + } + if cfg.Output == nil { + cfg.Output = ConfigDefault.Output + } + return cfg +} diff --git a/middleware/logger/logger.go b/middleware/logger/logger.go index d16b631e7e..48e07dc9aa 100644 --- a/middleware/logger/logger.go +++ b/middleware/logger/logger.go @@ -18,53 +18,6 @@ import ( "github.com/valyala/fasthttp" ) -// Config defines the config for middleware. -type Config struct { - // Next defines a function to skip this middleware when returned true. - // - // Optional. Default: nil - Next func(c *fiber.Ctx) bool - - // Format defines the logging tags - // - // Optional. Default: [${time}] ${status} - ${latency} ${method} ${path}\n - Format string - - // TimeFormat https://programming.guide/go/format-parse-string-time-date-example.html - // - // Optional. Default: 15:04:05 - TimeFormat string - - // TimeZone can be specified, such as "UTC" and "America/New_York" and "Asia/Chongqing", etc - // - // Optional. Default: "Local" - TimeZone string - - // TimeInterval is the delay before the timestamp is updated - // - // Optional. Default: 500 * time.Millisecond - TimeInterval time.Duration - - // Output is a writter where logs are written - // - // Default: os.Stderr - Output io.Writer - - enableColors bool - enableLatency bool - timeZoneLocation *time.Location -} - -// ConfigDefault is the default config -var ConfigDefault = Config{ - Next: nil, - Format: "[${time}] ${status} - ${latency} ${method} ${path}\n", - TimeFormat: "15:04:05", - TimeZone: "Local", - TimeInterval: 500 * time.Millisecond, - Output: os.Stderr, -} - // Logger variables const ( TagPid = "pid" @@ -117,39 +70,7 @@ const ( // New creates a new middleware handler func New(config ...Config) fiber.Handler { // Set default config - cfg := ConfigDefault - - // Override config if provided - if len(config) > 0 { - cfg = config[0] - - // Enable colors if no custom format or output is given - if cfg.Format == "" && cfg.Output == nil { - cfg.enableColors = true - } - - // Set default values - if cfg.Next == nil { - cfg.Next = ConfigDefault.Next - } - if cfg.Format == "" { - cfg.Format = ConfigDefault.Format - } - if cfg.TimeZone == "" { - cfg.TimeZone = ConfigDefault.TimeZone - } - if cfg.TimeFormat == "" { - cfg.TimeFormat = ConfigDefault.TimeFormat - } - if int(cfg.TimeInterval) <= 0 { - cfg.TimeInterval = ConfigDefault.TimeInterval - } - if cfg.Output == nil { - cfg.Output = ConfigDefault.Output - } - } else { - cfg.enableColors = true - } + cfg := configDefault(config...) // Get timezone location tz, err := time.LoadLocation(cfg.TimeZone) @@ -186,6 +107,7 @@ func New(config ...Config) fiber.Handler { var ( start, stop time.Time once sync.Once + mu sync.Mutex errHandler fiber.ErrorHandler ) @@ -362,6 +284,7 @@ func New(config ...Config) fiber.Handler { if err != nil { _, _ = buf.WriteString(err.Error()) } + mu.Lock() // Write buffer to output if _, err := cfg.Output.Write(buf.Bytes()); err != nil { // Write error to output @@ -370,6 +293,7 @@ func New(config ...Config) fiber.Handler { // TODO: What should we do here? } } + mu.Unlock() // Put buffer back to pool bytebufferpool.Put(buf) diff --git a/middleware/proxy/config.go b/middleware/proxy/config.go new file mode 100644 index 0000000000..9062f21799 --- /dev/null +++ b/middleware/proxy/config.go @@ -0,0 +1,55 @@ +package proxy + +import ( + "github.com/gofiber/fiber/v2" +) + +// Config defines the config for middleware. +type Config struct { + // Next defines a function to skip this middleware when returned true. + // + // Optional. Default: nil + Next func(c *fiber.Ctx) bool + + // Servers defines a list of :// HTTP servers, + // + // which are used in a round-robin manner. + // i.e.: "https://foobar.com, http://www.foobar.com" + // + // Required + Servers []string + + // ModifyRequest allows you to alter the request + // + // Optional. Default: nil + ModifyRequest fiber.Handler + + // ModifyResponse allows you to alter the response + // + // Optional. Default: nil + ModifyResponse fiber.Handler +} + +// ConfigDefault is the default config +var ConfigDefault = Config{ + Next: nil, + ModifyRequest: nil, + ModifyResponse: nil, +} + +// Helper function to set default values +func configDefault(config ...Config) Config { + // Return default config if nothing provided + if len(config) < 1 { + return ConfigDefault + } + + // Override default config + cfg := config[0] + + // Set default values + if len(cfg.Servers) == 0 { + panic("Servers cannot be empty") + } + return cfg +} diff --git a/middleware/proxy/proxy.go b/middleware/proxy/proxy.go index 8722b98f89..7724c2ed23 100644 --- a/middleware/proxy/proxy.go +++ b/middleware/proxy/proxy.go @@ -9,39 +9,6 @@ import ( "github.com/valyala/fasthttp" ) -// Config defines the config for middleware. -type Config struct { - // Next defines a function to skip this middleware when returned true. - // - // Optional. Default: nil - Next func(c *fiber.Ctx) bool - - // Servers defines a list of :// HTTP servers, - // - // which are used in a round-robin manner. - // i.e.: "https://foobar.com, http://www.foobar.com" - // - // Required - Servers []string - - // ModifyRequest allows you to alter the request - // - // Optional. Default: nil - ModifyRequest fiber.Handler - - // ModifyResponse allows you to alter the response - // - // Optional. Default: nil - ModifyResponse fiber.Handler -} - -// ConfigDefault is the default config -var ConfigDefault = Config{ - Next: nil, - ModifyRequest: nil, - ModifyResponse: nil, -} - // New is deprecated func New(config Config) fiber.Handler { fmt.Println("proxy.New is deprecated, please use proxy.Balancer instead") @@ -50,16 +17,8 @@ func New(config Config) fiber.Handler { // Balancer creates a load balancer among multiple upstream servers func Balancer(config Config) fiber.Handler { - // Override config if provided - cfg := config - - // Set default values - if cfg.Next == nil { - cfg.Next = ConfigDefault.Next - } - if len(cfg.Servers) == 0 { - panic("Servers cannot be empty") - } + // Set default config + cfg := configDefault(config) client := fasthttp.Client{ NoDefaultUserAgentHeader: true, diff --git a/middleware/recover/config.go b/middleware/recover/config.go new file mode 100644 index 0000000000..550fed9865 --- /dev/null +++ b/middleware/recover/config.go @@ -0,0 +1,31 @@ +package recover + +import ( + "github.com/gofiber/fiber/v2" +) + +// Config defines the config for middleware. +type Config struct { + // Next defines a function to skip this middleware when returned true. + // + // Optional. Default: nil + Next func(c *fiber.Ctx) bool +} + +// ConfigDefault is the default config +var ConfigDefault = Config{ + Next: nil, +} + +// Helper function to set default values +func configDefault(config ...Config) Config { + // Return default config if nothing provided + if len(config) < 1 { + return ConfigDefault + } + + // Override default config + cfg := config[0] + + return cfg +} diff --git a/middleware/recover/recover.go b/middleware/recover/recover.go index 8ee4b30f2b..89941c5a7b 100644 --- a/middleware/recover/recover.go +++ b/middleware/recover/recover.go @@ -6,28 +6,10 @@ import ( "github.com/gofiber/fiber/v2" ) -// Config defines the config for middleware. -type Config struct { - // Next defines a function to skip this middleware when returned true. - // - // Optional. Default: nil - Next func(c *fiber.Ctx) bool -} - -// ConfigDefault is the default config -var ConfigDefault = Config{ - Next: nil, -} - // New creates a new middleware handler func New(config ...Config) fiber.Handler { // Set default config - cfg := ConfigDefault - - // Override config if provided - if len(config) > 0 { - cfg = config[0] - } + cfg := configDefault(config...) // Return new handler return func(c *fiber.Ctx) (err error) { diff --git a/middleware/requestid/config.go b/middleware/requestid/config.go new file mode 100644 index 0000000000..ace51ad4a7 --- /dev/null +++ b/middleware/requestid/config.go @@ -0,0 +1,61 @@ +package requestid + +import ( + "github.com/gofiber/fiber/v2" + "github.com/gofiber/fiber/v2/utils" +) + +// Config defines the config for middleware. +type Config struct { + // Next defines a function to skip this middleware when returned true. + // + // Optional. Default: nil + Next func(c *fiber.Ctx) bool + + // Header is the header key where to get/set the unique request ID + // + // Optional. Default: "X-Request-ID" + Header string + + // Generator defines a function to generate the unique identifier. + // + // Optional. Default: utils.UUID + Generator func() string + + // ContextKey defines the key used when storing the request ID in + // the locals for a specific request. + // + // Optional. Default: requestid + ContextKey string +} + +// ConfigDefault is the default config +var ConfigDefault = Config{ + Next: nil, + Header: fiber.HeaderXRequestID, + Generator: utils.UUID, + ContextKey: "requestid", +} + +// Helper function to set default values +func configDefault(config ...Config) Config { + // Return default config if nothing provided + if len(config) < 1 { + return ConfigDefault + } + + // Override default config + cfg := config[0] + + // Set default values + if cfg.Header == "" { + cfg.Header = ConfigDefault.Header + } + if cfg.Generator == nil { + cfg.Generator = ConfigDefault.Generator + } + if cfg.ContextKey == "" { + cfg.ContextKey = ConfigDefault.ContextKey + } + return cfg +} diff --git a/middleware/requestid/requestid.go b/middleware/requestid/requestid.go index b0a4512846..cc77803815 100644 --- a/middleware/requestid/requestid.go +++ b/middleware/requestid/requestid.go @@ -2,61 +2,12 @@ package requestid import ( "github.com/gofiber/fiber/v2" - "github.com/gofiber/fiber/v2/utils" ) -// Config defines the config for middleware. -type Config struct { - // Next defines a function to skip this middleware when returned true. - // - // Optional. Default: nil - Next func(c *fiber.Ctx) bool - - // Header is the header key where to get/set the unique request ID - // - // Optional. Default: "X-Request-ID" - Header string - - // Generator defines a function to generate the unique identifier. - // - // Optional. Default: utils.UUID - Generator func() string - - // ContextKey defines the key used when storing the request ID in - // the locals for a specific request. - // - // Optional. Default: requestid - ContextKey string -} - -// ConfigDefault is the default config -var ConfigDefault = Config{ - Next: nil, - Header: fiber.HeaderXRequestID, - Generator: utils.UUID, - ContextKey: "requestid", -} - // New creates a new middleware handler func New(config ...Config) fiber.Handler { // Set default config - cfg := ConfigDefault - - // Override config if provided - if len(config) > 0 { - cfg = config[0] - - // Set default values - if cfg.Header == "" { - cfg.Header = ConfigDefault.Header - } - if cfg.Generator == nil { - cfg.Generator = ConfigDefault.Generator - } - if cfg.ContextKey == "" { - cfg.ContextKey = ConfigDefault.ContextKey - } - } + cfg := configDefault(config...) // Return new handler return func(c *fiber.Ctx) error { diff --git a/middleware/session/README.md b/middleware/session/README.md new file mode 100644 index 0000000000..9ecaca51ba --- /dev/null +++ b/middleware/session/README.md @@ -0,0 +1,108 @@ +# Session +Session middleware for [Fiber](https://github.com/gofiber/fiber) + +### Table of Contents +- [Signatures](#signatures) +- [Examples](#examples) +- [Config](#config) +- [Default Config](#default-config) + + +### Signatures +```go +func New(config ...Config) fiber.Handler +``` + +### Examples +Import the middleware package that is part of the Fiber web framework +```go +import ( + "github.com/gofiber/fiber/v2" + "github.com/gofiber/fiber/v2/middleware/session" +) +``` + +After you initiate your Fiber app, you can use the following possibilities: +```go +// Default middleware config +store := session.New() + +// This panic will be catch by the middleware +app.Get("/", func(c *fiber.Ctx) error { + // get session from storage + sess, err := store.Get(c) + if err != nil { + panic(err) + } + + // save session + defer sess.Save() + + // Get value + name := sess.Get("name") + + // Set key/value + sess.Set("name", "john") + + // Delete key + sess.Delete("name") + + // Destry session + if err := sess.Destroy(); err != nil { + panic(err) + } + + return fmt.Fprintf(ctx, "Welcome %v", name) +}) +``` + +### Config +```go +// Config defines the config for middleware. +type Config struct { + // Allowed session duration + // Optional. Default value 24 * time.Hour + Expiration time.Duration + + // Storage interface to store the session data + // Optional. Default value memory.New() + Storage fiber.Storage + + // Name of the session cookie. This cookie will store session key. + // Optional. Default value "session_id". + CookieName string + + // Domain of the CSRF cookie. + // Optional. Default value "". + CookieDomain string + + // Path of the CSRF cookie. + // Optional. Default value "". + CookiePath string + + // Indicates if CSRF cookie is secure. + // Optional. Default value false. + CookieSecure bool + + // Indicates if CSRF cookie is HTTP only. + // Optional. Default value false. + CookieHTTPOnly bool + + // Indicates if CSRF cookie is HTTP only. + // Optional. Default value false. + CookieSameSite string + + // KeyGenerator generates the session key. + // Optional. Default value utils.UUID + KeyGenerator func() string +} +``` + +### Default Config +```go +var ConfigDefault = Config{ + Expiration: 24 * time.Hour, + CookieName: "session_id", + KeyGenerator: utils.UUID, +} +``` \ No newline at end of file diff --git a/middleware/session/config.go b/middleware/session/config.go new file mode 100644 index 0000000000..788783d9f5 --- /dev/null +++ b/middleware/session/config.go @@ -0,0 +1,74 @@ +package session + +import ( + "time" + + "github.com/gofiber/fiber/v2" + "github.com/gofiber/fiber/v2/utils" +) + +// Config defines the config for middleware. +type Config struct { + // Allowed session duration + // Optional. Default value 24 * time.Hour + Expiration time.Duration + + // Storage interface to store the session data + // Optional. Default value memory.New() + Storage fiber.Storage + + // Name of the session cookie. This cookie will store session key. + // Optional. Default value "session_id". + CookieName string + + // Domain of the CSRF cookie. + // Optional. Default value "". + CookieDomain string + + // Path of the CSRF cookie. + // Optional. Default value "". + CookiePath string + + // Indicates if CSRF cookie is secure. + // Optional. Default value false. + CookieSecure bool + + // Indicates if CSRF cookie is HTTP only. + // Optional. Default value false. + CookieHTTPOnly bool + + // Indicates if CSRF cookie is HTTP only. + // Optional. Default value false. + CookieSameSite string + + // KeyGenerator generates the session key. + // Optional. Default value utils.UUID + KeyGenerator func() string +} + +// ConfigDefault is the default config +var ConfigDefault = Config{ + Expiration: 24 * time.Hour, + CookieName: "session_id", + KeyGenerator: utils.UUID, +} + +// Helper function to set default values +func configDefault(config ...Config) Config { + // Return default config if nothing provided + if len(config) < 1 { + return ConfigDefault + } + + // Override default config + cfg := config[0] + + // Set default values + if int(cfg.Expiration.Seconds()) <= 0 { + cfg.Expiration = ConfigDefault.Expiration + } + if cfg.KeyGenerator == nil { + cfg.KeyGenerator = ConfigDefault.KeyGenerator + } + return cfg +} diff --git a/middleware/session/db.go b/middleware/session/db.go new file mode 100644 index 0000000000..7e0fbbee9f --- /dev/null +++ b/middleware/session/db.go @@ -0,0 +1,84 @@ +package session + +// go:generate msgp +// msgp -file="db.go" -o="db_msgp.go" -tests=false -unexported +// don't forget to replace the msgp import path to: +// "github.com/gofiber/fiber/v2/internal/msgp" +type db struct { + d []kv +} + +// go:generate msgp +type kv struct { + k string + v interface{} +} + +func (d *db) Reset() { + d.d = d.d[:0] +} + +func (d *db) Get(key string) interface{} { + idx := d.indexOf(key) + if idx > -1 { + return d.d[idx].v + } + return nil +} + +func (d *db) Set(key string, value interface{}) { + idx := d.indexOf(key) + if idx > -1 { + kv := &d.d[idx] + kv.v = value + } else { + d.append(key, value) + } +} + +func (d *db) Delete(key string) { + idx := d.indexOf(key) + if idx > -1 { + n := len(d.d) - 1 + d.swap(idx, n) + d.d = d.d[:n] + } +} + +func (d *db) Len() int { + return len(d.d) +} + +func (d *db) swap(i, j int) { + iKey, iValue := d.d[i].k, d.d[i].v + jKey, jValue := d.d[j].k, d.d[j].v + + d.d[i].k, d.d[i].v = jKey, jValue + d.d[j].k, d.d[j].v = iKey, iValue +} + +func (d *db) allocPage() *kv { + n := len(d.d) + if cap(d.d) > n { + d.d = d.d[:n+1] + } else { + d.d = append(d.d, kv{}) + } + return &d.d[n] +} + +func (d *db) append(key string, value interface{}) { + kv := d.allocPage() + kv.k = key + kv.v = value +} + +func (d *db) indexOf(key string) int { + n := len(d.d) + for i := 0; i < n; i++ { + if d.d[i].k == key { + return i + } + } + return -1 +} diff --git a/middleware/session/db_msgp.go b/middleware/session/db_msgp.go new file mode 100644 index 0000000000..ffec03973c --- /dev/null +++ b/middleware/session/db_msgp.go @@ -0,0 +1,365 @@ +package session + +// Code generated by github.com/tinylib/msgp DO NOT EDIT. + +import ( + "github.com/gofiber/fiber/v2/internal/msgp" +) + +// DecodeMsg implements msgp.Decodable +func (z *db) DecodeMsg(dc *msgp.Reader) (err error) { + var field []byte + _ = field + var zb0001 uint32 + zb0001, err = dc.ReadMapHeader() + if err != nil { + err = msgp.WrapError(err) + return + } + for zb0001 > 0 { + zb0001-- + field, err = dc.ReadMapKeyPtr() + if err != nil { + err = msgp.WrapError(err) + return + } + switch msgp.UnsafeString(field) { + case "d": + var zb0002 uint32 + zb0002, err = dc.ReadArrayHeader() + if err != nil { + err = msgp.WrapError(err, "d") + return + } + if cap(z.d) >= int(zb0002) { + z.d = (z.d)[:zb0002] + } else { + z.d = make([]kv, zb0002) + } + for za0001 := range z.d { + var zb0003 uint32 + zb0003, err = dc.ReadMapHeader() + if err != nil { + err = msgp.WrapError(err, "d", za0001) + return + } + for zb0003 > 0 { + zb0003-- + field, err = dc.ReadMapKeyPtr() + if err != nil { + err = msgp.WrapError(err, "d", za0001) + return + } + switch msgp.UnsafeString(field) { + case "k": + z.d[za0001].k, err = dc.ReadString() + if err != nil { + err = msgp.WrapError(err, "d", za0001, "k") + return + } + case "v": + z.d[za0001].v, err = dc.ReadIntf() + if err != nil { + err = msgp.WrapError(err, "d", za0001, "v") + return + } + default: + err = dc.Skip() + if err != nil { + err = msgp.WrapError(err, "d", za0001) + return + } + } + } + } + default: + err = dc.Skip() + if err != nil { + err = msgp.WrapError(err) + return + } + } + } + return +} + +// EncodeMsg implements msgp.Encodable +func (z *db) EncodeMsg(en *msgp.Writer) (err error) { + // map header, size 1 + // write "d" + err = en.Append(0x81, 0xa1, 0x64) + if err != nil { + return + } + err = en.WriteArrayHeader(uint32(len(z.d))) + if err != nil { + err = msgp.WrapError(err, "d") + return + } + for za0001 := range z.d { + // map header, size 2 + // write "k" + err = en.Append(0x82, 0xa1, 0x6b) + if err != nil { + return + } + err = en.WriteString(z.d[za0001].k) + if err != nil { + err = msgp.WrapError(err, "d", za0001, "k") + return + } + // write "v" + err = en.Append(0xa1, 0x76) + if err != nil { + return + } + err = en.WriteIntf(z.d[za0001].v) + if err != nil { + err = msgp.WrapError(err, "d", za0001, "v") + return + } + } + return +} + +// MarshalMsg implements msgp.Marshaler +func (z *db) MarshalMsg(b []byte) (o []byte, err error) { + o = msgp.Require(b, z.Msgsize()) + // map header, size 1 + // string "d" + o = append(o, 0x81, 0xa1, 0x64) + o = msgp.AppendArrayHeader(o, uint32(len(z.d))) + for za0001 := range z.d { + // map header, size 2 + // string "k" + o = append(o, 0x82, 0xa1, 0x6b) + o = msgp.AppendString(o, z.d[za0001].k) + // string "v" + o = append(o, 0xa1, 0x76) + o, err = msgp.AppendIntf(o, z.d[za0001].v) + if err != nil { + err = msgp.WrapError(err, "d", za0001, "v") + return + } + } + return +} + +// UnmarshalMsg implements msgp.Unmarshaler +func (z *db) UnmarshalMsg(bts []byte) (o []byte, err error) { + var field []byte + _ = field + var zb0001 uint32 + zb0001, bts, err = msgp.ReadMapHeaderBytes(bts) + if err != nil { + err = msgp.WrapError(err) + return + } + for zb0001 > 0 { + zb0001-- + field, bts, err = msgp.ReadMapKeyZC(bts) + if err != nil { + err = msgp.WrapError(err) + return + } + switch msgp.UnsafeString(field) { + case "d": + var zb0002 uint32 + zb0002, bts, err = msgp.ReadArrayHeaderBytes(bts) + if err != nil { + err = msgp.WrapError(err, "d") + return + } + if cap(z.d) >= int(zb0002) { + z.d = (z.d)[:zb0002] + } else { + z.d = make([]kv, zb0002) + } + for za0001 := range z.d { + var zb0003 uint32 + zb0003, bts, err = msgp.ReadMapHeaderBytes(bts) + if err != nil { + err = msgp.WrapError(err, "d", za0001) + return + } + for zb0003 > 0 { + zb0003-- + field, bts, err = msgp.ReadMapKeyZC(bts) + if err != nil { + err = msgp.WrapError(err, "d", za0001) + return + } + switch msgp.UnsafeString(field) { + case "k": + z.d[za0001].k, bts, err = msgp.ReadStringBytes(bts) + if err != nil { + err = msgp.WrapError(err, "d", za0001, "k") + return + } + case "v": + z.d[za0001].v, bts, err = msgp.ReadIntfBytes(bts) + if err != nil { + err = msgp.WrapError(err, "d", za0001, "v") + return + } + default: + bts, err = msgp.Skip(bts) + if err != nil { + err = msgp.WrapError(err, "d", za0001) + return + } + } + } + } + default: + bts, err = msgp.Skip(bts) + if err != nil { + err = msgp.WrapError(err) + return + } + } + } + o = bts + return +} + +// Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message +func (z *db) Msgsize() (s int) { + s = 1 + 2 + msgp.ArrayHeaderSize + for za0001 := range z.d { + s += 1 + 2 + msgp.StringPrefixSize + len(z.d[za0001].k) + 2 + msgp.GuessSize(z.d[za0001].v) + } + return +} + +// DecodeMsg implements msgp.Decodable +func (z *kv) DecodeMsg(dc *msgp.Reader) (err error) { + var field []byte + _ = field + var zb0001 uint32 + zb0001, err = dc.ReadMapHeader() + if err != nil { + err = msgp.WrapError(err) + return + } + for zb0001 > 0 { + zb0001-- + field, err = dc.ReadMapKeyPtr() + if err != nil { + err = msgp.WrapError(err) + return + } + switch msgp.UnsafeString(field) { + case "k": + z.k, err = dc.ReadString() + if err != nil { + err = msgp.WrapError(err, "k") + return + } + case "v": + z.v, err = dc.ReadIntf() + if err != nil { + err = msgp.WrapError(err, "v") + return + } + default: + err = dc.Skip() + if err != nil { + err = msgp.WrapError(err) + return + } + } + } + return +} + +// EncodeMsg implements msgp.Encodable +func (z kv) EncodeMsg(en *msgp.Writer) (err error) { + // map header, size 2 + // write "k" + err = en.Append(0x82, 0xa1, 0x6b) + if err != nil { + return + } + err = en.WriteString(z.k) + if err != nil { + err = msgp.WrapError(err, "k") + return + } + // write "v" + err = en.Append(0xa1, 0x76) + if err != nil { + return + } + err = en.WriteIntf(z.v) + if err != nil { + err = msgp.WrapError(err, "v") + return + } + return +} + +// MarshalMsg implements msgp.Marshaler +func (z kv) MarshalMsg(b []byte) (o []byte, err error) { + o = msgp.Require(b, z.Msgsize()) + // map header, size 2 + // string "k" + o = append(o, 0x82, 0xa1, 0x6b) + o = msgp.AppendString(o, z.k) + // string "v" + o = append(o, 0xa1, 0x76) + o, err = msgp.AppendIntf(o, z.v) + if err != nil { + err = msgp.WrapError(err, "v") + return + } + return +} + +// UnmarshalMsg implements msgp.Unmarshaler +func (z *kv) UnmarshalMsg(bts []byte) (o []byte, err error) { + var field []byte + _ = field + var zb0001 uint32 + zb0001, bts, err = msgp.ReadMapHeaderBytes(bts) + if err != nil { + err = msgp.WrapError(err) + return + } + for zb0001 > 0 { + zb0001-- + field, bts, err = msgp.ReadMapKeyZC(bts) + if err != nil { + err = msgp.WrapError(err) + return + } + switch msgp.UnsafeString(field) { + case "k": + z.k, bts, err = msgp.ReadStringBytes(bts) + if err != nil { + err = msgp.WrapError(err, "k") + return + } + case "v": + z.v, bts, err = msgp.ReadIntfBytes(bts) + if err != nil { + err = msgp.WrapError(err, "v") + return + } + default: + bts, err = msgp.Skip(bts) + if err != nil { + err = msgp.WrapError(err) + return + } + } + } + o = bts + return +} + +// Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message +func (z kv) Msgsize() (s int) { + s = 1 + 2 + msgp.StringPrefixSize + len(z.k) + 2 + msgp.GuessSize(z.v) + return +} diff --git a/middleware/session/session.go b/middleware/session/session.go new file mode 100644 index 0000000000..8d34dd22ca --- /dev/null +++ b/middleware/session/session.go @@ -0,0 +1,172 @@ +package session + +import ( + "sync" + "time" + + "github.com/gofiber/fiber/v2" + "github.com/gofiber/fiber/v2/utils" + "github.com/valyala/fasthttp" +) + +type Session struct { + ctx *fiber.Ctx + config *Store + db *db + id string + fresh bool +} + +var sessionPool = sync.Pool{ + New: func() interface{} { + return new(Session) + }, +} + +func acquireSession() *Session { + s := sessionPool.Get().(*Session) + s.db = new(db) + s.fresh = true + return s +} + +func releaseSession(s *Session) { + s.ctx = nil + s.config = nil + if s.db != nil { + s.db.Reset() + } + s.id = "" + s.fresh = true + sessionPool.Put(s) +} + +// Fresh is true if the current session is new +func (s *Session) Fresh() bool { + return s.fresh +} + +// ID returns the session id +func (s *Session) ID() string { + return s.id +} + +// Get will return the value +func (s *Session) Get(key string) interface{} { + return s.db.Get(key) +} + +// Set will update or create a new key value +func (s *Session) Set(key string, val interface{}) { + s.db.Set(key, val) +} + +// Delete will delete the value +func (s *Session) Delete(key string) { + s.db.Delete(key) +} + +// Destroy will delete the session from Storage and expire session cookie +func (s *Session) Destroy() error { + // Reset local data + s.db.Reset() + + // Delete data from storage + if err := s.config.Storage.Delete(s.id); err != nil { + return err + } + + // Expire cookie + s.delCookie() + return nil +} + +// Regenerate generates a new session id and delete the old one from Storage +func (s *Session) Regenerate() error { + + // Delete old id from storage + if err := s.config.Storage.Delete(s.id); err != nil { + return err + } + // Create new ID + s.id = s.config.KeyGenerator() + + return nil +} + +// Save will update the storage and client cookie +func (s *Session) Save() error { + // Don't save to Storage if no data is available + if s.db.Len() <= 0 { + return nil + } + + // Convert book to bytes + data, err := s.db.MarshalMsg(nil) + if err != nil { + return err + } + + // pass raw bytes with session id to provider + if err := s.config.Storage.Set(s.id, data, s.config.Expiration); err != nil { + return err + } + + // Create cookie with the session ID + s.setCookie() + + // release session to pool to be re-used on next request + releaseSession(s) + + return nil +} + +func (s *Session) setCookie() { + fcookie := fasthttp.AcquireCookie() + fcookie.SetKey(s.config.CookieName) + fcookie.SetValue(s.id) + fcookie.SetPath(s.config.CookiePath) + fcookie.SetDomain(s.config.CookieDomain) + fcookie.SetMaxAge(int(s.config.Expiration.Seconds())) + fcookie.SetExpire(time.Now().Add(s.config.Expiration)) + fcookie.SetSecure(s.config.CookieSecure) + fcookie.SetHTTPOnly(s.config.CookieHTTPOnly) + + switch utils.ToLower(s.config.CookieSameSite) { + case "strict": + fcookie.SetSameSite(fasthttp.CookieSameSiteStrictMode) + case "none": + fcookie.SetSameSite(fasthttp.CookieSameSiteNoneMode) + default: + fcookie.SetSameSite(fasthttp.CookieSameSiteLaxMode) + } + + s.ctx.Response().Header.SetCookie(fcookie) + fasthttp.ReleaseCookie(fcookie) +} + +func (s *Session) delCookie() { + s.ctx.Request().Header.DelCookie(s.config.CookieName) + s.ctx.Response().Header.DelCookie(s.config.CookieName) + + fcookie := fasthttp.AcquireCookie() + fcookie.SetKey(s.config.CookieName) + fcookie.SetPath(s.config.CookiePath) + fcookie.SetDomain(s.config.CookieDomain) + fcookie.SetMaxAge(-1) + fcookie.SetExpire(time.Now().Add(-1 * time.Minute)) + fcookie.SetSecure(s.config.CookieSecure) + fcookie.SetHTTPOnly(s.config.CookieHTTPOnly) + + switch utils.ToLower(s.config.CookieSameSite) { + case "strict": + fcookie.SetSameSite(fasthttp.CookieSameSiteStrictMode) + case "none": + fcookie.SetSameSite(fasthttp.CookieSameSiteNoneMode) + default: + fcookie.SetSameSite(fasthttp.CookieSameSiteLaxMode) + } + + s.ctx.Response().Header.SetCookie(fcookie) + fasthttp.ReleaseCookie(fcookie) +} diff --git a/middleware/session/session_test.go b/middleware/session/session_test.go new file mode 100644 index 0000000000..26b0c2795f --- /dev/null +++ b/middleware/session/session_test.go @@ -0,0 +1,172 @@ +package session + +import ( + "testing" + "time" + + "github.com/gofiber/fiber/v2" + "github.com/gofiber/fiber/v2/utils" + "github.com/valyala/fasthttp" +) + +// go test -run Test_Session +func Test_Session(t *testing.T) { + t.Parallel() + + // session store + store := New() + + // fiber instance + app := fiber.New() + + // fiber context + ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) + defer app.ReleaseCtx(ctx) + + // set cookie + ctx.Request().Header.SetCookie(store.CookieName, "123") + + // get session + sess, err := store.Get(ctx) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, true, sess.Fresh()) + + // get value + name := sess.Get("name") + utils.AssertEqual(t, nil, name) + + // set value + sess.Set("name", "john") + + // get value + name = sess.Get("name") + utils.AssertEqual(t, "john", name) + + // delete key + sess.Delete("name") + + // get value + name = sess.Get("name") + utils.AssertEqual(t, nil, name) + + // get id + id := sess.ID() + utils.AssertEqual(t, "123", id) + + // delete cookie + ctx.Request().Header.Del(fiber.HeaderCookie) + + // get session + sess, err = store.Get(ctx) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, true, sess.Fresh()) + + // get id + id = sess.ID() + utils.AssertEqual(t, 36, len(id)) +} + +// go test -run Test_Session_Store_Reset +func Test_Session_Store_Reset(t *testing.T) { + t.Parallel() + // session store + store := New() + // fiber instance + app := fiber.New() + // fiber context + ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) + defer app.ReleaseCtx(ctx) + + // get session + sess, _ := store.Get(ctx) + // make sure its new + utils.AssertEqual(t, true, sess.Fresh()) + // set value & save + sess.Set("hello", "world") + ctx.Request().Header.SetCookie(store.CookieName, sess.ID()) + sess.Save() + + // reset store + store.Reset() + + // make sure the session is recreated + sess, _ = store.Get(ctx) + utils.AssertEqual(t, true, sess.Fresh()) + utils.AssertEqual(t, nil, sess.Get("hello")) +} + +// go test -run Test_Session_Save +func Test_Session_Save(t *testing.T) { + t.Parallel() + + // session store + store := New() + + // fiber instance + app := fiber.New() + + // fiber context + ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) + defer app.ReleaseCtx(ctx) + + // get store + sess, _ := store.Get(ctx) + + // set value + sess.Set("name", "john") + + // save session + err := sess.Save() + utils.AssertEqual(t, nil, err) + +} + +// go test -run Test_Session_Reset +func Test_Session_Reset(t *testing.T) { + t.Parallel() + // session store + store := New() + // fiber instance + app := fiber.New() + // fiber context + ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) + defer app.ReleaseCtx(ctx) + // get session + sess, _ := store.Get(ctx) + + sess.Set("name", "fenny") + sess.Destroy() + name := sess.Get("name") + utils.AssertEqual(t, nil, name) +} + +// go test -run Test_Session_Custom_Config +func Test_Session_Custom_Config(t *testing.T) { + t.Parallel() + + store := New(Config{Expiration: time.Hour, KeyGenerator: func() string { return "very random" }}) + utils.AssertEqual(t, time.Hour, store.Expiration) + utils.AssertEqual(t, "very random", store.KeyGenerator()) + + store = New(Config{Expiration: 0}) + utils.AssertEqual(t, ConfigDefault.Expiration, store.Expiration) +} + +// go test -run Test_Session_Cookie +func Test_Session_Cookie(t *testing.T) { + t.Parallel() + // session store + store := New() + // fiber instance + app := fiber.New() + // fiber context + ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) + defer app.ReleaseCtx(ctx) + + // get session + sess, _ := store.Get(ctx) + sess.Save() + + // cookie should not be set if empty data + utils.AssertEqual(t, 0, len(ctx.Response().Header.PeekCookie(store.CookieName))) +} diff --git a/middleware/session/store.go b/middleware/session/store.go new file mode 100644 index 0000000000..2b12e919af --- /dev/null +++ b/middleware/session/store.go @@ -0,0 +1,67 @@ +package session + +import ( + "github.com/gofiber/fiber/v2" + "github.com/gofiber/fiber/v2/internal/storage/memory" +) + +type Store struct { + Config +} + +// Storage ErrNotExist +var errNotExist = "key does not exist" + +func New(config ...Config) *Store { + // Set default config + cfg := configDefault(config...) + + if cfg.Storage == nil { + cfg.Storage = memory.New() + } + + return &Store{ + cfg, + } +} + +func (s *Store) Get(c *fiber.Ctx) (*Session, error) { + var fresh bool + + // Get key from cookie + id := c.Cookies(s.CookieName) + + // If no key exist, create new one + if len(id) == 0 { + id = s.KeyGenerator() + fresh = true + } + + // Create session object + sess := acquireSession() + sess.ctx = c + sess.config = s + sess.id = id + + // Fetch existing data + if !fresh { + raw, err := s.Storage.Get(id) + // Unmashal if we found data + if err == nil { + if _, err = sess.db.UnmarshalMsg(raw); err != nil { + return nil, err + } + sess.fresh = false + } else if err.Error() != errNotExist { + // Only return error if it's not ErrNotExist + return nil, err + } + } + + return sess, nil +} + +// Reset will delete all session from the storage +func (s *Store) Reset() error { + return s.Storage.Reset() +} diff --git a/router.go b/router.go index 081c6c67d7..909bd11301 100644 --- a/router.go +++ b/router.go @@ -349,7 +349,9 @@ func (app *App) registerStatic(prefix, root string, config ...Static) Router { if maxAge > 0 { cacheControlValue = "public, max-age=" + strconv.Itoa(maxAge) } - + if config[0].CacheDuration != 0 { + fs.CacheDuration = config[0].CacheDuration + } fs.Compress = config[0].Compress fs.AcceptByteRange = config[0].ByteRange fs.GenerateIndexPages = config[0].Browse