-
Notifications
You must be signed in to change notification settings - Fork 67
Fix singleton-batch matmul accumulator layout #1237
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -155,6 +155,15 @@ int64_t MultiplyStaticDims(const std::vector<int64_t>& dims, const std::string& | |
| return product; | ||
| } | ||
|
|
||
| /// Return true when an ND tile has only singleton batch dimensions, i.e. | ||
| /// [...batch, M, N] is effectively one 2D matrix page. | ||
| bool HasSingletonBatchDims(const TileTypePtr& tile_type, const std::string& context) { | ||
| if (!tile_type || tile_type->shape_.size() <= 2) return false; | ||
| std::vector<ExprPtr> batch_shape(tile_type->shape_.begin(), tile_type->shape_.end() - 2); | ||
| auto batch_dims = ToStaticDims(batch_shape, context + " batch"); | ||
| return MultiplyStaticDims(batch_dims, context + " batch size") == 1; | ||
| } | ||
|
|
||
| /// Decompose a flat batch index into per-dimension indices for the given batch shape. | ||
| /// e.g. flat_index=5 with batch_shape=[2,3] → indices=[1,2]. | ||
| std::vector<int64_t> BuildBatchIndices(int64_t flat_index, const std::vector<int64_t>& batch_shape) { | ||
|
|
@@ -218,6 +227,58 @@ bool IsTrailingMatrixAxisSwap(int64_t axis1, int64_t axis2, size_t ndim) { | |
| (axis1 == trailing_axis1 && axis2 == trailing_axis0); | ||
| } | ||
|
|
||
| std::vector<std::pair<std::string, std::any>> WithTargetMemory( | ||
| const std::vector<std::pair<std::string, std::any>>& kwargs, MemorySpace target_memory) { | ||
| auto updated = kwargs; | ||
| for (auto& kv : updated) { | ||
| if (kv.first == "target_memory") { | ||
| kv.second = target_memory; | ||
| return updated; | ||
| } | ||
| } | ||
| updated.emplace_back("target_memory", target_memory); | ||
| return updated; | ||
| } | ||
|
|
||
| void CollectBatchMatmulAccOperands(const std::vector<StmtPtr>& stmts, | ||
| std::unordered_set<const Var*>& acc_operands) { | ||
| for (const auto& stmt : stmts) { | ||
| if (auto assign = As<AssignStmt>(stmt)) { | ||
| if (auto call = As<Call>(assign->value_)) { | ||
| if (call->op_ && call->op_->name_ == "tile.batch_matmul_acc" && !call->args_.empty()) { | ||
| if (auto acc = AsVarLike(call->args_[0])) { | ||
| acc_operands.insert(acc.get()); | ||
| } | ||
| } | ||
| } | ||
| continue; | ||
| } | ||
| if (auto seq = As<SeqStmts>(stmt)) { | ||
| CollectBatchMatmulAccOperands(seq->stmts_, acc_operands); | ||
| continue; | ||
| } | ||
| if (auto scope = As<ScopeStmt>(stmt)) { | ||
| CollectBatchMatmulAccOperands(FlattenToStmts(scope->body_), acc_operands); | ||
| continue; | ||
| } | ||
| if (auto if_stmt = As<IfStmt>(stmt)) { | ||
| CollectBatchMatmulAccOperands(FlattenToStmts(if_stmt->then_body_), acc_operands); | ||
| if (if_stmt->else_body_.has_value()) { | ||
| CollectBatchMatmulAccOperands(FlattenToStmts(*if_stmt->else_body_), acc_operands); | ||
| } | ||
| continue; | ||
| } | ||
| if (auto for_stmt = As<ForStmt>(stmt)) { | ||
| CollectBatchMatmulAccOperands(FlattenToStmts(for_stmt->body_), acc_operands); | ||
| continue; | ||
| } | ||
| if (auto while_stmt = As<WhileStmt>(stmt)) { | ||
| CollectBatchMatmulAccOperands(FlattenToStmts(while_stmt->body_), acc_operands); | ||
| continue; | ||
| } | ||
| } | ||
| } | ||
|
|
||
| // ============================================================================ | ||
| // Precondition validation | ||
| // ============================================================================ | ||
|
|
@@ -1151,6 +1212,7 @@ std::vector<StmtPtr> TransformBody(const std::vector<StmtPtr>& stmts, FlattenCon | |
| // If conditions, For/While bounds, etc.), not just Call arguments. A Var used | ||
| // anywhere outside a tile.batch_matmul Call prevents it from being skipped. | ||
| std::unordered_set<const Var*> batch_matmul_only_vars; | ||
| std::unordered_set<const Var*> singleton_batch_acc_init_vars; | ||
| { | ||
| std::unordered_map<const Var*, int> use_count; | ||
| std::vector<const Var*> batch_matmul_operands; // ordered to avoid nondeterministic iteration | ||
|
|
@@ -1222,13 +1284,33 @@ std::vector<StmtPtr> TransformBody(const std::vector<StmtPtr>& stmts, FlattenCon | |
| CountVarRefs(for_stmt->start_); | ||
| CountVarRefs(for_stmt->stop_); | ||
| CountVarRefs(for_stmt->step_); | ||
| for (const auto& ia : for_stmt->iter_args_) CountVarRefs(ia->initValue_); | ||
| std::unordered_set<const Var*> batch_matmul_acc_operands; | ||
| CollectBatchMatmulAccOperands(FlattenToStmts(for_stmt->body_), batch_matmul_acc_operands); | ||
| for (const auto& ia : for_stmt->iter_args_) { | ||
| CountVarRefs(ia->initValue_); | ||
| auto init_var = As<Var>(ia->initValue_); | ||
| auto init_tile = init_var ? As<TileType>(init_var->GetType()) : nullptr; | ||
| if (init_var && HasSingletonBatchDims(init_tile, "tile.batch_matmul_acc init") && | ||
| batch_matmul_acc_operands.count(ia.get())) { | ||
| singleton_batch_acc_init_vars.insert(init_var.get()); | ||
| } | ||
| } | ||
|
Comment on lines
+1297
to
+1307
|
||
| continue; | ||
| } | ||
| // WhileStmt: count condition and iter_arg init Var refs. | ||
| if (auto while_stmt = As<WhileStmt>(s)) { | ||
| CountVarRefs(while_stmt->condition_); | ||
| for (const auto& ia : while_stmt->iter_args_) CountVarRefs(ia->initValue_); | ||
| std::unordered_set<const Var*> batch_matmul_acc_operands; | ||
| CollectBatchMatmulAccOperands(FlattenToStmts(while_stmt->body_), batch_matmul_acc_operands); | ||
| for (const auto& ia : while_stmt->iter_args_) { | ||
| CountVarRefs(ia->initValue_); | ||
| auto init_var = As<Var>(ia->initValue_); | ||
| auto init_tile = init_var ? As<TileType>(init_var->GetType()) : nullptr; | ||
| if (init_var && HasSingletonBatchDims(init_tile, "tile.batch_matmul_acc init") && | ||
| batch_matmul_acc_operands.count(ia.get())) { | ||
| singleton_batch_acc_init_vars.insert(init_var.get()); | ||
| } | ||
| } | ||
|
Comment on lines
+1297
to
+1323
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| continue; | ||
| } | ||
| } | ||
|
|
@@ -1599,7 +1681,13 @@ std::vector<StmtPtr> TransformBody(const std::vector<StmtPtr>& stmts, FlattenCon | |
| new_args.push_back(Substitute(call->args_[i], ctx.var_map)); | ||
| } | ||
|
|
||
| auto new_call = op_registry.Create(op_name, new_args, call->kwargs_, span); | ||
| auto new_kwargs = call->kwargs_; | ||
| if (singleton_batch_acc_init_vars.count(assign->var_.get())) { | ||
| // Batch=1 matmul_acc lowers to tile.matmul_acc; keep its loop init in Acc | ||
| // instead of round-tripping a dummy accumulator through Vec. | ||
| new_kwargs = WithTargetMemory(new_kwargs, MemorySpace::Acc); | ||
| } | ||
| auto new_call = op_registry.Create(op_name, new_args, new_kwargs, span); | ||
| auto flat_var = | ||
| std::make_shared<Var>(assign->var_->name_hint_, new_call->GetType(), assign->var_->span_); | ||
| result.push_back(std::make_shared<AssignStmt>(flat_var, new_call, assign->span_)); | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function manually traverses different statement types to find
tile.batch_matmul_acccalls. This logic should be simplified and made more maintainable by using anIRVisitor. Using a visitor ensures that the traversal is recursive, which is necessary to correctly handle nested control flow structures likeSeqStmtsandScopeStmt.References