Skip to content

Commit 9946463

Browse files
committed
Updated docs
1 parent 096ee13 commit 9946463

File tree

3 files changed

+58
-51
lines changed

3 files changed

+58
-51
lines changed

CHANGELOG.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ Model Optimizer Changelog (Linux)
3232
- Add ``specdec_bench`` example to benchmark speculative decoding performance. See `examples/specdec_bench/README.md <https://github.com/NVIDIA/TensorRT-Model-Optimizer/tree/main/examples/specdec_bench#speculative-decoding-benchmark>`_ for more details.
3333
- Add FP8/NVFP4 KV cache quantization support for Megatron Core models.
3434
- Add flag ``trt_plugins_precision`` in ONNX autocast to indicate custom ops precision. This is similar to the flag already existing in the quantization workflow.
35-
35+
- Added support for QAT fakequant evaluation in vLLM. in vLLM for fast evaluation of arbitrary quantization formats. See ``examples/vllm_serve/README.md#load-qatptq-model-and-serve-in-vllm-wip`` for more details.
3636

3737
0.39 (2025-11-11)
3838
^^^^^^^^^^^^^^^^^

examples/vllm_serve/README.md

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -55,15 +55,18 @@ lm_eval --model local-completions --tasks gsm8k --model_args model=<model_name>,
5555

5656
## Load QAT/PTQ model and serve in vLLM (WIP)
5757

58-
Overwrite the calibrated amax value with prepared values from either PTQ/QAT. This is only tested for Llama3.1
58+
Overwrite the calibrated amax value with prepared values from either QAT/PTQ.
5959

60-
Step 1: convert amax to merged amax, using llama3.1 as an example:
60+
Step 1: export the model with bf16 weights and amax values.
6161

62-
```bash
63-
python convert_amax_hf2vllm.py -i <amax.pth> -o <vllm_amax.pth>
64-
```
62+
- For HF model set `export_bf16_weights_amax` to export the model with function `modelopt.torch.export.unified_export_hf.export_hf_checkpoint`.
63+
- For MCore model use `export_bf16_weights_amax` to export the model with function `modelopt.torch.export.unified_export_megatron.export_mcore_gpt_to_hf`.
64+
65+
Step 2: configure <quant_amax.pth> from exported model using AMAX_FILE_PATH environment variable in step 1. For example:
6566

66-
Step 2: add `<vllm_amax.pth>` to `quant_config` in `vllm_serve_fakequant.py`
67+
```
68+
AMAX_FILE_PATH=<vllm_amax.pth> QUANT_CFG=<quant_config> python vllm_serve_fakequant.py <model_path> -tp 8 --host 0.0.0.0 --port 8000
69+
```
6770

6871
## Important Notes
6972

@@ -85,3 +88,4 @@ torch.distributed.barrier()
8588
## Known Problems
8689

8790
1. AWQ is not yet supported in vLLM.
91+
2. PTQ/QAT checkpoint doesn't work with KV Cache quantization enabled.

modelopt/torch/export/unified_export_hf.py

Lines changed: 47 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@
7676
to_quantized_weight,
7777
)
7878

79-
__all__ = ["export_hf_bf16_weights_amax", "export_hf_checkpoint"]
79+
__all__ = ["export_hf_checkpoint"]
8080

8181

