Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 150 additions & 0 deletions backends/xnnpack/_passes/lift_constant_scalar_operands_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# Copyright 2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

from numbers import Number
from typing import Dict, Optional, Union

import torch
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.dialects.edge._ops import EdgeOpOverload
from executorch.exir.pass_base import ExportPass, PassResult
from torch._ops import OpOverload


ScalarOp = Union[EdgeOpOverload, OpOverload]


class LiftConstantScalarOperandsPass(ExportPass):
"""
Lift scalar operands into tensor constants for selected binary ops.

XNNPACK already supports the tensor overloads for these binary operations.
This pass converts explicitly listed scalar overloads to their tensor
overloads by replacing constant scalar operands with small tensor constants.
The constants are registered as buffers so they do not become portable
``full`` kernels. Keep the op map narrow until each new scalar overload is
covered by tests.
"""

default_scalar_to_tensor_ops: Dict[ScalarOp, ScalarOp] = {
exir_ops.edge.aten.mul.Scalar: exir_ops.edge.aten.mul.Tensor,
}
sdpa_passthrough_ops = {
exir_ops.edge.aten.expand_copy.default,
exir_ops.edge.aten.view_copy.default,
}

def __init__(
self,
scalar_to_tensor_ops: Optional[Dict[ScalarOp, ScalarOp]] = None,
) -> None:
super().__init__()
self.scalar_to_tensor_ops = (
scalar_to_tensor_ops
if scalar_to_tensor_ops is not None
else self.default_scalar_to_tensor_ops
)

def _create_constant_node(
self,
graph_module: torch.fx.GraphModule,
node: torch.fx.Node,
value: Number,
) -> torch.fx.Node:
input_node = node.args[0]
if not isinstance(input_node, torch.fx.Node):
raise RuntimeError("Expected scalar op input to be an FX node.")

input_value = input_node.meta["val"]
tensor = torch.tensor(value, dtype=input_value.dtype, device=input_value.device)
name = self._get_new_attr_name(graph_module)
# Keep constants as module attributes so the portable path can emit them
# without introducing aten.full, while XNNPACK can still read them as params.
graph_module.register_buffer(name, tensor)

fake_mode = node.meta["val"].fake_mode
with graph_module.graph.inserting_before(node):
constant_node = graph_module.graph.get_attr(name)
constant_node.meta["val"] = fake_mode.from_tensor(
tensor, static_shapes=True
)
return constant_node

def _get_new_attr_name(self, graph_module: torch.fx.GraphModule) -> str:
prefix = "_tensor_constant_"
index = 0
while hasattr(graph_module, f"{prefix}{index}"):
index += 1
return f"{prefix}{index}"

def _feeds_sdpa_qk_bmm(self, node: torch.fx.Node) -> bool:
"""
Return true for the scale muls consumed by XNNPACK's SDPA pattern.

ConvertToSDPAPass recovers the user-specified attention scale from the
pre-QK^T ``aten.mul.Scalar`` nodes. Keep those scalar muls intact so
SDPA conversion can still find the scale before replacing the pattern.
"""
users_to_visit = list(node.users)
visited = set()
while users_to_visit:
user = users_to_visit.pop()
if user in visited:
continue
visited.add(user)

if (
user.op == "call_function"
and user.target == exir_ops.edge.aten.bmm.default
):
return True

if user.op == "call_function" and user.target in self.sdpa_passthrough_ops:
users_to_visit.extend(user.users)

return False

def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
modified = False

for node in list(graph_module.graph.nodes):
if (
node.op != "call_function"
or node.target not in self.scalar_to_tensor_ops
or len(node.args) != 2
or not isinstance(node.args[0], torch.fx.Node)
or not isinstance(node.args[1], Number)
):
continue

if (
node.target == exir_ops.edge.aten.mul.Scalar
and self._feeds_sdpa_qk_bmm(node)
):
continue

input_value = node.args[0].meta.get("val")
output_value = node.meta.get("val")
if (
input_value is None
or output_value is None
or input_value.dtype != output_value.dtype
):
continue

