Skip to content

qwen edit 2509 训练lora效果很差 #23

@hello-program

Description

@hello-program

你好,我参考了 https://github.com/inclusionAI/TwinFlow/tree/main/src 的lora使用介绍去蒸馏qwen edit,但是效果很差,能否帮忙看看是哪里的问题呀,谢谢
这是我的推理代码:

import os
import sys
sys.path.append("/home/workspace/code/TwinFlow/src")
import torch
from functools import partial
from torchvision.utils import save_image
from PIL import Image
from torchvision import transforms

from src.networks.qwen_image.modeling_qwen_image import QwenImage
from peft import PeftModel
from unified_sampler import UnifiedSampler
from safetensors.torch import load_file
from torch.amp import autocast as torch_autocast

seed = 42
torch.manual_seed(seed)
device = torch.device("cuda")

base_model_path = "/home/workspace/hf_model/Qwen-Image-Edit-2509"
lora_checkpoint_path = "../outputs/qwenimage_task/qwenimage_lora_2order/checkpoints/global_step_30/model"

input_image_path = "/home/workspace/code/DiffSynth-Studio/example_image_dataset/edit/image1_F.jpg"
prompt = "将裙子改为粉色"

height = 512
width = 512

input_image = Image.open(input_image_path).convert("RGB").resize((width, height), Image.LANCZOS)
transform = transforms.Compose([
    transforms.Resize(min(height, width), interpolation=transforms.InterpolationMode.BILINEAR),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5]),
])
input_tensor = transform(input_image).unsqueeze(0).to(device, dtype=torch.bfloat16)

print("load base model...")
model = QwenImage(
    base_model_path,
    model_type="edit",
    aux_time_embed=False,
    text_dtype=torch.bfloat16,
    imgs_dtype=torch.bfloat16,
    device=device
).to(device)

print("load lora...")
base_transformer = model.model.transformer

if lora_checkpoint_path and os.path.exists(lora_checkpoint_path):
    print(f"Loading LoRA from {lora_checkpoint_path}...")
    
    lora_transformer = PeftModel.from_pretrained(
        base_transformer,
        lora_checkpoint_path,
        adapter_name="default",
        is_trainable=False
    )
    
    adapter_path = os.path.join(lora_checkpoint_path, "adapter_model.safetensors")
    raw_state_dict = load_file(adapter_path)
    final_state_dict = {}
    
    for key, value in raw_state_dict.items():
        new_key = key
        if "base_model.model.transformer." in new_key:
            new_key = new_key.replace("base_model.model.transformer.", "base_model.model.")
        if "lora_A.weight" in new_key and ".default.weight" not in new_key:
            new_key = new_key.replace("lora_A.weight", "lora_A.default.weight")
        elif "lora_B.weight" in new_key and ".default.weight" not in new_key:
            new_key = new_key.replace("lora_B.weight", "lora_B.default.weight")
        elif "lora_embedding_A" in new_key and ".default.weight" not in new_key:
            new_key = new_key.replace("lora_embedding_A.weight", "lora_embedding_A.default.weight")
        elif "lora_embedding_B" in new_key and ".default.weight" not in new_key:
            new_key = new_key.replace("lora_embedding_B.weight", "lora_embedding_B.default.weight")
        
        final_state_dict[new_key] = value
    
    missing, unexpected = lora_transformer.load_state_dict(final_state_dict, strict=False)
    

    real_missing = [k for k in missing if "lora" in k]
    real_unexpected = [k for k in unexpected if "lora" in k]
    
    if len(real_missing) > 0:
        print("missing")
    else:
        print("success!")

    lora_transformer.set_adapter("default")
    lora_transformer.to(device, dtype=torch.bfloat16)
    lora_transformer.eval()
    
    # 5. 替换模型中的 transformer
    model.model.transformer = lora_transformer
    model.transformer.transformer = lora_transformer
else:
    print("can not find lora")
model.eval()

sampler_config = {
    "sampling_steps": 4,
    "stochast_ratio": 1.0,
    "extrapol_ratio": 0.0,
    "sampling_order": 1,
    "time_dist_ctrl": [1.0, 1.0, 1.0],
    "rfba_gap_steps": [0.001, 0.5],
}
sampler = partial(UnifiedSampler().sampling_loop, **sampler_config)

with (
    torch.no_grad(),
    torch_autocast(enabled=True, dtype=torch.bfloat16, device_type="cuda"),
):
        edited_image = model.sample(
            prompts=[prompt],
            images=input_tensor,
            cfg_scale=0.0, 
            seed=seed,
            height=height,
            width=width,
            sampler=sampler,
            return_traj=False,
        )

save_image((edited_image.squeeze(0) + 1) / 2, "edited_output.jpg")
print("done")

我保存模型的实现

def save_ckpt(
    ckpt_root_dir,
    model_to_save,
    global_step,
):
    model_dir = os.path.join(ckpt_root_dir, f"global_step_{global_step}", "model")
    os.makedirs(model_dir, exist_ok=True)

    if hasattr(model_to_save.transformer, 'module'):
        model_to_save.transformer.module.save_pretrained(
            model_dir,
            safe_serialization=True  
        )
    else:
        model_to_save.transformer.save_pretrained(
            model_dir,
            safe_serialization=True
        )

左图是蒸馏前的图,右边是蒸馏后的,我只训练了一张图,跑了120个step,无法过拟合,配置文件没咋改,只是改了这2个地方

  model_name: QwenImageEdit #Flux
  aux_time_embed: false
  lora_rank: 64
  lora_alpha: 64
Image

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions