Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 104 additions & 1 deletion backends/cadence/aot/reorder_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from collections import defaultdict
from math import prod
from typing import cast, DefaultDict, List, Tuple
from typing import Callable, cast, DefaultDict, List, Tuple

import torch
import torch.fx
Expand Down Expand Up @@ -719,6 +719,109 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
return True


@register_cadence_pass(CadencePassAttribute(opt_level=1))
class PropagateSlice(RemoveOrReplacePassInterface):
"""Propagate slice_copy before unary element-wise ops when the cost
model indicates it reduces total data movement.

Supported ops (extensible via dispatch table):
- quantize_per_tensor: element-wise, slice passes through unchanged
- dequantize_per_tensor: element-wise, slice passes through unchanged

Handles any slice dim and any step size. Runs in the iterative pass
loop — chains are handled by repeated application.
"""

def __init__(self) -> None:
super().__init__()
elementwise_targets = [
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
exir_ops.edge.cadence.quantize_per_tensor.default,
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
exir_ops.edge.cadence.dequantize_per_tensor.default,
]
self._dispatch: dict[
EdgeOpOverload,
tuple[
Callable[[torch.fx.Node, torch.fx.Node], bool],
Callable[[torch.fx.Node, torch.fx.Node], bool],
],
] = {
t: (self._should_swap_elementwise, self._swap_elementwise_slice)
for t in elementwise_targets
}

@property
def targets(self) -> list[EdgeOpOverload]:
return [exir_ops.edge.aten.slice_copy.Tensor]

def _should_swap_elementwise(
self, op_node: torch.fx.Node, slice_node: torch.fx.Node
) -> bool:
full_size = prod(op_node.meta["val"].shape)
sliced_size = prod(slice_node.meta["val"].shape)
return sliced_size < full_size

def _swap_elementwise_slice(
self, op_node: torch.fx.Node, slice_node: torch.fx.Node
) -> bool:
op_input = op_node.args[0]
assert isinstance(op_input, torch.fx.Node)
graph = slice_node.graph

slice_args = slice_node.args[1:]

with graph.inserting_before(op_node):
new_slice = graph.call_function(
exir_ops.edge.aten.slice_copy.Tensor,
args=(op_input, *slice_args),
)
new_slice.meta["val"] = exir_ops.edge.aten.slice_copy.Tensor(
op_input.meta["val"], *slice_args
)

new_args = list(op_node.args)
new_args[0] = new_slice
target = cast(EdgeOpOverload, op_node.target)
new_op = graph.call_function(
target,
args=tuple(new_args),
kwargs=op_node.kwargs,
)
new_op.meta["val"] = target(
new_slice.meta["val"],
*[
a.meta["val"] if isinstance(a, torch.fx.Node) else a
for a in new_args[1:]
],
**{
k: v.meta["val"] if isinstance(v, torch.fx.Node) else v
for k, v in op_node.kwargs.items()
},
)

slice_node.replace_all_uses_with(new_op)
graph.erase_node(slice_node)
graph.erase_node(op_node)
return True

def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
parent = node.args[0]
if not isinstance(parent, torch.fx.Node):
return False
if len(parent.users) != 1:
return False
if not isinstance(parent.target, EdgeOpOverload):
return False

entry = self._dispatch.get(parent.target)
if entry is None:
return False

should_swap, do_swap = entry
return should_swap(parent, node) and do_swap(parent, node)


# The following class consolidates functions to reoder ops (i.e., either hoist
# or sink some ops in the graph).
class CadenceReorderOpsInGraph:
Expand Down
166 changes: 166 additions & 0 deletions backends/cadence/aot/tests/test_reorder_ops_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
MoveSliceBeforePermutePass,
PostponeDequantizeOpBelowUseChainPass,
PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView,
PropagateSlice,
SinkOpsCloserToUsePass,
)
from executorch.backends.test.graph_builder import GraphBuilder
Expand Down Expand Up @@ -761,3 +762,168 @@ def test_non_dim0_slice_always_moved(self) -> None:
MoveSliceBeforePermutePass(),
)
self.assertTrue(result.modified)


class TestPropagateSlice(unittest.TestCase):
def test_swap_quantize_slice(self) -> None:
builder = GraphBuilder()
x = builder.placeholder("x", torch.randn(4, 60, 1, 1))
quant = builder.call_operator(
exir_ops.edge.cadence.quantize_per_tensor.default,
args=(x, 0.5, 0, 0, 255, torch.uint8),
)
sliced = builder.call_operator(
exir_ops.edge.aten.slice_copy.Tensor,
args=(quant, 0, 0, 4, 2),
)
builder.output([sliced])
gm = builder.get_graph_module()

result = PropagateSlice().call(gm)

self.assertTrue(result.modified)

slice_nodes = gm.graph.find_nodes(
op="call_function", target=exir_ops.edge.aten.slice_copy.Tensor
)
self.assertEqual(len(slice_nodes), 1)
slice_node = slice_nodes[0]
self.assertEqual(slice_node.args[0].name, "x")
self.assertEqual(list(slice_node.meta["val"].shape), [2, 60, 1, 1])

