|
1 | 1 | package redis_test
|
2 | 2 |
|
3 | 3 | import (
|
| 4 | + "context" |
| 5 | + "fmt" |
4 | 6 | "io"
|
5 | 7 | "net"
|
6 | 8 | "sync"
|
@@ -568,5 +570,98 @@ var _ = Describe("PubSub", func() {
|
568 | 570 | Expect(msg.Payload).To(Equal(text))
|
569 | 571 | })
|
570 | 572 |
|
571 |
| - It("should channel client-cache invalidation messages", func() {}) |
| 573 | + It("supports client-cache invalidation messages", func() { |
| 574 | + ch := make(chan []string, 2) |
| 575 | + defer close(ch) |
| 576 | + client := redis.NewClient(getOptsWithTracking(redisOptions(), func(keys []string) error { |
| 577 | + ch <- keys |
| 578 | + return nil |
| 579 | + })) |
| 580 | + defer client.Close() |
| 581 | + |
| 582 | + v1 := client.Get(context.Background(), "foo") |
| 583 | + Expect(v1.Val()).To(Equal("")) |
| 584 | + s1 := client.Set(context.Background(), "foo", "bar", time.Duration(time.Minute)) |
| 585 | + Expect(s1.Val()).To(Equal("OK")) |
| 586 | + v2 := client.Get(context.Background(), "foo") |
| 587 | + Expect(v2.Val()).To(Equal("bar")) |
| 588 | + // sleep a little to all time for the first invalidation message to come through |
| 589 | + time.Sleep(time.Second) |
| 590 | + s2 := client.Set(context.Background(), "foo", "foobar", time.Duration(time.Minute)) |
| 591 | + Expect(s2.Val()).To(Equal("OK")) |
| 592 | + |
| 593 | + for i := 0; i < 2; i++ { |
| 594 | + select { |
| 595 | + case keys := <-ch: |
| 596 | + Expect(keys).ToNot(BeEmpty()) |
| 597 | + Expect(keys[0]).To(Equal("foo")) |
| 598 | + case <-time.After(10 * time.Second): |
| 599 | + // fail on timeouts |
| 600 | + Fail("invalidation message wait timed out") |
| 601 | + } |
| 602 | + } |
| 603 | + }) |
| 604 | + |
572 | 605 | })
|
| 606 | + |
| 607 | +func getOptsWithTracking(opt *redis.Options, processInvalidKeysFunc func([]string) error) *redis.Options { |
| 608 | + var mu sync.Mutex |
| 609 | + invalidateClientID := int64(-1) |
| 610 | + invalidateOpts := *opt |
| 611 | + invalidateOpts.OnConnect = func(ctx context.Context, conn *redis.Conn) (err error) { |
| 612 | + invalidateClientID, err = conn.ClientID(ctx).Result() |
| 613 | + return |
| 614 | + } |
| 615 | + |
| 616 | + startBackgroundInvalidationSubscription := func(ctx context.Context) int64 { |
| 617 | + mu.Lock() |
| 618 | + defer mu.Unlock() |
| 619 | + |
| 620 | + if invalidateClientID != -1 { |
| 621 | + return invalidateClientID |
| 622 | + } |
| 623 | + |
| 624 | + invalidateClient := redis.NewClient(&invalidateOpts) |
| 625 | + invalidations := invalidateClient.Subscribe(ctx, "__redis__:invalidate") |
| 626 | + |
| 627 | + go func() { |
| 628 | + defer func() { |
| 629 | + invalidations.Close() |
| 630 | + invalidateClient.Close() |
| 631 | + |
| 632 | + mu.Lock() |
| 633 | + invalidateClientID = -1 |
| 634 | + mu.Unlock() |
| 635 | + }() |
| 636 | + |
| 637 | + for { |
| 638 | + msg, err := invalidations.ReceiveMessage(context.Background()) |
| 639 | + if err == io.EOF || err == context.Canceled { |
| 640 | + return |
| 641 | + } else if err != nil { |
| 642 | + fmt.Printf("warning: subscription on key invalidations aborted: %s\n", err.Error()) |
| 643 | + // send back empty []string to fail the test |
| 644 | + processInvalidKeysFunc([]string{}) |
| 645 | + return |
| 646 | + } |
| 647 | + |
| 648 | + processInvalidKeysFunc(msg.PayloadSlice) |
| 649 | + } |
| 650 | + }() |
| 651 | + |
| 652 | + return invalidateClientID |
| 653 | + } |
| 654 | + |
| 655 | + opt.OnConnect = func(ctx context.Context, conn *redis.Conn) error { |
| 656 | + invalidateClientID := startBackgroundInvalidationSubscription(ctx) |
| 657 | + return conn.Process( |
| 658 | + ctx, |
| 659 | + redis.NewBoolCmd( |
| 660 | + ctx, |
| 661 | + "CLIENT", "TRACKING", "on", |
| 662 | + "REDIRECT", fmt.Sprintf("%d", invalidateClientID), |
| 663 | + ), |
| 664 | + ) |
| 665 | + } |
| 666 | + return opt |
| 667 | +} |
0 commit comments