Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 144 additions & 41 deletions python/sglang/multimodal_gen/runtime/models/dits/zimage.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import math
import os
from typing import Any, List, Optional, Tuple

import torch
Expand All @@ -24,6 +25,12 @@
from sglang.multimodal_gen.runtime.utils.layerwise_offload import OffloadableDiTMixin
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger

try:
import aiter
AITER_AVAILABLE = True
except ImportError:
AITER_AVAILABLE = False

logger = init_logger(__name__)

ADALN_EMBED_DIM = 256
Expand Down Expand Up @@ -142,53 +149,108 @@ def __init__(
causal=False,
)

@staticmethod
@torch.compiler.disable
def _call_fused_rope_rms(
qkv: torch.Tensor,
norm_q_weight: torch.Tensor,
norm_k_weight: torch.Tensor,
cos_sin: torch.Tensor,
positions: torch.Tensor,
num_tokens: int,
num_heads_q: int,
num_heads_k: int,
num_heads_v: int,
head_dim: int,
is_neox_style: bool,
eps: float,
) -> None:
"""Wrapper for aiter.fused_rope_rms to prevent torch.compile decomposition."""
aiter.fused_rope_rms(
qkv,
norm_q_weight,
norm_k_weight,
cos_sin,
positions,
num_tokens,
num_heads_q,
num_heads_k,
num_heads_v,
head_dim,
is_neox_style,
eps,
)

def forward(
self,
hidden_states: torch.Tensor,
freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
aiter_prepared: Optional[dict] = None,
):
q, _ = self.to_q(hidden_states)
k, _ = self.to_k(hidden_states)
v, _ = self.to_v(hidden_states)

q = q.view(*q.shape[:-1], self.num_heads, self.head_dim)
k = k.view(*k.shape[:-1], self.num_kv_heads, self.head_dim)
v = v.view(*v.shape[:-1], self.num_kv_heads, self.head_dim)

if self.qk_norm:
if (
q.is_cuda
and (self.norm_q.variance_epsilon == self.norm_k.variance_epsilon)
and can_use_fused_inplace_qknorm(self.head_dim)
):
q, k = apply_qk_norm(
q=q,
k=k,
q_norm=self.norm_q,
k_norm=self.norm_k,
head_dim=self.head_dim,
allow_inplace=True,
)
else:
q = self.norm_q(q)
k = self.norm_k(k)

if freqs_cis is not None:
cos, sin = freqs_cis
if q.is_cuda and q.shape == k.shape:
cos_sin_cache = torch.cat(
[
cos.to(dtype=torch.float32).contiguous(),
sin.to(dtype=torch.float32).contiguous(),
],
dim=-1,
)
q, k = apply_flashinfer_rope_qk_inplace(
q, k, cos_sin_cache, is_neox=False
)
else:
q = _apply_rotary_emb(q, cos, sin, is_neox_style=False)
k = _apply_rotary_emb(k, cos, sin, is_neox_style=False)
# Check if aiter parameters are pre-prepared (from outer layer)
use_aiter_fused = aiter_prepared is not None

if use_aiter_fused:
# Use pre-prepared parameters to avoid redundant computation
B, L = hidden_states.shape[0], hidden_states.shape[1]
num_tokens = B * L

# Concatenate Q, K, V: [B, L, (Hq*D + Hkv*D + Hkv*D)]
qkv = torch.cat([q, k, v], dim=-1)
# Reshape to [num_tokens, num_heads_total, head_dim]
qkv = qkv.view(num_tokens, self.num_heads + 2 * self.num_kv_heads, self.head_dim).contiguous()

# Convert cos_sin to match qkv dtype (done once per layer now)
cos_sin = aiter_prepared['cos_sin'].to(qkv.dtype)
positions = aiter_prepared['positions']

# Call aiter's fused kernel (INPLACE modification of qkv)
self._call_fused_rope_rms(
qkv,
self.norm_q.weight,
self.norm_k.weight,
cos_sin,
positions,
num_tokens,
self.num_heads,
self.num_kv_heads,
self.num_kv_heads,
self.head_dim,
False,
self.norm_q.variance_epsilon,
)

# Unpack QKV after fusion
q_size = self.num_heads * self.head_dim
k_size = self.num_kv_heads * self.head_dim
v_size = self.num_kv_heads * self.head_dim
qkv_flat = qkv.view(num_tokens, q_size + k_size + v_size)
q, k, v = qkv_flat.split([q_size, k_size, v_size], dim=-1)

