Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 60 additions & 62 deletions vectordb_bench/backend/clients/chroma/chroma.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import logging
from contextlib import contextmanager
from typing import Any

import chromadb

from ..api import DBCaseConfig, VectorDB
from ..api import VectorDB

log = logging.getLogger(__name__)

Expand All @@ -16,104 +15,103 @@ class ChromaClient(VectorDB):

To change to running in process, modify the HttpClient() in __init__() and init().
"""

def __init__(
self,
dim: int,
db_config: dict,
db_case_config: DBCaseConfig,
db_case_config,
collection_name: str = "VectorDBBenchCollection",
drop_old: bool = False,
**kwargs,
**kwargs
):
self.db_config = db_config
self.case_config = db_case_config
self.collection_name = "example2"
self.collection_name = collection_name

client = chromadb.HttpClient(host=self.db_config["host"], port=self.db_config["port"])
client = chromadb.HttpClient(**db_config)
assert client.heartbeat() is not None

if drop_old:
try:
client.reset() # Reset the database
client.reset()
except Exception:
drop_old = False
log.info(f"Chroma client drop_old collection: {self.collection_name}")
log.info("Chroma client drop_old collection: "
+ f"{self.collection_name}")

@contextmanager
def init(self) -> None:
"""create and destory connections to database.

Examples:
>>> with self.init():
>>> self.insert_embeddings()
"""
# create connection
self.client = chromadb.HttpClient(host=self.db_config["host"], port=self.db_config["port"])

self.collection = self.client.get_or_create_collection("example2")
yield
self.client = None
self.collection = None

@contextmanager
def init(self):
try:
self.client = chromadb.HttpClient(
host=self.db_config.get("host", "localhost"),
port=self.db_config.get("port", 8000)
)

self.collection = self.client.get_or_create_collection(
name=self.collection_name,
configuration=self.case_config.index_param()
)
yield
self.client = None
self.collection = None
except Exception as e:
log.error(f"Failed to initialize Chroma client: {e}")
raise e

def ready_to_search(self) -> bool:
pass

def optimize(self, data_size: int | None = None):
pass
assert self.collection is not None, "Please call self.init() before"
try:
self.collection.modify(
configuration=self.case_config.search_param()
)
except Exception as e:
log.warning(f"Optimize error: {e}")
raise e

def insert_embeddings(
self,
embeddings: list[list[float]],
metadata: list[int],
**kwargs: Any,
**kwargs,
) -> tuple[int, Exception]:
"""Insert embeddings into the database.

Args:
embeddings(list[list[float]]): list of embeddings
metadata(list[int]): list of metadata
kwargs: other arguments

Returns:
tuple[int, Exception]: number of embeddings inserted and exception if any
"""
ids = [str(i) for i in metadata]
metadata = [{"id": int(i)} for i in metadata]
assert self.collection is not None, "Please call self.init() before"
ids = [f"{idx}" for idx in metadata]
metadata = [{"index": mid} for mid in metadata]
try:
if len(embeddings) > 0:
self.collection.add(embeddings=embeddings, ids=ids, metadatas=metadata)
self.collection.add(
ids=ids,
embeddings=embeddings,
metadatas=metadata
)
except Exception as e:
log.warning(f"Failed to insert data: error: {e!s}")
log.info(f"Failed to insert data: {e}")
return 0, e
return len(embeddings), None

return len(metadata), None

def search_embedding(
self,
query: list[float],
k: int = 100,
filters: dict | None = None,
timeout: int | None = None,
**kwargs: Any,
) -> dict:
"""Search embeddings from the database.
Args:
embedding(list[float]): embedding to search
k(int): number of results to return
kwargs: other arguments

