Skip to content

Commit 53b018e

Browse files
authored
[Bugfix] Get available quantization methods from quantization registry (vllm-project#4098)
1 parent 66ded03 commit 53b018e

File tree

6 files changed

+18
-13
lines changed

6 files changed

+18
-13
lines changed

benchmarks/benchmark_latency.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from tqdm import tqdm
1010

1111
from vllm import LLM, SamplingParams
12+
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
1213

1314

1415
def main(args: argparse.Namespace):
@@ -101,7 +102,7 @@ def run_to_completion(profile_dir: Optional[str] = None):
101102
parser.add_argument('--tokenizer', type=str, default=None)
102103
parser.add_argument('--quantization',
103104
'-q',
104-
choices=['awq', 'gptq', 'squeezellm', None],
105+
choices=[*QUANTIZATION_METHODS, None],
105106
default=None)
106107
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
107108
parser.add_argument('--input-len', type=int, default=32)

benchmarks/benchmark_throughput.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
from transformers import (AutoModelForCausalLM, AutoTokenizer,
1111
PreTrainedTokenizerBase)
1212

13+
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
14+
1315

1416
def sample_requests(
1517
dataset_path: str,
@@ -267,7 +269,7 @@ def main(args: argparse.Namespace):
267269
parser.add_argument("--tokenizer", type=str, default=None)
268270
parser.add_argument('--quantization',
269271
'-q',
270-
choices=['awq', 'gptq', 'squeezellm', None],
272+
choices=[*QUANTIZATION_METHODS, None],
271273
default=None)
272274
parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
273275
parser.add_argument("--n",

tests/models/test_marlin.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,12 @@
1616
import pytest
1717
import torch
1818

19-
from vllm.model_executor.layers.quantization import (
20-
_QUANTIZATION_CONFIG_REGISTRY)
19+
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
2120

2221
capability = torch.cuda.get_device_capability()
2322
capability = capability[0] * 10 + capability[1]
24-
marlin_not_supported = (
25-
capability < _QUANTIZATION_CONFIG_REGISTRY["marlin"].get_min_capability())
23+
marlin_not_supported = (capability <
24+
QUANTIZATION_METHODS["marlin"].get_min_capability())
2625

2726

2827
@dataclass

vllm/config.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from transformers import PretrainedConfig
1010

1111
from vllm.logger import init_logger
12+
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
1213
from vllm.transformers_utils.config import get_config, get_hf_text_config
1314
from vllm.utils import (get_cpu_memory, get_nvcc_cuda_version, is_cpu, is_hip,
1415
is_neuron)
@@ -118,8 +119,8 @@ def _verify_tokenizer_mode(self) -> None:
118119
self.tokenizer_mode = tokenizer_mode
119120

120121
def _verify_quantization(self) -> None:
121-
supported_quantization = ["awq", "gptq", "squeezellm", "marlin"]
122-
rocm_not_supported_quantization = ["awq", "marlin"]
122+
supported_quantization = [*QUANTIZATION_METHODS]
123+
rocm_supported_quantization = ["gptq", "squeezellm"]
123124
if self.quantization is not None:
124125
self.quantization = self.quantization.lower()
125126

@@ -155,7 +156,7 @@ def _verify_quantization(self) -> None:
155156
f"Unknown quantization method: {self.quantization}. Must "
156157
f"be one of {supported_quantization}.")
157158
if is_hip(
158-
) and self.quantization in rocm_not_supported_quantization:
159+
) and self.quantization not in rocm_supported_quantization:
159160
raise ValueError(
160161
f"{self.quantization} quantization is currently not "
161162
f"supported in ROCm.")

vllm/engine/arg_utils.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
EngineConfig, LoadConfig, LoRAConfig, ModelConfig,
88
ParallelConfig, SchedulerConfig, SpeculativeConfig,
99
TokenizerPoolConfig, VisionLanguageConfig)
10+
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
1011
from vllm.utils import str_to_int_tuple
1112

1213

@@ -286,7 +287,7 @@ def add_cli_args(
286287
parser.add_argument('--quantization',
287288
'-q',
288289
type=str,
289-
choices=['awq', 'gptq', 'squeezellm', None],
290+
choices=[*QUANTIZATION_METHODS, None],
290291
default=EngineArgs.quantization,
291292
help='Method used to quantize the weights. If '
292293
'None, we first check the `quantization_config` '

vllm/model_executor/layers/quantization/__init__.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from vllm.model_executor.layers.quantization.marlin import MarlinConfig
88
from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig
99

10-
_QUANTIZATION_CONFIG_REGISTRY = {
10+
QUANTIZATION_METHODS = {
1111
"awq": AWQConfig,
1212
"gptq": GPTQConfig,
1313
"squeezellm": SqueezeLLMConfig,
@@ -16,12 +16,13 @@
1616

1717

1818
def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
19-
if quantization not in _QUANTIZATION_CONFIG_REGISTRY:
19+
if quantization not in QUANTIZATION_METHODS:
2020
raise ValueError(f"Invalid quantization method: {quantization}")
21-
return _QUANTIZATION_CONFIG_REGISTRY[quantization]
21+
return QUANTIZATION_METHODS[quantization]
2222

2323

2424
__all__ = [
2525
"QuantizationConfig",
2626
"get_quantization_config",
27+
"QUANTIZATION_METHODS",
2728
]

0 commit comments

Comments
 (0)