diff --git a/transformers4rec/torch/utils/padding.py b/transformers4rec/torch/utils/padding.py index b34333ad2..5d11dbfb3 100644 --- a/transformers4rec/torch/utils/padding.py +++ b/transformers4rec/torch/utils/padding.py @@ -25,7 +25,7 @@ def _pad_dense_tensor(t: torch.Tensor, length: int) -> torch.Tensor: return F.pad(input=t, pad=(0, pad_diff, 0, 0)) elif len(t.shape) == 3: pad_diff = length - t.shape[1] - return F.pad(input=t, pad=(0, pad_diff, 0, 0, 0, 0)) + return F.pad(input=t, (0, 0, 0, pad_diff, 0, 0)) else: return t