Skip to content

Commit

Permalink
feat: remove the dependency on FusedMoE (#2153)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhyncs authored Nov 24, 2024
1 parent dbe1729 commit b509db5
Show file tree
Hide file tree
Showing 7 changed files with 1,602 additions and 7 deletions.
20 changes: 15 additions & 5 deletions python/sglang/srt/layers/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,23 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
"QUANTIZATION_METHODS",
]

"""
def fp8_get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]:

def fp8_get_quant_method(self, layer, prefix):
from vllm.model_executor.layers.linear import LinearBase
from vllm.model_executor.layers.quantization.fp8 import (
Fp8LinearMethod,
Fp8MoEMethod,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
is_layer_skipped,
)

from sglang.srt.layers.triton_fused_moe.layer import FusedMoE

if isinstance(layer, LinearBase):
if is_layer_skipped(prefix, self.ignored_layers):
from sglang.srt.layers.linear import UnquantizedLinearMethod

return UnquantizedLinearMethod()
return Fp8LinearMethod(self)
elif isinstance(layer, FusedMoE):
Expand All @@ -71,4 +82,3 @@ def fp8_get_quant_method(


setattr(Fp8Config, "get_quant_method", fp8_get_quant_method)
"""
44 changes: 44 additions & 0 deletions python/sglang/srt/layers/triton_fused_moe/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from contextlib import contextmanager
from typing import Any, Dict, Optional

import sglang.srt.layers.triton_fused_moe.fused_moe # noqa
from sglang.srt.layers.triton_fused_moe.fused_moe import (
fused_experts,
fused_topk,
get_config_file_name,
grouped_topk,
)
from sglang.srt.layers.triton_fused_moe.layer import (
FusedMoE,
FusedMoEMethodBase,
FusedMoeWeightScaleSupported,
)

_config: Optional[Dict[str, Any]] = None


@contextmanager
def override_config(config):
global _config
old_config = _config
_config = config
yield
_config = old_config


def get_config() -> Optional[Dict[str, Any]]:
return _config


__all__ = [
"FusedMoE",
"FusedMoEMethodBase",
"FusedMoeWeightScaleSupported",
"override_config",
"get_config",
"fused_moe",
"fused_topk",
"fused_experts",
"get_config_file_name",
"grouped_topk",
]
10 changes: 10 additions & 0 deletions python/sglang/srt/layers/triton_fused_moe/configs/README
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
This directory contains tuned configurations for different settings of the fused_moe kernel.
For different settings of
- E (number of experts)
- N (intermediate size)
- device_name (torch.cuda.get_device_name())
the JSON file contains a mapping from M (batch size) to the chosen configuration.

The example configurations provided are for the Mixtral model for TP2 on H100
and TP4 on A100. Mixtral has intermediate size N = 14336, i.e. for TP2 we have
N = 7168 and for TP4 we have N = 3584.
Loading

0 comments on commit b509db5

Please sign in to comment.