Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 31 additions & 31 deletions videox_fun/utils/fp8_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,54 +3,54 @@
import torch
import torch.nn as nn

def autocast_model_forward(cls, origin_dtype, *inputs, **kwargs):
weight_dtype = cls.weight.dtype
cls.to(origin_dtype)

# Convert all inputs to the original dtype
inputs = [input.to(origin_dtype) for input in inputs]
out = cls.original_forward(*inputs, **kwargs)

cls.to(weight_dtype)
return out

def replace_parameters_by_name(module, name_keywords, device):
from torch import nn
for name, param in list(module.named_parameters(recurse=False)):
if any(keyword in name for keyword in name_keywords):
if isinstance(param, nn.Parameter):
tensor = param.data
delattr(module, name)
setattr(module, name, tensor.to(device=device))
for child_name, child_module in module.named_children():
for _, child_module in module.named_children():
replace_parameters_by_name(child_module, name_keywords, device)


def convert_model_weight_to_float8(model, exclude_module_name=['embed_tokens']):
for name, module in model.named_modules():
flag = False
for _exclude_module_name in exclude_module_name:
if _exclude_module_name in name:
flag = True
if flag:
if any(ex in name for ex in exclude_module_name):
continue

for param_name, param in module.named_parameters():
flag = False
for _exclude_module_name in exclude_module_name:
if _exclude_module_name in param_name:
flag = True
if flag:
if any(ex in param_name for ex in exclude_module_name):
continue

param.data = param.data.to(torch.float8_e4m3fn)


def convert_weight_dtype_wrapper(module, origin_dtype):
for name, module in module.named_modules():
for name, mod in module.named_modules():
# skip root and embedding layers
if name == "" or "embed_tokens" in name:
continue
original_forward = module.forward
if hasattr(module, "weight") and module.weight is not None:
setattr(module, "original_forward", original_forward)
setattr(
module,
"forward",
lambda *inputs, m=module, **kwargs: autocast_model_forward(m, origin_dtype, *inputs, **kwargs)
)

# avoid wrapping twice
if hasattr(mod, "original_forward"):
continue

if hasattr(mod, "weight") and mod.weight is not None:
orig_forward = mod.forward

# unwrap accelerate / decorator wrappers if present
while hasattr(orig_forward, "__wrapped__"):
orig_forward = orig_forward.__wrapped__

mod.original_forward = orig_forward

def new_forward(*inputs, m=mod, **kwargs):
casted_inputs = [
inp.to(origin_dtype) if torch.is_tensor(inp) else inp
for inp in inputs
]
return m.original_forward(*casted_inputs, **kwargs)

mod.forward = new_forward