From 66a881441b27322a331f1b526cf1eb6b3358a4d8 Mon Sep 17 00:00:00 2001 From: Jason McNeil Date: Sun, 30 Jun 2024 16:16:23 -0300 Subject: [PATCH] fix(middleware/session): mutex for thread safety (#3050) * chore: Remove extra release and acquire ctx calls in session_test.go * feat: Remove unnecessary session mutex lock in decodeSessionData function * chore: Refactor session benchmark tests * fix(middleware/session): mutex for thread safety * feat: Add session mutex lock for thread safety * chore: Refactor releaseSession mutex --- middleware/session/session.go | 49 +++++- middleware/session/session_test.go | 229 +++++++++++++++++++++++++++++ middleware/session/store.go | 22 +-- 3 files changed, 276 insertions(+), 24 deletions(-) diff --git a/middleware/session/session.go b/middleware/session/session.go index ebe00f6057..b4115faf23 100644 --- a/middleware/session/session.go +++ b/middleware/session/session.go @@ -14,6 +14,7 @@ import ( ) type Session struct { + mu sync.RWMutex // Mutex to protect non-data fields id string // session id fresh bool // if new session ctx *fiber.Ctx // fiber context @@ -42,6 +43,7 @@ func acquireSession() *Session { } func releaseSession(s *Session) { + s.mu.Lock() s.id = "" s.exp = 0 s.ctx = nil @@ -52,16 +54,21 @@ func releaseSession(s *Session) { if s.byteBuffer != nil { s.byteBuffer.Reset() } + s.mu.Unlock() sessionPool.Put(s) } // Fresh is true if the current session is new func (s *Session) Fresh() bool { + s.mu.RLock() + defer s.mu.RUnlock() return s.fresh } // ID returns the session id func (s *Session) ID() string { + s.mu.RLock() + defer s.mu.RUnlock() return s.id } @@ -102,6 +109,9 @@ func (s *Session) Destroy() error { // Reset local data s.data.Reset() + s.mu.RLock() + defer s.mu.RUnlock() + // Use external Storage if exist if err := s.config.Storage.Delete(s.id); err != nil { return err @@ -114,6 +124,9 @@ func (s *Session) Destroy() error { // Regenerate generates a new session id and delete the old one from Storage func (s *Session) Regenerate() error { + s.mu.Lock() + defer s.mu.Unlock() + // Delete old id from storage if err := s.config.Storage.Delete(s.id); err != nil { return err @@ -131,6 +144,10 @@ func (s *Session) Reset() error { if s.data != nil { s.data.Reset() } + + s.mu.Lock() + defer s.mu.Unlock() + // Reset byte buffer if s.byteBuffer != nil { s.byteBuffer.Reset() @@ -154,20 +171,24 @@ func (s *Session) Reset() error { // refresh generates a new session, and set session.fresh to be true func (s *Session) refresh() { - // Create a new id s.id = s.config.KeyGenerator() - - // We assign a new id to the session, so the session must be fresh s.fresh = true } // Save will update the storage and client cookie +// +// sess.Save() will save the session data to the storage and update the +// client cookie, and it will release the session after saving. +// +// It's not safe to use the session after calling Save(). func (s *Session) Save() error { // Better safe than sorry if s.data == nil { return nil } + s.mu.Lock() + // Check if session has your own expiration, otherwise use default value if s.exp <= 0 { s.exp = s.config.Expiration @@ -177,25 +198,25 @@ func (s *Session) Save() error { s.setSession() // Convert data to bytes - mux.Lock() - defer mux.Unlock() encCache := gob.NewEncoder(s.byteBuffer) err := encCache.Encode(&s.data.Data) if err != nil { return fmt.Errorf("failed to encode data: %w", err) } - // copy the data in buffer + // Copy the data in buffer encodedBytes := make([]byte, s.byteBuffer.Len()) copy(encodedBytes, s.byteBuffer.Bytes()) - // pass copied bytes with session id to provider + // Pass copied bytes with session id to provider if err := s.config.Storage.Set(s.id, encodedBytes, s.exp); err != nil { return err } + s.mu.Unlock() + // Release session - // TODO: It's not safe to use the Session after called Save() + // TODO: It's not safe to use the Session after calling Save() releaseSession(s) return nil @@ -211,6 +232,8 @@ func (s *Session) Keys() []string { // SetExpiry sets a specific expiration for this session func (s *Session) SetExpiry(exp time.Duration) { + s.mu.Lock() + defer s.mu.Unlock() s.exp = exp } @@ -276,3 +299,13 @@ func (s *Session) delSession() { fasthttp.ReleaseCookie(fcookie) } } + +// decodeSessionData decodes the session data from raw bytes. +func (s *Session) decodeSessionData(rawData []byte) error { + _, _ = s.byteBuffer.Write(rawData) //nolint:errcheck // This will never fail + encCache := gob.NewDecoder(s.byteBuffer) + if err := encCache.Decode(&s.data.Data); err != nil { + return fmt.Errorf("failed to decode session data: %w", err) + } + return nil +} diff --git a/middleware/session/session_test.go b/middleware/session/session_test.go index 5345393207..ccd3fd0690 100644 --- a/middleware/session/session_test.go +++ b/middleware/session/session_test.go @@ -1,6 +1,8 @@ package session import ( + "errors" + "sync" "testing" "time" @@ -673,3 +675,230 @@ func Benchmark_Session(b *testing.B) { utils.AssertEqual(b, nil, err) }) } + +// go test -v -run=^$ -bench=Benchmark_Session_Parallel -benchmem -count=4 +func Benchmark_Session_Parallel(b *testing.B) { + b.Run("default", func(b *testing.B) { + app, store := fiber.New(), New() + b.ReportAllocs() + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + c := app.AcquireCtx(&fasthttp.RequestCtx{}) + c.Request().Header.SetCookie(store.sessionName, "12356789") + + sess, _ := store.Get(c) //nolint:errcheck // We're inside a benchmark + sess.Set("john", "doe") + _ = sess.Save() //nolint:errcheck // We're inside a benchmark + app.ReleaseCtx(c) + } + }) + }) + + b.Run("storage", func(b *testing.B) { + app := fiber.New() + store := New(Config{ + Storage: memory.New(), + }) + b.ReportAllocs() + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + c := app.AcquireCtx(&fasthttp.RequestCtx{}) + c.Request().Header.SetCookie(store.sessionName, "12356789") + + sess, _ := store.Get(c) //nolint:errcheck // We're inside a benchmark + sess.Set("john", "doe") + _ = sess.Save() //nolint:errcheck // We're inside a benchmark + app.ReleaseCtx(c) + } + }) + }) +} + +// go test -v -run=^$ -bench=Benchmark_Session_Asserted -benchmem -count=4 +func Benchmark_Session_Asserted(b *testing.B) { + b.Run("default", func(b *testing.B) { + app, store := fiber.New(), New() + c := app.AcquireCtx(&fasthttp.RequestCtx{}) + defer app.ReleaseCtx(c) + c.Request().Header.SetCookie(store.sessionName, "12356789") + + b.ReportAllocs() + b.ResetTimer() + for n := 0; n < b.N; n++ { + sess, err := store.Get(c) + utils.AssertEqual(b, nil, err) + sess.Set("john", "doe") + err = sess.Save() + utils.AssertEqual(b, nil, err) + } + }) + + b.Run("storage", func(b *testing.B) { + app := fiber.New() + store := New(Config{ + Storage: memory.New(), + }) + c := app.AcquireCtx(&fasthttp.RequestCtx{}) + defer app.ReleaseCtx(c) + c.Request().Header.SetCookie(store.sessionName, "12356789") + + b.ReportAllocs() + b.ResetTimer() + for n := 0; n < b.N; n++ { + sess, err := store.Get(c) + utils.AssertEqual(b, nil, err) + sess.Set("john", "doe") + err = sess.Save() + utils.AssertEqual(b, nil, err) + } + }) +} + +// go test -v -run=^$ -bench=Benchmark_Session_Asserted_Parallel -benchmem -count=4 +func Benchmark_Session_Asserted_Parallel(b *testing.B) { + b.Run("default", func(b *testing.B) { + app, store := fiber.New(), New() + b.ReportAllocs() + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + c := app.AcquireCtx(&fasthttp.RequestCtx{}) + c.Request().Header.SetCookie(store.sessionName, "12356789") + + sess, err := store.Get(c) + utils.AssertEqual(b, nil, err) + sess.Set("john", "doe") + utils.AssertEqual(b, nil, sess.Save()) + app.ReleaseCtx(c) + } + }) + }) + + b.Run("storage", func(b *testing.B) { + app := fiber.New() + store := New(Config{ + Storage: memory.New(), + }) + b.ReportAllocs() + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + c := app.AcquireCtx(&fasthttp.RequestCtx{}) + c.Request().Header.SetCookie(store.sessionName, "12356789") + + sess, err := store.Get(c) + utils.AssertEqual(b, nil, err) + sess.Set("john", "doe") + utils.AssertEqual(b, nil, sess.Save()) + app.ReleaseCtx(c) + } + }) + }) +} + +// go test -v -race -run Test_Session_Concurrency ./... +func Test_Session_Concurrency(t *testing.T) { + t.Parallel() + app := fiber.New() + store := New() + + var wg sync.WaitGroup + errChan := make(chan error, 10) // Buffered channel to collect errors + const numGoroutines = 10 // Number of concurrent goroutines to test + + // Start numGoroutines goroutines + for i := 0; i < numGoroutines; i++ { + wg.Add(1) + go func() { + defer wg.Done() + + localCtx := app.AcquireCtx(&fasthttp.RequestCtx{}) + + sess, err := store.Get(localCtx) + if err != nil { + errChan <- err + return + } + + // Set a value + sess.Set("name", "john") + + // get the session id + id := sess.ID() + + // Check if the session is fresh + if !sess.Fresh() { + errChan <- errors.New("session should be fresh") + return + } + + // Save the session + if err := sess.Save(); err != nil { + errChan <- err + return + } + + // Release the context + app.ReleaseCtx(localCtx) + + // Acquire a new context + localCtx = app.AcquireCtx(&fasthttp.RequestCtx{}) + defer app.ReleaseCtx(localCtx) + + // Set the session id in the header + localCtx.Request().Header.SetCookie(store.sessionName, id) + + // Get the session + sess, err = store.Get(localCtx) + if err != nil { + errChan <- err + return + } + + // Get the value + name := sess.Get("name") + if name != "john" { + errChan <- errors.New("name should be john") + return + } + + // Get ID from the session + if sess.ID() != id { + errChan <- errors.New("id should be the same") + return + } + + // Check if the session is fresh + if sess.Fresh() { + errChan <- errors.New("session should not be fresh") + return + } + + // Delete the key + sess.Delete("name") + + // Get the value + name = sess.Get("name") + if name != nil { + errChan <- errors.New("name should be nil") + return + } + + // Destroy the session + if err := sess.Destroy(); err != nil { + errChan <- err + return + } + }() + } + + wg.Wait() // Wait for all goroutines to finish + close(errChan) // Close the channel to signal no more errors will be sent + + // Check for errors sent to errChan + for err := range errChan { + utils.AssertEqual(t, nil, err) + } +} diff --git a/middleware/session/store.go b/middleware/session/store.go index 3b692f4928..65db815d13 100644 --- a/middleware/session/store.go +++ b/middleware/session/store.go @@ -4,7 +4,6 @@ import ( "encoding/gob" "errors" "fmt" - "sync" "github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2/internal/storage/memory" @@ -14,9 +13,6 @@ import ( // ErrEmptySessionID is an error that occurs when the session ID is empty. var ErrEmptySessionID = errors.New("session id cannot be empty") -// mux is a global mutex for session operations. -var mux sync.Mutex - // sessionIDKey is the local key type used to store and retrieve the session ID in context. type sessionIDKey int @@ -81,6 +77,10 @@ func (s *Store) Get(c *fiber.Ctx) (*Session, error) { // Create session object sess := acquireSession() + + sess.mu.Lock() + defer sess.mu.Unlock() + sess.ctx = c sess.config = s sess.id = id @@ -88,6 +88,8 @@ func (s *Store) Get(c *fiber.Ctx) (*Session, error) { // Decode session data if found if rawData != nil { + sess.data.Lock() + defer sess.data.Unlock() if err := sess.decodeSessionData(rawData); err != nil { return nil, fmt.Errorf("failed to decode session data: %w", err) } @@ -132,15 +134,3 @@ func (s *Store) Delete(id string) error { } return s.Storage.Delete(id) } - -// decodeSessionData decodes the session data from raw bytes. -func (s *Session) decodeSessionData(rawData []byte) error { - mux.Lock() - defer mux.Unlock() - _, _ = s.byteBuffer.Write(rawData) //nolint:errcheck // This will never fail - encCache := gob.NewDecoder(s.byteBuffer) - if err := encCache.Decode(&s.data.Data); err != nil { - return fmt.Errorf("failed to decode session data: %w", err) - } - return nil -}