diff --git a/docs/en/dev/passes/14-flatten_tile_nd_to_2d.md b/docs/en/dev/passes/14-flatten_tile_nd_to_2d.md index 974759f6c..936b43e9b 100644 --- a/docs/en/dev/passes/14-flatten_tile_nd_to_2d.md +++ b/docs/en/dev/passes/14-flatten_tile_nd_to_2d.md @@ -53,7 +53,7 @@ Per-statement handling: | `tile.store` (2D tensor) | Pass through unchanged | | `tile.create`/`tile.full` (>2D) | Rebuild with flattened 2D shape directly | | `tile.sum`/`tile.max`/`tile.min` (>2D) | Remap axis to 1 (last axis of 2D) | -| `tile.batch_matmul` | Expand to per-batch 2D `tile.matmul`, honoring batch broadcast and any operand-side transpose carried in the producer `tile.load(target_memory=Mat, transpose=True)` | +| `tile.batch_matmul` | Expand to per-batch 2D `tile.matmul`, honoring batch broadcast and any operand-side transpose carried in the producer `tile.load(target_memory=Mat, transpose=True)`. When the rhs has no effective batch (`prod(rhs_batch_dims) == 1`), the rhs page extraction is hoisted out of the per-batch unroll so only one rhs.load is emitted (saves B-1 redundant rhs loads) | | `tile.batch_matmul_acc` | Expand to per-batch 2D `tile.matmul_acc`, slicing the (already-flattened) accumulator per batch index; an explicit `tile.move(target_memory=Acc)` is inserted when the accumulator is in another memory space | | Other tile ops (>2D) | Substitute vars, re-create with 2D types | | 1D/2D tile ops | Unchanged | diff --git a/docs/zh-cn/dev/passes/14-flatten_tile_nd_to_2d.md b/docs/zh-cn/dev/passes/14-flatten_tile_nd_to_2d.md index 4710d8eb7..d4ce23834 100644 --- a/docs/zh-cn/dev/passes/14-flatten_tile_nd_to_2d.md +++ b/docs/zh-cn/dev/passes/14-flatten_tile_nd_to_2d.md @@ -52,7 +52,7 @@ program_2d = flatten_pass(program) | `tile.store`(2D 张量) | 直接透传 | | `tile.create`/`tile.full`(>2D) | 直接使用展平的 2D 形状重建 | | `tile.sum`/`tile.max`/`tile.min`(>2D) | 将 axis 映射为 1(2D 的最后轴) | -| `tile.batch_matmul` | 展开为逐 batch 的 2D `tile.matmul`,处理 batch broadcast;operand 的 transpose 通过生产侧 `tile.load(target_memory=Mat, transpose=True)` 携带 | +| `tile.batch_matmul` | 展开为逐 batch 的 2D `tile.matmul`,并处理 batch broadcast;operand 的 transpose 通过生产侧 `tile.load(target_memory=Mat, transpose=True)` 携带。当 rhs 实质无 batch(`prod(rhs_batch_dims) == 1`)时,rhs 的 page 提取会被提到 unroll 循环外,整段只发射一次 rhs.load(省掉 B-1 次冗余 rhs 加载) | | `tile.batch_matmul_acc` | 展开为逐 batch 的 2D `tile.matmul_acc`,按 batch 索引切分(已展平的)累加器;累加器若不在 Acc 内存空间会插入显式 `tile.move(target_memory=Acc)` | | 其他 Tile 操作(>2D) | 替换变量,使用 2D 类型重新创建 | | 1D/2D Tile 操作 | 不变 | diff --git a/runtime b/runtime index 08f6f7693..551a79c00 160000 --- a/runtime +++ b/runtime @@ -1 +1 @@ -Subproject commit 08f6f76937f6be121b975f80188a615d3541bcfe +Subproject commit 551a79c00eab0f58c947338f7d4ba062a893ad7c diff --git a/src/ir/transforms/flatten_tile_nd_to_2d_pass.cpp b/src/ir/transforms/flatten_tile_nd_to_2d_pass.cpp index 154ad0e67..c9a280262 100644 --- a/src/ir/transforms/flatten_tile_nd_to_2d_pass.cpp +++ b/src/ir/transforms/flatten_tile_nd_to_2d_pass.cpp @@ -850,6 +850,19 @@ BatchMatmulResult LowerBatchMatmul(const AssignStmtPtr& assign, const CallPtr& c } } + // Hoist rhs.ExtractBatchPage when rhs has no effective batch (all rhs batch + // dims are 1, or rhs is 2D). In that case BuildOperandFlatBatchIndex would + // return 0 for every iteration, so ExtractBatchPage emits identical IR each + // round — emit it once outside the loop and reuse the resulting Var across + // all per-batch matmuls. Saves (batch_count - 1) redundant rhs loads. + const int64_t batch_count_rhs = MultiplyStaticDims(rhs_batch_dims, "tile.batch_matmul rhs batch size"); + std::optional hoisted_rhs_page; + if (batch_count_rhs == 1 && batch_count > 1) { + hoisted_rhs_page = ExtractBatchPage(rhs_info, rhs_dims, rhs_batch_dims, /*batch_index=*/0, "rhs", def_map, + ctx, op_registry, span); + out.stmts.insert(out.stmts.end(), hoisted_rhs_page->stmts.begin(), hoisted_rhs_page->stmts.end()); + } + // Unroll batch dimensions. for (int64_t i = 0; i < batch_count; ++i) { auto output_batch_indices = BuildBatchIndices(i, output_batch_dims); @@ -858,16 +871,24 @@ BatchMatmulResult LowerBatchMatmul(const AssignStmtPtr& assign, const CallPtr& c int64_t rhs_batch_idx = BuildOperandFlatBatchIndex(rhs_batch_dims, output_batch_dims, output_batch_indices); - // Extract 2D pages. + // Extract 2D pages. lhs always extracted per batch; rhs reuses the hoisted + // var when applicable. auto lhs_page = ExtractBatchPage(lhs_info, lhs_dims, lhs_batch_dims, lhs_batch_idx, "lhs", def_map, ctx, op_registry, span); - auto rhs_page = ExtractBatchPage(rhs_info, rhs_dims, rhs_batch_dims, rhs_batch_idx, "rhs", def_map, ctx, - op_registry, span); out.stmts.insert(out.stmts.end(), lhs_page.stmts.begin(), lhs_page.stmts.end()); - out.stmts.insert(out.stmts.end(), rhs_page.stmts.begin(), rhs_page.stmts.end()); + + VarPtr rhs_var; + if (hoisted_rhs_page) { + rhs_var = hoisted_rhs_page->var; + } else { + auto rhs_page = ExtractBatchPage(rhs_info, rhs_dims, rhs_batch_dims, rhs_batch_idx, "rhs", def_map, ctx, + op_registry, span); + out.stmts.insert(out.stmts.end(), rhs_page.stmts.begin(), rhs_page.stmts.end()); + rhs_var = rhs_page.var; + } // Emit tile.matmul. - auto matmul = op_registry.Create("tile.matmul", {lhs_page.var, rhs_page.var}, span); + auto matmul = op_registry.Create("tile.matmul", {lhs_page.var, rhs_var}, span); auto matmul_var = std::make_shared("matmul_" + std::to_string(i), matmul->GetType(), span); out.stmts.push_back(std::make_shared(matmul_var, matmul, span)); @@ -1147,6 +1168,11 @@ std::vector TransformBody(const std::vector& stmts, FlattenCon // from the original tensor, the full-batch load becomes dead code. Skip emitting // it to avoid wasted memory and potential hardware pipeline interference. // + // Exception: when LowerBatchMatmul will take the unified single-matmul fast path, + // the operand loads are NOT dead — the fast path uses tile.reshape on the original + // load result, not Strategy-1 per-batch re-emission. Operands of such batch_matmuls + // must stay in the IR. + // // Safety: we count ALL Var references across every statement type (Return, Yield, // If conditions, For/While bounds, etc.), not just Call arguments. A Var used // anywhere outside a tile.batch_matmul Call prevents it from being skipped. diff --git a/tests/st/runtime/test_batch_matmul.py b/tests/st/runtime/test_batch_matmul.py index c9831dea7..545457512 100644 --- a/tests/st/runtime/test_batch_matmul.py +++ b/tests/st/runtime/test_batch_matmul.py @@ -205,6 +205,68 @@ def compute_expected(self, tensors, params=None): tensors["c"][:] = torch.bmm(tensors["a"].transpose(-2, -1), tensors["b"]) +class TestBatchMatmulLhsNdRhs2d(PTOTestCase): + """Tile-level batch matmul with rhs as a true 2D (non-batched) operand. + + Exercises FlattenTileNdTo2D's rhs-hoist optimization: lhs is `[B, M, K]`, + rhs is shared `[K, N]` across all batches, so the lowering hoists the + rhs.load out of the per-batch unroll loop (saves B-1 redundant loads). + """ + + __test__ = False + + def __init__(self, batch: int = 2, m: int = 64, k: int = 64, n: int = 64, config=None): + super().__init__(config) + self.batch = batch + self.M = m + self.K = k + self.N = n + + def get_name(self) -> str: + return f"batch_matmul_lhs_nd_rhs_2d_{self.batch}x{self.M}x{self.K}x{self.N}" + + def define_tensors(self) -> list[TensorSpec]: + return [ + TensorSpec("a", [self.batch, self.M, self.K], DataType.FP32, init_value=torch.randn), + TensorSpec("b", [self.K, self.N], DataType.FP32, init_value=torch.randn), + TensorSpec("c", [self.batch, self.M, self.N], DataType.FP32, is_output=True), + ] + + def get_program(self) -> Any: + B, M, K, N = self.batch, self.M, self.K, self.N + + @pl.program + class BatchMatmulLhsNdRhs2dProgram: + @pl.function(type=pl.FunctionType.InCore) + def batch_matmul_lhs_nd_rhs_2d( + self, + a: pl.Tensor[[B, M, K], pl.FP32], + b: pl.Tensor[[K, N], pl.FP32], + c: pl.Out[pl.Tensor[[B, M, N], pl.FP32]], + ) -> pl.Tensor[[B, M, N], pl.FP32]: + tile_a = pl.load(a, offsets=[0, 0, 0], shapes=[B, M, K], target_memory=pl.MemorySpace.Mat) + tile_b = pl.load(b, offsets=[0, 0], shapes=[K, N], target_memory=pl.MemorySpace.Mat) + tile_c = pl.batch_matmul(tile_a, tile_b) + out_c = pl.store(tile_c, offsets=[0, 0, 0], output_tensor=c) + return out_c + + @pl.function(type=pl.FunctionType.Orchestration) + def orchestrator( + self, + a: pl.Tensor[[B, M, K], pl.FP32], + b: pl.Tensor[[K, N], pl.FP32], + c: pl.Out[pl.Tensor[[B, M, N], pl.FP32]], + ) -> pl.Tensor[[B, M, N], pl.FP32]: + out_c = self.batch_matmul_lhs_nd_rhs_2d(a, b, c) + return out_c + + return BatchMatmulLhsNdRhs2dProgram + + def compute_expected(self, tensors, params=None): + # torch.matmul broadcasts 2D rhs across all batches of 3D lhs. + tensors["c"][:] = torch.matmul(tensors["a"], tensors["b"]) + + class TestBatchMatmulOperations: """Test suite for tile-level batch matrix multiplication. @@ -260,6 +322,21 @@ def test_batch_matmul_a_transpose(self, test_runner, batch, m, k, n): result = test_runner.run(test_case) assert result.passed, f"Test failed (A-trans): {result.error}" + @pytest.mark.parametrize( + "batch,m,k,n", + [ + (2, 64, 64, 64), + (4, 32, 64, 32), + # batch=1 falls through the existing batch_count==1 fast path. + (1, 64, 64, 64), + ], + ) + def test_batch_matmul_lhs_nd_rhs_2d(self, test_runner, batch, m, k, n): + """lhs ND × rhs 2D triggers FlattenTileNdTo2D's rhs-hoist optimization.""" + test_case = TestBatchMatmulLhsNdRhs2d(batch=batch, m=m, k=k, n=n) + result = test_runner.run(test_case) + assert result.passed, f"Test failed (lhs ND × rhs 2D): {result.error}" + if __name__ == "__main__": pytest.main([__file__, "-v", "--forked"]) diff --git a/tests/ut/ir/transforms/test_flatten_tile_nd_to_2d.py b/tests/ut/ir/transforms/test_flatten_tile_nd_to_2d.py index db3a72c71..fee00b996 100644 --- a/tests/ut/ir/transforms/test_flatten_tile_nd_to_2d.py +++ b/tests/ut/ir/transforms/test_flatten_tile_nd_to_2d.py @@ -1507,6 +1507,169 @@ def load_tile_shape(shape: list[int], transpose: bool) -> list[int]: "expected_store_shapes" ] + def test_batch_matmul_rhs_2d_hoists_rhs_load(self): + """lhs [B, M, K] x rhs [K, N] (B>1): rhs.load is hoisted out of the per-batch loop. + + Saves (B-1) redundant rhs tile.load calls. Per-batch loop still emits B copies of + (lhs.load, matmul, store), but only ONE rhs.load is emitted for the whole sequence. + """ + + @pl.program + class Before: + @pl.function(type=pl.FunctionType.InCore) + def main_incore_0( + self, + lhs: pl.Tensor[[2, 16, 128], pl.FP16], + rhs: pl.Tensor[[128, 64], pl.FP16], + out_0: pl.Out[pl.Tensor[[2, 16, 64], pl.FP16]], + ) -> pl.Tensor[[2, 16, 64], pl.FP16]: + lhs_tile: pl.Tile[[2, 16, 128], pl.FP16] = pl.load( + lhs, [0, 0, 0], [2, 16, 128], target_memory=pl.MemorySpace.Mat + ) + rhs_tile: pl.Tile[[128, 64], pl.FP16] = pl.load( + rhs, [0, 0], [128, 64], target_memory=pl.MemorySpace.Mat + ) + out_tile: pl.Tile[[2, 16, 64], pl.FP32] = pl.tile.batch_matmul(lhs_tile, rhs_tile) + out_0 = pl.store(out_tile, [0, 0, 0], out_0) + return out_0 + + @pl.function + def main( + self, + lhs: pl.Tensor[[2, 16, 128], pl.FP16], + rhs: pl.Tensor[[128, 64], pl.FP16], + ) -> pl.Tensor[[2, 16, 64], pl.FP16]: + out_0 = pl.create_tensor([2, 16, 64], dtype=pl.FP16) + y = self.main_incore_0(lhs, rhs, out_0) + return y + + func = self._flattened_incore(Before) + calls = self._top_level_calls(func) + # Expect: 1 rhs.load (hoisted), then per-batch (lhs.load, matmul, store) × 2. + assert [call.op.name for call in calls] == [ + "tile.load", # hoisted rhs + "tile.load", # batch 0 lhs + "tile.matmul", + "tile.store", + "tile.load", # batch 1 lhs + "tile.matmul", + "tile.store", + ] + load_calls = [call for call in calls if call.op.name == "tile.load"] + # First load is rhs (2D, no batch shift). Subsequent loads are per-batch lhs. + assert [self._tuple_const_values(call.args[1]) for call in load_calls] == [ + [0, 0], # hoisted rhs offset + [0, 0, 0], # batch 0 lhs + [1, 0, 0], # batch 1 lhs + ] + assert [self._tuple_const_values(call.args[2]) for call in load_calls] == [ + [128, 64], # rhs shape + [1, 16, 128], # per-batch lhs shape (single batch slice) + [1, 16, 128], + ] + + def test_batch_matmul_rhs_size1_batch_hoists_rhs_load(self): + """lhs [B, M, K] x rhs [1, K, N] (B>1): rhs hoist also fires for size-1-batch rhs. + + rhs has a leading batch dim of 1 (broadcasted across all output batches), so + ExtractBatchPage with batch_idx=0 gives the same result every iteration. Hoist. + """ + + @pl.program + class Before: + @pl.function(type=pl.FunctionType.InCore) + def main_incore_0( + self, + lhs: pl.Tensor[[3, 16, 128], pl.FP16], + rhs: pl.Tensor[[1, 128, 64], pl.FP16], + out_0: pl.Out[pl.Tensor[[3, 16, 64], pl.FP16]], + ) -> pl.Tensor[[3, 16, 64], pl.FP16]: + lhs_tile: pl.Tile[[3, 16, 128], pl.FP16] = pl.load( + lhs, [0, 0, 0], [3, 16, 128], target_memory=pl.MemorySpace.Mat + ) + rhs_tile: pl.Tile[[1, 128, 64], pl.FP16] = pl.load( + rhs, [0, 0, 0], [1, 128, 64], target_memory=pl.MemorySpace.Mat + ) + out_tile: pl.Tile[[3, 16, 64], pl.FP32] = pl.tile.batch_matmul(lhs_tile, rhs_tile) + out_0 = pl.store(out_tile, [0, 0, 0], out_0) + return out_0 + + @pl.function + def main( + self, + lhs: pl.Tensor[[3, 16, 128], pl.FP16], + rhs: pl.Tensor[[1, 128, 64], pl.FP16], + ) -> pl.Tensor[[3, 16, 64], pl.FP16]: + out_0 = pl.create_tensor([3, 16, 64], dtype=pl.FP16) + y = self.main_incore_0(lhs, rhs, out_0) + return y + + func = self._flattened_incore(Before) + op_names = [call.op.name for call in self._top_level_calls(func)] + # 1 hoisted rhs.load + 3 × (lhs.load, matmul, store). + assert op_names == [ + "tile.load", # hoisted rhs + "tile.load", + "tile.matmul", + "tile.store", + "tile.load", + "tile.matmul", + "tile.store", + "tile.load", + "tile.matmul", + "tile.store", + ] + # Exactly one load whose shape is the rhs shape (rest are lhs). + load_calls = [call for call in self._top_level_calls(func) if call.op.name == "tile.load"] + rhs_loads = [call for call in load_calls if self._tuple_const_values(call.args[2]) == [1, 128, 64]] + assert len(rhs_loads) == 1, f"expected 1 hoisted rhs load, got {len(rhs_loads)}" + + def test_batch_matmul_both_batched_no_hoist(self): + """lhs [B, M, K] x rhs [B, K, N] (B>1): rhs varies per batch, hoist must NOT fire.""" + + @pl.program + class Before: + @pl.function(type=pl.FunctionType.InCore) + def main_incore_0( + self, + lhs: pl.Tensor[[2, 16, 128], pl.FP16], + rhs: pl.Tensor[[2, 128, 64], pl.FP16], + out_0: pl.Out[pl.Tensor[[2, 16, 64], pl.FP16]], + ) -> pl.Tensor[[2, 16, 64], pl.FP16]: + lhs_tile: pl.Tile[[2, 16, 128], pl.FP16] = pl.load( + lhs, [0, 0, 0], [2, 16, 128], target_memory=pl.MemorySpace.Mat + ) + rhs_tile: pl.Tile[[2, 128, 64], pl.FP16] = pl.load( + rhs, [0, 0, 0], [2, 128, 64], target_memory=pl.MemorySpace.Mat + ) + out_tile: pl.Tile[[2, 16, 64], pl.FP32] = pl.tile.batch_matmul(lhs_tile, rhs_tile) + out_0 = pl.store(out_tile, [0, 0, 0], out_0) + return out_0 + + @pl.function + def main( + self, + lhs: pl.Tensor[[2, 16, 128], pl.FP16], + rhs: pl.Tensor[[2, 128, 64], pl.FP16], + ) -> pl.Tensor[[2, 16, 64], pl.FP16]: + out_0 = pl.create_tensor([2, 16, 64], dtype=pl.FP16) + y = self.main_incore_0(lhs, rhs, out_0) + return y + + func = self._flattened_incore(Before) + op_names = [call.op.name for call in self._top_level_calls(func)] + # Both batched -> no hoist: 2 × (lhs.load, rhs.load, matmul, store). + assert op_names == [ + "tile.load", + "tile.load", + "tile.matmul", + "tile.store", + "tile.load", + "tile.load", + "tile.matmul", + "tile.store", + ] + # ---------------------------------------------------------------------------- # tile.batch_matmul_acc lowering