Skip to content

Commit 13c3db3

Browse files
committed
more test coverage
1 parent 49c327f commit 13c3db3

File tree

2 files changed

+97
-3
lines changed

2 files changed

+97
-3
lines changed

pubsub.go

+1-2
Original file line numberDiff line numberDiff line change
@@ -412,7 +412,6 @@ func (c *PubSub) newMessage(reply interface{}) (interface{}, error) {
412412
case "invalidate":
413413
switch payload := reply[1].(type) {
414414
case []interface{}:
415-
_ = payload
416415
s := make([]string, len(payload))
417416
for idx := range payload {
418417
s[idx] = payload[idx].(string)
@@ -422,7 +421,7 @@ func (c *PubSub) newMessage(reply interface{}) (interface{}, error) {
422421
PayloadSlice: s,
423422
}, nil
424423
default:
425-
return nil, fmt.Errorf("redis: unsupported invalidate message payload: %q", payload)
424+
return nil, fmt.Errorf("redis: unsupported invalidate message payload: %#v", payload)
426425
}
427426
default:
428427
return nil, fmt.Errorf("redis: unsupported pubsub message: %q", kind)

pubsub_test.go

+96-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
package redis_test
22

33
import (
4+
"context"
5+
"fmt"
46
"io"
57
"net"
68
"sync"
@@ -568,5 +570,98 @@ var _ = Describe("PubSub", func() {
568570
Expect(msg.Payload).To(Equal(text))
569571
})
570572

571-
It("should channel client-cache invalidation messages", func() {})
573+
It("supports client-cache invalidation messages", func() {
574+
ch := make(chan []string, 2)
575+
defer close(ch)
576+
client := redis.NewClient(getOptsWithTracking(redisOptions(), func(keys []string) error {
577+
ch <- keys
578+
return nil
579+
}))
580+
defer client.Close()
581+
582+
v1 := client.Get(context.Background(), "foo")
583+
Expect(v1.Val()).To(Equal(""))
584+
s1 := client.Set(context.Background(), "foo", "bar", time.Duration(time.Minute))
585+
Expect(s1.Val()).To(Equal("OK"))
586+
v2 := client.Get(context.Background(), "foo")
587+
Expect(v2.Val()).To(Equal("bar"))
588+
// sleep a little to all time for the first invalidation message to come through
589+
time.Sleep(time.Second)
590+
s2 := client.Set(context.Background(), "foo", "foobar", time.Duration(time.Minute))
591+
Expect(s2.Val()).To(Equal("OK"))
592+
593+
for i := 0; i < 2; i++ {
594+
select {
595+
case keys := <-ch:
596+
Expect(keys).ToNot(BeEmpty())
597+
Expect(keys[0]).To(Equal("foo"))
598+
case <-time.After(10 * time.Second):
599+
// fail on timeouts
600+
Fail("invalidation message wait timed out")
601+
}
602+
}
603+
})
604+
572605
})
606+
607+
func getOptsWithTracking(opt *redis.Options, processInvalidKeysFunc func([]string) error) *redis.Options {
608+
var mu sync.Mutex
609+
invalidateClientID := int64(-1)
610+
invalidateOpts := *opt
611+
invalidateOpts.OnConnect = func(ctx context.Context, conn *redis.Conn) (err error) {
612+
invalidateClientID, err = conn.ClientID(ctx).Result()
613+
return
614+
}
615+
616+
startBackgroundInvalidationSubscription := func(ctx context.Context) int64 {
617+
mu.Lock()
618+
defer mu.Unlock()
619+
620+
if invalidateClientID != -1 {
621+
return invalidateClientID
622+
}
623+
624+
invalidateClient := redis.NewClient(&invalidateOpts)
625+
invalidations := invalidateClient.Subscribe(ctx, "__redis__:invalidate")
626+
627+
go func() {
628+
defer func() {
629+
invalidations.Close()
630+
invalidateClient.Close()
631+
632+
mu.Lock()
633+
invalidateClientID = -1
634+
mu.Unlock()
635+
}()
636+
637+
for {
638+
msg, err := invalidations.ReceiveMessage(context.Background())
639+
if err == io.EOF || err == context.Canceled {
640+
return
641+
} else if err != nil {
642+
fmt.Printf("warning: subscription on key invalidations aborted: %s\n", err.Error())
643+
// send back empty []string to fail the test
644+
processInvalidKeysFunc([]string{})
645+
return
646+
}
647+
648+
processInvalidKeysFunc(msg.PayloadSlice)
649+
}
650+
}()
651+
652+
return invalidateClientID
653+
}
654+
655+
opt.OnConnect = func(ctx context.Context, conn *redis.Conn) error {
656+
invalidateClientID := startBackgroundInvalidationSubscription(ctx)
657+
return conn.Process(
658+
ctx,
659+
redis.NewBoolCmd(
660+
ctx,
661+
"CLIENT", "TRACKING", "on",
662+
"REDIRECT", fmt.Sprintf("%d", invalidateClientID),
663+
),
664+
)
665+
}
666+
return opt
667+
}

0 commit comments

Comments
 (0)