-
Notifications
You must be signed in to change notification settings - Fork 67
Fix singleton-batch matmul accumulator layout #1237
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -155,6 +155,15 @@ int64_t MultiplyStaticDims(const std::vector<int64_t>& dims, const std::string& | |
| return product; | ||
| } | ||
|
|
||
| /// Return true when an ND tile has only singleton batch dimensions, i.e. | ||
| /// [...batch, M, N] is effectively one 2D matrix page. | ||
| bool HasSingletonBatchDims(const TileTypePtr& tile_type, const std::string& context) { | ||
| if (!tile_type || tile_type->shape_.size() <= 2) return false; | ||
| std::vector<ExprPtr> batch_shape(tile_type->shape_.begin(), tile_type->shape_.end() - 2); | ||
| auto batch_dims = ToStaticDims(batch_shape, context + " batch"); | ||
| return MultiplyStaticDims(batch_dims, context + " batch size") == 1; | ||
| } | ||
|
|
||
| /// Decompose a flat batch index into per-dimension indices for the given batch shape. | ||
| /// e.g. flat_index=5 with batch_shape=[2,3] → indices=[1,2]. | ||
| std::vector<int64_t> BuildBatchIndices(int64_t flat_index, const std::vector<int64_t>& batch_shape) { | ||
|
|
@@ -218,6 +227,58 @@ bool IsTrailingMatrixAxisSwap(int64_t axis1, int64_t axis2, size_t ndim) { | |
| (axis1 == trailing_axis1 && axis2 == trailing_axis0); | ||
| } | ||
|
|
||
| std::vector<std::pair<std::string, std::any>> WithTargetMemory( | ||
| const std::vector<std::pair<std::string, std::any>>& kwargs, MemorySpace target_memory) { | ||
| auto updated = kwargs; | ||
| for (auto& kv : updated) { | ||
| if (kv.first == "target_memory") { | ||
| kv.second = target_memory; | ||
| return updated; | ||
| } | ||
| } | ||
| updated.emplace_back("target_memory", target_memory); | ||
| return updated; | ||
| } | ||
|
|
||
| void CollectBatchMatmulAccOperands(const std::vector<StmtPtr>& stmts, | ||
| std::unordered_set<const Var*>& acc_operands) { | ||
| for (const auto& stmt : stmts) { | ||
| if (auto assign = As<AssignStmt>(stmt)) { | ||
| if (auto call = As<Call>(assign->value_)) { | ||
| if (call->op_ && call->op_->name_ == "tile.batch_matmul_acc" && !call->args_.empty()) { | ||
| if (auto acc = AsVarLike(call->args_[0])) { | ||
| acc_operands.insert(acc.get()); | ||
| } | ||
| } | ||
| } | ||
| continue; | ||
| } | ||
| if (auto seq = As<SeqStmts>(stmt)) { | ||
| CollectBatchMatmulAccOperands(seq->stmts_, acc_operands); | ||
| continue; | ||
| } | ||
| if (auto scope = As<ScopeStmt>(stmt)) { | ||
| CollectBatchMatmulAccOperands(FlattenToStmts(scope->body_), acc_operands); | ||
| continue; | ||
| } | ||
| if (auto if_stmt = As<IfStmt>(stmt)) { | ||
| CollectBatchMatmulAccOperands(FlattenToStmts(if_stmt->then_body_), acc_operands); | ||
| if (if_stmt->else_body_.has_value()) { | ||
| CollectBatchMatmulAccOperands(FlattenToStmts(*if_stmt->else_body_), acc_operands); | ||
| } | ||
| continue; | ||
| } | ||
| if (auto for_stmt = As<ForStmt>(stmt)) { | ||
| CollectBatchMatmulAccOperands(FlattenToStmts(for_stmt->body_), acc_operands); | ||
| continue; | ||
| } | ||
| if (auto while_stmt = As<WhileStmt>(stmt)) { | ||
| CollectBatchMatmulAccOperands(FlattenToStmts(while_stmt->body_), acc_operands); | ||
| continue; | ||
| } | ||
| } | ||
| } | ||
|
|
||
| // ============================================================================ | ||
| // Precondition validation | ||
| // ============================================================================ | ||
|
|
@@ -628,11 +689,21 @@ BatchPageResult ExtractBatchPage(const BatchOperandInfo& info, const std::vector | |
|
|
||
| } else if (operand_type->shape_.size() == 2) { | ||
| // Strategy 2: Slice from already-flattened 2D tile. | ||
| auto offset = MakeShapeTupleFromInts({batch_index * source_rows, 0}, span); | ||
| auto shape = MakeShapeTupleFromInts({source_rows, source_cols}, span); | ||
| auto slice = op_registry.Create("tile.slice", {operand, shape, offset}, span); | ||
| current = std::make_shared<Var>(base_name + "_slice_" + suffix, slice->GetType(), span); | ||
| page.stmts.push_back(std::make_shared<AssignStmt>(current, slice, span)); | ||
| auto flat_rows = As<ConstInt>(operand_type->shape_[0]); | ||
| auto flat_cols = As<ConstInt>(operand_type->shape_[1]); | ||
| if (batch_index == 0 && flat_rows && flat_cols && flat_rows->value_ == source_rows && | ||
| flat_cols->value_ == source_cols) { | ||
| // Singleton-batch operands are already the exact 2D page. Avoid a full-tile | ||
| // slice, which would lower to an unsupported Mat->Mat tmov on a2a3. | ||
| current = AsVarLike(operand); | ||
| CHECK(current) << "FlattenTileNdTo2D: expected 2D batch_matmul operand to be Var-like"; | ||
| } else { | ||
| auto offset = MakeShapeTupleFromInts({batch_index * source_rows, 0}, span); | ||
| auto shape = MakeShapeTupleFromInts({source_rows, source_cols}, span); | ||
| auto slice = op_registry.Create("tile.slice", {operand, shape, offset}, span); | ||
| current = std::make_shared<Var>(base_name + "_slice_" + suffix, slice->GetType(), span); | ||
| page.stmts.push_back(std::make_shared<AssignStmt>(current, slice, span)); | ||
| } | ||
|
Comment on lines
+692
to
+706
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Avoid hard-failing when the 2D singleton-page operand is not Var-like Line 698–699 assumes the operand is always Var/IterArg and Suggested fix- if (batch_index == 0 && flat_rows && flat_cols && flat_rows->value_ == source_rows &&
- flat_cols->value_ == source_cols) {
+ if (batch_index == 0 && flat_rows && flat_cols && flat_rows->value_ == source_rows &&
+ flat_cols->value_ == source_cols) {
// Singleton-batch operands are already the exact 2D page. Avoid a full-tile
// slice, which would lower to an unsupported Mat->Mat tmov on a2a3.
- current = AsVarLike(operand);
- CHECK(current) << "FlattenTileNdTo2D: expected 2D batch_matmul operand to be Var-like";
+ if (auto operand_var = AsVarLike(operand)) {
+ current = operand_var;
+ } else {
+ auto offset = MakeShapeTupleFromInts({batch_index * source_rows, 0}, span);
+ auto shape = MakeShapeTupleFromInts({source_rows, source_cols}, span);
+ auto slice = op_registry.Create("tile.slice", {operand, shape, offset}, span);
+ current = std::make_shared<Var>(base_name + "_slice_" + suffix, slice->GetType(), span);
+ page.stmts.push_back(std::make_shared<AssignStmt>(current, slice, span));
+ }
} else {
auto offset = MakeShapeTupleFromInts({batch_index * source_rows, 0}, span);
auto shape = MakeShapeTupleFromInts({source_rows, source_cols}, span);
auto slice = op_registry.Create("tile.slice", {operand, shape, offset}, span);
current = std::make_shared<Var>(base_name + "_slice_" + suffix, slice->GetType(), span);
page.stmts.push_back(std::make_shared<AssignStmt>(current, slice, span));
}🤖 Prompt for AI Agents |
||
|
|
||
| } else { | ||
| // Strategy 3: rank>2 tile.slice + tile.reshape to 2D. | ||
|
|
@@ -1151,6 +1222,7 @@ std::vector<StmtPtr> TransformBody(const std::vector<StmtPtr>& stmts, FlattenCon | |
| // If conditions, For/While bounds, etc.), not just Call arguments. A Var used | ||
| // anywhere outside a tile.batch_matmul Call prevents it from being skipped. | ||
| std::unordered_set<const Var*> batch_matmul_only_vars; | ||
| std::unordered_set<const Var*> singleton_batch_acc_init_vars; | ||
| { | ||
| std::unordered_map<const Var*, int> use_count; | ||
| std::vector<const Var*> batch_matmul_operands; // ordered to avoid nondeterministic iteration | ||
|
|
@@ -1222,13 +1294,33 @@ std::vector<StmtPtr> TransformBody(const std::vector<StmtPtr>& stmts, FlattenCon | |
| CountVarRefs(for_stmt->start_); | ||
| CountVarRefs(for_stmt->stop_); | ||
| CountVarRefs(for_stmt->step_); | ||
| for (const auto& ia : for_stmt->iter_args_) CountVarRefs(ia->initValue_); | ||
| std::unordered_set<const Var*> batch_matmul_acc_operands; | ||
| CollectBatchMatmulAccOperands(FlattenToStmts(for_stmt->body_), batch_matmul_acc_operands); | ||
| for (const auto& ia : for_stmt->iter_args_) { | ||
| CountVarRefs(ia->initValue_); | ||
| auto init_var = As<Var>(ia->initValue_); | ||
| auto init_tile = init_var ? As<TileType>(init_var->GetType()) : nullptr; | ||
| if (init_var && HasSingletonBatchDims(init_tile, "tile.batch_matmul_acc init") && | ||
| batch_matmul_acc_operands.count(ia.get())) { | ||
| singleton_batch_acc_init_vars.insert(init_var.get()); | ||
| } | ||
| } | ||
|
Comment on lines
+1297
to
+1307
|
||
| continue; | ||
| } | ||
| // WhileStmt: count condition and iter_arg init Var refs. | ||
| if (auto while_stmt = As<WhileStmt>(s)) { | ||
| CountVarRefs(while_stmt->condition_); | ||
| for (const auto& ia : while_stmt->iter_args_) CountVarRefs(ia->initValue_); | ||
| std::unordered_set<const Var*> batch_matmul_acc_operands; | ||
| CollectBatchMatmulAccOperands(FlattenToStmts(while_stmt->body_), batch_matmul_acc_operands); | ||
| for (const auto& ia : while_stmt->iter_args_) { | ||
| CountVarRefs(ia->initValue_); | ||
| auto init_var = As<Var>(ia->initValue_); | ||
| auto init_tile = init_var ? As<TileType>(init_var->GetType()) : nullptr; | ||
| if (init_var && HasSingletonBatchDims(init_tile, "tile.batch_matmul_acc init") && | ||
| batch_matmul_acc_operands.count(ia.get())) { | ||
| singleton_batch_acc_init_vars.insert(init_var.get()); | ||
| } | ||
| } | ||
|
Comment on lines
+1297
to
+1323
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| continue; | ||
| } | ||
| } | ||
|
|
@@ -1599,7 +1691,26 @@ std::vector<StmtPtr> TransformBody(const std::vector<StmtPtr>& stmts, FlattenCon | |
| new_args.push_back(Substitute(call->args_[i], ctx.var_map)); | ||
| } | ||
|
|
||
| auto new_call = op_registry.Create(op_name, new_args, call->kwargs_, span); | ||
| auto new_kwargs = call->kwargs_; | ||
| bool force_acc_init = singleton_batch_acc_init_vars.count(assign->var_.get()) != 0; | ||
| if (force_acc_init) { | ||
| // Batch=1 matmul_acc lowers to tile.matmul_acc; keep its loop init in Acc | ||
| // instead of round-tripping a dummy accumulator through Vec. | ||
| new_kwargs = WithTargetMemory(new_kwargs, MemorySpace::Acc); | ||
| } | ||
| auto new_call = op_registry.Create(op_name, new_args, new_kwargs, span); | ||
| if (force_acc_init) { | ||
| auto new_call_tile = As<TileType>(new_call->GetType()); | ||
| CHECK(new_call_tile) << "FlattenTileNdTo2D: expected flattened accumulator init to be TileType"; | ||
| // Refresh the old N-D/default TileView when changing the dummy init | ||
| // to Acc so memory reuse sees the same layout as matmul outputs. | ||
| auto acc_view = tile_view_semantics::GetImplicitTileView(new_call_tile->shape_, MemorySpace::Acc); | ||
| auto acc_type = | ||
| std::make_shared<TileType>(new_call_tile->shape_, new_call_tile->dtype_, new_call_tile->memref_, | ||
| std::move(acc_view), MemorySpace::Acc); | ||
| new_call = std::make_shared<Call>(new_call->op_, new_call->args_, new_call->kwargs_, | ||
| std::move(acc_type), new_call->span_); | ||
| } | ||
| auto flat_var = | ||
| std::make_shared<Var>(assign->var_->name_hint_, new_call->GetType(), assign->var_->span_); | ||
| result.push_back(std::make_shared<AssignStmt>(flat_var, new_call, assign->span_)); | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function manually traverses different statement types to find
tile.batch_matmul_acccalls. This logic should be simplified and made more maintainable by using anIRVisitor. Using a visitor ensures that the traversal is recursive, which is necessary to correctly handle nested control flow structures likeSeqStmtsandScopeStmt.References