Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion redisvl/index/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@

from redis import __version__ as redis_version
from redis.client import NEVER_DECODE
from redis.commands.helpers import get_protocol_version # type: ignore

from redisvl.utils.redis_protocol import get_protocol_version

# Redis 5.x compatibility (6 fixed the import path)
if redis_version.startswith("5"):
Expand Down
3 changes: 2 additions & 1 deletion redisvl/redis/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from redis import __version__ as redis_version
from redis.asyncio.cluster import RedisCluster as AsyncRedisCluster
from redis.client import NEVER_DECODE, Pipeline
from redis.commands.helpers import get_protocol_version
from redis.commands.search import AsyncSearch, Search
from redis.commands.search.commands import (
CREATE_CMD,
Expand All @@ -23,6 +22,8 @@
)
from redis.commands.search.field import Field

from redisvl.utils.redis_protocol import get_protocol_version

# Redis 5.x compatibility (6 fixed the import path)
if redis_version.startswith("5"):
from redis.commands.search.indexDefinition import ( # type: ignore[import-untyped]
Expand Down
95 changes: 95 additions & 0 deletions redisvl/utils/redis_protocol.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
"""
Utilities for handling Redis protocol version detection safely across different client types.

This module provides safe wrappers around redis-py's get_protocol_version function
to handle edge cases with Redis Cluster pipelines.
"""

from typing import Union

from redis.asyncio.cluster import ClusterPipeline as AsyncClusterPipeline
from redis.cluster import ClusterPipeline
from redis.commands.helpers import get_protocol_version as redis_get_protocol_version

from redisvl.utils.log import get_logger

logger = get_logger(__name__)


def get_protocol_version(client) -> str:
"""
Wrapper for redis-py's get_protocol_version that handles ClusterPipeline.

ClusterPipeline doesn't have nodes_manager attribute, so we need to
handle this case specially to avoid AttributeError.

Args:
client: Redis client, pipeline, or cluster pipeline object

Returns:
str: Protocol version ("2" or "3")

Note:
This function addresses issue #365 where get_protocol_version() fails
with ClusterPipeline objects due to missing nodes_manager attribute.
"""
# Handle sync ClusterPipeline
if isinstance(client, ClusterPipeline):
try:
# Try to get protocol from the underlying cluster client
if hasattr(client, "_redis_cluster") and client._redis_cluster:
try:
result = redis_get_protocol_version(client._redis_cluster)
if result is not None:
return result
except (AttributeError, Exception):
# If anything fails, fall back to default
pass

logger.debug(
"ClusterPipeline without valid _redis_cluster, defaulting to protocol 3"
)
return "3"
except AttributeError as e:
logger.debug(
f"Failed to get protocol version from ClusterPipeline: {e}, defaulting to protocol 3"
)
return "3"

# Handle async ClusterPipeline
if isinstance(client, AsyncClusterPipeline):
try:
# Try to get protocol from the underlying cluster client
if hasattr(client, "_redis_cluster") and client._redis_cluster:
try:
result = redis_get_protocol_version(client._redis_cluster)
if result is not None:
return result
except (AttributeError, Exception):
# If anything fails, fall back to default
pass

logger.debug(
"AsyncClusterPipeline without valid _redis_cluster, defaulting to protocol 3"
)
return "3"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bsbodden could you help me unpack this fix a bit?

So we were failing before because the objects was trying to use an attribute nodes_manager that didn't exist. The fix implemented here checks the protocol version and defaults to 3 in the case of not knowing.

Was the fail happening because it was previously defaulting to "2" when in cluster mode vs with other connections types?

I haven't done a ton with protocols in redis so trying to understand some of those key differences.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well shit, it should be 2, not 3 - and I was just able to get a real integration test working with my cluster setup and this won't cut it. Good catch.

except AttributeError as e:
logger.debug(
f"Failed to get protocol version from AsyncClusterPipeline: {e}, defaulting to protocol 3"
)
return "3"

# For all other client types, use the standard function
try:
result = redis_get_protocol_version(client)
if result is None:
logger.warning(
f"get_protocol_version returned None for client {type(client)}, defaulting to protocol 3"
)
return "3"
return result
except AttributeError as e:
logger.warning(
f"Failed to get protocol version from client {type(client)}: {e}, defaulting to protocol 3"
)
return "3"
142 changes: 142 additions & 0 deletions tests/integration/test_cluster_pipelining.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
"""
Integration test for issue #365: ClusterPipeline AttributeError fix
https://github.com/redis/redis-vl-python/issues/365

This test verifies that the safe_get_protocol_version fix prevents the
AttributeError: 'ClusterPipeline' object has no attribute 'nodes_manager'
"""

from unittest.mock import Mock

import pytest
from redis.asyncio.cluster import ClusterPipeline as AsyncClusterPipeline
from redis.cluster import ClusterPipeline

from redisvl.index import SearchIndex
from redisvl.query import FilterQuery
from redisvl.schema import IndexSchema
from redisvl.utils.redis_protocol import get_protocol_version


def test_pipeline_operations_no_nodes_manager_error(redis_url):
"""
Test that pipeline operations don't fail with nodes_manager AttributeError.

Before the fix, operations that use get_protocol_version() internally would fail
with AttributeError when using ClusterPipeline. This test ensures those operations
now work without that specific error.
"""
# Create a simple schema
schema_dict = {
"index": {"name": "test-365-fix", "prefix": "doc", "storage_type": "hash"},
"fields": [{"name": "id", "type": "tag"}, {"name": "text", "type": "text"}],
}

schema = IndexSchema.from_dict(schema_dict)
index = SearchIndex(schema, redis_url=redis_url)

# Create the index
index.create(overwrite=True)

try:
# Test 1: Load with batching (uses pipelines internally)
test_data = [{"id": f"item{i}", "text": f"Document {i}"} for i in range(10)]

# This would fail with AttributeError before the fix
keys = index.load(
data=test_data,
id_field="id",
batch_size=3, # Force multiple pipeline operations
)

assert len(keys) == 10

# Test 2: Batch search (uses safe_get_protocol_version internally)
queries = [FilterQuery(filter_expression=f"@id:{{item{i}}}") for i in range(3)]

try:
# The critical test: no AttributeError about nodes_manager
results = index.batch_search(queries, batch_size=2)
assert len(results) == 3
except Exception as e:
# If there's an error, it must NOT be the nodes_manager AttributeError
assert "nodes_manager" not in str(
e
), f"Got nodes_manager error which indicates fix isn't working: {e}"

# Test 3: TTL operations
try:
index.expire_keys(keys[:3], 3600)
except Exception as e:
# Again, ensure no nodes_manager error
assert "nodes_manager" not in str(e)

finally:
index.delete()


def test_json_storage_no_error(redis_url):
"""Test with JSON storage type."""
schema_dict = {
"index": {"name": "test-365-json", "prefix": "json", "storage_type": "json"},
"fields": [{"name": "id", "type": "tag"}, {"name": "data", "type": "text"}],
}

schema = IndexSchema.from_dict(schema_dict)
index = SearchIndex(schema, redis_url=redis_url)

index.create(overwrite=True)

try:
# Load test data
test_data = [{"id": f"doc{i}", "data": f"Document {i}"} for i in range(5)]

# Should work without nodes_manager AttributeError
keys = index.load(data=test_data, id_field="id", batch_size=2)

assert len(keys) == 5

finally:
index.delete()


def test_clusterpipeline_with_valid_redis_cluster_attribute():
"""
Test get_protocol_version when ClusterPipeline has _redis_cluster attribute.
"""
# Create mock ClusterPipeline with _redis_cluster attribute
mock_pipeline = Mock(spec=ClusterPipeline)
mock_cluster = Mock()
mock_cluster.nodes_manager.connection_kwargs.get.return_value = "3"
mock_pipeline._redis_cluster = mock_cluster

# Should successfully get protocol from _redis_cluster
result = get_protocol_version(mock_pipeline)
assert result == "3"


def test_clusterpipeline_with_none_redis_cluster():
"""
Test get_protocol_version when _redis_cluster is None.
"""
mock_pipeline = Mock(spec=ClusterPipeline)
mock_pipeline._redis_cluster = None

# Should fallback to "3"
result = get_protocol_version(mock_pipeline)
assert result == "3"


def test_async_clusterpipeline_without_nodes_manager():
"""
Test get_protocol_version with AsyncClusterPipeline missing nodes_manager.
"""
mock_pipeline = Mock(spec=AsyncClusterPipeline)
# Ensure no nodes_manager attribute
if hasattr(mock_pipeline, "nodes_manager"):
delattr(mock_pipeline, "nodes_manager")
mock_pipeline._redis_cluster = None

# Should fallback to "3" without error
result = get_protocol_version(mock_pipeline)
assert result == "3"
Loading
Loading