Skip to content

Commit 357e17e

Browse files
committed
Merge branch 'main' into feat/langcache-extension
2 parents b760113 + c72460c commit 357e17e

File tree

7 files changed

+442
-33
lines changed

7 files changed

+442
-33
lines changed

docs/user_guide/index.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ User guides provide helpful resources for using RedisVL and its different compon
1515
01_getting_started
1616
02_hybrid_queries
1717
03_llmcache
18-
10_embeddings_cache
1918
04_vectorizers
2019
05_hash_vs_json
2120
06_rerankers

redisvl/query/aggregate.py

Lines changed: 69 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from typing import Any, Dict, List, Optional, Set, Tuple, Union
22

3-
from pydantic import BaseModel, field_validator
3+
from pydantic import BaseModel, field_validator, model_validator
44
from redis.commands.search.aggregation import AggregateRequest, Desc
5+
from typing_extensions import Self
56

67
from redisvl.query.filter import FilterExpression
78
from redisvl.redis.utils import array_to_buffer
@@ -32,9 +33,16 @@ def validate_dtype(cls, dtype: str) -> str:
3233
raise ValueError(
3334
f"Invalid data type: {dtype}. Supported types are: {[t.lower() for t in VectorDataType]}"
3435
)
35-
3636
return dtype
3737

38+
@model_validator(mode="after")
39+
def validate_vector(self) -> Self:
40+
"""If the vector passed in is an array of float convert it to a byte string."""
41+
if isinstance(self.vector, bytes):
42+
return self
43+
self.vector = array_to_buffer(self.vector, self.dtype)
44+
return self
45+
3846

3947
class AggregationQuery(AggregateRequest):
4048
"""
@@ -94,6 +102,7 @@ def __init__(
94102
return_fields: Optional[List[str]] = None,
95103
stopwords: Optional[Union[str, Set[str]]] = "english",
96104
dialect: int = 2,
105+
text_weights: Optional[Dict[str, float]] = None,
97106
):
98107
"""
99108
Instantiates a HybridQuery object.
@@ -119,6 +128,9 @@ def __init__(
119128
set, or tuple of strings is provided then those will be used as stopwords.
120129
Defaults to "english". if set to "None" then no stopwords will be removed.
121130
dialect (int, optional): The Redis dialect version. Defaults to 2.
131+
text_weights (Optional[Dict[str, float]]): The importance weighting of individual words
132+
within the query text. Defaults to None, as no modifications will be made to the
133+
text_scorer score.
122134
123135
Raises:
124136
ValueError: If the text string is empty, or if the text string becomes empty after
@@ -138,6 +150,7 @@ def __init__(
138150
self._dtype = dtype
139151
self._num_results = num_results
140152
self._set_stopwords(stopwords)
153+
self._text_weights = self._parse_text_weights(text_weights)
141154

142155
query_string = self._build_query_string()
143156
super().__init__(query_string)
@@ -185,6 +198,7 @@ def _set_stopwords(self, stopwords: Optional[Union[str, Set[str]]] = "english"):
185198
language will be used. if a list, set, or tuple of strings is provided then those
186199
will be used as stopwords. Defaults to "english". if set to "None" then no stopwords
187200
will be removed.
201+
188202
Raises:
189203
TypeError: If the stopwords are not a set, list, or tuple of strings.
190204
"""
@@ -214,6 +228,7 @@ def _tokenize_and_escape_query(self, user_query: str) -> str:
214228
215229
Returns:
216230
str: The tokenized and escaped query string.
231+
217232
Raises:
218233
ValueError: If the text string becomes empty after stopwords are removed.
219234
"""
@@ -225,13 +240,57 @@ def _tokenize_and_escape_query(self, user_query: str) -> str:
225240
)
226241
for token in user_query.split()
227242
]
228-
tokenized = " | ".join(
229-
[token for token in tokens if token and token not in self._stopwords]
230-
)
231243

232-
if not tokenized:
244+
token_list = [
245+
token for token in tokens if token and token not in self._stopwords
246+
]
247+
for i, token in enumerate(token_list):
248+
if token in self._text_weights:
249+
token_list[i] = f"{token}=>{{$weight:{self._text_weights[token]}}}"
250+
251+
if not token_list:
233252
raise ValueError("text string cannot be empty after removing stopwords")
234-
return tokenized
253+
return " | ".join(token_list)
254+
255+
def _parse_text_weights(
256+
self, weights: Optional[Dict[str, float]]
257+
) -> Dict[str, float]:
258+
parsed_weights: Dict[str, float] = {}
259+
if not weights:
260+
return parsed_weights
261+
for word, weight in weights.items():
262+
word = word.strip().lower()
263+
if not word or " " in word:
264+
raise ValueError(
265+
f"Only individual words may be weighted. Got {{ {word}:{weight} }}"
266+
)
267+
if (
268+
not (isinstance(weight, float) or isinstance(weight, int))
269+
or weight < 0.0
270+
):
271+
raise ValueError(
272+
f"Weights must be positive number. Got {{ {word}:{weight} }}"
273+
)
274+
parsed_weights[word] = weight
275+
return parsed_weights
276+
277+
def set_text_weights(self, weights: Dict[str, float]):
278+
"""Set or update the text weights for the query.
279+
280+
Args:
281+
text_weights: Dictionary of word:weight mappings
282+
"""
283+
self._text_weights = self._parse_text_weights(weights)
284+
self._query = self._build_query_string()
285+
286+
@property
287+
def text_weights(self) -> Dict[str, float]:
288+
"""Get the text weights.
289+
290+
Returns:
291+
Dictionary of word:weight mappings.
292+
"""
293+
return self._text_weights
235294