tensor_arg = self._create_constant_node(graph_module, node, node.args[1])
node.args = (node.args[0], tensor_arg)
node.target = self.scalar_to_tensor_ops[node.target]
modified = True

graph_module.graph.eliminate_dead_code()
graph_module.graph.lint()
graph_module.recompile()

return PassResult(graph_module, modified)
31 changes: 31 additions & 0 deletions backends/xnnpack/test/ops/test_multiply.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# Copyright 2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand All @@ -8,6 +9,10 @@

import torch
from executorch.backends.xnnpack.test.tester import Tester
from executorch.backends.xnnpack.utils.configs import (
get_transform_passes,
get_xnnpack_edge_compile_config,
)


class TestMul(unittest.TestCase):
Expand All @@ -29,6 +34,10 @@ def forward(self, x, y):
z = torch.mul(x, y) * torch.functional.torch.mul(x, y)
return z

class MulScalar(torch.nn.Module):
def forward(self, x):
return torch.ops.aten.mul.Scalar(x, 0.5)

class MulRelu(torch.nn.Module):
def forward(self, x, y):
z = x * y
Expand Down Expand Up @@ -58,6 +67,28 @@ def test_fp32_mul(self):
inputs = (torch.randn((1, 3)), torch.randn((4, 3)))
self._test_mul(inputs)

def test_fp32_mul_scalar(self):
(
Tester(self.MulScalar(), (torch.randn(2, 3),))
.export()
.to_edge_transform_and_lower(
transform_passes=get_transform_passes(),
edge_compile_config=get_xnnpack_edge_compile_config(
skip_dim_order=True
),
)
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
.check_not(
[
"executorch_exir_dialects_edge__ops_aten_mul_Tensor",
"executorch_exir_dialects_edge__ops_aten_mul_Scalar",
]
)
.to_executorch()
.serialize()
.run_method_and_compare_outputs()
)

def test_qs8_mul(self):
inputs = (torch.randn(1, 1, 4, 4), torch.randn(1, 1, 4, 1))
(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# Copyright 2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest
from copy import deepcopy

import torch
from executorch.backends.xnnpack._passes.lift_constant_scalar_operands_pass import (
LiftConstantScalarOperandsPass,
)
from executorch.backends.xnnpack.partition.graphs import sdpa
from executorch.backends.xnnpack.utils.configs import get_xnnpack_edge_compile_config
from executorch.exir import to_edge
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_manager import ExportedProgramPassManager


class TestLiftConstantScalarOperandsPass(unittest.TestCase):
def setUp(self):
torch._dynamo.reset()

class MulScalar(torch.nn.Module):
def forward(self, x):
return torch.ops.aten.mul.Scalar(x, 0.5)

class AddScalar(torch.nn.Module):
def forward(self, x):
return torch.ops.aten.add.Scalar(x, 0.5)

def _to_edge_program_manager(self, module):
return to_edge(
torch.export.export(module, (torch.randn(2, 3),), strict=True),
compile_config=get_xnnpack_edge_compile_config(skip_dim_order=True),
)

def _to_edge_graph(self, module):
edge = self._to_edge_program_manager(module)
return ExportedProgramPassManager([LiftConstantScalarOperandsPass()])(
edge.exported_program()
).exported_program

def test_lifts_mul_scalar_operand(self):
graph = self._to_edge_graph(self.MulScalar()).graph_module.graph

self.assertFalse(
any(node.target == exir_ops.edge.aten.mul.Scalar for node in graph.nodes)
)
self.assertTrue(
any(node.target == exir_ops.edge.aten.mul.Tensor for node in graph.nodes)
)
self.assertTrue(any(node.op == "get_attr" for node in graph.nodes))

def test_lifted_mul_scalar_can_emit_without_delegation(self):
edge = self._to_edge_program_manager(self.MulScalar()).transform(
(LiftConstantScalarOperandsPass(),)
)

self.assertIsNotNone(edge.to_executorch())

def test_keeps_unmapped_scalar_op(self):
graph = self._to_edge_graph(self.AddScalar()).graph_module.graph

self.assertTrue(
any(node.target == exir_ops.edge.aten.add.Scalar for node in graph.nodes)
)

def test_keeps_sdpa_scale_mul_scalar(self):
graph_module = deepcopy(sdpa.get_graphs()[0])

LiftConstantScalarOperandsPass()(graph_module)

scale_mul_count = 0
lifted_mul_count = 0
for node in graph_module.graph.nodes:
if node.op != "call_function":
continue
if node.target == exir_ops.edge.aten.mul.Scalar:
scale_mul_count += 1
if node.target == exir_ops.edge.aten.mul.Tensor:
lifted_mul_count += 1

self.assertEqual(scale_mul_count, 2)
self.assertEqual(lifted_mul_count, 0)
56 changes: 54 additions & 2 deletions backends/xnnpack/test/tester/tester.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# Copyright 2024-2025 Arm Limited and/or its affiliates.
# Copyright 2024-2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand All @@ -24,9 +24,11 @@
QuantizationConfig,
)
from executorch.backends.xnnpack.utils.configs import get_xnnpack_edge_compile_config
from executorch.exir import EdgeCompileConfig
from executorch.exir import EdgeCompileConfig, to_edge_transform_and_lower
from executorch.exir.backend.partitioner import Partitioner
from executorch.exir.pass_manager import PassType as ExirPassType
from torch._export.pass_base import PassType
from torch.export import ExportedProgram
from torchao.quantization.pt2e.quantizer import Quantizer
Comment thread
mansnils marked this conversation as resolved.


