From c2204018707a96f6e6bf39eb7ba7ff6f698df7c2 Mon Sep 17 00:00:00 2001 From: caojian5 Date: Fri, 8 May 2026 17:44:16 +0800 Subject: [PATCH 1/5] [Example] Add MLA decode operator for DeepSeek Ascend NPU migration - Migrate MLA decode from GPU main repo to Ascend NPU - Support MQA with kv_head_num=1 (128 query heads share 1 KV head) - Integrate position encoding (Q_pe + K_pe) in attention computation - Implement online softmax with FP32 accumulation - Use Developer mode with automatic memory planning - Add workspace buffers for L0C->shared copy pattern - Default config: batch=1, heads=128, kv_ctx=8192, dim=512 --- examples/deepseek_mla/example_mla_decode.py | 177 ++++++++++++++++++++ 1 file changed, 177 insertions(+) create mode 100644 examples/deepseek_mla/example_mla_decode.py diff --git a/examples/deepseek_mla/example_mla_decode.py b/examples/deepseek_mla/example_mla_decode.py new file mode 100644 index 000000000..ea2fa634a --- /dev/null +++ b/examples/deepseek_mla/example_mla_decode.py @@ -0,0 +1,177 @@ +import torch +import torch.nn.functional as F +import tilelang +from tilelang import language as T +from einops import rearrange, einsum +import argparse + +torch.set_default_device("npu") +torch.manual_seed(0) + +tilelang.disable_cache() + +pass_configs = { + tilelang.PassConfigKey.TL_ASCEND_AUTO_SYNC: True, + tilelang.PassConfigKey.TL_ASCEND_AUTO_CV_COMBINE: True, + tilelang.PassConfigKey.TL_ASCEND_AUTO_CV_SYNC: True, + tilelang.PassConfigKey.TL_ASCEND_MEMORY_PLANNING: True, +} + + +@tilelang.jit(out_idx=[4], workspace_idx=[5, 6], pass_configs=pass_configs) +def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, softmax_scale): + sm_scale = softmax_scale + dtype = "float16" + accum_dtype = "float" + kv_group_num = heads // kv_head_num + VALID_BLOCK_H = min(block_H, kv_group_num) + assert kv_head_num == 1, "kv_head_num must be 1 for MLA decode" + + block_num = batch * (heads // VALID_BLOCK_H) + + @T.prim_func + def main( + Q: T.Tensor([batch, heads, dim], dtype), + Q_pe: T.Tensor([batch, heads, pe_dim], dtype), + KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), + K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), + Output: T.Tensor([batch, heads, dim], dtype), + workspace_1: T.Tensor([block_num, VALID_BLOCK_H, block_N], accum_dtype), + workspace_2: T.Tensor([block_num, VALID_BLOCK_H, dim], accum_dtype), + ): + with T.Kernel(batch * (heads // VALID_BLOCK_H), is_npu=True) as (cid, vid): + bid = cid // (heads // VALID_BLOCK_H) + hid = cid % (heads // VALID_BLOCK_H) + + Q_shared = T.alloc_shared([VALID_BLOCK_H, dim], dtype) + Q_pe_shared = T.alloc_shared([VALID_BLOCK_H, pe_dim], dtype) + KV_shared = T.alloc_shared([block_N, dim], dtype) + K_pe_shared = T.alloc_shared([block_N, pe_dim], dtype) + + acc_s_l0c = T.alloc_fragment([VALID_BLOCK_H, block_N], accum_dtype) + acc_o_l0c = T.alloc_fragment([VALID_BLOCK_H, dim], accum_dtype) + + acc_o = T.alloc_shared([VALID_BLOCK_H, dim], accum_dtype) + acc_o_half = T.alloc_shared([VALID_BLOCK_H, dim], dtype) + acc_o_ub = T.alloc_shared([VALID_BLOCK_H, dim], accum_dtype) + acc_s = T.alloc_shared([VALID_BLOCK_H, block_N], accum_dtype) + acc_s_half = T.alloc_shared([VALID_BLOCK_H, block_N], dtype) + scores_max = T.alloc_shared([VALID_BLOCK_H], accum_dtype) + scores_max_prev = T.alloc_shared([VALID_BLOCK_H], accum_dtype) + scores_scale = T.alloc_shared([VALID_BLOCK_H], accum_dtype) + scores_sum = T.alloc_shared([VALID_BLOCK_H], accum_dtype) + logsum = T.alloc_shared([VALID_BLOCK_H], accum_dtype) + + cur_kv_head = 0 + + T.copy(Q[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, :], Q_shared) + T.copy(Q_pe[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, :], Q_pe_shared) + T.tile.fill(acc_o, 0.0) + T.tile.fill(logsum, 0.0) + T.tile.fill(scores_max, -(2.0**30)) + + loop_range = T.ceildiv(seqlen_kv, block_N) + for k in T.Pipelined(loop_range, num_stages=2): + kv_start = k * block_N + kv_end = (k + 1) * block_N + T.copy(KV[bid, kv_start:kv_end, cur_kv_head, :], KV_shared) + T.copy(K_pe[bid, kv_start:kv_end, cur_kv_head, :], K_pe_shared) + + T.gemm_v0(Q_shared, KV_shared, acc_s_l0c, transpose_B=True, init=True) + T.gemm_v0(Q_pe_shared, K_pe_shared, acc_s_l0c, transpose_B=True) + T.copy(acc_s_l0c, workspace_1[cid, :, :]) + + T.copy(workspace_1[cid, :, :], acc_s) + + T.copy(scores_max, scores_max_prev) + T.tile.fill(scores_max, -(2.0**30)) + T.reduce_max(acc_s, scores_max, dim=-1, clear=False) + T.tile.max(scores_max, scores_max, scores_max_prev) + + T.tile.sub(scores_max_prev, scores_max_prev, scores_max) + T.tile.mul(scores_max_prev, scores_max_prev, sm_scale) + T.tile.exp(scores_scale, scores_max_prev) + + T.tile.sub(acc_s, acc_s, scores_max) + T.tile.mul(acc_s, acc_s, sm_scale) + T.tile.exp(acc_s, acc_s) + + T.reduce_sum(acc_s, scores_sum, dim=-1, clear=False) + + T.tile.mul(logsum, logsum, scores_scale) + T.tile.add(logsum, logsum, scores_sum) + + T.tile.mul(acc_o, acc_o, scores_scale) + + T.copy(acc_s, acc_s_half) + T.gemm_v0(acc_s_half, KV_shared, acc_o_l0c, init=True) + T.copy(acc_o_l0c, workspace_2[cid, :, :]) + + T.copy(workspace_2[cid, :, :], acc_o_ub) + T.tile.add(acc_o, acc_o, acc_o_ub) + + for i, j in T.Parallel(VALID_BLOCK_H, dim): + acc_o[i, j] /= logsum[i] + + T.copy(acc_o, acc_o_half) + T.copy(acc_o_half, Output[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, :]) + + return main + + +def ref_program(q, q_pe, kv, k_pe): + dim = q.shape[-1] + pe_dim = q_pe.shape[-1] + num_head_groups = q.shape[1] // kv.shape[2] + scale = (dim + pe_dim) ** 0.5 + q = rearrange(q, "b (h g) d -> b g h d", g=num_head_groups) + q_pe = rearrange(q_pe, "b (h g) d -> b g h d", g=num_head_groups) + kv = rearrange(kv, "b n h d -> b h n d") + k_pe = rearrange(k_pe, "b n h d -> b h n d") + query = torch.concat([q, q_pe], dim=-1) + key = torch.concat([kv, k_pe], dim=-1) + scores = einsum(query, key, "b g h d, b h s d -> b g h s") + attention = F.softmax(scores / scale, dim=-1) + out = einsum(attention, kv, "b g h s, b h s d -> b g h d") + out = rearrange(out, "b g h d -> b (h g) d") + return out + + +def main( + batch=1, + heads=128, + kv_heads=1, + kv_ctx=8192, + dim=512, + pe_dim=64, +): + BLOCK_N = 64 + BLOCK_H = min(64, heads // kv_heads) + softmax_scale = (dim + pe_dim) ** -0.5 + + kernel = flashattn(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H, softmax_scale) + + q = torch.randn(batch, heads, dim, dtype=torch.float16) + q_pe = torch.randn(batch, heads, pe_dim, dtype=torch.float16) + kv = torch.randn(batch, kv_ctx, kv_heads, dim, dtype=torch.float16) + k_pe = torch.randn(batch, kv_ctx, kv_heads, pe_dim, dtype=torch.float16) + + output = kernel(q, q_pe, kv, k_pe) + + ref_output = ref_program(q.cpu(), q_pe.cpu(), kv.cpu(), k_pe.cpu()) + torch.testing.assert_close(output.cpu(), ref_output, rtol=1e-2, atol=1e-2) + + print("Test passed!") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--batch", type=int, default=1, help="batch size") + parser.add_argument("--heads", type=int, default=128, help="q heads number") + parser.add_argument("--kv_heads", type=int, default=1, help="kv heads number") + parser.add_argument("--kv_ctx", type=int, default=8192, help="kv context length") + parser.add_argument("--dim", type=int, default=512, help="head dim") + parser.add_argument("--pe_dim", type=int, default=64, help="pe head dim") + args = parser.parse_args() + batch, heads, kv_heads, kv_ctx, dim, pe_dim = args.batch, args.heads, args.kv_heads, args.kv_ctx, args.dim, args.pe_dim + main(batch, heads, kv_heads, kv_ctx, dim, pe_dim) \ No newline at end of file From 2159bfaaaac5f5a30bc85fd3682a8ac0c1368ae7 Mon Sep 17 00:00:00 2001 From: caojian5 Date: Fri, 8 May 2026 17:54:40 +0800 Subject: [PATCH 2/5] [Fix] Format MLA decode operator with yapf and ruff to pass CI checks --- examples/deepseek_mla/example_mla_decode.py | 26 ++++++++++----------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/examples/deepseek_mla/example_mla_decode.py b/examples/deepseek_mla/example_mla_decode.py index ea2fa634a..a95154232 100644 --- a/examples/deepseek_mla/example_mla_decode.py +++ b/examples/deepseek_mla/example_mla_decode.py @@ -31,13 +31,13 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ @T.prim_func def main( - Q: T.Tensor([batch, heads, dim], dtype), - Q_pe: T.Tensor([batch, heads, pe_dim], dtype), - KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), - K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), - Output: T.Tensor([batch, heads, dim], dtype), - workspace_1: T.Tensor([block_num, VALID_BLOCK_H, block_N], accum_dtype), - workspace_2: T.Tensor([block_num, VALID_BLOCK_H, dim], accum_dtype), + Q: T.Tensor([batch, heads, dim], dtype), + Q_pe: T.Tensor([batch, heads, pe_dim], dtype), + KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), + K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), + Output: T.Tensor([batch, heads, dim], dtype), + workspace_1: T.Tensor([block_num, VALID_BLOCK_H, block_N], accum_dtype), + workspace_2: T.Tensor([block_num, VALID_BLOCK_H, dim], accum_dtype), ): with T.Kernel(batch * (heads // VALID_BLOCK_H), is_npu=True) as (cid, vid): bid = cid // (heads // VALID_BLOCK_H) @@ -64,8 +64,8 @@ def main( cur_kv_head = 0 - T.copy(Q[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, :], Q_shared) - T.copy(Q_pe[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, :], Q_pe_shared) + T.copy(Q[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :], Q_shared) + T.copy(Q_pe[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :], Q_pe_shared) T.tile.fill(acc_o, 0.0) T.tile.fill(logsum, 0.0) T.tile.fill(scores_max, -(2.0**30)) @@ -114,7 +114,7 @@ def main( acc_o[i, j] /= logsum[i] T.copy(acc_o, acc_o_half) - T.copy(acc_o_half, Output[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, :]) + T.copy(acc_o_half, Output[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :]) return main @@ -123,7 +123,7 @@ def ref_program(q, q_pe, kv, k_pe): dim = q.shape[-1] pe_dim = q_pe.shape[-1] num_head_groups = q.shape[1] // kv.shape[2] - scale = (dim + pe_dim) ** 0.5 + scale = (dim + pe_dim)**0.5 q = rearrange(q, "b (h g) d -> b g h d", g=num_head_groups) q_pe = rearrange(q_pe, "b (h g) d -> b g h d", g=num_head_groups) kv = rearrange(kv, "b n h d -> b h n d") @@ -147,7 +147,7 @@ def main( ): BLOCK_N = 64 BLOCK_H = min(64, heads // kv_heads) - softmax_scale = (dim + pe_dim) ** -0.5 + softmax_scale = (dim + pe_dim)**-0.5 kernel = flashattn(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H, softmax_scale) @@ -174,4 +174,4 @@ def main( parser.add_argument("--pe_dim", type=int, default=64, help="pe head dim") args = parser.parse_args() batch, heads, kv_heads, kv_ctx, dim, pe_dim = args.batch, args.heads, args.kv_heads, args.kv_ctx, args.dim, args.pe_dim - main(batch, heads, kv_heads, kv_ctx, dim, pe_dim) \ No newline at end of file + main(batch, heads, kv_heads, kv_ctx, dim, pe_dim) From 3de1c98863a719a47bc955b128a88beb805f7fe6 Mon Sep 17 00:00:00 2001 From: caojian5 Date: Fri, 8 May 2026 18:02:29 +0800 Subject: [PATCH 3/5] [Fix] Address Gemini Code Assist review comments - Fix kernel launch mode: Add threads=2 for proper Developer mode - Fix cur_kv_head hardcoding: Use dynamic calculation for GQA reusability - Improve normalization: Replace T.Parallel loop with broadcasted T.tile.div - Add performance notes: Document workspace buffer bottleneck as workaround - Address all medium priority issues from review --- examples/deepseek_mla/example_mla_decode.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/examples/deepseek_mla/example_mla_decode.py b/examples/deepseek_mla/example_mla_decode.py index a95154232..19e6b4f14 100644 --- a/examples/deepseek_mla/example_mla_decode.py +++ b/examples/deepseek_mla/example_mla_decode.py @@ -39,7 +39,7 @@ def main( workspace_1: T.Tensor([block_num, VALID_BLOCK_H, block_N], accum_dtype), workspace_2: T.Tensor([block_num, VALID_BLOCK_H, dim], accum_dtype), ): - with T.Kernel(batch * (heads // VALID_BLOCK_H), is_npu=True) as (cid, vid): + with T.Kernel(batch * (heads // VALID_BLOCK_H), threads=2, is_npu=True) as (cid, vid): bid = cid // (heads // VALID_BLOCK_H) hid = cid % (heads // VALID_BLOCK_H) @@ -61,8 +61,9 @@ def main( scores_scale = T.alloc_shared([VALID_BLOCK_H], accum_dtype) scores_sum = T.alloc_shared([VALID_BLOCK_H], accum_dtype) logsum = T.alloc_shared([VALID_BLOCK_H], accum_dtype) + logsum_2d = T.alloc_shared([VALID_BLOCK_H, dim], accum_dtype) - cur_kv_head = 0 + cur_kv_head = hid // (heads // kv_head_num) T.copy(Q[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :], Q_shared) T.copy(Q_pe[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :], Q_pe_shared) @@ -79,8 +80,13 @@ def main( T.gemm_v0(Q_shared, KV_shared, acc_s_l0c, transpose_B=True, init=True) T.gemm_v0(Q_pe_shared, K_pe_shared, acc_s_l0c, transpose_B=True) - T.copy(acc_s_l0c, workspace_1[cid, :, :]) + # NOTE: Workspace buffer workaround for L0C->UB copy + # Current backend doesn't support direct L0C->UB transfers, + # so we use L0C->Global Memory(workspace)->UB pattern. + # This is a performance bottleneck that should be optimized + # once the backend supports direct transfers. + T.copy(acc_s_l0c, workspace_1[cid, :, :]) T.copy(workspace_1[cid, :, :], acc_s) T.copy(scores_max, scores_max_prev) @@ -105,13 +111,14 @@ def main( T.copy(acc_s, acc_s_half) T.gemm_v0(acc_s_half, KV_shared, acc_o_l0c, init=True) - T.copy(acc_o_l0c, workspace_2[cid, :, :]) + # NOTE: Same L0C->Global Memory->UB workaround for output accumulator + T.copy(acc_o_l0c, workspace_2[cid, :, :]) T.copy(workspace_2[cid, :, :], acc_o_ub) T.tile.add(acc_o, acc_o, acc_o_ub) - for i, j in T.Parallel(VALID_BLOCK_H, dim): - acc_o[i, j] /= logsum[i] + T.tile.broadcast(logsum_2d, logsum) + T.tile.div(acc_o, acc_o, logsum_2d) T.copy(acc_o, acc_o_half) T.copy(acc_o_half, Output[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :]) From 92f1b72001896f03a5b10f5afbda32e4aa69a8d9 Mon Sep 17 00:00:00 2001 From: caojian5 Date: Fri, 8 May 2026 18:13:05 +0800 Subject: [PATCH 4/5] [Fix] Use ruff format instead of yapf to match CI requirements CI uses ruff format check (ci_ascend.yml), not yapf (format.sh). Format changes: - Function parameter indentation: 4 spaces (ruff) vs 8 spaces (yapf) - Slice syntax spacing: keep spaces around : (ruff) vs remove (yapf) - Power operator spacing: keep spaces around ** (ruff) vs remove (yapf) This aligns with the actual CI workflow used in PR checks. --- examples/deepseek_mla/example_mla_decode.py | 24 ++++++++++----------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/examples/deepseek_mla/example_mla_decode.py b/examples/deepseek_mla/example_mla_decode.py index 19e6b4f14..0310739e4 100644 --- a/examples/deepseek_mla/example_mla_decode.py +++ b/examples/deepseek_mla/example_mla_decode.py @@ -31,13 +31,13 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ @T.prim_func def main( - Q: T.Tensor([batch, heads, dim], dtype), - Q_pe: T.Tensor([batch, heads, pe_dim], dtype), - KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), - K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), - Output: T.Tensor([batch, heads, dim], dtype), - workspace_1: T.Tensor([block_num, VALID_BLOCK_H, block_N], accum_dtype), - workspace_2: T.Tensor([block_num, VALID_BLOCK_H, dim], accum_dtype), + Q: T.Tensor([batch, heads, dim], dtype), + Q_pe: T.Tensor([batch, heads, pe_dim], dtype), + KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), + K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), + Output: T.Tensor([batch, heads, dim], dtype), + workspace_1: T.Tensor([block_num, VALID_BLOCK_H, block_N], accum_dtype), + workspace_2: T.Tensor([block_num, VALID_BLOCK_H, dim], accum_dtype), ): with T.Kernel(batch * (heads // VALID_BLOCK_H), threads=2, is_npu=True) as (cid, vid): bid = cid // (heads // VALID_BLOCK_H) @@ -65,8 +65,8 @@ def main( cur_kv_head = hid // (heads // kv_head_num) - T.copy(Q[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :], Q_shared) - T.copy(Q_pe[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :], Q_pe_shared) + T.copy(Q[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, :], Q_shared) + T.copy(Q_pe[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, :], Q_pe_shared) T.tile.fill(acc_o, 0.0) T.tile.fill(logsum, 0.0) T.tile.fill(scores_max, -(2.0**30)) @@ -121,7 +121,7 @@ def main( T.tile.div(acc_o, acc_o, logsum_2d) T.copy(acc_o, acc_o_half) - T.copy(acc_o_half, Output[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :]) + T.copy(acc_o_half, Output[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, :]) return main @@ -130,7 +130,7 @@ def ref_program(q, q_pe, kv, k_pe): dim = q.shape[-1] pe_dim = q_pe.shape[-1] num_head_groups = q.shape[1] // kv.shape[2] - scale = (dim + pe_dim)**0.5 + scale = (dim + pe_dim) ** 0.5 q = rearrange(q, "b (h g) d -> b g h d", g=num_head_groups) q_pe = rearrange(q_pe, "b (h g) d -> b g h d", g=num_head_groups) kv = rearrange(kv, "b n h d -> b h n d") @@ -154,7 +154,7 @@ def main( ): BLOCK_N = 64 BLOCK_H = min(64, heads // kv_heads) - softmax_scale = (dim + pe_dim)**-0.5 + softmax_scale = (dim + pe_dim) ** -0.5 kernel = flashattn(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H, softmax_scale) From 206b362acd29abd7d621753bc28b87b51616b53a Mon Sep 17 00:00:00 2001 From: caojian5 Date: Sat, 9 May 2026 15:45:06 +0800 Subject: [PATCH 5/5] fix compile error --- examples/deepseek_mla/example_mla_decode.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/examples/deepseek_mla/example_mla_decode.py b/examples/deepseek_mla/example_mla_decode.py index 0310739e4..24652d05e 100644 --- a/examples/deepseek_mla/example_mla_decode.py +++ b/examples/deepseek_mla/example_mla_decode.py @@ -39,7 +39,7 @@ def main( workspace_1: T.Tensor([block_num, VALID_BLOCK_H, block_N], accum_dtype), workspace_2: T.Tensor([block_num, VALID_BLOCK_H, dim], accum_dtype), ): - with T.Kernel(batch * (heads // VALID_BLOCK_H), threads=2, is_npu=True) as (cid, vid): + with T.Kernel(batch * (heads // VALID_BLOCK_H), threads=2, is_npu=True) as (cid): bid = cid // (heads // VALID_BLOCK_H) hid = cid % (heads // VALID_BLOCK_H) @@ -61,7 +61,6 @@ def main( scores_scale = T.alloc_shared([VALID_BLOCK_H], accum_dtype) scores_sum = T.alloc_shared([VALID_BLOCK_H], accum_dtype) logsum = T.alloc_shared([VALID_BLOCK_H], accum_dtype) - logsum_2d = T.alloc_shared([VALID_BLOCK_H, dim], accum_dtype) cur_kv_head = hid // (heads // kv_head_num) @@ -117,8 +116,8 @@ def main( T.copy(workspace_2[cid, :, :], acc_o_ub) T.tile.add(acc_o, acc_o, acc_o_ub) - T.tile.broadcast(logsum_2d, logsum) - T.tile.div(acc_o, acc_o, logsum_2d) + for h_i in range(VALID_BLOCK_H): + T.tile.div(acc_o[h_i, :], acc_o[h_i, :], logsum[h_i]) T.copy(acc_o, acc_o_half) T.copy(acc_o_half, Output[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, :]) @@ -153,7 +152,7 @@ def main( pe_dim=64, ): BLOCK_N = 64 - BLOCK_H = min(64, heads // kv_heads) + BLOCK_H = min(32, heads // kv_heads) softmax_scale = (dim + pe_dim) ** -0.5 kernel = flashattn(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H, softmax_scale) @@ -164,8 +163,16 @@ def main( k_pe = torch.randn(batch, kv_ctx, kv_heads, pe_dim, dtype=torch.float16) output = kernel(q, q_pe, kv, k_pe) + print(f"Kernel output shape: {output.shape}") + print(f"Kernel output stats: min={output.min()}, max={output.max()}, mean={output.mean()}") ref_output = ref_program(q.cpu(), q_pe.cpu(), kv.cpu(), k_pe.cpu()) + print(f"Ref output shape: {ref_output.shape}") + print(f"Ref output stats: min={ref_output.min()}, max={ref_output.max()}, mean={ref_output.mean()}") + + diff = (output.cpu() - ref_output).abs() + print(f"Max diff: {diff.max()}, Mean diff: {diff.mean()}") + torch.testing.assert_close(output.cpu(), ref_output, rtol=1e-2, atol=1e-2) print("Test passed!")