diff --git a/daemon.go b/daemon.go index 97602075..6a4cf3c0 100644 --- a/daemon.go +++ b/daemon.go @@ -41,7 +41,6 @@ import ( "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/keepalive" - "google.golang.org/grpc/resolver" "google.golang.org/protobuf/encoding/protojson" ) @@ -433,7 +432,7 @@ func (s *Daemon) Client() (V1Client, error) { conn, err := grpc.DialContext(context.Background(), fmt.Sprintf("static:///%s", s.PeerInfo.GRPCAddress), - grpc.WithResolvers(newStaticBuilder()), + grpc.WithResolvers(NewStaticBuilder()), grpc.WithTransportCredentials(insecure.NewCredentials())) if err != nil { return nil, err @@ -482,41 +481,3 @@ func WaitForConnect(ctx context.Context, addresses []string) error { } return nil } - -type staticBuilder struct{} - -var _ resolver.Builder = (*staticBuilder)(nil) - -func (sb *staticBuilder) Scheme() string { - return "static" -} - -func (sb *staticBuilder) Build(target resolver.Target, cc resolver.ClientConn, _ resolver.BuildOptions) (resolver.Resolver, error) { - var resolverAddrs []resolver.Address - for _, address := range strings.Split(target.Endpoint(), ",") { - resolverAddrs = append(resolverAddrs, resolver.Address{ - Addr: address, - ServerName: address, - }) - } - if err := cc.UpdateState(resolver.State{Addresses: resolverAddrs}); err != nil { - return nil, err - } - return &staticResolver{cc: cc}, nil -} - -// newStaticBuilder returns a builder which returns a staticResolver that tells GRPC -// to connect a specific peer in the cluster. -func newStaticBuilder() resolver.Builder { - return &staticBuilder{} -} - -type staticResolver struct { - cc resolver.ClientConn -} - -func (sr *staticResolver) ResolveNow(_ resolver.ResolveNowOptions) {} - -func (sr *staticResolver) Close() {} - -var _ resolver.Resolver = (*staticResolver)(nil) diff --git a/functional_test.go b/functional_test.go index b377e86a..246e837e 100644 --- a/functional_test.go +++ b/functional_test.go @@ -21,6 +21,7 @@ import ( "context" "fmt" "io" + "math/rand" "net/http" "os" "strings" @@ -35,6 +36,9 @@ import ( "github.com/prometheus/common/model" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" json "google.golang.org/protobuf/encoding/protojson" ) @@ -1017,6 +1021,70 @@ func TestGlobalRateLimits(t *testing.T) { sendHit(peers[4].MustClient(), guber.Status_OVER_LIMIT, 1, 0) } +// Ensure global broadcast updates all peers when GetRateLimits is called on +// either owner or non-owner peer. +func TestGlobalRateLimitsWithLoadBalancing(t *testing.T) { + ctx := context.Background() + const name = "test_global" + key := fmt.Sprintf("key:%016x", rand.Int()) + + // Determine owner and non-owner peers. + ownerPeerInfo, err := cluster.FindOwningPeer(name, key) + require.NoError(t, err) + owner := ownerPeerInfo.GRPCAddress + nonOwner := cluster.PeerAt(0).GRPCAddress + if nonOwner == owner { + nonOwner = cluster.PeerAt(1).GRPCAddress + } + require.NotEqual(t, owner, nonOwner) + + // Connect to owner and non-owner peers in round robin. + dialOpts := []grpc.DialOption{ + grpc.WithResolvers(guber.NewStaticBuilder()), + grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithDefaultServiceConfig(`{"loadBalancingConfig": [{"round_robin":{}}]}`), + } + address := fmt.Sprintf("static:///%s,%s", owner, nonOwner) + conn, err := grpc.DialContext(ctx, address, dialOpts...) + require.NoError(t, err) + client := guber.NewV1Client(conn) + + sendHit := func(status guber.Status, i int) { + ctx, cancel := context.WithTimeout(ctx, 10*clock.Second) + defer cancel() + resp, err := client.GetRateLimits(ctx, &guber.GetRateLimitsReq{ + Requests: []*guber.RateLimitReq{ + { + Name: name, + UniqueKey: key, + Algorithm: guber.Algorithm_LEAKY_BUCKET, + Behavior: guber.Behavior_GLOBAL, + Duration: guber.Minute * 5, + Hits: 1, + Limit: 2, + }, + }, + }) + require.NoError(t, err, i) + item := resp.Responses[0] + assert.Equal(t, "", item.GetError(), fmt.Sprintf("mismatch error, iteration %d", i)) + assert.Equal(t, status, item.GetStatus(), fmt.Sprintf("mismatch status, iteration %d", i)) + } + + // Send two hits that should be processed by the owner and non-owner and + // deplete the limit consistently. + sendHit(guber.Status_UNDER_LIMIT, 1) + sendHit(guber.Status_UNDER_LIMIT, 2) + + // Sleep to ensure the global broadcast occurs (every 100ms). + time.Sleep(150 * time.Millisecond) + + // All successive hits should return OVER_LIMIT. + for i := 2; i <= 10; i++ { + sendHit(guber.Status_OVER_LIMIT, i) + } +} + func TestGlobalRateLimitsPeerOverLimit(t *testing.T) { const ( name = "test_global_token_limit" diff --git a/staticbuilder.go b/staticbuilder.go new file mode 100644 index 00000000..9bbd8325 --- /dev/null +++ b/staticbuilder.go @@ -0,0 +1,45 @@ +package gubernator + +import ( + "strings" + + "google.golang.org/grpc/resolver" +) + +type staticBuilder struct{} + +var _ resolver.Builder = (*staticBuilder)(nil) + +func (sb *staticBuilder) Scheme() string { + return "static" +} + +func (sb *staticBuilder) Build(target resolver.Target, cc resolver.ClientConn, _ resolver.BuildOptions) (resolver.Resolver, error) { + var resolverAddrs []resolver.Address + for _, address := range strings.Split(target.Endpoint(), ",") { + resolverAddrs = append(resolverAddrs, resolver.Address{ + Addr: address, + ServerName: address, + }) + } + if err := cc.UpdateState(resolver.State{Addresses: resolverAddrs}); err != nil { + return nil, err + } + return &staticResolver{cc: cc}, nil +} + +// NewStaticBuilder returns a builder which returns a staticResolver that tells GRPC +// to connect a specific peer in the cluster. +func NewStaticBuilder() resolver.Builder { + return &staticBuilder{} +} + +type staticResolver struct { + cc resolver.ClientConn +} + +func (sr *staticResolver) ResolveNow(_ resolver.ResolveNowOptions) {} + +func (sr *staticResolver) Close() {} + +var _ resolver.Resolver = (*staticResolver)(nil)