Skip to content

Commit 32968c9

Browse files
committed
added seperate file for vLLM for export
1 parent 13f6bcd commit 32968c9

File tree

3 files changed

+155
-99
lines changed

3 files changed

+155
-99
lines changed
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import os
17+
from pathlib import Path
18+
19+
import torch
20+
import torch.nn as nn
21+
22+
from modelopt.torch.quantization.utils import get_quantizer_state_dict
23+
from modelopt.torch.export.model_config import QUANTIZATION_NONE
24+
from modelopt.torch.export.layer_utils import is_quantlinear
25+
26+
27+
def export_hf_vllm_fq_checkpoint(
28+
model: nn.Module,
29+
export_dir: Path | str,
30+
) -> dict[str, torch.Tensor]:
31+
"""Exports the torch model weights and amax values separately.
32+
33+
This function:
34+
1. Extracts amax values for calibration
35+
2. Deletes all quantizer parameters from state dict to store only weights in original dtype
36+
37+
Args:
38+
model: The quantized model to export
39+
export_dir: Directory to save the amax values
40+
41+
Returns:
42+
post_state_dict: Dict containing quantized weights
43+
"""
44+
amax_dict = {
45+
name + "._amax": param["_amax"].detach().clone().cpu()
46+
for name, param in get_quantizer_state_dict(model).items()
47+
if "_amax" in param
48+
}
49+
50+
# remove quantizer from model
51+
for _, module in model.named_modules():
52+
if is_quantlinear(module):
53+
delattr(module, "weight_quantizer")
54+
delattr(module, "input_quantizer")
55+
delattr(module, "output_quantizer")
56+
module.export()
57+
torch.save(amax_dict, f"{export_dir}/quant_amax.pth")
58+
return model.state_dict()
59+
60+
61+
def get_mcore_vllm_fq_quantized_state(module: torch.nn.Module, name_to_value: dict, dtype: torch.dtype = torch.bfloat16):
62+
""" Return a state_dict, quantization format, and block_size of the quantized module.
63+
64+
Args:
65+
module: The target module to perform real quantization.
66+
name_to_value: The dictionary to store the quantized state.
67+
dtype: The default data type.
68+
69+
Returns:
70+
Tuple: state dict, quantization format, and block_size of the quantized module.
71+
72+
"""
73+
74+
qformat: str = QUANTIZATION_NONE
75+
block_size = 0
76+
77+
for name, param in get_quantizer_state_dict(module).items():
78+
if "_amax" in param:
79+
name_to_value[name + "._amax"] = param["_amax"].to(dtype).cpu()
80+
return name_to_value, qformat, block_size
81+
82+
def gather_mcore_vllm_fq_quantized_state_dict(state_dict: dict[str, torch.Tensor], save_directory: str | os.PathLike):
83+
"""
84+
Gather all quantized state dict from all ranks and save them to a file.
85+
86+
Args:
87+
state_dict: The state dictionary of the module.
88+
save_directory: The directory to save the quantized state dict.
89+
90+
Returns:
91+
The state dictionary of the module without quantized state.
92+
"""
93+
amax_state_dict = {
94+
k: v.detach().clone().cpu() for k, v in state_dict.items() if k.endswith("_amax")
95+
}
96+
97+
# Gather all amax dicts to rank 0
98+
world_size = torch.distributed.get_world_size()
99+
rank = torch.distributed.get_rank()
100+
101+
if rank == 0:
102+
# Rank 0 will collect all amax values
103+
all_amax_dicts = [None] * world_size
104+
torch.distributed.gather_object(amax_state_dict, all_amax_dicts, dst=0)
105+
106+
# Merge all amax dicts into one
107+
merged_amax_dict = {}
108+
for amax_dict in all_amax_dicts:
109+
if amax_dict is not None:
110+
merged_amax_dict.update(amax_dict)
111+
112+
print(f"Total amax entries from all ranks: {len(merged_amax_dict.keys())}")
113+
torch.save(merged_amax_dict, save_directory + "/quant_amax.pth")
114+
else:
115+
# Other ranks just send their amax values
116+
torch.distributed.gather_object(amax_state_dict, None, dst=0)
117+
118+
torch.distributed.barrier()
119+
120+
# remove amax values from state_dict
121+
return {k: v for k, v in state_dict.items() if not k.endswith("_amax")}
122+

