2222from langchain_core .embeddings import Embeddings
2323from langchain_core .runnables .config import run_in_executor
2424from langchain_core .vectorstores import VectorStore
25- from pymongo import MongoClient , ReplaceOne
25+ from pymongo import MongoClient
2626from pymongo .collection import Collection
2727from pymongo .errors import CollectionInvalid
28+ from pymongo_search_utils import bulk_embed_and_insert_texts
2829
2930from langchain_mongodb .index import (
3031 create_vector_search_index ,
@@ -238,7 +239,7 @@ def __init__(
238239 self ._relevance_score_fn = relevance_score_fn
239240
240241 # append_metadata was added in PyMongo 4.14.0, but is a valid database name on earlier versions
241- _append_client_metadata (self ._collection .database .client )
242+ _append_client_metadata (self ._collection .database .client , DRIVER_METADATA )
242243
243244 if auto_create_index is False :
244245 return
@@ -362,12 +363,23 @@ def add_texts(
362363 metadatas_batch .append (metadata )
363364 if (j + 1 ) % batch_size == 0 or size >= 47_000_000 :
364365 if ids :
365- batch_res = self .bulk_embed_and_insert_texts (
366- texts_batch , metadatas_batch , ids [i : j + 1 ]
366+ batch_res = bulk_embed_and_insert_texts (
367+ embedding_func = self ._embedding .embed_documents ,
368+ collection = self ._collection ,
369+ text_key = self ._text_key ,
370+ embedding_key = self ._embedding_key ,
371+ texts = texts_batch ,
372+ metadatas = metadatas_batch ,
373+ ids = ids [i : j + 1 ],
367374 )
368375 else :
369- batch_res = self .bulk_embed_and_insert_texts (
370- texts_batch , metadatas_batch
376+ batch_res = bulk_embed_and_insert_texts (
377+ embedding_func = self ._embedding .embed_documents ,
378+ collection = self ._collection ,
379+ text_key = self ._text_key ,
380+ embedding_key = self ._embedding_key ,
381+ texts = texts_batch ,
382+ metadatas = metadatas_batch ,
371383 )
372384 result_ids .extend (batch_res )
373385 texts_batch = []
@@ -376,12 +388,23 @@ def add_texts(
376388 i = j + 1
377389 if texts_batch :
378390 if ids :
379- batch_res = self .bulk_embed_and_insert_texts (
380- texts_batch , metadatas_batch , ids [i : j + 1 ]
391+ batch_res = bulk_embed_and_insert_texts (
392+ embedding_func = self ._embedding .embed_documents ,
393+ collection = self ._collection ,
394+ text_key = self ._text_key ,
395+ embedding_key = self ._embedding_key ,
396+ texts = texts_batch ,
397+ metadatas = metadatas_batch ,
398+ ids = ids [i : j + 1 ],
381399 )
382400 else :
383- batch_res = self .bulk_embed_and_insert_texts (
384- texts_batch , metadatas_batch
401+ batch_res = bulk_embed_and_insert_texts (
402+ embedding_func = self ._embedding .embed_documents ,
403+ collection = self ._collection ,
404+ text_key = self ._text_key ,
405+ embedding_key = self ._embedding_key ,
406+ texts = texts_batch ,
407+ metadatas = metadatas_batch ,
385408 )
386409 result_ids .extend (batch_res )
387410 return result_ids
@@ -419,39 +442,6 @@ def get_by_ids(self, ids: Sequence[str], /) -> list[Document]:
419442 docs .append (Document (page_content = text , id = oid_to_str (_id ), metadata = doc ))
420443 return docs
421444
422- def bulk_embed_and_insert_texts (
423- self ,
424- texts : Union [List [str ], Iterable [str ]],
425- metadatas : Union [List [dict ], Generator [dict , Any , Any ]],
426- ids : Optional [List [str ]] = None ,
427- ) -> List [str ]:
428- """Bulk insert single batch of texts, embeddings, and optionally ids.
429-
430- See add_texts for additional details.
431- """
432- if not texts :
433- return []
434- # Compute embedding vectors
435- embeddings = self ._embedding .embed_documents (list (texts ))
436- if not ids :
437- ids = [str (ObjectId ()) for _ in range (len (list (texts )))]
438- docs = [
439- {
440- "_id" : str_to_oid (i ),
441- self ._text_key : t ,
442- self ._embedding_key : embedding ,
443- ** m ,
444- }
445- for i , t , m , embedding in zip (
446- ids , texts , metadatas , embeddings , strict = True
447- )
448- ]
449- operations = [ReplaceOne ({"_id" : doc ["_id" ]}, doc , upsert = True ) for doc in docs ]
450- # insert the documents in MongoDB Atlas
451- result = self ._collection .bulk_write (operations )
452- assert result .upserted_ids is not None
453- return [oid_to_str (_id ) for _id in result .upserted_ids .values ()]
454-
455445 def add_documents (
456446 self ,
457447 documents : List [Document ],
@@ -484,8 +474,14 @@ def add_documents(
484474 strict = True ,
485475 )
486476 result_ids .extend (
487- self .bulk_embed_and_insert_texts (
488- texts = texts , metadatas = metadatas , ids = ids [start :end ]
477+ bulk_embed_and_insert_texts (
478+ embedding_func = self ._embedding .embed_documents ,
479+ collection = self ._collection ,
480+ text_key = self ._text_key ,
481+ embedding_key = self ._embedding_key ,
482+ texts = texts ,
483+ metadatas = metadatas ,
484+ ids = ids [start :end ],
489485 )
490486 )
491487 start = end
0 commit comments