Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ All the database client supported
| oceanbase | `pip install vectordb-bench[oceanbase]` |
| hologres | `pip install vectordb-bench[hologres]` |
| tencent_es | `pip install vectordb-bench[tencent_es]` |
| alisql | `pip install 'vectordb-bench[alisql]'` |

### Run

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ clickhouse = [ "clickhouse-connect" ]
vespa = [ "pyvespa" ]
lancedb = [ "lancedb" ]
oceanbase = [ "mysql-connector-python" ]
alisql = [ "mysql-connector-python" ]

[project.urls]
"repository" = "https://github.com/zilliztech/VectorDBBench"
Expand Down
18 changes: 17 additions & 1 deletion vectordb_bench/backend/clients/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ class DB(Enum):
S3Vectors = "S3Vectors"
Hologres = "Alibaba Cloud Hologres"
TencentElasticsearch = "TencentElasticsearch"
AliSQL = "AlibabaCloudRDSMySQL"

@property
def init_cls(self) -> type[VectorDB]: # noqa: PLR0911, PLR0912, C901, PLR0915
Expand Down Expand Up @@ -206,6 +207,11 @@ def init_cls(self) -> type[VectorDB]: # noqa: PLR0911, PLR0912, C901, PLR0915

return TencentElasticsearch

if self == DB.AliSQL:
from .alisql.alisql import AliSQL

return AliSQL

msg = f"Unknown DB: {self.name}"
raise ValueError(msg)

Expand Down Expand Up @@ -362,10 +368,15 @@ def config_cls(self) -> type[DBConfig]: # noqa: PLR0911, PLR0912, C901, PLR0915

return TencentElasticsearchConfig

if self == DB.AliSQL:
from .alisql.config import AliSQLConfig

return AliSQLConfig

msg = f"Unknown DB: {self.name}"
raise ValueError(msg)

