From 854007d77f94a5144932ea513ba98588a9fcf959 Mon Sep 17 00:00:00 2001 From: LiuYinfeng01 Date: Fri, 29 May 2026 16:32:05 +0800 Subject: [PATCH] [Perf] Support Qwen-Image-Edit A8W8 GEMM and fused QKV --- .../multimodal_gen/runtime/layers/linear.py | 109 ++++++++++++++++++ .../runtime/models/dits/qwen_image.py | 49 +++++++- python/sglang/srt/environ.py | 3 + 3 files changed, 159 insertions(+), 2 deletions(-) diff --git a/python/sglang/multimodal_gen/runtime/layers/linear.py b/python/sglang/multimodal_gen/runtime/layers/linear.py index ca5325e97964..8c4670785bfe 100644 --- a/python/sglang/multimodal_gen/runtime/layers/linear.py +++ b/python/sglang/multimodal_gen/runtime/layers/linear.py @@ -4,6 +4,7 @@ # Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/model_executor/layers/linear.py from abc import abstractmethod +from fnmatch import fnmatch import torch import torch.distributed as dist @@ -49,6 +50,7 @@ is_cpu, is_hip ) +from sglang.srt.environ import envs _is_cpu_amx_available = cpu_has_amx_support() _is_hip = is_hip() @@ -60,6 +62,46 @@ logger = init_logger(__name__) +_qwen_image_a8w8_hipb_initialized = False +_QWEN_IMAGE_A8W8_OUTPROJ_BF16_BLOCKS = set(range(56, 60)) +_QWEN_IMAGE_A8W8_DEFAULT_PATTERNS = [ + "transformer_blocks.*.attn.to_qkv", + "transformer_blocks.*.attn.to_added_qkv", + "transformer_blocks.*.attn.to_q", + "transformer_blocks.*.attn.to_k", + "transformer_blocks.*.attn.to_v", + "transformer_blocks.*.attn.add_q_proj", + "transformer_blocks.*.attn.add_k_proj", + "transformer_blocks.*.attn.add_v_proj", + "transformer_blocks.*.img_mlp.*", + "transformer_blocks.*.txt_mlp.*", + *[ + f"transformer_blocks.{i}.attn.to_out.0" + for i in range(60) + if i not in _QWEN_IMAGE_A8W8_OUTPROJ_BF16_BLOCKS + ], + *[ + f"transformer_blocks.{i}.attn.to_add_out" + for i in range(60) + if i not in _QWEN_IMAGE_A8W8_OUTPROJ_BF16_BLOCKS + ], +] + + +def _qwen_image_a8w8_patterns() -> list[str]: + spec = envs.SGLANG_QWEN_IMAGE_A8W8_GEMM_PATTERNS.get() + if spec: + return [pattern.strip() for pattern in spec.split(",") if pattern.strip()] + return _QWEN_IMAGE_A8W8_DEFAULT_PATTERNS + + +def _use_qwen_image_a8w8_gemm(prefix: str) -> bool: + return ( + _is_hip + and envs.SGLANG_QWEN_IMAGE_A8W8_GEMM.get() + and any(fnmatch(prefix, pattern) for pattern in _qwen_image_a8w8_patterns()) + ) + WEIGHT_LOADER_V2_SUPPORTED = [ "CompressedTensorsLinearMethod", "AWQMarlinLinearMethod", @@ -173,6 +215,55 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: if _is_cpu and _is_cpu_amx_available: _amx_process_weight_after_loading(layer, ["weight"]) + if _use_qwen_image_a8w8_gemm(getattr(layer, "prefix", "")): + try: + global _qwen_image_a8w8_hipb_initialized + import aiter + from aiter.ops.shuffle import shuffle_weight + + layout = (16, 16) + old_weight = layer.weight + weight = old_weight.detach() + if not AiterHipblaslt.can_shuffle( + weight.shape[0], weight.shape[1], layout + ): + logger.warning( + "Skipping Qwen-Image A8W8 bpreshuffle GEMM for %s: " + "unsupported weight shape %s", + layer.prefix, + tuple(weight.shape), + ) + self._qwen_image_a8w8_gemm = False + return + + if ( + not _qwen_image_a8w8_hipb_initialized + and hasattr(aiter, "hipb_create_extension") + ): + aiter.hipb_create_extension() + _qwen_image_a8w8_hipb_initialized = True + weight_q, weight_scale = aiter.pertoken_quant( + weight.contiguous(), quant_dtype=aiter.dtypes.fp8 + ) + layer.weight = Parameter( + shuffle_weight(weight_q, layout).contiguous(), requires_grad=False + ) + layer.register_buffer( + "weight_scale", weight_scale.contiguous(), persistent=False + ) + self._qwen_image_a8w8_gemm = True + del old_weight, weight, weight_q, weight_scale + torch.cuda.empty_cache() + return + except Exception as exc: + logger.warning( + "Skipping Qwen-Image A8W8 bpreshuffle GEMM for %s: %s", + getattr(layer, "prefix", ""), + exc, + ) + self._qwen_image_a8w8_gemm = False + return + if _use_aiter and get_bool_env_var("SGLANG_ROCM_USE_AITER_LINEAR_SHUFFLE"): AiterHipblaslt._initialize_hipblaslt() layout = (16, 16) @@ -192,6 +283,24 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def apply( self, layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None = None ) -> torch.Tensor: + if getattr(self, "_qwen_image_a8w8_gemm", False): + import aiter + + orig_shape = x.shape + x_2d = x.reshape(-1, x.size(-1)).contiguous() + x_q, x_scale = aiter.pertoken_quant(x_2d, quant_dtype=aiter.dtypes.fp8) + output = aiter.gemm_a8w8_bpreshuffle( + x_q, + layer.weight, + x_scale.contiguous(), + layer.weight_scale, + None, + x.dtype, + ) + if bias is not None: + output = output + bias + return output.view(*orig_shape[:-1], layer.weight.shape[0]) + if ( _use_aiter and get_bool_env_var("SGLANG_ROCM_USE_AITER_LINEAR_SHUFFLE") diff --git a/python/sglang/multimodal_gen/runtime/models/dits/qwen_image.py b/python/sglang/multimodal_gen/runtime/models/dits/qwen_image.py index b57241c54473..3f6dedf0e55f 100644 --- a/python/sglang/multimodal_gen/runtime/models/dits/qwen_image.py +++ b/python/sglang/multimodal_gen/runtime/models/dits/qwen_image.py @@ -58,6 +58,11 @@ logger = init_logger(__name__) # pylint: disable=invalid-name _is_cuda = current_platform.is_cuda() + +def _qwen_image_use_fused_qkv() -> bool: + return envs.SGLANG_QWEN_IMAGE_FUSED_QKV.get() + + try: from nunchaku.models.attention import NunchakuFeedForward # type: ignore[import] except Exception: @@ -594,7 +599,9 @@ def __init__( self.added_kv_proj_dim = added_kv_proj_dim self.prefix = prefix - self.use_fused_qkv = isinstance(quant_config, NunchakuConfig) + self.use_fused_qkv = isinstance(quant_config, NunchakuConfig) or ( + quant_config is None and _qwen_image_use_fused_qkv() + ) self.inner_dim = out_dim if out_dim is not None else head_dim * num_heads self.inner_kv_dim = self.inner_dim @@ -637,7 +644,9 @@ def __init__( self.norm_k = RMSNorm(head_dim, eps=eps) if qk_norm else nn.Identity() if added_kv_proj_dim is not None: - self.use_fused_added_qkv = isinstance(quant_config, NunchakuConfig) + self.use_fused_added_qkv = isinstance(quant_config, NunchakuConfig) or ( + quant_config is None and _qwen_image_use_fused_qkv() + ) if self.use_fused_added_qkv: self.to_added_qkv = MergedColumnParallelLinear( added_kv_proj_dim, @@ -1209,6 +1218,42 @@ def __init__( self.zero_cond_t = getattr(config.arch_config, "zero_cond_t", False) self.out_channels = out_channels or in_channels self.inner_dim = num_attention_heads * attention_head_dim + self.param_names_mapping = dict(self.param_names_mapping) + if quant_config is None and _qwen_image_use_fused_qkv(): + self.param_names_mapping.update( + { + r"^(transformer_blocks\.\d+\.attn\.)to_q\.(.+)$": ( + r"\1to_qkv.\2", + 0, + 3, + ), + r"^(transformer_blocks\.\d+\.attn\.)to_k\.(.+)$": ( + r"\1to_qkv.\2", + 1, + 3, + ), + r"^(transformer_blocks\.\d+\.attn\.)to_v\.(.+)$": ( + r"\1to_qkv.\2", + 2, + 3, + ), + r"^(transformer_blocks\.\d+\.attn\.)add_q_proj\.(.+)$": ( + r"\1to_added_qkv.\2", + 0, + 3, + ), + r"^(transformer_blocks\.\d+\.attn\.)add_k_proj\.(.+)$": ( + r"\1to_added_qkv.\2", + 1, + 3, + ), + r"^(transformer_blocks\.\d+\.attn\.)add_v_proj\.(.+)$": ( + r"\1to_added_qkv.\2", + 2, + 3, + ), + } + ) self.use_additional_t_cond: bool = getattr( config.arch_config, "use_additional_t_cond", False diff --git a/python/sglang/srt/environ.py b/python/sglang/srt/environ.py index a0f3322c582d..f588515e8864 100644 --- a/python/sglang/srt/environ.py +++ b/python/sglang/srt/environ.py @@ -305,6 +305,9 @@ class Envs: SGLANG_USE_AITER_FA_ROUND_MODE = EnvBool(False) SGLANG_ENABLE_FUSED_ROPE_RMS_2WAY = EnvBool(False) SGLANG_ROCM_USE_AITER_LINEAR_SHUFFLE = EnvBool(False) + SGLANG_QWEN_IMAGE_A8W8_GEMM = EnvBool(False) + SGLANG_QWEN_IMAGE_A8W8_GEMM_PATTERNS = EnvStr("") + SGLANG_QWEN_IMAGE_FUSED_QKV = EnvBool(False) # NPU SGLANG_NPU_DISABLE_ACL_FORMAT_WEIGHT = EnvBool(False)