diff --git a/x/blockdb/cache_db.go b/x/blockdb/cache_db.go index 12e4edbc14d3..7d415911fb1c 100644 --- a/x/blockdb/cache_db.go +++ b/x/blockdb/cache_db.go @@ -5,7 +5,7 @@ package blockdb import ( "slices" - "sync" + "sync/atomic" "go.uber.org/zap" @@ -22,11 +22,9 @@ var _ database.HeightIndex = (*cacheDB)(nil) // the cache and database contain different values. This limitation is acceptable // because concurrent writes to the same height are not an intended use case. type cacheDB struct { - db *Database - cache *lru.Cache[BlockHeight, BlockData] - - closeMu sync.RWMutex - closed bool + db *Database + cache *lru.Cache[BlockHeight, BlockData] + closed atomic.Bool } func newCacheDB(db *Database, size uint16) *cacheDB { @@ -37,10 +35,7 @@ func newCacheDB(db *Database, size uint16) *cacheDB { } func (c *cacheDB) Get(height BlockHeight) (BlockData, error) { - c.closeMu.RLock() - defer c.closeMu.RUnlock() - - if c.closed { + if c.closed.Load() { c.db.log.Error("Failed Get: database closed", zap.Uint64("height", height)) return nil, database.ErrClosed } @@ -57,10 +52,7 @@ func (c *cacheDB) Get(height BlockHeight) (BlockData, error) { } func (c *cacheDB) Put(height BlockHeight, data BlockData) error { - c.closeMu.RLock() - defer c.closeMu.RUnlock() - - if c.closed { + if c.closed.Load() { c.db.log.Error("Failed Put: database closed", zap.Uint64("height", height)) return database.ErrClosed } @@ -74,10 +66,7 @@ func (c *cacheDB) Put(height BlockHeight, data BlockData) error { } func (c *cacheDB) Has(height BlockHeight) (bool, error) { - c.closeMu.RLock() - defer c.closeMu.RUnlock() - - if c.closed { + if c.closed.Load() { c.db.log.Error("Failed Has: database closed", zap.Uint64("height", height)) return false, database.ErrClosed } @@ -89,13 +78,9 @@ func (c *cacheDB) Has(height BlockHeight) (bool, error) { } func (c *cacheDB) Close() error { - c.closeMu.Lock() - defer c.closeMu.Unlock() - - if c.closed { + if !c.closed.CompareAndSwap(false, true) { return database.ErrClosed } - c.closed = true c.cache.Flush() return c.db.Close() }