Skip to content

Commit a4bec12

Browse files
committed
Add NewFromURL
Also, force pubsub to go through Pipeline
1 parent fe58e42 commit a4bec12

File tree

3 files changed

+139
-22
lines changed

3 files changed

+139
-22
lines changed

client.go

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@ package redjet
33
import (
44
"bufio"
55
"context"
6+
"crypto/tls"
67
"encoding/json"
78
"errors"
89
"fmt"
910
"io"
1011
"net"
12+
"net/url"
1113
"strconv"
1214
"strings"
1315
"sync"
@@ -58,6 +60,62 @@ func New(addr string) *Client {
5860
return c
5961
}
6062

63+
func NewFromURL(rawURL string) (*Client, error) {
64+
u, err := url.Parse(rawURL)
65+
if err != nil {
66+
return nil, fmt.Errorf("parse url: %w", err)
67+
}
68+
69+
client := New(u.Host)
70+
71+
var (
72+
addr string
73+
isUnixSocket bool
74+
)
75+
if u.Host == "" && u.Path != "" {
76+
// Likely using a unix socket.
77+
addr = u.Path
78+
isUnixSocket = true
79+
} else {
80+
addr = u.Host
81+
}
82+
83+
if !isUnixSocket {
84+
if u.Port() != "" {
85+
addr = net.JoinHostPort(addr, u.Port())
86+
} else {
87+
addr = net.JoinHostPort(addr, "6379")
88+
}
89+
}
90+
91+
switch u.Scheme {
92+
case "redis":
93+
client.Dial = func(ctx context.Context) (net.Conn, error) {
94+
var d net.Dialer
95+
proto := "tcp"
96+
if isUnixSocket {
97+
proto = "unix"
98+
}
99+
return d.DialContext(ctx, proto, addr)
100+
}
101+
case "rediss":
102+
client.Dial = func(ctx context.Context) (net.Conn, error) {
103+
var d tls.Dialer
104+
return d.DialContext(ctx, "tcp", addr)
105+
}
106+
default:
107+
return nil, fmt.Errorf("unsupported scheme: %s", u.Scheme)
108+
}
109+
110+
if u.User != nil {
111+
client.AuthUsername = u.User.Username()
112+
pass, _ := u.User.Password()
113+
client.AuthPassword = pass
114+
}
115+
116+
return client, nil
117+
}
118+
61119
func (c *Client) initPool() {
62120
c.poolMu.Lock()
63121
defer c.poolMu.Unlock()
@@ -294,6 +352,13 @@ func (c *Client) Pipeline(ctx context.Context, r *Pipeline, cmd string, args ...
294352
//
295353
// The caller should call Close on the result when finished with it.
296354
func (c *Client) Command(ctx context.Context, cmd string, args ...any) *Pipeline {
355+
if isSubscribeCmd(cmd) {
356+
return &Pipeline{
357+
// Close behavior becomes confusing when combining subscription
358+
// and CloseOnRead.
359+
err: fmt.Errorf("cannot use Command with subscribe command %s, use Pipeline instead", cmd),
360+
}
361+
}
297362
r := c.Pipeline(ctx, nil, cmd, args...)
298363
r.CloseOnRead = true
299364
return r

client_test.go

Lines changed: 73 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,41 @@ import (
1919
"go.uber.org/goleak"
2020
)
2121

22+
func pingPong(t *testing.T, c *redjet.Client) {
23+
ctx := context.Background()
24+
got, err := c.Command(ctx, "PING").String()
25+
require.NoError(t, err)
26+
require.Equal(t, "PONG", got)
27+
}
28+
29+
func TestNewFromURL(t *testing.T) {
30+
t.Parallel()
31+
32+
t.Run("Normal", func(t *testing.T) {
33+
t.Parallel()
34+
35+
addr, _ := redtest.StartRedisServer(t)
36+
37+
c, err := redjet.NewFromURL("redis://" + addr)
38+
require.NoError(t, err)
39+
defer c.Close()
40+
41+
pingPong(t, c)
42+
})
43+
44+
t.Run("Password", func(t *testing.T) {
45+
t.Parallel()
46+
47+
addr, _ := redtest.StartRedisServer(t, "--requirepass", "hunter2")
48+
49+
c, err := redjet.NewFromURL("redis://:hunter2@" + addr)
50+
require.NoError(t, err)
51+
defer c.Close()
52+
53+
pingPong(t, c)
54+
})
55+
}
56+
2257
func TestClient_SetGet(t *testing.T) {
2358
t.Parallel()
2459

@@ -313,33 +348,50 @@ func TestClient_Auth(t *testing.T) {
313348
func TestClient_PubSub(t *testing.T) {
314349
t.Parallel()
315350

316-
_, client := redtest.StartRedisServer(t)
351+
t.Run("NoCommand", func(t *testing.T) {
352+
_, client := redtest.StartRedisServer(t)
317353

318-
ctx := context.Background()
319-
subCmd := client.Command(ctx, "SUBSCRIBE", "foo")
320-
defer subCmd.Close()
354+
ctx := context.Background()
355+
subCmd := client.Command(ctx, "SUBSCRIBE", "foo")
356+
defer subCmd.Close()
321357

322-
msg, err := subCmd.NextSubMessage()
323-
require.NoError(t, err)
358+
_, err := subCmd.NextSubMessage()
359+
require.Error(t, err)
360+
})
324361

325-
require.Equal(t, &redjet.SubMessage{
326-
Channel: "foo",
327-
Type: "subscribe",
328-
Payload: "1",
329-
}, msg)
362+
t.Run("OK", func(t *testing.T) {
363+
_, client := redtest.StartRedisServer(t)
330364

331-
n, err := client.Command(ctx, "PUBLISH", "foo", "bar").Int()
332-
require.NoError(t, err)
333-
require.Equal(t, 1, n)
365+
ctx := context.Background()
334366

335-
msg, err = subCmd.NextSubMessage()
336-
require.NoError(t, err)
367+
subCmd := client.Pipeline(ctx, nil, "SUBSCRIBE", "foo")
368+
defer subCmd.Close()
369+
370+
msg, err := subCmd.NextSubMessage()
371+
require.NoError(t, err)
372+
373+
require.Equal(t, &redjet.SubMessage{
374+
Channel: "foo",
375+
Type: "subscribe",
376+
Payload: "1",
377+
}, msg)
378+
379+
pubPipe := client.Pipeline(ctx, nil, "PUBLISH", "foo", "bar")
380+
defer pubPipe.Close()
337381

338-
require.Equal(t, &redjet.SubMessage{
339-
Channel: "foo",
340-
Type: "message",
341-
Payload: "bar",
342-
}, msg)
382+
n, err := pubPipe.Int()
383+
require.NoError(t, err)
384+
require.Equal(t, 1, n)
385+
386+
msg, err = subCmd.NextSubMessage()
387+
require.NoError(t, err)
388+
389+
require.Equal(t, &redjet.SubMessage{
390+
Channel: "foo",
391+
Type: "message",
392+
Payload: "bar",
393+
}, msg)
394+
})
343395
}
344396

345397
func TestClient_ConnReuse(t *testing.T) {

pubsub.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ func (r *Pipeline) NextSubMessage() (*SubMessage, error) {
5858

5959
func isSubscribeCmd(cmd string) bool {
6060
switch cmd {
61-
case "SUBSCRIBE", "PSUBSCRIBE", "UNSUBSCRIBE", "PUNSUBSCRIBE", "PING", "QUIT", "RESET":
61+
case "SUBSCRIBE", "PSUBSCRIBE", "UNSUBSCRIBE", "PUNSUBSCRIBE", "QUIT", "RESET":
6262
return true
6363
default:
6464
return false

0 commit comments

Comments
 (0)