Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 109 additions & 0 deletions python/sglang/multimodal_gen/runtime/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/model_executor/layers/linear.py

from abc import abstractmethod
from fnmatch import fnmatch

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -49,6 +50,7 @@
is_cpu,
is_hip
)
from sglang.srt.environ import envs

_is_cpu_amx_available = cpu_has_amx_support()
_is_hip = is_hip()
Expand All @@ -60,6 +62,46 @@

logger = init_logger(__name__)

_qwen_image_a8w8_hipb_initialized = False
_QWEN_IMAGE_A8W8_OUTPROJ_BF16_BLOCKS = set(range(56, 60))
_QWEN_IMAGE_A8W8_DEFAULT_PATTERNS = [
"transformer_blocks.*.attn.to_qkv",
"transformer_blocks.*.attn.to_added_qkv",
"transformer_blocks.*.attn.to_q",
"transformer_blocks.*.attn.to_k",
"transformer_blocks.*.attn.to_v",
"transformer_blocks.*.attn.add_q_proj",
"transformer_blocks.*.attn.add_k_proj",
"transformer_blocks.*.attn.add_v_proj",
"transformer_blocks.*.img_mlp.*",
"transformer_blocks.*.txt_mlp.*",
*[
f"transformer_blocks.{i}.attn.to_out.0"
for i in range(60)
if i not in _QWEN_IMAGE_A8W8_OUTPROJ_BF16_BLOCKS
],
*[
f"transformer_blocks.{i}.attn.to_add_out"
for i in range(60)
if i not in _QWEN_IMAGE_A8W8_OUTPROJ_BF16_BLOCKS
],
]


def _qwen_image_a8w8_patterns() -> list[str]:
spec = envs.SGLANG_QWEN_IMAGE_A8W8_GEMM_PATTERNS.get()
if spec:
return [pattern.strip() for pattern in spec.split(",") if pattern.strip()]
return _QWEN_IMAGE_A8W8_DEFAULT_PATTERNS


def _use_qwen_image_a8w8_gemm(prefix: str) -> bool:
return (
_is_hip
and envs.SGLANG_QWEN_IMAGE_A8W8_GEMM.get()
and any(fnmatch(prefix, pattern) for pattern in _qwen_image_a8w8_patterns())
)

