diff --git a/README.md b/README.md index 0d286aa4e..aa89c1a3a 100644 --- a/README.md +++ b/README.md @@ -60,6 +60,7 @@ All the database client supported | oceanbase | `pip install vectordb-bench[oceanbase]` | | hologres | `pip install vectordb-bench[hologres]` | | tencent_es | `pip install vectordb-bench[tencent_es]` | +| alisql | `pip install 'vectordb-bench[alisql]'` | ### Run diff --git a/pyproject.toml b/pyproject.toml index 404e48616..70e73a407 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -99,6 +99,7 @@ clickhouse = [ "clickhouse-connect" ] vespa = [ "pyvespa" ] lancedb = [ "lancedb" ] oceanbase = [ "mysql-connector-python" ] +alisql = [ "mysql-connector-python" ] [project.urls] "repository" = "https://github.com/zilliztech/VectorDBBench" diff --git a/vectordb_bench/backend/clients/__init__.py b/vectordb_bench/backend/clients/__init__.py index 7ddb04383..65a254c1e 100644 --- a/vectordb_bench/backend/clients/__init__.py +++ b/vectordb_bench/backend/clients/__init__.py @@ -52,6 +52,7 @@ class DB(Enum): S3Vectors = "S3Vectors" Hologres = "Alibaba Cloud Hologres" TencentElasticsearch = "TencentElasticsearch" + AliSQL = "AlibabaCloudRDSMySQL" @property def init_cls(self) -> type[VectorDB]: # noqa: PLR0911, PLR0912, C901, PLR0915 @@ -206,6 +207,11 @@ def init_cls(self) -> type[VectorDB]: # noqa: PLR0911, PLR0912, C901, PLR0915 return TencentElasticsearch + if self == DB.AliSQL: + from .alisql.alisql import AliSQL + + return AliSQL + msg = f"Unknown DB: {self.name}" raise ValueError(msg) @@ -362,10 +368,15 @@ def config_cls(self) -> type[DBConfig]: # noqa: PLR0911, PLR0912, C901, PLR0915 return TencentElasticsearchConfig + if self == DB.AliSQL: + from .alisql.config import AliSQLConfig + + return AliSQLConfig + msg = f"Unknown DB: {self.name}" raise ValueError(msg) - def case_config_cls( # noqa: C901, PLR0911, PLR0912 + def case_config_cls( # noqa: C901, PLR0911, PLR0912, PLR0915 self, index_type: IndexType | None = None, ) -> type[DBCaseConfig]: @@ -493,6 +504,11 @@ def case_config_cls( # noqa: C901, PLR0911, PLR0912 return TencentElasticsearchIndexConfig + if self == DB.AliSQL: + from .alisql.alisql import AliSQLIndexConfig + + return AliSQLIndexConfig + # DB.Pinecone, DB.Chroma, DB.Redis return EmptyDBCaseConfig diff --git a/vectordb_bench/backend/clients/alisql/alisql.py b/vectordb_bench/backend/clients/alisql/alisql.py new file mode 100644 index 000000000..c3c2fd953 --- /dev/null +++ b/vectordb_bench/backend/clients/alisql/alisql.py @@ -0,0 +1,215 @@ +import logging +from contextlib import contextmanager + +import mysql.connector as mysql +import numpy as np + +from ..api import VectorDB +from .config import AliSQLConfigDict, AliSQLIndexConfig + +log = logging.getLogger(__name__) + + +class AliSQL(VectorDB): + def __init__( + self, + dim: int, + db_config: AliSQLConfigDict, + db_case_config: AliSQLIndexConfig, + collection_name: str = "vec_collection", + drop_old: bool = False, + **kwargs, + ): + self.name = "AliSQL" + self.db_config = db_config + self.case_config = db_case_config + self.db_name = "vectordbbench" + self.table_name = collection_name + self.dim = dim + + # construct basic units + self.conn, self.cursor = self._create_connection() + + if drop_old: + self._drop_db() + self._create_db_table(dim) + + self.cursor.close() + self.conn.close() + self.cursor = None + self.conn = None + + def _create_connection(self): + conn = mysql.connect( + host=self.db_config["host"], + user=self.db_config["user"], + port=self.db_config["port"], + password=self.db_config["password"], + buffered=True, + ) + cursor = conn.cursor() + + assert conn is not None, "Connection is not initialized" + assert cursor is not None, "Cursor is not initialized" + + return conn, cursor + + def _drop_db(self): + assert self.conn is not None, "Connection is not initialized" + assert self.cursor is not None, "Cursor is not initialized" + log.info(f"{self.name} client drop db : {self.db_name}") + + # flush tables before dropping database to avoid some locking issue + self.cursor.execute("FLUSH TABLES") + self.cursor.execute(f"DROP DATABASE IF EXISTS {self.db_name}") + self.cursor.execute("COMMIT") + self.cursor.execute("FLUSH TABLES") + + def _create_db_table(self, dim: int): + assert self.conn is not None, "Connection is not initialized" + assert self.cursor is not None, "Cursor is not initialized" + + try: + log.info(f"{self.name} client create database : {self.db_name}") + self.cursor.execute(f"CREATE DATABASE {self.db_name}") + + log.info(f"{self.name} client create table : {self.table_name}") + self.cursor.execute(f"USE {self.db_name}") + + self.cursor.execute( + f""" + CREATE TABLE {self.table_name} ( + id INT PRIMARY KEY, + v VECTOR({self.dim}) NOT NULL + ) + """ + ) + self.cursor.execute("COMMIT") + + except Exception as e: + log.warning(f"Failed to create table: {self.table_name} error: {e}") + raise e from None + + @contextmanager + def init(self): + """create and destory connections to database. + + Examples: + >>> with self.init(): + >>> self.insert_embeddings() + """ + self.conn, self.cursor = self._create_connection() + + index_param = self.case_config.index_param() + search_param = self.case_config.search_param() + + # maximize allowed package size + self.cursor.execute("SET GLOBAL max_allowed_packet = 1073741824") + + if index_param["index_type"] == "HNSW": + if index_param["cache_size"] is not None: + self.cursor.execute(f"SET GLOBAL vidx_hnsw_cache_size = {index_param['cache_size']}") + if search_param["ef_search"] is not None: + self.cursor.execute(f"SET GLOBAL vidx_hnsw_ef_search = {search_param['ef_search']}") + self.cursor.execute("COMMIT") + + self.insert_sql = f"INSERT INTO {self.db_name}.{self.table_name} (id, v) VALUES (%s, %s)" # noqa: S608 + self.select_sql = ( + f"SELECT id FROM {self.db_name}.{self.table_name} " # noqa: S608 + f"ORDER by vec_distance_{search_param['metric_type']}(v, %s) LIMIT %s" + ) + self.select_sql_with_filter = ( + f"SELECT id FROM {self.db_name}.{self.table_name} WHERE id >= %s " # noqa: S608 + f"ORDER by vec_distance_{search_param['metric_type']}(v, %s) LIMIT %s" + ) + + try: + yield + finally: + self.cursor.close() + self.conn.close() + self.cursor = None + self.conn = None + + def ready_to_load(self) -> bool: + pass + + def optimize(self, data_size: int) -> None: + assert self.conn is not None, "Connection is not initialized" + assert self.cursor is not None, "Cursor is not initialized" + + index_param = self.case_config.index_param() + + try: + index_options = f"DISTANCE={index_param['metric_type']}" + if index_param["index_type"] == "HNSW" and index_param["M"] is not None: + index_options += f" M={index_param['M']}" + + self.cursor.execute( + f""" + ALTER TABLE {self.db_name}.{self.table_name} + ADD VECTOR KEY v(v) {index_options} + """ + ) + self.cursor.execute("COMMIT") + + except Exception as e: + log.warning(f"Failed to create index: {self.table_name} error: {e}") + raise e from None + + @staticmethod + def vector_to_hex(v): # noqa: ANN001 + return np.array(v, "float32").tobytes() + + def insert_embeddings( + self, + embeddings: list[list[float]], + metadata: list[int], + **kwargs, + ) -> tuple[int, Exception]: + """Insert embeddings into the database. + Should call self.init() first. + """ + assert self.conn is not None, "Connection is not initialized" + assert self.cursor is not None, "Cursor is not initialized" + + try: + metadata_arr = np.array(metadata) + embeddings_arr = np.array(embeddings) + + batch_data = [] + for i, row in enumerate(metadata_arr): + batch_data.append((int(row), self.vector_to_hex(embeddings_arr[i]))) + + self.cursor.executemany(self.insert_sql, batch_data) + self.cursor.execute("COMMIT") + self.cursor.execute("FLUSH TABLES") + + return len(metadata), None + except Exception as e: + log.warning(f"Failed to insert data into Vector table ({self.table_name}), error: {e}") + return 0, e + + def search_embedding( + self, + query: list[float], + k: int = 100, + filters: dict | None = None, + timeout: int | None = None, + **kwargs, + ) -> list[int]: + assert self.conn is not None, "Connection is not initialized" + assert self.cursor is not None, "Cursor is not initialized" + + search_param = self.case_config.search_param() # noqa: F841 + + try: + if filters: + self.cursor.execute(self.select_sql_with_filter, (filters.get("id"), self.vector_to_hex(query), k)) + else: + self.cursor.execute(self.select_sql, (self.vector_to_hex(query), k)) + return [row[0] for row in self.cursor.fetchall()] + + except mysql.Error: + log.exception("Failed to execute search query") + raise diff --git a/vectordb_bench/backend/clients/alisql/cli.py b/vectordb_bench/backend/clients/alisql/cli.py new file mode 100644 index 000000000..dbbc8513d --- /dev/null +++ b/vectordb_bench/backend/clients/alisql/cli.py @@ -0,0 +1,111 @@ +from typing import Annotated, Unpack + +import click +from pydantic import SecretStr + +from vectordb_bench.backend.clients import DB + +from ....cli.cli import ( + CommonTypedDict, + cli, + click_parameter_decorators_from_typed_dict, + run, +) + + +class AliSQLTypedDict(CommonTypedDict): + user_name: Annotated[ + str, + click.option( + "--username", + type=str, + help="Username", + required=True, + ), + ] + password: Annotated[ + str, + click.option( + "--password", + type=str, + help="Password", + required=True, + ), + ] + + host: Annotated[ + str, + click.option( + "--host", + type=str, + help="Db host", + default="127.0.0.1", + ), + ] + + port: Annotated[ + int, + click.option( + "--port", + type=int, + default=3306, + help="DB Port", + ), + ] + + +class AliSQLHNSWTypedDict(AliSQLTypedDict): + m: Annotated[ + int | None, + click.option( + "--m", + type=int, + help="M parameter in HNSW vector indexing", + required=False, + ), + ] + + ef_search: Annotated[ + int | None, + click.option( + "--ef-search", + type=int, + help="AliSQL system variable vidx_hnsw_ef_search", + required=False, + ), + ] + + cache_size: Annotated[ + int | None, + click.option( + "--cache-size", + type=int, + help="AliSQL system variable vidx_hnsw_cache_size", + required=False, + ), + ] + + +@cli.command() +@click_parameter_decorators_from_typed_dict(AliSQLHNSWTypedDict) +def AliSQLHNSW( + **parameters: Unpack[AliSQLHNSWTypedDict], +): + from .config import AliSQLConfig, AliSQLHNSWConfig + + run( + db=DB.AliSQL, + db_config=AliSQLConfig( + db_label=parameters["db_label"], + user_name=parameters["username"], + password=SecretStr(parameters["password"]), + host=parameters["host"], + port=parameters["port"], + ), + db_case_config=AliSQLHNSWConfig( + M=parameters["m"], + ef_search=parameters["ef_search"], + cache_size=parameters["cache_size"], + ), + **parameters, + ) diff --git a/vectordb_bench/backend/clients/alisql/config.py b/vectordb_bench/backend/clients/alisql/config.py new file mode 100644 index 000000000..eddf19e98 --- /dev/null +++ b/vectordb_bench/backend/clients/alisql/config.py @@ -0,0 +1,71 @@ +from typing import TypedDict + +from pydantic import BaseModel, SecretStr + +from ..api import DBCaseConfig, DBConfig, IndexType, MetricType + + +class AliSQLConfigDict(TypedDict): + """These keys will be directly used as kwargs in alisql connection string, + so the names must match exactly alisql API""" + + user: str + password: str + host: str + port: int + + +class AliSQLConfig(DBConfig): + user_name: str = "root" + password: SecretStr + host: str = "127.0.0.1" + port: int = 3306 + + def to_dict(self) -> AliSQLConfigDict: + pwd_str = self.password.get_secret_value() + return { + "host": self.host, + "port": self.port, + "user": self.user_name, + "password": pwd_str, + } + + +class AliSQLIndexConfig(BaseModel): + """Base config for AliSQL""" + + metric_type: MetricType | None = None + + def parse_metric(self) -> str: + if self.metric_type == MetricType.L2: + return "euclidean" + if self.metric_type == MetricType.COSINE: + return "cosine" + msg = f"Metric type {self.metric_type} is not supported!" + raise ValueError(msg) + + +class AliSQLHNSWConfig(AliSQLIndexConfig, DBCaseConfig): + M: int | None + ef_search: int | None + index: IndexType = IndexType.HNSW + cache_size: int | None + + def index_param(self) -> dict: + return { + "metric_type": self.parse_metric(), + "index_type": self.index.value, + "M": self.M, + "cache_size": self.cache_size, + } + + def search_param(self) -> dict: + return { + "metric_type": self.parse_metric(), + "ef_search": self.ef_search, + } + + +_alisql_case_config = { + IndexType.HNSW: AliSQLHNSWConfig, +} diff --git a/vectordb_bench/cli/vectordbbench.py b/vectordb_bench/cli/vectordbbench.py index 3d8309fa6..fb4ca7dac 100644 --- a/vectordb_bench/cli/vectordbbench.py +++ b/vectordb_bench/cli/vectordbbench.py @@ -1,3 +1,4 @@ +from ..backend.clients.alisql.cli import AliSQLHNSW from ..backend.clients.alloydb.cli import AlloyDBScaNN from ..backend.clients.aws_opensearch.cli import AWSOpenSearch from ..backend.clients.clickhouse.cli import Clickhouse @@ -52,6 +53,7 @@ cli.add_command(BatchCli) cli.add_command(S3Vectors) cli.add_command(TencentElasticsearch) +cli.add_command(AliSQLHNSW) if __name__ == "__main__": diff --git a/vectordb_bench/frontend/config/dbCaseConfigs.py b/vectordb_bench/frontend/config/dbCaseConfigs.py index 4ddc07b0a..003c84fe6 100644 --- a/vectordb_bench/frontend/config/dbCaseConfigs.py +++ b/vectordb_bench/frontend/config/dbCaseConfigs.py @@ -1550,6 +1550,53 @@ class CaseConfigInput(BaseModel): }, ) +CaseConfigParamInput_M_AliSQL = CaseConfigInput( + label=CaseConfigParamType.M, + inputHelp="M parameter in HNSW vector indexing", + inputType=InputType.Number, + inputConfig={ + "min": 3, + "max": 200, + "value": 6, + }, + isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) == IndexType.HNSW.value, +) + +CaseConfigParamInput_EFSearch_AliSQL = CaseConfigInput( + label=CaseConfigParamType.ef_search, + inputHelp="vidx_hnsw_ef_search", + inputType=InputType.Number, + inputConfig={ + "min": 1, + "max": 10000, + "value": 20, + }, + isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) == IndexType.HNSW.value, +) + +CaseConfigParamInput_CacheSize_AliSQL = CaseConfigInput( + label=CaseConfigParamType.cache_size, + inputHelp="vidx_hnsw_cache_size", + inputType=InputType.Number, + inputConfig={ + "min": 1048576, + "max": (1 << 53) - 1, + "value": 16 * 1024**3, + }, + isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) == IndexType.HNSW.value, +) + +CaseConfigParamInput_IndexType_AliSQL = CaseConfigInput( + label=CaseConfigParamType.IndexType, + inputHelp="Select Index Type", + inputType=InputType.Option, + inputConfig={ + "options": [ + IndexType.HNSW.value, + ], + }, +) + CaseConfigParamInput_M_MariaDB = CaseConfigInput( label=CaseConfigParamType.M, inputHelp="M parameter in MHNSW vector indexing", @@ -2026,6 +2073,20 @@ class CaseConfigInput(BaseModel): ] VespaPerformanceConfig = VespaLoadingConfig +AliSQLLoadingConfig = [ + CaseConfigParamInput_IndexType_AliSQL, + CaseConfigParamInput_StorageEngine_AliSQL, + CaseConfigParamInput_M_AliSQL, + CaseConfigParamInput_CacheSize_AliSQL, +] +AliSQLPerformanceConfig = [ + CaseConfigParamInput_IndexType_AliSQL, + CaseConfigParamInput_StorageEngine_AliSQL, + CaseConfigParamInput_M_AliSQL, + CaseConfigParamInput_CacheSize_AliSQL, + CaseConfigParamInput_EFSearch_AliSQL, +] + CaseConfigParamInput_IndexType_LanceDB = CaseConfigInput( label=CaseConfigParamType.IndexType, inputHelp="AUTOINDEX = IVFPQ with default parameters", @@ -2250,6 +2311,10 @@ class CaseConfigInput(BaseModel): CaseLabel.Load: TencentElasticsearchLoadingConfig, CaseLabel.Performance: TencentElasticsearchPerformanceConfig, }, + DB.AliSQL: { + CaseLabel.Load: AliSQLLoadingConfig, + CaseLabel.Performance: AliSQLPerformanceConfig, + }, }