Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,7 +625,6 @@ def _tosa_pipeline(
DecomposePermuteForU55Pass(),
RewriteSlicePass(),
InsertConstShapesPass(),
ExirToTosaPass(exported_program),
]
)

Expand All @@ -634,6 +633,7 @@ def _tosa_pipeline(
[
CastInt64BuffersToInt32Pass(exported_program),
FuseEqualPlaceholdersPass(exported_program),
ExirToTosaPass(exported_program),
SymbolicToTosaShapesPass(),
InsertDynamicPaddingPass(),
FuseConsecutiveConcatShapesPass(),
Expand Down
44 changes: 44 additions & 0 deletions backends/arm/_passes/aten_to_tosa_tensor_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,47 @@ def rewrite_argmax(node: Node, pass_: AtenToDialectPass) -> DialectNodeSpec:
(input_node, dim),
{},
)


def rewrite_binary_operator(
node: Node, pass_: AtenToDialectPass
) -> DialectNodeSpec | None:
match node.target:
case exir_ops.edge.aten.add.Tensor:
target = exir_ops.backend.tosa.ADD.default
case exir_ops.edge.aten.bitwise_and.Tensor:
target = exir_ops.backend.tosa.BITWISE_AND.default
case exir_ops.edge.aten.bitwise_left_shift.Tensor:
target = exir_ops.backend.tosa.LOGICAL_LEFT_SHIFT.default
case exir_ops.edge.aten.bitwise_or.Tensor:
target = exir_ops.backend.tosa.BITWISE_OR.default
case exir_ops.edge.aten.bitwise_right_shift.Tensor:
target = exir_ops.backend.tosa.ARITHMETIC_RIGHT_SHIFT.default
case exir_ops.edge.aten.bitwise_xor.Tensor:
target = exir_ops.backend.tosa.BITWISE_XOR.default
case exir_ops.edge.aten.eq.Tensor:
target = exir_ops.backend.tosa.EQUAL.default
case exir_ops.edge.aten.ge.Tensor:
target = exir_ops.backend.tosa.GREATER_EQUAL.default
case exir_ops.edge.aten.gt.Tensor:
target = exir_ops.backend.tosa.GREATER.default
case exir_ops.edge.aten.logical_and.default:
target = exir_ops.backend.tosa.LOGICAL_AND.default
case exir_ops.edge.aten.logical_or.default:
target = exir_ops.backend.tosa.LOGICAL_OR.default
case exir_ops.edge.aten.logical_xor.default:
target = exir_ops.backend.tosa.LOGICAL_XOR.default
case exir_ops.edge.aten.maximum.default:
target = exir_ops.backend.tosa.MAXIMUM.default
case exir_ops.edge.aten.minimum.default:
target = exir_ops.backend.tosa.MINIMUM.default
case exir_ops.edge.aten.mul.Tensor:
target = exir_ops.backend.tosa.MUL.default
case exir_ops.edge.aten.pow.Tensor_Tensor:
target = exir_ops.backend.tosa.POW.default
case exir_ops.edge.aten.sub.Tensor:
target = exir_ops.backend.tosa.SUB.default
case _:
return None

return DialectNodeSpec(target, node.args, dict(node.kwargs))
55 changes: 50 additions & 5 deletions backends/arm/_passes/exir_to_tosa_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,24 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from collections.abc import Callable

import executorch.backends.arm.tosa.dialect # noqa: F401
from executorch.backends.arm._passes.aten_to_tosa_activation_functions import (
get_activation_replacement,
)
from executorch.backends.arm._passes.aten_to_tosa_tensor_operators import rewrite_argmax
from executorch.backends.arm._passes.aten_to_tosa_tensor_operators import (
rewrite_argmax,
rewrite_binary_operator,
)
from executorch.backends.transforms.aten_to_dialect_pass import (
AtenToDialectPass,
DialectNodeSpec,
SubstitutionFn,
)
from executorch.exir.dialects._ops import ops as exir_ops
from torch.fx import Node
from torch.fx.node import Target


class ExirToTosaPass(AtenToDialectPass):
Expand All @@ -25,17 +32,55 @@ class ExirToTosaPass(AtenToDialectPass):
"""


def register_dialect_substitutions(
*targets: Target,
) -> Callable[[SubstitutionFn], SubstitutionFn]:
def decorator(func: SubstitutionFn) -> SubstitutionFn:
for target in targets:
ExirToTosaPass.register_dialect_substitution(target)(func)
return func

return decorator


@ExirToTosaPass.register_dialect_substitution(exir_ops.edge.aten.argmax.default)
def _get_tensor_operators_replacement(
node: Node, pass_: AtenToDialectPass
) -> DialectNodeSpec:
return rewrite_argmax(node, pass_)


@ExirToTosaPass.register_dialect_substitution(exir_ops.edge.aten.clamp.default)
@ExirToTosaPass.register_dialect_substitution(exir_ops.edge.aten.erf.default)
@ExirToTosaPass.register_dialect_substitution(exir_ops.edge.aten.sigmoid.default)
@ExirToTosaPass.register_dialect_substitution(exir_ops.edge.aten.tanh.default)
@register_dialect_substitutions(
exir_ops.edge.aten.add.Tensor,
exir_ops.edge.aten.bitwise_and.Tensor,
exir_ops.edge.aten.bitwise_left_shift.Tensor,
exir_ops.edge.aten.bitwise_or.Tensor,
exir_ops.edge.aten.bitwise_right_shift.Tensor,
exir_ops.edge.aten.bitwise_xor.Tensor,
exir_ops.edge.aten.eq.Tensor,
exir_ops.edge.aten.ge.Tensor,
exir_ops.edge.aten.gt.Tensor,
exir_ops.edge.aten.logical_and.default,
exir_ops.edge.aten.logical_or.default,
exir_ops.edge.aten.logical_xor.default,
exir_ops.edge.aten.maximum.default,
exir_ops.edge.aten.minimum.default,
exir_ops.edge.aten.mul.Tensor,
exir_ops.edge.aten.pow.Tensor_Tensor,
exir_ops.edge.aten.sub.Tensor,
)
def _get_binary_operator_replacement(
node: Node, pass_: AtenToDialectPass
) -> DialectNodeSpec | None:
return rewrite_binary_operator(node, pass_)


@register_dialect_substitutions(
exir_ops.edge.aten.clamp.default,
exir_ops.edge.aten.erf.default,
exir_ops.edge.aten.sigmoid.default,
exir_ops.edge.aten.tanh.default,
)
def _get_activation_replacement(
node: Node, pass_: AtenToDialectPass
) -> DialectNodeSpec | None:
Expand Down
20 changes: 7 additions & 13 deletions backends/arm/_passes/promote_bool_operands_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# The TOSA BITWISE_AND, BITWISE_OR, and BITWISE_XOR don't handle bool inputs.
# When a targeted op receives boolean tensors, we promote them to an integer type before
# invocation and cast the result back to the expected dtype afterwards.
# Some TOSA ops don't handle bool inputs. When a targeted op receives boolean
# tensors, we promote them to an integer type before invocation and cast the
# result back to the expected dtype afterwards.

from typing import Set, Type

Expand All @@ -23,10 +23,9 @@ class PromoteBoolOperandsPass(ArmOpTargetedPass):

_passes_required_after: Set[Type[ExportPass]] = set()

# Bool bitwise ops are handled by RewriteBoolBitwiseToLogicalPass. Promoting
# them here would hide the bool dtype and prevent that rewrite.
target_ops = {
exir_ops.edge.aten.bitwise_and.Tensor,
exir_ops.edge.aten.bitwise_or.Tensor,
exir_ops.edge.aten.bitwise_xor.Tensor,
exir_ops.edge.aten.mul.Tensor,
}

Expand All @@ -41,14 +40,9 @@ def call_operator(self, op, args, kwargs, meta):
# select the first non-bool dtype, or None if all bool
promoted_dtype = next((dt for dt in original_dtypes if dt != torch.bool), None)

# if we don't have a dtype specified by the op, promote to default choice for the op
# If all operands are bool, promote mul to int32.
if promoted_dtype is None:
if op == exir_ops.edge.aten.mul.Tensor:
# mul as int32
promoted_dtype = torch.int32
else:
# bitwise ops can be int8
promoted_dtype = torch.int8
promoted_dtype = torch.int32

target_dtypes = []
for dt in original_dtypes:
Expand Down
52 changes: 51 additions & 1 deletion backends/arm/operator_support/tosa_supported_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,15 @@ def _is_quantized_constant(node: torch.fx.Node) -> bool:
return len(users) > 0


def _floating_profile_negative_checks(
tosa_spec: TosaSpecification, reporter: WhyNoPartitionReporter
) -> list[OperatorSupportBase]:
checks: list[OperatorSupportBase] = [CheckMixedFloatingInputs(reporter)]
if not tosa_spec.support_integer():
checks.append(CheckInt32ComparisonInputs(reporter))
return checks


def is_quantized(node: torch.fx.Node) -> bool:
"""Checks if the node is quantized.

Expand Down Expand Up @@ -341,7 +350,7 @@ def _negative_checks(
checks.extend(_wrapped_additional_checks(additional_checks, reporter))

if tosa_spec.support_float():
checks.append(CheckMixedFloatingInputs(reporter))
checks.extend(_floating_profile_negative_checks(tosa_spec, reporter))
else:
checks.append(CheckArmQuantized(reporter))
checks.append(CheckProperQuantization(reporter))
Expand Down Expand Up @@ -995,6 +1004,47 @@ def is_node_supported(
return True


class CheckInt32ComparisonInputs(OperatorSupportBase):
"""Reject int32 comparisons under the FP profile."""

target_ops = {
exir_ops.edge.aten.eq.Tensor,
exir_ops.edge.aten.eq.Scalar,
exir_ops.edge.aten.ge.Tensor,
exir_ops.edge.aten.ge.Scalar,
exir_ops.edge.aten.gt.Tensor,
exir_ops.edge.aten.gt.Scalar,
exir_ops.edge.aten.le.Tensor,
exir_ops.edge.aten.le.Scalar,
exir_ops.edge.aten.lt.Tensor,
exir_ops.edge.aten.lt.Scalar,
}

def __init__(self, reporter: WhyNoPartitionReporter) -> None:
self.reporter = reporter
super().__init__()

def is_node_supported(
self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node
) -> bool:
if node.target not in self.target_ops:
return True

for input_node in (
input_node
for input_node in node.all_input_nodes
if input_node.op != "get_attr"
):
if get_first_fake_tensor(input_node).dtype == torch.int32:
self.reporter.report_reject(
node,
"FP profile does not support int32 comparison inputs.",
)
return False

return True


class RankCheck(OperatorSupportBase):
"""Reject nodes with rank greater than ``max_rank``."""

Expand Down
22 changes: 11 additions & 11 deletions backends/arm/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from . import ( # noqa
node_visitor,
op_abs,
op_add,
op_amax,
op_amin,
op_any,
Expand All @@ -21,56 +20,57 @@
op_ceil,
op_cond_if,
op_cos,
op_eq,
op_exp,
op_floor,
op_ge,
op_gt,
op_log,
op_logical_not,
op_maximum,
op_minimum,
op_mul,
op_neg,
op_permute,
op_pow,
op_reciprocal,
op_repeat,
op_rshift_tensor,
op_rsqrt,
op_sin,
op_sub,
op_sum,
op_to_dim_order_copy,
op_tosa_add,
op_tosa_argmax,
op_tosa_avg_pool2d,
op_tosa_avg_pool2d_adaptive,
op_tosa_binary_ops,
op_tosa_cast_to_block_scaled,
op_tosa_clamp,
op_tosa_conv2d,
op_tosa_conv2d_block_scaled,
op_tosa_conv3d,
op_tosa_custom,
op_tosa_depthwise_conv2d,
op_tosa_eq,
op_tosa_erf,
op_tosa_gather,
op_tosa_ge,
op_tosa_gt,
op_tosa_identity,
op_tosa_matmul,
op_tosa_matmul_t_block_scaled,
op_tosa_max_pool2d,
op_tosa_max_pool2d_adaptive,
op_tosa_maximum,
op_tosa_minimum,
op_tosa_mul,
op_tosa_pad,
op_tosa_pow,
op_tosa_rescale,
op_tosa_resize,
op_tosa_rshift_tensor,
op_tosa_scatter,
op_tosa_shapes,
op_tosa_sigmoid,
op_tosa_slice,
op_tosa_sub,
op_tosa_table,
op_tosa_tanh,
op_tosa_transpose_conv2d,
op_view,
op_where,
op_while,
ops_binary,
)
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

@register_node_visitor
class AddVisitor(SimpleNodeVisitor):
target = "aten.add.Tensor"
target = "tosa.ADD.default"

@classmethod
def get_config(cls) -> SimpleNodeVisitorConfig:
Expand Down
Loading
Loading