diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index ab992e0580..6eb9fb4cbb 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -3101,27 +3101,40 @@ def aten_embedding_bag_padding_idx( sparse: bool = False, per_sample_weights: Optional[TFloat] = None, include_last_offset: bool = False, - padding_idx: int = -1, + padding_idx: Optional[int] = None, ) -> Tuple[TFloat, TFloat, TFloat, TFloat]: """embedding_bag.padding_idx(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq, int mode, bool sparse, Tensor? per_sample_weights, bool include_last_offset, int? padding_idx) -> (Tensor, Tensor, Tensor, Tensor) We add default values for the attributes to accommodate _embedding_bag as well: _embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False, int padding_idx=-1) """ - assert padding_idx is not None, ( - "padding_idx must not be None. This is likely a dispatcher error" - ) if per_sample_weights is None: per_sample_weights = op.Expand(op.Constant(value_floats=[1.0]), op.Shape(indices)) per_sample_weights = op.CastLike(per_sample_weights, weight) - # Change padding_idx to positive value, -1 means the last index - if padding_idx < 0: - padding_idx = weight.shape[0] + padding_idx + if padding_idx is not None: + # Call the existing function for handling padding_idx + result, offset2bag, bag_size, max_indices = _aten_embedding_bag_1d_padding_idx_onnx( + weight, + indices, + offsets, + mode, + per_sample_weights, + include_last_offset, + padding_idx, + ) - result, offset2bag, bag_size, max_indices = _aten_embedding_bag_1d_padding_idx_onnx( - weight, indices, offsets, mode, per_sample_weights, include_last_offset, padding_idx + return result, offset2bag, bag_size, max_indices + + # When padding_idx is None, use the standard embedding_bag implementation + result, offset2bag, bag_size, max_indices = _aten_embedding_bag_onnx( + weight, + indices, + offsets, + mode, + per_sample_weights, + include_last_offset, ) return result, offset2bag, bag_size, max_indices diff --git a/tests/function_libs/torch_lib/extra_opinfo.py b/tests/function_libs/torch_lib/extra_opinfo.py index ca80cf5172..3d81896187 100644 --- a/tests/function_libs/torch_lib/extra_opinfo.py +++ b/tests/function_libs/torch_lib/extra_opinfo.py @@ -2210,6 +2210,44 @@ def __init__(self): sample_inputs_func=sample_inputs_embedding_bag_padding_idx, supports_out=False, ), + opinfo_core.OpInfo( + "ops.aten.embedding_bag.padding_idx_none", + op=torch.nn.functional.embedding_bag, + dtypes=common_dtype.floating_types_and_half(), + sample_inputs_func=lambda op_info, device, dtype, requires_grad: [ + opinfo_core.SampleInput( + torch.tensor( + [[1.0, 1.0, 1.0], [2.0, 2.0, 2.0], [3.0, 3.0, 3.0], [4.0, 4.0, 4.0]], + dtype=dtype, + device=device, + ), + args=( + torch.tensor([0, 1, 2, 3], dtype=torch.int64, device=device), + torch.tensor([0, 2], dtype=torch.int64, device=device), + ), + kwargs={"padding_idx": None}, + ) + ], + ), + opinfo_core.OpInfo( + "ops.aten.embedding_bag.padding_idx_int", + op=torch.nn.functional.embedding_bag, + dtypes=common_dtype.floating_types_and_half(), + sample_inputs_func=lambda op_info, device, dtype, requires_grad: [ + opinfo_core.SampleInput( + torch.tensor( + [[1.0, 1.0, 1.0], [2.0, 2.0, 2.0], [3.0, 3.0, 3.0]], + dtype=dtype, + device=device, + ), + args=( + torch.tensor([0, 1, 2], dtype=torch.int64, device=device), + torch.tensor([0, 2], dtype=torch.int64, device=device), + ), + kwargs={"padding_idx": 0}, + ) + ], + ), opinfo_core.OpInfo( "ops.aten.embedding_renorm", aten_name="embedding_renorm", diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index 7af7413185..183b23cc4c 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -185,6 +185,25 @@ def xfail( # Modify this section ########################################################## +def _embedding_bag_input_wrangler( + args: list[Any], kwargs: dict[str, Any] +) -> tuple[list[Any], dict[str, Any]]: + # ONNX attributes cannot be None; omit padding_idx if it's None. + if "padding_idx" in kwargs: + padding_idx = kwargs.pop("padding_idx") + if padding_idx is not None: + kwargs["padding_idx"] = int(padding_idx) + + # Ensure indices/offsets are int64 (positional: weight, indices, offsets, ...) + if len(args) >= 3: + if isinstance(args[1], torch.Tensor): + args[1] = args[1].to(torch.long) + if isinstance(args[2], torch.Tensor): + args[2] = args[2].to(torch.long) + + return args, kwargs + + def _amin_amax_input_wrangler( args: list[Any], kwargs: dict[str, Any] ) -> tuple[list[Any], dict[str, Any]]: @@ -1035,15 +1054,27 @@ def _where_input_wrangler( core_ops.aten_embedding_bag, tolerance={torch.float32: (1e-4, 5e-4)}, compare_shape_only_for_output=(1, 2, 3), + input_wrangler=_embedding_bag_input_wrangler, ).skip( dtypes=(torch.float16,), reason="fixme: results mismatch in torch nightly.", ), + TorchLibOpInfo( + "ops.aten.embedding_bag.padding_idx_none", + core_ops.aten_embedding_bag, + input_wrangler=_embedding_bag_input_wrangler, + ), + TorchLibOpInfo( + "ops.aten.embedding_bag.padding_idx_int", + core_ops.aten_embedding_bag, + input_wrangler=_embedding_bag_input_wrangler, + ), TorchLibOpInfo( "ops.aten.embedding_bag.padding_idx", core_ops.aten_embedding_bag_padding_idx, tolerance={torch.float16: (1e-2, 1e-2)}, compare_shape_only_for_output=(1, 2, 3), + input_wrangler=_embedding_bag_input_wrangler, ), TorchLibOpInfo( "ops.aten.embedding_renorm",