Skip to content

Commit a958f00

Browse files
mvanhornclaude
andcommitted
feat(embedding): surface non-symmetric embedding config for VikingDB provider
VikingDB embedders accepted is_query but ignored it. Now VikingDBDenseEmbedder and VikingDBHybridEmbedder accept query_param/document_param and pass input_type to the API when non-symmetric mode is configured. - Add query_param/document_param to VikingDB Dense and Hybrid constructors - Add _resolve_input_type() to select query vs document param - Pass input_type in _call_api data items when set - Wire factory entries to pass config params through - Sparse embedder unchanged (sparse models are symmetric) Closes #655 Co-Authored-By: Claude Opus 4.6 (1M context) <[email protected]>
1 parent 7f05828 commit a958f00

3 files changed

Lines changed: 104 additions & 5 deletions

File tree

openviking/models/embedder/vikingdb_embedders.py

Lines changed: 40 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,15 @@ def _call_api(
3939
texts: List[str],
4040
dense_model: Dict[str, Any] = None,
4141
sparse_model: Optional[Dict[str, Any]] = None,
42+
input_type: Optional[str] = None,
4243
) -> List[Dict[str, Any]]:
4344
"""Call VikingDB Embedding API"""
4445
path = "/api/vikingdb/embedding"
4546

4647
data_items = [{"text": text} for text in texts]
48+
if input_type is not None:
49+
for item in data_items:
50+
item["input_type"] = input_type
4751

4852
req_body = {"data": data_items}
4953
if dense_model:
@@ -115,17 +119,31 @@ def __init__(
115119
dimension: Optional[int] = None,
116120
embedding_type: str = "text",
117121
config: Optional[Dict[str, Any]] = None,
122+
query_param: Optional[str] = None,
123+
document_param: Optional[str] = None,
118124
):
119125
DenseEmbedderBase.__init__(self, model_name, config)
120126
self._init_vikingdb_client(ak, sk, region, host)
121127
self.model_version = model_version
122128
self.dimension = dimension
123129
self.embedding_type = embedding_type
124130
self.dense_model = {"name": model_name, "version": model_version, "dim": dimension}
131+
self.query_param = query_param
132+
self.document_param = document_param
133+
134+
def _resolve_input_type(self, is_query: bool) -> Optional[str]:
135+
"""Return the input_type value for query or document side, or None for symmetric mode."""
136+
if is_query and self.query_param is not None:
137+
return self.query_param
138+
if not is_query and self.document_param is not None:
139+
return self.document_param
140+
return None
125141

126142
def embed(self, text: str, is_query: bool = False) -> EmbedResult:
143+
input_type = self._resolve_input_type(is_query)
144+
127145
def _call() -> EmbedResult:
128-
results = self._call_api([text], dense_model=self.dense_model)
146+
results = self._call_api([text], dense_model=self.dense_model, input_type=input_type)
129147
if not results:
130148
return EmbedResult(dense_vector=[])
131149

@@ -154,9 +172,10 @@ def _call() -> EmbedResult:
154172
def embed_batch(self, texts: List[str], is_query: bool = False) -> List[EmbedResult]:
155173
if not texts:
156174
return []
175+
input_type = self._resolve_input_type(is_query)
157176

