Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 106 additions & 0 deletions lightx2v_train/configs/dmd_lora/qwen_image_dmd_lora.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
model:
name: qwen_image
pretrained_model_name_or_path: /path/to/Qwen/Qwen-Image
max_sequence_length: 1024
running_dtype: bf16

data:
train:
name: image_dataset
num_workers: 8
prompt_dropout_rate: 0.0
target_area: 1048576 # 1024 * 1024
shuffle: true
data_path:
- /path/to/LightX2V_train_data_examples/dataset_v1/train.jsonl
val:
name: image_dataset
num_workers: 8
shuffle: false
data_path:
- /path/to/LightX2V_train_data_examples/dataset_v1/val.jsonl

scheduler:
num_train_timesteps: 1000
timestep_distribution: uniform
min_t: 0.001
max_t: 1.0
time_shift_settings:
do_time_shift: true
shift_type: exponential
# shift function: "linear" => mu/(mu+(1/t-1)^p), "exponential" => exp(mu)/(exp(mu)+(1/t-1)^p)
time_shift_power: 1.0
dynamic_shift: true
patch_size: [2, 2] # [H, W]
# https://github.com/huggingface/diffusers/blob/v0.38.0/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py#L59
shift_x1: 256
shift_x2: 4096
shift_y1: 0.5
shift_y2: 1.15

training:
method: dmd_lora
max_train_iters: 10
gradient_accumulation_iters: 1
gradient_checkpointing: true
max_grad_norm: 1.0
lr_scheduler: constant
lr_warmup_iters: 10
save_every_iters: 5
save_total_limit: 10
lora:
rank: 32
alpha: 32
target_modules:
- to_k
- to_q
- to_v
- to_out.0
# - add_q_proj
# - add_k_proj
# - add_v_proj
# - to_add_out
# - img_mlp.net.0.proj
# - img_mlp.net.2
# - txt_mlp.net.0.proj
# - txt_mlp.net.2
optimizer:
learning_rate: 0.0001
adam_beta1: 0.9
adam_beta2: 0.999
weight_decay: 0.001
adam_epsilon: 0.00000001
fake:
optimizer:
learning_rate: 0.00002
adam_beta1: 0.9
adam_beta2: 0.999
weight_decay: 0.001
adam_epsilon: 0.00000001
dmd:
num_inference_steps: 4
fake_update_ratio: 2
guidance_scale: 4.0
negative_prompt: " "
cfg_norm: layer_norm
image_sizes:
- [1024, 1024]
- [768, 1344]
- [1344, 768]
sigma_min: 0.02
sigma_max: 1.0
discrete_samples: 1000
renoise_shift: 5.0
inference_shift: 3.0
output_dir: ./output_train/qwen_image_dmd_lora

inference:
method: image_infer
default_width: 1024
default_height: 1024
num_inference_steps: 4
cfg_guidance_scale: 4.0
negative_prompt: " "

resume:
auto_resume: true
19 changes: 17 additions & 2 deletions lightx2v_train/lightx2v_train/model_zoo/qwen_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,15 @@ class QwenImageModel(BaseModel):

pipeline_cls = QwenImagePipeline

def load_components(self):
def load_components(self, transformer_only=False, reference_model=None):
if transformer_only:
if reference_model is not None:
self.text_pipeline = reference_model.text_pipeline
self.vae = reference_model.vae
self.vae_scale_factor = reference_model.vae_scale_factor
self.image_processor = reference_model.image_processor
self.transformer = self.load_transformer()
return
model_path = self.config["model"]["pretrained_model_name_or_path"]
self.text_pipeline = QwenImagePipeline.from_pretrained(
model_path,
Expand All @@ -35,13 +43,17 @@ def load_components(self):
torch_dtype=self.running_dtype,
).to(self.device)
self.vae = AutoencoderKLQwenImage.from_pretrained(model_path, subfolder="vae").to(self.device, dtype=self.running_dtype)
self.transformer = QwenImageTransformer2DModel.from_pretrained(model_path, subfolder="transformer").to(self.device, dtype=self.running_dtype)
self.transformer = self.load_transformer()

self.text_pipeline.text_encoder.requires_grad_(False)
self.vae.requires_grad_(False)
self.vae_scale_factor = 2 ** len(self.vae.temperal_downsample)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)

