Summary
Implemented FA3-aligned warp-specialized GQA attention kernels. Two variants work correctly; IntraWGOverlap is blocked by TileLang bugs (#8).
Working Implementations
1. LargeHeadDimV WS (D=64, Dv=512)
File: _test_ws_fa3_large_hdv.py
Strictly follows FA3's LargeHeadDimV path:
- WG0 (producer): TMA load K/V
- WG1 (mma): QK + softmax + PV first half (O[:, :256])
- WG2 (mma_pv): read P from smem + PV second half (O[:, 256:])
- P communicated WG1→WG2 via smem with PFull/PEmpty barriers
- Rescale factors via smem_scale
Results (vs pipelined baseline):
| Config |
Pipelined best |
FA3-WS best |
Speedup |
| B=2 S=2048 non-causal |
78 TFLOPS |
104 TFLOPS (bm64 bn128) |
1.33x |
| B=2 S=1024 non-causal |
69 TFLOPS |
87 TFLOPS |
1.27x |
2. Cooperative 2-WG WS (D=128, dim==dim_v)
File: _test_ws_fa3_cooperative.py
Both consumer WGs cooperate on the same 128-row tile:
- WG0 (producer): TMA load K/V
- WG1 (consumer): rows [0, 64) — QK + softmax + PV
- WG2 (consumer): rows [64, 128) — QK + softmax + PV
Correctness: D=128 all pass (max_diff ≤ 0.002). D=64 fails due to smem alias (#8).
Results (D=128, official benchmark params):
| Config |
Pipelined best |
WS best |
Ratio |
| llama8b-4k |
316 TFLOPS (bm128 bn64 thr128) |
172 TFLOPS |
0.55x |
| llama8b-8k |
396 TFLOPS |
202 TFLOPS |
0.51x |
WS is slower because it lacks IntraWGOverlap (QK and PV run serially within each consumer). The pipelined kernel achieves overlap via T.Pipelined order/stage/group.
IntraWGOverlap: What's Needed
FA3's key performance feature: overlap QK[n] with PV[n-1] within each consumer WG using async WGMMA.
Attempted approaches and blockers:
-
T.wgmma_gemm + T.wait_wgmma: TileLang emits fence_operand before wait_wgmma (should be after). Additionally, T.wgmma_gemm uses RS mode for PV while T.gemm uses SS mode — different lowering paths with different correctness in T.ws context.
-
Manual warpgroup primitives around T.gemm: T.gemm already includes its own fence/arrive/commit/wait — manual primitives cause double-sync.
-
CUDA postproc: Fixed fence/wait ordering via regex, but the underlying RS-vs-SS path difference in T.wgmma_gemm still causes incorrect results.
Viable path forward: hand-write WGMMA
Extract the exact WGMMA instruction sequence from T.gemm's compiled CUDA and reimplement with T.ptx_wgmma_ss/T.ptx_wgmma_rs + manual fence/arrive/commit/wait ordering. This gives full control over:
- Which WGMMA to wait for (wait count)
- When to fence accumulators
- Overlap scheduling between QK and PV
Key parameters extracted from compiled CUDA (D=128, block_n=64):
QK (SS): wgmma_ss<f16,f16,f32, 64,64,16, false,false>, 8 iters, desc <LBO=1, SBO=1, swizzle=64B>
PV (SS via stmatrix): wgmma_ss<f16,f16,f32, 64,128,16, false,true>, 4 iters, desc <LBO=1, SBO=512, swizzle=64B>
Files
_test_ws_fa3_large_hdv.py — LargeHeadDimV kernel + tests
_bench_ws_fa3_large_hdv.py — LargeHeadDimV benchmarks
_test_ws_fa3_cooperative.py — Cooperative 2-WG kernel + pipelined baseline + tests
_bench_ws_fa3_coop.py — Cooperative benchmarks (official params)
_test_ws_fa3_overlap.py — IntraWGOverlap attempts (not working)
_test_ws_postproc_overlap.py — CUDA postproc approach (not working)
Related
Summary
Implemented FA3-aligned warp-specialized GQA attention kernels. Two variants work correctly; IntraWGOverlap is blocked by TileLang bugs (#8).
Working Implementations
1. LargeHeadDimV WS (D=64, Dv=512)
File:
_test_ws_fa3_large_hdv.pyStrictly follows FA3's LargeHeadDimV path:
Results (vs pipelined baseline):
2. Cooperative 2-WG WS (D=128, dim==dim_v)
File:
_test_ws_fa3_cooperative.pyBoth consumer WGs cooperate on the same 128-row tile:
Correctness: D=128 all pass (max_diff ≤ 0.002). D=64 fails due to smem alias (#8).
Results (D=128, official benchmark params):
WS is slower because it lacks IntraWGOverlap (QK and PV run serially within each consumer). The pipelined kernel achieves overlap via
T.Pipelinedorder/stage/group.IntraWGOverlap: What's Needed
FA3's key performance feature: overlap QK[n] with PV[n-1] within each consumer WG using async WGMMA.
Attempted approaches and blockers:
T.wgmma_gemm+T.wait_wgmma: TileLang emitsfence_operandbeforewait_wgmma(should be after). Additionally,T.wgmma_gemmuses RS mode for PV whileT.gemmuses SS mode — different lowering paths with different correctness in T.ws context.Manual warpgroup primitives around
T.gemm:T.gemmalready includes its own fence/arrive/commit/wait — manual primitives cause double-sync.CUDA postproc: Fixed fence/wait ordering via regex, but the underlying RS-vs-SS path difference in
T.wgmma_gemmstill causes incorrect results.Viable path forward: hand-write WGMMA
Extract the exact WGMMA instruction sequence from
T.gemm's compiled CUDA and reimplement withT.ptx_wgmma_ss/T.ptx_wgmma_rs+ manual fence/arrive/commit/wait ordering. This gives full control over:Key parameters extracted from compiled CUDA (D=128, block_n=64):
QK (SS):
wgmma_ss<f16,f16,f32, 64,64,16, false,false>, 8 iters, desc<LBO=1, SBO=1, swizzle=64B>PV (SS via stmatrix):
wgmma_ss<f16,f16,f32, 64,128,16, false,true>, 4 iters, desc<LBO=1, SBO=512, swizzle=64B>Files
_test_ws_fa3_large_hdv.py— LargeHeadDimV kernel + tests_bench_ws_fa3_large_hdv.py— LargeHeadDimV benchmarks_test_ws_fa3_cooperative.py— Cooperative 2-WG kernel + pipelined baseline + tests_bench_ws_fa3_coop.py— Cooperative benchmarks (official params)_test_ws_fa3_overlap.py— IntraWGOverlap attempts (not working)_test_ws_postproc_overlap.py— CUDA postproc approach (not working)Related