Skip to content

Commit 6467ec2

Browse files
realAsmaAsma Kuriparambil Thekkumpate
andauthored
[2/N] Added KDLoss based AutoQuantize (#592)
## What does this PR do? **Type of change:** ? New Feature **Overview:** This PR extends AutoQuantize with KL Divergence Loss-based sensitivity measurement as an alternative to the existing gradient-based approach. KD Loss uses a binary searcher similar to the binary searcher in FastNAS. AutoQuantize gradient is faster than KL Divergence based AutoQuantize. However KL Divergence does not need the model implementation to support gradient backward. In addition, AutoQuantize collected KL Divergence is useful for sensitivity analysis of the model. KL Divergence is a more direct measure of sensitivity than gradient scores. ## Usage see `tests/unit/torch/quantization/test_autoquant.py` ## Testing Testes with unit tests. Result for Qwen3 8B <img width="1979" height="980" alt="image" src="https://github.com/user-attachments/assets/6cc36425-ea60-4a76-a3c6-25293667c742" /> ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes/No <!--- If No, explain why. --> - **Did you write any new necessary tests?**: Yes/No - **Did you add or update any necessary documentation?**: Yes/No - **Did you update [Changelog](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CHANGELOG.rst)?**: Yes/No <!--- Only for new features, API changes, critical bug fixes or bw breaking changes. --> ## Additional Information <!-- E.g. related issue. --> --------- Signed-off-by: realAsma <[email protected]> Co-authored-by: Asma Kuriparambil Thekkumpate <[email protected]>
1 parent 4b72089 commit 6467ec2

File tree

12 files changed

+614
-66
lines changed

12 files changed

+614
-66
lines changed

CHANGELOG.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ Model Optimizer Changelog (Linux)
1414
- Add MoE (e.g. Qwen3-30B-A3B, gpt-oss-20b) pruning support for ``num_moe_experts``, ``moe_ffn_hidden_size`` and ``moe_shared_expert_intermediate_size`` parameters in Minitron pruning (``mcore_minitron``).
1515
- Add ``specdec_bench`` example to benchmark speculative decoding performance. See `examples/specdec_bench/README.md <https://github.com/NVIDIA/TensorRT-Model-Optimizer/tree/main/examples/specdec_bench#speculative-decoding-benchmark>`_ for more details.
1616
- Add FP8/NVFP4 KV cache quantization support for Megatron Core models.
17+
- Add KL Divergence loss based auto_quantize method. See `auto_quantize API docs <https://nvidia.github.io/TensorRT-Model-Optimizer/reference/generated/modelopt.torch.quantization.model_quant.html#modelopt.torch.quantization.model_quant.auto_quantize>`_ for more details.
18+
- Add support for saving and resuming auto_quantize search state. This speeds up the auto_quantize process by skipping the score estimation step if the search state is provided.
1719
- Add flag ``trt_plugins_precision`` in ONNX autocast to indicate custom ops precision. This is similar to the flag already existing in the quantization workflow.
1820
- Add support for PyTorch Geometric quantization.
1921
- Add per tensor and per channel MSE calibrator support.

examples/llm_eval/gen_model_answer.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,8 +201,11 @@ def get_model_answers(
201201
tokenizer,
202202
args.calib_batch_size,
203203
args.calib_size,
204-
args.auto_quantize_bits,
205204
test_generated=False,
205+
auto_quantize_bits=args.auto_quantize_bits,
206+
auto_quantize_method=args.auto_quantize_method,
207+
auto_quantize_score_size=args.auto_quantize_score_size,
208+
auto_quantize_checkpoint=args.auto_quantize_checkpoint,
206209
)
207210

208211
for question in tqdm(questions):
@@ -450,6 +453,36 @@ def reorg_answer_file(answer_file):
450453
"regular quantization without auto_quantize search will be applied."
451454
),
452455
)
456+
parser.add_argument(
457+
"--auto_quantize_method",
458+
type=str,
459+
default="gradient",
460+
choices=["gradient", "kl_div"],
461+
help=(
462+
"Method for auto_quantize sensitivity analysis. 'gradient' uses gradient-based method "
463+
"(requires labels in dataset). 'kl_div' uses KL divergence between original and "
464+
"quantized model outputs (no labels required). Default: 'gradient'"
465+
),
466+
)
467+
parser.add_argument(
468+
"--auto_quantize_score_size",
469+
type=int,
470+
default=128,
471+
help=(
472+
"Number of samples to use for auto_quantize scoring. Most of auto_quantize time is spent on "
473+
"sensitivity score estimation, so reducing this speeds it up while only minimally affecting "
474+
"final model accuracy compared to lowering --calib_size (the number of samples used for calibration)."
475+
),
476+
)
477+
parser.add_argument(
478+
"--auto_quantize_checkpoint",
479+
type=str,
480+
default=None,
481+
help=(
482+
"Path to checkpoint file for saving/restoring auto_quantize search state "
483+
"(sensitivity scores, costs, etc.). Only used when auto_quantize_bits is specified."
484+
),
485+
)
453486
parser.add_argument(
454487
"--trust_remote_code",
455488
help="Set trust_remote_code for Huggingface models and tokenizers",

examples/llm_eval/lm_eval_hf.py

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,9 @@ def create_from_arg_obj(cls: type[T], arg_dict: dict, additional_config: dict |
5353

5454
quant_cfg = arg_dict.pop("quant_cfg", None)
5555
auto_quantize_bits = arg_dict.pop("auto_quantize_bits", None)
56+
auto_quantize_method = arg_dict.pop("auto_quantize_method", "gradient")
57+
auto_quantize_score_size = arg_dict.pop("auto_quantize_score_size", 128)
58+
auto_quantize_checkpoint = arg_dict.pop("auto_quantize_checkpoint", None)
5659
calib_batch_size = arg_dict.pop("calib_batch_size", None)
5760
calib_size = arg_dict.pop("calib_size", 512)
5861
compress = arg_dict.pop("compress", False)
@@ -81,8 +84,11 @@ def create_from_arg_obj(cls: type[T], arg_dict: dict, additional_config: dict |
8184
batch_size=calib_batch_size,
8285
calib_size=calib_size,
8386
auto_quantize_bits=auto_quantize_bits,
87+
auto_quantize_method=auto_quantize_method,
88+
auto_quantize_score_size=auto_quantize_score_size,
8489
test_generated=False,
8590
compress=compress,
91+
auto_quantize_checkpoint=auto_quantize_checkpoint,
8692
)
8793

8894
return model_obj
@@ -101,6 +107,12 @@ def setup_parser_with_modelopt_args():
101107
"comma-separated list of quantization quantization formats that will be searched by `auto_quantize`"
102108
),
103109
)
110+
parser.add_argument(
111+
"--calib_batch_size", type=int, help="Batch size for quantization calibration"
112+
)
113+
parser.add_argument(
114+
"--calib_size", type=int, help="Calibration size for quantization", default=512
115+
)
104116
parser.add_argument(
105117
"--auto_quantize_bits",
106118
type=float,
@@ -110,10 +122,30 @@ def setup_parser_with_modelopt_args():
110122
),
111123
)
112124
parser.add_argument(
113-
"--calib_batch_size", type=int, help="Batch size for quantization calibration"
125+
"--auto_quantize_method",
126+
type=str,
127+
default="gradient",
128+
choices=["gradient", "kl_div"],
129+
help=(
130+
"Method for auto_quantize sensitivity analysis. 'gradient' uses gradient-based method "
131+
"(requires labels in dataset). 'kl_div' uses KL divergence between original and "
132+
"quantized model outputs (no labels required). Default: 'gradient'"
133+
),
114134
)
115135
parser.add_argument(
116-
"--calib_size", type=int, help="Calibration size for quantization", default=512
136+
"--auto_quantize_score_size",
137+
type=int,
138+
default=128,
139+
help=(
140+
"Number of samples to use for auto_quantize scoring. Most of auto_quantize time is spent on "
141+
"sensitivity score estimation, so reducing this speeds it up while only minimally affecting "
142+
"final model accuracy compared to lowering --calib_size (the number of samples used for calibration)."
143+
),
144+
)
145+
parser.add_argument(
146+
"--auto_quantize_checkpoint",
147+
type=str,
148+
help=("Path to checkpoint file for saving/restoring auto_quantize search state. "),
117149
)
118150
parser.add_argument(
119151
"--compress",
@@ -139,6 +171,9 @@ def setup_parser_with_modelopt_args():
139171
{
140172
"quant_cfg": args.quant_cfg,
141173
"auto_quantize_bits": args.auto_quantize_bits,
174+
"auto_quantize_method": args.auto_quantize_method,
175+
"auto_quantize_score_size": args.auto_quantize_score_size,
176+
"auto_quantize_checkpoint": args.auto_quantize_checkpoint,
142177
"calib_batch_size": args.calib_batch_size,
143178
"calib_size": args.calib_size,
144179
"compress": args.compress,

examples/llm_eval/mmlu.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,9 @@ def main(
227227
batch_size: int = 0,
228228
calib_size: int = 512,
229229
dtype: str = "bfloat16",
230+
auto_quantize_method: str = "gradient",
231+
auto_quantize_score_size: int = 128,
232+
auto_quantize_checkpoint: str | None = None,
230233
**kwargs,
231234
):
232235
random.seed(RAND_SEED)
@@ -281,6 +284,9 @@ def main(
281284
batch_size=batch_size,
282285
calib_size=calib_size,
283286
auto_quantize_bits=auto_quantize_bits,
287+
auto_quantize_method=auto_quantize_method,
288+
auto_quantize_score_size=auto_quantize_score_size,
289+
auto_quantize_checkpoint=auto_quantize_checkpoint,
284290
)
285291

286292
for subject in tqdm(subjects):

examples/llm_eval/quantization_utils.py

Lines changed: 52 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,11 @@ def _quantize_model_with_dataset(
6666
quant_cfg: str | list[str],
6767
calib_dataset,
6868
auto_quantize_bits=None,
69+
auto_quantize_method="gradient",
70+
auto_quantize_score_size=128,
6971
batch_size=1,
7072
compress=False,
73+
auto_quantize_checkpoint=None,
7174
):
7275
if hasattr(lm, "gpt2"):
7376
net = lm.gpt2
@@ -81,23 +84,42 @@ def _quantize_model_with_dataset(
8184
getattr(mtq, quant_fmt) for quant_fmt in quant_cfg if quant_fmt != "NONE"
8285
]
8386

84-
def loss_func(output, data):
85-
# For transformers AutoModelForCausalLM models, the outputs are wrapped in `CausalLMOutputWithPast`
86-
# which contains the loss attribute.
87-
return output.loss
87+
# Configure forward_step and loss_func based on method
88+
if auto_quantize_method == "gradient":
89+
# For gradient-based method, return full output with loss
90+
def forward_step(model, batch):
91+
return model(**batch)
92+
93+
def loss_func(output, data):
94+
# For transformers AutoModelForCausalLM models, the outputs are wrapped in `CausalLMOutputWithPast`
95+
# which contains the loss attribute.
96+
return output.loss
97+
elif auto_quantize_method == "kl_div":
98+
# For KL divergence method, return only logits
99+
def forward_step(model, batch):
100+
return model(**batch).logits
101+
102+
loss_func = None # KL divergence doesn't need a custom loss function
103+
else:
104+
raise ValueError(
105+
f"Invalid auto_quantize_method: {auto_quantize_method}. "
106+
"Must be 'gradient' or 'kl_div'"
107+
)
88108

89109
net, _ = mtq.auto_quantize(
90110
net,
91111
constraints={"effective_bits": auto_quantize_bits},
92112
quantization_formats=quant_cfg_for_search,
93113
data_loader=calib_dataset,
94-
forward_step=lambda model, batch: model(**batch),
114+
forward_step=forward_step,
95115
loss_func=loss_func,
96116
num_calib_steps=len(calib_dataset),
97-
num_score_steps=min(
98-
len(calib_dataset), 128 // batch_size
99-
), # Limit the number of score steps to avoid long calibration time
117+
# Most time is spent on score estimation; fewer samples speed it up with little accuracy impact.
118+
num_score_steps=min(len(calib_dataset), max(auto_quantize_score_size // batch_size, 1)),
100119
verbose=True,
120+
method=auto_quantize_method,
121+
# disabled_layers=["*lm_head*", "*mlp.gate.*"],
122+
checkpoint=auto_quantize_checkpoint,
101123
)
102124
else:
103125
mtq_cfg = CUSTOM_CONFIG.get(quant_cfg) # type: ignore [arg-type]
@@ -141,10 +163,13 @@ def quantize_model(
141163
tokenizer,
142164
batch_size,
143165
calib_size,
144-
auto_quantize_bits=None,
145166
data="cnn_dailymail",
146167
test_generated=True,
147168
compress=False,
169+
auto_quantize_bits=None,
170+
auto_quantize_method="gradient",
171+
auto_quantize_score_size=128,
172+
auto_quantize_checkpoint=None,
148173
):
149174
"""Quantizes the model with the provided calibration dataset.
150175
@@ -155,10 +180,14 @@ def quantize_model(
155180
tokenizer: the tokenizer.
156181
batch_size: the calibration batch size for each calibration inference run.
157182
calib_size: the total calibration dataset size.
158-
auto_quantize_bits: The effective bits constraint for auto_quantize.
159183
data: the name of the calibration dataset.
160184
test_generated: If ``True``, test the generated text before and after quantization.
161185
compress: If ``True``, compress the model after quantization.
186+
auto_quantize_bits: The effective bits constraint for auto_quantize.
187+
auto_quantize_method: The method for auto_quantize ('gradient' or 'kl_div').
188+
auto_quantize_score_size: Number of samples used for auto_quantize scoring.
189+
auto_quantize_checkpoint: Path to checkpoint file for saving/restoring auto_quantize search state
190+
(sensitivity scores, costs, etc.). Only used when auto_quantize_bits is specified.
162191
"""
163192
if "AWQ" in quant_cfg:
164193
print(
@@ -170,8 +199,10 @@ def quantize_model(
170199
if hasattr(model, "model"):
171200
device = model.model.device
172201

202+
is_gradient_based = auto_quantize_bits is not None and auto_quantize_method == "gradient"
203+
173204
if batch_size == 0:
174-
if auto_quantize_bits is not None or torch.distributed.is_initialized():
205+
if is_gradient_based or torch.distributed.is_initialized():
175206
raise ValueError("We dont support automatic batch size inference for this case.")
176207

177208
net = model.gpt2 if hasattr(model, "gpt2") else model.model
@@ -186,15 +217,23 @@ def quantize_model(
186217
batch_size=batch_size,
187218
num_samples=calib_size,
188219
device=device,
189-
include_labels=auto_quantize_bits is not None,
220+
include_labels=is_gradient_based,
190221
)
191222

192223
if test_generated:
193224
input_str = tokenizer.decode(next(iter(calib_dataloader))["input_ids"][0])
194225
generated_str_before_ptq = model.run(input_str)
195226

196227
_quantize_model_with_dataset(
197-
model, quant_cfg, calib_dataloader, auto_quantize_bits, batch_size, compress
228+
model,
229+
quant_cfg,
230+
calib_dataloader,
231+
auto_quantize_bits,
232+
auto_quantize_method,
233+
auto_quantize_score_size,
234+
batch_size,
235+
compress,
236+
auto_quantize_checkpoint,
198237
)
199238

200239
if test_generated:

0 commit comments

Comments
 (0)