Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 23 additions & 6 deletions .agents/skills/tilelang-op-design/SKILL.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ description: "根据算子需求生成 TileLang-Ascend 算子设计文档(desi
| **动态循环边界不支持** | 循环次数不能依赖 tensor 值(如 `batch_sizes[bz]`) | `T.Pipelined(batch_sizes[bz])` 报错 | 预计算最大循环次数,用 `T.serial(max_iters)` + 条件判断 |
| **流水线不支持动态边界** | `T.Pipelined` 的循环次数必须静态 | 动态批次无法流水线 | 改用 `T.serial` 或预计算固定迭代次数 |
| **部分 GPU API 不可用** | CUDA 专用 API 在 Ascend 不存在 | 直接移植 GPU 代码失败 | 查阅本项目 `examples/` 确认 Ascend API |
| **GEMM 要求 M,N 为 block 整数倍** | `M // block_M` 整除依赖;`M < block_M` 时零 block 启动 | 输出全零或除零编译崩溃 | 设计文档 §4/§5 必须明确处理策略:host 侧 padding+crop 或 Kernel 动态 block |
| **L0C 容量上限** | A2/A3 设备 L0C = 128KB | `block_M × block_N × sizeof(accum) > 128KB` 导致 segfault | 设计 block 时满足 `block_M × block_N ≤ 16384`(float32 accum) |

### 2.5.2 强制检测规则

Expand All @@ -80,6 +82,8 @@ description: "根据算子需求生成 TileLang-Ascend 算子设计文档(desi
| threads 参数 | 参考实现 threads > 2 | **立即警告**,建议 threads=2 或移除 |
| 动态循环边界 | 循环边界依赖 tensor 值 | **立即警告**,提出静态边界 + 条件判断方案 |
| GPU 专用 API | CUDA 相关 API(如 `T.gemm` 通用版) | **立即警告**,查阅本项目确认 Ascend API |
| GEMM 非整除风险 | `M` 或 `N` 不被 block size 整除(即 `M % block_M ≠ 0` 或 `N % block_N ≠ 0`) | **立即警告**,要求 design 中明确 padding 策略 |
| L0C 溢出风险 | block_M × block_N × sizeof(accum_dtype) > 131072 (128KB) | **立即警告**,建议减小 block 或拆分 |

### 2.5.3 警告输出格式

Expand Down Expand Up @@ -118,11 +122,13 @@ description: "根据算子需求生成 TileLang-Ascend 算子设计文档(desi
- 纯 Vector(element-wise / reduction)→ 仅需 UB
- 纯 Cube(仅 matmul)→ 需要 L1 + L0A/L0B/L0C
- 混合(matmul + element-wise 后处理)→ 核间流水线,需要 CV 融合
- **Host 预处理**:如 im2col 等 Python 侧预处理步骤,标明在 design 的 §1 和 §4 中
- **复杂度级别**:
- 单步(如 element-wise add)→ 无循环、单次搬运
- 多步(如 softmax = max + sub + exp + sum + div)→ 多次计算、可能需要中间缓冲
- 融合(如 flash attention = GEMM + softmax + GEMM)→ 核间协作、流水线
- **动态 shape 判定**:是否存在运行时才确定的维度
4. **非整除场景预判**:检查输入 shape 是否可能不被 block size 整除。GEMM 类算子的 `M // block_M` 和 `N // block_N` 在 `M < block_M` 或 `N < block_N` 时产生零 block 或不完整 tile,必须在设计中明确处理策略(host 侧 zero-padding + crop,或 Kernel 内动态 block size)

### Phase 2:信息收集

Expand Down Expand Up @@ -168,10 +174,10 @@ grep "T.Scope\|T.barrier" examples/{同类实现} # 同步方式
2. 编程模式选型
3. API 映射设计
4. 数据规格与内存规划
5. Tiling 策略
5. Tiling 策略(**必含:非整除时 padding+crop 策略,或 Kernel 内动态 block 方案**)
6. 循环与调度结构
7. 同步策略
8. CV 融合设计(如有)
8. CV 融合设计
9. 验证方案
10. 风险点与注意事项
11. 交付清单
Expand All @@ -192,7 +198,12 @@ grep "T.Scope\|T.barrier" examples/{同类实现} # 同步方式

## 4. 算子特征分析决策树(修订版)

### 4.0 平台识别
### 4.0 函数设计原则

1. **维度参数自推导**:算子调用函数(如 `conv_im2col_gemm`)应从输入 tensor shape 提取 B/C/H/W 等维度,不依赖模块级全局变量。这保证多场景顺序测试时不发生变量污染。
2. **Host 预处理显式声明**:若计算的一部分在 Python 侧完成(如 im2col),必须在 §1 算法描述和 §4 数据流中明确标注。

