Skip to content

perf(ir): Hoist rhs.ExtractBatchPage in batch_matmul unroll when rhs has no batch#1276

Open
lyfne123 wants to merge 2 commits into
hw-native-sys:mainfrom
lyfne123:perf/batch-matmul-collapse-rhs-no-batch
Open

perf(ir): Hoist rhs.ExtractBatchPage in batch_matmul unroll when rhs has no batch#1276
lyfne123 wants to merge 2 commits into
hw-native-sys:mainfrom
lyfne123:perf/batch-matmul-collapse-rhs-no-batch

Conversation

@lyfne123
Copy link
Copy Markdown
Collaborator

@lyfne123 lyfne123 commented May 6, 2026

Summary

When tile.batch_matmul's rhs has no effective batch dimensions (prod(rhs_batch_dims) == 1 — covers both rhs-is-2D and ND-with-all-1 leading dims), ExtractBatchPage produces identical IR each iteration of the per-batch unroll. This PR hoists that page extraction once outside the loop and reuses the resulting Var across all per-batch matmuls, saving B-1 redundant rhs tile.load calls.

The IR shape stays exactly what codegen has always supported (per-batch unroll with 2D tile.matmul), so no codegen changes are required.

Why this scope

Earlier iterations attempted to collapse the entire unroll into a single 2D tile.matmul over an axis-merged lhs ([prod(batch)*M, K]). The IR was semantically correct but the TLOAD<Tile, GlobalTensor> C++ template requires the source tensor's last 2 dims to match the tile shape, which the auto-flattened "ND source view → 2D dest tile" form violates. That optimization needs codegen-side support and is tracked in #1278.

The rhs-hoist optimization in this PR uses the same per-batch IR shape that already passes CI, so it lands cleanly today while we wait for the codegen work to enable the bigger optimization.

Cases

lhs × rhs batch_count batch_count_rhs Behavior
[B,M,K] × [K,N] (B>1) B 1 Hoist rhs.load — saves B-1 loads
[B,M,K] × [1,K,N] (B>1) B 1 Hoist rhs.load — saves B-1 loads
[B1,B2,M,K] × [K,N] (B1*B2>1) B1*B2 1 Hoist rhs.load — saves B1*B2-1 loads
[B,M,K] × [B,K,N] B B No hoist (rhs varies per batch)
Anything with batch_count == 1 1 Existing batch_count==1 fast path (unchanged)

Test coverage

tests/ut/ir/transforms/test_flatten_tile_nd_to_2d.py (3 new):

  • test_batch_matmul_rhs_2d_hoists_rhs_load — verifies rhs.load appears once for [2,16,128] × [128,64]
  • test_batch_matmul_rhs_size1_batch_hoists_rhs_load — verifies hoist also fires for size-1 batch broadcast
  • test_batch_matmul_both_batched_no_hoist — verifies hoist does NOT fire when rhs has true batch

tests/st/runtime/test_batch_matmul.py:

  • TestBatchMatmulLhsNdRhs2d end-to-end class + 3 parametrized configs against torch.matmul

Test plan

  • cmake --build build --parallel
  • python -m pytest tests/ut/ir/ — 2957 passed, 26 skipped
  • All existing tile.batch_matmul and tile.batch_matmul_acc tests still pass
  • python tests/lint/clang_tidy.py --diff-base HEAD clean
  • CI system-tests (a2a3 simulator) — was failing on the previous axis-merge approach due to codegen template mismatch; should pass now since IR shape is unchanged from existing per-batch unroll

Related

@coderabbitai
Copy link
Copy Markdown

coderabbitai Bot commented May 6, 2026

Review Change Stack

📝 Walkthrough

Walkthrough

The PR optimizes the FlattenTileNdTo2D compiler pass to reduce redundant RHS page extraction operations in batch matmul lowering. When the RHS operand has no effective batch dimension (size 1), the pass now hoists RHS page extraction outside the per-batch unroll loop, eliminating B-1 duplicate load operations while maintaining per-batch LHS extraction and computation.

Changes

RHS Page Hoisting Optimization

