diff --git a/docs/features/quantization/README.md b/docs/features/quantization/README.md index 8b4dcf01969a..c84f6ede1c1d 100644 --- a/docs/features/quantization/README.md +++ b/docs/features/quantization/README.md @@ -5,12 +5,11 @@ Quantization trades off model precision for smaller memory footprint, allowing l Contents: - [AutoAWQ](auto_awq.md) -- [AutoRound](auto_round.md) - [BitsAndBytes](bnb.md) - [BitBLAS](bitblas.md) - [GGUF](gguf.md) - [GPTQModel](gptqmodel.md) -- [INC](inc.md) +- [Intel Neural Compressor](inc.md) - [INT4 W4A16](int4.md) - [INT8 W8A8](int8.md) - [FP8 W8A8](fp8.md) @@ -43,23 +42,23 @@ th:not(:first-child) { } -| Implementation | Volta | Turing | Ampere | Ada | Hopper | AMD GPU | Intel GPU | Intel Gaudi | x86 CPU | -|-----------------------|---------|----------|----------|-------|----------|-----------|-------------|-------------|-----------| -| AWQ | ❌ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ✅︎ | ❌ | ✅︎ | -| GPTQ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ✅︎ | ❌ | ✅︎ | -| Marlin (GPTQ/AWQ/FP8) | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | -| INT8 (W8A8) | ❌ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ✅︎ | -| FP8 (W8A8) | ❌ | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | -| BitBLAS | ✅︎ | ✅ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | -| BitBLAS (GPTQ) | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | -| bitsandbytes | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | -| DeepSpeedFP | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | -| GGUF | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | -| INC (W8A8) | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅︎ | ❌ | +| Implementation | Volta | Turing | Ampere | Ada | Hopper | AMD GPU | Intel GPU | x86 CPU | +|-----------------------|---------|----------|----------|-------|----------|-----------|-------------|-----------| +| AWQ | ❌ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ✅︎ | ✅︎ | +| GPTQ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ✅︎ | ✅︎ | +| Marlin (GPTQ/AWQ/FP8) | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | +| INT8 (W8A8) | ❌ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ✅︎ | +| FP8 (W8A8) | ❌ | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | +| BitBLAS | ✅︎ | ✅ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | +| BitBLAS (GPTQ) | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | +| bitsandbytes | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | +| DeepSpeedFP | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | +| GGUF | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | - Volta refers to SM 7.0, Turing to SM 7.5, Ampere to SM 8.0/8.6, Ada to SM 8.9, and Hopper to SM 9.0. - ✅︎ indicates that the quantization method is supported on the specified hardware. - ❌ indicates that the quantization method is not supported on the specified hardware. +- All Intel Gaudi quantization support has been migrated to [vLLM-Gaudi](https://github.com/vllm-project/vllm-gaudi). !!! note For information on quantization support on Google TPU, please refer to the [TPU-Inference Recommended Models and Features](https://docs.vllm.ai/projects/tpu/en/latest/recommended_models_features/) documentation. diff --git a/docs/features/quantization/auto_round.md b/docs/features/quantization/auto_round.md deleted file mode 100644 index 9c14f362b663..000000000000 --- a/docs/features/quantization/auto_round.md +++ /dev/null @@ -1,103 +0,0 @@ -# AutoRound - -[AutoRound](https://github.com/intel/auto-round) is Intel’s advanced quantization algorithm designed to produce highly efficient **INT2, INT3, INT4, and INT8** -quantized large language models—striking an optimal balance between accuracy and deployment performance. - -AutoRound applies weight-only quantization to transformer-based models, enabling significant memory savings and faster -inference while maintaining near-original accuracy. It supports a wide range of hardware platforms, including **CPUs, -Intel GPUs, HPUs, and CUDA-enabled devices**. - -Please refer to the [AutoRound guide](https://github.com/intel/auto-round/blob/main/docs/step_by_step.md) for more details. - -Key Features: - -✅ **AutoRound, AutoAWQ, AutoGPTQ, and GGUF** are supported - -✅ **10+ vision-language models (VLMs)** are supported - -✅ **Per-layer mixed-bit quantization** for fine-grained control - -✅ **RTN (Round-To-Nearest) mode** for quick quantization with slight accuracy loss - -✅ **Multiple quantization recipes**: best, base, and light - -✅ Advanced utilities such as immediate packing and support for **10+ backends** - -## Installation - -```bash -uv pip install auto-round -``` - -## Quantizing a model - -For VLMs, please change to `auto-round-mllm` in CLI usage and `AutoRoundMLLM` in API usage. - -### CLI usage - -```bash -auto-round \ - --model Qwen/Qwen3-0.6B \ - --bits 4 \ - --group_size 128 \ - --format "auto_round" \ - --output_dir ./tmp_autoround -``` - -```bash -auto-round \ - --model Qwen/Qwen3-0.6B \ - --format "gguf:q4_k_m" \ - --output_dir ./tmp_autoround -``` - -### API usage - -```python -from transformers import AutoModelForCausalLM, AutoTokenizer -from auto_round import AutoRound - -model_name = "Qwen/Qwen3-0.6B" -model = AutoModelForCausalLM.from_pretrained(model_name, dtype="auto") -tokenizer = AutoTokenizer.from_pretrained(model_name) - -bits, group_size, sym = 4, 128, True -autoround = AutoRound(model, tokenizer, bits=bits, group_size=group_size, sym=sym) - -# the best accuracy, 4-5X slower, low_gpu_mem_usage could save ~20G but ~30% slower -# autoround = AutoRound(model, tokenizer, nsamples=512, iters=1000, low_gpu_mem_usage=True, bits=bits, group_size=group_size, sym=sym) - -# 2-3X speedup, slight accuracy drop at W4G128 -# autoround = AutoRound(model, tokenizer, nsamples=128, iters=50, lr=5e-3, bits=bits, group_size=group_size, sym=sym ) - -output_dir = "./tmp_autoround" -# format= 'auto_round'(default), 'auto_gptq', 'auto_awq' -autoround.quantize_and_save(output_dir, format="auto_round") -``` - -## Running a quantized model with vLLM - -Here is some example code to run auto-round format in vLLM: - -```python -from vllm import LLM, SamplingParams - -prompts = [ - "Hello, my name is", -] -sampling_params = SamplingParams(temperature=0.6, top_p=0.95) -model_name = "Intel/DeepSeek-R1-0528-Qwen3-8B-int4-AutoRound" -llm = LLM(model=model_name) - -outputs = llm.generate(prompts, sampling_params) - -for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") -``` - -## Acknowledgement - -Special thanks to open-source low precision libraries such as AutoGPTQ, AutoAWQ, GPTQModel, Triton, Marlin, and -ExLLaMAV2 for providing low-precision CUDA kernels, which are leveraged in AutoRound. diff --git a/docs/features/quantization/inc.md b/docs/features/quantization/inc.md index f2bbca498cd0..adb6b3ae8e2f 100644 --- a/docs/features/quantization/inc.md +++ b/docs/features/quantization/inc.md @@ -1,50 +1,89 @@ -# FP8 INC +# Intel Quantization Support -vLLM supports FP8 (8-bit floating point) weight and activation quantization using Intel® Neural Compressor (INC) on Intel® Gaudi® 2 and Intel® Gaudi® 3 AI accelerators. -Currently, quantization is validated only in Llama models. +[AutoRound](https://github.com/intel/auto-round) is Intel’s advanced quantization algorithm designed for large language models(LLMs). It produces highly efficient **INT2, INT3, INT4, INT8, MXFP8, MXFP4, NVFP4**, and **GGUF** quantized models, balancing accuracy and inference performance. AutoRound is also part of the [Intel® Neural Compressor](https://github.com/intel/neural-compressor). For a deeper introduction, see the [AutoRound step-by-step guide](https://github.com/intel/auto-round/blob/main/docs/step_by_step.md). -Intel Gaudi supports quantization of various modules and functions, including, but not limited to `Linear`, `KVCache`, `Matmul` and `Softmax`. For more information, please refer to: -[Supported Modules\\Supported Functions\\Custom Patched Modules](https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Quantization/Inference_Using_FP8.html#supported-modules). +## Key Features -!!! note - Measurement files are required to run quantized models with vLLM on Gaudi accelerators. The FP8 model calibration procedure is described in the [vLLM HPU extension](https://github.com/HabanaAI/vllm-hpu-extension/tree/main/calibration/README.md) package. +✅ Superior Accuracy Delivers strong performance even at 2–3 bits [example models](https://huggingface.co/collections/OPEA/2-3-bits) -!!! note - `QUANT_CONFIG` is an environment variable that points to the measurement or quantization [JSON config file](https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Quantization/Inference_Using_FP8.html#supported-json-config-file-options). - The measurement configuration file is used during the calibration procedure to collect measurements for a given model. The quantization configuration is used during inference. +✅ Fast Mixed `Bits`/`Dtypes` Scheme Generation Automatically configure in minutes + +✅ Support for exporting **AutoRound, AutoAWQ, AutoGPTQ, and GGUF** formats + +✅ **10+ vision-language models (VLMs)** are supported + +✅ **Per-layer mixed-bit quantization** for fine-grained control + +✅ **RTN (Round-To-Nearest) mode** for quick quantization with slight accuracy loss + +✅ **Multiple quantization recipes**: best, base, and light -## Run Online Inference Using FP8 +✅ Advanced utilities such as immediate packing and support for **10+ backends** -Once you've completed the model calibration process and collected the measurements, you can run FP8 inference with vLLM using the following command: +## Supported Recipes on Intel Platforms + +On Intel platforms, AutoRound recipes are being enabled progressively by format and hardware. Currently, vLLM supports: + +- **`W4A16`**: weight-only, 4-bit weights with 16-bit activations +- **`W8A16`**: weight-only, 8-bit weights with 16-bit activations + +Additional recipes and formats will be supported in future releases. + +## Quantizing a Model + +### Installation ```bash -export QUANT_CONFIG=/path/to/quant/config/inc/meta-llama-3.1-405b-instruct/maxabs_measure_g3.json -vllm serve meta-llama/Llama-3.1-405B-Instruct --quantization inc --kv-cache-dtype fp8_inc --tensor-parallel-size 8 +uv pip install auto-round ``` -!!! tip - When using FP8 models, you may experience timeouts caused by the long compilation time of FP8 operations. To mitigate this problem, you can use the below environment variables: - `VLLM_ENGINE_ITERATION_TIMEOUT_S` - to adjust the vLLM server timeout. You can set the value in seconds, e.g., 600 equals 10 minutes. - `VLLM_RPC_TIMEOUT` - to adjust the RPC protocol timeout used by the OpenAI-compatible API. This value is in microseconds, e.g., 600000 equals 10 minutes. - -## Run Offline Inference Using FP8 +### Quantize with CLI -To run offline inference (after completing the model calibration process): +```bash +auto-round \ + --model Qwen/Qwen3-0.6B \ + --scheme W4A16 \ + --format auto_round \ + --output_dir ./tmp_autoround +``` -* Set the "QUANT_CONFIG" environment variable to point to a JSON configuration file with QUANTIZE mode. -* Pass `quantization=inc` and `kv_cache_dtype=fp8_inc` as parameters to the `LLM` object. -* Call shutdown method of the model_executor at the end of the run. +### Quantize with Python API ```python -from vllm import LLM -llm = LLM("llama3.1/Meta-Llama-3.1-8B-Instruct", quantization="inc", kv_cache_dtype="fp8_inc") -... -# Call llm.generate on the required prompts and sampling params. -... -llm.llm_engine.model_executor.shutdown() +from transformers import AutoModelForCausalLM, AutoTokenizer +from auto_round import AutoRound + +model_name = "Qwen/Qwen3-0.6B" +autoround = AutoRound(model_name, scheme="W4A16") + +# the best accuracy, 4-5X slower, low_gpu_mem_usage could save ~20G but ~30% slower +# autoround = AutoRound(model, tokenizer, nsamples=512, iters=1000, low_gpu_mem_usage=True, bits=bits, group_size=group_size, sym=sym) + +# 2-3X speedup, slight accuracy drop at W4G128 +# autoround = AutoRound(model, tokenizer, nsamples=128, iters=50, lr=5e-3, bits=bits, group_size=group_size, sym=sym ) + +output_dir = "./tmp_autoround" +# format= 'auto_round'(default), 'auto_gptq', 'auto_awq' +autoround.quantize_and_save(output_dir, format="auto_round") ``` -## Device for the Model's Weights Uploading +## Deploying AutoRound Quantized Models in vLLM -The unquantized weights are first loaded onto the CPU, then quantized and transferred to the target device (HPU) for model execution. -This reduces the device memory footprint of model weights, as only quantized weights are stored in the device memory. +```bash +vllm serve Intel/DeepSeek-R1-0528-Qwen3-8B-int4-AutoRound \ + --gpu-memory-utilization 0.8 \ + --max-model-len 4096 +``` + +!!! note + To deploy `wNa16` models on Intel GPU/CPU, please add `--enforce-eager` for now. + +## Evaluating the Quantized Model with vLLM + +```bash +lm_eval --model vllm \ + --model_args pretrained="Intel/DeepSeek-R1-0528-Qwen3-8B-int4-AutoRound,max_model_len=8192,max_num_batched_tokens=32768,max_num_seqs=128,gpu_memory_utilization=0.8,dtype=bfloat16,max_gen_toks=2048,enforce_eager=True" \ + --tasks gsm8k \ + --num_fewshot 5 \ + --batch_size 128 +``` diff --git a/tests/quantization/test_auto_round.py b/tests/quantization/test_auto_round.py index a2a1ebc014cb..9f5db8219501 100644 --- a/tests/quantization/test_auto_round.py +++ b/tests/quantization/test_auto_round.py @@ -26,9 +26,7 @@ ) @pytest.mark.parametrize("model", MODELS) def test_auto_round(vllm_runner, model): - with vllm_runner( - model, enforce_eager=True, allow_deprecated_quantization=True - ) as llm: + with vllm_runner(model, enforce_eager=True) as llm: output = llm.generate_greedy(["The capital of France is"], max_tokens=8) assert output print(f"{output[0][1]}") diff --git a/vllm/config/model.py b/vllm/config/model.py index f46918029571..166fc950c69f 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -884,6 +884,7 @@ def _verify_quantization(self) -> None: "gptq_bitblas", "awq_marlin", "ipex", + "inc", "moe_wna16", "modelopt", "modelopt_fp4", diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index b7f3969ee21a..82ca5403ff0a 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -223,10 +223,6 @@ def get_type_hints(type_hint: TypeHint) -> set[TypeHint]: return type_hints -def is_online_quantization(quantization: Any) -> bool: - return quantization in ["inc"] - - NEEDS_HELP = ( any("--help" in arg for arg in sys.argv) # vllm SUBCOMMAND --help or (argv0 := sys.argv[0]).endswith("mkdocs") # mkdocs SUBCOMMAND @@ -1304,7 +1300,6 @@ def create_load_config(self) -> LoadConfig: load_format=self.load_format, download_dir=self.download_dir, safetensors_load_strategy=self.safetensors_load_strategy, - device="cpu" if is_online_quantization(self.quantization) else None, model_loader_extra_config=self.model_loader_extra_config, ignore_patterns=self.ignore_patterns, use_tqdm_on_load=self.use_tqdm_on_load, diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index f0f8868c1a7e..718976eb18cc 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -33,7 +33,6 @@ "quark", "moe_wna16", "torchao", - "auto-round", "rtn", "inc", "mxfp4", @@ -54,7 +53,6 @@ "hqq", "experts_int8", "ipex", - "auto-round", "rtn", "petit_nvfp4", ] @@ -120,7 +118,6 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]: # lazy import to avoid triggering `torch.compile` too early from vllm.model_executor.layers.quantization.quark.quark import QuarkConfig - from .auto_round import AutoRoundConfig from .awq import AWQConfig from .awq_marlin import AWQMarlinConfig from .bitblas import BitBLASConfig @@ -174,8 +171,8 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]: "quark": QuarkConfig, "moe_wna16": MoeWNA16Config, "torchao": TorchAOConfig, - "auto-round": AutoRoundConfig, "rtn": RTNConfig, + "auto-round": INCConfig, "inc": INCConfig, "mxfp4": Mxfp4Config, "petit_nvfp4": PetitNvFp4Config, diff --git a/vllm/model_executor/layers/quantization/auto_round.py b/vllm/model_executor/layers/quantization/auto_round.py deleted file mode 100644 index 5d77d1e3c7dd..000000000000 --- a/vllm/model_executor/layers/quantization/auto_round.py +++ /dev/null @@ -1,454 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from fractions import Fraction -from typing import TYPE_CHECKING, Any - -import regex as re -import torch - -from vllm.logger import init_logger -from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod -from vllm.model_executor.layers.quantization import ( - QuantizationConfig, - QuantizationMethods, -) -from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead -from vllm.platforms import current_platform -from vllm.scalar_type import scalar_types - -if TYPE_CHECKING: - from vllm.model_executor.models.utils import WeightsMapper - -logger = init_logger(__name__) - - -class AutoRoundConfig(QuantizationConfig): - """Config class for AutoRound. - Reference: https://arxiv.org/pdf/2309.05516 - """ - - SUPPORTED_BITS = {2, 3, 4, 8} - SUPPORTED_DTYPES = {"int"} - SUPPORTED_FORMATS = {"auto_round:auto_gptq", "auto_round:auto_awq"} - SUPPORTED_BACKENDS = { - "auto", - "gptq", - "gptq:marlin", - "awq", - "awq:marlin", - "marlin", - "ipex", - } - - def __init__( - self, - weight_bits: int, - group_size: int, - sym: bool = True, - packing_format: str = "auto_round:auto_gptq", - block_name_to_quantize: str | list[str] | None = None, - extra_config: dict[str, Any] | None = None, - data_type: str = "int", - backend: str = "auto", - ) -> None: - super().__init__() - if weight_bits not in self.SUPPORTED_BITS: - raise ValueError( - f"Unsupported weight_bits: {weight_bits}, " - f"currently only support {self.SUPPORTED_BITS}." - ) - if data_type not in self.SUPPORTED_DTYPES: - raise ValueError( - f"Unsupported data_type: {data_type}, " - f"currently only support {self.SUPPORTED_DTYPES}." - ) - if packing_format not in self.SUPPORTED_FORMATS: - raise ValueError( - f"Unsupported packing_format: {packing_format}, " - f"currently only support {self.SUPPORTED_FORMATS}." - ) - if backend not in self.SUPPORTED_BACKENDS: - raise ValueError( - f"Unsupported backend: {backend}, " - f"currently only support {self.SUPPORTED_BACKENDS}." - ) - - self.weight_bits = weight_bits - self.group_size = group_size - self.sym = sym - self.packing_format = packing_format - self.block_name_to_quantize = ( - block_name_to_quantize.split(",") - if isinstance(block_name_to_quantize, str) - else block_name_to_quantize - ) - self.extra_config = extra_config - self.data_type = data_type - self.backend = backend - self.pack_factor = Fraction(32, weight_bits) - - def __repr__(self) -> str: - return ( - f"AutoRoundConfig(weight_bits={self.weight_bits}, " - f"group_size={self.group_size}, sym={self.sym})" - ) - - @classmethod - def get_name(cls) -> QuantizationMethods: - return "auto-round" - - @classmethod - def get_supported_act_dtypes(cls) -> list[torch.dtype]: - return [torch.half, torch.bfloat16] - - @classmethod - def get_min_capability(cls) -> int: - return 60 - - @classmethod - def get_config_filenames(cls) -> list[str]: - return ["quantization_config.json"] - - @classmethod - def from_config(cls, config: dict[str, Any]) -> "AutoRoundConfig": - return cls( - weight_bits=cls.get_from_keys(config, ["bits"]), - group_size=cls.get_from_keys(config, ["group_size"]), - sym=cls.get_from_keys(config, ["sym"]), - packing_format=cls.get_from_keys_or( - config, ["packing_format"], "auto_round:auto_gptq" - ), - block_name_to_quantize=cls.get_from_keys_or( - config, ["block_name_to_quantize", "to_quant_block_names"], None - ), - extra_config=cls.get_from_keys_or(config, ["extra_config"], None), - data_type=cls.get_from_keys_or(config, ["data_type"], "int"), - backend=cls.get_from_keys_or(config, ["backend", "vllm_backend"], "auto"), - ) - - def get_layer_config(self, layer, layer_name: str): - def get_config(name: str, quantized: bool = True): - if not self.extra_config: - return ( - self.weight_bits if quantized else 16, - self.group_size if quantized else -1, - self.sym if quantized else True, - ) - - # exact match first - if name in self.extra_config: - cfg = self.extra_config[name] - return ( - cfg.get("bits", self.weight_bits if quantized else 16), - cfg.get("group_size", self.group_size if quantized else -1), - cfg.get("sym", self.sym if quantized else True), - ) - - REGEX_SPECIAL_CHARS = set(r"*+?^$()[]{}|\\") - for pattern, cfg in self.extra_config.items(): - if not isinstance(pattern, str) or not any( - c in REGEX_SPECIAL_CHARS for c in pattern - ): - continue - - try: - if re.search(re.compile(pattern), name) is not None: - return ( - cfg.get("bits", self.weight_bits if quantized else 16), - cfg.get("group_size", self.group_size if quantized else -1), - cfg.get("sym", self.sym if quantized else True), - ) - except re.error: - # Invalid regex, ignore. - continue - - return ( - self.weight_bits if quantized else 16, - self.group_size if quantized else -1, - self.sym if quantized else True, - ) - - # 1. Exact match from config - if self.extra_config and layer_name in self.extra_config: - return get_config(layer_name) - - # 2. Determine whether layer should be quantized - quantized = not isinstance(layer, ParallelLMHead) - if self.block_name_to_quantize: - quantized = any( - layer_name.startswith(name) for name in self.block_name_to_quantize - ) - - # 3. Handle fused MoE - if self.extra_config and "fusedmoe" in layer.__class__.__name__.lower(): - moe_configs = [ - get_config(name, quantized) - for name in self.extra_config - if name.startswith(layer_name) - ] - if moe_configs: - if len(set(moe_configs)) == 1: - return moe_configs[0] - raise ValueError( - f"Fused MoE layer '{layer_name}' requires " - f"consistent quant config for all sub-layers" - ) - - # 4. Handle fused QKV or other patterns - if self.extra_config: - for fusion_key, sub_keys in self.packed_modules_mapping.items(): - if fusion_key in layer_name and layer_name.count(fusion_key) == 1: - sub_names = [ - layer_name.replace(fusion_key, sub_key) for sub_key in sub_keys - ] - sub_configs = [get_config(name, quantized) for name in sub_names] - if len(set(sub_configs)) == 1: - return sub_configs[0] - raise ValueError( - f"Fused module '{layer_name}' requires " - f"consistent quant config for {sub_names}" - ) - - # 5. Fallback or try a regular expression match - return get_config(layer_name, quantized) - - def check_quantized(self, weight_bits: int) -> bool: - return weight_bits < 16 - - def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"): - if self.block_name_to_quantize is not None: - self.block_name_to_quantize = hf_to_vllm_mapper.apply_list( - self.block_name_to_quantize - ) - if self.extra_config is not None: - self.extra_config = hf_to_vllm_mapper.apply_dict(self.extra_config) - - def apply_awq_quant_layer(self, layer, prefix: str, backend: str = "auto"): - from vllm.model_executor.layers.fused_moe import FusedMoE - from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - check_marlin_supported, - check_moe_marlin_supports_layer, - ) - - weight_bits, group_size, sym = self.get_layer_config(layer, prefix) - if not self.check_quantized(weight_bits): - if isinstance(layer, (LinearBase, ParallelLMHead)): - return UnquantizedLinearMethod() - else: - return None - - logger.debug( - "[%s] Type: %s, Bits: %s, Group Size: %s, Sym: %s", - prefix, - layer.__class__.__name__, - weight_bits, - group_size, - sym, - ) - if backend == "auto" or "marlin" in backend: - AWQ_TYPE_MAP = { - 4: scalar_types.uint4, - 8: scalar_types.uint8, - } - use_marlin = (weight_bits in AWQ_TYPE_MAP) and check_marlin_supported( - AWQ_TYPE_MAP[weight_bits], group_size, not sym - ) - - if isinstance(layer, FusedMoE): - use_marlin = use_marlin and check_moe_marlin_supports_layer( - layer, group_size - ) - - else: - use_marlin = False - if use_marlin: - from vllm.model_executor.layers.quantization.awq_marlin import ( - AWQMarlinConfig, - AWQMarlinLinearMethod, - AWQMarlinMoEMethod, - ) - - quant_args_marlin = AWQMarlinConfig( - weight_bits=weight_bits, - group_size=group_size, - zero_point=not sym, - lm_head_quantized=False, - full_config={}, - modules_to_not_convert=[], - ) - else: - from vllm.model_executor.layers.quantization.awq import ( - AWQConfig, - AWQLinearMethod, - ) - - quant_args = AWQConfig( - weight_bits=weight_bits, - group_size=group_size, - zero_point=not sym, - ) - - if isinstance(layer, FusedMoE): - if use_marlin: - return AWQMarlinMoEMethod(quant_args_marlin, layer.moe) - from vllm.model_executor.layers.quantization.moe_wna16 import MoeWNA16Config - - config = { - "quant_method": "awq", - "bits": weight_bits, - "group_size": group_size, - "zero_point": not sym, - "lm_head": False, - } - return MoeWNA16Config.from_config(config).get_quant_method(layer, prefix) - - if isinstance(layer, (LinearBase, ParallelLMHead)): - if use_marlin: - return AWQMarlinLinearMethod(quant_args_marlin) - else: - return AWQLinearMethod(quant_args) - return None - - def apply_gptq_quant_layer(self, layer, prefix: str, backend: str = "auto"): - from vllm.model_executor.layers.fused_moe import FusedMoE - from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - check_marlin_supported, - check_moe_marlin_supports_layer, - ) - - weight_bits, group_size, sym = self.get_layer_config(layer, prefix) - if not self.check_quantized(weight_bits): - if isinstance(layer, (LinearBase, ParallelLMHead)): - return UnquantizedLinearMethod() - else: - return None - - logger.debug( - "[%s] Type: %s, Bits: %s, Group Size: %s, Sym: %s", - prefix, - layer.__class__.__name__, - weight_bits, - group_size, - sym, - ) - if backend == "auto" or "marlin" in backend: - GPTQ_TYPE_MAP = { - (4, True): scalar_types.uint4b8, - (8, True): scalar_types.uint8b128, - } - use_marlin = (weight_bits, sym) in GPTQ_TYPE_MAP and check_marlin_supported( - GPTQ_TYPE_MAP[(weight_bits, sym)], group_size, has_zp=not sym - ) - if isinstance(layer, FusedMoE): - use_marlin = use_marlin and check_moe_marlin_supports_layer( - layer, group_size - ) - else: - use_marlin = False - if use_marlin: - from vllm.model_executor.layers.quantization.gptq_marlin import ( - GPTQMarlinConfig, - GPTQMarlinLinearMethod, - GPTQMarlinMoEMethod, - ) - - quant_args_marlin = GPTQMarlinConfig( - weight_bits=weight_bits, - group_size=group_size, - is_sym=sym, - lm_head_quantized=False, - desc_act=False, - dynamic={}, - full_config={}, - ) - else: - from vllm.model_executor.layers.quantization.gptq import ( - GPTQConfig, - GPTQLinearMethod, - ) - - quant_args = GPTQConfig( - weight_bits=weight_bits, - group_size=group_size, - lm_head_quantized=False, - desc_act=False, - dynamic={}, - ) - - if isinstance(layer, FusedMoE): - if use_marlin: - return GPTQMarlinMoEMethod(quant_args_marlin, layer.moe_config) - else: - from vllm.model_executor.layers.quantization.moe_wna16 import ( - MoeWNA16Config, - ) - - config = { - "quant_method": "gptq", - "bits": weight_bits, - "group_size": group_size, - "sym": sym, - "lm_head": False, - } - return MoeWNA16Config.from_config(config).get_quant_method( - layer, prefix - ) - - if isinstance(layer, (LinearBase, ParallelLMHead)): - if use_marlin: - return GPTQMarlinLinearMethod(quant_args_marlin) - else: - return GPTQLinearMethod(quant_args) - - return None - - def apply_ipex_quant_layer(self, layer, prefix: str): - weight_bits, group_size, sym = self.get_layer_config(layer, prefix) - if not self.check_quantized(weight_bits): - if isinstance(layer, (LinearBase, ParallelLMHead)): - return UnquantizedLinearMethod() - else: - return None - from vllm.model_executor.layers.quantization.ipex_quant import ( - IPEXAWQLinearMethod, - IPEXConfig, - IPEXGPTQLinearMethod, - ) - - if isinstance(layer, (LinearBase, ParallelLMHead)): - if "awq" in self.packing_format: - config = IPEXConfig( - method="awq", weight_bits=weight_bits, group_size=group_size - ) - return IPEXAWQLinearMethod(config) - elif "gptq" in self.packing_format: - config = IPEXConfig( - method="gptq", weight_bits=weight_bits, group_size=group_size - ) - return IPEXGPTQLinearMethod(config) - else: - raise ValueError( - f"ipex backend only supports awq " - f"and gtpq format,but got {self.packing_format}" - ) - else: - return None - - def get_quant_method(self, layer: torch.nn.Module, prefix: str): - if prefix and self.extra_config: - for layer_name in self.extra_config: - if ( - layer_name == prefix or layer_name == f"model.{prefix}" - ) and self.extra_config[layer_name].get("bits", 16) >= 16: - return UnquantizedLinearMethod() - if ( - current_platform.is_cpu() - or current_platform.is_xpu() - or self.backend == "ipex" - ): - return self.apply_ipex_quant_layer(layer, prefix) - if "gptq" in self.packing_format or "gptq" in self.backend: - return self.apply_gptq_quant_layer(layer, prefix) - if "awq" in self.packing_format or "awq" in self.backend: - return self.apply_awq_quant_layer(layer, prefix) diff --git a/vllm/model_executor/layers/quantization/inc.py b/vllm/model_executor/layers/quantization/inc.py index 4e736378e9da..63392b763f13 100644 --- a/vllm/model_executor/layers/quantization/inc.py +++ b/vllm/model_executor/layers/quantization/inc.py @@ -1,39 +1,98 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -# -# Intel Gaudi supports quantization of various modules and functions, -# including, but not limited to `Linear`, `KVCache`, `Matmul` and `Softmax`. -# During model loading, -# INC will patch layers with quantization/dequantization operators. -# Meanwhile, INC will convert original weight to target datatype -# and loading to target device. -# static scaling should be provided through Quant_CONFIG: -# `QUANT_CONFIG` is an environment variable, -# that points to the measurement or quantization JSON config file. -# The measurement configuration file is used during the calibration procedure, -# to collect measurements for a given model. -# The quantization configuration is used during inference. -# For more information, please refer to: -# https://docs.habana.ai/en/v1.21.1/PyTorch/vLLM_Inference/vLLM_FP8_Inference.html - -from typing import Any, Optional +from fractions import Fraction +from typing import TYPE_CHECKING, Any, Optional + +import regex as re import torch -from vllm.model_executor.layers.fused_moe.layer import ( - FusedMoE, - UnquantizedFusedMoEMethod, -) +from vllm.logger import init_logger from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod -from vllm.model_executor.layers.quantization import QuantizationMethods -from vllm.model_executor.layers.quantization.base_config import ( +from vllm.model_executor.layers.quantization import ( QuantizationConfig, - QuantizeMethodBase, + QuantizationMethods, ) +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.platforms import current_platform +from vllm.scalar_type import scalar_types + +if TYPE_CHECKING: + from vllm.model_executor.models.utils import WeightsMapper + +logger = init_logger(__name__) class INCConfig(QuantizationConfig): - """Config class for FP8 using Intel Neural Compressor.""" + """Config class for Intel Neural Compressor (INC). + Repo: https://github.com/intel/neural-compressor + """ + + SUPPORTED_BITS = {2, 3, 4, 8} + SUPPORTED_DTYPES = {"int"} + SUPPORTED_FORMATS = {"auto_round:auto_gptq", "auto_round:auto_awq"} + SUPPORTED_BACKENDS = { + "auto", + "gptq", + "gptq:marlin", + "awq", + "awq:marlin", + "marlin", + "ipex", + } + + def __init__( + self, + weight_bits: int, + group_size: int, + sym: bool = True, + packing_format: str = "auto_round:auto_gptq", + block_name_to_quantize: str | list[str] | None = None, + extra_config: dict[str, Any] | None = None, + data_type: str = "int", + backend: str = "auto", + ) -> None: + super().__init__() + if weight_bits not in self.SUPPORTED_BITS: + raise ValueError( + f"Unsupported weight_bits: {weight_bits}, " + f"currently only support {self.SUPPORTED_BITS}." + ) + if data_type not in self.SUPPORTED_DTYPES: + raise ValueError( + f"Unsupported data_type: {data_type}," + f" currently only support {self.SUPPORTED_DTYPES}." + ) + if packing_format not in self.SUPPORTED_FORMATS: + raise ValueError( + f"Unsupported packing_format: {packing_format}, " + f"currently only support {self.SUPPORTED_FORMATS}." + ) + if backend not in self.SUPPORTED_BACKENDS: + raise ValueError( + f"Unsupported backend: {backend}, " + f"currently only support {self.SUPPORTED_BACKENDS}." + ) + + self.weight_bits = weight_bits + self.group_size = group_size + self.sym = sym + self.packing_format = packing_format + self.block_name_to_quantize = ( + block_name_to_quantize.split(",") + if isinstance(block_name_to_quantize, str) + else block_name_to_quantize + ) + self.extra_config = extra_config + self.data_type = data_type + self.backend = backend + self.pack_factor = Fraction(32, weight_bits) + + def __repr__(self) -> str: + return ( + f"INCConfig(weight_bits={self.weight_bits}, " + f"group_size={self.group_size}, sym={self.sym})" + ) @classmethod def get_name(cls) -> QuantizationMethods: @@ -41,25 +100,365 @@ def get_name(cls) -> QuantizationMethods: @classmethod def get_supported_act_dtypes(cls) -> list[torch.dtype]: - return [torch.bfloat16] + return [torch.half, torch.bfloat16] + + @classmethod + def get_min_capability(cls) -> int: + return 60 + + @classmethod + def get_config_filenames(cls) -> list[str]: + return ["quantization_config.json"] @classmethod def from_config(cls, config: dict[str, Any]) -> "INCConfig": - raise AssertionError - - def get_quant_method( - self, layer: torch.nn.Module, prefix: str - ) -> Optional["QuantizeMethodBase"]: - if isinstance(layer, LinearBase): - return UnquantizedLinearMethod() - elif isinstance(layer, FusedMoE): - return UnquantizedFusedMoEMethod(layer.moe_config) + return cls( + weight_bits=cls.get_from_keys(config, ["bits"]), + group_size=cls.get_from_keys(config, ["group_size"]), + sym=cls.get_from_keys(config, ["sym"]), + packing_format=cls.get_from_keys_or( + config, ["packing_format"], "auto_round:auto_gptq" + ), + block_name_to_quantize=cls.get_from_keys_or( + config, ["block_name_to_quantize", "to_quant_block_names"], None + ), + extra_config=cls.get_from_keys_or(config, ["extra_config"], None), + data_type=cls.get_from_keys_or(config, ["data_type"], "int"), + backend=cls.get_from_keys_or(config, ["backend", "vllm_backend"], "auto"), + ) + + def get_layer_config(self, layer, layer_name: str): + def get_config(name: str, quantized: bool = True): + if not self.extra_config: + return ( + self.weight_bits if quantized else 16, + self.group_size if quantized else -1, + self.sym if quantized else True, + ) + + # exact match first + if name in self.extra_config: + cfg = self.extra_config[name] + return ( + cfg.get("bits", self.weight_bits if quantized else 16), + cfg.get("group_size", self.group_size if quantized else -1), + cfg.get("sym", self.sym if quantized else True), + ) + + REGEX_SPECIAL_CHARS = set(r"*+?^$()[]{}|\\") + for pattern, cfg in self.extra_config.items(): + if not isinstance(pattern, str) or not any( + c in REGEX_SPECIAL_CHARS for c in pattern + ): + continue + + try: + if re.search(re.compile(pattern), name) is not None: + return ( + cfg.get("bits", self.weight_bits if quantized else 16), + cfg.get("group_size", self.group_size if quantized else -1), + cfg.get("sym", self.sym if quantized else True), + ) + except re.error: + # Invalid regex, ignore. + continue + + return ( + self.weight_bits if quantized else 16, + self.group_size if quantized else -1, + self.sym if quantized else True, + ) + + # 1. Exact match from config + if self.extra_config and layer_name in self.extra_config: + return get_config(layer_name) + + # 2. Determine whether layer should be quantized + quantized = not isinstance(layer, ParallelLMHead) + if self.block_name_to_quantize: + quantized = any( + layer_name.startswith(name) for name in self.block_name_to_quantize + ) + + # 3. Handle fused MoE + if self.extra_config and "fusedmoe" in layer.__class__.__name__.lower(): + moe_configs = [ + get_config(name, quantized) + for name in self.extra_config + if name.startswith(layer_name) + ] + if moe_configs: + if len(set(moe_configs)) == 1: + return moe_configs[0] + raise ValueError( + f"Fused MoE layer '{layer_name}' requires " + f"consistent quant config for all sub-layers" + ) + + # 4. Handle fused QKV or other patterns + if self.extra_config: + for fusion_key, sub_keys in self.packed_modules_mapping.items(): + if fusion_key in layer_name and layer_name.count(fusion_key) == 1: + sub_names = [ + layer_name.replace(fusion_key, sub_key) for sub_key in sub_keys + ] + sub_configs = [get_config(name, quantized) for name in sub_names] + if len(set(sub_configs)) == 1: + return sub_configs[0] + raise ValueError( + f"Fused module '{layer_name}' requires " + f"consistent quant config for {sub_names}" + ) + + # 5. Fallback or try a regular expression match + return get_config(layer_name, quantized) + + def check_quantized(self, weight_bits: int) -> bool: + return weight_bits < 16 + + def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"): + if self.block_name_to_quantize is not None: + self.block_name_to_quantize = hf_to_vllm_mapper.apply_list( + self.block_name_to_quantize + ) + if self.extra_config is not None: + self.extra_config = hf_to_vllm_mapper.apply_dict(self.extra_config) + + def apply_awq_quant_layer(self, layer, prefix: str, backend: str = "auto"): + from vllm.model_executor.layers.fused_moe import FusedMoE + from vllm.model_executor.layers.quantization.utils.marlin_utils import ( + check_marlin_supported, + check_moe_marlin_supports_layer, + ) + + weight_bits, group_size, sym = self.get_layer_config(layer, prefix) + if not self.check_quantized(weight_bits): + if isinstance(layer, (LinearBase, ParallelLMHead)): + return UnquantizedLinearMethod() + else: + return None + + logger.debug( + "[%s] Type: %s, Bits: %s, Group Size: %s, Sym: %s", + prefix, + layer.__class__.__name__, + weight_bits, + group_size, + sym, + ) + if backend == "auto" or "marlin" in backend: + AWQ_TYPE_MAP = { + 4: scalar_types.uint4, + 8: scalar_types.uint8, + } + use_marlin = (weight_bits in AWQ_TYPE_MAP) and check_marlin_supported( + AWQ_TYPE_MAP[weight_bits], group_size, not sym + ) + + if isinstance(layer, FusedMoE): + use_marlin = use_marlin and check_moe_marlin_supports_layer( + layer, group_size + ) + + else: + use_marlin = False + if use_marlin: + from vllm.model_executor.layers.quantization.awq_marlin import ( + AWQMarlinConfig, + AWQMarlinLinearMethod, + AWQMarlinMoEMethod, + ) + + quant_args_marlin = AWQMarlinConfig( + weight_bits=weight_bits, + group_size=group_size, + zero_point=not sym, + lm_head_quantized=False, + full_config={}, + modules_to_not_convert=[], + ) + else: + from vllm.model_executor.layers.quantization.awq import ( + AWQConfig, + AWQLinearMethod, + ) + + quant_args = AWQConfig( + weight_bits=weight_bits, + group_size=group_size, + zero_point=not sym, + ) + + if isinstance(layer, FusedMoE): + if use_marlin: + return AWQMarlinMoEMethod(quant_args_marlin, layer.moe_config) + from vllm.model_executor.layers.quantization.moe_wna16 import MoeWNA16Config + + config = { + "quant_method": "awq", + "bits": weight_bits, + "group_size": group_size, + "zero_point": not sym, + "lm_head": False, + } + return MoeWNA16Config.from_config(config).get_quant_method(layer, prefix) + + if isinstance(layer, (LinearBase, ParallelLMHead)): + if use_marlin: + return AWQMarlinLinearMethod(quant_args_marlin) + else: + return AWQLinearMethod(quant_args) return None - @classmethod - def get_min_capability(cls) -> int: - raise AssertionError + def apply_gptq_quant_layer(self, layer, prefix: str, backend: str = "auto"): + from vllm.model_executor.layers.fused_moe import FusedMoE + from vllm.model_executor.layers.quantization.utils.marlin_utils import ( + check_marlin_supported, + check_moe_marlin_supports_layer, + ) + + weight_bits, group_size, sym = self.get_layer_config(layer, prefix) + if not self.check_quantized(weight_bits): + if isinstance(layer, (LinearBase, ParallelLMHead)): + return UnquantizedLinearMethod() + else: + return None + + logger.debug( + "[%s] Type: %s, Bits: %s, Group Size: %s, Sym: %s", + prefix, + layer.__class__.__name__, + weight_bits, + group_size, + sym, + ) + if backend == "auto" or "marlin" in backend: + GPTQ_TYPE_MAP = { + (4, True): scalar_types.uint4b8, + (8, True): scalar_types.uint8b128, + } + use_marlin = (weight_bits, sym) in GPTQ_TYPE_MAP and check_marlin_supported( + GPTQ_TYPE_MAP[(weight_bits, sym)], group_size, has_zp=not sym + ) + if isinstance(layer, FusedMoE): + use_marlin = use_marlin and check_moe_marlin_supports_layer( + layer, group_size + ) + else: + use_marlin = False + if use_marlin: + from vllm.model_executor.layers.quantization.gptq_marlin import ( + GPTQMarlinConfig, + GPTQMarlinLinearMethod, + GPTQMarlinMoEMethod, + ) - @staticmethod - def get_config_filenames() -> list[str]: - return [] + quant_args_marlin = GPTQMarlinConfig( + weight_bits=weight_bits, + group_size=group_size, + is_sym=sym, + lm_head_quantized=False, + desc_act=False, + dynamic={}, + full_config={}, + ) + else: + from vllm.model_executor.layers.quantization.gptq import ( + GPTQConfig, + GPTQLinearMethod, + ) + + quant_args = GPTQConfig( + weight_bits=weight_bits, + group_size=group_size, + lm_head_quantized=False, + desc_act=False, + dynamic={}, + ) + + if isinstance(layer, FusedMoE): + if use_marlin: + return GPTQMarlinMoEMethod(quant_args_marlin, layer.moe_config) + else: + from vllm.model_executor.layers.quantization.moe_wna16 import ( + MoeWNA16Config, + ) + + config = { + "quant_method": "gptq", + "bits": weight_bits, + "group_size": group_size, + "sym": sym, + "lm_head": False, + } + return MoeWNA16Config.from_config(config).get_quant_method( + layer, prefix + ) + + if isinstance(layer, (LinearBase, ParallelLMHead)): + if use_marlin: + return GPTQMarlinLinearMethod(quant_args_marlin) + else: + return GPTQLinearMethod(quant_args) + + return None + + def apply_ipex_quant_layer(self, layer, prefix: str): + weight_bits, group_size, sym = self.get_layer_config(layer, prefix) + if not self.check_quantized(weight_bits): + if isinstance(layer, (LinearBase, ParallelLMHead)): + return UnquantizedLinearMethod() + else: + return None + from vllm.model_executor.layers.quantization.ipex_quant import ( + IPEXAWQLinearMethod, + IPEXConfig, + IPEXGPTQLinearMethod, + ) + + if isinstance(layer, (LinearBase, ParallelLMHead)): + if "awq" in self.packing_format: + config = IPEXConfig( + method="awq", weight_bits=weight_bits, group_size=group_size + ) + return IPEXAWQLinearMethod(config) + elif "gptq" in self.packing_format: + config = IPEXConfig( + method="gptq", weight_bits=weight_bits, group_size=group_size + ) + return IPEXGPTQLinearMethod(config) + else: + raise ValueError( + f"ipex backend only supports awq " + f"and gptq format,but got {self.packing_format}" + ) + else: + return None + + def get_quant_method(self, layer: torch.nn.Module, prefix: str): + if prefix and self.extra_config: + for layer_name in self.extra_config: + if ( + layer_name == prefix or layer_name == f"model.{prefix}" + ) and self.extra_config[layer_name].get("bits", 16) >= 16: + return UnquantizedLinearMethod() + if ( + current_platform.is_cpu() + or current_platform.is_xpu() + or self.backend == "ipex" + ): + return self.apply_ipex_quant_layer(layer, prefix) + if "gptq" in self.packing_format or "gptq" in self.backend: + return self.apply_gptq_quant_layer(layer, prefix) + if "awq" in self.packing_format or "awq" in self.backend: + return self.apply_awq_quant_layer(layer, prefix) + + @classmethod + def override_quantization_method( + cls, hf_quant_cfg, user_quant + ) -> Optional["QuantizationMethods"]: + """Override the `auto-round` method to `inc`.""" + is_auto_round_format = hf_quant_cfg.get("quant_method", None) == "auto-round" + if is_auto_round_format: + return cls.get_name() + return None diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index 02f10eb2abbb..b20638c7eb28 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -233,7 +233,7 @@ def get_quant_config( quant_cls = get_quantization_config(model_config.quantization) # GGUF doesn't have config file - if model_config.quantization in ("gguf", "inc"): + if model_config.quantization == "gguf": return quant_cls() # Read the quantization config from the HF model config, if available.