Skip to content

Commit 5192b20

Browse files
authored
Merge pull request #14 from arthhhhh23/fix/ttl-map-orphaned-elements
fix: map ttl reinsert head or tail
2 parents b63990f + 43700e4 commit 5192b20

File tree

4 files changed

+202
-8
lines changed

4 files changed

+202
-8
lines changed

locker_test.go

+8
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,10 @@ func TestLockerRPanics(t *testing.T) {
9393
t.Errorf("expected panic (Delete on RLocked)")
9494
}
9595

96+
if !panics(func() { tx.SetIfPresent(1, 1) }) {
97+
t.Errorf("expected panic (SetIfPresent on RLocked)")
98+
}
99+
96100
tx.Unlock()
97101
if !panics(func() { tx.Unlock() }) {
98102
t.Errorf("expected panic (Unlock on already unlocked)")
@@ -117,6 +121,10 @@ func TestLockerRPanics(t *testing.T) {
117121
if !panics(func() { tx.Snapshot() }) {
118122
t.Errorf("expected panic (Snapshot on already unlocked)")
119123
}
124+
125+
if !panics(func() { tx.SetIfPresent(1, 1) }) {
126+
t.Errorf("expected panic (SetIfPresent on already unlocked)")
127+
}
120128
}
121129

122130
func TestLockerRWPanics(t *testing.T) {

map_ttl.go

+25-8
Original file line numberDiff line numberDiff line change
@@ -190,29 +190,46 @@ func (c *MapTTLCache[K, V]) Len() int {
190190
}
191191

192192
func (c *MapTTLCache[K, V]) set(key K, value V) {
193+
ts := c.now()
193194
val := ttlRec[K, V]{
194195
value: value,
195196
prev: c.tail,
196-
timestamp: c.now(),
197+
timestamp: ts,
197198
}
198199

199200
if c.head == c.zero {
200201
c.head = key
201202
c.tail = key
202-
val.prev = c.zero
203203
c.data[key] = val
204204
return
205205
}
206206

207+
// If it's already the tail, we only need to update the value and timestamp
208+
if c.tail == key {
209+
rec := c.data[c.tail]
210+
rec.timestamp = ts
211+
rec.value = value
212+
c.data[c.tail] = rec
213+
return
214+
}
215+
207216
// If the record for this key already exists
208-
// and is somewhere in the middle of the list
217+
// and is not already the tail of the list,
209218
// removing it before adding to the tail.
210-
if rec, ok := c.data[key]; ok && key != c.tail {
211-
prev := c.data[rec.prev]
219+
if rec, ok := c.data[key]; ok {
212220
next := c.data[rec.next]
213-
prev.next = rec.next
214-
next.prev = rec.prev
215-
c.data[rec.prev] = prev
221+
222+
// edge case: the current head becomes the new tail
223+
if key == c.head {
224+
c.head = rec.next
225+
next.prev = c.zero
226+
} else {
227+
prev := c.data[rec.prev]
228+
prev.next = rec.next
229+
c.data[rec.prev] = prev
230+
next.prev = rec.prev
231+
}
232+
216233
c.data[rec.next] = next
217234
}
218235

ttl_test.go

+128
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@ package geche
22

33
import (
44
"context"
5+
"math/rand"
56
"strconv"
7+
"sync"
68
"testing"
79
"time"
810
)
@@ -174,6 +176,132 @@ func TestTTLScenario(t *testing.T) {
174176
}
175177
}
176178

179+
func TestHeadTailLogicConcurrent(t *testing.T) {
180+
ctx, cancel := context.WithCancel(context.Background())
181+
defer cancel()
182+
183+
m := NewMapTTLCache[string, string](ctx, time.Millisecond, time.Hour)
184+
185+
pool := make([]string, 50)
186+
for i := range pool {
187+
pool[i] = randomString(10)
188+
}
189+
190+
wg := sync.WaitGroup{}
191+
for i := 0; i < 1000; i++ {
192+
idx := rand.Intn(len(pool))
193+
wg.Add(1)
194+
195+
go func() {
196+
defer wg.Done()
197+
198+
_, err := m.Get(pool[idx])
199+
if err != nil {
200+
m.Set(pool[idx], pool[idx])
201+
}
202+
}()
203+
}
204+
205+
wg.Wait()
206+
time.Sleep(time.Millisecond * 3)
207+
208+
var expectedHeadKey, expectedTailKey string
209+
var expectedHeadValue, expectedTailValue ttlRec[string, string]
210+
211+
keys := map[string]struct{}{}
212+
213+
for k, v := range m.data {
214+
keys[k] = struct{}{}
215+
if expectedTailKey == "" {
216+
expectedTailKey = k
217+
expectedHeadKey = k
218+
expectedHeadValue = v
219+
expectedTailValue = v
220+
} else if expectedTailValue.timestamp.Before(v.timestamp) {
221+
expectedTailKey = k
222+
expectedTailValue = v
223+
} else if expectedHeadValue.timestamp.After(v.timestamp) {
224+
expectedHeadKey = k
225+
expectedHeadValue = v
226+
}
227+
}
228+
229+
for k, v := range m.data {
230+
if _, ok := keys[v.next]; k != m.tail && !ok {
231+
t.Errorf("expected key %q not found in data", v.next)
232+
}
233+
234+
if _, ok := keys[v.prev]; k != m.head && !ok {
235+
t.Errorf("expected key %q not found in data", v.prev)
236+
}
237+
}
238+
239+
if m.head != expectedHeadKey {
240+
t.Errorf("expected head key %q, but got %v", expectedHeadKey, m.head)
241+
}
242+
243+
if m.tail != expectedTailKey {
244+
t.Errorf("expected tail key %q, but got %v", expectedTailKey, m.tail)
245+
}
246+
247+
if err := m.cleanup(); err != nil {
248+
t.Errorf("unexpected error in cleanup: %v", err)
249+
}
250+
251+
if m.Len() != 0 {
252+
t.Errorf("expected clean to have %d elements, but got %d", 0, m.Len())
253+
}
254+
255+
if m.tail != m.zero {
256+
t.Errorf("expected tail to be zero, but got %v", m.tail)
257+
}
258+
259+
if m.head != m.zero {
260+
t.Errorf("expected head to be zero, but got %v", m.head)
261+
}
262+
}
263+
264+
func TestReinsertHead(t *testing.T) {
265+
c := NewMapTTLCache[string, string](context.Background(), time.Millisecond, time.Second)
266+
c.Set("k1", "v1")
267+
c.Set("k2", "v2")
268+
c.Set("k3", "v3")
269+
c.Set("k1", "v2")
270+
time.Sleep(2 * time.Millisecond)
271+
if err := c.cleanup(); err != nil {
272+
t.Errorf("unexpected cleanup error: %v", err)
273+
}
274+
275+
if c.Len() != 0 {
276+
t.Errorf("expected cache data len to be 0 but got %d", c.Len())
277+
}
278+
}
279+
280+
func TestReinsertTail(t *testing.T) {
281+
c := NewMapTTLCache[string, string](context.Background(), time.Millisecond, time.Second)
282+
c.Set("k1", "v1")
283+
c.Set("k2", "v2")
284+
c.Set("k3", "v3")
285+
c.Set("k3", "v4")
286+
time.Sleep(2 * time.Millisecond)
287+
288+
if c.data["k3"].next != "" {
289+
t.Errorf("expected tail next to be zero")
290+
}
291+
292+
if c.data["k3"].prev != "k2" {
293+
t.Errorf("expected tail prev to be k2, but got %s", c.data["k3"].prev)
294+
}
295+
296+
if err := c.cleanup(); err != nil {
297+
t.Errorf("unexpected cleanup error: %v", err)
298+
}
299+
300+
if c.Len() != 0 {
301+
t.Errorf("expected cache data len to be 0 but got %d", c.Len())
302+
}
303+
}
304+
177305
func TestSetIfPresentResetsTTL(t *testing.T) {
178306
ctx, cancel := context.WithCancel(context.Background())
179307
defer cancel()

updater_test.go

+41
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,47 @@ func TestUpdaterScenario(t *testing.T) {
103103
}
104104
}
105105

106+
func TestUpdaterSetIfPresent(t *testing.T) {
107+
u := NewCacheUpdater[string, string](
108+
NewMapCache[string, string](),
109+
updateFn,
110+
2,
111+
)
112+
113+
if u.Len() != 0 {
114+
t.Errorf("expected length to be 0, but got %d", u.Len())
115+
}
116+
117+
s, inserted := u.SetIfPresent("test", "test")
118+
if inserted {
119+
t.Error("expected not to insert the value")
120+
}
121+
122+
if s != "" {
123+
t.Errorf("expected to get empty string, but got %q", s)
124+
}
125+
126+
u.Set("test", "test")
127+
128+
s, inserted = u.SetIfPresent("test", "test2")
129+
if !inserted {
130+
t.Error("expected to insert the value")
131+
}
132+
133+
if s != "test" {
134+
t.Errorf("expected to get %q, but got %q", "test", s)
135+
}
136+
137+
v, err := u.Get("test")
138+
if err != nil {
139+
t.Errorf("unexpected error in Get: %v", err)
140+
}
141+
142+
if v != "test2" {
143+
t.Errorf("expected to get %q, but got %q", "test2", v)
144+
}
145+
}
146+
106147
func TestUpdaterErr(t *testing.T) {
107148
u := NewCacheUpdater[string, string](
108149
NewMapCache[string, string](),

0 commit comments

Comments
 (0)