diff --git a/vllm_gaudi/extension/ops.py b/vllm_gaudi/extension/ops.py index a40388d3e..d02ffc6cc 100644 --- a/vllm_gaudi/extension/ops.py +++ b/vllm_gaudi/extension/ops.py @@ -13,10 +13,12 @@ from vllm_gaudi.extension.runtime import get_config from vllm_gaudi.extension.utils import get_kv_fetch_extra_args from vllm_gaudi.extension.scales import ConvertScaleToHwAligned - +import vllm.model_executor.layers.quantization as vllm_quant import habana_frameworks.torch.utils.experimental as htexp import types from vllm.model_executor.layers.fused_moe import FusedMoeWeightScaleSupported +from vllm.model_executor.layers.quantization import get_quantization_config as vllm_get_quantization_config +from vllm.model_executor.layers.quantization.base_config import QuantizationConfig is_hpu_gaudi2 = htexp._get_device_type() == htexp.synDeviceType.synDeviceGaudi2 is_hpu_gaudi3 = htexp._get_device_type() == htexp.synDeviceType.synDeviceGaudi3 @@ -1330,3 +1332,13 @@ def forward( else: final_hidden_states += slice_final_hidden_states return final_hidden_states + + +def oot_get_quantization_config(quantization: str) -> QuantizationConfig: + from .quant import _FakeINCConfig + if quantization == "inc": + return _FakeINCConfig + return vllm_get_quantization_config(quantization) + + +vllm_quant.get_quantization_config = oot_get_quantization_config diff --git a/vllm_gaudi/extension/quant.py b/vllm_gaudi/extension/quant.py new file mode 100644 index 000000000..92b6a4f4b --- /dev/null +++ b/vllm_gaudi/extension/quant.py @@ -0,0 +1,40 @@ +# SPDX-License-Identifier: Apache-2.0 +from typing import Any, Optional + +import torch + +from vllm.model_executor.layers.fused_moe.layer import (FusedMoE, UnquantizedFusedMoEMethod) +from vllm.model_executor.layers.linear import (LinearBase, UnquantizedLinearMethod) +from vllm.model_executor.layers.quantization import QuantizationMethods +from vllm.model_executor.layers.quantization.base_config import (QuantizationConfig, QuantizeMethodBase) + + +class _FakeINCConfig(QuantizationConfig): + """Placeholder INC quantization config class for FP8 using Intel Neural Compressor.""" + + @classmethod + def get_name(cls) -> QuantizationMethods: + return "inc" + + @classmethod + def get_supported_act_dtypes(cls) -> list[torch.dtype]: + return [torch.bfloat16] + + @classmethod + def from_config(cls, config: dict[str, Any]) -> "_FakeINCConfig": + raise AssertionError + + def get_quant_method(self, layer: torch.nn.Module, prefix: str) -> Optional["QuantizeMethodBase"]: + if isinstance(layer, LinearBase): + return UnquantizedLinearMethod() + elif isinstance(layer, FusedMoE): + return UnquantizedFusedMoEMethod(layer.moe_config) + return None + + @classmethod + def get_min_capability(cls) -> int: + raise AssertionError + + @staticmethod + def get_config_filenames() -> list[str]: + return []