Skip to content

GQA Forward: Warp Specialization (FA3-aligned) — Findings & Blockers #4

@superAngGao

Description

@superAngGao

Summary

Implemented a warp-specialized (WS) GQA forward kernel for TileOPs, aligned with FlashAttention-3's producer-consumer architecture using TileLang's T.ws() + T.tma_copy() + barrier primitives.

Correctness: WS-256 (2 warp groups) passes all tests.
Performance: WS-256 is ~2-3x slower than the existing GqaFwdWgmmaPipelinedKernel.
Blocker: WS-384 (3 warp groups) produces incorrect results — cannot match pipelined kernel's compute throughput.

Architecture

FA3 reference pattern

  • Producer (1 warp, 32 threads): TMA async load K/V, near-zero compute overhead
  • Consumer (1-3 warp groups, 128-384 threads): QK GEMM → online softmax → rescale O → PV GEMM
  • Synchronization: mbarrier pingpong (k_ready, v_ready, compute_done)

TileLang WS implementation

for k_idx in T.Pipelined(loop_range, num_stages=0):
    with T.ws(0):  # Producer WG0 (128 threads)
        T.barrier_wait(compute_done, (k_idx + 1) % 2)
        T.tma_copy(k[...], k_shared, barrier=k_ready)
        T.barrier_arrive(k_ready)
        T.tma_copy(v[...], v_shared, barrier=v_ready)
        T.barrier_arrive(v_ready)
    with T.ws(1):  # Consumer WG1 (128 threads)
        T.barrier_wait(k_ready, k_idx % 2)
        T.gemm(q_shared, k_shared, acc_s, ...)
        online_softmax(...)
        T.barrier_wait(v_ready, k_idx % 2)
        T.gemm(acc_s_cast, v_shared, acc_o, ...)
        T.barrier_arrive(compute_done)

Benchmark Results (H100)

Correctness (WS-256, threads=256)

Config max_diff Status
B=1 S=256 H=8 Hkv=4 D=64 non-causal 0.000244 PASS
B=1 S=1024 H=8 Hkv=4 D=64 causal 0.001953 PASS
B=4 S=2048 H=64 Hkv=4 D=128 non-causal 0.000122 PASS
B=4 S=2048 H=64 Hkv=4 D=128 causal 0.001953 PASS
B=2 S=4096 H=32 Hkv=8 D=128 causal 0.001953 PASS

Performance (WS-256 vs Pipelined)

Config Pipelined (ms) WS-256 (ms) Ratio
B=1 S=2048 H=32 Hkv=8 D=128 non-causal 0.271 0.550 0.49x
B=1 S=2048 H=32 Hkv=8 D=128 causal 0.183 0.601 0.31x
B=4 S=4096 H=64 Hkv=8 D=128 non-causal 7.102 15.796 0.45x
B=4 S=4096 H=64 Hkv=8 D=128 causal 3.986 13.947 0.29x

WS-384 (3 warp groups) — INCORRECT

Config threads diff Status
D=64 non-causal 384 0.728577 FAIL
D=128 non-causal 384 0.514252 FAIL
D=128 causal 384 4.488281 FAIL

Root Cause Analysis

Why WS-256 is slow

TileLang's T.ws() allocates entire warp groups (128 threads). With 2 WG:

  • Producer: 128 threads (only need 1 thread for TMA, 127 idle)
  • Consumer: 128 threads (half the compute of pipelined's 256 threads)

FA3 uses only 1 warp (32 threads) for the producer, keeping 224 threads for compute. TileLang's minimum WG granularity of 128 threads is the bottleneck.

Why WS-384 produces incorrect results

With T.ws(1, 2), both WG1 and WG2 execute the consumer block (256 threads). The C++ WarpSpecialize correctly merges consecutive WG IDs into thread_idx >= 128 && thread_idx < 384.

However, the GEMM + online_softmax interaction across 2 consumer WGs appears broken:

  1. Each WG has its own fragment (register) copies of acc_s, acc_o, scores_max, etc.
  2. T.gemm with GemmWarpPolicy.FullRow distributes rows across WGs (WG1 → rows 0-63, WG2 → rows 64-127)
  3. Online softmax operates per-WG on partial row sets — this should be correct IF each WG's scores_max, logsum etc. only cover its own rows
  4. Hypothesis: Fragment allocation or copy-back may not correctly handle the per-WG row partitioning, causing data corruption

Further investigation needed in TileLang's WGMMA lowering for multi-consumer WS.

Barrier Configuration

k_ready:       arrive_count=128  (producer WG0 arrives)
v_ready:       arrive_count=128  (producer WG0 arrives)
compute_done:  arrive_count=N*128 (N consumer WGs arrive)

Tested compute_done with arrive_count=128 and arrive_count=256 for 3-WG — both produce incorrect results. The issue is not barrier misconfiguration.

Environment

  • TileLang: 0.1.8+cuda.git5f70374c
  • GPU: NVIDIA H100
  • PyTorch: 2.9.1+cu128
  • Note: T.tma_copy() is only available in this newer build. Docker image tileops-runner:latest (TileLang v0.1.8) lacks T.tma_copy.

Next Steps

  1. File upstream TileLang issue for multi-WG consumer (T.ws(1, 2)) correctness
  2. Investigate finer-grained producer (< 128 threads) in TileLang's WS model
  3. Once 3-WG is fixed, WS kernel should match or exceed pipelined performance (2 consumer WGs = same 256-thread compute power, plus TMA overlap)
  4. Consider adding T.tma_copy to the Docker image's TileLang build

Files Modified (on branch feat/gqa-fwd-ws, currently reverted)

  • tileops/kernels/flash_attn/fwd.pyGqaFwdWsKernel, _gqa_fwd_ws_kernel()
  • tileops/kernels/flash_attn/__init__.py — export
  • tileops/ops/gqa.py — dispatch (not switched to default, kept pipelined)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions