Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 125 additions & 2 deletions deepcompressor/quantizer/impl/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,8 @@ def quantize(
round_delta = kwargs.pop("round_delta", None)
if round_delta is not None:
round_delta = round_delta.view(-1, *shape[channels_dim:])
result = self._quantize(
if hasattr(self.config.dtype, "name") and self.config.dtype.name in ["sfp4_e2m1_all", "sfp6_e2m3_all", "sfp6_e3m2_all"] and self.config.scale_dtypes[0].name in ["ufp8_e8m0_nan"]:
result = self._quantize_mx(
tensor,
kernel=kernel,
scale=scale,
Expand All @@ -122,7 +123,23 @@ def quantize(
default_dtype=default_dtype or tensor.dtype,
develop_dtype=develop_dtype,
**kwargs,
)
)
else:
result = self._quantize(
tensor,
kernel=kernel,
scale=scale,
zero=zero,
dynamic_range=dynamic_range,
range_bound=range_bound,
quant_range=quant_range,
round_delta=round_delta,
return_with_dequant=return_with_dequant,
return_with_quant=return_with_quant,
default_dtype=default_dtype or tensor.dtype,
develop_dtype=develop_dtype,
**kwargs,
)
if result.data is not None:
result._dequantized = result.data.view(shape)
if result.qdata is not None:
Expand Down Expand Up @@ -321,3 +338,109 @@ def update(
config, tensor_shape, default_dtype, quant_range=quant_range, range_bound=range_bound
)
return self.info

def _quantize_mx( # noqa: C901
self,
tensor: torch.Tensor,
*,
kernel: BaseQuantKernel | BaseQuantKernelConfig | None = None,
# scale-based quantization arguments
scale: torch.Tensor | tp.Sequence[torch.Tensor | None] | None = None,
zero: torch.Tensor | None = None,
# range-based quantization arguments
dynamic_range: DynamicRange | tp.Sequence[DynamicRange | None] | None = None,
# other arguments
range_bound: RangeBound | None = None,
quant_range: QuantRange | None = None,
round_delta: torch.Tensor | None = None,
return_with_dequant: bool = True,
return_with_quant: bool = False,
default_dtype: torch.dtype = torch.float16,
develop_dtype: torch.dtype = torch.float32,
**kwargs,
) -> QuantTensor:
"""Quantize a floating point tensor.

Args:
tensor (`torch.Tensor`):
The floating-point tensor to be quantized.
kernel (`QuantKernel` or `QuantKernelConfig` or `None`, *optional*, defaults to `None`):
The quantization kernel or its configuration.
scale (`torch.Tensor` or `Sequence[torch.Tensor]` or `None`, *optional*, defaults to `None`):
The scale tensor.
zero (`torch.Tensor` or `None`, *optional*, defaults to `None`):
The zero point tensor.
dynamic_range (`DynamicRange` or `Sequence[DynamicRange]` or `None`, *optional*, defaults to `None`):
The dynamic range.
range_bound (`RangeBound` or `None`, *optional*, defaults to `None`):
The dynamic range bound.
quant_range (`QuantRange` or `None`, *optional*, defaults to `None`):
The quantization range.
return_with_dequant (`bool`, *optional*, defaults to `True`):
Whether to return with dequantized tensor.
return_with_quant (`bool`, *optional*, defaults to `False`):
Whether to return with quantized tensor.
default_dtype (`torch.dtype`, *optional*, defaults to `torch.float16`):
The default dtype for scale.
develop_dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
The develop dtype.
**kwargs:
Other keyword arguments for the quantization kernel. For example,
``inputs`` for the input tensors in GPTQ kernel,
``round_delta`` for the rounding delta in the RTN kernel.

Returns:
`QuantTensor`:
The quantized tensor.
"""
shape, dtype = tensor.shape, tensor.dtype
self.update(shape, default_dtype, quant_range, range_bound)
if self.info is None or self.info.num_steps == 0:
return QuantTensor(dequantized=tensor, quantized=tensor, view_shape=shape)
assert self.info.num_steps == 1
# region compute and quantize the scales and zero point for quantization
quant_scale = QuantScale()
develop_tensor = tensor.to(dtype=develop_dtype) if dtype != develop_dtype else tensor.clone()

step_scale, step_zero = self.info.steps[0].scale.quantize_mx(
scale=scale,
zero=None,
tensor=develop_tensor,
dynamic_range=dynamic_range,
)
quant_scale.append(step_scale)
quant_zero = step_zero
# endregion
# region quantize the tensor
assert isinstance(step_scale, QuantScale), "The last scale must be a QuantScale."
assert isinstance(step_zero, torch.Tensor), "The last zero point must be a tensor."

assert not develop_tensor.isnan().any(), "Quantized tensor contains NaN."
assert not develop_tensor.isinf().any(), "Quantized tensor contains Inf."
# endregion
# region update the quantized tensor
quantized = None

# endregion
# region update the dequantized tensor

assert return_with_dequant

from .mx import fake_quantize_mx

develop_tensor = develop_tensor.reshape(self.info.steps[-1].tensor_view_shape)

quantized = fake_quantize_mx(develop_tensor, step_scale.data, self.config.dtype.name)

dequantized = quantized * step_scale.data

dequantized = dequantized.view(shape).to(dtype=dtype)

# endregion
return QuantTensor(
dequantized=dequantized,
quantized=quantized,
scale=quant_scale if return_with_quant else None,
zero=quant_zero if return_with_quant else None,
view_shape=self.info.steps[-1].tensor_view_shape if return_with_quant else None,
)
68 changes: 68 additions & 0 deletions deepcompressor/quantizer/impl/mx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import struct
import torch

# Constants based on the provided macros
FLOAT32_EXP_BIAS = 127
FLOAT32_EXP_MAX = 255
FLOAT32_TRAILING_MBITS = 23
FLOAT32_IMPLIED1 = (1 << FLOAT32_TRAILING_MBITS)
FLOAT32_FULL_MBITS = (FLOAT32_TRAILING_MBITS + 1)
FLOAT32_INF = 0x7fe00000
FLOAT32_EXP_OFFSET = 23
FLOAT32_SIGN_OFFSET = 31
FLOAT32_EXP_MASK = 0x7f800000
FLOAT32_MANTISSA_MASK = 0x007fffff

# RoundMode is assumed to be an integer, you can define it based on your specific use case.
RoundMode = int # This can be further defined if you have specific enum values for rounding modes.

from torch.utils.cpp_extension import load
import torch

import os
os.environ['PATH'] = '/usr/lib/cuda/bin:' + os.environ['PATH']

extra_cflags = ["-DUSE_CUDA"]
extra_cuda_cflags = ["-DUSE_CUDA"]

mx = load(name="mx",
sources=[
"deepcompressor/quantizer/impl/mx/funcs.cpp",
"deepcompressor/quantizer/impl/mx/funcs.cu"
],
extra_cuda_cflags=extra_cuda_cflags,
extra_cflags=extra_cflags,
extra_include_paths=[
"/group/amdneuralopt/zhaofeng/tools/miniconda3/envs/deepcompressor/lib/python3.12/site-packages/triton/backends/nvidia/include/",
],
verbose=True)


def get_dtype_params(dtype: str) -> tuple[int, int, int]:
if dtype == "fp6_e3m2_all":
ebits, mbits = 3, 2
emax = 2**(ebits - 1)
elif dtype == "sfp6_e2m3_all":
ebits, mbits = 2, 3
emax = 2**(ebits - 1)
elif dtype == "sfp4_e2m1_all":
ebits, mbits = 2, 1
emax = 2**(ebits - 1)
else:
raise Exception("Unknown element format %s" % dtype)

return ebits, mbits, emax



def fake_quantize_mx(input_tensor, scale, element_dtype):
ebits, mbits, _ = get_dtype_params(element_dtype)
max_exp = pow(2.0, ebits) - 1
offset_exp = pow(2.0, ebits - 1) - 1
quant_max = pow(2.0, max_exp - offset_exp) * (1 + (pow(2.0, mbits) - 1) / (pow(2.0, mbits)))

input_tensor = input_tensor / scale

output_tensor = mx.fake_quantize_to_low_precision_fp(input_tensor.contiguous(), ebits, mbits, quant_max, 0)

return output_tensor
55 changes: 55 additions & 0 deletions deepcompressor/quantizer/impl/mx/funcs.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
//
// Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
// SPDX-License-Identifier: MIT
//

#include <ATen/ATen.h>
#include <torch/extension.h>
#include "funcs.cuh"

void fake_quantize_to_low_precision_fp_cpu(
float * input,
float * output,
uint32_t num_elements,
int ebits,
int mbits,
float max_norm,
RoundMode round_mode
) {
for (int i = 0; i < num_elements; i++) {
output[i] = fake_quantize_element(input[i], max_norm, ebits, mbits, round_mode);
}
}

torch::Tensor fake_quantize_to_low_precision_fp(
torch::Tensor &input,
int ebits,
int mbits,
float max_norm,
uint32_t round_mode
) {
float * input_data = input.data_ptr<float>();
torch::Tensor output = torch::empty_like(input);
float * output_data = output.data_ptr<float>();
#ifdef USE_CUDA
if (input.is_cpu()) {
fake_quantize_to_low_precision_fp_cpu(input_data, output_data, input.numel(), ebits, mbits, max_norm, (RoundMode)round_mode);
} else {
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
fake_quantize_to_low_precision_fp_cuda(input_data, output_data, input.numel(), ebits, mbits, max_norm, (RoundMode)round_mode);
}
#else
fake_quantize_to_low_precision_fp_cpu(input_data, output_data, input.numel(), ebits, mbits, max_norm, (RoundMode)round_mode);
#endif
return output;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("fake_quantize_to_low_precision_fp", &fake_quantize_to_low_precision_fp, "fake_quantize_to_low_precision_fp",
py::arg("input"),
py::arg("ebits"),
py::arg("mbits"),
py::arg("max_norm"),
py::arg("round_mode")
);
}
37 changes: 37 additions & 0 deletions deepcompressor/quantizer/impl/mx/funcs.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
//
// Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
// SPDX-License-Identifier: MIT
//
#include <c10/cuda/CUDAGuard.h>
#include "funcs.cuh"

__global__ void fake_quantize_kernel(
float * input,
float * output,
uint32_t num_elements,
int ebits,
int mbits,
float max_norm,
RoundMode round_mode
) {
int index = blockDim.x * blockIdx.x + threadIdx.x;
if (index >= num_elements) {
return;
}
output[index] = fake_quantize_element(input[index], max_norm, ebits, mbits, round_mode);
}

void fake_quantize_to_low_precision_fp_cuda(
float * input,
float * output,
uint32_t num_elements,
int ebits,
int mbits,
float max_norm,
RoundMode round_mode
) {
int blocks = num_elements % THREADS_PER_BLOCK == 0 ? num_elements / THREADS_PER_BLOCK : num_elements / THREADS_PER_BLOCK + 1;
fake_quantize_kernel<<<blocks, THREADS_PER_BLOCK>>>(
input, output, num_elements, ebits, mbits, max_norm, round_mode
);
}
Loading