diff --git a/.agents/skills/tilelang-op-design/SKILL.md b/.agents/skills/tilelang-op-design/SKILL.md index 69e5cc99a..e18f60a53 100644 --- a/.agents/skills/tilelang-op-design/SKILL.md +++ b/.agents/skills/tilelang-op-design/SKILL.md @@ -69,6 +69,8 @@ description: "根据算子需求生成 TileLang-Ascend 算子设计文档(desi | **动态循环边界不支持** | 循环次数不能依赖 tensor 值(如 `batch_sizes[bz]`) | `T.Pipelined(batch_sizes[bz])` 报错 | 预计算最大循环次数,用 `T.serial(max_iters)` + 条件判断 | | **流水线不支持动态边界** | `T.Pipelined` 的循环次数必须静态 | 动态批次无法流水线 | 改用 `T.serial` 或预计算固定迭代次数 | | **部分 GPU API 不可用** | CUDA 专用 API 在 Ascend 不存在 | 直接移植 GPU 代码失败 | 查阅本项目 `examples/` 确认 Ascend API | +| **GEMM 要求 M,N 为 block 整数倍** | `M // block_M` 整除依赖;`M < block_M` 时零 block 启动 | 输出全零或除零编译崩溃 | 设计文档 §4/§5 必须明确处理策略:host 侧 padding+crop 或 Kernel 动态 block | +| **L0C 容量上限** | A2/A3 设备 L0C = 128KB | `block_M × block_N × sizeof(accum) > 128KB` 导致 segfault | 设计 block 时满足 `block_M × block_N ≤ 16384`(float32 accum) | ### 2.5.2 强制检测规则 @@ -80,6 +82,8 @@ description: "根据算子需求生成 TileLang-Ascend 算子设计文档(desi | threads 参数 | 参考实现 threads > 2 | **立即警告**,建议 threads=2 或移除 | | 动态循环边界 | 循环边界依赖 tensor 值 | **立即警告**,提出静态边界 + 条件判断方案 | | GPU 专用 API | CUDA 相关 API(如 `T.gemm` 通用版) | **立即警告**,查阅本项目确认 Ascend API | +| GEMM 非整除风险 | `M` 或 `N` 不被 block size 整除(即 `M % block_M ≠ 0` 或 `N % block_N ≠ 0`) | **立即警告**,要求 design 中明确 padding 策略 | +| L0C 溢出风险 | block_M × block_N × sizeof(accum_dtype) > 131072 (128KB) | **立即警告**,建议减小 block 或拆分 | ### 2.5.3 警告输出格式 @@ -118,11 +122,13 @@ description: "根据算子需求生成 TileLang-Ascend 算子设计文档(desi - 纯 Vector(element-wise / reduction)→ 仅需 UB - 纯 Cube(仅 matmul)→ 需要 L1 + L0A/L0B/L0C - 混合(matmul + element-wise 后处理)→ 核间流水线,需要 CV 融合 + - **Host 预处理**:如 im2col 等 Python 侧预处理步骤,标明在 design 的 §1 和 §4 中 - **复杂度级别**: - 单步(如 element-wise add)→ 无循环、单次搬运 - 多步(如 softmax = max + sub + exp + sum + div)→ 多次计算、可能需要中间缓冲 - 融合(如 flash attention = GEMM + softmax + GEMM)→ 核间协作、流水线 - **动态 shape 判定**:是否存在运行时才确定的维度 +4. **非整除场景预判**:检查输入 shape 是否可能不被 block size 整除。GEMM 类算子的 `M // block_M` 和 `N // block_N` 在 `M < block_M` 或 `N < block_N` 时产生零 block 或不完整 tile,必须在设计中明确处理策略(host 侧 zero-padding + crop,或 Kernel 内动态 block size) ### Phase 2:信息收集 @@ -168,10 +174,10 @@ grep "T.Scope\|T.barrier" examples/{同类实现} # 同步方式 2. 编程模式选型 3. API 映射设计 4. 数据规格与内存规划 -5. Tiling 策略 +5. Tiling 策略(**必含:非整除时 padding+crop 策略,或 Kernel 内动态 block 方案**) 6. 循环与调度结构 7. 同步策略 -8. CV 融合设计(如有) +8. CV 融合设计 9. 验证方案 10. 风险点与注意事项 11. 交付清单 @@ -192,7 +198,12 @@ grep "T.Scope\|T.barrier" examples/{同类实现} # 同步方式 ## 4. 算子特征分析决策树(修订版) -### 4.0 平台识别 +### 4.0 函数设计原则 + +1. **维度参数自推导**:算子调用函数(如 `conv_im2col_gemm`)应从输入 tensor shape 提取 B/C/H/W 等维度,不依赖模块级全局变量。这保证多场景顺序测试时不发生变量污染。 +2. **Host 预处理显式声明**:若计算的一部分在 Python 侧完成(如 im2col),必须在 §1 算法描述和 §4 数据流中明确标注。 + +### 4.1 平台识别 **本项目为 TileLang-Ascend(昇腾 NPU)**,与 GPU 版 TileLang 有差异: @@ -283,7 +294,7 @@ grep "T.Scope\|T.barrier" examples/{同类实现} # 同步方式 | 3 | **内存搬运路径完整**:从 GM 到计算再到 GM 的每一步都有说明 | ✅ 必须 | | 4 | **Tiling 策略有约束分析**:解释了为什么选择该 Block/Tile 大小 | ⭕ 推荐 | | 5 | **同步策略与编程模式匹配**:Developer 用自动同步、Expert 标明手动同步点 | ⭕ 推荐 | -| 6 | **验证方案覆盖典型配置**:不是「待补充」 | ⭕ 推荐 | +| 6 | **验证方案覆盖 4 类典型配置**:完美对齐 + 单维 padding + 全维 padding + 多 block(GEMM 类必含),不是「待补充」 | ⭕ 推荐 | | 7 | **无占位符或模糊描述**:无 `{placeholder}`、TODO、「待补充」(已确认的除外) | ✅ 必须 | | 8 | **技术约束已确认**:三维 Kernel、threads、动态边界等问题已处理 | ✅ 必须 | | 9 | **本项目同类实现已列出**:有具体的 examples/ 文件路径参考 | ✅ 必须 | @@ -291,8 +302,11 @@ grep "T.Scope\|T.barrier" examples/{同类实现} # 同步方式 | 11 | **参考实现分析完整**(如有参考实现):记录了内存层级 API、同步策略、pass_configs 等技术决策 | ⭕ 推荐 | | 12 | **CV 融合设计完整**(如需):workspace 规格、数据流、pass_configs | ⭕ 推荐 | | 13 | **workspace_idx 配置正确**(如需 CV 融合):与 workspace 参数位置一致 | ✅ 必须 | +| 14 | **非整除处理策略明确**(GEMM 类必含):主机侧 padding+crop 或 Kernel 内动态 block,说明溢出/下溢处理 | ✅ 必须 | +| 15 | **L0C 容量约束已验证**(GEMM 类必含):`block_M × block_N × sizeof(accum_dtype) ≤ L0C_capacity (128KB)` | ⭕ 推荐 | +| 16 | **函数无全局变量依赖**:维度参数从 tensor shape 或函数参数获取,支持多场景顺序测试 | ⭕ 推荐 | -**通过条件**:必须项(1, 2, 3, 7, 8, 9)全部通过,推荐项(4, 5, 6, 10)至少通过 3/4。 +**通过条件**:必须项(1, 2, 3, 7, 8, 9, 14)全部通过,推荐项至少通过 4/9。 --- @@ -352,11 +366,14 @@ grep "T.Scope\|T.barrier" examples/{同类实现} # 同步方式 3. 内存搬运完整性: ✅ / ❌ 4. Tiling 约束分析: ✅ / ❌ 5. 同步策略匹配: ✅ / ❌ -6. 验证方案覆盖: ✅ / ❌ +6. 验证方案覆盖(4 类): ✅ / ❌ 7. 无占位符: ✅ / ❌ 8. 技术约束确认: ✅ / ❌ 9. 本项目同类实现列出: ✅ / ❌ 10. 参考实现差异说明: ✅ / ❌ / N/A +11. 非整除处理策略: ✅ / ❌ / N/A +12. L0C 容量约束: ✅ / ❌ / N/A +13. 无全局变量依赖: ✅ / ❌ ### 待确认项 - {列出需要用户进一步确认的内容} diff --git a/.agents/skills/tilelang-op-generate/references/troubleshooting.md b/.agents/skills/tilelang-op-generate/references/troubleshooting.md index 32bdaf36b..f8642cc26 100644 --- a/.agents/skills/tilelang-op-generate/references/troubleshooting.md +++ b/.agents/skills/tilelang-op-generate/references/troubleshooting.md @@ -117,13 +117,16 @@ TypeError: get_configs() missing 1 required positional argument: 'K' **现象**: autotune 编译通过但 benchmark 时进程直接 crash(Segfault),无 Python 异常。 -**原因**: `block_M * block_N * sizeof(accum_dtype) > L0C_capacity`(A2/A3 设备 L0C=64KB,float32 cum元素≤16384)。如 `block_M=256, block_N=256 → 256KB > 64KB`。 +**Segfault类似问题排查建议**: Segment fault 需要通过 gdb 等工具定位具体 crash 位置和调用栈,再结合 kernel 配置、访存范围、片上内存使用量等因素判断根因。 +可能原因之一:当 `block_M * block_N * sizeof(accum_dtype) > L0C_capacity` 时,可能导致片上 buffer 使用超过硬件限制。例如 A2/A3 设备 L0C 为 128KB,float32 accum 元素数不应超过 32768; **解决方案**: autotune 的 `get_configs` 中过滤超大 block: ```python block_M = [bs for bs in [64, 128] if bs <= M] # 排除 256 ``` + + ## 运行时错误 ### 1. 结果不正确 diff --git a/examples/seer_attention/block_sparse_attn.py b/examples/seer_attention/block_sparse_attn.py new file mode 100644 index 000000000..884c5c0da --- /dev/null +++ b/examples/seer_attention/block_sparse_attn.py @@ -0,0 +1,260 @@ +import math +import torch + +import tilelang +import tilelang.language as T +import torch.nn.functional as F + + +def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False): + bsz, num_head, downsample_len, _ = x.shape + sparse_index = torch.topk(x, topk, dim=-1).indices + dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], False, dtype=torch.bool, device=x.device) + dense_mask.scatter_(-1, sparse_index, True) + if use_dense_for_last_block: + dense_mask[:, :, -2:, :] = True + dense_mask.tril_() + return dense_mask + + +def get_sparse_attn_mask_from_threshold(x, threshold, use_dense_for_last_block=False): + dense_mask = x > threshold + if use_dense_for_last_block: + dense_mask[:, :, -2:, :] = True + dense_mask.tril_() + return dense_mask + + +pass_configs = { + tilelang.PassConfigKey.TL_ASCEND_AUTO_CV_COMBINE: True, + tilelang.PassConfigKey.TL_ASCEND_AUTO_CV_SYNC: True, + tilelang.PassConfigKey.TL_ASCEND_AUTO_SYNC: True, + tilelang.PassConfigKey.TL_ASCEND_MEMORY_PLANNING: True, +} + + +@tilelang.jit(out_idx=[4], workspace_idx=[5, 6, 7], pass_configs=pass_configs) +def blocksparse_flashattn_ascend(batch, heads, seq_q, seq_kv, dim, downsample_len, is_causal): + block_M = 64 + block_N = 64 + dtype = "float16" + accum_dtype = "float" + block_mask_dtype = "int8" + sm_scale = (1.0 / dim) ** 0.5 + + q_shape = [batch, heads, seq_q, dim] + kv_shape = [batch, heads, seq_kv, dim] + block_mask_shape = [batch, heads, downsample_len, downsample_len] + block_num = ((seq_q + block_M - 1) // block_M) * heads * batch + + @T.prim_func + def main( + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(kv_shape, dtype), + V: T.Tensor(kv_shape, dtype), + BlockSparseMask: T.Tensor(block_mask_shape, block_mask_dtype), + Output: T.Tensor(q_shape, dtype), + workspace_1: T.Tensor([block_num, block_M, block_N], accum_dtype), + workspace_2: T.Tensor([block_num, block_M, block_N], dtype), + workspace_3: T.Tensor([block_num, block_M, dim], accum_dtype), + ): + with T.Kernel(block_num, is_npu=True) as (cid, vid): + bx = cid % ((seq_q + block_M - 1) // block_M) + by = cid // ((seq_q + block_M - 1) // block_M) % heads + bz = cid // ((seq_q + block_M - 1) // block_M) // heads % batch + + # L1 buffers (Cube core) + q_l1 = T.alloc_L1([block_M, dim], dtype) + k_l1 = T.alloc_L1([block_N, dim], dtype) + v_l1 = T.alloc_L1([block_N, dim], dtype) + acc_s_l1 = T.alloc_L1([block_M, block_N], dtype) + acc_s_l0c = T.alloc_L0C([block_M, block_N], accum_dtype) + acc_o_l0c = T.alloc_L0C([block_M, dim], accum_dtype) + + # UB buffers (Vector core) for online softmax + acc_o = T.alloc_ub([block_M // 2, dim], accum_dtype) + sumexp = T.alloc_ub([block_M // 2], accum_dtype) + m_i = T.alloc_ub([block_M // 2], accum_dtype) + acc_s_ub = T.alloc_ub([block_M // 2, block_N], accum_dtype) + m_i_prev = T.alloc_ub([block_M // 2], accum_dtype) + acc_s_ub_ = T.alloc_ub([block_M // 2, block_N], accum_dtype) + sumexp_i_ub = T.alloc_ub([block_M // 2], accum_dtype) + acc_s_half = T.alloc_ub([block_M // 2, block_N], dtype) + acc_o_ub = T.alloc_ub([block_M // 2, dim], accum_dtype) + acc_o_half = T.alloc_ub([block_M // 2, dim], dtype) + + # UB buffers for block-sparse mask and causal mask + mask_row = T.alloc_ub([downsample_len], block_mask_dtype) + col_idx = T.alloc_ub([block_N], "int32") + col_idx_f = T.alloc_ub([block_N], "float") + cmp_mask = T.alloc_ub([block_N // 8], "uint8") + + # === Cube: Load Q, run GEMM pipeline for all KV blocks === + T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], q_l1) + + loop_range = (seq_kv + block_N - 1) // block_N + for k in T.serial(loop_range): + T.copy(K[bz, by, k * block_N : (k + 1) * block_N, :], k_l1) + T.gemm_v0(q_l1, k_l1, acc_s_l0c, transpose_B=True, init=True) + T.copy(acc_s_l0c, workspace_1[cid, :, :]) + + T.copy(workspace_2[cid, :, :], acc_s_l1) + T.copy(V[bz, by, k * block_N : (k + 1) * block_N, :], v_l1) + T.gemm_v0(acc_s_l1, v_l1, acc_o_l0c, init=True) + T.copy(acc_o_l0c, workspace_3[cid, :, :]) + + # === Vector: Online softmax + output accumulation === + T.tile.fill(acc_o, 0.0) + T.tile.fill(sumexp, 0.0) + T.tile.fill(m_i, -T.infinity(accum_dtype)) + + T.copy(BlockSparseMask[bz, by, bx, :], mask_row) + + past_len = seq_kv - seq_q + + for _k in T.serial(loop_range): + T.tile.fill(acc_s_ub, 0.0) + T.copy(m_i, m_i_prev) + + # Load S = Q@K^T from Cube workspace + T.copy(workspace_1[cid, vid * block_M // 2 : vid * block_M // 2 + block_M // 2, :], acc_s_ub_) + + # Block sparsity mask + if mask_row[_k] == 0: + T.tile.fill(acc_s_ub_, -T.infinity(accum_dtype)) + + # Element-level causal mask via compare+select (float32) + elif is_causal: + T.tile.createvecindex(col_idx, _k * block_N) + T.copy(col_idx, col_idx_f) + + for h_i in range(block_M // 2): + q_pos = bx * block_M + vid * (block_M // 2) + h_i + past_len + T.tile.compare(cmp_mask, col_idx_f, T.float32(q_pos), "LE") + T.tile.select(acc_s_ub_[h_i, :], cmp_mask, acc_s_ub_[h_i, :], -T.infinity(accum_dtype), "VSEL_TENSOR_SCALAR_MODE") + + T.tile.add(acc_s_ub, acc_s_ub, acc_s_ub_) + T.tile.mul(acc_s_ub, acc_s_ub, sm_scale) + + T.reduce_max(acc_s_ub, m_i, dim=-1) + T.tile.max(m_i, m_i, m_i_prev) + T.tile.sub(m_i_prev, m_i_prev, m_i) + T.tile.exp(m_i_prev, m_i_prev) + + for h_i in range(block_M // 2): + T.tile.sub(acc_s_ub[h_i, :], acc_s_ub[h_i, :], m_i[h_i]) + + T.tile.exp(acc_s_ub, acc_s_ub) + T.reduce_sum(acc_s_ub, sumexp_i_ub, dim=-1) + T.tile.mul(sumexp, sumexp, m_i_prev) + T.tile.add(sumexp, sumexp, sumexp_i_ub) + + for h_i in range(block_M // 2): + T.tile.mul(acc_o[h_i, :], acc_o[h_i, :], m_i_prev[h_i]) + + T.copy(acc_s_ub, acc_s_half) + T.copy(acc_s_half, workspace_2[cid, vid * block_M // 2 : vid * block_M // 2 + block_M // 2, :]) + + T.copy(workspace_3[cid, vid * block_M // 2 : vid * block_M // 2 + block_M // 2, :], acc_o_ub) + T.tile.add(acc_o, acc_o, acc_o_ub) + + # Final normalization: acc_o /= sumexp + for h_i in range(block_M // 2): + T.tile.div(acc_o[h_i, :], acc_o[h_i, :], sumexp[h_i]) + + T.copy(acc_o, acc_o_half) + T.copy(acc_o_half, Output[bz, by, bx * block_M + vid * block_M // 2 : bx * block_M + vid * block_M // 2 + block_M // 2, :]) + + return main + + +def test_topk_sparse_attention(): + BATCH, N_HEADS, SEQ_LEN, D_HEAD = 4, 2, 256, 64 + TOPK = 2 + BLOCK = 64 + torch.manual_seed(0) + + q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="npu", dtype=torch.float16) + k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="npu", dtype=torch.float16) + v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="npu", dtype=torch.float16) + + sm_scale = 1.0 / (D_HEAD**0.5) + + downsample_factor = BLOCK + downsample_len = math.ceil(SEQ_LEN / downsample_factor) + x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], device="npu", dtype=torch.float16) + x_ds[:, :, :, 0] = 100 + block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK) + + kernel = blocksparse_flashattn_ascend(BATCH, N_HEADS, SEQ_LEN, SEQ_LEN, D_HEAD, downsample_len, is_causal=True) + torch.npu.synchronize() + tilelang_output = kernel(q, k, v, block_mask.to(torch.int8)) + torch.npu.synchronize() + + # Reference with FULL element-level causal + block sparsity + full_mask = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device="npu")) + full_mask = full_mask[..., :SEQ_LEN, :SEQ_LEN].bool() + full_mask = full_mask & torch.tril(torch.ones_like(full_mask)) + + attn = torch.einsum("bhsd,bhtd->bhst", q.float(), k.float()) * sm_scale + attn = attn.masked_fill(~full_mask, float("-inf")) + attn = F.softmax(attn, dim=-1) + ref_output = torch.einsum("bhst,bhtd->bhsd", attn, v.float()).to(torch.float16) + + torch.testing.assert_close(tilelang_output, ref_output, atol=1e-2, rtol=1e-2) + print("Pass topk sparse attention test with qlen == klen") + + +def test_topk_sparse_attention_qlen_lt_klen(): + BATCH, N_HEADS = 1, 1 + Q_LEN, K_LEN, D_HEAD = 128, 256, 64 + TOPK = 1 + BLOCK = 64 + torch.manual_seed(0) + + q = torch.randn(BATCH, N_HEADS, Q_LEN, D_HEAD, device="npu", dtype=torch.float16) + k = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device="npu", dtype=torch.float16) + v = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device="npu", dtype=torch.float16) + sm_scale = 1.0 / (D_HEAD**0.5) + + downsample_factor = BLOCK + downsample_len = math.ceil(K_LEN / downsample_factor) + x_ds = torch.randn(BATCH, N_HEADS, downsample_len, downsample_len, device="npu", dtype=torch.float16) + x_ds[:, :, :, 0] = 100 + block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK) + + kernel = blocksparse_flashattn_ascend(BATCH, N_HEADS, Q_LEN, K_LEN, D_HEAD, downsample_len, is_causal=True) + torch.npu.synchronize() + tilelang_output = kernel(q, k, v, block_mask.to(torch.int8)) + torch.npu.synchronize() + + past_len = K_LEN - Q_LEN + + attn = torch.einsum("bhsd,bhtd->bhst", q.float(), k.float()) * sm_scale + + full_mask_full = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device="npu")).bool() + full_mask_full = full_mask_full[..., :K_LEN, :K_LEN] + effective_mask = full_mask_full[..., past_len:K_LEN, :] + + i_global = torch.arange(past_len, K_LEN, device=k.device).unsqueeze(1) + j_global = torch.arange(K_LEN, device=k.device).unsqueeze(0) + causal_mask = j_global <= i_global + + final_mask = effective_mask & causal_mask + + attn = attn.masked_fill(~final_mask, float("-inf")) + attn = F.softmax(attn, dim=-1) + ref_output = torch.einsum("bhst,bhtd->bhsd", attn, v.float()).to(torch.float16) + + torch.testing.assert_close(tilelang_output, ref_output, atol=1e-2, rtol=1e-2) + print("Pass topk sparse attention test with qlen < klen") + + +def main(): + test_topk_sparse_attention() + test_topk_sparse_attention_qlen_lt_klen() + print("Test passed!") + + +if __name__ == "__main__": + main()