Skip to content

Commit eba9a0f

Browse files
committed
client-cache invalidation message parsing code
1 parent 2d8fa02 commit eba9a0f

File tree

2 files changed

+112
-0
lines changed

2 files changed

+112
-0
lines changed

pubsub.go

+14
Original file line numberDiff line numberDiff line change
@@ -409,6 +409,20 @@ func (c *PubSub) newMessage(reply interface{}) (interface{}, error) {
409409
return &Pong{
410410
Payload: reply[1].(string),
411411
}, nil
412+
case "invalidate":
413+
switch payload := reply[1].(type) {
414+
case []interface{}:
415+
s := make([]string, len(payload))
416+
for idx := range payload {
417+
s[idx] = payload[idx].(string)
418+
}
419+
return &Message{
420+
Channel: "invalidate",
421+
PayloadSlice: s,
422+
}, nil
423+
default:
424+
return nil, fmt.Errorf("redis: unsupported invalidate message payload: %#v", payload)
425+
}
412426
default:
413427
return nil, fmt.Errorf("redis: unsupported pubsub message: %q", kind)
414428
}

pubsub_test.go

+98
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
package redis_test
22

33
import (
4+
"context"
5+
"fmt"
46
"io"
57
"net"
68
"sync"
@@ -567,4 +569,100 @@ var _ = Describe("PubSub", func() {
567569
Expect(msg.Channel).To(Equal("mychannel"))
568570
Expect(msg.Payload).To(Equal(text))
569571
})
572+
573+
It("supports client-cache invalidation messages", func() {
574+
ch := make(chan []string, 2)
575+
defer close(ch)
576+
client := redis.NewClient(getOptsWithTracking(redisOptions(), func(keys []string) error {
577+
ch <- keys
578+
return nil
579+
}))
580+
defer client.Close()
581+
582+
v1 := client.Get(context.Background(), "foo")
583+
Expect(v1.Val()).To(Equal(""))
584+
s1 := client.Set(context.Background(), "foo", "bar", time.Duration(time.Minute))
585+
Expect(s1.Val()).To(Equal("OK"))
586+
v2 := client.Get(context.Background(), "foo")
587+
Expect(v2.Val()).To(Equal("bar"))
588+
// sleep a little to allow time for the first invalidation message to come through
589+
590+
time.Sleep(time.Second)
591+
s2 := client.Set(context.Background(), "foo", "foobar", time.Duration(time.Minute))
592+
Expect(s2.Val()).To(Equal("OK"))
593+
594+
for i := 0; i < 2; i++ {
595+
select {
596+
case keys := <-ch:
597+
Expect(keys).ToNot(BeEmpty())
598+
Expect(keys[0]).To(Equal("foo"))
599+
case <-time.After(10 * time.Second):
600+
// fail on timeouts
601+
Fail("invalidation message wait timed out")
602+
}
603+
}
604+
})
605+
570606
})
607+
608+
func getOptsWithTracking(opt *redis.Options, processInvalidKeysFunc func([]string) error) *redis.Options {
609+
var mu sync.Mutex
610+
invalidateClientID := int64(-1)
611+
invalidateOpts := *opt
612+
invalidateOpts.OnConnect = func(ctx context.Context, conn *redis.Conn) (err error) {
613+
invalidateClientID, err = conn.ClientID(ctx).Result()
614+
return
615+
}
616+
617+
startBackgroundInvalidationSubscription := func(ctx context.Context) int64 {
618+
mu.Lock()
619+
defer mu.Unlock()
620+
621+
if invalidateClientID != -1 {
622+
return invalidateClientID
623+
}
624+
625+
invalidateClient := redis.NewClient(&invalidateOpts)
626+
invalidations := invalidateClient.Subscribe(ctx, "__redis__:invalidate")
627+
628+
go func() {
629+
defer func() {
630+
invalidations.Close()
631+
invalidateClient.Close()
632+
633+
mu.Lock()
634+
invalidateClientID = -1
635+
mu.Unlock()
636+
}()
637+
638+
for {
639+
msg, err := invalidations.ReceiveMessage(context.Background())
640+
if err == io.EOF || err == context.Canceled {
641+
return
642+
} else if err != nil {
643+
fmt.Printf("warning: subscription on key invalidations aborted: %s\n", err.Error())
644+
// send back empty []string to fail the test
645+
processInvalidKeysFunc([]string{})
646+
return
647+
}
648+
649+
processInvalidKeysFunc(msg.PayloadSlice)
650+
}
651+
}()
652+
653+
return invalidateClientID
654+
}
655+
656+
opt.OnConnect = func(ctx context.Context, conn *redis.Conn) error {
657+
invalidateClientID := startBackgroundInvalidationSubscription(ctx)
658+
return conn.Process(
659+
ctx,
660+
redis.NewBoolCmd(
661+
ctx,
662+
"CLIENT", "TRACKING", "on",
663+
"REDIRECT", fmt.Sprintf("%d", invalidateClientID),
664+
),
665+
)
666+
}
667+
return opt
668+
}

0 commit comments

Comments
 (0)