Skip to content

FA3-aligned WS GQA: working implementations and IntraWGOverlap roadmap #9

@superAngGao

Description

@superAngGao

Summary

Implemented FA3-aligned warp-specialized GQA attention kernels. Two variants work correctly; IntraWGOverlap is blocked by TileLang bugs (#8).

Working Implementations

1. LargeHeadDimV WS (D=64, Dv=512)

File: _test_ws_fa3_large_hdv.py

Strictly follows FA3's LargeHeadDimV path:

  • WG0 (producer): TMA load K/V
  • WG1 (mma): QK + softmax + PV first half (O[:, :256])
  • WG2 (mma_pv): read P from smem + PV second half (O[:, 256:])
  • P communicated WG1→WG2 via smem with PFull/PEmpty barriers
  • Rescale factors via smem_scale

Results (vs pipelined baseline):

Config Pipelined best FA3-WS best Speedup
B=2 S=2048 non-causal 78 TFLOPS 104 TFLOPS (bm64 bn128) 1.33x
B=2 S=1024 non-causal 69 TFLOPS 87 TFLOPS 1.27x

2. Cooperative 2-WG WS (D=128, dim==dim_v)

File: _test_ws_fa3_cooperative.py

Both consumer WGs cooperate on the same 128-row tile:

  • WG0 (producer): TMA load K/V
  • WG1 (consumer): rows [0, 64) — QK + softmax + PV
  • WG2 (consumer): rows [64, 128) — QK + softmax + PV

Correctness: D=128 all pass (max_diff ≤ 0.002). D=64 fails due to smem alias (#8).

Results (D=128, official benchmark params):

Config Pipelined best WS best Ratio
llama8b-4k 316 TFLOPS (bm128 bn64 thr128) 172 TFLOPS 0.55x
llama8b-8k 396 TFLOPS 202 TFLOPS 0.51x

WS is slower because it lacks IntraWGOverlap (QK and PV run serially within each consumer). The pipelined kernel achieves overlap via T.Pipelined order/stage/group.

IntraWGOverlap: What's Needed

FA3's key performance feature: overlap QK[n] with PV[n-1] within each consumer WG using async WGMMA.

Attempted approaches and blockers:

  1. T.wgmma_gemm + T.wait_wgmma: TileLang emits fence_operand before wait_wgmma (should be after). Additionally, T.wgmma_gemm uses RS mode for PV while T.gemm uses SS mode — different lowering paths with different correctness in T.ws context.

  2. Manual warpgroup primitives around T.gemm: T.gemm already includes its own fence/arrive/commit/wait — manual primitives cause double-sync.

  3. CUDA postproc: Fixed fence/wait ordering via regex, but the underlying RS-vs-SS path difference in T.wgmma_gemm still causes incorrect results.

Viable path forward: hand-write WGMMA

Extract the exact WGMMA instruction sequence from T.gemm's compiled CUDA and reimplement with T.ptx_wgmma_ss/T.ptx_wgmma_rs + manual fence/arrive/commit/wait ordering. This gives full control over:

  • Which WGMMA to wait for (wait count)
  • When to fence accumulators
  • Overlap scheduling between QK and PV

Key parameters extracted from compiled CUDA (D=128, block_n=64):

QK (SS): wgmma_ss<f16,f16,f32, 64,64,16, false,false>, 8 iters, desc <LBO=1, SBO=1, swizzle=64B>
PV (SS via stmatrix): wgmma_ss<f16,f16,f32, 64,128,16, false,true>, 4 iters, desc <LBO=1, SBO=512, swizzle=64B>

Files

  • _test_ws_fa3_large_hdv.py — LargeHeadDimV kernel + tests
  • _bench_ws_fa3_large_hdv.py — LargeHeadDimV benchmarks
  • _test_ws_fa3_cooperative.py — Cooperative 2-WG kernel + pipelined baseline + tests
  • _bench_ws_fa3_coop.py — Cooperative benchmarks (official params)
  • _test_ws_fa3_overlap.py — IntraWGOverlap attempts (not working)
  • _test_ws_postproc_overlap.py — CUDA postproc approach (not working)

Related

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions