Skip to content

Commit 9ebd69f

Browse files
committed
chery-picked some relevant changes
1 parent 25c41f7 commit 9ebd69f

File tree

1 file changed

+23
-14
lines changed

1 file changed

+23
-14
lines changed

modelopt/torch/quantization/algorithms.py

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -228,13 +228,9 @@ def __init__(
228228

229229
self.active = self.original
230230

231-
# Importance dict is keyed by quant_module (where the quantization is applied)
231+
# Importance dict is keyed by score_module (where the score is computed)
232232
self._importance_dict = {
233-
quant_recipe: {
234-
mod: torch.zeros((), device=mod.weight.device, dtype=torch.float32)
235-
for mod in self.quant_modules
236-
}
237-
for quant_recipe in self.choices
233+
quant_recipe: dict.fromkeys(self.score_modules) for quant_recipe in self.choices
238234
}
239235

240236
# Attach this hparam to each score_module's set of hparams it scores
@@ -266,7 +262,7 @@ def active(self, val: HPType | None):
266262
def importance(self) -> dict:
267263
"""Return the importance dict mapping recipe and importance."""
268264
return {
269-
quant_recipe: sum(v.cpu().item() for v in importance_dict.values())
265+
quant_recipe: sum(v.cpu().item() for v in importance_dict.values() if v is not None)
270266
for quant_recipe, importance_dict in self._importance_dict.items()
271267
}
272268

@@ -275,11 +271,6 @@ def attrs(self) -> list[str]:
275271
"""Return the attributes of the hparam for repr."""
276272
return ["name", *super().attrs]
277273

278-
279-
def _add_auto_quantize_score(grad_output, output_diff, score_tensor):
280-
score_tensor += ((grad_output.float() ** 2) * (output_diff.float() ** 2)).sum()
281-
282-
283274
class _AutoQuantizeBaseSearcher(BaseSearcher, ABC):
284275
"""A base searcher for AutoQuantize algorithm."""
285276

@@ -665,6 +656,18 @@ def run_search(self):
665656
QuantRecipe.fold_pqs_to_weights(self.model)
666657

667658

659+
660+
661+
@torch.compile
662+
def _get_auto_quantize_score(grad_output, output_diff):
663+
return ((grad_output.float() ** 2) * (output_diff.float() ** 2)).sum()
664+
665+
666+
@torch.compile
667+
def _add_auto_quantize_score(grad_output, output_diff, score_tensor):
668+
score_tensor += _get_auto_quantize_score(grad_output, output_diff)
669+
670+
668671
class AutoQuantizeGradientSearcher(_AutoQuantizeBaseSearcher):
669672
"""A searcher for AutoQuantize algorithm that uses gradient based score estimation.
670673
@@ -790,8 +793,14 @@ def auto_quantize_score_estimate_forward(module, input, *args, **kwargs):
790793
def backward_hook(module, grad_input, grad_output):
791794
for hparam, output_diff_dict in module.output_diff_dict.items():
792795
for recipe, output_diff in output_diff_dict.items():
793-
score_tensor = hparam._importance_dict[recipe][module]
794-
_add_auto_quantize_score(grad_output[0], output_diff, score_tensor)
796+
if hparam._importance_dict[recipe][module] is None:
797+
hparam._importance_dict[recipe][module] = _get_auto_quantize_score(
798+
grad_output[0], output_diff
799+
)
800+
else:
801+
_add_auto_quantize_score(
802+
grad_output[0], output_diff, hparam._importance_dict[recipe][module]
803+
)
795804

796805
def setup_params_for_score_estimation(name, param, params_metadata, enable_grad=True):
797806
# Let us delete the gradient as soon as they are computed to save memory

0 commit comments

Comments
 (0)