Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add unique op #1547

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
Open

Conversation

a-gardner1
Copy link

Add support for exporting torch.unique following the conclusion of pytorch/pytorch#113118.

Copy link

codecov bot commented May 15, 2024

Codecov Report

Attention: Patch coverage is 57.14286% with 18 lines in your changes missing coverage. Please review.

Project coverage is 77.50%. Comparing base (69ae7f4) to head (f9885f1).
Report is 171 commits behind head on main.

Files with missing lines Patch % Lines
onnxscript/function_libs/torch_lib/ops/core.py 57.14% 16 Missing and 2 partials ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1547      +/-   ##
==========================================
- Coverage   77.56%   77.50%   -0.07%     
==========================================
  Files         214      216       +2     
  Lines       23186    23381     +195     
  Branches     3975     4033      +58     
==========================================
+ Hits        17984    18121     +137     
- Misses       4433     4477      +44     
- Partials      769      783      +14     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@a-gardner1 a-gardner1 marked this pull request as draft May 15, 2024 22:27
Copy link
Collaborator

@justinchuby justinchuby left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your contribution! Could you follow the CLA bot's instruction to get that cleared?

Comment on lines 8385 to 8390
except Exception as e:
# try to provide a more informative error message
if _NOT_IMPLEMENTED_UNIQUE.search(str(e)) is not None:
raise NotImplementedError(
f"'onnxruntime' does not yet support Unique(11) operator with dtype={self.dtype}'"
) from e
Copy link
Collaborator

@justinchuby justinchuby May 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would remove this try-catch as the function here is symbolic; we don't expect them to raise any errors

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in b528a6a

@justinchuby justinchuby added the topic: torch_lib Related to the torch/aten function lib in development label May 15, 2024
@a-gardner1
Copy link
Author

Thanks for your contribution! Could you follow the CLA bot's instruction to get that cleared?

Yea, I may have jumped the gun a bit. Working on officially getting permission from my employer.

@a-gardner1
Copy link
Author

a-gardner1 commented May 16, 2024

@a-gardner1 please read the following Contributor License Agreement(CLA). If you agree with the CLA, please reply with the following information.

@microsoft-github-policy-service agree [company="{your company}"]

Options:

  • (default - no company specified) I have sole ownership of intellectual property rights to my Submissions and I am not making Submissions in the course of work for my employer.
@microsoft-github-policy-service agree
  • (when company given) I am making Submissions in the course of work for my employer (or my employer has intellectual property rights in my Submissions by contract or applicable law). I have permission from my employer to make Submissions and enter into this Agreement on behalf of my employer. By signing below, the defined term “You” includes me and my employer.
@microsoft-github-policy-service agree company="Microsoft"

Contributor License Agreement

@microsoft-github-policy-service agree [company="Radiance Technologies"]

@microsoft-github-policy-service agree company="Radiance Technologies"

@a-gardner1
Copy link
Author

@microsoft-github-policy-service agree company="Radiance Technologies"