def load_transformer(self):
model_path = self.config["model"]["pretrained_model_name_or_path"]
return QwenImageTransformer2DModel.from_pretrained(model_path, subfolder="transformer").to(self.device, dtype=self.running_dtype)

def encode_to_latent(self, sample):
image = sample["target_image"].to(device=self.device, dtype=self.running_dtype)
pixel_values = image.unsqueeze(2)
Expand All @@ -53,6 +65,9 @@ def encode_to_latent(self, sample):

def encode_condition(self, sample):
prompt = sample["prompt"]
return self.encode_prompt_condition(prompt)

def encode_prompt_condition(self, prompt):
prompt_embed, prompt_embed_mask = self.text_pipeline.encode_prompt(
prompt=prompt,
device=self.device,
Expand Down
4 changes: 4 additions & 0 deletions lightx2v_train/lightx2v_train/schedulers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .dmd_scheduler import DMDFlowMatchingScheduler
from .flow_matching import RectifiedFlowMatchingScheduler

__all__ = ["DMDFlowMatchingScheduler", "RectifiedFlowMatchingScheduler"]
57 changes: 57 additions & 0 deletions lightx2v_train/lightx2v_train/schedulers/dmd_scheduler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import torch

from .flow_matching import RectifiedFlowMatchingScheduler


class DMDFlowMatchingScheduler(RectifiedFlowMatchingScheduler):
def __init__(self, config, dmd_config={}):
super().__init__(config)
self.inference_shift = float(dmd_config.get("inference_shift", 3.0))
self.renoise_shift = float(dmd_config.get("renoise_shift", 5.0))
self.min_sigma = float(dmd_config.get("sigma_min", 0.02))
self.max_sigma = float(dmd_config.get("sigma_max", 1.0))
self.discrete_samples = int(dmd_config.get("discrete_samples", 1000))

@staticmethod
def linear_shift(mu, t):
return mu / (mu + (1 / t - 1))

def set_timesteps(self, num_inference_steps, device=None):
self.num_inference_steps = int(num_inference_steps)
device = device or self.device
timesteps = torch.linspace(
self.num_train_timesteps,
0,
self.num_inference_steps + 1,
dtype=torch.float32,
device=device,
)
self.sigmas = self.linear_shift(self.inference_shift, timesteps / self.num_train_timesteps)
self.timesteps = self.sigmas * self.num_train_timesteps

def sigma_at(self, step_idx, batch_size, device=None, dtype=None):
sigma = self.sigmas[int(step_idx)].expand(int(batch_size))
if device is not None or dtype is not None:
sigma = sigma.to(device=device, dtype=dtype)
return sigma

def sample_renoise_sigma(self, batch_size, device=None, dtype=None):
device = device or self.device
raw = torch.rand((int(batch_size),), device=device, dtype=torch.float32)
if self.discrete_samples > 0:
raw = torch.ceil(raw * self.discrete_samples) / self.discrete_samples
raw = torch.clamp(raw, 1e-7, 1 - 1e-7)
sigma = torch.clamp(self.linear_shift(self.renoise_shift, raw), self.min_sigma, self.max_sigma)
if dtype is not None:
sigma = sigma.to(dtype=dtype)
return sigma

def add_noise(self, latent, noise, sigmas):
return ((1.0 - sigmas) * latent + sigmas * noise).to(dtype=latent.dtype)

def step_by_index(self, velocity, step_idx, sample):
sigma = self.sigma_at(step_idx, sample.shape[0], device=sample.device)
sigma_next = self.sigma_at(int(step_idx) + 1, sample.shape[0], device=sample.device)
next_sample = sample + (sigma_next - sigma) * velocity
x0 = sample - sigma * velocity
return next_sample.to(sample.dtype), x0.to(sample.dtype)
3 changes: 2 additions & 1 deletion lightx2v_train/lightx2v_train/trainers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from lightx2v_train.utils.registry import build_trainer

from .dmd_lora import DmdLoraTrainer
from .lora import LoraTrainer

__all__ = ["build_trainer", "LoraTrainer"]
__all__ = ["build_trainer", "DmdLoraTrainer", "LoraTrainer"]
Loading
Loading