Layer / File(s) Summary
Core Batch Matmul Lowering
src/ir/transforms/flatten_tile_nd_to_2d_pass.cpp
LowerBatchMatmul now computes RHS batch size separately and hoists ExtractBatchPage for RHS outside the unroll loop when RHS batch == 1 and overall batch > 1. Per-batch iterations reuse hoisted RHS while extracting LHS per batch.
Unit Tests for Hoisting
tests/ut/ir/transforms/test_flatten_tile_nd_to_2d.py
Added two tests: test_batch_matmul_rhs_2d_hoists_rhs_load validates 2D RHS hoisting; test_batch_matmul_rhs_size1_batch_hoists_rhs_load validates explicit batch-size-1 RHS hoisting. Both assert exact operation sequences with single hoisted RHS load.
System Tests
tests/st/runtime/test_batch_matmul.py
Added TestBatchMatmulLhsNdRhs2d class and parametrized test_batch_matmul_lhs_nd_rhs_2d test method validating ND LHS with 2D RHS batch matmul correctness against PyTorch with RHS broadcasting across batch dimension.
Documentation
docs/en/dev/passes/14-flatten_tile_nd_to_2d.md, docs/zh-cn/dev/passes/14-flatten_tile_nd_to_2d.md
Updated tile.batch_matmul descriptions in both English and Chinese docs to document RHS page extraction hoisting when RHS has no effective batch, reducing redundant load emissions during per-batch expansion.
Runtime Submodule
runtime
Updated submodule pointer to a new commit hash.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Possibly related issues

Possibly related PRs

Poem

🐰 A rabbit hops through batches fast,
RHS pages—now just one, not cast!
Loop unrolling, pages hoisted high,
Redundant loads? We say goodbye!

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 36.36% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly describes the main optimization: hoisting RHS page extraction in batch_matmul unroll when RHS has no batch dimensions. It is specific, concise, and directly reflects the core change.
Description check ✅ Passed The description provides comprehensive context for the PR including the optimization rationale, scope explanation, case coverage table, test details, and related issues. It is well-structured and directly relates to all changes in the changeset.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
⚔️ Resolve merge conflicts
  • Resolve merge conflict in branch perf/batch-matmul-collapse-rhs-no-batch

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a single-matmul fast path for tile.batch_matmul within the FlattenTileNdTo2D pass, specifically targeting cases where the RHS has an effective batch count of one. By collapsing the computation into a single 2D tile.matmul, the pass avoids per-batch unrolling. The review feedback highlights a critical issue where the transpose detection logic is not scope-aware, potentially leading to incorrect fast-path selection for variables defined in parent scopes. Additionally, it is recommended to adopt the INTERNAL_CHECK_SPAN macro for internal IR checks to improve debugging by including source location information.

}
};

auto local_def_map = BuildAssignDefMap(stmts);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The local_def_map is built only from the current list of statements. Since TransformBody is recursive, this map will not contain definitions from parent scopes. Consequently, IsLhsTransposed (and the corresponding logic in LowerBatchMatmul) will fail to detect transposes for variables defined outside the current block (e.g., a tile loaded before an if or for block). This could lead to the single-matmul fast path being incorrectly selected for transposed operands, which would result in incorrect code generation because the merged axis would not be contiguous in memory. Consider building a function-wide definition map in TransformFunction and passing it down to ensure correct transpose detection across all scopes, adhering to the project's scope-aware transformation guidelines.

References
  1. When implementing AST transformations that involve variable substitution, use a scope-aware mechanism (e.g., a stack of substitution maps or save/restore semantics) to correctly handle variable shadowing in nested scopes.
  2. When performing function-scoped analysis (e.g., collecting variable uses), running the analysis per-function is the correct pattern as variable scopes do not cross function boundaries.

Comment on lines +863 to +865
CHECK(store_tensor_type)
<< "FlattenTileNdTo2D: tile.batch_matmul direct store target must be TensorType";
CHECK(store_offsets) << "FlattenTileNdTo2D: tile.store offsets must be a MakeTuple";
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

According to the project's general rules, internal checks on IR nodes that have source location information should use the INTERNAL_CHECK_SPAN macro. This ensures that the source location is automatically included in the error message, which is crucial for debugging compiler passes.

