diff --git a/examples/llm_ptq/README.md b/examples/llm_ptq/README.md index 46260db8f..246624d6b 100755 --- a/examples/llm_ptq/README.md +++ b/examples/llm_ptq/README.md @@ -131,9 +131,21 @@ Please reference our [framework scripts](#framework-scripts) and our [docs](http | QwQ | ✅ | - | - | - | ✅ | | T5 | ✅ | ✅ | ✅ | ✅ | - | | Whisper | ✅ | ❌ | ❌ | ❌ | - | +| Kimi-K2-Thinking-BF16 | ✅ | ❌ | ❌ | ❌ | ✅ | > *This is a subset of the models supported. For the full list please check the [TensorRT-LLM support matrix](https://nvidia.github.io/TensorRT-LLM/reference/precision.html#support-matrix)* +> We recommend upcasting Kimi-K2-Thinking from INT4 to BF16 before running quantization. + +```python +from transformers import AutoModelForCausalLM +from transformers.utils.quantization_config import CompressedTensorsConfig + +model = AutoModelForCausalLM.from_pretrained("moonshotai/Kimi-K2-Thinking", torch_dtype="auto", device_map = "auto", local_files_only = True, trust_remote_code = True, quantization_config = CompressedTensorsConfig(run_compressed=False)) + +# And then save it with save_pretrained +``` + > *1.The w4a8_awq is an experimental quantization scheme that may result in a higher accuracy penalty.* \ > *2.For some models, there is only support for exporting quantized checkpoints.* \ > *3.W4A8_AWQ is only available on some models but not all* \ diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index 57f0b5a89..40cf94d75 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -80,6 +80,7 @@ "w4a8_nvfp4_fp8": mtq.W4A8_NVFP4_FP8_CFG, "w4a8_mxfp4_fp8": mtq.W4A8_MXFP4_FP8_CFG, "nvfp4_mlp_only": mtq.NVFP4_MLP_ONLY_CFG, + "nvfp4_experts_only": mtq.NVFP4_EXPERTS_ONLY_CFG, } KV_QUANT_CFG_CHOICES = { @@ -121,6 +122,7 @@ def auto_quantize( "fp8_pb_wo", "w4a8_mxfp4_fp8", "nvfp4_mlp_only", + "nvfp4_experts_only", ] for qformat in qformat_list ), "One or more quantization formats provided are not supported for unified checkpoint export" diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index 3a43113df..611c48fb9 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -623,6 +623,25 @@ "algorithm": "max", } +NVFP4_EXPERTS_ONLY_CFG = { + "quant_cfg": { + "*mlp.experts*weight_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, + "enable": True, + "pass_through_bwd": True, + }, + "*mlp.experts*input_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, + "enable": True, + "pass_through_bwd": True, + }, + **_default_disabled_quantizer_cfg, + }, + "algorithm": "max", +} + choices: set[str] = { "FP8_2D_BLOCKWISE_WEIGHT_ONLY_CFG", "FP8_AFFINE_KV_CFG", @@ -652,6 +671,7 @@ "NVFP4_MLP_WEIGHT_ONLY_CFG", "MXFP4_MLP_WEIGHT_ONLY_CFG", "NVFP4_MLP_ONLY_CFG", + "NVFP4_EXPERTS_ONLY_CFG", } BiasType = Literal["static", "dynamic"]