def case_config_cls( # noqa: C901, PLR0911, PLR0912
def case_config_cls( # noqa: C901, PLR0911, PLR0912, PLR0915
self,
index_type: IndexType | None = None,
) -> type[DBCaseConfig]:
Expand Down Expand Up @@ -493,6 +504,11 @@ def case_config_cls( # noqa: C901, PLR0911, PLR0912

return TencentElasticsearchIndexConfig

if self == DB.AliSQL:
from .alisql.alisql import AliSQLIndexConfig

return AliSQLIndexConfig

# DB.Pinecone, DB.Chroma, DB.Redis
return EmptyDBCaseConfig

Expand Down
215 changes: 215 additions & 0 deletions vectordb_bench/backend/clients/alisql/alisql.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
import logging
from contextlib import contextmanager

import mysql.connector as mysql
import numpy as np

from ..api import VectorDB
from .config import AliSQLConfigDict, AliSQLIndexConfig

log = logging.getLogger(__name__)


class AliSQL(VectorDB):
def __init__(
self,
dim: int,
db_config: AliSQLConfigDict,
db_case_config: AliSQLIndexConfig,
collection_name: str = "vec_collection",
drop_old: bool = False,
**kwargs,
):
self.name = "AliSQL"
self.db_config = db_config
self.case_config = db_case_config
self.db_name = "vectordbbench"
self.table_name = collection_name
self.dim = dim

# construct basic units
self.conn, self.cursor = self._create_connection()

if drop_old:
self._drop_db()
self._create_db_table(dim)

self.cursor.close()
self.conn.close()
self.cursor = None
self.conn = None

def _create_connection(self):
conn = mysql.connect(
host=self.db_config["host"],
user=self.db_config["user"],
port=self.db_config["port"],
password=self.db_config["password"],
buffered=True,
)
cursor = conn.cursor()

assert conn is not None, "Connection is not initialized"
assert cursor is not None, "Cursor is not initialized"

return conn, cursor

def _drop_db(self):
assert self.conn is not None, "Connection is not initialized"
assert self.cursor is not None, "Cursor is not initialized"
log.info(f"{self.name} client drop db : {self.db_name}")

# flush tables before dropping database to avoid some locking issue
self.cursor.execute("FLUSH TABLES")
self.cursor.execute(f"DROP DATABASE IF EXISTS {self.db_name}")
self.cursor.execute("COMMIT")
self.cursor.execute("FLUSH TABLES")

def _create_db_table(self, dim: int):
assert self.conn is not None, "Connection is not initialized"
assert self.cursor is not None, "Cursor is not initialized"

try:
log.info(f"{self.name} client create database : {self.db_name}")
self.cursor.execute(f"CREATE DATABASE {self.db_name}")

log.info(f"{self.name} client create table : {self.table_name}")
self.cursor.execute(f"USE {self.db_name}")

self.cursor.execute(
f"""
CREATE TABLE {self.table_name} (
id INT PRIMARY KEY,
v VECTOR({self.dim}) NOT NULL
)
"""
)
self.cursor.execute("COMMIT")

except Exception as e:
log.warning(f"Failed to create table: {self.table_name} error: {e}")
raise e from None

@contextmanager
def init(self):
"""create and destory connections to database.

Examples:
>>> with self.init():
>>> self.insert_embeddings()
"""
self.conn, self.cursor = self._create_connection()

index_param = self.case_config.index_param()
search_param = self.case_config.search_param()

# maximize allowed package size
self.cursor.execute("SET GLOBAL max_allowed_packet = 1073741824")

if index_param["index_type"] == "HNSW":
if index_param["cache_size"] is not None:
self.cursor.execute(f"SET GLOBAL vidx_hnsw_cache_size = {index_param['cache_size']}")
if search_param["ef_search"] is not None:
self.cursor.execute(f"SET GLOBAL vidx_hnsw_ef_search = {search_param['ef_search']}")
self.cursor.execute("COMMIT")

self.insert_sql = f"INSERT INTO {self.db_name}.{self.table_name} (id, v) VALUES (%s, %s)" # noqa: S608
self.select_sql = (
f"SELECT id FROM {self.db_name}.{self.table_name} " # noqa: S608
f"ORDER by vec_distance_{search_param['metric_type']}(v, %s) LIMIT %s"
)
self.select_sql_with_filter = (
f"SELECT id FROM {self.db_name}.{self.table_name} WHERE id >= %s " # noqa: S608
f"ORDER by vec_distance_{search_param['metric_type']}(v, %s) LIMIT %s"
)

try:
yield
finally:
self.cursor.close()
self.conn.close()
self.cursor = None
self.conn = None

def ready_to_load(self) -> bool:
pass

def optimize(self, data_size: int) -> None:
assert self.conn is not None, "Connection is not initialized"
assert self.cursor is not None, "Cursor is not initialized"

index_param = self.case_config.index_param()

try:
index_options = f"DISTANCE={index_param['metric_type']}"
if index_param["index_type"] == "HNSW" and index_param["M"] is not None:
index_options += f" M={index_param['M']}"

self.cursor.execute(
f"""
ALTER TABLE {self.db_name}.{self.table_name}
ADD VECTOR KEY v(v) {index_options}
"""
)
self.cursor.execute("COMMIT")

except Exception as e:
log.warning(f"Failed to create index: {self.table_name} error: {e}")
raise e from None

@staticmethod
def vector_to_hex(v): # noqa: ANN001
return np.array(v, "float32").tobytes()

def insert_embeddings(
self,
embeddings: list[list[float]],
metadata: list[int],
**kwargs,
) -> tuple[int, Exception]:
"""Insert embeddings into the database.
Should call self.init() first.
"""
assert self.conn is not None, "Connection is not initialized"
assert self.cursor is not None, "Cursor is not initialized"

try:
metadata_arr = np.array(metadata)
embeddings_arr = np.array(embeddings)

batch_data = []
for i, row in enumerate(metadata_arr):
batch_data.append((int(row), self.vector_to_hex(embeddings_arr[i])))

self.cursor.executemany(self.insert_sql, batch_data)
self.cursor.execute("COMMIT")
self.cursor.execute("FLUSH TABLES")

return len(metadata), None
except Exception as e:
log.warning(f"Failed to insert data into Vector table ({self.table_name}), error: {e}")
return 0, e

def search_embedding(
self,
query: list[float],
k: int = 100,
filters: dict | None = None,
timeout: int | None = None,
**kwargs,
) -> list[int]:
assert self.conn is not None, "Connection is not initialized"
assert self.cursor is not None, "Cursor is not initialized"

search_param = self.case_config.search_param() # noqa: F841

try:
if filters:
self.cursor.execute(self.select_sql_with_filter, (filters.get("id"), self.vector_to_hex(query), k))
else:
self.cursor.execute(self.select_sql, (self.vector_to_hex(query), k))
return [row[0] for row in self.cursor.fetchall()]

except mysql.Error:
log.exception("Failed to execute search query")
raise
111 changes: 111 additions & 0 deletions vectordb_bench/backend/clients/alisql/cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
from typing import Annotated, Unpack

import click
from pydantic import SecretStr

from vectordb_bench.backend.clients import DB

from ....cli.cli import (
CommonTypedDict,
cli,
click_parameter_decorators_from_typed_dict,
run,
)


class AliSQLTypedDict(CommonTypedDict):
user_name: Annotated[
str,
click.option(
"--username",
type=str,
help="Username",
required=True,
),
]
password: Annotated[
str,
click.option(
"--password",
type=str,
help="Password",
required=True,
),
]

host: Annotated[
str,
click.option(
"--host",
type=str,
help="Db host",
default="127.0.0.1",
),
]

port: Annotated[
int,
click.option(
"--port",
type=int,
default=3306,
help="DB Port",
),
]


class AliSQLHNSWTypedDict(AliSQLTypedDict):
m: Annotated[
int | None,
click.option(
"--m",
type=int,
help="M parameter in HNSW vector indexing",
required=False,
),
]

ef_search: Annotated[
int | None,
click.option(
"--ef-search",
type=int,
help="AliSQL system variable vidx_hnsw_ef_search",
required=False,
),
]

cache_size: Annotated[
int | None,
click.option(
"--cache-size",
type=int,
help="AliSQL system variable vidx_hnsw_cache_size",
required=False,
),
]


@cli.command()
@click_parameter_decorators_from_typed_dict(AliSQLHNSWTypedDict)
def AliSQLHNSW(
**parameters: Unpack[AliSQLHNSWTypedDict],
):
from .config import AliSQLConfig, AliSQLHNSWConfig

run(
db=DB.AliSQL,
db_config=AliSQLConfig(
db_label=parameters["db_label"],
user_name=parameters["username"],
password=SecretStr(parameters["password"]),
host=parameters["host"],
port=parameters["port"],
),
db_case_config=AliSQLHNSWConfig(
M=parameters["m"],
ef_search=parameters["ef_search"],
cache_size=parameters["cache_size"],
),
**parameters,
)
Loading