From f5919267edd5a32f38a4a6ec51becad26fa6b8c4 Mon Sep 17 00:00:00 2001 From: Andrew Grebenisan Date: Wed, 6 May 2026 14:47:58 -0700 Subject: [PATCH] More generic slice propagation before unary ops which works for non-contiguous slices (#19345) Summary: The existing MoveSliceToInputPass in the Jarvis compiler propagates slices backward through computation chains using the tiling/region infrastructure. However, it only supports contiguous slices (step=1) because non-unitary steps cannot be represented as contiguous regions. This diff adds PropagateSlice, a lightweight pass that swaps slice_copy past element-wise quantize/dequantize ops when a cost model indicates it reduces data movement. Unlike MoveSliceToInputPass, it handles any step size (including non-contiguous slices like step=2), but only swaps past a single adjacent op rather than walking entire chains. The two passes are complementary: MoveSliceToInputPass handles deep propagation of contiguous slices through complex op chains with tiling, while PropagateSlice handles the simpler case of moving strided slices past quant/dequant boundaries where tiling is irrelevant (so far, but will be extended to other cases). The idea is eventually it can be applied iteratively to keep pushing the slice up. Changes: - Add PropagateSlice pass to reorder_ops.py with a dispatch-table design for extensibility - Cost model: only swap when the slice actually reduces tensor volume (sliced < full) - Supported ops: quantize_per_tensor, dequantize_per_tensor (both cadence and quantized_decomposed variants) - Tests moved into test_reorder_ops_passes.py alongside other reorder pass tests Reviewed By: ethansfng Differential Revision: D103752840 --- backends/cadence/aot/reorder_ops.py | 105 ++++++++++- .../aot/tests/test_reorder_ops_passes.py | 166 ++++++++++++++++++ 2 files changed, 270 insertions(+), 1 deletion(-) diff --git a/backends/cadence/aot/reorder_ops.py b/backends/cadence/aot/reorder_ops.py index a8eda5cc457..5a9b76b473a 100644 --- a/backends/cadence/aot/reorder_ops.py +++ b/backends/cadence/aot/reorder_ops.py @@ -11,7 +11,7 @@ from collections import defaultdict from math import prod -from typing import cast, DefaultDict, List, Tuple +from typing import Callable, cast, DefaultDict, List, Tuple import torch import torch.fx @@ -719,6 +719,109 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: return True +@register_cadence_pass(CadencePassAttribute(opt_level=1)) +class PropagateSlice(RemoveOrReplacePassInterface): + """Propagate slice_copy before unary element-wise ops when the cost + model indicates it reduces total data movement. + + Supported ops (extensible via dispatch table): + - quantize_per_tensor: element-wise, slice passes through unchanged + - dequantize_per_tensor: element-wise, slice passes through unchanged + + Handles any slice dim and any step size. Runs in the iterative pass + loop — chains are handled by repeated application. + """ + + def __init__(self) -> None: + super().__init__() + elementwise_targets = [ + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, + exir_ops.edge.cadence.quantize_per_tensor.default, + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, + exir_ops.edge.cadence.dequantize_per_tensor.default, + ] + self._dispatch: dict[ + EdgeOpOverload, + tuple[ + Callable[[torch.fx.Node, torch.fx.Node], bool], + Callable[[torch.fx.Node, torch.fx.Node], bool], + ], + ] = { + t: (self._should_swap_elementwise, self._swap_elementwise_slice) + for t in elementwise_targets + } + + @property + def targets(self) -> list[EdgeOpOverload]: + return [exir_ops.edge.aten.slice_copy.Tensor] + + def _should_swap_elementwise( + self, op_node: torch.fx.Node, slice_node: torch.fx.Node + ) -> bool: + full_size = prod(op_node.meta["val"].shape) + sliced_size = prod(slice_node.meta["val"].shape) + return sliced_size < full_size + + def _swap_elementwise_slice( + self, op_node: torch.fx.Node, slice_node: torch.fx.Node + ) -> bool: + op_input = op_node.args[0] + assert isinstance(op_input, torch.fx.Node) + graph = slice_node.graph + + slice_args = slice_node.args[1:] + + with graph.inserting_before(op_node): + new_slice = graph.call_function( + exir_ops.edge.aten.slice_copy.Tensor, + args=(op_input, *slice_args), + ) + new_slice.meta["val"] = exir_ops.edge.aten.slice_copy.Tensor( + op_input.meta["val"], *slice_args + ) + + new_args = list(op_node.args) + new_args[0] = new_slice + target = cast(EdgeOpOverload, op_node.target) + new_op = graph.call_function( + target, + args=tuple(new_args), + kwargs=op_node.kwargs, + ) + new_op.meta["val"] = target( + new_slice.meta["val"], + *[ + a.meta["val"] if isinstance(a, torch.fx.Node) else a + for a in new_args[1:] + ], + **{ + k: v.meta["val"] if isinstance(v, torch.fx.Node) else v + for k, v in op_node.kwargs.items() + }, + ) + + slice_node.replace_all_uses_with(new_op) + graph.erase_node(slice_node) + graph.erase_node(op_node) + return True + + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: + parent = node.args[0] + if not isinstance(parent, torch.fx.Node): + return False + if len(parent.users) != 1: + return False + if not isinstance(parent.target, EdgeOpOverload): + return False + + entry = self._dispatch.get(parent.target) + if entry is None: + return False + + should_swap, do_swap = entry + return should_swap(parent, node) and do_swap(parent, node) + + # The following class consolidates functions to reoder ops (i.e., either hoist # or sink some ops in the graph). class CadenceReorderOpsInGraph: diff --git a/backends/cadence/aot/tests/test_reorder_ops_passes.py b/backends/cadence/aot/tests/test_reorder_ops_passes.py index ba9089a652e..cf3a6840179 100644 --- a/backends/cadence/aot/tests/test_reorder_ops_passes.py +++ b/backends/cadence/aot/tests/test_reorder_ops_passes.py @@ -26,6 +26,7 @@ MoveSliceBeforePermutePass, PostponeDequantizeOpBelowUseChainPass, PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView, + PropagateSlice, SinkOpsCloserToUsePass, ) from executorch.backends.test.graph_builder import GraphBuilder @@ -761,3 +762,168 @@ def test_non_dim0_slice_always_moved(self) -> None: MoveSliceBeforePermutePass(), ) self.assertTrue(result.modified) + + +class TestPropagateSlice(unittest.TestCase): + def test_swap_quantize_slice(self) -> None: + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(4, 60, 1, 1)) + quant = builder.call_operator( + exir_ops.edge.cadence.quantize_per_tensor.default, + args=(x, 0.5, 0, 0, 255, torch.uint8), + ) + sliced = builder.call_operator( + exir_ops.edge.aten.slice_copy.Tensor, + args=(quant, 0, 0, 4, 2), + ) + builder.output([sliced]) + gm = builder.get_graph_module() + + result = PropagateSlice().call(gm) + + self.assertTrue(result.modified) + + slice_nodes = gm.graph.find_nodes( + op="call_function", target=exir_ops.edge.aten.slice_copy.Tensor + ) + self.assertEqual(len(slice_nodes), 1) + slice_node = slice_nodes[0] + self.assertEqual(slice_node.args[0].name, "x") + self.assertEqual(list(slice_node.meta["val"].shape), [2, 60, 1, 1]) + + quant_nodes = gm.graph.find_nodes( + op="call_function", + target=exir_ops.edge.cadence.quantize_per_tensor.default, + ) + self.assertEqual(len(quant_nodes), 1) + self.assertEqual(quant_nodes[0].args[0], slice_node) + self.assertEqual(list(quant_nodes[0].meta["val"].shape), [2, 60, 1, 1]) + + def test_swap_dequantize_slice(self) -> None: + builder = GraphBuilder() + x = builder.placeholder( + "x", torch.randint(0, 255, (4, 60, 4, 4), dtype=torch.uint8) + ) + dequant = builder.call_operator( + exir_ops.edge.cadence.dequantize_per_tensor.default, + args=(x, 0.5, 0, 0, 255, torch.uint8), + ) + sliced = builder.call_operator( + exir_ops.edge.aten.slice_copy.Tensor, + args=(dequant, 0, 0, 4, 2), + ) + builder.output([sliced]) + gm = builder.get_graph_module() + + result = PropagateSlice().call(gm) + + self.assertTrue(result.modified) + + slice_nodes = gm.graph.find_nodes( + op="call_function", target=exir_ops.edge.aten.slice_copy.Tensor + ) + self.assertEqual(len(slice_nodes), 1) + self.assertEqual(slice_nodes[0].args[0].name, "x") + + def test_step_2_through_quantize(self) -> None: + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(4, 60, 1, 1)) + quant = builder.call_operator( + exir_ops.edge.cadence.quantize_per_tensor.default, + args=(x, 0.5, 0, 0, 255, torch.uint8), + ) + sliced = builder.call_operator( + exir_ops.edge.aten.slice_copy.Tensor, + args=(quant, 0, 0, 4, 2), + ) + builder.output([sliced]) + gm = builder.get_graph_module() + + result = PropagateSlice().call(gm) + + self.assertTrue(result.modified) + + slice_nodes = gm.graph.find_nodes( + op="call_function", target=exir_ops.edge.aten.slice_copy.Tensor + ) + self.assertEqual(len(slice_nodes), 1) + self.assertEqual(slice_nodes[0].args[4], 2) + self.assertEqual(list(slice_nodes[0].meta["val"].shape), [2, 60, 1, 1]) + + def test_non_batch_dim_slice(self) -> None: + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(4, 60, 4, 4)) + quant = builder.call_operator( + exir_ops.edge.cadence.quantize_per_tensor.default, + args=(x, 0.5, 0, 0, 255, torch.uint8), + ) + sliced = builder.call_operator( + exir_ops.edge.aten.slice_copy.Tensor, + args=(quant, 1, 0, 30, 1), + ) + builder.output([sliced]) + gm = builder.get_graph_module() + + result = PropagateSlice().call(gm) + + self.assertTrue(result.modified) + + slice_nodes = gm.graph.find_nodes( + op="call_function", target=exir_ops.edge.aten.slice_copy.Tensor + ) + self.assertEqual(len(slice_nodes), 1) + self.assertEqual(list(slice_nodes[0].meta["val"].shape), [4, 30, 4, 4]) + + def test_no_swap_when_multi_user(self) -> None: + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(4, 60, 1, 1)) + quant = builder.call_operator( + exir_ops.edge.cadence.quantize_per_tensor.default, + args=(x, 0.5, 0, 0, 255, torch.uint8), + ) + sliced = builder.call_operator( + exir_ops.edge.aten.slice_copy.Tensor, + args=(quant, 0, 0, 4, 2), + ) + builder.output([sliced, quant]) + gm = builder.get_graph_module() + + result = PropagateSlice().call(gm) + + self.assertFalse(result.modified) + + def test_no_swap_noop_slice(self) -> None: + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(4, 60, 1, 1)) + quant = builder.call_operator( + exir_ops.edge.cadence.quantize_per_tensor.default, + args=(x, 0.5, 0, 0, 255, torch.uint8), + ) + sliced = builder.call_operator( + exir_ops.edge.aten.slice_copy.Tensor, + args=(quant, 0, 0, 4, 1), + ) + builder.output([sliced]) + gm = builder.get_graph_module() + + result = PropagateSlice().call(gm) + + self.assertFalse(result.modified) + + def test_unsupported_parent_not_swapped(self) -> None: + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(4, 60, 1, 1)) + relu = builder.call_operator( + exir_ops.edge.aten.relu.default, + args=(x,), + ) + sliced = builder.call_operator( + exir_ops.edge.aten.slice_copy.Tensor, + args=(relu, 0, 0, 4, 2), + ) + builder.output([sliced]) + gm = builder.get_graph_module() + + result = PropagateSlice().call(gm) + + self.assertFalse(result.modified)