Skip to content

Commit ae7dd89

Browse files
committed
rebase
2 parents 2826107 + e6b369a commit ae7dd89

11 files changed

Lines changed: 667 additions & 7 deletions

File tree

adalflow/adalflow/components/retriever/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,18 @@
2222
OptionalPackages.SQLALCHEMY,
2323
)
2424

25+
QdrantRetriever = LazyImport(
26+
"adalflow.components.retriever.qdrant_retriever.QdrantRetriever",
27+
OptionalPackages.QDRANT,
28+
)
29+
2530
__all__ = [
2631
"BM25Retriever",
2732
"LLMRetriever",
2833
"FAISSRetriever",
2934
"RerankerRetriever",
3035
"PostgresRetriever",
36+
"QdrantRetriever",
3137
"split_text_by_word_fn",
3238
"split_text_by_word_fn_then_lower_tokenized",
3339
]
Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
"""Leverage a Qdrant collection to retrieve documents."""
2+
3+
from typing import List, Optional, Any
4+
from qdrant_client import QdrantClient, models
5+
6+
from adalflow.core.retriever import (
7+
Retriever,
8+
)
9+
from adalflow.core.embedder import Embedder
10+
11+
from adalflow.core.types import (
12+
RetrieverOutput,
13+
RetrieverStrQueryType,
14+
RetrieverStrQueriesType,
15+
Document,
16+
)
17+
18+
19+
class QdrantRetriever(Retriever[Any, RetrieverStrQueryType]):
20+
__doc__ = r"""Use a Qdrant collection to retrieve documents.
21+
22+
Args:
23+
collection_name (str): the collection name in Qdrant.
24+
client (QdrantClient): An instance of qdrant_client.QdrantClient.
25+
embedder (Embedder): An instance of Embedder.
26+
top_k (Optional[int], optional): top k documents to fetch. Defaults to 10.
27+
vector_name (Optional[str], optional): the name of the vector in the collection. Defaults to None.
28+
text_key (str, optional): the key in the payload that contains the text. Defaults to "text".
29+
metadata_key (str, optional): the key in the payload that contains the metadata. Defaults to "meta_data".
30+
filter (Optional[models.Filter], optional): the filter to apply to the query. Defaults to None.
31+
32+
References:
33+
[1] Qdrant: https://qdrant.tech/
34+
[2] Documentation: https://qdrant.tech/documentation/
35+
"""
36+
37+
def __init__(
38+
self,
39+
collection_name: str,
40+
client: QdrantClient,
41+
embedder: Embedder,
42+
top_k: Optional[int] = 10,
43+
vector_name: Optional[str] = None,
44+
text_key: str = "text",
45+
metadata_key: str = "meta_data",
46+
filter: Optional[models.Filter] = None,
47+
):
48+
super().__init__()
49+
self._top_k = top_k
50+
self._collection_name = collection_name
51+
self._client = client
52+
self._embedder = embedder
53+
self._text_key = text_key
54+
self._metadata_key = metadata_key
55+
self._filter = filter
56+
57+
self._vector_name = vector_name or self._get_first_vector_name()
58+
59+
def reset_index(self):
60+
if self._client.collection_exists(self._collection_name):
61+
self._client.delete_collection(self._collection_name)
62+
63+
def call(
64+
self,
65+
input: RetrieverStrQueriesType,
66+
top_k: Optional[int] = None,
67+
**kwargs,
68+
) -> List[RetrieverOutput]:
69+
top_k = top_k or self._top_k
70+
queries: List[str] = input if isinstance(input, list) else [input]
71+
72+
queries_embeddings = self._embedder(queries)
73+
74+
query_requests: List[models.QueryRequest] = []
75+
for idx, query in enumerate(queries):
76+
query_embedding = queries_embeddings.data[idx].embedding
77+
query_requests.append(
78+
models.QueryRequest(
79+
query=query_embedding,
80+
limit=top_k,
81+
using=self._vector_name,
82+
with_payload=True,
83+
with_vector=True,
84+
filter=self._filter,
85+
**kwargs,
86+
)
87+
)
88+
89+
results = self._client.query_batch_points(
90+
self._collection_name, requests=query_requests
91+
)
92+
retrieved_outputs: List[RetrieverOutput] = []
93+
for result in results:
94+
out = self._points_to_output(
95+
result.points,
96+
query,
97+
self._text_key,
98+
self._metadata_key,
99+
self._vector_name,
100+
)
101+
retrieved_outputs.append(out)
102+
103+
return retrieved_outputs
104+
105+
def _get_first_vector_name(self) -> Optional[str]:
106+
vectors = self._client.get_collection(
107+
self._collection_name
108+
).config.params.vectors
109+
110+
if not isinstance(vectors, dict):
111+
# The collection only has the default, unnamed vector
112+
return None
113+
114+
first_vector_name = list(vectors.keys())[0]
115+
116+
# The collection has multiple vectors. Could also include the falsy unnamed vector - Empty string("")
117+
return first_vector_name or None
118+
119+
@classmethod
120+
def _points_to_output(
121+
cls,
122+
points: List[models.ScoredPoint],
123+
query: str,
124+
text_key: str,
125+
metadata_key: str,
126+
vector_name: Optional[str],
127+
) -> RetrieverOutput:
128+
doc_indices = [point.id for point in points]
129+
doc_scores = [point.score for point in points]
130+
documents = [
131+
cls._doc_from_point(point, text_key, metadata_key, vector_name)
132+
for point in points
133+
]
134+
return RetrieverOutput(
135+
doc_indices=doc_indices,
136+
doc_scores=doc_scores,
137+
query=query,
138+
documents=documents,
139+
)
140+
141+
@classmethod
142+
def _doc_from_point(
143+
cls,
144+
point: models.ScoredPoint,
145+
text_key: str,
146+
metadata_key: str,
147+
vector_name: Optional[str] = None,
148+
) -> Document:
149+
vector = point.vector
150+
if isinstance(vector, dict):
151+
vector = vector[vector_name]
152+
153+
payload = point.payload.copy()
154+
return Document(
155+
id=point.id,
156+
text=payload.get(text_key, ""),
157+
meta_data=payload.get(metadata_key, {}),
158+
vector=vector,
159+
)

