diff --git a/backends/arm/_passes/__init__.py b/backends/arm/_passes/__init__.py index 174a1960aab..1fc1efae797 100644 --- a/backends/arm/_passes/__init__.py +++ b/backends/arm/_passes/__init__.py @@ -103,6 +103,7 @@ QuantizeClampArgumentsPass, ) from .fuse_batch_norm2d_pass import FuseBatchNorm2dPass # noqa +from .fuse_concat_pass import FuseConcatPass # noqa from .fuse_consecutive_concat_shapes import FuseConsecutiveConcatShapesPass # noqa from .fuse_consecutive_rescales_pass import FuseConsecutiveRescalesPass # noqa from .fuse_constant_ops_pass import ( # noqa diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index ebe6c4591e6..e3c3c35b033 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -100,6 +100,7 @@ EnsureUniqueOutputNodesPass, FoldAndAnnotateQParamsPass, FuseBatchNorm2dPass, + FuseConcatPass, FuseConsecutiveConcatShapesPass, FuseConsecutiveRescalesPass, FuseConstantArgsPass, @@ -532,6 +533,7 @@ def _tosa_pipeline( # Aten -> TOSA transformation passes self.add_passes( [ + FuseConcatPass(), RewriteUpsamplePass(), RewriteMaxPool2dPass(), RewriteConvPass(exported_program), diff --git a/backends/arm/_passes/fuse_concat_pass.py b/backends/arm/_passes/fuse_concat_pass.py new file mode 100644 index 00000000000..bce32b58d8e --- /dev/null +++ b/backends/arm/_passes/fuse_concat_pass.py @@ -0,0 +1,379 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from typing import Set, Type + +import torch.fx +from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, PassResult + +logger = logging.getLogger(__name__) + + +def _int_arg(node: torch.fx.Node, index: int, default: int) -> int: + """Get an integer argument from a node, with a default if missing.""" + val = node.args[index] if len(node.args) > index else default + assert isinstance(val, int) + return val + + +def _normalize_dim(dim: int, rank: int) -> int: + """Normalize a (possibly negative) dim index to ``[0, rank)``.""" + return (dim + rank) % rank + + +def _slice_params(node: torch.fx.Node, dim_size: int) -> tuple[int, int, int, int]: + """Extract (dim, start, end, step) from a slice_copy node. + + ``dim``, ``start``, and ``end`` are normalized to non-negative indices in + ``[0, dim_size]`` (matching PyTorch slice semantics, where negative bounds + count from the end of the dimension). ``dim_size`` is the size of the + source tensor along the slice dimension. + + """ + rank = len(get_first_fake_tensor(node).shape) + dim = _normalize_dim(_int_arg(node, 1, 0), rank) + start = _int_arg(node, 2, 0) + end = _int_arg(node, 3, dim_size) + if start < 0: + start += dim_size + if end < 0: + end += dim_size + start = max(0, min(start, dim_size)) + end = max(0, min(end, dim_size)) + step = _int_arg(node, 4, 1) + return dim, start, end, step + + +_SLICE_OP = exir_ops.edge.aten.slice_copy.Tensor + + +def _assert_qparams_match(source: torch.fx.Node, cat_node: torch.fx.Node) -> None: + """Verify the data-providing node's output_qparams match cat_node's. + + FuseConcatPass runs after FoldAndAnnotateQParamsPass, so cat_node may carry + output_qparams. TOSA cat preserves scale across inputs and output, so any + node we route consumers to (directly, or via a new slice that reads from + it) must produce data with matching qparams. + + Skipped when either side has no output_qparams entries: FP graphs (key + missing), Q ops (qparams live in args), and tosa_rescale ops (qparams in + args; folder leaves an empty dict). The assertion is best-effort, intended + to catch upstream bugs that produce two foldable ops with diverging + output_qparams. + + """ + cat_qp = cat_node.meta.get("output_qparams") + src_qp = source.meta.get("output_qparams") + if not cat_qp or not src_qp: + return + assert cat_qp == src_qp, ( + f"FuseConcatPass: output_qparams mismatch — {source.name} {src_qp} != " + f"cat {cat_node.name} {cat_qp}" + ) + + +def _is_valid_slice(node: torch.fx.Node, cat_dim: int, dim_size: int) -> bool: + """Check that node is a slice_copy on cat_dim with step=1.""" + if node.target != _SLICE_OP: + return False + s_dim, _, _, s_step = _slice_params(node, dim_size) + return s_dim == cat_dim and s_step == 1 + + +def _find_slice_replacement( + slice_op: torch.fx.Node, + cat_node: torch.fx.Node, + cat_dim: int, + s_start: int, + s_end: int, + offsets: list[tuple[int, int, torch.fx.Node]], +) -> torch.fx.Node | None: + """Find a replacement for a slice that consumes a cat output. + + ``offsets`` maps each concat input to its range in the concatenated + output: [(start, end, input_node), ...] along ``cat_dim``. + + Returns the replacement node (exact input match or adjusted sub-slice), + or None if the slice crosses input boundaries. + + """ + for o_start, o_end, inp in offsets: + if s_start == o_start and s_end == o_end: + _assert_qparams_match(inp, cat_node) + return inp + if s_start >= o_start and s_end <= o_end: + _assert_qparams_match(inp, cat_node) + graph = cat_node.graph + with graph.inserting_before(slice_op): + new_slice = graph.call_function( + _SLICE_OP, + (inp, cat_dim, s_start - o_start, s_end - o_start), + ) + new_slice.meta = slice_op.meta.copy() + return new_slice + return None + + +def _find_common_slice_source( + cat_inputs: list | tuple, + cat_dim: int, + dim_size: int, +) -> torch.fx.Node | None: + """Check all inputs are valid slices of the same source. + + Returns the source. + + """ + source_node = None + for inp in cat_inputs: + if not isinstance(inp, torch.fx.Node): + return None + if not _is_valid_slice(inp, cat_dim, dim_size): + return None + slice_source = inp.args[0] + if source_node is None: + source_node = slice_source + elif slice_source is not source_node: + return None + assert isinstance(source_node, torch.fx.Node) + return source_node + + +def _check_contiguous_slices( + cat_inputs: list | tuple, + source_dim_size: int, +) -> tuple[int, int] | None: + """Check slices are contiguous. + + Returns (first_start, last_end) or None. + + """ + _, first_start, _, _ = _slice_params(cat_inputs[0], source_dim_size) + expected_start = first_start + for inp in cat_inputs: + _, s_start, s_end, _ = _slice_params(inp, source_dim_size) + if s_start != expected_start: + return None + expected_start = s_end + + # expected_start is now the end of the last slice + return first_start, expected_start + + +class FuseConcatPass(ArmPass): + """Eliminate redundant concat (cat) operations via graph pattern matching. + + This pass recognizes and removes concat operations that can be proven to + produce no useful data movement. Eliminating these at the FX/TOSA level + prevents Vela from generating MemoryCopy operations on the Ethos-U NPU. + + Five patterns are handled: + + 1. Single-input concat: cat([x], dim) is a no-op; replace with x. + 2. Concat-then-slice (exact): if a consumer of cat([a, b, ...], dim) is + a slice_copy that extracts exactly one original input, replace it + with the corresponding concat input directly. + 3. Slice-then-concat (full): if cat([slice(x, d, s0, e0), + slice(x, d, s1, e1), ...], dim) reconstructs x exactly (contiguous + slices covering the full source dimension), replace with x. + 4. Concat-then-sub-slice: if a consumer of cat([a, b, ...], dim) is a + slice_copy whose range falls entirely within one original input, + replace it with an adjusted slice on that input directly. + 5. Slice-then-concat (partial): if contiguous slices of the same tensor + are concatenated but cover only a sub-range of the source dimension, + replace with a single slice on the source. + + """ + + _passes_required_after: Set[Type[ExportPass]] = set() + + cat_ops = { + exir_ops.edge.aten.cat.default, + } + slice_op = _SLICE_OP + + def call(self, graph_module: torch.fx.GraphModule): + modified = False + graph = graph_module.graph + + for node in list(graph.nodes): + if node.op != "call_function" or node.target not in self.cat_ops: + continue + if node.graph is None: + continue + + if self._eliminate_single_input_cat(node): + modified = True + continue + + if self._eliminate_cat_then_slice(node): + modified = True + continue + + if self._eliminate_slice_then_cat(node): + modified = True + continue + + if modified: + graph.eliminate_dead_code() + graph_module.recompile() + graph_module = super().call(graph_module).graph_module + + return PassResult(graph_module, modified) + + # ------------------------------------------------------------------ + # Pattern 1: single-input cat + # ------------------------------------------------------------------ + @staticmethod + def _eliminate_single_input_cat(cat_node: torch.fx.Node) -> bool: + inputs = cat_node.args[0] + if not isinstance(inputs, (list, tuple)) or len(inputs) != 1: + return False + sole_input = inputs[0] + assert isinstance(sole_input, torch.fx.Node) + _assert_qparams_match(sole_input, cat_node) + cat_node.replace_all_uses_with(sole_input) + logger.debug("Eliminated single-input cat: %s", cat_node.name) + return True + + # ------------------------------------------------------------------ + # Patterns 2 & 4: cat -> slice (exact input or sub-range of input) + # ------------------------------------------------------------------ + @staticmethod + def _eliminate_cat_then_slice( + cat_node: torch.fx.Node, + ) -> bool: + cat_inputs = cat_node.args[0] + if not isinstance(cat_inputs, (list, tuple)) or len(cat_inputs) < 2: + return False + + # if the dim does not exist as an arg, it defaults to '0' + output_rank = len(get_first_fake_tensor(cat_node).shape) + cat_dim = _normalize_dim(_int_arg(cat_node, 1, 0), output_rank) + + users = list(cat_node.users.keys()) + if not users: + return False + + # Build the offset map for each concat input along cat_dim. + offsets = [] + offset = 0 + for inp in cat_inputs: + assert isinstance(inp, torch.fx.Node) + inp_shape = get_first_fake_tensor(inp).shape + size = inp_shape[cat_dim] + offsets.append((offset, offset + size, inp)) + offset += size + + # Every user must be a slice_copy on the same dim with step=1. + # Collect validated (node, start, end) for replacement below. + validated_slices: list[tuple[torch.fx.Node, int, int]] = [] + for slice_op in users: + if not _is_valid_slice(slice_op, cat_dim, offset): + return False + if slice_op.args[0] is not cat_node: + return False + _, s_start, s_end, _ = _slice_params(slice_op, offset) + validated_slices.append((slice_op, s_start, s_end)) + + # For each user, try exact match (Pattern 2) then sub-range (Pattern 4). + # Users that cross input boundaries are skipped. + replacements: list[tuple[torch.fx.Node, torch.fx.Node]] = [] + + for slice_op, s_start, s_end in validated_slices: + replacement = _find_slice_replacement( + slice_op, cat_node, cat_dim, s_start, s_end, offsets + ) + if replacement is not None: + replacements.append((slice_op, replacement)) + + if not replacements: + return False + + for old_node, new_node in replacements: + old_node.replace_all_uses_with(new_node) + + logger.debug( + "Eliminated cat-then-slice pattern: %s (%d slices redirected)", + cat_node.name, + len(replacements), + ) + return True + + # ------------------------------------------------------------------ + # Patterns 3 & 5: slice -> cat (contiguous slices, full or partial) + # ------------------------------------------------------------------ + @staticmethod + def _eliminate_slice_then_cat( + cat_node: torch.fx.Node, + ) -> bool: + cat_inputs = cat_node.args[0] + if not isinstance(cat_inputs, (list, tuple)) or len(cat_inputs) < 2: + return False + + output_rank = len(get_first_fake_tensor(cat_node).shape) + cat_dim = _normalize_dim(_int_arg(cat_node, 1, 0), output_rank) + + # All inputs must be slice_copy on the same source tensor and dim, + # with step=1. + source_node = _find_common_slice_source(cat_inputs, cat_dim, output_rank) + if source_node is None: + return False + + source_shape = get_first_fake_tensor(source_node).shape + source_dim_size = source_shape[cat_dim] + + # Verify slices are contiguous (but not necessarily starting at 0). + bounds = _check_contiguous_slices(cat_inputs, source_dim_size) + if bounds is None: + return False + first_start, last_end = bounds + + # Verify output shape matches expectations. + cat_shape = get_first_fake_tensor(cat_node).shape + + if first_start == 0 and last_end == source_dim_size: + # Pattern 3: full coverage — replace with source tensor. + if list(cat_shape) != list(source_shape): + return False + _assert_qparams_match(source_node, cat_node) + cat_node.replace_all_uses_with(source_node) + logger.debug( + "Eliminated slice-then-cat (full): %s -> %s", + cat_node.name, + source_node.name, + ) + else: + # Pattern 5: partial coverage — replace with single slice. + expected_dim_size = last_end - first_start + if cat_shape[cat_dim] != expected_dim_size: + return False + for i, (cs, ss) in enumerate(zip(cat_shape, source_shape)): + if i != cat_dim and cs != ss: # dims must match except for cat_dim + return False + _assert_qparams_match(source_node, cat_node) + graph = cat_node.graph + with graph.inserting_before(cat_node): + new_slice = graph.call_function( + _SLICE_OP, + (source_node, cat_dim, first_start, last_end), + ) + new_slice.meta = cat_node.meta.copy() + cat_node.replace_all_uses_with(new_slice) + logger.debug( + "Eliminated slice-then-cat (partial): %s -> slice(%s, %d, %d:%d)", + cat_node.name, + source_node.name, + cat_dim, + first_start, + last_end, + ) + return True diff --git a/backends/arm/test/passes/test_fuse_concat_pass.py b/backends/arm/test/passes/test_fuse_concat_pass.py new file mode 100644 index 00000000000..8f838e25bcc --- /dev/null +++ b/backends/arm/test/passes/test_fuse_concat_pass.py @@ -0,0 +1,490 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from executorch.backends.arm._passes.fuse_concat_pass import FuseConcatPass +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.test_pipeline import ( + EthosU55PipelineINT, + EthosU85PipelineINT, + PassPipeline, + TosaPipelineFP, +) + + +aten_cat_op = "torch.ops.aten.cat.default" +cat_op = "executorch_exir_dialects_edge__ops_aten_cat_default" +slice_op = "executorch_exir_dialects_edge__ops_aten_slice_copy_Tensor" + + +class SingleInputCat(torch.nn.Module): + """Pattern 1: cat with a single input is a no-op.""" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.cat([x], dim=0) + + data = (torch.randn(2, 3, 4),) + ops_before_pass = {cat_op: 1} + ops_after_pass: dict = {} + ops_not_after_pass = [cat_op] + + +class CatThenSlice(torch.nn.Module): + """Pattern 2: cat followed by slices that extract exactly the inputs.""" + + def forward(self, a: torch.Tensor, b: torch.Tensor) -> tuple[torch.Tensor, ...]: + combined = torch.cat([a, b], dim=1) + # Extract exactly a and b back out + part_a = combined[:, :3, :] + part_b = combined[:, 3:, :] + return part_a + 1, part_b + 1 + + data = (torch.randn(1, 3, 4), torch.randn(1, 5, 4)) + ops_before_pass = {cat_op: 1, slice_op: 2} + ops_after_pass: dict = {} + ops_not_after_pass = [cat_op, slice_op] + + +class SliceThenCat(torch.nn.Module): + """Pattern 3: contiguous slices of the same tensor concatenated back.""" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + a = x[:, :3, :] + b = x[:, 3:, :] + return torch.cat([a, b], dim=1) + + data = (torch.randn(1, 8, 4),) + ops_before_pass = {cat_op: 1, slice_op: 2} + ops_after_pass: dict = {} + ops_not_after_pass = [cat_op, slice_op] + + +class CatNotEliminated(torch.nn.Module): + """Negative test: cat of different tensors should NOT be eliminated.""" + + def forward(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + return torch.cat([a, b], dim=1) + + data = (torch.randn(1, 3, 4), torch.randn(1, 5, 4)) + ops_before_pass = {cat_op: 1} + ops_after_pass = {cat_op: 1} + + +class SliceThenCatPartial(torch.nn.Module): + """Negative test: non-contiguous slices should NOT be eliminated.""" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + a = x[:, :3, :] + b = x[:, 4:, :] # Gap at index 3 + return torch.cat([a, b], dim=1) + + data = (torch.randn(1, 8, 4),) + ops_before_pass = {cat_op: 1} + ops_after_pass = {cat_op: 1} + + +class CatThenSliceMismatch(torch.nn.Module): + """Negative test: slices that don't match original inputs.""" + + def forward(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + combined = torch.cat([a, b], dim=1) + return combined[:, 1:5, :] # Crosses the boundary + + data = (torch.randn(1, 3, 4), torch.randn(1, 5, 4)) + ops_before_pass = {cat_op: 1} + ops_after_pass = {cat_op: 1} + + +class CatThenSliceWithStep(torch.nn.Module): + """Negative test: slices with step != 1 should NOT be eliminated.""" + + def forward(self, a: torch.Tensor, b: torch.Tensor) -> tuple[torch.Tensor, ...]: + combined = torch.cat([a, b], dim=1) + part_a = combined[:, :3:2, :] # step=2, output shape differs from a + part_b = combined[:, 3::1, :] + return part_a + 1, part_b + 1 + + data = (torch.randn(1, 3, 4), torch.randn(1, 5, 4)) + ops_before_pass = {cat_op: 1} + ops_after_pass = {cat_op: 1} + + +class CatThenSliceMixedUsers(torch.nn.Module): + """Mixed cat users: one exact-match slice and one cross-boundary slice. + + The exact-match slice (Pattern 2) is rewritten to point at ``a`` directly; + the cross-boundary slice has no replacement and keeps the cat alive. Tests + the partial-fusion branch in ``_eliminate_cat_then_slice``. + """ + + def forward(self, a: torch.Tensor, b: torch.Tensor) -> tuple[torch.Tensor, ...]: + combined = torch.cat([a, b], dim=1) # combined dim1=8 + return combined[:, :3, :], combined[:, 1:5, :] # exact + cross-boundary + + data = (torch.randn(1, 3, 4), torch.randn(1, 5, 4)) + ops_before_pass = {cat_op: 1, slice_op: 2} + ops_after_pass = {cat_op: 1, slice_op: 1} + + +class SliceThenCatDifferentSources(torch.nn.Module): + """Slice-then-cat where slices come from different source tensors. + + ``_find_common_slice_source`` detects the source mismatch and bails; + the cat (and both slices) survive. + + """ + + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + a = x[:, :3, :] + b = y[:, :3, :] + return torch.cat([a, b], dim=1) + + data = (torch.randn(1, 5, 4), torch.randn(1, 5, 4)) + ops_before_pass = {cat_op: 1, slice_op: 2} + ops_after_pass = {cat_op: 1, slice_op: 2} + + +class CatThenSubSlice(torch.nn.Module): + """Pattern 4: slice extracts a sub-range within one concat input.""" + + def forward(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + combined = torch.cat([a, b], dim=1) # a dim1=6, b dim1=4 + # Range [1,5) falls entirely within a's range [0,6) + return combined[:, 1:5, :] + 1 + + data = (torch.randn(1, 6, 4), torch.randn(1, 4, 4)) + ops_before_pass = {cat_op: 1, slice_op: 1} + ops_after_pass = {slice_op: 1} + ops_not_after_pass = [cat_op] + + +class CatThenSubSliceSecondInput(torch.nn.Module): + """Pattern 4: sub-slice within second concat input (tests offset adjust).""" + + def forward(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + combined = torch.cat([a, b], dim=1) # a dim1=3, b dim1=8 + # Range [5,9) falls within b's range [3,11), adjusted to [2,6) on b + return combined[:, 5:9, :] + 1 + + data = (torch.randn(1, 3, 4), torch.randn(1, 8, 4)) + ops_before_pass = {cat_op: 1, slice_op: 1} + ops_after_pass = {slice_op: 1} + ops_not_after_pass = [cat_op] + + +class SliceThenCatPartialContiguous(torch.nn.Module): + """Pattern 5: contiguous slices covering a sub-range of the dimension.""" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + a = x[:, 2:5, :] + b = x[:, 5:8, :] + return torch.cat([a, b], dim=1) # Equivalent to x[:, 2:8, :] + + data = (torch.randn(1, 10, 4),) + ops_before_pass = {cat_op: 1, slice_op: 2} + ops_after_pass = {slice_op: 1} + ops_not_after_pass = [cat_op] + + +class CatThenSubSliceNegativeIndex(torch.nn.Module): + """Pattern 4 with negative slice bounds. + + ``combined[:, -6:-2, :]`` with combined dim1=11 normalizes to + ``combined[:, 5:9, :]``, which falls within b's range [3, 11) and + becomes ``b[:, 2:6, :]`` after fusion. + + """ + + def forward(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + combined = torch.cat([a, b], dim=1) + return combined[:, -6:-2, :] + 1 + + data = (torch.randn(1, 3, 4), torch.randn(1, 8, 4)) + ops_before_pass = {cat_op: 1, slice_op: 1} + ops_after_pass = {slice_op: 1} + ops_not_after_pass = [cat_op] + + +class SliceThenCatPartialNegativeIndex(torch.nn.Module): + """Pattern 5 with negative slice bounds. + + With x dim1=10, ``x[:, -8:-5, :]`` normalizes to ``x[:, 2:5, :]`` and + ``x[:, -5:-2, :]`` to ``x[:, 5:8, :]``; together equivalent to a single + ``x[:, 2:8, :]``. + + """ + + def forward(self, x: torch.Tensor) -> torch.Tensor: + a = x[:, -8:-5, :] + b = x[:, -5:-2, :] + return torch.cat([a, b], dim=1) + + data = (torch.randn(1, 10, 4),) + ops_before_pass = {cat_op: 1, slice_op: 2} + ops_after_pass = {slice_op: 1} + ops_not_after_pass = [cat_op] + + +class CatNegDimThenSlice(torch.nn.Module): + """Pattern 2 with a negative cat dim. + + Exercises ``_normalize_dim`` on the cat side: ``dim=-2`` on a rank-3 tensor + must resolve to ``dim=1`` so that the offset map and slice matching line up + against the same axis. + + """ + + def forward(self, a: torch.Tensor, b: torch.Tensor) -> tuple[torch.Tensor, ...]: + combined = torch.cat([a, b], dim=-2) + part_a = combined[:, :3, :] + part_b = combined[:, 3:, :] + return part_a + 1, part_b + 1 + + data = (torch.randn(1, 3, 4), torch.randn(1, 5, 4)) + ops_before_pass = {cat_op: 1, slice_op: 2} + ops_after_pass: dict = {} + ops_not_after_pass = [cat_op, slice_op] + + +class CatThenSliceWithNonSliceUser(torch.nn.Module): + """Negative test: a cat with both a slice consumer and a non-slice + consumer (an ``add``). The non-slice user means the cat must be kept + alive even though one user is an exact-match slice. + """ + + def forward(self, a: torch.Tensor, b: torch.Tensor) -> tuple[torch.Tensor, ...]: + combined = torch.cat([a, b], dim=1) + return combined[:, :3, :] + 1, combined + 0.5 + + data = (torch.randn(1, 3, 4), torch.randn(1, 5, 4)) + ops_before_pass = {cat_op: 1, slice_op: 1} + ops_after_pass = {cat_op: 1, slice_op: 1} + + +positive_tests = { + "single_input_cat": SingleInputCat(), + "cat_then_slice": CatThenSlice(), + "slice_then_cat": SliceThenCat(), + "cat_then_sub_slice": CatThenSubSlice(), + "cat_then_sub_slice_second_input": CatThenSubSliceSecondInput(), + "slice_then_cat_partial_contiguous": SliceThenCatPartialContiguous(), + "cat_then_sub_slice_negative_index": CatThenSubSliceNegativeIndex(), + "slice_then_cat_partial_negative_index": SliceThenCatPartialNegativeIndex(), + "cat_neg_dim_then_slice": CatNegDimThenSlice(), +} + +negative_tests = { + "cat_not_eliminated": CatNotEliminated(), + "slice_then_cat_partial": SliceThenCatPartial(), + "cat_then_slice_mismatch": CatThenSliceMismatch(), + "cat_then_slice_with_step": CatThenSliceWithStep(), + "cat_then_slice_mixed_users": CatThenSliceMixedUsers(), + "slice_then_cat_different_sources": SliceThenCatDifferentSources(), + "cat_then_slice_with_non_slice_user": CatThenSliceWithNonSliceUser(), +} + + +@common.parametrize("model", positive_tests) +def test_fuse_concat_eliminates(model): + pipeline = PassPipeline( + model, + model.data, + quantize=False, + ops_before_pass=model.ops_before_pass, + ops_after_pass=model.ops_after_pass, + ops_not_after_pass=getattr(model, "ops_not_after_pass", []), + pass_list=[FuseConcatPass], + ) + pipeline.run() + + +@common.parametrize("model", negative_tests) +def test_fuse_concat_preserves(model): + pipeline = PassPipeline( + model, + model.data, + quantize=False, + ops_before_pass=model.ops_before_pass, + ops_after_pass=model.ops_after_pass, + pass_list=[FuseConcatPass], + ) + pipeline.run() + + +def test_find_common_slice_source_skips_non_node_input(): + """Defensive guard in ``_find_common_slice_source``: when ``cat_inputs`` + contains a non-Node entry, return None so slice-then-cat fusion bails. + + Such graphs aren't producible from normal ``torch.cat()`` (which always + yields Node inputs in FX), so we exercise the helper directly rather + than constructing a malformed FX graph. + + """ + from executorch.backends.arm._passes.fuse_concat_pass import ( + _find_common_slice_source, + ) + + assert _find_common_slice_source([None, None], cat_dim=0, dim_size=10) is None + assert _find_common_slice_source(["x", "y"], cat_dim=0, dim_size=10) is None + + +# End-to-end Ethos-U pipeline test models. The PassPipeline modules above wire +# the pattern directly to the model's forward args, which is fine for unit-level +# pass validation but causes I/O signature drift through Vela whenever the pass +# eliminates the only use of an input placeholder. These wrappers feed each +# pattern through a single input and surrounding compute so that lowering sees +# realistic graphs. + + +class SingleInputCatEthos(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.cat([x + 1], dim=0) + 2 + + data = (torch.randn(2, 3, 4),) + + +class CatThenSliceEthos(torch.nn.Module): + def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, ...]: + a = x[:, :3, :] + 0.5 + b = x[:, 3:, :] + 0.5 + combined = torch.cat([a, b], dim=1) + return combined[:, :3, :] + 1, combined[:, 3:, :] + 1 + + data = (torch.randn(1, 8, 4),) + + +class SliceThenCatEthos(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + y = x + 1 + a = y[:, :3, :] + b = y[:, 3:, :] + return torch.cat([a, b], dim=1) + 2 + + data = (torch.randn(1, 8, 4),) + + +class CatThenSubSliceEthos(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + a = x[:, :6, :] + 0.5 + b = x[:, 6:, :] + 0.5 + combined = torch.cat([a, b], dim=1) + return combined[:, 1:5, :] + 1 + + data = (torch.randn(1, 10, 4),) + + +class CatThenSubSliceSecondInputEthos(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + a = x[:, :3, :] + 0.5 + b = x[:, 3:, :] + 0.5 + combined = torch.cat([a, b], dim=1) + return combined[:, 5:9, :] + 1 + + data = (torch.randn(1, 11, 4),) + + +class SliceThenCatPartialContiguousEthos(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + a = x[:, 2:5, :] + b = x[:, 5:8, :] + return torch.cat([a, b], dim=1) + 1 + + data = (torch.randn(1, 10, 4),) + + +class CatThenSubSliceNegativeIndexEthos(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + a = x[:, :3, :] + 0.5 + b = x[:, 3:, :] + 0.5 + combined = torch.cat([a, b], dim=1) + return combined[:, -6:-2, :] + 1 + + data = (torch.randn(1, 11, 4),) + + +class CatNegDimThenSliceEthos(torch.nn.Module): + def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, ...]: + a = x[:, :3, :] + 0.5 + b = x[:, 3:, :] + 0.5 + combined = torch.cat([a, b], dim=-2) + return combined[:, :3, :] + 1, combined[:, 3:, :] + 1 + + data = (torch.randn(1, 8, 4),) + + +ethos_tests = { + "single_input_cat": SingleInputCatEthos(), + "cat_then_slice": CatThenSliceEthos(), + "slice_then_cat": SliceThenCatEthos(), + "cat_then_sub_slice": CatThenSubSliceEthos(), + "cat_then_sub_slice_second_input": CatThenSubSliceSecondInputEthos(), + "slice_then_cat_partial_contiguous": SliceThenCatPartialContiguousEthos(), + "cat_then_sub_slice_negative_index": CatThenSubSliceNegativeIndexEthos(), + "cat_neg_dim_then_slice": CatNegDimThenSliceEthos(), +} + + +@common.parametrize("model", ethos_tests) +def test_fuse_concat_tosa_FP(model): + """Verify the fusion pass eliminates all CONCAT ops in the lowered TOSA graph. + + Ethos pipelines emit Vela command streams, which the operator distribution + helper cannot inspect, so we lower to TOSA-FP and parse the flatbuffer + instead. + + """ + pipeline = TosaPipelineFP[type(model.data)]( + model, + model.data, + aten_op=aten_cat_op, + run_on_tosa_ref_model=False, + ) + pipeline.count_tosa_ops({"CONCAT": 0}) + pipeline.run() + + +@common.parametrize("model", negative_tests) +def test_fuse_concat_tosa_FP_preserves(model): + """Control for ``test_fuse_concat_tosa_FP``. + + Confirms CONCAT survives lowering when the fusion patterns don't apply, + which (a) proves the count mechanism reports non-zero counts, and (b) + catches regressions where fusion incorrectly fires on these patterns. + + """ + pipeline = TosaPipelineFP[type(model.data)]( + model, + model.data, + aten_op=aten_cat_op, + run_on_tosa_ref_model=False, + ) + pipeline.count_tosa_ops({"CONCAT": 1}) + pipeline.run() + + +@common.parametrize("model", ethos_tests) +@common.XfailIfNoCorstone300 +def test_fuse_concat_u55_INT(model): + pipeline = EthosU55PipelineINT( + model, + model.data, + aten_cat_op, + cat_op, + ) + pipeline.run() + + +@common.parametrize("model", ethos_tests) +@common.XfailIfNoCorstone320 +def test_fuse_concat_u85_INT(model): + pipeline = EthosU85PipelineINT( + model, + model.data, + aten_cat_op, + cat_op, + ) + pipeline.run()