Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 86 additions & 0 deletions lightx2v_train/configs/lora/flux2_klein_lora.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
model:
name: flux2_klein
pretrained_model_name_or_path: /mnt/miaohua/wangshankun/LightX2V/FLUX.2-klein-base-9B
max_sequence_length: 512
text_encoder_out_layers: [9, 18, 27]
running_dtype: bf16

data:
train:
name: image_dataset
num_workers: 8
prompt_dropout_rate: 0.1
target_area: 1048576 # 1024 * 1024
shuffle: true
# examples: https://github.com/ModelTC/LightX2V_train_data_examples
data_path:
- /mnt/miaohua/wangshankun/LightX2V_train_data_examples/dataset_v1/train.jsonl
val:
name: image_dataset
num_workers: 8
shuffle: false
data_path:
- /mnt/miaohua/wangshankun/LightX2V_train_data_examples/dataset_v1/val.jsonl

scheduler:
num_train_timesteps: 1000
timestep_distribution: logitnormal
logitnormal_mean: 0.0
logitnormal_std: 1.0
min_t: 0.001
max_t: 1.0
time_shift_settings:
do_time_shift: true
shift_type: exponential
time_shift_power: 1.0
dynamic_shift: true
shift_mu_strategy: flux2_empirical
shift_mu_num_steps: 50
# Flux2 latents are already 2x2-patchified before scheduler shift length calculation.
patch_size: [1, 1] # [H, W]

training:
method: lora
max_train_iters: 100
gradient_accumulation_iters: 1
gradient_checkpointing: true
max_grad_norm: 1.0
lr_scheduler: constant
lr_warmup_iters: 10
save_every_iters: 100
save_total_limit: 10
lora:
rank: 16
alpha: 16
target_modules:
- to_q
- to_k
- to_v
- to_out.0
- add_q_proj
- add_k_proj
- add_v_proj
- to_add_out
- to_qkv_mlp_proj
optimizer:
learning_rate: 0.0001
adam_beta1: 0.9
adam_beta2: 0.999
weight_decay: 0.01
adam_epsilon: 0.00000001
output_dir: ./output_train/flux2_klein_lora

inference:
method: image_infer
negative_prompt: ""
default_width: 1024
default_height: 1024
num_inference_steps: 50
enable_cfg: true
cfg_guidance_scale: 4.0
seed: 42
output_dir: ./output_infer/flux2_klein_lora
infer_every_iters: ${training.save_every_iters}

resume:
auto_resume: true
83 changes: 83 additions & 0 deletions lightx2v_train/configs/lora/longcat_image_lora.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
model:
name: longcat_image
pretrained_model_name_or_path: /data/nvme1/models/meituan-longcat/LongCat-Image
max_sequence_length: 1024
running_dtype: bf16

data:
train:
name: image_dataset
num_workers: 8
prompt_dropout_rate: 0.1
target_area: 1048576 # 1024 * 1024
shuffle: true
# examples: https://github.com/ModelTC/LightX2V_train_data_examples
data_path:
- /data/nvme1/yongyang/kkk/LightX2V_train_data_examples/dataset_v1/train.jsonl
val:
name: image_dataset
num_workers: 8
shuffle: false
data_path:
- /data/nvme1/yongyang/kkk/LightX2V_train_data_examples/dataset_v1/val.jsonl

scheduler:
num_train_timesteps: 1000
timestep_distribution: logitnormal
logitnormal_mean: 0.0
logitnormal_std: 1.0
min_t: 0.001
max_t: 1.0
time_shift_settings:
do_time_shift: true
shift_type: exponential
# shift function: "linear" => mu/(mu+(1/t-1)^p), "exponential" => exp(mu)/(exp(mu)+(1/t-1)^p)
time_shift_power: 1.0
dynamic_shift: true
patch_size: [2, 2] # [H, W]
# https://github.com/huggingface/diffusers/blob/v0.38.0/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py#L59
shift_x1: 256
shift_x2: 4096
shift_y1: 0.5
shift_y2: 1.15