### 4.1 平台识别

**本项目为 TileLang-Ascend(昇腾 NPU)**,与 GPU 版 TileLang 有差异:

Expand Down Expand Up @@ -283,16 +294,19 @@ grep "T.Scope\|T.barrier" examples/{同类实现} # 同步方式
| 3 | **内存搬运路径完整**:从 GM 到计算再到 GM 的每一步都有说明 | ✅ 必须 |
| 4 | **Tiling 策略有约束分析**:解释了为什么选择该 Block/Tile 大小 | ⭕ 推荐 |
| 5 | **同步策略与编程模式匹配**:Developer 用自动同步、Expert 标明手动同步点 | ⭕ 推荐 |
| 6 | **验证方案覆盖典型配置**:不是「待补充」 | ⭕ 推荐 |
| 6 | **验证方案覆盖 4 类典型配置**:完美对齐 + 单维 padding + 全维 padding + 多 block(GEMM 类必含),不是「待补充」 | ⭕ 推荐 |
| 7 | **无占位符或模糊描述**:无 `{placeholder}`、TODO、「待补充」(已确认的除外) | ✅ 必须 |
| 8 | **技术约束已确认**:三维 Kernel、threads、动态边界等问题已处理 | ✅ 必须 |
| 9 | **本项目同类实现已列出**:有具体的 examples/ 文件路径参考 | ✅ 必须 |
| 10 | **参考实现差异已说明**:如果有外部参考,列出 API/结构差异 | ⭕ 推荐 |
| 11 | **参考实现分析完整**(如有参考实现):记录了内存层级 API、同步策略、pass_configs 等技术决策 | ⭕ 推荐 |
| 12 | **CV 融合设计完整**(如需):workspace 规格、数据流、pass_configs | ⭕ 推荐 |
| 13 | **workspace_idx 配置正确**(如需 CV 融合):与 workspace 参数位置一致 | ✅ 必须 |
| 14 | **非整除处理策略明确**(GEMM 类必含):主机侧 padding+crop 或 Kernel 内动态 block,说明溢出/下溢处理 | ✅ 必须 |
| 15 | **L0C 容量约束已验证**(GEMM 类必含):`block_M × block_N × sizeof(accum_dtype) ≤ L0C_capacity (128KB)` | ⭕ 推荐 |
| 16 | **函数无全局变量依赖**:维度参数从 tensor shape 或函数参数获取,支持多场景顺序测试 | ⭕ 推荐 |

**通过条件**:必须项(1, 2, 3, 7, 8, 9)全部通过,推荐项(4, 5, 6, 10)至少通过 3/4
**通过条件**:必须项(1, 2, 3, 7, 8, 9, 14)全部通过,推荐项至少通过 4/9
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The list of mandatory items is missing item 14 (which was just added) and item 13 (which is marked as mandatory on line 304). Please update the pass criteria to include all required checks.

Suggested change
**通过条件**:必须项(1, 2, 3, 7, 8, 9, 14)全部通过,推荐项至少通过 4/9。
**通过条件**:必须项(1, 2, 3, 7, 8, 9, 13, 14)全部通过,推荐项至少通过 4/9。


---

Expand Down Expand Up @@ -352,11 +366,14 @@ grep "T.Scope\|T.barrier" examples/{同类实现} # 同步方式
3. 内存搬运完整性: ✅ / ❌
4. Tiling 约束分析: ✅ / ❌
5. 同步策略匹配: ✅ / ❌
6. 验证方案覆盖: ✅ / ❌
6. 验证方案覆盖(4 类): ✅ / ❌
7. 无占位符: ✅ / ❌
8. 技术约束确认: ✅ / ❌
9. 本项目同类实现列出: ✅ / ❌
10. 参考实现差异说明: ✅ / ❌ / N/A
11. 非整除处理策略: ✅ / ❌ / N/A
12. L0C 容量约束: ✅ / ❌ / N/A
13. 无全局变量依赖: ✅ / ❌

### 待确认项
- {列出需要用户进一步确认的内容}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,13 +117,16 @@ TypeError: get_configs() missing 1 required positional argument: 'K'

**现象**: autotune 编译通过但 benchmark 时进程直接 crash(Segfault),无 Python 异常。

