diff --git a/atom/model_ops/vit_attention.py b/atom/model_ops/vit_attention.py new file mode 100644 index 000000000..9f6aa8292 --- /dev/null +++ b/atom/model_ops/vit_attention.py @@ -0,0 +1,175 @@ +import torch +import math +import triton +import triton.language as tl + + +@triton.jit +def _vit_attn_varlen( + Q, + K, + V, + Out, + B_Start, + B_Seqlen, + sm_scale, + sq_n, + sq_h, + sk_n, + sk_h, + sv_n, + sv_h, + so_n, + so_h, + HEAD_DIM: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + cur_b = tl.program_id(0) + pid_h = tl.program_id(1) + pid_m = tl.program_id(2) + seqlen = tl.load(B_Seqlen + cur_b) + start = tl.load(B_Start + cur_b) + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_t = tl.arange(0, 16) + m_mask = offs_m < seqlen + qb = Q + (start + offs_m)[:, None] * sq_n + pid_h * sq_h + d0 = 0 * 16 + offs_t + q0 = tl.load( + qb + d0[None, :], mask=m_mask[:, None] & (d0[None, :] < HEAD_DIM), other=0.0 + ) + a0 = tl.zeros([BLOCK_M, 16], tl.float32) + d1 = 1 * 16 + offs_t + q1 = tl.load( + qb + d1[None, :], mask=m_mask[:, None] & (d1[None, :] < HEAD_DIM), other=0.0 + ) + a1 = tl.zeros([BLOCK_M, 16], tl.float32) + d2 = 2 * 16 + offs_t + q2 = tl.load( + qb + d2[None, :], mask=m_mask[:, None] & (d2[None, :] < HEAD_DIM), other=0.0 + ) + a2 = tl.zeros([BLOCK_M, 16], tl.float32) + d3 = 3 * 16 + offs_t + q3 = tl.load( + qb + d3[None, :], mask=m_mask[:, None] & (d3[None, :] < HEAD_DIM), other=0.0 + ) + a3 = tl.zeros([BLOCK_M, 16], tl.float32) + d4 = 4 * 16 + offs_t + q4 = tl.load( + qb + d4[None, :], mask=m_mask[:, None] & (d4[None, :] < HEAD_DIM), other=0.0 + ) + a4 = tl.zeros([BLOCK_M, 16], tl.float32) + m_i = tl.full([BLOCK_M], -float("inf"), tl.float32) + l_i = tl.zeros([BLOCK_M], tl.float32) + for n0 in range(0, seqlen, BLOCK_N): + offs_n = n0 + tl.arange(0, BLOCK_N) + nmask = offs_n < seqlen + kb = K + (start + offs_n)[None, :] * sk_n + pid_h * sk_h + vb = V + (start + offs_n)[:, None] * sv_n + pid_h * sv_h + qk = tl.zeros([BLOCK_M, BLOCK_N], tl.float32) + k0 = tl.load( + kb + d0[:, None], mask=nmask[None, :] & (d0[:, None] < HEAD_DIM), other=0.0 + ) + qk += tl.dot(q0, k0) + k1 = tl.load( + kb + d1[:, None], mask=nmask[None, :] & (d1[:, None] < HEAD_DIM), other=0.0 + ) + qk += tl.dot(q1, k1) + k2 = tl.load( + kb + d2[:, None], mask=nmask[None, :] & (d2[:, None] < HEAD_DIM), other=0.0 + ) + qk += tl.dot(q2, k2) + k3 = tl.load( + kb + d3[:, None], mask=nmask[None, :] & (d3[:, None] < HEAD_DIM), other=0.0 + ) + qk += tl.dot(q3, k3) + k4 = tl.load( + kb + d4[:, None], mask=nmask[None, :] & (d4[:, None] < HEAD_DIM), other=0.0 + ) + qk += tl.dot(q4, k4) + qk = qk * sm_scale + tl.where(nmask[None, :], 0.0, -float("inf")) + m_new = tl.maximum(m_i, tl.max(qk, 1)) + p = tl.exp(qk - m_new[:, None]) + alpha = tl.exp(m_i - m_new) + l_i = l_i * alpha + tl.sum(p, 1) + p = p.to(V.dtype.element_ty) + v0 = tl.load( + vb + d0[None, :], mask=nmask[:, None] & (d0[None, :] < HEAD_DIM), other=0.0 + ) + a0 = a0 * alpha[:, None] + tl.dot(p, v0) + v1 = tl.load( + vb + d1[None, :], mask=nmask[:, None] & (d1[None, :] < HEAD_DIM), other=0.0 + ) + a1 = a1 * alpha[:, None] + tl.dot(p, v1) + v2 = tl.load( + vb + d2[None, :], mask=nmask[:, None] & (d2[None, :] < HEAD_DIM), other=0.0 + ) + a2 = a2 * alpha[:, None] + tl.dot(p, v2) + v3 = tl.load( + vb + d3[None, :], mask=nmask[:, None] & (d3[None, :] < HEAD_DIM), other=0.0 + ) + a3 = a3 * alpha[:, None] + tl.dot(p, v3) + v4 = tl.load( + vb + d4[None, :], mask=nmask[:, None] & (d4[None, :] < HEAD_DIM), other=0.0 + ) + a4 = a4 * alpha[:, None] + tl.dot(p, v4) + m_i = m_new + ob = Out + (start + offs_m)[:, None] * so_n + pid_h * so_h + tl.store( + ob + d0[None, :], + (a0 / l_i[:, None]).to(Out.dtype.element_ty), + mask=m_mask[:, None] & (d0[None, :] < HEAD_DIM), + ) + tl.store( + ob + d1[None, :], + (a1 / l_i[:, None]).to(Out.dtype.element_ty), + mask=m_mask[:, None] & (d1[None, :] < HEAD_DIM), + ) + tl.store( + ob + d2[None, :], + (a2 / l_i[:, None]).to(Out.dtype.element_ty), + mask=m_mask[:, None] & (d2[None, :] < HEAD_DIM), + ) + tl.store( + ob + d3[None, :], + (a3 / l_i[:, None]).to(Out.dtype.element_ty), + mask=m_mask[:, None] & (d3[None, :] < HEAD_DIM), + ) + tl.store( + ob + d4[None, :], + (a4 / l_i[:, None]).to(Out.dtype.element_ty), + mask=m_mask[:, None] & (d4[None, :] < HEAD_DIM), + ) + + +def vit_flash_attn( + q, k, v, b_start_loc, b_seq_len, max_seqlen, BLOCK_M=128, BLOCK_N=32 +): + # q,k,v: (total_tokens, num_heads, head_dim) contiguous. Per-image varlen via cu_seqlens. + N, H, D = q.shape + assert D <= 80, "this kernel tiles head_dim into 5x16=80" + o = torch.empty_like(q) + batch = b_seq_len.shape[0] + grid = (batch, H, triton.cdiv(max_seqlen, BLOCK_M)) + _vit_attn_varlen[grid]( + q, + k, + v, + o, + b_start_loc, + b_seq_len, + 1.0 / math.sqrt(D), + q.stride(0), + q.stride(1), + k.stride(0), + k.stride(1), + v.stride(0), + v.stride(1), + o.stride(0), + o.stride(1), + HEAD_DIM=D, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + num_warps=4, + ) + return o diff --git a/atom/models/qwen3_5_vl.py b/atom/models/qwen3_5_vl.py index 8f028b734..9e7b34853 100644 --- a/atom/models/qwen3_5_vl.py +++ b/atom/models/qwen3_5_vl.py @@ -15,6 +15,19 @@ import torch.nn as nn import torch.nn.functional as F +# gfx1151 (RDNA3.5): torch SDPA falls back to the unfused math backend for the ViT +# (flash/mem-efficient are disabled). Use a custom Triton flash attention that +# tiles head_dim into 5x16=80 (vs padding to 128) -> ~24 TFLOPS, ~20x over SDPA. +# On gfx9/CDNA SDPA already has a fast flash backend, so keep SDPA there. +try: + from atom.model_ops.vit_attention import vit_flash_attn + from atom.utils.arch import aiter_hip_kernels_supported + + _USE_TRITON_VIT_ATTN = not aiter_hip_kernels_supported() +except Exception: # pragma: no cover - fallback if aiter/arch unavailable + vit_flash_attn = None + _USE_TRITON_VIT_ATTN = False + class Qwen3VisionPatchEmbed(nn.Module): def __init__( @@ -79,6 +92,7 @@ def forward( x: torch.Tensor, rotary_pos_emb_cos: torch.Tensor, rotary_pos_emb_sin: torch.Tensor, + cu_seqlens=None, ) -> torch.Tensor: # x: [seq_len, 1, embed_dim] (the VisionBlock adds a batch dim) seq_len = x.shape[0] @@ -93,13 +107,32 @@ def forward( q = self._apply_rotary_emb(q, rotary_pos_emb_cos, rotary_pos_emb_sin) k = self._apply_rotary_emb(k, rotary_pos_emb_cos, rotary_pos_emb_sin) - # Reshape for SDPA: [batch, heads, seq_len, head_dim] - q = q.unsqueeze(0).transpose(1, 2) # [1, num_heads, seq_len, head_dim] - k = k.unsqueeze(0).transpose(1, 2) - v = v.unsqueeze(0).transpose(1, 2) - - out = F.scaled_dot_product_attention(q, k, v) - out = out.transpose(1, 2).reshape(seq_len, self.embed_dim) + if ( + _USE_TRITON_VIT_ATTN + and vit_flash_attn is not None + and cu_seqlens is not None + and self.head_dim <= 80 + ): + # Custom head-dim-tiled flash attention (per-image varlen). Tiles + # head_dim into 5x16=80 instead of padding to 128, so the QK/AV + # contraction is 80-deep (1.6x fewer WMMA k-steps). No external pad. + b_start_loc, b_seq_len, max_seqlen = cu_seqlens + out = vit_flash_attn( + q.contiguous(), + k.contiguous(), + v.contiguous(), + b_start_loc, + b_seq_len, + max_seqlen, + ) + out = out.reshape(seq_len, self.embed_dim) + else: + # Reshape for SDPA: [batch, heads, seq_len, head_dim] + qh = q.unsqueeze(0).transpose(1, 2) # [1, num_heads, seq_len, head_dim] + kh = k.unsqueeze(0).transpose(1, 2) + vh = v.unsqueeze(0).transpose(1, 2) + out = F.scaled_dot_product_attention(qh, kh, vh) + out = out.transpose(1, 2).reshape(seq_len, self.embed_dim) out = self.proj(out) return out.view(seq_len, batch, self.embed_dim) @@ -144,11 +177,13 @@ def forward( x: torch.Tensor, rotary_pos_emb_cos: torch.Tensor, rotary_pos_emb_sin: torch.Tensor, + cu_seqlens=None, ) -> torch.Tensor: x = x + self.attn( self.norm1(x), rotary_pos_emb_cos=rotary_pos_emb_cos, rotary_pos_emb_sin=rotary_pos_emb_sin, + cu_seqlens=cu_seqlens, ) x = x + self.mlp(self.norm2(x)) return x @@ -387,11 +422,25 @@ def forward( # Rotary position embeddings rotary_cos, rotary_sin = self._compute_rotary_emb(grid_thw_list) + # Per-image attention segments (cu_seqlens) for the Triton prefill kernel: + # each image attends only within its own t*h*w patches. + seqlens = [int(t) * int(h) * int(w) for t, h, w in grid_thw_list] + b_seq_len = torch.tensor( + seqlens, dtype=torch.int32, device=hidden_states.device + ) + b_start_loc = torch.zeros( + len(seqlens), dtype=torch.int32, device=hidden_states.device + ) + if len(seqlens) > 1: + b_start_loc[1:] = torch.cumsum(b_seq_len[:-1], dim=0).to(torch.int32) + cu_seqlens = (b_start_loc, b_seq_len, max(seqlens) if seqlens else 0) + for blk in self.blocks: hidden_states = blk( hidden_states, rotary_pos_emb_cos=rotary_cos, rotary_pos_emb_sin=rotary_sin, + cu_seqlens=cu_seqlens, ) hidden_states = self.merger(hidden_states)