Skip to content

Commit 8374427

Browse files
committed
Distinguish between redis nil values and empty strings
1 parent 69956ca commit 8374427

File tree

3 files changed

+89
-48
lines changed

3 files changed

+89
-48
lines changed

client.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,7 @@ func (c *Client) Pipeline(ctx context.Context, p *Pipeline, cmd string, args ...
297297
p, err = c.getConn(ctx)
298298
if err != nil {
299299
return &Pipeline{
300-
err: fmt.Errorf("get conn: %w", err),
300+
protoErr: fmt.Errorf("get conn: %w", err),
301301
}
302302
}
303303

@@ -365,7 +365,7 @@ func (c *Client) Command(ctx context.Context, cmd string, args ...any) *Pipeline
365365
return &Pipeline{
366366
// Close behavior becomes confusing when combining subscription
367367
// and CloseOnRead.
368-
err: fmt.Errorf("cannot use Command with subscribe command %s, use Pipeline instead", cmd),
368+
protoErr: fmt.Errorf("cannot use Command with subscribe command %s, use Pipeline instead", cmd),
369369
}
370370
}
371371
r := c.Pipeline(ctx, nil, cmd, args...)

client_test.go

+45-16
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ func TestClient_NotFound(t *testing.T) {
8484
ctx := context.Background()
8585

8686
got, err := client.Command(ctx, "GET", "brah").String()
87-
require.NoError(t, err)
87+
require.ErrorIs(t, err, redjet.ErrNil)
8888
require.Equal(t, "", got)
8989
}
9090

