From 7cd9c24519cc07ed494dc665cd0f866ef1f71268 Mon Sep 17 00:00:00 2001 From: Ammar Bandukwala Date: Tue, 13 Feb 2024 13:12:13 -0600 Subject: [PATCH] Simplify Close logic to prevent deadlocks --- client_test.go | 34 +++++++++++++++++++++++++--------- pipeline.go | 15 ++------------- redtest/server.go | 4 +++- 3 files changed, 30 insertions(+), 23 deletions(-) diff --git a/client_test.go b/client_test.go index 88b0d0c..f7288f6 100644 --- a/client_test.go +++ b/client_test.go @@ -146,19 +146,35 @@ func TestClient_Race(t *testing.T) { func TestClient_IdleDrain(t *testing.T) { t.Parallel() - _, client := redtest.StartRedisServer(t) + t.Run("Unexpected", func(t *testing.T) { + t.Parallel() + _, client := redtest.StartRedisServer(t) - require.Equal(t, 0, client.PoolStats().FreeConns) + require.Equal(t, 0, client.PoolStats().FreeConns) - err := client.Command(context.Background(), "SET", "foo", "bar").Close() - require.NoError(t, err) + // Close comes before reading result. + err := client.Command(context.Background(), "SET", "foo", "bar").Close() + require.NoError(t, err) + + // Connection not returned to pool. + require.Equal(t, 0, client.PoolStats().FreeConns) + }) - require.Equal(t, 1, client.PoolStats().FreeConns) + t.Run("Regular", func(t *testing.T) { + t.Parallel() + _, client := redtest.StartRedisServer(t) - // After the idle timeout, the connection should be drained. - require.Eventually(t, func() bool { - return client.PoolStats().FreeConns == 0 - }, time.Second, 10*time.Millisecond) + require.Equal(t, 0, client.PoolStats().FreeConns) + err := client.Command(context.Background(), "SET", "foo", "bar").Ok() + require.NoError(t, err) + + require.Equal(t, 1, client.PoolStats().FreeConns) + + // After the idle timeout, the connection should be drained. + require.Eventually(t, func() bool { + return client.PoolStats().FreeConns == 0 + }, time.Second, 10*time.Millisecond) + }) } func TestClient_ShortRead(t *testing.T) { diff --git a/pipeline.go b/pipeline.go index 0a88b47..3feb1c7 100644 --- a/pipeline.go +++ b/pipeline.go @@ -483,18 +483,6 @@ func (r *Pipeline) Close() error { } func (r *Pipeline) close() error { - for r.hasMore() { - // Read the result into discard so that the connection can be reused. - _, _, err := r.writeTo(io.Discard) - if errors.Is(err, errClosed) { - // Should be impossible to close a result without draining - // it, in which case at == end and we would never get here. - return fmt.Errorf("SEVERE: result closed while iterating") - } else if err != nil { - return fmt.Errorf("drain: %w", err) - } - } - if !atomic.CompareAndSwapInt64(&r.closed, 0, 1) { // double-close return nil @@ -507,7 +495,8 @@ func (r *Pipeline) close() error { conn := r.conn // r.conn is set to nil to prevent accidental reuse. r.conn = nil - if r.err == nil && !r.subscribeMode { + // Only return conn when it is in a known good state. + if r.err == nil && !r.subscribeMode && !r.hasMore() { r.client.putConn(conn) return nil } diff --git a/redtest/server.go b/redtest/server.go index 5c5a48d..a1fe71e 100644 --- a/redtest/server.go +++ b/redtest/server.go @@ -31,7 +31,9 @@ func (w *testWriter) Write(p []byte) (int, error) { } func StartRedisServer(t testing.TB, args ...string) (string, *redjet.Client) { - socket := filepath.Join(t.TempDir(), "redis.sock") + // Use short-hand r.sock instead of redis.sock since redis has a 104 + // character limit on unix socket paths. + socket := filepath.Join(t.TempDir(), "r.sock") serverCmd := exec.Command( "redis-server", "--unixsocket", socket, "--loglevel", "debug", "--port", "0",