diff --git a/backends/cadence/aot/reorder_ops.py b/backends/cadence/aot/reorder_ops.py index 857446592ee..b1daefb8a2f 100644 --- a/backends/cadence/aot/reorder_ops.py +++ b/backends/cadence/aot/reorder_ops.py @@ -299,8 +299,9 @@ def advancing_feasible(self, quant_node: torch.fx.Node): # All the conditions satisfied, we advance. return True - def advance_quantize_op(self, graph_module: torch.fx.GraphModule): + def advance_quantize_op(self, graph_module: torch.fx.GraphModule) -> bool: graph = graph_module.graph + modified = False for node in reversed(graph.nodes): if get_overload_packet(node.target) not in ( exir_ops.edge.quantized_decomposed.quantize_per_tensor, @@ -339,15 +340,19 @@ def advance_quantize_op(self, graph_module: torch.fx.GraphModule): # We can safely remove the quant node and trivially quantizable op graph.erase_node(node) graph.erase_node(trivially_quantizable_op) + modified = True - graph_module.recompile() - graph_module.graph.eliminate_dead_code() + return modified def call(self, graph_module: torch.fx.GraphModule) -> PassResult: self.graph_module = graph_module - self.advance_quantize_op(graph_module) - result = super().call(graph_module) - return result + modified = self.advance_quantize_op(graph_module) + if modified: + graph_module.recompile() + graph_module.graph.eliminate_dead_code() + return super().call(graph_module) + + return PassResult(graph_module, False) @register_cadence_pass(CadencePassAttribute(opt_level=1)) @@ -474,14 +479,21 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: # the graph (up to 3 times max, to avoid potential infinite loops) self.graph_module = graph_module iter_count = 0 - modified = True + local_modified = False + overall_modified = False + + while local_modified or iter_count == 0: + local_modified = self.postpone_dequantize_op(self.graph_module) + overall_modified |= local_modified + + if local_modified: + self.graph_module = super().call(self.graph_module).graph_module - while modified and iter_count < 3: - modified = self.postpone_dequantize_op(self.graph_module) - self.graph_module = super().call(self.graph_module).graph_module iter_count += 1 + if iter_count == 3: + break - return super().call(self.graph_module) + return PassResult(self.graph_module, overall_modified) @register_cadence_pass(CadencePassAttribute(opt_level=1)) diff --git a/backends/cadence/aot/tests/test_reorder_ops_passes.py b/backends/cadence/aot/tests/test_reorder_ops_passes.py index 9435957c5e5..998bfd7a676 100644 --- a/backends/cadence/aot/tests/test_reorder_ops_passes.py +++ b/backends/cadence/aot/tests/test_reorder_ops_passes.py @@ -286,13 +286,14 @@ def test_advance_branched_quantize(self) -> None: @torch.no_grad() def test_advance_quantize(self) -> None: builder = GraphBuilder() - x = builder.placeholder("x", torch.randn(16, 1, 6, 32, dtype=torch.float32)) - weights = builder.placeholder( - "weights", torch.randint(-128, 127, (32, 32), dtype=torch.int8) - ) + x_data = torch.randn(16, 1, 32, 6, dtype=torch.float32) + weight_data = torch.randint(-128, 127, (32, 32), dtype=torch.int8) + x = builder.placeholder("x", x_data) + weights = builder.placeholder("weights", weight_data) full = builder.call_operator( op=exir_ops.edge.aten.full.default, args=([1], -7), + kwargs={"dtype": torch.int32}, ) full_1 = builder.call_operator( op=exir_ops.edge.aten.full.default, @@ -304,7 +305,8 @@ def test_advance_quantize(self) -> None: ) full_3 = builder.call_operator( op=exir_ops.edge.aten.full.default, - args=([12], 0.0), + args=([1], 0), + kwargs={"dtype": torch.int32}, ) permute = builder.call_operator( op=exir_ops.edge.aten.permute_copy.default, @@ -337,8 +339,13 @@ def test_advance_quantize(self) -> None: p1 = AdvanceQuantizeOpAboveDefInBranchPass() tmp_graph = cast(PassResult, p1(original_graph)).graph_module - p2 = AdvanceQuantizeOpAboveDefChainPass() - converted_graph = cast(PassResult, p2(tmp_graph)).graph_module + result = transform_and_check_numerics( + tmp_graph, + (x_data, weight_data), + AdvanceQuantizeOpAboveDefChainPass(), + ) + self.assertFalse(result.modified) + converted_graph = result.graph_module # Assert that permute node is now the successor of the quant node. self.assertTrue( get_node_pos( @@ -349,13 +356,14 @@ def test_advance_quantize(self) -> None: def test_postpone_dequantize1(self) -> None: builder = GraphBuilder() - x = builder.placeholder("x", torch.randn(1, 16, 32, 6, dtype=torch.float32)) - weights = builder.placeholder( - "weights", torch.randint(-128, 127, (6, 6), dtype=torch.int8) - ) + x_data = torch.randn(1, 16, 32, 6, dtype=torch.float32) + weight_data = torch.randint(-128, 127, (6, 6), dtype=torch.int8) + x = builder.placeholder("x", x_data) + weights = builder.placeholder("weights", weight_data) full = builder.call_operator( op=exir_ops.edge.aten.full.default, args=([1], -7), + kwargs={"dtype": torch.int32}, ) full_1 = builder.call_operator( op=exir_ops.edge.aten.full.default, @@ -367,7 +375,8 @@ def test_postpone_dequantize1(self) -> None: ) full_3 = builder.call_operator( op=exir_ops.edge.aten.full.default, - args=([12], 0.0), + args=([1], 0), + kwargs={"dtype": torch.int32}, ) quantize_per_tensor = builder.call_operator( op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, @@ -397,8 +406,13 @@ def test_postpone_dequantize1(self) -> None: ) builder.output([permute]) original_graph = builder.get_graph_module() - p = PostponeDequantizeOpBelowUseChainPass() - converted_graph = cast(PassResult, p(original_graph)).graph_module + result = transform_and_check_numerics( + original_graph, + (x_data, weight_data), + PostponeDequantizeOpBelowUseChainPass(), + ) + self.assertTrue(result.modified) + converted_graph = result.graph_module # Assert that dequant node is now the successor of the permute node. self.assertTrue( get_node_pos(converted_graph, exir_ops.edge.aten.permute_copy.default) diff --git a/exir/passes/scalar_to_tensor_pass.py b/exir/passes/scalar_to_tensor_pass.py index 6dd80cd577d..f7857ae333d 100644 --- a/exir/passes/scalar_to_tensor_pass.py +++ b/exir/passes/scalar_to_tensor_pass.py @@ -7,26 +7,47 @@ # pyre-strict import torch -from executorch.exir.pass_base import ExportPass, map_args +from executorch.exir.pass_base import ExportPass, map_args, PassResult class ScalarToTensorPass(ExportPass): - # pyre-ignore - def call_operator(self, op, args, kwargs, meta): - # pyre-ignore + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: + modified: bool = False def try_coerce(value, arg): # Note: we want to create tensor constants instead of # FakeTensor or ProxyTensor. If python_dispatcher is enabled, # the fake_tensor_mode of inputs will be used so that we won't # get a constant tensor with torch.tensor() call but instead # a fake tensor is created. + nonlocal modified with torch.utils._python_dispatch._disable_current_modes(): - return ( - torch.tensor(value) - if isinstance(value, (float, int, bool)) - and isinstance(arg.type, torch.TensorType) - else value - ) - - args, kwargs = map_args(op, try_coerce, args, kwargs) - return super().call_operator(op, args, kwargs, meta) + if isinstance(value, (float, int, bool)) and isinstance( + arg.type, torch.TensorType + ): + modified = True + return torch.tensor(value) + return value + + args, kwargs = map_args(node.target, try_coerce, node.args, node.kwargs) + if modified: + node.args = args + node.kwargs = kwargs + return modified + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + changed = False + for module in filter( + lambda m: isinstance(m, torch.fx.GraphModule), graph_module.modules() + ): + for node in module.graph.nodes: + if not isinstance( + node.target, (torch._ops.OpOverload, torch._ops.OpOverloadPacket) + ): + continue + changed |= self.maybe_remove_or_replace(node) + + if changed: + graph_module.recompile() + return super().call(graph_module) + + return PassResult(graph_module, False)