training:
method: lora
max_train_iters: 3000
gradient_accumulation_iters: 1
gradient_checkpointing: true
max_grad_norm: 1.0
lr_scheduler: constant
lr_warmup_iters: 10
save_every_iters: 100
save_total_limit: 10
lora:
rank: 16
alpha: 16
target_modules:
- to_k
- to_q
- to_v
- to_out.0
optimizer:
learning_rate: 0.0001
adam_beta1: 0.9
adam_beta2: 0.999
weight_decay: 0.01
adam_epsilon: 0.00000001
output_dir: ./output_train/longcat_image_lora

inference:
method: image_infer
negative_prompt: " "
default_width: 1024
default_height: 1024
num_inference_steps: 50
enable_cfg: true
cfg_guidance_scale: 4.0
seed: 42
output_dir: ./output_infer/longcat_image_lora
infer_every_iters: ${training.save_every_iters}

resume:
auto_resume: true
2 changes: 1 addition & 1 deletion lightx2v_train/lightx2v_train/infer/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def infer(self):
generator = torch.Generator(device=self.model.device).manual_seed(base_seed + i)
pos_cond = self.model.encode_condition({"prompt": prompt})
latent = self.model.prepare_infer_latents(height, width, generator)
latent_hw = (latent.shape[3], latent.shape[4])
latent_hw = (latent.shape[-2], latent.shape[-1])
self.scheduler.set_timesteps(num_inference_steps, latent_hw=latent_hw)

for step_idx, current_timestep in enumerate(tqdm(self.scheduler.infer_timesteps, desc=f"[{i + 1}/{len(prompts)}] Denoising")):
Expand Down
3 changes: 2 additions & 1 deletion lightx2v_train/lightx2v_train/model_zoo/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from lightx2v_train.utils.registry import build_model

from .flux2_klein import Flux2KleinModel
from .longcat_image import LongCatImageModel
from .qwen_image import QwenImageModel

__all__ = ["build_model", "QwenImageModel", "LongCatImageModel"]
__all__ = ["build_model", "QwenImageModel", "LongCatImageModel", "Flux2KleinModel"]
133 changes: 133 additions & 0 deletions lightx2v_train/lightx2v_train/model_zoo/flux2_klein.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
from dataclasses import dataclass

import torch
from diffusers import AutoencoderKLFlux2, Flux2KleinPipeline, Flux2Transformer2DModel
from diffusers.pipelines.flux2.image_processor import Flux2ImageProcessor

from lightx2v_train.utils.registry import MODEL_REGISTER

from .base import BaseModel


@dataclass
class Flux2KleinDenoiserInput:
hidden_states: torch.Tensor
img_ids: torch.Tensor
height: int
width: int


@MODEL_REGISTER("flux2_klein")
class Flux2KleinModel(BaseModel):
pipeline_cls = Flux2KleinPipeline

def load_components(self):
model_path = self.config["model"]["pretrained_model_name_or_path"]
self.text_pipeline = Flux2KleinPipeline.from_pretrained(
model_path,
transformer=None,
vae=None,
torch_dtype=self.running_dtype,
).to(self.device)
self.vae = AutoencoderKLFlux2.from_pretrained(model_path, subfolder="vae").to(self.device, dtype=self.running_dtype)
self.transformer = Flux2Transformer2DModel.from_pretrained(model_path, subfolder="transformer").to(self.device, dtype=self.running_dtype)

self.text_pipeline.text_encoder.requires_grad_(False)
self.vae.requires_grad_(False)
self.image_processor = Flux2ImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)

@property
def vae_scale_factor(self):
return 2 ** (len(self.vae.config.block_out_channels) - 1)

def _normalize_patch_latents(self, latents):
latents = Flux2KleinPipeline._patchify_latents(latents)
latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(latents.device, latents.dtype)
latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps).to(latents.device, latents.dtype)
return (latents - latents_bn_mean) / latents_bn_std