Suggested change
CHECK(store_tensor_type)
<< "FlattenTileNdTo2D: tile.batch_matmul direct store target must be TensorType";
CHECK(store_offsets) << "FlattenTileNdTo2D: tile.store offsets must be a MakeTuple";
INTERNAL_CHECK_SPAN(span, store_tensor_type)
<< "FlattenTileNdTo2D: tile.batch_matmul direct store target must be TensorType";
INTERNAL_CHECK_SPAN(span, store_offsets) << "FlattenTileNdTo2D: tile.store offsets must be a MakeTuple";
References
  1. For internal checks on statements with source location information, use the INTERNAL_CHECK_SPAN macro to automatically include the source location in the error message as per project convention.

…rhs has no batch

When `tile.batch_matmul`'s rhs has no effective batch dimensions
(`prod(rhs_batch_dims) == 1`, covering both rhs-is-2D and ND-with-all-1
batch dims), `ExtractBatchPage` returns the same IR for every iteration of
the per-batch unroll loop because `BuildOperandFlatBatchIndex` evaluates to
0 every time. Hoist that page extraction once outside the loop and reuse
the resulting Var across all per-batch matmuls. Saves `batch_count - 1`
redundant rhs `tile.load` calls without changing IR shape, codegen
templates, or any other invariant.

Cases covered:
  - `[B, M, K] x [K, N]` (B>1)            — rhs is 2D
  - `[B, M, K] x [1, K, N]` (B>1)         — rhs has size-1 batch broadcast
  - `[B1, B2, M, K] x [K, N]` (B1*B2>1)   — multiple lhs batch dims, rhs is 2D
Cases unchanged:
  - `batch_count == 1` — hits the existing fast path before the unroll
  - `[B, M, K] x [B, K, N]` (true batch on both) — rhs varies per batch, no hoist

