Skip to content

Commit

Permalink
Simplify Close logic to prevent deadlocks
Browse files Browse the repository at this point in the history
  • Loading branch information
ammario committed Feb 13, 2024
1 parent 7a21b41 commit 7cd9c24
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 23 deletions.
34 changes: 25 additions & 9 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,19 +146,35 @@ func TestClient_Race(t *testing.T) {
func TestClient_IdleDrain(t *testing.T) {
t.Parallel()

_, client := redtest.StartRedisServer(t)
t.Run("Unexpected", func(t *testing.T) {
t.Parallel()
_, client := redtest.StartRedisServer(t)

require.Equal(t, 0, client.PoolStats().FreeConns)
require.Equal(t, 0, client.PoolStats().FreeConns)

err := client.Command(context.Background(), "SET", "foo", "bar").Close()
require.NoError(t, err)
// Close comes before reading result.
err := client.Command(context.Background(), "SET", "foo", "bar").Close()
require.NoError(t, err)

// Connection not returned to pool.
require.Equal(t, 0, client.PoolStats().FreeConns)
})

require.Equal(t, 1, client.PoolStats().FreeConns)
t.Run("Regular", func(t *testing.T) {
t.Parallel()
_, client := redtest.StartRedisServer(t)

// After the idle timeout, the connection should be drained.
require.Eventually(t, func() bool {
return client.PoolStats().FreeConns == 0
}, time.Second, 10*time.Millisecond)
require.Equal(t, 0, client.PoolStats().FreeConns)
err := client.Command(context.Background(), "SET", "foo", "bar").Ok()
require.NoError(t, err)

require.Equal(t, 1, client.PoolStats().FreeConns)

// After the idle timeout, the connection should be drained.
require.Eventually(t, func() bool {
return client.PoolStats().FreeConns == 0
}, time.Second, 10*time.Millisecond)
})
}

func TestClient_ShortRead(t *testing.T) {
Expand Down
15 changes: 2 additions & 13 deletions pipeline.go
Original file line number Diff line number Diff line change
Expand Up @@ -483,18 +483,6 @@ func (r *Pipeline) Close() error {
}

func (r *Pipeline) close() error {
for r.hasMore() {
// Read the result into discard so that the connection can be reused.
_, _, err := r.writeTo(io.Discard)
if errors.Is(err, errClosed) {
// Should be impossible to close a result without draining
// it, in which case at == end and we would never get here.
return fmt.Errorf("SEVERE: result closed while iterating")
} else if err != nil {
return fmt.Errorf("drain: %w", err)
}
}

if !atomic.CompareAndSwapInt64(&r.closed, 0, 1) {
// double-close
return nil
Expand All @@ -507,7 +495,8 @@ func (r *Pipeline) close() error {
conn := r.conn
// r.conn is set to nil to prevent accidental reuse.
r.conn = nil
if r.err == nil && !r.subscribeMode {
// Only return conn when it is in a known good state.
if r.err == nil && !r.subscribeMode && !r.hasMore() {
r.client.putConn(conn)
return nil
}
Expand Down
4 changes: 3 additions & 1 deletion redtest/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ func (w *testWriter) Write(p []byte) (int, error) {
}

func StartRedisServer(t testing.TB, args ...string) (string, *redjet.Client) {
socket := filepath.Join(t.TempDir(), "redis.sock")
// Use short-hand r.sock instead of redis.sock since redis has a 104
// character limit on unix socket paths.
socket := filepath.Join(t.TempDir(), "r.sock")
serverCmd := exec.Command(
"redis-server", "--unixsocket", socket, "--loglevel", "debug",
"--port", "0",
Expand Down

0 comments on commit 7cd9c24

Please sign in to comment.