From bc0ab8a58b71265a3752688c3223ea8371c55a0f Mon Sep 17 00:00:00 2001 From: zhiqiang-hhhh Date: Tue, 25 Nov 2025 21:02:16 +0800 Subject: [PATCH] sup doris --- README.md | 37 ++ pyproject.toml | 1 + vectordb_bench/backend/clients/__init__.py | 15 + vectordb_bench/backend/clients/doris/cli.py | 199 ++++++++ .../backend/clients/doris/config.py | 82 ++++ vectordb_bench/backend/clients/doris/doris.py | 452 ++++++++++++++++++ vectordb_bench/backend/runner/rate_runner.py | 13 + vectordb_bench/backend/task_runner.py | 22 +- vectordb_bench/cli/vectordbbench.py | 2 + vectordb_bench/frontend/config/styles.py | 2 + 10 files changed, 824 insertions(+), 1 deletion(-) create mode 100644 vectordb_bench/backend/clients/doris/cli.py create mode 100644 vectordb_bench/backend/clients/doris/config.py create mode 100644 vectordb_bench/backend/clients/doris/doris.py diff --git a/README.md b/README.md index aa89c1a3a..578f3ca53 100644 --- a/README.md +++ b/README.md @@ -61,6 +61,7 @@ All the database client supported | hologres | `pip install vectordb-bench[hologres]` | | tencent_es | `pip install vectordb-bench[tencent_es]` | | alisql | `pip install 'vectordb-bench[alisql]'` | +| doris | `pip install vectordb-bench[doris]` | ### Run @@ -321,6 +322,42 @@ Options: --help Show this message and exit. ``` +### Run Doris from command line + +Doris supports ann index with type hnsw from version 4.0.x + +```shell +NUM_PER_BATCH=1000000 vectordbbench doris --http-port=8030 --port=9030 --db-name=vector_test --case-type=Performance768D1M --stream-load-rows-per-batch=500000 +``` + +Using flag `--session-var`, if you want to test doris with some customized session variables. For example: +```shell +NUM_PER_BATCH=1000000 vectordbbench doris --http-port=8030 --port=9030 --db-name=vector_test --case-type=Performance768D1M --stream-load-rows-per-batch=500000 --session-var enable_profile=True +``` + +Mote options: + +```text +--m INTEGER hnsw m +--ef-construction INTEGER hnsw ef-construction +--username TEXT Username [default: root; required] +--password TEXT Password [default: ""] +--host TEXT Db host [default: 127.0.0.1; required] +--port INTEGER Query Port [default: 9030; required] +--http-port INTEGER Http Port [default: 8030; required] +--db-name TEXT Db name [default: test; required] +--ssl / --no-ssl Enable or disable SSL, for Doris Serverless + SSL must be enabled [default: no-ssl] +--index-prop TEXT Extra index PROPERTY as key=value + (repeatable) +--session-var TEXT Session variable key=value applied to each + SQL session (repeatable) +--stream-load-rows-per-batch INTEGER + Rows per single stream load request; default + uses NUM_PER_BATCH +--no-index Create table without ANN index +``` + #### Using a configuration file. The vectordbbench command can optionally read some or all the options from a yaml formatted configuration file. diff --git a/pyproject.toml b/pyproject.toml index 70e73a407..5376c5077 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -100,6 +100,7 @@ vespa = [ "pyvespa" ] lancedb = [ "lancedb" ] oceanbase = [ "mysql-connector-python" ] alisql = [ "mysql-connector-python" ] +doris = [ "doris-vector-search" ] [project.urls] "repository" = "https://github.com/zilliztech/VectorDBBench" diff --git a/vectordb_bench/backend/clients/__init__.py b/vectordb_bench/backend/clients/__init__.py index 65a254c1e..bc8fa28bf 100644 --- a/vectordb_bench/backend/clients/__init__.py +++ b/vectordb_bench/backend/clients/__init__.py @@ -53,6 +53,7 @@ class DB(Enum): Hologres = "Alibaba Cloud Hologres" TencentElasticsearch = "TencentElasticsearch" AliSQL = "AlibabaCloudRDSMySQL" + Doris = "Doris" @property def init_cls(self) -> type[VectorDB]: # noqa: PLR0911, PLR0912, C901, PLR0915 @@ -177,6 +178,11 @@ def init_cls(self) -> type[VectorDB]: # noqa: PLR0911, PLR0912, C901, PLR0915 return TiDB + if self == DB.Doris: + from .doris.doris import Doris + + return Doris + if self == DB.Test: from .test.test import Test @@ -338,6 +344,11 @@ def config_cls(self) -> type[DBConfig]: # noqa: PLR0911, PLR0912, C901, PLR0915 return TiDBConfig + if self == DB.Doris: + from .doris.config import DorisConfig + + return DorisConfig + if self == DB.Test: from .test.config import TestConfig @@ -508,6 +519,10 @@ def case_config_cls( # noqa: C901, PLR0911, PLR0912, PLR0915 from .alisql.alisql import AliSQLIndexConfig return AliSQLIndexConfig + if self == DB.Doris: + from .doris.config import DorisCaseConfig + + return DorisCaseConfig # DB.Pinecone, DB.Chroma, DB.Redis return EmptyDBCaseConfig diff --git a/vectordb_bench/backend/clients/doris/cli.py b/vectordb_bench/backend/clients/doris/cli.py new file mode 100644 index 000000000..8153b412d --- /dev/null +++ b/vectordb_bench/backend/clients/doris/cli.py @@ -0,0 +1,199 @@ +from typing import Annotated, Unpack + +import click +from pydantic import SecretStr + +from vectordb_bench.backend.clients import DB + +from ....cli.cli import ( + CommonTypedDict, + HNSWBaseTypedDict, + cli, + click_parameter_decorators_from_typed_dict, + run, +) + + +def _parse_kv_list(_ctx, _param, values): # noqa: ANN001 + """Parse repeatable or comma-separated key=value items into a dict. + + Accepts any of the following forms (and mixtures thereof): + --index-prop a=1 --index-prop b=2 + --index-prop a=1,b=2 + --index-prop a=1,b=2 --index-prop c=3 + """ + parsed: dict[str, str] = {} + if not values: + return parsed + for item in values: + # allow comma-separated list in a single occurrence + parts = [p.strip() for p in str(item).split(",") if p and p.strip()] + for part in parts: + if "=" not in part: + msg = f"Expect key=value, got: {part}" + raise click.BadParameter(msg) + k, v = part.split("=", 1) + k = k.strip() + v = v.strip() + if not k: + msg = f"Empty key in: {part}" + raise click.BadParameter(msg) + parsed[k] = v + return parsed + + +class DorisTypedDict(CommonTypedDict, HNSWBaseTypedDict): + user_name: Annotated[ + str, + click.option( + "--username", + type=str, + help="Username", + default="root", + show_default=True, + required=True, + ), + ] + password: Annotated[ + str, + click.option( + "--password", + type=str, + default="", + show_default=True, + help="Password", + ), + ] + host: Annotated[ + str, + click.option( + "--host", + type=str, + default="127.0.0.1", + show_default=True, + required=True, + help="Db host", + ), + ] + port: Annotated[ + int, + click.option( + "--port", + type=int, + default=9030, + show_default=True, + required=True, + help="Query Port", + ), + ] + http_port: Annotated[ + int, + click.option( + "--http-port", + type=int, + default=8030, + show_default=True, + required=True, + help="Http Port", + ), + ] + db_name: Annotated[ + str, + click.option( + "--db-name", + type=str, + default="test", + show_default=True, + required=True, + help="Db name", + ), + ] + ssl: Annotated[ + bool, + click.option( + "--ssl/--no-ssl", + default=False, + show_default=True, + is_flag=True, + help="Enable or disable SSL, for Doris Serverless SSL must be enabled", + ), + ] + index_prop: Annotated[ + dict, + click.option( + "--index-prop", + type=str, + multiple=True, + help="Extra index PROPERTY as key=value (repeatable or comma-separated, e.g. a=1,b=2)", + callback=_parse_kv_list, + ), + ] + session_var: Annotated[ + dict, + click.option( + "--session-var", + type=str, + multiple=True, + help="Session variable key=value applied to each SQL session (repeatable or comma-separated)", + callback=_parse_kv_list, + ), + ] + stream_load_rows_per_batch: Annotated[ + int | None, + click.option( + "--stream-load-rows-per-batch", + type=int, + required=False, + help="Rows per single stream load request; default uses NUM_PER_BATCH", + ), + ] + no_index: Annotated[ + bool, + click.option( + "--no-index", + is_flag=True, + default=False, + show_default=True, + help="Create table without ANN index", + ), + ] + + +@cli.command() +@click_parameter_decorators_from_typed_dict(DorisTypedDict) +def Doris( + **parameters: Unpack[DorisTypedDict], +): + from .config import DorisCaseConfig, DorisConfig + + # Merge explicit HNSW params into index properties using Doris naming + index_properties: dict[str, str] = {} + index_properties.update(parameters.get("index_prop", {}) or {}) + if parameters.get("m") is not None: + index_properties.setdefault("max_degree", str(parameters["m"])) + if parameters.get("ef_construction") is not None: + index_properties.setdefault("ef_construction", str(parameters["ef_construction"])) + + session_vars: dict[str, str] = parameters.get("session_var", {}) or {} + + run( + db=DB.Doris, + db_config=DorisConfig( + db_label=parameters["db_label"], + user_name=parameters["username"], + password=SecretStr(parameters["password"]), + host=parameters["host"], + port=parameters["port"], + http_port=parameters["http_port"], + db_name=parameters["db_name"], + ssl=parameters["ssl"], + ), + # metric_type should come from the dataset; Assembler will set it on the case config. + db_case_config=DorisCaseConfig( + index_properties=index_properties, + session_vars=session_vars, + stream_load_rows_per_batch=parameters.get("stream_load_rows_per_batch"), + no_index=parameters.get("no_index", False), + ), + **parameters, + ) diff --git a/vectordb_bench/backend/clients/doris/config.py b/vectordb_bench/backend/clients/doris/config.py new file mode 100644 index 000000000..bd7c91c7b --- /dev/null +++ b/vectordb_bench/backend/clients/doris/config.py @@ -0,0 +1,82 @@ +import logging + +from pydantic import BaseModel, SecretStr, validator + +from ..api import DBCaseConfig, DBConfig, MetricType + +log = logging.getLogger(__name__) + + +class DorisConfig(DBConfig): + user_name: str = "root" + password: SecretStr + host: str = "127.0.0.1" + port: int = 9030 + # Doris FE HTTP port for stream load. Default 8030 (8040 for HTTPS if enabled). + http_port: int = 8030 + db_name: str = "test" + ssl: bool = False + + @validator("*") + def not_empty_field(cls, v: any, field: any): + return v + + def to_dict(self) -> dict: + pwd_str = self.password.get_secret_value() + return { + "host": self.host, + "port": self.port, + "http_port": self.http_port, + "user": self.user_name, + "password": pwd_str, + "database": self.db_name, + } + + +class DorisCaseConfig(BaseModel, DBCaseConfig): + metric_type: MetricType | None = None + # Optional explicit HNSW params for convenience + m: int | None = None + ef_construction: int | None = None + # Arbitrary index properties and session variables + index_properties: dict[str, str] | None = None + session_vars: dict[str, str] | None = None + # Control rows per single stream load request + stream_load_rows_per_batch: int | None = None + # Create table without ANN index + no_index: bool = False + + def get_metric_fn(self) -> str: + if self.metric_type == MetricType.L2: + return "l2_distance_approximate" + if self.metric_type == MetricType.IP: + return "inner_product_approximate" + if self.metric_type == MetricType.COSINE: + log.debug("Using inner_product_approximate because doris doesn't support cosine as metric type") + return "inner_product_approximate" + msg = f"Unsupported metric type: {self.metric_type}" + raise ValueError(msg) + + def index_param(self) -> dict: + # Use exact metric function name for index creation by removing '_approximate' suffix + metric_fn = self.get_metric_fn() + if metric_fn.endswith("_approximate"): + metric_fn = metric_fn[: -len("_approximate")] + props = {"metric_fn": metric_fn} + # Merge optional HNSW params + if self.m is not None: + props.setdefault("max_degree", str(self.m)) + if self.ef_construction is not None: + props.setdefault("ef_construction", str(self.ef_construction)) + # Merge user provided index_properties + if self.index_properties: + props.update(self.index_properties) + return props + + def search_param(self) -> dict: + return { + "metric_fn": self.get_metric_fn(), + } + + def session_param(self) -> dict: + return self.session_vars or {} diff --git a/vectordb_bench/backend/clients/doris/doris.py b/vectordb_bench/backend/clients/doris/doris.py new file mode 100644 index 000000000..1805a60c9 --- /dev/null +++ b/vectordb_bench/backend/clients/doris/doris.py @@ -0,0 +1,452 @@ +import logging +import os +from contextlib import contextmanager +from typing import Any + +import pandas as pd +from doris_vector_search import AuthOptions, DorisVectorClient, IndexOptions, LoadOptions + +from ..api import MetricType, VectorDB +from .config import DorisCaseConfig + +log = logging.getLogger(__name__) + + +class Doris(VectorDB): + def __init__( + self, + dim: int, + db_config: dict, + db_case_config: DorisCaseConfig, + collection_name: str | None = None, + drop_old: bool = False, + **kwargs, + ): + self.name = "Doris" + self.db_config = db_config + self.case_config = db_case_config + self.dim = dim + self.search_fn = db_case_config.search_param()["metric_fn"] + # Prefer provided collection_name; otherwise fallback to a simple default + # e.g. l2_distance128, inner_product128 + self.table_name = collection_name if collection_name else (self.search_fn + str(dim)) + + # Store connection configuration for lazy initialization + self.auth_options = AuthOptions( + host=db_config.get("host", "127.0.0.1"), + query_port=db_config.get("port", 9030), + http_port=db_config.get("http_port", 8030), + user=db_config.get("user", "root"), + password=db_config.get("password", ""), + ) + + # Configure load options + self.load_options = None + if hasattr(db_case_config, "stream_load_rows_per_batch") and db_case_config.stream_load_rows_per_batch: + self.load_options = LoadOptions(batch_size=db_case_config.stream_load_rows_per_batch) + + # Store database name for lazy initialization + self.database_name = db_config.get("database", "test") + + # Initialize client and table as None (lazy initialization) + self.client = None + self.table = None + self._client_pid: int | None = None + + if drop_old: + self._drop_table() + self._create_table() + + def _ensure_client_initialized(self): + """Ensure the client is initialized and bound to the current process. + + Multiprocessing will pickle the DB wrapper. Any existing mysql connection or + table cursor cached from a different PID must be discarded and recreated here. + """ + cur_pid = os.getpid() + + need_new_client = False + if self.client is None: + need_new_client = True + else: + # If client was created in another process or connection is not usable, recreate it + try: + # different process + if self._client_pid is None or self._client_pid != cur_pid: + need_new_client = True + else: + # check connection health if available + conn = getattr(self.client, "connection", None) + if conn is None or not getattr(conn, "is_connected", lambda: False)(): + need_new_client = True + except Exception: + need_new_client = True + + if need_new_client: + # Drop any table cached from another PID (its cursors are not valid across processes) + self.table = None + + # Recreate client and set sessions + self.client = DorisVectorClient( + database=self.database_name, + auth_options=self.auth_options, + load_options=self.load_options, + ) + + if hasattr(self.case_config, "session_vars") and self.case_config.session_vars: + self.client.with_sessions(self.case_config.session_vars) + + self._client_pid = cur_pid + + # Re-open table in this process to ensure fresh cursors + try: + self.table = self.client.open_table(self.table_name) + if hasattr(self.table, "index_options") and self.table.index_options: + self.table.index_options.dim = self.dim + if self.search_fn.startswith("inner_product"): + self.table.index_options.metric_type = "inner_product" + else: + self.table.index_options.metric_type = "l2_distance" + except Exception: + # Table might not exist yet; leave it to ready_to_load + self.table = None + + @contextmanager + def init(self): + try: + self._ensure_client_initialized() + # Open or create the table + if not self.table: + try: + # Try to open existing table + self.table = self.client.open_table(self.table_name) + # Avoid SHOW CREATE TABLE parsing in SDK by setting dim/metric directly + try: + if hasattr(self.table, "index_options") and self.table.index_options: + self.table.index_options.dim = self.dim + # Set metric_type according to current case + if self.search_fn.startswith("inner_product"): + self.table.index_options.metric_type = "inner_product" + else: + self.table.index_options.metric_type = "l2_distance" + except Exception: + log.exception("Failed to update index options for table: %s", self.table_name) + except Exception: + # Table doesn't exist, will be created in ready_to_load + self.table = None + yield + finally: + # Clean up if needed + pass + + def _drop_table(self): + try: + self._ensure_client_initialized() + self.client.drop_table(self.table_name) + except Exception: + log.exception("Failed to drop table: %s", self.table_name) + raise + + def _create_table(self): + """Create the table using doris-vector-search library""" + try: + self._ensure_client_initialized() + sample_data = pd.DataFrame([{"id": 1, "embedding": [0.0] * self.dim}]) + + index_options = None + if not getattr(self.case_config, "no_index", False): + index_options = self._build_index_options() + + self._create_table_with_options(sample_data, index_options) + log.info("Successfully created table %s", self.table_name) + + except Exception: + log.exception("Failed to create table: %s", self.table_name) + raise + + def _build_index_options(self) -> IndexOptions | None: + index_param = self.case_config.index_param() + metric_type = index_param.get("metric_fn", "l2_distance") + index_options = IndexOptions( + index_type="hnsw", + metric_type=metric_type, + dim=self.dim, + ) + + extra_props = {k: v for k, v in index_param.items() if k != "metric_fn"} + if extra_props: + applied, stored = {}, {} + for key, value in extra_props.items(): + attr_name = key + if hasattr(index_options, attr_name): + try: + setattr(index_options, attr_name, value) + applied[key] = value + except Exception: + stored[key] = value + else: + stored[key] = value + if stored: + index_options.properties = {**getattr(index_options, "properties", {}), **stored} + log.info( + "Index options prepared: metric=%s applied_props=%s stored_props=%s", + metric_type, + applied, + stored, + ) + else: + log.info("Index options prepared: metric=%s (no extra properties)", metric_type) + log.info("Creating table %s with index %s", self.table_name, index_param) + return index_options + + def _create_table_with_options(self, sample_data: pd.DataFrame, index_options: IndexOptions | None) -> None: + create_index = not getattr(self.case_config, "no_index", False) + if not create_index: + log.info("Creating table %s without ANN index", self.table_name) + + self.table = self.client.create_table( + self.table_name, + sample_data, + create_index=create_index, + index_options=index_options, + overwrite=True, + insert_data=False, + ) + + try: + if hasattr(self.table, "index_options") and self.table.index_options: + self.table.index_options.dim = self.dim + if self.search_fn.startswith("inner_product"): + self.table.index_options.metric_type = "inner_product" + else: + self.table.index_options.metric_type = "l2_distance" + if ( + index_options + and hasattr(index_options, "properties") + and isinstance(index_options.properties, dict) + ): + for key, value in index_options.properties.items(): + if hasattr(self.table.index_options, key): + try: + setattr(self.table.index_options, key, value) + except Exception: + log.debug("Skip setting index_options.%s at runtime", key) + except Exception: + log.exception("Failed to adjust index options for table: %s", self.table_name) + + def ready_to_load(self) -> bool: + self._ensure_client_initialized() + if not self.table: + self._create_table() + return True + + def optimize(self, data_size: int | None = None) -> None: + log.info("Optimization completed using doris-vector-search library") + + def need_normalize_cosine(self) -> bool: + """Wheather this database need to normalize dataset to support COSINE""" + if self.case_config.metric_type == MetricType.COSINE: + log.info("cosine dataset need normalize.") + return True + + return False + + def insert_embeddings( + self, + embeddings: list[list[float]], + metadata: list[int], + **kwargs: Any, + ) -> tuple[int, Exception | None]: + """Insert embeddings using doris-vector-search library.""" + try: + self._ensure_client_initialized() + # Prepare data in pandas DataFrame format + data = pd.DataFrame([{"id": metadata[i], "embedding": embeddings[i]} for i in range(len(embeddings))]) + + msg = f"Inserting {len(embeddings)} embeddings into table {self.table_name}" + log.info(msg) + + # Add data to the table + self.table.add(data) + + return len(metadata), None + + except Exception as e: + msg = "Failed to insert embeddings" + log.exception(msg) + return 0, e + + def search_embedding( + self, + query: list[float], + k: int = 100, + filters: dict | None = None, + timeout: int | None = None, + **kwargs: Any, + ) -> list[int]: + try: + self._ensure_client_initialized() + # Map metric functions to doris-vector-search metric types + metric_type = "l2_distance" + if self.search_fn.startswith("inner_product"): + metric_type = "inner_product" + elif self.search_fn.startswith("l2_distance"): + metric_type = "l2_distance" + + # Perform search using doris-vector-search + search_query = self.table.search(query, metric_type=metric_type).limit(k).select(["id"]) + + # Apply filters if provided + if filters and "id" in filters: + if self.search_fn.startswith("inner_product"): + where_clause = f"id >= {filters['id']}" + search_query = search_query.where(where_clause) + else: + where_clause = f"id < {filters['id']}" + search_query = search_query.where(where_clause) + + # Execute and get results + results_df = search_query.to_pandas() + return results_df["id"].tolist() + + except Exception: + msg = "Search embedding failed" + log.exception(msg) + return [] + + def search_embedding_range( + self, + query: list[float], + k: int = 100, + filters: dict | None = None, + distance: float | None = None, + timeout: int | None = None, + **kwargs: Any, + ) -> list[int]: + try: + self._ensure_client_initialized() + # Map metric functions to doris-vector-search metric types + metric_type = "l2_distance" + if self.search_fn.startswith("inner_product"): + metric_type = "inner_product" + elif self.search_fn.startswith("l2_distance"): + metric_type = "l2_distance" + + # Perform range search using doris-vector-search + search_query = self.table.search(query, metric_type=metric_type).select(["id"]) + + # Apply distance range + if distance is not None: + if self.search_fn.startswith("inner_product"): + adjusted_distance = distance - 0.000001 + search_query = search_query.distance_range(lower_bound=adjusted_distance) + else: + adjusted_distance = distance + 0.000001 + search_query = search_query.distance_range(upper_bound=adjusted_distance) + + # Execute and get results + results_df = search_query.to_pandas() + return results_df["id"].tolist() + + except Exception: + msg = "Range search failed" + log.exception(msg) + return [] + + def search_embedding_compound( + self, + query: list[float], + k: int = 100, + filters: dict | None = None, + distance: float | None = None, + timeout: int | None = None, + **kwargs: Any, + ) -> list[int]: + try: + self._ensure_client_initialized() + # Map metric functions to doris-vector-search metric types + metric_type = "l2_distance" + if self.search_fn.startswith("inner_product"): + metric_type = "inner_product" + elif self.search_fn.startswith("l2_distance"): + metric_type = "l2_distance" + + # Perform compound search using doris-vector-search + search_query = self.table.search(query, metric_type=metric_type).limit(k).select(["id"]) + + # Apply distance range + if distance is not None: + if self.search_fn.startswith("inner_product"): + adjusted_distance = distance - 0.000001 + search_query = search_query.distance_range(lower_bound=adjusted_distance) + else: + adjusted_distance = distance + 0.000001 + search_query = search_query.distance_range(upper_bound=adjusted_distance) + + # Execute and get results + results_df = search_query.to_pandas() + return results_df["id"].tolist() + + except Exception: + msg = "Compound search failed" + log.exception(msg) + return [] + + def search_distance(self, query: list[float], target_id: int | None = None): + try: + self._ensure_client_initialized() + metric_type = self.search_fn + if metric_type.endswith("_approximate"): + metric_type = metric_type.replace("_approximate", "") + + search_metric = "inner_product" if metric_type.startswith("inner_product") else "l2_distance" + where_clause = f"id = {target_id}" + search_query = self.table.search(query, metric_type=search_metric).where(where_clause).select(["id"]) + results_df = search_query.to_pandas() + except Exception: + msg = "Distance search failed" + log.exception(msg) + return [] + + # For now, return a placeholder distance + # The exact distance calculation would need custom SQL or library support + return [0.0] if not results_df.empty else [] + + def search_embedding_exact( + self, + query: list[float], + k: int = 100, + filters: dict | None = None, + timeout: int | None = None, + **kwargs: Any, + ) -> list[int]: + try: + self._ensure_client_initialized() + # Use exact search by removing approximate suffix + metric_type = self.search_fn + if metric_type.endswith("_approximate"): + metric_type = metric_type.replace("_approximate", "") + + # Map to doris-vector-search metric types + search_metric = "inner_product" if metric_type.startswith("inner_product") else "l2_distance" + + # Perform exact search + search_query = self.table.search(query, metric_type=search_metric).limit(k).select(["id"]) + + # Apply filters if provided + if filters and "id" in filters: + if metric_type.startswith("inner_product"): + where_clause = f"id >= {filters['id']}" + search_query = search_query.where(where_clause) + else: + where_clause = f"id < {filters['id']}" + search_query = search_query.where(where_clause) + + # Execute and get results + results_df = search_query.to_pandas() + return results_df["id"].tolist() + + except Exception: + msg = "Exact search failed" + log.exception(msg) + return [] diff --git a/vectordb_bench/backend/runner/rate_runner.py b/vectordb_bench/backend/runner/rate_runner.py index 163d50689..ecfc53862 100644 --- a/vectordb_bench/backend/runner/rate_runner.py +++ b/vectordb_bench/backend/runner/rate_runner.py @@ -7,6 +7,7 @@ from vectordb_bench import config from vectordb_bench.backend.clients import api +from vectordb_bench.backend.clients.doris.doris import Doris from vectordb_bench.backend.dataset import DataSetIterator from vectordb_bench.backend.utils import time_it @@ -53,6 +54,18 @@ def _insert_embeddings(db: api.VectorDB, emb: list[list[float]], metadata: list[ db_copy = deepcopy(db) with db_copy.init(): _insert_embeddings(db_copy, emb, metadata, retry_idx=0) + elif isinstance(db, Doris): + # DorisVectorClient is not thread-safe. Similar to pgvector, create a per-thread client + # by deep-copying the wrapper and forcing lazy re-init inside the thread. + db_copy = deepcopy(db) + # Ensure a fresh client/table will be created in this thread + try: + db_copy.client = None + db_copy.table = None + except Exception: + log.debug("Failed to reset Doris client or table on thread-local copy", exc_info=True) + with db_copy.init(): + _insert_embeddings(db_copy, emb, metadata, retry_idx=0) else: _insert_embeddings(db, emb, metadata, retry_idx=0) diff --git a/vectordb_bench/backend/task_runner.py b/vectordb_bench/backend/task_runner.py index 38f69244a..4a7ba2b3a 100644 --- a/vectordb_bench/backend/task_runner.py +++ b/vectordb_bench/backend/task_runner.py @@ -1,5 +1,7 @@ import concurrent +import hashlib import logging +import re import traceback from enum import Enum, auto @@ -11,7 +13,7 @@ from ..models import PerformanceTimeoutError, TaskConfig, TaskStage from . import utils from .cases import Case, CaseLabel, StreamingPerformanceCase -from .clients import MetricType, api +from .clients import DB, MetricType, api from .data_source import DatasetSource from .runner import MultiProcessingSearchRunner, ReadWriteRunner, SerialInsertRunner, SerialSearchRunner @@ -97,6 +99,23 @@ def normalize(self) -> bool: def init_db(self, drop_old: bool = True) -> None: db_cls = self.config.db.init_cls + # Compose a compact, case-unique collection/table name for Doris to avoid cross-case interference + collection_name = None + try: + if self.config.db == DB.Doris: + # Primary identifier = case-type enum name from CLI (e.g., Performance768D10M) + case_type_name = self.config.case_config.case_id.name + base = f"{case_type_name.lower()}" + # Sanitize to [a-z0-9_] + base = re.sub(r"[^a-z0-9_]+", "_", base).strip("_") + # Cap to 63 chars; add short hash if truncated + if len(base) > 63: + h = hashlib.md5(base.encode(), usedforsecurity=False).hexdigest()[:6] + base = f"{base[:(63-7)]}_{h}" + collection_name = base + except Exception: + # If anything goes wrong, fall back silently; Doris will use its default name logic + collection_name = None self.db = db_cls( dim=self.ca.dataset.data.dim, @@ -104,6 +123,7 @@ def init_db(self, drop_old: bool = True) -> None: db_case_config=self.config.db_case_config, drop_old=drop_old, with_scalar_labels=self.ca.with_scalar_labels, + **({"collection_name": collection_name} if collection_name else {}), ) def _pre_run(self, drop_old: bool = True): diff --git a/vectordb_bench/cli/vectordbbench.py b/vectordb_bench/cli/vectordbbench.py index fb4ca7dac..6b92ea081 100644 --- a/vectordb_bench/cli/vectordbbench.py +++ b/vectordb_bench/cli/vectordbbench.py @@ -2,6 +2,7 @@ from ..backend.clients.alloydb.cli import AlloyDBScaNN from ..backend.clients.aws_opensearch.cli import AWSOpenSearch from ..backend.clients.clickhouse.cli import Clickhouse +from ..backend.clients.doris.cli import Doris from ..backend.clients.hologres.cli import HologresHGraph from ..backend.clients.lancedb.cli import LanceDB from ..backend.clients.mariadb.cli import MariaDBHNSW @@ -54,6 +55,7 @@ cli.add_command(S3Vectors) cli.add_command(TencentElasticsearch) cli.add_command(AliSQLHNSW) +cli.add_command(Doris) if __name__ == "__main__": diff --git a/vectordb_bench/frontend/config/styles.py b/vectordb_bench/frontend/config/styles.py index 4723ec7fd..cbc199cdb 100644 --- a/vectordb_bench/frontend/config/styles.py +++ b/vectordb_bench/frontend/config/styles.py @@ -68,6 +68,7 @@ def getPatternShape(i): DB.OceanBase: "", DB.S3Vectors: "https://assets.zilliz.com/s3_vectors_daf370b4e5.png", DB.Hologres: "https://img.alicdn.com/imgextra/i3/O1CN01d9qrry1i6lTNa2BRa_!!6000000004364-2-tps-218-200.png", + DB.Doris: "https://doris.apache.org/images/logo.svg", } # RedisCloud color: #0D6EFD @@ -85,6 +86,7 @@ def getPatternShape(i): DB.OSSOpenSearch.value: "#0DCAF0", DB.TiDB.value: "#0D6EFD", DB.Vespa.value: "#61d790", + DB.Doris.value: "#52CAA3", } COLORS_10 = [