Skip to content

Commit a703e22

Browse files
authored
AutoQuantize minor improvement: limit grad enabled parameters, limit cpu-gpu sync during scoring (#551)
## What does this PR do? **Type of change:** ? AutoQuantize minor improvement: limit grad enabled parameters, limit cpu-gpu sync during scoring **Overview:** ? Minor improvements for AutoQuantize. Added support to enable grad only for selected parameters - this should reduce the number of gemms In the backward pass to half. ## Testing Covered by unit tests (both CPU and GPU) ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes - **Did you write any new necessary tests?**: Yes - **Did you add or update any necessary documentation?**: NA - **Did you update [Changelog](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CHANGELOG.rst)?**: NA ## Additional Information <!-- E.g. related issue. --> Signed-off-by: realAsma <[email protected]>
1 parent c033276 commit a703e22

File tree

2 files changed

+64
-32
lines changed

2 files changed

+64
-32
lines changed

modelopt/torch/quantization/algorithms.py

Lines changed: 54 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,11 @@
2121
import warnings
2222
from collections import defaultdict
2323
from collections.abc import Callable, Sequence
24+
from contextlib import nullcontext
2425
from typing import Any
2526

2627
import regex as re
2728
import torch
28-
import torch.distributed
2929
import torch.nn as nn
3030
from tqdm import tqdm
3131

@@ -41,7 +41,7 @@
4141
from .config import QuantizeConfig, QuantizerAttributeConfig
4242
from .conversion import set_quantizer_by_cfg
4343
from .nn import QuantLinearConvBase, QuantModule, SequentialQuantizer, TensorQuantizer
44-
from .utils import is_quantized_linear, multi_context
44+
from .utils import is_quantized_linear
4545

4646

4747
def estimate_quant_compression(quant_cfg: QuantizeConfig) -> float:
@@ -212,7 +212,11 @@ def __init__(
212212
self.active = self.original
213213

214214
self._importance_dict = {
215-
quant_recipe: dict.fromkeys(self.nn_modules, 0.0) for quant_recipe in self.choices
215+
quant_recipe: {
216+
mod: torch.zeros((), device=mod.weight.device, dtype=torch.float32)
217+
for mod in self.nn_modules
218+
}
219+
for quant_recipe in self.choices
216220
}
217221

218222
@property
@@ -238,11 +242,15 @@ def active(self, val: HPType | None):
238242
def importance(self) -> dict:
239243
"""Return the importance dict mapping recipe and importance."""
240244
return {
241-
quant_recipe: sum(importance_dict.values())
245+
quant_recipe: sum(v.cpu().item() for v in importance_dict.values())
242246
for quant_recipe, importance_dict in self._importance_dict.items()
243247
}
244248

245249

250+
def _add_auto_quantize_score(grad_output, output_diff, score_tensor):
251+
score_tensor += ((grad_output.float() ** 2) * (output_diff.float() ** 2)).sum()
252+
253+
246254
class AutoQuantizeSearcher(BaseSearcher):
247255
"""A searcher for AutoQuantize algorithm.
248256
@@ -261,7 +269,7 @@ class AutoQuantizeSearcher(BaseSearcher):
261269

262270
candidate_stats: dict[str, dict[str, list[float]]]
263271
best: dict[str, Any]
264-
gradient_checkpointing_enable_contexts: list[tuple[Callable, Callable]] = []
272+
custom_support: list[tuple[Callable, Callable, Callable]] = []
265273

266274
rules = [
267275
r"^(.*?)\.(q_proj|k_proj|v_proj)$", # q_proj, k_proj, v_proj for llama like models
@@ -336,15 +344,19 @@ def _get_search_recipes(quantization_formats):
336344
)
337345

338346
@classmethod
339-
def register_gradient_checkpointing_enable_context(
340-
cls, is_supported_checker: Callable, context: Callable
347+
def register_custom_support(
348+
cls,
349+
is_supported_checker: Callable,
350+
grad_ckpt_context: Callable,
351+
is_param_grad_enabled: Callable,
341352
):
342-
"""Register a gradient checkpointing enable context for `AutoQuantize` score estimation.
353+
"""Register custom support for `AutoQuantize` score estimation.
343354
344-
If the `is_supported_checker(model)` returns True, the `context(model)` will be used to enable gradient
345-
checkpointing.
355+
If the `is_supported_checker(model)` returns True, the `grad_ckpt_context(model)` will be
356+
used to enable gradient checkpointing and `is_param_grad_enabled(pname, model)`
357+
will be used to enable gradient for the parameter.
346358
"""
347-
cls.gradient_checkpointing_enable_contexts.append((is_supported_checker, context))
359+
cls.custom_support.append((is_supported_checker, grad_ckpt_context, is_param_grad_enabled))
348360

349361
def _get_default_forward_backward_step(self):
350362
def forward_backward_step(model, data):
@@ -361,7 +373,7 @@ def forward_backward_step(model, data):
361373
return forward_backward_step
362374

363375
@torch.enable_grad()
364-
def _estimate_auto_quantize_scores(self):
376+
def _estimate_auto_quantize_scores(self, is_param_grad_enabled):
365377
# TODO: remove the no-quant recipe
366378
def auto_quantize_score_estimate_forward(module, input, *args, **kwargs):
367379
module.quant_recipe = QuantRecipe(quant_cfg=None)
@@ -377,7 +389,7 @@ def auto_quantize_score_estimate_forward(module, input, *args, **kwargs):
377389
module.output_diff_dict = {}
378390
with torch.no_grad():
379391
for recipe in module.get_hparam("quant_recipe").choices:
380-
if recipe.compression >= 1.0:
392+
if recipe == QuantRecipe(quant_cfg=None):
381393
continue
382394
module.quant_recipe = recipe
383395
output_diff = module._forward_original(input, *args, **kwargs)
@@ -392,18 +404,21 @@ def auto_quantize_score_estimate_forward(module, input, *args, **kwargs):
392404

393405
def backward_hook(module, grad_input, grad_output):
394406
for recipe, output_diff in module.output_diff_dict.items():
395-
score = ((grad_output[0].float() ** 2) * (output_diff.float() ** 2)).sum()
396-
module.get_hparam("quant_recipe")._importance_dict[recipe][module] += score.item()
397-
module.output_diff_dict[recipe] = None
407+
score_tensor = module.get_hparam("quant_recipe")._importance_dict[recipe][module]
408+
_add_auto_quantize_score(grad_output[0], output_diff, score_tensor)
398409

399410
del module.output_diff_dict
400411

401-
def setup_params_for_score_estimation(name, param, params_metadata):
412+
def setup_params_for_score_estimation(name, param, params_metadata, enable_grad=True):
402413
# Let us delete the gradient as soon as they are computed to save memory
403414
# In addition, this method enables gradient for all parameters
404415
# This is needed to make sure the re-entrant activation checkpointing works
405416
params_metadata[name] = {"requires_grad": param.requires_grad}
406-
param.requires_grad = True
417+
param.requires_grad = enable_grad
418+
if not enable_grad:
419+
return
420+
if self.config.get("verbose", False):
421+
print_rank_0(f"AutoQuantize: Enabling gradient for param {name}.")
407422
accum_grad, handle = create_param_grad_clear_hook(param)
408423
params_metadata[name]["accum_grad"] = accum_grad # We need to keep the accum_grad alive
409424
params_metadata[name]["handle"] = handle
@@ -421,7 +436,9 @@ def cleanup_module_after_score_estimation(module):
421436

422437
def cleanup_params_after_score_estimation(name, param, params_metadata):
423438
param.requires_grad = params_metadata[name]["requires_grad"]
424-
params_metadata[name]["handle"].remove()
439+
handle = params_metadata[name].get("handle", None)
440+
if handle is not None:
441+
handle.remove()
425442

426443
for name, module in self.model.named_modules():
427444
if (
@@ -432,10 +449,11 @@ def cleanup_params_after_score_estimation(name, param, params_metadata):
432449
setup_module_for_score_estimation(module)
433450

434451
params_metadata = {}
452+
435453
for name, param in self.model.named_parameters():
436-
# TODO: Enabling gradient for all parameters is not needed and making backward slow
437-
# We need to enable gradient only for the the first parameter of the module such as embedding weights
438-
setup_params_for_score_estimation(name, param, params_metadata)
454+
setup_params_for_score_estimation(
455+
name, param, params_metadata, is_param_grad_enabled(name, self.model)
456+
)
439457

440458
gc.collect()
441459
if torch.cuda.is_available():
@@ -588,14 +606,20 @@ def forward_loop(model):
588606
ModeloptStateManager(self.model).state_dict().pop()
589607

590608
self.model.eval()
591-
with multi_context(
592-
*(
593-
context(self.model)
594-
for is_supported_checker, context in self.gradient_checkpointing_enable_contexts
595-
if is_supported_checker(self.model)
596-
)
597-
):
598-
self._estimate_auto_quantize_scores()
609+
610+
def _default_is_param_grad_enabled(pname, model):
611+
return True
612+
613+
grad_checkpointing_ctxt = None
614+
is_param_grad_enabled = _default_is_param_grad_enabled
615+
for is_supported_checker, ctxt_candidate, grad_enabled_candidate in self.custom_support:
616+
if is_supported_checker(self.model):
617+
grad_checkpointing_ctxt = ctxt_candidate
618+
is_param_grad_enabled = grad_enabled_candidate
619+
break
620+
621+
with grad_checkpointing_ctxt(self.model) if grad_checkpointing_ctxt else nullcontext():
622+
self._estimate_auto_quantize_scores(is_param_grad_enabled)
599623

600624
def run_search(self):
601625
"""Search for the best per-layer quantization configuration and return the best model and configuration.

modelopt/torch/quantization/plugins/huggingface.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -739,8 +739,16 @@ def setup_model_for_gradient_checkpointing(model: nn.Module):
739739
model.config.use_cache = use_cache
740740

741741

742-
AutoQuantizeSearcher.register_gradient_checkpointing_enable_context(
743-
_is_supported_hf_model, setup_model_for_gradient_checkpointing
742+
def _is_param_grad_enabled_for_auto_quantize(pname, model):
743+
# Enable grad for embedding layers to propagate gradients through the model,
744+
# allowing each layer to compute its input gradients during the backward pass.
745+
return "embed" in pname
746+
747+
748+
AutoQuantizeSearcher.register_custom_support(
749+
_is_supported_hf_model,
750+
setup_model_for_gradient_checkpointing,
751+
_is_param_grad_enabled_for_auto_quantize,
744752
)
745753

746754
CUSTOM_MODEL_PLUGINS.update(

0 commit comments

Comments
 (0)