Skip to content

Commit 0aada4e

Browse files
Asma Kuriparambil ThekkumpaterealAsma
authored andcommitted
[3/N] Added autoquantize search state save/restore support
Some improvements for KLDiv Signed-off-by: realAsma <[email protected]> changelog update Signed-off-by: realAsma <[email protected]> minor Signed-off-by: realAsma <[email protected]> doc updates Signed-off-by: realAsma <[email protected]>
1 parent 1b52477 commit 0aada4e

File tree

11 files changed

+292
-52
lines changed

11 files changed

+292
-52
lines changed

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ Model Optimizer Changelog (Linux)
1515
- Add ``specdec_bench`` example to benchmark speculative decoding performance. See `examples/specdec_bench/README.md <https://github.com/NVIDIA/TensorRT-Model-Optimizer/tree/main/examples/specdec_bench#speculative-decoding-benchmark>`_ for more details.
1616
- Add FP8/NVFP4 KV cache quantization support for Megatron Core models.
1717
- Add KL Divergence loss based auto_quantize method. See `auto_quantize API docs <https://nvidia.github.io/TensorRT-Model-Optimizer/reference/generated/modelopt.torch.quantization.model_quant.html#modelopt.torch.quantization.model_quant.auto_quantize>`_ for more details.
18+
- Add support for saving and resuming auto_quantize search state. This speeds up the auto_quantize process by skipping the score estimation step if the search state is provided.
1819
- Add flag ``trt_plugins_precision`` in ONNX autocast to indicate custom ops precision. This is similar to the flag already existing in the quantization workflow.
1920
- Add support for PyTorch Geometric quantization.
2021
- Add per tensor and per channel MSE calibrator support.

examples/llm_eval/gen_model_answer.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,8 +201,11 @@ def get_model_answers(
201201
tokenizer,
202202
args.calib_batch_size,
203203
args.calib_size,
204-
args.auto_quantize_bits,
205204
test_generated=False,
205+
auto_quantize_bits=args.auto_quantize_bits,
206+
auto_quantize_method=args.auto_quantize_method,
207+
auto_quantize_score_size=args.auto_quantize_score_size,
208+
auto_quantize_checkpoint=args.auto_quantize_checkpoint,
206209
)
207210