**原因**: `block_M * block_N * sizeof(accum_dtype) > L0C_capacity`(A2/A3 设备 L0C=64KB,float32 cum元素≤16384)。如 `block_M=256, block_N=256 → 256KB > 64KB`。
**Segfault类似问题排查建议**: Segment fault 需要通过 gdb 等工具定位具体 crash 位置和调用栈,再结合 kernel 配置、访存范围、片上内存使用量等因素判断根因。
可能原因之一:当 `block_M * block_N * sizeof(accum_dtype) > L0C_capacity` 时,可能导致片上 buffer 使用超过硬件限制。例如 A2/A3 设备 L0C 为 128KB,float32 accum 元素数不应超过 32768;

**解决方案**: autotune 的 `get_configs` 中过滤超大 block:
```python
block_M = [bs for bs in [64, 128] if bs <= M] # 排除 256
```



## 运行时错误

### 1. 结果不正确
Expand Down
260 changes: 260 additions & 0 deletions examples/seer_attention/block_sparse_attn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,260 @@
import math
import torch

import tilelang
import tilelang.language as T
import torch.nn.functional as F


def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False):
bsz, num_head, downsample_len, _ = x.shape
sparse_index = torch.topk(x, topk, dim=-1).indices
dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], False, dtype=torch.bool, device=x.device)
dense_mask.scatter_(-1, sparse_index, True)
if use_dense_for_last_block:
dense_mask[:, :, -2:, :] = True
dense_mask.tril_()
return dense_mask


def get_sparse_attn_mask_from_threshold(x, threshold, use_dense_for_last_block=False):
dense_mask = x > threshold
if use_dense_for_last_block:
dense_mask[:, :, -2:, :] = True
dense_mask.tril_()
return dense_mask


pass_configs = {
tilelang.PassConfigKey.TL_ASCEND_AUTO_CV_COMBINE: True,
tilelang.PassConfigKey.TL_ASCEND_AUTO_CV_SYNC: True,
tilelang.PassConfigKey.TL_ASCEND_AUTO_SYNC: True,
tilelang.PassConfigKey.TL_ASCEND_MEMORY_PLANNING: True,
}


@tilelang.jit(out_idx=[4], workspace_idx=[5, 6, 7], pass_configs=pass_configs)
def blocksparse_flashattn_ascend(batch, heads, seq_q, seq_kv, dim, downsample_len, is_causal):
block_M = 64
block_N = 64
dtype = "float16"
accum_dtype = "float"
block_mask_dtype = "int8"
sm_scale = (1.0 / dim) ** 0.5

