Skip to content

Commit

Permalink
Bump version
Browse files Browse the repository at this point in the history
  • Loading branch information
I8dNLo committed Dec 24, 2024
1 parent c0f6a5b commit 9631970
Showing 1 changed file with 21 additions and 16 deletions.
37 changes: 21 additions & 16 deletions fastembed/text/onnx_embedding.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, Iterable, List, Optional, Sequence, Type, Union
from typing import Any, Iterable, Optional, Sequence, Type, Union

import numpy as np

Expand Down Expand Up @@ -168,16 +168,15 @@
"model_file": "onnx/model.onnx",
},
{
"model": "akshayballal/colpali-v1.2-merged",
"dim": 128,
"description": "",
"license": "mit",
"size_in_GB": 6.08,
"model": "jinaai/jina-clip-v1",
"dim": 768,
"description": "Text embeddings, Multimodal (text&image), English, Prefixes for queries/documents: not necessary, 2024 year",
"license": "apache-2.0",
"size_in_GB": 0.55,
"sources": {
"hf": "akshayballal/colpali-v1.2-merged-onnx",
"hf": "jinaai/jina-clip-v1",
},
"additional_files": ["model.onnx_data"],
"model_file": "model.onnx",
"model_file": "onnx/text_model.onnx",
},
]

Expand All @@ -186,12 +185,12 @@ class OnnxTextEmbedding(TextEmbeddingBase, OnnxTextModel[np.ndarray]):
"""Implementation of the Flag Embedding model."""

@classmethod
def list_supported_models(cls) -> List[Dict[str, Any]]:
def list_supported_models(cls) -> list[dict[str, Any]]:
"""
Lists the supported models.
Returns:
List[Dict[str, Any]]: A list of dictionaries containing the model information.
list[dict[str, Any]]: A list of dictionaries containing the model information.
"""
return supported_onnx_models

Expand All @@ -202,7 +201,7 @@ def __init__(
threads: Optional[int] = None,
providers: Optional[Sequence[OnnxProvider]] = None,
cuda: bool = False,
device_ids: Optional[List[int]] = None,
device_ids: Optional[list[int]] = None,
lazy_load: bool = False,
device_id: Optional[int] = None,
**kwargs,
Expand All @@ -218,7 +217,7 @@ def __init__(
Mutually exclusive with the `cuda` and `device_ids` arguments. Defaults to None.
cuda (bool, optional): Whether to use cuda for inference. Mutually exclusive with `providers`
Defaults to False.
device_ids (Optional[List[int]], optional): The list of device ids to use for data parallel processing in
device_ids (Optional[list[int]], optional): The list of device ids to use for data parallel processing in
workers. Should be used with `cuda=True`, mutually exclusive with `providers`. Defaults to None.
lazy_load (bool, optional): Whether to load the model during class initialization or on demand.
Should be set to True when using multiple-gpu and parallel encoding. Defaults to False.
Expand Down Expand Up @@ -291,16 +290,22 @@ def _get_worker_class(cls) -> Type["TextEmbeddingWorker"]:
return OnnxTextEmbeddingWorker

def _preprocess_onnx_input(
self, onnx_input: Dict[str, np.ndarray], **kwargs
) -> Dict[str, np.ndarray]:
self, onnx_input: dict[str, np.ndarray], **kwargs
) -> dict[str, np.ndarray]:
"""
Preprocess the onnx input.
"""
return onnx_input

def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[np.ndarray]:
embeddings = output.model_output
return normalize(embeddings[:, 0]).astype(np.float32)
if embeddings.ndim == 3: # (batch_size, seq_len, embedding_dim)
processed_embeddings = embeddings[:, 0]
elif embeddings.ndim == 2: # (batch_size, embedding_dim)
processed_embeddings = embeddings
else:
raise ValueError(f"Unsupported embedding shape: {embeddings.shape}")
return normalize(processed_embeddings).astype(np.float32)

def load_onnx_model(self) -> None:
self._load_onnx_model(
Expand Down

0 comments on commit 9631970

Please sign in to comment.