@@ -110,7 +110,7 @@ def __init__(
110
110
self .tokenizer = self .load_tokenizer (self .path , max_length = max_length )
111
111
self .model = ort .InferenceSession (str (model_path ), providers = onnx_providers , sess_options = so )
112
112
113
- def onnx_embed (self , documents : List [str ]) -> np .ndarray :
113
+ def onnx_embed (self , documents : List [str ]) -> Tuple [ np .ndarray , np . ndarray ] :
114
114
encoded = self .tokenizer .encode_batch (documents )
115
115
input_ids = np .array ([e .ids for e in encoded ])
116
116
attention_mask = np .array ([e .attention_mask for e in encoded ])
@@ -126,9 +126,8 @@ def onnx_embed(self, documents: List[str]) -> np.ndarray:
126
126
)
127
127
128
128
model_output = self .model .run (None , onnx_input )
129
- last_hidden_state = model_output [0 ][:, 0 ]
130
- embeddings = normalize (last_hidden_state ).astype (np .float32 )
131
- return embeddings
129
+ embeddings = model_output [0 ]
130
+ return embeddings , attention_mask
132
131
133
132
134
133
class EmbeddingWorker (Worker ):
@@ -150,8 +149,8 @@ def start(cls, path: Path, model_name: str, max_length: int = 512, **kwargs: Any
150
149
151
150
def process (self , items : Iterable [Tuple [int , Any ]]) -> Iterable [Tuple [int , Any ]]:
152
151
for idx , batch in items :
153
- embeddings = self .model .onnx_embed (batch )
154
- yield idx , embeddings
152
+ embeddings , attn_mask = self .model .onnx_embed (batch )
153
+ yield idx , ( embeddings , attn_mask )
155
154
156
155
157
156
class Embedding (ABC ):
@@ -226,6 +225,18 @@ def list_supported_models(cls) -> List[Dict[str, Union[str, Union[int, float]]]]
226
225
"description" : "Multilingual model, e5-large. Recommend using this model for non-English languages" ,
227
226
"size_in_GB" : 2.24
228
227
},
228
+ {
229
+ "model" : "jinaai/jina-embeddings-v2-base-en" ,
230
+ "dim" : 768 ,
231
+ "description" : " English embedding model supporting 8192 sequence length" ,
232
+ "size_in_GB" : 0.55
233
+ },
234
+ {
235
+ "model" : "jinaai/jina-embeddings-v2-small-en" ,
236
+ "dim" : 512 ,
237
+ "description" : " English embedding model supporting 8192 sequence length" ,
238
+ "size_in_GB" : 0.13
239
+ }
229
240
]
230
241
231
242
@classmethod
@@ -282,6 +293,24 @@ def download_file_from_gcs(cls, url: str, output_path: str, show_progress: bool
282
293
progress_bar .close ()
283
294
return output_path
284
295
296
+ @classmethod
297
+ def download_files_from_huggingface (cls , repod_id : str , cache_dir : Optional [str ] = None ) -> str :
298
+ """
299
+ Downloads a model from HuggingFace Hub.
300
+ Args:
301
+ repod_id (str): The HF hub id (name) of the model to retrieve.
302
+ cache_dir (Optional[str]): The path to the cache directory.
303
+ Raises:
304
+ ValueError: If the model_name is not in the format <org>/<model> e.g. "jinaai/jina-embeddings-v2-small-en".
305
+ Returns:
306
+ Path: The path to the model directory.
307
+ """
308
+ from huggingface_hub import snapshot_download
309
+
310
+ return snapshot_download (
311
+ repo_id = repod_id , ignore_patterns = ["model.safetensors" , "pytorch_model.bin" ], cache_dir = cache_dir
312
+ )
313
+
285
314
@classmethod
286
315
def decompress_to_cache (cls , targz_path : str , cache_dir : str ):
287
316
"""
@@ -317,7 +346,7 @@ def decompress_to_cache(cls, targz_path: str, cache_dir: str):
317
346
318
347
return cache_dir
319
348
320
- def retrieve_model (self , model_name : str , cache_dir : str ) -> Path :
349
+ def retrieve_model_gcs (self , model_name : str , cache_dir : str ) -> Path :
321
350
"""
322
351
Retrieves a model from Google Cloud Storage.
323
352
@@ -361,6 +390,24 @@ def retrieve_model(self, model_name: str, cache_dir: str) -> Path:
361
390
362
391
return model_dir
363
392
393
+ def retrieve_model_hf (self , model_name : str , cache_dir : str ) -> Path :
394
+ """
395
+ Retrieves a model from HuggingFace Hub.
396
+ Args:
397
+ model_name (str): The name of the model to retrieve.
398
+ cache_dir (str): The path to the cache directory.
399
+ Raises:
400
+ ValueError: If the model_name is not in the format <org>/<model> e.g. BAAI/bge-base-en.
401
+ Returns:
402
+ Path: The path to the model directory.
403
+ """
404
+
405
+ assert (
406
+ "/" in model_name
407
+ ), "model_name must be in the format <org>/<model> e.g. jinaai/jina-embeddings-v2-small-en"
408
+
409
+ return Path (self .download_files_from_huggingface (repod_id = model_name , cache_dir = cache_dir ))
410
+
364
411
def passage_embed (self , texts : Iterable [str ], ** kwargs ) -> Iterable [np .ndarray ]:
365
412
"""
366
413
Embeds a list of text passages into a list of embeddings.
@@ -425,7 +472,7 @@ def __init__(
425
472
cache_dir .mkdir (parents = True , exist_ok = True )
426
473
427
474
self ._cache_dir = cache_dir
428
- self ._model_dir = self .retrieve_model (model_name , cache_dir )
475
+ self ._model_dir = self .retrieve_model_gcs (model_name , cache_dir )
429
476
self ._max_length = max_length
430
477
431
478
self .model = EmbeddingModel (self ._model_dir , self .model_name , max_length = max_length ,
@@ -464,7 +511,8 @@ def embed(
464
511
465
512
if parallel is None or is_small :
466
513
for batch in iter_batch (documents , batch_size ):
467
- yield from self .model .onnx_embed (batch )
514
+ embeddings , _ = self .model .onnx_embed (batch )
515
+ yield from normalize (embeddings [:, 0 ]).astype (np .float32 )
468
516
else :
469
517
start_method = "forkserver" if "forkserver" in get_all_start_methods () else "spawn"
470
518
params = {
@@ -474,7 +522,16 @@ def embed(
474
522
}
475
523
pool = ParallelWorkerPool (parallel , EmbeddingWorker , start_method = start_method )
476
524
for batch in pool .ordered_map (iter_batch (documents , batch_size ), ** params ):
477
- yield from batch
525
+ embeddings , _ = batch
526
+ yield from normalize (embeddings [:, 0 ]).astype (np .float32 )
527
+
528
+ @classmethod
529
+ def list_supported_models (cls ) -> List [Dict [str , Union [str , Union [int , float ]]]]:
530
+ """
531
+ Lists the supported models.
532
+ """
533
+ # jina models are not supported by this class
534
+ return [model for model in super ().list_supported_models () if not model ['model' ].startswith ('jinaai' )]
478
535
479
536
480
537
class DefaultEmbedding (FlagEmbedding ):
@@ -505,3 +562,97 @@ def embed(self, texts, batch_size: int = 256, parallel: int = None):
505
562
# Use your OpenAI model to embed the texts
506
563
# return self.model.embed(texts)
507
564
raise NotImplementedError
565
+
566
+
567
+ class JinaEmbedding (Embedding ):
568
+ def __init__ (
569
+ self ,
570
+ model_name : str = "jinaai/jina-embeddings-v2-base-en" ,
571
+ max_length : int = 512 ,
572
+ cache_dir : str = None ,
573
+ threads : int = None ,
574
+ ):
575
+ """
576
+ Args:
577
+ model_name (str): The name of the model to use.
578
+ max_length (int, optional): The maximum number of tokens. Defaults to 512. Unknown behavior for values > 512.
579
+ cache_dir (str, optional): The path to the cache directory. Defaults to `local_cache` in the current directory.
580
+ threads (int, optional): The number of threads single onnxruntime session can use. Defaults to None.
581
+ Raises:
582
+ ValueError: If the model_name is not in the format <org>/<model> e.g. BAAI/bge-base-en.
583
+ """
584
+ self .model_name = model_name
585
+
586
+ if cache_dir is None :
587
+ cache_dir = Path ("." ).resolve () / "local_cache"
588
+ cache_dir .mkdir (parents = True , exist_ok = True )
589
+
590
+ self ._cache_dir = cache_dir
591
+ self ._model_dir = self .retrieve_model_hf (model_name , cache_dir )
592
+ self ._max_length = max_length
593
+
594
+ self .model = EmbeddingModel (self ._model_dir , self .model_name , max_length = max_length ,
595
+ max_threads = threads )
596
+
597
+ def embed (
598
+ self , documents : Union [str , Iterable [str ]], batch_size : int = 256 , parallel : int = None
599
+ ) -> Iterable [np .ndarray ]:
600
+ """
601
+ Encode a list of documents into list of embeddings.
602
+ We use mean pooling with attention so that the model can handle variable-length inputs.
603
+ Args:
604
+ documents: Iterator of documents or single document to embed
605
+ batch_size: Batch size for encoding -- higher values will use more memory, but be faster
606
+ parallel:
607
+ If > 1, data-parallel encoding will be used, recommended for offline encoding of large datasets.
608
+ If 0, use all available cores.
609
+ If None, don't use data-parallel processing, use default onnxruntime threading instead.
610
+ Returns:
611
+ List of embeddings, one per document
612
+ """
613
+ is_small = False
614
+
615
+ if isinstance (documents , str ):
616
+ documents = [documents ]
617
+ is_small = True
618
+
619
+ if isinstance (documents , list ):
620
+ if len (documents ) < batch_size :
621
+ is_small = True
622
+
623
+ if parallel == 0 :
624
+ parallel = os .cpu_count ()
625
+
626
+ if parallel is None or is_small :
627
+ for batch in iter_batch (documents , batch_size ):
628
+ embeddings , attn_mask = self .model .onnx_embed (batch )
629
+ yield from normalize (self .mean_pooling (embeddings , attn_mask )).astype (np .float32 )
630
+ else :
631
+ start_method = "forkserver" if "forkserver" in get_all_start_methods () else "spawn"
632
+ params = {
633
+ "path" : self ._model_dir ,
634
+ "model_name" : self .model_name ,
635
+ "max_length" : self ._max_length ,
636
+ }
637
+ pool = ParallelWorkerPool (parallel , EmbeddingWorker , start_method = start_method )
638
+ for batch in pool .ordered_map (iter_batch (documents , batch_size ), ** params ):
639
+ embeddings , attn_mask = batch
640
+ yield from normalize (self .mean_pooling (embeddings , attn_mask )).astype (np .float32 )
641
+
642
+ @classmethod
643
+ def list_supported_models (cls ) -> List [Dict [str , Union [str , Union [int , float ]]]]:
644
+ """
645
+ Lists the supported models.
646
+ """
647
+ # only jina models are supported by this class
648
+ return [model for model in Embedding .list_supported_models () if model ['model' ].startswith ('jinaai' )]
649
+
650
+ @staticmethod
651
+ def mean_pooling (model_output , attention_mask ):
652
+ token_embeddings = model_output
653
+ input_mask_expanded = (np .expand_dims (attention_mask , axis = - 1 )).astype (float )
654
+
655
+ sum_embeddings = np .sum (token_embeddings * input_mask_expanded , axis = 1 )
656
+ mask_sum = np .clip (np .sum (input_mask_expanded , axis = 1 ), a_min = 1e-9 , a_max = None )
657
+
658
+ return sum_embeddings / mask_sum
0 commit comments