Skip to content

Commit d253629

Browse files
committed
more test coverage
1 parent 49c327f commit d253629

File tree

1 file changed

+94
-1
lines changed

1 file changed

+94
-1
lines changed

pubsub_test.go

+94-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
package redis_test
22

33
import (
4+
"context"
5+
"fmt"
46
"io"
57
"net"
68
"sync"
@@ -568,5 +570,96 @@ var _ = Describe("PubSub", func() {
568570
Expect(msg.Payload).To(Equal(text))
569571
})
570572

571-
It("should channel client-cache invalidation messages", func() {})
573+
It("supports client-cache invalidation messages", func() {
574+
ch := make(chan []string, 2)
575+
defer close(ch)
576+
client := redis.NewClient(getOptsWithTracking(redisOptions(), func(keys []string) error {
577+
ch <- keys
578+
return nil
579+
}))
580+
defer client.Close()
581+
582+
v1 := client.Get(context.Background(), "foo")
583+
Expect(v1.Val()).To(Equal(""))
584+
s1 := client.Set(context.Background(), "foo", "bar", time.Duration(time.Minute))
585+
Expect(s1.Val()).To(Equal("OK"))
586+
v2 := client.Get(context.Background(), "foo")
587+
Expect(v2.Val()).To(Equal("bar"))
588+
s2 := client.Set(context.Background(), "foo", "foobar", time.Duration(time.Minute))
589+
Expect(s2.Val()).To(Equal("OK"))
590+
591+
for i := 0; i < 2; i++ {
592+
select {
593+
case keys := <-ch:
594+
Expect(keys).ToNot(BeEmpty())
595+
Expect(keys[0]).To(Equal("foo"))
596+
case <-time.After(time.Second):
597+
// fail on timeouts
598+
Expect(1).To(Equal(0))
599+
}
600+
}
601+
})
602+
572603
})
604+
605+
func getOptsWithTracking(opt *redis.Options, processInvalidKeysFunc func([]string) error) *redis.Options {
606+
var mu sync.Mutex
607+
invalidateClientID := int64(-1)
608+
invalidateOpts := *opt
609+
invalidateOpts.OnConnect = func(ctx context.Context, conn *redis.Conn) (err error) {
610+
invalidateClientID, err = conn.ClientID(ctx).Result()
611+
return
612+
}
613+
614+
startBackgroundInvalidationSubscription := func(ctx context.Context) int64 {
615+
mu.Lock()
616+
defer mu.Unlock()
617+
618+
if invalidateClientID != -1 {
619+
return invalidateClientID
620+
}
621+
622+
invalidateClient := redis.NewClient(&invalidateOpts)
623+
invalidations := invalidateClient.Subscribe(ctx, "__redis__:invalidate")
624+
625+
go func() {
626+
defer func() {
627+
invalidations.Close()
628+
invalidateClient.Close()
629+
630+
mu.Lock()
631+
invalidateClientID = -1
632+
mu.Unlock()
633+
}()
634+
635+
for {
636+
msg, err := invalidations.ReceiveMessage(context.Background())
637+
if err == io.EOF || err == context.Canceled {
638+
return
639+
} else if err != nil {
640+
fmt.Printf("warning: subscription on key invalidations aborted: %s\n", err.Error())
641+
// send back empty []string to fail the test
642+
processInvalidKeysFunc([]string{})
643+
return
644+
}
645+
646+
processInvalidKeysFunc(msg.PayloadSlice)
647+
}
648+
}()
649+
650+
return invalidateClientID
651+
}
652+
653+
opt.OnConnect = func(ctx context.Context, conn *redis.Conn) error {
654+
invalidateClientID := startBackgroundInvalidationSubscription(ctx)
655+
return conn.Process(
656+
ctx,
657+
redis.NewBoolCmd(
658+
ctx,
659+
"CLIENT", "TRACKING", "on",
660+
"REDIRECT", fmt.Sprintf("%d", invalidateClientID),
661+
),
662+
)
663+
}
664+
return opt
665+
}

0 commit comments

Comments
 (0)