Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add unique op #1547

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
71 changes: 69 additions & 2 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -8372,16 +8372,83 @@
raise NotImplementedError()


@torch_op("aten::_unique")
def aten__unique(
self: TensorType,
sorted: bool = True, # pylint: disable=unused-argument
return_inverse: bool = False,
) -> tuple[TensorType, TensorType]:
"""_unique(Tensor self, bool sorted=True, bool return_inverse=False) -> (Tensor, Tensor)"""

unique_values, _, inverse_indices, _ = op.Unique(self, axis=None, sorted=True)
input_size = op.Shape(self)
if return_inverse:
inverse_indices = op.Reshape(inverse_indices, input_size)
else:
input_numel = op.ReduceProd(input_size, keepdims=False)
if input_numel == 0:
inverse_indices = op.Reshape(inverse_indices, input_size)

Check warning on line 8390 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L8390

Added line #L8390 was not covered by tests
else:
inverse_indices = op.ConstantOfShape([0], value=[0])
return unique_values, inverse_indices


@torch_op("aten::_unique2")
def aten__unique2(
self: TensorType,
sorted: bool = True, # pylint: disable=unused-argument
return_inverse: bool = False,
Fixed Show fixed Hide fixed
a-gardner1 marked this conversation as resolved.
Show resolved Hide resolved
return_counts: bool = False,
) -> tuple[TensorType, TensorType, TensorType]:
"""_unique2(Tensor self, bool sorted=True, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor)"""

unique_values, indices, inverse_indices, counts = op.Unique(self, axis=None, sorted=True)
input_size = op.Shape(self)
if return_inverse:
inverse_indices = op.Reshape(inverse_indices, input_size)
else:
input_numel = op.ReduceProd(input_size, keepdims=False)
if input_numel == 0:
inverse_indices = op.Reshape(inverse_indices, input_size)

Check warning on line 8412 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L8412

Added line #L8412 was not covered by tests
else:
inverse_indices = op.ConstantOfShape([0], value=[0])
if return_counts:
# HACK: force indices to be in the graph so that it gets a name during optimization
# Otherwise an error will be raised in `onnxscript.Scope.lookup_or_create`
indices_size = op.Shape(indices)
counts = op.Reshape(counts, indices_size)
else:
counts = op.ConstantOfShape([0], value=[0])
return unique_values, inverse_indices, counts


@torch_op("aten::unique_dim")
def aten_unique_dim(
self: TensorType,
dim: int,
sorted: bool = True,
sorted: bool = True, # pylint: disable=unused-argument
return_inverse: bool = False,
return_counts: bool = False,
is_cuda: bool = False
Fixed Show fixed Hide fixed
) -> tuple[TensorType, TensorType, TensorType]:
"""unique_dim(Tensor self, int dim, bool sorted=True, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor)"""

raise NotImplementedError()
unique_values, indices, inverse_indices, counts = op.Unique(self, axis=dim, sorted=True)

Check warning on line 8436 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L8436

Added line #L8436 was not covered by tests
if return_inverse:
input_size = op.Shape(self)
inverse_indices = op.Reshape(inverse_indices, op.Reshape(input_size[dim], [-1]))

Check warning on line 8439 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L8438-L8439

Added lines #L8438 - L8439 were not covered by tests
else:
inverse_indices = op.ConstantOfShape([0], value=[0])

Check warning on line 8441 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L8441

Added line #L8441 was not covered by tests
if return_counts:
# HACK: force indices to be in the graph so that it gets a name during optimization
# Otherwise an error will be raised in `onnxscript.Scope.lookup_or_create`
indices_size = op.Shape(indices)
counts = op.Reshape(counts, indices_size)
output_size = op.Shape(unique_values)
counts = op.Reshape(counts, op.Reshape(output_size[dim], [-1]))

Check warning on line 8448 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L8445-L8448

Added lines #L8445 - L8448 were not covered by tests
else:
counts = op.ConstantOfShape([0], value=[0])
return unique_values, inverse_indices, counts

Check warning on line 8451 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L8450-L8451

Added lines #L8450 - L8451 were not covered by tests


def aten_unique_dim_consecutive(
Expand Down
50 changes: 50 additions & 0 deletions tests/function_libs/torch_lib/extra_opinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -1826,6 +1826,32 @@ def shape(size, rank, with_batch_channel=True):
)


def sample_inputs__unique(op_info, device, dtype, requires_grad, **kwargs):
for sample in common_methods_invocations.sample_inputs_unique(
op_info, device, dtype, requires_grad, **kwargs):
return_counts = sample.kwargs.pop('return_counts', None)
dim = sample.kwargs.pop('dim', None)
# take only those samples that do not ask for counts or a dim
if not return_counts and dim is None:
yield sample


def sample_inputs__unique2(op_info, device, dtype, requires_grad, **kwargs):
for sample in common_methods_invocations.sample_inputs_unique(
op_info, device, dtype, requires_grad, **kwargs):
# take only those samples that do not ask for a dim
if sample.kwargs.pop('dim', None) is None:
yield sample


def sample_inputs_unique_dim(op_info, device, dtype, requires_grad, **kwargs):
for sample in common_methods_invocations.sample_inputs_unique(
op_info, device, dtype, requires_grad, **kwargs):
# take only those samples that ask for a dim
if sample.kwargs.get('dim') is not None:
yield sample


class _TestParamsMaxPoolEmptyStrideBase:
# Adapted from https://github.com/pytorch/pytorch/blob/d6d55f8590eab05d2536756fb4efcfb2d07eb81a/torch/testing/_internal/common_methods_invocations.py#L3203
def __init__(self):
Expand Down Expand Up @@ -2413,4 +2439,28 @@ def __init__(self):
sample_inputs_func=sample_inputs_non_max_suppression,
supports_out=False,
),
opinfo_core.OpInfo(
"ops.aten._unique.default",
aten_name="_unique.default",
dtypes=common_dtype.floating_types_and(torch.float16, torch.int64, torch.int8),
sample_inputs_func=sample_inputs__unique,
supports_out=False,
supports_autograd=False,
),
opinfo_core.OpInfo(
"ops.aten._unique2.default",
aten_name="_unique2.default",
dtypes=common_dtype.floating_types_and(torch.float16, torch.int64, torch.int8),
sample_inputs_func=sample_inputs__unique2,
supports_out=False,
supports_autograd=False,
),
opinfo_core.OpInfo(
"ops.aten.unique_dim.default",
aten_name="unique_dim.default",
dtypes=common_dtype.floating_types_and(torch.float16, torch.int64, torch.int8),
sample_inputs_func=sample_inputs_unique_dim,
supports_out=False,
supports_autograd=False,
)
]
16 changes: 16 additions & 0 deletions tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2325,6 +2325,22 @@ def _where_input_wrangler(
TorchLibOpInfo(
"transpose", core_ops.aten_transpose_complex, trace_only=True, complex=True
),
TorchLibOpInfo(
"ops.aten._unique.default",
core_ops.aten__unique,
),
TorchLibOpInfo(
"ops.aten._unique2.default",
core_ops.aten__unique2,
),
TorchLibOpInfo(
"ops.aten.unique_dim.default",
core_ops.aten_unique_dim,
).skip(
device_type="cpu",
reason="ops.aten.unique_dim.default returns different shapes for optional outputs on CPU/CUDA."
" Our implementation is based on that for CUDA"
),
TorchLibOpInfo(
"var_mean",
core_ops.aten_var_mean,
Expand Down
Loading