-
Notifications
You must be signed in to change notification settings - Fork 67
perf(ir): Hoist rhs.ExtractBatchPage in batch_matmul unroll when rhs has no batch #1276
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1507,6 +1507,169 @@ def load_tile_shape(shape: list[int], transpose: bool) -> list[int]: | |
| "expected_store_shapes" | ||
| ] | ||
|
|
||
| def test_batch_matmul_rhs_2d_hoists_rhs_load(self): | ||
| """lhs [B, M, K] x rhs [K, N] (B>1): rhs.load is hoisted out of the per-batch loop. | ||
|
|
||
| Saves (B-1) redundant rhs tile.load calls. Per-batch loop still emits B copies of | ||
| (lhs.load, matmul, store), but only ONE rhs.load is emitted for the whole sequence. | ||
| """ | ||
|
|
||
| @pl.program | ||
| class Before: | ||
| @pl.function(type=pl.FunctionType.InCore) | ||
| def main_incore_0( | ||
| self, | ||
| lhs: pl.Tensor[[2, 16, 128], pl.FP16], | ||
| rhs: pl.Tensor[[128, 64], pl.FP16], | ||
| out_0: pl.Out[pl.Tensor[[2, 16, 64], pl.FP16]], | ||
| ) -> pl.Tensor[[2, 16, 64], pl.FP16]: | ||
| lhs_tile: pl.Tile[[2, 16, 128], pl.FP16] = pl.load( | ||
| lhs, [0, 0, 0], [2, 16, 128], target_memory=pl.MemorySpace.Mat | ||
| ) | ||
| rhs_tile: pl.Tile[[128, 64], pl.FP16] = pl.load( | ||
| rhs, [0, 0], [128, 64], target_memory=pl.MemorySpace.Mat | ||
| ) | ||
| out_tile: pl.Tile[[2, 16, 64], pl.FP32] = pl.tile.batch_matmul(lhs_tile, rhs_tile) | ||
| out_0 = pl.store(out_tile, [0, 0, 0], out_0) | ||
| return out_0 | ||
|
|
||
| @pl.function | ||
| def main( | ||
| self, | ||
| lhs: pl.Tensor[[2, 16, 128], pl.FP16], | ||
| rhs: pl.Tensor[[128, 64], pl.FP16], | ||
| ) -> pl.Tensor[[2, 16, 64], pl.FP16]: | ||
| out_0 = pl.create_tensor([2, 16, 64], dtype=pl.FP16) | ||
| y = self.main_incore_0(lhs, rhs, out_0) | ||
| return y | ||
|
|
||
| func = self._flattened_incore(Before) | ||
| calls = self._top_level_calls(func) | ||
| # Expect: 1 rhs.load (hoisted), then per-batch (lhs.load, matmul, store) × 2. | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Replace ambiguous These flagged comment lines should use plain Suggested edit- # Expect: 1 rhs.load (hoisted), then per-batch (lhs.load, matmul, store) × 2.
+ # Expect: 1 rhs.load (hoisted), then per-batch (lhs.load, matmul, store) x 2.
...
- # 1 hoisted rhs.load + 3 × (lhs.load, matmul, store).
+ # 1 hoisted rhs.load + 3 x (lhs.load, matmul, store).
...
- # Both batched -> no hoist: 2 × (lhs.load, rhs.load, matmul, store).
+ # Both batched -> no hoist: 2 x (lhs.load, rhs.load, matmul, store).Also applies to: 1609-1609, 1661-1661 🧰 Tools🪛 Ruff (0.15.12)[warning] 1548-1548: Comment contains ambiguous (RUF003) 🤖 Prompt for AI Agents |
||
| assert [call.op.name for call in calls] == [ | ||
| "tile.load", # hoisted rhs | ||
| "tile.load", # batch 0 lhs | ||
| "tile.matmul", | ||
| "tile.store", | ||
| "tile.load", # batch 1 lhs | ||
| "tile.matmul", | ||
| "tile.store", | ||
| ] | ||
| load_calls = [call for call in calls if call.op.name == "tile.load"] | ||
| # First load is rhs (2D, no batch shift). Subsequent loads are per-batch lhs. | ||
| assert [self._tuple_const_values(call.args[1]) for call in load_calls] == [ | ||
| [0, 0], # hoisted rhs offset | ||
| [0, 0, 0], # batch 0 lhs | ||
| [1, 0, 0], # batch 1 lhs | ||
| ] | ||
| assert [self._tuple_const_values(call.args[2]) for call in load_calls] == [ | ||
| [128, 64], # rhs shape | ||
| [1, 16, 128], # per-batch lhs shape (single batch slice) | ||
| [1, 16, 128], | ||
| ] | ||
|
|
||
| def test_batch_matmul_rhs_size1_batch_hoists_rhs_load(self): | ||
| """lhs [B, M, K] x rhs [1, K, N] (B>1): rhs hoist also fires for size-1-batch rhs. | ||
|
|
||
| rhs has a leading batch dim of 1 (broadcasted across all output batches), so | ||
| ExtractBatchPage with batch_idx=0 gives the same result every iteration. Hoist. | ||
| """ | ||
|
|
||
| @pl.program | ||
| class Before: | ||
| @pl.function(type=pl.FunctionType.InCore) | ||
| def main_incore_0( | ||
| self, | ||
| lhs: pl.Tensor[[3, 16, 128], pl.FP16], | ||
| rhs: pl.Tensor[[1, 128, 64], pl.FP16], | ||
| out_0: pl.Out[pl.Tensor[[3, 16, 64], pl.FP16]], | ||
| ) -> pl.Tensor[[3, 16, 64], pl.FP16]: | ||
| lhs_tile: pl.Tile[[3, 16, 128], pl.FP16] = pl.load( | ||
| lhs, [0, 0, 0], [3, 16, 128], target_memory=pl.MemorySpace.Mat | ||
| ) | ||
| rhs_tile: pl.Tile[[1, 128, 64], pl.FP16] = pl.load( | ||
| rhs, [0, 0, 0], [1, 128, 64], target_memory=pl.MemorySpace.Mat | ||
| ) | ||
| out_tile: pl.Tile[[3, 16, 64], pl.FP32] = pl.tile.batch_matmul(lhs_tile, rhs_tile) | ||
| out_0 = pl.store(out_tile, [0, 0, 0], out_0) | ||
| return out_0 | ||
|
|
||
| @pl.function | ||
| def main( | ||
| self, | ||
| lhs: pl.Tensor[[3, 16, 128], pl.FP16], | ||
| rhs: pl.Tensor[[1, 128, 64], pl.FP16], | ||
| ) -> pl.Tensor[[3, 16, 64], pl.FP16]: | ||
| out_0 = pl.create_tensor([3, 16, 64], dtype=pl.FP16) | ||
| y = self.main_incore_0(lhs, rhs, out_0) | ||
| return y | ||
|
|
||
| func = self._flattened_incore(Before) | ||
| op_names = [call.op.name for call in self._top_level_calls(func)] | ||
| # 1 hoisted rhs.load + 3 × (lhs.load, matmul, store). | ||
| assert op_names == [ | ||
| "tile.load", # hoisted rhs | ||
| "tile.load", | ||
| "tile.matmul", | ||
| "tile.store", | ||
| "tile.load", | ||
| "tile.matmul", | ||
| "tile.store", | ||
| "tile.load", | ||
| "tile.matmul", | ||
| "tile.store", | ||
| ] | ||
| # Exactly one load whose shape is the rhs shape (rest are lhs). | ||
| load_calls = [call for call in self._top_level_calls(func) if call.op.name == "tile.load"] | ||
| rhs_loads = [call for call in load_calls if self._tuple_const_values(call.args[2]) == [1, 128, 64]] | ||
| assert len(rhs_loads) == 1, f"expected 1 hoisted rhs load, got {len(rhs_loads)}" | ||
|
|
||
| def test_batch_matmul_both_batched_no_hoist(self): | ||
| """lhs [B, M, K] x rhs [B, K, N] (B>1): rhs varies per batch, hoist must NOT fire.""" | ||
|
|
||
| @pl.program | ||
| class Before: | ||
| @pl.function(type=pl.FunctionType.InCore) | ||
| def main_incore_0( | ||
| self, | ||
| lhs: pl.Tensor[[2, 16, 128], pl.FP16], | ||
| rhs: pl.Tensor[[2, 128, 64], pl.FP16], | ||
| out_0: pl.Out[pl.Tensor[[2, 16, 64], pl.FP16]], | ||
| ) -> pl.Tensor[[2, 16, 64], pl.FP16]: | ||
| lhs_tile: pl.Tile[[2, 16, 128], pl.FP16] = pl.load( | ||
| lhs, [0, 0, 0], [2, 16, 128], target_memory=pl.MemorySpace.Mat | ||
| ) | ||
| rhs_tile: pl.Tile[[2, 128, 64], pl.FP16] = pl.load( | ||
| rhs, [0, 0, 0], [2, 128, 64], target_memory=pl.MemorySpace.Mat | ||
| ) | ||
| out_tile: pl.Tile[[2, 16, 64], pl.FP32] = pl.tile.batch_matmul(lhs_tile, rhs_tile) | ||
| out_0 = pl.store(out_tile, [0, 0, 0], out_0) | ||
| return out_0 | ||
|
|
||
| @pl.function | ||
| def main( | ||
| self, | ||
| lhs: pl.Tensor[[2, 16, 128], pl.FP16], | ||
| rhs: pl.Tensor[[2, 128, 64], pl.FP16], | ||
| ) -> pl.Tensor[[2, 16, 64], pl.FP16]: | ||
| out_0 = pl.create_tensor([2, 16, 64], dtype=pl.FP16) | ||
| y = self.main_incore_0(lhs, rhs, out_0) | ||
| return y | ||
|
|
||
| func = self._flattened_incore(Before) | ||
| op_names = [call.op.name for call in self._top_level_calls(func)] | ||
| # Both batched -> no hoist: 2 × (lhs.load, rhs.load, matmul, store). | ||
| assert op_names == [ | ||
| "tile.load", | ||
| "tile.load", | ||
| "tile.matmul", | ||
| "tile.store", | ||
| "tile.load", | ||
| "tile.load", | ||
| "tile.matmul", | ||
| "tile.store", | ||
| ] | ||
|
|
||
|
|
||
| # ---------------------------------------------------------------------------- | ||
| # tile.batch_matmul_acc lowering | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use plain
xinstead of×in test text to avoid Ruff ambiguous-character warnings.Suggested edit
Also applies to: 338-338
🧰 Tools
🪛 Ruff (0.15.12)
[warning] 335-335: Docstring contains ambiguous
×(MULTIPLICATION SIGN). Did you meanx(LATIN SMALL LETTER X)?(RUF002)
🤖 Prompt for AI Agents