Skip to content

Commit f222d7c

Browse files
authored
add JinaEmbeddings class (#67)
* add JinaEmbeddings class * fix tests dimensions --------- Co-authored-by: Joan Fontanals Martinez <[email protected]>
1 parent d64b8f4 commit f222d7c

File tree

2 files changed

+183
-28
lines changed

2 files changed

+183
-28
lines changed

fastembed/embedding.py

+161-10
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def __init__(
110110
self.tokenizer = self.load_tokenizer(self.path, max_length=max_length)
111111
self.model = ort.InferenceSession(str(model_path), providers=onnx_providers, sess_options=so)
112112

113-
def onnx_embed(self, documents: List[str]) -> np.ndarray:
113+
def onnx_embed(self, documents: List[str]) -> Tuple[np.ndarray, np.ndarray]:
114114
encoded = self.tokenizer.encode_batch(documents)
115115
input_ids = np.array([e.ids for e in encoded])
116116
attention_mask = np.array([e.attention_mask for e in encoded])
@@ -126,9 +126,8 @@ def onnx_embed(self, documents: List[str]) -> np.ndarray:
126126
)
127127

128128
model_output = self.model.run(None, onnx_input)
129-
last_hidden_state = model_output[0][:, 0]
130-
embeddings = normalize(last_hidden_state).astype(np.float32)
131-
return embeddings
129+
embeddings = model_output[0]
130+
return embeddings, attention_mask
132131

133132

134133
class EmbeddingWorker(Worker):
@@ -150,8 +149,8 @@ def start(cls, path: Path, model_name: str, max_length: int = 512, **kwargs: Any
150149

151150
def process(self, items: Iterable[Tuple[int, Any]]) -> Iterable[Tuple[int, Any]]:
152151
for idx, batch in items:
153-
embeddings = self.model.onnx_embed(batch)
154-
yield idx, embeddings
152+
embeddings, attn_mask = self.model.onnx_embed(batch)
153+
yield idx, (embeddings, attn_mask)
155154

156155

157156
class Embedding(ABC):
@@ -226,6 +225,18 @@ def list_supported_models(cls) -> List[Dict[str, Union[str, Union[int, float]]]]
226225
"description": "Multilingual model, e5-large. Recommend using this model for non-English languages",
227226
"size_in_GB": 2.24
228227
},
228+
{
229+
"model": "jinaai/jina-embeddings-v2-base-en",
230+
"dim": 768,
231+
"description": " English embedding model supporting 8192 sequence length",
232+
"size_in_GB": 0.55
233+
},
234+
{
235+
"model": "jinaai/jina-embeddings-v2-small-en",
236+
"dim": 512,
237+
"description": " English embedding model supporting 8192 sequence length",
238+
"size_in_GB": 0.13
239+
}
229240
]
230241

