diff --git a/examples/jit_cpp/chunk_gdn/dynamic_bsnd/chunk_o_kernel.cpp b/examples/jit_cpp/chunk_gdn/dynamic_bsnd/chunk_o_kernel.cpp index 2090c762..5122ec92 100644 --- a/examples/jit_cpp/chunk_gdn/dynamic_bsnd/chunk_o_kernel.cpp +++ b/examples/jit_cpp/chunk_gdn/dynamic_bsnd/chunk_o_kernel.cpp @@ -280,9 +280,12 @@ AICORE void chunk_o_kernel( int64_t global_chunk_base = 0; bool first_cube_iter = true; - for (int64_t work_idx = static_cast(cid); - work_idx < total_work; - work_idx += static_cast(block_num)) { + int64_t cid_i64 = static_cast(cid); + int64_t work_base = (total_work * cid_i64) / static_cast(block_num); + int64_t work_end = + (total_work * (cid_i64 + 1)) / static_cast(block_num); + + for (int64_t work_idx = work_base; work_idx < work_end; ++work_idx) { // Wait for Vec to finish with previous chunk's workspace (flag 3) if (!first_cube_iter) wait_flag_dev(3); set_flag(PIPE_FIX, PIPE_M, EVENT_ID0); @@ -439,6 +442,18 @@ AICORE void chunk_o_kernel( // flag 2: Cube→Vec "QKV (GEMM 3 result) is ready" // flag 3: Vec→Cube "I'm done with this chunk, you can reuse workspace" ffts_cross_core_sync(PIPE_FIX, 1 | (2 << 4) | (0 << 8)); + // V is independent of vec-side gating and workspace stores. Start this + // prefetch earlier so its GM->L1 latency overlaps with store/gating work. + { + L1Mat _l1(valid_rows, HiddenSize); + TASSIGN(_l1, 131072); + Shape<1, 1, 1, DYNAMIC, DYNAMIC> _gs; + _gs.shape[3] = valid_rows; _gs.shape[4] = HiddenSize; + GlobalTensor> _gm(V_handle + qkv_offset, _gs); + TLOAD(_l1, _gm); + if (valid_rows != ChunkSize) TFILLPAD(_l1, _l1); + } + // Wait for Vec to write QK_gated back (flag 1, Vec→Cube) wait_flag_dev(1); @@ -457,17 +472,6 @@ AICORE void chunk_o_kernel( static_cast(cid) * WsGatedSize, _gs); TLOAD(_l1, _gm); } - // ── Load V [valid_rows × D] from GM → L1 ──────────────────────── - { - L1Mat _l1(valid_rows, HiddenSize); - TASSIGN(_l1, 131072); - Shape<1, 1, 1, DYNAMIC, DYNAMIC> _gs; - _gs.shape[3] = valid_rows; _gs.shape[4] = HiddenSize; - GlobalTensor> _gm(V_handle + qkv_offset, _gs); - TLOAD(_l1, _gm); - if (valid_rows != ChunkSize) TFILLPAD(_l1, _l1); - } - // ── GEMM 3: QKV = QK_gated @ V (gated attention → values) ────── { TileLeft _l0a; @@ -510,6 +514,19 @@ AICORE void chunk_o_kernel( } } else { // ── Variable-length sequence path (cu_seqlens != nullptr) ────────── + int64_t total_work_var = 0; + for (int64_t tsi = 0; tsi < num_seqs; ++tsi) { + int64_t tbos = static_cast(cu_seqlens[tsi]); + int64_t teos = static_cast(cu_seqlens[tsi + 1]); + int64_t tslen = teos - tbos; + int64_t tnc = (tslen + ChunkSize - 1) / ChunkSize; + total_work_var += tnc * NumHeads; + } + int64_t cid_i64 = static_cast(cid); + int64_t work_base = + (total_work_var * cid_i64) / static_cast(block_num); + int64_t work_end = + (total_work_var * (cid_i64 + 1)) / static_cast(block_num); int64_t gi = 0; int64_t chunk_global_idx = 0; bool first_cube_iter_v = true; @@ -521,8 +538,7 @@ AICORE void chunk_o_kernel( for (int64_t ci = 0; ci < nc; ++ci) { for (int32_t h = 0; h < NumHeads; ++h) { - if (gi % static_cast(block_num) == - static_cast(cid)) { + if (gi >= work_base && gi < work_end) { if (!first_cube_iter_v) wait_flag_dev(3); set_flag(PIPE_FIX, PIPE_M, EVENT_ID0); wait_flag(PIPE_FIX, PIPE_M, EVENT_ID0); @@ -631,6 +647,18 @@ AICORE void chunk_o_kernel( // Cube→Vec: QK & QS ready (flag 0) ffts_cross_core_sync(PIPE_FIX, 1 | (2 << 4) | (0 << 8)); + // V is independent of vec-side gating and workspace stores. Start this + // prefetch earlier so its GM->L1 latency overlaps with store/gating work. + { + L1Mat _l1(valid_rows, HiddenSize); + TASSIGN(_l1, 131072); + Shape<1, 1, 1, DYNAMIC, DYNAMIC> _gs; + _gs.shape[3] = valid_rows; _gs.shape[4] = HiddenSize; + GlobalTensor> _gm(V_handle + qkv_offset, _gs); + TLOAD(_l1, _gm); + if (valid_rows != ChunkSize) TFILLPAD(_l1, _l1); + } + // Wait Vec→Cube: QK_gated ready (flag 1) wait_flag_dev(1); @@ -649,17 +677,6 @@ AICORE void chunk_o_kernel( static_cast(cid) * WsGatedSize, _gs); TLOAD(_l1, _gm); } - // Load V - { - L1Mat _l1(valid_rows, HiddenSize); - TASSIGN(_l1, 131072); - Shape<1, 1, 1, DYNAMIC, DYNAMIC> _gs; - _gs.shape[3] = valid_rows; _gs.shape[4] = HiddenSize; - GlobalTensor> _gm(V_handle + qkv_offset, _gs); - TLOAD(_l1, _gm); - if (valid_rows != ChunkSize) TFILLPAD(_l1, _l1); - } - // GEMM 3: QKV = QK_gated @ V { TileLeft _l0a; @@ -738,9 +755,12 @@ AICORE void chunk_o_kernel( // ── Fixed-length sequence path ────────────────────────────────────── int64_t chunks_per_seq = (seq_len + ChunkSize - 1) / ChunkSize; - for (int64_t work_idx = static_cast(cid); - work_idx < total_work; - work_idx += static_cast(block_num)) { + int64_t cid_i64 = static_cast(cid); + int64_t work_base = (total_work * cid_i64) / static_cast(block_num); + int64_t work_end = + (total_work * (cid_i64 + 1)) / static_cast(block_num); + + for (int64_t work_idx = work_base; work_idx < work_end; ++work_idx) { int32_t head_idx = static_cast(work_idx % NumHeads); int64_t chunk_head_idx = work_idx / NumHeads; int64_t seq_idx = chunk_head_idx / chunks_per_seq; @@ -958,6 +978,19 @@ AICORE void chunk_o_kernel( } } else { // ── Variable-length sequence path (cu_seqlens != nullptr) ────────── + int64_t total_work_var = 0; + for (int64_t tsi = 0; tsi < num_seqs; ++tsi) { + int64_t tbos = static_cast(cu_seqlens[tsi]); + int64_t teos = static_cast(cu_seqlens[tsi + 1]); + int64_t tslen = teos - tbos; + int64_t tnc = (tslen + ChunkSize - 1) / ChunkSize; + total_work_var += tnc * NumHeads; + } + int64_t cid_i64 = static_cast(cid); + int64_t work_base = + (total_work_var * cid_i64) / static_cast(block_num); + int64_t work_end = + (total_work_var * (cid_i64 + 1)) / static_cast(block_num); int64_t gi = 0; for (int64_t si = 0; si < num_seqs; ++si) { int64_t bos = static_cast(cu_seqlens[si]); @@ -967,8 +1000,7 @@ AICORE void chunk_o_kernel( for (int64_t ci = 0; ci < nc; ++ci) { for (int32_t h = 0; h < NumHeads; ++h) { - if (gi % static_cast(block_num) == - static_cast(cid)) { + if (gi >= work_base && gi < work_end) { int64_t chunk_start = ci * ChunkSize; int64_t remaining = slen - chunk_start; int32_t valid_rows = static_cast(