Skip to content

Commit 7aa0559

Browse files
committed
added seperate file for vLLM for export
Signed-off-by: Kinjal Patel <[email protected]>
1 parent 13f6bcd commit 7aa0559

File tree

3 files changed

+159
-100
lines changed

3 files changed

+159
-100
lines changed
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""Export functions for vLLM fakequant."""
16+
17+
import os
18+
from pathlib import Path
19+
20+
import torch
21+
import torch.nn as nn
22+
23+
from modelopt.torch.export.layer_utils import is_quantlinear
24+
from modelopt.torch.export.model_config import QUANTIZATION_NONE
25+
from modelopt.torch.quantization.utils import get_quantizer_state_dict
26+
27+
28+
def export_hf_vllm_fq_checkpoint(
29+
model: nn.Module,
30+
export_dir: Path | str,
31+
) -> dict[str, torch.Tensor]:
32+
"""Exports the torch model weights and amax values separately.
33+
34+
This function:
35+
1. Extracts amax values for calibration
36+
2. Deletes all quantizer parameters from state dict to store only weights in original dtype
37+
38+
Args:
39+
model: The quantized model to export
40+
export_dir: Directory to save the amax values
41+
42+
Returns:
43+
post_state_dict: Dict containing quantized weights
44+
"""
45+
amax_dict = {
46+
name + "._amax": param["_amax"].detach().clone().cpu()
47+
for name, param in get_quantizer_state_dict(model).items()
48+
if "_amax" in param
49+
}
50+
51+
# remove quantizer from model
52+
for _, module in model.named_modules():
53+
if is_quantlinear(module):
54+
delattr(module, "weight_quantizer")
55+
delattr(module, "input_quantizer")
56+
delattr(module, "output_quantizer")
57+
module.export()
58+
torch.save(amax_dict, f"{export_dir}/quant_amax.pth")
59+
return model.state_dict()
60+
61+
62+
def get_mcore_vllm_fq_quantized_state(
63+
module: torch.nn.Module, name_to_value: dict, dtype: torch.dtype = torch.bfloat16
64+
):
65+
"""Return a state_dict, quantization format, and block_size of the quantized module.
66+
67+
Args:
68+
module: The target module to perform real quantization.
69+
name_to_value: The dictionary to store the quantized state.
70+
dtype: The default data type.
71+
72+
Returns:
73+
Tuple: state dict, quantization format, and block_size of the quantized module.
74+
75+
"""
76+
qformat: str = QUANTIZATION_NONE
77+
block_size = 0
78+
79+
for name, param in get_quantizer_state_dict(module).items():
80+
if "_amax" in param:
81+
name_to_value[name + "._amax"] = param["_amax"].to(dtype).cpu()
82+
return name_to_value, qformat, block_size
83+
84+
85+
def gather_mcore_vllm_fq_quantized_state_dict(
86+
state_dict: dict[str, torch.Tensor], save_directory: str | os.PathLike
87+
):
88+
"""Gather all quantized state dict from all ranks and save them to a file.
89+
90+
Args:
91+
state_dict: The state dictionary of the module.
92+
save_directory: The directory to save the quantized state dict.
93+
94+
Returns:
95+
The state dictionary of the module without quantized state.
96+
"""
97+
amax_state_dict = {
98+
k: v.detach().clone().cpu() for k, v in state_dict.items() if k.endswith("_amax")
99+
}
100+
101+
# Gather all amax dicts to rank 0
102+
world_size = torch.distributed.get_world_size()
103+
rank = torch.distributed.get_rank()
104+
105+
if rank == 0:
106+
# Rank 0 will collect all amax values
107+
all_amax_dicts = [None] * world_size
108+
torch.distributed.gather_object(amax_state_dict, all_amax_dicts, dst=0)
109+
110+
# Merge all amax dicts into one
111+
merged_amax_dict = {}
112+
for amax_dict in all_amax_dicts:
113+
if amax_dict is not None:
114+
merged_amax_dict.update(amax_dict)
115+
116+
print(f"Total amax entries from all ranks: {len(merged_amax_dict.keys())}")
117+
torch.save(merged_amax_dict, save_directory + "/quant_amax.pth")
118+
else:
119+
# Other ranks just send their amax values
120+
torch.distributed.gather_object(amax_state_dict, None, dst=0)
121+
122+
torch.distributed.barrier()
123+
124+
# remove amax values from state_dict
125+
return {k: v for k, v in state_dict.items() if not k.endswith("_amax")}

modelopt/torch/export/unified_export_hf.py

Lines changed: 8 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,7 @@
3333
from modelopt.torch.quantization import set_quantizer_by_cfg_context
3434
from modelopt.torch.quantization.nn import SequentialQuantizer, TensorQuantizer
3535
from modelopt.torch.quantization.qtensor import NVFP4QTensor
36-
from modelopt.torch.quantization.utils import (
37-
fsdp2_aware_weight_update,
38-
get_quantizer_state_dict,
39-
quantizer_attr_names,
40-
)
36+
from modelopt.torch.quantization.utils import fsdp2_aware_weight_update, quantizer_attr_names
4137

4238
from .convert_hf_config import convert_hf_quant_config_format
4339
from .layer_utils import (
@@ -62,6 +58,7 @@
6258
)
6359
from .model_utils import get_language_model_from_vl, is_multimodal_model
6460
from .plugins import export_spec_ckpt_config, export_spec_ckpt_state_dict, spec_opt_only
61+
from .plugins.vllm_fakequant import export_hf_vllm_fq_checkpoint
6562
from .quant_utils import (
6663
fuse_prequant_layernorm,
6764
fuse_prequant_to_linear,
@@ -536,44 +533,12 @@ def _export_hf_checkpoint(
536533
return quantized_state_dict, quant_config
537534

538535

539-
def _export_hf_bf16_weights_amax(
540-
model: nn.Module,
541-
) -> tuple[dict[str, torch.Tensor], dict[str, Any]]:
542-
"""Exports the torch model weights and amax values separately.
543-
544-
This function:
545-
1. Extracts amax values for calibration
546-
2. Deletes all quantizer parameters from state dict to store only weights in original dtype
547-
548-
Args:
549-
model: The quantized model to export
550-
551-
Returns:
552-
post_state_dict: Dict containing quantized weights
553-
amax_dict: Dict containing amax values
554-
"""
555-
amax_dict = {
556-
name + "._amax": param["_amax"].detach().clone().cpu()
557-
for name, param in get_quantizer_state_dict(model).items()
558-
if "_amax" in param
559-
}
560-
561-
# remove quantizer from model
562-
for name, module in model.named_modules():
563-
if is_quantlinear(module):
564-
delattr(module, "weight_quantizer")
565-
delattr(module, "input_quantizer")
566-
delattr(module, "output_quantizer")
567-
module.export()
568-
return model.state_dict(), amax_dict
569-
570-
571536
def export_hf_checkpoint(
572537
model: nn.Module,
573538
dtype: torch.dtype | None = None,
574539
export_dir: Path | str = tempfile.gettempdir(),
575540
save_modelopt_state: bool = False,
576-
export_bf16_weights_amax: bool = False,
541+
export_vllm_fq_weights_qstate: bool = False,
577542
):
578543
"""Exports the torch model to unified checkpoint and saves to export_dir.
579544
@@ -582,8 +547,8 @@ def export_hf_checkpoint(
582547
dtype: the weights data type to export the unquantized layers or the default model data type if None.
583548
export_dir: the target export path.
584549
save_modelopt_state: whether to save the modelopt state_dict.
585-
export_bf16_weights_amax: whether to export the bf16 weights and amax values separately. This can be used for
586-
vLLM fakequant serving.
550+
export_vllm_fq_weights_qstate: whether to export the weights and quantization state separately for vLLM
551+
fakequant serving.
587552
"""
588553
export_dir = Path(export_dir)
589554
export_dir.mkdir(parents=True, exist_ok=True)
@@ -597,15 +562,14 @@ def export_hf_checkpoint(
597562
return
598563

599564
try:
600-
if export_bf16_weights_amax:
601-
post_state_dict, amax_dict = _export_hf_bf16_weights_amax(model)
565+
if export_vllm_fq_weights_qstate:
566+
post_state_dict = export_hf_vllm_fq_checkpoint(model, export_dir)
602567
hf_quant_config = None
603-
torch.save(amax_dict, f"{export_dir}/quant_amax.pth")
604568
else:
605569
post_state_dict, hf_quant_config = _export_hf_checkpoint(model, dtype)
606570

607571
if hf_quant_config is not None:
608-
# Save hf_quant_config.json for backward compatibility
572+
# Save hf_quant_config.json for\ backward compatibility
609573
with open(f"{export_dir}/hf_quant_config.json", "w") as file:
610574
json.dump(hf_quant_config, file, indent=4)
611575

0 commit comments

Comments
 (0)