Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 175 additions & 0 deletions atom/model_ops/vit_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
import torch
import math
import triton
import triton.language as tl


@triton.jit
def _vit_attn_varlen(
Q,
K,
V,
Out,
B_Start,
B_Seqlen,
sm_scale,
sq_n,
sq_h,
sk_n,
sk_h,
sv_n,
sv_h,
so_n,
so_h,
HEAD_DIM: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
):
cur_b = tl.program_id(0)
pid_h = tl.program_id(1)
pid_m = tl.program_id(2)
seqlen = tl.load(B_Seqlen + cur_b)
start = tl.load(B_Start + cur_b)
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_t = tl.arange(0, 16)
m_mask = offs_m < seqlen
qb = Q + (start + offs_m)[:, None] * sq_n + pid_h * sq_h
d0 = 0 * 16 + offs_t
q0 = tl.load(
qb + d0[None, :], mask=m_mask[:, None] & (d0[None, :] < HEAD_DIM), other=0.0
)
a0 = tl.zeros([BLOCK_M, 16], tl.float32)
d1 = 1 * 16 + offs_t
q1 = tl.load(
qb + d1[None, :], mask=m_mask[:, None] & (d1[None, :] < HEAD_DIM), other=0.0
)
a1 = tl.zeros([BLOCK_M, 16], tl.float32)
d2 = 2 * 16 + offs_t
q2 = tl.load(
qb + d2[None, :], mask=m_mask[:, None] & (d2[None, :] < HEAD_DIM), other=0.0
)
a2 = tl.zeros([BLOCK_M, 16], tl.float32)
d3 = 3 * 16 + offs_t
q3 = tl.load(
qb + d3[None, :], mask=m_mask[:, None] & (d3[None, :] < HEAD_DIM), other=0.0
)
a3 = tl.zeros([BLOCK_M, 16], tl.float32)
d4 = 4 * 16 + offs_t
q4 = tl.load(
qb + d4[None, :], mask=m_mask[:, None] & (d4[None, :] < HEAD_DIM), other=0.0
)
a4 = tl.zeros([BLOCK_M, 16], tl.float32)
m_i = tl.full([BLOCK_M], -float("inf"), tl.float32)
l_i = tl.zeros([BLOCK_M], tl.float32)
for n0 in range(0, seqlen, BLOCK_N):
offs_n = n0 + tl.arange(0, BLOCK_N)
nmask = offs_n < seqlen
kb = K + (start + offs_n)[None, :] * sk_n + pid_h * sk_h
vb = V + (start + offs_n)[:, None] * sv_n + pid_h * sv_h
qk = tl.zeros([BLOCK_M, BLOCK_N], tl.float32)
k0 = tl.load(
kb + d0[:, None], mask=nmask[None, :] & (d0[:, None] < HEAD_DIM), other=0.0
)
qk += tl.dot(q0, k0)
k1 = tl.load(
kb + d1[:, None], mask=nmask[None, :] & (d1[:, None] < HEAD_DIM), other=0.0
)
qk += tl.dot(q1, k1)
k2 = tl.load(
kb + d2[:, None], mask=nmask[None, :] & (d2[:, None] < HEAD_DIM), other=0.0
)
qk += tl.dot(q2, k2)
k3 = tl.load(
kb + d3[:, None], mask=nmask[None, :] & (d3[:, None] < HEAD_DIM), other=0.0
)
qk += tl.dot(q3, k3)
k4 = tl.load(
kb + d4[:, None], mask=nmask[None, :] & (d4[:, None] < HEAD_DIM), other=0.0
)
qk += tl.dot(q4, k4)
qk = qk * sm_scale + tl.where(nmask[None, :], 0.0, -float("inf"))
m_new = tl.maximum(m_i, tl.max(qk, 1))
p = tl.exp(qk - m_new[:, None])
alpha = tl.exp(m_i - m_new)
l_i = l_i * alpha + tl.sum(p, 1)
p = p.to(V.dtype.element_ty)
v0 = tl.load(
vb + d0[None, :], mask=nmask[:, None] & (d0[None, :] < HEAD_DIM), other=0.0
)
a0 = a0 * alpha[:, None] + tl.dot(p, v0)
v1 = tl.load(
vb + d1[None, :], mask=nmask[:, None] & (d1[None, :] < HEAD_DIM), other=0.0
)
a1 = a1 * alpha[:, None] + tl.dot(p, v1)
v2 = tl.load(
vb + d2[None, :], mask=nmask[:, None] & (d2[None, :] < HEAD_DIM), other=0.0
)
a2 = a2 * alpha[:, None] + tl.dot(p, v2)
v3 = tl.load(
vb + d3[None, :], mask=nmask[:, None] & (d3[None, :] < HEAD_DIM), other=0.0
)
a3 = a3 * alpha[:, None] + tl.dot(p, v3)
v4 = tl.load(
vb + d4[None, :], mask=nmask[:, None] & (d4[None, :] < HEAD_DIM), other=0.0
)
a4 = a4 * alpha[:, None] + tl.dot(p, v4)
m_i = m_new
ob = Out + (start + offs_m)[:, None] * so_n + pid_h * so_h
tl.store(
ob + d0[None, :],
(a0 / l_i[:, None]).to(Out.dtype.element_ty),
mask=m_mask[:, None] & (d0[None, :] < HEAD_DIM),
)
tl.store(
ob + d1[None, :],
(a1 / l_i[:, None]).to(Out.dtype.element_ty),
mask=m_mask[:, None] & (d1[None, :] < HEAD_DIM),
)
tl.store(
ob + d2[None, :],
(a2 / l_i[:, None]).to(Out.dtype.element_ty),
mask=m_mask[:, None] & (d2[None, :] < HEAD_DIM),
)
tl.store(
ob + d3[None, :],
(a3 / l_i[:, None]).to(Out.dtype.element_ty),
mask=m_mask[:, None] & (d3[None, :] < HEAD_DIM),
)
tl.store(
ob + d4[None, :],
(a4 / l_i[:, None]).to(Out.dtype.element_ty),
mask=m_mask[:, None] & (d4[None, :] < HEAD_DIM),
)