def _denormalize_patch_latents(self, latents):
latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(latents.device, latents.dtype)
latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps).to(latents.device, latents.dtype)
latents = latents * latents_bn_std + latents_bn_mean
return Flux2KleinPipeline._unpatchify_latents(latents)

def encode_to_latent(self, sample):
image = sample["target_image"].to(device=self.device, dtype=self.running_dtype)
latent = self.vae.encode(image).latent_dist.sample()
return self._normalize_patch_latents(latent)

def encode_condition(self, sample):
prompt = sample["prompt"]
model_config = self.config["model"]
prompt_embed, text_ids = self.text_pipeline.encode_prompt(
prompt=prompt,
device=self.device,
num_images_per_prompt=1,
max_sequence_length=model_config.get("max_sequence_length", 512),
text_encoder_out_layers=tuple(model_config.get("text_encoder_out_layers", (9, 18, 27))),
)
return {"prompt_embed": prompt_embed, "text_ids": text_ids}

def prepare_denoiser_input(self, noisy_latent):
h, w = noisy_latent.shape[2], noisy_latent.shape[3]
packed = Flux2KleinPipeline._pack_latents(noisy_latent)
img_ids = Flux2KleinPipeline._prepare_latent_ids(noisy_latent).to(self.device)
return Flux2KleinDenoiserInput(
hidden_states=packed,
img_ids=img_ids,
height=h,
width=w,
)

def denoise(self, denoiser_input, timestep_or_sigma, condition):
return self.transformer(
hidden_states=denoiser_input.hidden_states,
timestep=timestep_or_sigma,
guidance=None,
encoder_hidden_states=condition["prompt_embed"],
txt_ids=condition["text_ids"],
img_ids=denoiser_input.img_ids,
joint_attention_kwargs={},
return_dict=False,
)[0]

def postprocess_denoiser_output(self, prediction, denoiser_input):
return Flux2KleinPipeline._unpack_latents_with_ids(
prediction,
denoiser_input.img_ids,
height=denoiser_input.height,
width=denoiser_input.width,
)

