diff --git a/neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py b/neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py index 12c463927a3..b69804255fd 100755 --- a/neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py +++ b/neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py @@ -591,6 +591,13 @@ def __init__(self, mod, parent, mod_extra_config, *args, **kwargs): setattr(self, "w2_weight", None) self.forward = self.forward_orig + def extra_repr(self) -> str: + return extra_representation( + self.extra_repr_org(), + self.class_name_org, + get_current_repr(self), + ) + # This patched module is called by the vllm-mixtral FusedMoE layer # we wrap each expert weight with this module since FusedMoE has a single tensor for all experts weights @@ -853,6 +860,13 @@ def update_measure(self, prev, cur, dim, idx, inp_seq_len): measure_output((output,), self._mod_extra_config.outputs) return output + def extra_repr(self) -> str: + return extra_representation( + self.extra_repr_org(), + self.class_name_org, + get_current_repr(self), + ) + class PatchedVLLMKVCache(PatchedModuleBase): # Module to patch VLLMKVCache module from llama model @@ -891,6 +905,14 @@ def fetch_from_cache(self, cache, blocks, permutations=None): output_cache = self.orig_fetch_from_cache(quant_cache, blocks) return self.dequant_output(output_cache) + def extra_repr(self) -> str: + return extra_representation( + self.extra_repr_org(), + self.class_name_org, + get_current_repr(self), + ) + + def init_conv(instance, mod_extra_config): if instance.quantization_mode in [QuantMode.QUANTIZE, QuantMode.LOAD]: instance.quant_input = instance._mod_extra_config.inputs[0] diff --git a/neural_compressor/torch/algorithms/fp8_quant/patched_module_base.py b/neural_compressor/torch/algorithms/fp8_quant/patched_module_base.py index d0c2c190666..02c219181aa 100644 --- a/neural_compressor/torch/algorithms/fp8_quant/patched_module_base.py +++ b/neural_compressor/torch/algorithms/fp8_quant/patched_module_base.py @@ -175,25 +175,26 @@ def forward_quant(self, *args, **kwargs): @classmethod def get_module_info(cls) -> ModuleInfo: - """Return the module info for the module. + """Only necessary for the newly registered patched module that doesn't in _mod_default_dict. + Return the module info for the module, which is used to determine the scaling methods for the module. For example, for linear module, the module info is: ModuleInfo(type="linear", patched_module=cls). """ return ModuleInfo(type=cls.get_type(), patched_module=cls) @classmethod - @abstractmethod def get_type(cls) -> str: - """Return the type of the patched module. + """Only necessary for the newly registered patched module that doesn't in _mod_default_dict. + Return the type of the patched module, which is used to determine the scaling methods for the module. Multiple patched modules can have the same type, and share the same scaling methods. """ raise NotImplementedError("`get_type` is not implemented") @classmethod - @abstractmethod def get_module_type(cls) -> ModuleType: - """Return the module type for the module. + """Only necessary for the newly registered patched module that doesn't in _mod_default_dict. + Return the module type for the module, which is used to determine the number of inputs, outputs, and parameters of the module. The module type is used to determine the number of inputs, outputs, and parameters of the module. For example, for linear module, the module type is: ModuleType(1, ["weight"], 1, False). @@ -201,6 +202,7 @@ def get_module_type(cls) -> ModuleType: raise NotImplementedError("`get_module_type` is not implemented") def extra_repr(self): + """This extra_repr is only for the newly registered patched module that doesn't in _mod_default_dict.""" return f"quantization_mode={self.quantization_mode}, " + \ f"module_info={self.get_module_info()}, " + \ f"module_type={self.get_module_type()}"