def vit_flash_attn(
q, k, v, b_start_loc, b_seq_len, max_seqlen, BLOCK_M=128, BLOCK_N=32
):
# q,k,v: (total_tokens, num_heads, head_dim) contiguous. Per-image varlen via cu_seqlens.
N, H, D = q.shape
assert D <= 80, "this kernel tiles head_dim into 5x16=80"
o = torch.empty_like(q)
batch = b_seq_len.shape[0]
grid = (batch, H, triton.cdiv(max_seqlen, BLOCK_M))
_vit_attn_varlen[grid](
q,
k,
v,
o,
b_start_loc,
b_seq_len,
1.0 / math.sqrt(D),
q.stride(0),
q.stride(1),
k.stride(0),
k.stride(1),
v.stride(0),
v.stride(1),
o.stride(0),
o.stride(1),
HEAD_DIM=D,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
num_warps=4,
)
return o
63 changes: 56 additions & 7 deletions atom/models/qwen3_5_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,19 @@
import torch.nn as nn
import torch.nn.functional as F

# gfx1151 (RDNA3.5): torch SDPA falls back to the unfused math backend for the ViT
# (flash/mem-efficient are disabled). Use a custom Triton flash attention that
# tiles head_dim into 5x16=80 (vs padding to 128) -> ~24 TFLOPS, ~20x over SDPA.
# On gfx9/CDNA SDPA already has a fast flash backend, so keep SDPA there.
try:
from atom.model_ops.vit_attention import vit_flash_attn
from atom.utils.arch import aiter_hip_kernels_supported

_USE_TRITON_VIT_ATTN = not aiter_hip_kernels_supported()
except Exception: # pragma: no cover - fallback if aiter/arch unavailable
vit_flash_attn = None
_USE_TRITON_VIT_ATTN = False