# Reshape back to [B, L, num_heads, head_dim]
q = q.view(B, L, self.num_heads, self.head_dim)
k = k.view(B, L, self.num_kv_heads, self.head_dim)
v = v.view(B, L, self.num_kv_heads, self.head_dim)
else:
# Fallback to original implementation: separate QK norm and RoPE
q = q.view(*q.shape[:-1], self.num_heads, self.head_dim)
k = k.view(*k.shape[:-1], self.num_kv_heads, self.head_dim)
v = v.view(*v.shape[:-1], self.num_kv_heads, self.head_dim)

# Apply QK normalization
q = self.norm_q(q)
k = self.norm_k(k)

# Apply RoPE
if freqs_cis is not None:
cos, sin = freqs_cis
q, k = _apply_rotary_emb(
q, cos, sin, is_neox_style=False
), _apply_rotary_emb(k, cos, sin, is_neox_style=False)

hidden_states = self.attn(q, k, v)
hidden_states = hidden_states.flatten(2)
Expand Down Expand Up @@ -241,6 +303,7 @@ def forward(
x: torch.Tensor,
freqs_cis: Tuple[torch.Tensor, torch.Tensor],
adaln_input: Optional[torch.Tensor] = None,
aiter_prepared: Optional[dict] = None,
):
if self.modulation:
assert adaln_input is not None
Expand All @@ -255,6 +318,7 @@ def forward(
attn_out = self.attention(
self.attention_norm1(x) * scale_msa,
freqs_cis=freqs_cis,
aiter_prepared=aiter_prepared,
)
x = x + gate_msa * self.attention_norm2(attn_out)

Expand All @@ -269,6 +333,7 @@ def forward(
attn_out = self.attention(
self.attention_norm1(x),
freqs_cis=freqs_cis,
aiter_prepared=aiter_prepared,
)
x = x + self.attention_norm2(attn_out)

Expand Down Expand Up @@ -610,8 +675,24 @@ def forward(

x = x.unsqueeze(0)
x_freqs_cis = x_freqs_cis

# Pre-prepare aiter parameters (done once for all layers)
use_aiter = (
os.environ.get("SGLANG_USE_AITER", "0") == "1"
and AITER_AVAILABLE
and x.is_cuda
)
x_aiter_prepared = None
if use_aiter and x_freqs_cis is not None:
cos, sin = x_freqs_cis
B, L = x.shape[0], x.shape[1]
x_aiter_prepared = {
'cos_sin': torch.cat([cos, sin], dim=-1).contiguous(),
'positions': torch.arange(L, device=x.device, dtype=torch.int64).repeat(B),
}

for layer in self.noise_refiner:
x = layer(x, x_freqs_cis, adaln_input)
x = layer(x, x_freqs_cis, adaln_input, aiter_prepared=x_aiter_prepared)

cap_feats = torch.cat(cap_feats, dim=0)

Expand All @@ -620,17 +701,38 @@ def forward(
cap_freqs_cis = freqs_cis[0]

cap_feats = cap_feats.unsqueeze(0)

# Pre-prepare aiter parameters for caption features
cap_aiter_prepared = None
if use_aiter and cap_freqs_cis is not None:
cos, sin = cap_freqs_cis
B, L = cap_feats.shape[0], cap_feats.shape[1]
cap_aiter_prepared = {
'cos_sin': torch.cat([cos, sin], dim=-1).contiguous(),
'positions': torch.arange(L, device=cap_feats.device, dtype=torch.int64).repeat(B),
}

for layer in self.context_refiner:
cap_feats = layer(cap_feats, cap_freqs_cis)
cap_feats = layer(cap_feats, cap_freqs_cis, aiter_prepared=cap_aiter_prepared)

unified = torch.cat([x, cap_feats], dim=1)
unified_freqs_cis = (
torch.cat([x_freqs_cis[0], cap_freqs_cis[0]], dim=0),
torch.cat([x_freqs_cis[1], cap_freqs_cis[1]], dim=0),
)

# Pre-prepare aiter parameters for unified layers
unified_aiter_prepared = None
if use_aiter and unified_freqs_cis is not None:
cos, sin = unified_freqs_cis
B, L = unified.shape[0], unified.shape[1]
unified_aiter_prepared = {
'cos_sin': torch.cat([cos, sin], dim=-1).contiguous(),
'positions': torch.arange(L, device=unified.device, dtype=torch.int64).repeat(B),
}

for layer in self.layers:
unified = layer(unified, unified_freqs_cis, adaln_input)
unified = layer(unified, unified_freqs_cis, adaln_input, aiter_prepared=unified_aiter_prepared)

unified = self.all_final_layer[f"{patch_size}-{f_patch_size}"](
unified, adaln_input
Expand All @@ -642,3 +744,4 @@ def forward(


EntryClass = ZImageTransformer2DModel