perf(ir): Hoist rhs.ExtractBatchPage in batch_matmul unroll when rhs has no batch#1276
perf(ir): Hoist rhs.ExtractBatchPage in batch_matmul unroll when rhs has no batch#1276lyfne123 wants to merge 2 commits into
Conversation
📝 WalkthroughWalkthroughThe PR optimizes the ChangesRHS Page Hoisting Optimization
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Possibly related issues
Possibly related PRs
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches⚔️ Resolve merge conflicts
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request introduces a single-matmul fast path for tile.batch_matmul within the FlattenTileNdTo2D pass, specifically targeting cases where the RHS has an effective batch count of one. By collapsing the computation into a single 2D tile.matmul, the pass avoids per-batch unrolling. The review feedback highlights a critical issue where the transpose detection logic is not scope-aware, potentially leading to incorrect fast-path selection for variables defined in parent scopes. Additionally, it is recommended to adopt the INTERNAL_CHECK_SPAN macro for internal IR checks to improve debugging by including source location information.
| } | ||
| }; | ||
|
|
||
| auto local_def_map = BuildAssignDefMap(stmts); |
There was a problem hiding this comment.
The local_def_map is built only from the current list of statements. Since TransformBody is recursive, this map will not contain definitions from parent scopes. Consequently, IsLhsTransposed (and the corresponding logic in LowerBatchMatmul) will fail to detect transposes for variables defined outside the current block (e.g., a tile loaded before an if or for block). This could lead to the single-matmul fast path being incorrectly selected for transposed operands, which would result in incorrect code generation because the merged axis would not be contiguous in memory. Consider building a function-wide definition map in TransformFunction and passing it down to ensure correct transpose detection across all scopes, adhering to the project's scope-aware transformation guidelines.
References
- When implementing AST transformations that involve variable substitution, use a scope-aware mechanism (e.g., a stack of substitution maps or save/restore semantics) to correctly handle variable shadowing in nested scopes.
- When performing function-scoped analysis (e.g., collecting variable uses), running the analysis per-function is the correct pattern as variable scopes do not cross function boundaries.
| CHECK(store_tensor_type) | ||
| << "FlattenTileNdTo2D: tile.batch_matmul direct store target must be TensorType"; | ||
| CHECK(store_offsets) << "FlattenTileNdTo2D: tile.store offsets must be a MakeTuple"; |
There was a problem hiding this comment.
According to the project's general rules, internal checks on IR nodes that have source location information should use the INTERNAL_CHECK_SPAN macro. This ensures that the source location is automatically included in the error message, which is crucial for debugging compiler passes.
| CHECK(store_tensor_type) | |
| << "FlattenTileNdTo2D: tile.batch_matmul direct store target must be TensorType"; | |
| CHECK(store_offsets) << "FlattenTileNdTo2D: tile.store offsets must be a MakeTuple"; | |
| INTERNAL_CHECK_SPAN(span, store_tensor_type) | |
| << "FlattenTileNdTo2D: tile.batch_matmul direct store target must be TensorType"; | |
| INTERNAL_CHECK_SPAN(span, store_offsets) << "FlattenTileNdTo2D: tile.store offsets must be a MakeTuple"; |
References
- For internal checks on statements with source location information, use the INTERNAL_CHECK_SPAN macro to automatically include the source location in the error message as per project convention.
…rhs has no batch When `tile.batch_matmul`'s rhs has no effective batch dimensions (`prod(rhs_batch_dims) == 1`, covering both rhs-is-2D and ND-with-all-1 batch dims), `ExtractBatchPage` returns the same IR for every iteration of the per-batch unroll loop because `BuildOperandFlatBatchIndex` evaluates to 0 every time. Hoist that page extraction once outside the loop and reuse the resulting Var across all per-batch matmuls. Saves `batch_count - 1` redundant rhs `tile.load` calls without changing IR shape, codegen templates, or any other invariant. Cases covered: - `[B, M, K] x [K, N]` (B>1) — rhs is 2D - `[B, M, K] x [1, K, N]` (B>1) — rhs has size-1 batch broadcast - `[B1, B2, M, K] x [K, N]` (B1*B2>1) — multiple lhs batch dims, rhs is 2D Cases unchanged: - `batch_count == 1` — hits the existing fast path before the unroll - `[B, M, K] x [B, K, N]` (true batch on both) — rhs varies per batch, no hoist Earlier iterations of this PR attempted to collapse the entire unroll into a single 2D `tile.matmul` over an axis-merged lhs (`[prod(batch)*M, K]`). The IR was correct but codegen's `TLOAD<Tile, GlobalTensor>` template requires the GlobalTensor's last 2 dims to match the Tile's shape, which the auto-flattened "ND source view -> 2D dest tile" form violates. That optimization needs codegen-side support (see the companion bug hw-native-sys#1278) and is deferred. The rhs-hoist optimization in this PR uses the same per-batch IR shape that codegen has always supported, so no codegen changes are required. Adds 3 unit tests in test_flatten_tile_nd_to_2d.py: - test_batch_matmul_rhs_2d_hoists_rhs_load — rhs truly 2D - test_batch_matmul_rhs_size1_batch_hoists_rhs_load — size-1 batch broadcast - test_batch_matmul_both_batched_no_hoist — true batch on both, no hoist Adds an end-to-end TestBatchMatmulLhsNdRhs2d class in test_batch_matmul.py with three parametrized shape configurations against torch.matmul.
1b0a56b to
ee6ce91
Compare
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@tests/st/runtime/test_batch_matmul.py`:
- Line 335: Replace the Unicode multiplication sign '×' with a plain ASCII 'x'
in the docstrings of the batch matmul tests to avoid Ruff ambiguous-character
warnings; specifically update the docstring that currently reads "lhs ND × rhs
2D triggers FlattenTileNdTo2D's rhs-hoist optimization." (and the similar
docstring at the nearby line) to use "lhs ND x rhs 2D ..." so the test
function/docstring text containing that phrase in
tests/st/runtime/test_batch_matmul.py is changed accordingly.
In `@tests/ut/ir/transforms/test_flatten_tile_nd_to_2d.py`:
- Line 1548: In test_flatten_tile_nd_to_2d.py replace the Unicode multiplication
character '×' used in comment lines with a plain ASCII 'x' to satisfy Ruff
RUF003; update the three comment occurrences (the one shown and the two other
occurrences flagged) so comments like "× 2" become "x 2" — no behavior change,
only ASCII comment character replacement.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 3b8a37a9-dd23-4f0a-bf4f-7b01fe7b4d32
📒 Files selected for processing (6)
docs/en/dev/passes/14-flatten_tile_nd_to_2d.mddocs/zh-cn/dev/passes/14-flatten_tile_nd_to_2d.mdruntimesrc/ir/transforms/flatten_tile_nd_to_2d_pass.cpptests/st/runtime/test_batch_matmul.pytests/ut/ir/transforms/test_flatten_tile_nd_to_2d.py
| ], | ||
| ) | ||
| def test_batch_matmul_lhs_nd_rhs_2d(self, test_runner, batch, m, k, n): | ||
| """lhs ND × rhs 2D triggers FlattenTileNdTo2D's rhs-hoist optimization.""" |
There was a problem hiding this comment.
Use plain x instead of × in test text to avoid Ruff ambiguous-character warnings.
Suggested edit
- """lhs ND × rhs 2D triggers FlattenTileNdTo2D's rhs-hoist optimization."""
+ """lhs ND x rhs 2D triggers FlattenTileNdTo2D's rhs-hoist optimization."""
...
- assert result.passed, f"Test failed (lhs ND × rhs 2D): {result.error}"
+ assert result.passed, f"Test failed (lhs ND x rhs 2D): {result.error}"Also applies to: 338-338
🧰 Tools
🪛 Ruff (0.15.12)
[warning] 335-335: Docstring contains ambiguous × (MULTIPLICATION SIGN). Did you mean x (LATIN SMALL LETTER X)?
(RUF002)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@tests/st/runtime/test_batch_matmul.py` at line 335, Replace the Unicode
multiplication sign '×' with a plain ASCII 'x' in the docstrings of the batch
matmul tests to avoid Ruff ambiguous-character warnings; specifically update the
docstring that currently reads "lhs ND × rhs 2D triggers FlattenTileNdTo2D's
rhs-hoist optimization." (and the similar docstring at the nearby line) to use
"lhs ND x rhs 2D ..." so the test function/docstring text containing that phrase
in tests/st/runtime/test_batch_matmul.py is changed accordingly.
|
|
||
| func = self._flattened_incore(Before) | ||
| calls = self._top_level_calls(func) | ||
| # Expect: 1 rhs.load (hoisted), then per-batch (lhs.load, matmul, store) × 2. |
There was a problem hiding this comment.
Replace ambiguous × characters in comments to satisfy Ruff.
These flagged comment lines should use plain x to avoid RUF003 warnings.
Suggested edit
- # Expect: 1 rhs.load (hoisted), then per-batch (lhs.load, matmul, store) × 2.
+ # Expect: 1 rhs.load (hoisted), then per-batch (lhs.load, matmul, store) x 2.
...
- # 1 hoisted rhs.load + 3 × (lhs.load, matmul, store).
+ # 1 hoisted rhs.load + 3 x (lhs.load, matmul, store).
...
- # Both batched -> no hoist: 2 × (lhs.load, rhs.load, matmul, store).
+ # Both batched -> no hoist: 2 x (lhs.load, rhs.load, matmul, store).Also applies to: 1609-1609, 1661-1661
🧰 Tools
🪛 Ruff (0.15.12)
[warning] 1548-1548: Comment contains ambiguous × (MULTIPLICATION SIGN). Did you mean x (LATIN SMALL LETTER X)?
(RUF003)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@tests/ut/ir/transforms/test_flatten_tile_nd_to_2d.py` at line 1548, In
test_flatten_tile_nd_to_2d.py replace the Unicode multiplication character '×'
used in comment lines with a plain ASCII 'x' to satisfy Ruff RUF003; update the
three comment occurrences (the one shown and the two other occurrences flagged)
so comments like "× 2" become "x 2" — no behavior change, only ASCII comment
character replacement.
Summary
When
tile.batch_matmul's rhs has no effective batch dimensions (prod(rhs_batch_dims) == 1— covers both rhs-is-2D and ND-with-all-1 leading dims),ExtractBatchPageproduces identical IR each iteration of the per-batch unroll. This PR hoists that page extraction once outside the loop and reuses the resulting Var across all per-batch matmuls, savingB-1redundant rhstile.loadcalls.The IR shape stays exactly what codegen has always supported (per-batch unroll with 2D
tile.matmul), so no codegen changes are required.Why this scope
Earlier iterations attempted to collapse the entire unroll into a single 2D
tile.matmulover an axis-merged lhs ([prod(batch)*M, K]). The IR was semantically correct but theTLOAD<Tile, GlobalTensor>C++ template requires the source tensor's last 2 dims to match the tile shape, which the auto-flattened "ND source view → 2D dest tile" form violates. That optimization needs codegen-side support and is tracked in #1278.The rhs-hoist optimization in this PR uses the same per-batch IR shape that already passes CI, so it lands cleanly today while we wait for the codegen work to enable the bigger optimization.
Cases
[B,M,K] × [K,N](B>1)[B,M,K] × [1,K,N](B>1)[B1,B2,M,K] × [K,N](B1*B2>1)[B,M,K] × [B,K,N]batch_count == 1Test coverage
tests/ut/ir/transforms/test_flatten_tile_nd_to_2d.py(3 new):test_batch_matmul_rhs_2d_hoists_rhs_load— verifies rhs.load appears once for[2,16,128] × [128,64]test_batch_matmul_rhs_size1_batch_hoists_rhs_load— verifies hoist also fires for size-1 batch broadcasttest_batch_matmul_both_batched_no_hoist— verifies hoist does NOT fire when rhs has true batchtests/st/runtime/test_batch_matmul.py:TestBatchMatmulLhsNdRhs2dend-to-end class + 3 parametrized configs againsttorch.matmulTest plan
cmake --build build --parallelpython -m pytest tests/ut/ir/— 2957 passed, 26 skippedtile.batch_matmulandtile.batch_matmul_acctests still passpython tests/lint/clang_tidy.py --diff-base HEADcleansystem-tests(a2a3 simulator) — was failing on the previous axis-merge approach due to codegen template mismatch; should pass now since IR shape is unchanged from existing per-batch unrollRelated