Expand Down Expand Up @@ -77,13 +79,29 @@ def __init__(
self,
partitioners: Optional[List[Partitioner]] = None,
edge_compile_config: Optional[EdgeCompileConfig] = None,
transform_passes: Optional[List[ExirPassType]] = None,
):
super().__init__(
default_partitioner_cls=XnnpackPartitioner,
partitioners=partitioners,
edge_compile_config=edge_compile_config
or get_xnnpack_edge_compile_config(),
)
self.transform_passes = transform_passes

def run(
self,
artifact: ExportedProgram,
inputs=None,
generate_etrecord: bool = False,
) -> None:
self.edge_dialect_program = to_edge_transform_and_lower(
artifact,
transform_passes=self.transform_passes,
compile_config=self.edge_compile_conf,
partitioner=self.partitioners,
generate_etrecord=generate_etrecord,
)


class Partition(BaseStages.Partition):
Expand Down Expand Up @@ -132,3 +150,37 @@ def __init__(
dynamic_shapes=dynamic_shapes,
**kwargs,
)

def to_edge_transform_and_lower(
self,
to_edge_and_transform_stage: Optional[
BaseStages.ToEdgeTransformAndLower
] = None,
generate_etrecord: bool = False,
*,
partitioners: Optional[List[Partitioner]] = None,
edge_compile_config: Optional[EdgeCompileConfig] = None,
transform_passes: Optional[List[ExirPassType]] = None,
):
if to_edge_and_transform_stage is None:
to_edge_and_transform_stage = ToEdgeTransformAndLower(
partitioners=partitioners,
edge_compile_config=edge_compile_config,
transform_passes=transform_passes,
)
else:
if partitioners is not None:
to_edge_and_transform_stage.partitioners = partitioners
if edge_compile_config is not None:
to_edge_and_transform_stage.edge_compile_conf = edge_compile_config
if transform_passes is not None:
if not isinstance(to_edge_and_transform_stage, ToEdgeTransformAndLower):
raise ValueError(
"transform_passes requires the XNNPACK "
"ToEdgeTransformAndLower stage."
)
to_edge_and_transform_stage.transform_passes = transform_passes
return super().to_edge_transform_and_lower(
to_edge_and_transform_stage,
generate_etrecord=generate_etrecord,
)
Loading
Loading