feat(gfx1151): Triton flash attention for Qwen3.x vision encoder#1357
Open
carlushuang wants to merge 2 commits into
Open
feat(gfx1151): Triton flash attention for Qwen3.x vision encoder#1357carlushuang wants to merge 2 commits into
carlushuang wants to merge 2 commits into
Conversation
On gfx1151 (RDNA3.5) torch's scaled_dot_product_attention falls back to the
unfused math backend for the vision encoder (flash/mem-efficient SDPA are
disabled on this arch), which is slow and O(N^2) memory -- large images OOM.
Route the ViT self-attention to aiter's Triton prefill kernel
(context_attention_fwd) instead, which runs on gfx1151 at the ViT head_dim
(72) and is a proper flash attention (O(N) memory). Per-image cu_seqlens are
built from grid_thw and threaded through VisionTransformer -> VisionBlock ->
VisionAttention, so each image attends only within its own patches (this also
fixes the previous full cross-image attention for multi-image inputs).
Gated on `not aiter_hip_kernels_supported()` so only non-gfx9 arches use the
Triton path; gfx9/CDNA keep SDPA (which has a fast flash backend there).
torch SDPA remains the fallback when aiter is unavailable.
Output is unchanged (validated: image color identification still correct).
Attention kernel speedup vs torch SDPA on a single Radeon 8060S (16 heads,
head_dim 72, bf16, non-causal):
tokens ~image triton sdpa speedup
1024 512px 0.89ms 4.33ms 4.9x
2048 724px 2.99ms 16.58ms 5.6x
4096 1024px 11.0ms 65.8ms 6.0x
8192 1448px 45.3ms 287ms 6.4x
Full 27-layer ViT forward: 1.8x at 512px, 2.6x at 768px (grows with size).
The Triton prefill kernel is pathologically slow when head_dim is not a multiple of 16 (the WMMA load granularity): at the Qwen3 ViT head_dim 72 it runs ~2x slower than head_dim 80 or 128 due to unaligned masked head-dim loads (measured on a Radeon 8060S: d72=11.0ms vs d80=5.4ms vs d128=6.7ms at N=4096). It still pads internally to next_power_of_2=128 either way, so the penalty is purely the unaligned actual head_dim. Zero-pad q/k/v to the next multiple of 16 before the kernel and slice the output back. The padded dims contribute nothing to QK^T / AV, so the result is unchanged (max relerr unchanged, ~0.008 vs SDPA). ~2x across sizes: tokens d72 pad80 speedup 1024 0.91ms 0.45ms 2.03x 2048 2.99ms 1.50ms 2.00x 4096 11.0ms 5.34ms 2.06x Combined with the Triton switch this is ~10-12x over the torch SDPA math fallback. No-op when head_dim is already a multiple of 16.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
feat(gfx1151): Triton flash attention for the Qwen3.x vision encoder
On gfx1151 (RDNA3.5 / Radeon 8060S) torch
scaled_dot_product_attentionfalls back to the unfused math backend for the ViT — flash/mem-efficient SDPA are disabled on this arch (AOTRITON "experimental" path). That's slow and O(N²) in memory, so large images are slow and can OOM.This routes the vision self-attention to aiter's Triton prefill kernel (
context_attention_fwd), which:head_dim=72(non-power-of-2),not aiter_hip_kernels_supported()); gfx9/CDNA keep SDPA (fast flash there), and torch SDPA stays as the fallback when aiter is unavailable.Per-image
cu_seqlensare built fromgrid_thwand threadedVisionTransformer → VisionBlock → VisionAttention, so each image attends only within its own patches — this also fixes the previous full cross-image attention for multi-image inputs.Correctness
Output unchanged — image color-identification test still correct (Red/Blue/Green). Numerically matches SDPA (max relerr < 0.008 across sizes).
Performance (single Radeon 8060S, gfx1151)
Attention kernel, 16 heads / head_dim 72 / bf16 / non-causal:
head_dim is also zero-padded 72→80 (next multiple of 16) — the Triton kernel is ~2x slower on the unaligned head_dim 72 (masked head loads) than on a 16-aligned dim, for identical results. Final attention kernel:
(the head_dim padding alone is ~2.0x; combined with the Triton switch it's ~10–12x over the SDPA math fallback).
Full 27-layer ViT forward: 1.8x @512px, 2.6x @768px (grows with image size; O(N) memory also avoids the large-image OOM).
Scope
Vision-attention only. ViT LayerNorm/GEMM/Conv3d stay torch (aiter's Triton LayerNorm runs but is slower;
gemm_a16w16≈ torch hipBLASLt; aiter CK/asm FMHA is CDNA-only and fails on gfx1151 — Triton is the only working flash path here).