Skip to content

Commit 7a21b41

Browse files
committed
Refactor authentication into Setup function
1 parent 20dc921 commit 7a21b41

File tree

3 files changed

+44
-36
lines changed

3 files changed

+44
-36
lines changed

README.md

-3
Original file line numberDiff line numberDiff line change
@@ -128,9 +128,6 @@ for r.Next() {
128128
p.Close() // allow the underlying connection to be reused.
129129
```
130130

131-
Fun fact: authentication happens over a pipeline, so it doesn't incur a round-trip.
132-
133-
134131
## PubSub
135132

136133
redjet suports PubSub via the `NextSubMessage` method. For example:

client.go

+40-32
Original file line numberDiff line numberDiff line change
@@ -29,19 +29,11 @@ type Client struct {
2929
// Dial is the function used to create new connections.
3030
Dial func(ctx context.Context) (net.Conn, error)
3131

32-
// AuthUsername is the username used for authentication.
32+
// Setup is called after a new connection is established, but before any
33+
// commands are sent. It is useful for selecting a database or authenticating.
3334
//
34-
// If set, AuthPassword must also be set. If not using Redis ACLs, just
35-
// set AuthPassword.
36-
//
37-
// See more: https://redis.io/commands/auth/
38-
AuthUsername string
39-
// AuthPassword is the password used for authentication.
40-
// Authentication must be set before any other commands are sent, and
41-
// must not change during the lifetime of the client.
42-
//
43-
// See more: https://redis.io/commands/auth/
44-
AuthPassword string
35+
// See SetupAuth for authenticating with a username and password.
36+
Setup func(ctx context.Context, client *Client, pipe *Pipeline) error
4537

4638
poolMu sync.Mutex
4739
pool *connPool
@@ -104,9 +96,8 @@ func NewFromURL(rawURL string) (*Client, error) {
10496
}
10597

10698
if u.User != nil {
107-
client.AuthUsername = u.User.Username()
10899
pass, _ := u.User.Password()
109-
client.AuthPassword = pass
100+
client.Setup = SetupAuth(u.User.Username(), pass)
110101
}
111102

112103
return client, nil
@@ -164,28 +155,45 @@ func (c *Client) getConn(ctx context.Context) (*Pipeline, error) {
164155

165156
r := c.newResult(conn)
166157

167-
if c.AuthUsername == "" && c.AuthPassword == "" {
168-
return r, nil
169-
}
170-
171-
if c.AuthUsername != "" && c.AuthPassword == "" {
172-
nc.Close()
173-
return nil, errors.New("auth username set but password not set")
158+
if c.Setup != nil {
159+
err = c.Setup(ctx, c, r)
160+
if err != nil {
161+
nc.Close()
162+
return nil, fmt.Errorf("setup: %w", err)
163+
}
174164
}
175165

176-
if c.AuthUsername != "" {
177-
r = c.Pipeline(ctx, r, "AUTH", c.AuthUsername, c.AuthPassword)
178-
} else {
179-
r = c.Pipeline(ctx, r, "AUTH", c.AuthPassword)
180-
}
166+
return r, nil
167+
}
181168

182-
err = r.Ok()
183-
if err != nil {
184-
nc.Close()
185-
return nil, fmt.Errorf("auth: %w", err)
169+
// SetupAuth returns a Setup function that authenticates with the given username and password.
170+
//
171+
// AuthUsername is the username used for authentication.
172+
//
173+
// If set, AuthPassword must also be set. If not using Redis ACLs, just
174+
// set AuthPassword.
175+
//
176+
// See more: https://redis.io/commands/auth/
177+
// AuthPassword is the password used for authentication.
178+
// Authentication must be set before any other commands are sent, and
179+
// must not change during the lifetime of the client.
180+
//
181+
// See more: https://redis.io/commands/auth/
182+
func SetupAuth(
183+
username string,
184+
password string,
185+
) func(ctx context.Context, client *Client, pipe *Pipeline) error {
186+
return func(ctx context.Context, client *Client, pipe *Pipeline) error {
187+
switch {
188+
case username != "" && password != "":
189+
pipe = client.Pipeline(ctx, pipe, "AUTH", username, password)
190+
case password != "":
191+
pipe = client.Pipeline(ctx, pipe, "AUTH", password)
192+
default:
193+
return fmt.Errorf("username is set but password is not")
194+
}
195+
return pipe.Ok()
186196
}
187-
188-
return r, nil
189197
}
190198

191199
func (c *Client) putConn(conn *conn) {

client_test.go

+4-1
Original file line numberDiff line numberDiff line change
@@ -332,7 +332,10 @@ func TestClient_Auth(t *testing.T) {
332332

333333
_, client := redtest.StartRedisServer(t, "--requirepass", password)
334334
ctx := context.Background()
335-
client.AuthPassword = password
335+
client.Setup = redjet.SetupAuth(
336+
"",
337+
password,
338+
)
336339

337340
// It's imperative to test both SET and GET because the response
338341
// of SET matches the response of AUTH.

0 commit comments

Comments
 (0)