2222from langchain_core .embeddings import Embeddings
2323from langchain_core .runnables .config import run_in_executor
2424from langchain_core .vectorstores import VectorStore
25- from pymongo import MongoClient
25+ from pymongo import MongoClient , ReplaceOne
2626from pymongo .collection import Collection
2727from pymongo .errors import CollectionInvalid
28- from pymongo_search_utils import bulk_embed_and_insert_texts
2928
3029from langchain_mongodb .index import (
3130 create_vector_search_index ,
@@ -239,7 +238,7 @@ def __init__(
239238 self ._relevance_score_fn = relevance_score_fn
240239
241240 # append_metadata was added in PyMongo 4.14.0, but is a valid database name on earlier versions
242- _append_client_metadata (self ._collection .database .client , DRIVER_METADATA )
241+ _append_client_metadata (self ._collection .database .client )
243242
244243 if auto_create_index is False :
245244 return
@@ -363,23 +362,12 @@ def add_texts(
363362 metadatas_batch .append (metadata )
364363 if (j + 1 ) % batch_size == 0 or size >= 47_000_000 :
365364 if ids :
366- batch_res = bulk_embed_and_insert_texts (
367- embedding_func = self ._embedding .embed_documents ,
368- collection = self ._collection ,
369- text_key = self ._text_key ,
370- embedding_key = self ._embedding_key ,
371- texts = texts_batch ,
372- metadatas = metadatas_batch ,
373- ids = ids [i : j + 1 ],
365+ batch_res = self .bulk_embed_and_insert_texts (
366+ texts_batch , metadatas_batch , ids [i : j + 1 ]
374367 )
375368 else :
376- batch_res = bulk_embed_and_insert_texts (
377- embedding_func = self ._embedding .embed_documents ,
378- collection = self ._collection ,
379- text_key = self ._text_key ,
380- embedding_key = self ._embedding_key ,
381- texts = texts_batch ,
382- metadatas = metadatas_batch ,
369+ batch_res = self .bulk_embed_and_insert_texts (
370+ texts_batch , metadatas_batch
383371 )
384372 result_ids .extend (batch_res )
385373 texts_batch = []
@@ -388,23 +376,12 @@ def add_texts(
388376 i = j + 1
389377 if texts_batch :
390378 if ids :
391- batch_res = bulk_embed_and_insert_texts (
392- embedding_func = self ._embedding .embed_documents ,
393- collection = self ._collection ,
394- text_key = self ._text_key ,
395- embedding_key = self ._embedding_key ,
396- texts = texts_batch ,
397- metadatas = metadatas_batch ,
398- ids = ids [i : j + 1 ],
379+ batch_res = self .bulk_embed_and_insert_texts (
380+ texts_batch , metadatas_batch , ids [i : j + 1 ]
399381 )
400382 else :
401- batch_res = bulk_embed_and_insert_texts (
402- embedding_func = self ._embedding .embed_documents ,
403- collection = self ._collection ,
404- text_key = self ._text_key ,
405- embedding_key = self ._embedding_key ,
406- texts = texts_batch ,
407- metadatas = metadatas_batch ,
383+ batch_res = self .bulk_embed_and_insert_texts (
384+ texts_batch , metadatas_batch
408385 )
409386 result_ids .extend (batch_res )
410387 return result_ids
@@ -442,6 +419,39 @@ def get_by_ids(self, ids: Sequence[str], /) -> list[Document]:
442419 docs .append (Document (page_content = text , id = oid_to_str (_id ), metadata = doc ))
443420 return docs
444421
422+ def bulk_embed_and_insert_texts (
423+ self ,
424+ texts : Union [List [str ], Iterable [str ]],
425+ metadatas : Union [List [dict ], Generator [dict , Any , Any ]],
426+ ids : Optional [List [str ]] = None ,
427+ ) -> List [str ]:
428+ """Bulk insert single batch of texts, embeddings, and optionally ids.
429+
430+ See add_texts for additional details.
431+ """
432+ if not texts :
433+ return []
434+ # Compute embedding vectors
435+ embeddings = self ._embedding .embed_documents (list (texts ))
436+ if not ids :
437+ ids = [str (ObjectId ()) for _ in range (len (list (texts )))]
438+ docs = [
439+ {
440+ "_id" : str_to_oid (i ),
441+ self ._text_key : t ,
442+ self ._embedding_key : embedding ,
443+ ** m ,
444+ }
445+ for i , t , m , embedding in zip (
446+ ids , texts , metadatas , embeddings , strict = True
447+ )
448+ ]
449+ operations = [ReplaceOne ({"_id" : doc ["_id" ]}, doc , upsert = True ) for doc in docs ]
450+ # insert the documents in MongoDB Atlas
451+ result = self ._collection .bulk_write (operations )
452+ assert result .upserted_ids is not None
453+ return [oid_to_str (_id ) for _id in result .upserted_ids .values ()]
454+
445455 def add_documents (
446456 self ,
447457 documents : List [Document ],
@@ -474,14 +484,8 @@ def add_documents(
474484 strict = True ,
475485 )
476486 result_ids .extend (
477- bulk_embed_and_insert_texts (
478- embedding_func = self ._embedding .embed_documents ,
479- collection = self ._collection ,
480- text_key = self ._text_key ,
481- embedding_key = self ._embedding_key ,
482- texts = texts ,
483- metadatas = metadatas ,
484- ids = ids [start :end ],
487+ self .bulk_embed_and_insert_texts (
488+ texts = texts , metadatas = metadatas , ids = ids [start :end ]
485489 )
486490 )
487491 start = end
0 commit comments