diff --git a/backends/cadence/aot/reorder_ops.py b/backends/cadence/aot/reorder_ops.py index 1e6682c5943..2774b3d7477 100644 --- a/backends/cadence/aot/reorder_ops.py +++ b/backends/cadence/aot/reorder_ops.py @@ -248,12 +248,22 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: @register_cadence_pass(CadencePassAttribute(opt_level=1)) class AdvanceQuantizeOpAboveDefChainPass(ExportPass): """ - If the input to quantize op is linear chain of view, transpose, permute, or - slice ops that are trivially quantized, we can convert the pattern - view/transpose/permute/slice(fp32) -> quantize(int8/uint8) to - quantize(int8/uint8) -> view/transpose/permute/slice(int8/uint8). - The benefit of such reordering is that the view/transpose/permute/slice - will move far less data. + Advances a quantize op above data-movement ops to reduce data volume. + + Handles two cases: + + 1. Linear chain: if the input to a quantize op is a chain of trivially + quantizable ops (view, transpose, permute, slice), rewrite + data_movement(fp32) -> quantize to quantize -> data_movement(quantized) + so the data movement operates on smaller quantized tensors. + + 2. Cat: if the input to a quantize op is a cat with a single user (the + quantize), advance the quantize above the cat by quantizing each cat + input individually. A later pass can clean up any redundant + dequant-quant pairs on the inputs. + + For the cat case, SplitDequantizedCatPass should run first to ensure + each cat has at most one quantize consumer. """ def __init__(self): @@ -302,6 +312,47 @@ def advancing_feasible(self, quant_node: torch.fx.Node): # All the conditions satisfied, we advance. return True + def _advance_above_cat( + self, quant_node: torch.fx.Node, cat_node: torch.fx.Node + ) -> None: + """Advance a quantize op above a cat by quantizing each cat input.""" + graph = quant_node.graph + quant_params = quant_node.args[1:] + + cat_inputs = cat_node.args[0] + assert isinstance(cat_inputs, (list, tuple)) + + new_inputs: list[torch.fx.Node] = [] + for inp in cat_inputs: + # cat concatenates tensors, so every input must be a node. + assert isinstance(inp, torch.fx.Node) + + with graph.inserting_before(cat_node): + new_quant = graph.call_function( + # pyre-ignore[6] + quant_node.target, + args=(inp, *quant_params), + ) + # This copies the fp32 input's meta, so meta["val"] keeps the + # fp32 dtype rather than the quantized output dtype. That's fine: + # nothing in this pass reads dtype from meta (only shape, which + # is correct), and call() re-runs super().call() to re-propagate + # fake tensors, making meta dtype-consistent before we return. + new_quant.meta = inp.meta.copy() + new_inputs.append(new_quant) + + dim = get_arg(cat_node, "dim", int) + with graph.inserting_before(quant_node): + new_cat = graph.call_function( + # pyre-ignore[6] + cat_node.target, + args=(new_inputs, dim), + ) + new_cat.meta = quant_node.meta.copy() + + quant_node.replace_all_uses_with(new_cat) + graph.erase_node(quant_node) + def advance_quantize_op(self, graph_module: torch.fx.GraphModule) -> bool: graph = graph_module.graph modified = False @@ -314,6 +365,17 @@ def advance_quantize_op(self, graph_module: torch.fx.GraphModule) -> bool: ): continue + inp = node.args[0] + if ( + isinstance(inp, torch.fx.Node) + and get_overload_packet(inp.target) + in (exir_ops.edge.aten.cat, torch.ops.aten.cat) + and len(inp.users) == 1 + ): + self._advance_above_cat(node, inp) + modified = True + continue + if not self.advancing_feasible(node): continue diff --git a/backends/cadence/aot/tests/test_reorder_ops_passes.py b/backends/cadence/aot/tests/test_reorder_ops_passes.py index f095be9628d..0253772a7b9 100644 --- a/backends/cadence/aot/tests/test_reorder_ops_passes.py +++ b/backends/cadence/aot/tests/test_reorder_ops_passes.py @@ -1268,3 +1268,74 @@ def test_two_quant_outputs_different_params_separate_cats(self) -> None: ) quant_cat_inputs = {node.args[0] for node in quant_nodes} self.assertEqual(len(quant_cat_inputs), 2) + + +class TestAdvanceQuantAboveCat(unittest.TestCase): + def test_float_inputs_get_quantized(self) -> None: + """Float (non-dq) inputs to cat should get a quant inserted.""" + builder = GraphBuilder() + a = builder.placeholder("a", torch.randn(2, 4)) + b = builder.placeholder("b", torch.randn(2, 4)) + cat = builder.call_operator(exir_ops.edge.aten.cat.default, args=([a, b], 0)) + q = builder.call_operator( + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, + args=(cat, 0.01, 0, -128, 127, torch.int8), + ) + builder.output([q]) + gm = builder.get_graph_module() + + result = AdvanceQuantizeOpAboveDefChainPass().call(gm) + + self.assertTrue(result.modified) + converted = result.graph_module + + # Two new quants (one per input) should exist; the original post-cat quant is gone. + self.assertEqual( + count_node( + converted, + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, + ), + 2, + ) + + # Cat should take quantized inputs. + cat_nodes = converted.graph.find_nodes( + op="call_function", target=exir_ops.edge.aten.cat.default + ) + self.assertEqual(len(cat_nodes), 1) + for inp in cat_nodes[0].args[0]: + self.assertEqual( + inp.target, + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, + ) + + def test_cat_with_multiple_users_not_advanced(self) -> None: + """Cat with multiple users should not be advanced (split pass handles this first).""" + builder = GraphBuilder() + x_int8 = builder.placeholder( + "x_int8", torch.randint(-128, 127, (2, 4), dtype=torch.int8) + ) + dq = builder.call_operator( + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, + args=(x_int8, 0.02, -5, -128, 127, torch.int8), + ) + b = builder.placeholder("b", torch.randn(2, 4)) + cat = builder.call_operator(exir_ops.edge.aten.cat.default, args=([dq, b], 0)) + sliced = builder.call_operator( + exir_ops.edge.aten.slice_copy.Tensor, args=(cat, 0, 0, 2) + ) + q = builder.call_operator( + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, + args=(cat, 0.02, -5, -128, 127, torch.int8), + ) + q_dq = builder.call_operator( + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, + args=(q, 0.02, -5, -128, 127, torch.int8), + ) + builder.output([sliced, q_dq]) + gm = builder.get_graph_module() + + result = AdvanceQuantizeOpAboveDefChainPass().call(gm) + + self.assertFalse(result.modified) + self.assertEqual(count_node(gm, exir_ops.edge.aten.cat.default), 1)