231242
@classmethod
@@ -282,6 +293,24 @@ def download_file_from_gcs(cls, url: str, output_path: str, show_progress: bool
282293
progress_bar.close()
283294
return output_path
284295

296+
@classmethod
297+
def download_files_from_huggingface(cls, repod_id: str, cache_dir: Optional[str] = None) -> str:
298+
"""
299+
Downloads a model from HuggingFace Hub.
300+
Args:
301+
repod_id (str): The HF hub id (name) of the model to retrieve.
302+
cache_dir (Optional[str]): The path to the cache directory.
303+
Raises:
304+
ValueError: If the model_name is not in the format <org>/<model> e.g. "jinaai/jina-embeddings-v2-small-en".
305+
Returns:
306+
Path: The path to the model directory.
307+
"""
308+
from huggingface_hub import snapshot_download
309+
310+
return snapshot_download(
311+
repo_id=repod_id, ignore_patterns=["model.safetensors", "pytorch_model.bin"], cache_dir=cache_dir
312+
)
313+
285314
@classmethod
286315
def decompress_to_cache(cls, targz_path: str, cache_dir: str):
287316
"""
@@ -317,7 +346,7 @@ def decompress_to_cache(cls, targz_path: str, cache_dir: str):
317346

318347
return cache_dir
319348

320-
def retrieve_model(self, model_name: str, cache_dir: str) -> Path:
349+
def retrieve_model_gcs(self, model_name: str, cache_dir: str) -> Path:
321350
"""
322351
Retrieves a model from Google Cloud Storage.
323352
@@ -361,6 +390,24 @@ def retrieve_model(self, model_name: str, cache_dir: str) -> Path:
361390

362391
return model_dir
363392

393+
def retrieve_model_hf(self, model_name: str, cache_dir: str) -> Path:
394+
"""
395+
Retrieves a model from HuggingFace Hub.
396+
Args:
397+
model_name (str): The name of the model to retrieve.
398+
cache_dir (str): The path to the cache directory.
399+
Raises:
400+
ValueError: If the model_name is not in the format <org>/<model> e.g. BAAI/bge-base-en.
401+
Returns:
402+
Path: The path to the model directory.
403+
"""
404+
405+
assert (
406+
"/" in model_name
407+
), "model_name must be in the format <org>/<model> e.g. jinaai/jina-embeddings-v2-small-en"
408+
409+
return Path(self.download_files_from_huggingface(repod_id=model_name, cache_dir=cache_dir))
410+
364411
def passage_embed(self, texts: Iterable[str], **kwargs) -> Iterable[np.ndarray]:
365412
"""
366413
Embeds a list of text passages into a list of embeddings.
@@ -425,7 +472,7 @@ def __init__(
425472
cache_dir.mkdir(parents=True, exist_ok=True)
426473

427474
self._cache_dir = cache_dir
428-
self._model_dir = self.retrieve_model(model_name, cache_dir)
475+
self._model_dir = self.retrieve_model_gcs(model_name, cache_dir)
429476
self._max_length = max_length
430477

431478
self.model = EmbeddingModel(self._model_dir, self.model_name, max_length=max_length,
@@ -464,7 +511,8 @@ def embed(
464511

465512
if parallel is None or is_small:
466513
for batch in iter_batch(documents, batch_size):
467-
yield from self.model.onnx_embed(batch)
514+
embeddings, _ = self.model.onnx_embed(batch)
515+
yield from normalize(embeddings[:, 0]).astype(np.float32)
468516
else:
469517
start_method = "forkserver" if "forkserver" in get_all_start_methods() else "spawn"
470518
params = {
@@ -474,7 +522,16 @@ def embed(
474522
}
475523
pool = ParallelWorkerPool(parallel, EmbeddingWorker, start_method=start_method)
476524
for batch in pool.ordered_map(iter_batch(documents, batch_size), **params):
477-
yield from batch
525+
embeddings, _ = batch
526+
yield from normalize(embeddings[:, 0]).astype(np.float32)
527+
528+
@classmethod
529+
def list_supported_models(cls) -> List[Dict[str, Union[str, Union[int, float]]]]:
530+
"""
531+
Lists the supported models.
532+
"""
533+
# jina models are not supported by this class
534+
return [model for model in super().list_supported_models() if not model['model'].startswith('jinaai')]
478535

479536

480537
class DefaultEmbedding(FlagEmbedding):
@@ -505,3 +562,97 @@ def embed(self, texts, batch_size: int = 256, parallel: int = None):
505562
# Use your OpenAI model to embed the texts
506563
# return self.model.embed(texts)
507564
raise NotImplementedError
565+
566+
567+
class JinaEmbedding(Embedding):
568+
def __init__(
569+
self,
570+
model_name: str = "jinaai/jina-embeddings-v2-base-en",
571+
max_length: int = 512,
572+
cache_dir: str = None,
573+
threads: int = None,
574+
):
575+
"""
576+
Args:
577+
model_name (str): The name of the model to use.
578+
max_length (int, optional): The maximum number of tokens. Defaults to 512. Unknown behavior for values > 512.
579+
cache_dir (str, optional): The path to the cache directory. Defaults to `local_cache` in the current directory.
580+
threads (int, optional): The number of threads single onnxruntime session can use. Defaults to None.
581+
Raises:
582+
ValueError: If the model_name is not in the format <org>/<model> e.g. BAAI/bge-base-en.
583+
"""
584+
self.model_name = model_name
585+
586+
if cache_dir is None:
587+
cache_dir = Path(".").resolve() / "local_cache"
588+
cache_dir.mkdir(parents=True, exist_ok=True)
589+
590+
self._cache_dir = cache_dir
591+
self._model_dir = self.retrieve_model_hf(model_name, cache_dir)
592+
self._max_length = max_length
593+
594+
self.model = EmbeddingModel(self._model_dir, self.model_name, max_length=max_length,
595+
max_threads=threads)
596+
597+
def embed(
598+
self, documents: Union[str, Iterable[str]], batch_size: int = 256, parallel: int = None
599+
) -> Iterable[np.ndarray]:
600+
"""
601+
Encode a list of documents into list of embeddings.
602+
We use mean pooling with attention so that the model can handle variable-length inputs.
603+
Args:
604+
documents: Iterator of documents or single document to embed
605+
batch_size: Batch size for encoding -- higher values will use more memory, but be faster
606+
parallel:
607+
If > 1, data-parallel encoding will be used, recommended for offline encoding of large datasets.
608+
If 0, use all available cores.
609+
If None, don't use data-parallel processing, use default onnxruntime threading instead.
610+
Returns:
611+
List of embeddings, one per document
612+
"""
613+
is_small = False
614+
615+
if isinstance(documents, str):
616+
documents = [documents]
617+
is_small = True
618+
619+
if isinstance(documents, list):
620+
if len(documents) < batch_size:
621+
is_small = True
622+
623+
if parallel == 0:
624+
parallel = os.cpu_count()
625+
626+
if parallel is None or is_small:
627+
for batch in iter_batch(documents, batch_size):
628+
embeddings, attn_mask = self.model.onnx_embed(batch)
629+
yield from normalize(self.mean_pooling(embeddings, attn_mask)).astype(np.float32)
630+
else:
631+
start_method = "forkserver" if "forkserver" in get_all_start_methods() else "spawn"
632+
params = {
633+
"path": self._model_dir,
634+
"model_name": self.model_name,
635+
"max_length": self._max_length,
636+
}
637+
pool = ParallelWorkerPool(parallel, EmbeddingWorker, start_method=start_method)
638+
for batch in pool.ordered_map(iter_batch(documents, batch_size), **params):
639+
embeddings, attn_mask = batch
640+
yield from normalize(self.mean_pooling(embeddings, attn_mask)).astype(np.float32)
641+
642+
@classmethod
643+
def list_supported_models(cls) -> List[Dict[str, Union[str, Union[int, float]]]]:
644+
"""
645+
Lists the supported models.
646+
"""
647+
# only jina models are supported by this class
648+
return [model for model in Embedding.list_supported_models() if model['model'].startswith('jinaai')]
649+
650+
@staticmethod
651+
def mean_pooling(model_output, attention_mask):
652+
token_embeddings = model_output
653+
input_mask_expanded = (np.expand_dims(attention_mask, axis=-1)).astype(float)
654+
655+
sum_embeddings = np.sum(token_embeddings * input_mask_expanded, axis=1)
656+
mask_sum = np.clip(np.sum(input_mask_expanded, axis=1), a_min=1e-9, a_max=None)
657+
658+
return sum_embeddings / mask_sum

tests/test_onnx_embeddings.py

+22-18
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,32 @@
11
import os
2-
2+
import pytest
33
import numpy as np
4-
from tqdm import tqdm
54

6-
from fastembed.embedding import DefaultEmbedding, Embedding
5+
from fastembed.embedding import DefaultEmbedding, JinaEmbedding
76

87
CANONICAL_VECTOR_VALUES = {
9-
"BAAI/bge-small-en": np.array([-0.0232, -0.0255, 0.0174, -0.0639, -0.0006]),
10-
"BAAI/bge-small-en-v1.5": np.array([0.01522374, -0.02271799, 0.00860278, -0.07424029, 0.00386434]),
11-
"BAAI/bge-small-zh-v1.5": np.array([-0.01023294, 0.07634465, 0.0691722 , -0.04458365, -0.03160762]),
12-
"BAAI/bge-base-en": np.array([0.0115, 0.0372, 0.0295, 0.0121, 0.0346]),
8+
"BAAI/bge-small-en": np.array([-0.0232, -0.0255, 0.0174, -0.0639, -0.0006]),
9+
"BAAI/bge-small-en-v1.5": np.array([0.01522374, -0.02271799, 0.00860278, -0.07424029, 0.00386434]),
10+
"BAAI/bge-small-zh-v1.5": np.array([-0.01023294, 0.07634465, 0.0691722, -0.04458365, -0.03160762]),
11+
"BAAI/bge-base-en": np.array([0.0115, 0.0372, 0.0295, 0.0121, 0.0346]),
1312
"BAAI/bge-base-en-v1.5": np.array([0.01129394, 0.05493144, 0.02615099, 0.00328772, 0.02996045]),
14-
"sentence-transformers/all-MiniLM-L6-v2": np.array([0.0259, 0.0058, 0.0114, 0.0380, -0.0233]),
15-
"intfloat/multilingual-e5-large": np.array([0.0098, 0.0045, 0.0066, -0.0354, 0.0070]),
13+
"sentence-transformers/all-MiniLM-L6-v2": np.array([0.0259, 0.0058, 0.0114, 0.0380, -0.0233]),
14+
"intfloat/multilingual-e5-large": np.array([0.0098, 0.0045, 0.0066, -0.0354, 0.0070]),
15+
"jinaai/jina-embeddings-v2-small-en": np.array([-0.0455, -0.0428, -0.0122, 0.0613, 0.0015]),
16+
"jinaai/jina-embeddings-v2-base-en": np.array([-0.0332, -0.0509, 0.0287, -0.0043, -0.0077]),
1617
}
1718

1819

19-
def test_default_embedding():
20+
@pytest.mark.parametrize('embedding_class', [DefaultEmbedding, JinaEmbedding])
21+
def test_embedding(embedding_class):
2022
is_ubuntu_ci = os.getenv("IS_UBUNTU_CI")
2123

22-
for model_desc in Embedding.list_supported_models():
24+
for model_desc in embedding_class.list_supported_models():
2325
if is_ubuntu_ci == "false" and model_desc["size_in_GB"] > 1:
2426
continue
2527

2628
dim = model_desc["dim"]
27-
model = DefaultEmbedding(model_name=model_desc["model"])
29+
model = embedding_class(model_name=model_desc["model"])
2830

2931
docs = ["hello world", "flag embedding"]
3032
embeddings = list(model.embed(docs))
@@ -35,18 +37,20 @@ def test_default_embedding():
3537
assert np.allclose(embeddings[0, :canonical_vector.shape[0]], canonical_vector, atol=1e-3), model_desc["model"]
3638

3739

38-
def test_batch_embedding():
39-
model = DefaultEmbedding()
40+
@pytest.mark.parametrize('n_dims,embedding_class', [(384, DefaultEmbedding), (768, JinaEmbedding)])
41+
def test_batch_embedding(n_dims, embedding_class):
42+
model = embedding_class()
4043

4144
docs = ["hello world", "flag embedding"] * 100
4245
embeddings = list(model.embed(docs, batch_size=10))
4346
embeddings = np.stack(embeddings, axis=0)
4447

45-
assert embeddings.shape == (200, 384)
48+
assert embeddings.shape == (200, n_dims)
4649

4750

48-
def test_parallel_processing():
49-
model = DefaultEmbedding()
51+
@pytest.mark.parametrize('n_dims,embedding_class', [(384, DefaultEmbedding), (768, JinaEmbedding)])
52+
def test_parallel_processing(n_dims, embedding_class):
53+
model = embedding_class()
5054

5155
docs = ["hello world", "flag embedding"] * 100
5256
embeddings = list(model.embed(docs, batch_size=10, parallel=2))
@@ -58,6 +62,6 @@ def test_parallel_processing():
5862
embeddings_3 = list(model.embed(docs, batch_size=10, parallel=0))
5963
embeddings_3 = np.stack(embeddings_3, axis=0)
6064

61-
assert embeddings.shape == (200, 384)
65+
assert embeddings.shape == (200, n_dims)
6266
assert np.allclose(embeddings, embeddings_2, atol=1e-3)
6367
assert np.allclose(embeddings, embeddings_3, atol=1e-3)

0 commit comments

Comments
 (0)