class Qwen3VisionPatchEmbed(nn.Module):
def __init__(
Expand Down Expand Up @@ -79,6 +92,7 @@ def forward(
x: torch.Tensor,
rotary_pos_emb_cos: torch.Tensor,
rotary_pos_emb_sin: torch.Tensor,
cu_seqlens=None,
) -> torch.Tensor:
# x: [seq_len, 1, embed_dim] (the VisionBlock adds a batch dim)
seq_len = x.shape[0]
Expand All @@ -93,13 +107,32 @@ def forward(
q = self._apply_rotary_emb(q, rotary_pos_emb_cos, rotary_pos_emb_sin)
k = self._apply_rotary_emb(k, rotary_pos_emb_cos, rotary_pos_emb_sin)

# Reshape for SDPA: [batch, heads, seq_len, head_dim]
q = q.unsqueeze(0).transpose(1, 2) # [1, num_heads, seq_len, head_dim]
k = k.unsqueeze(0).transpose(1, 2)
v = v.unsqueeze(0).transpose(1, 2)

out = F.scaled_dot_product_attention(q, k, v)
out = out.transpose(1, 2).reshape(seq_len, self.embed_dim)
if (
_USE_TRITON_VIT_ATTN
and vit_flash_attn is not None
and cu_seqlens is not None
and self.head_dim <= 80
):
# Custom head-dim-tiled flash attention (per-image varlen). Tiles
# head_dim into 5x16=80 instead of padding to 128, so the QK/AV
# contraction is 80-deep (1.6x fewer WMMA k-steps). No external pad.
b_start_loc, b_seq_len, max_seqlen = cu_seqlens
out = vit_flash_attn(
q.contiguous(),
k.contiguous(),
v.contiguous(),
b_start_loc,
b_seq_len,
max_seqlen,
)
out = out.reshape(seq_len, self.embed_dim)
else:
# Reshape for SDPA: [batch, heads, seq_len, head_dim]
qh = q.unsqueeze(0).transpose(1, 2) # [1, num_heads, seq_len, head_dim]
kh = k.unsqueeze(0).transpose(1, 2)
vh = v.unsqueeze(0).transpose(1, 2)
out = F.scaled_dot_product_attention(qh, kh, vh)
out = out.transpose(1, 2).reshape(seq_len, self.embed_dim)
out = self.proj(out)
return out.view(seq_len, batch, self.embed_dim)

Expand Down Expand Up @@ -144,11 +177,13 @@ def forward(
x: torch.Tensor,
rotary_pos_emb_cos: torch.Tensor,
rotary_pos_emb_sin: torch.Tensor,
cu_seqlens=None,
) -> torch.Tensor:
x = x + self.attn(
self.norm1(x),
rotary_pos_emb_cos=rotary_pos_emb_cos,
rotary_pos_emb_sin=rotary_pos_emb_sin,
cu_seqlens=cu_seqlens,
)
x = x + self.mlp(self.norm2(x))
return x
Expand Down Expand Up @@ -387,11 +422,25 @@ def forward(
# Rotary position embeddings
rotary_cos, rotary_sin = self._compute_rotary_emb(grid_thw_list)

# Per-image attention segments (cu_seqlens) for the Triton prefill kernel:
# each image attends only within its own t*h*w patches.
seqlens = [int(t) * int(h) * int(w) for t, h, w in grid_thw_list]
b_seq_len = torch.tensor(
seqlens, dtype=torch.int32, device=hidden_states.device
)
b_start_loc = torch.zeros(
len(seqlens), dtype=torch.int32, device=hidden_states.device
)
if len(seqlens) > 1:
b_start_loc[1:] = torch.cumsum(b_seq_len[:-1], dim=0).to(torch.int32)
cu_seqlens = (b_start_loc, b_seq_len, max(seqlens) if seqlens else 0)

for blk in self.blocks:
hidden_states = blk(
hidden_states,
rotary_pos_emb_cos=rotary_cos,
rotary_pos_emb_sin=rotary_sin,
cu_seqlens=cu_seqlens,
)

hidden_states = self.merger(hidden_states)
Expand Down
Loading