Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 22 additions & 9 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3101,27 +3101,40 @@ def aten_embedding_bag_padding_idx(
sparse: bool = False,
per_sample_weights: Optional[TFloat] = None,
include_last_offset: bool = False,
padding_idx: int = -1,
padding_idx: Optional[int] = None,
) -> Tuple[TFloat, TFloat, TFloat, TFloat]:
"""embedding_bag.padding_idx(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq, int mode, bool sparse, Tensor? per_sample_weights, bool include_last_offset, int? padding_idx) -> (Tensor, Tensor, Tensor, Tensor)

We add default values for the attributes to accommodate _embedding_bag as well:
_embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False, int padding_idx=-1)
"""
assert padding_idx is not None, (
"padding_idx must not be None. This is likely a dispatcher error"
)

if per_sample_weights is None:
per_sample_weights = op.Expand(op.Constant(value_floats=[1.0]), op.Shape(indices))
per_sample_weights = op.CastLike(per_sample_weights, weight)

# Change padding_idx to positive value, -1 means the last index
if padding_idx < 0:
padding_idx = weight.shape[0] + padding_idx
if padding_idx is not None:
# Call the existing function for handling padding_idx
result, offset2bag, bag_size, max_indices = _aten_embedding_bag_1d_padding_idx_onnx(
weight,
indices,
offsets,
mode,
per_sample_weights,
include_last_offset,
padding_idx,
)

result, offset2bag, bag_size, max_indices = _aten_embedding_bag_1d_padding_idx_onnx(
weight, indices, offsets, mode, per_sample_weights, include_last_offset, padding_idx
return result, offset2bag, bag_size, max_indices

# When padding_idx is None, use the standard embedding_bag implementation
result, offset2bag, bag_size, max_indices = _aten_embedding_bag_onnx(
weight,
indices,
offsets,
mode,
per_sample_weights,
include_last_offset,
)

return result, offset2bag, bag_size, max_indices
Expand Down
38 changes: 38 additions & 0 deletions tests/function_libs/torch_lib/extra_opinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2210,6 +2210,44 @@ def __init__(self):
sample_inputs_func=sample_inputs_embedding_bag_padding_idx,
supports_out=False,
),
opinfo_core.OpInfo(
"ops.aten.embedding_bag.padding_idx_none",
op=torch.nn.functional.embedding_bag,
dtypes=common_dtype.floating_types_and_half(),
sample_inputs_func=lambda op_info, device, dtype, requires_grad: [
opinfo_core.SampleInput(
torch.tensor(
[[1.0, 1.0, 1.0], [2.0, 2.0, 2.0], [3.0, 3.0, 3.0], [4.0, 4.0, 4.0]],
dtype=dtype,
device=device,
),
args=(
torch.tensor([0, 1, 2, 3], dtype=torch.int64, device=device),
torch.tensor([0, 2], dtype=torch.int64, device=device),
),
kwargs={"padding_idx": None},
)
],
),
opinfo_core.OpInfo(
"ops.aten.embedding_bag.padding_idx_int",
op=torch.nn.functional.embedding_bag,
dtypes=common_dtype.floating_types_and_half(),
sample_inputs_func=lambda op_info, device, dtype, requires_grad: [
opinfo_core.SampleInput(
torch.tensor(
[[1.0, 1.0, 1.0], [2.0, 2.0, 2.0], [3.0, 3.0, 3.0]],
dtype=dtype,
device=device,
),
args=(
torch.tensor([0, 1, 2], dtype=torch.int64, device=device),
torch.tensor([0, 2], dtype=torch.int64, device=device),
),
kwargs={"padding_idx": 0},
)
],
),
opinfo_core.OpInfo(
"ops.aten.embedding_renorm",
aten_name="embedding_renorm",
Expand Down
31 changes: 31 additions & 0 deletions tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,25 @@ def xfail(
# Modify this section ##########################################################


def _embedding_bag_input_wrangler(
args: list[Any], kwargs: dict[str, Any]
) -> tuple[list[Any], dict[str, Any]]:
# ONNX attributes cannot be None; omit padding_idx if it's None.
if "padding_idx" in kwargs:
padding_idx = kwargs.pop("padding_idx")
if padding_idx is not None:
kwargs["padding_idx"] = int(padding_idx)

# Ensure indices/offsets are int64 (positional: weight, indices, offsets, ...)
if len(args) >= 3:
if isinstance(args[1], torch.Tensor):
args[1] = args[1].to(torch.long)
if isinstance(args[2], torch.Tensor):
args[2] = args[2].to(torch.long)

return args, kwargs


def _amin_amax_input_wrangler(
args: list[Any], kwargs: dict[str, Any]
) -> tuple[list[Any], dict[str, Any]]:
Expand Down Expand Up @@ -1035,15 +1054,27 @@ def _where_input_wrangler(
core_ops.aten_embedding_bag,
tolerance={torch.float32: (1e-4, 5e-4)},
compare_shape_only_for_output=(1, 2, 3),
input_wrangler=_embedding_bag_input_wrangler,
).skip(
dtypes=(torch.float16,),
reason="fixme: results mismatch in torch nightly.",
),
TorchLibOpInfo(
"ops.aten.embedding_bag.padding_idx_none",
core_ops.aten_embedding_bag,
input_wrangler=_embedding_bag_input_wrangler,
),
TorchLibOpInfo(
"ops.aten.embedding_bag.padding_idx_int",
core_ops.aten_embedding_bag,
input_wrangler=_embedding_bag_input_wrangler,
),
TorchLibOpInfo(
"ops.aten.embedding_bag.padding_idx",
core_ops.aten_embedding_bag_padding_idx,
tolerance={torch.float16: (1e-2, 1e-2)},
compare_shape_only_for_output=(1, 2, 3),
input_wrangler=_embedding_bag_input_wrangler,
),
TorchLibOpInfo(
"ops.aten.embedding_renorm",
Expand Down
Loading