diff --git a/include/pypto/codegen/pto/pto_codegen.h b/include/pypto/codegen/pto/pto_codegen.h index 39a717480..9a40a1559 100644 --- a/include/pypto/codegen/pto/pto_codegen.h +++ b/include/pypto/codegen/pto/pto_codegen.h @@ -496,6 +496,7 @@ class PTOCodegen : public CodegenBase { std::map memref_to_var_name; ///< keyed by base_ Ptr std::vector>> tile_var_allocs; std::set emitted_tile_alloc_vars; + std::set emitted_tile_alloc_ssas; std::map tpop_result_vars; ir::FunctionPtr current_function; @@ -535,6 +536,7 @@ class PTOCodegen : public CodegenBase { memref_to_var_name.clear(); tile_var_allocs.clear(); emitted_tile_alloc_vars.clear(); + emitted_tile_alloc_ssas.clear(); tpop_result_vars.clear(); current_function.reset(); diff --git a/src/backend/common/pto_ops_common.cpp b/src/backend/common/pto_ops_common.cpp index 16cffc53e..31244354c 100644 --- a/src/backend/common/pto_ops_common.cpp +++ b/src/backend/common/pto_ops_common.cpp @@ -2365,22 +2365,25 @@ void RegisterPTOOps(Backend& backend, const std::unordered_set& exc }); // In-place accumulation ops (matmul_acc, gemv_acc): ptoas expects the - // accumulator in ins() to be the same SSA value as outs(). InitMemRef - // guarantees that the output shares the MemRef of the accumulator input - // (via set_output_reuses_input), so we use the result buffer (dst) as the - // accumulator operand instead of the IR-level input arg. + // accumulator in ins() to be the same SSA value as outs(). Multiple IR tile + // vars may share one MemRef but still have distinct per-var alloc_tile SSA + // names, so use the accumulator input's SSA value as both src and dst and + // bind the assignment result to that same SSA value. auto make_acc_codegen = [](const std::string& pto_op) { return [pto_op](const ir::CallPtr& op, codegen::CodegenBase& codegen_base) -> std::string { auto& codegen = dynamic_cast(codegen_base); CHECK(op->args_.size() == 3) << pto_op << " requires 3 arguments: acc, lhs, rhs"; - std::string dst = codegen.GetCurrentResultTarget(); + std::string dst = codegen.GetExprAsCode(op->args_[0]); std::string lhs = codegen.GetExprAsCode(op->args_[1]); std::string rhs = codegen.GetExprAsCode(op->args_[2]); - std::string dst_type = codegen.GetCurrentResultTileBufTypeString(); + std::string dst_type = codegen.GetExprTypeAnnotation(op->args_[0]); std::string lhs_type = codegen.GetExprTypeAnnotation(op->args_[1]); std::string rhs_type = codegen.GetExprTypeAnnotation(op->args_[2]); + INTERNAL_CHECK_SPAN(!dst.empty(), op->span_) << pto_op << " accumulator operand has no tile buffer"; + codegen.SetCurrentResultBuf(dst); + std::ostringstream acc_inst; acc_inst << pto_op << " ins(" << dst << ", " << lhs << ", " << rhs; std::vector ins_type_parts; diff --git a/src/codegen/pto/pto_codegen.cpp b/src/codegen/pto/pto_codegen.cpp index cdd04b3dc..7ee1243af 100644 --- a/src/codegen/pto/pto_codegen.cpp +++ b/src/codegen/pto/pto_codegen.cpp @@ -25,6 +25,7 @@ #include #include #include +#include #include #include @@ -32,6 +33,7 @@ #include "pypto/backend/common/backend_config.h" #include "pypto/backend/common/backend_handler.h" #include "pypto/codegen/pto/pto_type_utils.h" +#include "pypto/codegen/pto/tile_buf_signature.h" #include "pypto/core/dtype.h" #include "pypto/core/logging.h" #include "pypto/ir/expr.h" @@ -145,6 +147,42 @@ bool ShouldAliasScatterUpdateResultToInput(const AssignStmtPtr& stmt) { const auto& FlattenBody = transform_utils::FlattenToStmts; +bool IsInPlaceAccumulatorCall(const CallPtr& call) { + if (!call || !call->op_) return false; + return call->op_->name_ == "tile.matmul_acc" || call->op_->name_ == "tile.gemv_acc"; +} + +bool HasStaticAllocTileShape(const std::shared_ptr& tile_type) { + if (!tile_type) return false; + if (tile_type->shape_.size() == 1) { + return As(tile_type->shape_[0]) != nullptr; + } + if (tile_type->shape_.size() >= 2) { + return As(tile_type->shape_[0]) != nullptr && + As(tile_type->shape_[1]) != nullptr; + } + return false; +} + +bool IsMatmulOperandBuffer(const std::shared_ptr& tile_type) { + if (!tile_type) return false; + return tile_type->memory_space_ == ir::MemorySpace::Left || + tile_type->memory_space_ == ir::MemorySpace::Right; +} + +std::optional GetStaticMatmulOperandReuseKey(const std::shared_ptr& tile_type, + const std::string& type_str) { + if (!IsMatmulOperandBuffer(tile_type)) return std::nullopt; + if (!HasStaticAllocTileShape(tile_type)) return std::nullopt; + + const auto sig = TileBufSignature::FromTileType(*tile_type); + if (sig.v_row_dynamic || sig.v_col_dynamic) return std::nullopt; + + std::ostringstream key; + key << type_str << "|v_row=" << sig.v_row << "|v_col=" << sig.v_col; + return key.str(); +} + } // namespace // Visitor to collect all MemRef objects from TileType variables @@ -333,20 +371,42 @@ void PTOCodegen::GenerateFunction(const FunctionPtr& func) { // Still collect fs_.memref_to_tile_type for GetTileBufTypeString fallback paths fs_.memref_to_tile_type = collector.GetMemRefTileTypes(); - // Per-var SSA binding: each tile variable gets its own SSA name + // Tile-buffer SSA binding. A PTO tile_buf SSA denotes a mutable tile handle, + // not an immutable value. Reuse handles only for static Left/Right matmul + // operand buffers that MemoryReuse placed in the same physical L0A/L0B + // MemRef; PTOAS otherwise cannot see the WAR/WAW dependency between a tmov + // writing the reused buffer and the following tmatmul consuming it. Keep the + // per-variable model for Vec/Mat/Acc, dynamic byte offsets, and + // shape/view-distinct signatures so unrelated ST kernels preserve their + // existing scheduling surface. + std::map, std::string> matmul_operand_reuse; for (const auto& [tile_var, tile_type] : fs_.tile_var_allocs) { - std::string ssa_name = NewNamedTemp(tile_var->name_hint_); - BindVarToMlir(tile_var, ssa_name); - // Pre-populate type so body visitors (e.g., tile.reshape no-op check) // can query it before per-variable alloc_tile emission runs. std::string type_str = GetTileBufTypeStringFromTileType(tile_type); - fs_.ssa_to_tile_buf_type[ssa_name] = type_str; - auto memref = ir::GetDefinedMemRef(tile_type); + const ir::Var* base_ptr = memref->base_.get(); + + std::string ssa_name; + auto reuse_key = GetStaticMatmulOperandReuseKey(tile_type, type_str); + auto const_offset = As(memref->byte_offset_); + if (reuse_key.has_value() && const_offset && fs_.tpop_result_vars.count(tile_var.get()) == 0) { + auto key = std::make_tuple(base_ptr, const_offset->value_, *reuse_key); + auto reuse_it = matmul_operand_reuse.find(key); + if (reuse_it != matmul_operand_reuse.end()) { + ssa_name = reuse_it->second; + } else { + ssa_name = NewNamedTemp(tile_var->name_hint_); + matmul_operand_reuse.emplace(std::move(key), ssa_name); + } + } else { + ssa_name = NewNamedTemp(tile_var->name_hint_); + } + + BindVarToMlir(tile_var, ssa_name); + fs_.ssa_to_tile_buf_type[ssa_name] = type_str; // Also maintain fs_.memref_to_mlir for compatibility (first var per allocation) - const ir::Var* base_ptr = memref->base_.get(); if (fs_.memref_to_mlir.find(base_ptr) == fs_.memref_to_mlir.end()) { fs_.memref_to_mlir[base_ptr] = ssa_name; } @@ -822,6 +882,9 @@ void PTOCodegen::EmitAllocTileForVar(const ir::VarPtr& tile_var, INTERNAL_CHECK_SPAN(mlir_it != fs_.var_to_mlir.end(), tile_var->span_) << "Tile var " << tile_var->name_hint_ << " not found in fs_.var_to_mlir"; std::string tile_buf = mlir_it->second; + if (!fs_.emitted_tile_alloc_ssas.insert(tile_buf).second) { + return; + } AllocTileFields fields = ComputeAllocTileFields(tile_type); @@ -1059,7 +1122,7 @@ void PTOCodegen::VisitStmt_(const AssignStmtPtr& op) { if (auto tile_type = ir::GetTileTypeWithMemRef(op->var_->GetType())) { if (!is_set_validshape && fs_.tpop_result_vars.count(op->var_.get()) == 0 && - !alias_scatter_result_to_input) { + !alias_scatter_result_to_input && !IsInPlaceAccumulatorCall(call)) { EmitAllocTileForVar(op->var_, tile_type); } } diff --git a/tests/st/runtime/test_matmul.py b/tests/st/runtime/test_matmul.py index 570ecefc9..8a7ad48bb 100644 --- a/tests/st/runtime/test_matmul.py +++ b/tests/st/runtime/test_matmul.py @@ -23,6 +23,7 @@ class can run on multiple platforms via ``@pytest.mark.parametrize``. import torch from examples.kernels.matmul import matmul_acc_64 from harness.core.harness import PLATFORMS, DataType, PTOTestCase, TensorSpec +from pypto.runtime.runner import RunConfig class TestMatmul(PTOTestCase): @@ -463,6 +464,70 @@ def compute_expected(self, tensors, params=None): tensors["c"][:] = torch.matmul(tensors["a"].to(torch.float32), tensors["b"].to(torch.float32)) +class TestMatmulAccBTransposeNopad(PTOTestCase): + """Issue #1213: C = X @ W^T with b_trans=True and K split across matmul_acc.""" + + __test__ = False + + def __init__( + self, + dtype: DataType = DataType.FP32, + *, + platform: str | None = None, + config=None, + ): + super().__init__(config or RunConfig(rtol=4e-3, atol=4e-3), platform=platform) + self.M = 16 + self.K = 1024 + self.N = 32 + self.K_CHUNK = 512 + self.dtype = dtype + + def get_name(self) -> str: + return f"matmulacc_btranspose_nopad_{self.dtype.value}_{self.M}x{self.K}x{self.N}" + + def define_tensors(self) -> list[TensorSpec]: + return [ + TensorSpec("x", [self.M, self.K], self.dtype, init_value=torch.randn), + TensorSpec("w", [self.N, self.K], self.dtype, init_value=torch.randn), + TensorSpec("out", [self.M, self.N], DataType.FP32, is_output=True), + ] + + def get_program(self) -> Any: + M, K, N, K_CHUNK = self.M, self.K, self.N, self.K_CHUNK + K_BLOCKS = K // K_CHUNK + elem_dtype = pl.FP32 if self.dtype is DataType.FP32 else pl.BF16 + + @pl.program + class MatmulAccBTransposeNopadProgram: + @pl.function(type=pl.FunctionType.Opaque) + def main( + self, + x: pl.Tensor[[M, K], elem_dtype], + w: pl.Tensor[[N, K], elem_dtype], + out: pl.Out[pl.Tensor[[M, N], pl.FP32]], + ): + with pl.at( + level=pl.Level.CORE_GROUP, + optimization=pl.chunked_loop_optimizer, + name_hint="linear", + ): + x0 = pl.slice(x, [M, K_CHUNK], [0, 0]) + w0 = pl.slice(w, [N, K_CHUNK], [0, 0]) + acc = pl.matmul(x0, w0, b_trans=True, out_dtype=pl.FP32) + for kb in pl.range(1, K_BLOCKS): + k0 = kb * K_CHUNK + x_chunk = pl.slice(x, [M, K_CHUNK], [0, k0]) + w_chunk = pl.slice(w, [N, K_CHUNK], [0, k0]) + acc = pl.matmul_acc(acc, x_chunk, w_chunk, b_trans=True) + out = pl.assemble(out, acc, [0, 0]) + + return MatmulAccBTransposeNopadProgram + + def compute_expected(self, tensors, params=None): + tensors["out"][:] = tensors["x"].float() @ tensors["w"].float().T + + # ============================================================================= # pytest test functions # ============================================================================= @@ -484,6 +549,8 @@ def compute_expected(self, tensors, params=None): # (BATCH=16, K_CHUNK=128, OUT_CHUNK=256). Same 2-iter K-loop, BF16 inputs + # FP32 accumulator. _AUTOL0_BF16_SHAPES = [(16, 128, 256)] +_ISSUE1213_DTYPES = [pytest.param(DataType.FP32, id="fp32"), pytest.param(DataType.BF16, id="bf16")] +_A2A3_ONLY = [pytest.param("a2a3", id="a2a3")] class TestMatmulOperations: @@ -559,6 +626,13 @@ def test_matmul_outer_pipelined_bf16(self, test_runner, platform): result = test_runner.run(TestMatmulOuterPipelinedBF16(platform=platform)) assert result.passed, f"Test failed: {result.error}" + @pytest.mark.parametrize("platform", _A2A3_ONLY) + @pytest.mark.parametrize("dtype", _ISSUE1213_DTYPES) + def test_matmulacc_btranspose_nopad_issue1213(self, test_runner, platform, dtype): + """Regression for b_trans=True matmul_acc over reused L0A/L0B buffers.""" + result = test_runner.run(TestMatmulAccBTransposeNopad(dtype=dtype, platform=platform)) + assert result.passed, f"Test failed: {result.error}" + if __name__ == "__main__": pytest.main([__file__, "-v"]) diff --git a/tests/ut/codegen/test_pto_codegen.py b/tests/ut/codegen/test_pto_codegen.py index 3b6f16020..7ce087d5a 100644 --- a/tests/ut/codegen/test_pto_codegen.py +++ b/tests/ut/codegen/test_pto_codegen.py @@ -102,6 +102,19 @@ def _generate_default_mlir(program_cls) -> str: return _generate_mlir(_run_default_passes(program_cls)) +def _generate_default_mlir_for_func_type(program_cls, func_type: ir.FunctionType) -> str: + """Run default passes and generate MLIR for the single function of func_type.""" + transformed = _run_default_passes(program_cls) + funcs = [func for func in transformed.functions.values() if func.func_type == func_type] + assert len(funcs) == 1, ( + f"Expected exactly one {func_type} function, got " + f"{[(func.name, func.func_type) for func in transformed.functions.values()]}" + ) + func = funcs[0] + program_name = getattr(program_cls, "__name__", "program") + return _generate_mlir(ir.Program([func], f"{program_name}_{func_type.name}", func.span)) + + def _get_mlir_lines(mlir_code: str) -> list[str]: """Return stripped MLIR lines for line-oriented assertions.""" return [line.strip() for line in mlir_code.splitlines()] @@ -377,6 +390,150 @@ def test_pto_codegen_fillpad_shared_memref_uses_single_alloc_tile(): ) +def test_pto_codegen_vec_shared_static_memref_keeps_per_var_ssa_handles(): + """Static Vec tiles sharing storage must keep distinct handles for ST scheduling.""" + span = ir.Span.unknown() + zero = ir.ConstInt(0, DataType.INDEX, span) + offset = ir.ConstInt(4096, DataType.INT64, span) + size = ir.ConstInt(32, DataType.INDEX, span) + + lhs_tensor = ir.Var("lhs", ir.TensorType([32, 32], DataType.FP32), span) + rhs_tensor = ir.Var("rhs", ir.TensorType([32, 32], DataType.FP32), span) + output_tensor = ir.Var("out", ir.TensorType([32, 32], DataType.FP32), span) + shared_memref = ir.MemRef(ir.MemorySpace.Vec, zero, 32 * 32 * 4, 0) + result_memref = ir.MemRef(ir.MemorySpace.Vec, offset, 32 * 32 * 4, 1) + + static_view = ir.TileView(valid_shape=[size, size]) + lhs_tile_type = ir.TileType([32, 32], DataType.FP32, shared_memref, static_view, ir.MemorySpace.Vec) + rhs_tile_type = ir.TileType([32, 32], DataType.FP32, shared_memref, static_view, ir.MemorySpace.Vec) + result_tile_type = ir.TileType([32, 32], DataType.FP32, result_memref, static_view, ir.MemorySpace.Vec) + lhs_tile = ir.Var("lhs_tile", lhs_tile_type, span) + rhs_tile = ir.Var("rhs_tile", rhs_tile_type, span) + result_tile = ir.Var("result_tile", result_tile_type, span) + result_tensor = ir.Var("result", ir.TensorType([32, 32], DataType.FP32), span) + + offsets = ir.MakeTuple([zero, zero], span) + shapes = ir.MakeTuple([size, size], span) + load_lhs = ir.Call(ir.Op("tile.load"), [lhs_tensor, offsets, shapes], {}, lhs_tile_type, span) + load_rhs = ir.Call(ir.Op("tile.load"), [rhs_tensor, offsets, shapes], {}, rhs_tile_type, span) + add = ir.Call(ir.Op("tile.add"), [lhs_tile, rhs_tile], {}, result_tile_type, span) + store = ir.Call(ir.Op("tile.store"), [result_tile, offsets, output_tensor], result_tensor.type, span) + + body = ir.SeqStmts( + [ + ir.AssignStmt(lhs_tile, load_lhs, span), + ir.AssignStmt(rhs_tile, load_rhs, span), + ir.AssignStmt(result_tile, add, span), + ir.AssignStmt(result_tensor, store, span), + ir.ReturnStmt([result_tensor], span), + ], + span, + ) + func = ir.Function( + "vec_shared_static", + [ + (lhs_tensor, ir.ParamDirection.In), + (rhs_tensor, ir.ParamDirection.In), + (output_tensor, ir.ParamDirection.Out), + ], + [ir.TensorType([32, 32], DataType.FP32)], + body, + span, + ir.FunctionType.InCore, + ) + + alloc_lines = _get_alloc_tile_lines(_generate_mlir(ir.Program([func], "vec_shared_static", span))) + shared_allocs = [line for line in alloc_lines if "addr = %c0_i64" in line] + assert len(shared_allocs) == 2, f"Expected distinct Vec handles for shared storage: {alloc_lines}" + assert shared_allocs[0].split(" = ", 1)[0] != shared_allocs[1].split(" = ", 1)[0] + + +def test_pto_codegen_matmul_operand_reuse_respects_byte_offset(): + """Left/Right tiles with one base but different offsets need distinct handles.""" + span = ir.Span.unknown() + zero = ir.ConstInt(0, DataType.INT64, span) + second_tile_offset = ir.ConstInt(16 * 512 * 4, DataType.INT64, span) + index_zero = ir.ConstInt(0, DataType.INDEX, span) + rows = ir.ConstInt(16, DataType.INDEX, span) + cols = ir.ConstInt(512, DataType.INDEX, span) + + lhs_tensor = ir.Var("lhs", ir.TensorType([32, 512], DataType.FP32), span) + rhs_tensor = ir.Var("rhs", ir.TensorType([32, 512], DataType.FP32), span) + out_tensor = ir.Var("out", ir.TensorType([16, 16], DataType.FP32), span) + + lhs_first_memref = ir.MemRef(ir.MemorySpace.Left, zero, 16 * 512 * 4, 0) + lhs_second_memref = ir.MemRef(lhs_first_memref.base_, second_tile_offset, 16 * 512 * 4) + rhs_first_memref = ir.MemRef(ir.MemorySpace.Right, zero, 16 * 512 * 4, 1) + rhs_second_memref = ir.MemRef(rhs_first_memref.base_, second_tile_offset, 16 * 512 * 4) + acc_memref = ir.MemRef(ir.MemorySpace.Acc, zero, 16 * 16 * 4, 2) + + lhs_first_type = ir.TileType([16, 512], DataType.FP32, lhs_first_memref, None, ir.MemorySpace.Left) + lhs_second_type = ir.TileType([16, 512], DataType.FP32, lhs_second_memref, None, ir.MemorySpace.Left) + rhs_first_type = ir.TileType([16, 512], DataType.FP32, rhs_first_memref, None, ir.MemorySpace.Right) + rhs_second_type = ir.TileType([16, 512], DataType.FP32, rhs_second_memref, None, ir.MemorySpace.Right) + acc_type = ir.TileType( + [16, 16], + DataType.FP32, + acc_memref, + ir.TileView(valid_shape=[rows, rows]), + ir.MemorySpace.Acc, + ) + + lhs_first = ir.Var("lhs_first", lhs_first_type, span) + lhs_second = ir.Var("lhs_second", lhs_second_type, span) + rhs_first = ir.Var("rhs_first", rhs_first_type, span) + rhs_second = ir.Var("rhs_second", rhs_second_type, span) + acc = ir.Var("acc", acc_type, span) + result_tensor = ir.Var("result", ir.TensorType([16, 16], DataType.FP32), span) + + offsets = ir.MakeTuple([index_zero, index_zero], span) + shapes = ir.MakeTuple([rows, cols], span) + acc_offsets = ir.MakeTuple([index_zero, index_zero], span) + load_lhs_first = ir.Call(ir.Op("tile.load"), [lhs_tensor, offsets, shapes], {}, lhs_first_type, span) + load_lhs_second = ir.Call(ir.Op("tile.load"), [lhs_tensor, offsets, shapes], {}, lhs_second_type, span) + load_rhs_first = ir.Call(ir.Op("tile.load"), [rhs_tensor, offsets, shapes], {}, rhs_first_type, span) + load_rhs_second = ir.Call(ir.Op("tile.load"), [rhs_tensor, offsets, shapes], {}, rhs_second_type, span) + matmul = ir.Call(ir.Op("tile.matmul"), [lhs_first, rhs_first], {}, acc_type, span) + store = ir.Call(ir.Op("tile.store"), [acc, acc_offsets, out_tensor], result_tensor.type, span) + body = ir.SeqStmts( + [ + ir.AssignStmt(lhs_first, load_lhs_first, span), + ir.AssignStmt(lhs_second, load_lhs_second, span), + ir.AssignStmt(rhs_first, load_rhs_first, span), + ir.AssignStmt(rhs_second, load_rhs_second, span), + ir.AssignStmt(acc, matmul, span), + ir.AssignStmt(result_tensor, store, span), + ir.ReturnStmt([result_tensor], span), + ], + span, + ) + func = ir.Function( + "matmul_operand_offsets", + [ + (lhs_tensor, ir.ParamDirection.In), + (rhs_tensor, ir.ParamDirection.In), + (out_tensor, ir.ParamDirection.Out), + ], + [ir.TensorType([16, 16], DataType.FP32)], + body, + span, + ir.FunctionType.InCore, + ) + + alloc_lines = _get_alloc_tile_lines(_generate_mlir(ir.Program([func], "matmul_operand_offsets", span))) + lhs_allocs = [line for line in alloc_lines if "loc=left" in line] + rhs_allocs = [line for line in alloc_lines if "loc=right" in line] + + assert len(lhs_allocs) == 2, f"Expected separate L0A handles for distinct offsets: {alloc_lines}" + assert len(rhs_allocs) == 2, f"Expected separate L0B handles for distinct offsets: {alloc_lines}" + assert lhs_allocs[0].split(" = ", 1)[0] != lhs_allocs[1].split(" = ", 1)[0] + assert rhs_allocs[0].split(" = ", 1)[0] != rhs_allocs[1].split(" = ", 1)[0] + assert any("addr = %c0_i64" in line for line in lhs_allocs) + assert any("addr = %c32768_i64" in line for line in lhs_allocs) + assert any("addr = %c0_i64" in line for line in rhs_allocs) + assert any("addr = %c32768_i64" in line for line in rhs_allocs) + + def test_pto_codegen_fillpad_inplace(): """Test that tile.fillpad_inplace emits pto.tfillpad and shares MemRef with input.""" span = ir.Span.unknown() @@ -1577,6 +1734,57 @@ def mixed( assert "index" in yield_line, f"Expected index type in scf.yield: {yield_line}" +def test_pto_codegen_matmul_acc_uses_loop_carried_accumulator_buffer(): + """matmul_acc in a tile-carried loop must preserve shared tile-buffer SSA dependencies.""" + + @pl.program + class MatmulAccLoopProgram: + @pl.function(type=pl.FunctionType.Opaque) + def main( + self, + x: pl.Tensor[[16, 1024], pl.FP32], + w: pl.Tensor[[32, 1024], pl.FP32], + out: pl.Out[pl.Tensor[[16, 32], pl.FP32]], + ): + with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer, name_hint="linear"): + x0 = pl.slice(x, [16, 512], [0, 0]) + w0 = pl.slice(w, [32, 512], [0, 0]) + acc = pl.matmul(x0, w0, b_trans=True, out_dtype=pl.FP32) + for kb in pl.range(1, 2): + k0 = kb * 512 + x_chunk = pl.slice(x, [16, 512], [0, k0]) + w_chunk = pl.slice(w, [32, 512], [0, k0]) + acc = pl.matmul_acc(acc, x_chunk, w_chunk, b_trans=True) + out = pl.assemble(out, acc, [0, 0]) + + mlir_code = _generate_default_mlir_for_func_type(MatmulAccLoopProgram, ir.FunctionType.AIC) + lines = _get_mlir_lines(mlir_code) + + acc_line = _single_line(lines, "pto.tmatmul.acc", startswith=True) + matmul_line = _single_line(lines, "pto.tmatmul ins", startswith=True) + store_line = _single_line(lines, "pto.tstore", startswith=True) + + acc_ins = re.search(r"ins\((%[\w\d_]+),", acc_line) + acc_outs = re.search(r"outs\((%[\w\d_]+) :", acc_line) + store_ins = re.search(r"ins\((%[\w\d_]+) :", store_line) + assert acc_ins and acc_outs and store_ins, ( + f"Expected tile buffer operands in matmul_acc/store:\n{acc_line}\n{store_line}" + ) + + assert acc_ins.group(1) == acc_outs.group(1), f"matmul_acc ins/outs must share accumulator: {acc_line}" + assert acc_outs.group(1) == store_ins.group(1), ( + f"the final store must read the loop-carried accumulator buffer, got:\n{acc_line}\n{store_line}" + ) + + matmul_inputs = re.search(r"ins\((%[\w\d_]+), (%[\w\d_]+)", matmul_line) + acc_inputs = re.search(r"ins\(%[\w\d_]+, (%[\w\d_]+), (%[\w\d_]+)", acc_line) + assert matmul_inputs and acc_inputs, f"Expected L0 input operands:\n{matmul_line}\n{acc_line}" + assert matmul_inputs.groups() == acc_inputs.groups(), ( + "matmul_acc must reuse the same L0A/L0B tile-buffer SSA handles when MemoryReuse " + f"reuses those MemRefs:\n{matmul_line}\n{acc_line}" + ) + + def test_pto_codegen_slice_fillpad_partial_dynamic_valid_shape(): """Slice with partially dynamic valid_shape followed by fillpad must lower without a scratch slice alloc, and feed a dynamic valid_shape tile into