modelopt/torch/export/unified_export_hf.py

Lines changed: 7 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,7 @@
3333
from modelopt.torch.quantization import set_quantizer_by_cfg_context
3434
from modelopt.torch.quantization.nn import SequentialQuantizer, TensorQuantizer
3535
from modelopt.torch.quantization.qtensor import NVFP4QTensor
36-
from modelopt.torch.quantization.utils import (
37-
fsdp2_aware_weight_update,
38-
get_quantizer_state_dict,
39-
quantizer_attr_names,
40-
)
36+
from modelopt.torch.quantization.utils import fsdp2_aware_weight_update, quantizer_attr_names
4137

4238
from .convert_hf_config import convert_hf_quant_config_format
4339
from .layer_utils import (
@@ -62,6 +58,7 @@
6258
)
6359
from .model_utils import get_language_model_from_vl, is_multimodal_model
6460
from .plugins import export_spec_ckpt_config, export_spec_ckpt_state_dict, spec_opt_only
61+
from .plugins.vllm_fakequant import export_hf_vllm_fq_checkpoint
6562
from .quant_utils import (
6663
fuse_prequant_layernorm,
6764
fuse_prequant_to_linear,
@@ -536,44 +533,12 @@ def _export_hf_checkpoint(
536533
return quantized_state_dict, quant_config
537534

538535

539-
def _export_hf_bf16_weights_amax(
540-
model: nn.Module,
541-
) -> tuple[dict[str, torch.Tensor], dict[str, Any]]:
542-
"""Exports the torch model weights and amax values separately.
543-
544-
This function:
545-
1. Extracts amax values for calibration
546-
2. Deletes all quantizer parameters from state dict to store only weights in original dtype
547-
548-
Args:
549-
model: The quantized model to export
550-
551-
Returns:
552-
post_state_dict: Dict containing quantized weights
553-
amax_dict: Dict containing amax values
554-
"""
555-
amax_dict = {
556-
name + "._amax": param["_amax"].detach().clone().cpu()
557-
for name, param in get_quantizer_state_dict(model).items()
558-
if "_amax" in param
559-
}
560-
561-
# remove quantizer from model
562-
for name, module in model.named_modules():
563-
if is_quantlinear(module):
564-
delattr(module, "weight_quantizer")
565-
delattr(module, "input_quantizer")
566-
delattr(module, "output_quantizer")
567-
module.export()
568-
return model.state_dict(), amax_dict
569-
570-
571536
def export_hf_checkpoint(
572537
model: nn.Module,
573538
dtype: torch.dtype | None = None,
574539
export_dir: Path | str = tempfile.gettempdir(),
575540
save_modelopt_state: bool = False,
576-
export_bf16_weights_amax: bool = False,
541+
export_vllm_fq_weights_qstate: bool = False,
577542
):
578543
"""Exports the torch model to unified checkpoint and saves to export_dir.
579544
@@ -582,8 +547,8 @@ def export_hf_checkpoint(
582547
dtype: the weights data type to export the unquantized layers or the default model data type if None.
583548
export_dir: the target export path.
584549
save_modelopt_state: whether to save the modelopt state_dict.
585-
export_bf16_weights_amax: whether to export the bf16 weights and amax values separately. This can be used for
586-
vLLM fakequant serving.
550+
export_vllm_fq_weights_qstate: whether to export the weights and quantization state separately for vLLM
551+
fakequant serving.
587552
"""
588553
export_dir = Path(export_dir)
589554
export_dir.mkdir(parents=True, exist_ok=True)
@@ -597,10 +562,9 @@ def export_hf_checkpoint(
597562
return
598563

599564
try:
600-
if export_bf16_weights_amax:
601-
post_state_dict, amax_dict = _export_hf_bf16_weights_amax(model)
565+
if export_vllm_fq_weights_qstate:
566+
post_state_dict = export_hf_vllm_fq_checkpoint(model, export_dir)
602567
hf_quant_config = None
603-
torch.save(amax_dict, f"{export_dir}/quant_amax.pth")
604568
else:
605569
post_state_dict, hf_quant_config = _export_hf_checkpoint(model, dtype)
606570

0 commit comments

Comments
 (0)