diff --git a/backends/nxp/aten_passes/convert_unsqueeze_to_view.py b/backends/nxp/aten_passes/convert_nodes_to_view.py similarity index 66% rename from backends/nxp/aten_passes/convert_unsqueeze_to_view.py rename to backends/nxp/aten_passes/convert_nodes_to_view.py index 613c0a3b9b1..66f30913337 100644 --- a/backends/nxp/aten_passes/convert_unsqueeze_to_view.py +++ b/backends/nxp/aten_passes/convert_nodes_to_view.py @@ -11,8 +11,21 @@ from torch.fx.passes.infra.pass_base import PassBase, PassResult -class ConvertUnsqueezeToViewPass(PassBase): - """Replace 'aten.unsqueeze.default' with 'aten.view.default'. +class ConvertNodesToViewPass(PassBase): + """Replaces: + - 'aten.squeeze.default', 'aten.squeeze.dims' and 'aten.squeeze.dim' with 'aten.view.default'. + + x x + │ │ + ┌─────────────▼─────────────┐ replace with ┌─────────────▼─────────────┐ + │ aten.squeeze(x, dim) │ ──────────────► │ aten.view.default(x, S) │ + └─────────────┬─────────────┘ └─────────────┬─────────────┘ + │ │ + ▼ ▼ + out out + + + - 'aten.unsqueeze.default' with 'aten.view.default'. x x │ │ @@ -22,8 +35,17 @@ class ConvertUnsqueezeToViewPass(PassBase): │ │ ▼ ▼ out out + """ + @staticmethod + def _is_squeeze(node_: Node) -> bool: + return node_.op == "call_function" and ( + node_.target == torch.ops.aten.squeeze.dim + or node_.target == torch.ops.aten.squeeze.dims + or node_.target == torch.ops.aten.squeeze.default + ) + @staticmethod def _is_unsqueeze(node_: Node) -> bool: return ( @@ -55,11 +77,8 @@ def call(self, graph_module: GraphModule) -> Optional[PassResult]: self.graph_module = graph_module made_changes = False - if not any(self._is_unsqueeze(n) for n in graph_module.graph.nodes): - return PassResult(graph_module, made_changes) - for node in list(graph_module.graph.nodes): - if not self._is_unsqueeze(node): + if not self._is_squeeze(node) and not self._is_unsqueeze(node): continue input_node = node.all_input_nodes[0] diff --git a/backends/nxp/aten_passes/neutron_aten_pass_manager.py b/backends/nxp/aten_passes/neutron_aten_pass_manager.py index 09b6efd5392..d47dc633acc 100644 --- a/backends/nxp/aten_passes/neutron_aten_pass_manager.py +++ b/backends/nxp/aten_passes/neutron_aten_pass_manager.py @@ -1,4 +1,4 @@ -# Copyright 2026 NXP +# Copyright 2025-2026 NXP # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -7,8 +7,8 @@ import torch -from executorch.backends.nxp.aten_passes.convert_unsqueeze_to_view import ( - ConvertUnsqueezeToViewPass, +from executorch.backends.nxp.aten_passes.convert_nodes_to_view import ( + ConvertNodesToViewPass, ) from executorch.backends.nxp.aten_passes.fuse_batch_norm_with_conv_pass import ( FuseBatchNormWithConvPass, @@ -52,7 +52,7 @@ def __init__( RemoveNodesWithKnownOutputs(), FuseLinearAndAddPass(), MoveActivationBeforeConcat(neutron_target_spec), - ConvertUnsqueezeToViewPass(), + ConvertNodesToViewPass(), ] super().__init__(passes) diff --git a/backends/nxp/tests/models.py b/backends/nxp/tests/models.py index 19aa707aaa5..79d57bc84ee 100644 --- a/backends/nxp/tests/models.py +++ b/backends/nxp/tests/models.py @@ -679,3 +679,15 @@ def __init__(self, dim): def forward(self, x, y): return torch.unsqueeze(x + y, self.dim) + + +class SqueezeAddModel(torch.nn.Module): + def __init__(self, dim=None): + super().__init__() + self.dim = dim + + def forward(self, x, y): + if self.dim is None: + return torch.squeeze(x + y) + else: + return torch.squeeze(x + y, self.dim) diff --git a/backends/nxp/tests/test_convert_nodes_to_view.py b/backends/nxp/tests/test_convert_nodes_to_view.py new file mode 100644 index 00000000000..3c18dbca2e5 --- /dev/null +++ b/backends/nxp/tests/test_convert_nodes_to_view.py @@ -0,0 +1,281 @@ +# Copyright 2026 NXP +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import pytest +import torch +from executorch.backends.nxp.aten_passes.convert_nodes_to_view import ( + ConvertNodesToViewPass, +) +from executorch.backends.nxp.aten_passes.neutron_aten_pass_manager import ( + NeutronAtenPassManager, +) +from executorch.backends.nxp.backend.edge_program_converter import ( + EdgeProgramToIRConverter, +) +from executorch.backends.nxp.tests.executorch_pipeline import ( + neutron_target_spec, + to_quantized_edge_program, +) +from executorch.backends.nxp.tests.executors import ( + convert_run_compare, + graph_contains_any_of_ops, +) + +from executorch.backends.nxp.tests.models import SqueezeAddModel, UnsqueezeAddModel +from executorch.exir.dialects._ops import ops as exir_ops +from torch.export import ExportedProgram + + +@pytest.fixture(autouse=True) +def reseed_model_per_test_run(): + torch.manual_seed(42) + np.random.seed(23) + + +@pytest.mark.parametrize( + "input_shape, dim", + [ + pytest.param((8, 1, 1), None, id="3D, dim = None."), + pytest.param((8, 4, 1), 2, id="3D, dim hit."), + pytest.param((8, 4, 1), 1, id="3D, dim miss."), + pytest.param((8, 4, 1), -1, id="3D, negative dim hit."), + pytest.param((8, 1, 1, 8), [1, 2], id="4D, full dims overlap."), + pytest.param((8, 1, 4, 8), [1, 2], id="4D, partial dims overlap."), + pytest.param((1, 8, 4, 8), [1, 2], id="4D, no dims overlap."), + pytest.param((8, 1, 1, 8), [-2, -3], id="4D, negative full dims overlap."), + pytest.param((8, 1, 4, 8), [-2, -3], id="4D, negative partial dims overlap."), + pytest.param((1, 8, 4, 8), [-2, -3], id="4D, negative no dims overlap."), + pytest.param( + (8, 1, 1, 8), (1, 2), id="4D, tuple instead of list, full dims overlap." + ), + pytest.param( + (8, 1, 4, 8), (1, 2), id="4D, tuple instead of list, partial dims overlap." + ), + pytest.param( + (1, 8, 4, 8), (1, 2), id="4D, tuple instead of list, no dims overlap." + ), + ], +) +def test_convert_squeeze_to_view_simple(mocker, input_shape, dim): + model = SqueezeAddModel(dim=dim) + + example_input_1 = torch.rand(input_shape) + example_input_2 = torch.rand(input_shape) + + exir_program_aten = torch.export.export( + model, + (example_input_1, example_input_2), + ).module() + + # Check that `Squeeze` is present in the model. + assert graph_contains_any_of_ops( + exir_program_aten.graph, + [ + torch.ops.aten.squeeze.dim, + torch.ops.aten.squeeze.dims, + torch.ops.aten.squeeze.default, + ], + ) + + example_input = (example_input_1, example_input_2) + outputs_before = [o.detach().numpy() for o in exir_program_aten(*example_input)] + + # Apply the optimization. + NeutronAtenPassManager(neutron_target_spec, [ConvertNodesToViewPass()])( + exir_program_aten + ) + + # Make sure no `Squeeze` is in the model. + assert not graph_contains_any_of_ops( + exir_program_aten.graph, + [ + torch.ops.aten.squeeze.dim, + torch.ops.aten.squeeze.dims, + torch.ops.aten.squeeze.default, + ], + ) + + # Make sure there is `aten.view.default` in the model. + assert graph_contains_any_of_ops( + exir_program_aten.graph, + [torch.ops.aten.view.default], + ) + + outputs_after = [o.detach().numpy() for o in exir_program_aten(*example_input)] + + # Make sure the model still produces the exact same output. + assert len(outputs_before) == len(outputs_after) + + for i in range(len(outputs_before)): + assert np.allclose(outputs_before[i], outputs_after[i]) + + +@pytest.mark.parametrize( + "input_shape, dim", + [ + pytest.param((8, 1, 1), None, id="3D, dim = None."), + pytest.param((8, 4, 1), 2, id="3D, dim hit."), + pytest.param((8, 4, 1), 1, id="3D, dim miss."), + pytest.param((8, 4, 1), -1, id="3D, negative dim hit."), + pytest.param((8, 1, 4, 8), [1, 2], id="4D, partial dims overlap."), + pytest.param((8, 1, 4, 8), [-2, -3], id="4D, negative partial dims overlap."), + ], +) +def test_convert_squeeze_to_view_full_pipeline(mocker, input_shape, dim): + model = SqueezeAddModel(dim) + converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") + + # Run conversion + edge_program = to_quantized_edge_program( + model, + [input_shape, input_shape], + ).exported_program() + + # Check that `Squeeze` is no longer present in the model + assert not graph_contains_any_of_ops( + edge_program.graph, + [ + torch.ops.aten.squeeze.dim, + torch.ops.aten.squeeze.dims, + torch.ops.aten.squeeze.default, + ], + ) + + # Capture generated model + neutron_ir_model = converter_spy.spy_return[0] + exported_program: ExportedProgram = converter_spy.call_args.args[1] + + # Make sure `edge.aten.view_copy.default` is in the model. + assert graph_contains_any_of_ops( + exported_program.graph, + [ + exir_ops.edge.aten.view_copy.default, + ], + ) + + example_input_1 = (np.random.random(input_shape).astype(np.float32) * 50).astype( + np.int8 + ) + example_input_2 = (np.random.random(input_shape).astype(np.float32) * 50).astype( + np.int8 + ) + example_input = {0: example_input_1, 1: example_input_2} + + convert_run_compare( + exported_program, + input_data=example_input, + tfl_model=neutron_ir_model, + ) + + +@pytest.mark.parametrize( + "input_shape, dim", + [ + pytest.param((2,), 0, id="1D."), + pytest.param((8, 4, 6), 2, id="3D."), + pytest.param((8, 4, 6, 8), -2, id="4D, negative dim."), + pytest.param((8, 4, 6), 3, id="3D, dim arg is clipped."), + pytest.param((8, 4, 6), -4, id="3D, dim arg is clipped."), + ], +) +def test_convert_unsqueeze_to_view_simple(mocker, input_shape, dim): + model = UnsqueezeAddModel(dim) + + example_input_1 = torch.rand(input_shape) + example_input_2 = torch.rand(input_shape) + + exir_program_aten = torch.export.export( + model, + (example_input_1, example_input_2), + ).module() + + # Check "aten.unsqueeze.default" is present + assert graph_contains_any_of_ops( + exir_program_aten.graph, [torch.ops.aten.unsqueeze.default] + ) + + example_input = (example_input_1, example_input_2) + outputs_before = [o.detach().numpy() for o in exir_program_aten(*example_input)] + + # Apply the optimization. + NeutronAtenPassManager(neutron_target_spec, [ConvertNodesToViewPass()])( + exir_program_aten + ) + + # Make sure no "aten.unsqueeze.default" is in the model. + assert not graph_contains_any_of_ops( + exir_program_aten.graph, + [torch.ops.aten.unsqueeze.default], + ) + + # Make sure there is "aten.view.default" in the model. + assert graph_contains_any_of_ops( + exir_program_aten.graph, + [torch.ops.aten.view.default], + ) + + outputs_after = [o.detach().numpy() for o in exir_program_aten(*example_input)] + + # Make sure the model still produces the exact same output. + assert len(outputs_before) == len(outputs_after) + + for i in range(len(outputs_before)): + assert np.allclose(outputs_before[i], outputs_after[i]) + + +@pytest.mark.parametrize( + "input_shape, dim", + [ + pytest.param((2,), 0, id="1D."), + pytest.param((8, 4, 6), 2, id="3D."), + pytest.param((8, 4, 6, 8), -2, id="4D, negative dim."), + pytest.param((8, 4, 6), 3, id="3D, dim arg is clipped."), + pytest.param((8, 4, 6), -4, id="3D, dim arg is clipped."), + ], +) +def test_convert_unsqueeze_to_view_full_pipeline(mocker, input_shape, dim): + model = UnsqueezeAddModel(dim) + converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") + + # Run conversion + edge_program = to_quantized_edge_program( + model, + [input_shape, input_shape], + ).exported_program() + + # Make sure no "aten.unsqueeze.default" is in the model. + assert not graph_contains_any_of_ops( + edge_program.graph, + [ + torch.ops.aten.unsqueeze.default, + ], + ) + + # Capture generated model + neutron_ir_model = converter_spy.spy_return[0] + exported_program: ExportedProgram = converter_spy.call_args.args[1] + + # Make sure "edge.aten.view_copy.default" is in the model. + assert graph_contains_any_of_ops( + exported_program.graph, + [ + exir_ops.edge.aten.view_copy.default, + ], + ) + + example_input_1 = (np.random.random(input_shape).astype(np.float32) * 50).astype( + np.int8 + ) + example_input_2 = (np.random.random(input_shape).astype(np.float32) * 50).astype( + np.int8 + ) + example_input = {0: example_input_1, 1: example_input_2} + + convert_run_compare( + exported_program, + input_data=example_input, + tfl_model=neutron_ir_model, + ) diff --git a/backends/nxp/tests/test_convert_unsqueeze_to_view.py b/backends/nxp/tests/test_convert_unsqueeze_to_view.py deleted file mode 100644 index 1d2555e5809..00000000000 --- a/backends/nxp/tests/test_convert_unsqueeze_to_view.py +++ /dev/null @@ -1,143 +0,0 @@ -# Copyright 2026 NXP -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import numpy as np -import pytest -import torch -from executorch.backends.nxp.aten_passes.neutron_aten_pass_manager import ( - ConvertUnsqueezeToViewPass, - NeutronAtenPassManager, -) -from executorch.backends.nxp.backend.edge_program_converter import ( - EdgeProgramToIRConverter, -) -from executorch.backends.nxp.tests.executorch_pipeline import ( - neutron_target_spec, - to_quantized_edge_program, -) -from executorch.backends.nxp.tests.executors import ( - convert_run_compare, - graph_contains_any_of_ops, -) - -from executorch.backends.nxp.tests.models import UnsqueezeAddModel -from executorch.exir.dialects._ops import ops as exir_ops -from torch.export import ExportedProgram - - -@pytest.fixture(autouse=True) -def reseed_model_per_test_run(): - torch.manual_seed(42) - np.random.seed(23) - - -@pytest.mark.parametrize( - "input_shape, dim", - [ - pytest.param((2,), 0, id="1D."), - pytest.param((8, 4, 6), 2, id="3D."), - pytest.param((8, 4, 6, 8), -2, id="4D, negative dim."), - pytest.param((8, 4, 6), 3, id="3D, dim arg is clipped."), - pytest.param((8, 4, 6), -4, id="3D, dim arg is clipped."), - ], -) -def test_convert_unsqueeze_to_view_simple(mocker, input_shape, dim): - model = UnsqueezeAddModel(dim) - - example_input_1 = torch.rand(input_shape) - example_input_2 = torch.rand(input_shape) - - exir_program_aten = torch.export.export( - model, - (example_input_1, example_input_2), - ).module() - - # Check "aten.unsqueeze.default" is present - assert graph_contains_any_of_ops( - exir_program_aten.graph, [torch.ops.aten.unsqueeze.default] - ) - - example_input = (example_input_1, example_input_2) - outputs_before = [o.detach().numpy() for o in exir_program_aten(*example_input)] - - # Apply the optimization. - NeutronAtenPassManager(neutron_target_spec, [ConvertUnsqueezeToViewPass()])( - exir_program_aten - ) - - # Make sure no "aten.unsqueeze.default" is in the model. - assert not graph_contains_any_of_ops( - exir_program_aten.graph, - [torch.ops.aten.unsqueeze.default], - ) - - # Make sure there is "aten.view.default" in the model. - assert graph_contains_any_of_ops( - exir_program_aten.graph, - [torch.ops.aten.view.default], - ) - - outputs_after = [o.detach().numpy() for o in exir_program_aten(*example_input)] - - # Make sure the model still produces the exact same output. - assert len(outputs_before) == len(outputs_after) - - for i in range(len(outputs_before)): - assert np.allclose(outputs_before[i], outputs_after[i]) - - -@pytest.mark.parametrize( - "input_shape, dim", - [ - pytest.param((2,), 0, id="1D."), - pytest.param((8, 4, 6), 2, id="3D."), - pytest.param((8, 4, 6, 8), -2, id="4D, negative dim."), - pytest.param((8, 4, 6), 3, id="3D, dim arg is clipped."), - pytest.param((8, 4, 6), -4, id="3D, dim arg is clipped."), - ], -) -def test_convert_unsqueeze_to_view_full_pipeline(mocker, input_shape, dim): - model = UnsqueezeAddModel(dim) - converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") - - # Run conversion - edge_program = to_quantized_edge_program( - model, - [input_shape, input_shape], - ).exported_program() - - # Make sure no "aten.unsqueeze.default" is in the model. - assert not graph_contains_any_of_ops( - edge_program.graph, - [ - torch.ops.aten.unsqueeze.default, - ], - ) - - # Capture generated model - neutron_ir_model = converter_spy.spy_return[0] - exported_program: ExportedProgram = converter_spy.call_args.args[1] - - # Make sure "edge.aten.view_copy.default" is in the model. - assert graph_contains_any_of_ops( - exported_program.graph, - [ - exir_ops.edge.aten.view_copy.default, - ], - ) - - example_input_1 = (np.random.random(input_shape).astype(np.float32) * 50).astype( - np.int8 - ) - example_input_2 = (np.random.random(input_shape).astype(np.float32) * 50).astype( - np.int8 - ) - example_input = {0: example_input_1, 1: example_input_2} - - convert_run_compare( - exported_program, - input_data=example_input, - tfl_model=neutron_ir_model, - )