Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion vllm_gaudi/extension/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@
from vllm_gaudi.extension.runtime import get_config
from vllm_gaudi.extension.utils import get_kv_fetch_extra_args
from vllm_gaudi.extension.scales import ConvertScaleToHwAligned

import vllm.model_executor.layers.quantization as vllm_quant
import habana_frameworks.torch.utils.experimental as htexp
import types
from vllm.model_executor.layers.fused_moe import FusedMoeWeightScaleSupported
from vllm.model_executor.layers.quantization import get_quantization_config as vllm_get_quantization_config
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig

is_hpu_gaudi2 = htexp._get_device_type() == htexp.synDeviceType.synDeviceGaudi2
is_hpu_gaudi3 = htexp._get_device_type() == htexp.synDeviceType.synDeviceGaudi3
Expand Down Expand Up @@ -1330,3 +1332,13 @@ def forward(
else:
final_hidden_states += slice_final_hidden_states
return final_hidden_states


def oot_get_quantization_config(quantization: str) -> QuantizationConfig:
from .quant import _FakeINCConfig
if quantization == "inc":
return _FakeINCConfig
return vllm_get_quantization_config(quantization)


vllm_quant.get_quantization_config = oot_get_quantization_config
40 changes: 40 additions & 0 deletions vllm_gaudi/extension/quant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Optional

import torch

from vllm.model_executor.layers.fused_moe.layer import (FusedMoE, UnquantizedFusedMoEMethod)
from vllm.model_executor.layers.linear import (LinearBase, UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (QuantizationConfig, QuantizeMethodBase)


class _FakeINCConfig(QuantizationConfig):
"""Placeholder INC quantization config class for FP8 using Intel Neural Compressor."""

@classmethod
def get_name(cls) -> QuantizationMethods:
return "inc"

@classmethod
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.bfloat16]

@classmethod
def from_config(cls, config: dict[str, Any]) -> "_FakeINCConfig":
raise AssertionError

def get_quant_method(self, layer: torch.nn.Module, prefix: str) -> Optional["QuantizeMethodBase"]:
if isinstance(layer, LinearBase):
return UnquantizedLinearMethod()
elif isinstance(layer, FusedMoE):
return UnquantizedFusedMoEMethod(layer.moe_config)
return None

@classmethod
def get_min_capability(cls) -> int:
raise AssertionError

@staticmethod
def get_config_filenames() -> list[str]:
return []