208211
for question in tqdm(questions):
@@ -450,6 +453,36 @@ def reorg_answer_file(answer_file):
450453
"regular quantization without auto_quantize search will be applied."
451454
),
452455
)
456+
parser.add_argument(
457+
"--auto_quantize_method",
458+
type=str,
459+
default="gradient",
460+
choices=["gradient", "kl_div"],
461+
help=(
462+
"Method for auto_quantize sensitivity analysis. 'gradient' uses gradient-based method "
463+
"(requires labels in dataset). 'kl_div' uses KL divergence between original and "
464+
"quantized model outputs (no labels required). Default: 'gradient'"
465+
),
466+
)
467+
parser.add_argument(
468+
"--auto_quantize_score_size",
469+
type=int,
470+
default=128,
471+
help=(
472+
"Number of samples to use for auto_quantize scoring. Most of auto_quantize time is spent on "
473+
"sensitivity score estimation, so reducing this speeds it up while only minimally affecting "
474+
"final model accuracy compared to lowering --calib_size (the number of samples used for calibration)."
475+
),
476+
)
477+
parser.add_argument(
478+
"--auto_quantize_checkpoint",
479+
type=str,
480+
default=None,
481+
help=(
482+
"Path to checkpoint file for saving/restoring auto_quantize search state "
483+
"(sensitivity scores, costs, etc.). Only used when auto_quantize_bits is specified."
484+
),
485+
)
453486
parser.add_argument(
454487
"--trust_remote_code",
455488
help="Set trust_remote_code for Huggingface models and tokenizers",

examples/llm_eval/lm_eval_hf.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ def create_from_arg_obj(cls: type[T], arg_dict: dict, additional_config: dict |
5454
quant_cfg = arg_dict.pop("quant_cfg", None)
5555
auto_quantize_bits = arg_dict.pop("auto_quantize_bits", None)
5656
auto_quantize_method = arg_dict.pop("auto_quantize_method", "gradient")
57+
auto_quantize_score_size = arg_dict.pop("auto_quantize_score_size", 128)
58+
auto_quantize_checkpoint = arg_dict.pop("auto_quantize_checkpoint", None)
5759
calib_batch_size = arg_dict.pop("calib_batch_size", None)
5860
calib_size = arg_dict.pop("calib_size", 512)
5961
compress = arg_dict.pop("compress", False)
@@ -83,8 +85,10 @@ def create_from_arg_obj(cls: type[T], arg_dict: dict, additional_config: dict |
8385
calib_size=calib_size,
8486
auto_quantize_bits=auto_quantize_bits,
8587
auto_quantize_method=auto_quantize_method,
88+
auto_quantize_score_size=auto_quantize_score_size,
8689
test_generated=False,
8790
compress=compress,
91+
auto_quantize_checkpoint=auto_quantize_checkpoint,
8892
)
8993

9094
return model_obj
@@ -103,6 +107,12 @@ def setup_parser_with_modelopt_args():
103107
"comma-separated list of quantization quantization formats that will be searched by `auto_quantize`"
104108
),
105109
)
110+
parser.add_argument(
111+
"--calib_batch_size", type=int, help="Batch size for quantization calibration"
112+
)
113+
parser.add_argument(
114+
"--calib_size", type=int, help="Calibration size for quantization", default=512
115+
)
106116
parser.add_argument(
107117
"--auto_quantize_bits",
108118
type=float,
@@ -123,10 +133,19 @@ def setup_parser_with_modelopt_args():
123133
),
124134
)
125135
parser.add_argument(
126-
"--calib_batch_size", type=int, help="Batch size for quantization calibration"
136+
"--auto_quantize_score_size",
137+
type=int,
138+
default=128,
139+
help=(
140+
"Number of samples to use for auto_quantize scoring. Most of auto_quantize time is spent on "
141+
"sensitivity score estimation, so reducing this speeds it up while only minimally affecting "
142+
"final model accuracy compared to lowering --calib_size (the number of samples used for calibration)."
143+
),
127144
)
128145
parser.add_argument(
129-
"--calib_size", type=int, help="Calibration size for quantization", default=512
146+
"--auto_quantize_checkpoint",
147+
type=str,
148+
help=("Path to checkpoint file for saving/restoring auto_quantize search state. "),
130149
)
131150
parser.add_argument(
132151
"--compress",
@@ -153,6 +172,8 @@ def setup_parser_with_modelopt_args():
153172
"quant_cfg": args.quant_cfg,
154173
"auto_quantize_bits": args.auto_quantize_bits,
155174
"auto_quantize_method": args.auto_quantize_method,
175+
"auto_quantize_score_size": args.auto_quantize_score_size,
176+
"auto_quantize_checkpoint": args.auto_quantize_checkpoint,
156177
"calib_batch_size": args.calib_batch_size,
157178
"calib_size": args.calib_size,
158179
"compress": args.compress,

examples/llm_eval/mmlu.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -224,10 +224,12 @@ def main(
224224
ntrain: int = 5,
225225
quant_cfg: str | None = None,
226226
auto_quantize_bits: float | None = None,
227-
auto_quantize_method: str = "gradient",
228227
batch_size: int = 0,
229228
calib_size: int = 512,
230229
dtype: str = "bfloat16",
230+
auto_quantize_method: str = "gradient",
231+
auto_quantize_score_size: int = 128,
232+
auto_quantize_checkpoint: str | None = None,
231233
**kwargs,
232234
):
233235
random.seed(RAND_SEED)
@@ -283,6 +285,8 @@ def main(
283285
calib_size=calib_size,
284286
auto_quantize_bits=auto_quantize_bits,
285287
auto_quantize_method=auto_quantize_method,
288+
auto_quantize_score_size=auto_quantize_score_size,
289+
auto_quantize_checkpoint=auto_quantize_checkpoint,
286290
)
287291

288292
for subject in tqdm(subjects):

examples/llm_eval/quantization_utils.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,10 @@ def _quantize_model_with_dataset(
6767
calib_dataset,
6868
auto_quantize_bits=None,
6969
auto_quantize_method="gradient",
70+
auto_quantize_score_size=128,
7071
batch_size=1,
7172
compress=False,
73+
auto_quantize_checkpoint=None,
7274
):
7375
if hasattr(lm, "gpt2"):
7476
net = lm.gpt2
@@ -112,11 +114,12 @@ def forward_step(model, batch):
112114
forward_step=forward_step,
113115
loss_func=loss_func,
114116
num_calib_steps=len(calib_dataset),
115-
num_score_steps=min(
116-
len(calib_dataset), 128 // batch_size
117-
), # Limit the number of score steps to avoid long calibration time
117+
# Most time is spent on score estimation; fewer samples speed it up with little accuracy impact.
118+
num_score_steps=min(len(calib_dataset), max(auto_quantize_score_size // batch_size, 1)),
118119
verbose=True,
119120
method=auto_quantize_method,
121+
# disabled_layers=["*lm_head*", "*mlp.gate.*"],
122+
checkpoint=auto_quantize_checkpoint,
120123
)
121124
else:
122125
mtq_cfg = CUSTOM_CONFIG.get(quant_cfg) # type: ignore [arg-type]
@@ -160,11 +163,13 @@ def quantize_model(
160163
tokenizer,
161164
batch_size,
162165
calib_size,
163-
auto_quantize_bits=None,
164-
auto_quantize_method="gradient",
165166
data="cnn_dailymail",
166167
test_generated=True,
167168
compress=False,
169+
auto_quantize_bits=None,
170+
auto_quantize_method="gradient",
171+
auto_quantize_score_size=128,
172+
auto_quantize_checkpoint=None,
168173
):
169174
"""Quantizes the model with the provided calibration dataset.
170175
@@ -175,11 +180,14 @@ def quantize_model(
175180
tokenizer: the tokenizer.
176181
batch_size: the calibration batch size for each calibration inference run.
177182
calib_size: the total calibration dataset size.
178-
auto_quantize_bits: The effective bits constraint for auto_quantize.
179-
auto_quantize_method: The method for auto_quantize ('gradient' or 'kl_div').
180183
data: the name of the calibration dataset.
181184
test_generated: If ``True``, test the generated text before and after quantization.
182185
compress: If ``True``, compress the model after quantization.
186+
auto_quantize_bits: The effective bits constraint for auto_quantize.
187+
auto_quantize_method: The method for auto_quantize ('gradient' or 'kl_div').
188+
auto_quantize_score_size: Number of samples used for auto_quantize scoring.
189+
auto_quantize_checkpoint: Path to checkpoint file for saving/restoring auto_quantize search state
190+
(sensitivity scores, costs, etc.). Only used when auto_quantize_bits is specified.
183191
"""
184192
if "AWQ" in quant_cfg:
185193
print(
@@ -191,8 +199,10 @@ def quantize_model(
191199
if hasattr(model, "model"):
192200
device = model.model.device
193201

202+
is_gradient_based = auto_quantize_bits is not None and auto_quantize_method == "gradient"
203+
194204
if batch_size == 0:
195-
if auto_quantize_bits is not None or torch.distributed.is_initialized():
205+
if is_gradient_based or torch.distributed.is_initialized():
196206
raise ValueError("We dont support automatic batch size inference for this case.")
197207

198208
net = model.gpt2 if hasattr(model, "gpt2") else model.model
@@ -201,16 +211,13 @@ def quantize_model(
201211
batch_size = get_max_batch_size(net)
202212
print(f"Update calib batch {batch_size}")
203213

204-
# Labels are only needed for gradient-based auto_quantize
205-
include_labels = auto_quantize_bits is not None and auto_quantize_method == "gradient"
206-
207214
calib_dataloader = get_dataset_dataloader(
208215
dataset_name=data,
209216
tokenizer=tokenizer,
210217
batch_size=batch_size,
211218
num_samples=calib_size,
212219
device=device,
213-
include_labels=include_labels,
220+
include_labels=is_gradient_based,
214221
)
215222

216223
if test_generated:
@@ -223,8 +230,10 @@ def quantize_model(
223230
calib_dataloader,
224231
auto_quantize_bits,
225232
auto_quantize_method,
233+
auto_quantize_score_size,
226234
batch_size,
227235
compress,
236+
auto_quantize_checkpoint,
228237
)
229238

230239
if test_generated:

examples/llm_ptq/hf_ptq.py

Lines changed: 67 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,15 @@
9595

9696

9797
def auto_quantize(
98-
model, qformat, auto_quantize_bits, calib_dataloader, calibrate_loop, batch_size=1
98+
model,
99+
qformat,
100+
calib_dataloader,
101+
calibrate_loop,
102+
auto_quantize_bits,
103+
batch_size=1,
104+
auto_quantize_method="gradient",
105+
auto_quantize_score_size=128,
106+
auto_quantize_checkpoint=None,
99107
):
100108
qformat_list = qformat.split(",")
101109
assert qformat_list, "No quantization formats provided"
@@ -122,18 +130,34 @@ def loss_func(output, data):
122130
# which contains the loss attribute.
123131
return output.loss
124132

133+
if auto_quantize_method == "gradient":
134+
# For gradient-based method, return full output with loss
135+
def forward_step(model, batch):
136+
return model(**batch)
137+
elif auto_quantize_method == "kl_div":
138+
# For KL divergence method, return only logits
139+
def forward_step(model, batch):
140+
return model(**batch).logits
141+
else:
142+
raise ValueError(
143+
f"Invalid auto_quantize_method: {auto_quantize_method}. Must be 'gradient' or 'kl_div'"
144+
)
145+
125146
model, _ = mtq.auto_quantize(
126147
model,
127148
constraints={"effective_bits": auto_quantize_bits},
128149
data_loader=calib_dataloader,
129-
forward_step=lambda model, batch: model(**batch),
130-
loss_func=loss_func,
150+
forward_step=forward_step,
151+
loss_func=loss_func, # Only used for gradient-based method
131152
# TRTLLM only support one quantization format or None (do not quantize, internally supported)
132153
quantization_formats=[QUANT_CFG_CHOICES[format] for format in qformat_list],
133154
num_calib_steps=len(calib_dataloader),
134-
num_score_steps=len(calib_dataloader),
155+
# AutoQuantize scoring is the costly phase; allow smaller sample counts than calibration.
156+
num_score_steps=min(len(calib_dataloader), max(auto_quantize_score_size // batch_size, 1)),
135157
verbose=True,
136158
disabled_layers=["*lm_head*"],
159+
method=auto_quantize_method,
160+
checkpoint=auto_quantize_checkpoint,
137161
)
138162

139163
# We need to explicitly calibrate for kv cache quantization
@@ -191,10 +215,13 @@ def quantize_model(model, quant_cfg, args, calib_dataloader=None, calibration_on
191215
model = auto_quantize(
192216
model,
193217
args.qformat,
194-
args.auto_quantize_bits,
195218
calib_dataloader,
196219
calibrate_loop,
220+
args.auto_quantize_bits,
197221
args.batch_size,
222+
args.auto_quantize_method,
223+
args.auto_quantize_score_size,
224+
args.auto_quantize_checkpoint,
198225
)
199226
elif calibration_only:
200227
model = mtq.calibrate(model, quant_cfg["algorithm"], forward_loop=calibrate_loop)
@@ -444,13 +471,17 @@ def main(args):
444471
assert tokenizer is not None and isinstance(
445472
tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)
446473
), "The PreTrainedTokenizer must be set"
474+
# Labels are only needed for gradient-based auto_quantize
475+
include_labels = (
476+
args.auto_quantize_bits is not None and args.auto_quantize_method == "gradient"
477+
)
447478
calib_dataloader = get_dataset_dataloader(
448479
dataset_name=args.dataset,
449480
tokenizer=tokenizer,
450481
batch_size=args.batch_size,
451482
num_samples=args.calib_size,
452483
device=device,
453-
include_labels=args.auto_quantize_bits is not None,
484+
include_labels=include_labels,
454485
)
455486

456487
quant_cfg = build_quant_cfg(
@@ -803,6 +834,36 @@ def output_decode(generated_ids, input_shape):
803834
default=None,
804835
type=str,
805836
)
837+
parser.add_argument(
838+
"--auto_quantize_method",
839+
type=str,
840+
default="gradient",
841+
choices=["gradient", "kl_div"],
842+
help=(
843+
"Method for auto_quantize sensitivity analysis. 'gradient' uses gradient-based method "
844+
"(requires labels in dataset). 'kl_div' uses KL divergence between original and "
845+
"quantized model outputs (no labels required). Default: 'gradient'"
846+
),
847+
)
848+
parser.add_argument(
849+
"--auto_quantize_score_size",
850+
type=int,
851+
default=128,
852+
help=(
853+
"Number of samples to use for auto_quantize scoring. Most of auto_quantize time is spent on "
854+
"sensitivity score estimation, so reducing this speeds it up while only minimally affecting "
855+
"final model accuracy compared to lowering --calib_size (the number of samples used for calibration)."
856+
),
857+
)
858+
parser.add_argument(
859+
"--auto_quantize_checkpoint",
860+
type=str,
861+
default=None,
862+
help=(
863+
"Path to checkpoint file for saving/restoring auto_quantize search state "
864+
"(sensitivity scores, costs, etc.). Only used when auto_quantize_bits is specified."
865+
),
866+
)
806867

807868
args = parser.parse_args()
808869

0 commit comments

Comments
 (0)