Skip to content

Commit

Permalink
Remove try-catch block and apply fixes to enable torch.onnx.dynamo_ex…
Browse files Browse the repository at this point in the history
…port to succeed
  • Loading branch information
a-gardner1 committed May 17, 2024
1 parent ebc1b96 commit 453783f
Showing 1 changed file with 16 additions and 32 deletions.
48 changes: 16 additions & 32 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from __future__ import annotations

import math
import re
from typing import Any, Optional, Sequence, Tuple, Union

from onnxscript import (
Expand Down Expand Up @@ -8360,13 +8359,6 @@ def aten_unique_consecutive(
raise NotImplementedError()


_NOT_IMPLEMENTED_UNIQUE = re.compile(
r"NOT_IMPLEMENTED\s*:\s*Could\s+not\s+find\s+an\s+implementation\s+for\s+Unique"
)
"""
A pattern to detect an unsupported (not implemented) Unique operator
"""

@torch_op("aten::unique", trace_only=True)
def aten_unique(
self: TensorType,
Expand All @@ -8377,18 +8369,10 @@ def aten_unique(
) -> tuple[TensorType, TensorType, TensorType]:
"""unique(Tensor self, bool sorted=True, bool return_inverse=False, bool return_counts=False, int? dim=None) -> (Tensor, Tensor?, Tensor?)"""

try:
if dim is None:
unique_values, inverse_indices, counts = aten_unique2(self, sorted, return_inverse, return_counts)
else:
unique_values, inverse_indices, counts = aten_unique_dim(self, dim, sorted, return_inverse, return_counts)
except Exception as e:
# try to provide a more informative error message
if _NOT_IMPLEMENTED_UNIQUE.search(str(e)) is not None:
raise NotImplementedError(
f"'onnxruntime' does not yet support Unique(11) operator with dtype={self.dtype}'"
) from e
raise
if dim is None:
unique_values, inverse_indices, counts = aten_unique2(self, sorted, return_inverse, return_counts)
else:
unique_values, inverse_indices, counts = aten_unique_dim(self, dim, sorted, return_inverse, return_counts)
if return_inverse:
if return_counts:
result = unique_values, inverse_indices, counts
Expand All @@ -8410,7 +8394,11 @@ def aten_unique2(
) -> tuple[TensorType, TensorType, TensorType]:
"""unique(Tensor self, bool sorted=True, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor)"""

unique_values, _, inverse_indices, counts = op.Unique(self, axis=None, sorted=sorted)
unique_values, indices, inverse_indices, counts = op.Unique(self, axis=None, sorted=sorted)
# HACK: force indices to be in the graph so that it gets a name during optimization
# Otherwise an error will be raised in `onnxscript.Scope.lookup_or_create`
indices_size = op.Shape(indices)
counts = op.Reshape(counts, indices_size)
input_size = op.Shape(self)
inverse_indices = op.Reshape(inverse_indices, input_size)
return unique_values, inverse_indices, counts
Expand All @@ -8426,19 +8414,15 @@ def aten_unique_dim(
) -> tuple[TensorType, TensorType, TensorType]:
"""unique_dim(Tensor self, int dim, bool sorted=True, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor)"""

unique_values, _, inverse_indices, counts = op.Unique(self, axis=dim, sorted=sorted)
unique_values, indices, inverse_indices, counts = op.Unique(self, axis=dim, sorted=sorted)
# HACK: force indices to be in the graph so that it gets a name during optimization
# Otherwise an error will be raised in `onnxscript.Scope.lookup_or_create`
indices_size = op.Shape(indices)
counts = op.Reshape(counts, indices_size)
input_size = op.Shape(self)
# PyTorch accepts negative dim as reversed counting
input_rank = op.Size(input_size)
dim = input_rank + dim
dim = dim % input_rank
starts = op.Reshape(dim, [-1])
ends = op.Reshape(dim + 1, [-1])
input_dim_size = op.Slice(input_size, starts=starts, ends=ends)
inverse_indices = op.Reshape(inverse_indices, input_dim_size)
inverse_indices = op.Reshape(inverse_indices, op.Reshape(input_size[dim], [-1]))
output_size = op.Shape(unique_values)
output_dim_size = op.Slice(output_size, starts=starts, ends=ends)
counts = op.Reshape(counts, output_dim_size)
counts = op.Reshape(counts, op.Reshape(output_size[dim], [-1]))
return unique_values, inverse_indices, counts


Expand Down

0 comments on commit 453783f

Please sign in to comment.