adalflow/adalflow/core/generator.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -810,7 +810,12 @@ async def acall(
810810
output = GeneratorOutput(raw_response=str(completion), error=str(e))
811811

812812
log.info(f"output: {output}")
813-
self._run_callbacks(output, input=api_kwargs)
813+
self._run_callbacks(
814+
output,
815+
input=api_kwargs,
816+
prompt_kwargs=prompt_kwargs,
817+
model_kwargs=model_kwargs,
818+
)
814819
return output
815820

816821
def __call__(self, *args, **kwargs) -> Union[GeneratorOutputType, Any]:

adalflow/adalflow/utils/lazy_import.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,10 @@ class OptionalPackages(Enum):
5151
"datasets",
5252
"Please install datasets with: pip install datasets",
5353
)
54+
QDRANT = (
55+
"qdrant-client",
56+
"Please install qdrant-client with: pip install qdrant-client",
57+
)
5458

5559
def __init__(self, package_name, error_message):
5660
self.package_name = package_name
Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
import pytest
2+
from unittest.mock import MagicMock
3+
from adalflow.components.retriever import QdrantRetriever
4+
from adalflow.core.types import (
5+
RetrieverOutput,
6+
Document,
7+
)
8+
from adalflow.core.embedder import Embedder
9+
10+
qdrant_client = pytest.importorskip(
11+
"qdrant_client", reason="qdrant_client not installed"
12+
)
13+
14+
COLLECTION_NAME = "test_collection"
15+
16+
17+
@pytest.fixture
18+
def mock_qdrant_client():
19+
return MagicMock(spec=qdrant_client.QdrantClient)
20+
21+
22+
@pytest.fixture
23+
def qdrant_retriever(mock_qdrant_client):
24+
return QdrantRetriever(
25+
collection_name=COLLECTION_NAME,
26+
client=mock_qdrant_client,
27+
embedder=MagicMock(spec=Embedder),
28+
top_k=5,
29+
)
30+
31+
32+
def test_reset_index(qdrant_retriever, mock_qdrant_client):
33+
mock_qdrant_client.collection_exists.return_value = True
34+
qdrant_retriever.reset_index()
35+
mock_qdrant_client.delete_collection.assert_called_once_with(COLLECTION_NAME)
36+
37+
38+
def test_call_single_query(qdrant_retriever, mock_qdrant_client):
39+
query = "test query"
40+
41+
mock_point = MagicMock()
42+
mock_point.id = 1
43+
mock_point.score = 0.9
44+
mock_point.payload = {"text": "retrieved text", "meta_data": {"key": "value"}}
45+
mock_point.vector = [0.1, 0.2, 0.3]
46+
47+
mock_query_response = MagicMock()
48+
mock_query_response.points = [mock_point]
49+
50+
mock_qdrant_client.query_batch_points.return_value = [mock_query_response]
51+
52+
result = qdrant_retriever.call(query)
53+
54+
assert isinstance(result, list)
55+
assert len(result) == 1
56+
assert isinstance(result[0], RetrieverOutput)
57+
assert result[0].query == query
58+
assert len(result[0].doc_indices) == 1
59+
assert result[0].doc_indices[0] == 1
60+
assert len(result[0].doc_scores) == 1
61+
assert result[0].doc_scores[0] == 0.9
62+
assert len(result[0].documents) == 1
63+
assert isinstance(result[0].documents[0], Document)
64+
assert result[0].documents[0].text == "retrieved text"
65+
assert result[0].documents[0].meta_data == {"key": "value"}
66+
67+
68+
def test_get_first_vector_name(qdrant_retriever, mock_qdrant_client):
69+
# Check single unnamed vector
70+
mock_qdrant_client.get_collection.return_value = MagicMock(
71+
config=MagicMock(
72+
params=MagicMock(
73+
vectors=qdrant_client.models.VectorParams(
74+
size=1, distance=qdrant_client.models.Distance.COSINE
75+
)
76+
)
77+
)
78+
)
79+
vector_name = qdrant_retriever._get_first_vector_name()
80+
assert vector_name is None
81+
82+
mock_qdrant_client.get_collection.return_value = MagicMock(
83+
config=MagicMock(
84+
params=MagicMock(vectors={"vector1": "details", "vector2": "details"})
85+
)
86+
)
87+
vector_name = qdrant_retriever._get_first_vector_name()
88+
assert vector_name == "vector1"
89+
90+
91+
def test_points_to_output():
92+
# Prepare mocked ScoredPoint
93+
mock_point = MagicMock()
94+
mock_point.id = 1
95+
mock_point.score = 0.9
96+
mock_point.payload = {"text": "sample text", "meta_data": {"key": "value"}}
97+
mock_point.vector = [0.1, 0.2, 0.3]
98+
99+
points = [mock_point]
100+
query = "test query"
101+
text_key = "text"
102+
metadata_key = "meta_data"
103+
vector_name = "vector_name"
104+
105+
result = QdrantRetriever._points_to_output(
106+
points, query, text_key, metadata_key, vector_name
107+
)
108+
109+
assert isinstance(result, RetrieverOutput)
110+
assert result.query == query
111+
assert result.doc_indices == [1]
112+
assert result.doc_scores == [0.9]
113+
assert len(result.documents) == 1
114+
assert isinstance(result.documents[0], Document)
115+
assert result.documents[0].text == "sample text"
116+
assert result.documents[0].meta_data == {"key": "value"}
117+
assert result.documents[0].vector == [0.1, 0.2, 0.3]
118+
119+
120+
def test_doc_from_point():
121+
mock_point = MagicMock()
122+
mock_point.id = 1
123+
mock_point.payload = {"content": "sample text", "some_meta": {"key": "value"}}
124+
mock_point.vector = [0.1, 0.2, 0.3]
125+
126+
text_key = "content"
127+
metadata_key = "some_meta"
128+
vector_name = None
129+
130+
document = QdrantRetriever._doc_from_point(
131+
mock_point, text_key, metadata_key, vector_name
132+
)
133+
134+
assert isinstance(document, Document)
135+
assert document.id == 1
136+
assert document.text == "sample text"
137+
assert document.meta_data == {"key": "value"}
138+
assert document.vector == [0.1, 0.2, 0.3]
139+
140+
141+
def test_doc_from_point_with_vector_name():
142+
mock_point = MagicMock()
143+
mock_point.id = 1
144+
mock_point.payload = {"text": "sample text", "meta_data": {"key": "value"}}
145+
mock_point.vector = {"vector_name": [0.4, 0.5, 0.6]}
146+
147+
text_key = "text"
148+
metadata_key = "meta_data"
149+
vector_name = "vector_name"
150+
151+
document = QdrantRetriever._doc_from_point(
152+
mock_point, text_key, metadata_key, vector_name
153+
)
154+
155+
assert isinstance(document, Document)
156+
assert document.id == 1
157+
assert document.text == "sample text"
158+
assert document.meta_data == {"key": "value"}
159+
assert document.vector == [0.4, 0.5, 0.6]
160+
161+
162+
def test_call_with_custom_limit(qdrant_retriever, mock_qdrant_client):
163+
query = "test query"
164+
custom_limit = 5
165+
166+
mock_point = MagicMock()
167+
mock_point.id = 1
168+
mock_point.score = 0.9
169+
mock_point.payload = {"text": "retrieved text", "meta_data": {"key": "value"}}
170+
mock_point.vector = [0.1, 0.2, 0.3]
171+
172+
mock_query_response = MagicMock(spec=qdrant_client.models.QueryResponse)
173+
mock_query_response.points = [mock_point]
174+
175+
mock_qdrant_client.query_batch_points.return_value = [mock_query_response]
176+
177+
qdrant_retriever.call([query, query, query], top_k=custom_limit)
178+
179+
mock_qdrant_client.query_batch_points.assert_called_once()
180+
181+
collection_name = mock_qdrant_client.query_batch_points.call_args[0]
182+
assert collection_name == (COLLECTION_NAME,)
183+
184+
requests = mock_qdrant_client.query_batch_points.call_args[1]["requests"]
185+
for request in requests:
186+
assert request.limit == custom_limit

0 commit comments

Comments
 (0)