Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions lightx2v_train/configs/lora/longcat_image_lora.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
model:
name: longcat_image
pretrained_model_name_or_path: /mnt/miaohua/wangshankun/HF/hub/models--meituan-longcat--LongCat-Image/snapshots/d2ea50b79a930074c37b9b97ce45e3b2ea8cf4d8
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The configuration contains absolute paths specific to a local environment (e.g., /mnt/miaohua/wangshankun/...). This makes the configuration non-portable and likely to fail for other users or in different environments. Consider using relative paths or environment variables/placeholders for model and data paths.

max_sequence_length: 1024
running_dtype: bf16

data:
train:
name: image_dataset
num_workers: 8
prompt_dropout_rate: 0.1
target_area: 1048576 # 1024 * 1024
shuffle: true
# examples: https://github.com/ModelTC/LightX2V_train_data_examples
data_path:
- /mnt/miaohua/wangshankun/LightX2V_train_data_examples/dataset_v1/train.jsonl
val:
name: image_dataset
num_workers: 8
shuffle: false
data_path:
- /mnt/miaohua/wangshankun/LightX2V_train_data_examples/dataset_v1/val.jsonl

scheduler:
num_train_timesteps: 1000
timestep_distribution: logitnormal
logitnormal_mean: 0.0
logitnormal_std: 1.0
min_t: 0.001
max_t: 1.0
do_time_shift: true
time_shift_mu: 5.0
time_shift_power: 1.0

training:
method: lora
max_train_iters: 100
gradient_accumulation_iters: 1
gradient_checkpointing: true
max_grad_norm: 1.0
lr_scheduler: constant
lr_warmup_iters: 10
save_every_iters: 100
save_total_limit: 10
lora:
rank: 16
alpha: 16
target_modules:
- to_k
- to_q
- to_v
- to_out.0
optimizer:
learning_rate: 0.0001
adam_beta1: 0.9
adam_beta2: 0.999
weight_decay: 0.01
adam_epsilon: 0.00000001
output_dir: ./output_train/longcat_image_lora

inference:
method: image_infer
negative_prompt: " "
default_width: 1024
default_height: 1024
num_inference_steps: 50
enable_cfg: true
cfg_guidance_scale: 4.0
seed: 42
output_dir: ./output_infer/longcat_image_lora
infer_every_iters: ${training.save_every_iters}

resume:
auto_resume: true
19 changes: 16 additions & 3 deletions lightx2v_train/lightx2v_train/model_zoo/longcat_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def load_components(self):
).to(self.device)
self.vae = AutoencoderKL.from_pretrained(model_path, subfolder="vae").to(self.device, dtype=self.running_dtype)
self.transformer = LongCatImageTransformer2DModel.from_pretrained(model_path, subfolder="transformer").to(self.device, dtype=self.running_dtype)
self.text_pipeline.text_encoder.requires_grad_(False)
self.vae.requires_grad_(False)

@property
Expand All @@ -50,15 +51,14 @@ def encode_to_latent(self, sample):
def encode_condition(self, sample):
prompt = sample["prompt"]
if self.config.get("enable_prompt_rewrite_training", False):
prompt = self.text_pipeline.rewrite_prompt(prompt, self.device)
prompt = self.text_pipeline.rewire_prompt(prompt, self.device)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Typo detected: rewire_prompt should likely be rewrite_prompt. The configuration key enable_prompt_rewrite_training and the previous version of the code both use "rewrite". This will cause an AttributeError at runtime if the method does not exist.

Suggested change
prompt = self.text_pipeline.rewire_prompt(prompt, self.device)
prompt = self.text_pipeline.rewrite_prompt(prompt, self.device)

prompt_embed, text_ids = self.text_pipeline.encode_prompt(
prompt=prompt,
device=self.device,
num_images_per_prompt=1,
)
Comment on lines 55 to 58
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The device argument was removed from the encode_prompt call. While the pipeline might default to its own device, explicitly passing self.device ensures consistency and avoids potential device mismatch issues, especially since it was explicitly provided in the previous version.

Suggested change
prompt_embed, text_ids = self.text_pipeline.encode_prompt(
prompt=prompt,
device=self.device,
num_images_per_prompt=1,
)
prompt_embed, text_ids = self.text_pipeline.encode_prompt(
prompt=prompt,
device=self.device,
num_images_per_prompt=1,
)

return {"prompt_embed": prompt_embed, "text_ids": text_ids}

def prepare_denoiser_input(self, noisy_latent, sample, condition):
def prepare_denoiser_input(self, noisy_latent):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The signature of prepare_denoiser_input has been changed to accept only one argument (noisy_latent), which violates the interface defined in the base class BaseModel (which expects noisy_latent, sample, condition). To maintain compatibility with the base class while supporting callers that only provide one argument (like the current LoraTrainer), consider using optional arguments.

Suggested change
def prepare_denoiser_input(self, noisy_latent):
def prepare_denoiser_input(self, noisy_latent, sample=None, condition=None):

n = noisy_latent.shape[0]
h, w = noisy_latent.shape[2], noisy_latent.shape[3]
packed = LongCatImagePipeline._pack_latents(noisy_latent, n, noisy_latent.shape[1], h, w)
Expand Down Expand Up @@ -119,7 +119,20 @@ def assemble_pipeline(self, scheduler=None):
return LongCatImagePipeline(
tokenizer=self.text_pipeline.tokenizer,
text_encoder=self.text_pipeline.text_encoder,
text_processor=self.text_pipeline.text_processor,
vae=self.vae,
transformer=self.transformer,
scheduler=scheduler or self.text_pipeline.scheduler,
).to(self.device)

def get_pipeline_infer_kwargs(self, infer_config):
enable_cfg = infer_config.get("enable_cfg", False)
return {
"height": infer_config.get("height", infer_config.get("default_height", 1024)),
"width": infer_config.get("width", infer_config.get("default_width", 1024)),
"num_inference_steps": infer_config.get("num_inference_steps", 50),
"guidance_scale": infer_config.get("cfg_guidance_scale", 4.0) if enable_cfg else 1.0,
"enable_cfg_renorm": infer_config.get("enable_cfg_renorm", True),
"cfg_renorm_min": infer_config.get("cfg_renorm_min", 0.0),
"enable_prompt_rewrite": infer_config.get("enable_prompt_rewrite", True),
}
Loading