diff --git a/adala/memories/__init__.py b/adala/memories/__init__.py index 4e8c8242..08254075 100644 --- a/adala/memories/__init__.py +++ b/adala/memories/__init__.py @@ -1,3 +1,4 @@ from .file_memory import FileMemory from .vectordb import VectorDBMemory +from .qdrant_memory import QdrantMemory from .base import Memory diff --git a/adala/memories/qdrant_memory.py b/adala/memories/qdrant_memory.py new file mode 100644 index 00000000..7fcac4ee --- /dev/null +++ b/adala/memories/qdrant_memory.py @@ -0,0 +1,174 @@ +from typing import Any, List, Dict, Optional +import uuid +from pydantic import Field, model_validator + +from .base import Memory + +try: + from qdrant_client import QdrantClient + from qdrant_client.models import Distance, VectorParams, PointStruct + import openai + + QDRANT_AVAILABLE = True +except ImportError: + QDRANT_AVAILABLE = False + + +class QdrantMemory(Memory): + """ + Memory backed by [Qdrant](https://qdrant.tech/). + """ + + model_config = {"arbitrary_types_allowed": True} + + collection_name: str = Field(..., description="Name of the Qdrant collection") + openai_api_key: str = Field(..., description="OpenAI API key for embeddings") + openai_embedding_model: str = Field( + default="text-embedding-3-small", description="OpenAI embedding model" + ) + qdrant_url: Optional[str] = Field( + default=None, description="Qdrant server URL" + ) + qdrant_api_key: Optional[str] = Field( + default=None, description="Qdrant API key for remote instances" + ) + qdrant_client: Optional[QdrantClient] = Field( + default=None, description="Pre-configured QdrantClient instance" + ) + dimension: int = Field(default=1536, description="Vector dimension size") + distance_metric: str = Field( + default="Cosine", description="Distance metric for similarity search" + ) + + _client: Optional[QdrantClient] = None + _openai_client: Optional[openai.OpenAI] = None + + @model_validator(mode="after") + def init_database(self): + if not QDRANT_AVAILABLE: + raise ImportError( + "Qdrant dependencies not available. " + "Please install with: pip install qdrant-client openai" + ) + + if self.qdrant_client is not None and ( + self.qdrant_url is not None or self.qdrant_api_key is not None + ): + raise ValueError( + "Cannot specify both 'qdrant_client' and 'qdrant_url'/'qdrant_api_key'. " + "Use either a pre-configured QdrantClient or URL-based configuration, not both." + ) + + if self.qdrant_client is not None: + self._client = self.qdrant_client + elif self.qdrant_url: + self._client = QdrantClient( + url=self.qdrant_url, api_key=self.qdrant_api_key + ) + else: + raise ValueError( + "No Qdrant configuration provided. Please specify either 'qdrant_client' " + "or 'qdrant_url' to configure the Qdrant connection." + ) + + if not self.openai_api_key: + raise ValueError("OpenAI API key is required but not provided") + self._openai_client = openai.OpenAI(api_key=self.openai_api_key) + + if not self._client.collection_exists(self.collection_name): + self._client.create_collection( + collection_name=self.collection_name, + vectors_config=VectorParams( + size=self.dimension, distance=self._get_distance_metric() + ), + ) + + return self + + def _generate_uuid(self, string: str) -> str: + return uuid.uuid5(uuid.NAMESPACE_URL, string).hex + + def _get_distance_metric(self) -> Distance: + distance_map = { + "Cosine": Distance.COSINE, + "Dot": Distance.DOT, + "Euclidean": Distance.EUCLID, + "Manhattan": Distance.MANHATTAN, + } + return distance_map.get(self.distance_metric, Distance.COSINE) + + def _get_embedding(self, text: str) -> List[float]: + response = self._openai_client.embeddings.create( + model=self.openai_embedding_model, input=text + ) + return response.data[0].embedding + + def _get_embeddings(self, texts: List[str]) -> List[List[float]]: + response = self._openai_client.embeddings.create( + model=self.openai_embedding_model, input=texts + ) + return [data.embedding for data in response.data] + + def remember(self, observation: str, data: Any): + """Store a single observation with its associated data.""" + self.remember_many([observation], [data]) + + def remember_many(self, observations: List[str], data: List[Dict]): + """Store multiple observations with their associated data.""" + + data = [{k: v for k, v in d.items() if v is not None} for d in data] + + embeddings = self._get_embeddings(observations) + + points = [] + for obs, embedding, metadata in zip(observations, embeddings, data): + point_id = self._generate_uuid(obs) + points.append( + PointStruct( + id=point_id, vector=embedding, payload={"text": obs, **metadata} + ) + ) + + self._client.upsert(collection_name=self.collection_name, points=points) + + def retrieve_many(self, observations: List[str], num_results: int = 1) -> List[Any]: + """Retrieve similar observations for multiple queries.""" + results = [] + + for observation in observations: + query_embedding = self._get_embedding(observation) + + search_results = self._client.query_points( + collection_name=self.collection_name, + query=query_embedding, + limit=num_results, + with_payload=True, + ).points + + metadatas = [] + for result in search_results: + payload = result.payload.copy() + + payload.pop("text", None) + metadatas.append(payload) + + results.append(metadatas) + + return results + + def retrieve(self, observation: str, num_results: int = 1) -> Any: + """Retrieve similar observations for a single query.""" + return self.retrieve_many([observation], num_results=num_results)[0] + + def clear(self): + """Clear all data from the collection.""" + + if self._client.collection_exists(self.collection_name): + self._client.delete_collection(self.collection_name) + + self._client.create_collection( + collection_name=self.collection_name, + vectors_config=VectorParams( + size=self.dimension, distance=self._get_distance_metric() + ), + ) diff --git a/poetry.lock b/poetry.lock index 4944a5bc..91bd3474 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 2.2.0 and should not be changed by hand. +# This file is automatically @generated by Poetry 2.1.4 and should not be changed by hand. [[package]] name = "aiohappyeyeballs" @@ -1843,7 +1843,6 @@ files = [ {file = "fastuuid-0.12.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b9b31dd488d0778c36f8279b306dc92a42f16904cba54acca71e107d65b60b0c"}, {file = "fastuuid-0.12.0-cp313-cp313-manylinux_2_34_x86_64.whl", hash = "sha256:b19361ee649365eefc717ec08005972d3d1eb9ee39908022d98e3bfa9da59e37"}, {file = "fastuuid-0.12.0-cp313-cp313-win_amd64.whl", hash = "sha256:8fc66b11423e6f3e1937385f655bedd67aebe56a3dcec0cb835351cfe7d358c9"}, - {file = "fastuuid-0.12.0-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:2925f67b88d47cb16aa3eb1ab20fdcf21b94d74490e0818c91ea41434b987493"}, {file = "fastuuid-0.12.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:7b15c54d300279ab20a9cc0579ada9c9f80d1bc92997fc61fb7bf3103d7cb26b"}, {file = "fastuuid-0.12.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:458f1bc3ebbd76fdb89ad83e6b81ccd3b2a99fa6707cd3650b27606745cfb170"}, {file = "fastuuid-0.12.0-cp38-cp38-manylinux_2_34_x86_64.whl", hash = "sha256:a8f0f83fbba6dc44271a11b22e15838641b8c45612cdf541b4822a5930f6893c"}, @@ -2406,6 +2405,22 @@ files = [ {file = "h11-0.16.0.tar.gz", hash = "sha256:4e35b956cf45792e4caa5885e69fba00bdbc6ffafbfa020300e549b208ee5ff1"}, ] +[[package]] +name = "h2" +version = "4.3.0" +description = "Pure-Python HTTP/2 protocol implementation" +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "h2-4.3.0-py3-none-any.whl", hash = "sha256:c438f029a25f7945c69e0ccf0fb951dc3f73a5f6412981daee861431b70e2bdd"}, + {file = "h2-4.3.0.tar.gz", hash = "sha256:6c59efe4323fa18b47a632221a1888bd7fde6249819beda254aeca909f221bf1"}, +] + +[package.dependencies] +hpack = ">=4.1,<5" +hyperframe = ">=6.1,<7" + [[package]] name = "hiredis" version = "3.1.0" @@ -2525,6 +2540,18 @@ files = [ {file = "hiredis-3.1.0.tar.gz", hash = "sha256:51d40ac3611091020d7dea6b05ed62cb152bff595fa4f931e7b6479d777acf7c"}, ] +[[package]] +name = "hpack" +version = "4.1.0" +description = "Pure-Python HPACK header encoding" +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "hpack-4.1.0-py3-none-any.whl", hash = "sha256:157ac792668d995c657d93111f46b4535ed114f0c9c8d672271bbec7eae1b496"}, + {file = "hpack-4.1.0.tar.gz", hash = "sha256:ec5eca154f7056aa06f196a557655c5b009b382873ac8d1e66e79e87535f1dca"}, +] + [[package]] name = "httpcore" version = "1.0.9" @@ -2618,6 +2645,7 @@ files = [ [package.dependencies] anyio = "*" certifi = "*" +h2 = {version = ">=3,<5", optional = true, markers = "extra == \"http2\""} httpcore = "==1.*" idna = "*" @@ -2694,6 +2722,18 @@ files = [ [package.extras] tests = ["freezegun", "pytest", "pytest-cov"] +[[package]] +name = "hyperframe" +version = "6.1.0" +description = "Pure-Python HTTP/2 framing" +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "hyperframe-6.1.0-py3-none-any.whl", hash = "sha256:b03380493a519fce58ea5af42e4a42317bf9bd425596f7a0835ffce80f1a42e5"}, + {file = "hyperframe-6.1.0.tar.gz", hash = "sha256:f630908a00854a7adeabd6382b43923a4c4cd4b821fcb527e6ab9e15382a3b08"}, +] + [[package]] name = "idna" version = "3.10" @@ -5692,6 +5732,26 @@ files = [ dev = ["pre-commit", "tox"] testing = ["pytest", "pytest-benchmark"] +[[package]] +name = "portalocker" +version = "3.2.0" +description = "Wraps the portalocker recipe for easy usage" +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "portalocker-3.2.0-py3-none-any.whl", hash = "sha256:3cdc5f565312224bc570c49337bd21428bba0ef363bbcf58b9ef4a9f11779968"}, + {file = "portalocker-3.2.0.tar.gz", hash = "sha256:1f3002956a54a8c3730586c5c77bf18fae4149e07eaf1c29fc3faf4d5a3f89ac"}, +] + +[package.dependencies] +pywin32 = {version = ">=226", markers = "platform_system == \"Windows\""} + +[package.extras] +docs = ["portalocker[tests]"] +redis = ["redis"] +tests = ["coverage-conditional-plugin (>=0.9.0)", "portalocker[redis]", "pytest (>=5.4.1)", "pytest-cov (>=2.8.1)", "pytest-mypy (>=0.8.0)", "pytest-rerunfailures (>=15.0)", "pytest-timeout (>=2.1.0)", "sphinx (>=6.0.0)", "types-pywin32 (>=310.0.0.20250429)", "types-redis"] + [[package]] name = "posthog" version = "4.0.0" @@ -6582,8 +6642,7 @@ version = "310" description = "Python for Window Extensions" optional = false python-versions = "*" -groups = ["dev"] -markers = "sys_platform == \"win32\"" +groups = ["main", "dev"] files = [ {file = "pywin32-310-cp310-cp310-win32.whl", hash = "sha256:6dd97011efc8bf51d6793a82292419eba2c71cf8e7250cfac03bba284454abc1"}, {file = "pywin32-310-cp310-cp310-win_amd64.whl", hash = "sha256:c3e78706e4229b915a0821941a84e7ef420bf2b77e08c9dae3c76fd03fd2ae3d"}, @@ -6602,6 +6661,7 @@ files = [ {file = "pywin32-310-cp39-cp39-win32.whl", hash = "sha256:851c8d927af0d879221e616ae1f66145253537bbdd321a77e8ef701b443a9a1a"}, {file = "pywin32-310-cp39-cp39-win_amd64.whl", hash = "sha256:96867217335559ac619f00ad70e513c0fcf84b8a3af9fc2bba3b59b97da70475"}, ] +markers = {main = "platform_system == \"Windows\"", dev = "sys_platform == \"win32\""} [[package]] name = "pywinpty" @@ -6805,6 +6865,35 @@ files = [ [package.dependencies] cffi = {version = "*", markers = "implementation_name == \"pypy\""} +[[package]] +name = "qdrant-client" +version = "1.15.1" +description = "Client library for the Qdrant vector search engine" +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "qdrant_client-1.15.1-py3-none-any.whl", hash = "sha256:2b975099b378382f6ca1cfb43f0d59e541be6e16a5892f282a4b8de7eff5cb63"}, + {file = "qdrant_client-1.15.1.tar.gz", hash = "sha256:631f1f3caebfad0fd0c1fba98f41be81d9962b7bf3ca653bed3b727c0e0cbe0e"}, +] + +[package.dependencies] +grpcio = ">=1.41.0" +httpx = {version = ">=0.20.0", extras = ["http2"]} +numpy = [ + {version = ">=1.26", markers = "python_version == \"3.12\""}, + {version = ">=2.1.0", markers = "python_version >= \"3.13\""}, + {version = ">=1.21", markers = "python_version >= \"3.10\" and python_version < \"3.12\""}, +] +portalocker = ">=2.7.0,<4.0" +protobuf = ">=3.20.0" +pydantic = ">=1.10.8,<2.0.dev0 || >2.2.0" +urllib3 = ">=1.26.14,<3" + +[package.extras] +fastembed = ["fastembed (>=0.7,<0.8)"] +fastembed-gpu = ["fastembed-gpu (>=0.7,<0.8)"] + [[package]] name = "redis" version = "5.2.1" @@ -8837,4 +8926,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.1" python-versions = ">=3.10,<4" -content-hash = "30d0b0339cb30f6bd6488afcd45cd29c857c10484fd107d66fb78598be96c7ac" +content-hash = "67bb520752d4d2619d411f97a1d2e20e01b45f658186764d3e4de20b9153657a" diff --git a/pyproject.toml b/pyproject.toml index 9200e583..09a1d32f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,7 +46,8 @@ dependencies = [ "pandarallel (>=1.6.5,<2.0.0)", "instructor (==1.4.3)", "async-lru (>=2.0.5,<3.0.0)", - "jinja2 (>=3.1.6,<4.0)" + "jinja2 (>=3.1.6,<4.0)", + "qdrant-client (>=1.15.1,<2.0.0)" ] [project.urls]