Skip to content

fix NotImplementedError: get_type is not implemented #2133

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Apr 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -591,6 +591,13 @@ def __init__(self, mod, parent, mod_extra_config, *args, **kwargs):
setattr(self, "w2_weight", None)
self.forward = self.forward_orig

def extra_repr(self) -> str:
return extra_representation(
self.extra_repr_org(),
self.class_name_org,
get_current_repr(self),
)


# This patched module is called by the vllm-mixtral FusedMoE layer
# we wrap each expert weight with this module since FusedMoE has a single tensor for all experts weights
Expand Down Expand Up @@ -853,6 +860,13 @@ def update_measure(self, prev, cur, dim, idx, inp_seq_len):
measure_output((output,), self._mod_extra_config.outputs)
return output

def extra_repr(self) -> str:
return extra_representation(
self.extra_repr_org(),
self.class_name_org,
get_current_repr(self),
)


class PatchedVLLMKVCache(PatchedModuleBase):
# Module to patch VLLMKVCache module from llama model
Expand Down Expand Up @@ -891,6 +905,14 @@ def fetch_from_cache(self, cache, blocks, permutations=None):
output_cache = self.orig_fetch_from_cache(quant_cache, blocks)
return self.dequant_output(output_cache)

def extra_repr(self) -> str:
return extra_representation(
self.extra_repr_org(),
self.class_name_org,
get_current_repr(self),
)


def init_conv(instance, mod_extra_config):
if instance.quantization_mode in [QuantMode.QUANTIZE, QuantMode.LOAD]:
instance.quant_input = instance._mod_extra_config.inputs[0]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,32 +175,34 @@ def forward_quant(self, *args, **kwargs):

@classmethod
def get_module_info(cls) -> ModuleInfo:
"""Return the module info for the module.
"""Only necessary for the newly registered patched module that doesn't in _mod_default_dict.
Return the module info for the module, which is used to determine the scaling methods for the module.

For example, for linear module, the module info is: ModuleInfo(type="linear", patched_module=cls).
"""
return ModuleInfo(type=cls.get_type(), patched_module=cls)

@classmethod
@abstractmethod
def get_type(cls) -> str:
"""Return the type of the patched module.
"""Only necessary for the newly registered patched module that doesn't in _mod_default_dict.
Return the type of the patched module, which is used to determine the scaling methods for the module.

Multiple patched modules can have the same type, and share the same scaling methods.
"""
raise NotImplementedError("`get_type` is not implemented")

@classmethod
@abstractmethod
def get_module_type(cls) -> ModuleType:
"""Return the module type for the module.
"""Only necessary for the newly registered patched module that doesn't in _mod_default_dict.
Return the module type for the module, which is used to determine the number of inputs, outputs, and parameters of the module.

The module type is used to determine the number of inputs, outputs, and parameters of the module.
For example, for linear module, the module type is: ModuleType(1, ["weight"], 1, False).
"""
raise NotImplementedError("`get_module_type` is not implemented")

def extra_repr(self):
"""This extra_repr is only for the newly registered patched module that doesn't in _mod_default_dict."""
return f"quantization_mode={self.quantization_mode}, " + \
f"module_info={self.get_module_info()}, " + \
f"module_type={self.get_module_type()}"
Expand Down
Loading