1
1
import logging
2
- from typing import Callable , Dict , Iterable , List , Literal , Optional , Tuple , Union
2
+ from typing import Any , Callable , Dict , Iterable , List , Literal , Optional , Tuple , Union
3
3
4
4
import faiss
5
5
import numpy as np
@@ -60,6 +60,9 @@ class AutoModelForEmbedding(nn.Module):
60
60
from the Hugging Face Hub with that name.
61
61
"""
62
62
63
+ encode_kwargs : Dict [str , Any ] = dict ()
64
+ show_progress : bool = False
65
+
63
66
def __init__ (
64
67
self ,
65
68
model_name_or_path : str ,
@@ -184,7 +187,17 @@ def forward_from_loader(self, inputs):
184
187
return embeddings
185
188
186
189
def forward_from_text (self , texts ):
187
- return self .forward_from_loader (texts )
190
+ batch_dict = self .tokenizer (
191
+ texts ,
192
+ max_length = self .max_length ,
193
+ return_attention_mask = False ,
194
+ padding = False ,
195
+ truncation = True ,
196
+ )
197
+ batch_dict ["input_ids" ] = [input_ids + [self .tokenizer .eos_token_id ] for input_ids in batch_dict ["input_ids" ]]
198
+ batch_dict = self .tokenizer .pad (batch_dict , padding = True , return_attention_mask = True , return_tensors = "pt" )
199
+ batch_dict .pop ("token_type_ids" )
200
+ return self .forward_from_loader (batch_dict )
188
201
189
202
def encode (
190
203
self ,
@@ -197,7 +210,7 @@ def encode(
197
210
device : str = None ,
198
211
normalize_embeddings : bool = False ,
199
212
):
200
- if isinstance (inputs , DataLoader ):
213
+ if isinstance (inputs , ( BatchEncoding , Dict ) ):
201
214
return self .encode_from_loader (
202
215
loader = inputs ,
203
216
batch_size = batch_size ,
@@ -208,7 +221,7 @@ def encode(
208
221
device = device ,
209
222
normalize_embeddings = normalize_embeddings ,
210
223
)
211
- elif isinstance (inputs , (str , Iterable )):
224
+ elif isinstance (inputs , (str , List , Tuple )):
212
225
return self .encode_from_text (
213
226
sentences = inputs ,
214
227
batch_size = batch_size ,
@@ -219,6 +232,17 @@ def encode(
219
232
device = device ,
220
233
normalize_embeddings = normalize_embeddings ,
221
234
)
235
+ else :
236
+ raise ValueError
237
+
238
+ def embed_documents (self , texts : List [str ]) -> List [List [float ]]:
239
+ """Compute doc embeddings using a HuggingFace transformer model."""
240
+ embeddings = self .encode (texts , show_progress_bar = self .show_progress , ** self .encode_kwargs )
241
+ return embeddings .tolist ()
242
+
243
+ def embed_query (self , text : str ) -> List [float ]:
244
+ """Compute query embeddings using a HuggingFace transformer model."""
245
+ return self .embed_documents ([text ])[0 ]
222
246
223
247
def encode_from_loader (
224
248
self ,
0 commit comments