diff --git a/deepcompressor/quantizer/impl/base.py b/deepcompressor/quantizer/impl/base.py index 2b26f03..12069b1 100644 --- a/deepcompressor/quantizer/impl/base.py +++ b/deepcompressor/quantizer/impl/base.py @@ -108,7 +108,8 @@ def quantize( round_delta = kwargs.pop("round_delta", None) if round_delta is not None: round_delta = round_delta.view(-1, *shape[channels_dim:]) - result = self._quantize( + if hasattr(self.config.dtype, "name") and self.config.dtype.name in ["sfp4_e2m1_all", "sfp6_e2m3_all", "sfp6_e3m2_all"] and self.config.scale_dtypes[0].name in ["ufp8_e8m0_nan"]: + result = self._quantize_mx( tensor, kernel=kernel, scale=scale, @@ -122,7 +123,23 @@ def quantize( default_dtype=default_dtype or tensor.dtype, develop_dtype=develop_dtype, **kwargs, - ) + ) + else: + result = self._quantize( + tensor, + kernel=kernel, + scale=scale, + zero=zero, + dynamic_range=dynamic_range, + range_bound=range_bound, + quant_range=quant_range, + round_delta=round_delta, + return_with_dequant=return_with_dequant, + return_with_quant=return_with_quant, + default_dtype=default_dtype or tensor.dtype, + develop_dtype=develop_dtype, + **kwargs, + ) if result.data is not None: result._dequantized = result.data.view(shape) if result.qdata is not None: @@ -321,3 +338,109 @@ def update( config, tensor_shape, default_dtype, quant_range=quant_range, range_bound=range_bound ) return self.info + + def _quantize_mx( # noqa: C901 + self, + tensor: torch.Tensor, + *, + kernel: BaseQuantKernel | BaseQuantKernelConfig | None = None, + # scale-based quantization arguments + scale: torch.Tensor | tp.Sequence[torch.Tensor | None] | None = None, + zero: torch.Tensor | None = None, + # range-based quantization arguments + dynamic_range: DynamicRange | tp.Sequence[DynamicRange | None] | None = None, + # other arguments + range_bound: RangeBound | None = None, + quant_range: QuantRange | None = None, + round_delta: torch.Tensor | None = None, + return_with_dequant: bool = True, + return_with_quant: bool = False, + default_dtype: torch.dtype = torch.float16, + develop_dtype: torch.dtype = torch.float32, + **kwargs, + ) -> QuantTensor: + """Quantize a floating point tensor. + + Args: + tensor (`torch.Tensor`): + The floating-point tensor to be quantized. + kernel (`QuantKernel` or `QuantKernelConfig` or `None`, *optional*, defaults to `None`): + The quantization kernel or its configuration. + scale (`torch.Tensor` or `Sequence[torch.Tensor]` or `None`, *optional*, defaults to `None`): + The scale tensor. + zero (`torch.Tensor` or `None`, *optional*, defaults to `None`): + The zero point tensor. + dynamic_range (`DynamicRange` or `Sequence[DynamicRange]` or `None`, *optional*, defaults to `None`): + The dynamic range. + range_bound (`RangeBound` or `None`, *optional*, defaults to `None`): + The dynamic range bound. + quant_range (`QuantRange` or `None`, *optional*, defaults to `None`): + The quantization range. + return_with_dequant (`bool`, *optional*, defaults to `True`): + Whether to return with dequantized tensor. + return_with_quant (`bool`, *optional*, defaults to `False`): + Whether to return with quantized tensor. + default_dtype (`torch.dtype`, *optional*, defaults to `torch.float16`): + The default dtype for scale. + develop_dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + The develop dtype. + **kwargs: + Other keyword arguments for the quantization kernel. For example, + ``inputs`` for the input tensors in GPTQ kernel, + ``round_delta`` for the rounding delta in the RTN kernel. + + Returns: + `QuantTensor`: + The quantized tensor. + """ + shape, dtype = tensor.shape, tensor.dtype + self.update(shape, default_dtype, quant_range, range_bound) + if self.info is None or self.info.num_steps == 0: + return QuantTensor(dequantized=tensor, quantized=tensor, view_shape=shape) + assert self.info.num_steps == 1 + # region compute and quantize the scales and zero point for quantization + quant_scale = QuantScale() + develop_tensor = tensor.to(dtype=develop_dtype) if dtype != develop_dtype else tensor.clone() + + step_scale, step_zero = self.info.steps[0].scale.quantize_mx( + scale=scale, + zero=None, + tensor=develop_tensor, + dynamic_range=dynamic_range, + ) + quant_scale.append(step_scale) + quant_zero = step_zero + # endregion + # region quantize the tensor + assert isinstance(step_scale, QuantScale), "The last scale must be a QuantScale." + assert isinstance(step_zero, torch.Tensor), "The last zero point must be a tensor." + + assert not develop_tensor.isnan().any(), "Quantized tensor contains NaN." + assert not develop_tensor.isinf().any(), "Quantized tensor contains Inf." + # endregion + # region update the quantized tensor + quantized = None + + # endregion + # region update the dequantized tensor + + assert return_with_dequant + + from .mx import fake_quantize_mx + + develop_tensor = develop_tensor.reshape(self.info.steps[-1].tensor_view_shape) + + quantized = fake_quantize_mx(develop_tensor, step_scale.data, self.config.dtype.name) + + dequantized = quantized * step_scale.data + + dequantized = dequantized.view(shape).to(dtype=dtype) + + # endregion + return QuantTensor( + dequantized=dequantized, + quantized=quantized, + scale=quant_scale if return_with_quant else None, + zero=quant_zero if return_with_quant else None, + view_shape=self.info.steps[-1].tensor_view_shape if return_with_quant else None, + ) diff --git a/deepcompressor/quantizer/impl/mx.py b/deepcompressor/quantizer/impl/mx.py new file mode 100644 index 0000000..91dee21 --- /dev/null +++ b/deepcompressor/quantizer/impl/mx.py @@ -0,0 +1,68 @@ +import struct +import torch + +# Constants based on the provided macros +FLOAT32_EXP_BIAS = 127 +FLOAT32_EXP_MAX = 255 +FLOAT32_TRAILING_MBITS = 23 +FLOAT32_IMPLIED1 = (1 << FLOAT32_TRAILING_MBITS) +FLOAT32_FULL_MBITS = (FLOAT32_TRAILING_MBITS + 1) +FLOAT32_INF = 0x7fe00000 +FLOAT32_EXP_OFFSET = 23 +FLOAT32_SIGN_OFFSET = 31 +FLOAT32_EXP_MASK = 0x7f800000 +FLOAT32_MANTISSA_MASK = 0x007fffff + +# RoundMode is assumed to be an integer, you can define it based on your specific use case. +RoundMode = int # This can be further defined if you have specific enum values for rounding modes. + +from torch.utils.cpp_extension import load +import torch + +import os +os.environ['PATH'] = '/usr/lib/cuda/bin:' + os.environ['PATH'] + +extra_cflags = ["-DUSE_CUDA"] +extra_cuda_cflags = ["-DUSE_CUDA"] + +mx = load(name="mx", + sources=[ + "deepcompressor/quantizer/impl/mx/funcs.cpp", + "deepcompressor/quantizer/impl/mx/funcs.cu" + ], + extra_cuda_cflags=extra_cuda_cflags, + extra_cflags=extra_cflags, + extra_include_paths=[ + "/group/amdneuralopt/zhaofeng/tools/miniconda3/envs/deepcompressor/lib/python3.12/site-packages/triton/backends/nvidia/include/", + ], + verbose=True) + + +def get_dtype_params(dtype: str) -> tuple[int, int, int]: + if dtype == "fp6_e3m2_all": + ebits, mbits = 3, 2 + emax = 2**(ebits - 1) + elif dtype == "sfp6_e2m3_all": + ebits, mbits = 2, 3 + emax = 2**(ebits - 1) + elif dtype == "sfp4_e2m1_all": + ebits, mbits = 2, 1 + emax = 2**(ebits - 1) + else: + raise Exception("Unknown element format %s" % dtype) + + return ebits, mbits, emax + + + +def fake_quantize_mx(input_tensor, scale, element_dtype): + ebits, mbits, _ = get_dtype_params(element_dtype) + max_exp = pow(2.0, ebits) - 1 + offset_exp = pow(2.0, ebits - 1) - 1 + quant_max = pow(2.0, max_exp - offset_exp) * (1 + (pow(2.0, mbits) - 1) / (pow(2.0, mbits))) + + input_tensor = input_tensor / scale + + output_tensor = mx.fake_quantize_to_low_precision_fp(input_tensor.contiguous(), ebits, mbits, quant_max, 0) + + return output_tensor diff --git a/deepcompressor/quantizer/impl/mx/funcs.cpp b/deepcompressor/quantizer/impl/mx/funcs.cpp new file mode 100644 index 0000000..8464f0a --- /dev/null +++ b/deepcompressor/quantizer/impl/mx/funcs.cpp @@ -0,0 +1,55 @@ +// +// Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include +#include +#include "funcs.cuh" + +void fake_quantize_to_low_precision_fp_cpu( + float * input, + float * output, + uint32_t num_elements, + int ebits, + int mbits, + float max_norm, + RoundMode round_mode +) { +for (int i = 0; i < num_elements; i++) { + output[i] = fake_quantize_element(input[i], max_norm, ebits, mbits, round_mode); +} +} + +torch::Tensor fake_quantize_to_low_precision_fp( + torch::Tensor &input, + int ebits, + int mbits, + float max_norm, + uint32_t round_mode +) { +float * input_data = input.data_ptr(); +torch::Tensor output = torch::empty_like(input); +float * output_data = output.data_ptr(); +#ifdef USE_CUDA +if (input.is_cpu()) { + fake_quantize_to_low_precision_fp_cpu(input_data, output_data, input.numel(), ebits, mbits, max_norm, (RoundMode)round_mode); +} else { + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); + fake_quantize_to_low_precision_fp_cuda(input_data, output_data, input.numel(), ebits, mbits, max_norm, (RoundMode)round_mode); +} +#else +fake_quantize_to_low_precision_fp_cpu(input_data, output_data, input.numel(), ebits, mbits, max_norm, (RoundMode)round_mode); +#endif +return output; +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("fake_quantize_to_low_precision_fp", &fake_quantize_to_low_precision_fp, "fake_quantize_to_low_precision_fp", + py::arg("input"), + py::arg("ebits"), + py::arg("mbits"), + py::arg("max_norm"), + py::arg("round_mode") + ); + } \ No newline at end of file diff --git a/deepcompressor/quantizer/impl/mx/funcs.cu b/deepcompressor/quantizer/impl/mx/funcs.cu new file mode 100644 index 0000000..de62c0c --- /dev/null +++ b/deepcompressor/quantizer/impl/mx/funcs.cu @@ -0,0 +1,37 @@ +// +// Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. +// SPDX-License-Identifier: MIT +// +#include +#include "funcs.cuh" + +__global__ void fake_quantize_kernel( + float * input, + float * output, + uint32_t num_elements, + int ebits, + int mbits, + float max_norm, + RoundMode round_mode +) { + int index = blockDim.x * blockIdx.x + threadIdx.x; + if (index >= num_elements) { + return; + } + output[index] = fake_quantize_element(input[index], max_norm, ebits, mbits, round_mode); +} + +void fake_quantize_to_low_precision_fp_cuda( + float * input, + float * output, + uint32_t num_elements, + int ebits, + int mbits, + float max_norm, + RoundMode round_mode +) { + int blocks = num_elements % THREADS_PER_BLOCK == 0 ? num_elements / THREADS_PER_BLOCK : num_elements / THREADS_PER_BLOCK + 1; + fake_quantize_kernel<<>>( + input, output, num_elements, ebits, mbits, max_norm, round_mode + ); +} diff --git a/deepcompressor/quantizer/impl/mx/funcs.cuh b/deepcompressor/quantizer/impl/mx/funcs.cuh new file mode 100644 index 0000000..5c15c0e --- /dev/null +++ b/deepcompressor/quantizer/impl/mx/funcs.cuh @@ -0,0 +1,148 @@ +// +// Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include + +#ifdef USE_CUDA +#include +#include +#include +#endif + +#define FLOAT32_EXP_BIAS 127 +#define FLOAT32_EXP_MAX 255 +#define FLOAT32_TRAILING_MBITS 23 +#define FLOAT32_IMPLIED1 (1 << FLOAT32_TRAILING_MBITS) +#define FLOAT32_FULL_MBITS (FLOAT32_TRAILING_MBITS + 1) +#define FLOAT32_INF 0x7fe00000 +#define FLOAT32_EXP_OFFSET 23 +#define FLOAT32_SIGN_OFFSET 31 +#define FLOAT32_EXP_MASK 0x7f800000 +#define FLOAT32_MANTISSA_MASK 0x007fffff + +#define THREADS_PER_BLOCK 32 + +enum RoundMode { + ROUND_HALF_TO_EVEN = 8 +}; + +union u_float_int { + float float_val; + uint32_t int_val; +}; + +#ifdef USE_CUDA +__host__ __device__ __forceinline__ +#else +inline +#endif +int get_exponent(float f) { + u_float_int u; + u.float_val = f; + u.int_val &= FLOAT32_EXP_MASK; + return u.int_val >> FLOAT32_TRAILING_MBITS; +} + +#ifdef USE_CUDA +__host__ __device__ __forceinline__ +#else +inline +#endif +uint32_t get_mantissa(float f) { + u_float_int u; + u.float_val = f; + return u.int_val &= FLOAT32_MANTISSA_MASK; +} + +#ifdef USE_CUDA +__host__ __device__ __forceinline__ +#else +inline +#endif +uint32_t get_sign(float f) { + u_float_int u; + u.float_val = f; + return u.int_val >> FLOAT32_SIGN_OFFSET; +} + +#ifdef USE_CUDA +__host__ __device__ __forceinline__ +#else +inline +#endif +uint32_t shift_and_round(uint32_t mantissa, int tail_bits, RoundMode round_mode) { + if (tail_bits == 0) return mantissa; + if (tail_bits > 25) return 0; + uint32_t half = 1 << (tail_bits - 1); + uint32_t tail = mantissa & ((1 << tail_bits) - 1); + uint32_t ret = mantissa >> tail_bits; + if (tail < half) return ret; + else if (tail > half) return ret + 1; + else return (ret) % 2 == 1 ? ret + 1 : ret; +} + +#ifdef USE_CUDA +__host__ __device__ __forceinline__ +#else +inline +#endif +float construct_float(uint32_t sign, uint32_t exponent, uint32_t mantissa) { + u_float_int u; + u.int_val = (sign << FLOAT32_SIGN_OFFSET) + (exponent << FLOAT32_EXP_OFFSET) + (mantissa & FLOAT32_MANTISSA_MASK); + return u.float_val; +} + +#ifdef USE_CUDA +__host__ __device__ __forceinline__ +#else +inline +#endif +float fake_quantize_element( + float element, + float max_norm, + int ebits, + int mbits, + RoundMode round_mode +) { + int exp = get_exponent(element); + if (exp == FLOAT32_EXP_MAX) return element; + int new_bias = (1 << (ebits - 1)) - 1; + uint32_t mantissa = get_mantissa(element); + int mantissa_bits = FLOAT32_TRAILING_MBITS; + if (exp != 0) { + mantissa = (mantissa | FLOAT32_IMPLIED1); + mantissa_bits++; + } + + int new_exp = exp - FLOAT32_EXP_BIAS + new_bias; + int exp_shift = new_exp > 0 ? 0 : 1 - new_exp; + + int tail_bits = FLOAT32_TRAILING_MBITS - mbits + exp_shift; + mantissa = shift_and_round(mantissa, tail_bits, round_mode); + if (mantissa == 0) return 0.0; + mantissa = mantissa << tail_bits; + if (mantissa >= (1 << mantissa_bits)) { + if (exp != 0) mantissa = mantissa >> 1; + exp++; + } + float absolute_ret = construct_float(0, exp, mantissa); + if (absolute_ret > max_norm) absolute_ret = max_norm; + + u_float_int u; + u.float_val = absolute_ret; + u.int_val += (get_sign(element) << FLOAT32_SIGN_OFFSET); + return u.float_val; +} + +void fake_quantize_to_low_precision_fp_cuda( + float * input, + float * output, + uint32_t num_elements, + int ebits, + int mbits, + float max_norm, + RoundMode round_mode); diff --git a/deepcompressor/quantizer/impl/scale.py b/deepcompressor/quantizer/impl/scale.py index cb09af0..27a61a2 100644 --- a/deepcompressor/quantizer/impl/scale.py +++ b/deepcompressor/quantizer/impl/scale.py @@ -13,6 +13,7 @@ from ...data.utils import ScaleUtils from ...data.zero import ZeroPointDomain from .simple import simple_quantize +from .mx import get_dtype_params __all__ = ["quantize_scale", "QuantScaleInfo"] @@ -240,3 +241,85 @@ def quantize( assert not z.isinf().any(), "Zero point tensor contains Inf." # endregion return s, z + + def even_round(self, max_abs: torch.Tensor, mbits, emax) -> torch.Tensor: + f32_min_normal = 2 ** (-127 + 1) + eps = f32_min_normal * (max_abs == 0).type(max_abs.dtype) + + nan_mask = torch.isnan(max_abs) + max_abs = max_abs.to(torch.float32).view(torch.int32) + val_to_add = 1 << (23 - mbits - 1) + mask = ((1 << (8 + 1)) - 1) << 23 + max_abs = (max_abs + val_to_add) & mask + max_abs = max_abs.view(torch.float32) + max_abs[nan_mask] = torch.tensor(float("nan"), device=max_abs.device) + scale_e8m0_unbiased = torch.floor(torch.log2(max_abs + eps)) - emax + scale_e8m0_unbiased = torch.clamp(scale_e8m0_unbiased, min=-127, max=127) + scale_float = torch.pow(2, scale_e8m0_unbiased) + return scale_float + + def quantize_mx( + self, + *, + # scale-based quantization related arguments + scale: torch.Tensor | None = None, + zero: torch.Tensor | None = None, + # range-based quantization related arguments + tensor: torch.Tensor | None = None, + dynamic_range: DynamicRange | None = None, + ) -> tuple[QuantScale, torch.Tensor]: + """Get the quantization scale and zero point of the tensor to be quantized. + + Args: + scale (`torch.Tensor` or `None`, *optional*, defaults to `None`): + The scale tensor. + zero (`torch.Tensor` or `None`, *optional*, defaults to `None`): + The zero point tensor. + tensor (`torch.Tensor` or `None`, *optional*, defaults to `None`): + Ten tensor to be quantized. This is only used for range-based quantization. + dynamic_range (`DynamicRange` or `None`, *optional*, defaults to `None`): + The dynamic range of the tensor to be quantized. + + Returns: + `tuple[QuantScale, torch.Tensor]`: + The scale and the zero point. + """ + # region step 1: get the dynamic span for range-based scale or the scale tensor + # breakpoint() + assert scale is None + + assert isinstance(tensor, torch.Tensor), "View tensor must be a tensor." + dynamic_range = dynamic_range or DynamicRange() + dynamic_range = dynamic_range.measure( + tensor.view(self.tensor_view_shape), + zero_domain=self.tensor_zero_domain, + is_float_point=self.tensor_quant_dtype.is_float_point, + ) + dynamic_range = dynamic_range.intersect(self.tensor_range_bound) + dynamic_span = dynamic_range.max + + # endregion + # region step 2: get the scale + # breakpoint() + ebits, mbits, emax = get_dtype_params(self.tensor_quant_dtype.name) + lin_s = self.even_round(dynamic_span, mbits, emax) + # lin_s = torch.pow(2, torch.floor(torch.log2(dynamic_span)) - emax) + eps = torch.finfo(torch.float32).eps + lin_s = lin_s.masked_fill(lin_s == 0.0, eps) + lin_s = lin_s.to(torch.float32) + lin_s = QuantScale().append(lin_s) + assert lin_s.data is not None, "ufp8_e8m0_nan scale tensor is None." + assert not lin_s.data.isnan().any(), "ufp8_e8m0_nan scale tensor contains NaN." + + s = lin_s + assert s.data is not None, "Scale tensor is None." + assert not s.data.isnan().any(), "Scale tensor contains NaN." + assert not s.data.isinf().any(), "Scale tensor contains Inf." + # endregion + # region step 3: get the zero point + z = torch.tensor(0, dtype=s.data.dtype, device=s.data.device) + assert not z.isnan().any(), "Zero point tensor contains NaN." + assert not z.isinf().any(), "Zero point tensor contains Inf." + # endregion + # breakpoint() + return s, z \ No newline at end of file diff --git a/examples/diffusion/configs/svdquant/mxfp4.yaml b/examples/diffusion/configs/svdquant/mxfp4.yaml new file mode 100644 index 0000000..49f8ec3 --- /dev/null +++ b/examples/diffusion/configs/svdquant/mxfp4.yaml @@ -0,0 +1,22 @@ +quant: + wgts: + dtype: sfp4_e2m1_all + group_shapes: + - - 1 + - 32 + - 1 + - 1 + - 1 + scale_dtypes: + - ufp8_e8m0_nan + ipts: + static: false + dtype: sfp4_e2m1_all + group_shapes: + - - 1 + - 32 # tensor_view_shape + - 1 + - 1 + - 1 + scale_dtypes: + - ufp8_e8m0_nan \ No newline at end of file diff --git a/examples/diffusion/configs/svdquant/mxfp6_e2m3.yaml b/examples/diffusion/configs/svdquant/mxfp6_e2m3.yaml new file mode 100644 index 0000000..ebe47f5 --- /dev/null +++ b/examples/diffusion/configs/svdquant/mxfp6_e2m3.yaml @@ -0,0 +1,22 @@ +quant: + wgts: + dtype: sfp6_e2m3_all + group_shapes: + - - 1 + - 32 + - 1 + - 1 + - 1 + scale_dtypes: + - ufp8_e8m0_nan + ipts: + static: false + dtype: sfp6_e2m3_all + group_shapes: + - - 1 + - 32 # tensor_view_shape + - 1 + - 1 + - 1 + scale_dtypes: + - ufp8_e8m0_nan \ No newline at end of file