Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions backends/arm/_passes/decompose_add_sub_alpha_pass.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2025 Arm Limited and/or its affiliates.
# Copyright 2025-2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand Down Expand Up @@ -76,7 +76,11 @@ def call_operator(self, op, args, kwargs, meta, updated: bool | None = False):
lhs, rhs = args

alpha_full = super().call_operator(
full_op, ((1,), float(alpha)), {}, meta, updated=True
full_op,
((1,), float(alpha)),
{"device": meta["val"].device},
meta,
updated=True,
)
scaled_rhs = super().call_operator(
mul_op,
Expand Down
13 changes: 7 additions & 6 deletions backends/arm/_passes/decompose_avg_pool2d_pass.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2025 Arm Limited and/or its affiliates.
# Copyright 2025-2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand Down Expand Up @@ -51,6 +51,7 @@ def call_operator(self, op, args, kwargs, meta):
full_op, cat_op, avgpool_op, mul_op = get_decomposition(op)

x = args[0]
full_kwargs = {"device": x.data.device}
kernel_h, kernel_w = args[1]
kernel_size = kernel_h * kernel_w
if len(args) > 2 and args[2] is not None:
Expand All @@ -75,7 +76,7 @@ def call_operator(self, op, args, kwargs, meta):
if count_include_pad and pad_w > 0:
pre_pad_shape = [n, c, h, pad_w]
pre_pad = super().call_operator(
full_op, (pre_pad_shape, 0.0), kwargs, meta, updated=True
full_op, (pre_pad_shape, 0.0), full_kwargs, meta, updated=True
)

if ceil_mode and divisor_override is None:
Expand All @@ -88,7 +89,7 @@ def call_operator(self, op, args, kwargs, meta):
if post_pad_w > 0:
post_pad_shape = [n, c, h, post_pad_w]
post_pad = super().call_operator(
full_op, (post_pad_shape, 0.0), kwargs, meta, updated=True
full_op, (post_pad_shape, 0.0), full_kwargs, meta, updated=True
)
cat_nodes = [pre_pad, x, post_pad]
else:
Expand All @@ -103,7 +104,7 @@ def call_operator(self, op, args, kwargs, meta):
if count_include_pad and pad_h > 0:
pre_pad_shape = [n, c, pad_h, w + pad_w + post_pad_w]
pre_pad = super().call_operator(
full_op, (pre_pad_shape, 0.0), kwargs, meta, updated=True
full_op, (pre_pad_shape, 0.0), full_kwargs, meta, updated=True
)

if ceil_mode and divisor_override is None:
Expand All @@ -116,7 +117,7 @@ def call_operator(self, op, args, kwargs, meta):
if post_pad_h > 0:
post_pad_shape = [n, c, post_pad_h, w + pad_w + post_pad_w]
post_pad = super().call_operator(
full_op, (post_pad_shape, 0.0), kwargs, meta, updated=True
full_op, (post_pad_shape, 0.0), full_kwargs, meta, updated=True
)
cat_nodes = [pre_pad, x, post_pad]
else:
Expand All @@ -142,7 +143,7 @@ def call_operator(self, op, args, kwargs, meta):
override_multiplier = super().call_operator(
full_op,
([1, 1, 1, 1], kernel_size / divisor_override),
kwargs,
full_kwargs,
meta,
updated=True,
)
Expand Down
4 changes: 2 additions & 2 deletions backends/arm/_passes/decompose_div_tensor_mode.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2025 Arm Limited and/or its affiliates.
# Copyright 2025-2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand Down Expand Up @@ -78,7 +78,7 @@ def call_operator(self, op, args, kwargs, meta):
zero = super().call_operator(
opset["full"],
args=((1,) * len(meta["val"].size()), 0.0),
kwargs={"dtype": torch.float32},
kwargs={"dtype": torch.float32, "device": meta["val"].device},
meta=meta,
updated=True,
)
Expand Down
6 changes: 4 additions & 2 deletions backends/arm/_passes/decompose_groupnorm_pass.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2025 Arm Limited and/or its affiliates.
# Copyright 2025-2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand Down Expand Up @@ -92,9 +92,11 @@ def call(self, graph_module: torch.fx.GraphModule):
if isinstance(meta["val"], tuple):
shape = meta["val"][0].size()
dtype = meta["val"][0].dtype
device = meta["val"][0].device
else:
shape = meta["val"].size()
dtype = meta["val"].dtype
device = meta["val"].device
match len(args):
# MI profile always provides all the args: x, weight, bias, N, C, HxW, group, eps
case 8:
Expand Down Expand Up @@ -156,7 +158,7 @@ def call(self, graph_module: torch.fx.GraphModule):
graph_module.graph,
full_op,
args=(epsilon_reshaped_shape, eps),
kwargs={"dtype": dtype},
kwargs={"dtype": dtype, "device": device},
from_node=node,
)
add0 = create_node(
Expand Down
6 changes: 4 additions & 2 deletions backends/arm/_passes/decompose_layernorm_pass.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2024-2025 Arm Limited and/or its affiliates.
# Copyright 2024-2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand Down Expand Up @@ -108,9 +108,11 @@ def call(self, graph_module: torch.fx.GraphModule):
if isinstance(meta["val"], tuple):
shape = meta["val"][0].size()
dtype = meta["val"][0].dtype
device = meta["val"][0].device
else:
shape = meta["val"].size()
dtype = meta["val"].dtype
device = meta["val"].device
rank = len(shape)
dims = list(range(-1, -1 * (n_dims + 1), -1))
dims = [dim % rank for dim in dims]
Expand Down Expand Up @@ -142,7 +144,7 @@ def call(self, graph_module: torch.fx.GraphModule):
graph_module.graph,
full_op,
args=(epsilon_reshaped_shape, epsilon),
kwargs={"dtype": dtype},
kwargs={"dtype": dtype, "device": device},
from_node=node,
)
add0 = create_node(
Expand Down
5 changes: 3 additions & 2 deletions backends/arm/_passes/decompose_leaky_relu_pass.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2025 Arm Limited and/or its affiliates.
# Copyright 2025-2026 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
Expand Down Expand Up @@ -57,6 +57,7 @@ def call_operator(self, op, args, kwargs, meta):
x = args[0]
slope = args[1] if len(args) > 1 else 0.01
dtype = x.node.meta["val"].dtype
device = x.node.meta["val"].device
clamp, full, mul, add = _get_leaky_relu_ops(op)
op1 = super().call_operator(
op=clamp, args=(x, 0, None), kwargs=kwargs, meta=meta
Expand All @@ -67,7 +68,7 @@ def call_operator(self, op, args, kwargs, meta):
op3 = super().call_operator(
op=full,
args=(x.node.meta["val"].shape, slope),
kwargs={"dtype": dtype},
kwargs={"dtype": dtype, "device": device},
meta=meta,
)
op4 = super().call_operator(op=mul, args=(op3, op2), kwargs=kwargs, meta=meta)
Expand Down
8 changes: 6 additions & 2 deletions backends/arm/_passes/decompose_meandim_pass.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2024-2025 Arm Limited and/or its affiliates.
# Copyright 2024-2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand Down Expand Up @@ -167,7 +167,11 @@ def _reduce_by_sum(self, op, input_node, dims, meta, dtype):

sum = super().call_operator(sum_op, (input_node, dims, True), {}, meta, True)
full = super().call_operator(
full_op, ([1] * len(output_shape), 1 / N), {"dtype": dtype}, meta, True
full_op,
([1] * len(output_shape), 1 / N),
{"dtype": dtype, "device": input_node.data.device},
meta,
True,
)
if (quant_ops := get_quantization(input_node.node.target)) is not None:
# Insert Q and DQ nodes after full op.
Expand Down
4 changes: 2 additions & 2 deletions backends/arm/_passes/decompose_var_pass.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2024-2025 Arm Limited and/or its affiliates.
# Copyright 2024-2026 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
Expand Down Expand Up @@ -98,7 +98,7 @@ def call_operator(self, op, args, kwargs, meta):
full = super().call_operator(
full_op,
([], 1 / max(0, N - correction)),
{"dtype": dtype},
{"dtype": dtype, "device": x.data.device},
meta,
True,
)
Expand Down
8 changes: 5 additions & 3 deletions backends/arm/_passes/scalars_to_attribute_pass.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2024-2025 Arm Limited and/or its affiliates.
# Copyright 2024-2026 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
Expand Down Expand Up @@ -49,13 +49,14 @@ def call(self, graph_module: GraphModule) -> PassResult:
shape = get_first_fake_tensor(arg).shape
biggest_rank = max(biggest_rank, len(shape))

output_fake_tensor = get_first_fake_tensor(n)
new_args: list[Node | int] = []
for arg in n.args:
if isinstance(arg, Node):
new_args.append(arg)
continue
if isinstance(arg, int) and not torch.is_floating_point(
get_first_fake_tensor(n)
output_fake_tensor
):
new_args.append(arg)
continue
Expand All @@ -64,7 +65,8 @@ def call(self, graph_module: GraphModule) -> PassResult:
get_new_attr_name = get_new_attr_name_with_prefix(prefix)
tensor_constant_name = get_new_attr_name(graph_module)
float_tensor = torch.tensor(
float(cast(Union[int, float], arg))
float(cast(Union[int, float], arg)),
device=output_fake_tensor.device,
).reshape((1,) * biggest_rank)
graph_module.register_buffer(tensor_constant_name, float_tensor)
fake_mode = n.meta["val"].fake_mode
Expand Down
48 changes: 45 additions & 3 deletions backends/arm/quantizer/arm_quantizer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# Copyright 2024-2025 Arm Limited and/or its affiliates.
# Copyright 2024-2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand Down Expand Up @@ -597,8 +597,50 @@ def _annotate_io(
mark_node_as_annotated(node)

def validate(self, model: GraphModule) -> None:
"""TODO: Implement validation of annotated graph for TOSA backend."""
pass
"""Validate the quantization results. Currently, this includes:
- Ensure tensor inputs to each operator live on the same device.

Args:
model (GraphModule): GraphModule being validated.
Raises:
ValueError: If tensor inputs for any operator span more than one
device.
"""
for node in model.graph.nodes:
if node.op != "call_function":
continue

devices = set()
for arg_node in node.all_input_nodes:
meta_val = arg_node.meta.get("val", None)
if meta_val is None:
continue
if isinstance(meta_val, (tuple, list)):
for tensor in meta_val:
devices.add(
str(
getattr(
tensor,
"device",
f"Could not get device from {tensor}",
)
)
)
else:
devices.add(
str(
getattr(
meta_val,
"device",
f"Could not get device from {meta_val}",
)
)
)

if len(devices) > 1:
raise ValueError(
f"Quantizer detected operator {node.name} with different device inputs: {devices}."
)

def quantize_with_submodules(
self,
Expand Down
55 changes: 55 additions & 0 deletions backends/arm/test/quantizer/test_tosa_quantizer_validate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Copyright 2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import pytest
import torch

from executorch.backends.arm.quantizer import TOSAQuantizer
from executorch.backends.arm.tosa import TosaSpecification
from torch.fx import symbolic_trace


def _annotate_placeholders_with_devices(gm, device_map):
for node in gm.graph.nodes:
if node.op == "placeholder":
device = device_map[node.target]
node.meta["val"] = torch.empty(1, device=device)


def _get_quantizer():
return TOSAQuantizer(TosaSpecification.create_from_string("TOSA-1.0+INT"))


class TwoIndependentAdds(torch.nn.Module):

def forward(self, x: torch.Tensor, y: torch.Tensor):
return x + 1, y + 1


class CrossDeviceAdd(torch.nn.Module):

def forward(self, x: torch.Tensor, y: torch.Tensor):
return x + y


def test_validate_allows_different_devices_across_operators():
gm = symbolic_trace(TwoIndependentAdds())
_annotate_placeholders_with_devices(
gm, {"x": torch.device("cpu"), "y": torch.device("meta")}
)

quantizer = _get_quantizer()
quantizer.validate(gm)


def test_validate_rejects_mixed_devices_within_operator():
gm = symbolic_trace(CrossDeviceAdd())
_annotate_placeholders_with_devices(
gm, {"x": torch.device("cpu"), "y": torch.device("meta")}
)

quantizer = _get_quantizer()
with pytest.raises(ValueError, match="Quantizer detected operator"):
quantizer.validate(gm)
7 changes: 5 additions & 2 deletions backends/transforms/replace_scalar_with_tensor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# Copyright 2025 Arm Limited and/or its affiliates.
# Copyright 2025-2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand Down Expand Up @@ -59,7 +59,10 @@ def get_replacement(self, op, args, kwargs, meta):
(1,),
args[1],
),
kwargs={"dtype": args[0].to_tensor().dtype},
kwargs={
"dtype": args[0].to_tensor().dtype,
"device": args[0].to_tensor().device,
},
meta=meta,
),
# Other args.
Expand Down
Loading