Earlier iterations of this PR attempted to collapse the entire unroll into
a single 2D `tile.matmul` over an axis-merged lhs (`[prod(batch)*M, K]`).
The IR was correct but codegen's `TLOAD<Tile, GlobalTensor>` template
requires the GlobalTensor's last 2 dims to match the Tile's shape, which
the auto-flattened "ND source view -> 2D dest tile" form violates. That
optimization needs codegen-side support (see the companion bug hw-native-sys#1278) and
is deferred. The rhs-hoist optimization in this PR uses the same per-batch
IR shape that codegen has always supported, so no codegen changes are
required.

Adds 3 unit tests in test_flatten_tile_nd_to_2d.py:
  - test_batch_matmul_rhs_2d_hoists_rhs_load           — rhs truly 2D
  - test_batch_matmul_rhs_size1_batch_hoists_rhs_load  — size-1 batch broadcast
  - test_batch_matmul_both_batched_no_hoist            — true batch on both, no hoist

Adds an end-to-end TestBatchMatmulLhsNdRhs2d class in test_batch_matmul.py
with three parametrized shape configurations against torch.matmul.
@lyfne123 lyfne123 force-pushed the perf/batch-matmul-collapse-rhs-no-batch branch from 1b0a56b to ee6ce91 Compare May 6, 2026 09:31
@lyfne123 lyfne123 changed the title perf(ir): Collapse batch_matmul to one tile.matmul when rhs has no batch perf(ir): Hoist rhs.ExtractBatchPage in batch_matmul unroll when rhs has no batch May 6, 2026
Copy link
Copy Markdown

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@tests/st/runtime/test_batch_matmul.py`:
- Line 335: Replace the Unicode multiplication sign '×' with a plain ASCII 'x'
in the docstrings of the batch matmul tests to avoid Ruff ambiguous-character
warnings; specifically update the docstring that currently reads "lhs ND × rhs
2D triggers FlattenTileNdTo2D's rhs-hoist optimization." (and the similar
docstring at the nearby line) to use "lhs ND x rhs 2D ..." so the test
function/docstring text containing that phrase in
tests/st/runtime/test_batch_matmul.py is changed accordingly.

In `@tests/ut/ir/transforms/test_flatten_tile_nd_to_2d.py`:
- Line 1548: In test_flatten_tile_nd_to_2d.py replace the Unicode multiplication
character '×' used in comment lines with a plain ASCII 'x' to satisfy Ruff
RUF003; update the three comment occurrences (the one shown and the two other
occurrences flagged) so comments like "× 2" become "x 2" — no behavior change,
only ASCII comment character replacement.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 3b8a37a9-dd23-4f0a-bf4f-7b01fe7b4d32

📥 Commits

Reviewing files that changed from the base of the PR and between 3eb3b38 and 17a3c6f.

📒 Files selected for processing (6)
  • docs/en/dev/passes/14-flatten_tile_nd_to_2d.md
  • docs/zh-cn/dev/passes/14-flatten_tile_nd_to_2d.md
  • runtime
  • src/ir/transforms/flatten_tile_nd_to_2d_pass.cpp
  • tests/st/runtime/test_batch_matmul.py
  • tests/ut/ir/transforms/test_flatten_tile_nd_to_2d.py

],
)
def test_batch_matmul_lhs_nd_rhs_2d(self, test_runner, batch, m, k, n):
"""lhs ND × rhs 2D triggers FlattenTileNdTo2D's rhs-hoist optimization."""
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Use plain x instead of × in test text to avoid Ruff ambiguous-character warnings.

Suggested edit
-        """lhs ND × rhs 2D triggers FlattenTileNdTo2D's rhs-hoist optimization."""
+        """lhs ND x rhs 2D triggers FlattenTileNdTo2D's rhs-hoist optimization."""
...
-        assert result.passed, f"Test failed (lhs ND × rhs 2D): {result.error}"
+        assert result.passed, f"Test failed (lhs ND x rhs 2D): {result.error}"

Also applies to: 338-338

🧰 Tools
🪛 Ruff (0.15.12)

[warning] 335-335: Docstring contains ambiguous × (MULTIPLICATION SIGN). Did you mean x (LATIN SMALL LETTER X)?

(RUF002)

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/st/runtime/test_batch_matmul.py` at line 335, Replace the Unicode
multiplication sign '×' with a plain ASCII 'x' in the docstrings of the batch
matmul tests to avoid Ruff ambiguous-character warnings; specifically update the
docstring that currently reads "lhs ND × rhs 2D triggers FlattenTileNdTo2D's
rhs-hoist optimization." (and the similar docstring at the nearby line) to use
"lhs ND x rhs 2D ..." so the test function/docstring text containing that phrase
in tests/st/runtime/test_batch_matmul.py is changed accordingly.


func = self._flattened_incore(Before)
calls = self._top_level_calls(func)
# Expect: 1 rhs.load (hoisted), then per-batch (lhs.load, matmul, store) × 2.
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Replace ambiguous × characters in comments to satisfy Ruff.

These flagged comment lines should use plain x to avoid RUF003 warnings.

Suggested edit
-        # Expect: 1 rhs.load (hoisted), then per-batch (lhs.load, matmul, store) × 2.
+        # Expect: 1 rhs.load (hoisted), then per-batch (lhs.load, matmul, store) x 2.
...
-        # 1 hoisted rhs.load + 3 × (lhs.load, matmul, store).
+        # 1 hoisted rhs.load + 3 x (lhs.load, matmul, store).
...
-        # Both batched -> no hoist: 2 × (lhs.load, rhs.load, matmul, store).
+        # Both batched -> no hoist: 2 x (lhs.load, rhs.load, matmul, store).

Also applies to: 1609-1609, 1661-1661

🧰 Tools
🪛 Ruff (0.15.12)

[warning] 1548-1548: Comment contains ambiguous × (MULTIPLICATION SIGN). Did you mean x (LATIN SMALL LETTER X)?

(RUF003)

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/ut/ir/transforms/test_flatten_tile_nd_to_2d.py` at line 1548, In
test_flatten_tile_nd_to_2d.py replace the Unicode multiplication character '×'
used in comment lines with a plain ASCII 'x' to satisfy Ruff RUF003; update the
three comment occurrences (the one shown and the two other occurrences flagged)
so comments like "× 2" become "x 2" — no behavior change, only ASCII comment
character replacement.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

Status: No status

Development

Successfully merging this pull request may close these issues.

1 participant