Background
Issue #12 目标:将 TileOPs GQA WarpSpecialized kernel 的 IntraWGOverlap 性能从 ~598 TF/s (90% FA3) 提升至 FA3 水平 (~660 TF/s)。
经过数周迭代(delay_rescale、vwait_up、schedbar、anchor),通过详细的 SASS 逐指令对比,今天定位到了根本原因。
SASS 层差异(实测,H100 SXM5)
| 变体 |
wait<1>→wait<0> 间指令数 |
说明 |
| FA3 |
163 条 |
包含完整 softmax + rescale(acc_o) |
| TileOPs anchor |
16 条 |
仅 6 个 FMNMX + anchor store |
| TileOPs delay_rescale |
10 条 |
仅几个 FMNMX |
| TileOPs scaleout |
10 条 |
同上 |
目标是让 wait<1>→wait<0> 窗口内执行完整的 softmax+rescale(约 100 条指令),让 PV WGMMA 与之并行。
各层 pipeline 顺序对比
FA3(FlashAttention 3,CUTLASS 手写)
算法层:
Iter N:
QK_N commit (WGMMA gsb0)
PV_{N-1} commit (WGMMA gsb0,使用上一iter的 P_{N-1}_raw)
wait<1> → QK_N 完成,acc_s_N 可读
reduce_max(acc_s_N) → 找 max QK 分(FMNMX 树)
rescale(acc_o_old, Δm) → acc_o *= exp(m_{N-2} - m_{N-1})
wait<0> → PV_{N-1} 完成,acc_pv_N-1 可读
acc_pv /= sum_exp → FMUL 归一化(标量除法)
acc_o += acc_pv → FADD 合并
exp(acc_s_N), cast to FP16 → 为下一iter准备 P_N_raw
关键点:FA3 不在 QK 和 PV 之间等 V ready(无 barrier_wait(v_full))。 V 的 TMA load 由生产者 WG 异步完成,消费者无条件提交 PV WGMMA,数据就绪由 mbarrier 协议隐式保证。
关键的双缓冲寄存器设计(SASS 实测):
R112–R216:PV WGMMA 输出(专用临时缓冲 acc_pv)
R13–R108:acc_o 主累积(rescale 读写对象)
- 两个寄存器集完全不重叠 → ptxas 可以安全地将 wait<0> 推迟到 rescale 之后
TileOPs(当前实现)
算法层:
Iter N (n_idx ≥ 1):
QK_N commit
barrier_wait(v_full, N-1) ← TileOPs 在此等 V 就绪
PV_{N-1} commit(acc_o += acc_s_cast @ V,accumulate into acc_o)
wait<1>
softmax(acc_s_N) → ss
wait<0> ← ptxas 提前插在这里,仅 10 条指令后
rescale(acc_o, ss) ← 只能在 wait<0> 后做(RAW 冲突)
RAW 冲突根本原因:
TileOPs 的 PV WGMMA 使用 clear_accum=False(ScaleOut::One),直接累积进 acc_o(acc_o += P@V)。rescale 也读写同一组 acc_o 寄存器,形成 WAW/RAW 冲突。ptxas 无法将 wait<0> 推迟到 rescale 后面,只能在 PV commit 完成后立即插入 wait<0>。
所有之前的尝试(delay_rescale、schedbar、anchor 的 st.shared.u32 分支)都无法解决这个根本冲突——只要 rescale 和 PV 写同一组寄存器,ptxas 就必须在 rescale 前插入 wait<0>。
解决方案:Split Accumulator(本 PR 实现)
引入独立的 acc_pv fragment(与 acc_o 完全独立的寄存器集):
acc_pv_1 = T.alloc_fragment([half_m, D], accum_dtype) # 新增
# 稳态循环:
T.wgmma_gemm(acc_s_cast_1, v, acc_pv_1, clear_accum=True) # PV → acc_pv
# ... wait<1> ...
softmax(acc_s_1, ...) # 读 acc_s,不读 acc_o / acc_pv
rescale(acc_o_1, ss_1) # 只读写 acc_o(无冲突)
# ... wait<0> ... # ptxas 可延迟到 rescale 之后
fence(acc_pv_1)
# acc_o += acc_pv(64 FADD,约 2 cycle)
for i, j in T.Parallel(half_m, D):
acc_o_1[i, j] = acc_o_1[i, j] + acc_pv_1[i, j]
目标 SASS 顺序(预期与 FA3 一致):
PV commit (gsb0)
wait<1>
[~64 FMNMX: reduce_max]
[~31 FFMA: rescale acc_o]
wait<0>
[64 FADD: acc_o += acc_pv]
代价分析
| 项目 |
代价 |
收益 |
| +64 FADD/iter |
~2 cycle |
softmax+rescale (~100 insns) 与 PV 完全并行 |
| 额外寄存器压力 |
+64 FP32 regs |
可能影响 occupancy(待测) |
| 额外寄存器合并 epilogue |
同上 |
无额外延迟 |
净收益预期:正数(100 cycle 节省 >> 2 cycle FADD)。
实验文件
/home/ga/TileOPs/_test_ws_fa3_v2_persistent_splitbuf.py — 本变体
/home/ga/tmp/sass_gqa/anchor2.sass — anchor SASS(wait<1>→wait<0> = 16 insns)
/home/ga/tmp/sass_gqa/fa3_target.sass — FA3 SASS(wait<1>→wait<0> = 163 insns)
/home/ga/TileOPs/_anchor_helper.h — tl::wait_wgmma_anchor<N> inline PTX 辅助
状态
Background
Issue #12 目标:将 TileOPs GQA WarpSpecialized kernel 的 IntraWGOverlap 性能从 ~598 TF/s (90% FA3) 提升至 FA3 水平 (~660 TF/s)。
经过数周迭代(delay_rescale、vwait_up、schedbar、anchor),通过详细的 SASS 逐指令对比,今天定位到了根本原因。
SASS 层差异(实测,H100 SXM5)
目标是让 wait<1>→wait<0> 窗口内执行完整的 softmax+rescale(约 100 条指令),让 PV WGMMA 与之并行。
各层 pipeline 顺序对比
FA3(FlashAttention 3,CUTLASS 手写)
算法层:
关键点:FA3 不在 QK 和 PV 之间等 V ready(无
barrier_wait(v_full))。 V 的 TMA load 由生产者 WG 异步完成,消费者无条件提交 PV WGMMA,数据就绪由 mbarrier 协议隐式保证。关键的双缓冲寄存器设计(SASS 实测):
R112–R216:PV WGMMA 输出(专用临时缓冲acc_pv)R13–R108:acc_o 主累积(rescale 读写对象)TileOPs(当前实现)
算法层:
RAW 冲突根本原因:
TileOPs 的 PV WGMMA 使用
clear_accum=False(ScaleOut::One),直接累积进 acc_o(acc_o += P@V)。rescale 也读写同一组 acc_o 寄存器,形成 WAW/RAW 冲突。ptxas 无法将 wait<0> 推迟到 rescale 后面,只能在 PV commit 完成后立即插入 wait<0>。所有之前的尝试(delay_rescale、schedbar、anchor 的
st.shared.u32分支)都无法解决这个根本冲突——只要 rescale 和 PV 写同一组寄存器,ptxas 就必须在 rescale 前插入 wait<0>。解决方案:Split Accumulator(本 PR 实现)
引入独立的
acc_pvfragment(与acc_o完全独立的寄存器集):目标 SASS 顺序(预期与 FA3 一致):
代价分析
净收益预期:正数(100 cycle 节省 >> 2 cycle FADD)。
实验文件
/home/ga/TileOPs/_test_ws_fa3_v2_persistent_splitbuf.py— 本变体/home/ga/tmp/sass_gqa/anchor2.sass— anchor SASS(wait<1>→wait<0> = 16 insns)/home/ga/tmp/sass_gqa/fa3_target.sass— FA3 SASS(wait<1>→wait<0> = 163 insns)/home/ga/TileOPs/_anchor_helper.h—tl::wait_wgmma_anchor<N>inline PTX 辅助状态