quant_nodes = gm.graph.find_nodes(
op="call_function",
target=exir_ops.edge.cadence.quantize_per_tensor.default,
)
self.assertEqual(len(quant_nodes), 1)
self.assertEqual(quant_nodes[0].args[0], slice_node)
self.assertEqual(list(quant_nodes[0].meta["val"].shape), [2, 60, 1, 1])

def test_swap_dequantize_slice(self) -> None:
builder = GraphBuilder()
x = builder.placeholder(
"x", torch.randint(0, 255, (4, 60, 4, 4), dtype=torch.uint8)
)
dequant = builder.call_operator(
exir_ops.edge.cadence.dequantize_per_tensor.default,
args=(x, 0.5, 0, 0, 255, torch.uint8),
)
sliced = builder.call_operator(
exir_ops.edge.aten.slice_copy.Tensor,
args=(dequant, 0, 0, 4, 2),
)
builder.output([sliced])
gm = builder.get_graph_module()

result = PropagateSlice().call(gm)

self.assertTrue(result.modified)

slice_nodes = gm.graph.find_nodes(
op="call_function", target=exir_ops.edge.aten.slice_copy.Tensor
)
self.assertEqual(len(slice_nodes), 1)
self.assertEqual(slice_nodes[0].args[0].name, "x")

def test_step_2_through_quantize(self) -> None:
builder = GraphBuilder()
x = builder.placeholder("x", torch.randn(4, 60, 1, 1))
quant = builder.call_operator(
exir_ops.edge.cadence.quantize_per_tensor.default,
args=(x, 0.5, 0, 0, 255, torch.uint8),
)
sliced = builder.call_operator(
exir_ops.edge.aten.slice_copy.Tensor,
args=(quant, 0, 0, 4, 2),
)
builder.output([sliced])
gm = builder.get_graph_module()

result = PropagateSlice().call(gm)

self.assertTrue(result.modified)

slice_nodes = gm.graph.find_nodes(
op="call_function", target=exir_ops.edge.aten.slice_copy.Tensor
)
self.assertEqual(len(slice_nodes), 1)
self.assertEqual(slice_nodes[0].args[4], 2)
self.assertEqual(list(slice_nodes[0].meta["val"].shape), [2, 60, 1, 1])

def test_non_batch_dim_slice(self) -> None:
builder = GraphBuilder()
x = builder.placeholder("x", torch.randn(4, 60, 4, 4))
quant = builder.call_operator(
exir_ops.edge.cadence.quantize_per_tensor.default,
args=(x, 0.5, 0, 0, 255, torch.uint8),
)
sliced = builder.call_operator(
exir_ops.edge.aten.slice_copy.Tensor,
args=(quant, 1, 0, 30, 1),
)
builder.output([sliced])
gm = builder.get_graph_module()

result = PropagateSlice().call(gm)

self.assertTrue(result.modified)

slice_nodes = gm.graph.find_nodes(
op="call_function", target=exir_ops.edge.aten.slice_copy.Tensor
)
self.assertEqual(len(slice_nodes), 1)
self.assertEqual(list(slice_nodes[0].meta["val"].shape), [4, 30, 4, 4])

def test_no_swap_when_multi_user(self) -> None:
builder = GraphBuilder()
x = builder.placeholder("x", torch.randn(4, 60, 1, 1))
quant = builder.call_operator(
exir_ops.edge.cadence.quantize_per_tensor.default,
args=(x, 0.5, 0, 0, 255, torch.uint8),
)
sliced = builder.call_operator(
exir_ops.edge.aten.slice_copy.Tensor,
args=(quant, 0, 0, 4, 2),
)
builder.output([sliced, quant])
gm = builder.get_graph_module()

result = PropagateSlice().call(gm)

self.assertFalse(result.modified)

def test_no_swap_noop_slice(self) -> None:
builder = GraphBuilder()
x = builder.placeholder("x", torch.randn(4, 60, 1, 1))
quant = builder.call_operator(
exir_ops.edge.cadence.quantize_per_tensor.default,
args=(x, 0.5, 0, 0, 255, torch.uint8),
)
sliced = builder.call_operator(
exir_ops.edge.aten.slice_copy.Tensor,
args=(quant, 0, 0, 4, 1),
)
builder.output([sliced])
gm = builder.get_graph_module()

result = PropagateSlice().call(gm)

self.assertFalse(result.modified)

def test_unsupported_parent_not_swapped(self) -> None:
builder = GraphBuilder()
x = builder.placeholder("x", torch.randn(4, 60, 1, 1))
relu = builder.call_operator(
exir_ops.edge.aten.relu.default,
args=(x,),
)
sliced = builder.call_operator(
exir_ops.edge.aten.slice_copy.Tensor,
args=(relu, 0, 0, 4, 2),
)
builder.output([sliced])
gm = builder.get_graph_module()

result = PropagateSlice().call(gm)

self.assertFalse(result.modified)
Loading