Skip to content

Commit 560dfc7

Browse files
committed
Updated docs
Signed-off-by: Kinjal Patel <[email protected]>
1 parent cff3cc6 commit 560dfc7

File tree

3 files changed

+58
-50
lines changed

3 files changed

+58
-50
lines changed

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ Model Optimizer Changelog (Linux)
1616
- Add FP8/NVFP4 KV cache quantization support for Megatron Core models.
1717
- Add flag ``trt_plugins_precision`` in ONNX autocast to indicate custom ops precision. This is similar to the flag already existing in the quantization workflow.
1818
- Add support for PyTorch Geometric quantization.
19+
- Added support for QAT fakequant evaluation in vLLM. in vLLM for fast evaluation of arbitrary quantization formats. See ``examples/vllm_serve/README.md#load-qatptq-model-and-serve-in-vllm-wip`` for more details.
1920

2021
**Documentation**
2122

examples/vllm_serve/README.md

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -55,15 +55,18 @@ lm_eval --model local-completions --tasks gsm8k --model_args model=<model_name>,
5555

5656
## Load QAT/PTQ model and serve in vLLM (WIP)
5757

58-
Overwrite the calibrated amax value with prepared values from either PTQ/QAT. This is only tested for Llama3.1
58+
Overwrite the calibrated amax value with prepared values from either QAT/PTQ.
5959

60-
Step 1: convert amax to merged amax, using llama3.1 as an example:
60+
Step 1: export the model with bf16 weights and amax values.
6161

62-
```bash
63-
python convert_amax_hf2vllm.py -i <amax.pth> -o <vllm_amax.pth>
64-
```
62+
- For HF model set `export_bf16_weights_amax` to export the model with function `modelopt.torch.export.unified_export_hf.export_hf_checkpoint`.
63+
- For MCore model use `export_bf16_weights_amax` to export the model with function `modelopt.torch.export.unified_export_megatron.export_mcore_gpt_to_hf`.
64+
65+
Step 2: configure <quant_amax.pth> from exported model using AMAX_FILE_PATH environment variable in step 1. For example:
6566

66-
Step 2: add `<vllm_amax.pth>` to `quant_config` in `vllm_serve_fakequant.py`
67+
```
68+
AMAX_FILE_PATH=<vllm_amax.pth> QUANT_CFG=<quant_config> python vllm_serve_fakequant.py <model_path> -tp 8 --host 0.0.0.0 --port 8000
69+
```
6770

6871
## Important Notes
6972

@@ -85,3 +88,4 @@ torch.distributed.barrier()
8588
## Known Problems
8689

8790
1. AWQ is not yet supported in vLLM.
91+
2. PTQ/QAT checkpoint doesn't work with KV Cache quantization enabled.

modelopt/torch/export/unified_export_hf.py

Lines changed: 47 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@
7777
to_quantized_weight,
7878
)
7979

80-
__all__ = ["export_hf_bf16_weights_amax", "export_hf_checkpoint"]
80+
__all__ = ["export_hf_checkpoint"]
8181

8282

