Skip to content

Commit 32f7ae8

Browse files
committed
added test for vllm fq export
Signed-off-by: Kinjal Patel <[email protected]>
1 parent 7aa0559 commit 32f7ae8

File tree

1 file changed

+194
-0
lines changed

1 file changed

+194
-0
lines changed
Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,194 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import pytest
17+
import torch
18+
from copy import deepcopy
19+
from functools import partial
20+
import modelopt.torch.quantization as mtq
21+
from modelopt.torch.export.unified_export_hf import export_hf_checkpoint
22+
from modelopt.torch.export.unified_export_megatron import export_mcore_gpt_to_hf
23+
from _test_utils.torch.transformers_models import create_tiny_llama_dir
24+
from _test_utils.torch.distributed.utils import spawn_multiprocess_job
25+
from _test_utils.torch.megatron.models import get_mcore_gpt_model
26+
from _test_utils.import_helper import skip_if_no_megatron
27+
from transformers import AutoModelForCausalLM
28+
29+
import os
30+
import json
31+
32+
skip_if_no_megatron(apex_or_te_required=True)
33+
34+
@pytest.mark.parametrize("quant_cfg", [mtq.FP8_DEFAULT_CFG])
35+
def test_hf_vllm_export(tmp_path, quant_cfg):
36+
"""Test HuggingFace model export for vLLM with fake quantization.
37+
38+
This test verifies:
39+
1. Model weights match before and after export
40+
2. quant_amax.pth file is created, huggingface config file does not exist
41+
3. Amax values are correctly extracted and saved in quant_amax.pth file
42+
"""
43+
44+
# Create a tiny LLaMA model for testing
45+
tiny_model_dir = create_tiny_llama_dir(tmp_path, with_tokenizer=True, num_hidden_layers=2)
46+
47+
# Load the model
48+
model = AutoModelForCausalLM.from_pretrained(tiny_model_dir)
49+
model = model.cuda()
50+
model.eval()
51+
52+
# Quantize the model
53+
def forward_loop(model):
54+
input_ids = torch.randint(0, model.config.vocab_size, (1, 128)).cuda()
55+
with torch.no_grad():
56+
model(input_ids)
57+
58+
model = mtq.quantize(model, quant_cfg, forward_loop)
59+
60+
model_state_dict = deepcopy(model.state_dict())
61+
62+
# Export directory
63+
export_dir = tmp_path / "vllm_export"
64+
export_dir.mkdir(exist_ok=True)
65+
66+
# Export for vLLM
67+
export_hf_checkpoint(model, export_dir=export_dir, export_vllm_fq_weights_qstate=True)
68+
69+
# check if quant_amax.pth file exists
70+
quant_amax_file = export_dir / "quant_amax.pth"
71+
assert quant_amax_file.exists(), f"quant_amax.pth file should be created in {export_dir}"
72+
73+
# make sure hf_quant_config.json file does not exist
74+
hf_quant_config_file = export_dir / "hf_quant_config.json"
75+
assert not hf_quant_config_file.exists(), f"hf_quant_config.json file should not be created in {export_dir}"
76+
77+
# check weights match before and after export
78+
model_after = AutoModelForCausalLM.from_pretrained(export_dir)
79+
model_after = model_after.cuda()
80+
model_after.eval()
81+
model_after_state_dict = model_after.state_dict()
82+
amax_state_dict = {}
83+
for key in model_state_dict.keys():
84+
if key.endswith("_amax"):
85+
amax_state_dict[key] = model_state_dict[key]
86+
continue
87+
88+
assert torch.allclose(model_state_dict[key], model_after_state_dict[key], atol=1e-6), (
89+
f"Weight mismatch for {key}: "
90+
f"before shape={model_state_dict[key].shape}, after shape={model_after_state_dict[key].shape}, "
91+
f"max diff={torch.abs(model_state_dict[key] - model_after_state_dict[key]).max()}"
92+
)
93+
94+
# Verify amax values are correct
95+
amax_dict = torch.load(quant_amax_file)
96+
assert len(amax_dict) > 0, "amax_dict should not be empty"
97+
assert amax_dict.keys() == amax_state_dict.keys(), f"amax keys mismatch between before and after export"
98+
99+
100+
def _test_mcore_vllm_export(tmp_path, quant_cfg, rank, size):
101+
"""Test megatron-core model export for vLLM with fake quantization.
102+
103+
"""
104+
# Create a tiny mcore GPT model
105+
num_layers = 2
106+
hidden_size = 64
107+
num_attention_heads = 8
108+
num_query_groups = size
109+
ffn_hidden_size = 128
110+
max_sequence_length = 32
111+
vocab_size = 64
112+
113+
model = get_mcore_gpt_model(
114+
tensor_model_parallel_size=size,
115+
pipeline_model_parallel_size=1,
116+
initialize_megatron=True,
117+
num_layers=num_layers,
118+
hidden_size=hidden_size,
119+
num_attention_heads=num_attention_heads,
120+
num_query_groups=num_query_groups,
121+
ffn_hidden_size=ffn_hidden_size,
122+
max_sequence_length=max_sequence_length,
123+
vocab_size=vocab_size,
124+
activation_func="swiglu",
125+
normalization="RMSNorm",
126+
transformer_impl="modelopt",
127+
).cuda()
128+
model.eval()
129+
130+
# Quantize the model
131+
def forward_loop(model):
132+
batch_size = 1
133+
seq_len = 32
134+
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len)).cuda()
135+
position_ids = torch.arange(seq_len).unsqueeze(0).cuda()
136+
# Create causal attention mask
137+
attention_mask = torch.tril(torch.ones((1, 1, seq_len, seq_len))).cuda()
138+
attention_mask = attention_mask < 0.5 # Convert to boolean mask
139+
with torch.no_grad():
140+
model(input_ids, position_ids, attention_mask)
141+
142+
model = mtq.quantize(model, quant_cfg, forward_loop)
143+
144+
model_state_dict = deepcopy(model.state_dict())
145+
146+
# Create HF config for export
147+
pretrained_config = {
148+
"architectures": ["LlamaForCausalLM"],
149+
"attention_bias": False,
150+
"hidden_size": hidden_size,
151+
"intermediate_size": ffn_hidden_size,
152+
"max_position_embeddings": max_sequence_length,
153+
"model_type": "llama",
154+
"num_attention_heads": num_attention_heads,
155+
"num_hidden_layers": num_layers,
156+
"num_key_value_heads": num_query_groups,
157+
"torch_dtype": "bfloat16",
158+
}
159+
160+
with open(tmp_path / "config.json", "w") as f:
161+
json.dump(pretrained_config, f)
162+
163+
# Export directory
164+
export_dir = tmp_path / "vllm_export"
165+
export_dir.mkdir(exist_ok=True)
166+
167+
# Export for vLLM
168+
export_mcore_gpt_to_hf(
169+
model,
170+
pretrained_model_name_or_path=tmp_path,
171+
dtype=torch.bfloat16,
172+
export_dir=str(export_dir),
173+
export_vllm_fq_weights_qstate=True,
174+
)
175+
176+
# check if quant_amax.pth file exists
177+
quant_amax_file = export_dir / "quant_amax.pth"
178+
assert quant_amax_file.exists(), f"quant_amax.pth file should be created in {export_dir}"
179+
180+
# make sure hf_quant_config.json file does not exist
181+
hf_quant_config_file = export_dir / "hf_quant_config.json"
182+
assert not hf_quant_config_file.exists(), f"hf_quant_config.json file should not be created in {export_dir}"
183+
184+
185+
@pytest.mark.parametrize("quant_cfg", [mtq.FP8_DEFAULT_CFG])
186+
def test_mcore_vllm_export(tmp_path, quant_cfg):
187+
"""Wrapper test function for mcore vLLM export."""
188+
spawn_multiprocess_job(
189+
size=1,
190+
job=partial(_test_mcore_vllm_export, tmp_path, quant_cfg),
191+
backend="nccl",
192+
)
193+
194+

0 commit comments

Comments
 (0)