Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 9 additions & 31 deletions examples/lightning_indexer/example_lightning_indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,16 @@
tilelang.disable_cache()


@tilelang.jit(out_idx=[-1], workspace_idx=[-3]) # for jit
def indexer(B, N2, G, S1, S2, D, TOP_K, VECTOR_BASEN, VECTOR_BASEG, BLOCK_M, BLOCK_N, BLOCK_K, input_dtype="float16", calc_dtype="float"):
# Memory planning is the "address reuse" pass: it lets non-overlapping
# buffers (and the auto-injected sort/topk tmp buffers that aren't in
# T.annotate_address) share UB space.
pass_configs = {
tilelang.PassConfigKey.TL_ASCEND_MEMORY_PLANNING: True,
}


@tilelang.jit(out_idx=[-1], workspace_idx=[-3], pass_configs=pass_configs)
def indexer(B, N2, G, S1, S2, D, TOP_K, VECTOR_BASEN, VECTOR_BASEG, BLOCK_M, BLOCK_N, BLOCK_K, input_dtype="float16", calc_dtype="float"):
@T.prim_func
def main(
Query: T.Tensor((B, S1, N2, G * D), input_dtype),
Expand All @@ -30,15 +37,6 @@ def main(

C_L0 = T.alloc_L0C((BLOCK_M, BLOCK_N), calc_dtype)

T.annotate_address(
{
# L1 address
Q_L1: 0,
K_L1: 16384,
# L0C address
C_L0: 0,
}
)
T.barrier_all()
for n2 in T.serial(N2):
for g in T.serial(G):
Expand All @@ -63,10 +61,7 @@ def main(

with T.Scope("V"):
mm_res_ub = T.alloc_ub((VECTOR_BASEG, VECTOR_BASEN), calc_dtype)
mm_res_ub_flat = T.alloc_ub((VECTOR_BASEG * VECTOR_BASEN), calc_dtype)
mm_res_ub_uint8 = T.alloc_ub((VECTOR_BASEG, VECTOR_BASEN), "uint8")
weight_ub = T.alloc_ub(VECTOR_BASEG, calc_dtype)
weight_brcb_ub = T.alloc_ub((VECTOR_BASEG, 8), calc_dtype)
reduce_tmp_ub = T.alloc_ub((VECTOR_BASEG, VECTOR_BASEN), calc_dtype)
reduce_g_ub = T.alloc_ub(VECTOR_BASEN, calc_dtype)
# Accumulate all S2 scores, then topk once
Expand All @@ -75,23 +70,6 @@ def main(
topk_index_ub = T.alloc_ub(TOP_K, calc_dtype)
output_ub = T.alloc_ub(TOP_K, "int")

T.annotate_address(
{
# ub address
mm_res_ub: 0,
mm_res_ub_flat: 0,
mm_res_ub_uint8: 0,
weight_ub: 32768,
weight_brcb_ub: 32832,
reduce_tmp_ub: 33344,
reduce_g_ub: 66112,
score_accum_ub: 67136,
topk_dst_ub: 83520,
topk_index_ub: 91712,
output_ub: 95808,
}
)

s1_start_idx = vid * each_core_process_num
s1_end_idx = s1_start_idx + each_core_process_num

Expand Down
110 changes: 110 additions & 0 deletions src/target/codegen_ascend_pto.cc
Original file line number Diff line number Diff line change
Expand Up @@ -858,6 +858,10 @@ void CodeGenTileLangAscendPto::VisitExpr_(const CallNode *op,
PowCodegen(op);
} else if (op->op.same_as(tl::ascend_sort32())) {
Sort32Codegen(op, "TSORT32");
} else if (op->op.same_as(tl::ascend_sort())) {
SortCodegen(op);
} else if (op->op.same_as(tl::ascend_topk())) {
TopKCodegen(op);
} else if (op->op.same_as(tl::ascend_merge_sort())) {
MergeSortCodegen(op, "TMRGSORT");
} else if (op->op.same_as(tl::ascend_transpose())) {
Expand Down Expand Up @@ -1575,6 +1579,112 @@ void CodeGenTileLangAscendPto::MergeSortCodegen(const CallNode *op,
this->stream << ");\n";
}

void CodeGenTileLangAscendPto::SortCodegen(const CallNode *op) {
// After tmp injection, args layout:
// [0] func_name (e.g. "Sort<float>")
// [1] dst access_ptr -- 2*alignedCount user_T elements
// [2] src access_ptr -- alignedCount user_T elements (may be mutated)
// [3] tmp access_ptr -- internal workspace allocated by
// allocate_tmp_buffer [4] repeatTimes (constant) [5] actual_num (constant)
ICHECK(op->args.size() == 6)
<< "ascend_sort expects 6 args after tmp injection, got "
<< op->args.size();

auto dst_call = op->args[1].as<CallNode>();
auto src_call = op->args[2].as<CallNode>();
auto tmp_call = op->args[3].as<CallNode>();
ICHECK(dst_call && dst_call->op.same_as(builtin::tvm_access_ptr()));
ICHECK(src_call && src_call->op.same_as(builtin::tvm_access_ptr()));
ICHECK(tmp_call && tmp_call->op.same_as(builtin::tvm_access_ptr()));

int32_t repeat_times = Downcast<IntImm>(op->args[4])->value;
int32_t actual_num = Downcast<IntImm>(op->args[5])->value;

EmitSortAlgorithm(dst_call, src_call, tmp_call, repeat_times, actual_num,
/*top_k=*/-1);
}

void CodeGenTileLangAscendPto::TopKCodegen(const CallNode *op) {
// After tmp injection, args layout:
// [0] func_name (e.g. "TopK<float>")
// [1] dst access_ptr -- 2*K user_T elements (UB-rounded)
// [2] src access_ptr -- alignedCount user_T elements
// [3] tmp access_ptr -- internal workspace
// [4] K (constant)
// [5] repeatTimes (constant)
// [6] actual_num (constant)
ICHECK(op->args.size() == 7)
<< "ascend_topk expects 7 args after tmp injection, got "
<< op->args.size();

auto dst_call = op->args[1].as<CallNode>();
auto src_call = op->args[2].as<CallNode>();
auto tmp_call = op->args[3].as<CallNode>();
ICHECK(dst_call && dst_call->op.same_as(builtin::tvm_access_ptr()));
ICHECK(src_call && src_call->op.same_as(builtin::tvm_access_ptr()));
ICHECK(tmp_call && tmp_call->op.same_as(builtin::tvm_access_ptr()));

int32_t k = Downcast<IntImm>(op->args[4])->value;
int32_t repeat_times = Downcast<IntImm>(op->args[5])->value;
int32_t actual_num = Downcast<IntImm>(op->args[6])->value;
ICHECK(k > 0) << "TopK requires K > 0, got " << k;

EmitSortAlgorithm(dst_call, src_call, tmp_call, repeat_times, actual_num,
/*top_k=*/k);
}

// =============================================================================
// Sort/TopK pipeline: thin codegen wrapper
// =============================================================================
//
// The full algorithm (pad, sort32, merge tree, finalize) lives in
// pto/common.h as the device template tl::ascend_pto::Sort. This codegen
// just forwards parsed parameters and emits a single template call.

void CodeGenTileLangAscendPto::EmitSortAlgorithm(const CallNode *dst_call,
const CallNode *src_call,
const CallNode *tmp_call,
int32_t repeat_times,
int32_t actual_num,
int32_t top_k) {
int32_t aligned_count = repeat_times * 32;

DataType dtype = src_call->args[0].dtype();
bool is_half = dtype.is_float() && dtype.bits() == 16;
bool is_float = dtype.is_float() && dtype.bits() == 32;
ICHECK(is_half || is_float)
<< "PTO Sort/TopK supports float32 / float16 input, got " << dtype;
std::string user_T = is_half ? "half" : "float";
int32_t user_T_bytes = is_half ? 2 : 4;

Var dst_var = Downcast<Var>(dst_call->args[1]);
Var src_var = Downcast<Var>(src_call->args[1]);
Var tmp_var = Downcast<Var>(tmp_call->args[1]);
ICHECK(buffer_address_map_.count(dst_var))
<< "Buffer address not found for dst: " << dst_var->name_hint;
ICHECK(buffer_address_map_.count(src_var))
<< "Buffer address not found for src: " << src_var->name_hint;
ICHECK(buffer_address_map_.count(tmp_var))
<< "Buffer address not found for tmp: " << tmp_var->name_hint;

// Emit "<base> + ((offset) * elem_bytes)" as a runtime byte address.
auto byte_addr = [this](Var var, PrimExpr offset, int32_t elem_bytes) {
std::string base = PrintExpr(buffer_address_map_.at(var));
std::string off = PrintExpr(offset);
return base + " + ((" + off + ") * " + std::to_string(elem_bytes) + ")";
};

std::string dst_addr = byte_addr(dst_var, dst_call->args[2], user_T_bytes);
std::string src_addr = byte_addr(src_var, src_call->args[2], user_T_bytes);
std::string tmp_addr =
byte_addr(tmp_var, tmp_call->args[2], /*elem_bytes=*/4);

this->PrintIndent();
this->stream << kAscendPtoScope << "Sort<" << user_T << ", " << aligned_count
<< ", " << actual_num << ", " << top_k << ">(" << dst_addr
<< ", " << src_addr << ", " << tmp_addr << ");\n";
}

void CodeGenTileLangAscendPto::TransposeCodegen(const CallNode *op,
const std::string &op_name) {
this->PrintIndent();
Expand Down
11 changes: 11 additions & 0 deletions src/target/codegen_ascend_pto.h
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,17 @@ class CodeGenTileLangAscendPto final : public CodeGenC {

void MergeSortCodegen(const CallNode *op, const std::string &op_name);

void SortCodegen(const CallNode *op);

void TopKCodegen(const CallNode *op);

// Emits a single tl::ascend_pto::Sort<UserT, N, ActualCount, TopK>(...)
// call. The full algorithm (pad, sort32, merge tree, finalize) lives in
// pto/common.h.
void EmitSortAlgorithm(const CallNode *dst_call, const CallNode *src_call,
const CallNode *tmp_call, int32_t repeat_times,
int32_t actual_num, int32_t top_k);

void TransposeCodegen(const CallNode *op, const std::string &op_name);

void XorCodegen(const CallNode *op, const std::string &op_name);
Expand Down
Loading
Loading