diff --git a/neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py b/neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py index a2cd3ac258f..b716d361ab0 100644 --- a/neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py +++ b/neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py @@ -1082,10 +1082,10 @@ def forward_qdq(self, input, *args, **kwargs): output_cache = self.orig_mod(qinput, *args, **kwargs) return output_cache - # def forward_quant(self, input, *args, **kwargs): - # qinput = self.quant_input(input) - # output_cache = self.orig_mod(qinput, *args, **kwargs) - # return self.dequant_output(output_cache) + def forward_quant(self, input, *args, **kwargs): + qinput = self.quant_input(input) + output_cache = self.orig_mod(qinput, *args, **kwargs) + return self.dequant_output(output_cache) def forward_measure(self, input, *args, **kwargs): measure_input((input, ), self._mod_extra_config.inputs) @@ -1093,22 +1093,8 @@ def forward_measure(self, input, *args, **kwargs): measure_output((output_cache, ), self._mod_extra_config.outputs) return output_cache - # def fetch_from_cache(self, cache, blocks, permutations=None): - # # quant_cache = self.quant_input(cache) - # quant_cache = cache - # if permutations: - # output_cache = self.orig_mod.fetch_from_cache(quant_cache, blocks, permutations) - # for i in range(len(output_cache)): - # output_cache[i] = self.dequant_output(output_cache[i]) - # return output_cache - # output_cache = self.orig_mod.fetch_from_cache(quant_cache, blocks) - # return self.dequant_output(output_cache) - - def forward_quant(self, input, *args, **kwargs): - qinput = self.quant_input(input) - return self.orig_mod(qinput, *args, **kwargs) - - def fetch_from_cache(self, quant_cache, blocks, permutations=None): + def fetch_from_cache(self, cache, blocks, permutations=None): + quant_cache = self.quant_input(cache) if permutations: output_cache = self.orig_mod.fetch_from_cache(quant_cache, blocks, permutations) for i in range(len(output_cache)):