diff --git a/gptqmodel/models/loader.py b/gptqmodel/models/loader.py index ad66777ae..25c9a72d1 100644 --- a/gptqmodel/models/loader.py +++ b/gptqmodel/models/loader.py @@ -580,14 +580,14 @@ def build_layerwise_device_map( device_ids = list(range(num_gpus)) device_map: Dict[str, str] = {} mod2name = {m: n for n, m in model.named_modules()} - + if torch.cuda.is_available(): device_strs = [f"cuda:{i}" for i in range(num_gpus)] elif hasattr(torch, "xpu") and torch.xpu.is_available(): device_strs = [f"xpu:{i}" for i in range(num_gpus)] else: device_strs = ["cpu"] * num_gpus - + def assign(mod, device_id): if mod is None: return @@ -726,7 +726,7 @@ def assign(mod, device_id): ) if not _validate_machete_device_support(): raise ValueError( - f"Kernel: Machete kernel requires compute capability >= 9.0. Detected capability: {torch.cuda.get_device_capability()}" + f"Kernel: Machete kernel currently supports Hopper GPUs (SM 90). Detected capability: {torch.cuda.get_device_capability()}." ) if backend in [BACKEND.MARLIN, BACKEND.MARLIN_FP16] and ( diff --git a/gptqmodel/nn_modules/qlinear/awq_machete.py b/gptqmodel/nn_modules/qlinear/awq_machete.py deleted file mode 100644 index 622d7dd35..000000000 --- a/gptqmodel/nn_modules/qlinear/awq_machete.py +++ /dev/null @@ -1,200 +0,0 @@ -# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai -# SPDX-License-Identifier: Apache-2.0 -# Contact: qubitium@modelcloud.ai, x.com/qubitium - -from __future__ import annotations - -from typing import Optional, Tuple - -import torch - -from ...adapter.adapter import Adapter, Lora -from ...models._const import DEVICE, PLATFORM -from ...nn_modules.qlinear import AWQuantLinear -from ...utils.backend import BACKEND -from ...utils.logger import setup_logger -from ...utils.machete import ( - _validate_machete_device_support, - machete_import_exception, - machete_mm, - machete_prepack_B, - pack_quantized_values_into_int32, -) -from ...utils.marlin import replace_parameter, unpack_cols -from ...utils.marlin_scalar_type import scalar_types -from ...utils.rocm import IS_ROCM - - -log = setup_logger() - - -class AwqMacheteQuantLinear(AWQuantLinear): - SUPPORTS_BITS = [4, 8] - SUPPORTS_GROUP_SIZE = [-1, 32, 64, 128] - SUPPORTS_DESC_ACT = [False] # AWQ kernels do not reorder activations - SUPPORTS_SYM = [True, False] - SUPPORTS_SHARDS = True - SUPPORTS_TRAINING = False - SUPPORTS_AUTO_PADDING = False - SUPPORTS_IN_FEATURES_DIVISIBLE_BY = [64] - SUPPORTS_OUT_FEATURES_DIVISIBLE_BY = [128] - - SUPPORTS_DEVICES = [DEVICE.CUDA] - SUPPORTS_PLATFORM = [PLATFORM.LINUX] - SUPPORTS_PACK_DTYPES = [torch.int32] - SUPPORTS_ADAPTERS = [Lora] - - SUPPORTS_DTYPES = [torch.float16, torch.bfloat16] - - REQUIRES_FORMAT_V2 = False - - QUANT_TYPE = "awq_machete" - - TYPE_MAP = { - 4: scalar_types.uint4, - 8: scalar_types.uint8, - } - - def __init__( - self, - bits: int, - group_size: int, - desc_act: bool, - sym: bool, - in_features: int, - out_features: int, - bias: bool = False, - pack_dtype: torch.dtype = torch.int32, - adapter: Adapter = None, - register_buffers: bool = False, - **kwargs): - if machete_import_exception is not None: - raise ValueError( - "Trying to use the machete backend, but could not import the " - f"C++/CUDA dependencies with the following error: {machete_import_exception}" - ) - - if bits not in self.TYPE_MAP: - raise ValueError(f"Unsupported num_bits = {bits}. Supported: {list(self.TYPE_MAP.keys())}") - - super().__init__( - bits=bits, - group_size=group_size, - sym=sym, - desc_act=False, - in_features=in_features, - out_features=out_features, - bias=bias, - pack_dtype=pack_dtype, - backend=kwargs.pop("backend", BACKEND.MACHETE), - adapter=adapter, - register_buffers=register_buffers, - **kwargs) - - self.weight_type = self.TYPE_MAP[self.bits] - self.has_zero_points = True - - @classmethod - def validate(cls, **args) -> Tuple[bool, Optional[Exception]]: - if machete_import_exception is not None: - return False, ImportError(machete_import_exception) - return cls._validate(**args) - - @classmethod - def validate_device(cls, device: DEVICE): - super().validate_device(device) - if device == DEVICE.CUDA: - if IS_ROCM: - raise NotImplementedError("Machete kernel is not supported on ROCm.") - if not _validate_machete_device_support(): - raise NotImplementedError("Machete kernel requires compute capability >= 9.0.") - - def post_init(self): - device = self.qweight.device - - # Reconstruct integer weights from packed AWQ representation - qweight_int = unpack_cols( - self.qweight, - self.bits, - self.in_features, - self.out_features, - ).to(device=device) - - packed = pack_quantized_values_into_int32( - qweight_int, - self.weight_type, - packed_dim=0, - ) - packed = packed.t().contiguous().t() - prepacked = machete_prepack_B( - packed, - a_type=self.scales.dtype, - b_type=self.weight_type, - group_scales_type=self.scales.dtype, - ) - replace_parameter( - self, - "qweight", - torch.nn.Parameter(prepacked.contiguous(), requires_grad=False), - ) - - # Ensure scales are contiguous and resident on the correct device. - replace_parameter( - self, - "scales", - torch.nn.Parameter(self.scales.contiguous(), requires_grad=False), - ) - - # Convert zero-points: unpack columns, then pre-apply scales as expected by machete_mm - effective_group_size = self.in_features if self.group_size == -1 else self.group_size - num_groups = self.in_features // effective_group_size - - qzeros_unpacked = unpack_cols( - self.qzeros, - self.bits, - num_groups, - self.out_features, - ).to(device=device) - - scales = self.scales - qzeros_fp = (-1.0 * scales.to(dtype=scales.dtype) * qzeros_unpacked.to(scales.dtype)).contiguous() - replace_parameter( - self, - "qzeros", - torch.nn.Parameter(qzeros_fp, requires_grad=False), - ) - - if self.bias is not None: - self.bias = self.bias.to(device=device) - - super().post_init() - - def forward(self, x: torch.Tensor): - if x.shape[0] == 0: - return torch.empty((0, self.out_features), dtype=x.dtype, device=x.device) - - input_2d = x.reshape(-1, x.shape[-1]) - group_scales = self.scales.to(dtype=input_2d.dtype) - group_zeros = self.qzeros.to(dtype=input_2d.dtype) - - output = machete_mm( - a=input_2d, - b_q=self.qweight, - b_type=self.weight_type, - b_group_scales=group_scales, - b_group_zeros=group_zeros, - b_group_size=self.group_size, - ) - - if self.bias is not None: - output.add_(self.bias) - - result = output.reshape(x.shape[:-1] + (self.out_features,)) - - if self.adapter: - result = self.adapter.apply(x=x, out=result) - - return result - - -__all__ = ["AwqMacheteQuantLinear"] diff --git a/gptqmodel/nn_modules/qlinear/machete.py b/gptqmodel/nn_modules/qlinear/machete.py index 177979e29..4c729d9cf 100644 --- a/gptqmodel/nn_modules/qlinear/machete.py +++ b/gptqmodel/nn_modules/qlinear/machete.py @@ -2,10 +2,10 @@ # SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai # SPDX-License-Identifier: Apache-2.0 # Contact: qubitium@modelcloud.ai, x.com/qubitium - from __future__ import annotations -from typing import List, Optional, Tuple +import math +from typing import Dict, Optional, Tuple import torch @@ -20,13 +20,8 @@ machete_import_exception, machete_mm, machete_prepack_B, - pack_quantized_values_into_int32, - query_machete_supported_group_sizes, - unpack_quantized_values_into_int32, ) -from ...utils.marlin import replace_parameter -from ...utils.marlin_scalar_type import scalar_types -from ...utils.rocm import IS_ROCM +from ...utils.marlin_scalar_type import ScalarType, scalar_types log = setup_logger() @@ -35,50 +30,55 @@ class MacheteQuantLinear(BaseQuantLinear): SUPPORTS_BITS = [4, 8] SUPPORTS_GROUP_SIZE = [-1, 64, 128] - SUPPORTS_DESC_ACT = [True, False] + SUPPORTS_DESC_ACT = [False] SUPPORTS_SYM = [True] - SUPPORTS_SHARDS = True + SUPPORTS_SHARDS = False SUPPORTS_TRAINING = False SUPPORTS_AUTO_PADDING = False SUPPORTS_IN_FEATURES_DIVISIBLE_BY = [64] SUPPORTS_OUT_FEATURES_DIVISIBLE_BY = [128] - SUPPORTS_DEVICES = [DEVICE.CUDA] - SUPPORTS_PLATFORM = [PLATFORM.LINUX] SUPPORTS_PACK_DTYPES = [torch.int32] SUPPORTS_ADAPTERS = [Lora] + SUPPORTS_DEVICES = [DEVICE.CUDA] + SUPPORTS_PLATFORM = [PLATFORM.LINUX] + SUPPORTS_DTYPES = [torch.float16, torch.bfloat16] REQUIRES_FORMAT_V2 = False - QUANT_TYPE = "machete" - TYPE_MAP = { + TYPE_MAP: Dict[Tuple[int, bool], ScalarType] = { (4, True): scalar_types.uint4b8, (8, True): scalar_types.uint8b128, } def __init__( - self, - bits: int, - group_size: int, - desc_act: bool, - sym: bool, - in_features: int, - out_features: int, - bias: bool = False, - pack_dtype: torch.dtype = torch.int32, - register_buffers: bool = False, - adapter: Adapter = None, - **kwargs): + self, + bits: int, + group_size: int, + desc_act: bool, + sym: bool, + in_features: int, + out_features: int, + bias: bool = False, + pack_dtype: torch.dtype = torch.int32, + register_buffers: bool = False, + adapter: Adapter | None = None, + **kwargs, + ): if machete_import_exception is not None: raise ValueError( - "Trying to use the machete backend, but could not import the " - f"C++/CUDA dependencies with the following error: {machete_import_exception}" + "Trying to use the machete backend, but its CUDA extension could " + f"not be imported: {machete_import_exception}" ) + ok_shape, msg = check_machete_supports_shape(in_features, out_features) + if not ok_shape: + raise ValueError(msg) + if (bits, sym) not in self.TYPE_MAP: - raise ValueError(f"Unsupported quantization config: bits={bits}, sym={sym}") + raise ValueError(f"Unsupported quantization config bits={bits}, sym={sym}") super().__init__( bits=bits, @@ -92,200 +92,151 @@ def __init__( backend=kwargs.pop("backend", BACKEND.MACHETE), adapter=adapter, register_buffers=False, - **kwargs) + **kwargs, + ) - # Quantized weights (packed) + self.sym = sym + self.weight_type = self.TYPE_MAP[(self.bits, sym)] + + rows = self.in_features // self.pack_factor self.register_parameter( "qweight", torch.nn.Parameter( - torch.empty( - self.in_features // self.pack_factor, - self.out_features, - dtype=torch.int32, - ), + torch.empty(rows, self.out_features, dtype=torch.int32), requires_grad=False, ), ) - # Activation order indices + groups = max(1, math.ceil(self.in_features / self.group_size)) self.register_parameter( - "g_idx", + "scales", torch.nn.Parameter( - torch.empty(self.in_features, dtype=torch.int32), + torch.empty(groups, self.out_features, dtype=torch.float16), requires_grad=False, ), ) - # Scales - scales_rows = self.in_features if self.group_size == -1 else self.in_features // self.group_size self.register_parameter( - "scales", + "qzeros", torch.nn.Parameter( - torch.empty( - scales_rows, - self.out_features, - dtype=torch.float16, - ), + torch.empty(groups, self.out_features // self.pack_factor, dtype=torch.int32), requires_grad=False, ), ) - # Zero points unused for symmetric GPTQ self.register_parameter( - "qzeros", + "g_idx", torch.nn.Parameter( - torch.empty(0, dtype=torch.float16), - requires_grad=False, + torch.empty(self.in_features, dtype=torch.int32), requires_grad=False ), ) if bias: - self.register_buffer("bias", torch.zeros((self.out_features), dtype=torch.float16)) + self.register_parameter( + "bias", + torch.nn.Parameter( + torch.zeros(self.out_features, dtype=torch.float16), requires_grad=False + ), + ) else: self.bias = None - self.weight_type = self.TYPE_MAP[(self.bits, sym)] - self.has_zero_points = False - - # Buffer storing permutation applied to activations (empty when unused) - self.register_buffer("input_perm", torch.empty(0, dtype=torch.int32)) + self._prepacked_cache: Dict[torch.dtype, torch.Tensor] = {} @classmethod - def validate(cls, **args) -> Tuple[bool, Optional[Exception]]: + def validate(cls, **args): if machete_import_exception is not None: return False, ImportError(machete_import_exception) - - ok, err = cls._validate(**args) - if not ok: - return ok, err - in_features = args.get("in_features") out_features = args.get("out_features") if in_features is not None and out_features is not None: - supported, reason = check_machete_supports_shape(in_features, out_features) - if not supported: - return False, ValueError(reason) - - bits = args.get("bits") - sym = args.get("sym", True) - quant_type = cls.TYPE_MAP.get((bits, sym)) - if quant_type is None: - return False, ValueError(f"Machete does not support bits={bits}, sym={sym}") - + ok, msg = check_machete_supports_shape(in_features, out_features) + if not ok: + return False, ValueError(msg) group_size = args.get("group_size") - dtype = args.get("dtype", torch.float16) - if group_size not in query_machete_supported_group_sizes(dtype): - return False, ValueError( - f"Machete does not support group_size={group_size} for dtype={dtype}" + desc_act = args.get("desc_act") + dtype = args.get("pack_dtype", torch.int32) + if desc_act and desc_act not in cls.SUPPORTS_DESC_ACT: + return False, NotImplementedError("Machete does not support desc_act=True.") + if dtype not in cls.SUPPORTS_PACK_DTYPES: + return False, NotImplementedError( + f"Machete only supports pack_dtype=torch.int32, got {dtype}." ) - - return True, None + return cls._validate(**args) @classmethod def validate_device(cls, device: DEVICE): super().validate_device(device) - if device == DEVICE.CUDA: - if IS_ROCM: - raise NotImplementedError("Machete kernel is not supported on ROCm.") - if not _validate_machete_device_support(): - raise NotImplementedError("Machete kernel requires compute capability >= 9.0.") + if device == DEVICE.CUDA and not _validate_machete_device_support(): + raise NotImplementedError( + "Machete kernels require an NVIDIA Hopper (SM90+) GPU." + ) def post_init(self): - device = self.qweight.device - - perm = None - if self.desc_act: - perm = torch.argsort(self.g_idx).to(torch.int32) - sorted_g_idx = self.g_idx[perm] - replace_parameter( - self, - "g_idx", - torch.nn.Parameter(sorted_g_idx.to(device=device), requires_grad=False), + if not _validate_machete_device_support(): + raise RuntimeError( + "Machete kernel currently supports Hopper GPUs (SM90+) and CUDA." ) - self.input_perm = perm.to(device=device) - else: - self.input_perm = torch.empty(0, dtype=torch.int32, device=device) - qweight_unpacked = unpack_quantized_values_into_int32( - self.qweight.data, self.weight_type, packed_dim=0) - if perm is not None: - qweight_unpacked = qweight_unpacked[perm, :] + self._prepacked_cache.clear() - qweight_packed = pack_quantized_values_into_int32( - qweight_unpacked, self.weight_type, packed_dim=0) - qweight_packed = qweight_packed.t().contiguous().t() - prepacked = machete_prepack_B( - qweight_packed, - a_type=self.scales.dtype, - b_type=self.weight_type, - group_scales_type=self.scales.dtype, - ) - replace_parameter( - self, - "qweight", - torch.nn.Parameter(prepacked.contiguous(), requires_grad=False), - ) + super().post_init() - replace_parameter( - self, - "scales", - torch.nn.Parameter(self.scales.data.contiguous(), requires_grad=False), - ) + def _ensure_prepacked(self, act_dtype: torch.dtype) -> torch.Tensor: + cached = self._prepacked_cache.get(act_dtype) + if cached is not None and cached.device == self.qweight.device: + return cached - replace_parameter( - self, - "qzeros", - torch.nn.Parameter(torch.empty(0, dtype=self.scales.dtype, device=device), requires_grad=False), - ) - self.has_zero_points = False + group_scales_type = self.scales.dtype if self.scales is not None else None + weight = self.qweight.data + if weight.stride(0) != 1: + weight = weight.t().contiguous().t() - if self.bias is not None: - self.bias = self.bias.to(device=device) - - super().post_init() + prepacked = machete_prepack_B( + weight, + a_type=act_dtype, + b_type=self.weight_type, + group_scales_type=group_scales_type, + ).detach() - def list_buffers(self) -> List: - buf = super().list_buffers() - if hasattr(self, "input_perm") and self.input_perm is not None: - buf.append(self.input_perm) - return buf + self._prepacked_cache[act_dtype] = prepacked + return prepacked - def forward(self, x: torch.Tensor): + def forward(self, x: torch.Tensor) -> torch.Tensor: if x.shape[0] == 0: return torch.empty((0, self.out_features), dtype=x.dtype, device=x.device) - input_2d = x.reshape(-1, x.shape[-1]) + if machete_import_exception is not None: + raise RuntimeError(machete_import_exception) - if self.input_perm.numel() > 0: - perm = self.input_perm - if perm.device != input_2d.device: - perm = perm.to(device=input_2d.device) - input_2d = input_2d[:, perm] + act_dtype = x.dtype + if act_dtype not in self.SUPPORTS_DTYPES: + raise ValueError(f"Machete kernel does not support dtype {act_dtype}.") - group_scales = self.scales - if group_scales.dtype != input_2d.dtype: - group_scales = group_scales.to(dtype=input_2d.dtype) + x_2d = x.reshape(-1, x.shape[-1]) - group_zeros = self.qzeros if self.has_zero_points and self.qzeros.numel() > 0 else None + prepacked = self._ensure_prepacked(act_dtype) + group_scales = self.scales.to(dtype=act_dtype, device=x_2d.device) - output = machete_mm( - a=input_2d, - b_q=self.qweight, + output_2d = machete_mm( + a=x_2d.contiguous(), + b_q=prepacked, b_type=self.weight_type, b_group_scales=group_scales, - b_group_zeros=group_zeros, + b_group_zeros=None, b_group_size=self.group_size, + out_dtype=act_dtype, ) if self.bias is not None: - output.add_(self.bias) + output_2d = output_2d + self.bias.to(dtype=output_2d.dtype, device=output_2d.device) - result = output.reshape(x.shape[:-1] + (self.out_features,)) + output = output_2d.reshape(*x.shape[:-1], self.out_features) - if self.adapter: - result = self.adapter.apply(x=x, out=result) + if self.adapter is not None: + output = self.adapter.apply(x=x, out=output) - return result + return output __all__ = ["MacheteQuantLinear"] diff --git a/gptqmodel/nn_modules/qlinear/torch.py b/gptqmodel/nn_modules/qlinear/torch.py index d3275651b..523156e50 100644 --- a/gptqmodel/nn_modules/qlinear/torch.py +++ b/gptqmodel/nn_modules/qlinear/torch.py @@ -207,7 +207,7 @@ def _forward(self, x, out_shape): return out def _forward_eager(self, x: torch.Tensor, out_shape): - num_itr = self.g_idx.shape[0] // x.shape[-1] + num_itr = max(1, self.g_idx.shape[0] // x.shape[-1]) weights = self._consume_prefetched_weights(x.dtype) if weights is None: weights = self.dequantize_weight(num_itr=num_itr).to(x.dtype) diff --git a/gptqmodel/nn_modules/qlinear/torch_fused.py b/gptqmodel/nn_modules/qlinear/torch_fused.py index 633c324af..eff4c3485 100644 --- a/gptqmodel/nn_modules/qlinear/torch_fused.py +++ b/gptqmodel/nn_modules/qlinear/torch_fused.py @@ -203,7 +203,7 @@ def forward(self, x: torch.Tensor): return out def _forward(self, x, out_shape): - num_itr = self.g_idx.shape[0] // x.shape[-1] + num_itr = max(1, self.g_idx.shape[0] // x.shape[-1]) if not self.training and not self.transformed and TORCH_HAS_FUSED_OPS: # one-time transform per module for xpu aten fused ops diff --git a/gptqmodel/utils/importer.py b/gptqmodel/utils/importer.py index 3bb46f107..deb007dd5 100644 --- a/gptqmodel/utils/importer.py +++ b/gptqmodel/utils/importer.py @@ -18,7 +18,6 @@ from ..nn_modules.qlinear.awq_gemm import AwqGEMMQuantLinear from ..nn_modules.qlinear.awq_gemv import AwqGEMVQuantLinear from ..nn_modules.qlinear.awq_gemv_fast import AwqGEMVFastQuantLinear -from ..nn_modules.qlinear.awq_machete import AwqMacheteQuantLinear from ..nn_modules.qlinear.awq_marlin import AwqMarlinQuantLinear from ..nn_modules.qlinear.bitblas import BitBLASQuantLinear from ..nn_modules.qlinear.exllama import ExllamaQuantLinear @@ -62,7 +61,6 @@ BACKEND.QQQ: QQQQuantLinear, # qqq kernel based on marlin }), METHOD.AWQ: OrderedDict({ - BACKEND.MACHETE: AwqMacheteQuantLinear, BACKEND.MARLIN: AwqMarlinQuantLinear, BACKEND.EXLLAMA_V2: AwqExllamaV2QuantLinear, BACKEND.EXLLAMA_V1: AwqExllamaQuantLinear, @@ -83,10 +81,10 @@ FORMAT.QQQ: [BACKEND.QQQ], }, METHOD.AWQ: { - FORMAT.GEMM: [BACKEND.MACHETE, BACKEND.MARLIN, BACKEND.EXLLAMA_V2, BACKEND.EXLLAMA_V1, BACKEND.GEMM], + FORMAT.GEMM: [BACKEND.MARLIN, BACKEND.EXLLAMA_V2, BACKEND.EXLLAMA_V1, BACKEND.GEMM], FORMAT.GEMV: [BACKEND.GEMV], FORMAT.GEMV_FAST: [BACKEND.GEMV_FAST], - FORMAT.MARLIN: [BACKEND.MACHETE, BACKEND.MARLIN], + FORMAT.MARLIN: [BACKEND.MARLIN], } } @@ -286,7 +284,7 @@ def select_quant_linear( qlinear = BitBLASQuantLinear elif backend == BACKEND.MACHETE: if quant_method == METHOD.AWQ: - qlinear = AwqMacheteQuantLinear + qlinear = AwqMarlinQuantLinear else: qlinear = MacheteQuantLinear elif backend in [BACKEND.MARLIN, BACKEND.MARLIN_FP16]: diff --git a/gptqmodel/utils/machete.py b/gptqmodel/utils/machete.py index 57aaee535..f7e758530 100644 --- a/gptqmodel/utils/machete.py +++ b/gptqmodel/utils/machete.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai # SPDX-License-Identifier: Apache-2.0 # Contact: qubitium@modelcloud.ai, x.com/qubitium - from __future__ import annotations from typing import List, Optional @@ -11,6 +11,7 @@ from ._extension_loader import load_extension_module from .logger import setup_logger from .marlin_scalar_type import ScalarType, scalar_types +from .rocm import IS_ROCM log = setup_logger() @@ -18,158 +19,115 @@ machete_import_exception: Optional[str] = None try: gptqmodel_machete_kernels = load_extension_module("gptqmodel_machete_kernels") -except ImportError as e: # pragma: no cover - surfaced at runtime - machete_import_exception = str(e) - gptqmodel_machete_kernels = None +except ImportError as exc: # pragma: no cover - runtime guard + machete_import_exception = str(exc) + MACHETE_PREPACKED_BLOCK_SHAPE = (64, 128) def _validate_machete_device_support() -> bool: - return (torch.cuda.is_available() - and torch.cuda.get_device_capability()[0] >= 9) - - -def query_machete_supported_quant_types(zero_points: bool) -> List[ScalarType]: - if zero_points: - return [scalar_types.uint4, scalar_types.uint8] - return [scalar_types.uint4b8, scalar_types.uint8b128] - - -def query_machete_supported_act_types(_zero_points: bool) -> List[torch.dtype]: - return [torch.float16, torch.bfloat16] - - -def query_machete_supported_group_sizes(act_type: torch.dtype) -> List[int]: - if act_type in (torch.float16, torch.bfloat16): - return [-1, 64, 128] - return [-1, 128] - - -def check_machete_supports_shape(in_features: int, - out_features: int) -> tuple[bool, Optional[str]]: - if in_features % MACHETE_PREPACKED_BLOCK_SHAPE[0] != 0: - return (False, - f"Input features size must be divisible by {MACHETE_PREPACKED_BLOCK_SHAPE[0]}") - if out_features % MACHETE_PREPACKED_BLOCK_SHAPE[1] != 0: - return (False, - f"Output features size must be divisible by {MACHETE_PREPACKED_BLOCK_SHAPE[1]}") - return (True, None) - - -def _ensure_machete_loaded(): - if machete_import_exception is not None: - raise ImportError( - f"Trying to use the machete backend, but could not import the C++/CUDA dependencies: {machete_import_exception}" - ) - - -def _maybe_scalar_type(t: Optional[torch.Tensor]) -> Optional[torch.dtype]: - return t.dtype if t is not None else None + """ + Returns ``True`` when the active CUDA device can execute the Machete kernel. + """ + if not torch.cuda.is_available() or IS_ROCM: + return False + major, _minor = torch.cuda.get_device_capability() + return major >= 9 def machete_prepack_B( - weight: torch.Tensor, - a_type: torch.dtype, - b_type: ScalarType, - group_scales_type: Optional[torch.dtype]) -> torch.Tensor: - _ensure_machete_loaded() - return gptqmodel_machete_kernels.machete_prepack_B( - weight, - a_type, - b_type.id, - group_scales_type, - ) - - -def machete_supported_schedules( - a_type: torch.dtype, - b_type: ScalarType, - group_scales_type: Optional[torch.dtype] = None, - group_zeros_type: Optional[torch.dtype] = None, - channel_scales_type: Optional[torch.dtype] = None, - token_scales_type: Optional[torch.dtype] = None, - out_type: Optional[torch.dtype] = None) -> List[str]: - _ensure_machete_loaded() - return gptqmodel_machete_kernels.machete_supported_schedules( + B: torch.Tensor, + *, + a_type: torch.dtype, + b_type: ScalarType, + group_scales_type: Optional[torch.dtype], +) -> torch.Tensor: + if machete_import_exception is not None: + raise ImportError(machete_import_exception) + return torch.ops.gptqmodel_machete_kernels.machete_prepack_B( + B, a_type, b_type.id, group_scales_type, - group_zeros_type, - channel_scales_type, - token_scales_type, - out_type, ) def machete_mm( - *, - a: torch.Tensor, - b_q: torch.Tensor, - b_type: ScalarType, - b_group_scales: Optional[torch.Tensor] = None, - b_group_zeros: Optional[torch.Tensor] = None, - b_group_size: Optional[int] = None, - b_channel_scales: Optional[torch.Tensor] = None, - a_token_scales: Optional[torch.Tensor] = None, - out_type: Optional[torch.dtype] = None, - schedule: Optional[str] = None) -> torch.Tensor: - _ensure_machete_loaded() - return gptqmodel_machete_kernels.machete_mm( + *, + a: torch.Tensor, + b_q: torch.Tensor, + b_type: ScalarType, + b_group_scales: Optional[torch.Tensor], + b_group_zeros: Optional[torch.Tensor], + b_group_size: Optional[int], + out_dtype: Optional[torch.dtype] = None, +) -> torch.Tensor: + if machete_import_exception is not None: + raise ImportError(machete_import_exception) + return torch.ops.gptqmodel_machete_kernels.machete_mm( a, b_q, b_type.id, - out_type, + out_dtype, b_group_scales, b_group_zeros, b_group_size, - b_channel_scales, - a_token_scales, - schedule, + None, # channel scales currently unused + None, # token scales currently unused + None, # schedule hint ) -def pack_quantized_values_into_int32( - tensor: torch.Tensor, - qtype: ScalarType, - packed_dim: int = 0) -> torch.Tensor: - perm = tuple(i for i in range(tensor.ndim) if i != packed_dim) + (packed_dim,) - inv_perm = tuple(perm.index(i) for i in range(len(perm))) - temp = tensor.permute(perm) - - pack_factor = 32 // qtype.size_bits - mask = (1 << qtype.size_bits) - 1 - - assert temp.shape[-1] % pack_factor == 0 - new_shape = list(temp.shape) - new_shape[-1] //= pack_factor +def machete_supported_schedules( + *, + a_type: torch.dtype, + b_type: ScalarType, + b_group_scales_type: Optional[torch.dtype], + b_group_zeros_type: Optional[torch.dtype], + out_type: Optional[torch.dtype] = None, +) -> List[str]: + if machete_import_exception is not None: + raise ImportError(machete_import_exception) + return torch.ops.gptqmodel_machete_kernels.machete_supported_schedules( + a_type, + b_type.id, + b_group_scales_type, + b_group_zeros_type, + None, + None, + out_type, + ) - result = torch.zeros(new_shape, dtype=torch.int32, device=tensor.device) - for i in range(pack_factor): - result |= ((temp[..., i::pack_factor] & mask) << (qtype.size_bits * i)) - return result.permute(inv_perm) +def query_machete_supported_quant_types(include_zero_points: bool) -> List[ScalarType]: + if include_zero_points: + return [scalar_types.uint4, scalar_types.uint8] + return [scalar_types.uint4b8, scalar_types.uint8b128] -def unpack_quantized_values_into_int32( - tensor: torch.Tensor, - qtype: ScalarType, - packed_dim: int = 0) -> torch.Tensor: - perm = tuple(i for i in range(tensor.ndim) if i != packed_dim) + (packed_dim,) - inv_perm = tuple(perm.index(i) for i in range(len(perm))) - temp = tensor.permute(perm) +def query_machete_supported_act_types(_with_zero_points: bool) -> List[torch.dtype]: + return [torch.float16, torch.bfloat16] - pack_factor = 32 // qtype.size_bits - mask = (1 << qtype.size_bits) - 1 - new_shape = list(temp.shape) - new_shape[-1] *= pack_factor +def query_machete_supported_group_sizes(act_type: torch.dtype) -> List[int]: + if act_type in (torch.float16, torch.bfloat16): + return [-1, 64, 128] + return [-1, 128] - result = torch.zeros(new_shape, dtype=torch.int32, device=tensor.device) - for i in range(pack_factor): - result[..., i::pack_factor] = (temp >> (qtype.size_bits * i)) & mask - return result.permute(inv_perm) +def check_machete_supports_shape(in_features: int, out_features: int) -> tuple[bool, Optional[str]]: + if in_features % MACHETE_PREPACKED_BLOCK_SHAPE[0] != 0: + return ( + False, + f"in_features must be divisible by {MACHETE_PREPACKED_BLOCK_SHAPE[0]} (got {in_features})", + ) + if out_features % MACHETE_PREPACKED_BLOCK_SHAPE[1] != 0: + return ( + False, + f"out_features must be divisible by {MACHETE_PREPACKED_BLOCK_SHAPE[1]} (got {out_features})", + ) + return True, None __all__ = [ @@ -179,9 +137,7 @@ def unpack_quantized_values_into_int32( "machete_mm", "machete_prepack_B", "machete_supported_schedules", - "pack_quantized_values_into_int32", "query_machete_supported_act_types", "query_machete_supported_group_sizes", "query_machete_supported_quant_types", - "unpack_quantized_values_into_int32", ] diff --git a/gptqmodel/utils/model.py b/gptqmodel/utils/model.py index 9e51e6fc3..32a1ed122 100644 --- a/gptqmodel/utils/model.py +++ b/gptqmodel/utils/model.py @@ -396,6 +396,19 @@ def create_quant_module( if err is not None: raise err + # determine the dtype the original module expected to operate in + module_dtype = None + if isinstance(submodule, BaseQuantLinear): + module_dtype = getattr(submodule, "module_dtype", None) + if module_dtype is None and hasattr(submodule, "scales"): + module_dtype = getattr(submodule.scales, "dtype", None) + elif hasattr(submodule, "weight") and submodule.weight is not None: + module_dtype = submodule.weight.dtype + elif hasattr(submodule, "bias") and submodule.bias is not None: + module_dtype = submodule.bias.dtype + if module_dtype is None: + module_dtype = torch.float16 + new_layer = linear_cls( bits=tmp_bits, group_size=tmp_group_size, @@ -412,6 +425,7 @@ def create_quant_module( register_buffers=register_buffers, adapter=adapter, ) + setattr(new_layer, "module_dtype", module_dtype) new_layer.device = ori_layer_device recurse_setattr(module, name, new_layer.to(ori_layer_device)) diff --git a/gptqmodel_ext/machete/core/batch_invariant.hpp b/gptqmodel_ext/machete/core/batch_invariant.hpp new file mode 100644 index 000000000..fffe96b86 --- /dev/null +++ b/gptqmodel_ext/machete/core/batch_invariant.hpp @@ -0,0 +1,19 @@ +#pragma once +#include +#include +#include + +namespace vllm { + +// vllm_is_batch_invariant(); returns true +// if env VLLM_BATCH_INVARIANT=1 +inline bool vllm_is_batch_invariant() { + static bool cached = []() { + std::string env_key = "VLLM_BATCH_INVARIANT"; + const char* val = std::getenv(env_key.c_str()); + return (val && std::atoi(val) != 0) ? 1 : 0; + }(); + return cached; +} + +} // namespace vllm diff --git a/gptqmodel_ext/machete/core/exception.hpp b/gptqmodel_ext/machete/core/exception.hpp new file mode 100644 index 000000000..f3b2ffaef --- /dev/null +++ b/gptqmodel_ext/machete/core/exception.hpp @@ -0,0 +1,3 @@ +#pragma once + +#define VLLM_IMPLIES(p, q) (!(p) || (q)) diff --git a/gptqmodel_ext/machete/core/math.hpp b/gptqmodel_ext/machete/core/math.hpp new file mode 100644 index 000000000..6764e1fd6 --- /dev/null +++ b/gptqmodel_ext/machete/core/math.hpp @@ -0,0 +1,28 @@ +#pragma once + +#include +#include + +inline constexpr uint32_t next_pow_2(uint32_t const num) { + if (num <= 1) return num; + return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1)); +} + +template +static inline constexpr auto div_ceil(A a, B b) { + return (a + b - 1) / b; +} + +// Round a down to the next multiple of b. The caller is responsible for making +// sure that b is non-zero +template +inline constexpr T round_to_previous_multiple_of(T a, T b) { + return a % b == 0 ? a : (a / b) * b; +} + +// Round a up to the next multiple of b. The caller is responsible for making +// sure that b is non-zero +template +inline constexpr T round_to_next_multiple_of(T a, T b) { + return a % b == 0 ? a : ((a / b) + 1) * b; +} diff --git a/gptqmodel_ext/machete/core/scalar_type.hpp b/gptqmodel_ext/machete/core/scalar_type.hpp index 97078169d..68a8750f5 100644 --- a/gptqmodel_ext/machete/core/scalar_type.hpp +++ b/gptqmodel_ext/machete/core/scalar_type.hpp @@ -3,19 +3,8 @@ // For TORCH_CHECK #include -#include - namespace vllm { -template -inline To bit_cast_like(const From& src) noexcept { - static_assert(sizeof(To) == sizeof(From), - "bit_cast_like requires source and destination to be the same size"); - To dst{}; - std::memcpy(&dst, &src, sizeof(To)); - return dst; -} - // // ScalarType can represent a wide range of floating point and integer types, // in particular it can be used to represent sub-byte data types (something @@ -219,29 +208,30 @@ class ScalarType { // the exponent uint64_t double_raw = (max_mantissa << (52 - mantissa)) | (max_exponent_double << 52); - return bit_cast_like(double_raw); + + return *reinterpret_cast(&double_raw); } - std::variant _raw_max() const { + constexpr std::variant _raw_max() const { if (is_floating_point()) { return {_floating_point_max()}; } else { - TORCH_CHECK(size_bits() < 64 || (size_bits() == 64 && is_signed()), + TORCH_CHECK(size_bits() < 64 || size_bits() == 64 && is_signed(), "Cannot represent max as a int64_t"); return {(int64_t(1) << mantissa) - 1}; } } - std::variant _raw_min() const { + constexpr std::variant _raw_min() const { if (is_floating_point()) { TORCH_CHECK(is_signed(), "We currently assume all floating point types are signed"); constexpr uint64_t sign_bit_double = (uint64_t(1) << 63); double max = _floating_point_max(); - uint64_t max_raw = bit_cast_like(max); + uint64_t max_raw = *reinterpret_cast(&max); uint64_t min_raw = max_raw | sign_bit_double; - return {bit_cast_like(min_raw)}; + return {*reinterpret_cast(&min_raw)}; } else { TORCH_CHECK(!is_signed() || size_bits() <= 64, "Cannot represent min as a int64_t"); @@ -259,7 +249,7 @@ class ScalarType { public: // Max representable value for this scalar type. // (accounting for bias if there is one) - std::variant max() const { + constexpr std::variant max() const { return std::visit( [this](auto x) -> std::variant { return {x - bias}; }, _raw_max()); @@ -267,7 +257,7 @@ class ScalarType { // Min representable value for this scalar type. // (accounting for bias if there is one) - std::variant min() const { + constexpr std::variant min() const { return std::visit( [this](auto x) -> std::variant { return {x - bias}; }, _raw_min()); diff --git a/gptqmodel_ext/machete/generate.py b/gptqmodel_ext/machete/generate.py index 52bd805de..0a81eb494 100644 --- a/gptqmodel_ext/machete/generate.py +++ b/gptqmodel_ext/machete/generate.py @@ -5,6 +5,7 @@ import math import os import shutil +import sys from collections.abc import Iterable from copy import deepcopy from dataclasses import dataclass, fields @@ -13,29 +14,23 @@ import jinja2 -import sys - -_ROOT = Path(__file__).resolve().parents[2] -_CUTLASS_EXT_DIR = _ROOT / "gptqmodel_ext" / "cutlass_extensions" +_THIS_DIR = Path(__file__).resolve().parent +_CUTLASS_EXT_DIR = _THIS_DIR.parent / "cutlass_extensions" +if str(_CUTLASS_EXT_DIR) not in sys.path: + sys.path.insert(0, str(_CUTLASS_EXT_DIR)) _CUTLASS_ROOT = os.environ.get("GPTQMODEL_CUTLASS_DIR") -if _CUTLASS_ROOT is not None: - _CUTLASS_ROOT = Path(_CUTLASS_ROOT) -else: - _CUTLASS_ROOT = _ROOT / "cutlass" - -_CUTLASS_PYTHON_DIR = _CUTLASS_ROOT / "python" - -_CUTLASS_PYTHON_DIR.mkdir(parents=True, exist_ok=True) - -if str(_CUTLASS_EXT_DIR) not in sys.path: - sys.path.append(str(_CUTLASS_EXT_DIR)) -if _CUTLASS_PYTHON_DIR.exists() and str(_CUTLASS_PYTHON_DIR) not in sys.path: - sys.path.append(str(_CUTLASS_PYTHON_DIR)) -if not _CUTLASS_PYTHON_DIR.exists(): - raise RuntimeError( - "CUTLASS python bindings not found. Set GPTQMODEL_CUTLASS_DIR to a valid CUTLASS checkout." - ) +if not _CUTLASS_ROOT: + deps_dir = _THIS_DIR.parent.parent / "build" / "_deps" + if deps_dir.exists(): + candidates = sorted(deps_dir.glob("cutlass-v*/"), reverse=True) + if candidates: + _CUTLASS_ROOT = str(candidates[0]) + +if _CUTLASS_ROOT: + cutlass_scripts = Path(_CUTLASS_ROOT) / "tools" / "library" / "scripts" / "py" + if cutlass_scripts.exists() and str(cutlass_scripts) not in sys.path: + sys.path.insert(0, str(cutlass_scripts)) from vllm_cutlass_library_extension import ( DataType, diff --git a/gptqmodel_ext/machete/machete_pytorch.cu b/gptqmodel_ext/machete/machete_pytorch.cu index 05a51ee21..2f4463483 100644 --- a/gptqmodel_ext/machete/machete_pytorch.cu +++ b/gptqmodel_ext/machete/machete_pytorch.cu @@ -8,6 +8,15 @@ namespace machete { using namespace vllm; +TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { + m.def( + "machete_prepack_B(Tensor B, ScalarType a_type, int b_type_id, ScalarType? group_scales_type) -> Tensor"); + m.def( + "machete_mm(Tensor a, Tensor b, int b_type_id, ScalarType? out_type, Tensor? group_scales, Tensor? group_zeros, int? group_size, Tensor? channel_scales, Tensor? token_scales, str? schedule) -> Tensor"); + m.def( + "machete_supported_schedules(ScalarType a_type, int b_type_id, ScalarType? group_scales_type, ScalarType? group_zeros_type, ScalarType? channel_scales_type, ScalarType? token_scales_type, ScalarType? out_type) -> str[]"); +} + std::vector supported_schedules( at::ScalarType a_type, int64_t b_type_id, std::optional maybe_group_scales_type, @@ -71,3 +80,5 @@ TORCH_LIBRARY_IMPL(TORCH_EXTENSION_NAME, CatchAll, m) { } }; // namespace machete + +REGISTER_EXTENSION(TORCH_EXTENSION_NAME); diff --git a/pyproject.toml b/pyproject.toml index 555e88b71..a8f940e30 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,6 +53,7 @@ dependencies = [ "pyarrow>=21.0", "dill>=0.3.8", # datasets requirements "pypcre>=0.2.4", + "ninja>=1.13.0", # "cython>=3.1.4", # required by hf-xet/hf-transfer # "flash-attn>=2.8.3", <-- install for lower vram usage ] diff --git a/requirements.txt b/requirements.txt index 82d2ac6f1..995042d2c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -19,4 +19,5 @@ datasets>=3.6.0 pyarrow>=21.0 dill>=0.3.8 pypcre>=0.2.4 - +ninja>=1.13.0 +jinja2>=3.1.4 diff --git a/setup.py b/setup.py index 7fd343759..1b7de9134 100644 --- a/setup.py +++ b/setup.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai # SPDX-License-Identifier: Apache-2.0 # Contact: qubitium@modelcloud.ai, x.com/qubitium +import copy import os import re import subprocess @@ -11,11 +12,12 @@ from pathlib import Path from shutil import rmtree +from packaging.version import InvalidVersion, Version from setuptools import find_namespace_packages, find_packages, setup from setuptools.command.bdist_wheel import bdist_wheel as _bdist_wheel -CUTLASS_VERSION = "3.5.0" +CUTLASS_VERSION = "3.5.1" CUTLASS_RELEASE_URL = f"https://github.com/NVIDIA/cutlass/archive/refs/tags/v{CUTLASS_VERSION}.tar.gz" @@ -559,6 +561,13 @@ def _env_enabled_any(names, default="1") -> bool: cutlass_include_flags = [f"-I{path}" for path in cutlass_include_paths] extra_compile_args["cxx"] += cutlass_include_flags extra_compile_args["nvcc"] += cutlass_include_flags + # Blackwell (SM 120) requires explicit family-enabled targets in CUDA 12.9+ for FP4 instructions. + if CUDA_ARCH_LIST and _version_geq(CUDA_VERSION, 12, 9): + parsed_arches = _parse_arch_list(CUDA_ARCH_LIST) + if any(a.split("+", 1)[0] in {"12.0", "12"} for a in parsed_arches): + flag = "-gencode=arch=compute_120,code=sm_120f" + if flag not in extra_compile_args["nvcc"]: + extra_compile_args["nvcc"].append(flag) # Windows/OpenMP note: adjust flags as needed for MSVC if you add native Windows wheels if sys.platform == "win32": @@ -646,42 +655,57 @@ def _hipify_compile_flags(flags): ] if BUILD_MACHETE and HAS_CUDA_V9 and _version_geq(NVCC_VERSION, 12, 0): + machete_dir = Path("gptqmodel_ext/machete") + machete_generated_dir = machete_dir / "generated" + if machete_generated_dir.exists(): + rmtree(machete_generated_dir) + machete_generated_dir.mkdir(parents=True, exist_ok=True) + + generator_script = machete_dir / "generate.py" + if not generator_script.exists(): + raise RuntimeError("Machete generator script not found; expected gptqmodel_ext/machete/generate.py") try: - result = subprocess.run( - [sys.executable, "gptqmodel_ext/machete/generate.py"], + subprocess.run( + [sys.executable, str(generator_script)], check=True, text=True, - capture_output=True + capture_output=True, ) except subprocess.CalledProcessError as e: raise RuntimeError( - f"Error generating machete kernel templates:\n" + "Error generating machete kernel templates:\n" f"Return code: {e.returncode}\n" f"Stderr: {e.stderr}\n" f"Stdout: {e.stdout}" ) - machete_dir = Path("gptqmodel_ext/machete") - machete_generated_dir = machete_dir / "generated" machete_sources = [str(machete_dir / "machete_pytorch.cu")] machete_generated_sources = sorted(machete_generated_dir.glob("*.cu")) - if not machete_generated_sources: raise RuntimeError( - "No generated machete kernel templates detected. Run gptqmodel_ext/machete/generate.py" - " with CUTLASS checkout before building." + "No generated machete kernel templates detected. " + "Run gptqmodel_ext/machete/generate.py manually to diagnose generation issues." ) machete_sources += [str(path) for path in machete_generated_sources] machete_include_dirs = [str(Path("gptqmodel_ext").resolve())] + [str(path) for path in cutlass_include_paths] + machete_extra_compile_args = copy.deepcopy(extra_compile_args) + machete_arch_flags = [ + "--generate-code=arch=compute_90,code=sm_90", + "--generate-code=arch=compute_90a,code=sm_90a", + ] + for flag in machete_arch_flags: + if flag not in machete_extra_compile_args["nvcc"]: + machete_extra_compile_args["nvcc"].append(flag) + extensions += [ cpp_ext.CUDAExtension( "gptqmodel_machete_kernels", machete_sources, extra_link_args=extra_link_args, - extra_compile_args=extra_compile_args, + extra_compile_args=machete_extra_compile_args, include_dirs=machete_include_dirs, ) ] @@ -773,12 +797,6 @@ def _hipify_compile_flags(flags): ), ] - # Ensure machete kernels are compiled before other extensions - machete_exts = [ext for ext in extensions if getattr(ext, "name", "") == "gptqmodel_machete_kernels"] - if machete_exts: - other_exts = [ext for ext in extensions if getattr(ext, "name", "") != "gptqmodel_machete_kernels"] - extensions[:] = machete_exts + other_exts - # additional_setup_kwargs = { # "ext_modules": extensions, # "cmdclass": {"build_ext": cpp_ext.BuildExtension}, @@ -788,10 +806,10 @@ def _hipify_compile_flags(flags): "ext_modules": extensions, "cmdclass": {"build_ext": cpp_ext.BuildExtension.with_options( use_ninja=True, - no_python_abi_suffix=True, - build_temp="build/temp", - build_lib="build/lib", - clean_first=False # keep intermediates for reuse + no_python_abi_suffix=False, + # build_temp="build/temp", + # build_lib="build/lib", + clean_first=True # keep intermediates for reuse )}, } diff --git a/tests/models/model_test.py b/tests/models/model_test.py index c4302cddc..7c403eeed 100644 --- a/tests/models/model_test.py +++ b/tests/models/model_test.py @@ -84,7 +84,7 @@ class ModelTest(unittest.TestCase): TORCH_DTYPE = "auto" EVAL_BATCH_SIZE = "auto" QUANT_BATCH_SIZE = 1 - LOAD_BACKEND = BACKEND.MARLIN + LOAD_BACKEND = BACKEND.AUTO QUANT_BACKEND = BACKEND.AUTO USE_VLLM = False INPUTS_MAX_LENGTH = 2048 @@ -357,87 +357,132 @@ def run_eval_tasks(self, model, backend, trust_remote_code=False): return task_results def _current_load_backend(self): - effective = getattr(self, "_effective_load_backend", None) - if effective is not None and self.LOAD_BACKEND == BACKEND.MARLIN: - return effective + resolved = getattr(self, "_resolved_load_backend", None) + if self.LOAD_BACKEND == BACKEND.AUTO and resolved is not None: + return resolved return self.LOAD_BACKEND + def _detect_model_backend(self, model): + for _, module in model.named_modules(): + if isinstance(module, BaseQuantLinear): + backend = getattr(module, "backend", None) + if backend is None: + continue + if isinstance(backend, BACKEND): + return backend + try: + return BACKEND(str(backend)) + except Exception: + return None + return None + + def _is_machete_supported(self) -> bool: + try: + from gptqmodel.nn_modules.qlinear.machete import MacheteQuantLinear # type: ignore + except Exception: + return False + + try: + from gptqmodel.utils.machete import ( # type: ignore + _validate_machete_device_support, + machete_import_exception, + ) + except Exception: + return False + + if machete_import_exception: + return False + + if not _validate_machete_device_support(): + return False + + requested_bits = getattr(self, "BITS", None) + if requested_bits is not None: + machete_bits = tuple(getattr(MacheteQuantLinear, "SUPPORTS_BITS", ())) + if machete_bits and requested_bits not in machete_bits: + return False + + requested_group_size = getattr(self, "GROUP_SIZE", None) + if requested_group_size is not None: + machete_group_sizes = tuple(getattr(MacheteQuantLinear, "SUPPORTS_GROUP_SIZE", ())) + if machete_group_sizes and requested_group_size not in machete_group_sizes: + return False + + requested_sym = getattr(self, "SYM", None) + if requested_sym is not None: + machete_sym = tuple(getattr(MacheteQuantLinear, "SUPPORTS_SYM", ())) + if machete_sym and requested_sym not in machete_sym: + return False + + return True + def perform_post_quant_validation(self, model_path, trust_remote_code=False): inference_records = {} eval_records = {} reuse_candidates = {} - compare_backends = (BACKEND.MARLIN,) if self.FORMAT is FORMAT.GPTQ else (BACKEND.MARLIN, BACKEND.GEMM) - fallback_backend = None - if BACKEND.MARLIN in compare_backends: - try: - from gptqmodel.nn_modules.qlinear.marlin import MarlinQuantLinear # type: ignore - except Exception: # pragma: no cover - fallback if module unavailable - marlin_group_sizes = () - marlin_sym = () - else: - marlin_group_sizes = tuple(getattr(MarlinQuantLinear, "SUPPORTS_GROUP_SIZE", ())) - marlin_sym = tuple(getattr(MarlinQuantLinear, "SUPPORTS_SYM", ())) - - requested_group_size = getattr(self, "GROUP_SIZE", None) - requested_sym = getattr(self, "SYM", None) - - marlin_supported = True - if marlin_group_sizes and requested_group_size not in marlin_group_sizes: - marlin_supported = False - if marlin_sym and requested_sym not in marlin_sym: - marlin_supported = False - - if not marlin_supported: - fallback_backend = BACKEND.TORCH - compare_backends = tuple( - BACKEND.TORCH if backend == BACKEND.MARLIN else backend - for backend in compare_backends - ) - log.info( - f"Marlin backend unsupported for current quant config (group_size={requested_group_size}, sym={requested_sym}); " - "falling back to BACKEND.TORCH for validation." - ) - - if fallback_backend is not None and self.LOAD_BACKEND == BACKEND.MARLIN: - self._effective_load_backend = fallback_backend - else: - self._effective_load_backend = None + self._resolved_load_backend = None - target_backend = self._current_load_backend() - can_reuse = target_backend not in (BACKEND.AUTO, BACKEND.AUTO_TRAINABLE) - - for backend in compare_backends: - log.info(f"Loading post-quant model with backend `{backend.name}`") - # Pin post-quant loads to the first CUDA device to avoid auto sharding across GPUs. - use_cuda_map = torch.cuda.is_available() and backend != BACKEND.TORCH_FUSED + def _load_and_evaluate_backend(requested_backend: BACKEND, retain_model: bool = False): + log.info(f"Loading post-quant model with backend `{requested_backend.name}`") + use_cuda_map = torch.cuda.is_available() and requested_backend != BACKEND.TORCH_FUSED if use_cuda_map: model = self.loadQuantModel( model_path, trust_remote_code=trust_remote_code, - backend=backend, + backend=requested_backend, device_map={"": "cuda:0"}, ) else: model = self.loadQuantModel( model_path, trust_remote_code=trust_remote_code, - backend=backend, + backend=requested_backend, + ) + + resolved_backend = self._detect_model_backend(model) or requested_backend + if resolved_backend != requested_backend: + log.info( + f"Backend `{requested_backend.name}` resolved to `{resolved_backend.name}` during load" ) - tokenizer = model.tokenizer or self.load_tokenizer(model_path, trust_remote_code=trust_remote_code) - inference_records[backend] = self.run_generic_inference_checks(model, tokenizer, backend) - should_reuse = can_reuse and backend == target_backend and not self.USE_VLLM + tokenizer = model.tokenizer or self.load_tokenizer(model_path, trust_remote_code=trust_remote_code) + inference_summary = self.run_generic_inference_checks(model, tokenizer, resolved_backend) try: - eval_records[backend] = self.run_eval_tasks(model, backend, trust_remote_code=trust_remote_code) - finally: - if should_reuse: - reuse_candidates[backend] = model - else: - del model + eval_summary = self.run_eval_tasks(model, resolved_backend, trust_remote_code=trust_remote_code) + except Exception: + del model + torch_empty_cache() + raise + + if retain_model: + reuse_candidates[resolved_backend] = model + else: + del model torch_empty_cache() + return resolved_backend, inference_summary, eval_summary + + requested_backend = self.LOAD_BACKEND + retain_target = not self.USE_VLLM + target_backend, target_inference, target_eval = _load_and_evaluate_backend( + requested_backend, retain_model=retain_target + ) + inference_records[target_backend] = target_inference + eval_records[target_backend] = target_eval + self._resolved_load_backend = target_backend + + if target_backend == BACKEND.TORCH: + log.info("Target backend resolved to BACKEND.TORCH; skipping separate baseline comparison.") + else: + baseline_requested = BACKEND.TORCH + baseline_resolved, baseline_inference, baseline_eval = _load_and_evaluate_backend( + baseline_requested, retain_model=False + ) + inference_records[baseline_resolved] = baseline_inference + eval_records[baseline_resolved] = baseline_eval + self.render_inference_summary(inference_records) self.render_eval_summary(eval_records) @@ -526,7 +571,14 @@ def _colorize(text, matched): def render_inference_summary(self, inference_records): if not inference_records: return - ordered_backends = [backend for backend in (BACKEND.MARLIN, BACKEND.TORCH) if backend in inference_records] + backend_priority = ( + BACKEND.AUTO, + BACKEND.MACHETE, + BACKEND.MARLIN, + BACKEND.GEMM, + BACKEND.TORCH, + ) + ordered_backends = [backend for backend in backend_priority if backend in inference_records] if not ordered_backends: return @@ -587,7 +639,14 @@ def _format_inference_entry(self, entry): def render_eval_summary(self, eval_records): if not eval_records: return - ordered_backends = [backend for backend in (BACKEND.MARLIN, BACKEND.TORCH) if backend in eval_records] + backend_priority = ( + BACKEND.AUTO, + BACKEND.MACHETE, + BACKEND.MARLIN, + BACKEND.GEMM, + BACKEND.TORCH, + ) + ordered_backends = [backend for backend in backend_priority if backend in eval_records] if not ordered_backends: return @@ -628,7 +687,7 @@ def load_tokenizer(self, model_id_or_path, trust_remote_code=False): @classmethod def load_dataset(cls, tokenizer=None, rows: int = 0): try: - dataset = load_dataset(path="/monster/data/model/dataset/nm-calibration", name="LLM", split="train") + dataset = load_dataset(path="neuralmagic/calibration", name="LLM", split="train") except Exception as exc: # pragma: no cover - exercised in fallbacks log.warning("load_dataset failed; falling back to local parquet: %s", exc) dataset = cls._load_calibration_parquet() @@ -743,7 +802,7 @@ def quantModel(self, model_id_or_path, trust_remote_code=False, dtype="auto", ne tokenizer = model.tokenizer self._post_quant_eval_records = {} - self._effective_load_backend = None + self._resolved_load_backend = None is_image_to_text_model = MODALITY.IMAGE_TO_TEXT in model.modality calibration_dataset = get_calib_dataset(model) if is_image_to_text_model else self.load_dataset(tokenizer, self.DATASET_SIZE) diff --git a/tests/models/test_llama3_2.py b/tests/models/test_llama3_2.py index 5860f3133..d8b3ce25a 100644 --- a/tests/models/test_llama3_2.py +++ b/tests/models/test_llama3_2.py @@ -20,7 +20,7 @@ # | arc_challenge :: acc_norm,none | 0.3601 | # | mmlu :: acc,none | 0.3186 | class TestLlama3_2(ModelTest): - NATIVE_MODEL_ID = "/monster/data/model/Llama-3.2-1B-Instruct" # "meta-llama/Llama-3.2-1B-Instruct" + NATIVE_MODEL_ID = "unsloth/Llama-3.2-3B-Instruct" EVAL_TASKS = { EVAL.LM_EVAL.ARC_CHALLENGE: { "acc": { diff --git a/tests/test_awq.py b/tests/test_awq.py index 1afa83198..66c5938b6 100644 --- a/tests/test_awq.py +++ b/tests/test_awq.py @@ -18,11 +18,9 @@ from gptqmodel.nn_modules.qlinear.awq_gemm import AwqGEMMQuantLinear from gptqmodel.nn_modules.qlinear.awq_gemv import AwqGEMVQuantLinear from gptqmodel.nn_modules.qlinear.awq_gemv_fast import AwqGEMVFastQuantLinear -from gptqmodel.nn_modules.qlinear.awq_machete import AwqMacheteQuantLinear from gptqmodel.nn_modules.qlinear.awq_marlin import AwqMarlinQuantLinear from gptqmodel.quantization import FORMAT, METHOD, QUANT_CONFIG_FILENAME from gptqmodel.utils.machete import _validate_machete_device_support, machete_import_exception -from gptqmodel.utils.torch import torch_empty_cache os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" @@ -127,7 +125,6 @@ def tearDownClass(cls): @parameterized.expand([ (FORMAT.GEMM, BACKEND.GEMM, 128), - (FORMAT.GEMM, BACKEND.MACHETE, 128), (FORMAT.GEMM, BACKEND.MARLIN, 128), (FORMAT.GEMV, BACKEND.GEMV, 128), (FORMAT.GEMV_FAST, BACKEND.GEMV_FAST, 128), @@ -166,8 +163,6 @@ def assert_awq_linear(self, model, backend): for _, module in model.named_modules(): if backend == BACKEND.GEMM: linear = AwqGEMMQuantLinear - elif backend == BACKEND.MACHETE: - linear = AwqMacheteQuantLinear elif backend == BACKEND.MARLIN: linear = AwqMarlinQuantLinear elif backend == BACKEND.GEMV: diff --git a/tests/test_benchmark_gar.py b/tests/test_benchmark_gar.py index 4a2f74e51..616a1c5e9 100644 --- a/tests/test_benchmark_gar.py +++ b/tests/test_benchmark_gar.py @@ -6,8 +6,7 @@ import torch from tabulate import tabulate -from gptqmodel.quantization import gar -from gptqmodel.quantization import gar_ref +from gptqmodel.quantization import gar, gar_ref def _benchmark_fn(label, fn, device, warmup_runs=3, measured_runs=10): diff --git a/tests/test_machete_smoke.py b/tests/test_machete_smoke.py new file mode 100644 index 000000000..6e93c86ec --- /dev/null +++ b/tests/test_machete_smoke.py @@ -0,0 +1,47 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium +import pytest +import torch + +from gptqmodel.nn_modules.qlinear.machete import MacheteQuantLinear +from gptqmodel.utils.machete import ( + _validate_machete_device_support, + machete_import_exception, +) + + +@pytest.mark.cuda +def test_machete_forward_smoke(): + if machete_import_exception is not None: + pytest.skip(f"Machete extension unavailable: {machete_import_exception}") + if not torch.cuda.is_available() or not _validate_machete_device_support(): + pytest.skip("Machete smoke test requires a Hopper (SM90+) CUDA device") + + layer = MacheteQuantLinear( + bits=4, + group_size=128, + desc_act=False, + sym=True, + in_features=256, + out_features=512, + bias=True, + ).cuda() + + layer.qweight.data.zero_() + layer.scales.data.fill_(1.0) + layer.qzeros.data.zero_() + g_idx = ( + torch.arange(layer.in_features, dtype=torch.int32, device=layer.g_idx.device) + // layer.group_size + ) + layer.g_idx.data.copy_(g_idx) + layer.bias.data.zero_() + layer.post_init() + + x = torch.randn(4, layer.in_features, dtype=torch.float16, device="cuda") + out = layer(x) + + assert out.shape == (4, layer.out_features) + assert out.dtype == x.dtype + assert torch.all(torch.isfinite(out))