From 5cc8286ded3e381dfbbcaa7b5e622e4c8a13801e Mon Sep 17 00:00:00 2001 From: Ryan Monroe Date: Fri, 8 May 2026 09:47:07 -0700 Subject: [PATCH] Add FuseConcatPass to eliminate redundant concat ops (#18827) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: Adds an FX-level pass that eliminates concat ops which can be proven structurally redundant before the TOSA backend / Vela compiler ever sees them. In the Gen2 Executorch ARM / Ethos-U stack, `torch.cat` lowers to TOSA `CONCAT`, which Vela converts to N `MemoryCopy` ops — real DMA on the NPU. Catching the obvious cases up front keeps the TOSA flatbuffer fed to Vela smaller, keeps debug graphs honest, and provides defensive coverage on TOSA targets where Vela's own scheduler doesn't run (e.g., the VGF backend). Five rewrite patterns are handled (inspired by Espresso's `bolt/nn/espresso/transforms/remove_nops.py`): 1. **Single-input concat**: `cat([x], dim) ≡ x` — replace cat with x. 2. **Concat-then-slice (exact)**: `cat([a, b, ...], dim)` feeding a `slice_copy` that extracts exactly one original input — replace the slice with the corresponding cat input directly. 3. **Slice-then-concat (full)**: `cat([slice(x, d, s0, e0), slice(x, d, s1, e1), ...], dim)` reconstructing x exactly (contiguous slices covering the full source dimension) — replace cat with x. 4. **Concat-then-sub-slice**: a `slice_copy` whose range falls entirely within one cat input — replace with an adjusted slice on that input directly. 5. **Slice-then-concat (partial)**: contiguous slices of the same tensor concatenated back but covering only a sub-range of the source — replace with a single slice on the source. ## Empirical impact across the production EMG model fleet Measured by running every `frl/ctrl/torchstream/torchstream/pt2/tests/test_emg_lowering_*` quantize+lower test with FuseConcatPass instrumented to log per-call counters, then comparing against the same target with the pass commented out in `arm_pass_manager.py`. All 8 model targets pass under both configurations. | Model | Cats scanned | Eliminated | Pattern fired | | --- | --- | --- | --- | | cascade_classifier | 5 | **3 (60%)** | single-input | | mux_fusion | 8 | **3 (38%)** | single-input | | combined_control | 11 | **3 (27%)** | single-input | | cascade_detector | 14 | 0 | — | | cascade_hw_classifier | 12 | 0 | — | | handwriting | 106 | 0 | — | | wake | 11 | 0 | — | | auth | 6 | 0 | — | | **Total** | **173** | **9** | all single-input | ## Two findings worth highlighting **Patterns 2–5 (the slice-related rewrites) never matched on any production EMG model.** PyTorch's Aten lowering on this fleet doesn't produce the cat↔slice algebra these patterns target. They remain useful for non-EMG TOSA workloads — and for the VGF backend where Vela's own optimizer doesn't run — but on the current EMG production set they are unexercised. **Vela already folds single-input cats during compilation.** A before/after measurement on cascade_classifier (the model with the highest hit rate, 3/5 cats eliminated) shows Vela emits the same 9 `MemoryCopy` ops and consumes the same 481,339 NPU cycles either way. The eliminated cats reappear in the Vela operator table as `Reshape → MemoryCopy` instead of `Concat → MemoryCopy`. Total NPU runtime is unchanged. Pre-Vela artifacts do shrink (TOSA flatbuffer −16 KB / −0.68%, peak staging −1.5 KB / −0.45%), but post-Vela on-device performance is identical. ## Net effect This pass is value-additive even where it doesn't move NPU cycles: - Cleaner TOSA fed into Vela (~16 KB smaller per cascade_classifier instance). - Slightly tighter peak staging during Vela scheduling (~1.5 KB). - Defensive coverage for TOSA-only targets without a Vela-grade scheduler (notably the VGF / Vulkan path). - More truthful FX / EXIR debug graphs — concats that were genuinely no-ops no longer show up in `model-explorer`, `delegation_metadata.json`, or the lowered graph dumps. It does **not** produce measurable NPU cycle savings on the current EMG production fleet. The patterns that would have produced real Vela savings (cat↔slice algebra) don't appear in these models. Authored with Claude. Differential Revision: D97667069 --- backends/arm/_passes/__init__.py | 1 + backends/arm/_passes/arm_pass_manager.py | 2 + backends/arm/_passes/fuse_concat_pass.py | 379 ++++++++++++++ .../arm/test/passes/test_fuse_concat_pass.py | 490 ++++++++++++++++++ 4 files changed, 872 insertions(+) create mode 100644 backends/arm/_passes/fuse_concat_pass.py create mode 100644 backends/arm/test/passes/test_fuse_concat_pass.py diff --git a/backends/arm/_passes/__init__.py b/backends/arm/_passes/__init__.py index 174a1960aab..1fc1efae797 100644 --- a/backends/arm/_passes/__init__.py +++ b/backends/arm/_passes/__init__.py @@ -103,6 +103,7 @@ QuantizeClampArgumentsPass, ) from .fuse_batch_norm2d_pass import FuseBatchNorm2dPass # noqa +from .fuse_concat_pass import FuseConcatPass # noqa from .fuse_consecutive_concat_shapes import FuseConsecutiveConcatShapesPass # noqa from .fuse_consecutive_rescales_pass import FuseConsecutiveRescalesPass # noqa from .fuse_constant_ops_pass import ( # noqa diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index ebe6c4591e6..e3c3c35b033 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -100,6 +100,7 @@ EnsureUniqueOutputNodesPass, FoldAndAnnotateQParamsPass, FuseBatchNorm2dPass, + FuseConcatPass, FuseConsecutiveConcatShapesPass, FuseConsecutiveRescalesPass, FuseConstantArgsPass, @@ -532,6 +533,7 @@ def _tosa_pipeline( # Aten -> TOSA transformation passes self.add_passes( [ + FuseConcatPass(), RewriteUpsamplePass(), RewriteMaxPool2dPass(), RewriteConvPass(exported_program), diff --git a/backends/arm/_passes/fuse_concat_pass.py b/backends/arm/_passes/fuse_concat_pass.py new file mode 100644 index 00000000000..bce32b58d8e --- /dev/null +++ b/backends/arm/_passes/fuse_concat_pass.py @@ -0,0 +1,379 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from typing import Set, Type + +import torch.fx +from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, PassResult + +logger = logging.getLogger(__name__) + + +def _int_arg(node: torch.fx.Node, index: int, default: int) -> int: + """Get an integer argument from a node, with a default if missing.""" + val = node.args[index] if len(node.args) > index else default + assert isinstance(val, int) + return val + + +def _normalize_dim(dim: int, rank: int) -> int: + """Normalize a (possibly negative) dim index to ``[0, rank)``.""" + return (dim + rank) % rank + + +def _slice_params(node: torch.fx.Node, dim_size: int) -> tuple[int, int, int, int]: + """Extract (dim, start, end, step) from a slice_copy node. + + ``dim``, ``start``, and ``end`` are normalized to non-negative indices in + ``[0, dim_size]`` (matching PyTorch slice semantics, where negative bounds + count from the end of the dimension). ``dim_size`` is the size of the + source tensor along the slice dimension. + + """ + rank = len(get_first_fake_tensor(node).shape) + dim = _normalize_dim(_int_arg(node, 1, 0), rank) + start = _int_arg(node, 2, 0) + end = _int_arg(node, 3, dim_size) + if start < 0: + start += dim_size + if end < 0: + end += dim_size + start = max(0, min(start, dim_size)) + end = max(0, min(end, dim_size)) + step = _int_arg(node, 4, 1) + return dim, start, end, step + + +_SLICE_OP = exir_ops.edge.aten.slice_copy.Tensor + + +def _assert_qparams_match(source: torch.fx.Node, cat_node: torch.fx.Node) -> None: + """Verify the data-providing node's output_qparams match cat_node's. + + FuseConcatPass runs after FoldAndAnnotateQParamsPass, so cat_node may carry + output_qparams. TOSA cat preserves scale across inputs and output, so any + node we route consumers to (directly, or via a new slice that reads from + it) must produce data with matching qparams. + + Skipped when either side has no output_qparams entries: FP graphs (key + missing), Q ops (qparams live in args), and tosa_rescale ops (qparams in + args; folder leaves an empty dict). The assertion is best-effort, intended + to catch upstream bugs that produce two foldable ops with diverging + output_qparams. + + """ + cat_qp = cat_node.meta.get("output_qparams") + src_qp = source.meta.get("output_qparams") + if not cat_qp or not src_qp: + return + assert cat_qp == src_qp, ( + f"FuseConcatPass: output_qparams mismatch — {source.name} {src_qp} != " + f"cat {cat_node.name} {cat_qp}" + ) + + +def _is_valid_slice(node: torch.fx.Node, cat_dim: int, dim_size: int) -> bool: + """Check that node is a slice_copy on cat_dim with step=1.""" + if node.target != _SLICE_OP: + return False + s_dim, _, _, s_step = _slice_params(node, dim_size) + return s_dim == cat_dim and s_step == 1 + + +def _find_slice_replacement( + slice_op: torch.fx.Node, + cat_node: torch.fx.Node, + cat_dim: int, + s_start: int, + s_end: int, + offsets: list[tuple[int, int, torch.fx.Node]], +) -> torch.fx.Node | None: + """Find a replacement for a slice that consumes a cat output. + + ``offsets`` maps each concat input to its range in the concatenated + output: [(start, end, input_node), ...] along ``cat_dim``. + + Returns the replacement node (exact input match or adjusted sub-slice), + or None if the slice crosses input boundaries. + + """ + for o_start, o_end, inp in offsets: + if s_start == o_start and s_end == o_end: + _assert_qparams_match(inp, cat_node) + return inp + if s_start >= o_start and s_end <= o_end: + _assert_qparams_match(inp, cat_node) + graph = cat_node.graph + with graph.inserting_before(slice_op): + new_slice = graph.call_function( + _SLICE_OP, + (inp, cat_dim, s_start - o_start, s_end - o_start), + ) + new_slice.meta = slice_op.meta.copy() + return new_slice + return None + + +def _find_common_slice_source( + cat_inputs: list | tuple, + cat_dim: int, + dim_size: int, +) -> torch.fx.Node | None: + """Check all inputs are valid slices of the same source. + + Returns the source. + + """ + source_node = None + for inp in cat_inputs: + if not isinstance(inp, torch.fx.Node): + return None + if not _is_valid_slice(inp, cat_dim, dim_size): + return None + slice_source = inp.args[0] + if source_node is None: + source_node = slice_source + elif slice_source is not source_node: + return None + assert isinstance(source_node, torch.fx.Node) + return source_node + + +def _check_contiguous_slices( + cat_inputs: list | tuple, + source_dim_size: int, +) -> tuple[int, int] | None: + """Check slices are contiguous. + + Returns (first_start, last_end) or None. + + """ + _, first_start, _, _ = _slice_params(cat_inputs[0], source_dim_size) + expected_start = first_start + for inp in cat_inputs: + _, s_start, s_end, _ = _slice_params(inp, source_dim_size) + if s_start != expected_start: + return None + expected_start = s_end + + # expected_start is now the end of the last slice + return first_start, expected_start + + +class FuseConcatPass(ArmPass): + """Eliminate redundant concat (cat) operations via graph pattern matching. + + This pass recognizes and removes concat operations that can be proven to + produce no useful data movement. Eliminating these at the FX/TOSA level + prevents Vela from generating MemoryCopy operations on the Ethos-U NPU. + + Five patterns are handled: + + 1. Single-input concat: cat([x], dim) is a no-op; replace with x. + 2. Concat-then-slice (exact): if a consumer of cat([a, b, ...], dim) is + a slice_copy that extracts exactly one original input, replace it + with the corresponding concat input directly. + 3. Slice-then-concat (full): if cat([slice(x, d, s0, e0), + slice(x, d, s1, e1), ...], dim) reconstructs x exactly (contiguous + slices covering the full source dimension), replace with x. + 4. Concat-then-sub-slice: if a consumer of cat([a, b, ...], dim) is a + slice_copy whose range falls entirely within one original input, + replace it with an adjusted slice on that input directly. + 5. Slice-then-concat (partial): if contiguous slices of the same tensor + are concatenated but cover only a sub-range of the source dimension, + replace with a single slice on the source. + + """ + + _passes_required_after: Set[Type[ExportPass]] = set() + + cat_ops = { + exir_ops.edge.aten.cat.default, + } + slice_op = _SLICE_OP + + def call(self, graph_module: torch.fx.GraphModule): + modified = False + graph = graph_module.graph + + for node in list(graph.nodes): + if node.op != "call_function" or node.target not in self.cat_ops: + continue + if node.graph is None: + continue + + if self._eliminate_single_input_cat(node): + modified = True + continue + + if self._eliminate_cat_then_slice(node): + modified = True + continue + + if self._eliminate_slice_then_cat(node): + modified = True + continue + + if modified: + graph.eliminate_dead_code() + graph_module.recompile() + graph_module = super().call(graph_module).graph_module + + return PassResult(graph_module, modified) + + # ------------------------------------------------------------------ + # Pattern 1: single-input cat + # ------------------------------------------------------------------ + @staticmethod + def _eliminate_single_input_cat(cat_node: torch.fx.Node) -> bool: + inputs = cat_node.args[0] + if not isinstance(inputs, (list, tuple)) or len(inputs) != 1: + return False + sole_input = inputs[0] + assert isinstance(sole_input, torch.fx.Node) + _assert_qparams_match(sole_input, cat_node) + cat_node.replace_all_uses_with(sole_input) + logger.debug("Eliminated single-input cat: %s", cat_node.name) + return True + + # ------------------------------------------------------------------ + # Patterns 2 & 4: cat -> slice (exact input or sub-range of input) + # ------------------------------------------------------------------ + @staticmethod + def _eliminate_cat_then_slice( + cat_node: torch.fx.Node, + ) -> bool: + cat_inputs = cat_node.args[0] + if not isinstance(cat_inputs, (list, tuple)) or len(cat_inputs) < 2: + return False + + # if the dim does not exist as an arg, it defaults to '0' + output_rank = len(get_first_fake_tensor(cat_node).shape) + cat_dim = _normalize_dim(_int_arg(cat_node, 1, 0), output_rank) + + users = list(cat_node.users.keys()) + if not users: + return False + + # Build the offset map for each concat input along cat_dim. + offsets = [] + offset = 0 + for inp in cat_inputs: + assert isinstance(inp, torch.fx.Node) + inp_shape = get_first_fake_tensor(inp).shape + size = inp_shape[cat_dim] + offsets.append((offset, offset + size, inp)) + offset += size + + # Every user must be a slice_copy on the same dim with step=1. + # Collect validated (node, start, end) for replacement below. + validated_slices: list[tuple[torch.fx.Node, int, int]] = [] + for slice_op in users: + if not _is_valid_slice(slice_op, cat_dim, offset): + return False + if slice_op.args[0] is not cat_node: + return False + _, s_start, s_end, _ = _slice_params(slice_op, offset) + validated_slices.append((slice_op, s_start, s_end)) + + # For each user, try exact match (Pattern 2) then sub-range (Pattern 4). + # Users that cross input boundaries are skipped. + replacements: list[tuple[torch.fx.Node, torch.fx.Node]] = [] + + for slice_op, s_start, s_end in validated_slices: + replacement = _find_slice_replacement( + slice_op, cat_node, cat_dim, s_start, s_end, offsets + ) + if replacement is not None: + replacements.append((slice_op, replacement)) + + if not replacements: + return False + + for old_node, new_node in replacements: + old_node.replace_all_uses_with(new_node) + + logger.debug( + "Eliminated cat-then-slice pattern: %s (%d slices redirected)", + cat_node.name, + len(replacements), + ) + return True + + # ------------------------------------------------------------------ + # Patterns 3 & 5: slice -> cat (contiguous slices, full or partial) + # ------------------------------------------------------------------ + @staticmethod + def _eliminate_slice_then_cat( + cat_node: torch.fx.Node, + ) -> bool: + cat_inputs = cat_node.args[0] + if not isinstance(cat_inputs, (list, tuple)) or len(cat_inputs) < 2: + return False + + output_rank = len(get_first_fake_tensor(cat_node).shape) + cat_dim = _normalize_dim(_int_arg(cat_node, 1, 0), output_rank) + + # All inputs must be slice_copy on the same source tensor and dim, + # with step=1. + source_node = _find_common_slice_source(cat_inputs, cat_dim, output_rank) + if source_node is None: + return False + + source_shape = get_first_fake_tensor(source_node).shape + source_dim_size = source_shape[cat_dim] + + # Verify slices are contiguous (but not necessarily starting at 0). + bounds = _check_contiguous_slices(cat_inputs, source_dim_size) + if bounds is None: + return False + first_start, last_end = bounds + + # Verify output shape matches expectations. + cat_shape = get_first_fake_tensor(cat_node).shape + + if first_start == 0 and last_end == source_dim_size: + # Pattern 3: full coverage — replace with source tensor. + if list(cat_shape) != list(source_shape): + return False + _assert_qparams_match(source_node, cat_node) + cat_node.replace_all_uses_with(source_node) + logger.debug( + "Eliminated slice-then-cat (full): %s -> %s", + cat_node.name, + source_node.name, + ) + else: + # Pattern 5: partial coverage — replace with single slice. + expected_dim_size = last_end - first_start + if cat_shape[cat_dim] != expected_dim_size: + return False + for i, (cs, ss) in enumerate(zip(cat_shape, source_shape)): + if i != cat_dim and cs != ss: # dims must match except for cat_dim + return False + _assert_qparams_match(source_node, cat_node) + graph = cat_node.graph + with graph.inserting_before(cat_node): + new_slice = graph.call_function( + _SLICE_OP, + (source_node, cat_dim, first_start, last_end), + ) + new_slice.meta = cat_node.meta.copy() + cat_node.replace_all_uses_with(new_slice) + logger.debug( + "Eliminated slice-then-cat (partial): %s -> slice(%s, %d, %d:%d)", + cat_node.name, + source_node.name, + cat_dim, + first_start, + last_end, + ) + return True diff --git a/backends/arm/test/passes/test_fuse_concat_pass.py b/backends/arm/test/passes/test_fuse_concat_pass.py new file mode 100644 index 00000000000..8f838e25bcc --- /dev/null +++ b/backends/arm/test/passes/test_fuse_concat_pass.py @@ -0,0 +1,490 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from executorch.backends.arm._passes.fuse_concat_pass import FuseConcatPass +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.test_pipeline import ( + EthosU55PipelineINT, + EthosU85PipelineINT, + PassPipeline, + TosaPipelineFP, +) + + +aten_cat_op = "torch.ops.aten.cat.default" +cat_op = "executorch_exir_dialects_edge__ops_aten_cat_default" +slice_op = "executorch_exir_dialects_edge__ops_aten_slice_copy_Tensor" + + +class SingleInputCat(torch.nn.Module): + """Pattern 1: cat with a single input is a no-op.""" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.cat([x], dim=0) + + data = (torch.randn(2, 3, 4),) + ops_before_pass = {cat_op: 1} + ops_after_pass: dict = {} + ops_not_after_pass = [cat_op] + + +class CatThenSlice(torch.nn.Module): + """Pattern 2: cat followed by slices that extract exactly the inputs.""" + + def forward(self, a: torch.Tensor, b: torch.Tensor) -> tuple[torch.Tensor, ...]: + combined = torch.cat([a, b], dim=1) + # Extract exactly a and b back out + part_a = combined[:, :3, :] + part_b = combined[:, 3:, :] + return part_a + 1, part_b + 1 + + data = (torch.randn(1, 3, 4), torch.randn(1, 5, 4)) + ops_before_pass = {cat_op: 1, slice_op: 2} + ops_after_pass: dict = {} + ops_not_after_pass = [cat_op, slice_op] + + +class SliceThenCat(torch.nn.Module): + """Pattern 3: contiguous slices of the same tensor concatenated back.""" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + a = x[:, :3, :] + b = x[:, 3:, :] + return torch.cat([a, b], dim=1) + + data = (torch.randn(1, 8, 4),) + ops_before_pass = {cat_op: 1, slice_op: 2} + ops_after_pass: dict = {} + ops_not_after_pass = [cat_op, slice_op] + + +class CatNotEliminated(torch.nn.Module): + """Negative test: cat of different tensors should NOT be eliminated.""" + + def forward(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + return torch.cat([a, b], dim=1) + + data = (torch.randn(1, 3, 4), torch.randn(1, 5, 4)) + ops_before_pass = {cat_op: 1} + ops_after_pass = {cat_op: 1} + + +class SliceThenCatPartial(torch.nn.Module): + """Negative test: non-contiguous slices should NOT be eliminated.""" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + a = x[:, :3, :] + b = x[:, 4:, :] # Gap at index 3 + return torch.cat([a, b], dim=1) + + data = (torch.randn(1, 8, 4),) + ops_before_pass = {cat_op: 1} + ops_after_pass = {cat_op: 1} + + +class CatThenSliceMismatch(torch.nn.Module): + """Negative test: slices that don't match original inputs.""" + + def forward(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + combined = torch.cat([a, b], dim=1) + return combined[:, 1:5, :] # Crosses the boundary + + data = (torch.randn(1, 3, 4), torch.randn(1, 5, 4)) + ops_before_pass = {cat_op: 1} + ops_after_pass = {cat_op: 1} + + +class CatThenSliceWithStep(torch.nn.Module): + """Negative test: slices with step != 1 should NOT be eliminated.""" + + def forward(self, a: torch.Tensor, b: torch.Tensor) -> tuple[torch.Tensor, ...]: + combined = torch.cat([a, b], dim=1) + part_a = combined[:, :3:2, :] # step=2, output shape differs from a + part_b = combined[:, 3::1, :] + return part_a + 1, part_b + 1 + + data = (torch.randn(1, 3, 4), torch.randn(1, 5, 4)) + ops_before_pass = {cat_op: 1} + ops_after_pass = {cat_op: 1} + + +class CatThenSliceMixedUsers(torch.nn.Module): + """Mixed cat users: one exact-match slice and one cross-boundary slice. + + The exact-match slice (Pattern 2) is rewritten to point at ``a`` directly; + the cross-boundary slice has no replacement and keeps the cat alive. Tests + the partial-fusion branch in ``_eliminate_cat_then_slice``. + """ + + def forward(self, a: torch.Tensor, b: torch.Tensor) -> tuple[torch.Tensor, ...]: + combined = torch.cat([a, b], dim=1) # combined dim1=8 + return combined[:, :3, :], combined[:, 1:5, :] # exact + cross-boundary + + data = (torch.randn(1, 3, 4), torch.randn(1, 5, 4)) + ops_before_pass = {cat_op: 1, slice_op: 2} + ops_after_pass = {cat_op: 1, slice_op: 1} + + +class SliceThenCatDifferentSources(torch.nn.Module): + """Slice-then-cat where slices come from different source tensors. + + ``_find_common_slice_source`` detects the source mismatch and bails; + the cat (and both slices) survive. + + """ + + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + a = x[:, :3, :] + b = y[:, :3, :] + return torch.cat([a, b], dim=1) + + data = (torch.randn(1, 5, 4), torch.randn(1, 5, 4)) + ops_before_pass = {cat_op: 1, slice_op: 2} + ops_after_pass = {cat_op: 1, slice_op: 2} + + +class CatThenSubSlice(torch.nn.Module): + """Pattern 4: slice extracts a sub-range within one concat input.""" + + def forward(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + combined = torch.cat([a, b], dim=1) # a dim1=6, b dim1=4 + # Range [1,5) falls entirely within a's range [0,6) + return combined[:, 1:5, :] + 1 + + data = (torch.randn(1, 6, 4), torch.randn(1, 4, 4)) + ops_before_pass = {cat_op: 1, slice_op: 1} + ops_after_pass = {slice_op: 1} + ops_not_after_pass = [cat_op] + + +class CatThenSubSliceSecondInput(torch.nn.Module): + """Pattern 4: sub-slice within second concat input (tests offset adjust).""" + + def forward(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + combined = torch.cat([a, b], dim=1) # a dim1=3, b dim1=8 + # Range [5,9) falls within b's range [3,11), adjusted to [2,6) on b + return combined[:, 5:9, :] + 1 + + data = (torch.randn(1, 3, 4), torch.randn(1, 8, 4)) + ops_before_pass = {cat_op: 1, slice_op: 1} + ops_after_pass = {slice_op: 1} + ops_not_after_pass = [cat_op] + + +class SliceThenCatPartialContiguous(torch.nn.Module): + """Pattern 5: contiguous slices covering a sub-range of the dimension.""" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + a = x[:, 2:5, :] + b = x[:, 5:8, :] + return torch.cat([a, b], dim=1) # Equivalent to x[:, 2:8, :] + + data = (torch.randn(1, 10, 4),) + ops_before_pass = {cat_op: 1, slice_op: 2} + ops_after_pass = {slice_op: 1} + ops_not_after_pass = [cat_op] + + +class CatThenSubSliceNegativeIndex(torch.nn.Module): + """Pattern 4 with negative slice bounds. + + ``combined[:, -6:-2, :]`` with combined dim1=11 normalizes to + ``combined[:, 5:9, :]``, which falls within b's range [3, 11) and + becomes ``b[:, 2:6, :]`` after fusion. + + """ + + def forward(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + combined = torch.cat([a, b], dim=1) + return combined[:, -6:-2, :] + 1 + + data = (torch.randn(1, 3, 4), torch.randn(1, 8, 4)) + ops_before_pass = {cat_op: 1, slice_op: 1} + ops_after_pass = {slice_op: 1} + ops_not_after_pass = [cat_op] + + +class SliceThenCatPartialNegativeIndex(torch.nn.Module): + """Pattern 5 with negative slice bounds. + + With x dim1=10, ``x[:, -8:-5, :]`` normalizes to ``x[:, 2:5, :]`` and + ``x[:, -5:-2, :]`` to ``x[:, 5:8, :]``; together equivalent to a single + ``x[:, 2:8, :]``. + + """ + + def forward(self, x: torch.Tensor) -> torch.Tensor: + a = x[:, -8:-5, :] + b = x[:, -5:-2, :] + return torch.cat([a, b], dim=1) + + data = (torch.randn(1, 10, 4),) + ops_before_pass = {cat_op: 1, slice_op: 2} + ops_after_pass = {slice_op: 1} + ops_not_after_pass = [cat_op] + + +class CatNegDimThenSlice(torch.nn.Module): + """Pattern 2 with a negative cat dim. + + Exercises ``_normalize_dim`` on the cat side: ``dim=-2`` on a rank-3 tensor + must resolve to ``dim=1`` so that the offset map and slice matching line up + against the same axis. + + """ + + def forward(self, a: torch.Tensor, b: torch.Tensor) -> tuple[torch.Tensor, ...]: + combined = torch.cat([a, b], dim=-2) + part_a = combined[:, :3, :] + part_b = combined[:, 3:, :] + return part_a + 1, part_b + 1 + + data = (torch.randn(1, 3, 4), torch.randn(1, 5, 4)) + ops_before_pass = {cat_op: 1, slice_op: 2} + ops_after_pass: dict = {} + ops_not_after_pass = [cat_op, slice_op] + + +class CatThenSliceWithNonSliceUser(torch.nn.Module): + """Negative test: a cat with both a slice consumer and a non-slice + consumer (an ``add``). The non-slice user means the cat must be kept + alive even though one user is an exact-match slice. + """ + + def forward(self, a: torch.Tensor, b: torch.Tensor) -> tuple[torch.Tensor, ...]: + combined = torch.cat([a, b], dim=1) + return combined[:, :3, :] + 1, combined + 0.5 + + data = (torch.randn(1, 3, 4), torch.randn(1, 5, 4)) + ops_before_pass = {cat_op: 1, slice_op: 1} + ops_after_pass = {cat_op: 1, slice_op: 1} + + +positive_tests = { + "single_input_cat": SingleInputCat(), + "cat_then_slice": CatThenSlice(), + "slice_then_cat": SliceThenCat(), + "cat_then_sub_slice": CatThenSubSlice(), + "cat_then_sub_slice_second_input": CatThenSubSliceSecondInput(), + "slice_then_cat_partial_contiguous": SliceThenCatPartialContiguous(), + "cat_then_sub_slice_negative_index": CatThenSubSliceNegativeIndex(), + "slice_then_cat_partial_negative_index": SliceThenCatPartialNegativeIndex(), + "cat_neg_dim_then_slice": CatNegDimThenSlice(), +} + +negative_tests = { + "cat_not_eliminated": CatNotEliminated(), + "slice_then_cat_partial": SliceThenCatPartial(), + "cat_then_slice_mismatch": CatThenSliceMismatch(), + "cat_then_slice_with_step": CatThenSliceWithStep(), + "cat_then_slice_mixed_users": CatThenSliceMixedUsers(), + "slice_then_cat_different_sources": SliceThenCatDifferentSources(), + "cat_then_slice_with_non_slice_user": CatThenSliceWithNonSliceUser(), +} + + +@common.parametrize("model", positive_tests) +def test_fuse_concat_eliminates(model): + pipeline = PassPipeline( + model, + model.data, + quantize=False, + ops_before_pass=model.ops_before_pass, + ops_after_pass=model.ops_after_pass, + ops_not_after_pass=getattr(model, "ops_not_after_pass", []), + pass_list=[FuseConcatPass], + ) + pipeline.run() + + +@common.parametrize("model", negative_tests) +def test_fuse_concat_preserves(model): + pipeline = PassPipeline( + model, + model.data, + quantize=False, + ops_before_pass=model.ops_before_pass, + ops_after_pass=model.ops_after_pass, + pass_list=[FuseConcatPass], + ) + pipeline.run() + + +def test_find_common_slice_source_skips_non_node_input(): + """Defensive guard in ``_find_common_slice_source``: when ``cat_inputs`` + contains a non-Node entry, return None so slice-then-cat fusion bails. + + Such graphs aren't producible from normal ``torch.cat()`` (which always + yields Node inputs in FX), so we exercise the helper directly rather + than constructing a malformed FX graph. + + """ + from executorch.backends.arm._passes.fuse_concat_pass import ( + _find_common_slice_source, + ) + + assert _find_common_slice_source([None, None], cat_dim=0, dim_size=10) is None + assert _find_common_slice_source(["x", "y"], cat_dim=0, dim_size=10) is None + + +# End-to-end Ethos-U pipeline test models. The PassPipeline modules above wire +# the pattern directly to the model's forward args, which is fine for unit-level +# pass validation but causes I/O signature drift through Vela whenever the pass +# eliminates the only use of an input placeholder. These wrappers feed each +# pattern through a single input and surrounding compute so that lowering sees +# realistic graphs. + + +class SingleInputCatEthos(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.cat([x + 1], dim=0) + 2 + + data = (torch.randn(2, 3, 4),) + + +class CatThenSliceEthos(torch.nn.Module): + def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, ...]: + a = x[:, :3, :] + 0.5 + b = x[:, 3:, :] + 0.5 + combined = torch.cat([a, b], dim=1) + return combined[:, :3, :] + 1, combined[:, 3:, :] + 1 + + data = (torch.randn(1, 8, 4),) + + +class SliceThenCatEthos(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + y = x + 1 + a = y[:, :3, :] + b = y[:, 3:, :] + return torch.cat([a, b], dim=1) + 2 + + data = (torch.randn(1, 8, 4),) + + +class CatThenSubSliceEthos(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + a = x[:, :6, :] + 0.5 + b = x[:, 6:, :] + 0.5 + combined = torch.cat([a, b], dim=1) + return combined[:, 1:5, :] + 1 + + data = (torch.randn(1, 10, 4),) + + +class CatThenSubSliceSecondInputEthos(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + a = x[:, :3, :] + 0.5 + b = x[:, 3:, :] + 0.5 + combined = torch.cat([a, b], dim=1) + return combined[:, 5:9, :] + 1 + + data = (torch.randn(1, 11, 4),) + + +class SliceThenCatPartialContiguousEthos(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + a = x[:, 2:5, :] + b = x[:, 5:8, :] + return torch.cat([a, b], dim=1) + 1 + + data = (torch.randn(1, 10, 4),) + + +class CatThenSubSliceNegativeIndexEthos(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + a = x[:, :3, :] + 0.5 + b = x[:, 3:, :] + 0.5 + combined = torch.cat([a, b], dim=1) + return combined[:, -6:-2, :] + 1 + + data = (torch.randn(1, 11, 4),) + + +class CatNegDimThenSliceEthos(torch.nn.Module): + def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, ...]: + a = x[:, :3, :] + 0.5 + b = x[:, 3:, :] + 0.5 + combined = torch.cat([a, b], dim=-2) + return combined[:, :3, :] + 1, combined[:, 3:, :] + 1 + + data = (torch.randn(1, 8, 4),) + + +ethos_tests = { + "single_input_cat": SingleInputCatEthos(), + "cat_then_slice": CatThenSliceEthos(), + "slice_then_cat": SliceThenCatEthos(), + "cat_then_sub_slice": CatThenSubSliceEthos(), + "cat_then_sub_slice_second_input": CatThenSubSliceSecondInputEthos(), + "slice_then_cat_partial_contiguous": SliceThenCatPartialContiguousEthos(), + "cat_then_sub_slice_negative_index": CatThenSubSliceNegativeIndexEthos(), + "cat_neg_dim_then_slice": CatNegDimThenSliceEthos(), +} + + +@common.parametrize("model", ethos_tests) +def test_fuse_concat_tosa_FP(model): + """Verify the fusion pass eliminates all CONCAT ops in the lowered TOSA graph. + + Ethos pipelines emit Vela command streams, which the operator distribution + helper cannot inspect, so we lower to TOSA-FP and parse the flatbuffer + instead. + + """ + pipeline = TosaPipelineFP[type(model.data)]( + model, + model.data, + aten_op=aten_cat_op, + run_on_tosa_ref_model=False, + ) + pipeline.count_tosa_ops({"CONCAT": 0}) + pipeline.run() + + +@common.parametrize("model", negative_tests) +def test_fuse_concat_tosa_FP_preserves(model): + """Control for ``test_fuse_concat_tosa_FP``. + + Confirms CONCAT survives lowering when the fusion patterns don't apply, + which (a) proves the count mechanism reports non-zero counts, and (b) + catches regressions where fusion incorrectly fires on these patterns. + + """ + pipeline = TosaPipelineFP[type(model.data)]( + model, + model.data, + aten_op=aten_cat_op, + run_on_tosa_ref_model=False, + ) + pipeline.count_tosa_ops({"CONCAT": 1}) + pipeline.run() + + +@common.parametrize("model", ethos_tests) +@common.XfailIfNoCorstone300 +def test_fuse_concat_u55_INT(model): + pipeline = EthosU55PipelineINT( + model, + model.data, + aten_cat_op, + cat_op, + ) + pipeline.run() + + +@common.parametrize("model", ethos_tests) +@common.XfailIfNoCorstone320 +def test_fuse_concat_u85_INT(model): + pipeline = EthosU85PipelineINT( + model, + model.data, + aten_cat_op, + cat_op, + ) + pipeline.run()