8282
def _is_enabled_quantizer(quantizer):
@@ -531,11 +531,44 @@ def _export_hf_checkpoint(
531531
return quantized_state_dict, quant_config
532532

533533

534+
def _export_hf_bf16_weights_amax(
535+
model: nn.Module,
536+
) -> tuple[dict[str, torch.Tensor], dict[str, Any]]:
537+
"""Exports the torch model weights and amax values separately.
538+
539+
This function:
540+
1. Extracts amax values for calibration
541+
2. Deletes all quantizer parameters from state dict to store only weights in original dtype
542+
543+
Args:
544+
model: The quantized model to export
545+
546+
Returns:
547+
post_state_dict: Dict containing quantized weights
548+
amax_dict: Dict containing amax values
549+
"""
550+
amax_dict = {
551+
name + "._amax": param["_amax"].detach().clone().cpu()
552+
for name, param in get_quantizer_state_dict(model).items()
553+
if "_amax" in param
554+
}
555+
556+
# remove quantizer from model
557+
for name, module in model.named_modules():
558+
if is_quantlinear(module):
559+
delattr(module, "weight_quantizer")
560+
delattr(module, "input_quantizer")
561+
delattr(module, "output_quantizer")
562+
module.export()
563+
return model.state_dict(), amax_dict
564+
565+
534566
def export_hf_checkpoint(
535567
model: nn.Module,
536568
dtype: torch.dtype | None = None,
537569
export_dir: Path | str = tempfile.gettempdir(),
538570
save_modelopt_state: bool = False,
571+
export_bf16_weights_amax: bool = False,
539572
):
540573
"""Exports the torch model to unified checkpoint and saves to export_dir.
541574
@@ -557,13 +590,19 @@ def export_hf_checkpoint(
557590
return
558591

559592
try:
560-
post_state_dict, hf_quant_config = _export_hf_checkpoint(model, dtype)
593+
if export_bf16_weights_amax:
594+
post_state_dict, amax_dict = _export_hf_bf16_weights_amax(model)
595+
hf_quant_config = None
596+
torch.save(amax_dict, f"{export_dir}/quant_amax.pth")
597+
else:
598+
post_state_dict, hf_quant_config = _export_hf_checkpoint(model, dtype)
561599

562-
# Save hf_quant_config.json for backward compatibility
563-
with open(f"{export_dir}/hf_quant_config.json", "w") as file:
564-
json.dump(hf_quant_config, file, indent=4)
600+
if hf_quant_config is not None:
601+
# Save hf_quant_config.json for backward compatibility
602+
with open(f"{export_dir}/hf_quant_config.json", "w") as file:
603+
json.dump(hf_quant_config, file, indent=4)
565604

566-
hf_quant_config = convert_hf_quant_config_format(hf_quant_config)
605+
hf_quant_config = convert_hf_quant_config_format(hf_quant_config)
567606

568607
# Save model
569608
model.save_pretrained(
@@ -576,7 +615,8 @@ def export_hf_checkpoint(
576615
with open(original_config) as file:
577616
config_data = json.load(file)
578617

579-
config_data["quantization_config"] = hf_quant_config
618+
if hf_quant_config is not None:
619+
config_data["quantization_config"] = hf_quant_config
580620

581621
with open(original_config, "w") as file:
582622
json.dump(config_data, file, indent=4)
@@ -587,40 +627,3 @@ def export_hf_checkpoint(
587627
" can be saved with torch.save for further inspection."
588628
)
589629
raise e
590-
591-
592-
def export_hf_bf16_weights_amax(
593-
model: nn.Module,
594-
export_dir: Path | str = tempfile.gettempdir(),
595-
):
596-
"""Exports the torch model weights and amax values separately which can be used for vLLM fakequant serve.
597-
598-
This function:
599-
1. Extracts amax values for calibration
600-
2. Deletes all quantizer parameters from state dict to store only weights in original dtype
601-
3. Saves model checkpoint (with weights in original dtype) and amax values separately
602-
603-
Args:
604-
model: The quantized model to export
605-
export_dir: Directory to save the model and artifacts
606-
"""
607-
export_dir = Path(export_dir)
608-
export_dir.mkdir(parents=True, exist_ok=True)
609-
610-
amax_dict = {
611-
name + "._amax": param["_amax"].detach().clone().cpu()
612-
for name, param in get_quantizer_state_dict(model).items()
613-
if "_amax" in param
614-
}
615-
616-
# remove quantizer from model
617-
for name, module in model.named_modules():
618-
if is_quantlinear(module):
619-
delattr(module, "weight_quantizer")
620-
delattr(module, "input_quantizer")
621-
delattr(module, "output_quantizer")
622-
module.export()
623-
624-
# Save with model without quantizer parameters
625-
model.save_pretrained(export_dir)
626-
torch.save(amax_dict, f"{export_dir}/quant_amax.pth")

0 commit comments

Comments
 (0)