@@ -284,25 +284,37 @@ func TestClient_Stringer(t *testing.T) {
284284
func TestClient_JSON(t *testing.T) {
285285
t.Parallel()
286286

287-
_, client := redtest.StartRedisServer(t)
287+
t.Run("GetSet", func(t *testing.T) {
288+
t.Parallel()
289+
_, client := redtest.StartRedisServer(t)
288290

289-
var v struct {
290-
Foo string
291-
Bar int
292-
}
291+
var v struct {
292+
Foo string
293+
Bar int
294+
}
293295

294-
v.Foo = "bar"
295-
v.Bar = 123
296+
v.Foo = "bar"
297+
v.Bar = 123
296298

297-
ctx := context.Background()
298-
err := client.Command(ctx, "SET", "foo", v).Ok()
299-
require.NoError(t, err)
299+
ctx := context.Background()
300+
err := client.Command(ctx, "SET", "foo", v).Ok()
301+
require.NoError(t, err)
300302

301-
resp := make(map[string]interface{})
302-
err = client.Command(ctx, "GET", "foo").JSON(&resp)
303-
require.NoError(t, err)
304-
require.Equal(t, "bar", resp["Foo"])
305-
require.Equal(t, float64(123), resp["Bar"])
303+
resp := make(map[string]interface{})
304+
err = client.Command(ctx, "GET", "foo").JSON(&resp)
305+
require.NoError(t, err)
306+
require.Equal(t, "bar", resp["Foo"])
307+
require.Equal(t, float64(123), resp["Bar"])
308+
})
309+
t.Run("NotFound", func(t *testing.T) {
310+
t.Parallel()
311+
_, client := redtest.StartRedisServer(t)
312+
313+
var resp map[string]interface{}
314+
ctx := context.Background()
315+
err := client.Command(ctx, "GET", "foo").JSON(&resp)
316+
require.ErrorIs(t, err, redjet.ErrNil)
317+
})
306318
}
307319

308320
func TestClient_MGet(t *testing.T) {
@@ -328,6 +340,23 @@ func TestClient_MGet(t *testing.T) {
328340
require.Equal(t, []string{"antelope", "bat", "cat"}, got)
329341
}
330342

343+
func TestClient_MGet_Nil(t *testing.T) {
344+
t.Parallel()
345+
346+
_, client := redtest.StartRedisServer(t)
347+
348+
ctx := context.Background()
349+
// Only set first and last keys.
350+
err := client.Command(ctx, "MSET", "a", "antelope", "c", "cat").Ok()
351+
require.NoError(t, err)
352+
353+
// As a special case of handling nil, we return empty strings for
354+
// missing keys.
355+
got, err := client.Command(ctx, "MGET", "a", "b", "c").Strings()
356+
assert.NoError(t, err)
357+
assert.Equal(t, []string{"antelope", "", "cat"}, got)
358+
}
359+
331360
func TestClient_Auth(t *testing.T) {
332361
t.Parallel()
333362
const password = "hunt12"

pipeline.go

+42-30
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,8 @@ type Pipeline struct {
6262
closeCh chan struct{}
6363
closed int64
6464

65-
err error
65+
// protoErr is set to non-nil if there is an unrecoverable protocol error.
66+
protoErr error
6667

6768
conn *conn
6869
client *Client
@@ -77,10 +78,10 @@ type Pipeline struct {
7778
func (r *Pipeline) Error() string {
7879
r.mu.Lock()
7980
defer r.mu.Unlock()
80-
if r.err == nil {
81+
if r.protoErr == nil {
8182
return ""
8283
}
83-
return r.err.Error()
84+
return r.protoErr.Error()
8485
}
8586

8687
// readUntilNewline reads until a newline, returning the bytes without the newline.
@@ -110,6 +111,10 @@ var (
110111
_ grower = (*strings.Builder)(nil)
111112
)
112113

114+
// ErrNil is a nil value. For example, it is returned for missing keys in
115+
// GET and MGET.
116+
var ErrNil = errors.New("(nil)")
117+
113118
func readBulkString(w io.Writer, rd *bufio.Reader, copyBuf []byte) (int, error) {
114119
newlineBuf, err := readUntilNewline(rd, copyBuf)
115120
if err != nil {
@@ -123,7 +128,7 @@ func readBulkString(w io.Writer, rd *bufio.Reader, copyBuf []byte) (int, error)
123128

124129
// n == -1 signals a nil value.
125130
if stringSize <= 0 {
126-
return 0, nil
131+
return 0, ErrNil
127132
}
128133

129134
if g, ok := w.(grower); ok {
@@ -231,7 +236,7 @@ func (r *Pipeline) Strings() ([]string, error) {
231236
var ss []string
232237
for i := 0; i < ln; i++ {
233238
s, err := r.String()
234-
if err != nil {
239+
if err != nil && !errors.Is(err, ErrNil) {
235240
return ss, fmt.Errorf("read string %d: %w", i, err)
236241
}
237242
ss = append(ss, s)
@@ -258,25 +263,25 @@ func (r *Pipeline) writeTo(w io.Writer) (int64, replyType, error) {
258263
return 0, 0, err
259264
}
260265

261-
if r.err != nil {
262-
return 0, 0, r.err
266+
if r.protoErr != nil {
267+
return 0, 0, r.protoErr
263268
}
264269

265270
if r.pipeline.at == r.pipeline.end && len(r.arrayStack) == 0 && !r.subscribeMode {
266271
return 0, 0, fmt.Errorf("no more results")
267272
}
268273

269-
r.err = r.conn.wr.Flush()
270-
if r.err != nil {
271-
r.err = fmt.Errorf("flush: %w", r.err)
272-
return 0, 0, r.err
274+
r.protoErr = r.conn.wr.Flush()
275+
if r.protoErr != nil {
276+
r.protoErr = fmt.Errorf("flush: %w", r.protoErr)
277+
return 0, 0, r.protoErr
273278
}
274279

275280
var typByte byte
276-
typByte, r.err = r.conn.rd.ReadByte()
277-
if r.err != nil {
278-
r.err = fmt.Errorf("read type: %w", r.err)
279-
return 0, 0, r.err
281+
typByte, r.protoErr = r.conn.rd.ReadByte()
282+
if r.protoErr != nil {
283+
r.protoErr = fmt.Errorf("read type: %w", r.protoErr)
284+
return 0, 0, r.protoErr
280285
}
281286
typ := replyType(typByte)
282287

@@ -306,15 +311,15 @@ func (r *Pipeline) writeTo(w io.Writer) (int64, replyType, error) {
306311
switch typ {
307312
case replyTypeSimpleString, replyTypeInteger, replyTypeArray:
308313
// Simple string or integer
309-
s, r.err = readUntilNewline(r.conn.rd, r.conn.miscBuf)
310-
if r.err != nil {
311-
return 0, typ, r.err
314+
s, r.protoErr = readUntilNewline(r.conn.rd, r.conn.miscBuf)
315+
if r.protoErr != nil {
316+
return 0, typ, r.protoErr
312317
}
313318

314319
isNewArray := typ == '*'
315320

316321
var n int
317-
n, r.err = w.Write(s)
322+
n, r.protoErr = w.Write(s)
318323
incrRead(isNewArray)
319324
var newArraySize int
320325
if isNewArray {
@@ -328,26 +333,33 @@ func (r *Pipeline) writeTo(w io.Writer) (int64, replyType, error) {
328333
r.arrayStack = append(r.arrayStack, newArraySize)
329334
}
330335
}
331-
return int64(n), typ, r.err
336+
return int64(n), typ, r.protoErr
332337
case replyTypeBulkString:
333338
// Bulk string
334-
var n int
335-
n, r.err = readBulkString(w, r.conn.rd, r.conn.miscBuf)
339+
var (
340+
n int
341+
err error
342+
)
343+
n, err = readBulkString(w, r.conn.rd, r.conn.miscBuf)
336344
incrRead(false)
337-
return int64(n), typ, r.err
345+
// A nil is highly recoverable.
346+
if !errors.Is(err, ErrNil) {
347+
r.protoErr = err
348+
}
349+
return int64(n), typ, err
338350
case replyTypeError:
339351
// Error
340-
s, r.err = readUntilNewline(r.conn.rd, r.conn.miscBuf)
341-
if r.err != nil {
342-
return 0, typ, r.err
352+
s, r.protoErr = readUntilNewline(r.conn.rd, r.conn.miscBuf)
353+
if r.protoErr != nil {
354+
return 0, typ, r.protoErr
343355
}
344356
incrRead(false)
345357
return 0, typ, &Error{
346358
raw: string(s),
347359
}
348360
default:
349-
r.err = fmt.Errorf("unknown type %q", typ)
350-
return 0, typ, r.err
361+
r.protoErr = fmt.Errorf("unknown type %q", typ)
362+
return 0, typ, r.protoErr
351363
}
352364
}
353365

@@ -456,7 +468,7 @@ func (r *Pipeline) Next() bool {
456468

457469
// HasMore returns true if there are more results to read.
458470
func (r *Pipeline) HasMore() bool {
459-
if r.err != nil {
471+
if r.protoErr != nil {
460472
return false
461473
}
462474

@@ -495,7 +507,7 @@ func (r *Pipeline) close() error {
495507
// r.conn is set to nil to prevent accidental reuse.
496508
r.conn = nil
497509
// Only return conn when it is in a known good state.
498-
if r.err == nil && !r.subscribeMode && !r.HasMore() {
510+
if r.protoErr == nil && !r.subscribeMode && !r.HasMore() {
499511
r.client.putConn(conn)
500512
return nil
501513
}

0 commit comments

Comments
 (0)