diff --git a/src/target/codegen_ascend_pto.cc b/src/target/codegen_ascend_pto.cc index dc39ef741..abee8c037 100644 --- a/src/target/codegen_ascend_pto.cc +++ b/src/target/codegen_ascend_pto.cc @@ -11,6 +11,7 @@ #include #include +#include #include #include #include @@ -161,6 +162,16 @@ int GetValidShape(int shape, const std::string &dtype) { return shape + (32 - shape_mod) / dtype_len; } +int GetRowReduceTmpCol(int valid_col, const std::string &dtype) { + constexpr int kVectorRepeatBytes = 256; + int dtype_len = GetTypeLen(dtype); + int elem_per_repeat = kVectorRepeatBytes / dtype_len; + int tmp_col = valid_col <= elem_per_repeat + ? 1 + : std::max(valid_col / 2, elem_per_repeat); + return GetValidShape(tmp_col, dtype); +} + std::string CodeGenTileLangAscendPto::GetVarId(const Var &var) const { auto it = var_idmap_.find(var.get()); return (it != var_idmap_.end() && !it->second.empty()) @@ -1132,8 +1143,8 @@ void CodeGenTileLangAscendPto::GMCopyCall(const CallNode *call, stream << copy_base_addr_map_.at(gm_info.id) << " + " << gm_offset_string; if (is_dynamic) { - stream << ", pto::Shape<" << shape_tmpl << ">()" - << ", pto::Stride<" << stride_tmpl << ">(" << stride_param << ")"; + stream << ", pto::Shape<" << shape_tmpl << ">()" << ", pto::Stride<" + << stride_tmpl << ">(" << stride_param << ")"; } stream << ", " << PrintExpr(buffer_address_map_.at(local_info.var)) << ", " @@ -1417,8 +1428,8 @@ void CodeGenTileLangAscendPto::CreateVecIndexCodegen( this->PrintIndent(); this->stream << kAscendPtoScope << "tci" << "<" << getType(dst_info.dtype) - << ", " << PrintExpr(M) << ", " << PrintExpr(N) << ">" - << "(" << PrintExpr(dst_slice_info.first_addr) << ", " + << ", " << PrintExpr(M) << ", " << PrintExpr(N) << ">" << "(" + << PrintExpr(dst_slice_info.first_addr) << ", " << dst_slice_info.offset << ", " << GetTypeLen(dst_slice_info.type) << ", " << first_value << ");\n"; @@ -1475,15 +1486,13 @@ void CodeGenTileLangAscendPto::PowCodegen(const CallNode *op) { this->PrintIndent(); this->stream << kAscendPtoScope << "pow" << "<" << dst_shape_info.type << ", " << dst_shape_info.slice_row << ", " - << dst_shape_info.slice_col << ">" - << "(" << dst_temp_name << ", " << src0_temp_name << ", " - << src1_temp_name << ");\n"; + << dst_shape_info.slice_col << ">" << "(" << dst_temp_name + << ", " << src0_temp_name << ", " << src1_temp_name << ");\n"; } else { this->PrintIndent(); this->stream << kAscendPtoScope << "pow" << "<" << dst_shape_info.type << ", " << dst_shape_info.row << ", " << dst_shape_info.col - << ">" - << "(" << dst_shape_info.ub_name << ", " + << ">" << "(" << dst_shape_info.ub_name << ", " << src0_shape_info.ub_name << ", " << src1_shape_info.ub_name << ");\n"; } @@ -1863,8 +1872,7 @@ void CodeGenTileLangAscendPto::CodegenColBroadcast(const ShapeInfo &dst, } this->PrintIndent(); - this->stream << "TCOLEXPAND" - << "(" << dst_name << ", " << src_name << ");\n"; + this->stream << "TCOLEXPAND" << "(" << dst_name << ", " << src_name << ");\n"; } void CodeGenTileLangAscendPto::BroadcastOpCodegen(const CallNode *op) { @@ -2076,8 +2084,7 @@ void CodeGenTileLangAscendPto::AxpyCodegen(const CallNode *op) { this->PrintIndent(); this->stream << kAscendPtoScope << "axpy" << "<" << dst_shape_info.type << ", " << dst_shape_info.row << ", " << dst_shape_info.col - << ">" - << "(" << dst_shape_info.ub_name << ", " + << ">" << "(" << dst_shape_info.ub_name << ", " << src_shape_info.ub_name << ", " << scalar << ");\n"; } } @@ -2350,9 +2357,31 @@ void CodeGenTileLangAscendPto::CodegenRowReduce(const ReduceOpInfo &op_info, CreateUbVariableND(src_name, src); } + ICHECK(dst.type == src.type) + << "Row reduce input dtype must be consistent with the output dtype."; + + std::string temp_name = tmp.ub_name; + if (src.type != tmp.type) { + temp_name = GetTempVarName(temp_name); + int tmp_col = GetRowReduceTmpCol(src.slice_valid_col, src.type); + ShapeInfo tmp_cast = ShapeInfo{src.slice_valid_row, + tmp_col, + src.slice_valid_row, + tmp_col, + src.slice_valid_row, + tmp_col, + tmp.extent, + tmp.first_addr, + "0", + src.type, + tmp.ub_name, + false}; + CreateUbVariableND(temp_name, tmp_cast); + } + this->PrintIndent(); this->stream << op_name << "(" << dst_name << ", " << src_name << ", " - << tmp.ub_name << ");\n"; + << temp_name << ");\n"; } void CodeGenTileLangAscendPto::CodegenColReduce(const ReduceOpInfo &op_info, @@ -2722,8 +2751,8 @@ void CodeGenTileLangAscendPto::VisitExpr_(const SelectNode *op, auto true_value = PrintExpr(op->true_value); auto false_value = PrintExpr(op->false_value); - os << "(" << condition << " ? " - << "" << true_value << " : " << false_value << ")"; + os << "(" << condition << " ? " << "" << true_value << " : " << false_value + << ")"; } static void ProcessHostInput(std::ostream &os, @@ -2731,8 +2760,7 @@ static void ProcessHostInput(std::ostream &os, std::vector &shape_vars, bool add_args = true) { for (auto shape_var : shape_vars) { - os << ", " - << "int64_t " << shape_var->name_hint; + os << ", " << "int64_t " << shape_var->name_hint; if (add_args) { arg_names.push_back(shape_var->name_hint); } diff --git a/src/transform/allocate_tmp_buffer.cc b/src/transform/allocate_tmp_buffer.cc index 82e551e34..ef5b8db4e 100644 --- a/src/transform/allocate_tmp_buffer.cc +++ b/src/transform/allocate_tmp_buffer.cc @@ -36,6 +36,15 @@ int64_t AlignReduceOutputCols(int64_t valid_col, int64_t dtype_bytes) { return aligned_bytes / dtype_bytes; } +int64_t GetPtoRowReduceTmpCols(int64_t valid_col, int64_t dtype_bytes) { + constexpr int64_t kVectorRepeatBytes = 256; + const int64_t elem_per_repeat = kVectorRepeatBytes / dtype_bytes; + const int64_t tmp_col = valid_col <= elem_per_repeat + ? 1 + : std::max(valid_col / 2, elem_per_repeat); + return AlignReduceOutputCols(tmp_col, dtype_bytes); +} + } // namespace class CallNodeCollector : public ExprVisitor, public StmtVisitor { @@ -807,14 +816,20 @@ class TmpBufferInjector : public StmtExprMutator { valid_row = src_buffer_node->shape[1].as()->value; valid_col = src_buffer_node->shape[2].as()->value; } - if (op_name == "TROWMAX" || op_name == "TROWMIN") { - int64_t tmp_shape_size = valid_row; - if (valid_row > shape_size) { - Array tmp_shape = {IntImm(DataType::Int(32), valid_row)}; + if (op_name == "TROWMAX" || op_name == "TROWMIN" || + op_name == "TROWSUM") { + const int64_t dtype_bytes = src_buffer_node->dtype.bytes(); + const int64_t tmp_col = + GetPtoRowReduceTmpCols(valid_col, dtype_bytes); + const int64_t tmp_shape_size = valid_row * tmp_col * dtype_bytes; + if (tmp_shape_size > shape_size) { + Array tmp_shape = { + IntImm(DataType::Int(32), tmp_shape_size), + }; shape = tmp_shape; shape_size = tmp_shape_size; } - } else if (op_name == "TROWSUM" || op_name == "TCOLSUM") { + } else if (op_name == "TCOLSUM") { int64_t tmp_shape_size = valid_row * valid_col / 2; if (tmp_shape_size > shape_size) { Array tmp_shape = { diff --git a/tilelang/language/ascend_tile.py b/tilelang/language/ascend_tile.py index 72ca34470..88bd9b560 100644 --- a/tilelang/language/ascend_tile.py +++ b/tilelang/language/ascend_tile.py @@ -1382,7 +1382,7 @@ def gather(dst: Buffer | BufferRegion, src: Buffer | BufferRegion, src_offset: B src_ptr = src.access_ptr("r") size = math.prod(src.shape) - if isinstance(src, BufferRegion): + if isinstance(src_offset, BufferRegion): src_offset_ptr, _ = _handle_buffer_region(src_offset, "r") else: src_offset_ptr = src_offset.access_ptr("r")