Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/backend/common/pto_ops_common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
#include "pypto/ir/kind_traits.h"
#include "pypto/ir/scalar_expr.h"
#include "pypto/ir/transforms/utils/memref_utils.h"
#include "pypto/ir/transforms/utils/tile_view_semantics.h"
#include "pypto/ir/type.h"

namespace pypto {
Expand Down Expand Up @@ -1038,9 +1039,8 @@ static std::string MakeTileStoreCodegenPTO(const CallPtr& op, codegen::CodegenBa

auto tile_type = As<ir::TileType>(tile->GetType());
INTERNAL_CHECK_SPAN(tile_type, op->span_) << "tile.store first argument must have TileType";
INTERNAL_CHECK_SPAN(tile_type->tile_view_.has_value(), op->span_)
<< "tile.store tile must have TileView with valid_shape";
const auto tile_view = tile_type->tile_view_.value_or(ir::TileView{});
const auto tile_view = ir::tile_view_semantics::NormalizeImplicitTileView(
tile_type->tile_view_, tile_type->shape_, tile_type->memory_space_);
const auto& valid_shape = tile_view.valid_shape;
INTERNAL_CHECK_SPAN(valid_shape.size() == 2, op->span_) << "tile.store tile valid_shape must be 2D";

Expand Down
16 changes: 7 additions & 9 deletions src/codegen/pto/pto_type_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "pypto/ir/expr.h"
#include "pypto/ir/kind_traits.h"
#include "pypto/ir/scalar_expr.h"
#include "pypto/ir/transforms/utils/tile_view_semantics.h"
#include "pypto/ir/type.h"

namespace pypto {
Expand Down Expand Up @@ -132,15 +133,12 @@ TileTypeComponents ExtractTileTypeInfo(const ir::TileType& tile_type, const std:
c.v_row_dynamic = true;
c.v_col_dynamic = true;

if (tile_type.tile_view_.has_value()) {
const auto& tv = *tile_type.tile_view_;
c.blayout = tv.blayout;
c.slayout = tv.slayout;
c.fractal = tv.fractal;
c.pad = tv.pad;
} else if (c.cols == 1 && c.rows > 1) {
c.blayout = ir::TileLayout::col_major;
}
auto tv = ir::tile_view_semantics::NormalizeImplicitTileView(tile_type.tile_view_, tile_type.shape_,
tile_type.memory_space_);
c.blayout = tv.blayout;
c.slayout = tv.slayout;
c.fractal = tv.fractal;
c.pad = tv.pad;
return c;
}

Expand Down
127 changes: 119 additions & 8 deletions src/ir/transforms/flatten_tile_nd_to_2d_pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,15 @@ int64_t MultiplyStaticDims(const std::vector<int64_t>& dims, const std::string&
return product;
}

/// Return true when an ND tile has only singleton batch dimensions, i.e.
/// [...batch, M, N] is effectively one 2D matrix page.
bool HasSingletonBatchDims(const TileTypePtr& tile_type, const std::string& context) {
if (!tile_type || tile_type->shape_.size() <= 2) return false;
std::vector<ExprPtr> batch_shape(tile_type->shape_.begin(), tile_type->shape_.end() - 2);
auto batch_dims = ToStaticDims(batch_shape, context + " batch");
return MultiplyStaticDims(batch_dims, context + " batch size") == 1;
}

/// Decompose a flat batch index into per-dimension indices for the given batch shape.
/// e.g. flat_index=5 with batch_shape=[2,3] → indices=[1,2].
std::vector<int64_t> BuildBatchIndices(int64_t flat_index, const std::vector<int64_t>& batch_shape) {
Expand Down Expand Up @@ -218,6 +227,58 @@ bool IsTrailingMatrixAxisSwap(int64_t axis1, int64_t axis2, size_t ndim) {
(axis1 == trailing_axis1 && axis2 == trailing_axis0);
}

std::vector<std::pair<std::string, std::any>> WithTargetMemory(
const std::vector<std::pair<std::string, std::any>>& kwargs, MemorySpace target_memory) {
auto updated = kwargs;
for (auto& kv : updated) {
if (kv.first == "target_memory") {
kv.second = target_memory;
return updated;
}
}
updated.emplace_back("target_memory", target_memory);
return updated;
}

void CollectBatchMatmulAccOperands(const std::vector<StmtPtr>& stmts,
std::unordered_set<const Var*>& acc_operands) {
for (const auto& stmt : stmts) {
if (auto assign = As<AssignStmt>(stmt)) {
if (auto call = As<Call>(assign->value_)) {
if (call->op_ && call->op_->name_ == "tile.batch_matmul_acc" && !call->args_.empty()) {
if (auto acc = AsVarLike(call->args_[0])) {
acc_operands.insert(acc.get());
}
}
}
continue;
}
if (auto seq = As<SeqStmts>(stmt)) {
CollectBatchMatmulAccOperands(seq->stmts_, acc_operands);
continue;
}
if (auto scope = As<ScopeStmt>(stmt)) {
CollectBatchMatmulAccOperands(FlattenToStmts(scope->body_), acc_operands);
continue;
}
if (auto if_stmt = As<IfStmt>(stmt)) {
CollectBatchMatmulAccOperands(FlattenToStmts(if_stmt->then_body_), acc_operands);
if (if_stmt->else_body_.has_value()) {
CollectBatchMatmulAccOperands(FlattenToStmts(*if_stmt->else_body_), acc_operands);
}
continue;
}
if (auto for_stmt = As<ForStmt>(stmt)) {
CollectBatchMatmulAccOperands(FlattenToStmts(for_stmt->body_), acc_operands);
continue;
}
if (auto while_stmt = As<WhileStmt>(stmt)) {
CollectBatchMatmulAccOperands(FlattenToStmts(while_stmt->body_), acc_operands);
continue;
}
}
}
Comment on lines +243 to +280
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This function manually traverses different statement types to find tile.batch_matmul_acc calls. This logic should be simplified and made more maintainable by using an IRVisitor. Using a visitor ensures that the traversal is recursive, which is necessary to correctly handle nested control flow structures like SeqStmts and ScopeStmt.

References
  1. Helper functions that traverse IR statements to find specific nodes must be recursive to handle nested control flow structures like SeqStmts and ScopeStmt.


// ============================================================================
// Precondition validation
// ============================================================================
Expand Down Expand Up @@ -628,11 +689,21 @@ BatchPageResult ExtractBatchPage(const BatchOperandInfo& info, const std::vector

} else if (operand_type->shape_.size() == 2) {
// Strategy 2: Slice from already-flattened 2D tile.
auto offset = MakeShapeTupleFromInts({batch_index * source_rows, 0}, span);
auto shape = MakeShapeTupleFromInts({source_rows, source_cols}, span);
auto slice = op_registry.Create("tile.slice", {operand, shape, offset}, span);
current = std::make_shared<Var>(base_name + "_slice_" + suffix, slice->GetType(), span);
page.stmts.push_back(std::make_shared<AssignStmt>(current, slice, span));
auto flat_rows = As<ConstInt>(operand_type->shape_[0]);
auto flat_cols = As<ConstInt>(operand_type->shape_[1]);
if (batch_index == 0 && flat_rows && flat_cols && flat_rows->value_ == source_rows &&
flat_cols->value_ == source_cols) {
// Singleton-batch operands are already the exact 2D page. Avoid a full-tile
// slice, which would lower to an unsupported Mat->Mat tmov on a2a3.
current = AsVarLike(operand);
CHECK(current) << "FlattenTileNdTo2D: expected 2D batch_matmul operand to be Var-like";
} else {
auto offset = MakeShapeTupleFromInts({batch_index * source_rows, 0}, span);
auto shape = MakeShapeTupleFromInts({source_rows, source_cols}, span);
auto slice = op_registry.Create("tile.slice", {operand, shape, offset}, span);
current = std::make_shared<Var>(base_name + "_slice_" + suffix, slice->GetType(), span);
page.stmts.push_back(std::make_shared<AssignStmt>(current, slice, span));
}
Comment on lines +692 to +706
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Avoid hard-failing when the 2D singleton-page operand is not Var-like

Line 698–699 assumes the operand is always Var/IterArg and CHECKs otherwise. This can abort the pass for valid IR where the operand is an inline expression. Please fall back to the existing tile.slice branch when AsVarLike(operand) is null, instead of crashing.

Suggested fix
-    if (batch_index == 0 && flat_rows && flat_cols && flat_rows->value_ == source_rows &&
-        flat_cols->value_ == source_cols) {
+    if (batch_index == 0 && flat_rows && flat_cols && flat_rows->value_ == source_rows &&
+        flat_cols->value_ == source_cols) {
       // Singleton-batch operands are already the exact 2D page. Avoid a full-tile
       // slice, which would lower to an unsupported Mat->Mat tmov on a2a3.
-      current = AsVarLike(operand);
-      CHECK(current) << "FlattenTileNdTo2D: expected 2D batch_matmul operand to be Var-like";
+      if (auto operand_var = AsVarLike(operand)) {
+        current = operand_var;
+      } else {
+        auto offset = MakeShapeTupleFromInts({batch_index * source_rows, 0}, span);
+        auto shape = MakeShapeTupleFromInts({source_rows, source_cols}, span);
+        auto slice = op_registry.Create("tile.slice", {operand, shape, offset}, span);
+        current = std::make_shared<Var>(base_name + "_slice_" + suffix, slice->GetType(), span);
+        page.stmts.push_back(std::make_shared<AssignStmt>(current, slice, span));
+      }
     } else {
       auto offset = MakeShapeTupleFromInts({batch_index * source_rows, 0}, span);
       auto shape = MakeShapeTupleFromInts({source_rows, source_cols}, span);
       auto slice = op_registry.Create("tile.slice", {operand, shape, offset}, span);
       current = std::make_shared<Var>(base_name + "_slice_" + suffix, slice->GetType(), span);
       page.stmts.push_back(std::make_shared<AssignStmt>(current, slice, span));
     }
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/ir/transforms/flatten_tile_nd_to_2d_pass.cpp` around lines 692 - 706, The
pass currently CHECKs that AsVarLike(operand) is non-null for the 2D
singleton-page fast path (in FlattenTileNdTo2D) which aborts on valid inline
expressions; change the logic to try AsVarLike(operand) and if it returns null
fall back to creating the same tile.slice + Var + AssignStmt sequence used in
the else branch (use the same offset/shape and page.stmts push) instead of
calling CHECK, so inline expressions are handled by emitting a slice-backed
Var-like temporary rather than crashing.


} else {
// Strategy 3: rank>2 tile.slice + tile.reshape to 2D.
Expand Down Expand Up @@ -1151,6 +1222,7 @@ std::vector<StmtPtr> TransformBody(const std::vector<StmtPtr>& stmts, FlattenCon
// If conditions, For/While bounds, etc.), not just Call arguments. A Var used
// anywhere outside a tile.batch_matmul Call prevents it from being skipped.
std::unordered_set<const Var*> batch_matmul_only_vars;
std::unordered_set<const Var*> singleton_batch_acc_init_vars;
{
std::unordered_map<const Var*, int> use_count;
std::vector<const Var*> batch_matmul_operands; // ordered to avoid nondeterministic iteration
Expand Down Expand Up @@ -1222,13 +1294,33 @@ std::vector<StmtPtr> TransformBody(const std::vector<StmtPtr>& stmts, FlattenCon
CountVarRefs(for_stmt->start_);
CountVarRefs(for_stmt->stop_);
CountVarRefs(for_stmt->step_);
for (const auto& ia : for_stmt->iter_args_) CountVarRefs(ia->initValue_);
std::unordered_set<const Var*> batch_matmul_acc_operands;
CollectBatchMatmulAccOperands(FlattenToStmts(for_stmt->body_), batch_matmul_acc_operands);
for (const auto& ia : for_stmt->iter_args_) {
CountVarRefs(ia->initValue_);
auto init_var = As<Var>(ia->initValue_);
auto init_tile = init_var ? As<TileType>(init_var->GetType()) : nullptr;
if (init_var && HasSingletonBatchDims(init_tile, "tile.batch_matmul_acc init") &&
batch_matmul_acc_operands.count(ia.get())) {
singleton_batch_acc_init_vars.insert(init_var.get());
}
}
Comment on lines +1297 to +1307
Copy link

Copilot AI Apr 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The singleton-batch accumulator init detection only runs on ForStmt/WhileStmt nodes that are directly present in the current stmts vector. If the tile.create([1,...]) init is defined in an outer block while the loop is nested under a ScopeStmt/IfStmt/SeqStmts, this set won’t be populated in the outer TransformBody invocation, so the initializer won’t be forced to MemorySpace::Acc and the original Vec/Acc round-trip can remain. Consider performing this analysis via a recursive walk over the full statement tree (similar to CollectBatchMatmulAccOperands) and sharing the result across recursive TransformBody calls.

Copilot uses AI. Check for mistakes.
continue;
}
// WhileStmt: count condition and iter_arg init Var refs.
if (auto while_stmt = As<WhileStmt>(s)) {
CountVarRefs(while_stmt->condition_);
for (const auto& ia : while_stmt->iter_args_) CountVarRefs(ia->initValue_);
std::unordered_set<const Var*> batch_matmul_acc_operands;
CollectBatchMatmulAccOperands(FlattenToStmts(while_stmt->body_), batch_matmul_acc_operands);
for (const auto& ia : while_stmt->iter_args_) {
CountVarRefs(ia->initValue_);
auto init_var = As<Var>(ia->initValue_);
auto init_tile = init_var ? As<TileType>(init_var->GetType()) : nullptr;
if (init_var && HasSingletonBatchDims(init_tile, "tile.batch_matmul_acc init") &&
batch_matmul_acc_operands.count(ia.get())) {
singleton_batch_acc_init_vars.insert(init_var.get());
}
}
Comment on lines +1297 to +1323
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic for processing iter_args in ForStmt and WhileStmt is identical. This duplicated code could be extracted into a helper function or a lambda to improve readability and maintainability.

continue;
}
}
Expand Down Expand Up @@ -1599,7 +1691,26 @@ std::vector<StmtPtr> TransformBody(const std::vector<StmtPtr>& stmts, FlattenCon
new_args.push_back(Substitute(call->args_[i], ctx.var_map));
}

auto new_call = op_registry.Create(op_name, new_args, call->kwargs_, span);
auto new_kwargs = call->kwargs_;
bool force_acc_init = singleton_batch_acc_init_vars.count(assign->var_.get()) != 0;
if (force_acc_init) {
// Batch=1 matmul_acc lowers to tile.matmul_acc; keep its loop init in Acc
// instead of round-tripping a dummy accumulator through Vec.
new_kwargs = WithTargetMemory(new_kwargs, MemorySpace::Acc);
}
auto new_call = op_registry.Create(op_name, new_args, new_kwargs, span);
if (force_acc_init) {
auto new_call_tile = As<TileType>(new_call->GetType());
CHECK(new_call_tile) << "FlattenTileNdTo2D: expected flattened accumulator init to be TileType";
// Refresh the old N-D/default TileView when changing the dummy init
// to Acc so memory reuse sees the same layout as matmul outputs.
auto acc_view = tile_view_semantics::GetImplicitTileView(new_call_tile->shape_, MemorySpace::Acc);
auto acc_type =
std::make_shared<TileType>(new_call_tile->shape_, new_call_tile->dtype_, new_call_tile->memref_,
std::move(acc_view), MemorySpace::Acc);
new_call = std::make_shared<Call>(new_call->op_, new_call->args_, new_call->kwargs_,
std::move(acc_type), new_call->span_);
}
auto flat_var =
std::make_shared<Var>(assign->var_->name_hint_, new_call->GetType(), assign->var_->span_);
result.push_back(std::make_shared<AssignStmt>(flat_var, new_call, assign->span_));
Expand Down
75 changes: 75 additions & 0 deletions tests/ut/ir/transforms/test_flatten_tile_nd_to_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -1713,6 +1713,81 @@ def collect_names(prog: ir.Program) -> list[str]:
assert "tile.batch_matmul_acc" not in names_flatten
assert names_flatten.count("tile.matmul") == 1
assert names_flatten.count("tile.matmul_acc") == 1
assert "tile.slice" not in names_flatten

def test_singleton_batch_acc_init_stays_acc_in_loop(self):
"""A batch=1 dummy accumulator init should flatten directly to Acc.

This covers the common ``acc = create; for k: if first matmul else
matmul_acc`` style. When the 3D RHS selects ``tile.batch_matmul_acc``,
flattening the singleton-batch init to Vec forces an unnecessary
Vec/Acc round-trip on the loop-carried accumulator.
"""

ib = IRBuilder()
span = ir.Span.unknown()
with ib.program("main") as prog:
incore_gvar = prog.declare_function("main_incore_0")
prog.declare_function("main")

with ib.function("main_incore_0", type=ir.FunctionType.InCore) as f:
h = f.param("h", ir.TensorType([16, 128], DataType.BF16))
w = f.param("w", ir.TensorType([1, 64, 128], DataType.BF16))
out_p = f.param(
"out_0",
ir.TensorType([1, 16, 64], DataType.FP32),
direction=ir.ParamDirection.Out,
)
f.return_type(ir.TensorType([1, 16, 64], DataType.FP32))

acc_init = ib.let("acc_init", tile_ops.create([1, 16, 64], DataType.FP32))
kb = ib.var("kb", ir.ScalarType(DataType.INDEX), span)
with ib.for_loop(kb, 0, 2, 1) as loop:
acc_iter = loop.iter_arg("acc", acc_init)
loop.return_var("acc_out")
lhs = ib.let(
"lhs",
tile_ops.load(h, [0, 0], [16, 128], target_memory=ir.MemorySpace.Mat),
)
rhs = ib.let(
"rhs",
tile_ops.load(
w,
[0, 0, 0],
[1, 64, 128],
target_memory=ir.MemorySpace.Mat,
transpose=True,
),
)
acc_next = ib.let("acc_next", tile_ops.batch_matmul_acc(acc_iter, lhs, rhs))
ib.emit(ir.YieldStmt([acc_next], span))
acc_final = loop.output()
out_r = ib.let("out_0", tile_ops.store(acc_final, [0, 0, 0], out_p))
ib.return_stmt(out_r)
prog.add_function(f.get_result())

with ib.function("main") as f:
h = f.param("h", ir.TensorType([16, 128], DataType.BF16))
w = f.param("w", ir.TensorType([1, 64, 128], DataType.BF16))
f.return_type(ir.TensorType([1, 16, 64], DataType.FP32))
out_v = ib.let("out_0", tensor_ops.create([1, 16, 64], DataType.FP32))
y = ib.let("y", ir.Call(incore_gvar, [h, w, out_v], span))
ib.return_stmt(y)
prog.add_function(f.get_result())
before = prog.get_result()

# This fixture intentionally models pre-flatten tile IR; FlattenTileNdTo2D
# repairs the loop-carried TileView contract.
with passes.PassContext([], verification_level=passes.VerificationLevel.NONE):
after = passes.flatten_tile_nd_to_2d()(before)
ir_str = str(after.get_function("main_incore_0"))

assert "tile.batch_matmul" not in ir_str
assert "tile.batch_matmul_acc" not in ir_str
assert "tile.slice" not in ir_str
assert "tile.move" not in ir_str
assert "target_memory=pl.Mem.Acc" in ir_str
assert "pl.Mem.Acc, pl.TileView(blayout=pl.TileLayout.row_major" not in ir_str


if __name__ == "__main__":
Expand Down
Loading