Skip to content

Commit a2cc7ee

Browse files
authored
Fix the output type of Padding op (#145)
test padding+embedding op
1 parent d9e97b4 commit a2cc7ee

File tree

2 files changed

+40
-13
lines changed

2 files changed

+40
-13
lines changed

merlin/dataloader/ops/padding.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,23 +11,20 @@
1111

1212

1313
class Padding(BaseOperator):
14-
"""Create an operator that will apply a embedding table to supplied indices.
15-
This operator allows the user to supply an id lookup table if the indices supplied
16-
via the id_lookup_table.
14+
"""Create an operator that will apply right padding to a given sequence.
15+
This operator pads the sequence with a specified padding value up to a specified padding size.
16+
If the sequence is longer than the padding size,
17+
it is truncated to the first `padding size` elements.
1718
1819
Parameters
1920
----------
20-
embeddings : np.ndarray
21-
numpy ndarray representing embedding values
22-
lookup_key : str, optional
23-
the name of the column that will be used as indices, by default "id"
24-
embedding_name : str, optional
25-
name of new column of embeddings, added to output, by default "embeddings"
26-
id_lookup_table : np.array, optional
27-
numpy array of values that represent embedding indices, by default None
21+
padding_size : int
22+
The target size for the padded sequence
23+
padding_value : Union[int, float]
24+
The value to be used for padding the sequence, by default 0
2825
"""
2926

30-
def __init__(self, padding_size: int, padding_value: Union[int, float]):
27+
def __init__(self, padding_size: int, padding_value: Union[int, float] = 0):
3128
self.padding_size = padding_size
3229
self.padding_value = padding_value
3330

@@ -76,7 +73,7 @@ def pad_put_zeros(column, padding_size, padding_val):
7673
# account for zero prepend
7774
array_lib = cupy if column.device == Device.GPU else np
7875
num_rows = len(column.offsets) - 1
79-
zeros = array_lib.zeros((num_rows, padding_size)).flatten()
76+
zeros = array_lib.zeros((num_rows, padding_size)).flatten() + padding_val
8077
row_lengths = column.offsets[1:] - column.offsets[:-1]
8178
row_ranges = []
8279
starts = array_lib.arange(num_rows) * padding_size
@@ -85,4 +82,5 @@ def pad_put_zeros(column, padding_size, padding_val):
8582
row_ranges.extend(array_lib.arange(int(starts[idx]), int(ends[idx])))
8683
array_lib.put(zeros, row_ranges, column.values)
8784
zeros = array_lib.reshape(zeros, (num_rows, padding_size))
85+
zeros = zeros.astype(column.dtype.element_type.value)
8886
return zeros

tests/unit/dataloader/test_embeddings.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from merlin.core.dispatch import HAS_GPU
2323
from merlin.dataloader.loader_base import LoaderBase as Loader # noqa
2424
from merlin.dataloader.ops.embeddings import EmbeddingOperator
25+
from merlin.dataloader.ops.padding import Padding
2526
from merlin.io import Dataset
2627
from merlin.schema import Tags
2728
from merlin.table import TensorColumn, TensorTable
@@ -245,3 +246,31 @@ def test_embedding_np_dl_with_lookup_ragged(
245246
assert (embeddings_offs == id_offsets).all()
246247
full_len += int(batch[0]["embeddings"].shape[0])
247248
assert full_len == offsets.shape[0] - 1
249+
250+
251+
def test_embedding_with_padding():
252+
max_length = 10
253+
batch_size = 3
254+
id_embeddings = np.random.rand(1000, 16)
255+
df = pd.DataFrame(
256+
{
257+
"id": [[0, 1, 2], [3, 4], [5, 6, 7, 8]],
258+
}
259+
)
260+
261+
dataset = Dataset(df)
262+
transform = (
263+
["id"]
264+
>> Padding(padding_size=max_length, padding_value=0)
265+
>> EmbeddingOperator(id_embeddings, lookup_key="id", embedding_name="id_embedding")
266+
)
267+
data_loader = Loader(
268+
dataset,
269+
batch_size=batch_size,
270+
transforms=transform,
271+
shuffle=False,
272+
)
273+
x, _ = data_loader.peek()
274+
assert x["id"].values.shape == (batch_size, max_length)
275+
assert x["id_embedding"].values.shape == (batch_size, max_length, 16)
276+
assert data_loader.output_schema.column_names == ["id", "id_embedding"]

0 commit comments

Comments
 (0)