Skip to content

Commit a2438e4

Browse files
authored
Arm backend: Add TOSA binary op visitors (#20479)
Run ExirToTosaPass after int64 cleanup so dtype fixes and placeholder fusion happen before TOSA dialect rewriting. Update bool promotion to leave bitwise ops available for logical rewrites, and handle tied-parameter cleanup in FuseConstantArgsPass. Reject int32 comparisons for FP-only TOSA specs and update model partition expectations to match the stricter support check. Signed-off-by: Saoirse Stewart <saoirse.stewart@arm.com>
1 parent 23e9bec commit a2438e4

25 files changed

Lines changed: 319 additions & 180 deletions

backends/arm/_passes/arm_pass_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -627,7 +627,6 @@ def _tosa_pipeline(
627627
DecomposePermuteForU55Pass(),
628628
RewriteSlicePass(),
629629
InsertConstShapesPass(),
630-
ExirToTosaPass(exported_program),
631630
]
632631
)
633632

@@ -636,6 +635,7 @@ def _tosa_pipeline(
636635
[
637636
CastInt64BuffersToInt32Pass(exported_program),
638637
FuseEqualPlaceholdersPass(exported_program),
638+
ExirToTosaPass(exported_program),
639639
SymbolicToTosaShapesPass(),
640640
InsertDynamicPaddingPass(),
641641
FuseConsecutiveConcatShapesPass(),

backends/arm/_passes/aten_to_tosa_tensor_operators.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,47 @@ def rewrite_argmax(node: Node, pass_: AtenToDialectPass) -> DialectNodeSpec:
2424
(input_node, dim),
2525
{},
2626
)
27+
28+
29+
def rewrite_binary_operator(
30+
node: Node, pass_: AtenToDialectPass
31+
) -> DialectNodeSpec | None:
32+
match node.target:
33+
case exir_ops.edge.aten.add.Tensor:
34+
target = exir_ops.backend.tosa.ADD.default
35+
case exir_ops.edge.aten.bitwise_and.Tensor:
36+
target = exir_ops.backend.tosa.BITWISE_AND.default
37+
case exir_ops.edge.aten.bitwise_left_shift.Tensor:
38+
target = exir_ops.backend.tosa.LOGICAL_LEFT_SHIFT.default
39+
case exir_ops.edge.aten.bitwise_or.Tensor:
40+
target = exir_ops.backend.tosa.BITWISE_OR.default
41+
case exir_ops.edge.aten.bitwise_right_shift.Tensor:
42+
target = exir_ops.backend.tosa.ARITHMETIC_RIGHT_SHIFT.default
43+
case exir_ops.edge.aten.bitwise_xor.Tensor:
44+
target = exir_ops.backend.tosa.BITWISE_XOR.default
45+
case exir_ops.edge.aten.eq.Tensor:
46+
target = exir_ops.backend.tosa.EQUAL.default
47+
case exir_ops.edge.aten.ge.Tensor:
48+
target = exir_ops.backend.tosa.GREATER_EQUAL.default
49+
case exir_ops.edge.aten.gt.Tensor:
50+
target = exir_ops.backend.tosa.GREATER.default
51+
case exir_ops.edge.aten.logical_and.default:
52+
target = exir_ops.backend.tosa.LOGICAL_AND.default
53+
case exir_ops.edge.aten.logical_or.default:
54+
target = exir_ops.backend.tosa.LOGICAL_OR.default
55+
case exir_ops.edge.aten.logical_xor.default:
56+
target = exir_ops.backend.tosa.LOGICAL_XOR.default
57+
case exir_ops.edge.aten.maximum.default:
58+
target = exir_ops.backend.tosa.MAXIMUM.default
59+
case exir_ops.edge.aten.minimum.default:
60+
target = exir_ops.backend.tosa.MINIMUM.default
61+
case exir_ops.edge.aten.mul.Tensor:
62+
target = exir_ops.backend.tosa.MUL.default
63+
case exir_ops.edge.aten.pow.Tensor_Tensor:
64+
target = exir_ops.backend.tosa.POW.default
65+
case exir_ops.edge.aten.sub.Tensor:
66+
target = exir_ops.backend.tosa.SUB.default
67+
case _:
68+
return None
69+
70+
return DialectNodeSpec(target, node.args, dict(node.kwargs))

backends/arm/_passes/exir_to_tosa_pass.py

Lines changed: 50 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,24 @@
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6+
from collections.abc import Callable
7+
68
import executorch.backends.arm.tosa.dialect # noqa: F401
79
from executorch.backends.arm._passes.aten_to_tosa_activation_functions import (
810
get_activation_replacement,
911
)
10-
from executorch.backends.arm._passes.aten_to_tosa_tensor_operators import rewrite_argmax
12+
from executorch.backends.arm._passes.aten_to_tosa_tensor_operators import (
13+
rewrite_argmax,
14+
rewrite_binary_operator,
15+
)
1116
from executorch.backends.transforms.aten_to_dialect_pass import (
1217
AtenToDialectPass,
1318
DialectNodeSpec,
19+
SubstitutionFn,
1420
)
1521
from executorch.exir.dialects._ops import ops as exir_ops
1622
from torch.fx import Node
23+
from torch.fx.node import Target
1724

1825

1926
class ExirToTosaPass(AtenToDialectPass):
@@ -25,17 +32,55 @@ class ExirToTosaPass(AtenToDialectPass):
2532
"""
2633

2734

35+
def register_dialect_substitutions(
36+
*targets: Target,
37+
) -> Callable[[SubstitutionFn], SubstitutionFn]:
38+
def decorator(func: SubstitutionFn) -> SubstitutionFn:
39+
for target in targets:
40+
ExirToTosaPass.register_dialect_substitution(target)(func)
41+
return func
42+
43+
return decorator
44+
45+
2846
@ExirToTosaPass.register_dialect_substitution(exir_ops.edge.aten.argmax.default)
2947
def _get_tensor_operators_replacement(
3048
node: Node, pass_: AtenToDialectPass
3149
) -> DialectNodeSpec:
3250
return rewrite_argmax(node, pass_)
3351

3452

35-
@ExirToTosaPass.register_dialect_substitution(exir_ops.edge.aten.clamp.default)
36-
@ExirToTosaPass.register_dialect_substitution(exir_ops.edge.aten.erf.default)
37-
@ExirToTosaPass.register_dialect_substitution(exir_ops.edge.aten.sigmoid.default)
38-
@ExirToTosaPass.register_dialect_substitution(exir_ops.edge.aten.tanh.default)
53+
@register_dialect_substitutions(
54+
exir_ops.edge.aten.add.Tensor,
55+
exir_ops.edge.aten.bitwise_and.Tensor,
56+
exir_ops.edge.aten.bitwise_left_shift.Tensor,
57+
exir_ops.edge.aten.bitwise_or.Tensor,
58+
exir_ops.edge.aten.bitwise_right_shift.Tensor,
59+
exir_ops.edge.aten.bitwise_xor.Tensor,
60+
exir_ops.edge.aten.eq.Tensor,
61+
exir_ops.edge.aten.ge.Tensor,
62+
exir_ops.edge.aten.gt.Tensor,
63+
exir_ops.edge.aten.logical_and.default,
64+
exir_ops.edge.aten.logical_or.default,
65+
exir_ops.edge.aten.logical_xor.default,
66+
exir_ops.edge.aten.maximum.default,
67+
exir_ops.edge.aten.minimum.default,
68+
exir_ops.edge.aten.mul.Tensor,
69+
exir_ops.edge.aten.pow.Tensor_Tensor,
70+
exir_ops.edge.aten.sub.Tensor,
71+
)
72+
def _get_binary_operator_replacement(
73+
node: Node, pass_: AtenToDialectPass
74+
) -> DialectNodeSpec | None:
75+
return rewrite_binary_operator(node, pass_)
76+
77+
78+
@register_dialect_substitutions(
79+
exir_ops.edge.aten.clamp.default,
80+
exir_ops.edge.aten.erf.default,
81+
exir_ops.edge.aten.sigmoid.default,
82+
exir_ops.edge.aten.tanh.default,
83+
)
3984
def _get_activation_replacement(
4085
node: Node, pass_: AtenToDialectPass
4186
) -> DialectNodeSpec | None:

backends/arm/_passes/fuse_constant_ops_pass.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from executorch.exir import ExportedProgram
2929
from executorch.exir.dialects._ops import ops as exir_ops
3030
from executorch.exir.pass_base import ExportPass, PassResult
31-
from torch.export.graph_signature import InputKind
31+
from torch.export.graph_signature import ExportGraphSignature, InputKind
3232

3333
logger = logging.getLogger(__name__)
3434

@@ -55,6 +55,35 @@ def __init__(self, exported_program: ExportedProgram, *args, **kwargs) -> None:
5555
super().__init__(*args, **kwargs)
5656
self.exported_program = exported_program
5757

58+
def _delete_placeholder_input(self, node: torch.fx.Node) -> None:
59+
if len(node.users) != 0:
60+
raise RuntimeError(
61+
f"Cannot delete input node {node.name} since it has users in the graph."
62+
)
63+
64+
input_specs = [
65+
spec
66+
for spec in self.exported_program.graph_signature.input_specs
67+
if spec.arg.name != node.name
68+
]
69+
self.exported_program._graph_signature = ExportGraphSignature(
70+
input_specs, self.exported_program.graph_signature.output_specs
71+
)
72+
node.graph.erase_node(node)
73+
74+
def _delete_constant_placeholder(self, node: torch.fx.Node) -> None:
75+
graph_signature = self.exported_program.graph_signature
76+
if node.name in graph_signature.inputs_to_parameters:
77+
target = graph_signature.inputs_to_parameters[node.name]
78+
if target not in self.exported_program.state_dict:
79+
# Tied parameters can share a state_dict entry across placeholders;
80+
# another dead placeholder may have already removed the tensor, so
81+
# only remove this placeholder from the graph signature.
82+
self._delete_placeholder_input(node)
83+
return
84+
85+
delete_constant_placeholder(self.exported_program, node)
86+
5887
@staticmethod
5988
def _is_tosa_dialect_op(target) -> bool:
6089
target_str = str(target)
@@ -214,7 +243,7 @@ def call(self, graph_module):
214243
graph_module.graph.eliminate_dead_code()
215244
for input_node in input_nodes_to_maybe_delete:
216245
if len(input_node.users) == 0:
217-
delete_constant_placeholder(self.exported_program, input_node)
246+
self._delete_constant_placeholder(input_node)
218247

219248
graph_module = super().call(graph_module).graph_module
220249

backends/arm/_passes/promote_bool_operands_pass.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6-
# The TOSA BITWISE_AND, BITWISE_OR, and BITWISE_XOR don't handle bool inputs.
7-
# When a targeted op receives boolean tensors, we promote them to an integer type before
8-
# invocation and cast the result back to the expected dtype afterwards.
6+
# Some TOSA ops don't handle bool inputs. When a targeted op receives boolean
7+
# tensors, we promote them to an integer type before invocation and cast the
8+
# result back to the expected dtype afterwards.
99

1010
from typing import Set, Type
1111

@@ -23,10 +23,9 @@ class PromoteBoolOperandsPass(ArmOpTargetedPass):
2323

2424
_passes_required_after: Set[Type[ExportPass]] = set()
2525

26+
# Bool bitwise ops are handled by RewriteBoolBitwiseToLogicalPass. Promoting
27+
# them here would hide the bool dtype and prevent that rewrite.
2628
target_ops = {
27-
exir_ops.edge.aten.bitwise_and.Tensor,
28-
exir_ops.edge.aten.bitwise_or.Tensor,
29-
exir_ops.edge.aten.bitwise_xor.Tensor,
3029
exir_ops.edge.aten.mul.Tensor,
3130
}
3231

@@ -41,14 +40,9 @@ def call_operator(self, op, args, kwargs, meta):
4140
# select the first non-bool dtype, or None if all bool
4241
promoted_dtype = next((dt for dt in original_dtypes if dt != torch.bool), None)
4342

44-
# if we don't have a dtype specified by the op, promote to default choice for the op
43+
# If all operands are bool, promote mul to int32.
4544
if promoted_dtype is None:
46-
if op == exir_ops.edge.aten.mul.Tensor:
47-
# mul as int32
48-
promoted_dtype = torch.int32
49-
else:
50-
# bitwise ops can be int8
51-
promoted_dtype = torch.int8
45+
promoted_dtype = torch.int32
5246

5347
target_dtypes = []
5448
for dt in original_dtypes:

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,15 @@ def _is_quantized_constant(node: torch.fx.Node) -> bool:
173173
return len(users) > 0
174174

175175

176+
def _floating_profile_negative_checks(
177+
tosa_spec: TosaSpecification, reporter: WhyNoPartitionReporter
178+
) -> list[OperatorSupportBase]:
179+
checks: list[OperatorSupportBase] = [CheckMixedFloatingInputs(reporter)]
180+
if not tosa_spec.support_integer():
181+
checks.append(CheckInt32ComparisonInputs(reporter))
182+
return checks
183+
184+
176185
def is_quantized(node: torch.fx.Node) -> bool:
177186
"""Checks if the node is quantized.
178187
@@ -341,7 +350,7 @@ def _negative_checks(
341350
checks.extend(_wrapped_additional_checks(additional_checks, reporter))
342351

343352
if tosa_spec.support_float():
344-
checks.append(CheckMixedFloatingInputs(reporter))
353+
checks.extend(_floating_profile_negative_checks(tosa_spec, reporter))
345354
else:
346355
checks.append(CheckArmQuantized(reporter))
347356
checks.append(CheckProperQuantization(reporter))
@@ -995,6 +1004,47 @@ def is_node_supported(
9951004
return True
9961005

9971006

1007+
class CheckInt32ComparisonInputs(OperatorSupportBase):
1008+
"""Reject int32 comparisons under the FP profile."""
1009+
1010+
target_ops = {
1011+
exir_ops.edge.aten.eq.Tensor,
1012+
exir_ops.edge.aten.eq.Scalar,
1013+
exir_ops.edge.aten.ge.Tensor,
1014+
exir_ops.edge.aten.ge.Scalar,
1015+
exir_ops.edge.aten.gt.Tensor,
1016+
exir_ops.edge.aten.gt.Scalar,
1017+
exir_ops.edge.aten.le.Tensor,
1018+
exir_ops.edge.aten.le.Scalar,
1019+
exir_ops.edge.aten.lt.Tensor,
1020+
exir_ops.edge.aten.lt.Scalar,
1021+
}
1022+
1023+
def __init__(self, reporter: WhyNoPartitionReporter) -> None:
1024+
self.reporter = reporter
1025+
super().__init__()
1026+
1027+
def is_node_supported(
1028+
self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node
1029+
) -> bool:
1030+
if node.target not in self.target_ops:
1031+
return True
1032+
1033+
for input_node in (
1034+
input_node
1035+
for input_node in node.all_input_nodes
1036+
if input_node.op != "get_attr"
1037+
):
1038+
if get_first_fake_tensor(input_node).dtype == torch.int32:
1039+
self.reporter.report_reject(
1040+
node,
1041+
"FP profile does not support int32 comparison inputs.",
1042+
)
1043+
return False
1044+
1045+
return True
1046+
1047+
9981048
class RankCheck(OperatorSupportBase):
9991049
"""Reject nodes with rank greater than ``max_rank``."""
10001050

backends/arm/operators/__init__.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from . import ( # noqa
1313
node_visitor,
1414
op_abs,
15-
op_add,
1615
op_amax,
1716
op_amin,
1817
op_any,
@@ -21,56 +20,57 @@
2120
op_ceil,
2221
op_cond_if,
2322
op_cos,
24-
op_eq,
2523
op_exp,
2624
op_floor,
27-
op_ge,
28-
op_gt,
2925
op_log,
3026
op_logical_not,
31-
op_maximum,
32-
op_minimum,
33-
op_mul,
3427
op_neg,
3528
op_permute,
36-
op_pow,
3729
op_reciprocal,
3830
op_repeat,
39-
op_rshift_tensor,
4031
op_rsqrt,
4132
op_sin,
42-
op_sub,
4333
op_sum,
4434
op_to_dim_order_copy,
35+
op_tosa_add,
4536
op_tosa_argmax,
4637
op_tosa_avg_pool2d,
4738
op_tosa_avg_pool2d_adaptive,
39+
op_tosa_binary_ops,
4840
op_tosa_cast_to_block_scaled,
4941
op_tosa_clamp,
5042
op_tosa_conv2d,
5143
op_tosa_conv2d_block_scaled,
5244
op_tosa_conv3d,
5345
op_tosa_custom,
5446
op_tosa_depthwise_conv2d,
47+
op_tosa_eq,
5548
op_tosa_erf,
5649
op_tosa_gather,
50+
op_tosa_ge,
51+
op_tosa_gt,
5752
op_tosa_identity,
5853
op_tosa_matmul,
5954
op_tosa_matmul_t_block_scaled,
6055
op_tosa_max_pool2d,
6156
op_tosa_max_pool2d_adaptive,
57+
op_tosa_maximum,
58+
op_tosa_minimum,
59+
op_tosa_mul,
6260
op_tosa_pad,
61+
op_tosa_pow,
6362
op_tosa_rescale,
6463
op_tosa_resize,
64+
op_tosa_rshift_tensor,
6565
op_tosa_scatter,
6666
op_tosa_shapes,
6767
op_tosa_sigmoid,
6868
op_tosa_slice,
69+
op_tosa_sub,
6970
op_tosa_table,
7071
op_tosa_tanh,
7172
op_tosa_transpose_conv2d,
7273
op_view,
7374
op_where,
7475
op_while,
75-
ops_binary,
7676
)

0 commit comments

Comments
 (0)