2121import warnings
2222from collections import defaultdict
2323from collections .abc import Callable , Sequence
24+ from contextlib import nullcontext
2425from typing import Any
2526
2627import regex as re
2728import torch
28- import torch .distributed
2929import torch .nn as nn
3030from tqdm import tqdm
3131
4141from .config import QuantizeConfig , QuantizerAttributeConfig
4242from .conversion import set_quantizer_by_cfg
4343from .nn import QuantLinearConvBase , QuantModule , SequentialQuantizer , TensorQuantizer
44- from .utils import is_quantized_linear , multi_context
44+ from .utils import is_quantized_linear
4545
4646
4747def estimate_quant_compression (quant_cfg : QuantizeConfig ) -> float :
@@ -212,7 +212,11 @@ def __init__(
212212 self .active = self .original
213213
214214 self ._importance_dict = {
215- quant_recipe : dict .fromkeys (self .nn_modules , 0.0 ) for quant_recipe in self .choices
215+ quant_recipe : {
216+ mod : torch .zeros ((), device = mod .weight .device , dtype = torch .float32 )
217+ for mod in self .nn_modules
218+ }
219+ for quant_recipe in self .choices
216220 }
217221
218222 @property
@@ -238,11 +242,15 @@ def active(self, val: HPType | None):
238242 def importance (self ) -> dict :
239243 """Return the importance dict mapping recipe and importance."""
240244 return {
241- quant_recipe : sum (importance_dict .values ())
245+ quant_recipe : sum (v . cpu (). item () for v in importance_dict .values ())
242246 for quant_recipe , importance_dict in self ._importance_dict .items ()
243247 }
244248
245249
250+ def _add_auto_quantize_score (grad_output , output_diff , score_tensor ):
251+ score_tensor += ((grad_output .float () ** 2 ) * (output_diff .float () ** 2 )).sum ()
252+
253+
246254class AutoQuantizeSearcher (BaseSearcher ):
247255 """A searcher for AutoQuantize algorithm.
248256
@@ -261,7 +269,7 @@ class AutoQuantizeSearcher(BaseSearcher):
261269
262270 candidate_stats : dict [str , dict [str , list [float ]]]
263271 best : dict [str , Any ]
264- gradient_checkpointing_enable_contexts : list [tuple [Callable , Callable ]] = []
272+ custom_support : list [tuple [Callable , Callable , Callable ]] = []
265273
266274 rules = [
267275 r"^(.*?)\.(q_proj|k_proj|v_proj)$" , # q_proj, k_proj, v_proj for llama like models
@@ -336,15 +344,19 @@ def _get_search_recipes(quantization_formats):
336344 )
337345
338346 @classmethod
339- def register_gradient_checkpointing_enable_context (
340- cls , is_supported_checker : Callable , context : Callable
347+ def register_custom_support (
348+ cls ,
349+ is_supported_checker : Callable ,
350+ grad_ckpt_context : Callable ,
351+ is_param_grad_enabled : Callable ,
341352 ):
342- """Register a gradient checkpointing enable context for `AutoQuantize` score estimation.
353+ """Register custom support for `AutoQuantize` score estimation.
343354
344- If the `is_supported_checker(model)` returns True, the `context(model)` will be used to enable gradient
345- checkpointing.
355+ If the `is_supported_checker(model)` returns True, the `grad_ckpt_context(model)` will be
356+ used to enable gradient checkpointing and `is_param_grad_enabled(pname, model)`
357+ will be used to enable gradient for the parameter.
346358 """
347- cls .gradient_checkpointing_enable_contexts .append ((is_supported_checker , context ))
359+ cls .custom_support .append ((is_supported_checker , grad_ckpt_context , is_param_grad_enabled ))
348360
349361 def _get_default_forward_backward_step (self ):
350362 def forward_backward_step (model , data ):
@@ -361,7 +373,7 @@ def forward_backward_step(model, data):
361373 return forward_backward_step
362374
363375 @torch .enable_grad ()
364- def _estimate_auto_quantize_scores (self ):
376+ def _estimate_auto_quantize_scores (self , is_param_grad_enabled ):
365377 # TODO: remove the no-quant recipe
366378 def auto_quantize_score_estimate_forward (module , input , * args , ** kwargs ):
367379 module .quant_recipe = QuantRecipe (quant_cfg = None )
@@ -377,7 +389,7 @@ def auto_quantize_score_estimate_forward(module, input, *args, **kwargs):
377389 module .output_diff_dict = {}
378390 with torch .no_grad ():
379391 for recipe in module .get_hparam ("quant_recipe" ).choices :
380- if recipe . compression >= 1.0 :
392+ if recipe == QuantRecipe ( quant_cfg = None ) :
381393 continue
382394 module .quant_recipe = recipe
383395 output_diff = module ._forward_original (input , * args , ** kwargs )
@@ -392,18 +404,21 @@ def auto_quantize_score_estimate_forward(module, input, *args, **kwargs):
392404
393405 def backward_hook (module , grad_input , grad_output ):
394406 for recipe , output_diff in module .output_diff_dict .items ():
395- score = ((grad_output [0 ].float () ** 2 ) * (output_diff .float () ** 2 )).sum ()
396- module .get_hparam ("quant_recipe" )._importance_dict [recipe ][module ] += score .item ()
397- module .output_diff_dict [recipe ] = None
407+ score_tensor = module .get_hparam ("quant_recipe" )._importance_dict [recipe ][module ]
408+ _add_auto_quantize_score (grad_output [0 ], output_diff , score_tensor )
398409
399410 del module .output_diff_dict
400411
401- def setup_params_for_score_estimation (name , param , params_metadata ):
412+ def setup_params_for_score_estimation (name , param , params_metadata , enable_grad = True ):
402413 # Let us delete the gradient as soon as they are computed to save memory
403414 # In addition, this method enables gradient for all parameters
404415 # This is needed to make sure the re-entrant activation checkpointing works
405416 params_metadata [name ] = {"requires_grad" : param .requires_grad }
406- param .requires_grad = True
417+ param .requires_grad = enable_grad
418+ if not enable_grad :
419+ return
420+ if self .config .get ("verbose" , False ):
421+ print_rank_0 (f"AutoQuantize: Enabling gradient for param { name } ." )
407422 accum_grad , handle = create_param_grad_clear_hook (param )
408423 params_metadata [name ]["accum_grad" ] = accum_grad # We need to keep the accum_grad alive
409424 params_metadata [name ]["handle" ] = handle
@@ -421,7 +436,9 @@ def cleanup_module_after_score_estimation(module):
421436
422437 def cleanup_params_after_score_estimation (name , param , params_metadata ):
423438 param .requires_grad = params_metadata [name ]["requires_grad" ]
424- params_metadata [name ]["handle" ].remove ()
439+ handle = params_metadata [name ].get ("handle" , None )
440+ if handle is not None :
441+ handle .remove ()
425442
426443 for name , module in self .model .named_modules ():
427444 if (
@@ -432,10 +449,11 @@ def cleanup_params_after_score_estimation(name, param, params_metadata):
432449 setup_module_for_score_estimation (module )
433450
434451 params_metadata = {}
452+
435453 for name , param in self .model .named_parameters ():
436- # TODO: Enabling gradient for all parameters is not needed and making backward slow
437- # We need to enable gradient only for the the first parameter of the module such as embedding weights
438- setup_params_for_score_estimation ( name , param , params_metadata )
454+ setup_params_for_score_estimation (
455+ name , param , params_metadata , is_param_grad_enabled ( name , self . model )
456+ )
439457
440458 gc .collect ()
441459 if torch .cuda .is_available ():
@@ -588,14 +606,20 @@ def forward_loop(model):
588606 ModeloptStateManager (self .model ).state_dict ().pop ()
589607
590608 self .model .eval ()
591- with multi_context (
592- * (
593- context (self .model )
594- for is_supported_checker , context in self .gradient_checkpointing_enable_contexts
595- if is_supported_checker (self .model )
596- )
597- ):
598- self ._estimate_auto_quantize_scores ()
609+
610+ def _default_is_param_grad_enabled (pname , model ):
611+ return True
612+
613+ grad_checkpointing_ctxt = None
614+ is_param_grad_enabled = _default_is_param_grad_enabled
615+ for is_supported_checker , ctxt_candidate , grad_enabled_candidate in self .custom_support :
616+ if is_supported_checker (self .model ):
617+ grad_checkpointing_ctxt = ctxt_candidate
618+ is_param_grad_enabled = grad_enabled_candidate
619+ break
620+
621+ with grad_checkpointing_ctxt (self .model ) if grad_checkpointing_ctxt else nullcontext ():
622+ self ._estimate_auto_quantize_scores (is_param_grad_enabled )
599623
600624 def run_search (self ):
601625 """Search for the best per-layer quantization configuration and return the best model and configuration.
0 commit comments