@@ -438,6 +438,34 @@ def _where_input_wrangler(
return args, kwargs


def _unique_unsorted_xfail_matcher(
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@justinchuby I'm not sure what the preferred behavior is here. Should we match torch.unique and ignore the sorted argument (i.e., always sort in aten_unique) or respect the argument and deviate in accordance with this matcher?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if the argument leads to different behavior in cuda/cpu etc? I assume sorted=False means it can be sorted, but it doesn't need to be; and there are some potential performance gain by turning it off. If that's the interpretation I would keep the argument. Otherwise ignoring the argument and matching behavior would also be nice.

Copy link
Author

@a-gardner1 a-gardner1 May 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am investigating differences in behavior between cuda/cpu and have found at least one already (unique_dim on CPU ignores the return_inverse and return_counts arguments whereas the CUDA impl does not). How should these differences be handled? Can the op registration be conditioned by the device somehow, or should I favor CUDA over CPU?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Matching CUDA for now is preferable. Thanks!

@a-gardner1 a-gardner1 force-pushed the wip-113118-add-unique-ops branch from 453783f to b528a6a Compare May 17, 2024 20:35
@a-gardner1 a-gardner1 marked this pull request as ready for review May 17, 2024 20:35
# HACK: force indices to be in the graph so that it gets a name during optimization
# Otherwise an error will be raised in `onnxscript.Scope.lookup_or_create`
indices_size = op.Shape(indices)
counts = op.Reshape(counts, indices_size)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I want to note that the way that this function was written in 1d74d59 is functionally equivalent but yields an error in onnxscript.Scope.lookup_or_create because it causes modified to be True in onnxscript.optimizer.optimize, thus causing a second loop of optimization that crashes in the first call to inline_simple_functions.

This seems indicative of a potential bug to me, but I am not knowledgeable enough about the codebase to suggest a cause or fix.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@justinchuby
Copy link
Collaborator

Thanks for completing the CLA. I will take a look next week

@justinchuby justinchuby self-assigned this May 18, 2024
result = unique_values, counts
else:
result = unique_values
return result
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need to always return the same number of values. Consider returning None when they are not available?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doing so deviates from the behavior of torch.unique and causes this assertion in the unit tests to fail:

assert len(flattened_torch_outputs) == len(flattened_function_outputs)

Please advise on how to address this.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does torch.ops.aten.unique exhibit the same behavior? If it always returns three variables, consider creating a new OpInfo for torch.ops.aten.unique similar to

opinfo_core.OpInfo(
"ops.aten._native_batch_norm_legit.no_stats",
aten_name="_native_batch_norm_legit.no_stats",
dtypes=common_dtype.floating_types_and(torch.bfloat16),
dtypesIfCUDA=common_dtype.floating_types_and(torch.float16, torch.bfloat16),
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
assert_jit_shape_analysis=True,
sample_inputs_func=sample_inputs__native_batch_norm_legit_no_stats,
),
. You may remove the xfail with the custom OpInfo too because you may simply remove the xfail cases.

You may adapt the sample function from https://github.com/pytorch/pytorch/blob/b948b1ad7a9cf61c9692506c60c295fd40e00f43/torch/testing/_internal/common_methods_invocations.py#L3346-L3372

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the pointer to extra_opinfo. It turns out torch.ops.aten.unique does not exist, but torch.ops.aten._unique does. Added OpInfo for it, _unique2, and unique_dim in 14d03b5

"""unique(Tensor self, bool sorted=True, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor)"""

unique_values, indices, inverse_indices, counts = op.Unique(self, axis=None, sorted=sorted)
# HACK: force indices to be in the graph so that it gets a name during optimization
Copy link
Collaborator

@justinchuby justinchuby May 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a bug we should fix elsewhere? saw comment below

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this could possibly be considered a different bug. The other one is a side-effect of onnxscript.optimizer.constant_folding.fold_constants, whereas this one is a side-effect of the function linked below, which converts the names of unused outputs to empty strings but only removes them if they are trailing. Since inverse_indices and counts are used, it leads to an error being raised in onnxscript.Scope.lookup_or_create due to the empty string name given to indices.

def remove_unused_optional_outputs(

@a-gardner1 a-gardner1 force-pushed the wip-113118-add-unique-ops branch from 56c06cf to 7e6d906 Compare May 20, 2024 17:24
@@ -8380,8 +8380,21 @@ def aten__unique(
) -> tuple[TensorType, TensorType]:
"""_unique(Tensor self, bool sorted=True, bool return_inverse=False) -> (Tensor, Tensor)"""

unique_values, _, inverse_indices, _ = op.Unique(self, axis=None, sorted=True)
unique_values, indices, inverse_indices, _ = op.Unique(self, axis=None, sorted=True)
# HACK: force indices to be in the graph so that it gets a name during optimization
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest removing all hacks. I will go fix what's necessary where the bug is. We are also moving to prefer trace_only=True for new functions so if you can include the flag in @torch_op that would be awesome.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That would be awesome. The hacks are definitely getting out of hand. I'll wait for that fix so that I can continue to test with this locally.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have a short script handy that will reproduce the error?

Copy link
Author

@a-gardner1 a-gardner1 May 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if __name__ == '__main__':
    import logging
    import torch
    import numpy as np
    import onnx
    import onnxruntime as ort
    for i in range(16):
        sorted = bool(i & 1)
        return_inverse = bool((i & 2) > 1)
        return_counts = bool((i & 4) > 1)
        dim = 0 if bool((i & 8) > 1) else None

        print(
            f"Testing sorted={sorted}, return_inverse={return_inverse}, return_counts={return_counts}, dim={dim}"
        )

        def test_function(
                x: torch.Tensor,
                s: bool = sorted,
                ri: bool = return_inverse,
                rc: bool = return_counts,
                d: int | None = dim) -> Any:
            result = torch.unique(
                x,
                sorted=s,
                return_inverse=ri,
                return_counts=rc,
                dim=d)
            return result

        onnx_program = torch.onnx.dynamo_export(
            test_function,
            torch.arange(10),
            export_options=torch.onnx.ExportOptions(
                dynamic_shapes=True,
                diagnostic_options=torch.onnx.DiagnosticOptions(
                    verbosity_level=logging.DEBUG)))
        onnx_program.save("torch_unique.onnx")
        onnx_inputs = onnx_program.adapt_torch_inputs_to_onnx(torch.arange(10))
        onnx_outputs = onnx_program(*onnx_inputs)
        loaded_onnx_program = onnx.load("torch_unique.onnx")
        onnx.checker.check_model(loaded_onnx_program)
        ort_session = ort.InferenceSession("torch_unique.onnx")
        inputs = np.random.randint(0, 10, 10)
        print(f"Inputs: {inputs}")
        outputs = ort_session.run(None,
                                  {"l_x_": inputs})
        print(f"Outputs: {outputs}")
    print("Success")

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, you should also test using the nightly release of PyTorch with the changes in pytorch/pytorch#126561.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is trace_only=True expected to require significant changes to the way one implements an op? It appears that enabling the flag breaks passing a value to op.ConstantOfShape and also breaks indexing a shape.

For example, op.ConstantOfShape([0], value=[0]) must become op.Cast(op.ConstantOfShape([0]), to=INT64.dtype), and output_size[dim] must become op.Slice(output_size, [dim], [dim+1]).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your observation is correct. This may be the case because the gaps in implementation we have. Bridging the gaps is in our roadmap but is not the highest priority for the team.

pytorchmergebot pushed a commit to pytorch/pytorch that referenced this pull request Jun 1, 2024
Follow-up to #113118 and #124306.

Developed in coordination with the solution to microsoft/onnxscript#1547

This PR adds the missing fake tensor implementation for `aten.unique_dim`, thus enabling tracing and compilation of `torch.unique` when `dim` is not None.

Local testing has proceeded with the following simple script (provided that one has checked out the changes in microsoft/onnxscript#1547):

```python
    import onnx
    import onnxruntime as ort
    import logging
    import numpy as np
    onnx_program = torch.onnx.dynamo_export(
        lambda x: torch.unique(x,
                               dim=0,
                               return_inverse=True),
        torch.arange(10),
        export_options=torch.onnx.ExportOptions(
            dynamic_shapes=True,
            diagnostic_options=torch.onnx.DiagnosticOptions(
                verbosity_level=logging.DEBUG)))
    onnx_program.save("torch_unique.onnx")
    onnx_inputs = onnx_program.adapt_torch_inputs_to_onnx(torch.arange(10))
    onnx_outputs = onnx_program(*onnx_inputs)
    loaded_onnx_program = onnx.load("torch_unique.onnx")
    onnx.checker.check_model(loaded_onnx_program)
    ort_session = ort.InferenceSession("torch_unique.onnx")
    inputs = np.random.randint(0, 10, 10)
    print(f"Inputs: {inputs}")
    outputs = ort_session.run(None,
                              {
                                  "l_x_": inputs
                              })
    print(f"Outputs: {outputs}")
    print("Success")
```

Co-authored-by: Edward Z. Yang <[email protected]>
Pull Request resolved: #126561
Approved by: https://github.com/ezyang
petrex pushed a commit to petrex/pytorch that referenced this pull request Jun 5, 2024
Follow-up to pytorch#113118 and pytorch#124306.

Developed in coordination with the solution to microsoft/onnxscript#1547

This PR adds the missing fake tensor implementation for `aten.unique_dim`, thus enabling tracing and compilation of `torch.unique` when `dim` is not None.

Local testing has proceeded with the following simple script (provided that one has checked out the changes in microsoft/onnxscript#1547):

```python
    import onnx
    import onnxruntime as ort
    import logging
    import numpy as np
    onnx_program = torch.onnx.dynamo_export(
        lambda x: torch.unique(x,
                               dim=0,
                               return_inverse=True),
        torch.arange(10),
        export_options=torch.onnx.ExportOptions(
            dynamic_shapes=True,
            diagnostic_options=torch.onnx.DiagnosticOptions(
                verbosity_level=logging.DEBUG)))
    onnx_program.save("torch_unique.onnx")
    onnx_inputs = onnx_program.adapt_torch_inputs_to_onnx(torch.arange(10))
    onnx_outputs = onnx_program(*onnx_inputs)
    loaded_onnx_program = onnx.load("torch_unique.onnx")
    onnx.checker.check_model(loaded_onnx_program)
    ort_session = ort.InferenceSession("torch_unique.onnx")
    inputs = np.random.randint(0, 10, 10)
    print(f"Inputs: {inputs}")
    outputs = ort_session.run(None,
                              {
                                  "l_x_": inputs
                              })
    print(f"Outputs: {outputs}")
    print("Success")
```

Co-authored-by: Edward Z. Yang <[email protected]>
Pull Request resolved: pytorch#126561
Approved by: https://github.com/ezyang
@a-gardner1
Copy link
Author

Circling back around to this @justinchuby. At the time, I had been waiting for you to resolve the bug that required a hacky workaround, but I realize that might not be clear.

There were also a couple of other potential unresolved bugs outside the scope of this PR, e.g., this comment.

How would you like to proceed?

@justinchuby
Copy link
Collaborator

Sorry for missing the clarity. I would suggest that you remove all the hacks so that the code is at its desirable state. If tests fail because of that, that’s ok. I will then go ahead to fix what’s needed. (After I’m back from vacation)

@a-gardner1
Copy link
Author

Sorry for missing the clarity. I would suggest that you remove all the hacks so that the code is at its desirable state. If tests fail because of that, that’s ok. I will then go ahead to fix what’s needed. (After I’m back from vacation)

Sounds good. FYI, the hacks were removed in b8b4cb1. As a reminder, the unit tests within onnxscript pass(ed) without the hacks, but the full export from torch to ONNX via Dynamo fails. This script should reproduce the errors with torch==2.4 or later (any release that includes pytorch/pytorch#126561).

If I get a chance, I'll try to rebase this PR and resolve conflicts first.

@kabyanil
Copy link

I am implementing a CTC decoder class in pytorch -

class GreedyCTCDecoder(torch.nn.Module):
    def __init__(self, labels, blank=0):
        super().__init__()
        self.labels = labels
        self.blank = blank

    def forward(self, emission: torch.Tensor) -> List[str]:
        """Given a sequence emission over labels, get the best path
        Args:
          emission (Tensor): Logit tensors. Shape `[num_seq, num_label]`.

        Returns:
          List[str]: The resulting transcript
        """
        indices = torch.argmax(emission, dim=-1)  # [num_seq,]
        indices = torch.unique_consecutive(indices, dim=-1)
        indices = [i for i in indices if i != self.blank]
        joined = "".join([self.labels[i] for i in indices])
        return joined.replace("|", " ").strip().split()


greedy_decoder = GreedyCTCDecoder(tokens)

I'm not able to export this class to onnx. Here is my error -

/usr/local/lib/python3.10/dist-packages/torch/onnx/_internal/_exporter_legacy.py:116: UserWarning: torch.onnx.dynamo_export only implements opset version 18 for now. If you need to use a different opset version, please register them with register_custom_op.
  warnings.warn(
---------------------------------------------------------------------------
DynamicOutputShapeException               Traceback (most recent call last)
[/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py](https://localhost:8080/#) in run_node(tracer, node, args, kwargs, nnmodule)
   2131             if op == "call_function":
-> 2132                 return node.target(*args, **kwargs)
   2133             elif op == "call_method":

53 frames
DynamicOutputShapeException: aten.unique_consecutive.default

The above exception was the direct cause of the following exception:

RuntimeError                              Traceback (most recent call last)
RuntimeError: Failed running call_function <function boolean_dispatch.<locals>.fn at 0x7ba8c8f9dd80>(*(FakeTensor(..., size=(s0,), dtype=torch.int64),), **{'dim': -1}):
aten.unique_consecutive.default

During handling of the above exception, another exception occurred:

Unsupported                               Traceback (most recent call last)
Unsupported: dynamic shape operator: aten.unique_consecutive.default; Operator does not have a meta kernel that supports dynamic output shapes, please report an issue to PyTorch

from user code:
   File "<ipython-input-43-c1d03e7f78a6>", line 16, in forward
    indices = torch.unique_consecutive(indices, dim=-1)

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


The above exception was the direct cause of the following exception:

OnnxExporterError                         Traceback (most recent call last)
[/usr/local/lib/python3.10/dist-packages/torch/onnx/_internal/_exporter_legacy.py](https://localhost:8080/#) in dynamo_export(model, export_options, *model_args, **model_kwargs)
   1231             f"Please report a bug on PyTorch Github: {_PYTORCH_GITHUB_ISSUES_URL}"
   1232         )
-> 1233         raise errors.OnnxExporterError(message) from e
   1234 
   1235 

OnnxExporterError: Failed to export the model to ONNX. Generating SARIF report at 'report_dynamo_export.sarif'. SARIF is a standard format for the output of static analysis tools. SARIF logs can be loaded in VS Code SARIF viewer extension, or SARIF web viewer (https://microsoft.github.io/sarif-web-component/). Please report a bug on PyTorch Github: https://github.com/pytorch/pytorch/issues

How can I resolve this?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
topic: torch_lib Related to the torch/aten function lib in development
Projects
Development

Successfully merging this pull request may close these issues.

3 participants