Skip to content

Commit 537fc9c

Browse files
committed
fixture tweak, add util fn, fix test
1 parent fee9ac2 commit 537fc9c

File tree

5 files changed

+43
-19
lines changed

5 files changed

+43
-19
lines changed

redisvl/extensions/router/semantic.py

+8-10
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from pathlib import Path
2-
from typing import Any, ClassVar, Dict, List, Optional, Type, Union
2+
from typing import Any, Dict, List, Optional, Type, Union
33

44
import redis.commands.search.reducers as reducers
55
import yaml
@@ -23,7 +23,7 @@
2323
from redisvl.redis.connection import RedisConnectionFactory
2424
from redisvl.redis.utils import convert_bytes, hashify, make_dict
2525
from redisvl.utils.log import get_logger
26-
from redisvl.utils.utils import deprecated_argument, model_to_dict
26+
from redisvl.utils.utils import deprecated_argument, model_to_dict, scan_by_pattern
2727
from redisvl.utils.vectorize.base import BaseVectorizer
2828
from redisvl.utils.vectorize.text.huggingface import HFTextVectorizer
2929

@@ -43,7 +43,6 @@ class SemanticRouter(BaseModel):
4343
"""Configuration for routing behavior."""
4444

4545
_index: SearchIndex = PrivateAttr()
46-
_persist_config: bool = PrivateAttr()
4746

4847
model_config = ConfigDict(arbitrary_types_allowed=True)
4948

@@ -170,7 +169,7 @@ def _initialize_index(
170169

171170
if not existed or overwrite:
172171
# write the routes to Redis
173-
self.add_routes(self.routes)
172+
self._add_routes(self.routes)
174173

175174
@property
176175
def route_names(self) -> List[str]:
@@ -213,7 +212,7 @@ def _route_ref_key(index: SearchIndex, route_name: str, reference_hash: str) ->
213212
"""Generate the route reference key."""
214213
return f"{index.prefix}:{route_name}:{reference_hash}"
215214

216-
def add_routes(self, routes: List[Route]):
215+
def _add_routes(self, routes: List[Route]):
217216
"""Add routes to the router and index.
218217
219218
Args:
@@ -719,8 +718,8 @@ def get_route_references(
719718
queries = self._make_filter_queries(reference_ids)
720719
elif route_name:
721720
if not keys:
722-
_, keys = self._index.client.scan( # type: ignore
723-
cursor=0, match=f"{self._index.prefix}:{route_name}:*"
721+
keys = scan_by_pattern(
722+
self._index.client, f"{self._index.prefix}:{route_name}:*" # type: ignore
724723
)
725724

726725
queries = self._make_filter_queries(
@@ -757,10 +756,9 @@ def delete_route_references(
757756
res = self._index.batch_query(queries)
758757
keys = [r[0]["id"] for r in res if len(r) > 0]
759758
elif not keys:
760-
_, keys = self._index.client.scan( # type: ignore
761-
cursor=0, match=f"{self._index.prefix}:{route_name}:*"
759+
keys = scan_by_pattern(
760+
self._index.client, f"{self._index.prefix}:{route_name}:*" # type: ignore
762761
)
763-
keys = convert_bytes(keys)
764762

765763
if not keys:
766764
raise ValueError(f"No references found for route {route_name}")

redisvl/utils/utils.py

+21-1
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,11 @@
77
from enum import Enum
88
from functools import wraps
99
from time import time
10-
from typing import Any, Callable, Coroutine, Dict, Optional
10+
from typing import Any, Callable, Coroutine, Dict, Optional, Sequence
1111
from warnings import warn
1212

1313
from pydantic import BaseModel
14+
from redis import Redis
1415
from ulid import ULID
1516

1617

@@ -213,3 +214,22 @@ def norm_l2_distance(value: float) -> float:
213214
Normalize the L2 distance.
214215
"""
215216
return 1 / (1 + value)
217+
218+
219+
def scan_by_pattern(
220+
redis_client: Redis,
221+
pattern: str,
222+
) -> Sequence[str]:
223+
"""
224+
Scan the Redis database for keys matching a specific pattern.
225+
226+
Args:
227+
redis (Redis): The Redis client instance.
228+
pattern (str): The pattern to match keys against.
229+
230+
Returns:
231+
List[str]: A dictionary containing the keys and their values.
232+
"""
233+
from redisvl.redis.utils import convert_bytes
234+
235+
return convert_bytes(list(redis_client.scan_iter(match=pattern)))

schemas/semantic_router.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
name: test-router-01JSEWS2CA00GT0HMABBGVEKRR
1+
name: test-router-01JSHERM5V4G94GN3W68XB45PK
22
routes:
33
- name: greeting
44
references:

tests/integration/test_query.py

+12-6
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
Timestamp,
2323
)
2424
from redisvl.redis.utils import array_to_buffer
25+
from redisvl.utils.utils import create_ulid
2526

2627
# TODO expand to multiple schema types and sync + async
2728

@@ -145,11 +146,12 @@ def sorted_range_query():
145146
@pytest.fixture
146147
def index(sample_data, redis_url):
147148
# construct a search index from the schema
149+
idx = f"user_index_{create_ulid()}"
148150
index = SearchIndex.from_dict(
149151
{
150152
"index": {
151-
"name": "user_index",
152-
"prefix": "v1",
153+
"name": idx,
154+
"prefix": idx,
153155
"storage_type": "hash",
154156
},
155157
"fields": [
@@ -190,17 +192,20 @@ def hash_preprocess(item: dict) -> dict:
190192
yield index
191193

192194
# clean up
193-
index.delete(drop=True)
195+
index.clear()
196+
index.delete()
194197

195198

196199
@pytest.fixture
197200
def L2_index(sample_data, redis_url):
198201
# construct a search index from the schema
202+
idx = f"L2_index_{create_ulid()}"
203+
199204
index = SearchIndex.from_dict(
200205
{
201206
"index": {
202-
"name": "L2_index",
203-
"prefix": "L2_index",
207+
"name": idx,
208+
"prefix": idx,
204209
"storage_type": "hash",
205210
},
206211
"fields": [
@@ -240,7 +245,8 @@ def hash_preprocess(item: dict) -> dict:
240245
yield index
241246

242247
# clean up
243-
index.delete(drop=True)
248+
index.clear()
249+
index.delete()
244250

245251

246252
def test_search_and_query(index):

tests/integration/test_semantic_router.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -520,4 +520,4 @@ def test_delete_route_references(semantic_router):
520520
assert deleted == 2
521521

522522
router_dict = semantic_router.to_dict()
523-
assert len(router_dict["references"]) == 0
523+
assert len(router_dict["routes"][0]["references"]) == 0

0 commit comments

Comments
 (0)