1111
1212
1313class Padding (BaseOperator ):
14- """Create an operator that will apply a embedding table to supplied indices.
15- This operator allows the user to supply an id lookup table if the indices supplied
16- via the id_lookup_table.
14+ """Create an operator that will apply right padding to a given sequence.
15+ This operator pads the sequence with a specified padding value up to a specified padding size.
16+ If the sequence is longer than the padding size,
17+ it is truncated to the first `padding size` elements.
1718
1819 Parameters
1920 ----------
20- embeddings : np.ndarray
21- numpy ndarray representing embedding values
22- lookup_key : str, optional
23- the name of the column that will be used as indices, by default "id"
24- embedding_name : str, optional
25- name of new column of embeddings, added to output, by default "embeddings"
26- id_lookup_table : np.array, optional
27- numpy array of values that represent embedding indices, by default None
21+ padding_size : int
22+ The target size for the padded sequence
23+ padding_value : Union[int, float]
24+ The value to be used for padding the sequence, by default 0
2825 """
2926
30- def __init__ (self , padding_size : int , padding_value : Union [int , float ]):
27+ def __init__ (self , padding_size : int , padding_value : Union [int , float ] = 0 ):
3128 self .padding_size = padding_size
3229 self .padding_value = padding_value
3330
@@ -76,7 +73,7 @@ def pad_put_zeros(column, padding_size, padding_val):
7673 # account for zero prepend
7774 array_lib = cupy if column .device == Device .GPU else np
7875 num_rows = len (column .offsets ) - 1
79- zeros = array_lib .zeros ((num_rows , padding_size )).flatten ()
76+ zeros = array_lib .zeros ((num_rows , padding_size )).flatten () + padding_val
8077 row_lengths = column .offsets [1 :] - column .offsets [:- 1 ]
8178 row_ranges = []
8279 starts = array_lib .arange (num_rows ) * padding_size
@@ -85,4 +82,5 @@ def pad_put_zeros(column, padding_size, padding_val):
8582 row_ranges .extend (array_lib .arange (int (starts [idx ]), int (ends [idx ])))
8683 array_lib .put (zeros , row_ranges , column .values )
8784 zeros = array_lib .reshape (zeros , (num_rows , padding_size ))
85+ zeros = zeros .astype (column .dtype .element_type .value )
8886 return zeros
0 commit comments