diff --git a/lightx2v_train/configs/lora/flux2_klein_lora.yaml b/lightx2v_train/configs/lora/flux2_klein_lora.yaml new file mode 100644 index 000000000..9d8b04417 --- /dev/null +++ b/lightx2v_train/configs/lora/flux2_klein_lora.yaml @@ -0,0 +1,86 @@ +model: + name: flux2_klein + pretrained_model_name_or_path: /mnt/miaohua/wangshankun/LightX2V/FLUX.2-klein-base-9B + max_sequence_length: 512 + text_encoder_out_layers: [9, 18, 27] + running_dtype: bf16 + +data: + train: + name: image_dataset + num_workers: 8 + prompt_dropout_rate: 0.1 + target_area: 1048576 # 1024 * 1024 + shuffle: true + # examples: https://github.com/ModelTC/LightX2V_train_data_examples + data_path: + - /mnt/miaohua/wangshankun/LightX2V_train_data_examples/dataset_v1/train.jsonl + val: + name: image_dataset + num_workers: 8 + shuffle: false + data_path: + - /mnt/miaohua/wangshankun/LightX2V_train_data_examples/dataset_v1/val.jsonl + +scheduler: + num_train_timesteps: 1000 + timestep_distribution: logitnormal + logitnormal_mean: 0.0 + logitnormal_std: 1.0 + min_t: 0.001 + max_t: 1.0 + time_shift_settings: + do_time_shift: true + shift_type: exponential + time_shift_power: 1.0 + dynamic_shift: true + shift_mu_strategy: flux2_empirical + shift_mu_num_steps: 50 + # Flux2 latents are already 2x2-patchified before scheduler shift length calculation. + patch_size: [1, 1] # [H, W] + +training: + method: lora + max_train_iters: 100 + gradient_accumulation_iters: 1 + gradient_checkpointing: true + max_grad_norm: 1.0 + lr_scheduler: constant + lr_warmup_iters: 10 + save_every_iters: 100 + save_total_limit: 10 + lora: + rank: 16 + alpha: 16 + target_modules: + - to_q + - to_k + - to_v + - to_out.0 + - add_q_proj + - add_k_proj + - add_v_proj + - to_add_out + - to_qkv_mlp_proj + optimizer: + learning_rate: 0.0001 + adam_beta1: 0.9 + adam_beta2: 0.999 + weight_decay: 0.01 + adam_epsilon: 0.00000001 + output_dir: ./output_train/flux2_klein_lora + +inference: + method: image_infer + negative_prompt: "" + default_width: 1024 + default_height: 1024 + num_inference_steps: 50 + enable_cfg: true + cfg_guidance_scale: 4.0 + seed: 42 + output_dir: ./output_infer/flux2_klein_lora + infer_every_iters: ${training.save_every_iters} + +resume: + auto_resume: true diff --git a/lightx2v_train/configs/lora/longcat_image_lora.yaml b/lightx2v_train/configs/lora/longcat_image_lora.yaml new file mode 100644 index 000000000..3053d850e --- /dev/null +++ b/lightx2v_train/configs/lora/longcat_image_lora.yaml @@ -0,0 +1,83 @@ +model: + name: longcat_image + pretrained_model_name_or_path: /data/nvme1/models/meituan-longcat/LongCat-Image + max_sequence_length: 1024 + running_dtype: bf16 + +data: + train: + name: image_dataset + num_workers: 8 + prompt_dropout_rate: 0.1 + target_area: 1048576 # 1024 * 1024 + shuffle: true + # examples: https://github.com/ModelTC/LightX2V_train_data_examples + data_path: + - /data/nvme1/yongyang/kkk/LightX2V_train_data_examples/dataset_v1/train.jsonl + val: + name: image_dataset + num_workers: 8 + shuffle: false + data_path: + - /data/nvme1/yongyang/kkk/LightX2V_train_data_examples/dataset_v1/val.jsonl + +scheduler: + num_train_timesteps: 1000 + timestep_distribution: logitnormal + logitnormal_mean: 0.0 + logitnormal_std: 1.0 + min_t: 0.001 + max_t: 1.0 + time_shift_settings: + do_time_shift: true + shift_type: exponential + # shift function: "linear" => mu/(mu+(1/t-1)^p), "exponential" => exp(mu)/(exp(mu)+(1/t-1)^p) + time_shift_power: 1.0 + dynamic_shift: true + patch_size: [2, 2] # [H, W] + # https://github.com/huggingface/diffusers/blob/v0.38.0/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py#L59 + shift_x1: 256 + shift_x2: 4096 + shift_y1: 0.5 + shift_y2: 1.15 + +training: + method: lora + max_train_iters: 3000 + gradient_accumulation_iters: 1 + gradient_checkpointing: true + max_grad_norm: 1.0 + lr_scheduler: constant + lr_warmup_iters: 10 + save_every_iters: 100 + save_total_limit: 10 + lora: + rank: 16 + alpha: 16 + target_modules: + - to_k + - to_q + - to_v + - to_out.0 + optimizer: + learning_rate: 0.0001 + adam_beta1: 0.9 + adam_beta2: 0.999 + weight_decay: 0.01 + adam_epsilon: 0.00000001 + output_dir: ./output_train/longcat_image_lora + +inference: + method: image_infer + negative_prompt: " " + default_width: 1024 + default_height: 1024 + num_inference_steps: 50 + enable_cfg: true + cfg_guidance_scale: 4.0 + seed: 42 + output_dir: ./output_infer/longcat_image_lora + infer_every_iters: ${training.save_every_iters} + +resume: + auto_resume: true diff --git a/lightx2v_train/lightx2v_train/infer/image.py b/lightx2v_train/lightx2v_train/infer/image.py index a5a702fc9..3e424667c 100644 --- a/lightx2v_train/lightx2v_train/infer/image.py +++ b/lightx2v_train/lightx2v_train/infer/image.py @@ -51,7 +51,7 @@ def infer(self): generator = torch.Generator(device=self.model.device).manual_seed(base_seed + i) pos_cond = self.model.encode_condition({"prompt": prompt}) latent = self.model.prepare_infer_latents(height, width, generator) - latent_hw = (latent.shape[3], latent.shape[4]) + latent_hw = (latent.shape[-2], latent.shape[-1]) self.scheduler.set_timesteps(num_inference_steps, latent_hw=latent_hw) for step_idx, current_timestep in enumerate(tqdm(self.scheduler.infer_timesteps, desc=f"[{i + 1}/{len(prompts)}] Denoising")): diff --git a/lightx2v_train/lightx2v_train/model_zoo/__init__.py b/lightx2v_train/lightx2v_train/model_zoo/__init__.py index 97c480bdc..b665ba177 100644 --- a/lightx2v_train/lightx2v_train/model_zoo/__init__.py +++ b/lightx2v_train/lightx2v_train/model_zoo/__init__.py @@ -1,6 +1,7 @@ from lightx2v_train.utils.registry import build_model +from .flux2_klein import Flux2KleinModel from .longcat_image import LongCatImageModel from .qwen_image import QwenImageModel -__all__ = ["build_model", "QwenImageModel", "LongCatImageModel"] +__all__ = ["build_model", "QwenImageModel", "LongCatImageModel", "Flux2KleinModel"] diff --git a/lightx2v_train/lightx2v_train/model_zoo/flux2_klein.py b/lightx2v_train/lightx2v_train/model_zoo/flux2_klein.py new file mode 100644 index 000000000..17b39811e --- /dev/null +++ b/lightx2v_train/lightx2v_train/model_zoo/flux2_klein.py @@ -0,0 +1,133 @@ +from dataclasses import dataclass + +import torch +from diffusers import AutoencoderKLFlux2, Flux2KleinPipeline, Flux2Transformer2DModel +from diffusers.pipelines.flux2.image_processor import Flux2ImageProcessor + +from lightx2v_train.utils.registry import MODEL_REGISTER + +from .base import BaseModel + + +@dataclass +class Flux2KleinDenoiserInput: + hidden_states: torch.Tensor + img_ids: torch.Tensor + height: int + width: int + + +@MODEL_REGISTER("flux2_klein") +class Flux2KleinModel(BaseModel): + pipeline_cls = Flux2KleinPipeline + + def load_components(self): + model_path = self.config["model"]["pretrained_model_name_or_path"] + self.text_pipeline = Flux2KleinPipeline.from_pretrained( + model_path, + transformer=None, + vae=None, + torch_dtype=self.running_dtype, + ).to(self.device) + self.vae = AutoencoderKLFlux2.from_pretrained(model_path, subfolder="vae").to(self.device, dtype=self.running_dtype) + self.transformer = Flux2Transformer2DModel.from_pretrained(model_path, subfolder="transformer").to(self.device, dtype=self.running_dtype) + + self.text_pipeline.text_encoder.requires_grad_(False) + self.vae.requires_grad_(False) + self.image_processor = Flux2ImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + + @property + def vae_scale_factor(self): + return 2 ** (len(self.vae.config.block_out_channels) - 1) + + def _normalize_patch_latents(self, latents): + latents = Flux2KleinPipeline._patchify_latents(latents) + latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(latents.device, latents.dtype) + latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps).to(latents.device, latents.dtype) + return (latents - latents_bn_mean) / latents_bn_std + + def _denormalize_patch_latents(self, latents): + latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(latents.device, latents.dtype) + latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps).to(latents.device, latents.dtype) + latents = latents * latents_bn_std + latents_bn_mean + return Flux2KleinPipeline._unpatchify_latents(latents) + + def encode_to_latent(self, sample): + image = sample["target_image"].to(device=self.device, dtype=self.running_dtype) + latent = self.vae.encode(image).latent_dist.sample() + return self._normalize_patch_latents(latent) + + def encode_condition(self, sample): + prompt = sample["prompt"] + model_config = self.config["model"] + prompt_embed, text_ids = self.text_pipeline.encode_prompt( + prompt=prompt, + device=self.device, + num_images_per_prompt=1, + max_sequence_length=model_config.get("max_sequence_length", 512), + text_encoder_out_layers=tuple(model_config.get("text_encoder_out_layers", (9, 18, 27))), + ) + return {"prompt_embed": prompt_embed, "text_ids": text_ids} + + def prepare_denoiser_input(self, noisy_latent): + h, w = noisy_latent.shape[2], noisy_latent.shape[3] + packed = Flux2KleinPipeline._pack_latents(noisy_latent) + img_ids = Flux2KleinPipeline._prepare_latent_ids(noisy_latent).to(self.device) + return Flux2KleinDenoiserInput( + hidden_states=packed, + img_ids=img_ids, + height=h, + width=w, + ) + + def denoise(self, denoiser_input, timestep_or_sigma, condition): + return self.transformer( + hidden_states=denoiser_input.hidden_states, + timestep=timestep_or_sigma, + guidance=None, + encoder_hidden_states=condition["prompt_embed"], + txt_ids=condition["text_ids"], + img_ids=denoiser_input.img_ids, + joint_attention_kwargs={}, + return_dict=False, + )[0] + + def postprocess_denoiser_output(self, prediction, denoiser_input): + return Flux2KleinPipeline._unpack_latents_with_ids( + prediction, + denoiser_input.img_ids, + height=denoiser_input.height, + width=denoiser_input.width, + ) + + def prepare_infer_latents(self, height, width, generator=None): + latent_h = 2 * (int(height) // (self.vae_scale_factor * 2)) + latent_w = 2 * (int(width) // (self.vae_scale_factor * 2)) + shape = (1, self.transformer.config.in_channels, latent_h // 2, latent_w // 2) + return torch.randn(shape, generator=generator, device=self.device, dtype=self.running_dtype) + + def decode_latent(self, latent): + latent = self._denormalize_patch_latents(latent) + image = self.vae.decode(latent).sample + return self.image_processor.postprocess(image, output_type="pil") + + def assemble_pipeline(self, scheduler=None): + return Flux2KleinPipeline( + tokenizer=self.text_pipeline.tokenizer, + text_encoder=self.text_pipeline.text_encoder, + vae=self.vae, + transformer=self.transformer, + scheduler=scheduler or self.text_pipeline.scheduler, + is_distilled=self.text_pipeline.config.is_distilled, + ).to(self.device) + + def get_pipeline_infer_kwargs(self, infer_config): + enable_cfg = infer_config.get("enable_cfg", True) + return { + "height": infer_config.get("height", infer_config.get("default_height", 1024)), + "width": infer_config.get("width", infer_config.get("default_width", 1024)), + "num_inference_steps": infer_config.get("num_inference_steps", 50), + "guidance_scale": infer_config.get("cfg_guidance_scale", 4.0) if enable_cfg else 1.0, + "max_sequence_length": self.config["model"].get("max_sequence_length", 512), + "text_encoder_out_layers": tuple(self.config["model"].get("text_encoder_out_layers", (9, 18, 27))), + } diff --git a/lightx2v_train/lightx2v_train/model_zoo/longcat_image.py b/lightx2v_train/lightx2v_train/model_zoo/longcat_image.py index 3158d3aca..95a2989a7 100644 --- a/lightx2v_train/lightx2v_train/model_zoo/longcat_image.py +++ b/lightx2v_train/lightx2v_train/model_zoo/longcat_image.py @@ -34,6 +34,7 @@ def load_components(self): ).to(self.device) self.vae = AutoencoderKL.from_pretrained(model_path, subfolder="vae").to(self.device, dtype=self.running_dtype) self.transformer = LongCatImageTransformer2DModel.from_pretrained(model_path, subfolder="transformer").to(self.device, dtype=self.running_dtype) + self.text_pipeline.text_encoder.requires_grad_(False) self.vae.requires_grad_(False) @property @@ -50,15 +51,14 @@ def encode_to_latent(self, sample): def encode_condition(self, sample): prompt = sample["prompt"] if self.config.get("enable_prompt_rewrite_training", False): - prompt = self.text_pipeline.rewrite_prompt(prompt, self.device) + prompt = self.text_pipeline.rewire_prompt(prompt, self.device) prompt_embed, text_ids = self.text_pipeline.encode_prompt( prompt=prompt, - device=self.device, num_images_per_prompt=1, ) return {"prompt_embed": prompt_embed, "text_ids": text_ids} - def prepare_denoiser_input(self, noisy_latent, sample, condition): + def prepare_denoiser_input(self, noisy_latent): n = noisy_latent.shape[0] h, w = noisy_latent.shape[2], noisy_latent.shape[3] packed = LongCatImagePipeline._pack_latents(noisy_latent, n, noisy_latent.shape[1], h, w) @@ -119,7 +119,20 @@ def assemble_pipeline(self, scheduler=None): return LongCatImagePipeline( tokenizer=self.text_pipeline.tokenizer, text_encoder=self.text_pipeline.text_encoder, + text_processor=self.text_pipeline.text_processor, vae=self.vae, transformer=self.transformer, scheduler=scheduler or self.text_pipeline.scheduler, ).to(self.device) + + def get_pipeline_infer_kwargs(self, infer_config): + enable_cfg = infer_config.get("enable_cfg", False) + return { + "height": infer_config.get("height", infer_config.get("default_height", 1024)), + "width": infer_config.get("width", infer_config.get("default_width", 1024)), + "num_inference_steps": infer_config.get("num_inference_steps", 50), + "guidance_scale": infer_config.get("cfg_guidance_scale", 4.0) if enable_cfg else 1.0, + "enable_cfg_renorm": infer_config.get("enable_cfg_renorm", True), + "cfg_renorm_min": infer_config.get("cfg_renorm_min", 0.0), + "enable_prompt_rewrite": infer_config.get("enable_prompt_rewrite", True), + } diff --git a/lightx2v_train/lightx2v_train/schedulers/flow_matching.py b/lightx2v_train/lightx2v_train/schedulers/flow_matching.py index 23451100e..ae2104c22 100644 --- a/lightx2v_train/lightx2v_train/schedulers/flow_matching.py +++ b/lightx2v_train/lightx2v_train/schedulers/flow_matching.py @@ -26,12 +26,18 @@ def __init__(self, config): self.shift_type = time_shift_settings.get("shift_type", "linear") self.dynamic_shift = time_shift_settings.get("dynamic_shift", False) if self.dynamic_shift: - self.shift_x1 = time_shift_settings["shift_x1"] - self.shift_x2 = time_shift_settings["shift_x2"] - self.shift_y1 = time_shift_settings["shift_y1"] - self.shift_y2 = time_shift_settings["shift_y2"] - self._mu_slope = (self.shift_y2 - self.shift_y1) / (self.shift_x2 - self.shift_x1) - self._mu_bias = self.shift_y1 - self._mu_slope * self.shift_x1 + self.shift_mu_strategy = time_shift_settings.get("shift_mu_strategy", "linear") + if self.shift_mu_strategy == "linear": + self.shift_x1 = time_shift_settings["shift_x1"] + self.shift_x2 = time_shift_settings["shift_x2"] + self.shift_y1 = time_shift_settings["shift_y1"] + self.shift_y2 = time_shift_settings["shift_y2"] + self._mu_slope = (self.shift_y2 - self.shift_y1) / (self.shift_x2 - self.shift_x1) + self._mu_bias = self.shift_y1 - self._mu_slope * self.shift_x1 + elif self.shift_mu_strategy == "flux2_empirical": + self.shift_mu_num_steps = time_shift_settings.get("shift_mu_num_steps", 50) + else: + raise ValueError(f"Unsupported shift_mu_strategy: {self.shift_mu_strategy}") self.patch_size = time_shift_settings.get("patch_size", [2, 2]) else: self.time_shift_mu = time_shift_settings.get("time_shift_mu", 5.0) @@ -45,12 +51,31 @@ def __init__(self, config): self.infer_timesteps = None self.num_inference_steps = None - def _get_time_shift_mu(self, latent_hw=None): + @staticmethod + def _compute_flux2_empirical_mu(image_seq_len, num_steps): + a1, b1 = 8.73809524e-05, 1.89833333 + a2, b2 = 0.00016927, 0.45666666 + + if image_seq_len > 4300: + return float(a2 * image_seq_len + b2) + + m_200 = a2 * image_seq_len + b2 + m_10 = a1 * image_seq_len + b1 + a = (m_200 - m_10) / 190.0 + b = m_200 - 200.0 * a + return float(a * num_steps + b) + + def _get_time_shift_mu(self, latent_hw=None, num_steps=None): if self.dynamic_shift: if latent_hw is None: raise ValueError("latent_hw=(H, W) must be provided when dynamic_shift=True") h, w = latent_hw image_seq_len = (h // self.patch_size[0]) * (w // self.patch_size[1]) + if self.shift_mu_strategy == "flux2_empirical": + return self._compute_flux2_empirical_mu( + image_seq_len=image_seq_len, + num_steps=num_steps or self.shift_mu_num_steps, + ) return self._mu_slope * image_seq_len + self._mu_bias return self.time_shift_mu @@ -68,8 +93,8 @@ def sample_timestep_or_sigma(self, num_samples, latent_hw=None): timestep_or_sigma = self.time_shift(timestep_or_sigma, latent_hw=latent_hw) return timestep_or_sigma.to(self.running_dtype) - def time_shift(self, t, latent_hw=None): - mu = self._get_time_shift_mu(latent_hw) + def time_shift(self, t, latent_hw=None, num_steps=None): + mu = self._get_time_shift_mu(latent_hw, num_steps=num_steps) if self.shift_type == "exponential": mu = math.exp(mu) return mu / (mu + (1 / t - 1) ** self.time_shift_power) @@ -89,7 +114,7 @@ def set_timesteps(self, num_inference_steps, sigmas=None, latent_hw=None): if sigmas is None: sigmas = torch.linspace(1.0, 1.0 / num_inference_steps, num_inference_steps) if self.do_time_shift: - sigmas = self.time_shift(sigmas, latent_hw=latent_hw) + sigmas = self.time_shift(sigmas, latent_hw=latent_hw, num_steps=num_inference_steps) else: sigmas = torch.tensor(sigmas, dtype=torch.float32) self.infer_sigmas = torch.cat([sigmas, torch.zeros(1)]).to(self.device) diff --git a/lightx2v_train/lightx2v_train/trainers/lora.py b/lightx2v_train/lightx2v_train/trainers/lora.py index 4bcd6b36d..8372b3e76 100644 --- a/lightx2v_train/lightx2v_train/trainers/lora.py +++ b/lightx2v_train/lightx2v_train/trainers/lora.py @@ -89,7 +89,7 @@ def compute_loss_on_sample(self, sample): latent = self.model.encode_to_latent(sample) n = latent.shape[0] noise = torch.randn_like(latent, dtype=self.running_dtype) - latent_hw = (latent.shape[3], latent.shape[4]) + latent_hw = (latent.shape[-2], latent.shape[-1]) timestep_or_sigma = self.noise_scheduler.sample_timestep_or_sigma(n, latent_hw=latent_hw) noisy_latent = self.noise_scheduler.add_noise(latent, noise, timestep_or_sigma) condition = self.model.encode_condition(sample)