Skip to content

Commit c822247

Browse files
apullinfacebook-github-bot
authored andcommitted
Quantize moveaxis/movedim so they delegate to Ethos-U (#20314)
Summary: The ARM PT2 quantizer's pass-through shared-qspec set in quantization_annotator.py (_one_to_one_shared_input_qspec) covers permute/permute_copy/transpose/view/squeeze etc., but omits aten.moveaxis/aten.movedim. A model that uses torch.moveaxis therefore leaves those ops unquantized: the quantizer brackets each one with dequantize -> moveaxis(float) -> quantize. On lowering, moveaxis decomposes to a float permute_copy. The Ethos-U55 operator-support check (operator_support/ethos_u55_support.py) only delegates permute_copy for int8/int16/int32, so it rejects the float one. Each rejected permute is stranded on the host, splitting the model into many delegated partitions (one NPU island per permute), which bloats the .pte with per-partition delegate overhead and host round-trips. Add aten.moveaxis.int / aten.movedim.int to _one_to_one_shared_input_qspec (guarded with getattr for torch-build variance, mirroring the existing transpose.Dimname handling) so they share the input quantization spec exactly like transpose/permute. They then stay int8, decompose to int8 permute_copy, and delegate to the NPU -- eliminating the host float islands. Impact: a quantized example ensemble (ConvNeXt-style blocks that use torch.moveaxis) that previously lowered into 9 Ethos-U55 partitions now lowers into a single delegate, with zero host permutes and ~24% smaller .pte, with no model changes. Generalizes to any moveaxis/movedim-using model on the Ethos-U backend. Differential Revision: D108478011
1 parent 23f9021 commit c822247

2 files changed

Lines changed: 35 additions & 0 deletions

File tree

backends/arm/quantizer/quantization_annotator.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -619,6 +619,10 @@ def _get_fixed_qparams_qspec(
619619
# dequant -> neg -> requant chain.
620620
torch.ops.aten.neg.default,
621621
torch.ops.aten.detach_copy.default,
622+
torch.ops.aten.moveaxis.int,
623+
torch.ops.aten.moveaxis.intlist,
624+
torch.ops.aten.movedim.int,
625+
torch.ops.aten.movedim.intlist,
622626
}
623627

624628
# Dimname has been removed from upstream PyTorch, but there may be a window
@@ -630,6 +634,7 @@ def _get_fixed_qparams_qspec(
630634
if _transpose_dimname is not None:
631635
_one_to_one_shared_input_qspec.add(_transpose_dimname)
632636

637+
633638
_one_to_one_shared_input_or_input_act_qspec: set[OpOverload] = {
634639
torch.ops.aten.alias.default,
635640
torch.ops.aten.clone.default,

backends/arm/test/ops/test_permute.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from typing import Tuple
1010

1111
import torch
12+
from executorch.backends.arm.quantizer import quantization_annotator
1213
from executorch.backends.arm.quantizer.arm_quantizer import (
1314
get_symmetric_a16w8_quantization_config,
1415
)
@@ -78,6 +79,12 @@ def forward(self, x):
7879
return torch.permute(x, self.dims)
7980

8081

82+
class SimpleMoveAxis(torch.nn.Module):
83+
84+
def forward(self, x):
85+
return torch.moveaxis(x, 1, -1)
86+
87+
8188
@common.parametrize(
8289
"test_data", test_data_suite | test_data_suite_fp16 | test_data_suite_bf16
8390
)
@@ -118,6 +125,29 @@ def test_permute_u55_INT(test_data):
118125
pipeline.run()
119126

120127

128+
def test_moveaxis_movedim_shared_qspec_annotations():
129+
expected_ops = {
130+
torch.ops.aten.moveaxis.int,
131+
torch.ops.aten.moveaxis.intlist,
132+
torch.ops.aten.movedim.int,
133+
torch.ops.aten.movedim.intlist,
134+
}
135+
136+
assert expected_ops <= quantization_annotator._one_to_one_shared_input_qspec
137+
138+
139+
@common.XfailIfNoCorstone300
140+
def test_moveaxis_u55_INT():
141+
pipeline = EthosU55PipelineINT[input_t1](
142+
SimpleMoveAxis(),
143+
(torch.rand(1, 4, 5, 6),),
144+
"torch.ops.aten.moveaxis.int",
145+
exir_ops="executorch_exir_dialects_edge__ops_aten_permute_copy_default",
146+
run_on_fvp=False,
147+
)
148+
pipeline.run()
149+
150+
121151
@common.parametrize("test_data", test_data_suite_u55_reject)
122152
def test_permute_u55_INT_not_delegated(test_data: torch.Tensor):
123153
test_data, dims = test_data()

0 commit comments

Comments
 (0)