236295
def _build_query_string(self) -> str:
237296
"""Build the full query string for text search with optional filtering."""
@@ -256,7 +315,7 @@ def __str__(self) -> str:
256315

257316
class MultiVectorQuery(AggregationQuery):
258317
"""
259-
MultiVectorQuery allows for search over multiple vector fields in a document simulateously.
318+
MultiVectorQuery allows for search over multiple vector fields in a document simultaneously.
260319
The final score will be a weighted combination of the individual vector similarity scores
261320
following the formula:
262321
@@ -364,12 +423,8 @@ def params(self) -> Dict[str, Any]:
364423
Dict[str, Any]: The parameters for the aggregation.
365424
"""
366425
params = {}
367-
for i, (vector, dtype) in enumerate(
368-
[(v.vector, v.dtype) for v in self._vectors]
369-
):
370-
if isinstance(vector, list):
371-
vector = array_to_buffer(vector, dtype=dtype) # type: ignore
372-
params[f"vector_{i}"] = vector
426+
for i, v in enumerate(self._vectors):
427+
params[f"vector_{i}"] = v.vector
373428
return params
374429

375430
def _build_query_string(self) -> str:

redisvl/query/query.py

Lines changed: 53 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1028,6 +1028,7 @@ def __init__(
10281028
in_order: bool = False,
10291029
params: Optional[Dict[str, Any]] = None,
10301030
stopwords: Optional[Union[str, Set[str]]] = "english",
1031+
text_weights: Optional[Dict[str, float]] = None,
10311032
):
10321033
"""A query for running a full text search, along with an optional filter expression.
10331034
@@ -1064,13 +1065,16 @@ def __init__(
10641065
a default set of stopwords for that language will be used. Users may specify
10651066
their own stop words by providing a List or Set of words. if set to None,
10661067
then no words will be removed. Defaults to 'english'.
1067-
1068+
text_weights (Optional[Dict[str, float]]): The importance weighting of individual words
1069+
within the query text. Defaults to None, as no modifications will be made to the
1070+
text_scorer score.
10681071
Raises:
10691072
ValueError: if stopwords language string cannot be loaded.
10701073
TypeError: If stopwords is not a valid iterable set of strings.
10711074
"""
10721075
self._text = text
10731076
self._field_weights = self._parse_field_weights(text_field_name)
1077+
self._text_weights = self._parse_text_weights(text_weights)
10741078
self._num_results = num_results
10751079

10761080
self._set_stopwords(stopwords)
@@ -1151,9 +1155,14 @@ def _tokenize_and_escape_query(self, user_query: str) -> str:
11511155
)
11521156
for token in user_query.split()
11531157
]
1154-
return " | ".join(
1155-
[token for token in tokens if token and token not in self._stopwords]
1156-
)
1158+
token_list = [
1159+
token for token in tokens if token and token not in self._stopwords
1160+
]
1161+
for i, token in enumerate(token_list):
1162+
if token in self._text_weights:
1163+
token_list[i] = f"{token}=>{{$weight:{self._text_weights[token]}}}"
1164+
1165+
return " | ".join(token_list)
11571166

