Skip to content

GQA WS IntraWGOverlap: SASS analysis identifies FA3 split-accumulator as root cause of 10% gap #13

@superAngGao

Description

@superAngGao

Background

Issue #12 目标:将 TileOPs GQA WarpSpecialized kernel 的 IntraWGOverlap 性能从 ~598 TF/s (90% FA3) 提升至 FA3 水平 (~660 TF/s)。

经过数周迭代(delay_rescale、vwait_up、schedbar、anchor),通过详细的 SASS 逐指令对比,今天定位到了根本原因。


SASS 层差异(实测,H100 SXM5)

变体 wait<1>→wait<0> 间指令数 说明
FA3 163 条 包含完整 softmax + rescale(acc_o)
TileOPs anchor 16 条 仅 6 个 FMNMX + anchor store
TileOPs delay_rescale 10 条 仅几个 FMNMX
TileOPs scaleout 10 条 同上

目标是让 wait<1>→wait<0> 窗口内执行完整的 softmax+rescale(约 100 条指令),让 PV WGMMA 与之并行。


各层 pipeline 顺序对比

FA3(FlashAttention 3,CUTLASS 手写)

算法层:

Iter N:
  QK_N  commit  (WGMMA gsb0)
  PV_{N-1} commit  (WGMMA gsb0,使用上一iter的 P_{N-1}_raw)
  wait<1>                      → QK_N 完成,acc_s_N 可读
  reduce_max(acc_s_N)          → 找 max QK 分(FMNMX 树)
  rescale(acc_o_old, Δm)       → acc_o *= exp(m_{N-2} - m_{N-1})
  wait<0>                      → PV_{N-1} 完成,acc_pv_N-1 可读
  acc_pv /= sum_exp            → FMUL 归一化(标量除法)
  acc_o += acc_pv              → FADD 合并
  exp(acc_s_N), cast to FP16   → 为下一iter准备 P_N_raw

关键点:FA3 不在 QK 和 PV 之间等 V ready(无 barrier_wait(v_full))。 V 的 TMA load 由生产者 WG 异步完成,消费者无条件提交 PV WGMMA,数据就绪由 mbarrier 协议隐式保证。

关键的双缓冲寄存器设计(SASS 实测):

  • R112–R216:PV WGMMA 输出(专用临时缓冲 acc_pv
  • R13–R108:acc_o 主累积(rescale 读写对象)
  • 两个寄存器集完全不重叠 → ptxas 可以安全地将 wait<0> 推迟到 rescale 之后

TileOPs(当前实现)

算法层:

Iter N (n_idx ≥ 1):
  QK_N commit
  barrier_wait(v_full, N-1)    ← TileOPs 在此等 V 就绪
  PV_{N-1} commit(acc_o += acc_s_cast @ V,accumulate into acc_o)
  wait<1>
  softmax(acc_s_N) → ss
  wait<0>                      ← ptxas 提前插在这里,仅 10 条指令后
  rescale(acc_o, ss)           ← 只能在 wait<0> 后做(RAW 冲突)

RAW 冲突根本原因:
TileOPs 的 PV WGMMA 使用 clear_accum=False(ScaleOut::One),直接累积进 acc_oacc_o += P@V)。rescale 也读写同一组 acc_o 寄存器,形成 WAW/RAW 冲突。ptxas 无法将 wait<0> 推迟到 rescale 后面,只能在 PV commit 完成后立即插入 wait<0>。

所有之前的尝试(delay_rescale、schedbar、anchor 的 st.shared.u32 分支)都无法解决这个根本冲突——只要 rescale 和 PV 写同一组寄存器,ptxas 就必须在 rescale 前插入 wait<0>。


解决方案:Split Accumulator(本 PR 实现)

引入独立的 acc_pv fragment(与 acc_o 完全独立的寄存器集):

acc_pv_1 = T.alloc_fragment([half_m, D], accum_dtype)  # 新增

# 稳态循环:
T.wgmma_gemm(acc_s_cast_1, v, acc_pv_1, clear_accum=True)  # PV → acc_pv
# ... wait<1> ...
softmax(acc_s_1, ...)           # 读 acc_s,不读 acc_o / acc_pv
rescale(acc_o_1, ss_1)          # 只读写 acc_o(无冲突)
# ... wait<0> ...               # ptxas 可延迟到 rescale 之后
fence(acc_pv_1)
# acc_o += acc_pv(64 FADD,约 2 cycle)
for i, j in T.Parallel(half_m, D):
    acc_o_1[i, j] = acc_o_1[i, j] + acc_pv_1[i, j]

目标 SASS 顺序(预期与 FA3 一致):

PV commit (gsb0)
wait<1>
[~64 FMNMX: reduce_max]
[~31 FFMA: rescale acc_o]
wait<0>
[64 FADD: acc_o += acc_pv]

代价分析

项目 代价 收益
+64 FADD/iter ~2 cycle softmax+rescale (~100 insns) 与 PV 完全并行
额外寄存器压力 +64 FP32 regs 可能影响 occupancy(待测)
额外寄存器合并 epilogue 同上 无额外延迟

净收益预期:正数(100 cycle 节省 >> 2 cycle FADD)。


实验文件

  • /home/ga/TileOPs/_test_ws_fa3_v2_persistent_splitbuf.py — 本变体
  • /home/ga/tmp/sass_gqa/anchor2.sass — anchor SASS(wait<1>→wait<0> = 16 insns)
  • /home/ga/tmp/sass_gqa/fa3_target.sass — FA3 SASS(wait<1>→wait<0> = 163 insns)
  • /home/ga/TileOPs/_anchor_helper.htl::wait_wgmma_anchor<N> inline PTX 辅助

状态

  • SASS root-cause 分析完成
  • split-buf 变体代码编写完成
  • 正确性验证(运行中)
  • SASS 确认(wait<1>→wait<0> 窗口扩大)
  • 性能测试

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions