diff --git a/client.go b/client.go index 16f4b4b..9b6363a 100644 --- a/client.go +++ b/client.go @@ -3,11 +3,13 @@ package redjet import ( "bufio" "context" + "crypto/tls" "encoding/json" "errors" "fmt" "io" "net" + "net/url" "strconv" "strings" "sync" @@ -58,6 +60,62 @@ func New(addr string) *Client { return c } +func NewFromURL(rawURL string) (*Client, error) { + u, err := url.Parse(rawURL) + if err != nil { + return nil, fmt.Errorf("parse url: %w", err) + } + + client := New(u.Host) + + var ( + addr string + isUnixSocket bool + ) + if u.Host == "" && u.Path != "" { + // Likely using a unix socket. + addr = u.Path + isUnixSocket = true + } else { + addr = u.Host + } + + if !isUnixSocket { + if u.Port() != "" { + addr = net.JoinHostPort(addr, u.Port()) + } else { + addr = net.JoinHostPort(addr, "6379") + } + } + + switch u.Scheme { + case "redis": + client.Dial = func(ctx context.Context) (net.Conn, error) { + var d net.Dialer + proto := "tcp" + if isUnixSocket { + proto = "unix" + } + return d.DialContext(ctx, proto, addr) + } + case "rediss": + client.Dial = func(ctx context.Context) (net.Conn, error) { + var d tls.Dialer + return d.DialContext(ctx, "tcp", addr) + } + default: + return nil, fmt.Errorf("unsupported scheme: %s", u.Scheme) + } + + if u.User != nil { + client.AuthUsername = u.User.Username() + pass, _ := u.User.Password() + client.AuthPassword = pass + } + + return client, nil +} + func (c *Client) initPool() { c.poolMu.Lock() defer c.poolMu.Unlock() @@ -294,6 +352,13 @@ func (c *Client) Pipeline(ctx context.Context, r *Pipeline, cmd string, args ... // // The caller should call Close on the result when finished with it. func (c *Client) Command(ctx context.Context, cmd string, args ...any) *Pipeline { + if isSubscribeCmd(cmd) { + return &Pipeline{ + // Close behavior becomes confusing when combining subscription + // and CloseOnRead. + err: fmt.Errorf("cannot use Command with subscribe command %s, use Pipeline instead", cmd), + } + } r := c.Pipeline(ctx, nil, cmd, args...) r.CloseOnRead = true return r diff --git a/client_test.go b/client_test.go index 0f86dfd..ff1285c 100644 --- a/client_test.go +++ b/client_test.go @@ -19,6 +19,41 @@ import ( "go.uber.org/goleak" ) +func pingPong(t *testing.T, c *redjet.Client) { + ctx := context.Background() + got, err := c.Command(ctx, "PING").String() + require.NoError(t, err) + require.Equal(t, "PONG", got) +} + +func TestNewFromURL(t *testing.T) { + t.Parallel() + + t.Run("Normal", func(t *testing.T) { + t.Parallel() + + addr, _ := redtest.StartRedisServer(t) + + c, err := redjet.NewFromURL("redis://" + addr) + require.NoError(t, err) + defer c.Close() + + pingPong(t, c) + }) + + t.Run("Password", func(t *testing.T) { + t.Parallel() + + addr, _ := redtest.StartRedisServer(t, "--requirepass", "hunter2") + + c, err := redjet.NewFromURL("redis://:hunter2@" + addr) + require.NoError(t, err) + defer c.Close() + + pingPong(t, c) + }) +} + func TestClient_SetGet(t *testing.T) { t.Parallel() @@ -313,33 +348,50 @@ func TestClient_Auth(t *testing.T) { func TestClient_PubSub(t *testing.T) { t.Parallel() - _, client := redtest.StartRedisServer(t) + t.Run("NoCommand", func(t *testing.T) { + _, client := redtest.StartRedisServer(t) - ctx := context.Background() - subCmd := client.Command(ctx, "SUBSCRIBE", "foo") - defer subCmd.Close() + ctx := context.Background() + subCmd := client.Command(ctx, "SUBSCRIBE", "foo") + defer subCmd.Close() - msg, err := subCmd.NextSubMessage() - require.NoError(t, err) + _, err := subCmd.NextSubMessage() + require.Error(t, err) + }) - require.Equal(t, &redjet.SubMessage{ - Channel: "foo", - Type: "subscribe", - Payload: "1", - }, msg) + t.Run("OK", func(t *testing.T) { + _, client := redtest.StartRedisServer(t) - n, err := client.Command(ctx, "PUBLISH", "foo", "bar").Int() - require.NoError(t, err) - require.Equal(t, 1, n) + ctx := context.Background() - msg, err = subCmd.NextSubMessage() - require.NoError(t, err) + subCmd := client.Pipeline(ctx, nil, "SUBSCRIBE", "foo") + defer subCmd.Close() + + msg, err := subCmd.NextSubMessage() + require.NoError(t, err) + + require.Equal(t, &redjet.SubMessage{ + Channel: "foo", + Type: "subscribe", + Payload: "1", + }, msg) + + pubPipe := client.Pipeline(ctx, nil, "PUBLISH", "foo", "bar") + defer pubPipe.Close() - require.Equal(t, &redjet.SubMessage{ - Channel: "foo", - Type: "message", - Payload: "bar", - }, msg) + n, err := pubPipe.Int() + require.NoError(t, err) + require.Equal(t, 1, n) + + msg, err = subCmd.NextSubMessage() + require.NoError(t, err) + + require.Equal(t, &redjet.SubMessage{ + Channel: "foo", + Type: "message", + Payload: "bar", + }, msg) + }) } func TestClient_ConnReuse(t *testing.T) { diff --git a/pubsub.go b/pubsub.go index 287f3d7..6f7da01 100644 --- a/pubsub.go +++ b/pubsub.go @@ -58,7 +58,7 @@ func (r *Pipeline) NextSubMessage() (*SubMessage, error) { func isSubscribeCmd(cmd string) bool { switch cmd { - case "SUBSCRIBE", "PSUBSCRIBE", "UNSUBSCRIBE", "PUNSUBSCRIBE", "PING", "QUIT", "RESET": + case "SUBSCRIBE", "PSUBSCRIBE", "UNSUBSCRIBE", "PUNSUBSCRIBE", "QUIT", "RESET": return true default: return false