@@ -39,11 +39,15 @@ def _call_api(
3939 texts : List [str ],
4040 dense_model : Dict [str , Any ] = None ,
4141 sparse_model : Optional [Dict [str , Any ]] = None ,
42+ input_type : Optional [str ] = None ,
4243 ) -> List [Dict [str , Any ]]:
4344 """Call VikingDB Embedding API"""
4445 path = "/api/vikingdb/embedding"
4546
4647 data_items = [{"text" : text } for text in texts ]
48+ if input_type is not None :
49+ for item in data_items :
50+ item ["input_type" ] = input_type
4751
4852 req_body = {"data" : data_items }
4953 if dense_model :
@@ -115,17 +119,31 @@ def __init__(
115119 dimension : Optional [int ] = None ,
116120 embedding_type : str = "text" ,
117121 config : Optional [Dict [str , Any ]] = None ,
122+ query_param : Optional [str ] = None ,
123+ document_param : Optional [str ] = None ,
118124 ):
119125 DenseEmbedderBase .__init__ (self , model_name , config )
120126 self ._init_vikingdb_client (ak , sk , region , host )
121127 self .model_version = model_version
122128 self .dimension = dimension
123129 self .embedding_type = embedding_type
124130 self .dense_model = {"name" : model_name , "version" : model_version , "dim" : dimension }
131+ self .query_param = query_param
132+ self .document_param = document_param
133+
134+ def _resolve_input_type (self , is_query : bool ) -> Optional [str ]:
135+ """Return the input_type value for query or document side, or None for symmetric mode."""
136+ if is_query and self .query_param is not None :
137+ return self .query_param
138+ if not is_query and self .document_param is not None :
139+ return self .document_param
140+ return None
125141
126142 def embed (self , text : str , is_query : bool = False ) -> EmbedResult :
143+ input_type = self ._resolve_input_type (is_query )
144+
127145 def _call () -> EmbedResult :
128- results = self ._call_api ([text ], dense_model = self .dense_model )
146+ results = self ._call_api ([text ], dense_model = self .dense_model , input_type = input_type )
129147 if not results :
130148 return EmbedResult (dense_vector = [])
131149
@@ -154,9 +172,10 @@ def _call() -> EmbedResult:
154172 def embed_batch (self , texts : List [str ], is_query : bool = False ) -> List [EmbedResult ]:
155173 if not texts :
156174 return []
175+ input_type = self ._resolve_input_type (is_query )
157176
158177 def _call () -> List [EmbedResult ]:
159- raw_results = self ._call_api (texts , dense_model = self .dense_model )
178+ raw_results = self ._call_api (texts , dense_model = self .dense_model , input_type = input_type )
160179 return [
161180 EmbedResult (
162181 dense_vector = self ._truncate_and_normalize (
@@ -277,6 +296,8 @@ def __init__(
277296 dimension : Optional [int ] = None ,
278297 embedding_type : str = "text" ,
279298 config : Optional [Dict [str , Any ]] = None ,
299+ query_param : Optional [str ] = None ,
300+ document_param : Optional [str ] = None ,
280301 ):
281302 HybridEmbedderBase .__init__ (self , model_name , config )
282303 self ._init_vikingdb_client (ak , sk , region , host )
@@ -288,19 +309,31 @@ def __init__(
288309 "name" : model_name ,
289310 "version" : model_version ,
290311 }
312+ self .query_param = query_param
313+ self .document_param = document_param
314+
315+ def _resolve_input_type (self , is_query : bool ) -> Optional [str ]:
316+ """Return the input_type value for query or document side, or None for symmetric mode."""
317+ if is_query and self .query_param is not None :
318+ return self .query_param
319+ if not is_query and self .document_param is not None :
320+ return self .document_param
321+ return None
291322
292323 def embed (self , text : str , is_query : bool = False ) -> EmbedResult :
324+ input_type = self ._resolve_input_type (is_query )
325+
293326 def _call () -> EmbedResult :
294327 results = self ._call_api (
295- [text ], dense_model = self .dense_model , sparse_model = self .sparse_model
328+ [text ], dense_model = self .dense_model , sparse_model = self .sparse_model ,
329+ input_type = input_type ,
296330 )
297331 if not results :
298332 return EmbedResult (dense_vector = [], sparse_vector = {})
299333
300334 item = results [0 ]
301335 dense_vector = []
302336 sparse_vector = {}
303-
304337 if "dense" in item :
305338 dense_vector = self ._truncate_and_normalize (item ["dense" ], self .dimension )
306339 if "sparse" in item :
@@ -326,10 +359,12 @@ def _call() -> EmbedResult:
326359 def embed_batch (self , texts : List [str ], is_query : bool = False ) -> List [EmbedResult ]:
327360 if not texts :
328361 return []
362+ input_type = self ._resolve_input_type (is_query )
329363
330364 def _call () -> List [EmbedResult ]:
331365 raw_results = self ._call_api (
332- texts , dense_model = self .dense_model , sparse_model = self .sparse_model
366+ texts , dense_model = self .dense_model , sparse_model = self .sparse_model ,
367+ input_type = input_type ,
333368 )
334369 results = []
335370 for item in raw_results :
0 commit comments