diff --git a/videox_fun/utils/fp8_optimization.py b/videox_fun/utils/fp8_optimization.py index 1aa6d26..058b06f 100644 --- a/videox_fun/utils/fp8_optimization.py +++ b/videox_fun/utils/fp8_optimization.py @@ -3,54 +3,54 @@ import torch import torch.nn as nn -def autocast_model_forward(cls, origin_dtype, *inputs, **kwargs): - weight_dtype = cls.weight.dtype - cls.to(origin_dtype) - - # Convert all inputs to the original dtype - inputs = [input.to(origin_dtype) for input in inputs] - out = cls.original_forward(*inputs, **kwargs) - - cls.to(weight_dtype) - return out def replace_parameters_by_name(module, name_keywords, device): - from torch import nn for name, param in list(module.named_parameters(recurse=False)): if any(keyword in name for keyword in name_keywords): if isinstance(param, nn.Parameter): tensor = param.data delattr(module, name) setattr(module, name, tensor.to(device=device)) - for child_name, child_module in module.named_children(): + for _, child_module in module.named_children(): replace_parameters_by_name(child_module, name_keywords, device) + def convert_model_weight_to_float8(model, exclude_module_name=['embed_tokens']): for name, module in model.named_modules(): - flag = False - for _exclude_module_name in exclude_module_name: - if _exclude_module_name in name: - flag = True - if flag: + if any(ex in name for ex in exclude_module_name): continue + for param_name, param in module.named_parameters(): - flag = False - for _exclude_module_name in exclude_module_name: - if _exclude_module_name in param_name: - flag = True - if flag: + if any(ex in param_name for ex in exclude_module_name): continue + param.data = param.data.to(torch.float8_e4m3fn) + def convert_weight_dtype_wrapper(module, origin_dtype): - for name, module in module.named_modules(): + for name, mod in module.named_modules(): + # skip root and embedding layers if name == "" or "embed_tokens" in name: continue - original_forward = module.forward - if hasattr(module, "weight") and module.weight is not None: - setattr(module, "original_forward", original_forward) - setattr( - module, - "forward", - lambda *inputs, m=module, **kwargs: autocast_model_forward(m, origin_dtype, *inputs, **kwargs) - ) + + # avoid wrapping twice + if hasattr(mod, "original_forward"): + continue + + if hasattr(mod, "weight") and mod.weight is not None: + orig_forward = mod.forward + + # unwrap accelerate / decorator wrappers if present + while hasattr(orig_forward, "__wrapped__"): + orig_forward = orig_forward.__wrapped__ + + mod.original_forward = orig_forward + + def new_forward(*inputs, m=mod, **kwargs): + casted_inputs = [ + inp.to(origin_dtype) if torch.is_tensor(inp) else inp + for inp in inputs + ] + return m.original_forward(*casted_inputs, **kwargs) + + mod.forward = new_forward \ No newline at end of file