From b509db5832c96e2a47dd82d500bc9d4c855c9b4c Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Sun, 24 Nov 2024 20:09:27 +0800 Subject: [PATCH] feat: remove the dependency on FusedMoE (#2153) --- .../srt/layers/quantization/__init__.py | 20 +- .../srt/layers/triton_fused_moe/__init__.py | 44 + .../layers/triton_fused_moe/configs/README | 10 + .../srt/layers/triton_fused_moe/fused_moe.py | 858 ++++++++++++++++++ .../srt/layers/triton_fused_moe/layer.py | 631 +++++++++++++ python/sglang/srt/models/deepseek_v2.py | 2 +- python/sglang/srt/utils.py | 44 +- 7 files changed, 1602 insertions(+), 7 deletions(-) create mode 100644 python/sglang/srt/layers/triton_fused_moe/__init__.py create mode 100644 python/sglang/srt/layers/triton_fused_moe/configs/README create mode 100644 python/sglang/srt/layers/triton_fused_moe/fused_moe.py create mode 100644 python/sglang/srt/layers/triton_fused_moe/layer.py diff --git a/python/sglang/srt/layers/quantization/__init__.py b/python/sglang/srt/layers/quantization/__init__.py index 88a05c6d000..584ae0d89e4 100644 --- a/python/sglang/srt/layers/quantization/__init__.py +++ b/python/sglang/srt/layers/quantization/__init__.py @@ -57,12 +57,23 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]: "QUANTIZATION_METHODS", ] -""" -def fp8_get_quant_method( - self, layer: torch.nn.Module, prefix: str -) -> Optional["QuantizeMethodBase"]: + +def fp8_get_quant_method(self, layer, prefix): + from vllm.model_executor.layers.linear import LinearBase + from vllm.model_executor.layers.quantization.fp8 import ( + Fp8LinearMethod, + Fp8MoEMethod, + ) + from vllm.model_executor.layers.quantization.utils.quant_utils import ( + is_layer_skipped, + ) + + from sglang.srt.layers.triton_fused_moe.layer import FusedMoE + if isinstance(layer, LinearBase): if is_layer_skipped(prefix, self.ignored_layers): + from sglang.srt.layers.linear import UnquantizedLinearMethod + return UnquantizedLinearMethod() return Fp8LinearMethod(self) elif isinstance(layer, FusedMoE): @@ -71,4 +82,3 @@ def fp8_get_quant_method( setattr(Fp8Config, "get_quant_method", fp8_get_quant_method) -""" diff --git a/python/sglang/srt/layers/triton_fused_moe/__init__.py b/python/sglang/srt/layers/triton_fused_moe/__init__.py new file mode 100644 index 00000000000..b2eb118358d --- /dev/null +++ b/python/sglang/srt/layers/triton_fused_moe/__init__.py @@ -0,0 +1,44 @@ +from contextlib import contextmanager +from typing import Any, Dict, Optional + +import sglang.srt.layers.triton_fused_moe.fused_moe # noqa +from sglang.srt.layers.triton_fused_moe.fused_moe import ( + fused_experts, + fused_topk, + get_config_file_name, + grouped_topk, +) +from sglang.srt.layers.triton_fused_moe.layer import ( + FusedMoE, + FusedMoEMethodBase, + FusedMoeWeightScaleSupported, +) + +_config: Optional[Dict[str, Any]] = None + + +@contextmanager +def override_config(config): + global _config + old_config = _config + _config = config + yield + _config = old_config + + +def get_config() -> Optional[Dict[str, Any]]: + return _config + + +__all__ = [ + "FusedMoE", + "FusedMoEMethodBase", + "FusedMoeWeightScaleSupported", + "override_config", + "get_config", + "fused_moe", + "fused_topk", + "fused_experts", + "get_config_file_name", + "grouped_topk", +] diff --git a/python/sglang/srt/layers/triton_fused_moe/configs/README b/python/sglang/srt/layers/triton_fused_moe/configs/README new file mode 100644 index 00000000000..45d40cbfb1a --- /dev/null +++ b/python/sglang/srt/layers/triton_fused_moe/configs/README @@ -0,0 +1,10 @@ +This directory contains tuned configurations for different settings of the fused_moe kernel. +For different settings of +- E (number of experts) +- N (intermediate size) +- device_name (torch.cuda.get_device_name()) +the JSON file contains a mapping from M (batch size) to the chosen configuration. + +The example configurations provided are for the Mixtral model for TP2 on H100 +and TP4 on A100. Mixtral has intermediate size N = 14336, i.e. for TP2 we have +N = 7168 and for TP4 we have N = 3584. diff --git a/python/sglang/srt/layers/triton_fused_moe/fused_moe.py b/python/sglang/srt/layers/triton_fused_moe/fused_moe.py new file mode 100644 index 00000000000..8a2c7257bc0 --- /dev/null +++ b/python/sglang/srt/layers/triton_fused_moe/fused_moe.py @@ -0,0 +1,858 @@ +"""Fused MoE kernel.""" + +import functools +import json +import logging +import os +from typing import Any, Callable, Dict, Optional, Tuple + +import torch +import triton +import triton.language as tl +from vllm import _custom_ops as ops + +from sglang.srt.utils import direct_register_custom_op, get_device_name + +logger = logging.getLogger(__name__) + + +@triton.jit +def fused_moe_kernel( + # Pointers to matrices + a_ptr, + b_ptr, + c_ptr, + a_scale_ptr, + b_scale_ptr, + topk_weights_ptr, + sorted_token_ids_ptr, + expert_ids_ptr, + num_tokens_post_padded_ptr, + # Matrix dimensions + N, + K, + EM, + num_valid_tokens, + # The stride variables represent how much to increase the ptr by when + # moving by 1 element in a particular dimension. E.g. `stride_am` is + # how much to increase `a_ptr` by to get the element one row down + # (A has M rows). + stride_am, + stride_ak, + stride_be, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_bse, + stride_bsn, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + MUL_ROUTED_WEIGHT: tl.constexpr, + top_k: tl.constexpr, + compute_type: tl.constexpr, + use_fp8_w8a8: tl.constexpr, + use_int8_w8a16: tl.constexpr, +): + """ + Implements the fused computation for a Mixture of Experts (MOE) using + token and expert matrices. + + Key Parameters: + - A: The input tensor representing tokens with shape (*, K), where '*' can + be any shape representing batches and K is the feature dimension of + each token. + - B: The stacked MOE weight tensor with shape (E, N, K), where E is + the number of experts, K is the input feature dimension, and N is + the output feature dimension. + - C: The output cache tensor with shape (M, topk, N), where M is the + total number of tokens post padding, topk is the number of times + each token is repeated, and N is the output feature dimension. + - sorted_token_ids: A tensor containing the sorted indices of tokens, + repeated topk times and arranged by the expert index they are + assigned to. + - expert_ids: A tensor containing the indices of the expert for each + block. It determines which expert matrix from B should be used for + each block in A. + This kernel performs the multiplication of a token by its corresponding + expert matrix as determined by `expert_ids`. The sorting of + `sorted_token_ids` by expert index and padding ensures divisibility by + BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix + multiplication across different blocks processed by the same expert. + """ + # ----------------------------------------------------------- + # Map program ids `pid` to the block of C it should compute. + # This is done in a grouped ordering to promote L2 data reuse. + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + # ---------------------------------------------------------- + # Create pointers for the first blocks of A and B. + # We will advance this pointer as we move in the K direction + # and accumulate + # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers + # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers + num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) + if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: + return + offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) + token_mask = offs_token < num_valid_tokens + + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + ( + offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak + ) + + off_experts = tl.load(expert_ids_ptr + pid_m) + b_ptrs = ( + b_ptr + + off_experts * stride_be + + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + ) + if use_int8_w8a16: + b_scale_ptrs = ( + b_scale_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsn + ) + b_scale = tl.load(b_scale_ptrs) + + if use_fp8_w8a8: + a_scale = tl.load(a_scale_ptr) + b_scale = tl.load(b_scale_ptr + off_experts) + + # ----------------------------------------------------------- + # Iterate to compute a block of the C matrix. + # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block + # of fp32 values for higher accuracy. + # `accumulator` will be converted back to fp16 after the loop. + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + # Load the next block of A and B, generate a mask by checking the + # K dimension. + a = tl.load( + a_ptrs, + mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K), + other=0.0, + ) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + # We accumulate along the K dimension. + if use_int8_w8a16: + accumulator = tl.dot(a, b.to(compute_type), acc=accumulator) + elif use_fp8_w8a8: + accumulator = tl.dot(a, b, acc=accumulator) + else: + accumulator += tl.dot(a, b) + # Advance the ptrs to the next K block. + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + if MUL_ROUTED_WEIGHT: + moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0) + accumulator = accumulator * moe_weight[:, None] + if use_int8_w8a16: + accumulator = (accumulator * b_scale).to(compute_type) + elif use_fp8_w8a8: + accumulator = (accumulator * a_scale * b_scale).to(compute_type) + else: + accumulator = accumulator.to(compute_type) + # ----------------------------------------------------------- + # Write back the block of the output + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :] + c_mask = token_mask[:, None] & (offs_cn[None, :] < N) + tl.store(c_ptrs, accumulator, mask=c_mask) + + +def moe_align_block_size( + topk_ids: torch.Tensor, block_size: int, num_experts: int +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Aligns the token distribution across experts to be compatible with block + size for matrix multiplication. + + Parameters: + - topk_ids: A tensor of shape [total_tokens, top_k] representing the + top-k expert indices for each token. + - block_size: The block size used in block matrix multiplication. + - num_experts: The total number of experts. + + Returns: + - sorted_token_ids: A tensor containing the sorted token indices according + to their allocated expert. + - expert_ids: A tensor indicating the assigned expert index for each block. + - num_tokens_post_padded: The total number of tokens after padding, + ensuring divisibility by block_size. + + This function pads the number of tokens that each expert needs to process + so that it is divisible by block_size. + Padding ensures that during block matrix multiplication, the dimensions + align correctly. + + Example: + Given topk_ids = [[2, 3, 4], [1, 2, 4], [1, 3, 4], [1, 2, 3]], + block_size = 4, and num_experts = 4: + - We initially have 12 tokens (after repeating 'top_k' times) and 4 experts, + with each expert needing to process 3 tokens. + - As block_size is 4, we pad 1 token for each expert. + - First, flatten topk_ids to [2, 3, 4, 1, 2, 4, 1, 3, 4, 1, 2, 3]. + - Then append padding tokens [12, 12, 12, 12] for each block. + - After sorting by expert index, we obtain token_ids + [3, 6, 9, 12, 0, 4, 10, 12, 1, 7, 11, 12, 2, 5, 8, 12]. + Tokens 12 are non-existent (padding) and are ignored in + the subsequent matrix multiplication. + - The padding ensures that the total number of tokens is now divisible + by block_size for proper block matrix operations. + """ + max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) + sorted_ids = torch.empty( + (max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device + ) + sorted_ids.fill_(topk_ids.numel()) + max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size) + expert_ids = torch.empty( + (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device + ) + num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device) + ops.moe_align_block_size( + topk_ids, num_experts, block_size, sorted_ids, expert_ids, num_tokens_post_pad + ) + return sorted_ids, expert_ids, num_tokens_post_pad + + +def invoke_fused_moe_kernel( + A: torch.Tensor, + B: torch.Tensor, + C: torch.Tensor, + A_scale: Optional[torch.Tensor], + B_scale: Optional[torch.Tensor], + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + sorted_token_ids: torch.Tensor, + expert_ids: torch.Tensor, + num_tokens_post_padded: torch.Tensor, + mul_routed_weight: bool, + top_k: int, + config: Dict[str, Any], + compute_type: tl.dtype, + use_fp8_w8a8: bool, + use_int8_w8a16: bool, +) -> None: + assert topk_weights.stride(1) == 1 + assert sorted_token_ids.stride(0) == 1 + + if use_fp8_w8a8: + A, A_scale = ops.scaled_fp8_quant(A, A_scale) + assert B_scale is not None + elif use_int8_w8a16: + assert B_scale is not None + else: + assert A_scale is None + assert B_scale is None + + grid = lambda META: ( + triton.cdiv(sorted_token_ids.shape[0], META["BLOCK_SIZE_M"]) + * triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]), + ) + + fused_moe_kernel[grid]( + A, + B, + C, + A_scale, + B_scale, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + B.shape[1], + B.shape[2], + sorted_token_ids.shape[0], + topk_ids.numel(), + A.stride(0), + A.stride(1), + B.stride(0), + B.stride(2), + B.stride(1), + C.stride(1), + C.stride(2), + B_scale.stride(0) if B_scale is not None and use_int8_w8a16 else 0, + B_scale.stride(1) if B_scale is not None and use_int8_w8a16 else 0, + MUL_ROUTED_WEIGHT=mul_routed_weight, + top_k=top_k, + compute_type=compute_type, + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a16=use_int8_w8a16, + **config, + ) + + +def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str: + device_name = get_device_name().replace(" ", "_") + dtype_selector = "" if not dtype else f",dtype={dtype}" + return f"E={E},N={N},device_name={device_name}{dtype_selector}.json" + + +@functools.lru_cache +def get_moe_configs(E: int, N: int, dtype: Optional[str]) -> Optional[Dict[int, Any]]: + """ + Return optimized configurations for the fused MoE kernel. + + The return value will be a dictionary that maps an irregular grid of + batch sizes to configurations of the fused_moe kernel. To evaluate the + kernel on a given batch size bs, the closest batch size in the grid should + be picked and the associated configuration chosen to invoke the kernel. + """ + + # First look up if an optimized configuration is available in the configs + # directory + json_file_name = get_config_file_name(E, N, dtype) + + config_file_path = os.path.join( + os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name + ) + if os.path.exists(config_file_path): + with open(config_file_path) as f: + logger.info("Using configuration from %s for MoE layer.", config_file_path) + # If a configuration has been found, return it + return {int(key): val for key, val in json.load(f).items()} + + # If no optimized configuration is available, we will use the default + # configuration + logger.warning( + ( + "Using default MoE config. Performance might be sub-optimal! " + "Config file not found at %s" + ), + config_file_path, + ) + return None + + +def get_default_config( + M: int, + E: int, + N: int, + K: int, + topk: int, + dtype: Optional[str], + is_marlin: bool, +) -> Dict[str, int]: + config = { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + } + # A heuristic: fused marlin works faster with this config for small M + if M <= E or (is_marlin and M <= 32): + config = { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + } + return config + + +def try_get_optimal_moe_config( + w1_shape: Tuple[int, ...], + w2_shape: Tuple[int, ...], + top_k: int, + dtype: Optional[str], + M: int, + is_marlin: bool = False, +): + from sglang.srt.layers.triton_fused_moe import get_config + + override_config = get_config() + if override_config: + config = override_config + else: + # First try to load optimal config from the file + E, _, N = w2_shape + configs = get_moe_configs(E, N, dtype) + + if configs: + # If an optimal configuration map has been found, look up the + # optimal config + config = configs[min(configs.keys(), key=lambda x: abs(x - M))] + else: + # Else use the default config + config = get_default_config(M, E, N, w1_shape[2], top_k, dtype, is_marlin) + return config + + +def fused_topk( + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, +): + assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" + + M, _ = hidden_states.shape + + topk_weights = torch.empty( + M, topk, dtype=torch.float32, device=hidden_states.device + ) + topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device) + token_expert_indicies = torch.empty( + M, topk, dtype=torch.int32, device=hidden_states.device + ) + + ops.topk_softmax( + topk_weights, + topk_ids, + token_expert_indicies, + gating_output.float(), # TODO(woosuk): Optimize this. + ) + del token_expert_indicies # Not used. Will be used in the future. + + if renormalize: + topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + + return topk_weights, topk_ids + + +# This is used by the Deepseek-V2 model +def grouped_topk( + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + num_expert_group: int = 0, + topk_group: int = 0, +): + + assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" + + scores = torch.softmax(gating_output, dim=-1) + num_token = scores.shape[0] + group_scores = ( + scores.view(num_token, num_expert_group, -1).max(dim=-1).values + ) # [n, n_group] + group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[ + 1 + ] # [n, top_k_group] + group_mask = torch.zeros_like(group_scores) # [n, n_group] + group_mask.scatter_(1, group_idx, 1) # [n, n_group] + score_mask = ( + group_mask.unsqueeze(-1) + .expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group) + .reshape(num_token, -1) + ) # [n, e] + tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e] + topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False) + + if renormalize: + topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + + return topk_weights.to(torch.float32), topk_ids.to(torch.int32) + + +def get_config_dtype_str( + dtype: torch.dtype, + use_int8_w8a16: Optional[bool] = False, + use_fp8_w8a8: Optional[bool] = False, +): + if use_fp8_w8a8: + return "fp8_w8a8" + elif use_int8_w8a16: + return "int8_w8a16" + elif dtype == torch.float: + # avoiding cases where kernel fails when float32 MoE + # use fp16/bfloat16 configs + return "float32" + return None + + +def inplace_fused_experts( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + use_fp8_w8a8: bool = False, + use_int8_w8a16: bool = False, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, +) -> None: + fused_experts_impl( + hidden_states, + w1, + w2, + topk_weights, + topk_ids, + True, + use_fp8_w8a8, + use_int8_w8a16, + w1_scale, + w2_scale, + a1_scale, + a2_scale, + ) + + +def inplace_fused_experts_fake( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + use_fp8_w8a8: bool = False, + use_int8_w8a16: bool = False, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, +) -> None: + pass + + +direct_register_custom_op( + op_name="inplace_fused_experts", + op_func=inplace_fused_experts, + mutates_args=["hidden_states"], + fake_impl=inplace_fused_experts_fake, +) + + +def outplace_fused_experts( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + use_fp8_w8a8: bool = False, + use_int8_w8a16: bool = False, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, +) -> torch.Tensor: + return fused_experts_impl( + hidden_states, + w1, + w2, + topk_weights, + topk_ids, + False, + use_fp8_w8a8, + use_int8_w8a16, + w1_scale, + w2_scale, + a1_scale, + a2_scale, + ) + + +def outplace_fused_experts_fake( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + use_fp8_w8a8: bool = False, + use_int8_w8a16: bool = False, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, +) -> torch.Tensor: + return torch.empty_like(hidden_states) + + +direct_register_custom_op( + op_name="outplace_fused_experts", + op_func=outplace_fused_experts, + mutates_args=[], + fake_impl=outplace_fused_experts_fake, +) + + +def fused_experts( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + inplace: bool = False, + use_fp8_w8a8: bool = False, + use_int8_w8a16: bool = False, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, +): + if inplace: + torch.ops.sglang.inplace_fused_experts( + hidden_states, + w1, + w2, + topk_weights, + topk_ids, + use_fp8_w8a8, + use_int8_w8a16, + w1_scale, + w2_scale, + a1_scale, + a2_scale, + ) + return hidden_states + else: + return torch.ops.sglang.outplace_fused_experts( + hidden_states, + w1, + w2, + topk_weights, + topk_ids, + use_fp8_w8a8, + use_int8_w8a16, + w1_scale, + w2_scale, + a1_scale, + a2_scale, + ) + + +def fused_experts_impl( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + inplace: bool = False, + use_fp8_w8a8: bool = False, + use_int8_w8a16: bool = False, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, +): + # Check constraints. + assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" + assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" + assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" + assert w1.is_contiguous(), "Expert weights1 must be contiguous" + assert w2.is_contiguous(), "Expert weights2 must be contiguous" + assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16] + + num_tokens, _ = hidden_states.shape + E, N, _ = w1.shape + # We execute the fused_moe kernel in chunks to circumvent this issue: + # https://github.com/vllm-project/vllm/issues/5938 + CHUNK_SIZE = 64 * 1024 + M = min(num_tokens, CHUNK_SIZE) + config_dtype = get_config_dtype_str( + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a16=use_int8_w8a16, + dtype=hidden_states.dtype, + ) + + get_config_func = functools.partial( + try_get_optimal_moe_config, + w1.shape, + w2.shape, + topk_ids.shape[1], + config_dtype, + ) + + config = get_config_func(M) + + intermediate_cache1 = torch.empty( + (M, topk_ids.shape[1], N), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + intermediate_cache2 = torch.empty( + (M * topk_ids.shape[1], N // 2), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + intermediate_cache3 = torch.empty( + (M, topk_ids.shape[1], w2.shape[1]), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + + compute_type = tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16 + + if inplace: + out_hidden_states = hidden_states + else: + out_hidden_states = torch.empty_like(hidden_states) + + for chunk in range((num_tokens // CHUNK_SIZE) + 1): + begin_chunk_idx, end_chunk_idx = ( + chunk * CHUNK_SIZE, + min((chunk + 1) * CHUNK_SIZE, num_tokens), + ) + curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx] + tokens_in_chunk, _ = curr_hidden_states.shape + + if tokens_in_chunk == 0: + break + + if tokens_in_chunk < CHUNK_SIZE and chunk > 0: + # Adjust the intermediate cache size and config for the last + # chunk. Note that in most cases we only have one chunk + # so the cache size and config are already set correctly and + # do not need to be adjusted. + intermediate_cache1 = intermediate_cache1[:tokens_in_chunk] + intermediate_cache2 = intermediate_cache2[:tokens_in_chunk] + intermediate_cache3 = intermediate_cache3[:tokens_in_chunk] + config = get_config_func(tokens_in_chunk) + + curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx] + curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx] + + sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( + curr_topk_ids, config["BLOCK_SIZE_M"], E + ) + + invoke_fused_moe_kernel( + curr_hidden_states, + w1, + intermediate_cache1, + a1_scale, + w1_scale, + curr_topk_weights, + curr_topk_ids, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + False, + topk_ids.shape[1], + config, + compute_type=compute_type, + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a16=use_int8_w8a16, + ) + + ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) + + invoke_fused_moe_kernel( + intermediate_cache2, + w2, + intermediate_cache3, + a2_scale, + w2_scale, + curr_topk_weights, + curr_topk_ids, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + True, + 1, + config, + compute_type=compute_type, + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a16=use_int8_w8a16, + ) + + ops.moe_sum( + intermediate_cache3.view(*intermediate_cache3.shape), + out_hidden_states[begin_chunk_idx:end_chunk_idx], + ) + return out_hidden_states + + +def fused_moe( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + inplace: bool = False, + use_grouped_topk: bool = False, + num_expert_group: Optional[int] = None, + topk_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + use_fp8_w8a8: bool = False, + use_int8_w8a16: bool = False, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """ + This function computes a Mixture of Experts (MoE) layer using two sets of + weights, w1 and w2, and top-k gating mechanism. + + Parameters: + - hidden_states (torch.Tensor): The input tensor to the MoE layer. + - w1 (torch.Tensor): The first set of expert weights. + - w2 (torch.Tensor): The second set of expert weights. + - gating_output (torch.Tensor): The output of the gating operation + (before softmax). + - topk (int): The number of top-k experts to select. + - renormalize (bool): If True, renormalize the top-k weights to sum to 1. + - inplace (bool): If True, perform the operation in-place. + Defaults to False. + - num_expert_group: Optional[int]: additional parameter for grouped_topk + - topk_group: Optional[int]: additional parameter for grouped_topk + - use_grouped_topk: If True, use grouped_topk instead of fused_topk + note: Deepseekv2 model uses grouped_topk + - use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner + products for w1 and w2. Defaults to False. + - use_int8_w8a16 (bool): If True, use fp8 arithmetic to compute the inner + products for w1 and w2. Defaults to False. + - w1_scale (Optional[torch.Tensor]): Optional scale to be used for + w1. + - w2_scale (Optional[torch.Tensor]): Optional scale to be used for + w2. + + Returns: + - torch.Tensor: The output tensor after applying the MoE layer. + """ + # Check constraints. + assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch" + + if use_grouped_topk: + assert num_expert_group is not None and topk_group is not None + topk_weights, topk_ids = grouped_topk( + hidden_states, + gating_output, + topk, + renormalize, + num_expert_group, + topk_group, + ) + elif custom_routing_function is None: + topk_weights, topk_ids = fused_topk( + hidden_states, gating_output, topk, renormalize + ) + else: + topk_weights, topk_ids = custom_routing_function( + hidden_states, gating_output, topk, renormalize + ) + + return fused_experts( + hidden_states, + w1, + w2, + topk_weights, + topk_ids, + inplace=inplace, + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a16=use_int8_w8a16, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + ) diff --git a/python/sglang/srt/layers/triton_fused_moe/layer.py b/python/sglang/srt/layers/triton_fused_moe/layer.py new file mode 100644 index 00000000000..3ec2f7a340b --- /dev/null +++ b/python/sglang/srt/layers/triton_fused_moe/layer.py @@ -0,0 +1,631 @@ +from abc import abstractmethod +from enum import Enum +from typing import Callable, List, Optional, Tuple + +import torch +from vllm.distributed import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce, +) +from vllm.model_executor.custom_op import CustomOp + +from sglang.srt.layers.custom_op_util import register_custom_op +from sglang.srt.layers.quantization.base_config import ( + QuantizationConfig, + QuantizeMethodBase, +) +from sglang.srt.utils import set_weight_attrs + +if torch.cuda.is_available() or torch.hip.is_available(): + from .fused_moe import fused_experts +else: + fused_experts = None # type: ignore + +import logging + +logger = logging.getLogger(__name__) + + +class FusedMoeWeightScaleSupported(Enum): + TENSOR = "tensor" + CHANNEL = "channel" + GROUP = "group" + + +class FusedMoEMethodBase(QuantizeMethodBase): + + @abstractmethod + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + raise NotImplementedError + + @abstractmethod + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool, + ) -> torch.Tensor: + raise NotImplementedError + + +@register_custom_op("sglang_unquantized_fused_moe") +class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): + """MoE method without quantization.""" + + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + # Fused gate_up_proj (column parallel) + w13_weight = torch.nn.Parameter( + torch.empty( + num_experts, 2 * intermediate_size, hidden_size, dtype=params_dtype + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + # down_proj (row parallel) + w2_weight = torch.nn.Parameter( + torch.empty( + num_experts, hidden_size, intermediate_size, dtype=params_dtype + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + ) -> torch.Tensor: + return self.forward( + x=x, + layer=layer, + router_logits=router_logits, + top_k=top_k, + renormalize=renormalize, + use_grouped_topk=use_grouped_topk, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + ) + + def forward_cuda( + self, + layer: torch.nn.Module, + x: torch.Tensor, + use_grouped_topk: bool, + top_k: int, + router_logits: torch.Tensor, + renormalize: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + ) -> torch.Tensor: + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + ) + + return fused_experts( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=True, + ) + + def forward_cpu(self, *args, **kwargs): + raise NotImplementedError("The CPU backend currently does not support MoE.") + + def forward_tpu(self, *args, **kwargs) -> torch.Tensor: + raise NotImplementedError("The TPU backend currently does not support MoE.") + + forward_native = forward_cuda + + +class FusedMoE(torch.nn.Module): + """FusedMoE layer for MoE models. + + This layer contains both MergedColumnParallel weights (gate_up_proj / + w13) and RowParallelLinear weights (down_proj/ w2). + + Note: Mixtral uses w1, w2, and w3 for gate, up, and down_proj. We + copy that naming convention here and handle any remapping in the + load_weights function in each model implementation. + + Args: + num_experts: Number of experts in the model + top_k: Number of experts selected for each token + hidden_size: Input hidden state size of the transformer + intermediate_size: Intermediate size of the experts + params_dtype: Data type for the parameters. + reduce_results: Whether to all all_reduce on the output of the layer + renomalize: Whether to renormalize the logits in the fused_moe kernel + quant_config: Quantization configure. + """ + + def __init__( + self, + num_experts: int, + top_k: int, + hidden_size: int, + intermediate_size: int, + params_dtype: Optional[torch.dtype] = None, + reduce_results: bool = False, + renormalize: bool = True, + use_grouped_topk: bool = False, + num_expert_group: Optional[int] = None, + topk_group: Optional[int] = None, + quant_config: Optional[QuantizationConfig] = None, + tp_size: Optional[int] = None, + prefix: str = "", + custom_routing_function: Optional[Callable] = None, + ): + super().__init__() + + if params_dtype is None: + params_dtype = torch.get_default_dtype() + + self.tp_size = ( + tp_size if tp_size is not None else get_tensor_model_parallel_world_size() + ) + self.top_k = top_k + self.num_experts = num_experts + self.intermediate_size_per_partition = intermediate_size // self.tp_size + self.reduce_results = reduce_results + self.renormalize = renormalize + self.use_grouped_topk = use_grouped_topk + if self.use_grouped_topk: + assert num_expert_group is not None and topk_group is not None + self.num_expert_group = num_expert_group + self.topk_group = topk_group + self.custom_routing_function = custom_routing_function + + if quant_config is None: + self.quant_method: Optional[QuantizeMethodBase] = ( + UnquantizedFusedMoEMethod() + ) + else: + self.quant_method = quant_config.get_quant_method(self, prefix) + assert self.quant_method is not None + + self.quant_method.create_weights( + layer=self, + num_experts=num_experts, + hidden_size=hidden_size, + intermediate_size=self.intermediate_size_per_partition, + params_dtype=params_dtype, + weight_loader=self.weight_loader, + ) + + def _load_per_tensor_weight_scale( + self, + shard_id: str, + param: torch.nn.Parameter, + loaded_weight: torch.Tensor, + expert_id: int, + ): + param_data = param.data + # for per tensor weight quantization + if shard_id in ("w1", "w3"): + # We have to keep the weight scales of w1 and w3 because + # we need to re-quantize w1/w3 weights after weight loading. + idx = 0 if shard_id == "w1" else 1 + param_data[expert_id][idx] = loaded_weight + # If we are in the row parallel case (down_proj) + elif shard_id == "w2": + param_data[expert_id] = loaded_weight + + def _load_model_weight_or_group_weight_scale( + self, + shard_dim: int, + expert_data: torch.Tensor, + shard_id: str, + loaded_weight: torch.tensor, + tp_rank: int, + ): + # Load grouped weight scales for group quantization + # or model weights + if shard_id == "w2": + self._load_w2( + shard_id=shard_id, + shard_dim=shard_dim, + loaded_weight=loaded_weight, + expert_data=expert_data, + tp_rank=tp_rank, + ) + elif shard_id in ("w1", "w3"): + self._load_w13( + shard_id=shard_id, + shard_dim=shard_dim, + loaded_weight=loaded_weight, + expert_data=expert_data, + tp_rank=tp_rank, + ) + + def _load_per_channel_weight_scale( + self, + expert_data: torch.Tensor, + shard_dim: int, + shard_id: str, + loaded_weight: torch.tensor, + tp_rank: int, + ): + # for per channel weight quantization + if shard_id == "w2": + expert_data.copy_(loaded_weight) + elif shard_id in ("w1", "w3"): + self._load_w13( + shard_id=shard_id, + shard_dim=shard_dim, + loaded_weight=loaded_weight, + expert_data=expert_data, + tp_rank=tp_rank, + ) + + def _load_w13( + self, + expert_data: torch.Tensor, + shard_dim: int, + shard_id: str, + loaded_weight: torch.tensor, + tp_rank: int, + ): + + # Index the loaded weight for tp sharding. + # gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim + shard_size = expert_data.shape[shard_dim] // 2 + loaded_weight = loaded_weight.narrow( + shard_dim, shard_size * tp_rank, shard_size + ) + # Narrow parameter and load. + # w1, gate_proj: Load into first logical weight of w13. + if shard_id == "w1": + expert_data = expert_data.narrow(shard_dim, 0, shard_size) + # w3, up_proj: Load into second logical weight of w13. + else: + assert shard_id == "w3" + expert_data = expert_data.narrow(shard_dim, shard_size, shard_size) + expert_data.copy_(loaded_weight) + + def _load_w2( + self, + expert_data: torch.Tensor, + shard_dim: int, + shard_id: str, + loaded_weight: torch.tensor, + tp_rank: int, + ): + + # Index the loaded weight for tp sharding. + # down_proj: "RowParallel" so tp sharding on input_dim + # Narrow parameter and load. + shard_size = expert_data.shape[shard_dim] + loaded_weight = loaded_weight.narrow( + shard_dim, shard_size * tp_rank, shard_size + ) + # w2, down_proj: Load into only logical weight of w2. + expert_data.copy_(loaded_weight) + + def _load_single_value( + self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, expert_id: int + ): + param_data = param.data + + # Input scales can be loaded directly and should be equal. + param_data[expert_id] = loaded_weight + + def _load_g_idx( + self, + shard_id: str, + expert_data: torch.Tensor, + shard_dim: int, + loaded_weight: torch.tensor, + tp_rank: int, + ): + + if shard_id == "w2": + self._load_w2( + shard_id=shard_id, + shard_dim=shard_dim, + loaded_weight=loaded_weight, + expert_data=expert_data, + tp_rank=tp_rank, + ) + else: + assert shard_id in ("w1", "w3") + expert_data.copy_(loaded_weight) + + def weight_loader( + self, + param: torch.nn.Parameter, + loaded_weight: torch.Tensor, + weight_name: str, + shard_id: str, + expert_id: int, + ) -> None: + + # compressed-tensors checkpoints with packed weights are stored flipped + # TODO (mgoin): check self.quant_method.quant_config.quant_format + # against known CompressionFormat enum values that have this quality + loaded_weight = ( + loaded_weight.t().contiguous() + if ( + self.quant_method.__class__.__name__ + == "CompressedTensorsWNA16MoEMethod" + ) + else loaded_weight + ) + + if shard_id not in ("w1", "w2", "w3"): + raise ValueError( + f"shard_id must be ['w1','w2','w3'] but " f"got {shard_id}." + ) + + WEIGHT_SCALE_SUPPORTED = [e.value for e in FusedMoeWeightScaleSupported] + # Fetch the dim to shard the parameter/loaded weight + # based on the shard id. This will be whatever + # dimension intermediate_size is used. + SHARD_ID_TO_SHARDED_DIM = {"w1": 0, "w2": 1, "w3": 0} + + expert_data = param.data[expert_id] + tp_rank = get_tensor_model_parallel_rank() + + # is_transposed: if the dim to shard the weight + # should be flipped. Required by GPTQ, compressed-tensors + # should be whatever dimension intermediate_size is + is_transposed = getattr(param, "is_transposed", False) + shard_dim = SHARD_ID_TO_SHARDED_DIM[shard_id] + if is_transposed: + shard_dim = ~shard_dim + + # Case input scale: input_scale loading is only supported for fp8 + if "input_scale" in weight_name: + # this is needed for compressed-tensors only + loaded_weight = loaded_weight.to(param.data.device) + + if ( + param.data[expert_id] != 1 + and (param.data[expert_id] - loaded_weight).abs() > 1e-5 + ): + raise ValueError( + "input_scales of w1 and w3 of a layer " + f"must be equal. But got {param.data[expert_id]} " + f"vs. {loaded_weight}" + ) + + self._load_single_value( + param=param, loaded_weight=loaded_weight, expert_id=expert_id + ) + return + + # Case g_idx + if "g_idx" in weight_name: + self._load_g_idx( + shard_dim=0, + shard_id=shard_id, + loaded_weight=loaded_weight, + expert_data=expert_data, + tp_rank=tp_rank, + ) + return + + # Case weight scales and zero_points + if "scale" in weight_name or "zero" in weight_name: + # load the weight scales and zp based on the quantization scheme + # supported weight scales/zp can be found in + # FusedMoeWeightScaleSupported + # TODO @dsikka: once hardened, refactor to use vLLM Parameters + # specific to each case + quant_method = getattr(param, "quant_method", None) + if quant_method == FusedMoeWeightScaleSupported.CHANNEL.value: + self._load_per_channel_weight_scale( + shard_id=shard_id, + shard_dim=shard_dim, + loaded_weight=loaded_weight, + expert_data=expert_data, + tp_rank=tp_rank, + ) + elif quant_method == FusedMoeWeightScaleSupported.GROUP.value: + self._load_model_weight_or_group_weight_scale( + shard_id=shard_id, + shard_dim=shard_dim, + loaded_weight=loaded_weight, + expert_data=expert_data, + tp_rank=tp_rank, + ) + elif quant_method == FusedMoeWeightScaleSupported.TENSOR.value: + self._load_per_tensor_weight_scale( + shard_id=shard_id, + param=param, + loaded_weight=loaded_weight, + expert_id=expert_id, + ) + else: + raise ValueError( + f"quant method must be one of {WEIGHT_SCALE_SUPPORTED}" + ) + return + + # Case weight_shape + if "weight_shape" in weight_name: + # only required by compressed-tensors + self._load_single_value( + param=param, loaded_weight=loaded_weight, expert_id=expert_id + ) + return + + # Case model weights + if "weight" in weight_name: + self._load_model_weight_or_group_weight_scale( + shard_id=shard_id, + shard_dim=shard_dim, + loaded_weight=loaded_weight, + expert_data=expert_data, + tp_rank=tp_rank, + ) + return + + @staticmethod + def select_experts( + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + use_grouped_topk: bool, + renormalize: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + ): + from vllm.model_executor.layers.fused_moe.fused_moe import ( + fused_topk, + grouped_topk, + ) + + # DeekSeekv2 uses grouped_top_k + if use_grouped_topk: + assert topk_group is not None + assert num_expert_group is not None + topk_weights, topk_ids = grouped_topk( + hidden_states=hidden_states, + gating_output=router_logits, + topk=top_k, + renormalize=renormalize, + num_expert_group=num_expert_group, + topk_group=topk_group, + ) + elif custom_routing_function is None: + topk_weights, topk_ids = fused_topk( + hidden_states=hidden_states, + gating_output=router_logits, + topk=top_k, + renormalize=renormalize, + ) + else: + topk_weights, topk_ids = custom_routing_function( + hidden_states=hidden_states, + gating_output=router_logits, + topk=top_k, + renormalize=renormalize, + ) + + return topk_weights, topk_ids + + def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): + assert self.quant_method is not None + + # Matrix multiply. + final_hidden_states = self.quant_method.apply( + layer=self, + x=hidden_states, + router_logits=router_logits, + top_k=self.top_k, + renormalize=self.renormalize, + use_grouped_topk=self.use_grouped_topk, + topk_group=self.topk_group, + num_expert_group=self.num_expert_group, + custom_routing_function=self.custom_routing_function, + ) + + if self.reduce_results and self.tp_size > 1: + final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) + + return final_hidden_states + + @classmethod + def make_expert_params_mapping( + cls, + ckpt_gate_proj_name: str, + ckpt_down_proj_name: str, + ckpt_up_proj_name: str, + num_experts: int, + ) -> List[Tuple[str, str, int, str]]: + + return [ + # (param_name, weight_name, expert_id, shard_id) + ( + ( + "experts.w13_" + if weight_name in [ckpt_gate_proj_name, ckpt_up_proj_name] + else "experts.w2_" + ), + f"experts.{expert_id}.{weight_name}.", + expert_id, + shard_id, + ) + for expert_id in range(num_experts) + for shard_id, weight_name in [ + ("w1", ckpt_gate_proj_name), + ("w2", ckpt_down_proj_name), + ("w3", ckpt_up_proj_name), + ] + ] + + def _load_fp8_scale( + self, + param: torch.nn.Parameter, + loaded_weight: torch.Tensor, + weight_name: str, + shard_id: str, + expert_id: int, + ) -> None: + param_data = param.data + + # Input scales can be loaded directly and should be equal. + if "input_scale" in weight_name: + if ( + param_data[expert_id] != 1 + and (param_data[expert_id] - loaded_weight).abs() > 1e-5 + ): + raise ValueError( + "input_scales of w1 and w3 of a layer " + f"must be equal. But got {param_data[expert_id]} " + f"vs. {loaded_weight}" + ) + param_data[expert_id] = loaded_weight + # Weight scales + elif "weight_scale" in weight_name: + # If we are in merged column case (gate_up_proj) + if shard_id in ("w1", "w3"): + # We have to keep the weight scales of w1 and w3 because + # we need to re-quantize w1/w3 weights after weight loading. + idx = 0 if shard_id == "w1" else 1 + param_data[expert_id][idx] = loaded_weight + # If we are in the row parallel case (down_proj) + else: + param_data[expert_id] = loaded_weight diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 815d7771768..73ab9c05919 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -27,7 +27,6 @@ get_tp_group, tensor_model_parallel_all_reduce, ) -from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -42,6 +41,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.triton_fused_moe import FusedMoE from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index fcba31a56bf..f0d129a47e7 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -31,7 +31,7 @@ import warnings from importlib.metadata import PackageNotFoundError, version from io import BytesIO -from typing import Any, Dict, List, Optional, Protocol, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple, Union import numpy as np import psutil @@ -45,6 +45,7 @@ from starlette.routing import Mount from torch import nn from torch.func import functional_call +from torch.library import Library from torch.profiler import ProfilerActivity, profile, record_function from triton.runtime.cache import ( FileCacheManager, @@ -930,3 +931,44 @@ def get_nvgpu_memory_capacity(): def crash_on_warnings(): # Crash on warning if we are running CI tests return os.getenv("SGLANG_IS_IN_CI", "false") == "true" + + +def get_device_name(device_id: int = 0) -> str: + if hasattr(torch, "cuda") and torch.cuda.is_available(): + return torch.cuda.get_device_name(device_id) + + if hasattr(torch, "hip") and torch.hip.is_available(): + return torch.hip.get_device_name(device_id) + + if hasattr(torch, "xpu") and torch.xpu.is_available(): + return torch.xpu.get_device_name(device_id) + + if hasattr(torch, "hpu") and torch.hpu.is_available(): + return torch.hpu.get_device_name(device_id) + + +sglang_lib = Library("sglang", "FRAGMENT") # noqa + + +def direct_register_custom_op( + op_name: str, + op_func: Callable, + mutates_args: List[str], + fake_impl: Optional[Callable] = None, + target_lib: Optional[Library] = None, +): + import torch.library + + if hasattr(torch.library, "infer_schema"): + schema_str = torch.library.infer_schema(op_func, mutates_args=mutates_args) + else: + # for pytorch 2.4 + import torch._custom_op.impl + + schema_str = torch._custom_op.impl.infer_schema(op_func, mutates_args) + + my_lib = target_lib or sglang_lib + my_lib.define(op_name + schema_str) + my_lib.impl(op_name, op_func, "CUDA") + if fake_impl is not None: + my_lib._register_fake(op_name, fake_impl)