Skip to content

support PTO for topk and sort#973

Merged
fuhouyu-hw merged 11 commits intotile-ai:ascendc_ptofrom
hedi515:ascendc_pto
May 9, 2026
Merged

support PTO for topk and sort#973
fuhouyu-hw merged 11 commits intotile-ai:ascendc_ptofrom
hedi515:ascendc_pto

Conversation

@hedi515
Copy link
Copy Markdown
Contributor

@hedi515 hedi515 commented May 7, 2026

No description provided.

@github-actions
Copy link
Copy Markdown

github-actions Bot commented May 7, 2026

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run bash format.sh in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work!

🚀

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements the Sort and TopK operators for the Ascend PTO target, adding code generation logic, merge sort templates for asymmetric sources, and temporary buffer allocation support. Feedback suggests refactoring the lengthy EmitSortAlgorithm function into smaller phases, using variadic templates for MergeSortVar to reduce code duplication, and simplifying the logic for calculating buffer shape sizes.

Comment thread src/target/codegen_ascend_pto.cc Outdated
Comment on lines +1637 to +2049
void CodeGenTileLangAscendPto::EmitSortAlgorithm(
const CallNode *dst_call, const CallNode *src_call,
const CallNode *tmp_call, int32_t repeat_times, int32_t actual_num,
int32_t top_k) {
int32_t aligned_count = repeat_times * 32;
int32_t pad_count = aligned_count - actual_num;
int32_t block_num = repeat_times;
int32_t N = aligned_count;

// top_k < 0 -> full sort: emit 2*N (val,idx) pairs to dst.
// top_k >= 0 -> topk: emit 2*K (val,idx) pairs to dst, dst is too
// small to be reused as bufB, so bufB lives in tmp.
bool is_topk = (top_k >= 0);

DataType dtype = src_call->args[0].dtype();
bool is_half = dtype.is_float() && dtype.bits() == 16;
bool is_float = dtype.is_float() && dtype.bits() == 32;
ICHECK(is_half || is_float)
<< "PTO Sort/TopK supports float32 / float16 input, got " << dtype;

// Internally the algorithm always sorts in float (matching AscendC's B16
// workaround: TMRGSORT requires >= 256 bytes per source, but Sort32 only
// produces 128 bytes per block for B16, so half input is cast to float).
std::string T = "float";
int32_t T_bytes = 4;
std::string user_T = is_half ? "half" : "float";
int32_t user_T_bytes = is_half ? 2 : 4;

Var dst_var = Downcast<Var>(dst_call->args[1]);
Var src_var = Downcast<Var>(src_call->args[1]);
Var tmp_var = Downcast<Var>(tmp_call->args[1]);
ICHECK(buffer_address_map_.count(dst_var))
<< "Buffer address not found for dst: " << dst_var->name_hint;
ICHECK(buffer_address_map_.count(src_var))
<< "Buffer address not found for src: " << src_var->name_hint;
ICHECK(buffer_address_map_.count(tmp_var))
<< "Buffer address not found for tmp: " << tmp_var->name_hint;

std::string dst_base = PrintExpr(buffer_address_map_.at(dst_var));
std::string src_base = PrintExpr(buffer_address_map_.at(src_var));
std::string tmp_base = PrintExpr(buffer_address_map_.at(tmp_var));
std::string dst_off = PrintExpr(dst_call->args[2]);
std::string src_off = PrintExpr(src_call->args[2]);
std::string tmp_off = PrintExpr(tmp_call->args[2]);

// Address-expression helper: <base> + offset_in_elements * elem_bytes +
// extra_bytes.
auto addr_expr = [](const std::string &base, const std::string &off,
int32_t elem_bytes, int32_t extra_bytes) -> std::string {
std::string s =
base + " + ((" + off + ") * " + std::to_string(elem_bytes) + ")";
if (extra_bytes != 0) {
s += " + " + std::to_string(extra_bytes);
}
return s;
};

// Layout in tmp (interpreted as float):
// bufA: tmp[0 .. 2N) sort32 output, ping-pong A
// bufB:
// float full sort: dst (= 2N float, dst can host the ping-pong)
// half OR topk: tmp[2N .. 4N) (dst is too small / wrong type)
// bufC: mergesort wrapper internal tmp
// same-target as bufB: tmp[2N .. 4N) (when bufB == dst)
// bufB-in-tmp: tmp[4N .. 6N)
// indices (initial position, before sort32 consumes them):
// lives in bufB low half (= bufB + 0)
// float_src (half case only, before sort32 consumes it):
// lives in bufB high half (= bufB + N float)
//
// The total tmp footprint is:
// float full sort: 4N float = 16N bytes (multiplier=4)
// float topk: 6N float = 24N bytes (multiplier=6)
// half (any): 6N float = 24N bytes; allocated as 16N halves = 32N
// bytes (multiplier=16)
bool bufB_in_tmp = is_half || is_topk;
auto bufA_addr = [&](int32_t extra_bytes = 0) -> std::string {
return addr_expr(tmp_base, tmp_off, T_bytes, extra_bytes);
};
auto bufB_addr = [&](int32_t extra_bytes = 0) -> std::string {
if (bufB_in_tmp) {
return addr_expr(tmp_base, tmp_off, T_bytes,
2 * N * T_bytes + extra_bytes);
} else {
return addr_expr(dst_base, dst_off, T_bytes, extra_bytes);
}
};
auto bufC_addr = [&](int32_t extra_bytes = 0) -> std::string {
int32_t base_off = (bufB_in_tmp ? 4 : 2) * N * T_bytes;
return addr_expr(tmp_base, tmp_off, T_bytes, base_off + extra_bytes);
};
auto indices_addr = [&]() -> std::string { return bufB_addr(0); };
auto sort_src_addr = [&]() -> std::string {
// For float, the working source is the user-supplied src buffer.
// For half, the working source is the casted float copy at bufB high half.
if (is_half) {
return bufB_addr(N * T_bytes);
} else {
return addr_expr(src_base, src_off, user_T_bytes, 0);
}
};

// Open outer scope to keep helper view declarations local.
this->PrintIndent();
this->stream << "{\n";
int outer_scope = this->BeginScope();

// Phase 0 (half only): Cast src(half) → float_src(float) at bufB high half.
if (is_half) {
this->PrintIndent();
this->stream << "{\n";
int s = this->BeginScope();
this->PrintIndent();
this->stream << kAscendPtoScope << "TileUbDataND<half, 1, " << N << ", 1, "
<< N << "> sort_h_src;\n";
this->PrintIndent();
this->stream << "TASSIGN(sort_h_src, "
<< addr_expr(src_base, src_off, user_T_bytes, 0) << ");\n";
this->PrintIndent();
this->stream << kAscendPtoScope << "TileUbDataND<float, 1, " << N
<< ", 1, " << N << "> sort_f_src;\n";
this->PrintIndent();
this->stream << "TASSIGN(sort_f_src, " << sort_src_addr() << ");\n";
this->PrintIndent();
this->stream << "pto::TCVT(sort_f_src, sort_h_src, "
<< "pto::RoundMode::CAST_NONE);\n";
this->EndScope(s);
this->PrintIndent();
this->stream << "}\n";
this->PrintIndent();
this->stream << "pipe_barrier(PIPE_V);\n";
}

// Phase 1: Pad sort_src tail with -inf for [actual_num, aligned_count).
// TFILLPAD_INPLACE handles unaligned actual_num via the tile's PadValue::Min.
if (pad_count > 0) {
this->PrintIndent();
this->stream << "{\n";
int s = this->BeginScope();
this->PrintIndent();
this->stream << kAscendPtoScope << "TileUbDataND<" << T << ", 1, " << N
<< ", 1, " << actual_num
<< ", pto::PadValue::Min> sort_src_v;\n";
this->PrintIndent();
this->stream << "TASSIGN(sort_src_v, " << sort_src_addr() << ");\n";
this->PrintIndent();
this->stream << kAscendPtoScope << "TileUbDataND<" << T << ", 1, " << N
<< ", 1, " << N << ", pto::PadValue::Min> sort_src_f;\n";
this->PrintIndent();
this->stream << "TASSIGN(sort_src_f, " << sort_src_addr() << ");\n";
this->PrintIndent();
this->stream << "pto::TFILLPAD_INPLACE(sort_src_f, sort_src_v);\n";
this->EndScope(s);
this->PrintIndent();
this->stream << "}\n";
this->PrintIndent();
this->stream << "pipe_barrier(PIPE_V);\n";
}

// Phase 2: Generate ascending indices in bufB low half as float values
// 0,1,...,N-1. Sort32 reads this as uint32 (via reinterpret) so the bit
// pattern of the float values flows through as the sorted-pair index field.
this->PrintIndent();
this->stream << "{\n";
{
int s = this->BeginScope();
this->PrintIndent();
this->stream << kAscendPtoScope << "TileUbDataND<" << T << ", 1, " << N
<< ", 1, " << N << "> sort_idx;\n";
this->PrintIndent();
this->stream << "TASSIGN(sort_idx, " << indices_addr() << ");\n";
this->PrintIndent();
this->stream << "TCI<decltype(sort_idx), " << T
<< ", /*descending=*/0>(sort_idx, (" << T << ")0);\n";
this->EndScope(s);
}
this->PrintIndent();
this->stream << "}\n";
this->PrintIndent();
this->stream << "pipe_barrier(PIPE_V);\n";

// Phase 3: Sort32 — input float src + uint32 indices → bufA.
this->PrintIndent();
this->stream << "{\n";
{
int s = this->BeginScope();
this->PrintIndent();
this->stream << kAscendPtoScope << "TileUbDataND<" << T << ", 1, " << N
<< ", 1, " << N << "> sort_src;\n";
this->PrintIndent();
this->stream << "TASSIGN(sort_src, " << sort_src_addr() << ");\n";
this->PrintIndent();
this->stream << kAscendPtoScope << "TileUbDataND<uint32_t, 1, " << N
<< ", 1, " << N << "> sort_idx_u;\n";
this->PrintIndent();
this->stream << "TASSIGN(sort_idx_u, " << indices_addr() << ");\n";
this->PrintIndent();
this->stream << kAscendPtoScope << "TileUbDataND<" << T << ", 1, "
<< (2 * N) << ", 1, " << (2 * N) << "> sort_buf_a;\n";
this->PrintIndent();
this->stream << "TASSIGN(sort_buf_a, " << bufA_addr() << ");\n";
this->PrintIndent();
this->stream << "TSORT32(sort_buf_a, sort_src, sort_idx_u);\n";
this->EndScope(s);
}
this->PrintIndent();
this->stream << "}\n";
this->PrintIndent();
this->stream << "pipe_barrier(PIPE_V);\n";

// Phase 4: Merge tree — ping-pong between bufA and bufB.
// Each element in the merge stream is a (value, index) pair occupying 2
// floats in memory. After Sort32 we have block_num sorted segments of 32
// elements each in bufA.
std::vector<int32_t> segs(block_num, 32);
bool read_from_a = true; // sort32 wrote to bufA

while (segs.size() > 1) {
std::vector<int32_t> new_segs;
int32_t in_off = 0; // float-element offset from start of read buffer
int32_t out_off = 0; // float-element offset from start of write buffer

for (size_t g = 0; g < segs.size(); g += 4) {
size_t group_count = std::min<size_t>(4, segs.size() - g);
int32_t lengths[4] = {0, 0, 0, 0};
int32_t total_elems = 0;
for (size_t i = 0; i < group_count; i++) {
lengths[i] = segs[g + i];
total_elems += lengths[i];
}

auto read_addr = [&](int32_t off_floats) {
return read_from_a ? bufA_addr(off_floats * T_bytes)
: bufB_addr(off_floats * T_bytes);
};
auto write_addr = [&](int32_t off_floats) {
return read_from_a ? bufB_addr(off_floats * T_bytes)
: bufA_addr(off_floats * T_bytes);
};

this->PrintIndent();
this->stream << "{\n";
int s = this->BeginScope();

if (group_count == 1) {
// Single segment: TMOV from read buffer to write buffer.
int32_t copy_floats = lengths[0] * 2;
this->PrintIndent();
this->stream << kAscendPtoScope << "TileUbDataND<" << T << ", 1, "
<< copy_floats << ", 1, " << copy_floats
<< "> sort_cs;\n";
this->PrintIndent();
this->stream << "TASSIGN(sort_cs, " << read_addr(in_off) << ");\n";
this->PrintIndent();
this->stream << kAscendPtoScope << "TileUbDataND<" << T << ", 1, "
<< copy_floats << ", 1, " << copy_floats
<< "> sort_cd;\n";
this->PrintIndent();
this->stream << "TASSIGN(sort_cd, " << write_addr(out_off) << ");\n";
this->PrintIndent();
this->stream << "TMOV(sort_cd, sort_cs);\n";
} else {
int32_t dst_floats = total_elems * 2;
int32_t local_in = in_off;
for (size_t i = 0; i < group_count; i++) {
int32_t src_floats = lengths[i] * 2;
this->PrintIndent();
this->stream << kAscendPtoScope << "TileUbDataND<" << T << ", 1, "
<< src_floats << ", 1, " << src_floats << "> sort_s"
<< i << ";\n";
this->PrintIndent();
this->stream << "TASSIGN(sort_s" << i << ", " << read_addr(local_in)
<< ");\n";
local_in += src_floats;
}
this->PrintIndent();
this->stream << kAscendPtoScope << "TileUbDataND<" << T << ", 1, "
<< dst_floats << ", 1, " << dst_floats
<< "> sort_md;\n";
this->PrintIndent();
this->stream << "TASSIGN(sort_md, " << write_addr(out_off) << ");\n";
this->PrintIndent();
this->stream << kAscendPtoScope << "TileUbDataND<" << T << ", 1, "
<< dst_floats << ", 1, " << dst_floats
<< "> sort_mt;\n";
this->PrintIndent();
this->stream << "TASSIGN(sort_mt, " << bufC_addr() << ");\n";

bool uniform = true;
for (size_t i = 1; i < group_count; i++) {
if (lengths[i] != lengths[0]) {
uniform = false;
break;
}
}

this->PrintIndent();
if (uniform) {
int32_t src_floats = lengths[0] * 2;
this->stream << kAscendPtoScope << "MergeSort<" << T << ", "
<< src_floats << ", " << dst_floats
<< ">(sort_md, sort_mt";
} else {
this->stream << kAscendPtoScope << "MergeSortVar<" << T;
for (size_t i = 0; i < group_count; i++) {
this->stream << ", " << (lengths[i] * 2);
}
this->stream << ", " << dst_floats << ">(sort_md, sort_mt";
}
for (size_t i = 0; i < group_count; i++) {
this->stream << ", sort_s" << i;
}
this->stream << ");\n";
}

this->EndScope(s);
this->PrintIndent();
this->stream << "}\n";

in_off += total_elems * 2;
out_off += total_elems * 2;
new_segs.push_back(total_elems);
}

this->PrintIndent();
this->stream << "pipe_barrier(PIPE_V);\n";

segs = new_segs;
read_from_a = !read_from_a;
}

// After merge: result lives in bufA if no merges ran (block_num == 1) or
// if the toggle leaves read_from_a pointing back at bufA after the last
// pass; otherwise it's in bufB.
bool result_in_bufA = (block_num == 1) || (block_num > 1 && read_from_a);
std::string result_addr = result_in_bufA ? bufA_addr() : bufB_addr();

// Phase 5: Finalize into dst.
// Number of float pair-elements to emit:
// full sort: all 2*N pairs
// topk: first 2*K pairs, rounded up to user_T's block alignment
// (matches AscendC, which uses DataCopy with byte-aligned size).
int32_t output_pairs;
if (is_topk) {
int32_t elems_per_block = 32 / user_T_bytes; // 8 (float) or 16 (half)
int32_t topk_elems = 2 * top_k;
output_pairs =
((topk_elems + elems_per_block - 1) / elems_per_block) * elems_per_block;
} else {
output_pairs = 2 * N;
}

if (is_half) {
// Cast `output_pairs` floats (interleaved value-index) → `output_pairs`
// halves at dst. CAST_RINT keeps integer indices exact.
this->PrintIndent();
this->stream << "{\n";
int s = this->BeginScope();
this->PrintIndent();
this->stream << kAscendPtoScope << "TileUbDataND<float, 1, " << output_pairs
<< ", 1, " << output_pairs << "> sort_fs;\n";
this->PrintIndent();
this->stream << "TASSIGN(sort_fs, " << result_addr << ");\n";
this->PrintIndent();
this->stream << kAscendPtoScope << "TileUbDataND<half, 1, " << output_pairs
<< ", 1, " << output_pairs << "> sort_fd;\n";
this->PrintIndent();
this->stream << "TASSIGN(sort_fd, "
<< addr_expr(dst_base, dst_off, user_T_bytes, 0) << ");\n";
this->PrintIndent();
this->stream << "pto::TCVT(sort_fd, sort_fs, "
<< "pto::RoundMode::CAST_RINT);\n";
this->EndScope(s);
this->PrintIndent();
this->stream << "}\n";
this->PrintIndent();
this->stream << "pipe_barrier(PIPE_V);\n";
} else {
// Float case. For full sort the result already lives in dst (bufB) when
// result_in_bufA is false, so the TMOV is only needed otherwise. For topk
// bufB lives in tmp regardless, so we always emit the copy.
bool need_copy = is_topk || result_in_bufA;
if (need_copy) {
this->PrintIndent();
this->stream << "{\n";
int s = this->BeginScope();
this->PrintIndent();
this->stream << kAscendPtoScope << "TileUbDataND<" << T << ", 1, "
<< output_pairs << ", 1, " << output_pairs
<< "> sort_fs;\n";
this->PrintIndent();
this->stream << "TASSIGN(sort_fs, " << result_addr << ");\n";
this->PrintIndent();
this->stream << kAscendPtoScope << "TileUbDataND<" << T << ", 1, "
<< output_pairs << ", 1, " << output_pairs
<< "> sort_fd;\n";
this->PrintIndent();
this->stream << "TASSIGN(sort_fd, "
<< addr_expr(dst_base, dst_off, user_T_bytes, 0) << ");\n";
this->PrintIndent();
this->stream << "TMOV(sort_fd, sort_fs);\n";
this->EndScope(s);
this->PrintIndent();
this->stream << "}\n";
this->PrintIndent();
this->stream << "pipe_barrier(PIPE_V);\n";
}
}

this->EndScope(outer_scope);
this->PrintIndent();
this->stream << "}\n";
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The function EmitSortAlgorithm is very long (over 400 lines), which makes it difficult to read and maintain. Consider refactoring it by breaking it down into smaller, more manageable private helper functions, each corresponding to a phase of the algorithm (e.g., EmitSortPhase0_CastHalfToFloat, EmitSortPhase1_PadSource, etc.). This would improve code clarity and modularity.

Comment on lines +958 to +1012
// 2-way merge sort with asymmetric source sizes (used by Sort recursion).
template <typename T, int32_t Src0Cols, int32_t Src1Cols, int32_t DstCols>
AICORE PTO_INLINE void
MergeSortVar(TileUbDataND<T, 1, DstCols, 1, DstCols> &dst,
TileUbDataND<T, 1, DstCols, 1, DstCols> &tmp,
TileUbDataND<T, 1, Src0Cols, 1, Src0Cols> &src0,
TileUbDataND<T, 1, Src1Cols, 1, Src1Cols> &src1) {
pto::MrgSortExecutedNumList executedNumList;
pto::TMRGSORT<TileUbDataND<T, 1, DstCols, 1, DstCols>,
TileUbDataND<T, 1, DstCols, 1, DstCols>,
TileUbDataND<T, 1, Src0Cols, 1, Src0Cols>,
TileUbDataND<T, 1, Src1Cols, 1, Src1Cols>, false>(
dst, executedNumList, tmp, src0, src1);
pipe_barrier(PIPE_V);
}

// 3-way merge sort with asymmetric source sizes.
template <typename T, int32_t Src0Cols, int32_t Src1Cols, int32_t Src2Cols,
int32_t DstCols>
AICORE PTO_INLINE void
MergeSortVar(TileUbDataND<T, 1, DstCols, 1, DstCols> &dst,
TileUbDataND<T, 1, DstCols, 1, DstCols> &tmp,
TileUbDataND<T, 1, Src0Cols, 1, Src0Cols> &src0,
TileUbDataND<T, 1, Src1Cols, 1, Src1Cols> &src1,
TileUbDataND<T, 1, Src2Cols, 1, Src2Cols> &src2) {
pto::MrgSortExecutedNumList executedNumList;
pto::TMRGSORT<TileUbDataND<T, 1, DstCols, 1, DstCols>,
TileUbDataND<T, 1, DstCols, 1, DstCols>,
TileUbDataND<T, 1, Src0Cols, 1, Src0Cols>,
TileUbDataND<T, 1, Src1Cols, 1, Src1Cols>,
TileUbDataND<T, 1, Src2Cols, 1, Src2Cols>, false>(
dst, executedNumList, tmp, src0, src1, src2);
pipe_barrier(PIPE_V);
}

// 4-way merge sort with asymmetric source sizes.
template <typename T, int32_t Src0Cols, int32_t Src1Cols, int32_t Src2Cols,
int32_t Src3Cols, int32_t DstCols>
AICORE PTO_INLINE void
MergeSortVar(TileUbDataND<T, 1, DstCols, 1, DstCols> &dst,
TileUbDataND<T, 1, DstCols, 1, DstCols> &tmp,
TileUbDataND<T, 1, Src0Cols, 1, Src0Cols> &src0,
TileUbDataND<T, 1, Src1Cols, 1, Src1Cols> &src1,
TileUbDataND<T, 1, Src2Cols, 1, Src2Cols> &src2,
TileUbDataND<T, 1, Src3Cols, 1, Src3Cols> &src3) {
pto::MrgSortExecutedNumList executedNumList;
pto::TMRGSORT<TileUbDataND<T, 1, DstCols, 1, DstCols>,
TileUbDataND<T, 1, DstCols, 1, DstCols>,
TileUbDataND<T, 1, Src0Cols, 1, Src0Cols>,
TileUbDataND<T, 1, Src1Cols, 1, Src1Cols>,
TileUbDataND<T, 1, Src2Cols, 1, Src2Cols>,
TileUbDataND<T, 1, Src3Cols, 1, Src3Cols>, false>(
dst, executedNumList, tmp, src0, src1, src2, src3);
pipe_barrier(PIPE_V);
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The overloads for MergeSortVar for 2, 3, and 4 ways are nearly identical. This code duplication can be avoided by using a single variadic template. This would make the code more concise and easier to maintain.

For example:

template <typename T, int32_t DstCols, typename... Srcs>
AICORE PTO_INLINE void MergeSortVar(
    TileUbDataND<T, 1, DstCols, 1, DstCols>& dst,
    TileUbDataND<T, 1, DstCols, 1, DstCols>& tmp,
    Srcs&... srcs) {
  pto::MrgSortExecutedNumList executedNumList;
  pto::TMRGSORT<TileUbDataND<T, 1, DstCols, 1, DstCols>,
                TileUbDataND<T, 1, DstCols, 1, DstCols>,
                Srcs..., false>(dst, executedNumList, tmp, srcs...);
  pipe_barrier(PIPE_V);
}

The call site in codegen_ascend_pto.cc would also need to be updated to use this variadic template, which could simplify the code generation logic there as well.

Comment on lines +567 to +585
if (shapes.count(dtype) > 0) {
int64_t shape_size = 0;
for (size_t k = 0; k < shapes.at(dtype).size(); k++) {
if (shape_size == 0) {
shape_size = shapes.at(dtype)[k].as<IntImmNode>()->value;
} else {
shape_size *= shapes.at(dtype)[k].as<IntImmNode>()->value;
}
}
if (tmp_shape_size > shape_size) {
Array<PrimExpr> tmp_shape;
tmp_shape.push_back(IntImm(DataType::Int(32), tmp_shape_size));
shapes[dtype] = tmp_shape;
}
} else {
Array<PrimExpr> tmp_shape;
tmp_shape.push_back(IntImm(DataType::Int(32), tmp_shape_size));
shapes[dtype] = tmp_shape;
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic to calculate the total size of a shape and update it if a larger size is needed can be simplified. The current implementation has a slightly convoluted loop for calculating the product of dimensions and repeats the buffer update logic.
A refactoring could make this clearer and more concise.

          int64_t shape_size = 0;
          if (shapes.count(dtype) > 0 && !shapes.at(dtype).empty()) {
            shape_size = 1;
            for (const auto& dim : shapes.at(dtype)) {
              shape_size *= dim.as<IntImmNode>()->value;
            }
          }

          if (tmp_shape_size > shape_size) {
            Array<PrimExpr> tmp_shape;
            tmp_shape.push_back(IntImm(DataType::Int(32), tmp_shape_size));
            shapes[dtype] = tmp_shape;
          }

@hedi515 hedi515 force-pushed the ascendc_pto branch 2 times, most recently from b226f5a to 90e9c5c Compare May 7, 2026 03:36
# topk_index_ub: 91712,
# output_ub: 95808,
# }
# )
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

注释就删了吧,不要留着了

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已经删除

fuhouyu-hw
fuhouyu-hw previously approved these changes May 7, 2026
LLMZhangYC
LLMZhangYC previously approved these changes May 7, 2026
Copy link
Copy Markdown
Collaborator

@LLMZhangYC LLMZhangYC left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

fuhouyu-hw
fuhouyu-hw previously approved these changes May 7, 2026
Copy link
Copy Markdown
Collaborator

@fuhouyu-hw fuhouyu-hw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

/approve

- New .agents/skill-journal/ for collecting per-op skill feedback
- New tilelang-skill-review skill: aggregate, table-output, CLI apply/reject
- Append §6 to tilelang-op-generate: enumerate consulted skills + write entries

Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
@hedi515 hedi515 dismissed stale reviews from fuhouyu-hw and LLMZhangYC via 702f83e May 9, 2026 03:37
Copy link
Copy Markdown
Collaborator

@LLMZhangYC LLMZhangYC left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

Copy link
Copy Markdown
Collaborator

@fuhouyu-hw fuhouyu-hw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

/approve

@fuhouyu-hw fuhouyu-hw merged commit 5cc9d77 into tile-ai:ascendc_pto May 9, 2026
6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants