diff --git a/backends/cadence/aot/reorder_ops.py b/backends/cadence/aot/reorder_ops.py index a8eda5cc457..5a9b76b473a 100644 --- a/backends/cadence/aot/reorder_ops.py +++ b/backends/cadence/aot/reorder_ops.py @@ -11,7 +11,7 @@ from collections import defaultdict from math import prod -from typing import cast, DefaultDict, List, Tuple +from typing import Callable, cast, DefaultDict, List, Tuple import torch import torch.fx @@ -719,6 +719,109 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: return True +@register_cadence_pass(CadencePassAttribute(opt_level=1)) +class PropagateSlice(RemoveOrReplacePassInterface): + """Propagate slice_copy before unary element-wise ops when the cost + model indicates it reduces total data movement. + + Supported ops (extensible via dispatch table): + - quantize_per_tensor: element-wise, slice passes through unchanged + - dequantize_per_tensor: element-wise, slice passes through unchanged + + Handles any slice dim and any step size. Runs in the iterative pass + loop — chains are handled by repeated application. + """ + + def __init__(self) -> None: + super().__init__() + elementwise_targets = [ + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, + exir_ops.edge.cadence.quantize_per_tensor.default, + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, + exir_ops.edge.cadence.dequantize_per_tensor.default, + ] + self._dispatch: dict[ + EdgeOpOverload, + tuple[ + Callable[[torch.fx.Node, torch.fx.Node], bool], + Callable[[torch.fx.Node, torch.fx.Node], bool], + ], + ] = { + t: (self._should_swap_elementwise, self._swap_elementwise_slice) + for t in elementwise_targets + } + + @property + def targets(self) -> list[EdgeOpOverload]: + return [exir_ops.edge.aten.slice_copy.Tensor] + + def _should_swap_elementwise( + self, op_node: torch.fx.Node, slice_node: torch.fx.Node + ) -> bool: + full_size = prod(op_node.meta["val"].shape) + sliced_size = prod(slice_node.meta["val"].shape) + return sliced_size < full_size + + def _swap_elementwise_slice( + self, op_node: torch.fx.Node, slice_node: torch.fx.Node + ) -> bool: + op_input = op_node.args[0] + assert isinstance(op_input, torch.fx.Node) + graph = slice_node.graph + + slice_args = slice_node.args[1:] + + with graph.inserting_before(op_node): + new_slice = graph.call_function( + exir_ops.edge.aten.slice_copy.Tensor, + args=(op_input, *slice_args), + ) + new_slice.meta["val"] = exir_ops.edge.aten.slice_copy.Tensor( + op_input.meta["val"], *slice_args + ) + + new_args = list(op_node.args) + new_args[0] = new_slice + target = cast(EdgeOpOverload, op_node.target) + new_op = graph.call_function( + target, + args=tuple(new_args), + kwargs=op_node.kwargs, + ) + new_op.meta["val"] = target( + new_slice.meta["val"], + *[ + a.meta["val"] if isinstance(a, torch.fx.Node) else a + for a in new_args[1:] + ], + **{ + k: v.meta["val"] if isinstance(v, torch.fx.Node) else v + for k, v in op_node.kwargs.items() + }, + ) + + slice_node.replace_all_uses_with(new_op) + graph.erase_node(slice_node) + graph.erase_node(op_node) + return True + + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: + parent = node.args[0] + if not isinstance(parent, torch.fx.Node): + return False + if len(parent.users) != 1: + return False + if not isinstance(parent.target, EdgeOpOverload): + return False + + entry = self._dispatch.get(parent.target) + if entry is None: + return False + + should_swap, do_swap = entry + return should_swap(parent, node) and do_swap(parent, node) + + # The following class consolidates functions to reoder ops (i.e., either hoist # or sink some ops in the graph). class CadenceReorderOpsInGraph: diff --git a/backends/cadence/aot/tests/test_reorder_ops_passes.py b/backends/cadence/aot/tests/test_reorder_ops_passes.py index ba9089a652e..cf3a6840179 100644 --- a/backends/cadence/aot/tests/test_reorder_ops_passes.py +++ b/backends/cadence/aot/tests/test_reorder_ops_passes.py @@ -26,6 +26,7 @@ MoveSliceBeforePermutePass, PostponeDequantizeOpBelowUseChainPass, PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView, + PropagateSlice, SinkOpsCloserToUsePass, ) from executorch.backends.test.graph_builder import GraphBuilder @@ -761,3 +762,168 @@ def test_non_dim0_slice_always_moved(self) -> None: MoveSliceBeforePermutePass(), ) self.assertTrue(result.modified) + + +class TestPropagateSlice(unittest.TestCase): + def test_swap_quantize_slice(self) -> None: + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(4, 60, 1, 1)) + quant = builder.call_operator( + exir_ops.edge.cadence.quantize_per_tensor.default, + args=(x, 0.5, 0, 0, 255, torch.uint8), + ) + sliced = builder.call_operator( + exir_ops.edge.aten.slice_copy.Tensor, + args=(quant, 0, 0, 4, 2), + ) + builder.output([sliced]) + gm = builder.get_graph_module() + + result = PropagateSlice().call(gm) + + self.assertTrue(result.modified) + + slice_nodes = gm.graph.find_nodes( + op="call_function", target=exir_ops.edge.aten.slice_copy.Tensor + ) + self.assertEqual(len(slice_nodes), 1) + slice_node = slice_nodes[0] + self.assertEqual(slice_node.args[0].name, "x") + self.assertEqual(list(slice_node.meta["val"].shape), [2, 60, 1, 1]) + + quant_nodes = gm.graph.find_nodes( + op="call_function", + target=exir_ops.edge.cadence.quantize_per_tensor.default, + ) + self.assertEqual(len(quant_nodes), 1) + self.assertEqual(quant_nodes[0].args[0], slice_node) + self.assertEqual(list(quant_nodes[0].meta["val"].shape), [2, 60, 1, 1]) + + def test_swap_dequantize_slice(self) -> None: + builder = GraphBuilder() + x = builder.placeholder( + "x", torch.randint(0, 255, (4, 60, 4, 4), dtype=torch.uint8) + ) + dequant = builder.call_operator( + exir_ops.edge.cadence.dequantize_per_tensor.default, + args=(x, 0.5, 0, 0, 255, torch.uint8), + ) + sliced = builder.call_operator( + exir_ops.edge.aten.slice_copy.Tensor, + args=(dequant, 0, 0, 4, 2), + ) + builder.output([sliced]) + gm = builder.get_graph_module() + + result = PropagateSlice().call(gm) + + self.assertTrue(result.modified) + + slice_nodes = gm.graph.find_nodes( + op="call_function", target=exir_ops.edge.aten.slice_copy.Tensor + ) + self.assertEqual(len(slice_nodes), 1) + self.assertEqual(slice_nodes[0].args[0].name, "x") + + def test_step_2_through_quantize(self) -> None: + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(4, 60, 1, 1)) + quant = builder.call_operator( + exir_ops.edge.cadence.quantize_per_tensor.default, + args=(x, 0.5, 0, 0, 255, torch.uint8), + ) + sliced = builder.call_operator( + exir_ops.edge.aten.slice_copy.Tensor, + args=(quant, 0, 0, 4, 2), + ) + builder.output([sliced]) + gm = builder.get_graph_module() + + result = PropagateSlice().call(gm) + + self.assertTrue(result.modified) + + slice_nodes = gm.graph.find_nodes( + op="call_function", target=exir_ops.edge.aten.slice_copy.Tensor + ) + self.assertEqual(len(slice_nodes), 1) + self.assertEqual(slice_nodes[0].args[4], 2) + self.assertEqual(list(slice_nodes[0].meta["val"].shape), [2, 60, 1, 1]) + + def test_non_batch_dim_slice(self) -> None: + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(4, 60, 4, 4)) + quant = builder.call_operator( + exir_ops.edge.cadence.quantize_per_tensor.default, + args=(x, 0.5, 0, 0, 255, torch.uint8), + ) + sliced = builder.call_operator( + exir_ops.edge.aten.slice_copy.Tensor, + args=(quant, 1, 0, 30, 1), + ) + builder.output([sliced]) + gm = builder.get_graph_module() + + result = PropagateSlice().call(gm) + + self.assertTrue(result.modified) + + slice_nodes = gm.graph.find_nodes( + op="call_function", target=exir_ops.edge.aten.slice_copy.Tensor + ) + self.assertEqual(len(slice_nodes), 1) + self.assertEqual(list(slice_nodes[0].meta["val"].shape), [4, 30, 4, 4]) + + def test_no_swap_when_multi_user(self) -> None: + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(4, 60, 1, 1)) + quant = builder.call_operator( + exir_ops.edge.cadence.quantize_per_tensor.default, + args=(x, 0.5, 0, 0, 255, torch.uint8), + ) + sliced = builder.call_operator( + exir_ops.edge.aten.slice_copy.Tensor, + args=(quant, 0, 0, 4, 2), + ) + builder.output([sliced, quant]) + gm = builder.get_graph_module() + + result = PropagateSlice().call(gm) + + self.assertFalse(result.modified) + + def test_no_swap_noop_slice(self) -> None: + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(4, 60, 1, 1)) + quant = builder.call_operator( + exir_ops.edge.cadence.quantize_per_tensor.default, + args=(x, 0.5, 0, 0, 255, torch.uint8), + ) + sliced = builder.call_operator( + exir_ops.edge.aten.slice_copy.Tensor, + args=(quant, 0, 0, 4, 1), + ) + builder.output([sliced]) + gm = builder.get_graph_module() + + result = PropagateSlice().call(gm) + + self.assertFalse(result.modified) + + def test_unsupported_parent_not_swapped(self) -> None: + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(4, 60, 1, 1)) + relu = builder.call_operator( + exir_ops.edge.aten.relu.default, + args=(x,), + ) + sliced = builder.call_operator( + exir_ops.edge.aten.slice_copy.Tensor, + args=(relu, 0, 0, 4, 2), + ) + builder.output([sliced]) + gm = builder.get_graph_module() + + result = PropagateSlice().call(gm) + + self.assertFalse(result.modified)