Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions backends/arm/_passes/decompose_add_sub_alpha_pass.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2025 Arm Limited and/or its affiliates.
# Copyright 2025-2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand Down Expand Up @@ -76,7 +76,11 @@ def call_operator(self, op, args, kwargs, meta, updated: bool | None = False):
lhs, rhs = args

alpha_full = super().call_operator(
full_op, ((1,), float(alpha)), {}, meta, updated=True
full_op,
((1,), float(alpha)),
{"device": meta["val"].device},
meta,
updated=True,
)
scaled_rhs = super().call_operator(
mul_op,
Expand Down
13 changes: 7 additions & 6 deletions backends/arm/_passes/decompose_avg_pool2d_pass.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2025 Arm Limited and/or its affiliates.
# Copyright 2025-2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand Down Expand Up @@ -51,6 +51,7 @@ def call_operator(self, op, args, kwargs, meta):
full_op, cat_op, avgpool_op, mul_op = get_decomposition(op)

x = args[0]
full_kwargs = {"device": x.data.device}
kernel_h, kernel_w = args[1]
kernel_size = kernel_h * kernel_w
if len(args) > 2 and args[2] is not None:
Expand All @@ -75,7 +76,7 @@ def call_operator(self, op, args, kwargs, meta):
if count_include_pad and pad_w > 0:
pre_pad_shape = [n, c, h, pad_w]
pre_pad = super().call_operator(
full_op, (pre_pad_shape, 0.0), kwargs, meta, updated=True
full_op, (pre_pad_shape, 0.0), full_kwargs, meta, updated=True
)

if ceil_mode and divisor_override is None:
Expand All @@ -88,7 +89,7 @@ def call_operator(self, op, args, kwargs, meta):
if post_pad_w > 0:
post_pad_shape = [n, c, h, post_pad_w]
post_pad = super().call_operator(
full_op, (post_pad_shape, 0.0), kwargs, meta, updated=True
full_op, (post_pad_shape, 0.0), full_kwargs, meta, updated=True
)
cat_nodes = [pre_pad, x, post_pad]
else:
Expand All @@ -103,7 +104,7 @@ def call_operator(self, op, args, kwargs, meta):
if count_include_pad and pad_h > 0:
pre_pad_shape = [n, c, pad_h, w + pad_w + post_pad_w]
pre_pad = super().call_operator(
full_op, (pre_pad_shape, 0.0), kwargs, meta, updated=True
full_op, (pre_pad_shape, 0.0), full_kwargs, meta, updated=True
)

if ceil_mode and divisor_override is None:
Expand All @@ -116,7 +117,7 @@ def call_operator(self, op, args, kwargs, meta):
if post_pad_h > 0:
post_pad_shape = [n, c, post_pad_h, w + pad_w + post_pad_w]
post_pad = super().call_operator(
full_op, (post_pad_shape, 0.0), kwargs, meta, updated=True
full_op, (post_pad_shape, 0.0), full_kwargs, meta, updated=True
)
cat_nodes = [pre_pad, x, post_pad]
else:
Expand All @@ -142,7 +143,7 @@ def call_operator(self, op, args, kwargs, meta):
override_multiplier = super().call_operator(
full_op,
([1, 1, 1, 1], kernel_size / divisor_override),
kwargs,
full_kwargs,
meta,
updated=True,
)
Expand Down
4 changes: 2 additions & 2 deletions backends/arm/_passes/decompose_div_tensor_mode.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2025 Arm Limited and/or its affiliates.
# Copyright 2025-2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand Down Expand Up @@ -78,7 +78,7 @@ def call_operator(self, op, args, kwargs, meta):
zero = super().call_operator(
opset["full"],
args=((1,) * len(meta["val"].size()), 0.0),
kwargs={"dtype": torch.float32},
kwargs={"dtype": torch.float32, "device": meta["val"].device},
meta=meta,
updated=True,
)
Expand Down
6 changes: 4 additions & 2 deletions backends/arm/_passes/decompose_groupnorm_pass.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2025 Arm Limited and/or its affiliates.
# Copyright 2025-2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand Down Expand Up @@ -92,9 +92,11 @@ def call(self, graph_module: torch.fx.GraphModule):
if isinstance(meta["val"], tuple):
shape = meta["val"][0].size()
dtype = meta["val"][0].dtype
device = meta["val"][0].device
else:
shape = meta["val"].size()
dtype = meta["val"].dtype
device = meta["val"].device
match len(args):
# MI profile always provides all the args: x, weight, bias, N, C, HxW, group, eps
case 8:
Expand Down Expand Up @@ -156,7 +158,7 @@ def call(self, graph_module: torch.fx.GraphModule):
graph_module.graph,
full_op,
args=(epsilon_reshaped_shape, eps),
kwargs={"dtype": dtype},
kwargs={"dtype": dtype, "device": device},
from_node=node,
)
add0 = create_node(
Expand Down
6 changes: 4 additions & 2 deletions backends/arm/_passes/decompose_layernorm_pass.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2024-2025 Arm Limited and/or its affiliates.
# Copyright 2024-2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand Down Expand Up @@ -103,9 +103,11 @@ def call(self, graph_module: torch.fx.GraphModule):
if isinstance(meta["val"], tuple):
shape = meta["val"][0].size()
dtype = meta["val"][0].dtype
device = meta["val"][0].device
else:
shape = meta["val"].size()
dtype = meta["val"].dtype
device = meta["val"].device
rank = len(shape)
dims = list(range(-1, -1 * (n_dims + 1), -1))
dims = [dim % rank for dim in dims]
Expand Down Expand Up @@ -137,7 +139,7 @@ def call(self, graph_module: torch.fx.GraphModule):
graph_module.graph,
full_op,
args=(epsilon_reshaped_shape, epsilon),
kwargs={"dtype": dtype},
kwargs={"dtype": dtype, "device": device},
from_node=node,
)
add0 = create_node(
Expand Down
5 changes: 3 additions & 2 deletions backends/arm/_passes/decompose_leaky_relu_pass.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2025 Arm Limited and/or its affiliates.
# Copyright 2025-2026 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
Expand Down Expand Up @@ -57,6 +57,7 @@ def call_operator(self, op, args, kwargs, meta):
x = args[0]
slope = args[1] if len(args) > 1 else 0.01
dtype = x.node.meta["val"].dtype
device = x.node.meta["val"].device
clamp, full, mul, add = _get_leaky_relu_ops(op)
op1 = super().call_operator(
op=clamp, args=(x, 0, None), kwargs=kwargs, meta=meta
Expand All @@ -67,7 +68,7 @@ def call_operator(self, op, args, kwargs, meta):
op3 = super().call_operator(
op=full,
args=(x.node.meta["val"].shape, slope),
kwargs={"dtype": dtype},
kwargs={"dtype": dtype, "device": device},
meta=meta,
)
op4 = super().call_operator(op=mul, args=(op3, op2), kwargs=kwargs, meta=meta)
Expand Down
8 changes: 6 additions & 2 deletions backends/arm/_passes/decompose_meandim_pass.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2024-2025 Arm Limited and/or its affiliates.
# Copyright 2024-2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand Down Expand Up @@ -167,7 +167,11 @@ def _reduce_by_sum(self, op, input_node, dims, meta, dtype):

sum = super().call_operator(sum_op, (input_node, dims, True), {}, meta, True)
full = super().call_operator(
full_op, ([1] * len(output_shape), 1 / N), {"dtype": dtype}, meta, True
full_op,
([1] * len(output_shape), 1 / N),
{"dtype": dtype, "device": input_node.data.device},
meta,
True,
)
if (quant_ops := get_quantization(input_node.node.target)) is not None:
# Insert Q and DQ nodes after full op.
Expand Down
4 changes: 2 additions & 2 deletions backends/arm/_passes/decompose_var_pass.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2024-2025 Arm Limited and/or its affiliates.
# Copyright 2024-2026 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
Expand Down Expand Up @@ -98,7 +98,7 @@ def call_operator(self, op, args, kwargs, meta):
full = super().call_operator(
full_op,
([], 1 / max(0, N - correction)),
{"dtype": dtype},
{"dtype": dtype, "device": x.data.device},
meta,
True,
)
Expand Down
8 changes: 5 additions & 3 deletions backends/arm/_passes/scalars_to_attribute_pass.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2024-2025 Arm Limited and/or its affiliates.
# Copyright 2024-2026 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
Expand Down Expand Up @@ -49,13 +49,14 @@ def call(self, graph_module: GraphModule) -> PassResult:
shape = get_first_fake_tensor(arg).shape
biggest_rank = max(biggest_rank, len(shape))

output_fake_tensor = get_first_fake_tensor(n)
new_args: list[Node | int] = []
for arg in n.args:
if isinstance(arg, Node):
new_args.append(arg)
continue
if isinstance(arg, int) and not torch.is_floating_point(
get_first_fake_tensor(n)
output_fake_tensor
):
new_args.append(arg)
continue
Expand All @@ -64,7 +65,8 @@ def call(self, graph_module: GraphModule) -> PassResult:
get_new_attr_name = get_new_attr_name_with_prefix(prefix)
tensor_constant_name = get_new_attr_name(graph_module)
float_tensor = torch.tensor(
float(cast(Union[int, float], arg))
float(cast(Union[int, float], arg)),
device=output_fake_tensor.device,
).reshape((1,) * biggest_rank)
graph_module.register_buffer(tensor_constant_name, float_tensor)
fake_mode = n.meta["val"].fake_mode
Expand Down
48 changes: 45 additions & 3 deletions backends/arm/quantizer/arm_quantizer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# Copyright 2024-2025 Arm Limited and/or its affiliates.
# Copyright 2024-2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand Down Expand Up @@ -597,8 +597,50 @@ def _annotate_io(
mark_node_as_annotated(node)

def validate(self, model: GraphModule) -> None:
"""TODO: Implement validation of annotated graph for TOSA backend."""
pass
"""Validate the quantization results. Currently, this includes:
- Ensure tensor inputs to each operator live on the same device.

Args:
model (GraphModule): GraphModule being validated.
Raises:
ValueError: If tensor inputs for any operator span more than one
device.
"""
for node in model.graph.nodes:
if node.op != "call_function":
continue

devices = set()
for arg_node in node.all_input_nodes:
meta_val = arg_node.meta.get("val", None)
if meta_val is None:
continue
if isinstance(meta_val, (tuple, list)):
for tensor in meta_val:
devices.add(
str(
getattr(
tensor,
"device",
f"Could not get device from {tensor}",
)
)
)
else:
devices.add(
str(
getattr(
meta_val,
"device",
f"Could not get device from {meta_val}",
)
)
)

if len(devices) > 1:
raise ValueError(
f"Quantizer detected operator {node.name} with different device inputs: {devices}."
)

def quantize_with_submodules(
self,
Expand Down
55 changes: 55 additions & 0 deletions backends/arm/test/quantizer/test_tosa_quantizer_validate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Copyright 2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import pytest
import torch

from executorch.backends.arm.quantizer import TOSAQuantizer
from executorch.backends.arm.tosa import TosaSpecification
from torch.fx import symbolic_trace


def _annotate_placeholders_with_devices(gm, device_map):
for node in gm.graph.nodes:
if node.op == "placeholder":
device = device_map[node.target]
node.meta["val"] = torch.empty(1, device=device)


def _get_quantizer():
return TOSAQuantizer(TosaSpecification.create_from_string("TOSA-1.0+INT"))


class TwoIndependentAdds(torch.nn.Module):

def forward(self, x: torch.Tensor, y: torch.Tensor):
return x + 1, y + 1


class CrossDeviceAdd(torch.nn.Module):

def forward(self, x: torch.Tensor, y: torch.Tensor):
return x + y


def test_validate_allows_different_devices_across_operators():
gm = symbolic_trace(TwoIndependentAdds())
_annotate_placeholders_with_devices(
gm, {"x": torch.device("cpu"), "y": torch.device("meta")}
)

quantizer = _get_quantizer()
quantizer.validate(gm)


def test_validate_rejects_mixed_devices_within_operator():
gm = symbolic_trace(CrossDeviceAdd())
_annotate_placeholders_with_devices(
gm, {"x": torch.device("cpu"), "y": torch.device("meta")}
)

quantizer = _get_quantizer()
with pytest.raises(ValueError, match="Quantizer detected operator"):
quantizer.validate(gm)
7 changes: 5 additions & 2 deletions backends/transforms/replace_scalar_with_tensor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# Copyright 2025 Arm Limited and/or its affiliates.
# Copyright 2025-2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand Down Expand Up @@ -59,7 +59,10 @@ def get_replacement(self, op, args, kwargs, meta):
(1,),
args[1],
),
kwargs={"dtype": args[0].to_tensor().dtype},
kwargs={
"dtype": args[0].to_tensor().dtype,
"device": args[0].to_tensor().device,
},
meta=meta,
),
# Other args.
Expand Down
Loading