Skip to content

Commit 5741835

Browse files
committed
INTPYTHON-752 Integrate pymongo-vectorsearch-utils
1 parent dd088b7 commit 5741835

File tree

13 files changed

+164
-87
lines changed

13 files changed

+164
-87
lines changed

libs/langchain-mongodb/langchain_mongodb/agent_toolkit/database.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def __init__(
6464
self._sample_docs_in_coll_info = sample_docs_in_collection_info
6565
self._indexes_in_coll_info = indexes_in_collection_info
6666

67-
_append_client_metadata(self._client, DRIVER_METADATA)
67+
_append_client_metadata(self._client)
6868

6969
@classmethod
7070
def from_connection_string(

libs/langchain-mongodb/langchain_mongodb/chat_message_histories.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def __init__(
112112
if connection_string:
113113
raise ValueError("Must provide connection_string or client, not both")
114114
self.client = client
115-
_append_client_metadata(self.client, DRIVER_METADATA)
115+
_append_client_metadata(self.client)
116116
elif connection_string:
117117
try:
118118
self.client = MongoClient(

libs/langchain-mongodb/langchain_mongodb/docstores.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def __init__(self, collection: Collection, text_key: str = "page_content") -> No
3737
self.collection = collection
3838
self._text_key = text_key
3939

40-
_append_client_metadata(self.collection.database.client, DRIVER_METADATA)
40+
_append_client_metadata(self.collection.database.client)
4141

4242
@classmethod
4343
def from_connection_string(

libs/langchain-mongodb/langchain_mongodb/graphrag/graph.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ def __init__(
186186
self.collection = collection
187187

188188
# append_metadata was added in PyMongo 4.14.0, but is a valid database name on earlier versions
189-
_append_client_metadata(collection.database.client, DRIVER_METADATA)
189+
_append_client_metadata(collection.database.client)
190190

191191
self.entity_extraction_model = entity_extraction_model
192192
self.entity_prompt = (

libs/langchain-mongodb/langchain_mongodb/index.py

Lines changed: 102 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,6 @@
77
from pymongo.collection import Collection
88
from pymongo.operations import SearchIndexModel
99

10-
# Don't break imports for modules that expect these functions
11-
# to be in this module.
12-
from pymongo_search_utils import ( # noqa: F401
13-
create_vector_search_index,
14-
update_vector_search_index,
15-
)
16-
1710
logger = logging.getLogger(__file__)
1811

1912

@@ -41,6 +34,60 @@ def _vector_search_index_definition(
4134
return definition
4235

4336

37+
def create_vector_search_index(
38+
collection: Collection,
39+
index_name: str,
40+
dimensions: int,
41+
path: str,
42+
similarity: str,
43+
filters: Optional[List[str]] = None,
44+
*,
45+
wait_until_complete: Optional[float] = None,
46+
**kwargs: Any,
47+
) -> None:
48+
"""Experimental Utility function to create a vector search index
49+
50+
Args:
51+
collection (Collection): MongoDB Collection
52+
index_name (str): Name of Index
53+
dimensions (int): Number of dimensions in embedding
54+
path (str): field with vector embedding
55+
similarity (str): The similarity score used for the index
56+
filters (List[str]): Fields/paths to index to allow filtering in $vectorSearch
57+
wait_until_complete (Optional[float]): If provided, number of seconds to wait
58+
until search index is ready.
59+
kwargs: Keyword arguments supplying any additional options to SearchIndexModel.
60+
"""
61+
logger.info("Creating Search Index %s on %s", index_name, collection.name)
62+
63+
if collection.name not in collection.database.list_collection_names(
64+
authorizedCollections=True
65+
):
66+
collection.database.create_collection(collection.name)
67+
68+
result = collection.create_search_index(
69+
SearchIndexModel(
70+
definition=_vector_search_index_definition(
71+
dimensions=dimensions,
72+
path=path,
73+
similarity=similarity,
74+
filters=filters,
75+
**kwargs,
76+
),
77+
name=index_name,
78+
type="vectorSearch",
79+
)
80+
)
81+
82+
if wait_until_complete:
83+
_wait_for_predicate(
84+
predicate=lambda: _is_index_ready(collection, index_name),
85+
err=f"{index_name=} did not complete in {wait_until_complete}!",
86+
timeout=wait_until_complete,
87+
)
88+
logger.info(result)
89+
90+
4491
def drop_vector_search_index(
4592
collection: Collection,
4693
index_name: str,
@@ -68,6 +115,54 @@ def drop_vector_search_index(
68115
logger.info("Vector Search index %s.%s dropped", collection.name, index_name)
69116

70117

118+
def update_vector_search_index(
119+
collection: Collection,
120+
index_name: str,
121+
dimensions: int,
122+
path: str,
123+
similarity: str,
124+
filters: Optional[List[str]] = None,
125+
*,
126+
wait_until_complete: Optional[float] = None,
127+
**kwargs: Any,
128+
) -> None:
129+
"""Update a search index.
130+
131+
Replace the existing index definition with the provided definition.
132+
133+
Args:
134+
collection (Collection): MongoDB Collection
135+
index_name (str): Name of Index
136+
dimensions (int): Number of dimensions in embedding
137+
path (str): field with vector embedding
138+
similarity (str): The similarity score used for the index.
139+
filters (List[str]): Fields/paths to index to allow filtering in $vectorSearch
140+
wait_until_complete (Optional[float]): If provided, number of seconds to wait
141+
until search index is ready.
142+
kwargs: Keyword arguments supplying any additional options to SearchIndexModel.
143+
"""
144+
logger.info(
145+
"Updating Search Index %s from Collection: %s", index_name, collection.name
146+
)
147+
collection.update_search_index(
148+
name=index_name,
149+
definition=_vector_search_index_definition(
150+
dimensions=dimensions,
151+
path=path,
152+
similarity=similarity,
153+
filters=filters,
154+
**kwargs,
155+
),
156+
)
157+
if wait_until_complete:
158+
_wait_for_predicate(
159+
predicate=lambda: _is_index_ready(collection, index_name),
160+
err=f"Index {index_name} update did not complete in {wait_until_complete}!",
161+
timeout=wait_until_complete,
162+
)
163+
logger.info("Update succeeded")
164+
165+
71166
def _is_index_ready(collection: Collection, index_name: str) -> bool:
72167
"""Check for the index name in the list of available search indexes to see if the
73168
specified index is of status READY

libs/langchain-mongodb/langchain_mongodb/indexes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def __init__(self, collection: Collection) -> None:
3636
super().__init__(namespace=namespace)
3737
self._collection = collection
3838

39-
_append_client_metadata(self._collection.database.client, DRIVER_METADATA)
39+
_append_client_metadata(self._collection.database.client)
4040

4141
@classmethod
4242
def from_connection_string(

libs/langchain-mongodb/langchain_mongodb/loaders.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def __init__(
5454
self.include_db_collection_in_metadata = include_db_collection_in_metadata
5555

5656
# append_metadata was added in PyMongo 4.14.0, but is a valid database name on earlier versions
57-
_append_client_metadata(self.db.client, DRIVER_METADATA)
57+
_append_client_metadata(self.db.client)
5858

5959
@classmethod
6060
def from_connection_string(

libs/langchain-mongodb/langchain_mongodb/retrievers/full_text_search.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,7 @@
88
from pymongo.collection import Collection
99

1010
from langchain_mongodb.pipelines import text_search_stage
11-
from langchain_mongodb.utils import (
12-
DRIVER_METADATA,
13-
_append_client_metadata,
14-
make_serializable,
15-
)
11+
from langchain_mongodb.utils import _append_client_metadata, make_serializable
1612

1713

1814
class MongoDBAtlasFullTextSearchRetriever(BaseRetriever):
@@ -68,7 +64,7 @@ def _get_relevant_documents(
6864
)
6965

7066
if not self._added_metadata:
71-
_append_client_metadata(self.collection.database.client, DRIVER_METADATA)
67+
_append_client_metadata(self.collection.database.client)
7268
self._added_metadata = True
7369

7470
# Execution

libs/langchain-mongodb/langchain_mongodb/utils.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,21 +24,22 @@
2424
from typing import Any, Dict, List, Union
2525

2626
import numpy as np
27+
from pymongo import MongoClient
2728
from pymongo.driver_info import DriverInfo
2829

29-
# Don't break imports for modules that expect this function
30-
# to be in this module.
31-
from pymongo_search_utils import (
32-
append_client_metadata as _append_client_metadata, # noqa: F401
33-
)
34-
3530
logger = logging.getLogger(__name__)
3631

3732
Matrix = Union[List[List[float]], List[np.ndarray], np.ndarray]
3833

3934
DRIVER_METADATA = DriverInfo(name="Langchain", version=version("langchain-mongodb"))
4035

4136

37+
def _append_client_metadata(client: MongoClient) -> None:
38+
# append_metadata was added in PyMongo 4.14.0, but is a valid database name on earlier versions
39+
if callable(client.append_metadata):
40+
client.append_metadata(DRIVER_METADATA)
41+
42+
4243
def cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray:
4344
"""Row-wise cosine similarity between two equal-width matrices."""
4445
if len(X) == 0 or len(Y) == 0:

libs/langchain-mongodb/langchain_mongodb/vectorstores.py

Lines changed: 45 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,9 @@
2222
from langchain_core.embeddings import Embeddings
2323
from langchain_core.runnables.config import run_in_executor
2424
from langchain_core.vectorstores import VectorStore
25-
from pymongo import MongoClient
25+
from pymongo import MongoClient, ReplaceOne
2626
from pymongo.collection import Collection
2727
from pymongo.errors import CollectionInvalid
28-
from pymongo_search_utils import bulk_embed_and_insert_texts
2928

3029
from langchain_mongodb.index import (
3130
create_vector_search_index,
@@ -239,7 +238,7 @@ def __init__(
239238
self._relevance_score_fn = relevance_score_fn
240239

241240
# append_metadata was added in PyMongo 4.14.0, but is a valid database name on earlier versions
242-
_append_client_metadata(self._collection.database.client, DRIVER_METADATA)
241+
_append_client_metadata(self._collection.database.client)
243242

244243
if auto_create_index is False:
245244
return
@@ -363,23 +362,12 @@ def add_texts(
363362
metadatas_batch.append(metadata)
364363
if (j + 1) % batch_size == 0 or size >= 47_000_000:
365364
if ids:
366-
batch_res = bulk_embed_and_insert_texts(
367-
embedding_func=self._embedding.embed_documents,
368-
collection=self._collection,
369-
text_key=self._text_key,
370-
embedding_key=self._embedding_key,
371-
texts=texts_batch,
372-
metadatas=metadatas_batch,
373-
ids=ids[i : j + 1],
365+
batch_res = self.bulk_embed_and_insert_texts(
366+
texts_batch, metadatas_batch, ids[i : j + 1]
374367
)
375368
else:
376-
batch_res = bulk_embed_and_insert_texts(
377-
embedding_func=self._embedding.embed_documents,
378-
collection=self._collection,
379-
text_key=self._text_key,
380-
embedding_key=self._embedding_key,
381-
texts=texts_batch,
382-
metadatas=metadatas_batch,
369+
batch_res = self.bulk_embed_and_insert_texts(
370+
texts_batch, metadatas_batch
383371
)
384372
result_ids.extend(batch_res)
385373
texts_batch = []
@@ -388,23 +376,12 @@ def add_texts(
388376
i = j + 1
389377
if texts_batch:
390378
if ids:
391-
batch_res = bulk_embed_and_insert_texts(
392-
embedding_func=self._embedding.embed_documents,
393-
collection=self._collection,
394-
text_key=self._text_key,
395-
embedding_key=self._embedding_key,
396-
texts=texts_batch,
397-
metadatas=metadatas_batch,
398-
ids=ids[i : j + 1],
379+
batch_res = self.bulk_embed_and_insert_texts(
380+
texts_batch, metadatas_batch, ids[i : j + 1]
399381
)
400382
else:
401-
batch_res = bulk_embed_and_insert_texts(
402-
embedding_func=self._embedding.embed_documents,
403-
collection=self._collection,
404-
text_key=self._text_key,
405-
embedding_key=self._embedding_key,
406-
texts=texts_batch,
407-
metadatas=metadatas_batch,
383+
batch_res = self.bulk_embed_and_insert_texts(
384+
texts_batch, metadatas_batch
408385
)
409386
result_ids.extend(batch_res)
410387
return result_ids
@@ -442,6 +419,39 @@ def get_by_ids(self, ids: Sequence[str], /) -> list[Document]:
442419
docs.append(Document(page_content=text, id=oid_to_str(_id), metadata=doc))
443420
return docs
444421

422+
def bulk_embed_and_insert_texts(
423+
self,
424+
texts: Union[List[str], Iterable[str]],
425+
metadatas: Union[List[dict], Generator[dict, Any, Any]],
426+
ids: Optional[List[str]] = None,
427+
) -> List[str]:
428+
"""Bulk insert single batch of texts, embeddings, and optionally ids.
429+
430+
See add_texts for additional details.
431+
"""
432+
if not texts:
433+
return []
434+
# Compute embedding vectors
435+
embeddings = self._embedding.embed_documents(list(texts))
436+
if not ids:
437+
ids = [str(ObjectId()) for _ in range(len(list(texts)))]
438+
docs = [
439+
{
440+
"_id": str_to_oid(i),
441+
self._text_key: t,
442+
self._embedding_key: embedding,
443+
**m,
444+
}
445+
for i, t, m, embedding in zip(
446+
ids, texts, metadatas, embeddings, strict=True
447+
)
448+
]
449+
operations = [ReplaceOne({"_id": doc["_id"]}, doc, upsert=True) for doc in docs]
450+
# insert the documents in MongoDB Atlas
451+
result = self._collection.bulk_write(operations)
452+
assert result.upserted_ids is not None
453+
return [oid_to_str(_id) for _id in result.upserted_ids.values()]
454+
445455
def add_documents(
446456
self,
447457
documents: List[Document],
@@ -474,14 +484,8 @@ def add_documents(
474484
strict=True,
475485
)
476486
result_ids.extend(
477-
bulk_embed_and_insert_texts(
478-
embedding_func=self._embedding.embed_documents,
479-
collection=self._collection,
480-
text_key=self._text_key,
481-
embedding_key=self._embedding_key,
482-
texts=texts,
483-
metadatas=metadatas,
484-
ids=ids[start:end],
487+
self.bulk_embed_and_insert_texts(
488+
texts=texts, metadatas=metadatas, ids=ids[start:end]
485489
)
486490
)
487491
start = end

0 commit comments

Comments
 (0)