Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions backends/qualcomm/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from .decompose_remainder import DecomposeRemainder
from .decompose_roll import DecomposeRoll
from .decompose_silu import DecomposeSilu
from .decompose_tan import DecomposeTan
from .decompose_threshold import DecomposeThreshold
from .decompose_triu import DecomposeTriu
from .decompose_trunc import DecomposeTrunc
Expand Down Expand Up @@ -88,6 +89,7 @@
DecomposeRemainder,
DecomposeRoll,
DecomposeSilu,
DecomposeTan,
DecomposeThreshold,
DecomposeTriu,
DecomposeTrunc,
Expand Down
71 changes: 71 additions & 0 deletions backends/qualcomm/_passes/decompose_tan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Copyright (c) Qualcomm Innovation Center, Inc.
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.dialects.edge._ops import EdgeOpOverload
from executorch.exir.pass_base import ExportPass, PassResult

from .utils import copy_meta


class DecomposeTan(ExportPass):
"""
Decompose tan(x) = sin(x) / cos(x)
"""

def __init__(self):
super(DecomposeTan, self).__init__()
self.targets = {
torch.ops.aten.tan.default,
exir_ops.edge.aten.tan.default,
}

def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
graph = graph_module.graph

for node in list(graph.nodes):
if node.op == "call_function" and node.target in self.targets:
is_edge = isinstance(node.target, EdgeOpOverload)

sin_op = (
exir_ops.edge.aten.sin.default
if is_edge
else torch.ops.aten.sin.default
)
cos_op = (
exir_ops.edge.aten.cos.default
if is_edge
else torch.ops.aten.cos.default
)
div_op = (
exir_ops.edge.aten.div.Tensor
if is_edge
else torch.ops.aten.div.Tensor
)

with graph.inserting_before(node):
sin_node = graph.create_node(
"call_function", sin_op, (node.args[0],)
)
sin_node.meta = copy_meta(node.meta)

cos_node = graph.create_node(
"call_function", cos_op, (node.args[0],)
)
cos_node.meta = copy_meta(node.meta)

div_node = graph.create_node(
"call_function", div_op, (sin_node, cos_node)
)
div_node.meta = copy_meta(node.meta)

for user in node.users.copy():
user.replace_input_with(node, div_node)

graph.eliminate_dead_code()
graph_module.recompile()
return PassResult(graph_module, True)
3 changes: 3 additions & 0 deletions backends/qualcomm/_passes/qnn_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
DecomposeRemainder,
DecomposeRoll,
DecomposeSilu,
DecomposeTan,
DecomposeThreshold,
DecomposeTriu,
DecomposeTrunc,
Expand Down Expand Up @@ -112,6 +113,7 @@ def get_capture_program_passes():
(DecomposeMinMaxDim, True),
(DecomposePad, True),
(DecomposeRemainder, True),
(DecomposeTan, True),
(DecomposeTrunc, True),
(ExpandBroadcastTensorShape, True),
(FixedLinearKeepDim, True),
Expand Down Expand Up @@ -236,6 +238,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
self.add_pass(DecomposeScaledDotProductAttention())
self.add_pass(DecomposeRoll())
self.add_pass(DecomposeSilu())
self.add_pass(DecomposeTan())
self.add_pass(DecomposeThreshold())
self.add_pass(DecomposeTriu())
self.add_pass(DecomposeTrunc())
Expand Down
2 changes: 2 additions & 0 deletions backends/qualcomm/_passes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def get_passes_dependency_for_capture_program():
DecomposeMaxPool3d,
DecomposePad,
DecomposeRemainder,
DecomposeTan,
DecomposeTrunc,
ExpandBroadcastTensorShape,
FixedLinearKeepDim,
Expand Down Expand Up @@ -107,6 +108,7 @@ def get_passes_dependency_for_capture_program():
DecomposeMaxPool3d: [RemoveRedundancy],
DecomposePad: [RemoveRedundancy],
DecomposeRemainder: [RemoveRedundancy],
DecomposeTan: [RemoveRedundancy],
DecomposeTrunc: [RemoveRedundancy],
ExpandBroadcastTensorShape: [FoldQDQ],
FixedLinearKeepDim: [FoldQDQ],
Expand Down
5 changes: 3 additions & 2 deletions backends/qualcomm/builders/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ Please help update following table if you are contributing new operators:
+ 🚫 = Deprecated, supported with other QNN Ops


| Operators | HTP - 98/119 Enabled |
| Operators | HTP - 99/119 Enabled |
|-----------|---------|
| Argmax | ✓ |
| Argmin | ✓ |
Expand Down Expand Up @@ -472,7 +472,7 @@ Please help update following table if you are contributing new operators:
| ResizeNearestNeighbor | ✓ |
| RoiAlign | ✗ |
| RmsNorm | ✓ |
| ScatterElements | ✗ |
| ScatterElements | ✓ |
| ScatterNd | ✓ |
| Sigmoid | ✓ |
| Softmax | ✓ |
Expand Down Expand Up @@ -517,6 +517,7 @@ The following PyTorch operators are supported through decomposition or annotatio
| `aten.remainder.Scalar`, `aten.remainder.Tensor` | `DecomposeRemainder` |
| `aten.roll` | `DecomposeRoll` |
| `aten.silu` | `DecomposeSilu` |
| `aten.tan` | `DecomposeTan` |
| `aten.threshold` | `DecomposeThreshold` |
| `aten.triu` | `DecomposeTriu` |
| `aten.trunc` | `DecomposeTrunc` |
Expand Down
8 changes: 8 additions & 0 deletions backends/qualcomm/tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2425,6 +2425,14 @@ def forward(self, x):
return torch.swapaxes(x, axis0=self.axis0, axis1=self.axis1)


class Tan(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return torch.tan(x)


class Tanh(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down
11 changes: 11 additions & 0 deletions backends/qualcomm/tests/test_qnn_delegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2063,6 +2063,11 @@ def test_qnn_backend_swapaxes(self):
sample_input = (torch.randn([1, 2, 3, 4]),)
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_tan(self):
module = Tan() # noqa: F405
sample_input = (torch.rand(2, 5, 1, 3) * 2 - 1,)
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_tanh(self):
module = Tanh() # noqa: F405
sample_input = (torch.randn(2, 5, 1, 3),)
Expand Down Expand Up @@ -4667,6 +4672,12 @@ def test_qnn_backend_swapaxes(self):
module = self.get_qdq_module(module, sample_input)
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_tan(self):
module = Tan() # noqa: F405
sample_input = (torch.rand(2, 5, 1, 3) * 2 - 1,)
module = self.get_qdq_module(module, sample_input)
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_tanh(self):
module = Tanh() # noqa: F405
sample_input = (torch.randn(2, 5, 1, 3),)
Expand Down
Loading