158177
def _call() -> List[EmbedResult]:
159-
raw_results = self._call_api(texts, dense_model=self.dense_model)
178+
raw_results = self._call_api(texts, dense_model=self.dense_model, input_type=input_type)
160179
return [
161180
EmbedResult(
162181
dense_vector=self._truncate_and_normalize(
@@ -277,6 +296,8 @@ def __init__(
277296
dimension: Optional[int] = None,
278297
embedding_type: str = "text",
279298
config: Optional[Dict[str, Any]] = None,
299+
query_param: Optional[str] = None,
300+
document_param: Optional[str] = None,
280301
):
281302
HybridEmbedderBase.__init__(self, model_name, config)
282303
self._init_vikingdb_client(ak, sk, region, host)
@@ -288,19 +309,31 @@ def __init__(
288309
"name": model_name,
289310
"version": model_version,
290311
}
312+
self.query_param = query_param
313+
self.document_param = document_param
314+
315+
def _resolve_input_type(self, is_query: bool) -> Optional[str]:
316+
"""Return the input_type value for query or document side, or None for symmetric mode."""
317+
if is_query and self.query_param is not None:
318+
return self.query_param
319+
if not is_query and self.document_param is not None:
320+
return self.document_param
321+
return None
291322

292323
def embed(self, text: str, is_query: bool = False) -> EmbedResult:
324+
input_type = self._resolve_input_type(is_query)
325+
293326
def _call() -> EmbedResult:
294327
results = self._call_api(
295-
[text], dense_model=self.dense_model, sparse_model=self.sparse_model
328+
[text], dense_model=self.dense_model, sparse_model=self.sparse_model,
329+
input_type=input_type,
296330
)
297331
if not results:
298332
return EmbedResult(dense_vector=[], sparse_vector={})
299333

300334
item = results[0]
301335
dense_vector = []
302336
sparse_vector = {}
303-
304337
if "dense" in item:
305338
dense_vector = self._truncate_and_normalize(item["dense"], self.dimension)
306339
if "sparse" in item:
@@ -326,10 +359,12 @@ def _call() -> EmbedResult:
326359
def embed_batch(self, texts: List[str], is_query: bool = False) -> List[EmbedResult]:
327360
if not texts:
328361
return []
362+
input_type = self._resolve_input_type(is_query)
329363

330364
def _call() -> List[EmbedResult]:
331365
raw_results = self._call_api(
332-
texts, dense_model=self.dense_model, sparse_model=self.sparse_model
366+
texts, dense_model=self.dense_model, sparse_model=self.sparse_model,
367+
input_type=input_type,
333368
)
334369
results = []
335370
for item in raw_results:

openviking_cli/utils/config/embedding_config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -408,6 +408,8 @@ def _create_embedder(
408408
"dimension": cfg.dimension,
409409
"input_type": cfg.input,
410410
"config": {"max_retries": self.max_retries},
411+
**({"query_param": cfg.query_param} if cfg.query_param else {}),
412+
**({"document_param": cfg.document_param} if cfg.document_param else {}),
411413
},
412414
),
413415
("vikingdb", "sparse"): (
@@ -434,6 +436,8 @@ def _create_embedder(
434436
"dimension": cfg.dimension,
435437
"input_type": cfg.input,
436438
"config": {"max_retries": self.max_retries},
439+
**({"query_param": cfg.query_param} if cfg.query_param else {}),
440+
**({"document_param": cfg.document_param} if cfg.document_param else {}),
437441
},
438442
),
439443
("jina", "dense"): (
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd.
2+
# SPDX-License-Identifier: AGPL-3.0
3+
"""Tests for VikingDB non-symmetric embedding support."""
4+
5+
from unittest.mock import patch
6+
7+
import pytest
8+
9+
from openviking.models.embedder.vikingdb_embedders import (
10+
VikingDBDenseEmbedder,
11+
VikingDBHybridEmbedder,
12+
)
13+
14+
15+
@pytest.fixture
16+
def mock_vikingdb_client():
17+
"""Patch VikingDB client initialization."""
18+
with patch.object(
19+
VikingDBDenseEmbedder, "_init_vikingdb_client", return_value=None
20+
) as mock_init:
21+
mock_init.side_effect = lambda *args, **kwargs: None
22+
yield mock_init
23+
24+
25+
def test_dense_resolve_input_type_symmetric():
26+
"""When no query_param/document_param, input_type is None (symmetric)."""
27+
embedder = VikingDBDenseEmbedder.__new__(VikingDBDenseEmbedder)
28+
embedder.query_param = None
29+
embedder.document_param = None
30+
assert embedder._resolve_input_type(is_query=True) is None
31+
assert embedder._resolve_input_type(is_query=False) is None
32+
33+
34+
def test_dense_resolve_input_type_nonsymmetric():
35+
"""When query_param/document_param set, return correct value for is_query."""
36+
embedder = VikingDBDenseEmbedder.__new__(VikingDBDenseEmbedder)
37+
embedder.query_param = "query"
38+
embedder.document_param = "passage"
39+
assert embedder._resolve_input_type(is_query=True) == "query"
40+
assert embedder._resolve_input_type(is_query=False) == "passage"
41+
42+
43+
def test_hybrid_resolve_input_type_nonsymmetric():
44+
"""Hybrid embedder also resolves input_type correctly."""
45+
embedder = VikingDBHybridEmbedder.__new__(VikingDBHybridEmbedder)
46+
embedder.query_param = "search_query"
47+
embedder.document_param = "search_document"
48+
assert embedder._resolve_input_type(is_query=True) == "search_query"
49+
assert embedder._resolve_input_type(is_query=False) == "search_document"
50+
51+
52+
def test_dense_backward_compat_no_params():
53+
"""VikingDBDenseEmbedder without query_param/document_param works."""
54+
embedder = VikingDBDenseEmbedder.__new__(VikingDBDenseEmbedder)
55+
embedder.query_param = None
56+
embedder.document_param = None
57+
embedder.model_name = "test"
58+
embedder.dimension = 1024
59+
# Should not raise
60+
assert embedder._resolve_input_type(is_query=True) is None

0 commit comments

Comments
 (0)