1313# See the License for the specific language governing permissions and
1414# limitations under the License.
1515#
16+ import os
1617from typing import Optional , Union
1718
1819import numpy as np
2526
2627
2728class EmbeddingOperator (BaseOperator ):
28- """Create an operator that will apply a torch embedding table to supplied indices.
29- This operator allows the user to supply an id lookup table if the indices supplied
30- via the id_lookup_table.
29+ """Create an operator that will apply an embedding table to supplied indices.
30+
31+ An id lookup table for the embeddings can be supplied with the argument ` id_lookup_table` .
3132
3233 Parameters
3334 ----------
@@ -39,6 +40,12 @@ class EmbeddingOperator(BaseOperator):
3940 name of new column of embeddings, added to output, by default "embeddings"
4041 id_lookup_table : np.array, optional
4142 numpy array of values that represent embedding indices, by default None
43+ mmap : bool, default False
44+ When loading embeddings from a file, specify whether we should memory map the file
45+ This is useful for accessing a large file without reading the entire file into memory.
46+ unknown_value : Union[float, int, np.ndarray]
47+ If an embedding index is not found.
48+ Specifies the value should we return for the corresponding embedding.
4249 """
4350
4451 def __init__ (
@@ -47,24 +54,75 @@ def __init__(
4754 lookup_key : str = "id" ,
4855 embedding_name : str = "embeddings" ,
4956 id_lookup_table : Optional [Union [np .ndarray , str ]] = None ,
50- mmap = False ,
57+ mmap : bool = False ,
58+ unknown_value : Union [float , int , np .ndarray ] = 0 ,
5159 ):
52- if mmap :
53- embeddings = np .load (embeddings , mmap_mode = "r" )
54- id_lookup_table = np .load (id_lookup_table ) if id_lookup_table else None
60+ if isinstance (embeddings , (str , os .PathLike )):
61+ mmap_mode = "r" if mmap else None
62+ embeddings = np .load (embeddings , mmap_mode = mmap_mode )
63+ elif isinstance (embeddings , np .ndarray ):
64+ pass
65+ else :
66+ raise ValueError (
67+ f"Unsupported type '{ type (embeddings )} ' passed to argument `embeddings` "
68+ f"of '{ type (self ).__name__ } '. "
69+ "Expected either a numpy.ndarray "
70+ "or a (string or pathlike object corresponding to a numpy file) "
71+ "containing embeddings. "
72+ )
5573 self .embeddings = embeddings
74+
75+ embedding_index_mapping = None
76+ if isinstance (id_lookup_table , (str , os .PathLike )):
77+ _ids = np .load (id_lookup_table )
78+ embedding_index_mapping = self ._get_embedding_index_mapping (_ids )
79+ elif isinstance (id_lookup_table , np .ndarray ):
80+ _ids = id_lookup_table
81+ embedding_index_mapping = self ._get_embedding_index_mapping (_ids )
82+ elif id_lookup_table is None :
83+ pass
84+ else :
85+ raise ValueError (
86+ f"Unsupported type '{ type (id_lookup_table )} ' passed to argument `id_lookup_table` "
87+ f"of '{ type (self ).__name__ } '. "
88+ "Expected either a numpy.ndarray "
89+ "or a (string or pathlike object corresponding to a numpy file) "
90+ "containing the IDs that correspond to the embeddings. "
91+ )
92+ self .embedding_index_mapping = embedding_index_mapping
93+
5694 self .lookup_key = lookup_key
5795 self .embedding_name = embedding_name
58- self .id_lookup_table = id_lookup_table
96+ self .unknown_value = unknown_value
97+
98+ def _get_embedding_index_mapping (self , ids ):
99+ expected_ids_shape = (self .embeddings .shape [0 ],)
100+ assert ids .shape == expected_ids_shape , (
101+ "IDs provided must match the number of embeddings. "
102+ f"Expected IDs with shape { expected_ids_shape } "
103+ f"Received IDs with shape: { ids .shape } "
104+ f"Embeddings shape: { self .embeddings .shape } "
105+ )
106+ id_to_index_mapping = dict (zip (ids , range (len (ids ))))
107+ return id_to_index_mapping
59108
60109 def transform (
61110 self , col_selector : ColumnSelector , transformable : Transformable
62111 ) -> Transformable :
63112 keys = transformable [self .lookup_key ]
64113 indices = keys .cpu ().values
65- if self .id_lookup_table is not None :
66- indices = np .in1d (self .id_lookup_table , indices )
114+
115+ if self .embedding_index_mapping is not None :
116+ indices = np .array ([self .embedding_index_mapping .get (_id , - 1 ) for _id in indices ])
117+
67118 embeddings = self .embeddings [indices ]
119+
120+ # set unknown embedding to zero
121+ for idx in np .ndindex (indices .shape ):
122+ embedding_index = indices [idx ]
123+ if embedding_index == - 1 :
124+ embeddings [idx ] = self .unknown_value
125+
68126 embeddings_col = TensorColumn (embeddings , offsets = keys .cpu ().offsets )
69127 transformable [self .embedding_name ] = (
70128 embeddings_col .gpu () if keys .device == Device .GPU else embeddings_col
0 commit comments