-
Notifications
You must be signed in to change notification settings - Fork 115
[Fix] Fix PTO row-reduce temporary buffer size and type mismatch for TROWSUM #1027
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -11,6 +11,7 @@ | |||||||||||||||||
| #include <tvm/tir/index_map.h> | ||||||||||||||||||
| #include <tvm/tir/op.h> | ||||||||||||||||||
|
|
||||||||||||||||||
| #include <algorithm> | ||||||||||||||||||
| #include <cmath> | ||||||||||||||||||
| #include <iomanip> | ||||||||||||||||||
| #include <sstream> | ||||||||||||||||||
|
|
@@ -161,6 +162,16 @@ int GetValidShape(int shape, const std::string &dtype) { | |||||||||||||||||
| return shape + (32 - shape_mod) / dtype_len; | ||||||||||||||||||
| } | ||||||||||||||||||
|
|
||||||||||||||||||
| int GetRowReduceTmpCol(int valid_col, const std::string &dtype) { | ||||||||||||||||||
| constexpr int kVectorRepeatBytes = 256; | ||||||||||||||||||
| int dtype_len = GetTypeLen(dtype); | ||||||||||||||||||
| int elem_per_repeat = kVectorRepeatBytes / dtype_len; | ||||||||||||||||||
| int tmp_col = valid_col <= elem_per_repeat | ||||||||||||||||||
| ? 1 | ||||||||||||||||||
| : std::max(valid_col / 2, elem_per_repeat); | ||||||||||||||||||
| return GetValidShape(tmp_col, dtype); | ||||||||||||||||||
| } | ||||||||||||||||||
|
|
||||||||||||||||||
| std::string CodeGenTileLangAscendPto::GetVarId(const Var &var) const { | ||||||||||||||||||
| auto it = var_idmap_.find(var.get()); | ||||||||||||||||||
| return (it != var_idmap_.end() && !it->second.empty()) | ||||||||||||||||||
|
|
@@ -1132,8 +1143,8 @@ void CodeGenTileLangAscendPto::GMCopyCall(const CallNode *call, | |||||||||||||||||
| stream << copy_base_addr_map_.at(gm_info.id) << " + " << gm_offset_string; | ||||||||||||||||||
|
|
||||||||||||||||||
| if (is_dynamic) { | ||||||||||||||||||
| stream << ", pto::Shape<" << shape_tmpl << ">()" | ||||||||||||||||||
| << ", pto::Stride<" << stride_tmpl << ">(" << stride_param << ")"; | ||||||||||||||||||
| stream << ", pto::Shape<" << shape_tmpl << ">()" << ", pto::Stride<" | ||||||||||||||||||
| << stride_tmpl << ">(" << stride_param << ")"; | ||||||||||||||||||
| } | ||||||||||||||||||
|
|
||||||||||||||||||
| stream << ", " << PrintExpr(buffer_address_map_.at(local_info.var)) << ", " | ||||||||||||||||||
|
|
@@ -1417,8 +1428,8 @@ void CodeGenTileLangAscendPto::CreateVecIndexCodegen( | |||||||||||||||||
|
|
||||||||||||||||||
| this->PrintIndent(); | ||||||||||||||||||
| this->stream << kAscendPtoScope << "tci" << "<" << getType(dst_info.dtype) | ||||||||||||||||||
| << ", " << PrintExpr(M) << ", " << PrintExpr(N) << ">" | ||||||||||||||||||
| << "(" << PrintExpr(dst_slice_info.first_addr) << ", " | ||||||||||||||||||
| << ", " << PrintExpr(M) << ", " << PrintExpr(N) << ">" << "(" | ||||||||||||||||||
| << PrintExpr(dst_slice_info.first_addr) << ", " | ||||||||||||||||||
| << dst_slice_info.offset << ", " | ||||||||||||||||||
| << GetTypeLen(dst_slice_info.type) << ", " << first_value | ||||||||||||||||||
| << ");\n"; | ||||||||||||||||||
|
|
@@ -1434,6 +1445,7 @@ void CodeGenTileLangAscendPto::GatherbCodegen(const CallNode *op, | |||||||||||||||||
| << idx_name << ");\n"; | ||||||||||||||||||
| } | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
||||||||||||||||||
| void CodeGenTileLangAscendPto::GatherMaskCodegen(const CallNode *op, | ||||||||||||||||||
| const std::string &op_name) { | ||||||||||||||||||
| BufferInfo dst_info = GetBufferInfo(op->args[1]); | ||||||||||||||||||
|
|
@@ -1475,15 +1487,13 @@ void CodeGenTileLangAscendPto::PowCodegen(const CallNode *op) { | |||||||||||||||||
| this->PrintIndent(); | ||||||||||||||||||
| this->stream << kAscendPtoScope << "pow" << "<" << dst_shape_info.type | ||||||||||||||||||
| << ", " << dst_shape_info.slice_row << ", " | ||||||||||||||||||
| << dst_shape_info.slice_col << ">" | ||||||||||||||||||
| << "(" << dst_temp_name << ", " << src0_temp_name << ", " | ||||||||||||||||||
| << src1_temp_name << ");\n"; | ||||||||||||||||||
| << dst_shape_info.slice_col << ">" << "(" << dst_temp_name | ||||||||||||||||||
| << ", " << src0_temp_name << ", " << src1_temp_name << ");\n"; | ||||||||||||||||||
| } else { | ||||||||||||||||||
| this->PrintIndent(); | ||||||||||||||||||
| this->stream << kAscendPtoScope << "pow" << "<" << dst_shape_info.type | ||||||||||||||||||
| << ", " << dst_shape_info.row << ", " << dst_shape_info.col | ||||||||||||||||||
| << ">" | ||||||||||||||||||
| << "(" << dst_shape_info.ub_name << ", " | ||||||||||||||||||
| << ">" << "(" << dst_shape_info.ub_name << ", " | ||||||||||||||||||
| << src0_shape_info.ub_name << ", " << src1_shape_info.ub_name | ||||||||||||||||||
| << ");\n"; | ||||||||||||||||||
| } | ||||||||||||||||||
|
|
@@ -1863,8 +1873,7 @@ void CodeGenTileLangAscendPto::CodegenColBroadcast(const ShapeInfo &dst, | |||||||||||||||||
| } | ||||||||||||||||||
|
|
||||||||||||||||||
| this->PrintIndent(); | ||||||||||||||||||
| this->stream << "TCOLEXPAND" | ||||||||||||||||||
| << "(" << dst_name << ", " << src_name << ");\n"; | ||||||||||||||||||
| this->stream << "TCOLEXPAND" << "(" << dst_name << ", " << src_name << ");\n"; | ||||||||||||||||||
| } | ||||||||||||||||||
|
|
||||||||||||||||||
| void CodeGenTileLangAscendPto::BroadcastOpCodegen(const CallNode *op) { | ||||||||||||||||||
|
|
@@ -2076,8 +2085,7 @@ void CodeGenTileLangAscendPto::AxpyCodegen(const CallNode *op) { | |||||||||||||||||
| this->PrintIndent(); | ||||||||||||||||||
| this->stream << kAscendPtoScope << "axpy" << "<" << dst_shape_info.type | ||||||||||||||||||
| << ", " << dst_shape_info.row << ", " << dst_shape_info.col | ||||||||||||||||||
| << ">" | ||||||||||||||||||
| << "(" << dst_shape_info.ub_name << ", " | ||||||||||||||||||
| << ">" << "(" << dst_shape_info.ub_name << ", " | ||||||||||||||||||
| << src_shape_info.ub_name << ", " << scalar << ");\n"; | ||||||||||||||||||
| } | ||||||||||||||||||
| } | ||||||||||||||||||
|
|
@@ -2350,9 +2358,23 @@ void CodeGenTileLangAscendPto::CodegenRowReduce(const ReduceOpInfo &op_info, | |||||||||||||||||
| CreateUbVariableND(src_name, src); | ||||||||||||||||||
| } | ||||||||||||||||||
|
|
||||||||||||||||||
| ICHECK(dst.type == src.type) | ||||||||||||||||||
| << "Row reduce input dtype must be consistent with the output dtype."; | ||||||||||||||||||
|
|
||||||||||||||||||
| std::string temp_name = tmp.ub_name; | ||||||||||||||||||
| if (src.type != tmp.type) { | ||||||||||||||||||
| temp_name = GetTempVarName(temp_name); | ||||||||||||||||||
| int tmp_col = GetRowReduceTmpCol(src.slice_valid_col, src.type); | ||||||||||||||||||
| ShapeInfo tmp_cast = | ||||||||||||||||||
| ShapeInfo{src.slice_valid_row, tmp_col, src.slice_valid_row, tmp_col, | ||||||||||||||||||
| src.slice_valid_row, tmp_col, tmp.extent, tmp.first_addr, | ||||||||||||||||||
| "0", src.type, tmp.ub_name, false}; | ||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The
Suggested change
|
||||||||||||||||||
| CreateUbVariableND(temp_name, tmp_cast); | ||||||||||||||||||
| } | ||||||||||||||||||
|
|
||||||||||||||||||
| this->PrintIndent(); | ||||||||||||||||||
| this->stream << op_name << "(" << dst_name << ", " << src_name << ", " | ||||||||||||||||||
| << tmp.ub_name << ");\n"; | ||||||||||||||||||
| << temp_name << ");\n"; | ||||||||||||||||||
| } | ||||||||||||||||||
|
|
||||||||||||||||||
| void CodeGenTileLangAscendPto::CodegenColReduce(const ReduceOpInfo &op_info, | ||||||||||||||||||
|
|
@@ -2722,17 +2744,16 @@ void CodeGenTileLangAscendPto::VisitExpr_(const SelectNode *op, | |||||||||||||||||
| auto true_value = PrintExpr(op->true_value); | ||||||||||||||||||
| auto false_value = PrintExpr(op->false_value); | ||||||||||||||||||
|
|
||||||||||||||||||
| os << "(" << condition << " ? " | ||||||||||||||||||
| << "" << true_value << " : " << false_value << ")"; | ||||||||||||||||||
| os << "(" << condition << " ? " << "" << true_value << " : " << false_value | ||||||||||||||||||
| << ")"; | ||||||||||||||||||
| } | ||||||||||||||||||
|
|
||||||||||||||||||
| static void ProcessHostInput(std::ostream &os, | ||||||||||||||||||
| std::vector<std::string> &arg_names, | ||||||||||||||||||
| std::vector<const tir::VarNode *> &shape_vars, | ||||||||||||||||||
| bool add_args = true) { | ||||||||||||||||||
| for (auto shape_var : shape_vars) { | ||||||||||||||||||
| os << ", " | ||||||||||||||||||
| << "int64_t " << shape_var->name_hint; | ||||||||||||||||||
| os << ", " << "int64_t " << shape_var->name_hint; | ||||||||||||||||||
| if (add_args) { | ||||||||||||||||||
| arg_names.push_back(shape_var->name_hint); | ||||||||||||||||||
| } | ||||||||||||||||||
|
|
||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The function
GetRowReduceTmpColusesGetTypeLen(dtype), which currently does not supportint4b_t(it will trigger anICHECK(false)). However,getType(line 67) does support 4-bit integers. While row reduction is primarily used for floating-point types, this inconsistency could lead to unexpected compiler crashes if 4-bit integer reduction is attempted.