11581167
def _parse_field_weights(
11591168
self, field_spec: Union[str, Dict[str, float]]
@@ -1220,6 +1229,46 @@ def text_field_name(self) -> Union[str, Dict[str, float]]:
12201229
return field
12211230
return self._field_weights.copy()
12221231

1232+
def _parse_text_weights(
1233+
self, weights: Optional[Dict[str, float]]
1234+
) -> Dict[str, float]:
1235+
parsed_weights: Dict[str, float] = {}
1236+
if not weights:
1237+
return parsed_weights
1238+
for word, weight in weights.items():
1239+
word = word.strip().lower()
1240+
if not word or " " in word:
1241+
raise ValueError(
1242+
f"Only individual words may be weighted. Got {{ {word}:{weight} }}"
1243+
)
1244+
if (
1245+
not (isinstance(weight, float) or isinstance(weight, int))
1246+
or weight < 0.0
1247+
):
1248+
raise ValueError(
1249+
f"Weights must be positive number. Got {{ {word}:{weight} }}"
1250+
)
1251+
parsed_weights[word] = weight
1252+
return parsed_weights
1253+
1254+
def set_text_weights(self, weights: Dict[str, float]):
1255+
"""Set or update the text weights for the query.
1256+
1257+
Args:
1258+
text_weights: Dictionary of word:weight mappings
1259+
"""
1260+
self._text_weights = self._parse_text_weights(weights)
1261+
self._built_query_string = None
1262+
1263+
@property
1264+
def text_weights(self) -> Dict[str, float]:
1265+
"""Get the text weights.
1266+
1267+
Returns:
1268+
Dictionary of word:weight mappings.
1269+
"""
1270+
return self._text_weights
1271+
12231272
def _build_query_string(self) -> str:
12241273
"""Build the full query string for text search with optional filtering."""
12251274
filter_expression = self._filter_expression

tests/integration/test_aggregation.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,82 @@ def test_hybrid_query_with_text_filter(index):
317317
assert "research" not in result[text_field].lower()
318318

319319

320+
@pytest.mark.parametrize("scorer", ["BM25", "BM25STD", "TFIDF", "TFIDF.DOCNORM"])
321+
def test_hybrid_query_word_weights(index, scorer):
322+
skip_if_redis_version_below(index.client, "7.2.0")
323+
324+
text = "a medical professional with expertise in lung cancers"
325+
text_field = "description"
326+
vector = [0.1, 0.1, 0.5]
327+
vector_field = "user_embedding"
328+
return_fields = ["description"]
329+
330+
weights = {"medical": 3.4, "cancers": 5}
331+
332+
# test we can run a query with text weights
333+
weighted_query = HybridQuery(
334+
text=text,
335+
text_field_name=text_field,
336+
vector=vector,
337+
vector_field_name=vector_field,
338+
return_fields=return_fields,
339+
text_scorer=scorer,
340+
text_weights=weights,
341+
)
342+
343+
weighted_results = index.query(weighted_query)
344+
assert len(weighted_results) == 7
345+
346+
# test that weights do change the scores on results
347+
unweighted_query = HybridQuery(
348+
text=text,
349+
text_field_name=text_field,
350+
vector=vector,
351+
vector_field_name=vector_field,
352+
return_fields=return_fields,
353+
text_scorer=scorer,
354+
text_weights={},
355+
)
356+
357+
unweighted_results = index.query(unweighted_query)
358+
359+
for weighted, unweighted in zip(weighted_results, unweighted_results):
360+
for word in weights:
361+
if word in weighted["description"] or word in unweighted["description"]:
362+
assert float(weighted["text_score"]) > float(unweighted["text_score"])
363+
364+
# test that weights do change the document score and order of results
365+
weights = {"medical": 5, "cancers": 3.4} # switch the weights
366+
weighted_query = HybridQuery(
367+
text=text,
368+
text_field_name=text_field,
369+
vector=vector,
370+
vector_field_name=vector_field,
371+
return_fields=return_fields,
372+
text_scorer=scorer,
373+
text_weights=weights,
374+
)
375+
376+
weighted_results = index.query(weighted_query)
377+
assert weighted_results != unweighted_results
378+
379+
# test assigning weights on construction is equivalent to setting them on the query object
380+
new_query = HybridQuery(
381+
text=text,
382+
text_field_name=text_field,
383+
vector=vector,
384+
vector_field_name=vector_field,
385+
return_fields=return_fields,
386+
text_scorer=scorer,
387+
text_weights=None,
388+
)
389+
390+
new_query.set_text_weights(weights)
391+
392+
new_weighted_results = index.query(new_query)
393+
assert new_weighted_results == weighted_results
394+
395+
320396
def test_multivector_query(index):
321397
skip_if_redis_version_below(index.client, "7.2.0")
322398

@@ -365,6 +441,31 @@ def test_multivector_query(index):
365441
)
366442

367443

444+
def test_multivector_query_accepts_bytes(index):
445+
skip_if_redis_version_below(index.client, "7.2.0")
446+
447+
vector_bytes = [
448+
array_to_buffer([0.1, 0.1, 0.5], "float32"),
449+
array_to_buffer([0.3, 0.4, 0.7, 0.2, -0.3, 0.25], "float64"),
450+
]
451+
vector_fields = ["user_embedding", "audio_embedding"]
452+
dtypes = ["float32", "float64"]
453+
vectors = []
454+
for vector, field, dtype in zip(vector_bytes, vector_fields, dtypes):
455+
vectors.append(Vector(vector=vector, field_name=field, dtype=dtype))
456+
457+
return_fields = ["user", "credit_score", "age", "job", "location", "description"]
458+
459+
multi_query = MultiVectorQuery(
460+
vectors=vectors,
461+
return_fields=return_fields,
462+
)
463+
464+
results = index.query(multi_query)
465+
assert isinstance(results, list)
466+
assert len(results) == 7
467+
468+
368469
def test_multivector_query_with_filter(index):
369470
skip_if_redis_version_below(index.client, "7.2.0")
370471

0 commit comments

Comments
 (0)