Skip to content

Conversation

@kinjalpatel27
Copy link
Contributor

@kinjalpatel27 kinjalpatel27 commented Nov 19, 2025

What does this PR do?

Type of change: New Feature

Overview:

Support for vLLM fakequantize QAT/QAD checkpoint evaluation. This MR adds function to export checkpoint as BF16 weights and amax using export_hf_checkpoint for HF and export_mcore_gpt_to_hf for MCore using export_bf16_weights_amax option. The exported weights and amax can be used with vllm_serve_fakequant.py script to run saved checkpoint.

Usage

Refer to README.md

Testing

  • Tested HF approach by exporting bf16 model using QAT script and running vllm server, verified amax values match
  • Tested MCore approach by quantizing and exporting bf16 model using quantize.sh and export.sh script and running vllm server, verified amax values match

Before your PR is "Ready for review"

  • Make sure you read and follow Contributor guidelines and your commits are signed.
  • Is this change backward compatible?: Yes
  • Did you write any new necessary tests?: No
  • Did you add or update any necessary documentation?: Yes
  • Did you update Changelog?: Yes

Additional Information

MCore export script doesn't have the option to export enable currently

@copy-pr-bot
Copy link

copy-pr-bot bot commented Nov 19, 2025

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@codecov
Copy link

codecov bot commented Nov 19, 2025

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 74.47%. Comparing base (1d0ee04) to head (13f6bcd).
⚠️ Report is 7 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main     #579      +/-   ##
==========================================
+ Coverage   74.43%   74.47%   +0.04%     
==========================================
  Files         182      182              
  Lines       18234    18255      +21     
==========================================
+ Hits        13572    13596      +24     
+ Misses       4662     4659       -3     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@kinjalpatel27 kinjalpatel27 force-pushed the kinjal/bf16_weight_amax_export branch from 9946463 to 560dfc7 Compare November 19, 2025 22:09
@kinjalpatel27 kinjalpatel27 marked this pull request as ready for review November 19, 2025 22:10
@kinjalpatel27 kinjalpatel27 requested review from a team as code owners November 19, 2025 22:10
@kinjalpatel27 kinjalpatel27 self-assigned this Nov 19, 2025
Signed-off-by: Kinjal Patel <[email protected]>
Signed-off-by: Kinjal Patel <[email protected]>
## Known Problems

1. AWQ is not yet supported in vLLM.
2. PTQ/QAT checkpoint doesn't work with KV Cache quantization enabled.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Kinjal for documenting this. Create a jira ticket to address this - https://jirasw.nvidia.com/browse/OMNIML-3051

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for creating the ticket

Copy link
Contributor

@realAsma realAsma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to export the entire TensorQuantizer state - This way we can seamlessly support PTQ/QAT and fake quantization.

The current support does not work in cases such as Mixed precision quantization (some layers in FP4, some in FP8, some disabled etc.) - we would need manual work arounds for this case. This support also does not work for other quantization such as AWQ.

We are relying on the fact that we are quantizing the model with the same quantization formats as that of PTQ/QAT during vllm_serve.

Signed-off-by: Kinjal Patel <[email protected]>
@kinjalpatel27 kinjalpatel27 force-pushed the kinjal/bf16_weight_amax_export branch from e5a095d to 13f6bcd Compare November 21, 2025 18:49
@kinjalpatel27
Copy link
Contributor Author

kinjalpatel27 commented Nov 21, 2025

Is it possible to export the entire TensorQuantizer state - This way we can seamlessly support PTQ/QAT and fake quantization.

The current support does not work in cases such as Mixed precision quantization (some layers in FP4, some in FP8, some disabled etc.) - we would need manual work arounds for this case. This support also does not work for other quantization such as AWQ.

We are relying on the fact that we are quantizing the model with the same quantization formats as that of PTQ/QAT during vllm_serve.

I have created two tickets to explore mixed precision and other quantization algorithm support. Exporting Tensorquantizer state and loading may require additional effort since vLLM model also combines multiple layers etc.

Copy link
Contributor

@meenchen meenchen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you also add unit tests for modelopt.torch.export.unified_export_hf.export_hf_checkpoint and modelopt.torch.export.unified_export_megatron.export_mcore_gpt_to_hf

Comment on lines +70 to +74
gate_up_match = "mixer" not in key and re.search(r"(.*\.)(gate|up)_proj(\..+_amax)$", key)
if gate_up_match:
base_pattern = gate_up_match.group(1) + "gate_up_proj" + gate_up_match.group(3)
merge_groups[base_pattern].append((key, value))
continue
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this work with MoE models that use this quant module:

class _QuantFusedMoEBase(QuantModule):

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you give an example for which model you are talking about?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, you can try Qwen/Qwen3-30B-A3B-Instruct-2507

@kinjalpatel27 kinjalpatel27 force-pushed the kinjal/bf16_weight_amax_export branch from 32968c9 to 32f7ae8 Compare November 21, 2025 21:43
@kinjalpatel27
Copy link
Contributor Author

kinjalpatel27 commented Nov 21, 2025

Can you also add unit tests for modelopt.torch.export.unified_export_hf.export_hf_checkpoint and modelopt.torch.export.unified_export_megatron.export_mcore_gpt_to_hf

@meenchen Added tests for both in separate file: tests/gpu/torch/export/test_vllm_fakequant_export.py

Copy link
Contributor

@meenchen meenchen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Could you also try if Qwen/Qwen3-30B-A3B-Instruct-2507 works with this change?

@kinjalpatel27
Copy link
Contributor Author

LGTM. Could you also try if Qwen/Qwen3-30B-A3B-Instruct-2507 works with this change?

@meenchen Thank you. I checked Qwen/Qwen3-30B-A3B-Instruct-2507, It doesn't work with FuseMoE yet. I am looking into fixing it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants