diff --git a/backends/arm/_passes/decompose_add_sub_alpha_pass.py b/backends/arm/_passes/decompose_add_sub_alpha_pass.py index c0ed1bae09b..e59fb548a44 100644 --- a/backends/arm/_passes/decompose_add_sub_alpha_pass.py +++ b/backends/arm/_passes/decompose_add_sub_alpha_pass.py @@ -1,4 +1,4 @@ -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -76,7 +76,11 @@ def call_operator(self, op, args, kwargs, meta, updated: bool | None = False): lhs, rhs = args alpha_full = super().call_operator( - full_op, ((1,), float(alpha)), {}, meta, updated=True + full_op, + ((1,), float(alpha)), + {"device": meta["val"].device}, + meta, + updated=True, ) scaled_rhs = super().call_operator( mul_op, diff --git a/backends/arm/_passes/decompose_avg_pool2d_pass.py b/backends/arm/_passes/decompose_avg_pool2d_pass.py index d259dfb203f..c46a54b0efa 100644 --- a/backends/arm/_passes/decompose_avg_pool2d_pass.py +++ b/backends/arm/_passes/decompose_avg_pool2d_pass.py @@ -1,4 +1,4 @@ -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -51,6 +51,7 @@ def call_operator(self, op, args, kwargs, meta): full_op, cat_op, avgpool_op, mul_op = get_decomposition(op) x = args[0] + full_kwargs = {"device": x.data.device} kernel_h, kernel_w = args[1] kernel_size = kernel_h * kernel_w if len(args) > 2 and args[2] is not None: @@ -75,7 +76,7 @@ def call_operator(self, op, args, kwargs, meta): if count_include_pad and pad_w > 0: pre_pad_shape = [n, c, h, pad_w] pre_pad = super().call_operator( - full_op, (pre_pad_shape, 0.0), kwargs, meta, updated=True + full_op, (pre_pad_shape, 0.0), full_kwargs, meta, updated=True ) if ceil_mode and divisor_override is None: @@ -88,7 +89,7 @@ def call_operator(self, op, args, kwargs, meta): if post_pad_w > 0: post_pad_shape = [n, c, h, post_pad_w] post_pad = super().call_operator( - full_op, (post_pad_shape, 0.0), kwargs, meta, updated=True + full_op, (post_pad_shape, 0.0), full_kwargs, meta, updated=True ) cat_nodes = [pre_pad, x, post_pad] else: @@ -103,7 +104,7 @@ def call_operator(self, op, args, kwargs, meta): if count_include_pad and pad_h > 0: pre_pad_shape = [n, c, pad_h, w + pad_w + post_pad_w] pre_pad = super().call_operator( - full_op, (pre_pad_shape, 0.0), kwargs, meta, updated=True + full_op, (pre_pad_shape, 0.0), full_kwargs, meta, updated=True ) if ceil_mode and divisor_override is None: @@ -116,7 +117,7 @@ def call_operator(self, op, args, kwargs, meta): if post_pad_h > 0: post_pad_shape = [n, c, post_pad_h, w + pad_w + post_pad_w] post_pad = super().call_operator( - full_op, (post_pad_shape, 0.0), kwargs, meta, updated=True + full_op, (post_pad_shape, 0.0), full_kwargs, meta, updated=True ) cat_nodes = [pre_pad, x, post_pad] else: @@ -142,7 +143,7 @@ def call_operator(self, op, args, kwargs, meta): override_multiplier = super().call_operator( full_op, ([1, 1, 1, 1], kernel_size / divisor_override), - kwargs, + full_kwargs, meta, updated=True, ) diff --git a/backends/arm/_passes/decompose_div_tensor_mode.py b/backends/arm/_passes/decompose_div_tensor_mode.py index a3be67a572b..39e22c447e2 100644 --- a/backends/arm/_passes/decompose_div_tensor_mode.py +++ b/backends/arm/_passes/decompose_div_tensor_mode.py @@ -1,4 +1,4 @@ -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -78,7 +78,7 @@ def call_operator(self, op, args, kwargs, meta): zero = super().call_operator( opset["full"], args=((1,) * len(meta["val"].size()), 0.0), - kwargs={"dtype": torch.float32}, + kwargs={"dtype": torch.float32, "device": meta["val"].device}, meta=meta, updated=True, ) diff --git a/backends/arm/_passes/decompose_groupnorm_pass.py b/backends/arm/_passes/decompose_groupnorm_pass.py index 65c373cdae9..322f6949c64 100644 --- a/backends/arm/_passes/decompose_groupnorm_pass.py +++ b/backends/arm/_passes/decompose_groupnorm_pass.py @@ -1,4 +1,4 @@ -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -92,9 +92,11 @@ def call(self, graph_module: torch.fx.GraphModule): if isinstance(meta["val"], tuple): shape = meta["val"][0].size() dtype = meta["val"][0].dtype + device = meta["val"][0].device else: shape = meta["val"].size() dtype = meta["val"].dtype + device = meta["val"].device match len(args): # MI profile always provides all the args: x, weight, bias, N, C, HxW, group, eps case 8: @@ -156,7 +158,7 @@ def call(self, graph_module: torch.fx.GraphModule): graph_module.graph, full_op, args=(epsilon_reshaped_shape, eps), - kwargs={"dtype": dtype}, + kwargs={"dtype": dtype, "device": device}, from_node=node, ) add0 = create_node( diff --git a/backends/arm/_passes/decompose_layernorm_pass.py b/backends/arm/_passes/decompose_layernorm_pass.py index 5ebb7e92dad..992b21fd592 100644 --- a/backends/arm/_passes/decompose_layernorm_pass.py +++ b/backends/arm/_passes/decompose_layernorm_pass.py @@ -1,4 +1,4 @@ -# Copyright 2024-2025 Arm Limited and/or its affiliates. +# Copyright 2024-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -108,9 +108,11 @@ def call(self, graph_module: torch.fx.GraphModule): if isinstance(meta["val"], tuple): shape = meta["val"][0].size() dtype = meta["val"][0].dtype + device = meta["val"][0].device else: shape = meta["val"].size() dtype = meta["val"].dtype + device = meta["val"].device rank = len(shape) dims = list(range(-1, -1 * (n_dims + 1), -1)) dims = [dim % rank for dim in dims] @@ -142,7 +144,7 @@ def call(self, graph_module: torch.fx.GraphModule): graph_module.graph, full_op, args=(epsilon_reshaped_shape, epsilon), - kwargs={"dtype": dtype}, + kwargs={"dtype": dtype, "device": device}, from_node=node, ) add0 = create_node( diff --git a/backends/arm/_passes/decompose_leaky_relu_pass.py b/backends/arm/_passes/decompose_leaky_relu_pass.py index 6ebe4dd8a51..3e52efec33b 100644 --- a/backends/arm/_passes/decompose_leaky_relu_pass.py +++ b/backends/arm/_passes/decompose_leaky_relu_pass.py @@ -1,4 +1,4 @@ -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the @@ -57,6 +57,7 @@ def call_operator(self, op, args, kwargs, meta): x = args[0] slope = args[1] if len(args) > 1 else 0.01 dtype = x.node.meta["val"].dtype + device = x.node.meta["val"].device clamp, full, mul, add = _get_leaky_relu_ops(op) op1 = super().call_operator( op=clamp, args=(x, 0, None), kwargs=kwargs, meta=meta @@ -67,7 +68,7 @@ def call_operator(self, op, args, kwargs, meta): op3 = super().call_operator( op=full, args=(x.node.meta["val"].shape, slope), - kwargs={"dtype": dtype}, + kwargs={"dtype": dtype, "device": device}, meta=meta, ) op4 = super().call_operator(op=mul, args=(op3, op2), kwargs=kwargs, meta=meta) diff --git a/backends/arm/_passes/decompose_meandim_pass.py b/backends/arm/_passes/decompose_meandim_pass.py index 7a48f6c941d..a8cac5a197d 100644 --- a/backends/arm/_passes/decompose_meandim_pass.py +++ b/backends/arm/_passes/decompose_meandim_pass.py @@ -1,4 +1,4 @@ -# Copyright 2024-2025 Arm Limited and/or its affiliates. +# Copyright 2024-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -167,7 +167,11 @@ def _reduce_by_sum(self, op, input_node, dims, meta, dtype): sum = super().call_operator(sum_op, (input_node, dims, True), {}, meta, True) full = super().call_operator( - full_op, ([1] * len(output_shape), 1 / N), {"dtype": dtype}, meta, True + full_op, + ([1] * len(output_shape), 1 / N), + {"dtype": dtype, "device": input_node.data.device}, + meta, + True, ) if (quant_ops := get_quantization(input_node.node.target)) is not None: # Insert Q and DQ nodes after full op. diff --git a/backends/arm/_passes/decompose_var_pass.py b/backends/arm/_passes/decompose_var_pass.py index f3b7e6ab67f..cc6e87c0455 100644 --- a/backends/arm/_passes/decompose_var_pass.py +++ b/backends/arm/_passes/decompose_var_pass.py @@ -1,4 +1,4 @@ -# Copyright 2024-2025 Arm Limited and/or its affiliates. +# Copyright 2024-2026 Arm Limited and/or its affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the @@ -98,7 +98,7 @@ def call_operator(self, op, args, kwargs, meta): full = super().call_operator( full_op, ([], 1 / max(0, N - correction)), - {"dtype": dtype}, + {"dtype": dtype, "device": x.data.device}, meta, True, ) diff --git a/backends/arm/_passes/scalars_to_attribute_pass.py b/backends/arm/_passes/scalars_to_attribute_pass.py index ddef9c75213..87f58252e67 100644 --- a/backends/arm/_passes/scalars_to_attribute_pass.py +++ b/backends/arm/_passes/scalars_to_attribute_pass.py @@ -1,4 +1,4 @@ -# Copyright 2024-2025 Arm Limited and/or its affiliates. +# Copyright 2024-2026 Arm Limited and/or its affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the @@ -49,13 +49,14 @@ def call(self, graph_module: GraphModule) -> PassResult: shape = get_first_fake_tensor(arg).shape biggest_rank = max(biggest_rank, len(shape)) + output_fake_tensor = get_first_fake_tensor(n) new_args: list[Node | int] = [] for arg in n.args: if isinstance(arg, Node): new_args.append(arg) continue if isinstance(arg, int) and not torch.is_floating_point( - get_first_fake_tensor(n) + output_fake_tensor ): new_args.append(arg) continue @@ -64,7 +65,8 @@ def call(self, graph_module: GraphModule) -> PassResult: get_new_attr_name = get_new_attr_name_with_prefix(prefix) tensor_constant_name = get_new_attr_name(graph_module) float_tensor = torch.tensor( - float(cast(Union[int, float], arg)) + float(cast(Union[int, float], arg)), + device=output_fake_tensor.device, ).reshape((1,) * biggest_rank) graph_module.register_buffer(tensor_constant_name, float_tensor) fake_mode = n.meta["val"].fake_mode diff --git a/backends/arm/quantizer/arm_quantizer.py b/backends/arm/quantizer/arm_quantizer.py index 425fea0987b..735bf03b1aa 100644 --- a/backends/arm/quantizer/arm_quantizer.py +++ b/backends/arm/quantizer/arm_quantizer.py @@ -1,6 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. -# Copyright 2024-2025 Arm Limited and/or its affiliates. +# Copyright 2024-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -597,8 +597,50 @@ def _annotate_io( mark_node_as_annotated(node) def validate(self, model: GraphModule) -> None: - """TODO: Implement validation of annotated graph for TOSA backend.""" - pass + """Validate the quantization results. Currently, this includes: + - Ensure tensor inputs to each operator live on the same device. + + Args: + model (GraphModule): GraphModule being validated. + Raises: + ValueError: If tensor inputs for any operator span more than one + device. + """ + for node in model.graph.nodes: + if node.op != "call_function": + continue + + devices = set() + for arg_node in node.all_input_nodes: + meta_val = arg_node.meta.get("val", None) + if meta_val is None: + continue + if isinstance(meta_val, (tuple, list)): + for tensor in meta_val: + devices.add( + str( + getattr( + tensor, + "device", + f"Could not get device from {tensor}", + ) + ) + ) + else: + devices.add( + str( + getattr( + meta_val, + "device", + f"Could not get device from {meta_val}", + ) + ) + ) + + if len(devices) > 1: + raise ValueError( + f"Quantizer detected operator {node.name} with different device inputs: {devices}." + ) def quantize_with_submodules( self, diff --git a/backends/arm/test/quantizer/test_tosa_quantizer_validate.py b/backends/arm/test/quantizer/test_tosa_quantizer_validate.py new file mode 100644 index 00000000000..081e9ecabd6 --- /dev/null +++ b/backends/arm/test/quantizer/test_tosa_quantizer_validate.py @@ -0,0 +1,55 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import pytest +import torch + +from executorch.backends.arm.quantizer import TOSAQuantizer +from executorch.backends.arm.tosa import TosaSpecification +from torch.fx import symbolic_trace + + +def _annotate_placeholders_with_devices(gm, device_map): + for node in gm.graph.nodes: + if node.op == "placeholder": + device = device_map[node.target] + node.meta["val"] = torch.empty(1, device=device) + + +def _get_quantizer(): + return TOSAQuantizer(TosaSpecification.create_from_string("TOSA-1.0+INT")) + + +class TwoIndependentAdds(torch.nn.Module): + + def forward(self, x: torch.Tensor, y: torch.Tensor): + return x + 1, y + 1 + + +class CrossDeviceAdd(torch.nn.Module): + + def forward(self, x: torch.Tensor, y: torch.Tensor): + return x + y + + +def test_validate_allows_different_devices_across_operators(): + gm = symbolic_trace(TwoIndependentAdds()) + _annotate_placeholders_with_devices( + gm, {"x": torch.device("cpu"), "y": torch.device("meta")} + ) + + quantizer = _get_quantizer() + quantizer.validate(gm) + + +def test_validate_rejects_mixed_devices_within_operator(): + gm = symbolic_trace(CrossDeviceAdd()) + _annotate_placeholders_with_devices( + gm, {"x": torch.device("cpu"), "y": torch.device("meta")} + ) + + quantizer = _get_quantizer() + with pytest.raises(ValueError, match="Quantizer detected operator"): + quantizer.validate(gm) diff --git a/backends/transforms/replace_scalar_with_tensor.py b/backends/transforms/replace_scalar_with_tensor.py index 8ce05a3d4d4..d54b549409f 100644 --- a/backends/transforms/replace_scalar_with_tensor.py +++ b/backends/transforms/replace_scalar_with_tensor.py @@ -1,6 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -59,7 +59,10 @@ def get_replacement(self, op, args, kwargs, meta): (1,), args[1], ), - kwargs={"dtype": args[0].to_tensor().dtype}, + kwargs={ + "dtype": args[0].to_tensor().dtype, + "device": args[0].to_tensor().device, + }, meta=meta, ), # Other args.