@@ -228,13 +228,9 @@ def __init__(
228228
229229 self .active = self .original
230230
231- # Importance dict is keyed by quant_module (where the quantization is applied )
231+ # Importance dict is keyed by score_module (where the score is computed )
232232 self ._importance_dict = {
233- quant_recipe : {
234- mod : torch .zeros ((), device = mod .weight .device , dtype = torch .float32 )
235- for mod in self .quant_modules
236- }
237- for quant_recipe in self .choices
233+ quant_recipe : dict .fromkeys (self .score_modules ) for quant_recipe in self .choices
238234 }
239235
240236 # Attach this hparam to each score_module's set of hparams it scores
@@ -266,7 +262,7 @@ def active(self, val: HPType | None):
266262 def importance (self ) -> dict :
267263 """Return the importance dict mapping recipe and importance."""
268264 return {
269- quant_recipe : sum (v .cpu ().item () for v in importance_dict .values ())
265+ quant_recipe : sum (v .cpu ().item () for v in importance_dict .values () if v is not None )
270266 for quant_recipe , importance_dict in self ._importance_dict .items ()
271267 }
272268
@@ -275,11 +271,6 @@ def attrs(self) -> list[str]:
275271 """Return the attributes of the hparam for repr."""
276272 return ["name" , * super ().attrs ]
277273
278-
279- def _add_auto_quantize_score (grad_output , output_diff , score_tensor ):
280- score_tensor += ((grad_output .float () ** 2 ) * (output_diff .float () ** 2 )).sum ()
281-
282-
283274class _AutoQuantizeBaseSearcher (BaseSearcher , ABC ):
284275 """A base searcher for AutoQuantize algorithm."""
285276
@@ -665,6 +656,18 @@ def run_search(self):
665656 QuantRecipe .fold_pqs_to_weights (self .model )
666657
667658
659+
660+
661+ @torch .compile
662+ def _get_auto_quantize_score (grad_output , output_diff ):
663+ return ((grad_output .float () ** 2 ) * (output_diff .float () ** 2 )).sum ()
664+
665+
666+ @torch .compile
667+ def _add_auto_quantize_score (grad_output , output_diff , score_tensor ):
668+ score_tensor += _get_auto_quantize_score (grad_output , output_diff )
669+
670+
668671class AutoQuantizeGradientSearcher (_AutoQuantizeBaseSearcher ):
669672 """A searcher for AutoQuantize algorithm that uses gradient based score estimation.
670673
@@ -790,8 +793,14 @@ def auto_quantize_score_estimate_forward(module, input, *args, **kwargs):
790793 def backward_hook (module , grad_input , grad_output ):
791794 for hparam , output_diff_dict in module .output_diff_dict .items ():
792795 for recipe , output_diff in output_diff_dict .items ():
793- score_tensor = hparam ._importance_dict [recipe ][module ]
794- _add_auto_quantize_score (grad_output [0 ], output_diff , score_tensor )
796+ if hparam ._importance_dict [recipe ][module ] is None :
797+ hparam ._importance_dict [recipe ][module ] = _get_auto_quantize_score (
798+ grad_output [0 ], output_diff
799+ )
800+ else :
801+ _add_auto_quantize_score (
802+ grad_output [0 ], output_diff , hparam ._importance_dict [recipe ][module ]
803+ )
795804
796805 def setup_params_for_score_estimation (name , param , params_metadata , enable_grad = True ):
797806 # Let us delete the gradient as soon as they are computed to save memory
0 commit comments