diff --git a/backends/xnnpack/_passes/lift_constant_scalar_operands_pass.py b/backends/xnnpack/_passes/lift_constant_scalar_operands_pass.py new file mode 100644 index 00000000000..e256a31750e --- /dev/null +++ b/backends/xnnpack/_passes/lift_constant_scalar_operands_pass.py @@ -0,0 +1,150 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from numbers import Number +from typing import Dict, Optional, Union + +import torch +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.dialects.edge._ops import EdgeOpOverload +from executorch.exir.pass_base import ExportPass, PassResult +from torch._ops import OpOverload + + +ScalarOp = Union[EdgeOpOverload, OpOverload] + + +class LiftConstantScalarOperandsPass(ExportPass): + """ + Lift scalar operands into tensor constants for selected binary ops. + + XNNPACK already supports the tensor overloads for these binary operations. + This pass converts explicitly listed scalar overloads to their tensor + overloads by replacing constant scalar operands with small tensor constants. + The constants are registered as buffers so they do not become portable + ``full`` kernels. Keep the op map narrow until each new scalar overload is + covered by tests. + """ + + default_scalar_to_tensor_ops: Dict[ScalarOp, ScalarOp] = { + exir_ops.edge.aten.mul.Scalar: exir_ops.edge.aten.mul.Tensor, + } + sdpa_passthrough_ops = { + exir_ops.edge.aten.expand_copy.default, + exir_ops.edge.aten.view_copy.default, + } + + def __init__( + self, + scalar_to_tensor_ops: Optional[Dict[ScalarOp, ScalarOp]] = None, + ) -> None: + super().__init__() + self.scalar_to_tensor_ops = ( + scalar_to_tensor_ops + if scalar_to_tensor_ops is not None + else self.default_scalar_to_tensor_ops + ) + + def _create_constant_node( + self, + graph_module: torch.fx.GraphModule, + node: torch.fx.Node, + value: Number, + ) -> torch.fx.Node: + input_node = node.args[0] + if not isinstance(input_node, torch.fx.Node): + raise RuntimeError("Expected scalar op input to be an FX node.") + + input_value = input_node.meta["val"] + tensor = torch.tensor(value, dtype=input_value.dtype, device=input_value.device) + name = self._get_new_attr_name(graph_module) + # Keep constants as module attributes so the portable path can emit them + # without introducing aten.full, while XNNPACK can still read them as params. + graph_module.register_buffer(name, tensor) + + fake_mode = node.meta["val"].fake_mode + with graph_module.graph.inserting_before(node): + constant_node = graph_module.graph.get_attr(name) + constant_node.meta["val"] = fake_mode.from_tensor( + tensor, static_shapes=True + ) + return constant_node + + def _get_new_attr_name(self, graph_module: torch.fx.GraphModule) -> str: + prefix = "_tensor_constant_" + index = 0 + while hasattr(graph_module, f"{prefix}{index}"): + index += 1 + return f"{prefix}{index}" + + def _feeds_sdpa_qk_bmm(self, node: torch.fx.Node) -> bool: + """ + Return true for the scale muls consumed by XNNPACK's SDPA pattern. + + ConvertToSDPAPass recovers the user-specified attention scale from the + pre-QK^T ``aten.mul.Scalar`` nodes. Keep those scalar muls intact so + SDPA conversion can still find the scale before replacing the pattern. + """ + users_to_visit = list(node.users) + visited = set() + while users_to_visit: + user = users_to_visit.pop() + if user in visited: + continue + visited.add(user) + + if ( + user.op == "call_function" + and user.target == exir_ops.edge.aten.bmm.default + ): + return True + + if user.op == "call_function" and user.target in self.sdpa_passthrough_ops: + users_to_visit.extend(user.users) + + return False + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + modified = False + + for node in list(graph_module.graph.nodes): + if ( + node.op != "call_function" + or node.target not in self.scalar_to_tensor_ops + or len(node.args) != 2 + or not isinstance(node.args[0], torch.fx.Node) + or not isinstance(node.args[1], Number) + ): + continue + + if ( + node.target == exir_ops.edge.aten.mul.Scalar + and self._feeds_sdpa_qk_bmm(node) + ): + continue + + input_value = node.args[0].meta.get("val") + output_value = node.meta.get("val") + if ( + input_value is None + or output_value is None + or input_value.dtype != output_value.dtype + ): + continue + + tensor_arg = self._create_constant_node(graph_module, node, node.args[1]) + node.args = (node.args[0], tensor_arg) + node.target = self.scalar_to_tensor_ops[node.target] + modified = True + + graph_module.graph.eliminate_dead_code() + graph_module.graph.lint() + graph_module.recompile() + + return PassResult(graph_module, modified) diff --git a/backends/xnnpack/test/ops/test_multiply.py b/backends/xnnpack/test/ops/test_multiply.py index 3315200005d..118136fcd08 100644 --- a/backends/xnnpack/test/ops/test_multiply.py +++ b/backends/xnnpack/test/ops/test_multiply.py @@ -1,5 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. +# Copyright 2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -8,6 +9,10 @@ import torch from executorch.backends.xnnpack.test.tester import Tester +from executorch.backends.xnnpack.utils.configs import ( + get_transform_passes, + get_xnnpack_edge_compile_config, +) class TestMul(unittest.TestCase): @@ -29,6 +34,10 @@ def forward(self, x, y): z = torch.mul(x, y) * torch.functional.torch.mul(x, y) return z + class MulScalar(torch.nn.Module): + def forward(self, x): + return torch.ops.aten.mul.Scalar(x, 0.5) + class MulRelu(torch.nn.Module): def forward(self, x, y): z = x * y @@ -58,6 +67,28 @@ def test_fp32_mul(self): inputs = (torch.randn((1, 3)), torch.randn((4, 3))) self._test_mul(inputs) + def test_fp32_mul_scalar(self): + ( + Tester(self.MulScalar(), (torch.randn(2, 3),)) + .export() + .to_edge_transform_and_lower( + transform_passes=get_transform_passes(), + edge_compile_config=get_xnnpack_edge_compile_config( + skip_dim_order=True + ), + ) + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .check_not( + [ + "executorch_exir_dialects_edge__ops_aten_mul_Tensor", + "executorch_exir_dialects_edge__ops_aten_mul_Scalar", + ] + ) + .to_executorch() + .serialize() + .run_method_and_compare_outputs() + ) + def test_qs8_mul(self): inputs = (torch.randn(1, 1, 4, 4), torch.randn(1, 1, 4, 1)) ( diff --git a/backends/xnnpack/test/passes/test_lift_constant_scalar_operands_pass.py b/backends/xnnpack/test/passes/test_lift_constant_scalar_operands_pass.py new file mode 100644 index 00000000000..5c61731a786 --- /dev/null +++ b/backends/xnnpack/test/passes/test_lift_constant_scalar_operands_pass.py @@ -0,0 +1,87 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest +from copy import deepcopy + +import torch +from executorch.backends.xnnpack._passes.lift_constant_scalar_operands_pass import ( + LiftConstantScalarOperandsPass, +) +from executorch.backends.xnnpack.partition.graphs import sdpa +from executorch.backends.xnnpack.utils.configs import get_xnnpack_edge_compile_config +from executorch.exir import to_edge +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_manager import ExportedProgramPassManager + + +class TestLiftConstantScalarOperandsPass(unittest.TestCase): + def setUp(self): + torch._dynamo.reset() + + class MulScalar(torch.nn.Module): + def forward(self, x): + return torch.ops.aten.mul.Scalar(x, 0.5) + + class AddScalar(torch.nn.Module): + def forward(self, x): + return torch.ops.aten.add.Scalar(x, 0.5) + + def _to_edge_program_manager(self, module): + return to_edge( + torch.export.export(module, (torch.randn(2, 3),), strict=True), + compile_config=get_xnnpack_edge_compile_config(skip_dim_order=True), + ) + + def _to_edge_graph(self, module): + edge = self._to_edge_program_manager(module) + return ExportedProgramPassManager([LiftConstantScalarOperandsPass()])( + edge.exported_program() + ).exported_program + + def test_lifts_mul_scalar_operand(self): + graph = self._to_edge_graph(self.MulScalar()).graph_module.graph + + self.assertFalse( + any(node.target == exir_ops.edge.aten.mul.Scalar for node in graph.nodes) + ) + self.assertTrue( + any(node.target == exir_ops.edge.aten.mul.Tensor for node in graph.nodes) + ) + self.assertTrue(any(node.op == "get_attr" for node in graph.nodes)) + + def test_lifted_mul_scalar_can_emit_without_delegation(self): + edge = self._to_edge_program_manager(self.MulScalar()).transform( + (LiftConstantScalarOperandsPass(),) + ) + + self.assertIsNotNone(edge.to_executorch()) + + def test_keeps_unmapped_scalar_op(self): + graph = self._to_edge_graph(self.AddScalar()).graph_module.graph + + self.assertTrue( + any(node.target == exir_ops.edge.aten.add.Scalar for node in graph.nodes) + ) + + def test_keeps_sdpa_scale_mul_scalar(self): + graph_module = deepcopy(sdpa.get_graphs()[0]) + + LiftConstantScalarOperandsPass()(graph_module) + + scale_mul_count = 0 + lifted_mul_count = 0 + for node in graph_module.graph.nodes: + if node.op != "call_function": + continue + if node.target == exir_ops.edge.aten.mul.Scalar: + scale_mul_count += 1 + if node.target == exir_ops.edge.aten.mul.Tensor: + lifted_mul_count += 1 + + self.assertEqual(scale_mul_count, 2) + self.assertEqual(lifted_mul_count, 0) diff --git a/backends/xnnpack/test/tester/tester.py b/backends/xnnpack/test/tester/tester.py index fc12da231c0..396e149565f 100644 --- a/backends/xnnpack/test/tester/tester.py +++ b/backends/xnnpack/test/tester/tester.py @@ -1,6 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. -# Copyright 2024-2025 Arm Limited and/or its affiliates. +# Copyright 2024-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -24,9 +24,11 @@ QuantizationConfig, ) from executorch.backends.xnnpack.utils.configs import get_xnnpack_edge_compile_config -from executorch.exir import EdgeCompileConfig +from executorch.exir import EdgeCompileConfig, to_edge_transform_and_lower from executorch.exir.backend.partitioner import Partitioner +from executorch.exir.pass_manager import PassType as ExirPassType from torch._export.pass_base import PassType +from torch.export import ExportedProgram from torchao.quantization.pt2e.quantizer import Quantizer @@ -77,6 +79,7 @@ def __init__( self, partitioners: Optional[List[Partitioner]] = None, edge_compile_config: Optional[EdgeCompileConfig] = None, + transform_passes: Optional[List[ExirPassType]] = None, ): super().__init__( default_partitioner_cls=XnnpackPartitioner, @@ -84,6 +87,21 @@ def __init__( edge_compile_config=edge_compile_config or get_xnnpack_edge_compile_config(), ) + self.transform_passes = transform_passes + + def run( + self, + artifact: ExportedProgram, + inputs=None, + generate_etrecord: bool = False, + ) -> None: + self.edge_dialect_program = to_edge_transform_and_lower( + artifact, + transform_passes=self.transform_passes, + compile_config=self.edge_compile_conf, + partitioner=self.partitioners, + generate_etrecord=generate_etrecord, + ) class Partition(BaseStages.Partition): @@ -132,3 +150,37 @@ def __init__( dynamic_shapes=dynamic_shapes, **kwargs, ) + + def to_edge_transform_and_lower( + self, + to_edge_and_transform_stage: Optional[ + BaseStages.ToEdgeTransformAndLower + ] = None, + generate_etrecord: bool = False, + *, + partitioners: Optional[List[Partitioner]] = None, + edge_compile_config: Optional[EdgeCompileConfig] = None, + transform_passes: Optional[List[ExirPassType]] = None, + ): + if to_edge_and_transform_stage is None: + to_edge_and_transform_stage = ToEdgeTransformAndLower( + partitioners=partitioners, + edge_compile_config=edge_compile_config, + transform_passes=transform_passes, + ) + else: + if partitioners is not None: + to_edge_and_transform_stage.partitioners = partitioners + if edge_compile_config is not None: + to_edge_and_transform_stage.edge_compile_conf = edge_compile_config + if transform_passes is not None: + if not isinstance(to_edge_and_transform_stage, ToEdgeTransformAndLower): + raise ValueError( + "transform_passes requires the XNNPACK " + "ToEdgeTransformAndLower stage." + ) + to_edge_and_transform_stage.transform_passes = transform_passes + return super().to_edge_transform_and_lower( + to_edge_and_transform_stage, + generate_etrecord=generate_etrecord, + ) diff --git a/backends/xnnpack/utils/configs.py b/backends/xnnpack/utils/configs.py index 3016e94146b..ec47b81e835 100644 --- a/backends/xnnpack/utils/configs.py +++ b/backends/xnnpack/utils/configs.py @@ -9,6 +9,9 @@ import executorch.exir as exir +from executorch.backends.xnnpack._passes.lift_constant_scalar_operands_pass import ( + LiftConstantScalarOperandsPass, +) from executorch.backends.xnnpack._passes.remove_noop_expand_copy_pass import ( RemoveNoopExpandCopyPass, ) @@ -25,7 +28,7 @@ def get_xnnpack_edge_compile_config( def get_transform_passes(additional_passes=None) -> List[PassType]: - passes = [RemoveNoopExpandCopyPass()] + passes = [RemoveNoopExpandCopyPass(), LiftConstantScalarOperandsPass()] if additional_passes: passes.extend(additional_passes) return passes