Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 91 additions & 18 deletions backends/cadence/aot/reorder_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -721,15 +721,16 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:

@register_cadence_pass(CadencePassAttribute(opt_level=1))
class PropagateSlice(RemoveOrReplacePassInterface):
"""Propagate slice_copy before unary element-wise ops when the cost
model indicates it reduces total data movement.
"""Propagate slice_copy before element-wise ops when the cost model
indicates it reduces total data movement.

Supported ops (extensible via dispatch table):
- quantize_per_tensor: element-wise, slice passes through unchanged
- dequantize_per_tensor: element-wise, slice passes through unchanged
- quantize_per_tensor: unary element-wise
- dequantize_per_tensor: unary element-wise
- add.Tensor: binary with broadcast — slices non-broadcasting inputs
- mul.Tensor: binary with broadcast — slices non-broadcasting inputs

Handles any slice dim and any step size. Runs in the iterative pass
loop — chains are handled by repeated application.
Handles any slice dim and any step size.
"""

def __init__(self) -> None:
Expand All @@ -740,16 +741,28 @@ def __init__(self) -> None:
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
exir_ops.edge.cadence.dequantize_per_tensor.default,
]
binary_targets = [
exir_ops.edge.aten.add.Tensor,
exir_ops.edge.aten.mul.Tensor,
]
self._dispatch: dict[
EdgeOpOverload,
tuple[
Callable[[torch.fx.Node, torch.fx.Node], bool],
Callable[[torch.fx.Node, torch.fx.Node], bool],
],
] = {
t: (self._should_swap_elementwise, self._swap_elementwise_slice)
for t in elementwise_targets
}
] = {}
for t in elementwise_targets:
self._dispatch[t] = (
self._should_swap_elementwise,
self._swap_elementwise_slice,
)

for t in binary_targets:
self._dispatch[t] = (
self._should_swap_binary_elementwise,
self._swap_binary_elementwise_slice,
)

@property
def targets(self) -> list[EdgeOpOverload]:
Expand All @@ -765,19 +778,21 @@ def _should_swap_elementwise(
def _swap_elementwise_slice(
self, op_node: torch.fx.Node, slice_node: torch.fx.Node
) -> bool:
op_input = op_node.args[0]
assert isinstance(op_input, torch.fx.Node)
op_input = get_arg(op_node, "input", torch.fx.Node)
graph = slice_node.graph

slice_args = slice_node.args[1:]
slice_dim = get_arg(slice_node, "dim", int)
slice_start = get_arg(slice_node, "start")
slice_end = get_arg(slice_node, "end")
slice_step = get_arg(slice_node, "step", int)

with graph.inserting_before(op_node):
new_slice = graph.call_function(
exir_ops.edge.aten.slice_copy.Tensor,
args=(op_input, *slice_args),
args=(op_input, slice_dim, slice_start, slice_end, slice_step),
)
new_slice.meta["val"] = exir_ops.edge.aten.slice_copy.Tensor(
op_input.meta["val"], *slice_args
op_input.meta["val"], slice_dim, slice_start, slice_end, slice_step
)

new_args = list(op_node.args)
Expand Down Expand Up @@ -805,10 +820,68 @@ def _swap_elementwise_slice(
graph.erase_node(op_node)
return True

def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
parent = node.args[0]
if not isinstance(parent, torch.fx.Node):
def _should_swap_binary_elementwise(
self, op_node: torch.fx.Node, slice_node: torch.fx.Node
) -> bool:
lhs, rhs = op_node.args[0], op_node.args[1]
assert isinstance(lhs, torch.fx.Node) and isinstance(rhs, torch.fx.Node)
if lhs.meta["val"].shape == rhs.meta["val"].shape:
return False
full_size = prod(op_node.meta["val"].shape)
sliced_size = prod(slice_node.meta["val"].shape)
return sliced_size < full_size

def _swap_binary_elementwise_slice(
self, op_node: torch.fx.Node, slice_node: torch.fx.Node
) -> bool:
lhs, rhs = op_node.args[0], op_node.args[1]
assert isinstance(lhs, torch.fx.Node) and isinstance(rhs, torch.fx.Node)
graph = slice_node.graph

slice_dim = get_arg(slice_node, "dim", int)
slice_start = get_arg(slice_node, "start")
slice_end = get_arg(slice_node, "end")
slice_step = get_arg(slice_node, "step", int)

output_shape = op_node.meta["val"].shape

new_args = list(op_node.args)
with graph.inserting_before(op_node):
for i, inp in enumerate([lhs, rhs]):
if inp.meta["val"].shape[slice_dim] == output_shape[slice_dim]:
new_slice = graph.call_function(
exir_ops.edge.aten.slice_copy.Tensor,
args=(inp, slice_dim, slice_start, slice_end, slice_step),
)
new_slice.meta["val"] = exir_ops.edge.aten.slice_copy.Tensor(
inp.meta["val"], slice_dim, slice_start, slice_end, slice_step
)
new_args[i] = new_slice

target = cast(EdgeOpOverload, op_node.target)
new_op = graph.call_function(
target,
args=tuple(new_args),
kwargs=op_node.kwargs,
)
new_op.meta["val"] = target(
*[
a.meta["val"] if isinstance(a, torch.fx.Node) else a
for a in new_args
],
**{
k: v.meta["val"] if isinstance(v, torch.fx.Node) else v
for k, v in op_node.kwargs.items()
},
)

slice_node.replace_all_uses_with(new_op)
graph.erase_node(slice_node)
graph.erase_node(op_node)
return True

def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
parent = get_arg(node, "input", torch.fx.Node)
if len(parent.users) != 1:
return False
if not isinstance(parent.target, EdgeOpOverload):
Expand Down
97 changes: 97 additions & 0 deletions backends/cadence/aot/tests/test_reorder_ops_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -927,3 +927,100 @@ def test_unsupported_parent_not_swapped(self) -> None:
result = PropagateSlice().call(gm)

self.assertFalse(result.modified)

def test_swap_broadcast_mul_slice_on_broadcast_dim(self) -> None:
"""[1,60,1,1] * [4,1,1,1] → [4,60,1,1] → slice(dim=0, step=2)
Only the [4,1,1,1] input should be sliced."""
builder = GraphBuilder()
a = builder.placeholder("a", torch.randn(1, 60, 1, 1))
b = builder.placeholder("b", torch.randn(4, 1, 1, 1))
mul = builder.call_operator(exir_ops.edge.aten.mul.Tensor, args=(a, b))
sliced = builder.call_operator(
exir_ops.edge.aten.slice_copy.Tensor,
args=(mul, 0, 0, 4, 2),
)
builder.output([sliced])
gm = builder.get_graph_module()

result = PropagateSlice().call(gm)

self.assertTrue(result.modified)

slice_nodes = gm.graph.find_nodes(
op="call_function", target=exir_ops.edge.aten.slice_copy.Tensor
)
self.assertEqual(len(slice_nodes), 1)
self.assertEqual(slice_nodes[0].args[0].name, "b")
self.assertEqual(list(slice_nodes[0].meta["val"].shape), [2, 1, 1, 1])

mul_nodes = gm.graph.find_nodes(
op="call_function", target=exir_ops.edge.aten.mul.Tensor
)
self.assertEqual(len(mul_nodes), 1)
self.assertEqual(list(mul_nodes[0].meta["val"].shape), [2, 60, 1, 1])

def test_swap_broadcast_add_lhs_broadcasts(self) -> None:
"""[1,60,4,4] + [4,60,4,4] → [4,60,4,4] → slice(dim=0, step=2)
Only the [4,60,4,4] (rhs) should be sliced."""
builder = GraphBuilder()
a = builder.placeholder("a", torch.randn(1, 60, 4, 4))
b = builder.placeholder("b", torch.randn(4, 60, 4, 4))
add = builder.call_operator(exir_ops.edge.aten.add.Tensor, args=(a, b))
sliced = builder.call_operator(
exir_ops.edge.aten.slice_copy.Tensor,
args=(add, 0, 0, 4, 2),
)
builder.output([sliced])
gm = builder.get_graph_module()

result = PropagateSlice().call(gm)

self.assertTrue(result.modified)

slice_nodes = gm.graph.find_nodes(
op="call_function", target=exir_ops.edge.aten.slice_copy.Tensor
)
self.assertEqual(len(slice_nodes), 1)
self.assertEqual(slice_nodes[0].args[0].name, "b")

def test_swap_broadcast_mul_slice_on_non_broadcast_dim(self) -> None:
"""[4,60,1,1] * [4,1,1,1] → [4,60,1,1] → slice(dim=1, start=0, end=30)
Only the [4,60,1,1] (lhs) should be sliced since rhs has dim1=1."""
builder = GraphBuilder()
a = builder.placeholder("a", torch.randn(4, 60, 1, 1))
b = builder.placeholder("b", torch.randn(4, 1, 1, 1))
mul = builder.call_operator(exir_ops.edge.aten.mul.Tensor, args=(a, b))
sliced = builder.call_operator(
exir_ops.edge.aten.slice_copy.Tensor,
args=(mul, 1, 0, 30, 1),
)
builder.output([sliced])
gm = builder.get_graph_module()

result = PropagateSlice().call(gm)

self.assertTrue(result.modified)

slice_nodes = gm.graph.find_nodes(
op="call_function", target=exir_ops.edge.aten.slice_copy.Tensor
)
self.assertEqual(len(slice_nodes), 1)
self.assertEqual(slice_nodes[0].args[0].name, "a")
self.assertEqual(list(slice_nodes[0].meta["val"].shape), [4, 30, 1, 1])

def test_no_swap_binary_same_shape(self) -> None:
"""Same-shape binary ops are not swapped (no broadcast)."""
builder = GraphBuilder()
a = builder.placeholder("a", torch.randn(4, 60, 4, 4))
b = builder.placeholder("b", torch.randn(4, 60, 4, 4))
add = builder.call_operator(exir_ops.edge.aten.add.Tensor, args=(a, b))
sliced = builder.call_operator(
exir_ops.edge.aten.slice_copy.Tensor,
args=(add, 0, 0, 4, 2),
)
builder.output([sliced])
gm = builder.get_graph_module()

result = PropagateSlice().call(gm)

self.assertFalse(result.modified)
Loading