diff --git a/src/backend/common/pto_ops_common.cpp b/src/backend/common/pto_ops_common.cpp index e41426c82..258187e2a 100644 --- a/src/backend/common/pto_ops_common.cpp +++ b/src/backend/common/pto_ops_common.cpp @@ -42,6 +42,7 @@ #include "pypto/ir/kind_traits.h" #include "pypto/ir/scalar_expr.h" #include "pypto/ir/transforms/utils/memref_utils.h" +#include "pypto/ir/transforms/utils/tile_view_semantics.h" #include "pypto/ir/type.h" namespace pypto { @@ -1038,9 +1039,8 @@ static std::string MakeTileStoreCodegenPTO(const CallPtr& op, codegen::CodegenBa auto tile_type = As(tile->GetType()); INTERNAL_CHECK_SPAN(tile_type, op->span_) << "tile.store first argument must have TileType"; - INTERNAL_CHECK_SPAN(tile_type->tile_view_.has_value(), op->span_) - << "tile.store tile must have TileView with valid_shape"; - const auto tile_view = tile_type->tile_view_.value_or(ir::TileView{}); + const auto tile_view = ir::tile_view_semantics::NormalizeImplicitTileView( + tile_type->tile_view_, tile_type->shape_, tile_type->memory_space_); const auto& valid_shape = tile_view.valid_shape; INTERNAL_CHECK_SPAN(valid_shape.size() == 2, op->span_) << "tile.store tile valid_shape must be 2D"; diff --git a/src/codegen/pto/pto_type_utils.cpp b/src/codegen/pto/pto_type_utils.cpp index d6e7f71e9..b041080e4 100644 --- a/src/codegen/pto/pto_type_utils.cpp +++ b/src/codegen/pto/pto_type_utils.cpp @@ -21,6 +21,7 @@ #include "pypto/ir/expr.h" #include "pypto/ir/kind_traits.h" #include "pypto/ir/scalar_expr.h" +#include "pypto/ir/transforms/utils/tile_view_semantics.h" #include "pypto/ir/type.h" namespace pypto { @@ -132,15 +133,12 @@ TileTypeComponents ExtractTileTypeInfo(const ir::TileType& tile_type, const std: c.v_row_dynamic = true; c.v_col_dynamic = true; - if (tile_type.tile_view_.has_value()) { - const auto& tv = *tile_type.tile_view_; - c.blayout = tv.blayout; - c.slayout = tv.slayout; - c.fractal = tv.fractal; - c.pad = tv.pad; - } else if (c.cols == 1 && c.rows > 1) { - c.blayout = ir::TileLayout::col_major; - } + auto tv = ir::tile_view_semantics::NormalizeImplicitTileView(tile_type.tile_view_, tile_type.shape_, + tile_type.memory_space_); + c.blayout = tv.blayout; + c.slayout = tv.slayout; + c.fractal = tv.fractal; + c.pad = tv.pad; return c; } diff --git a/src/ir/transforms/flatten_tile_nd_to_2d_pass.cpp b/src/ir/transforms/flatten_tile_nd_to_2d_pass.cpp index 787657177..fccfd970e 100644 --- a/src/ir/transforms/flatten_tile_nd_to_2d_pass.cpp +++ b/src/ir/transforms/flatten_tile_nd_to_2d_pass.cpp @@ -155,6 +155,15 @@ int64_t MultiplyStaticDims(const std::vector& dims, const std::string& return product; } +/// Return true when an ND tile has only singleton batch dimensions, i.e. +/// [...batch, M, N] is effectively one 2D matrix page. +bool HasSingletonBatchDims(const TileTypePtr& tile_type, const std::string& context) { + if (!tile_type || tile_type->shape_.size() <= 2) return false; + std::vector batch_shape(tile_type->shape_.begin(), tile_type->shape_.end() - 2); + auto batch_dims = ToStaticDims(batch_shape, context + " batch"); + return MultiplyStaticDims(batch_dims, context + " batch size") == 1; +} + /// Decompose a flat batch index into per-dimension indices for the given batch shape. /// e.g. flat_index=5 with batch_shape=[2,3] → indices=[1,2]. std::vector BuildBatchIndices(int64_t flat_index, const std::vector& batch_shape) { @@ -218,6 +227,58 @@ bool IsTrailingMatrixAxisSwap(int64_t axis1, int64_t axis2, size_t ndim) { (axis1 == trailing_axis1 && axis2 == trailing_axis0); } +std::vector> WithTargetMemory( + const std::vector>& kwargs, MemorySpace target_memory) { + auto updated = kwargs; + for (auto& kv : updated) { + if (kv.first == "target_memory") { + kv.second = target_memory; + return updated; + } + } + updated.emplace_back("target_memory", target_memory); + return updated; +} + +void CollectBatchMatmulAccOperands(const std::vector& stmts, + std::unordered_set& acc_operands) { + for (const auto& stmt : stmts) { + if (auto assign = As(stmt)) { + if (auto call = As(assign->value_)) { + if (call->op_ && call->op_->name_ == "tile.batch_matmul_acc" && !call->args_.empty()) { + if (auto acc = AsVarLike(call->args_[0])) { + acc_operands.insert(acc.get()); + } + } + } + continue; + } + if (auto seq = As(stmt)) { + CollectBatchMatmulAccOperands(seq->stmts_, acc_operands); + continue; + } + if (auto scope = As(stmt)) { + CollectBatchMatmulAccOperands(FlattenToStmts(scope->body_), acc_operands); + continue; + } + if (auto if_stmt = As(stmt)) { + CollectBatchMatmulAccOperands(FlattenToStmts(if_stmt->then_body_), acc_operands); + if (if_stmt->else_body_.has_value()) { + CollectBatchMatmulAccOperands(FlattenToStmts(*if_stmt->else_body_), acc_operands); + } + continue; + } + if (auto for_stmt = As(stmt)) { + CollectBatchMatmulAccOperands(FlattenToStmts(for_stmt->body_), acc_operands); + continue; + } + if (auto while_stmt = As(stmt)) { + CollectBatchMatmulAccOperands(FlattenToStmts(while_stmt->body_), acc_operands); + continue; + } + } +} + // ============================================================================ // Precondition validation // ============================================================================ @@ -628,11 +689,21 @@ BatchPageResult ExtractBatchPage(const BatchOperandInfo& info, const std::vector } else if (operand_type->shape_.size() == 2) { // Strategy 2: Slice from already-flattened 2D tile. - auto offset = MakeShapeTupleFromInts({batch_index * source_rows, 0}, span); - auto shape = MakeShapeTupleFromInts({source_rows, source_cols}, span); - auto slice = op_registry.Create("tile.slice", {operand, shape, offset}, span); - current = std::make_shared(base_name + "_slice_" + suffix, slice->GetType(), span); - page.stmts.push_back(std::make_shared(current, slice, span)); + auto flat_rows = As(operand_type->shape_[0]); + auto flat_cols = As(operand_type->shape_[1]); + if (batch_index == 0 && flat_rows && flat_cols && flat_rows->value_ == source_rows && + flat_cols->value_ == source_cols) { + // Singleton-batch operands are already the exact 2D page. Avoid a full-tile + // slice, which would lower to an unsupported Mat->Mat tmov on a2a3. + current = AsVarLike(operand); + CHECK(current) << "FlattenTileNdTo2D: expected 2D batch_matmul operand to be Var-like"; + } else { + auto offset = MakeShapeTupleFromInts({batch_index * source_rows, 0}, span); + auto shape = MakeShapeTupleFromInts({source_rows, source_cols}, span); + auto slice = op_registry.Create("tile.slice", {operand, shape, offset}, span); + current = std::make_shared(base_name + "_slice_" + suffix, slice->GetType(), span); + page.stmts.push_back(std::make_shared(current, slice, span)); + } } else { // Strategy 3: rank>2 tile.slice + tile.reshape to 2D. @@ -1151,6 +1222,7 @@ std::vector TransformBody(const std::vector& stmts, FlattenCon // If conditions, For/While bounds, etc.), not just Call arguments. A Var used // anywhere outside a tile.batch_matmul Call prevents it from being skipped. std::unordered_set batch_matmul_only_vars; + std::unordered_set singleton_batch_acc_init_vars; { std::unordered_map use_count; std::vector batch_matmul_operands; // ordered to avoid nondeterministic iteration @@ -1222,13 +1294,33 @@ std::vector TransformBody(const std::vector& stmts, FlattenCon CountVarRefs(for_stmt->start_); CountVarRefs(for_stmt->stop_); CountVarRefs(for_stmt->step_); - for (const auto& ia : for_stmt->iter_args_) CountVarRefs(ia->initValue_); + std::unordered_set batch_matmul_acc_operands; + CollectBatchMatmulAccOperands(FlattenToStmts(for_stmt->body_), batch_matmul_acc_operands); + for (const auto& ia : for_stmt->iter_args_) { + CountVarRefs(ia->initValue_); + auto init_var = As(ia->initValue_); + auto init_tile = init_var ? As(init_var->GetType()) : nullptr; + if (init_var && HasSingletonBatchDims(init_tile, "tile.batch_matmul_acc init") && + batch_matmul_acc_operands.count(ia.get())) { + singleton_batch_acc_init_vars.insert(init_var.get()); + } + } continue; } // WhileStmt: count condition and iter_arg init Var refs. if (auto while_stmt = As(s)) { CountVarRefs(while_stmt->condition_); - for (const auto& ia : while_stmt->iter_args_) CountVarRefs(ia->initValue_); + std::unordered_set batch_matmul_acc_operands; + CollectBatchMatmulAccOperands(FlattenToStmts(while_stmt->body_), batch_matmul_acc_operands); + for (const auto& ia : while_stmt->iter_args_) { + CountVarRefs(ia->initValue_); + auto init_var = As(ia->initValue_); + auto init_tile = init_var ? As(init_var->GetType()) : nullptr; + if (init_var && HasSingletonBatchDims(init_tile, "tile.batch_matmul_acc init") && + batch_matmul_acc_operands.count(ia.get())) { + singleton_batch_acc_init_vars.insert(init_var.get()); + } + } continue; } } @@ -1599,7 +1691,26 @@ std::vector TransformBody(const std::vector& stmts, FlattenCon new_args.push_back(Substitute(call->args_[i], ctx.var_map)); } - auto new_call = op_registry.Create(op_name, new_args, call->kwargs_, span); + auto new_kwargs = call->kwargs_; + bool force_acc_init = singleton_batch_acc_init_vars.count(assign->var_.get()) != 0; + if (force_acc_init) { + // Batch=1 matmul_acc lowers to tile.matmul_acc; keep its loop init in Acc + // instead of round-tripping a dummy accumulator through Vec. + new_kwargs = WithTargetMemory(new_kwargs, MemorySpace::Acc); + } + auto new_call = op_registry.Create(op_name, new_args, new_kwargs, span); + if (force_acc_init) { + auto new_call_tile = As(new_call->GetType()); + CHECK(new_call_tile) << "FlattenTileNdTo2D: expected flattened accumulator init to be TileType"; + // Refresh the old N-D/default TileView when changing the dummy init + // to Acc so memory reuse sees the same layout as matmul outputs. + auto acc_view = tile_view_semantics::GetImplicitTileView(new_call_tile->shape_, MemorySpace::Acc); + auto acc_type = + std::make_shared(new_call_tile->shape_, new_call_tile->dtype_, new_call_tile->memref_, + std::move(acc_view), MemorySpace::Acc); + new_call = std::make_shared(new_call->op_, new_call->args_, new_call->kwargs_, + std::move(acc_type), new_call->span_); + } auto flat_var = std::make_shared(assign->var_->name_hint_, new_call->GetType(), assign->var_->span_); result.push_back(std::make_shared(flat_var, new_call, assign->span_)); diff --git a/tests/ut/ir/transforms/test_flatten_tile_nd_to_2d.py b/tests/ut/ir/transforms/test_flatten_tile_nd_to_2d.py index db3a72c71..f3736e78d 100644 --- a/tests/ut/ir/transforms/test_flatten_tile_nd_to_2d.py +++ b/tests/ut/ir/transforms/test_flatten_tile_nd_to_2d.py @@ -1713,6 +1713,81 @@ def collect_names(prog: ir.Program) -> list[str]: assert "tile.batch_matmul_acc" not in names_flatten assert names_flatten.count("tile.matmul") == 1 assert names_flatten.count("tile.matmul_acc") == 1 + assert "tile.slice" not in names_flatten + + def test_singleton_batch_acc_init_stays_acc_in_loop(self): + """A batch=1 dummy accumulator init should flatten directly to Acc. + + This covers the common ``acc = create; for k: if first matmul else + matmul_acc`` style. When the 3D RHS selects ``tile.batch_matmul_acc``, + flattening the singleton-batch init to Vec forces an unnecessary + Vec/Acc round-trip on the loop-carried accumulator. + """ + + ib = IRBuilder() + span = ir.Span.unknown() + with ib.program("main") as prog: + incore_gvar = prog.declare_function("main_incore_0") + prog.declare_function("main") + + with ib.function("main_incore_0", type=ir.FunctionType.InCore) as f: + h = f.param("h", ir.TensorType([16, 128], DataType.BF16)) + w = f.param("w", ir.TensorType([1, 64, 128], DataType.BF16)) + out_p = f.param( + "out_0", + ir.TensorType([1, 16, 64], DataType.FP32), + direction=ir.ParamDirection.Out, + ) + f.return_type(ir.TensorType([1, 16, 64], DataType.FP32)) + + acc_init = ib.let("acc_init", tile_ops.create([1, 16, 64], DataType.FP32)) + kb = ib.var("kb", ir.ScalarType(DataType.INDEX), span) + with ib.for_loop(kb, 0, 2, 1) as loop: + acc_iter = loop.iter_arg("acc", acc_init) + loop.return_var("acc_out") + lhs = ib.let( + "lhs", + tile_ops.load(h, [0, 0], [16, 128], target_memory=ir.MemorySpace.Mat), + ) + rhs = ib.let( + "rhs", + tile_ops.load( + w, + [0, 0, 0], + [1, 64, 128], + target_memory=ir.MemorySpace.Mat, + transpose=True, + ), + ) + acc_next = ib.let("acc_next", tile_ops.batch_matmul_acc(acc_iter, lhs, rhs)) + ib.emit(ir.YieldStmt([acc_next], span)) + acc_final = loop.output() + out_r = ib.let("out_0", tile_ops.store(acc_final, [0, 0, 0], out_p)) + ib.return_stmt(out_r) + prog.add_function(f.get_result()) + + with ib.function("main") as f: + h = f.param("h", ir.TensorType([16, 128], DataType.BF16)) + w = f.param("w", ir.TensorType([1, 64, 128], DataType.BF16)) + f.return_type(ir.TensorType([1, 16, 64], DataType.FP32)) + out_v = ib.let("out_0", tensor_ops.create([1, 16, 64], DataType.FP32)) + y = ib.let("y", ir.Call(incore_gvar, [h, w, out_v], span)) + ib.return_stmt(y) + prog.add_function(f.get_result()) + before = prog.get_result() + + # This fixture intentionally models pre-flatten tile IR; FlattenTileNdTo2D + # repairs the loop-carried TileView contract. + with passes.PassContext([], verification_level=passes.VerificationLevel.NONE): + after = passes.flatten_tile_nd_to_2d()(before) + ir_str = str(after.get_function("main_incore_0")) + + assert "tile.batch_matmul" not in ir_str + assert "tile.batch_matmul_acc" not in ir_str + assert "tile.slice" not in ir_str + assert "tile.move" not in ir_str + assert "target_memory=pl.Mem.Acc" in ir_str + assert "pl.Mem.Acc, pl.TileView(blayout=pl.TileLayout.row_major" not in ir_str if __name__ == "__main__":