q_shape = [batch, heads, seq_q, dim]
kv_shape = [batch, heads, seq_kv, dim]
block_mask_shape = [batch, heads, downsample_len, downsample_len]
block_num = ((seq_q + block_M - 1) // block_M) * heads * batch

@T.prim_func
def main(
Q: T.Tensor(q_shape, dtype),
K: T.Tensor(kv_shape, dtype),
V: T.Tensor(kv_shape, dtype),
BlockSparseMask: T.Tensor(block_mask_shape, block_mask_dtype),
Output: T.Tensor(q_shape, dtype),
workspace_1: T.Tensor([block_num, block_M, block_N], accum_dtype),
workspace_2: T.Tensor([block_num, block_M, block_N], dtype),
workspace_3: T.Tensor([block_num, block_M, dim], accum_dtype),
):
with T.Kernel(block_num, is_npu=True) as (cid, vid):
bx = cid % ((seq_q + block_M - 1) // block_M)
by = cid // ((seq_q + block_M - 1) // block_M) % heads
bz = cid // ((seq_q + block_M - 1) // block_M) // heads % batch

# L1 buffers (Cube core)
q_l1 = T.alloc_L1([block_M, dim], dtype)
k_l1 = T.alloc_L1([block_N, dim], dtype)
v_l1 = T.alloc_L1([block_N, dim], dtype)
acc_s_l1 = T.alloc_L1([block_M, block_N], dtype)
acc_s_l0c = T.alloc_L0C([block_M, block_N], accum_dtype)
acc_o_l0c = T.alloc_L0C([block_M, dim], accum_dtype)

# UB buffers (Vector core) for online softmax
acc_o = T.alloc_ub([block_M // 2, dim], accum_dtype)
sumexp = T.alloc_ub([block_M // 2], accum_dtype)
m_i = T.alloc_ub([block_M // 2], accum_dtype)
acc_s_ub = T.alloc_ub([block_M // 2, block_N], accum_dtype)
m_i_prev = T.alloc_ub([block_M // 2], accum_dtype)
acc_s_ub_ = T.alloc_ub([block_M // 2, block_N], accum_dtype)
sumexp_i_ub = T.alloc_ub([block_M // 2], accum_dtype)
acc_s_half = T.alloc_ub([block_M // 2, block_N], dtype)
acc_o_ub = T.alloc_ub([block_M // 2, dim], accum_dtype)
acc_o_half = T.alloc_ub([block_M // 2, dim], dtype)

# UB buffers for block-sparse mask and causal mask
mask_row = T.alloc_ub([downsample_len], block_mask_dtype)
col_idx = T.alloc_ub([block_N], "int32")
col_idx_f = T.alloc_ub([block_N], "float")
cmp_mask = T.alloc_ub([block_N // 8], "uint8")

# === Cube: Load Q, run GEMM pipeline for all KV blocks ===
T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], q_l1)

loop_range = (seq_kv + block_N - 1) // block_N
for k in T.serial(loop_range):
T.copy(K[bz, by, k * block_N : (k + 1) * block_N, :], k_l1)
T.gemm_v0(q_l1, k_l1, acc_s_l0c, transpose_B=True, init=True)
T.copy(acc_s_l0c, workspace_1[cid, :, :])

T.copy(workspace_2[cid, :, :], acc_s_l1)
T.copy(V[bz, by, k * block_N : (k + 1) * block_N, :], v_l1)
T.gemm_v0(acc_s_l1, v_l1, acc_o_l0c, init=True)
T.copy(acc_o_l0c, workspace_3[cid, :, :])

# === Vector: Online softmax + output accumulation ===
T.tile.fill(acc_o, 0.0)
T.tile.fill(sumexp, 0.0)
T.tile.fill(m_i, -T.infinity(accum_dtype))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using T.infinity might lead to compatibility issues or unexpected behavior on some NPU compiler versions. It is safer and more consistent with other examples in this repository (e.g., flash_attn_bhsd_cc_sync.py) to use a large concrete negative value like -2**30 for masking purposes.

Suggested change
T.tile.fill(m_i, -T.infinity(accum_dtype))
T.tile.fill(m_i, -2**30)


T.copy(BlockSparseMask[bz, by, bx, :], mask_row)

past_len = seq_kv - seq_q

for _k in T.serial(loop_range):
T.tile.fill(acc_s_ub, 0.0)
T.copy(m_i, m_i_prev)

# Load S = Q@K^T from Cube workspace
T.copy(workspace_1[cid, vid * block_M // 2 : vid * block_M // 2 + block_M // 2, :], acc_s_ub_)

# Block sparsity mask
if mask_row[_k] == 0:
T.tile.fill(acc_s_ub_, -T.infinity(accum_dtype))

# Element-level causal mask via compare+select (float32)
elif is_causal:
T.tile.createvecindex(col_idx, _k * block_N)
T.copy(col_idx, col_idx_f)

for h_i in range(block_M // 2):
q_pos = bx * block_M + vid * (block_M // 2) + h_i + past_len
T.tile.compare(cmp_mask, col_idx_f, T.float32(q_pos), "LE")
T.tile.select(acc_s_ub_[h_i, :], cmp_mask, acc_s_ub_[h_i, :], -T.infinity(accum_dtype), "VSEL_TENSOR_SCALAR_MODE")

T.tile.add(acc_s_ub, acc_s_ub, acc_s_ub_)
T.tile.mul(acc_s_ub, acc_s_ub, sm_scale)

T.reduce_max(acc_s_ub, m_i, dim=-1)
T.tile.max(m_i, m_i, m_i_prev)
T.tile.sub(m_i_prev, m_i_prev, m_i)
T.tile.exp(m_i_prev, m_i_prev)

for h_i in range(block_M // 2):
T.tile.sub(acc_s_ub[h_i, :], acc_s_ub[h_i, :], m_i[h_i])

T.tile.exp(acc_s_ub, acc_s_ub)
T.reduce_sum(acc_s_ub, sumexp_i_ub, dim=-1)
T.tile.mul(sumexp, sumexp, m_i_prev)
T.tile.add(sumexp, sumexp, sumexp_i_ub)

for h_i in range(block_M // 2):
T.tile.mul(acc_o[h_i, :], acc_o[h_i, :], m_i_prev[h_i])

T.copy(acc_s_ub, acc_s_half)
T.copy(acc_s_half, workspace_2[cid, vid * block_M // 2 : vid * block_M // 2 + block_M // 2, :])

T.copy(workspace_3[cid, vid * block_M // 2 : vid * block_M // 2 + block_M // 2, :], acc_o_ub)
T.tile.add(acc_o, acc_o, acc_o_ub)

# Final normalization: acc_o /= sumexp
for h_i in range(block_M // 2):
T.tile.div(acc_o[h_i, :], acc_o[h_i, :], sumexp[h_i])

T.copy(acc_o, acc_o_half)
T.copy(acc_o_half, Output[bz, by, bx * block_M + vid * block_M // 2 : bx * block_M + vid * block_M // 2 + block_M // 2, :])
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This T.copy operation poses an out-of-bounds (OOB) risk if seq_q is not a multiple of block_M. For example, if seq_q=100 and block_M=64, the second block will attempt to write to indices 64:128, exceeding the tensor boundary. This violates the rule added in SKILL.md (Item 14) requiring explicit handling of non-divisible shapes. You should either implement a tail-handling strategy (like using T.min for the slice end or validRow logic) or ensure the input is padded on the host side.


return main


def test_topk_sparse_attention():
BATCH, N_HEADS, SEQ_LEN, D_HEAD = 4, 2, 256, 64
TOPK = 2
BLOCK = 64
torch.manual_seed(0)

q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="npu", dtype=torch.float16)
k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="npu", dtype=torch.float16)
v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="npu", dtype=torch.float16)

sm_scale = 1.0 / (D_HEAD**0.5)

downsample_factor = BLOCK
downsample_len = math.ceil(SEQ_LEN / downsample_factor)
x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], device="npu", dtype=torch.float16)
x_ds[:, :, :, 0] = 100
block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK)

kernel = blocksparse_flashattn_ascend(BATCH, N_HEADS, SEQ_LEN, SEQ_LEN, D_HEAD, downsample_len, is_causal=True)
torch.npu.synchronize()
tilelang_output = kernel(q, k, v, block_mask.to(torch.int8))
torch.npu.synchronize()

# Reference with FULL element-level causal + block sparsity
full_mask = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device="npu"))
full_mask = full_mask[..., :SEQ_LEN, :SEQ_LEN].bool()
full_mask = full_mask & torch.tril(torch.ones_like(full_mask))

attn = torch.einsum("bhsd,bhtd->bhst", q.float(), k.float()) * sm_scale
attn = attn.masked_fill(~full_mask, float("-inf"))
attn = F.softmax(attn, dim=-1)
ref_output = torch.einsum("bhst,bhtd->bhsd", attn, v.float()).to(torch.float16)

torch.testing.assert_close(tilelang_output, ref_output, atol=1e-2, rtol=1e-2)
print("Pass topk sparse attention test with qlen == klen")


def test_topk_sparse_attention_qlen_lt_klen():
BATCH, N_HEADS = 1, 1
Q_LEN, K_LEN, D_HEAD = 128, 256, 64
TOPK = 1
BLOCK = 64
torch.manual_seed(0)

q = torch.randn(BATCH, N_HEADS, Q_LEN, D_HEAD, device="npu", dtype=torch.float16)
k = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device="npu", dtype=torch.float16)
v = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device="npu", dtype=torch.float16)
sm_scale = 1.0 / (D_HEAD**0.5)

downsample_factor = BLOCK
downsample_len = math.ceil(K_LEN / downsample_factor)
x_ds = torch.randn(BATCH, N_HEADS, downsample_len, downsample_len, device="npu", dtype=torch.float16)
x_ds[:, :, :, 0] = 100
block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK)

kernel = blocksparse_flashattn_ascend(BATCH, N_HEADS, Q_LEN, K_LEN, D_HEAD, downsample_len, is_causal=True)
torch.npu.synchronize()
tilelang_output = kernel(q, k, v, block_mask.to(torch.int8))
torch.npu.synchronize()

past_len = K_LEN - Q_LEN

attn = torch.einsum("bhsd,bhtd->bhst", q.float(), k.float()) * sm_scale

full_mask_full = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device="npu")).bool()
full_mask_full = full_mask_full[..., :K_LEN, :K_LEN]
effective_mask = full_mask_full[..., past_len:K_LEN, :]

i_global = torch.arange(past_len, K_LEN, device=k.device).unsqueeze(1)
j_global = torch.arange(K_LEN, device=k.device).unsqueeze(0)
causal_mask = j_global <= i_global

final_mask = effective_mask & causal_mask

attn = attn.masked_fill(~final_mask, float("-inf"))
attn = F.softmax(attn, dim=-1)
ref_output = torch.einsum("bhst,bhtd->bhsd", attn, v.float()).to(torch.float16)

torch.testing.assert_close(tilelang_output, ref_output, atol=1e-2, rtol=1e-2)
print("Pass topk sparse attention test with qlen < klen")


def main():
test_topk_sparse_attention()
test_topk_sparse_attention_qlen_lt_klen()
print("Test passed!")


if __name__ == "__main__":
main()
Loading