def prepare_infer_latents(self, height, width, generator=None):
latent_h = 2 * (int(height) // (self.vae_scale_factor * 2))
latent_w = 2 * (int(width) // (self.vae_scale_factor * 2))
shape = (1, self.transformer.config.in_channels, latent_h // 2, latent_w // 2)
return torch.randn(shape, generator=generator, device=self.device, dtype=self.running_dtype)

def decode_latent(self, latent):
latent = self._denormalize_patch_latents(latent)
image = self.vae.decode(latent).sample
return self.image_processor.postprocess(image, output_type="pil")

def assemble_pipeline(self, scheduler=None):
return Flux2KleinPipeline(
tokenizer=self.text_pipeline.tokenizer,
text_encoder=self.text_pipeline.text_encoder,
vae=self.vae,
transformer=self.transformer,
scheduler=scheduler or self.text_pipeline.scheduler,
is_distilled=self.text_pipeline.config.is_distilled,
).to(self.device)

def get_pipeline_infer_kwargs(self, infer_config):
enable_cfg = infer_config.get("enable_cfg", True)
return {
"height": infer_config.get("height", infer_config.get("default_height", 1024)),
"width": infer_config.get("width", infer_config.get("default_width", 1024)),
"num_inference_steps": infer_config.get("num_inference_steps", 50),
"guidance_scale": infer_config.get("cfg_guidance_scale", 4.0) if enable_cfg else 1.0,
"max_sequence_length": self.config["model"].get("max_sequence_length", 512),
"text_encoder_out_layers": tuple(self.config["model"].get("text_encoder_out_layers", (9, 18, 27))),
}
19 changes: 16 additions & 3 deletions lightx2v_train/lightx2v_train/model_zoo/longcat_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def load_components(self):
).to(self.device)
self.vae = AutoencoderKL.from_pretrained(model_path, subfolder="vae").to(self.device, dtype=self.running_dtype)
self.transformer = LongCatImageTransformer2DModel.from_pretrained(model_path, subfolder="transformer").to(self.device, dtype=self.running_dtype)
self.text_pipeline.text_encoder.requires_grad_(False)
self.vae.requires_grad_(False)

@property
Expand All @@ -50,15 +51,14 @@ def encode_to_latent(self, sample):
def encode_condition(self, sample):
prompt = sample["prompt"]
if self.config.get("enable_prompt_rewrite_training", False):
prompt = self.text_pipeline.rewrite_prompt(prompt, self.device)
prompt = self.text_pipeline.rewire_prompt(prompt, self.device)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Typo detected: rewire_prompt should likely be rewrite_prompt. The configuration key enable_prompt_rewrite_training and the previous version of the code both use "rewrite". This will cause an AttributeError at runtime if the method does not exist.

Suggested change
prompt = self.text_pipeline.rewire_prompt(prompt, self.device)
prompt = self.text_pipeline.rewrite_prompt(prompt, self.device)

prompt_embed, text_ids = self.text_pipeline.encode_prompt(
prompt=prompt,
device=self.device,
num_images_per_prompt=1,
)
Comment on lines 55 to 58
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The device argument was removed from the encode_prompt call. While the pipeline might default to its own device, explicitly passing self.device ensures consistency and avoids potential device mismatch issues, especially since it was explicitly provided in the previous version.

Suggested change
prompt_embed, text_ids = self.text_pipeline.encode_prompt(
prompt=prompt,
device=self.device,
num_images_per_prompt=1,
)
prompt_embed, text_ids = self.text_pipeline.encode_prompt(
prompt=prompt,
device=self.device,
num_images_per_prompt=1,
)

return {"prompt_embed": prompt_embed, "text_ids": text_ids}

def prepare_denoiser_input(self, noisy_latent, sample, condition):
def prepare_denoiser_input(self, noisy_latent):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The signature of prepare_denoiser_input has been changed to accept only one argument (noisy_latent), which violates the interface defined in the base class BaseModel (which expects noisy_latent, sample, condition). To maintain compatibility with the base class while supporting callers that only provide one argument (like the current LoraTrainer), consider using optional arguments.

Suggested change
def prepare_denoiser_input(self, noisy_latent):
def prepare_denoiser_input(self, noisy_latent, sample=None, condition=None):

n = noisy_latent.shape[0]
h, w = noisy_latent.shape[2], noisy_latent.shape[3]
packed = LongCatImagePipeline._pack_latents(noisy_latent, n, noisy_latent.shape[1], h, w)
Expand Down Expand Up @@ -119,7 +119,20 @@ def assemble_pipeline(self, scheduler=None):
return LongCatImagePipeline(
tokenizer=self.text_pipeline.tokenizer,
text_encoder=self.text_pipeline.text_encoder,
text_processor=self.text_pipeline.text_processor,
vae=self.vae,
transformer=self.transformer,
scheduler=scheduler or self.text_pipeline.scheduler,
).to(self.device)

def get_pipeline_infer_kwargs(self, infer_config):
enable_cfg = infer_config.get("enable_cfg", False)
return {
"height": infer_config.get("height", infer_config.get("default_height", 1024)),
"width": infer_config.get("width", infer_config.get("default_width", 1024)),
"num_inference_steps": infer_config.get("num_inference_steps", 50),
"guidance_scale": infer_config.get("cfg_guidance_scale", 4.0) if enable_cfg else 1.0,
"enable_cfg_renorm": infer_config.get("enable_cfg_renorm", True),
"cfg_renorm_min": infer_config.get("cfg_renorm_min", 0.0),
"enable_prompt_rewrite": infer_config.get("enable_prompt_rewrite", True),
}
Loading
Loading