WEIGHT_LOADER_V2_SUPPORTED = [
"CompressedTensorsLinearMethod",
"AWQMarlinLinearMethod",
Expand Down Expand Up @@ -173,6 +215,55 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if _is_cpu and _is_cpu_amx_available:
_amx_process_weight_after_loading(layer, ["weight"])

if _use_qwen_image_a8w8_gemm(getattr(layer, "prefix", "")):
try:
global _qwen_image_a8w8_hipb_initialized
import aiter
from aiter.ops.shuffle import shuffle_weight

layout = (16, 16)
old_weight = layer.weight
weight = old_weight.detach()
if not AiterHipblaslt.can_shuffle(
weight.shape[0], weight.shape[1], layout
):
logger.warning(
"Skipping Qwen-Image A8W8 bpreshuffle GEMM for %s: "
"unsupported weight shape %s",
layer.prefix,
tuple(weight.shape),
)
self._qwen_image_a8w8_gemm = False
return

if (
not _qwen_image_a8w8_hipb_initialized
and hasattr(aiter, "hipb_create_extension")
):
aiter.hipb_create_extension()
_qwen_image_a8w8_hipb_initialized = True
weight_q, weight_scale = aiter.pertoken_quant(
weight.contiguous(), quant_dtype=aiter.dtypes.fp8
)
layer.weight = Parameter(
shuffle_weight(weight_q, layout).contiguous(), requires_grad=False
)
layer.register_buffer(
"weight_scale", weight_scale.contiguous(), persistent=False
)
self._qwen_image_a8w8_gemm = True
del old_weight, weight, weight_q, weight_scale
torch.cuda.empty_cache()
return
except Exception as exc:
logger.warning(
"Skipping Qwen-Image A8W8 bpreshuffle GEMM for %s: %s",
getattr(layer, "prefix", ""),
exc,
)
self._qwen_image_a8w8_gemm = False
return

if _use_aiter and get_bool_env_var("SGLANG_ROCM_USE_AITER_LINEAR_SHUFFLE"):
AiterHipblaslt._initialize_hipblaslt()
layout = (16, 16)
Expand All @@ -192,6 +283,24 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
def apply(
self, layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None = None
) -> torch.Tensor:
if getattr(self, "_qwen_image_a8w8_gemm", False):
import aiter

orig_shape = x.shape
x_2d = x.reshape(-1, x.size(-1)).contiguous()
x_q, x_scale = aiter.pertoken_quant(x_2d, quant_dtype=aiter.dtypes.fp8)
output = aiter.gemm_a8w8_bpreshuffle(
x_q,
layer.weight,
x_scale.contiguous(),
layer.weight_scale,
None,
x.dtype,
)
if bias is not None:
output = output + bias
return output.view(*orig_shape[:-1], layer.weight.shape[0])

if (
_use_aiter
and get_bool_env_var("SGLANG_ROCM_USE_AITER_LINEAR_SHUFFLE")
Expand Down
49 changes: 47 additions & 2 deletions python/sglang/multimodal_gen/runtime/models/dits/qwen_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,11 @@
logger = init_logger(__name__) # pylint: disable=invalid-name
_is_cuda = current_platform.is_cuda()


def _qwen_image_use_fused_qkv() -> bool:
return envs.SGLANG_QWEN_IMAGE_FUSED_QKV.get()


try:
from nunchaku.models.attention import NunchakuFeedForward # type: ignore[import]
except Exception:
Expand Down Expand Up @@ -594,7 +599,9 @@ def __init__(
self.added_kv_proj_dim = added_kv_proj_dim
self.prefix = prefix

self.use_fused_qkv = isinstance(quant_config, NunchakuConfig)
self.use_fused_qkv = isinstance(quant_config, NunchakuConfig) or (
quant_config is None and _qwen_image_use_fused_qkv()
)

self.inner_dim = out_dim if out_dim is not None else head_dim * num_heads
self.inner_kv_dim = self.inner_dim
Expand Down Expand Up @@ -637,7 +644,9 @@ def __init__(
self.norm_k = RMSNorm(head_dim, eps=eps) if qk_norm else nn.Identity()

if added_kv_proj_dim is not None:
self.use_fused_added_qkv = isinstance(quant_config, NunchakuConfig)
self.use_fused_added_qkv = isinstance(quant_config, NunchakuConfig) or (
quant_config is None and _qwen_image_use_fused_qkv()
)
if self.use_fused_added_qkv:
self.to_added_qkv = MergedColumnParallelLinear(
added_kv_proj_dim,
Expand Down Expand Up @@ -1209,6 +1218,42 @@ def __init__(
self.zero_cond_t = getattr(config.arch_config, "zero_cond_t", False)
self.out_channels = out_channels or in_channels
self.inner_dim = num_attention_heads * attention_head_dim
self.param_names_mapping = dict(self.param_names_mapping)
if quant_config is None and _qwen_image_use_fused_qkv():
self.param_names_mapping.update(
{
r"^(transformer_blocks\.\d+\.attn\.)to_q\.(.+)$": (
r"\1to_qkv.\2",
0,
3,
),
r"^(transformer_blocks\.\d+\.attn\.)to_k\.(.+)$": (
r"\1to_qkv.\2",
1,
3,
),
r"^(transformer_blocks\.\d+\.attn\.)to_v\.(.+)$": (
r"\1to_qkv.\2",
2,
3,
),
r"^(transformer_blocks\.\d+\.attn\.)add_q_proj\.(.+)$": (
r"\1to_added_qkv.\2",
0,
3,
),
r"^(transformer_blocks\.\d+\.attn\.)add_k_proj\.(.+)$": (
r"\1to_added_qkv.\2",
1,
3,
),
r"^(transformer_blocks\.\d+\.attn\.)add_v_proj\.(.+)$": (
r"\1to_added_qkv.\2",
2,
3,
),
}
)

self.use_additional_t_cond: bool = getattr(
config.arch_config, "use_additional_t_cond", False
Expand Down
3 changes: 3 additions & 0 deletions python/sglang/srt/environ.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,9 @@ class Envs:
SGLANG_USE_AITER_FA_ROUND_MODE = EnvBool(False)
SGLANG_ENABLE_FUSED_ROPE_RMS_2WAY = EnvBool(False)
SGLANG_ROCM_USE_AITER_LINEAR_SHUFFLE = EnvBool(False)
SGLANG_QWEN_IMAGE_A8W8_GEMM = EnvBool(False)
SGLANG_QWEN_IMAGE_A8W8_GEMM_PATTERNS = EnvStr("")
SGLANG_QWEN_IMAGE_FUSED_QKV = EnvBool(False)

# NPU
SGLANG_NPU_DISABLE_ACL_FORMAT_WEIGHT = EnvBool(False)
Expand Down
Loading