Skip to content

Commit

Permalink
Add NewFromURL
Browse files Browse the repository at this point in the history
Also, force pubsub to go through Pipeline
  • Loading branch information
ammario committed Feb 12, 2024
1 parent fe58e42 commit a4bec12
Show file tree
Hide file tree
Showing 3 changed files with 139 additions and 22 deletions.
65 changes: 65 additions & 0 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@ package redjet
import (
"bufio"
"context"
"crypto/tls"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"net/url"
"strconv"
"strings"
"sync"
Expand Down Expand Up @@ -58,6 +60,62 @@ func New(addr string) *Client {
return c
}

func NewFromURL(rawURL string) (*Client, error) {
u, err := url.Parse(rawURL)
if err != nil {
return nil, fmt.Errorf("parse url: %w", err)
}

client := New(u.Host)

var (
addr string
isUnixSocket bool
)
if u.Host == "" && u.Path != "" {
// Likely using a unix socket.
addr = u.Path
isUnixSocket = true
} else {
addr = u.Host
}

if !isUnixSocket {
if u.Port() != "" {
addr = net.JoinHostPort(addr, u.Port())
} else {
addr = net.JoinHostPort(addr, "6379")
}
}

switch u.Scheme {
case "redis":
client.Dial = func(ctx context.Context) (net.Conn, error) {
var d net.Dialer
proto := "tcp"
if isUnixSocket {
proto = "unix"
}
return d.DialContext(ctx, proto, addr)
}
case "rediss":
client.Dial = func(ctx context.Context) (net.Conn, error) {
var d tls.Dialer
return d.DialContext(ctx, "tcp", addr)
}
default:
return nil, fmt.Errorf("unsupported scheme: %s", u.Scheme)
}

if u.User != nil {
client.AuthUsername = u.User.Username()
pass, _ := u.User.Password()
client.AuthPassword = pass
}

return client, nil
}

func (c *Client) initPool() {
c.poolMu.Lock()
defer c.poolMu.Unlock()
Expand Down Expand Up @@ -294,6 +352,13 @@ func (c *Client) Pipeline(ctx context.Context, r *Pipeline, cmd string, args ...
//
// The caller should call Close on the result when finished with it.
func (c *Client) Command(ctx context.Context, cmd string, args ...any) *Pipeline {
if isSubscribeCmd(cmd) {
return &Pipeline{
// Close behavior becomes confusing when combining subscription
// and CloseOnRead.
err: fmt.Errorf("cannot use Command with subscribe command %s, use Pipeline instead", cmd),
}
}
r := c.Pipeline(ctx, nil, cmd, args...)
r.CloseOnRead = true
return r
Expand Down
94 changes: 73 additions & 21 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,41 @@ import (
"go.uber.org/goleak"
)

func pingPong(t *testing.T, c *redjet.Client) {
ctx := context.Background()
got, err := c.Command(ctx, "PING").String()
require.NoError(t, err)
require.Equal(t, "PONG", got)
}

func TestNewFromURL(t *testing.T) {
t.Parallel()

t.Run("Normal", func(t *testing.T) {
t.Parallel()

addr, _ := redtest.StartRedisServer(t)

c, err := redjet.NewFromURL("redis://" + addr)
require.NoError(t, err)
defer c.Close()

pingPong(t, c)
})

t.Run("Password", func(t *testing.T) {
t.Parallel()

addr, _ := redtest.StartRedisServer(t, "--requirepass", "hunter2")

c, err := redjet.NewFromURL("redis://:hunter2@" + addr)
require.NoError(t, err)
defer c.Close()

pingPong(t, c)
})
}

func TestClient_SetGet(t *testing.T) {
t.Parallel()

Expand Down Expand Up @@ -313,33 +348,50 @@ func TestClient_Auth(t *testing.T) {
func TestClient_PubSub(t *testing.T) {
t.Parallel()

_, client := redtest.StartRedisServer(t)
t.Run("NoCommand", func(t *testing.T) {
_, client := redtest.StartRedisServer(t)

ctx := context.Background()
subCmd := client.Command(ctx, "SUBSCRIBE", "foo")
defer subCmd.Close()
ctx := context.Background()
subCmd := client.Command(ctx, "SUBSCRIBE", "foo")
defer subCmd.Close()

msg, err := subCmd.NextSubMessage()
require.NoError(t, err)
_, err := subCmd.NextSubMessage()
require.Error(t, err)
})

require.Equal(t, &redjet.SubMessage{
Channel: "foo",
Type: "subscribe",
Payload: "1",
}, msg)
t.Run("OK", func(t *testing.T) {
_, client := redtest.StartRedisServer(t)

n, err := client.Command(ctx, "PUBLISH", "foo", "bar").Int()
require.NoError(t, err)
require.Equal(t, 1, n)
ctx := context.Background()

msg, err = subCmd.NextSubMessage()
require.NoError(t, err)
subCmd := client.Pipeline(ctx, nil, "SUBSCRIBE", "foo")
defer subCmd.Close()

msg, err := subCmd.NextSubMessage()
require.NoError(t, err)

require.Equal(t, &redjet.SubMessage{
Channel: "foo",
Type: "subscribe",
Payload: "1",
}, msg)

pubPipe := client.Pipeline(ctx, nil, "PUBLISH", "foo", "bar")
defer pubPipe.Close()

require.Equal(t, &redjet.SubMessage{
Channel: "foo",
Type: "message",
Payload: "bar",
}, msg)
n, err := pubPipe.Int()
require.NoError(t, err)
require.Equal(t, 1, n)

msg, err = subCmd.NextSubMessage()
require.NoError(t, err)

require.Equal(t, &redjet.SubMessage{
Channel: "foo",
Type: "message",
Payload: "bar",
}, msg)
})
}

func TestClient_ConnReuse(t *testing.T) {
Expand Down
2 changes: 1 addition & 1 deletion pubsub.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ func (r *Pipeline) NextSubMessage() (*SubMessage, error) {

func isSubscribeCmd(cmd string) bool {
switch cmd {
case "SUBSCRIBE", "PSUBSCRIBE", "UNSUBSCRIBE", "PUNSUBSCRIBE", "PING", "QUIT", "RESET":
case "SUBSCRIBE", "PSUBSCRIBE", "UNSUBSCRIBE", "PUNSUBSCRIBE", "QUIT", "RESET":
return true
default:
return false
Expand Down

0 comments on commit a4bec12

Please sign in to comment.