Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 39 additions & 18 deletions src/target/codegen_ascend_pto.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <tvm/tir/index_map.h>
#include <tvm/tir/op.h>

#include <algorithm>
#include <cmath>
#include <iomanip>
#include <sstream>
Expand Down Expand Up @@ -161,6 +162,16 @@ int GetValidShape(int shape, const std::string &dtype) {
return shape + (32 - shape_mod) / dtype_len;
}

int GetRowReduceTmpCol(int valid_col, const std::string &dtype) {
constexpr int kVectorRepeatBytes = 256;
int dtype_len = GetTypeLen(dtype);
int elem_per_repeat = kVectorRepeatBytes / dtype_len;
int tmp_col = valid_col <= elem_per_repeat
? 1
: std::max(valid_col / 2, elem_per_repeat);
return GetValidShape(tmp_col, dtype);
}
Comment on lines +165 to +173
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The function GetRowReduceTmpCol uses GetTypeLen(dtype), which currently does not support int4b_t (it will trigger an ICHECK(false)). However, getType (line 67) does support 4-bit integers. While row reduction is primarily used for floating-point types, this inconsistency could lead to unexpected compiler crashes if 4-bit integer reduction is attempted.


std::string CodeGenTileLangAscendPto::GetVarId(const Var &var) const {
auto it = var_idmap_.find(var.get());
return (it != var_idmap_.end() && !it->second.empty())
Expand Down Expand Up @@ -1132,8 +1143,8 @@ void CodeGenTileLangAscendPto::GMCopyCall(const CallNode *call,
stream << copy_base_addr_map_.at(gm_info.id) << " + " << gm_offset_string;

if (is_dynamic) {
stream << ", pto::Shape<" << shape_tmpl << ">()"
<< ", pto::Stride<" << stride_tmpl << ">(" << stride_param << ")";
stream << ", pto::Shape<" << shape_tmpl << ">()" << ", pto::Stride<"
<< stride_tmpl << ">(" << stride_param << ")";
}

stream << ", " << PrintExpr(buffer_address_map_.at(local_info.var)) << ", "
Expand Down Expand Up @@ -1417,8 +1428,8 @@ void CodeGenTileLangAscendPto::CreateVecIndexCodegen(

this->PrintIndent();
this->stream << kAscendPtoScope << "tci" << "<" << getType(dst_info.dtype)
<< ", " << PrintExpr(M) << ", " << PrintExpr(N) << ">"
<< "(" << PrintExpr(dst_slice_info.first_addr) << ", "
<< ", " << PrintExpr(M) << ", " << PrintExpr(N) << ">" << "("
<< PrintExpr(dst_slice_info.first_addr) << ", "
<< dst_slice_info.offset << ", "
<< GetTypeLen(dst_slice_info.type) << ", " << first_value
<< ");\n";
Expand All @@ -1434,6 +1445,7 @@ void CodeGenTileLangAscendPto::GatherbCodegen(const CallNode *op,
<< idx_name << ");\n";
}


void CodeGenTileLangAscendPto::GatherMaskCodegen(const CallNode *op,
const std::string &op_name) {
BufferInfo dst_info = GetBufferInfo(op->args[1]);
Expand Down Expand Up @@ -1475,15 +1487,13 @@ void CodeGenTileLangAscendPto::PowCodegen(const CallNode *op) {
this->PrintIndent();
this->stream << kAscendPtoScope << "pow" << "<" << dst_shape_info.type
<< ", " << dst_shape_info.slice_row << ", "
<< dst_shape_info.slice_col << ">"
<< "(" << dst_temp_name << ", " << src0_temp_name << ", "
<< src1_temp_name << ");\n";
<< dst_shape_info.slice_col << ">" << "(" << dst_temp_name
<< ", " << src0_temp_name << ", " << src1_temp_name << ");\n";
} else {
this->PrintIndent();
this->stream << kAscendPtoScope << "pow" << "<" << dst_shape_info.type
<< ", " << dst_shape_info.row << ", " << dst_shape_info.col
<< ">"
<< "(" << dst_shape_info.ub_name << ", "
<< ">" << "(" << dst_shape_info.ub_name << ", "
<< src0_shape_info.ub_name << ", " << src1_shape_info.ub_name
<< ");\n";
}
Expand Down Expand Up @@ -1863,8 +1873,7 @@ void CodeGenTileLangAscendPto::CodegenColBroadcast(const ShapeInfo &dst,
}

this->PrintIndent();
this->stream << "TCOLEXPAND"
<< "(" << dst_name << ", " << src_name << ");\n";
this->stream << "TCOLEXPAND" << "(" << dst_name << ", " << src_name << ");\n";
}

void CodeGenTileLangAscendPto::BroadcastOpCodegen(const CallNode *op) {
Expand Down Expand Up @@ -2076,8 +2085,7 @@ void CodeGenTileLangAscendPto::AxpyCodegen(const CallNode *op) {
this->PrintIndent();
this->stream << kAscendPtoScope << "axpy" << "<" << dst_shape_info.type
<< ", " << dst_shape_info.row << ", " << dst_shape_info.col
<< ">"
<< "(" << dst_shape_info.ub_name << ", "
<< ">" << "(" << dst_shape_info.ub_name << ", "
<< src_shape_info.ub_name << ", " << scalar << ");\n";
}
}
Expand Down Expand Up @@ -2350,9 +2358,23 @@ void CodeGenTileLangAscendPto::CodegenRowReduce(const ReduceOpInfo &op_info,
CreateUbVariableND(src_name, src);
}

ICHECK(dst.type == src.type)
<< "Row reduce input dtype must be consistent with the output dtype.";

std::string temp_name = tmp.ub_name;
if (src.type != tmp.type) {
temp_name = GetTempVarName(temp_name);
int tmp_col = GetRowReduceTmpCol(src.slice_valid_col, src.type);
ShapeInfo tmp_cast =
ShapeInfo{src.slice_valid_row, tmp_col, src.slice_valid_row, tmp_col,
src.slice_valid_row, tmp_col, tmp.extent, tmp.first_addr,
"0", src.type, tmp.ub_name, false};
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The offset for the temporary buffer view tmp_cast is hardcoded to "0". While injected scratchpad buffers currently start at offset 0, this hardcoding is fragile and inconsistent with the implementation in CodegenColReduce (line 2414), which correctly uses tmp.offset. If the temporary buffer is ever a slice or part of a larger allocation, this will lead to incorrect memory access.

Suggested change
ShapeInfo tmp_cast =
ShapeInfo{src.slice_valid_row, tmp_col, src.slice_valid_row, tmp_col,
src.slice_valid_row, tmp_col, tmp.extent, tmp.first_addr,
"0", src.type, tmp.ub_name, false};
ShapeInfo tmp_cast =
ShapeInfo{src.slice_valid_row, tmp_col, src.slice_valid_row, tmp_col,
src.slice_valid_row, tmp_col, tmp.extent, tmp.first_addr,
tmp.offset, src.type, tmp.ub_name, false};

CreateUbVariableND(temp_name, tmp_cast);
}

this->PrintIndent();
this->stream << op_name << "(" << dst_name << ", " << src_name << ", "
<< tmp.ub_name << ");\n";
<< temp_name << ");\n";
}

void CodeGenTileLangAscendPto::CodegenColReduce(const ReduceOpInfo &op_info,
Expand Down Expand Up @@ -2722,17 +2744,16 @@ void CodeGenTileLangAscendPto::VisitExpr_(const SelectNode *op,
auto true_value = PrintExpr(op->true_value);
auto false_value = PrintExpr(op->false_value);

os << "(" << condition << " ? "
<< "" << true_value << " : " << false_value << ")";
os << "(" << condition << " ? " << "" << true_value << " : " << false_value
<< ")";
}

static void ProcessHostInput(std::ostream &os,
std::vector<std::string> &arg_names,
std::vector<const tir::VarNode *> &shape_vars,
bool add_args = true) {
for (auto shape_var : shape_vars) {
os << ", "
<< "int64_t " << shape_var->name_hint;
os << ", " << "int64_t " << shape_var->name_hint;
if (add_args) {
arg_names.push_back(shape_var->name_hint);
}
Expand Down
25 changes: 20 additions & 5 deletions src/transform/allocate_tmp_buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,15 @@ int64_t AlignReduceOutputCols(int64_t valid_col, int64_t dtype_bytes) {
return aligned_bytes / dtype_bytes;
}

int64_t GetPtoRowReduceTmpCols(int64_t valid_col, int64_t dtype_bytes) {
constexpr int64_t kVectorRepeatBytes = 256;
const int64_t elem_per_repeat = kVectorRepeatBytes / dtype_bytes;
const int64_t tmp_col = valid_col <= elem_per_repeat
? 1
: std::max(valid_col / 2, elem_per_repeat);
return AlignReduceOutputCols(tmp_col, dtype_bytes);
}

} // namespace

class CallNodeCollector : public ExprVisitor, public StmtVisitor {
Expand Down Expand Up @@ -807,14 +816,20 @@ class TmpBufferInjector : public StmtExprMutator {
valid_row = src_buffer_node->shape[1].as<IntImmNode>()->value;
valid_col = src_buffer_node->shape[2].as<IntImmNode>()->value;
}
if (op_name == "TROWMAX" || op_name == "TROWMIN") {
int64_t tmp_shape_size = valid_row;
if (valid_row > shape_size) {
Array<PrimExpr> tmp_shape = {IntImm(DataType::Int(32), valid_row)};
if (op_name == "TROWMAX" || op_name == "TROWMIN" ||
op_name == "TROWSUM") {
const int64_t dtype_bytes = src_buffer_node->dtype.bytes();
const int64_t tmp_col =
GetPtoRowReduceTmpCols(valid_col, dtype_bytes);
const int64_t tmp_shape_size = valid_row * tmp_col * dtype_bytes;
if (tmp_shape_size > shape_size) {
Array<PrimExpr> tmp_shape = {
IntImm(DataType::Int(32), tmp_shape_size),
};
shape = tmp_shape;
shape_size = tmp_shape_size;
}
} else if (op_name == "TROWSUM" || op_name == "TCOLSUM") {
} else if (op_name == "TCOLSUM") {
int64_t tmp_shape_size = valid_row * valid_col / 2;
if (tmp_shape_size > shape_size) {
Array<PrimExpr> tmp_shape = {
Expand Down
2 changes: 1 addition & 1 deletion tilelang/language/ascend_tile.py
Original file line number Diff line number Diff line change
Expand Up @@ -1382,7 +1382,7 @@ def gather(dst: Buffer | BufferRegion, src: Buffer | BufferRegion, src_offset: B
src_ptr = src.access_ptr("r")
size = math.prod(src.shape)

if isinstance(src, BufferRegion):
if isinstance(src_offset, BufferRegion):
src_offset_ptr, _ = _handle_buffer_region(src_offset, "r")
else:
src_offset_ptr = src_offset.access_ptr("r")
Expand Down
Loading