11from typing import Any , Dict , List , Optional , Set , Tuple , Union
22
3- from pydantic import BaseModel , field_validator
3+ from pydantic import BaseModel , field_validator , model_validator
44from redis .commands .search .aggregation import AggregateRequest , Desc
5+ from typing_extensions import Self
56
67from redisvl .query .filter import FilterExpression
78from redisvl .redis .utils import array_to_buffer
@@ -32,9 +33,16 @@ def validate_dtype(cls, dtype: str) -> str:
3233 raise ValueError (
3334 f"Invalid data type: { dtype } . Supported types are: { [t .lower () for t in VectorDataType ]} "
3435 )
35-
3636 return dtype
3737
38+ @model_validator (mode = "after" )
39+ def validate_vector (self ) -> Self :
40+ """If the vector passed in is an array of float convert it to a byte string."""
41+ if isinstance (self .vector , bytes ):
42+ return self
43+ self .vector = array_to_buffer (self .vector , self .dtype )
44+ return self
45+
3846
3947class AggregationQuery (AggregateRequest ):
4048 """
@@ -94,6 +102,7 @@ def __init__(
94102 return_fields : Optional [List [str ]] = None ,
95103 stopwords : Optional [Union [str , Set [str ]]] = "english" ,
96104 dialect : int = 2 ,
105+ text_weights : Optional [Dict [str , float ]] = None ,
97106 ):
98107 """
99108 Instantiates a HybridQuery object.
@@ -119,6 +128,9 @@ def __init__(
119128 set, or tuple of strings is provided then those will be used as stopwords.
120129 Defaults to "english". if set to "None" then no stopwords will be removed.
121130 dialect (int, optional): The Redis dialect version. Defaults to 2.
131+ text_weights (Optional[Dict[str, float]]): The importance weighting of individual words
132+ within the query text. Defaults to None, as no modifications will be made to the
133+ text_scorer score.
122134
123135 Raises:
124136 ValueError: If the text string is empty, or if the text string becomes empty after
@@ -138,6 +150,7 @@ def __init__(
138150 self ._dtype = dtype
139151 self ._num_results = num_results
140152 self ._set_stopwords (stopwords )
153+ self ._text_weights = self ._parse_text_weights (text_weights )
141154
142155 query_string = self ._build_query_string ()
143156 super ().__init__ (query_string )
@@ -185,6 +198,7 @@ def _set_stopwords(self, stopwords: Optional[Union[str, Set[str]]] = "english"):
185198 language will be used. if a list, set, or tuple of strings is provided then those
186199 will be used as stopwords. Defaults to "english". if set to "None" then no stopwords
187200 will be removed.
201+
188202 Raises:
189203 TypeError: If the stopwords are not a set, list, or tuple of strings.
190204 """
@@ -214,6 +228,7 @@ def _tokenize_and_escape_query(self, user_query: str) -> str:
214228
215229 Returns:
216230 str: The tokenized and escaped query string.
231+
217232 Raises:
218233 ValueError: If the text string becomes empty after stopwords are removed.
219234 """
@@ -225,13 +240,57 @@ def _tokenize_and_escape_query(self, user_query: str) -> str:
225240 )
226241 for token in user_query .split ()
227242 ]
228- tokenized = " | " .join (
229- [token for token in tokens if token and token not in self ._stopwords ]
230- )
231243
232- if not tokenized :
244+ token_list = [
245+ token for token in tokens if token and token not in self ._stopwords
246+ ]
247+ for i , token in enumerate (token_list ):
248+ if token in self ._text_weights :
249+ token_list [i ] = f"{ token } =>{{$weight:{ self ._text_weights [token ]} }}"
250+
251+ if not token_list :
233252 raise ValueError ("text string cannot be empty after removing stopwords" )
234- return tokenized
253+ return " | " .join (token_list )
254+
255+ def _parse_text_weights (
256+ self , weights : Optional [Dict [str , float ]]
257+ ) -> Dict [str , float ]:
258+ parsed_weights : Dict [str , float ] = {}
259+ if not weights :
260+ return parsed_weights
261+ for word , weight in weights .items ():
262+ word = word .strip ().lower ()
263+ if not word or " " in word :
264+ raise ValueError (
265+ f"Only individual words may be weighted. Got {{ { word } :{ weight } }}"
266+ )
267+ if (
268+ not (isinstance (weight , float ) or isinstance (weight , int ))
269+ or weight < 0.0
270+ ):
271+ raise ValueError (
272+ f"Weights must be positive number. Got {{ { word } :{ weight } }}"
273+ )
274+ parsed_weights [word ] = weight
275+ return parsed_weights
276+
277+ def set_text_weights (self , weights : Dict [str , float ]):
278+ """Set or update the text weights for the query.
279+
280+ Args:
281+ text_weights: Dictionary of word:weight mappings
282+ """
283+ self ._text_weights = self ._parse_text_weights (weights )
284+ self ._query = self ._build_query_string ()
285+
286+ @property
287+ def text_weights (self ) -> Dict [str , float ]:
288+ """Get the text weights.
289+
290+ Returns:
291+ Dictionary of word:weight mappings.
292+ """
293+ return self ._text_weights
235294
236295 def _build_query_string (self ) -> str :
237296 """Build the full query string for text search with optional filtering."""
@@ -256,7 +315,7 @@ def __str__(self) -> str:
256315
257316class MultiVectorQuery (AggregationQuery ):
258317 """
259- MultiVectorQuery allows for search over multiple vector fields in a document simulateously .
318+ MultiVectorQuery allows for search over multiple vector fields in a document simultaneously .
260319 The final score will be a weighted combination of the individual vector similarity scores
261320 following the formula:
262321
@@ -364,12 +423,8 @@ def params(self) -> Dict[str, Any]:
364423 Dict[str, Any]: The parameters for the aggregation.
365424 """
366425 params = {}
367- for i , (vector , dtype ) in enumerate (
368- [(v .vector , v .dtype ) for v in self ._vectors ]
369- ):
370- if isinstance (vector , list ):
371- vector = array_to_buffer (vector , dtype = dtype ) # type: ignore
372- params [f"vector_{ i } " ] = vector
426+ for i , v in enumerate (self ._vectors ):
427+ params [f"vector_{ i } " ] = v .vector
373428 return params
374429
375430 def _build_query_string (self ) -> str :
0 commit comments