Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 91 additions & 3 deletions src/ir/transforms/flatten_tile_nd_to_2d_pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,15 @@ int64_t MultiplyStaticDims(const std::vector<int64_t>& dims, const std::string&
return product;
}

/// Return true when an ND tile has only singleton batch dimensions, i.e.
/// [...batch, M, N] is effectively one 2D matrix page.
bool HasSingletonBatchDims(const TileTypePtr& tile_type, const std::string& context) {
if (!tile_type || tile_type->shape_.size() <= 2) return false;
std::vector<ExprPtr> batch_shape(tile_type->shape_.begin(), tile_type->shape_.end() - 2);
auto batch_dims = ToStaticDims(batch_shape, context + " batch");
return MultiplyStaticDims(batch_dims, context + " batch size") == 1;
}

/// Decompose a flat batch index into per-dimension indices for the given batch shape.
/// e.g. flat_index=5 with batch_shape=[2,3] → indices=[1,2].
std::vector<int64_t> BuildBatchIndices(int64_t flat_index, const std::vector<int64_t>& batch_shape) {
Expand Down Expand Up @@ -218,6 +227,58 @@ bool IsTrailingMatrixAxisSwap(int64_t axis1, int64_t axis2, size_t ndim) {
(axis1 == trailing_axis1 && axis2 == trailing_axis0);
}

std::vector<std::pair<std::string, std::any>> WithTargetMemory(
const std::vector<std::pair<std::string, std::any>>& kwargs, MemorySpace target_memory) {
auto updated = kwargs;
for (auto& kv : updated) {
if (kv.first == "target_memory") {
kv.second = target_memory;
return updated;
}
}
updated.emplace_back("target_memory", target_memory);
return updated;
}

void CollectBatchMatmulAccOperands(const std::vector<StmtPtr>& stmts,
std::unordered_set<const Var*>& acc_operands) {
for (const auto& stmt : stmts) {
if (auto assign = As<AssignStmt>(stmt)) {
if (auto call = As<Call>(assign->value_)) {
if (call->op_ && call->op_->name_ == "tile.batch_matmul_acc" && !call->args_.empty()) {
if (auto acc = AsVarLike(call->args_[0])) {
acc_operands.insert(acc.get());
}
}
}
continue;
}
if (auto seq = As<SeqStmts>(stmt)) {
CollectBatchMatmulAccOperands(seq->stmts_, acc_operands);
continue;
}
if (auto scope = As<ScopeStmt>(stmt)) {
CollectBatchMatmulAccOperands(FlattenToStmts(scope->body_), acc_operands);
continue;
}
if (auto if_stmt = As<IfStmt>(stmt)) {
CollectBatchMatmulAccOperands(FlattenToStmts(if_stmt->then_body_), acc_operands);
if (if_stmt->else_body_.has_value()) {
CollectBatchMatmulAccOperands(FlattenToStmts(*if_stmt->else_body_), acc_operands);
}
continue;
}
if (auto for_stmt = As<ForStmt>(stmt)) {
CollectBatchMatmulAccOperands(FlattenToStmts(for_stmt->body_), acc_operands);
continue;
}
if (auto while_stmt = As<WhileStmt>(stmt)) {
CollectBatchMatmulAccOperands(FlattenToStmts(while_stmt->body_), acc_operands);
continue;
}
}
}
Comment on lines +243 to +280
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This function manually traverses different statement types to find tile.batch_matmul_acc calls. This logic should be simplified and made more maintainable by using an IRVisitor. Using a visitor ensures that the traversal is recursive, which is necessary to correctly handle nested control flow structures like SeqStmts and ScopeStmt.

References
  1. Helper functions that traverse IR statements to find specific nodes must be recursive to handle nested control flow structures like SeqStmts and ScopeStmt.


// ============================================================================
// Precondition validation
// ============================================================================
Expand Down Expand Up @@ -1151,6 +1212,7 @@ std::vector<StmtPtr> TransformBody(const std::vector<StmtPtr>& stmts, FlattenCon
// If conditions, For/While bounds, etc.), not just Call arguments. A Var used
// anywhere outside a tile.batch_matmul Call prevents it from being skipped.
std::unordered_set<const Var*> batch_matmul_only_vars;
std::unordered_set<const Var*> singleton_batch_acc_init_vars;
{
std::unordered_map<const Var*, int> use_count;
std::vector<const Var*> batch_matmul_operands; // ordered to avoid nondeterministic iteration
Expand Down Expand Up @@ -1222,13 +1284,33 @@ std::vector<StmtPtr> TransformBody(const std::vector<StmtPtr>& stmts, FlattenCon
CountVarRefs(for_stmt->start_);
CountVarRefs(for_stmt->stop_);
CountVarRefs(for_stmt->step_);
for (const auto& ia : for_stmt->iter_args_) CountVarRefs(ia->initValue_);
std::unordered_set<const Var*> batch_matmul_acc_operands;
CollectBatchMatmulAccOperands(FlattenToStmts(for_stmt->body_), batch_matmul_acc_operands);
for (const auto& ia : for_stmt->iter_args_) {
CountVarRefs(ia->initValue_);
auto init_var = As<Var>(ia->initValue_);
auto init_tile = init_var ? As<TileType>(init_var->GetType()) : nullptr;
if (init_var && HasSingletonBatchDims(init_tile, "tile.batch_matmul_acc init") &&
batch_matmul_acc_operands.count(ia.get())) {
singleton_batch_acc_init_vars.insert(init_var.get());
}
}
Comment on lines +1297 to +1307
Copy link

Copilot AI Apr 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The singleton-batch accumulator init detection only runs on ForStmt/WhileStmt nodes that are directly present in the current stmts vector. If the tile.create([1,...]) init is defined in an outer block while the loop is nested under a ScopeStmt/IfStmt/SeqStmts, this set won’t be populated in the outer TransformBody invocation, so the initializer won’t be forced to MemorySpace::Acc and the original Vec/Acc round-trip can remain. Consider performing this analysis via a recursive walk over the full statement tree (similar to CollectBatchMatmulAccOperands) and sharing the result across recursive TransformBody calls.

Copilot uses AI. Check for mistakes.
continue;
}
// WhileStmt: count condition and iter_arg init Var refs.
if (auto while_stmt = As<WhileStmt>(s)) {
CountVarRefs(while_stmt->condition_);
for (const auto& ia : while_stmt->iter_args_) CountVarRefs(ia->initValue_);
std::unordered_set<const Var*> batch_matmul_acc_operands;
CollectBatchMatmulAccOperands(FlattenToStmts(while_stmt->body_), batch_matmul_acc_operands);
for (const auto& ia : while_stmt->iter_args_) {
CountVarRefs(ia->initValue_);
auto init_var = As<Var>(ia->initValue_);
auto init_tile = init_var ? As<TileType>(init_var->GetType()) : nullptr;
if (init_var && HasSingletonBatchDims(init_tile, "tile.batch_matmul_acc init") &&
batch_matmul_acc_operands.count(ia.get())) {
singleton_batch_acc_init_vars.insert(init_var.get());
}
}
Comment on lines +1297 to +1323
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic for processing iter_args in ForStmt and WhileStmt is identical. This duplicated code could be extracted into a helper function or a lambda to improve readability and maintainability.

continue;
}
}
Expand Down Expand Up @@ -1599,7 +1681,13 @@ std::vector<StmtPtr> TransformBody(const std::vector<StmtPtr>& stmts, FlattenCon
new_args.push_back(Substitute(call->args_[i], ctx.var_map));
}

auto new_call = op_registry.Create(op_name, new_args, call->kwargs_, span);
auto new_kwargs = call->kwargs_;
if (singleton_batch_acc_init_vars.count(assign->var_.get())) {
// Batch=1 matmul_acc lowers to tile.matmul_acc; keep its loop init in Acc
// instead of round-tripping a dummy accumulator through Vec.
new_kwargs = WithTargetMemory(new_kwargs, MemorySpace::Acc);
}
auto new_call = op_registry.Create(op_name, new_args, new_kwargs, span);
auto flat_var =
std::make_shared<Var>(assign->var_->name_hint_, new_call->GetType(), assign->var_->span_);
result.push_back(std::make_shared<AssignStmt>(flat_var, new_call, assign->span_));
Expand Down
72 changes: 72 additions & 0 deletions tests/ut/ir/transforms/test_flatten_tile_nd_to_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -1714,6 +1714,78 @@ def collect_names(prog: ir.Program) -> list[str]:
assert names_flatten.count("tile.matmul") == 1
assert names_flatten.count("tile.matmul_acc") == 1

def test_singleton_batch_acc_init_stays_acc_in_loop(self):
"""A batch=1 dummy accumulator init should flatten directly to Acc.

This covers the common ``acc = create; for k: if first matmul else
matmul_acc`` style. When the 3D RHS selects ``tile.batch_matmul_acc``,
flattening the singleton-batch init to Vec forces an unnecessary
Vec/Acc round-trip on the loop-carried accumulator.
"""

ib = IRBuilder()
span = ir.Span.unknown()
with ib.program("main") as prog:
incore_gvar = prog.declare_function("main_incore_0")
prog.declare_function("main")

with ib.function("main_incore_0", type=ir.FunctionType.InCore) as f:
h = f.param("h", ir.TensorType([16, 128], DataType.BF16))
w = f.param("w", ir.TensorType([1, 64, 128], DataType.BF16))
out_p = f.param(
"out_0",
ir.TensorType([1, 16, 64], DataType.FP32),
direction=ir.ParamDirection.Out,
)
f.return_type(ir.TensorType([1, 16, 64], DataType.FP32))

acc_init = ib.let("acc_init", tile_ops.create([1, 16, 64], DataType.FP32))
kb = ib.var("kb", ir.ScalarType(DataType.INDEX), span)
with ib.for_loop(kb, 0, 2, 1) as loop:
acc_iter = loop.iter_arg("acc", acc_init)
loop.return_var("acc_out")
lhs = ib.let(
"lhs",
tile_ops.load(h, [0, 0], [16, 128], target_memory=ir.MemorySpace.Mat),
)
rhs = ib.let(
"rhs",
tile_ops.load(
w,
[0, 0, 0],
[1, 64, 128],
target_memory=ir.MemorySpace.Mat,
transpose=True,
),
)
acc_next = ib.let("acc_next", tile_ops.batch_matmul_acc(acc_iter, lhs, rhs))
ib.emit(ir.YieldStmt([acc_next], span))
acc_final = loop.output()
out_r = ib.let("out_0", tile_ops.store(acc_final, [0, 0, 0], out_p))
ib.return_stmt(out_r)
prog.add_function(f.get_result())

with ib.function("main") as f:
h = f.param("h", ir.TensorType([16, 128], DataType.BF16))
w = f.param("w", ir.TensorType([1, 64, 128], DataType.BF16))
f.return_type(ir.TensorType([1, 16, 64], DataType.FP32))
out_v = ib.let("out_0", tensor_ops.create([1, 16, 64], DataType.FP32))
y = ib.let("y", ir.Call(incore_gvar, [h, w, out_v], span))
ib.return_stmt(y)
prog.add_function(f.get_result())
before = prog.get_result()

# This fixture intentionally models pre-flatten tile IR; FlattenTileNdTo2D
# repairs the loop-carried TileView contract.
with passes.PassContext([], verification_level=passes.VerificationLevel.NONE):
after = passes.flatten_tile_nd_to_2d()(before)
ir_str = str(after.get_function("main_incore_0"))

assert "tile.batch_matmul" not in ir_str
assert "tile.batch_matmul_acc" not in ir_str
assert "tile.move" not in ir_str
assert "target_memory=pl.Mem.Acc" in ir_str


if __name__ == "__main__":
pytest.main([__file__, "-v"])
Loading