diff --git a/backends/cadence/aot/reorder_ops.py b/backends/cadence/aot/reorder_ops.py index 857446592ee..b1daefb8a2f 100644 --- a/backends/cadence/aot/reorder_ops.py +++ b/backends/cadence/aot/reorder_ops.py @@ -299,8 +299,9 @@ def advancing_feasible(self, quant_node: torch.fx.Node): # All the conditions satisfied, we advance. return True - def advance_quantize_op(self, graph_module: torch.fx.GraphModule): + def advance_quantize_op(self, graph_module: torch.fx.GraphModule) -> bool: graph = graph_module.graph + modified = False for node in reversed(graph.nodes): if get_overload_packet(node.target) not in ( exir_ops.edge.quantized_decomposed.quantize_per_tensor, @@ -339,15 +340,19 @@ def advance_quantize_op(self, graph_module: torch.fx.GraphModule): # We can safely remove the quant node and trivially quantizable op graph.erase_node(node) graph.erase_node(trivially_quantizable_op) + modified = True - graph_module.recompile() - graph_module.graph.eliminate_dead_code() + return modified def call(self, graph_module: torch.fx.GraphModule) -> PassResult: self.graph_module = graph_module - self.advance_quantize_op(graph_module) - result = super().call(graph_module) - return result + modified = self.advance_quantize_op(graph_module) + if modified: + graph_module.recompile() + graph_module.graph.eliminate_dead_code() + return super().call(graph_module) + + return PassResult(graph_module, False) @register_cadence_pass(CadencePassAttribute(opt_level=1)) @@ -474,14 +479,21 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: # the graph (up to 3 times max, to avoid potential infinite loops) self.graph_module = graph_module iter_count = 0 - modified = True + local_modified = False + overall_modified = False + + while local_modified or iter_count == 0: + local_modified = self.postpone_dequantize_op(self.graph_module) + overall_modified |= local_modified + + if local_modified: + self.graph_module = super().call(self.graph_module).graph_module - while modified and iter_count < 3: - modified = self.postpone_dequantize_op(self.graph_module) - self.graph_module = super().call(self.graph_module).graph_module iter_count += 1 + if iter_count == 3: + break - return super().call(self.graph_module) + return PassResult(self.graph_module, overall_modified) @register_cadence_pass(CadencePassAttribute(opt_level=1)) diff --git a/backends/cadence/aot/tests/test_reorder_ops_passes.py b/backends/cadence/aot/tests/test_reorder_ops_passes.py index 9435957c5e5..998bfd7a676 100644 --- a/backends/cadence/aot/tests/test_reorder_ops_passes.py +++ b/backends/cadence/aot/tests/test_reorder_ops_passes.py @@ -286,13 +286,14 @@ def test_advance_branched_quantize(self) -> None: @torch.no_grad() def test_advance_quantize(self) -> None: builder = GraphBuilder() - x = builder.placeholder("x", torch.randn(16, 1, 6, 32, dtype=torch.float32)) - weights = builder.placeholder( - "weights", torch.randint(-128, 127, (32, 32), dtype=torch.int8) - ) + x_data = torch.randn(16, 1, 32, 6, dtype=torch.float32) + weight_data = torch.randint(-128, 127, (32, 32), dtype=torch.int8) + x = builder.placeholder("x", x_data) + weights = builder.placeholder("weights", weight_data) full = builder.call_operator( op=exir_ops.edge.aten.full.default, args=([1], -7), + kwargs={"dtype": torch.int32}, ) full_1 = builder.call_operator( op=exir_ops.edge.aten.full.default, @@ -304,7 +305,8 @@ def test_advance_quantize(self) -> None: ) full_3 = builder.call_operator( op=exir_ops.edge.aten.full.default, - args=([12], 0.0), + args=([1], 0), + kwargs={"dtype": torch.int32}, ) permute = builder.call_operator( op=exir_ops.edge.aten.permute_copy.default, @@ -337,8 +339,13 @@ def test_advance_quantize(self) -> None: p1 = AdvanceQuantizeOpAboveDefInBranchPass() tmp_graph = cast(PassResult, p1(original_graph)).graph_module - p2 = AdvanceQuantizeOpAboveDefChainPass() - converted_graph = cast(PassResult, p2(tmp_graph)).graph_module + result = transform_and_check_numerics( + tmp_graph, + (x_data, weight_data), + AdvanceQuantizeOpAboveDefChainPass(), + ) + self.assertFalse(result.modified) + converted_graph = result.graph_module # Assert that permute node is now the successor of the quant node. self.assertTrue( get_node_pos( @@ -349,13 +356,14 @@ def test_advance_quantize(self) -> None: def test_postpone_dequantize1(self) -> None: builder = GraphBuilder() - x = builder.placeholder("x", torch.randn(1, 16, 32, 6, dtype=torch.float32)) - weights = builder.placeholder( - "weights", torch.randint(-128, 127, (6, 6), dtype=torch.int8) - ) + x_data = torch.randn(1, 16, 32, 6, dtype=torch.float32) + weight_data = torch.randint(-128, 127, (6, 6), dtype=torch.int8) + x = builder.placeholder("x", x_data) + weights = builder.placeholder("weights", weight_data) full = builder.call_operator( op=exir_ops.edge.aten.full.default, args=([1], -7), + kwargs={"dtype": torch.int32}, ) full_1 = builder.call_operator( op=exir_ops.edge.aten.full.default, @@ -367,7 +375,8 @@ def test_postpone_dequantize1(self) -> None: ) full_3 = builder.call_operator( op=exir_ops.edge.aten.full.default, - args=([12], 0.0), + args=([1], 0), + kwargs={"dtype": torch.int32}, ) quantize_per_tensor = builder.call_operator( op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, @@ -397,8 +406,13 @@ def test_postpone_dequantize1(self) -> None: ) builder.output([permute]) original_graph = builder.get_graph_module() - p = PostponeDequantizeOpBelowUseChainPass() - converted_graph = cast(PassResult, p(original_graph)).graph_module + result = transform_and_check_numerics( + original_graph, + (x_data, weight_data), + PostponeDequantizeOpBelowUseChainPass(), + ) + self.assertTrue(result.modified) + converted_graph = result.graph_module # Assert that dequant node is now the successor of the permute node. self.assertTrue( get_node_pos(converted_graph, exir_ops.edge.aten.permute_copy.default)