[Feat] support longcat image lora train#1075
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces a new LoRA training configuration for the longcat_image model and updates the model implementation to freeze the text encoder, include a text processor in the pipeline, and provide inference keyword arguments. Feedback includes correcting a typo in a method call, replacing absolute paths in the configuration with portable alternatives, restoring the device argument for consistency in prompt encoding, and ensuring the denoiser input preparation method remains compatible with the base class interface.
| prompt = sample["prompt"] | ||
| if self.config.get("enable_prompt_rewrite_training", False): | ||
| prompt = self.text_pipeline.rewrite_prompt(prompt, self.device) | ||
| prompt = self.text_pipeline.rewire_prompt(prompt, self.device) |
There was a problem hiding this comment.
Typo detected: rewire_prompt should likely be rewrite_prompt. The configuration key enable_prompt_rewrite_training and the previous version of the code both use "rewrite". This will cause an AttributeError at runtime if the method does not exist.
| prompt = self.text_pipeline.rewire_prompt(prompt, self.device) | |
| prompt = self.text_pipeline.rewrite_prompt(prompt, self.device) |
| @@ -0,0 +1,74 @@ | |||
| model: | |||
| name: longcat_image | |||
| pretrained_model_name_or_path: /mnt/miaohua/wangshankun/HF/hub/models--meituan-longcat--LongCat-Image/snapshots/d2ea50b79a930074c37b9b97ce45e3b2ea8cf4d8 | |||
There was a problem hiding this comment.
The configuration contains absolute paths specific to a local environment (e.g., /mnt/miaohua/wangshankun/...). This makes the configuration non-portable and likely to fail for other users or in different environments. Consider using relative paths or environment variables/placeholders for model and data paths.
| prompt_embed, text_ids = self.text_pipeline.encode_prompt( | ||
| prompt=prompt, | ||
| device=self.device, | ||
| num_images_per_prompt=1, | ||
| ) |
There was a problem hiding this comment.
The device argument was removed from the encode_prompt call. While the pipeline might default to its own device, explicitly passing self.device ensures consistency and avoids potential device mismatch issues, especially since it was explicitly provided in the previous version.
| prompt_embed, text_ids = self.text_pipeline.encode_prompt( | |
| prompt=prompt, | |
| device=self.device, | |
| num_images_per_prompt=1, | |
| ) | |
| prompt_embed, text_ids = self.text_pipeline.encode_prompt( | |
| prompt=prompt, | |
| device=self.device, | |
| num_images_per_prompt=1, | |
| ) |
| return {"prompt_embed": prompt_embed, "text_ids": text_ids} | ||
|
|
||
| def prepare_denoiser_input(self, noisy_latent, sample, condition): | ||
| def prepare_denoiser_input(self, noisy_latent): |
There was a problem hiding this comment.
The signature of prepare_denoiser_input has been changed to accept only one argument (noisy_latent), which violates the interface defined in the base class BaseModel (which expects noisy_latent, sample, condition). To maintain compatibility with the base class while supporting callers that only provide one argument (like the current LoraTrainer), consider using optional arguments.
| def prepare_denoiser_input(self, noisy_latent): | |
| def prepare_denoiser_input(self, noisy_latent, sample=None, condition=None): |
bf0e43c to
2dae37f
Compare
2dae37f to
6520131
Compare
No description provided.