Skip to content

feat(gfx1151): Triton flash attention for Qwen3.x vision encoder#1357

Open
carlushuang wants to merge 2 commits into
mainfrom
carhuang/gfx1151_vit_triton_attn
Open

feat(gfx1151): Triton flash attention for Qwen3.x vision encoder#1357
carlushuang wants to merge 2 commits into
mainfrom
carhuang/gfx1151_vit_triton_attn

Conversation

@carlushuang

@carlushuang carlushuang commented Jun 25, 2026

Copy link
Copy Markdown
Collaborator

feat(gfx1151): Triton flash attention for the Qwen3.x vision encoder

On gfx1151 (RDNA3.5 / Radeon 8060S) torch scaled_dot_product_attention falls back to the unfused math backend for the ViT — flash/mem-efficient SDPA are disabled on this arch (AOTRITON "experimental" path). That's slow and O(N²) in memory, so large images are slow and can OOM.

This routes the vision self-attention to aiter's Triton prefill kernel (context_attention_fwd), which:

  • runs correctly on gfx1151 at the ViT head_dim=72 (non-power-of-2),
  • is a proper flash attention (O(N) memory),
  • is selected only on non-gfx9 arches (not aiter_hip_kernels_supported()); gfx9/CDNA keep SDPA (fast flash there), and torch SDPA stays as the fallback when aiter is unavailable.

Per-image cu_seqlens are built from grid_thw and threaded VisionTransformer → VisionBlock → VisionAttention, so each image attends only within its own patches — this also fixes the previous full cross-image attention for multi-image inputs.

Correctness

Output unchanged — image color-identification test still correct (Red/Blue/Green). Numerically matches SDPA (max relerr < 0.008 across sizes).

Performance (single Radeon 8060S, gfx1151)

Attention kernel, 16 heads / head_dim 72 / bf16 / non-causal:

head_dim is also zero-padded 72→80 (next multiple of 16) — the Triton kernel is ~2x slower on the unaligned head_dim 72 (masked head loads) than on a 16-aligned dim, for identical results. Final attention kernel:

tokens ~image Triton (d72→80) unaligned d72 torch SDPA vs SDPA
1024 512px 0.45 ms 0.89 ms 4.33 ms 9.6x
2048 724px 1.50 ms 2.99 ms 16.6 ms 11x
4096 1024px 5.34 ms 11.0 ms 65.8 ms 12x

(the head_dim padding alone is ~2.0x; combined with the Triton switch it's ~10–12x over the SDPA math fallback).

Full 27-layer ViT forward: 1.8x @512px, 2.6x @768px (grows with image size; O(N) memory also avoids the large-image OOM).

Scope

Vision-attention only. ViT LayerNorm/GEMM/Conv3d stay torch (aiter's Triton LayerNorm runs but is slower; gemm_a16w16 ≈ torch hipBLASLt; aiter CK/asm FMHA is CDNA-only and fails on gfx1151 — Triton is the only working flash path here).

On gfx1151 (RDNA3.5) torch's scaled_dot_product_attention falls back to the
unfused math backend for the vision encoder (flash/mem-efficient SDPA are
disabled on this arch), which is slow and O(N^2) memory -- large images OOM.

Route the ViT self-attention to aiter's Triton prefill kernel
(context_attention_fwd) instead, which runs on gfx1151 at the ViT head_dim
(72) and is a proper flash attention (O(N) memory). Per-image cu_seqlens are
built from grid_thw and threaded through VisionTransformer -> VisionBlock ->
VisionAttention, so each image attends only within its own patches (this also
fixes the previous full cross-image attention for multi-image inputs).

Gated on `not aiter_hip_kernels_supported()` so only non-gfx9 arches use the
Triton path; gfx9/CDNA keep SDPA (which has a fast flash backend there).
torch SDPA remains the fallback when aiter is unavailable.

Output is unchanged (validated: image color identification still correct).

Attention kernel speedup vs torch SDPA on a single Radeon 8060S (16 heads,
head_dim 72, bf16, non-causal):

  tokens  ~image   triton    sdpa   speedup
    1024   512px   0.89ms   4.33ms    4.9x
    2048   724px   2.99ms  16.58ms    5.6x
    4096  1024px  11.0ms   65.8ms     6.0x
    8192  1448px  45.3ms   287ms      6.4x

Full 27-layer ViT forward: 1.8x at 512px, 2.6x at 768px (grows with size).
The Triton prefill kernel is pathologically slow when head_dim is not a
multiple of 16 (the WMMA load granularity): at the Qwen3 ViT head_dim 72 it
runs ~2x slower than head_dim 80 or 128 due to unaligned masked head-dim
loads (measured on a Radeon 8060S: d72=11.0ms vs d80=5.4ms vs d128=6.7ms at
N=4096). It still pads internally to next_power_of_2=128 either way, so the
penalty is purely the unaligned actual head_dim.

Zero-pad q/k/v to the next multiple of 16 before the kernel and slice the
output back. The padded dims contribute nothing to QK^T / AV, so the result
is unchanged (max relerr unchanged, ~0.008 vs SDPA). ~2x across sizes:

  tokens   d72      pad80   speedup
   1024   0.91ms   0.45ms   2.03x
   2048   2.99ms   1.50ms   2.00x
   4096  11.0ms    5.34ms   2.06x

Combined with the Triton switch this is ~10-12x over the torch SDPA math
fallback. No-op when head_dim is already a multiple of 16.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant