Skip to content

Commit 66884b4

Browse files
authored
Qualcomm AI Engine Direct - Adding QNN backend support for the addmm coreATen op (#20355)
### Summary Added full support for the `aten.addmm` core ATen op via a two-pass decomposition strategy: 1. `AddmmToLinearTransform` (ExecuTorch shared pass): Converts the common `nn.Linear` decomposition pattern (`addmm(bias, input, weight.T)`) back to `aten.linear`, mapping to QNN's fused `FullyConnected` op for optimal performance. 2. `DecomposeAddmm` (new pass): Handles remaining standalone `addmm` nodes by decomposing them into `mm + add`. Supports non-unit `alpha`/`beta` scalars via additional `mul` nodes. `AddmmToLinearTransform` alone is not sufficient because it only handles the subset of `addmm` nodes that match the `nn.Linear` decomposition pattern, specifically where `args[2]` is a transposed weight (`t_copy` or `permute_copy`). Standalone `addmm(bias, A, B)` calls where `B` is not transposed are explicitly skipped by that pass. `DecomposeAddmm` serves as the fallback for these cases. Also made some small improvements to the `new_op_development` skill based on recent learnings. ### Test plan ``` python backends/qualcomm/tests/test_qnn_delegate.py -k TestQNNQuantizedOperator.test_qnn_backend_addmm --model SM8750 --host aisw-vm15-labsd --device 545ee4aa --build_folder build-android python backends/qualcomm/tests/test_qnn_delegate.py -k TestQNNFloatingPointOperator.test_qnn_backend_addmm --model SM8750 --host aisw-vm15-labsd --device 545ee4aa --build_folder build-android ```
1 parent 23f9021 commit 66884b4

9 files changed

Lines changed: 203 additions & 5 deletions

File tree

.claude/skills/qualcomm/new_op_development.md

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -217,8 +217,17 @@ class DecomposeMyOp(ExportPass):
217217

218218
### Registration (all decompose passes)
219219
1. `_passes/__init__.py` — import + `__all__`
220-
2. `_passes/qnn_pass_manager.py` — import + `transform_for_annotation_pipeline` + `transform_for_export_pipeline` + `get_capture_program_passes`
221-
3. `_passes/utils.py` — add to `get_passes_dependency_for_capture_program()` with `[RemoveRedundancy]` dependency
220+
2. `_passes/qnn_pass_manager.py` — The pass manager uses classmethods for pipeline definitions:
221+
- **Import** — add to the import block at top of file
222+
- **`get_annotation_passes()`** — add pass class to the returned list (runs before quantizer, ATen IR)
223+
- **`get_export_passes()`** — add pass class if needed for float-only path (runs after quantization, before to-edge)
224+
- **`get_default_pass_activations()`** — add `(PassClass, True)` ONLY if the pass also needs to run in the to-edge pipeline
225+
- **`get_passes_dependency_for_capture_program()`** — add `PassClass: [RemoveRedundancy]` dependency ONLY if also in `get_default_pass_activations`
226+
227+
**When to add to which pipeline:**
228+
- **Annotation only** (most common for decompose passes): `get_annotation_passes()` — pass decomposes the op before the quantizer sees it
229+
- **Export pipeline** too: if the float-only test fails without it (op doesn't get handled by PyTorch's built-in decomposition during to-edge)
230+
- **Capture program** (to-edge) too: if the op can appear in edge dialect and needs decomposition there (e.g., `DecomposeVar`, `DecomposeCDist`, `DecomposeDiagonal`)
222231

223232
---
224233

@@ -255,4 +264,4 @@ class DecomposeMyOp(ExportPass):
255264

256265
**Native QNN Op:** `qnn_constants.py``op_my_op.py``builders/__init__.py``htp_rules.py``lpai_rules.py``layout_transform.py``tests/models.py``test_qnn_delegate.py``partition/utils.py` (skip decomp) → `common_defs.py` (remove to_be_implemented) → `builders/README.md`
257266

258-
**Decompose Pass:** `_passes/decompose_my_op.py``_passes/__init__.py``qnn_pass_manager.py` (annotation + export + capture) → `_passes/utils.py` (dependency) → `tests/models.py``test_qnn_delegate.py``common_defs.py``builders/README.md`
267+
**Decompose Pass:** `_passes/decompose_my_op.py``_passes/__init__.py``qnn_pass_manager.py` (`get_annotation_passes` + optionally `get_export_passes`; if also needed in to-edge: `get_default_pass_activations` + `get_passes_dependency_for_capture_program`) → `tests/models.py``test_qnn_delegate.py``common_defs.py``builders/README.md`

backends/qualcomm/_passes/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from .convert_mha_to_sha import ConvertMhaToSha
1515
from .convert_square_to_pow import ConvertSquareToPow
1616
from .decompose_acos import DecomposeAcos
17+
from .decompose_addmm import DecomposeAddmm
1718
from .decompose_any import DecomposeAny
1819
from .decompose_atan2 import DecomposeAtan2
1920
from .decompose_binary_alpha import DecomposeBinaryAlpha
@@ -76,6 +77,7 @@
7677
ConvertMhaToSha,
7778
ConvertSquareToPow,
7879
DecomposeAcos,
80+
DecomposeAddmm,
7981
DecomposeAny,
8082
DecomposeAtan2,
8183
DecomposeBinaryAlpha,
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
from executorch.exir.dialects._ops import ops as exir_ops
9+
from executorch.exir.dialects.edge._ops import EdgeOpOverload
10+
from executorch.exir.pass_base import ExportPass, PassResult
11+
12+
from .utils import copy_meta, get_const_node
13+
14+
15+
class DecomposeAddmm(ExportPass):
16+
"""
17+
Decompose addmm into mm + add (with optional mul for non-unit alpha/beta).
18+
addmm(bias, input, mat2, beta=1, alpha=1) = beta * bias + alpha * (input @ mat2)
19+
20+
For the common case (alpha=1, beta=1): addmm(bias, input, mat2) = mm(input, mat2) + bias
21+
22+
Note: This pass serves as a fallback for standalone addmm nodes that are NOT
23+
handled by the ExecuTorch-provided pass AddmmToLinearTransform.
24+
Any remaining addmm nodes (e.g., with non-transposed mat2) are decomposed here into mm + add.
25+
"""
26+
27+
def __init__(self):
28+
super().__init__()
29+
self.addmm_targets = {
30+
torch.ops.aten.addmm.default,
31+
exir_ops.edge.aten.addmm.default,
32+
}
33+
34+
def call(self, graph_module: torch.fx.GraphModule):
35+
graph = graph_module.graph
36+
37+
for node in list(graph.nodes):
38+
if node.op == "call_function" and node.target in self.addmm_targets:
39+
is_edge = isinstance(node.target, EdgeOpOverload)
40+
bias_node = node.args[0]
41+
input_node = node.args[1]
42+
mat2_node = node.args[2]
43+
# kwargs beta and alpha default to 1
44+
beta = node.kwargs.get("beta", 1)
45+
alpha = node.kwargs.get("alpha", 1)
46+
47+
mm_op = (
48+
exir_ops.edge.aten.mm.default
49+
if is_edge
50+
else torch.ops.aten.mm.default
51+
)
52+
add_op = (
53+
exir_ops.edge.aten.add.Tensor
54+
if is_edge
55+
else torch.ops.aten.add.Tensor
56+
)
57+
mul_op = (
58+
exir_ops.edge.aten.mul.Tensor
59+
if is_edge
60+
else torch.ops.aten.mul.Tensor
61+
)
62+
63+
meta = node.meta
64+
65+
with graph.inserting_before(node):
66+
# mm_result = input @ mat2
67+
mm_node = graph.create_node(
68+
"call_function", mm_op, (input_node, mat2_node)
69+
)
70+
mm_node.meta = copy_meta(meta)
71+
72+
if alpha != 1:
73+
alpha_node = get_const_node(
74+
graph,
75+
graph_module,
76+
f"{node.name}_alpha",
77+
alpha,
78+
mm_node,
79+
)
80+
mm_scaled = graph.create_node(
81+
"call_function", mul_op, (mm_node, alpha_node)
82+
)
83+
mm_scaled.meta = copy_meta(meta)
84+
mm_result = mm_scaled
85+
else:
86+
mm_result = mm_node
87+
88+
if beta != 1:
89+
beta_const = get_const_node(
90+
graph,
91+
graph_module,
92+
f"{node.name}_beta",
93+
beta,
94+
bias_node,
95+
)
96+
bias_scaled = graph.create_node(
97+
"call_function", mul_op, (bias_node, beta_const)
98+
)
99+
bias_scaled.meta = copy_meta(meta)
100+
bias_result = bias_scaled
101+
else:
102+
bias_result = bias_node
103+
104+
# result = mm_result + bias
105+
add_node = graph.create_node(
106+
"call_function", add_op, (mm_result, bias_result)
107+
)
108+
add_node.meta = copy_meta(meta)
109+
110+
for user in node.users.copy():
111+
user.replace_input_with(node, add_node)
112+
113+
graph.eliminate_dead_code()
114+
graph_module.recompile()
115+
return PassResult(graph_module, True)

backends/qualcomm/_passes/qnn_pass_manager.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
ConvertMhaToSha,
2121
ConvertSquareToPow,
2222
DecomposeAcos,
23+
DecomposeAddmm,
2324
DecomposeAny,
2425
DecomposeAtan2,
2526
DecomposeBinaryAlpha,
@@ -122,6 +123,7 @@ def get_default_pass_activations(cls):
122123
(AnnotateUnbind, True),
123124
(ConvertBmmToMatmul, False),
124125
(DecomposeAcos, True),
126+
(DecomposeAddmm, True),
125127
(DecomposeAny, True),
126128
(DecomposeAtan2, True),
127129
(DecomposeColIm, True),
@@ -160,6 +162,7 @@ def get_annotation_passes(cls):
160162
RecomposeRmsNorm,
161163
ReplaceArangeArgs,
162164
DecomposeAcos,
165+
DecomposeAddmm,
163166
DecomposeAtan2,
164167
DecomposeBinaryAlpha,
165168
DecomposeCDist,
@@ -275,6 +278,7 @@ def get_passes_dependency_for_capture_program(cls):
275278
AnnotateUnbind: [RemoveRedundancy],
276279
ConvertBmmToMatmul: [RecomposePixelUnshuffle],
277280
DecomposeAcos: [RemoveRedundancy],
281+
DecomposeAddmm: [RemoveRedundancy],
278282
DecomposeAny: [RemoveRedundancy],
279283
DecomposeAtan2: [RemoveRedundancy],
280284
DecomposeColIm: [FoldQDQ],

backends/qualcomm/builders/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -498,6 +498,7 @@ The following PyTorch operators are supported through decomposition or annotatio
498498
| PyTorch Op | Decomposition Pass |
499499
|---|---|
500500
| `aten.acos` | `DecomposeAcos` |
501+
| `aten.addmm` | `DecomposeAddmm` |
501502
| `aten.adaptive_avg_pool1d`, `aten.avg_pool1d` | `AnnotateAvgPool1D` |
502503
| `aten.any` | `DecomposeAny` |
503504
| `aten.atan2.default`, `aten.atan2.out` | `DecomposeAtan2` |

backends/qualcomm/quantizer/annotators/htp_rules.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1077,7 +1077,11 @@ def annotate(node: Node, quantization_config: QuantizationConfig) -> None:
10771077

10781078

10791079
@register_annotator(
1080-
[torch.ops.aten.bmm.default, torch.ops.aten.matmul.default],
1080+
[
1081+
torch.ops.aten.bmm.default,
1082+
torch.ops.aten.matmul.default,
1083+
torch.ops.aten.mm.default,
1084+
],
10811085
QnnConstants.OpMatMul.op_name,
10821086
)
10831087
class MatMul(GeneralOpDef):

backends/qualcomm/quantizer/annotators/lpai_rules.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -601,7 +601,11 @@ def annotate(node: Node, quantization_config: QuantizationConfig) -> None:
601601

602602

603603
@register_annotator(
604-
[torch.ops.aten.bmm.default, torch.ops.aten.matmul.default],
604+
[
605+
torch.ops.aten.bmm.default,
606+
torch.ops.aten.matmul.default,
607+
torch.ops.aten.mm.default,
608+
],
605609
QnnConstants.OpMatMul.op_name,
606610
)
607611
class MatMul(GeneralOpDef):

backends/qualcomm/tests/models.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,16 @@ def forward(self, x):
144144
return 10 + x
145145

146146

147+
class AddMM(torch.nn.Module):
148+
def __init__(self, alpha=1, beta=1):
149+
super().__init__()
150+
self.alpha = alpha
151+
self.beta = beta
152+
153+
def forward(self, bias, input, mat2):
154+
return torch.addmm(bias, input, mat2, alpha=self.alpha, beta=self.beta)
155+
156+
147157
class Any(torch.nn.Module):
148158
def __init__(self, dim=None, keepdim=False):
149159
super().__init__()

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,30 @@ def test_qnn_backend_adaptive_max_pool2d(self):
190190
with self.subTest(i=i):
191191
self.lower_module_and_test_output(module, sample_input)
192192

193+
def test_qnn_backend_addmm(self):
194+
test_comb = [
195+
{
196+
QCOM_MODULE: [AddMM()], # noqa: F405
197+
QCOM_SAMPLE_INPUTS: [
198+
(torch.randn(8), torch.randn(4, 3), torch.randn(3, 8)),
199+
],
200+
},
201+
{
202+
QCOM_MODULE: [AddMM(alpha=2, beta=3)], # noqa: F405
203+
QCOM_SAMPLE_INPUTS: [
204+
(torch.randn(8), torch.randn(4, 3), torch.randn(3, 8)),
205+
],
206+
},
207+
]
208+
209+
index = 0
210+
for comb in test_comb:
211+
for module in comb[QCOM_MODULE]:
212+
for sample_input in comb[QCOM_SAMPLE_INPUTS]:
213+
with self.subTest(i=index):
214+
index += 1
215+
self.lower_module_and_test_output(module, sample_input)
216+
193217
def test_qnn_backend_alias(self):
194218
module = Alias() # noqa: F405
195219
sample_input = (torch.randn(1, 10),)
@@ -2969,6 +2993,31 @@ def test_qnn_backend_adaptive_max_pool2d(self):
29692993
module_one = self.get_qdq_module(module, sample_input)
29702994
self.lower_module_and_test_output(module_one, sample_input)
29712995

2996+
def test_qnn_backend_addmm(self):
2997+
test_comb = [
2998+
{
2999+
QCOM_MODULE: [AddMM()], # noqa: F405
3000+
QCOM_SAMPLE_INPUTS: [
3001+
(torch.randn(8), torch.randn(4, 3), torch.randn(3, 8)),
3002+
],
3003+
},
3004+
{
3005+
QCOM_MODULE: [AddMM(alpha=2, beta=3)], # noqa: F405
3006+
QCOM_SAMPLE_INPUTS: [
3007+
(torch.randn(8), torch.randn(4, 3), torch.randn(3, 8)),
3008+
],
3009+
},
3010+
]
3011+
3012+
index = 0
3013+
for comb in test_comb:
3014+
for module in comb[QCOM_MODULE]:
3015+
for sample_input in comb[QCOM_SAMPLE_INPUTS]:
3016+
with self.subTest(i=index):
3017+
index += 1
3018+
qdq_module = self.get_qdq_module(module, sample_input)
3019+
self.lower_module_and_test_output(qdq_module, sample_input)
3020+
29723021
def test_qnn_backend_alias(self):
29733022
module = Alias() # noqa: F405
29743023
sample_input = (torch.randn(1, 10),)

0 commit comments

Comments
 (0)