diff --git a/python/sglang/multimodal_gen/runtime/models/dits/zimage.py b/python/sglang/multimodal_gen/runtime/models/dits/zimage.py index b14b634960e7..72aff751a066 100644 --- a/python/sglang/multimodal_gen/runtime/models/dits/zimage.py +++ b/python/sglang/multimodal_gen/runtime/models/dits/zimage.py @@ -1,4 +1,5 @@ import math +import os from typing import Any, List, Optional, Tuple import torch @@ -24,6 +25,12 @@ from sglang.multimodal_gen.runtime.utils.layerwise_offload import OffloadableDiTMixin from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger +try: + import aiter + AITER_AVAILABLE = True +except ImportError: + AITER_AVAILABLE = False + logger = init_logger(__name__) ADALN_EMBED_DIM = 256 @@ -142,53 +149,108 @@ def __init__( causal=False, ) + @staticmethod + @torch.compiler.disable + def _call_fused_rope_rms( + qkv: torch.Tensor, + norm_q_weight: torch.Tensor, + norm_k_weight: torch.Tensor, + cos_sin: torch.Tensor, + positions: torch.Tensor, + num_tokens: int, + num_heads_q: int, + num_heads_k: int, + num_heads_v: int, + head_dim: int, + is_neox_style: bool, + eps: float, + ) -> None: + """Wrapper for aiter.fused_rope_rms to prevent torch.compile decomposition.""" + aiter.fused_rope_rms( + qkv, + norm_q_weight, + norm_k_weight, + cos_sin, + positions, + num_tokens, + num_heads_q, + num_heads_k, + num_heads_v, + head_dim, + is_neox_style, + eps, + ) + def forward( self, hidden_states: torch.Tensor, freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + aiter_prepared: Optional[dict] = None, ): q, _ = self.to_q(hidden_states) k, _ = self.to_k(hidden_states) v, _ = self.to_v(hidden_states) - q = q.view(*q.shape[:-1], self.num_heads, self.head_dim) - k = k.view(*k.shape[:-1], self.num_kv_heads, self.head_dim) - v = v.view(*v.shape[:-1], self.num_kv_heads, self.head_dim) - - if self.qk_norm: - if ( - q.is_cuda - and (self.norm_q.variance_epsilon == self.norm_k.variance_epsilon) - and can_use_fused_inplace_qknorm(self.head_dim) - ): - q, k = apply_qk_norm( - q=q, - k=k, - q_norm=self.norm_q, - k_norm=self.norm_k, - head_dim=self.head_dim, - allow_inplace=True, - ) - else: - q = self.norm_q(q) - k = self.norm_k(k) - - if freqs_cis is not None: - cos, sin = freqs_cis - if q.is_cuda and q.shape == k.shape: - cos_sin_cache = torch.cat( - [ - cos.to(dtype=torch.float32).contiguous(), - sin.to(dtype=torch.float32).contiguous(), - ], - dim=-1, - ) - q, k = apply_flashinfer_rope_qk_inplace( - q, k, cos_sin_cache, is_neox=False - ) - else: - q = _apply_rotary_emb(q, cos, sin, is_neox_style=False) - k = _apply_rotary_emb(k, cos, sin, is_neox_style=False) + # Check if aiter parameters are pre-prepared (from outer layer) + use_aiter_fused = aiter_prepared is not None + + if use_aiter_fused: + # Use pre-prepared parameters to avoid redundant computation + B, L = hidden_states.shape[0], hidden_states.shape[1] + num_tokens = B * L + + # Concatenate Q, K, V: [B, L, (Hq*D + Hkv*D + Hkv*D)] + qkv = torch.cat([q, k, v], dim=-1) + # Reshape to [num_tokens, num_heads_total, head_dim] + qkv = qkv.view(num_tokens, self.num_heads + 2 * self.num_kv_heads, self.head_dim).contiguous() + + # Convert cos_sin to match qkv dtype (done once per layer now) + cos_sin = aiter_prepared['cos_sin'].to(qkv.dtype) + positions = aiter_prepared['positions'] + + # Call aiter's fused kernel (INPLACE modification of qkv) + self._call_fused_rope_rms( + qkv, + self.norm_q.weight, + self.norm_k.weight, + cos_sin, + positions, + num_tokens, + self.num_heads, + self.num_kv_heads, + self.num_kv_heads, + self.head_dim, + False, + self.norm_q.variance_epsilon, + ) + + # Unpack QKV after fusion + q_size = self.num_heads * self.head_dim + k_size = self.num_kv_heads * self.head_dim + v_size = self.num_kv_heads * self.head_dim + qkv_flat = qkv.view(num_tokens, q_size + k_size + v_size) + q, k, v = qkv_flat.split([q_size, k_size, v_size], dim=-1) + + # Reshape back to [B, L, num_heads, head_dim] + q = q.view(B, L, self.num_heads, self.head_dim) + k = k.view(B, L, self.num_kv_heads, self.head_dim) + v = v.view(B, L, self.num_kv_heads, self.head_dim) + else: + # Fallback to original implementation: separate QK norm and RoPE + q = q.view(*q.shape[:-1], self.num_heads, self.head_dim) + k = k.view(*k.shape[:-1], self.num_kv_heads, self.head_dim) + v = v.view(*v.shape[:-1], self.num_kv_heads, self.head_dim) + + # Apply QK normalization + q = self.norm_q(q) + k = self.norm_k(k) + + # Apply RoPE + if freqs_cis is not None: + cos, sin = freqs_cis + q, k = _apply_rotary_emb( + q, cos, sin, is_neox_style=False + ), _apply_rotary_emb(k, cos, sin, is_neox_style=False) hidden_states = self.attn(q, k, v) hidden_states = hidden_states.flatten(2) @@ -241,6 +303,7 @@ def forward( x: torch.Tensor, freqs_cis: Tuple[torch.Tensor, torch.Tensor], adaln_input: Optional[torch.Tensor] = None, + aiter_prepared: Optional[dict] = None, ): if self.modulation: assert adaln_input is not None @@ -255,6 +318,7 @@ def forward( attn_out = self.attention( self.attention_norm1(x) * scale_msa, freqs_cis=freqs_cis, + aiter_prepared=aiter_prepared, ) x = x + gate_msa * self.attention_norm2(attn_out) @@ -269,6 +333,7 @@ def forward( attn_out = self.attention( self.attention_norm1(x), freqs_cis=freqs_cis, + aiter_prepared=aiter_prepared, ) x = x + self.attention_norm2(attn_out) @@ -610,8 +675,24 @@ def forward( x = x.unsqueeze(0) x_freqs_cis = x_freqs_cis + + # Pre-prepare aiter parameters (done once for all layers) + use_aiter = ( + os.environ.get("SGLANG_USE_AITER", "0") == "1" + and AITER_AVAILABLE + and x.is_cuda + ) + x_aiter_prepared = None + if use_aiter and x_freqs_cis is not None: + cos, sin = x_freqs_cis + B, L = x.shape[0], x.shape[1] + x_aiter_prepared = { + 'cos_sin': torch.cat([cos, sin], dim=-1).contiguous(), + 'positions': torch.arange(L, device=x.device, dtype=torch.int64).repeat(B), + } + for layer in self.noise_refiner: - x = layer(x, x_freqs_cis, adaln_input) + x = layer(x, x_freqs_cis, adaln_input, aiter_prepared=x_aiter_prepared) cap_feats = torch.cat(cap_feats, dim=0) @@ -620,17 +701,38 @@ def forward( cap_freqs_cis = freqs_cis[0] cap_feats = cap_feats.unsqueeze(0) + + # Pre-prepare aiter parameters for caption features + cap_aiter_prepared = None + if use_aiter and cap_freqs_cis is not None: + cos, sin = cap_freqs_cis + B, L = cap_feats.shape[0], cap_feats.shape[1] + cap_aiter_prepared = { + 'cos_sin': torch.cat([cos, sin], dim=-1).contiguous(), + 'positions': torch.arange(L, device=cap_feats.device, dtype=torch.int64).repeat(B), + } + for layer in self.context_refiner: - cap_feats = layer(cap_feats, cap_freqs_cis) + cap_feats = layer(cap_feats, cap_freqs_cis, aiter_prepared=cap_aiter_prepared) unified = torch.cat([x, cap_feats], dim=1) unified_freqs_cis = ( torch.cat([x_freqs_cis[0], cap_freqs_cis[0]], dim=0), torch.cat([x_freqs_cis[1], cap_freqs_cis[1]], dim=0), ) + + # Pre-prepare aiter parameters for unified layers + unified_aiter_prepared = None + if use_aiter and unified_freqs_cis is not None: + cos, sin = unified_freqs_cis + B, L = unified.shape[0], unified.shape[1] + unified_aiter_prepared = { + 'cos_sin': torch.cat([cos, sin], dim=-1).contiguous(), + 'positions': torch.arange(L, device=unified.device, dtype=torch.int64).repeat(B), + } for layer in self.layers: - unified = layer(unified, unified_freqs_cis, adaln_input) + unified = layer(unified, unified_freqs_cis, adaln_input, aiter_prepared=unified_aiter_prepared) unified = self.all_final_layer[f"{patch_size}-{f_patch_size}"]( unified, adaln_input @@ -642,3 +744,4 @@ def forward( EntryClass = ZImageTransformer2DModel +