Skip to content

Commit cff3cc6

Browse files
committed
Added support to export for BF16 weight and amax
Signed-off-by: Kinjal Patel <[email protected]>
1 parent 1d0ee04 commit cff3cc6

File tree

4 files changed

+231
-228
lines changed

4 files changed

+231
-228
lines changed

examples/vllm_serve/convert_amax_hf2vllm.py

Lines changed: 0 additions & 213 deletions
This file was deleted.

examples/vllm_serve/fakequant_worker.py

Lines changed: 74 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@
1515

1616
import dataclasses
1717
import os
18+
import re
1819
import warnings
20+
from collections import defaultdict
1921
from contextlib import contextmanager
2022
from typing import Any
2123

@@ -30,6 +32,68 @@
3032
from modelopt.torch.utils.dataset_utils import get_dataset_dataloader
3133

3234

35+
def convert_amax_hf2vllm(
36+
hf_state_dict: dict[str, torch.Tensor],
37+
) -> dict[str, torch.Tensor]:
38+
"""
39+
Convert amax values from HuggingFace format to vLLM format.
40+
41+
This function merges:
42+
- q_proj, k_proj, v_proj amax values into qkv_proj (taking max)
43+
- gate_proj, up_proj amax values into gate_up_proj (taking max)
44+
45+
Args:
46+
hf_state_dict: HuggingFace state dict containing amax values
47+
48+
Returns:
49+
vLLM format state dict with merged amax values
50+
"""
51+
vllm_state_dict = {}
52+
53+
# Group keys by their base pattern (without the specific projection name)
54+
merge_groups = defaultdict(list)
55+
56+
for key, value in hf_state_dict.items():
57+
if "_amax" not in key:
58+
# Copy non-amax keys as-is
59+
vllm_state_dict[key] = value
60+
continue
61+
62+
# Check if this is a q/k/v projection that needs merging
63+
qkv_match = re.search(r"(.*\.)([qkv])_proj(\..+_amax)$", key)
64+
if qkv_match:
65+
base_pattern = qkv_match.group(1) + "qkv_proj" + qkv_match.group(3)
66+
merge_groups[base_pattern].append((key, value))
67+
continue
68+
69+
# Check if this is a gate/up projection that needs merging
70+
gate_up_match = "mixer" not in key and re.search(r"(.*\.)(gate|up)_proj(\..+_amax)$", key)
71+
if gate_up_match:
72+
base_pattern = gate_up_match.group(1) + "gate_up_proj" + gate_up_match.group(3)
73+
merge_groups[base_pattern].append((key, value))
74+
continue
75+
76+
# Copy other amax keys as-is (like o_proj, down_proj)
77+
vllm_state_dict[key] = value
78+
79+
# Merge grouped amax values by taking the maximum
80+
for merged_key, key_value_pairs in merge_groups.items():
81+
if len(key_value_pairs) > 1:
82+
# Take the maximum across all values for this merged key
83+
values = [value for _, value in key_value_pairs]
84+
merged_value = torch.stack(values).max(dim=0)[0]
85+
vllm_state_dict[merged_key] = merged_value
86+
print(f"Merged {len(key_value_pairs)} keys into {merged_key}")
87+
for orig_key, _ in key_value_pairs:
88+
print(f" - {orig_key}")
89+
else:
90+
# Single key, just rename it
91+
_, value = key_value_pairs[0]
92+
vllm_state_dict[merged_key] = value
93+
94+
return vllm_state_dict
95+
96+
3397
@contextmanager
3498
def disable_compilation(model):
3599
do_not_compile = True
@@ -154,8 +218,17 @@ def calibrate_loop(model: Any = None) -> None:
154218
if amax_file_path:
155219
print(f"Loading amax values from {amax_file_path}")
156220
saved_amax_dict = torch.load(amax_file_path)
157-
current_state_dict = model.state_dict()
221+
# convert amax keys to vLLM format
222+
if hasattr(self.model_runner.model, "hf_to_vllm_mapper"):
223+
saved_amax_dict = self.model_runner.model.hf_to_vllm_mapper.apply_dict(saved_amax_dict)
224+
saved_amax_dict = {
225+
key.replace("quantizer_amax", "quantizer._amax"): value
226+
for key, value in saved_amax_dict.items()
227+
if key.endswith("quantizer_amax")
228+
}
229+
saved_amax_dict = convert_amax_hf2vllm(saved_amax_dict)
158230

231+
current_state_dict = model.state_dict()
159232
# Count amax keys in checkpoint and model
160233
checkpoint_amax_keys = [key for key in saved_amax_dict if key.endswith("_amax")]
161234
model_amax_keys = [key for key in current_state_dict if key.endswith("_amax")]

modelopt/torch/export/unified_export_hf.py

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,11 @@
3333
from modelopt.torch.quantization import set_quantizer_by_cfg_context
3434
from modelopt.torch.quantization.nn import SequentialQuantizer, TensorQuantizer
3535
from modelopt.torch.quantization.qtensor import NVFP4QTensor
36-
from modelopt.torch.quantization.utils import fsdp2_aware_weight_update, quantizer_attr_names
36+
from modelopt.torch.quantization.utils import (
37+
fsdp2_aware_weight_update,
38+
get_quantizer_state_dict,
39+
quantizer_attr_names,
40+
)
3741

3842
from .convert_hf_config import convert_hf_quant_config_format
3943
from .layer_utils import (
@@ -73,7 +77,7 @@
7377
to_quantized_weight,
7478
)
7579

76-
__all__ = ["export_hf_checkpoint"]
80+
__all__ = ["export_hf_bf16_weights_amax", "export_hf_checkpoint"]
7781

7882

7983
def _is_enabled_quantizer(quantizer):
@@ -588,3 +592,40 @@ def export_hf_checkpoint(
588592
" can be saved with torch.save for further inspection."
589593
)
590594
raise e
595+
596+
597+
def export_hf_bf16_weights_amax(
598+
model: nn.Module,
599+
export_dir: Path | str = tempfile.gettempdir(),
600+
):
601+
"""Exports the torch model weights and amax values separately which can be used for vLLM fakequant serve.
602+
603+
This function:
604+
1. Extracts amax values for calibration
605+
2. Deletes all quantizer parameters from state dict to store only weights in original dtype
606+
3. Saves model checkpoint (with weights in original dtype) and amax values separately
607+
608+
Args:
609+
model: The quantized model to export
610+
export_dir: Directory to save the model and artifacts
611+
"""
612+
export_dir = Path(export_dir)
613+
export_dir.mkdir(parents=True, exist_ok=True)
614+
615+
amax_dict = {
616+
name + "._amax": param["_amax"].detach().clone().cpu()
617+
for name, param in get_quantizer_state_dict(model).items()
618+
if "_amax" in param
619+
}
620+
621+
# remove quantizer from model
622+
for name, module in model.named_modules():
623+
if is_quantlinear(module):
624+
delattr(module, "weight_quantizer")
625+
delattr(module, "input_quantizer")
626+
delattr(module, "output_quantizer")
627+
module.export()
628+
629+
# Save with model without quantizer parameters
630+
model.save_pretrained(export_dir)
631+
torch.save(amax_dict, f"{export_dir}/quant_amax.pth")

0 commit comments

Comments
 (0)