diff --git a/backends/cadence/aot/reorder_ops.py b/backends/cadence/aot/reorder_ops.py index 5a9b76b473a..2ca766316f3 100644 --- a/backends/cadence/aot/reorder_ops.py +++ b/backends/cadence/aot/reorder_ops.py @@ -721,15 +721,16 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: @register_cadence_pass(CadencePassAttribute(opt_level=1)) class PropagateSlice(RemoveOrReplacePassInterface): - """Propagate slice_copy before unary element-wise ops when the cost - model indicates it reduces total data movement. + """Propagate slice_copy before element-wise ops when the cost model + indicates it reduces total data movement. Supported ops (extensible via dispatch table): - - quantize_per_tensor: element-wise, slice passes through unchanged - - dequantize_per_tensor: element-wise, slice passes through unchanged + - quantize_per_tensor: unary element-wise + - dequantize_per_tensor: unary element-wise + - add.Tensor: binary with broadcast — slices non-broadcasting inputs + - mul.Tensor: binary with broadcast — slices non-broadcasting inputs - Handles any slice dim and any step size. Runs in the iterative pass - loop — chains are handled by repeated application. + Handles any slice dim and any step size. """ def __init__(self) -> None: @@ -740,16 +741,28 @@ def __init__(self) -> None: exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, exir_ops.edge.cadence.dequantize_per_tensor.default, ] + binary_targets = [ + exir_ops.edge.aten.add.Tensor, + exir_ops.edge.aten.mul.Tensor, + ] self._dispatch: dict[ EdgeOpOverload, tuple[ Callable[[torch.fx.Node, torch.fx.Node], bool], Callable[[torch.fx.Node, torch.fx.Node], bool], ], - ] = { - t: (self._should_swap_elementwise, self._swap_elementwise_slice) - for t in elementwise_targets - } + ] = {} + for t in elementwise_targets: + self._dispatch[t] = ( + self._should_swap_elementwise, + self._swap_elementwise_slice, + ) + + for t in binary_targets: + self._dispatch[t] = ( + self._should_swap_binary_elementwise, + self._swap_binary_elementwise_slice, + ) @property def targets(self) -> list[EdgeOpOverload]: @@ -765,19 +778,21 @@ def _should_swap_elementwise( def _swap_elementwise_slice( self, op_node: torch.fx.Node, slice_node: torch.fx.Node ) -> bool: - op_input = op_node.args[0] - assert isinstance(op_input, torch.fx.Node) + op_input = get_arg(op_node, "input", torch.fx.Node) graph = slice_node.graph - slice_args = slice_node.args[1:] + slice_dim = get_arg(slice_node, "dim", int) + slice_start = get_arg(slice_node, "start") + slice_end = get_arg(slice_node, "end") + slice_step = get_arg(slice_node, "step", int) with graph.inserting_before(op_node): new_slice = graph.call_function( exir_ops.edge.aten.slice_copy.Tensor, - args=(op_input, *slice_args), + args=(op_input, slice_dim, slice_start, slice_end, slice_step), ) new_slice.meta["val"] = exir_ops.edge.aten.slice_copy.Tensor( - op_input.meta["val"], *slice_args + op_input.meta["val"], slice_dim, slice_start, slice_end, slice_step ) new_args = list(op_node.args) @@ -805,10 +820,68 @@ def _swap_elementwise_slice( graph.erase_node(op_node) return True - def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: - parent = node.args[0] - if not isinstance(parent, torch.fx.Node): + def _should_swap_binary_elementwise( + self, op_node: torch.fx.Node, slice_node: torch.fx.Node + ) -> bool: + lhs, rhs = op_node.args[0], op_node.args[1] + assert isinstance(lhs, torch.fx.Node) and isinstance(rhs, torch.fx.Node) + if lhs.meta["val"].shape == rhs.meta["val"].shape: return False + full_size = prod(op_node.meta["val"].shape) + sliced_size = prod(slice_node.meta["val"].shape) + return sliced_size < full_size + + def _swap_binary_elementwise_slice( + self, op_node: torch.fx.Node, slice_node: torch.fx.Node + ) -> bool: + lhs, rhs = op_node.args[0], op_node.args[1] + assert isinstance(lhs, torch.fx.Node) and isinstance(rhs, torch.fx.Node) + graph = slice_node.graph + + slice_dim = get_arg(slice_node, "dim", int) + slice_start = get_arg(slice_node, "start") + slice_end = get_arg(slice_node, "end") + slice_step = get_arg(slice_node, "step", int) + + output_shape = op_node.meta["val"].shape + + new_args = list(op_node.args) + with graph.inserting_before(op_node): + for i, inp in enumerate([lhs, rhs]): + if inp.meta["val"].shape[slice_dim] == output_shape[slice_dim]: + new_slice = graph.call_function( + exir_ops.edge.aten.slice_copy.Tensor, + args=(inp, slice_dim, slice_start, slice_end, slice_step), + ) + new_slice.meta["val"] = exir_ops.edge.aten.slice_copy.Tensor( + inp.meta["val"], slice_dim, slice_start, slice_end, slice_step + ) + new_args[i] = new_slice + + target = cast(EdgeOpOverload, op_node.target) + new_op = graph.call_function( + target, + args=tuple(new_args), + kwargs=op_node.kwargs, + ) + new_op.meta["val"] = target( + *[ + a.meta["val"] if isinstance(a, torch.fx.Node) else a + for a in new_args + ], + **{ + k: v.meta["val"] if isinstance(v, torch.fx.Node) else v + for k, v in op_node.kwargs.items() + }, + ) + + slice_node.replace_all_uses_with(new_op) + graph.erase_node(slice_node) + graph.erase_node(op_node) + return True + + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: + parent = get_arg(node, "input", torch.fx.Node) if len(parent.users) != 1: return False if not isinstance(parent.target, EdgeOpOverload): diff --git a/backends/cadence/aot/tests/test_reorder_ops_passes.py b/backends/cadence/aot/tests/test_reorder_ops_passes.py index cf3a6840179..ea8943df8e8 100644 --- a/backends/cadence/aot/tests/test_reorder_ops_passes.py +++ b/backends/cadence/aot/tests/test_reorder_ops_passes.py @@ -927,3 +927,100 @@ def test_unsupported_parent_not_swapped(self) -> None: result = PropagateSlice().call(gm) self.assertFalse(result.modified) + + def test_swap_broadcast_mul_slice_on_broadcast_dim(self) -> None: + """[1,60,1,1] * [4,1,1,1] → [4,60,1,1] → slice(dim=0, step=2) + Only the [4,1,1,1] input should be sliced.""" + builder = GraphBuilder() + a = builder.placeholder("a", torch.randn(1, 60, 1, 1)) + b = builder.placeholder("b", torch.randn(4, 1, 1, 1)) + mul = builder.call_operator(exir_ops.edge.aten.mul.Tensor, args=(a, b)) + sliced = builder.call_operator( + exir_ops.edge.aten.slice_copy.Tensor, + args=(mul, 0, 0, 4, 2), + ) + builder.output([sliced]) + gm = builder.get_graph_module() + + result = PropagateSlice().call(gm) + + self.assertTrue(result.modified) + + slice_nodes = gm.graph.find_nodes( + op="call_function", target=exir_ops.edge.aten.slice_copy.Tensor + ) + self.assertEqual(len(slice_nodes), 1) + self.assertEqual(slice_nodes[0].args[0].name, "b") + self.assertEqual(list(slice_nodes[0].meta["val"].shape), [2, 1, 1, 1]) + + mul_nodes = gm.graph.find_nodes( + op="call_function", target=exir_ops.edge.aten.mul.Tensor + ) + self.assertEqual(len(mul_nodes), 1) + self.assertEqual(list(mul_nodes[0].meta["val"].shape), [2, 60, 1, 1]) + + def test_swap_broadcast_add_lhs_broadcasts(self) -> None: + """[1,60,4,4] + [4,60,4,4] → [4,60,4,4] → slice(dim=0, step=2) + Only the [4,60,4,4] (rhs) should be sliced.""" + builder = GraphBuilder() + a = builder.placeholder("a", torch.randn(1, 60, 4, 4)) + b = builder.placeholder("b", torch.randn(4, 60, 4, 4)) + add = builder.call_operator(exir_ops.edge.aten.add.Tensor, args=(a, b)) + sliced = builder.call_operator( + exir_ops.edge.aten.slice_copy.Tensor, + args=(add, 0, 0, 4, 2), + ) + builder.output([sliced]) + gm = builder.get_graph_module() + + result = PropagateSlice().call(gm) + + self.assertTrue(result.modified) + + slice_nodes = gm.graph.find_nodes( + op="call_function", target=exir_ops.edge.aten.slice_copy.Tensor + ) + self.assertEqual(len(slice_nodes), 1) + self.assertEqual(slice_nodes[0].args[0].name, "b") + + def test_swap_broadcast_mul_slice_on_non_broadcast_dim(self) -> None: + """[4,60,1,1] * [4,1,1,1] → [4,60,1,1] → slice(dim=1, start=0, end=30) + Only the [4,60,1,1] (lhs) should be sliced since rhs has dim1=1.""" + builder = GraphBuilder() + a = builder.placeholder("a", torch.randn(4, 60, 1, 1)) + b = builder.placeholder("b", torch.randn(4, 1, 1, 1)) + mul = builder.call_operator(exir_ops.edge.aten.mul.Tensor, args=(a, b)) + sliced = builder.call_operator( + exir_ops.edge.aten.slice_copy.Tensor, + args=(mul, 1, 0, 30, 1), + ) + builder.output([sliced]) + gm = builder.get_graph_module() + + result = PropagateSlice().call(gm) + + self.assertTrue(result.modified) + + slice_nodes = gm.graph.find_nodes( + op="call_function", target=exir_ops.edge.aten.slice_copy.Tensor + ) + self.assertEqual(len(slice_nodes), 1) + self.assertEqual(slice_nodes[0].args[0].name, "a") + self.assertEqual(list(slice_nodes[0].meta["val"].shape), [4, 30, 1, 1]) + + def test_no_swap_binary_same_shape(self) -> None: + """Same-shape binary ops are not swapped (no broadcast).""" + builder = GraphBuilder() + a = builder.placeholder("a", torch.randn(4, 60, 4, 4)) + b = builder.placeholder("b", torch.randn(4, 60, 4, 4)) + add = builder.call_operator(exir_ops.edge.aten.add.Tensor, args=(a, b)) + sliced = builder.call_operator( + exir_ops.edge.aten.slice_copy.Tensor, + args=(add, 0, 0, 4, 2), + ) + builder.output([sliced]) + gm = builder.get_graph_module() + + result = PropagateSlice().call(gm) + + self.assertFalse(result.modified)