|
18 | 18 | Union,
|
19 | 19 | )
|
20 | 20 |
|
| 21 | +from redisvl.query.query import VectorQuery |
21 | 22 | from redisvl.redis.utils import convert_bytes, make_dict
|
22 | 23 | from redisvl.utils.utils import deprecated_argument, deprecated_function, sync_wrapper
|
23 | 24 |
|
|
34 | 35 | from redis.commands.search.indexDefinition import IndexDefinition
|
35 | 36 |
|
36 | 37 | from redisvl.exceptions import (
|
| 38 | + QueryValidationError, |
37 | 39 | RedisModuleVersionError,
|
38 | 40 | RedisSearchError,
|
39 | 41 | RedisVLError,
|
|
46 | 48 | BaseVectorQuery,
|
47 | 49 | CountQuery,
|
48 | 50 | FilterQuery,
|
49 |
| - HybridQuery, |
50 | 51 | )
|
51 | 52 | from redisvl.query.filter import FilterExpression
|
52 | 53 | from redisvl.redis.connection import (
|
53 | 54 | RedisConnectionFactory,
|
54 | 55 | convert_index_info_to_schema,
|
55 | 56 | )
|
56 |
| -from redisvl.redis.utils import convert_bytes |
57 | 57 | from redisvl.schema import IndexSchema, StorageType
|
58 |
| -from redisvl.schema.fields import VECTOR_NORM_MAP, VectorDistanceMetric |
| 58 | +from redisvl.schema.fields import ( |
| 59 | + VECTOR_NORM_MAP, |
| 60 | + VectorDistanceMetric, |
| 61 | + VectorIndexAlgorithm, |
| 62 | +) |
59 | 63 | from redisvl.utils.log import get_logger
|
60 | 64 |
|
61 | 65 | logger = get_logger(__name__)
|
@@ -194,6 +198,15 @@ def _storage(self) -> BaseStorage:
|
194 | 198 | index_schema=self.schema
|
195 | 199 | )
|
196 | 200 |
|
| 201 | + def _validate_query(self, query: BaseQuery) -> None: |
| 202 | + """Validate a query.""" |
| 203 | + if isinstance(query, VectorQuery): |
| 204 | + field = self.schema.fields[query._vector_field_name] |
| 205 | + if query.ef_runtime and field.attrs.algorithm != VectorIndexAlgorithm.HNSW: # type: ignore |
| 206 | + raise QueryValidationError( |
| 207 | + "Vector field using 'flat' algorithm does not support EF_RUNTIME query parameter." |
| 208 | + ) |
| 209 | + |
197 | 210 | @property
|
198 | 211 | def name(self) -> str:
|
199 | 212 | """The name of the Redis search index."""
|
@@ -592,6 +605,27 @@ def drop_keys(self, keys: Union[str, List[str]]) -> int:
|
592 | 605 | else:
|
593 | 606 | return self._redis_client.delete(keys) # type: ignore
|
594 | 607 |
|
| 608 | + def drop_documents(self, ids: Union[str, List[str]]) -> int: |
| 609 | + """Remove documents from the index by their document IDs. |
| 610 | +
|
| 611 | + This method converts document IDs to Redis keys automatically by applying |
| 612 | + the index's key prefix and separator configuration. |
| 613 | +
|
| 614 | + Args: |
| 615 | + ids (Union[str, List[str]]): The document ID or IDs to remove from the index. |
| 616 | +
|
| 617 | + Returns: |
| 618 | + int: Count of documents deleted from Redis. |
| 619 | + """ |
| 620 | + if isinstance(ids, list): |
| 621 | + if not ids: |
| 622 | + return 0 |
| 623 | + keys = [self.key(id) for id in ids] |
| 624 | + return self._redis_client.delete(*keys) # type: ignore |
| 625 | + else: |
| 626 | + key = self.key(ids) |
| 627 | + return self._redis_client.delete(key) # type: ignore |
| 628 | + |
595 | 629 | def expire_keys(
|
596 | 630 | self, keys: Union[str, List[str]], ttl: int
|
597 | 631 | ) -> Union[int, List[int]]:
|
@@ -816,6 +850,10 @@ def batch_query(
|
816 | 850 |
|
817 | 851 | def _query(self, query: BaseQuery) -> List[Dict[str, Any]]:
|
818 | 852 | """Execute a query and process results."""
|
| 853 | + try: |
| 854 | + self._validate_query(query) |
| 855 | + except QueryValidationError as e: |
| 856 | + raise QueryValidationError(f"Invalid query: {str(e)}") from e |
819 | 857 | results = self.search(query.query, query_params=query.params)
|
820 | 858 | return process_results(results, query=query, schema=self.schema)
|
821 | 859 |
|
@@ -1236,6 +1274,28 @@ async def drop_keys(self, keys: Union[str, List[str]]) -> int:
|
1236 | 1274 | else:
|
1237 | 1275 | return await client.delete(keys)
|
1238 | 1276 |
|
| 1277 | + async def drop_documents(self, ids: Union[str, List[str]]) -> int: |
| 1278 | + """Remove documents from the index by their document IDs. |
| 1279 | +
|
| 1280 | + This method converts document IDs to Redis keys automatically by applying |
| 1281 | + the index's key prefix and separator configuration. |
| 1282 | +
|
| 1283 | + Args: |
| 1284 | + ids (Union[str, List[str]]): The document ID or IDs to remove from the index. |
| 1285 | +
|
| 1286 | + Returns: |
| 1287 | + int: Count of documents deleted from Redis. |
| 1288 | + """ |
| 1289 | + client = await self._get_client() |
| 1290 | + if isinstance(ids, list): |
| 1291 | + if not ids: |
| 1292 | + return 0 |
| 1293 | + keys = [self.key(id) for id in ids] |
| 1294 | + return await client.delete(*keys) |
| 1295 | + else: |
| 1296 | + key = self.key(ids) |
| 1297 | + return await client.delete(key) |
| 1298 | + |
1239 | 1299 | async def expire_keys(
|
1240 | 1300 | self, keys: Union[str, List[str]], ttl: int
|
1241 | 1301 | ) -> Union[int, List[int]]:
|
@@ -1356,9 +1416,10 @@ async def fetch(self, id: str) -> Optional[Dict[str, Any]]:
|
1356 | 1416 | async def _aggregate(
|
1357 | 1417 | self, aggregation_query: AggregationQuery
|
1358 | 1418 | ) -> List[Dict[str, Any]]:
|
1359 |
| - """Execute an aggretation query and processes the results.""" |
| 1419 | + """Execute an aggregation query and processes the results.""" |
1360 | 1420 | results = await self.aggregate(
|
1361 |
| - aggregation_query, query_params=aggregation_query.params # type: ignore[attr-defined] |
| 1421 | + aggregation_query, |
| 1422 | + query_params=aggregation_query.params, # type: ignore[attr-defined] |
1362 | 1423 | )
|
1363 | 1424 | return process_aggregate_results(
|
1364 | 1425 | results,
|
@@ -1486,6 +1547,10 @@ async def batch_query(
|
1486 | 1547 |
|
1487 | 1548 | async def _query(self, query: BaseQuery) -> List[Dict[str, Any]]:
|
1488 | 1549 | """Asynchronously execute a query and process results."""
|
| 1550 | + try: |
| 1551 | + self._validate_query(query) |
| 1552 | + except QueryValidationError as e: |
| 1553 | + raise QueryValidationError(f"Invalid query: {str(e)}") from e |
1489 | 1554 | results = await self.search(query.query, query_params=query.params)
|
1490 | 1555 | return process_results(results, query=query, schema=self.schema)
|
1491 | 1556 |
|
|
0 commit comments