Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 64 additions & 32 deletions examples/jit_cpp/chunk_gdn/dynamic_bsnd/chunk_o_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -280,9 +280,12 @@ AICORE void chunk_o_kernel(
int64_t global_chunk_base = 0;
bool first_cube_iter = true;

for (int64_t work_idx = static_cast<int64_t>(cid);
work_idx < total_work;
work_idx += static_cast<int64_t>(block_num)) {
int64_t cid_i64 = static_cast<int64_t>(cid);
int64_t work_base = (total_work * cid_i64) / static_cast<int64_t>(block_num);
int64_t work_end =
(total_work * (cid_i64 + 1)) / static_cast<int64_t>(block_num);

for (int64_t work_idx = work_base; work_idx < work_end; ++work_idx) {
// Wait for Vec to finish with previous chunk's workspace (flag 3)
if (!first_cube_iter) wait_flag_dev(3);
set_flag(PIPE_FIX, PIPE_M, EVENT_ID0);
Expand Down Expand Up @@ -439,6 +442,18 @@ AICORE void chunk_o_kernel(
// flag 2: Cube→Vec "QKV (GEMM 3 result) is ready"
// flag 3: Vec→Cube "I'm done with this chunk, you can reuse workspace"
ffts_cross_core_sync(PIPE_FIX, 1 | (2 << 4) | (0 << 8));
// V is independent of vec-side gating and workspace stores. Start this
// prefetch earlier so its GM->L1 latency overlaps with store/gating work.
{
L1Mat<half, ChunkSize, HiddenSize, DYNAMIC, DYNAMIC> _l1(valid_rows, HiddenSize);
TASSIGN(_l1, 131072);
Shape<1, 1, 1, DYNAMIC, DYNAMIC> _gs;
_gs.shape[3] = valid_rows; _gs.shape[4] = HiddenSize;
GlobalTensor<half, decltype(_gs), Stride<1, 1, 1, NumHeads * HiddenSize, 1>> _gm(V_handle + qkv_offset, _gs);
TLOAD(_l1, _gm);
if (valid_rows != ChunkSize) TFILLPAD(_l1, _l1);
}


// Wait for Vec to write QK_gated back (flag 1, Vec→Cube)
wait_flag_dev(1);
Expand All @@ -457,17 +472,6 @@ AICORE void chunk_o_kernel(
static_cast<int64_t>(cid) * WsGatedSize, _gs);
TLOAD(_l1, _gm);
}
// ── Load V [valid_rows × D] from GM → L1 ────────────────────────
{
L1Mat<half, ChunkSize, HiddenSize, DYNAMIC, DYNAMIC> _l1(valid_rows, HiddenSize);
TASSIGN(_l1, 131072);
Shape<1, 1, 1, DYNAMIC, DYNAMIC> _gs;
_gs.shape[3] = valid_rows; _gs.shape[4] = HiddenSize;
GlobalTensor<half, decltype(_gs), Stride<1, 1, 1, NumHeads * HiddenSize, 1>> _gm(V_handle + qkv_offset, _gs);
TLOAD(_l1, _gm);
if (valid_rows != ChunkSize) TFILLPAD(_l1, _l1);
}

// ── GEMM 3: QKV = QK_gated @ V (gated attention → values) ──────
{
TileLeft<half, ChunkSize, ChunkSize, ChunkSize, ChunkSize> _l0a;
Expand Down Expand Up @@ -510,6 +514,19 @@ AICORE void chunk_o_kernel(
}
} else {
// ── Variable-length sequence path (cu_seqlens != nullptr) ──────────
int64_t total_work_var = 0;
for (int64_t tsi = 0; tsi < num_seqs; ++tsi) {
int64_t tbos = static_cast<int64_t>(cu_seqlens[tsi]);
int64_t teos = static_cast<int64_t>(cu_seqlens[tsi + 1]);
int64_t tslen = teos - tbos;
int64_t tnc = (tslen + ChunkSize - 1) / ChunkSize;
total_work_var += tnc * NumHeads;
}
int64_t cid_i64 = static_cast<int64_t>(cid);
int64_t work_base =
(total_work_var * cid_i64) / static_cast<int64_t>(block_num);
int64_t work_end =
(total_work_var * (cid_i64 + 1)) / static_cast<int64_t>(block_num);
int64_t gi = 0;
int64_t chunk_global_idx = 0;
bool first_cube_iter_v = true;
Expand All @@ -521,8 +538,7 @@ AICORE void chunk_o_kernel(

for (int64_t ci = 0; ci < nc; ++ci) {
for (int32_t h = 0; h < NumHeads; ++h) {
if (gi % static_cast<int64_t>(block_num) ==
static_cast<int64_t>(cid)) {
if (gi >= work_base && gi < work_end) {
if (!first_cube_iter_v) wait_flag_dev(3);
set_flag(PIPE_FIX, PIPE_M, EVENT_ID0);
wait_flag(PIPE_FIX, PIPE_M, EVENT_ID0);
Expand Down Expand Up @@ -631,6 +647,18 @@ AICORE void chunk_o_kernel(

// Cube→Vec: QK & QS ready (flag 0)
ffts_cross_core_sync(PIPE_FIX, 1 | (2 << 4) | (0 << 8));
// V is independent of vec-side gating and workspace stores. Start this
// prefetch earlier so its GM->L1 latency overlaps with store/gating work.
{
L1Mat<half, ChunkSize, HiddenSize, DYNAMIC, DYNAMIC> _l1(valid_rows, HiddenSize);
TASSIGN(_l1, 131072);
Shape<1, 1, 1, DYNAMIC, DYNAMIC> _gs;
_gs.shape[3] = valid_rows; _gs.shape[4] = HiddenSize;
GlobalTensor<half, decltype(_gs), Stride<1, 1, 1, NumHeads * HiddenSize, 1>> _gm(V_handle + qkv_offset, _gs);
TLOAD(_l1, _gm);
if (valid_rows != ChunkSize) TFILLPAD(_l1, _l1);
}


// Wait Vec→Cube: QK_gated ready (flag 1)
wait_flag_dev(1);
Expand All @@ -649,17 +677,6 @@ AICORE void chunk_o_kernel(
static_cast<int64_t>(cid) * WsGatedSize, _gs);
TLOAD(_l1, _gm);
}
// Load V
{
L1Mat<half, ChunkSize, HiddenSize, DYNAMIC, DYNAMIC> _l1(valid_rows, HiddenSize);
TASSIGN(_l1, 131072);
Shape<1, 1, 1, DYNAMIC, DYNAMIC> _gs;
_gs.shape[3] = valid_rows; _gs.shape[4] = HiddenSize;
GlobalTensor<half, decltype(_gs), Stride<1, 1, 1, NumHeads * HiddenSize, 1>> _gm(V_handle + qkv_offset, _gs);
TLOAD(_l1, _gm);
if (valid_rows != ChunkSize) TFILLPAD(_l1, _l1);
}

// GEMM 3: QKV = QK_gated @ V
{
TileLeft<half, ChunkSize, ChunkSize, ChunkSize, ChunkSize> _l0a;
Expand Down Expand Up @@ -738,9 +755,12 @@ AICORE void chunk_o_kernel(
// ── Fixed-length sequence path ──────────────────────────────────────
int64_t chunks_per_seq = (seq_len + ChunkSize - 1) / ChunkSize;

for (int64_t work_idx = static_cast<int64_t>(cid);
work_idx < total_work;
work_idx += static_cast<int64_t>(block_num)) {
int64_t cid_i64 = static_cast<int64_t>(cid);
int64_t work_base = (total_work * cid_i64) / static_cast<int64_t>(block_num);
int64_t work_end =
(total_work * (cid_i64 + 1)) / static_cast<int64_t>(block_num);

for (int64_t work_idx = work_base; work_idx < work_end; ++work_idx) {
int32_t head_idx = static_cast<int32_t>(work_idx % NumHeads);
int64_t chunk_head_idx = work_idx / NumHeads;
int64_t seq_idx = chunk_head_idx / chunks_per_seq;
Expand Down Expand Up @@ -958,6 +978,19 @@ AICORE void chunk_o_kernel(
}
} else {
// ── Variable-length sequence path (cu_seqlens != nullptr) ──────────
int64_t total_work_var = 0;
for (int64_t tsi = 0; tsi < num_seqs; ++tsi) {
int64_t tbos = static_cast<int64_t>(cu_seqlens[tsi]);
int64_t teos = static_cast<int64_t>(cu_seqlens[tsi + 1]);
int64_t tslen = teos - tbos;
int64_t tnc = (tslen + ChunkSize - 1) / ChunkSize;
total_work_var += tnc * NumHeads;
}
int64_t cid_i64 = static_cast<int64_t>(cid);
int64_t work_base =
(total_work_var * cid_i64) / static_cast<int64_t>(block_num);
int64_t work_end =
(total_work_var * (cid_i64 + 1)) / static_cast<int64_t>(block_num);
int64_t gi = 0;
for (int64_t si = 0; si < num_seqs; ++si) {
int64_t bos = static_cast<int64_t>(cu_seqlens[si]);
Expand All @@ -967,8 +1000,7 @@ AICORE void chunk_o_kernel(

for (int64_t ci = 0; ci < nc; ++ci) {
for (int32_t h = 0; h < NumHeads; ++h) {
if (gi % static_cast<int64_t>(block_num) ==
static_cast<int64_t>(cid)) {
if (gi >= work_base && gi < work_end) {
int64_t chunk_start = ci * ChunkSize;
int64_t remaining = slen - chunk_start;
int32_t valid_rows = static_cast<int32_t>(
Expand Down
Loading