Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 68 additions & 6 deletions backends/cadence/aot/reorder_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,12 +248,22 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
@register_cadence_pass(CadencePassAttribute(opt_level=1))
class AdvanceQuantizeOpAboveDefChainPass(ExportPass):
"""
If the input to quantize op is linear chain of view, transpose, permute, or
slice ops that are trivially quantized, we can convert the pattern
view/transpose/permute/slice(fp32) -> quantize(int8/uint8) to
quantize(int8/uint8) -> view/transpose/permute/slice(int8/uint8).
The benefit of such reordering is that the view/transpose/permute/slice
will move far less data.
Advances a quantize op above data-movement ops to reduce data volume.

Handles two cases:

1. Linear chain: if the input to a quantize op is a chain of trivially
quantizable ops (view, transpose, permute, slice), rewrite
data_movement(fp32) -> quantize to quantize -> data_movement(quantized)
so the data movement operates on smaller quantized tensors.

2. Cat: if the input to a quantize op is a cat with a single user (the
quantize), advance the quantize above the cat by quantizing each cat
input individually. A later pass can clean up any redundant
dequant-quant pairs on the inputs.

For the cat case, SplitDequantizedCatPass should run first to ensure
each cat has at most one quantize consumer.
"""

def __init__(self):
Expand Down Expand Up @@ -302,6 +312,47 @@ def advancing_feasible(self, quant_node: torch.fx.Node):
# All the conditions satisfied, we advance.
return True

def _advance_above_cat(
self, quant_node: torch.fx.Node, cat_node: torch.fx.Node
) -> None:
"""Advance a quantize op above a cat by quantizing each cat input."""
graph = quant_node.graph
quant_params = quant_node.args[1:]

cat_inputs = cat_node.args[0]
assert isinstance(cat_inputs, (list, tuple))

new_inputs: list[torch.fx.Node] = []
for inp in cat_inputs:
# cat concatenates tensors, so every input must be a node.
assert isinstance(inp, torch.fx.Node)

with graph.inserting_before(cat_node):
new_quant = graph.call_function(
# pyre-ignore[6]
quant_node.target,
args=(inp, *quant_params),
)
# This copies the fp32 input's meta, so meta["val"] keeps the
# fp32 dtype rather than the quantized output dtype. That's fine:
# nothing in this pass reads dtype from meta (only shape, which
# is correct), and call() re-runs super().call() to re-propagate
# fake tensors, making meta dtype-consistent before we return.
new_quant.meta = inp.meta.copy()
new_inputs.append(new_quant)

dim = get_arg(cat_node, "dim", int)
with graph.inserting_before(quant_node):
new_cat = graph.call_function(
# pyre-ignore[6]
cat_node.target,
args=(new_inputs, dim),
)
new_cat.meta = quant_node.meta.copy()

quant_node.replace_all_uses_with(new_cat)
graph.erase_node(quant_node)

def advance_quantize_op(self, graph_module: torch.fx.GraphModule) -> bool:
graph = graph_module.graph
modified = False
Expand All @@ -314,6 +365,17 @@ def advance_quantize_op(self, graph_module: torch.fx.GraphModule) -> bool:
):
continue

inp = node.args[0]
if (
isinstance(inp, torch.fx.Node)
and get_overload_packet(inp.target)
in (exir_ops.edge.aten.cat, torch.ops.aten.cat)
and len(inp.users) == 1
):
self._advance_above_cat(node, inp)
modified = True
continue

if not self.advancing_feasible(node):
continue

Expand Down
71 changes: 71 additions & 0 deletions backends/cadence/aot/tests/test_reorder_ops_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1268,3 +1268,74 @@ def test_two_quant_outputs_different_params_separate_cats(self) -> None:
)
quant_cat_inputs = {node.args[0] for node in quant_nodes}
self.assertEqual(len(quant_cat_inputs), 2)


class TestAdvanceQuantAboveCat(unittest.TestCase):
def test_float_inputs_get_quantized(self) -> None:
"""Float (non-dq) inputs to cat should get a quant inserted."""
builder = GraphBuilder()
a = builder.placeholder("a", torch.randn(2, 4))
b = builder.placeholder("b", torch.randn(2, 4))
cat = builder.call_operator(exir_ops.edge.aten.cat.default, args=([a, b], 0))
q = builder.call_operator(
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
args=(cat, 0.01, 0, -128, 127, torch.int8),
)
builder.output([q])
gm = builder.get_graph_module()

result = AdvanceQuantizeOpAboveDefChainPass().call(gm)

self.assertTrue(result.modified)
converted = result.graph_module

# Two new quants (one per input) should exist; the original post-cat quant is gone.
self.assertEqual(
count_node(
converted,
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
),
2,
)

# Cat should take quantized inputs.
cat_nodes = converted.graph.find_nodes(
op="call_function", target=exir_ops.edge.aten.cat.default
)
self.assertEqual(len(cat_nodes), 1)
for inp in cat_nodes[0].args[0]:
self.assertEqual(
inp.target,
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
)

def test_cat_with_multiple_users_not_advanced(self) -> None:
"""Cat with multiple users should not be advanced (split pass handles this first)."""
builder = GraphBuilder()
x_int8 = builder.placeholder(
"x_int8", torch.randint(-128, 127, (2, 4), dtype=torch.int8)
)
dq = builder.call_operator(
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
args=(x_int8, 0.02, -5, -128, 127, torch.int8),
)
b = builder.placeholder("b", torch.randn(2, 4))
cat = builder.call_operator(exir_ops.edge.aten.cat.default, args=([dq, b], 0))
sliced = builder.call_operator(
exir_ops.edge.aten.slice_copy.Tensor, args=(cat, 0, 0, 2)
)
q = builder.call_operator(
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
args=(cat, 0.02, -5, -128, 127, torch.int8),
)
q_dq = builder.call_operator(
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
args=(q, 0.02, -5, -128, 127, torch.int8),
)
builder.output([sliced, q_dq])
gm = builder.get_graph_module()

result = AdvanceQuantizeOpAboveDefChainPass().call(gm)

self.assertFalse(result.modified)
self.assertEqual(count_node(gm, exir_ops.edge.aten.cat.default), 1)
Loading