diff --git a/tileops/kernels/attention/deepseek_dsa_decode.py b/tileops/kernels/attention/deepseek_dsa_decode.py index c83439f49..aafd5fdf9 100644 --- a/tileops/kernels/attention/deepseek_dsa_decode.py +++ b/tileops/kernels/attention/deepseek_dsa_decode.py @@ -190,7 +190,7 @@ def _sparse_mla_fwd_main( alpha_local = T.alloc_fragment([h_per_block], accum_dtype) m_i = T.alloc_fragment([h_per_block], accum_dtype) m_i_prev = T.alloc_fragment([h_per_block], accum_dtype) - indices_local = T.alloc_local([1], indices_dtype) + indices_local = T.alloc_var(indices_dtype) # TODO: Multi buffer bar_q = T.alloc_barrier(arrive_count=384) @@ -212,9 +212,9 @@ def _sparse_mla_fwd_main( tx = T.get_thread_binding() - T.copy(q[b_i, s_i, h0:h1, 0:d // 2], q_shared_l) - T.copy(q[b_i, s_i, h0:h1, d // 2:d], q_shared_r) - T.copy(q[b_i, s_i, h0:h1, d:], q_tail_shared) + T.tma_copy(q[b_i, s_i, h0:h1, 0:d // 2], q_shared_l, barrier=bar_q) + T.tma_copy(q[b_i, s_i, h0:h1, d // 2:d], q_shared_r, barrier=bar_q) + T.tma_copy(q[b_i, s_i, h0:h1, d:], q_tail_shared, barrier=bar_q) T.barrier_arrive(bar_q) if tx < 128: @@ -232,9 +232,9 @@ def _sparse_mla_fwd_main( for h_i, bi_i in T.Parallel(h_per_block, i_block): acc_s[h_i, bi_i] = T.if_then_else(is_kv_valid[bi_i], 0, -T.infinity(acc_s.dtype)) - T.gemm(q_shared_l, kv_shared_0_l, acc_s, transpose_B=True, wg_wait=-1) - T.gemm(q_shared_r, kv_shared_0_r, acc_s, transpose_B=True, wg_wait=-1) - T.gemm(q_tail_shared, k_tail_shared_0, acc_s, transpose_B=True, wg_wait=-1) + T.wgmma_gemm(q_shared_l, kv_shared_0_l, acc_s, transpose_B=True) + T.wgmma_gemm(q_shared_r, kv_shared_0_r, acc_s, transpose_B=True) + T.wgmma_gemm(q_tail_shared, k_tail_shared_0, acc_s, transpose_B=True) T.wait_wgmma(0) @@ -243,7 +243,9 @@ def _sparse_mla_fwd_main( T.barrier_wait(bar_s_scale_and_s_free, ((i_i * 2) & 1) ^ 1) T.copy(m_i, m_i_prev) - T.reduce_max(acc_s, m_i, dim=1, clear=False) + T.reduce_max(acc_s, m_i, dim=1, clear=True) + for h_i in T.Parallel(h_per_block): + m_i[h_i] = T.max(m_i[h_i], m_i_prev[h_i]) for h_i in T.Parallel(h_per_block): alpha_local[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale) for h_i, bi_i in T.Parallel(h_per_block, i_block): @@ -268,9 +270,9 @@ def _sparse_mla_fwd_main( for h_i, bi_i in T.Parallel(h_per_block, i_block): acc_s[h_i, bi_i] = T.if_then_else(is_kv_valid[bi_i], 0, -T.infinity(acc_s.dtype)) - T.gemm(q_shared_l, kv_shared_1_l, acc_s, transpose_B=True, wg_wait=-1) - T.gemm(q_shared_r, kv_shared_1_r, acc_s, transpose_B=True, wg_wait=-1) - T.gemm(q_tail_shared, k_tail_shared_1, acc_s, transpose_B=True, wg_wait=-1) + T.wgmma_gemm(q_shared_l, kv_shared_1_l, acc_s, transpose_B=True) + T.wgmma_gemm(q_shared_r, kv_shared_1_r, acc_s, transpose_B=True) + T.wgmma_gemm(q_tail_shared, k_tail_shared_1, acc_s, transpose_B=True) T.wait_wgmma(0) @@ -278,7 +280,9 @@ def _sparse_mla_fwd_main( T.barrier_wait(bar_s_scale_and_s_free, ((i_i * 2 + 1) & 1) ^ 1) T.copy(m_i, m_i_prev) - T.reduce_max(acc_s, m_i, dim=1, clear=False) + T.reduce_max(acc_s, m_i, dim=1, clear=True) + for h_i in T.Parallel(h_per_block): + m_i[h_i] = T.max(m_i[h_i], m_i_prev[h_i]) for h_i in T.Parallel(h_per_block): alpha_local[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale) for h_i, bi_i in T.Parallel(h_per_block, i_block): @@ -343,55 +347,51 @@ def _sparse_mla_fwd_main( # Buffer 0 T.barrier_wait(bar_k_0_free[0], ((i_i & 1) ^ 1)) for r in T.serial(4): - indices_local[0] = indices[b_i, s_i, g_i, (i_i * 2) * i_block + r * 16 + + indices_local = indices[b_i, s_i, g_i, (i_i * 2) * i_block + r * 16 + (tx - 256) // 8] - is_kv_valid[r * 16 + (tx - 256) // 8] = indices_local[0] <= max_kv_i + is_kv_valid[r * 16 + (tx - 256) // 8] = indices_local <= max_kv_i if is_kv_valid[r * 16 + (tx - 256) // 8]: - with T.attr("default", "async_scope", 1): - for u in T.serial(4): - for v in T.vectorized(8): - kv_shared_0_l[r * 16 + (tx - 256) // 8, - 64 * u + (tx - 256) % 8 * 8 + - v] = kv[b_i, indices_local[0], g_i, - 64 * u + (tx - 256) % 8 * 8 + v] - kv_shared_0_r[r * 16 + (tx - 256) // 8, - 64 * u + (tx - 256) % 8 * 8 + - v] = kv[b_i, indices_local[0], g_i, - d // 2 + 64 * u + - (tx - 256) % 8 * 8 + v] - with T.attr("default", "async_scope", 1): - for v in T.vectorized(8): - k_tail_shared_0[r * 16 + (tx - 256) // 8, - (tx - 256) % 8 * 8 + - v] = kv[b_i, indices_local[0], g_i, - d + (tx - 256) % 8 * 8 + v] + for u in T.serial(4): + T.ptx_cp_async( + T.access_ptr(kv_shared_0_l[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8], "w", 8), + T.access_ptr(kv[b_i, indices_local, g_i, 64 * u + (tx - 256) % 8 * 8], "r", 8), + 16, + ) + T.ptx_cp_async( + T.access_ptr(kv_shared_0_r[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8], "w", 8), + T.access_ptr(kv[b_i, indices_local, g_i, d // 2 + 64 * u + (tx - 256) % 8 * 8], "r", 8), + 16, + ) + T.ptx_cp_async( + T.access_ptr(k_tail_shared_0[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8], "w", 8), + T.access_ptr(kv[b_i, indices_local, g_i, d + (tx - 256) % 8 * 8], "r", 8), + 16, + ) T.cp_async_barrier_noinc(bar_k_0_ready[0]) # Buffer 1 T.barrier_wait(bar_k_1_free[0], ((i_i & 1) ^ 1)) for r in T.serial(4): - indices_local[0] = indices[b_i, s_i, g_i, (i_i * 2 + 1) * i_block + + indices_local = indices[b_i, s_i, g_i, (i_i * 2 + 1) * i_block + r * 16 + (tx - 256) // 8] - is_kv_valid[r * 16 + (tx - 256) // 8] = indices_local[0] <= max_kv_i + is_kv_valid[r * 16 + (tx - 256) // 8] = indices_local <= max_kv_i if is_kv_valid[r * 16 + (tx - 256) // 8]: - with T.attr("default", "async_scope", 1): - for u in T.serial(4): - for v in T.vectorized(8): - kv_shared_1_l[r * 16 + (tx - 256) // 8, - 64 * u + (tx - 256) % 8 * 8 + - v] = kv[b_i, indices_local[0], g_i, - 64 * u + (tx - 256) % 8 * 8 + v] - kv_shared_1_r[r * 16 + (tx - 256) // 8, - 64 * u + (tx - 256) % 8 * 8 + - v] = kv[b_i, indices_local[0], g_i, - d // 2 + 64 * u + - (tx - 256) % 8 * 8 + v] - with T.attr("default", "async_scope", 1): - for v in T.vectorized(8): - k_tail_shared_1[r * 16 + (tx - 256) // 8, - (tx - 256) % 8 * 8 + - v] = kv[b_i, indices_local[0], g_i, - d + (tx - 256) % 8 * 8 + v] + for u in T.serial(4): + T.ptx_cp_async( + T.access_ptr(kv_shared_1_l[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8], "w", 8), + T.access_ptr(kv[b_i, indices_local, g_i, 64 * u + (tx - 256) % 8 * 8], "r", 8), + 16, + ) + T.ptx_cp_async( + T.access_ptr(kv_shared_1_r[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8], "w", 8), + T.access_ptr(kv[b_i, indices_local, g_i, d // 2 + 64 * u + (tx - 256) % 8 * 8], "r", 8), + 16, + ) + T.ptx_cp_async( + T.access_ptr(k_tail_shared_1[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8], "w", 8), + T.access_ptr(kv[b_i, indices_local, g_i, d + (tx - 256) % 8 * 8], "r", 8), + 16, + ) T.cp_async_barrier_noinc(bar_k_1_ready[0]) return _sparse_mla_fwd_main