Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 52 additions & 52 deletions tileops/kernels/attention/deepseek_dsa_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def _sparse_mla_fwd_main(
alpha_local = T.alloc_fragment([h_per_block], accum_dtype)
m_i = T.alloc_fragment([h_per_block], accum_dtype)
m_i_prev = T.alloc_fragment([h_per_block], accum_dtype)
indices_local = T.alloc_local([1], indices_dtype)
indices_local = T.alloc_var(indices_dtype)

# TODO: Multi buffer
bar_q = T.alloc_barrier(arrive_count=384)
Expand All @@ -212,9 +212,9 @@ def _sparse_mla_fwd_main(

tx = T.get_thread_binding()

T.copy(q[b_i, s_i, h0:h1, 0:d // 2], q_shared_l)
T.copy(q[b_i, s_i, h0:h1, d // 2:d], q_shared_r)
T.copy(q[b_i, s_i, h0:h1, d:], q_tail_shared)
T.tma_copy(q[b_i, s_i, h0:h1, 0:d // 2], q_shared_l, barrier=bar_q)
T.tma_copy(q[b_i, s_i, h0:h1, d // 2:d], q_shared_r, barrier=bar_q)
T.tma_copy(q[b_i, s_i, h0:h1, d:], q_tail_shared, barrier=bar_q)
T.barrier_arrive(bar_q)

if tx < 128:
Expand All @@ -232,9 +232,9 @@ def _sparse_mla_fwd_main(
for h_i, bi_i in T.Parallel(h_per_block, i_block):
acc_s[h_i, bi_i] = T.if_then_else(is_kv_valid[bi_i], 0,
-T.infinity(acc_s.dtype))
T.gemm(q_shared_l, kv_shared_0_l, acc_s, transpose_B=True, wg_wait=-1)
T.gemm(q_shared_r, kv_shared_0_r, acc_s, transpose_B=True, wg_wait=-1)
T.gemm(q_tail_shared, k_tail_shared_0, acc_s, transpose_B=True, wg_wait=-1)
T.wgmma_gemm(q_shared_l, kv_shared_0_l, acc_s, transpose_B=True)
T.wgmma_gemm(q_shared_r, kv_shared_0_r, acc_s, transpose_B=True)
T.wgmma_gemm(q_tail_shared, k_tail_shared_0, acc_s, transpose_B=True)

T.wait_wgmma(0)

Expand All @@ -243,7 +243,9 @@ def _sparse_mla_fwd_main(
T.barrier_wait(bar_s_scale_and_s_free, ((i_i * 2) & 1) ^ 1)

T.copy(m_i, m_i_prev)
T.reduce_max(acc_s, m_i, dim=1, clear=False)
T.reduce_max(acc_s, m_i, dim=1, clear=True)
for h_i in T.Parallel(h_per_block):
m_i[h_i] = T.max(m_i[h_i], m_i_prev[h_i])
for h_i in T.Parallel(h_per_block):
alpha_local[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale)
for h_i, bi_i in T.Parallel(h_per_block, i_block):
Expand All @@ -268,17 +270,19 @@ def _sparse_mla_fwd_main(
for h_i, bi_i in T.Parallel(h_per_block, i_block):
acc_s[h_i, bi_i] = T.if_then_else(is_kv_valid[bi_i], 0,
-T.infinity(acc_s.dtype))
T.gemm(q_shared_l, kv_shared_1_l, acc_s, transpose_B=True, wg_wait=-1)
T.gemm(q_shared_r, kv_shared_1_r, acc_s, transpose_B=True, wg_wait=-1)
T.gemm(q_tail_shared, k_tail_shared_1, acc_s, transpose_B=True, wg_wait=-1)
T.wgmma_gemm(q_shared_l, kv_shared_1_l, acc_s, transpose_B=True)
T.wgmma_gemm(q_shared_r, kv_shared_1_r, acc_s, transpose_B=True)
T.wgmma_gemm(q_tail_shared, k_tail_shared_1, acc_s, transpose_B=True)

T.wait_wgmma(0)

T.barrier_arrive(bar_s_scale_and_s_free)
T.barrier_wait(bar_s_scale_and_s_free, ((i_i * 2 + 1) & 1) ^ 1)

T.copy(m_i, m_i_prev)
T.reduce_max(acc_s, m_i, dim=1, clear=False)
T.reduce_max(acc_s, m_i, dim=1, clear=True)
for h_i in T.Parallel(h_per_block):
m_i[h_i] = T.max(m_i[h_i], m_i_prev[h_i])
for h_i in T.Parallel(h_per_block):
alpha_local[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale)
for h_i, bi_i in T.Parallel(h_per_block, i_block):
Expand Down Expand Up @@ -343,55 +347,51 @@ def _sparse_mla_fwd_main(
# Buffer 0
T.barrier_wait(bar_k_0_free[0], ((i_i & 1) ^ 1))
for r in T.serial(4):
indices_local[0] = indices[b_i, s_i, g_i, (i_i * 2) * i_block + r * 16 +
indices_local = indices[b_i, s_i, g_i, (i_i * 2) * i_block + r * 16 +
(tx - 256) // 8]
is_kv_valid[r * 16 + (tx - 256) // 8] = indices_local[0] <= max_kv_i
is_kv_valid[r * 16 + (tx - 256) // 8] = indices_local <= max_kv_i
if is_kv_valid[r * 16 + (tx - 256) // 8]:
with T.attr("default", "async_scope", 1):
for u in T.serial(4):
for v in T.vectorized(8):
kv_shared_0_l[r * 16 + (tx - 256) // 8,
64 * u + (tx - 256) % 8 * 8 +
v] = kv[b_i, indices_local[0], g_i,
64 * u + (tx - 256) % 8 * 8 + v]
kv_shared_0_r[r * 16 + (tx - 256) // 8,
64 * u + (tx - 256) % 8 * 8 +
v] = kv[b_i, indices_local[0], g_i,
d // 2 + 64 * u +
(tx - 256) % 8 * 8 + v]
with T.attr("default", "async_scope", 1):
for v in T.vectorized(8):
k_tail_shared_0[r * 16 + (tx - 256) // 8,
(tx - 256) % 8 * 8 +
v] = kv[b_i, indices_local[0], g_i,
d + (tx - 256) % 8 * 8 + v]
for u in T.serial(4):
T.ptx_cp_async(
T.access_ptr(kv_shared_0_l[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8], "w", 8),
T.access_ptr(kv[b_i, indices_local, g_i, 64 * u + (tx - 256) % 8 * 8], "r", 8),
16,
)
T.ptx_cp_async(
T.access_ptr(kv_shared_0_r[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8], "w", 8),
T.access_ptr(kv[b_i, indices_local, g_i, d // 2 + 64 * u + (tx - 256) % 8 * 8], "r", 8),
16,
)
T.ptx_cp_async(
T.access_ptr(k_tail_shared_0[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8], "w", 8),
T.access_ptr(kv[b_i, indices_local, g_i, d + (tx - 256) % 8 * 8], "r", 8),
16,
)
T.cp_async_barrier_noinc(bar_k_0_ready[0])

# Buffer 1
T.barrier_wait(bar_k_1_free[0], ((i_i & 1) ^ 1))
for r in T.serial(4):
indices_local[0] = indices[b_i, s_i, g_i, (i_i * 2 + 1) * i_block +
indices_local = indices[b_i, s_i, g_i, (i_i * 2 + 1) * i_block +
r * 16 + (tx - 256) // 8]
is_kv_valid[r * 16 + (tx - 256) // 8] = indices_local[0] <= max_kv_i
is_kv_valid[r * 16 + (tx - 256) // 8] = indices_local <= max_kv_i
if is_kv_valid[r * 16 + (tx - 256) // 8]:
with T.attr("default", "async_scope", 1):
for u in T.serial(4):
for v in T.vectorized(8):
kv_shared_1_l[r * 16 + (tx - 256) // 8,
64 * u + (tx - 256) % 8 * 8 +
v] = kv[b_i, indices_local[0], g_i,
64 * u + (tx - 256) % 8 * 8 + v]
kv_shared_1_r[r * 16 + (tx - 256) // 8,
64 * u + (tx - 256) % 8 * 8 +
v] = kv[b_i, indices_local[0], g_i,
d // 2 + 64 * u +
(tx - 256) % 8 * 8 + v]
with T.attr("default", "async_scope", 1):
for v in T.vectorized(8):
k_tail_shared_1[r * 16 + (tx - 256) // 8,
(tx - 256) % 8 * 8 +
v] = kv[b_i, indices_local[0], g_i,
d + (tx - 256) % 8 * 8 + v]
for u in T.serial(4):
T.ptx_cp_async(
T.access_ptr(kv_shared_1_l[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8], "w", 8),
T.access_ptr(kv[b_i, indices_local, g_i, 64 * u + (tx - 256) % 8 * 8], "r", 8),
16,
)
T.ptx_cp_async(
T.access_ptr(kv_shared_1_r[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8], "w", 8),
T.access_ptr(kv[b_i, indices_local, g_i, d // 2 + 64 * u + (tx - 256) % 8 * 8], "r", 8),
16,
)
T.ptx_cp_async(
T.access_ptr(k_tail_shared_1[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8], "w", 8),
T.access_ptr(kv[b_i, indices_local, g_i, d + (tx - 256) % 8 * 8], "r", 8),
16,
)
T.cp_async_barrier_noinc(bar_k_1_ready[0])

return _sparse_mla_fwd_main
Expand Down
Loading