diff --git a/lightx2v_train/configs/dmd_lora/qwen_image_dmd_lora.yaml b/lightx2v_train/configs/dmd_lora/qwen_image_dmd_lora.yaml new file mode 100644 index 000000000..831372175 --- /dev/null +++ b/lightx2v_train/configs/dmd_lora/qwen_image_dmd_lora.yaml @@ -0,0 +1,106 @@ +model: + name: qwen_image + pretrained_model_name_or_path: /path/to/Qwen/Qwen-Image + max_sequence_length: 1024 + running_dtype: bf16 + +data: + train: + name: image_dataset + num_workers: 8 + prompt_dropout_rate: 0.0 + target_area: 1048576 # 1024 * 1024 + shuffle: true + data_path: + - /path/to/LightX2V_train_data_examples/dataset_v1/train.jsonl + val: + name: image_dataset + num_workers: 8 + shuffle: false + data_path: + - /path/to/LightX2V_train_data_examples/dataset_v1/val.jsonl + +scheduler: + num_train_timesteps: 1000 + timestep_distribution: uniform + min_t: 0.001 + max_t: 1.0 + time_shift_settings: + do_time_shift: true + shift_type: exponential + # shift function: "linear" => mu/(mu+(1/t-1)^p), "exponential" => exp(mu)/(exp(mu)+(1/t-1)^p) + time_shift_power: 1.0 + dynamic_shift: true + patch_size: [2, 2] # [H, W] + # https://github.com/huggingface/diffusers/blob/v0.38.0/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py#L59 + shift_x1: 256 + shift_x2: 4096 + shift_y1: 0.5 + shift_y2: 1.15 + +training: + method: dmd_lora + max_train_iters: 10 + gradient_accumulation_iters: 1 + gradient_checkpointing: true + max_grad_norm: 1.0 + lr_scheduler: constant + lr_warmup_iters: 10 + save_every_iters: 5 + save_total_limit: 10 + lora: + rank: 32 + alpha: 32 + target_modules: + - to_k + - to_q + - to_v + - to_out.0 + # - add_q_proj + # - add_k_proj + # - add_v_proj + # - to_add_out + # - img_mlp.net.0.proj + # - img_mlp.net.2 + # - txt_mlp.net.0.proj + # - txt_mlp.net.2 + optimizer: + learning_rate: 0.0001 + adam_beta1: 0.9 + adam_beta2: 0.999 + weight_decay: 0.001 + adam_epsilon: 0.00000001 + fake: + optimizer: + learning_rate: 0.00002 + adam_beta1: 0.9 + adam_beta2: 0.999 + weight_decay: 0.001 + adam_epsilon: 0.00000001 + dmd: + num_inference_steps: 4 + fake_update_ratio: 2 + guidance_scale: 4.0 + negative_prompt: " " + cfg_norm: layer_norm + image_sizes: + - [1024, 1024] + - [768, 1344] + - [1344, 768] + sigma_min: 0.02 + sigma_max: 1.0 + discrete_samples: 1000 + renoise_shift: 5.0 + inference_shift: 3.0 + output_dir: ./output_train/qwen_image_dmd_lora + +inference: + method: image_infer + default_width: 1024 + default_height: 1024 + num_inference_steps: 4 + cfg_guidance_scale: 4.0 + negative_prompt: " " + +resume: + auto_resume: true diff --git a/lightx2v_train/lightx2v_train/model_zoo/qwen_image.py b/lightx2v_train/lightx2v_train/model_zoo/qwen_image.py index 44ed7766c..f8f4ee947 100644 --- a/lightx2v_train/lightx2v_train/model_zoo/qwen_image.py +++ b/lightx2v_train/lightx2v_train/model_zoo/qwen_image.py @@ -26,7 +26,15 @@ class QwenImageModel(BaseModel): pipeline_cls = QwenImagePipeline - def load_components(self): + def load_components(self, transformer_only=False, reference_model=None): + if transformer_only: + if reference_model is not None: + self.text_pipeline = reference_model.text_pipeline + self.vae = reference_model.vae + self.vae_scale_factor = reference_model.vae_scale_factor + self.image_processor = reference_model.image_processor + self.transformer = self.load_transformer() + return model_path = self.config["model"]["pretrained_model_name_or_path"] self.text_pipeline = QwenImagePipeline.from_pretrained( model_path, @@ -35,13 +43,17 @@ def load_components(self): torch_dtype=self.running_dtype, ).to(self.device) self.vae = AutoencoderKLQwenImage.from_pretrained(model_path, subfolder="vae").to(self.device, dtype=self.running_dtype) - self.transformer = QwenImageTransformer2DModel.from_pretrained(model_path, subfolder="transformer").to(self.device, dtype=self.running_dtype) + self.transformer = self.load_transformer() self.text_pipeline.text_encoder.requires_grad_(False) self.vae.requires_grad_(False) self.vae_scale_factor = 2 ** len(self.vae.temperal_downsample) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + def load_transformer(self): + model_path = self.config["model"]["pretrained_model_name_or_path"] + return QwenImageTransformer2DModel.from_pretrained(model_path, subfolder="transformer").to(self.device, dtype=self.running_dtype) + def encode_to_latent(self, sample): image = sample["target_image"].to(device=self.device, dtype=self.running_dtype) pixel_values = image.unsqueeze(2) @@ -53,6 +65,9 @@ def encode_to_latent(self, sample): def encode_condition(self, sample): prompt = sample["prompt"] + return self.encode_prompt_condition(prompt) + + def encode_prompt_condition(self, prompt): prompt_embed, prompt_embed_mask = self.text_pipeline.encode_prompt( prompt=prompt, device=self.device, diff --git a/lightx2v_train/lightx2v_train/schedulers/__init__.py b/lightx2v_train/lightx2v_train/schedulers/__init__.py index e69de29bb..e33e2c0a2 100644 --- a/lightx2v_train/lightx2v_train/schedulers/__init__.py +++ b/lightx2v_train/lightx2v_train/schedulers/__init__.py @@ -0,0 +1,4 @@ +from .dmd_scheduler import DMDFlowMatchingScheduler +from .flow_matching import RectifiedFlowMatchingScheduler + +__all__ = ["DMDFlowMatchingScheduler", "RectifiedFlowMatchingScheduler"] diff --git a/lightx2v_train/lightx2v_train/schedulers/dmd_scheduler.py b/lightx2v_train/lightx2v_train/schedulers/dmd_scheduler.py new file mode 100644 index 000000000..756542fe1 --- /dev/null +++ b/lightx2v_train/lightx2v_train/schedulers/dmd_scheduler.py @@ -0,0 +1,57 @@ +import torch + +from .flow_matching import RectifiedFlowMatchingScheduler + + +class DMDFlowMatchingScheduler(RectifiedFlowMatchingScheduler): + def __init__(self, config, dmd_config={}): + super().__init__(config) + self.inference_shift = float(dmd_config.get("inference_shift", 3.0)) + self.renoise_shift = float(dmd_config.get("renoise_shift", 5.0)) + self.min_sigma = float(dmd_config.get("sigma_min", 0.02)) + self.max_sigma = float(dmd_config.get("sigma_max", 1.0)) + self.discrete_samples = int(dmd_config.get("discrete_samples", 1000)) + + @staticmethod + def linear_shift(mu, t): + return mu / (mu + (1 / t - 1)) + + def set_timesteps(self, num_inference_steps, device=None): + self.num_inference_steps = int(num_inference_steps) + device = device or self.device + timesteps = torch.linspace( + self.num_train_timesteps, + 0, + self.num_inference_steps + 1, + dtype=torch.float32, + device=device, + ) + self.sigmas = self.linear_shift(self.inference_shift, timesteps / self.num_train_timesteps) + self.timesteps = self.sigmas * self.num_train_timesteps + + def sigma_at(self, step_idx, batch_size, device=None, dtype=None): + sigma = self.sigmas[int(step_idx)].expand(int(batch_size)) + if device is not None or dtype is not None: + sigma = sigma.to(device=device, dtype=dtype) + return sigma + + def sample_renoise_sigma(self, batch_size, device=None, dtype=None): + device = device or self.device + raw = torch.rand((int(batch_size),), device=device, dtype=torch.float32) + if self.discrete_samples > 0: + raw = torch.ceil(raw * self.discrete_samples) / self.discrete_samples + raw = torch.clamp(raw, 1e-7, 1 - 1e-7) + sigma = torch.clamp(self.linear_shift(self.renoise_shift, raw), self.min_sigma, self.max_sigma) + if dtype is not None: + sigma = sigma.to(dtype=dtype) + return sigma + + def add_noise(self, latent, noise, sigmas): + return ((1.0 - sigmas) * latent + sigmas * noise).to(dtype=latent.dtype) + + def step_by_index(self, velocity, step_idx, sample): + sigma = self.sigma_at(step_idx, sample.shape[0], device=sample.device) + sigma_next = self.sigma_at(int(step_idx) + 1, sample.shape[0], device=sample.device) + next_sample = sample + (sigma_next - sigma) * velocity + x0 = sample - sigma * velocity + return next_sample.to(sample.dtype), x0.to(sample.dtype) diff --git a/lightx2v_train/lightx2v_train/trainers/__init__.py b/lightx2v_train/lightx2v_train/trainers/__init__.py index ee9795f3d..a8c7f8fbc 100644 --- a/lightx2v_train/lightx2v_train/trainers/__init__.py +++ b/lightx2v_train/lightx2v_train/trainers/__init__.py @@ -1,5 +1,6 @@ from lightx2v_train.utils.registry import build_trainer +from .dmd_lora import DmdLoraTrainer from .lora import LoraTrainer -__all__ = ["build_trainer", "LoraTrainer"] +__all__ = ["build_trainer", "DmdLoraTrainer", "LoraTrainer"] diff --git a/lightx2v_train/lightx2v_train/trainers/dmd_lora.py b/lightx2v_train/lightx2v_train/trainers/dmd_lora.py new file mode 100644 index 000000000..3ba421c9a --- /dev/null +++ b/lightx2v_train/lightx2v_train/trainers/dmd_lora.py @@ -0,0 +1,288 @@ +import os +import shutil + +import torch +import torch.nn.functional as F +from diffusers.optimization import get_scheduler +from tqdm.auto import tqdm + +from lightx2v_train.model_zoo import build_model +from lightx2v_train.runtime.checkpoint import prune_checkpoints +from lightx2v_train.schedulers import DMDFlowMatchingScheduler +from lightx2v_train.utils.registry import TRAINER_REGISTER + +from .lora import LoraTrainer + + +@TRAINER_REGISTER("dmd_lora") +class DmdLoraTrainer(LoraTrainer): + def __init__(self, config): + super().__init__(config) + fake_config = self.training_config.get("fake", {}) + fake_optimizer_config = fake_config.get("optimizer", {}) + self.fake_optimizer_learning_rate = fake_optimizer_config.get("learning_rate", self.optimizer_learning_rate) + self.fake_optimizer_adam_beta1 = fake_optimizer_config.get("adam_beta1", self.optimizer_adam_beta1) + self.fake_optimizer_adam_beta2 = fake_optimizer_config.get("adam_beta2", self.optimizer_adam_beta2) + self.fake_optimizer_weight_decay = fake_optimizer_config.get("weight_decay", self.optimizer_weight_decay) + self.fake_optimizer_adam_epsilon = fake_optimizer_config.get("adam_epsilon", self.optimizer_adam_epsilon) + + self.dmd_config = self.training_config.get("dmd", {}) + self.num_inference_steps = int(self.dmd_config.get("num_inference_steps", 4)) + self.fake_update_ratio = int(self.dmd_config.get("fake_update_ratio", 1)) + self.guidance_scale = float(self.dmd_config.get("guidance_scale", 3.0)) + self.negative_prompt = self.dmd_config.get("negative_prompt", " ") + self.cfg_norm = self.dmd_config.get("cfg_norm", "layer_norm") + self.image_sizes = self.dmd_config.get("image_sizes", []) + + def setup(self, resume_ckpt_path=None): + super().setup(resume_ckpt_path=resume_ckpt_path) + self.fake_model = build_model(self.config) + self.fake_model.load_components(transformer_only=True, reference_model=self.model) + self.fake_model.add_lora(self.lora_rank, self.lora_alpha, self.lora_target_modules) + self.fake_model.set_lora_trainable() + if self.gradient_checkpointing: + self.fake_model.enable_gradient_checkpointing() + + self.teacher_model = build_model(self.config) + self.teacher_model.load_components(transformer_only=True, reference_model=self.model) + self.teacher_model.transformer.requires_grad_(False) + self.teacher_model.transformer.eval() + + self.fake_optimizer = torch.optim.AdamW( + self.fake_model.trainable_parameters(), + lr=self.fake_optimizer_learning_rate, + betas=(self.fake_optimizer_adam_beta1, self.fake_optimizer_adam_beta2), + weight_decay=self.fake_optimizer_weight_decay, + eps=self.fake_optimizer_adam_epsilon, + ) + self.fake_lr_scheduler = get_scheduler( + self.lr_scheduler_name, + optimizer=self.fake_optimizer, + num_warmup_steps=0, + num_training_steps=max(1, self.max_train_iters * self.fake_update_ratio), + ) + + self.scheduler = DMDFlowMatchingScheduler(self.config, self.dmd_config) + + if resume_ckpt_path is not None: + self.load_resume_ckpt(resume_ckpt_path) + + print(f"[dmd_lora] student trainable params={self._count_trainable(self.model.transformer)}") + print(f"[dmd_lora] fake trainable params={self._count_trainable(self.fake_model.transformer)}") + + @staticmethod + def _count_trainable(module): + return sum(1 for param in module.parameters() if param.requires_grad) + + @staticmethod + def _do_cfg(cond_pred, uncond_pred, cfg_scale, cfg_norm): + pred = uncond_pred + cfg_scale * (cond_pred - uncond_pred) + if cfg_norm in (None, "none"): + return pred + if cfg_norm == "layer_norm": + cond_norm = torch.norm(cond_pred, dim=-1, keepdim=True) + pred_norm = torch.norm(pred, dim=-1, keepdim=True) + return pred * (cond_norm / torch.clamp(pred_norm, min=1e-12)) + if cfg_norm == "scalar": + cond_norm = torch.norm(cond_pred) + pred_norm = torch.norm(pred) + return pred * min(1.0, (cond_norm / torch.clamp(pred_norm, min=1e-12)).item()) + raise ValueError(f"Unsupported cfg_norm: {cfg_norm}") + + @staticmethod + def _dmd_loss(latents, x_pred_fake_flow, x_pred_teacher): + with torch.no_grad(): + grad = x_pred_fake_flow - x_pred_teacher + dims = tuple(range(1, latents.ndim)) + normalizer = torch.abs(latents - x_pred_teacher).mean(dim=dims, keepdim=True) + grad = torch.nan_to_num(grad / normalizer) + return 0.5 * F.mse_loss(latents.float(), (latents.float() - grad.float()).detach(), reduction="mean") + + def _latent_shape(self, sample): + image = sample["target_image"] + batch_size = image.shape[0] + if self.image_sizes: + height, width = self.image_sizes[torch.randint(0, len(self.image_sizes), (1,), device=self.model.device).item()] + else: + height, width = image.shape[-2], image.shape[-1] + + latent_channels = getattr(self.model.vae.config, "z_dim", None) + if latent_channels is None: + latent_channels = self.model.transformer.config.in_channels // 4 + return ( + batch_size, + int(latent_channels), + 1, + height // self.model.vae_scale_factor, + width // self.model.vae_scale_factor, + ) + + def _encode_conditions(self, sample): + prompt = sample["prompt"] + if isinstance(prompt, str): + negative_prompt = self.negative_prompt + else: + negative_prompt = [self.negative_prompt] * len(prompt) + with torch.no_grad(): + condition = self.model.encode_prompt_condition(prompt) + negative_condition = self.model.encode_prompt_condition(negative_prompt) + return condition, negative_condition + + def _predict_velocity(self, model, latents, sigma, condition): + denoiser_input = model.prepare_denoiser_input(latents) + prediction = model.denoise(denoiser_input, sigma, condition) + prediction = model.postprocess_denoiser_output(prediction, denoiser_input) + return prediction + + def sample_initial_latents(self, latent_shape): + return torch.randn(latent_shape, device=self.model.device, dtype=self.running_dtype) + + def sample_end_step(self): + return int(torch.randint(0, self.num_inference_steps, (1,), device=self.model.device).item()) + + def run_back_simulation(self, condition, latent_shape, end_step_idx, grad_enabled, xt=None): + self.scheduler.set_timesteps(self.num_inference_steps) + if xt is None: + xt = self.sample_initial_latents(latent_shape) + x0 = None + self.model.transformer.train() + for idx in range(end_step_idx + 1): + sigma = self.scheduler.sigma_at(idx, latent_shape[0], device=self.model.device, dtype=self.running_dtype) + context = torch.enable_grad if (grad_enabled and idx == end_step_idx) else torch.no_grad + with context(): + velocity = self._predict_velocity(self.model, xt, sigma, condition) + xt, x0 = self.scheduler.step_by_index(velocity, idx, xt) + return x0 + + def forward_loss(self, latent_shape, conditions, stage): + condition, negative_condition = conditions + end_step_idx = self.sample_end_step() + xt_start = self.sample_initial_latents(latent_shape) + x0_ref = self.run_back_simulation(condition, latent_shape, end_step_idx, grad_enabled=False, xt=xt_start) + + sigma = self.scheduler.sample_renoise_sigma(latent_shape[0], device=self.model.device, dtype=self.running_dtype) + noise = torch.randn(latent_shape, device=self.model.device, dtype=torch.float32) + renoised_xt = self.scheduler.add_noise(x0_ref, noise, sigma) + + if stage == "fake": + self.fake_model.transformer.train() + velocity_fake = self._predict_velocity(self.fake_model, renoised_xt, sigma, condition) + velocity_gt = self.scheduler.build_train_gt(x0_ref.float(), noise) + return F.mse_loss(velocity_fake.float(), velocity_gt.float(), reduction="mean") + + with torch.no_grad(): + self.fake_model.transformer.eval() + velocity_fake = self._predict_velocity(self.fake_model, renoised_xt, sigma, condition) + velocity_teacher_cond = self._predict_velocity(self.teacher_model, renoised_xt, sigma, condition) + velocity_teacher_uncond = self._predict_velocity(self.teacher_model, renoised_xt, sigma, negative_condition) + velocity_teacher = self._do_cfg(velocity_teacher_cond, velocity_teacher_uncond, self.guidance_scale, self.cfg_norm) + + x_pred_fake = renoised_xt - sigma * velocity_fake + x_pred_teacher = renoised_xt - sigma * velocity_teacher + x0 = self.run_back_simulation(condition, latent_shape, end_step_idx, grad_enabled=True, xt=xt_start) + return self._dmd_loss(x0, x_pred_fake, x_pred_teacher) + + def train(self): + resume_ckpt_path, current_iter = self._resolve_resume() + self.setup(resume_ckpt_path=resume_ckpt_path) + os.makedirs(self.output_train_dir, exist_ok=True) + + max_train_iters = self.max_train_iters + fake_update_ratio = self.fake_update_ratio + max_grad_norm = self.max_grad_norm + save_every_iters = self.save_every_iters + save_total_limit = self.save_total_limit + running_dmd = 0.0 + running_fake = 0.0 + + progress = tqdm(total=max_train_iters, desc="DMD-LoRA iterations", initial=current_iter) + + while current_iter < max_train_iters: + for sample in self.dataloader_train: + conditions = self._encode_conditions(sample) + latent_shape = self._latent_shape(sample) + + loss_dmd = self.forward_loss(latent_shape, conditions, stage="student") + loss_dmd.backward() + torch.nn.utils.clip_grad_norm_(self.model.transformer.parameters(), max_grad_norm) + self.optimizer.step() + self.lr_scheduler.step() + self.optimizer.zero_grad(set_to_none=True) + running_dmd += loss_dmd.item() + + fake_loss = 0.0 + for _ in range(fake_update_ratio): + loss_fake = self.forward_loss(latent_shape, conditions, stage="fake") + loss_fake.backward() + torch.nn.utils.clip_grad_norm_(self.fake_model.transformer.parameters(), max_grad_norm) + self.fake_optimizer.step() + self.fake_lr_scheduler.step() + self.fake_optimizer.zero_grad(set_to_none=True) + fake_loss += loss_fake.item() + running_fake += fake_loss / fake_update_ratio + + current_iter += 1 + progress.update(1) + progress.set_postfix( + dmd=running_dmd, + fake=running_fake, + lr=self.lr_scheduler.get_last_lr()[0], + ) + running_dmd = 0.0 + running_fake = 0.0 + + if save_every_iters and current_iter % save_every_iters == 0: + self.save_checkpoint(current_iter, save_total_limit) + + if current_iter >= max_train_iters: + break + + progress.close() + + def load_resume_ckpt(self, resume_ckpt_path): + training_state_path = os.path.join(resume_ckpt_path, "training_state.pt") + fake_lora_path = os.path.join(resume_ckpt_path, "fake_lora") + fake_lora_weights_path = os.path.join(fake_lora_path, "pytorch_lora_weights.safetensors") + + if os.path.exists(fake_lora_weights_path): + self.fake_model.load_lora_weights_for_resume(fake_lora_path) + else: + print(f"Warning: fake LoRA weights not found in {fake_lora_path}. Fake model not restored.") + + if not os.path.exists(training_state_path): + return + + state = torch.load(training_state_path, map_location="cpu", weights_only=False) + if "fake_optimizer" in state: + self.fake_optimizer.load_state_dict(state["fake_optimizer"]) + else: + print(f"Warning: fake optimizer state not found in {training_state_path}.") + + if "fake_lr_scheduler" in state: + self.fake_lr_scheduler.load_state_dict(state["fake_lr_scheduler"]) + else: + print(f"Warning: fake lr scheduler state not found in {training_state_path}.") + + def save_checkpoint(self, iteration, save_total_limit): + prune_checkpoints(self.output_train_dir, save_total_limit) + + save_dir = os.path.join(self.output_train_dir, f"checkpoint-{iteration:09d}") + os.makedirs(save_dir, exist_ok=True) + self.model.save_lora_weights(save_dir) + + fake_save_dir = os.path.join(save_dir, "fake_lora") + os.makedirs(fake_save_dir, exist_ok=True) + self.fake_model.save_lora_weights(fake_save_dir) + + config_path = self.config.get("config_path") + if config_path is not None: + shutil.copy2(config_path, os.path.join(save_dir, "config.yaml")) + + training_state = { + "iteration": iteration, + "optimizer": self.optimizer.state_dict(), + "lr_scheduler": self.lr_scheduler.state_dict(), + "fake_optimizer": self.fake_optimizer.state_dict(), + "fake_lr_scheduler": self.fake_lr_scheduler.state_dict(), + } + torch.save(training_state, os.path.join(save_dir, "training_state.pt"))