diff --git a/examples/lightning_indexer/example_lightning_indexer.py b/examples/lightning_indexer/example_lightning_indexer.py index 741dbbdd3..cabfdf46b 100644 --- a/examples/lightning_indexer/example_lightning_indexer.py +++ b/examples/lightning_indexer/example_lightning_indexer.py @@ -8,9 +8,16 @@ tilelang.disable_cache() -@tilelang.jit(out_idx=[-1], workspace_idx=[-3]) # for jit -def indexer(B, N2, G, S1, S2, D, TOP_K, VECTOR_BASEN, VECTOR_BASEG, BLOCK_M, BLOCK_N, BLOCK_K, input_dtype="float16", calc_dtype="float"): +# Memory planning is the "address reuse" pass: it lets non-overlapping +# buffers (and the auto-injected sort/topk tmp buffers that aren't in +# T.annotate_address) share UB space. +pass_configs = { + tilelang.PassConfigKey.TL_ASCEND_MEMORY_PLANNING: True, +} + +@tilelang.jit(out_idx=[-1], workspace_idx=[-3], pass_configs=pass_configs) +def indexer(B, N2, G, S1, S2, D, TOP_K, VECTOR_BASEN, VECTOR_BASEG, BLOCK_M, BLOCK_N, BLOCK_K, input_dtype="float16", calc_dtype="float"): @T.prim_func def main( Query: T.Tensor((B, S1, N2, G * D), input_dtype), @@ -30,15 +37,6 @@ def main( C_L0 = T.alloc_L0C((BLOCK_M, BLOCK_N), calc_dtype) - T.annotate_address( - { - # L1 address - Q_L1: 0, - K_L1: 16384, - # L0C address - C_L0: 0, - } - ) T.barrier_all() for n2 in T.serial(N2): for g in T.serial(G): @@ -63,10 +61,7 @@ def main( with T.Scope("V"): mm_res_ub = T.alloc_ub((VECTOR_BASEG, VECTOR_BASEN), calc_dtype) - mm_res_ub_flat = T.alloc_ub((VECTOR_BASEG * VECTOR_BASEN), calc_dtype) - mm_res_ub_uint8 = T.alloc_ub((VECTOR_BASEG, VECTOR_BASEN), "uint8") weight_ub = T.alloc_ub(VECTOR_BASEG, calc_dtype) - weight_brcb_ub = T.alloc_ub((VECTOR_BASEG, 8), calc_dtype) reduce_tmp_ub = T.alloc_ub((VECTOR_BASEG, VECTOR_BASEN), calc_dtype) reduce_g_ub = T.alloc_ub(VECTOR_BASEN, calc_dtype) # Accumulate all S2 scores, then topk once @@ -75,23 +70,6 @@ def main( topk_index_ub = T.alloc_ub(TOP_K, calc_dtype) output_ub = T.alloc_ub(TOP_K, "int") - T.annotate_address( - { - # ub address - mm_res_ub: 0, - mm_res_ub_flat: 0, - mm_res_ub_uint8: 0, - weight_ub: 32768, - weight_brcb_ub: 32832, - reduce_tmp_ub: 33344, - reduce_g_ub: 66112, - score_accum_ub: 67136, - topk_dst_ub: 83520, - topk_index_ub: 91712, - output_ub: 95808, - } - ) - s1_start_idx = vid * each_core_process_num s1_end_idx = s1_start_idx + each_core_process_num diff --git a/src/target/codegen_ascend_pto.cc b/src/target/codegen_ascend_pto.cc index 01cfa7844..b839cf97d 100644 --- a/src/target/codegen_ascend_pto.cc +++ b/src/target/codegen_ascend_pto.cc @@ -858,6 +858,10 @@ void CodeGenTileLangAscendPto::VisitExpr_(const CallNode *op, PowCodegen(op); } else if (op->op.same_as(tl::ascend_sort32())) { Sort32Codegen(op, "TSORT32"); + } else if (op->op.same_as(tl::ascend_sort())) { + SortCodegen(op); + } else if (op->op.same_as(tl::ascend_topk())) { + TopKCodegen(op); } else if (op->op.same_as(tl::ascend_merge_sort())) { MergeSortCodegen(op, "TMRGSORT"); } else if (op->op.same_as(tl::ascend_transpose())) { @@ -1599,6 +1603,112 @@ void CodeGenTileLangAscendPto::MergeSortCodegen(const CallNode *op, this->stream << ");\n"; } +void CodeGenTileLangAscendPto::SortCodegen(const CallNode *op) { + // After tmp injection, args layout: + // [0] func_name (e.g. "Sort") + // [1] dst access_ptr -- 2*alignedCount user_T elements + // [2] src access_ptr -- alignedCount user_T elements (may be mutated) + // [3] tmp access_ptr -- internal workspace allocated by + // allocate_tmp_buffer [4] repeatTimes (constant) [5] actual_num (constant) + ICHECK(op->args.size() == 6) + << "ascend_sort expects 6 args after tmp injection, got " + << op->args.size(); + + auto dst_call = op->args[1].as(); + auto src_call = op->args[2].as(); + auto tmp_call = op->args[3].as(); + ICHECK(dst_call && dst_call->op.same_as(builtin::tvm_access_ptr())); + ICHECK(src_call && src_call->op.same_as(builtin::tvm_access_ptr())); + ICHECK(tmp_call && tmp_call->op.same_as(builtin::tvm_access_ptr())); + + int32_t repeat_times = Downcast(op->args[4])->value; + int32_t actual_num = Downcast(op->args[5])->value; + + EmitSortAlgorithm(dst_call, src_call, tmp_call, repeat_times, actual_num, + /*top_k=*/-1); +} + +void CodeGenTileLangAscendPto::TopKCodegen(const CallNode *op) { + // After tmp injection, args layout: + // [0] func_name (e.g. "TopK") + // [1] dst access_ptr -- 2*K user_T elements (UB-rounded) + // [2] src access_ptr -- alignedCount user_T elements + // [3] tmp access_ptr -- internal workspace + // [4] K (constant) + // [5] repeatTimes (constant) + // [6] actual_num (constant) + ICHECK(op->args.size() == 7) + << "ascend_topk expects 7 args after tmp injection, got " + << op->args.size(); + + auto dst_call = op->args[1].as(); + auto src_call = op->args[2].as(); + auto tmp_call = op->args[3].as(); + ICHECK(dst_call && dst_call->op.same_as(builtin::tvm_access_ptr())); + ICHECK(src_call && src_call->op.same_as(builtin::tvm_access_ptr())); + ICHECK(tmp_call && tmp_call->op.same_as(builtin::tvm_access_ptr())); + + int32_t k = Downcast(op->args[4])->value; + int32_t repeat_times = Downcast(op->args[5])->value; + int32_t actual_num = Downcast(op->args[6])->value; + ICHECK(k > 0) << "TopK requires K > 0, got " << k; + + EmitSortAlgorithm(dst_call, src_call, tmp_call, repeat_times, actual_num, + /*top_k=*/k); +} + +// ============================================================================= +// Sort/TopK pipeline: thin codegen wrapper +// ============================================================================= +// +// The full algorithm (pad, sort32, merge tree, finalize) lives in +// pto/common.h as the device template tl::ascend_pto::Sort. This codegen +// just forwards parsed parameters and emits a single template call. + +void CodeGenTileLangAscendPto::EmitSortAlgorithm(const CallNode *dst_call, + const CallNode *src_call, + const CallNode *tmp_call, + int32_t repeat_times, + int32_t actual_num, + int32_t top_k) { + int32_t aligned_count = repeat_times * 32; + + DataType dtype = src_call->args[0].dtype(); + bool is_half = dtype.is_float() && dtype.bits() == 16; + bool is_float = dtype.is_float() && dtype.bits() == 32; + ICHECK(is_half || is_float) + << "PTO Sort/TopK supports float32 / float16 input, got " << dtype; + std::string user_T = is_half ? "half" : "float"; + int32_t user_T_bytes = is_half ? 2 : 4; + + Var dst_var = Downcast(dst_call->args[1]); + Var src_var = Downcast(src_call->args[1]); + Var tmp_var = Downcast(tmp_call->args[1]); + ICHECK(buffer_address_map_.count(dst_var)) + << "Buffer address not found for dst: " << dst_var->name_hint; + ICHECK(buffer_address_map_.count(src_var)) + << "Buffer address not found for src: " << src_var->name_hint; + ICHECK(buffer_address_map_.count(tmp_var)) + << "Buffer address not found for tmp: " << tmp_var->name_hint; + + // Emit " + ((offset) * elem_bytes)" as a runtime byte address. + auto byte_addr = [this](Var var, PrimExpr offset, int32_t elem_bytes) { + std::string base = PrintExpr(buffer_address_map_.at(var)); + std::string off = PrintExpr(offset); + return base + " + ((" + off + ") * " + std::to_string(elem_bytes) + ")"; + }; + + std::string dst_addr = byte_addr(dst_var, dst_call->args[2], user_T_bytes); + std::string src_addr = byte_addr(src_var, src_call->args[2], user_T_bytes); + std::string tmp_addr = + byte_addr(tmp_var, tmp_call->args[2], /*elem_bytes=*/4); + + this->PrintIndent(); + this->stream << kAscendPtoScope << "Sort<" << user_T << ", " << aligned_count + << ", " << actual_num << ", " << top_k << ">(" << dst_addr + << ", " << src_addr << ", " << tmp_addr << ");\n"; +} + void CodeGenTileLangAscendPto::TransposeCodegen(const CallNode *op, const std::string &op_name) { this->PrintIndent(); diff --git a/src/target/codegen_ascend_pto.h b/src/target/codegen_ascend_pto.h index 97957dbf4..4902114fb 100644 --- a/src/target/codegen_ascend_pto.h +++ b/src/target/codegen_ascend_pto.h @@ -153,6 +153,17 @@ class CodeGenTileLangAscendPto final : public CodeGenC { void MergeSortCodegen(const CallNode *op, const std::string &op_name); + void SortCodegen(const CallNode *op); + + void TopKCodegen(const CallNode *op); + + // Emits a single tl::ascend_pto::Sort(...) + // call. The full algorithm (pad, sort32, merge tree, finalize) lives in + // pto/common.h. + void EmitSortAlgorithm(const CallNode *dst_call, const CallNode *src_call, + const CallNode *tmp_call, int32_t repeat_times, + int32_t actual_num, int32_t top_k); + void TransposeCodegen(const CallNode *op, const std::string &op_name); void XorCodegen(const CallNode *op, const std::string &op_name); diff --git a/src/tl_templates/pto/common.h b/src/tl_templates/pto/common.h index 488d802fb..299b3dd31 100644 --- a/src/tl_templates/pto/common.h +++ b/src/tl_templates/pto/common.h @@ -955,6 +955,391 @@ MergeSort(TileUbDataND &dst, pipe_barrier(PIPE_V); } +// 2-way merge sort with asymmetric source sizes (used by Sort recursion). +template +AICORE PTO_INLINE void +MergeSortVar(TileUbDataND &dst, + TileUbDataND &tmp, + TileUbDataND &src0, + TileUbDataND &src1) { + pto::MrgSortExecutedNumList executedNumList; + pto::TMRGSORT, + TileUbDataND, + TileUbDataND, + TileUbDataND, false>( + dst, executedNumList, tmp, src0, src1); + pipe_barrier(PIPE_V); +} + +// 3-way merge sort with asymmetric source sizes. +template +AICORE PTO_INLINE void +MergeSortVar(TileUbDataND &dst, + TileUbDataND &tmp, + TileUbDataND &src0, + TileUbDataND &src1, + TileUbDataND &src2) { + pto::MrgSortExecutedNumList executedNumList; + pto::TMRGSORT, + TileUbDataND, + TileUbDataND, + TileUbDataND, + TileUbDataND, false>( + dst, executedNumList, tmp, src0, src1, src2); + pipe_barrier(PIPE_V); +} + +// 4-way merge sort with asymmetric source sizes. +template +AICORE PTO_INLINE void +MergeSortVar(TileUbDataND &dst, + TileUbDataND &tmp, + TileUbDataND &src0, + TileUbDataND &src1, + TileUbDataND &src2, + TileUbDataND &src3) { + pto::MrgSortExecutedNumList executedNumList; + pto::TMRGSORT, + TileUbDataND, + TileUbDataND, + TileUbDataND, + TileUbDataND, + TileUbDataND, false>( + dst, executedNumList, tmp, src0, src1, src2, src3); + pipe_barrier(PIPE_V); +} + +// ============================================================================ +// Full Sort / TopK: device-side template implementation +// ============================================================================ +// +// Layout in tmp (interpreted as float): +// bufA: tmp[0 .. 2N) sort32 output / ping-pong A +// bufB: dst (float full sort) OR tmp[2N .. 4N) (half OR topk) +// bufC: tmp[2N .. 4N) when bufB == dst; tmp[4N .. 6N) when bufB lives in tmp +// +// Indices live in bufB low half before sort32 consumes them. For half input, +// the casted float source lives in bufB high half before sort32. Both regions +// are then free for use as bufB ping-pong during the merge tree. + +namespace sort_detail { + +// All constexpr helpers below are tagged AICORE so the CCE compiler lets +// them be called from [aicore]-attributed templates. They're still pure +// compile-time computations -- the attribute is purely a visibility hint. + +// Length of the i-th segment in a level given (NumSegs, FullSize, LastSize). +// Returns 0 when i is out of range so callers can use a uniform 4-element +// length tuple. +AICORE constexpr int32_t seg_length(int32_t num_segs, int32_t full_size, + int32_t last_size, int32_t i) { + if (i >= num_segs) + return 0; + if (i == num_segs - 1) + return last_size; + return full_size; +} + +// Length of the (single) last segment after one level of up-to-4-way merging. +AICORE constexpr int32_t next_last_size(int32_t num_segs, int32_t full_size, + int32_t last_size) { + int32_t last_group_start = ((num_segs - 1) / 4) * 4; + int32_t last_group_count = num_segs - last_group_start; + return (last_group_count - 1) * full_size + last_size; +} + +// Length of a "full" segment after one level. If only one segment remains +// after this level (tree converged) "full" equals the single remaining size. +AICORE constexpr int32_t next_full_size(int32_t num_segs, int32_t full_size, + int32_t last_size) { + int32_t new_num = (num_segs + 3) / 4; + if (new_num <= 1) { + return next_last_size(num_segs, full_size, last_size); + } + return 4 * full_size; +} + +// Number of merge-tree levels needed to reduce BlockNum segments to 1. +AICORE constexpr int32_t compute_levels(int32_t block_num) { + int32_t n = block_num; + int32_t levels = 0; + while (n > 1) { + n = (n + 3) / 4; + levels++; + } + return levels; +} + +// Whether the final result lives in bufA after the merge tree finishes. +// read_from_a starts true and toggles every level, so result_in_bufA equals +// (levels % 2 == 0). For BlockNum == 1 (zero levels) this is also true. +template +constexpr bool result_in_bufA_v = (compute_levels(BlockNum) % 2 == 0); + +// Number of float pair-elements the finalize step has to copy. For full sort +// it's 2*N; for topk it's 2*K rounded up to user_T's block alignment so the +// generated TMOV/TCVT lands on aligned bytes (matches AscendC's DataCopy). +AICORE constexpr int32_t output_pairs(int32_t n, int32_t top_k, + int32_t user_t_bytes) { + if (top_k < 0) + return 2 * n; + int32_t elems_per_block = 32 / user_t_bytes; + int32_t topk_elems = 2 * top_k; + return ((topk_elems + elems_per_block - 1) / elems_per_block) * + elems_per_block; +} + +// One sorted segment moved from a read buffer to a write buffer (no merging +// needed because the segment is alone in its 4-group). +template +AICORE PTO_INLINE void merge_group_copy(int32_t src_addr, int32_t dst_addr) { + constexpr int32_t copy_floats = Len * 2; + TileUbDataND sort_cs; + TASSIGN(sort_cs, src_addr); + TileUbDataND sort_cd; + TASSIGN(sort_cd, dst_addr); + TMOV(sort_cd, sort_cs); +} + +template +AICORE PTO_INLINE void merge_group_2way(int32_t s0, int32_t s1, int32_t md, + int32_t mt) { + constexpr int32_t dst_floats = (Len0 + Len1) * 2; + TileUbDataND sort_s0; + TASSIGN(sort_s0, s0); + TileUbDataND sort_s1; + TASSIGN(sort_s1, s1); + TileUbDataND sort_md; + TASSIGN(sort_md, md); + TileUbDataND sort_mt; + TASSIGN(sort_mt, mt); + if constexpr (Len0 == Len1) { + MergeSort(sort_md, sort_mt, sort_s0, sort_s1); + } else { + MergeSortVar(sort_md, sort_mt, sort_s0, + sort_s1); + } +} + +template +AICORE PTO_INLINE void merge_group_3way(int32_t s0, int32_t s1, int32_t s2, + int32_t md, int32_t mt) { + constexpr int32_t dst_floats = (Len0 + Len1 + Len2) * 2; + TileUbDataND sort_s0; + TASSIGN(sort_s0, s0); + TileUbDataND sort_s1; + TASSIGN(sort_s1, s1); + TileUbDataND sort_s2; + TASSIGN(sort_s2, s2); + TileUbDataND sort_md; + TASSIGN(sort_md, md); + TileUbDataND sort_mt; + TASSIGN(sort_mt, mt); + if constexpr (Len0 == Len1 && Len1 == Len2) { + MergeSort(sort_md, sort_mt, sort_s0, sort_s1, + sort_s2); + } else { + MergeSortVar( + sort_md, sort_mt, sort_s0, sort_s1, sort_s2); + } +} + +template +AICORE PTO_INLINE void merge_group_4way(int32_t s0, int32_t s1, int32_t s2, + int32_t s3, int32_t md, int32_t mt) { + constexpr int32_t dst_floats = (Len0 + Len1 + Len2 + Len3) * 2; + TileUbDataND sort_s0; + TASSIGN(sort_s0, s0); + TileUbDataND sort_s1; + TASSIGN(sort_s1, s1); + TileUbDataND sort_s2; + TASSIGN(sort_s2, s2); + TileUbDataND sort_s3; + TASSIGN(sort_s3, s3); + TileUbDataND sort_md; + TASSIGN(sort_md, md); + TileUbDataND sort_mt; + TASSIGN(sort_mt, mt); + if constexpr (Len0 == Len1 && Len1 == Len2 && Len2 == Len3) { + MergeSort(sort_md, sort_mt, sort_s0, sort_s1, + sort_s2, sort_s3); + } else { + MergeSortVar( + sort_md, sort_mt, sort_s0, sort_s1, sort_s2, sort_s3); + } +} + +// Walk the groups within one merge-tree level. Recurses on group index G. +template +AICORE PTO_INLINE void merge_groups_loop(int32_t bufA_addr, int32_t bufB_addr, + int32_t bufC_addr) { + if constexpr (G < NumSegs) { + constexpr int32_t len0 = seg_length(NumSegs, FullSize, LastSize, G); + constexpr int32_t len1 = seg_length(NumSegs, FullSize, LastSize, G + 1); + constexpr int32_t len2 = seg_length(NumSegs, FullSize, LastSize, G + 2); + constexpr int32_t len3 = seg_length(NumSegs, FullSize, LastSize, G + 3); + constexpr int32_t group_count = + (len0 > 0) + (len1 > 0) + (len2 > 0) + (len3 > 0); + constexpr int32_t total_elems = len0 + len1 + len2 + len3; + constexpr int32_t T_BYTES = sizeof(T); // sort runs in float + + const int32_t read_base = ReadFromA ? bufA_addr : bufB_addr; + const int32_t write_base = ReadFromA ? bufB_addr : bufA_addr; + const int32_t in_byte_off = InOff * T_BYTES; + const int32_t out_byte_off = OutOff * T_BYTES; + + if constexpr (group_count == 1) { + merge_group_copy(read_base + in_byte_off, + write_base + out_byte_off); + } else if constexpr (group_count == 2) { + merge_group_2way( + read_base + in_byte_off, read_base + in_byte_off + len0 * 2 * T_BYTES, + write_base + out_byte_off, bufC_addr); + } else if constexpr (group_count == 3) { + merge_group_3way( + read_base + in_byte_off, read_base + in_byte_off + len0 * 2 * T_BYTES, + read_base + in_byte_off + (len0 + len1) * 2 * T_BYTES, + write_base + out_byte_off, bufC_addr); + } else { // group_count == 4 + merge_group_4way( + read_base + in_byte_off, read_base + in_byte_off + len0 * 2 * T_BYTES, + read_base + in_byte_off + (len0 + len1) * 2 * T_BYTES, + read_base + in_byte_off + (len0 + len1 + len2) * 2 * T_BYTES, + write_base + out_byte_off, bufC_addr); + } + + merge_groups_loop( + bufA_addr, bufB_addr, bufC_addr); + } +} + +// Drive one level of the merge tree, then recurse to the next level. +template +AICORE PTO_INLINE void merge_levels(int32_t bufA_addr, int32_t bufB_addr, + int32_t bufC_addr) { + if constexpr (NumSegs > 1) { + merge_groups_loop( + bufA_addr, bufB_addr, bufC_addr); + pipe_barrier(PIPE_V); + + constexpr int32_t new_num_segs = (NumSegs + 3) / 4; + constexpr int32_t new_full = next_full_size(NumSegs, FullSize, LastSize); + constexpr int32_t new_last = next_last_size(NumSegs, FullSize, LastSize); + + merge_levels( + bufA_addr, bufB_addr, bufC_addr); + } +} + +} // namespace sort_detail + +// Top-level entry point. UserT is the user-facing dtype (float or half), +// internally everything sorts in float (matches AscendC's B16 workaround). +// TopK == -1 means "full sort", TopK >= 0 means "topk; emit only 2*K pairs". +template +AICORE PTO_INLINE void Sort(int32_t dst_addr, int32_t src_addr, + int32_t tmp_addr) { + static_assert(N % 32 == 0, "Sort: N must be a multiple of 32"); + static_assert(ActualCount > 0 && ActualCount <= N, + "Sort: 0 < ActualCount <= N"); + + constexpr bool is_topk = (TopK >= 0); + constexpr bool is_half = std::is_same_v; + constexpr bool buf_b_in_tmp = is_half || is_topk; + + constexpr int32_t T_BYTES = 4; // float internally + constexpr int32_t USER_T_BYTES = sizeof(UserT); // 4 or 2 + constexpr int32_t BLOCK_NUM = N / 32; + constexpr int32_t PAD_COUNT = N - ActualCount; + + const int32_t bufA = tmp_addr; + const int32_t bufB = buf_b_in_tmp ? (tmp_addr + 2 * N * T_BYTES) : dst_addr; + const int32_t bufC = tmp_addr + (buf_b_in_tmp ? 4 : 2) * N * T_BYTES; + const int32_t indices_addr = bufB; // bufB low half before sort32 + const int32_t sort_src_addr = + is_half ? (bufB + N * T_BYTES) : src_addr; // bufB high half for half + + // Phase 0 (half only): cast user src(half) -> float at bufB high half. + if constexpr (is_half) { + TileUbDataND sort_h_src; + TASSIGN(sort_h_src, src_addr); + TileUbDataND sort_f_src; + TASSIGN(sort_f_src, sort_src_addr); + pto::TCVT(sort_f_src, sort_h_src, pto::RoundMode::CAST_NONE); + pipe_barrier(PIPE_V); + } + + // Phase 1: pad sort_src tail with -inf for [ActualCount, N). + if constexpr (PAD_COUNT > 0) { + TileUbDataND sort_src_v; + TASSIGN(sort_src_v, sort_src_addr); + TileUbDataND sort_src_f; + TASSIGN(sort_src_f, sort_src_addr); + pto::TFILLPAD_INPLACE(sort_src_f, sort_src_v); + pipe_barrier(PIPE_V); + } + + // Phase 2: generate ascending indices in bufB low half (float values 0..N-1 + // that sort32 will reinterpret as uint32 in the value-index pair output). + { + TileUbDataND sort_idx; + TASSIGN(sort_idx, indices_addr); + TCI(sort_idx, (float)0); + } + pipe_barrier(PIPE_V); + + // Phase 3: sort32 (float src + uint32 indices -> bufA, 32-block sorted pairs) + { + TileUbDataND sort_src; + TASSIGN(sort_src, sort_src_addr); + TileUbDataND sort_idx_u; + TASSIGN(sort_idx_u, indices_addr); + TileUbDataND sort_buf_a; + TASSIGN(sort_buf_a, bufA); + TSORT32(sort_buf_a, sort_src, sort_idx_u); + } + pipe_barrier(PIPE_V); + + // Phase 4: merge tree (compile-time unrolled by sort_detail::merge_levels). + sort_detail::merge_levels(bufA, bufB, bufC); + + // Phase 5: finalize into dst. + constexpr bool result_in_bufA = sort_detail::result_in_bufA_v; + constexpr int32_t OUTPUT_PAIRS = + sort_detail::output_pairs(N, TopK, USER_T_BYTES); + + const int32_t result_addr = result_in_bufA ? bufA : bufB; + + if constexpr (is_half) { + // Cast 2*K (or 2*N) float pairs -> halves at dst. CAST_RINT keeps the + // integer indices exact since they were generated as 0..N-1. + TileUbDataND sort_fs; + TASSIGN(sort_fs, result_addr); + TileUbDataND sort_fd; + TASSIGN(sort_fd, dst_addr); + pto::TCVT(sort_fd, sort_fs, pto::RoundMode::CAST_RINT); + pipe_barrier(PIPE_V); + } else { + // Float full sort: dst is bufB when bufB lives in dst, so the TMOV is + // only needed when the final write landed in bufA. For topk bufB is + // always in tmp, so we always have to copy. + constexpr bool need_copy = is_topk || result_in_bufA; + if constexpr (need_copy) { + TileUbDataND sort_fs; + TASSIGN(sort_fs, result_addr); + TileUbDataND sort_fd; + TASSIGN(sort_fd, dst_addr); + TMOV(sort_fd, sort_fs); + pipe_barrier(PIPE_V); + } + } +} + template AICORE PTO_INLINE void transpose(TileUbDataND &dst, TileUbDataND &src, diff --git a/src/transform/allocate_tmp_buffer.cc b/src/transform/allocate_tmp_buffer.cc index 4c923f9c8..e86cde8b4 100644 --- a/src/transform/allocate_tmp_buffer.cc +++ b/src/transform/allocate_tmp_buffer.cc @@ -107,7 +107,9 @@ class CallNodeModifier : public StmtExprMutator { if (op->op.same_as(tl::ascend_sigmoid()) || op->op.same_as(tl::ascend_pow()) || op->op.same_as(tl::ascend_bitwise_xor()) || - op->op.same_as(tl::ascend_merge_sort())) { + op->op.same_as(tl::ascend_merge_sort()) || + ("pto" == target_ && (op->op.same_as(tl::ascend_sort()) || + op->op.same_as(tl::ascend_topk())))) { return CallNodeAddTmp(op, tmp_buffer_param_offset, 2); } else { return CallNodeAddTmp(op, tmp_buffer_param_offset, 1); @@ -172,7 +174,10 @@ class CallNodeModifier : public StmtExprMutator { } } } - } else if ("pto" == target_ && op->op.same_as(tl::ascend_bitwise_xor()) && + } else if ("pto" == target_ && + (op->op.same_as(tl::ascend_bitwise_xor()) || + op->op.same_as(tl::ascend_sort()) || + op->op.same_as(tl::ascend_topk())) && tmp_bufs_.size() > 0) { const CallNode *src_access_ptr = Downcast(op->args[1]).get(); DataType dtype = src_access_ptr->args[0].as()->dtype; @@ -201,7 +206,11 @@ class CallNodeModifier : public StmtExprMutator { } } } else if ("pto" == target_ && op->op.same_as(tl::ascend_gather_mask()) && - tmp_bufs_.size() > 0) { + tmp_bufs_.size() > 0 && op->args[3].as()) { + // gather_mask args[3] is the src1 pattern: either a Call (Buffer + // pattern) or a StringImm ("P1010" etc.). Only the Buffer-pattern + // form needs a dtype-keyed tmp; the string-pattern form falls through + // to the generic uint8 tmp_buf_ below. const CallNode *src_access_ptr = Downcast(op->args[3]).get(); DataType dtype = src_access_ptr->args[0].as()->dtype; if (dtype == DataType::UInt(8)) { @@ -548,6 +557,50 @@ class TmpBufferInjector : public StmtExprMutator { shapes[dtype] = tmp_shape; } } + } else if (call->op.same_as(tl::ascend_sort()) || + call->op.same_as(tl::ascend_topk())) { + const CallNode *src_access_ptr = Downcast(call->args[2]).get(); + std::string src_buffer_name = + src_access_ptr->args[1].as()->name_hint; + const BufferNode *src_buffer_node = + GetBufferNodeByName_(alloc_buffers, src_buffer_name); + DataType dtype = src_buffer_node->dtype; + if (dtype != DataType::UInt(8)) { + // sort: bufA (2*alignedCount) + bufC (2*alignedCount) for float + // (dst doubles as bufB). Half: also a float cast scratch. + // topk: must fit bufA + bufB + bufC = 6*alignedCount float since + // user dst (size 2*K) is too small to host bufB ping-pong. + // Both paths share this allocation; use the larger (topk) sizing + // when topk is present so the same tmp pool serves both ops. + bool is_topk = call->op.same_as(tl::ascend_topk()); + int64_t multiplier; + if (dtype.bytes() == 2) { + multiplier = 16; // half: reserve enough for cast-to-float pool + } else { + multiplier = is_topk ? 6 : 4; + } + int64_t tmp_shape_size = + Downcast(src_access_ptr->args[3])->value * multiplier; + if (shapes.count(dtype) > 0) { + int64_t shape_size = 0; + for (size_t k = 0; k < shapes.at(dtype).size(); k++) { + if (shape_size == 0) { + shape_size = shapes.at(dtype)[k].as()->value; + } else { + shape_size *= shapes.at(dtype)[k].as()->value; + } + } + if (tmp_shape_size > shape_size) { + Array tmp_shape; + tmp_shape.push_back(IntImm(DataType::Int(32), tmp_shape_size)); + shapes[dtype] = tmp_shape; + } + } else { + Array tmp_shape; + tmp_shape.push_back(IntImm(DataType::Int(32), tmp_shape_size)); + shapes[dtype] = tmp_shape; + } + } } else if (call->op.same_as(tl::ascend_gather_mask())) { if (call->args[3].as()) { const CallNode *dst_access_ptr = Downcast(call->args[1]).get(); diff --git a/src/transform/common/operation_config.h b/src/transform/common/operation_config.h index d340eb202..4f9bda62b 100644 --- a/src/transform/common/operation_config.h +++ b/src/transform/common/operation_config.h @@ -354,6 +354,7 @@ const std::unordered_map pto_tmp_arg_ops = { {tl::ascend_bitwise_xor().get(), 3}, {tl::ascend_round().get(), 2}, {tl::ascend_broadcast().get(), 3}, {tl::ascend_merge_sort().get(), 3}, {tl::ascend_select().get(), 3}, {tl::ascend_gather_mask().get(), 4}, + {tl::ascend_sort().get(), 3}, {tl::ascend_topk().get(), 3}, }; } // namespace tl diff --git a/testing/python/language/test_tilelang_ascend_language_elementwise.py b/testing/python/language/test_tilelang_ascend_language_elementwise.py index 7314f0736..b389ee4fe 100644 --- a/testing/python/language/test_tilelang_ascend_language_elementwise.py +++ b/testing/python/language/test_tilelang_ascend_language_elementwise.py @@ -3875,7 +3875,7 @@ def run_test_sort(M, N, block_M, block_N, dtype, target): func = sort(M, N, block_M, block_N, dtype) func = tilelang.compile(func, out_idx=[-1], pass_configs=pass_configs, target=target) - torch_dtype = torch.float if dtype == "float" else torch.float16 + torch_dtype = torch.float if dtype in ("float", "float32") else torch.float16 a = torch.arange(0, M * N, dtype=torch_dtype).reshape(M, N).npu() b = func(a) @@ -3891,8 +3891,8 @@ def run_test_sort(M, N, block_M, block_N, dtype, target): torch.testing.assert_close(out_indices, ref_index.float(), rtol=1e-3, atol=1e-3) -@pytest.mark.parametrize("dtype", ["float16", "float"]) -@pytest.mark.parametrize("target", ["ascendc"]) +@pytest.mark.parametrize("dtype", ["float16", "float32"]) +@pytest.mark.parametrize("target", ["ascendc", "pto"]) @pytest.mark.parametrize("shape", [(1, 131)]) def test_sort(dtype, target, shape): M, N = shape @@ -4034,7 +4034,7 @@ def run_test_topk(M, N, K, block_M, block_N, dtype, target): func = topk(M, N, K, block_M, block_N, dtype) func = tilelang.compile(func, out_idx=[-1], pass_configs=pass_configs, target=target) - torch_dtype = torch.float if dtype == "float" else torch.float16 + torch_dtype = torch.float if dtype in ("float", "float32") else torch.float16 a = torch.arange(0, M * N, dtype=torch_dtype).reshape(M, N).npu() b = func(a) @@ -4052,8 +4052,8 @@ def run_test_topk(M, N, K, block_M, block_N, dtype, target): torch.testing.assert_close(out_indices, ref_indices.float(), rtol=1e-3, atol=1e-3) -@pytest.mark.parametrize("dtype", ["float16", "float"]) -@pytest.mark.parametrize("target", ["ascendc"]) +@pytest.mark.parametrize("dtype", ["float16", "float32"]) +@pytest.mark.parametrize("target", ["ascendc", "pto"]) @pytest.mark.parametrize("shape", [(1, 51)]) def test_topk(dtype, target, shape): M, N = shape