8383
def _is_enabled_quantizer(quantizer):
@@ -536,11 +536,44 @@ def _export_hf_checkpoint(
536536
return quantized_state_dict, quant_config
537537

538538

539+
def _export_hf_bf16_weights_amax(
540+
model: nn.Module,
541+
) -> tuple[dict[str, torch.Tensor], dict[str, Any]]:
542+
"""Exports the torch model weights and amax values separately.
543+
544+
This function:
545+
1. Extracts amax values for calibration
546+
2. Deletes all quantizer parameters from state dict to store only weights in original dtype
547+
548+
Args:
549+
model: The quantized model to export
550+
551+
Returns:
552+
post_state_dict: Dict containing quantized weights
553+
amax_dict: Dict containing amax values
554+
"""
555+
amax_dict = {
556+
name + "._amax": param["_amax"].detach().clone().cpu()
557+
for name, param in get_quantizer_state_dict(model).items()
558+
if "_amax" in param
559+
}
560+
561+
# remove quantizer from model
562+
for name, module in model.named_modules():
563+
if is_quantlinear(module):
564+
delattr(module, "weight_quantizer")
565+
delattr(module, "input_quantizer")
566+
delattr(module, "output_quantizer")
567+
module.export()
568+
return model.state_dict(), amax_dict
569+
570+
539571
def export_hf_checkpoint(
540572
model: nn.Module,
541573
dtype: torch.dtype | None = None,
542574
export_dir: Path | str = tempfile.gettempdir(),
543575
save_modelopt_state: bool = False,
576+
export_bf16_weights_amax: bool = False,
544577
):
545578
"""Exports the torch model to unified checkpoint and saves to export_dir.
546579
@@ -562,13 +595,19 @@ def export_hf_checkpoint(
562595
return
563596

564597
try:
565-
post_state_dict, hf_quant_config = _export_hf_checkpoint(model, dtype)
598+
if export_bf16_weights_amax:
599+
post_state_dict, amax_dict = _export_hf_bf16_weights_amax(model)
600+
hf_quant_config = None
601+
torch.save(amax_dict, f"{export_dir}/quant_amax.pth")
602+
else:
603+
post_state_dict, hf_quant_config = _export_hf_checkpoint(model, dtype)
566604

567-
# Save hf_quant_config.json for backward compatibility
568-
with open(f"{export_dir}/hf_quant_config.json", "w") as file:
569-
json.dump(hf_quant_config, file, indent=4)
605+
if hf_quant_config is not None:
606+
# Save hf_quant_config.json for backward compatibility
607+
with open(f"{export_dir}/hf_quant_config.json", "w") as file:
608+
json.dump(hf_quant_config, file, indent=4)
570609

571-
hf_quant_config = convert_hf_quant_config_format(hf_quant_config)
610+
hf_quant_config = convert_hf_quant_config_format(hf_quant_config)
572611

573612
# Save model
574613
model.save_pretrained(
@@ -581,7 +620,8 @@ def export_hf_checkpoint(
581620
with open(original_config) as file:
582621
config_data = json.load(file)
583622

584-
config_data["quantization_config"] = hf_quant_config
623+
if hf_quant_config is not None:
624+
config_data["quantization_config"] = hf_quant_config
585625

586626
with open(original_config, "w") as file:
587627
json.dump(config_data, file, indent=4)
@@ -592,40 +632,3 @@ def export_hf_checkpoint(
592632
" can be saved with torch.save for further inspection."
593633
)
594634
raise e
595-
596-
597-
def export_hf_bf16_weights_amax(
598-
model: nn.Module,
599-
export_dir: Path | str = tempfile.gettempdir(),
600-
):
601-
"""Exports the torch model weights and amax values separately which can be used for vLLM fakequant serve.
602-
603-
This function:
604-
1. Extracts amax values for calibration
605-
2. Deletes all quantizer parameters from state dict to store only weights in original dtype
606-
3. Saves model checkpoint (with weights in original dtype) and amax values separately
607-
608-
Args:
609-
model: The quantized model to export
610-
export_dir: Directory to save the model and artifacts
611-
"""
612-
export_dir = Path(export_dir)
613-
export_dir.mkdir(parents=True, exist_ok=True)
614-
615-
amax_dict = {
616-
name + "._amax": param["_amax"].detach().clone().cpu()
617-
for name, param in get_quantizer_state_dict(model).items()
618-
if "_amax" in param
619-
}
620-
621-
# remove quantizer from model
622-
for name, module in model.named_modules():
623-
if is_quantlinear(module):
624-
delattr(module, "weight_quantizer")
625-
delattr(module, "input_quantizer")
626-
delattr(module, "output_quantizer")
627-
module.export()
628-
629-
# Save with model without quantizer parameters
630-
model.save_pretrained(export_dir)
631-
torch.save(amax_dict, f"{export_dir}/quant_amax.pth")

0 commit comments

Comments
 (0)