Skip to content

Commit 109d355

Browse files
committed
INTPYTHON-752 Integrate pymongo-vectorsearch-utils
1 parent d29d12e commit 109d355

File tree

14 files changed

+82
-164
lines changed

14 files changed

+82
-164
lines changed

libs/langchain-mongodb/langchain_mongodb/agent_toolkit/database.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def __init__(
6464
self._sample_docs_in_coll_info = sample_docs_in_collection_info
6565
self._indexes_in_coll_info = indexes_in_collection_info
6666

67-
_append_client_metadata(self._client)
67+
_append_client_metadata(self._client, DRIVER_METADATA)
6868

6969
@classmethod
7070
def from_connection_string(

libs/langchain-mongodb/langchain_mongodb/chat_message_histories.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def __init__(
112112
if connection_string:
113113
raise ValueError("Must provide connection_string or client, not both")
114114
self.client = client
115-
_append_client_metadata(self.client)
115+
_append_client_metadata(self.client, DRIVER_METADATA)
116116
elif connection_string:
117117
try:
118118
self.client = MongoClient(

libs/langchain-mongodb/langchain_mongodb/docstores.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def __init__(self, collection: Collection, text_key: str = "page_content") -> No
3737
self.collection = collection
3838
self._text_key = text_key
3939

40-
_append_client_metadata(self.collection.database.client)
40+
_append_client_metadata(self.collection.database.client, DRIVER_METADATA)
4141

4242
@classmethod
4343
def from_connection_string(

libs/langchain-mongodb/langchain_mongodb/graphrag/graph.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ def __init__(
186186
self.collection = collection
187187

188188
# append_metadata was added in PyMongo 4.14.0, but is a valid database name on earlier versions
189-
_append_client_metadata(collection.database.client)
189+
_append_client_metadata(collection.database.client, DRIVER_METADATA)
190190

191191
self.entity_extraction_model = entity_extraction_model
192192
self.entity_prompt = (

libs/langchain-mongodb/langchain_mongodb/index.py

Lines changed: 7 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,13 @@
77
from pymongo.collection import Collection
88
from pymongo.operations import SearchIndexModel
99

10+
# Don't break imports for modules that expect these functions
11+
# to be in this module.
12+
from pymongo_search_utils import ( # noqa: F401
13+
create_vector_search_index,
14+
update_vector_search_index,
15+
)
16+
1017
logger = logging.getLogger(__file__)
1118

1219

@@ -34,60 +41,6 @@ def _vector_search_index_definition(
3441
return definition
3542

3643

37-
def create_vector_search_index(
38-
collection: Collection,
39-
index_name: str,
40-
dimensions: int,
41-
path: str,
42-
similarity: str,
43-
filters: Optional[List[str]] = None,
44-
*,
45-
wait_until_complete: Optional[float] = None,
46-
**kwargs: Any,
47-
) -> None:
48-
"""Experimental Utility function to create a vector search index
49-
50-
Args:
51-
collection (Collection): MongoDB Collection
52-
index_name (str): Name of Index
53-
dimensions (int): Number of dimensions in embedding
54-
path (str): field with vector embedding
55-
similarity (str): The similarity score used for the index
56-
filters (List[str]): Fields/paths to index to allow filtering in $vectorSearch
57-
wait_until_complete (Optional[float]): If provided, number of seconds to wait
58-
until search index is ready.
59-
kwargs: Keyword arguments supplying any additional options to SearchIndexModel.
60-
"""
61-
logger.info("Creating Search Index %s on %s", index_name, collection.name)
62-
63-
if collection.name not in collection.database.list_collection_names(
64-
authorizedCollections=True
65-
):
66-
collection.database.create_collection(collection.name)
67-
68-
result = collection.create_search_index(
69-
SearchIndexModel(
70-
definition=_vector_search_index_definition(
71-
dimensions=dimensions,
72-
path=path,
73-
similarity=similarity,
74-
filters=filters,
75-
**kwargs,
76-
),
77-
name=index_name,
78-
type="vectorSearch",
79-
)
80-
)
81-
82-
if wait_until_complete:
83-
_wait_for_predicate(
84-
predicate=lambda: _is_index_ready(collection, index_name),
85-
err=f"{index_name=} did not complete in {wait_until_complete}!",
86-
timeout=wait_until_complete,
87-
)
88-
logger.info(result)
89-
90-
9144
def drop_vector_search_index(
9245
collection: Collection,
9346
index_name: str,
@@ -115,54 +68,6 @@ def drop_vector_search_index(
11568
logger.info("Vector Search index %s.%s dropped", collection.name, index_name)
11669

11770

118-
def update_vector_search_index(
119-
collection: Collection,
120-
index_name: str,
121-
dimensions: int,
122-
path: str,
123-
similarity: str,
124-
filters: Optional[List[str]] = None,
125-
*,
126-
wait_until_complete: Optional[float] = None,
127-
**kwargs: Any,
128-
) -> None:
129-
"""Update a search index.
130-
131-
Replace the existing index definition with the provided definition.
132-
133-
Args:
134-
collection (Collection): MongoDB Collection
135-
index_name (str): Name of Index
136-
dimensions (int): Number of dimensions in embedding
137-
path (str): field with vector embedding
138-
similarity (str): The similarity score used for the index.
139-
filters (List[str]): Fields/paths to index to allow filtering in $vectorSearch
140-
wait_until_complete (Optional[float]): If provided, number of seconds to wait
141-
until search index is ready.
142-
kwargs: Keyword arguments supplying any additional options to SearchIndexModel.
143-
"""
144-
logger.info(
145-
"Updating Search Index %s from Collection: %s", index_name, collection.name
146-
)
147-
collection.update_search_index(
148-
name=index_name,
149-
definition=_vector_search_index_definition(
150-
dimensions=dimensions,
151-
path=path,
152-
similarity=similarity,
153-
filters=filters,
154-
**kwargs,
155-
),
156-
)
157-
if wait_until_complete:
158-
_wait_for_predicate(
159-
predicate=lambda: _is_index_ready(collection, index_name),
160-
err=f"Index {index_name} update did not complete in {wait_until_complete}!",
161-
timeout=wait_until_complete,
162-
)
163-
logger.info("Update succeeded")
164-
165-
16671
def _is_index_ready(collection: Collection, index_name: str) -> bool:
16772
"""Check for the index name in the list of available search indexes to see if the
16873
specified index is of status READY

libs/langchain-mongodb/langchain_mongodb/indexes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def __init__(self, collection: Collection) -> None:
3636
super().__init__(namespace=namespace)
3737
self._collection = collection
3838

39-
_append_client_metadata(self._collection.database.client)
39+
_append_client_metadata(self._collection.database.client, DRIVER_METADATA)
4040

4141
@classmethod
4242
def from_connection_string(

libs/langchain-mongodb/langchain_mongodb/loaders.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def __init__(
5454
self.include_db_collection_in_metadata = include_db_collection_in_metadata
5555

5656
# append_metadata was added in PyMongo 4.14.0, but is a valid database name on earlier versions
57-
_append_client_metadata(self.db.client)
57+
_append_client_metadata(self.db.client, DRIVER_METADATA)
5858

5959
@classmethod
6060
def from_connection_string(

libs/langchain-mongodb/langchain_mongodb/retrievers/full_text_search.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,11 @@
88
from pymongo.collection import Collection
99

1010
from langchain_mongodb.pipelines import text_search_stage
11-
from langchain_mongodb.utils import _append_client_metadata, make_serializable
11+
from langchain_mongodb.utils import (
12+
DRIVER_METADATA,
13+
_append_client_metadata,
14+
make_serializable,
15+
)
1216

1317

1418
class MongoDBAtlasFullTextSearchRetriever(BaseRetriever):
@@ -64,7 +68,7 @@ def _get_relevant_documents(
6468
)
6569

6670
if not self._added_metadata:
67-
_append_client_metadata(self.collection.database.client)
71+
_append_client_metadata(self.collection.database.client, DRIVER_METADATA)
6872
self._added_metadata = True
6973

7074
# Execution

libs/langchain-mongodb/langchain_mongodb/utils.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,22 +24,21 @@
2424
from typing import Any, Dict, List, Union
2525

2626
import numpy as np
27-
from pymongo import MongoClient
2827
from pymongo.driver_info import DriverInfo
2928

29+
# Don't break imports for modules that expect this function
30+
# to be in this module.
31+
from pymongo_search_utils import (
32+
append_client_metadata as _append_client_metadata, # noqa: F401
33+
)
34+
3035
logger = logging.getLogger(__name__)
3136

3237
Matrix = Union[List[List[float]], List[np.ndarray], np.ndarray]
3338

3439
DRIVER_METADATA = DriverInfo(name="Langchain", version=version("langchain-mongodb"))
3540

3641

37-
def _append_client_metadata(client: MongoClient) -> None:
38-
# append_metadata was added in PyMongo 4.14.0, but is a valid database name on earlier versions
39-
if callable(client.append_metadata):
40-
client.append_metadata(DRIVER_METADATA)
41-
42-
4342
def cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray:
4443
"""Row-wise cosine similarity between two equal-width matrices."""
4544
if len(X) == 0 or len(Y) == 0:

libs/langchain-mongodb/langchain_mongodb/vectorstores.py

Lines changed: 41 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,10 @@
2222
from langchain_core.embeddings import Embeddings
2323
from langchain_core.runnables.config import run_in_executor
2424
from langchain_core.vectorstores import VectorStore
25-
from pymongo import MongoClient, ReplaceOne
25+
from pymongo import MongoClient
2626
from pymongo.collection import Collection
2727
from pymongo.errors import CollectionInvalid
28+
from pymongo_search_utils import bulk_embed_and_insert_texts
2829

2930
from langchain_mongodb.index import (
3031
create_vector_search_index,
@@ -238,7 +239,7 @@ def __init__(
238239
self._relevance_score_fn = relevance_score_fn
239240

240241
# append_metadata was added in PyMongo 4.14.0, but is a valid database name on earlier versions
241-
_append_client_metadata(self._collection.database.client)
242+
_append_client_metadata(self._collection.database.client, DRIVER_METADATA)
242243

243244
if auto_create_index is False:
244245
return
@@ -362,12 +363,23 @@ def add_texts(
362363
metadatas_batch.append(metadata)
363364
if (j + 1) % batch_size == 0 or size >= 47_000_000:
364365
if ids:
365-
batch_res = self.bulk_embed_and_insert_texts(
366-
texts_batch, metadatas_batch, ids[i : j + 1]
366+
batch_res = bulk_embed_and_insert_texts(
367+
embedding_func=self._embedding.embed_documents,
368+
collection=self._collection,
369+
text_key=self._text_key,
370+
embedding_key=self._embedding_key,
371+
texts=texts_batch,
372+
metadatas=metadatas_batch,
373+
ids=ids[i : j + 1],
367374
)
368375
else:
369-
batch_res = self.bulk_embed_and_insert_texts(
370-
texts_batch, metadatas_batch
376+
batch_res = bulk_embed_and_insert_texts(
377+
embedding_func=self._embedding.embed_documents,
378+
collection=self._collection,
379+
text_key=self._text_key,
380+
embedding_key=self._embedding_key,
381+
texts=texts_batch,
382+
metadatas=metadatas_batch,
371383
)
372384
result_ids.extend(batch_res)
373385
texts_batch = []
@@ -376,12 +388,23 @@ def add_texts(
376388
i = j + 1
377389
if texts_batch:
378390
if ids:
379-
batch_res = self.bulk_embed_and_insert_texts(
380-
texts_batch, metadatas_batch, ids[i : j + 1]
391+
batch_res = bulk_embed_and_insert_texts(
392+
embedding_func=self._embedding.embed_documents,
393+
collection=self._collection,
394+
text_key=self._text_key,
395+
embedding_key=self._embedding_key,
396+
texts=texts_batch,
397+
metadatas=metadatas_batch,
398+
ids=ids[i : j + 1],
381399
)
382400
else:
383-
batch_res = self.bulk_embed_and_insert_texts(
384-
texts_batch, metadatas_batch
401+
batch_res = bulk_embed_and_insert_texts(
402+
embedding_func=self._embedding.embed_documents,
403+
collection=self._collection,
404+
text_key=self._text_key,
405+
embedding_key=self._embedding_key,
406+
texts=texts_batch,
407+
metadatas=metadatas_batch,
385408
)
386409
result_ids.extend(batch_res)
387410
return result_ids
@@ -419,39 +442,6 @@ def get_by_ids(self, ids: Sequence[str], /) -> list[Document]:
419442
docs.append(Document(page_content=text, id=oid_to_str(_id), metadata=doc))
420443
return docs
421444

422-
def bulk_embed_and_insert_texts(
423-
self,
424-
texts: Union[List[str], Iterable[str]],
425-
metadatas: Union[List[dict], Generator[dict, Any, Any]],
426-
ids: Optional[List[str]] = None,
427-
) -> List[str]:
428-
"""Bulk insert single batch of texts, embeddings, and optionally ids.
429-
430-
See add_texts for additional details.
431-
"""
432-
if not texts:
433-
return []
434-
# Compute embedding vectors
435-
embeddings = self._embedding.embed_documents(list(texts))
436-
if not ids:
437-
ids = [str(ObjectId()) for _ in range(len(list(texts)))]
438-
docs = [
439-
{
440-
"_id": str_to_oid(i),
441-
self._text_key: t,
442-
self._embedding_key: embedding,
443-
**m,
444-
}
445-
for i, t, m, embedding in zip(
446-
ids, texts, metadatas, embeddings, strict=True
447-
)
448-
]
449-
operations = [ReplaceOne({"_id": doc["_id"]}, doc, upsert=True) for doc in docs]
450-
# insert the documents in MongoDB Atlas
451-
result = self._collection.bulk_write(operations)
452-
assert result.upserted_ids is not None
453-
return [oid_to_str(_id) for _id in result.upserted_ids.values()]
454-
455445
def add_documents(
456446
self,
457447
documents: List[Document],
@@ -484,8 +474,14 @@ def add_documents(
484474
strict=True,
485475
)
486476
result_ids.extend(
487-
self.bulk_embed_and_insert_texts(
488-
texts=texts, metadatas=metadatas, ids=ids[start:end]
477+
bulk_embed_and_insert_texts(
478+
embedding_func=self._embedding.embed_documents,
479+
collection=self._collection,
480+
text_key=self._text_key,
481+
embedding_key=self._embedding_key,
482+
texts=texts,
483+
metadatas=metadatas,
484+
ids=ids[start:end],
489485
)
490486
)
491487
start = end

0 commit comments

Comments
 (0)