Summary
Implemented a warp-specialized (WS) GQA forward kernel for TileOPs, aligned with FlashAttention-3's producer-consumer architecture using TileLang's T.ws() + T.tma_copy() + barrier primitives.
Correctness: WS-256 (2 warp groups) passes all tests.
Performance: WS-256 is ~2-3x slower than the existing GqaFwdWgmmaPipelinedKernel.
Blocker: WS-384 (3 warp groups) produces incorrect results — cannot match pipelined kernel's compute throughput.
Architecture
FA3 reference pattern
- Producer (1 warp, 32 threads): TMA async load K/V, near-zero compute overhead
- Consumer (1-3 warp groups, 128-384 threads): QK GEMM → online softmax → rescale O → PV GEMM
- Synchronization: mbarrier pingpong (k_ready, v_ready, compute_done)
TileLang WS implementation
for k_idx in T.Pipelined(loop_range, num_stages=0):
with T.ws(0): # Producer WG0 (128 threads)
T.barrier_wait(compute_done, (k_idx + 1) % 2)
T.tma_copy(k[...], k_shared, barrier=k_ready)
T.barrier_arrive(k_ready)
T.tma_copy(v[...], v_shared, barrier=v_ready)
T.barrier_arrive(v_ready)
with T.ws(1): # Consumer WG1 (128 threads)
T.barrier_wait(k_ready, k_idx % 2)
T.gemm(q_shared, k_shared, acc_s, ...)
online_softmax(...)
T.barrier_wait(v_ready, k_idx % 2)
T.gemm(acc_s_cast, v_shared, acc_o, ...)
T.barrier_arrive(compute_done)
Benchmark Results (H100)
Correctness (WS-256, threads=256)
| Config |
max_diff |
Status |
| B=1 S=256 H=8 Hkv=4 D=64 non-causal |
0.000244 |
PASS |
| B=1 S=1024 H=8 Hkv=4 D=64 causal |
0.001953 |
PASS |
| B=4 S=2048 H=64 Hkv=4 D=128 non-causal |
0.000122 |
PASS |
| B=4 S=2048 H=64 Hkv=4 D=128 causal |
0.001953 |
PASS |
| B=2 S=4096 H=32 Hkv=8 D=128 causal |
0.001953 |
PASS |
Performance (WS-256 vs Pipelined)
| Config |
Pipelined (ms) |
WS-256 (ms) |
Ratio |
| B=1 S=2048 H=32 Hkv=8 D=128 non-causal |
0.271 |
0.550 |
0.49x |
| B=1 S=2048 H=32 Hkv=8 D=128 causal |
0.183 |
0.601 |
0.31x |
| B=4 S=4096 H=64 Hkv=8 D=128 non-causal |
7.102 |
15.796 |
0.45x |
| B=4 S=4096 H=64 Hkv=8 D=128 causal |
3.986 |
13.947 |
0.29x |
WS-384 (3 warp groups) — INCORRECT
| Config |
threads |
diff |
Status |
| D=64 non-causal |
384 |
0.728577 |
FAIL |
| D=128 non-causal |
384 |
0.514252 |
FAIL |
| D=128 causal |
384 |
4.488281 |
FAIL |
Root Cause Analysis
Why WS-256 is slow
TileLang's T.ws() allocates entire warp groups (128 threads). With 2 WG:
- Producer: 128 threads (only need 1 thread for TMA, 127 idle)
- Consumer: 128 threads (half the compute of pipelined's 256 threads)
FA3 uses only 1 warp (32 threads) for the producer, keeping 224 threads for compute. TileLang's minimum WG granularity of 128 threads is the bottleneck.
Why WS-384 produces incorrect results
With T.ws(1, 2), both WG1 and WG2 execute the consumer block (256 threads). The C++ WarpSpecialize correctly merges consecutive WG IDs into thread_idx >= 128 && thread_idx < 384.
However, the GEMM + online_softmax interaction across 2 consumer WGs appears broken:
- Each WG has its own fragment (register) copies of
acc_s, acc_o, scores_max, etc.
T.gemm with GemmWarpPolicy.FullRow distributes rows across WGs (WG1 → rows 0-63, WG2 → rows 64-127)
- Online softmax operates per-WG on partial row sets — this should be correct IF each WG's
scores_max, logsum etc. only cover its own rows
- Hypothesis: Fragment allocation or copy-back may not correctly handle the per-WG row partitioning, causing data corruption
Further investigation needed in TileLang's WGMMA lowering for multi-consumer WS.
Barrier Configuration
k_ready: arrive_count=128 (producer WG0 arrives)
v_ready: arrive_count=128 (producer WG0 arrives)
compute_done: arrive_count=N*128 (N consumer WGs arrive)
Tested compute_done with arrive_count=128 and arrive_count=256 for 3-WG — both produce incorrect results. The issue is not barrier misconfiguration.
Environment
- TileLang: 0.1.8+cuda.git5f70374c
- GPU: NVIDIA H100
- PyTorch: 2.9.1+cu128
- Note:
T.tma_copy() is only available in this newer build. Docker image tileops-runner:latest (TileLang v0.1.8) lacks T.tma_copy.
Next Steps
- File upstream TileLang issue for multi-WG consumer (
T.ws(1, 2)) correctness
- Investigate finer-grained producer (< 128 threads) in TileLang's WS model
- Once 3-WG is fixed, WS kernel should match or exceed pipelined performance (2 consumer WGs = same 256-thread compute power, plus TMA overlap)
- Consider adding
T.tma_copy to the Docker image's TileLang build
Files Modified (on branch feat/gqa-fwd-ws, currently reverted)
tileops/kernels/flash_attn/fwd.py — GqaFwdWsKernel, _gqa_fwd_ws_kernel()
tileops/kernels/flash_attn/__init__.py — export
tileops/ops/gqa.py — dispatch (not switched to default, kept pipelined)
Summary
Implemented a warp-specialized (WS) GQA forward kernel for TileOPs, aligned with FlashAttention-3's producer-consumer architecture using TileLang's
T.ws()+T.tma_copy()+ barrier primitives.Correctness: WS-256 (2 warp groups) passes all tests.
Performance: WS-256 is ~2-3x slower than the existing
GqaFwdWgmmaPipelinedKernel.Blocker: WS-384 (3 warp groups) produces incorrect results — cannot match pipelined kernel's compute throughput.
Architecture
FA3 reference pattern
TileLang WS implementation
Benchmark Results (H100)
Correctness (WS-256, threads=256)
Performance (WS-256 vs Pipelined)
WS-384 (3 warp groups) — INCORRECT
Root Cause Analysis
Why WS-256 is slow
TileLang's
T.ws()allocates entire warp groups (128 threads). With 2 WG:FA3 uses only 1 warp (32 threads) for the producer, keeping 224 threads for compute. TileLang's minimum WG granularity of 128 threads is the bottleneck.
Why WS-384 produces incorrect results
With
T.ws(1, 2), both WG1 and WG2 execute the consumer block (256 threads). The C++WarpSpecializecorrectly merges consecutive WG IDs intothread_idx >= 128 && thread_idx < 384.However, the GEMM + online_softmax interaction across 2 consumer WGs appears broken:
acc_s,acc_o,scores_max, etc.T.gemmwithGemmWarpPolicy.FullRowdistributes rows across WGs (WG1 → rows 0-63, WG2 → rows 64-127)scores_max,logsumetc. only cover its own rowsFurther investigation needed in TileLang's WGMMA lowering for multi-consumer WS.
Barrier Configuration
Tested
compute_donewith arrive_count=128 and arrive_count=256 for 3-WG — both produce incorrect results. The issue is not barrier misconfiguration.Environment
T.tma_copy()is only available in this newer build. Docker imagetileops-runner:latest(TileLang v0.1.8) lacksT.tma_copy.Next Steps
T.ws(1, 2)) correctnessT.tma_copyto the Docker image's TileLang buildFiles Modified (on branch
feat/gqa-fwd-ws, currently reverted)tileops/kernels/flash_attn/fwd.py—GqaFwdWsKernel,_gqa_fwd_ws_kernel()tileops/kernels/flash_attn/__init__.py— exporttileops/ops/gqa.py— dispatch (not switched to default, kept pipelined)