Skip to content

Commit 9f22a9a

Browse files
Validate EF_RUNTIME query parameter (#319)
The EF_RUNTIME parameter is only valid in queries that target vector fields using the HNSW algorithm. If you run the same query against a 'flat' vector field, Redis will return an error. Because we have all the information to prevent you from running the bad query, we can introduce validation that raises a helpful error instead of relying on Redis error handling. --------- Co-authored-by: Tyler Hutcherson <[email protected]>
1 parent 61334ff commit 9f22a9a

File tree

9 files changed

+773
-52
lines changed

9 files changed

+773
-52
lines changed

.cursor/rules/redisvl.mdc

+11
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
---
2+
description:
3+
globs:
4+
alwaysApply: true
5+
---
6+
7+
# Rules for working on RedisVL
8+
- Do not change this line of code unless explicitly asked. It's already correct:
9+
```
10+
token.strip().strip(",").replace("“", "").replace("”", "").lower()
11+
```

redisvl/exceptions.py

+6
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,9 @@ def __init__(self, message, index=None):
3030
if index is not None:
3131
message = f"Validation failed for object at index {index}: {message}"
3232
super().__init__(message)
33+
34+
35+
class QueryValidationError(RedisVLError):
36+
"""Error when validating a query."""
37+
38+
pass

redisvl/index/index.py

+70-5
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
Union,
1919
)
2020

21+
from redisvl.query.query import VectorQuery
2122
from redisvl.redis.utils import convert_bytes, make_dict
2223
from redisvl.utils.utils import deprecated_argument, deprecated_function, sync_wrapper
2324

@@ -34,6 +35,7 @@
3435
from redis.commands.search.indexDefinition import IndexDefinition
3536

3637
from redisvl.exceptions import (
38+
QueryValidationError,
3739
RedisModuleVersionError,
3840
RedisSearchError,
3941
RedisVLError,
@@ -46,16 +48,18 @@
4648
BaseVectorQuery,
4749
CountQuery,
4850
FilterQuery,
49-
HybridQuery,
5051
)
5152
from redisvl.query.filter import FilterExpression
5253
from redisvl.redis.connection import (
5354
RedisConnectionFactory,
5455
convert_index_info_to_schema,
5556
)
56-
from redisvl.redis.utils import convert_bytes
5757
from redisvl.schema import IndexSchema, StorageType
58-
from redisvl.schema.fields import VECTOR_NORM_MAP, VectorDistanceMetric
58+
from redisvl.schema.fields import (
59+
VECTOR_NORM_MAP,
60+
VectorDistanceMetric,
61+
VectorIndexAlgorithm,
62+
)
5963
from redisvl.utils.log import get_logger
6064

6165
logger = get_logger(__name__)
@@ -194,6 +198,15 @@ def _storage(self) -> BaseStorage:
194198
index_schema=self.schema
195199
)
196200

201+
def _validate_query(self, query: BaseQuery) -> None:
202+
"""Validate a query."""
203+
if isinstance(query, VectorQuery):
204+
field = self.schema.fields[query._vector_field_name]
205+
if query.ef_runtime and field.attrs.algorithm != VectorIndexAlgorithm.HNSW: # type: ignore
206+
raise QueryValidationError(
207+
"Vector field using 'flat' algorithm does not support EF_RUNTIME query parameter."
208+
)
209+
197210
@property
198211
def name(self) -> str:
199212
"""The name of the Redis search index."""
@@ -592,6 +605,27 @@ def drop_keys(self, keys: Union[str, List[str]]) -> int:
592605
else:
593606
return self._redis_client.delete(keys) # type: ignore
594607

608+
def drop_documents(self, ids: Union[str, List[str]]) -> int:
609+
"""Remove documents from the index by their document IDs.
610+
611+
This method converts document IDs to Redis keys automatically by applying
612+
the index's key prefix and separator configuration.
613+
614+
Args:
615+
ids (Union[str, List[str]]): The document ID or IDs to remove from the index.
616+
617+
Returns:
618+
int: Count of documents deleted from Redis.
619+
"""
620+
if isinstance(ids, list):
621+
if not ids:
622+
return 0
623+
keys = [self.key(id) for id in ids]
624+
return self._redis_client.delete(*keys) # type: ignore
625+
else:
626+
key = self.key(ids)
627+
return self._redis_client.delete(key) # type: ignore
628+
595629
def expire_keys(
596630
self, keys: Union[str, List[str]], ttl: int
597631
) -> Union[int, List[int]]:
@@ -816,6 +850,10 @@ def batch_query(
816850

817851
def _query(self, query: BaseQuery) -> List[Dict[str, Any]]:
818852
"""Execute a query and process results."""
853+
try:
854+
self._validate_query(query)
855+
except QueryValidationError as e:
856+
raise QueryValidationError(f"Invalid query: {str(e)}") from e
819857
results = self.search(query.query, query_params=query.params)
820858
return process_results(results, query=query, schema=self.schema)
821859

@@ -1236,6 +1274,28 @@ async def drop_keys(self, keys: Union[str, List[str]]) -> int:
12361274
else:
12371275
return await client.delete(keys)
12381276

1277+
async def drop_documents(self, ids: Union[str, List[str]]) -> int:
1278+
"""Remove documents from the index by their document IDs.
1279+
1280+
This method converts document IDs to Redis keys automatically by applying
1281+
the index's key prefix and separator configuration.
1282+
1283+
Args:
1284+
ids (Union[str, List[str]]): The document ID or IDs to remove from the index.
1285+
1286+
Returns:
1287+
int: Count of documents deleted from Redis.
1288+
"""
1289+
client = await self._get_client()
1290+
if isinstance(ids, list):
1291+
if not ids:
1292+
return 0
1293+
keys = [self.key(id) for id in ids]
1294+
return await client.delete(*keys)
1295+
else:
1296+
key = self.key(ids)
1297+
return await client.delete(key)
1298+
12391299
async def expire_keys(
12401300
self, keys: Union[str, List[str]], ttl: int
12411301
) -> Union[int, List[int]]:
@@ -1356,9 +1416,10 @@ async def fetch(self, id: str) -> Optional[Dict[str, Any]]:
13561416
async def _aggregate(
13571417
self, aggregation_query: AggregationQuery
13581418
) -> List[Dict[str, Any]]:
1359-
"""Execute an aggretation query and processes the results."""
1419+
"""Execute an aggregation query and processes the results."""
13601420
results = await self.aggregate(
1361-
aggregation_query, query_params=aggregation_query.params # type: ignore[attr-defined]
1421+
aggregation_query,
1422+
query_params=aggregation_query.params, # type: ignore[attr-defined]
13621423
)
13631424
return process_aggregate_results(
13641425
results,
@@ -1486,6 +1547,10 @@ async def batch_query(
14861547

14871548
async def _query(self, query: BaseQuery) -> List[Dict[str, Any]]:
14881549
"""Asynchronously execute a query and process results."""
1550+
try:
1551+
self._validate_query(query)
1552+
except QueryValidationError as e:
1553+
raise QueryValidationError(f"Invalid query: {str(e)}") from e
14891554
results = await self.search(query.query, query_params=query.params)
14901555
return process_results(results, query=query, schema=self.schema)
14911556

0 commit comments

Comments
 (0)