Returns:
Dict {ids: list[list[int]],
embedding: list[list[float]]
distance: list[list[float]]}
"""
timeout: int | None = None
) -> list[int]:
assert self.client is not None, "Please call self.init() before"
if filters:
# assumes benchmark test filters of format: {'metadata': '>=10000', 'id': 10000}
id_value = filters.get("id")
results = self.collection.query(
query_embeddings=query,
query_embeddings=[query],
n_results=k,
where={"id": {"$gt": id_value}},
where={"id": {"$gt": filters.get("id")}}
)
else:
results = self.collection.query(
query_embeddings=[query],
n_results=k
)
# return list of id's in results
return [int(i) for i in results.get("ids")[0]]
results = self.collection.query(query_embeddings=query, n_results=k)
return [int(i) for i in results.get("ids")[0]]
return [int(idx) for idx in results['ids'][0]]
55 changes: 55 additions & 0 deletions vectordb_bench/backend/clients/chroma/cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from typing import Annotated, Unpack

import click
from pydantic import SecretStr

from vectordb_bench.backend.clients import DB
from vectordb_bench.cli.cli import (
CommonTypedDict,
cli,
click_parameter_decorators_from_typed_dict,
run,
)


DBTYPE = DB.Chroma


class ChromaTypeDict(CommonTypedDict):
host: Annotated[
str,
click.option("--host", type=str, help="Chroma host")
]
port: Annotated[
int,
click.option("--port", type=int, help="Chroma port", default=8000)
]
m: Annotated[
int,
click.option("--m", type=int, help="HNSW Maximum Neighbors", default=16)
]
ef_construct: Annotated[
int,
click.option("--ef-construct", type=int, help="HNSW efConstruct", default=100)
]
ef_search: Annotated[
int,
click.option("--ef-search", type=int, help="HNSW efSearch", default=100)
]


@cli.command()
@click_parameter_decorators_from_typed_dict(ChromaTypeDict)
def Chroma(**parameters: Unpack[ChromaTypeDict]):
from .config import ChromaConfig, ChromaIndexConfig
run(
db=DBTYPE,
db_config=ChromaConfig(host=SecretStr(parameters["host"]),
port=parameters["port"]),
db_case_config=ChromaIndexConfig(
m=parameters["m"],
ef_construct=parameters["ef_construct"],
ef_search=parameters["ef_search"],
),
**parameters,
)
42 changes: 36 additions & 6 deletions vectordb_bench/backend/clients/chroma/config.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,46 @@
from pydantic import SecretStr

from ..api import DBConfig
from ..api import DBConfig, DBCaseConfig, MetricType


class ChromaConfig(DBConfig):
password: SecretStr
host: SecretStr
port: int
host: SecretStr = "localhost"
port: int = 8000

def to_dict(self) -> dict:
return {
"host": self.host.get_secret_value(),
"port": self.port,
"password": self.password.get_secret_value(),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just double-check, removing "password" from connection info is as expected?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"port": self.port
}


class ChromaIndexConfig(ChromaConfig, DBCaseConfig):
metric_type: MetricType = "cosine"
m: int = 16
ef_construct: int = 100
ef_search: int | None = 100

def parse_metric(self) -> str:
if self.metric_type == MetricType.L2:
return "l2"
elif self.metric_type == MetricType.IP:
return "ip"
elif self.metric_type == MetricType.COSINE:
return "cosine"
else:
raise ValueError(f"Unsupported metric type: {self.metric_type}")

def index_param(self):
return {
"hnsw": {
"space": self.parse_metric(),
"max_neighbors": self.m,
"ef_construction": self.ef_construct,
"ef_search": self.search_param().get("ef_search", 100),
}
}

def search_param(self) -> dict:
return {
"ef_search": self.ef_search
}
3 changes: 3 additions & 0 deletions vectordb_bench/cli/vectordbbench.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from ..backend.clients.alloydb.cli import AlloyDBScaNN
from ..backend.clients.aws_opensearch.cli import AWSOpenSearch
from ..backend.clients.chroma.cli import Chroma
from ..backend.clients.clickhouse.cli import Clickhouse
from ..backend.clients.hologres.cli import HologresHGraph
from ..backend.clients.lancedb.cli import LanceDB
Expand All @@ -24,6 +25,7 @@
from .batch_cli import BatchCli
from .cli import cli


cli.add_command(PgVectorHNSW)
cli.add_command(PgVectoRSHNSW)
cli.add_command(PgVectoRSIVFFlat)
Expand All @@ -50,6 +52,7 @@
cli.add_command(QdrantLocal)
cli.add_command(BatchCli)
cli.add_command(S3Vectors)
cli.add_command(Chroma)


if __name__ == "__main__":
Expand Down
Loading