From f1ffc2282404385daa3ed76b9a8f9a62f1cafcb4 Mon Sep 17 00:00:00 2001 From: fukua95 Date: Tue, 29 Apr 2025 22:31:50 +0800 Subject: [PATCH] feat: add ParseFailoverURL * add ParseFailoverURL for FailoverOptions * fix 2 warning --- example_test.go | 4 +- sentinel.go | 142 +++++++++++++++++++++++++ sentinel_test.go | 267 +++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 411 insertions(+), 2 deletions(-) diff --git a/example_test.go b/example_test.go index 28d14b65a..c20e8390a 100644 --- a/example_test.go +++ b/example_test.go @@ -359,7 +359,7 @@ func ExampleMapStringStringCmd_Scan() { // Get the map. The same approach works for HmGet(). res := rdb.HGetAll(ctx, "map") if res.Err() != nil { - panic(err) + panic(res.Err()) } type data struct { @@ -392,7 +392,7 @@ func ExampleSliceCmd_Scan() { res := rdb.MGet(ctx, "name", "count", "empty", "correct") if res.Err() != nil { - panic(err) + panic(res.Err()) } type data struct { diff --git a/sentinel.go b/sentinel.go index cfc848cf0..314bde1ef 100644 --- a/sentinel.go +++ b/sentinel.go @@ -6,6 +6,8 @@ import ( "errors" "fmt" "net" + "net/url" + "strconv" "strings" "sync" "time" @@ -220,6 +222,146 @@ func (opt *FailoverOptions) clusterOptions() *ClusterOptions { } } +// ParseFailoverURL parses a URL into FailoverOptions that can be used to connect to Redis. +// The URL must be in the form: +// +// redis://:@:/ +// or +// rediss://:@:/ +// +// To add additional addresses, specify the query parameter, "addr" one or more times. e.g: +// +// redis://:@:/?addr=:&addr=: +// or +// rediss://:@:/?addr=:&addr=: +// +// Most Option fields can be set using query parameters, with the following restrictions: +// - field names are mapped using snake-case conversion: to set MaxRetries, use max_retries +// - only scalar type fields are supported (bool, int, time.Duration) +// - for time.Duration fields, values must be a valid input for time.ParseDuration(); +// additionally a plain integer as value (i.e. without unit) is interpreted as seconds +// - to disable a duration field, use value less than or equal to 0; to use the default +// value, leave the value blank or remove the parameter +// - only the last value is interpreted if a parameter is given multiple times +// - fields "network", "addr", "sentinel_username" and "sentinel_password" can only be set using other +// URL attributes (scheme, host, userinfo, resp.), query parameters using these +// names will be treated as unknown parameters +// - unknown parameter names will result in an error +// +// Example: +// +// redis://user:password@localhost:6789?master_name=mymaster&dial_timeout=3&read_timeout=6s&addr=localhost:6790&addr=localhost:6791 +// is equivalent to: +// &FailoverOptions{ +// MasterName: "mymaster", +// Addr: ["localhost:6789", "localhost:6790", "localhost:6791"] +// DialTimeout: 3 * time.Second, // no time unit = seconds +// ReadTimeout: 6 * time.Second, +// } +func ParseFailoverURL(redisURL string) (*FailoverOptions, error) { + u, err := url.Parse(redisURL) + if err != nil { + return nil, err + } + return setupFailoverConn(u) +} + +func setupFailoverConn(u *url.URL) (*FailoverOptions, error) { + o := &FailoverOptions{} + + o.SentinelUsername, o.SentinelPassword = getUserPassword(u) + + h, p := getHostPortWithDefaults(u) + o.SentinelAddrs = append(o.SentinelAddrs, net.JoinHostPort(h, p)) + + switch u.Scheme { + case "rediss": + o.TLSConfig = &tls.Config{ServerName: h, MinVersion: tls.VersionTLS12} + case "redis": + o.TLSConfig = nil + default: + return nil, fmt.Errorf("redis: invalid URL scheme: %s", u.Scheme) + } + + f := strings.FieldsFunc(u.Path, func(r rune) bool { + return r == '/' + }) + switch len(f) { + case 0: + o.DB = 0 + case 1: + var err error + if o.DB, err = strconv.Atoi(f[0]); err != nil { + return nil, fmt.Errorf("redis: invalid database number: %q", f[0]) + } + default: + return nil, fmt.Errorf("redis: invalid URL path: %s", u.Path) + } + + return setupFailoverConnParams(u, o) +} + +func setupFailoverConnParams(u *url.URL, o *FailoverOptions) (*FailoverOptions, error) { + q := queryOptions{q: u.Query()} + + o.MasterName = q.string("master_name") + o.ClientName = q.string("client_name") + o.RouteByLatency = q.bool("route_by_latency") + o.RouteRandomly = q.bool("route_randomly") + o.ReplicaOnly = q.bool("replica_only") + o.UseDisconnectedReplicas = q.bool("use_disconnected_replicas") + o.Protocol = q.int("protocol") + o.Username = q.string("username") + o.Password = q.string("password") + o.MaxRetries = q.int("max_retries") + o.MinRetryBackoff = q.duration("min_retry_backoff") + o.MaxRetryBackoff = q.duration("max_retry_backoff") + o.DialTimeout = q.duration("dial_timeout") + o.ReadTimeout = q.duration("read_timeout") + o.WriteTimeout = q.duration("write_timeout") + o.ContextTimeoutEnabled = q.bool("context_timeout_enabled") + o.PoolFIFO = q.bool("pool_fifo") + o.PoolSize = q.int("pool_size") + o.MinIdleConns = q.int("min_idle_conns") + o.MaxIdleConns = q.int("max_idle_conns") + o.MaxActiveConns = q.int("max_active_conns") + o.ConnMaxLifetime = q.duration("conn_max_lifetime") + o.ConnMaxIdleTime = q.duration("conn_max_idle_time") + o.PoolTimeout = q.duration("pool_timeout") + o.DisableIdentity = q.bool("disableIdentity") + o.IdentitySuffix = q.string("identitySuffix") + o.UnstableResp3 = q.bool("unstable_resp3") + + if q.err != nil { + return nil, q.err + } + + if tmp := q.string("db"); tmp != "" { + db, err := strconv.Atoi(tmp) + if err != nil { + return nil, fmt.Errorf("redis: invalid database number: %w", err) + } + o.DB = db + } + + addrs := q.strings("addr") + for _, addr := range addrs { + h, p, err := net.SplitHostPort(addr) + if err != nil || h == "" || p == "" { + return nil, fmt.Errorf("redis: unable to parse addr param: %s", addr) + } + + o.SentinelAddrs = append(o.SentinelAddrs, net.JoinHostPort(h, p)) + } + + // any parameters left? + if r := q.remaining(); len(r) > 0 { + return nil, fmt.Errorf("redis: unexpected option: %s", strings.Join(r, ", ")) + } + + return o, nil +} + // NewFailoverClient returns a Redis client that uses Redis Sentinel // for automatic failover. It's safe for concurrent use by multiple // goroutines. diff --git a/sentinel_test.go b/sentinel_test.go index 2d481d5fc..436895ff2 100644 --- a/sentinel_test.go +++ b/sentinel_test.go @@ -2,7 +2,11 @@ package redis_test import ( "context" + "crypto/tls" + "errors" "net" + "sort" + "testing" "time" . "github.com/bsm/ginkgo/v2" @@ -405,3 +409,266 @@ var _ = Describe("SentinelAclAuth", func() { Expect(val).To(Equal("acl-auth")) }) }) + +func TestParseFailoverURL(t *testing.T) { + cases := []struct { + url string + o *redis.FailoverOptions + err error + }{ + { + url: "redis://localhost:6379?master_name=test", + o: &redis.FailoverOptions{SentinelAddrs: []string{"localhost:6379"}, MasterName: "test"}, + }, + { + url: "redis://localhost:6379/5?master_name=test", + o: &redis.FailoverOptions{SentinelAddrs: []string{"localhost:6379"}, MasterName: "test", DB: 5}, + }, + { + url: "rediss://localhost:6379/5?master_name=test", + o: &redis.FailoverOptions{SentinelAddrs: []string{"localhost:6379"}, MasterName: "test", DB: 5, + TLSConfig: &tls.Config{ + ServerName: "localhost", + }}, + }, + { + url: "redis://localhost:6379/5?master_name=test&db=2", + o: &redis.FailoverOptions{SentinelAddrs: []string{"localhost:6379"}, MasterName: "test", DB: 2}, + }, + { + url: "redis://localhost:6379/5?addr=localhost:6380&addr=localhost:6381", + o: &redis.FailoverOptions{SentinelAddrs: []string{"localhost:6380", "localhost:6379", "localhost:6381"}, DB: 5}, + }, + { + url: "redis://foo:bar@localhost:6379/5?addr=localhost:6380", + o: &redis.FailoverOptions{SentinelAddrs: []string{"localhost:6380", "localhost:6379"}, + SentinelUsername: "foo", SentinelPassword: "bar", DB: 5}, + }, + { + url: "redis://:bar@localhost:6379/5?addr=localhost:6380", + o: &redis.FailoverOptions{SentinelAddrs: []string{"localhost:6380", "localhost:6379"}, + SentinelUsername: "", SentinelPassword: "bar", DB: 5}, + }, + { + url: "redis://foo@localhost:6379/5?addr=localhost:6380", + o: &redis.FailoverOptions{SentinelAddrs: []string{"localhost:6380", "localhost:6379"}, + SentinelUsername: "foo", SentinelPassword: "", DB: 5}, + }, + { + url: "redis://foo:bar@localhost:6379/5?addr=localhost:6380&dial_timeout=3", + o: &redis.FailoverOptions{SentinelAddrs: []string{"localhost:6380", "localhost:6379"}, + SentinelUsername: "foo", SentinelPassword: "bar", DB: 5, DialTimeout: 3 * time.Second}, + }, + { + url: "redis://foo:bar@localhost:6379/5?addr=localhost:6380&dial_timeout=3s", + o: &redis.FailoverOptions{SentinelAddrs: []string{"localhost:6380", "localhost:6379"}, + SentinelUsername: "foo", SentinelPassword: "bar", DB: 5, DialTimeout: 3 * time.Second}, + }, + { + url: "redis://foo:bar@localhost:6379/5?addr=localhost:6380&dial_timeout=3ms", + o: &redis.FailoverOptions{SentinelAddrs: []string{"localhost:6380", "localhost:6379"}, + SentinelUsername: "foo", SentinelPassword: "bar", DB: 5, DialTimeout: 3 * time.Millisecond}, + }, + { + url: "redis://foo:bar@localhost:6379/5?addr=localhost:6380&dial_timeout=3&pool_fifo=true", + o: &redis.FailoverOptions{SentinelAddrs: []string{"localhost:6380", "localhost:6379"}, + SentinelUsername: "foo", SentinelPassword: "bar", DB: 5, DialTimeout: 3 * time.Second, PoolFIFO: true}, + }, + { + url: "redis://localhost:6379/5?addr=localhost:6380&dial_timeout=3&pool_fifo=false", + o: &redis.FailoverOptions{SentinelAddrs: []string{"localhost:6380", "localhost:6379"}, + DB: 5, DialTimeout: 3 * time.Second, PoolFIFO: false}, + }, + { + url: "redis://localhost:6379/5?addr=localhost:6380&dial_timeout=3&pool_fifo", + o: &redis.FailoverOptions{SentinelAddrs: []string{"localhost:6380", "localhost:6379"}, + DB: 5, DialTimeout: 3 * time.Second, PoolFIFO: false}, + }, + { + url: "redis://localhost:6379/5?addr=localhost:6380&dial_timeout", + o: &redis.FailoverOptions{SentinelAddrs: []string{"localhost:6380", "localhost:6379"}, + DB: 5, DialTimeout: 0}, + }, + { + url: "redis://localhost:6379/5?addr=localhost:6380&dial_timeout=0", + o: &redis.FailoverOptions{SentinelAddrs: []string{"localhost:6380", "localhost:6379"}, + DB: 5, DialTimeout: -1}, + }, + { + url: "redis://localhost:6379/5?addr=localhost:6380&dial_timeout=-1", + o: &redis.FailoverOptions{SentinelAddrs: []string{"localhost:6380", "localhost:6379"}, + DB: 5, DialTimeout: -1}, + }, + { + url: "redis://localhost:6379/5?addr=localhost:6380&dial_timeout=-2", + o: &redis.FailoverOptions{SentinelAddrs: []string{"localhost:6380", "localhost:6379"}, + DB: 5, DialTimeout: -1}, + }, + { + url: "redis://localhost:6379/5?addr=localhost:6380&dial_timeout=", + o: &redis.FailoverOptions{SentinelAddrs: []string{"localhost:6380", "localhost:6379"}, + DB: 5, DialTimeout: 0}, + }, + { + url: "redis://localhost:6379/5?addr=localhost:6380&dial_timeout=0&abc=5", + o: &redis.FailoverOptions{SentinelAddrs: []string{"localhost:6380", "localhost:6379"}, + DB: 5, DialTimeout: -1}, + err: errors.New("redis: unexpected option: abc"), + }, + { + url: "http://google.com", + err: errors.New("redis: invalid URL scheme: http"), + }, + { + url: "redis://localhost/1/2/3/4", + err: errors.New("redis: invalid URL path: /1/2/3/4"), + }, + { + url: "12345", + err: errors.New("redis: invalid URL scheme: "), + }, + { + url: "redis://localhost/database", + err: errors.New(`redis: invalid database number: "database"`), + }, + } + + for i := range cases { + tc := cases[i] + t.Run(tc.url, func(t *testing.T) { + t.Parallel() + + actual, err := redis.ParseFailoverURL(tc.url) + if tc.err == nil && err != nil { + t.Fatalf("unexpected error: %q", err) + return + } + if tc.err != nil && err == nil { + t.Fatalf("got nil, expected %q", tc.err) + return + } + if tc.err != nil && err != nil { + if tc.err.Error() != err.Error() { + t.Fatalf("got %q, expected %q", err, tc.err) + } + return + } + compareFailoverOptions(t, actual, tc.o) + }) + } +} + +func compareFailoverOptions(t *testing.T, a, e *redis.FailoverOptions) { + if a.MasterName != e.MasterName { + t.Errorf("MasterName got %q, want %q", a.MasterName, e.MasterName) + } + compareSlices(t, a.SentinelAddrs, e.SentinelAddrs, "SentinelAddrs") + if a.ClientName != e.ClientName { + t.Errorf("ClientName got %q, want %q", a.ClientName, e.ClientName) + } + if a.SentinelUsername != e.SentinelUsername { + t.Errorf("SentinelUsername got %q, want %q", a.SentinelUsername, e.SentinelUsername) + } + if a.SentinelPassword != e.SentinelPassword { + t.Errorf("SentinelPassword got %q, want %q", a.SentinelPassword, e.SentinelPassword) + } + if a.RouteByLatency != e.RouteByLatency { + t.Errorf("RouteByLatency got %v, want %v", a.RouteByLatency, e.RouteByLatency) + } + if a.RouteRandomly != e.RouteRandomly { + t.Errorf("RouteRandomly got %v, want %v", a.RouteRandomly, e.RouteRandomly) + } + if a.ReplicaOnly != e.ReplicaOnly { + t.Errorf("ReplicaOnly got %v, want %v", a.ReplicaOnly, e.ReplicaOnly) + } + if a.UseDisconnectedReplicas != e.UseDisconnectedReplicas { + t.Errorf("UseDisconnectedReplicas got %v, want %v", a.UseDisconnectedReplicas, e.UseDisconnectedReplicas) + } + if a.Protocol != e.Protocol { + t.Errorf("Protocol got %v, want %v", a.Protocol, e.Protocol) + } + if a.Username != e.Username { + t.Errorf("Username got %q, want %q", a.Username, e.Username) + } + if a.Password != e.Password { + t.Errorf("Password got %q, want %q", a.Password, e.Password) + } + if a.DB != e.DB { + t.Errorf("DB got %v, want %v", a.DB, e.DB) + } + if a.MaxRetries != e.MaxRetries { + t.Errorf("MaxRetries got %v, want %v", a.MaxRetries, e.MaxRetries) + } + if a.MinRetryBackoff != e.MinRetryBackoff { + t.Errorf("MinRetryBackoff got %v, want %v", a.MinRetryBackoff, e.MinRetryBackoff) + } + if a.MaxRetryBackoff != e.MaxRetryBackoff { + t.Errorf("MaxRetryBackoff got %v, want %v", a.MaxRetryBackoff, e.MaxRetryBackoff) + } + if a.DialTimeout != e.DialTimeout { + t.Errorf("DialTimeout got %v, want %v", a.DialTimeout, e.DialTimeout) + } + if a.ReadTimeout != e.ReadTimeout { + t.Errorf("ReadTimeout got %v, want %v", a.ReadTimeout, e.ReadTimeout) + } + if a.WriteTimeout != e.WriteTimeout { + t.Errorf("WriteTimeout got %v, want %v", a.WriteTimeout, e.WriteTimeout) + } + if a.ContextTimeoutEnabled != e.ContextTimeoutEnabled { + t.Errorf("ContentTimeoutEnabled got %v, want %v", a.ContextTimeoutEnabled, e.ContextTimeoutEnabled) + } + if a.PoolFIFO != e.PoolFIFO { + t.Errorf("PoolFIFO got %v, want %v", a.PoolFIFO, e.PoolFIFO) + } + if a.PoolSize != e.PoolSize { + t.Errorf("PoolSize got %v, want %v", a.PoolSize, e.PoolSize) + } + if a.PoolTimeout != e.PoolTimeout { + t.Errorf("PoolTimeout got %v, want %v", a.PoolTimeout, e.PoolTimeout) + } + if a.MinIdleConns != e.MinIdleConns { + t.Errorf("MinIdleConns got %v, want %v", a.MinIdleConns, e.MinIdleConns) + } + if a.MaxIdleConns != e.MaxIdleConns { + t.Errorf("MaxIdleConns got %v, want %v", a.MaxIdleConns, e.MaxIdleConns) + } + if a.MaxActiveConns != e.MaxActiveConns { + t.Errorf("MaxActiveConns got %v, want %v", a.MaxActiveConns, e.MaxActiveConns) + } + if a.ConnMaxIdleTime != e.ConnMaxIdleTime { + t.Errorf("ConnMaxIdleTime got %v, want %v", a.ConnMaxIdleTime, e.ConnMaxIdleTime) + } + if a.ConnMaxLifetime != e.ConnMaxLifetime { + t.Errorf("ConnMaxLifeTime got %v, want %v", a.ConnMaxLifetime, e.ConnMaxLifetime) + } + if a.DisableIdentity != e.DisableIdentity { + t.Errorf("DisableIdentity got %v, want %v", a.DisableIdentity, e.DisableIdentity) + } + if a.IdentitySuffix != e.IdentitySuffix { + t.Errorf("IdentitySuffix got %v, want %v", a.IdentitySuffix, e.IdentitySuffix) + } + if a.UnstableResp3 != e.UnstableResp3 { + t.Errorf("UnstableResp3 got %v, want %v", a.UnstableResp3, e.UnstableResp3) + } + if (a.TLSConfig == nil && e.TLSConfig != nil) || (a.TLSConfig != nil && e.TLSConfig == nil) { + t.Errorf("TLSConfig error") + } + if a.TLSConfig != nil && e.TLSConfig != nil { + if a.TLSConfig.ServerName != e.TLSConfig.ServerName { + t.Errorf("TLSConfig.ServerName got %q, want %q", a.TLSConfig.ServerName, e.TLSConfig.ServerName) + } + } +} + +func compareSlices(t *testing.T, a, b []string, name string) { + sort.Slice(a, func(i, j int) bool { return a[i] < a[j] }) + sort.Slice(b, func(i, j int) bool { return b[i] < b[j] }) + if len(a) != len(b) { + t.Errorf("%s got %q, want %q", name, a, b) + } + for i := range a { + if a[i] != b[i] { + t.Errorf("%s got %q, want %q", name, a, b) + } + } +}