Skip to content

[diffusion, tests, doc] feat: add Qwen-Image SFT example#217

Open
zyd-ustc wants to merge 1 commit into
verl-project:mainfrom
zyd-ustc:qwen-image-sft-example
Open

[diffusion, tests, doc] feat: add Qwen-Image SFT example#217
zyd-ustc wants to merge 1 commit into
verl-project:mainfrom
zyd-ustc:qwen-image-sft-example

Conversation

@zyd-ustc

Copy link
Copy Markdown

What does this PR do?

This PR adds a standalone Qwen-Image SFT example for CoRT-style t2i / edit image data.

Main changes:

  • Add examples/qwen_image_sft_trainer/qwen_image_sft.py, a PyTorch/diffusers SFT entrypoint with LoRA, FSDP, CUDA/NPU/CPU device selection, checkpoint save/resume, and CoRT-style dataset expansion.
  • Add GPU and Ascend NPU LoRA launch scripts:
    • run_qwen_image_sft_lora.sh
    • run_qwen_image_sft_lora_npu.sh
  • Add README documentation for supported CoRT k-turn rows, atomic rows, weights, LoRA launch, NPU launch, and checkpoints.
  • Add CPU unit tests covering temporary dummy CoRT/atomic data expansion and the loss path.

This is a standalone recipe-style trainer, not a Ray/Hydra worker integration.

AI assistance was used to draft and prepare this change. The submitter has reviewed the changed files.

Checklist Before Starting

Test

PATH=/Users/zyd/miniconda3/bin:$PATH pre-commit run --files \
  examples/qwen_image_sft_trainer/README.md \
  examples/qwen_image_sft_trainer/qwen_image_sft.py \
  examples/qwen_image_sft_trainer/run_qwen_image_sft_lora.sh \
  examples/qwen_image_sft_trainer/run_qwen_image_sft_lora_npu.sh \
  tests/trainer/diffusion/test_qwen_image_sft_example_on_cpu.py

Add a standalone Qwen-Image SFT example for CoRT-style t2i/edit data, with LoRA/FSDP/NPU launch scripts and CPU unit tests.

Co-authored-by: GitHub Copilot

Signed-off-by: zhiyida2004 <zyd2004@mail.ustc.edu.cn>

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a standalone SFT trainer for Qwen-Image using a diffusion flow-matching MSE objective, along with supporting datasets, training scripts for GPU/NPU, and unit tests. The review feedback identifies several critical issues: a mismatch in global_step tracking when gradient accumulation is enabled, incorrect saving and loading of optimizer states under FSDP (which requires using FSDP-specific state dict utilities), and a performance bottleneck in latent encoding caused by recreating normalization tensors on every batch instead of caching them.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment on lines +1059 to +1104
optimizer.zero_grad(set_to_none=True)
for epoch in range(args.total_epochs):
train_sampler.set_epoch(epoch)
iterator = tqdm(train_loader, disable=rank != 0, desc=f"epoch {epoch + 1}/{args.total_epochs}")
for batch in iterator:
grad_norm = torch.tensor(0.0, device=device)
loss = compute_batch_loss(pipe, batch, args, device, dtype)
scaled_loss = loss / args.gradient_accumulation_steps
scaled_loss.backward()

if (global_step + 1) % args.gradient_accumulation_steps == 0:
if args.max_grad_norm > 0:
if use_fsdp:
grad_norm = pipe.transformer.clip_grad_norm_(args.max_grad_norm)
else:
grad_norm = torch.nn.utils.clip_grad_norm_(pipe.transformer.parameters(), args.max_grad_norm)
else:
grad_norm = torch.tensor(0.0, device=device)
optimizer.step()
scheduler.step()
optimizer.zero_grad(set_to_none=True)

global_step += 1
if rank == 0 and global_step % args.log_freq == 0:
lr = scheduler.get_last_lr()[0]
iterator.set_postfix(loss=f"{loss.detach().item():.4f}", lr=f"{lr:.2e}")
LOGGER.info(
"step=%d loss=%.6f lr=%.6e grad_norm=%.4f",
global_step,
loss.detach().item(),
lr,
float(grad_norm.detach().item()) if torch.is_tensor(grad_norm) else float(grad_norm),
)

if args.test_freq > 0 and val_loader is not None and global_step % args.test_freq == 0:
val_loss = validate(pipe, val_loader, args, device, dtype)
if rank == 0:
LOGGER.info("step=%d val/loss=%.6f", global_step, val_loss)

if args.save_freq > 0 and global_step % args.save_freq == 0:
save_checkpoint(pipe, optimizer, scheduler, args, global_step, rank, use_fsdp)

if global_step >= args.total_training_steps:
save_checkpoint(pipe, optimizer, scheduler, args, global_step, rank, use_fsdp)
cleanup_distributed()
return

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There is a mismatch between global_step (which is incremented on every batch) and scheduler.step() (which is called only on optimizer steps, i.e., once every gradient_accumulation_steps batches). Since command-line arguments like total_training_steps and warmup_steps are compared against global_step in some places but used as optimizer steps in the scheduler, the learning rate schedule will not decay properly and training will terminate prematurely. Tracking accumulation steps separately and incrementing global_step only on optimizer updates resolves this mismatch.

    optimizer.zero_grad(set_to_none=True)
    batch_accumulator = 0
    for epoch in range(args.total_epochs):
        train_sampler.set_epoch(epoch)
        iterator = tqdm(train_loader, disable=rank != 0, desc=f"epoch {epoch + 1}/{args.total_epochs}")
        for batch in iterator:
            grad_norm = torch.tensor(0.0, device=device)
            loss = compute_batch_loss(pipe, batch, args, device, dtype)
            scaled_loss = loss / args.gradient_accumulation_steps
            scaled_loss.backward()
            batch_accumulator += 1

            if batch_accumulator % args.gradient_accumulation_steps == 0:
                if args.max_grad_norm > 0:
                    if use_fsdp:
                        grad_norm = pipe.transformer.clip_grad_norm_(args.max_grad_norm)
                    else:
                        grad_norm = torch.nn.utils.clip_grad_norm_(pipe.transformer.parameters(), args.max_grad_norm)
                else:
                    grad_norm = torch.tensor(0.0, device=device)
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad(set_to_none=True)
                global_step += 1

                if rank == 0 and global_step % args.log_freq == 0:
                    lr = scheduler.get_last_lr()[0]
                    iterator.set_postfix(loss=f"{loss.detach().item():.4f}", lr=f"{lr:.2e}")
                    LOGGER.info(
                        "step=%d loss=%.6f lr=%.6e grad_norm=%.4f",
                        global_step,
                        loss.detach().item(),
                        lr,
                        float(grad_norm.detach().item()) if torch.is_tensor(grad_norm) else float(grad_norm),
                    )

                if args.test_freq > 0 and val_loader is not None and global_step % args.test_freq == 0:
                    val_loss = validate(pipe, val_loader, args, device, dtype)
                    if rank == 0:
                        LOGGER.info(
                            "step=%d val/loss=%.6f",
                            global_step,
                            val_loss,
                        )

                if args.save_freq > 0 and global_step % args.save_freq == 0:
                    save_checkpoint(pipe, optimizer, scheduler, args, global_step, rank, use_fsdp)

                if global_step >= args.total_training_steps:
                    save_checkpoint(pipe, optimizer, scheduler, args, global_step, rank, use_fsdp)
                    cleanup_distributed()
                    return

Comment on lines +926 to +945
transformer = pipe.transformer
state_dict = None
if use_fsdp:
save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
with FSDP.state_dict_type(transformer, StateDictType.FULL_STATE_DICT, save_policy):
state_dict = transformer.state_dict()
else:
state_dict = transformer.state_dict()

if rank == 0:
unwrapped = unwrap_transformer(transformer)
unwrapped.save_pretrained(ckpt_dir / "transformer", state_dict=state_dict, safe_serialization=True)
torch.save(
{
"optimizer": optimizer.state_dict(),
"scheduler": scheduler.state_dict(),
"step": step,
},
ckpt_dir / "trainer_state.pt",
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Under FSDP, calling optimizer.state_dict() directly on rank 0 only saves the sharded optimizer state for that rank. To correctly save the full optimizer state dict, you must use FSDP.optim_state_dict(transformer, optimizer) on all ranks within the FSDP state dict type context.

    transformer = pipe.transformer
    state_dict = None
    optim_state = None
    if use_fsdp:
        save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
        with FSDP.state_dict_type(transformer, StateDictType.FULL_STATE_DICT, save_policy):
            state_dict = transformer.state_dict()
            optim_state = FSDP.optim_state_dict(transformer, optimizer)
    else:
        state_dict = transformer.state_dict()
        optim_state = optimizer.state_dict()

    if rank == 0:
        unwrapped = unwrap_transformer(transformer)
        unwrapped.save_pretrained(ckpt_dir / "transformer", state_dict=state_dict, safe_serialization=True)
        torch.save(
            {
                "optimizer": optim_state,
                "scheduler": scheduler.state_dict(),
                "step": step,
            },
            ckpt_dir / "trainer_state.pt",
        )

Comment on lines +953 to +964
def load_training_state(path: str, optimizer, scheduler, rank: int) -> int:
state_path = Path(path) / "trainer_state.pt"
if not state_path.exists():
LOGGER.warning("No trainer_state.pt found under %s; model resume must be handled via model_name_or_path.", path)
return 0
state = torch.load(state_path, map_location="cpu")
optimizer.load_state_dict(state["optimizer"])
scheduler.load_state_dict(state["scheduler"])
step = int(state.get("step", 0))
if rank == 0:
LOGGER.info("Loaded trainer state from %s at step %d", state_path, step)
return step

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

When resuming training under FSDP, loading a full optimizer state dict directly into a sharded optimizer will fail or be incorrect. You must use FSDP.optim_state_dict_to_load to translate the full optimizer state dict into the sharded/flattened form expected by FSDP.

Suggested change
def load_training_state(path: str, optimizer, scheduler, rank: int) -> int:
state_path = Path(path) / "trainer_state.pt"
if not state_path.exists():
LOGGER.warning("No trainer_state.pt found under %s; model resume must be handled via model_name_or_path.", path)
return 0
state = torch.load(state_path, map_location="cpu")
optimizer.load_state_dict(state["optimizer"])
scheduler.load_state_dict(state["scheduler"])
step = int(state.get("step", 0))
if rank == 0:
LOGGER.info("Loaded trainer state from %s at step %d", state_path, step)
return step
def load_training_state(transformer, optimizer, scheduler, path: str, rank: int, use_fsdp: bool) -> int:
state_path = Path(path) / "trainer_state.pt"
if not state_path.exists():
LOGGER.warning("No trainer_state.pt found under %s; model resume must be handled via model_name_or_path.", path)
return 0
state = torch.load(state_path, map_location="cpu")
if use_fsdp:
sharded_osd = FSDP.optim_state_dict_to_load(transformer, optimizer, state["optimizer"])
optimizer.load_state_dict(sharded_osd)
else:
optimizer.load_state_dict(state["optimizer"])
scheduler.load_state_dict(state["scheduler"])
step = int(state.get("step", 0))
if rank == 0:
LOGGER.info("Loaded trainer state from %s at step %d", state_path, step)
return step


optimizer = create_optimizer(pipe, args)
scheduler = create_scheduler(optimizer, args)
global_step = load_training_state(args.resume_from, optimizer, scheduler, rank) if args.resume_from else 0

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Update the call site of load_training_state to pass the transformer and the use_fsdp flag.

Suggested change
global_step = load_training_state(args.resume_from, optimizer, scheduler, rank) if args.resume_from else 0
global_step = load_training_state(pipe.transformer, optimizer, scheduler, args.resume_from, rank, use_fsdp) if args.resume_from else 0

Comment on lines +727 to +734
z_dim = int(getattr(pipe.vae.config, "z_dim", image_latents.shape[1]))
latents_mean = (
torch.tensor(pipe.vae.config.latents_mean).view(1, z_dim, 1, 1, 1).to(image_latents.device, image_latents.dtype)
)
latents_std = (
torch.tensor(pipe.vae.config.latents_std).view(1, z_dim, 1, 1, 1).to(image_latents.device, image_latents.dtype)
)
image_latents = (image_latents - latents_mean) / latents_std

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Recreating latents_mean and latents_std tensors from configuration lists on CPU and copying them to GPU on every single batch is inefficient and introduces host-device synchronization overhead. Caching these tensors on pipe.vae avoids this overhead.

Suggested change
z_dim = int(getattr(pipe.vae.config, "z_dim", image_latents.shape[1]))
latents_mean = (
torch.tensor(pipe.vae.config.latents_mean).view(1, z_dim, 1, 1, 1).to(image_latents.device, image_latents.dtype)
)
latents_std = (
torch.tensor(pipe.vae.config.latents_std).view(1, z_dim, 1, 1, 1).to(image_latents.device, image_latents.dtype)
)
image_latents = (image_latents - latents_mean) / latents_std
z_dim = int(getattr(pipe.vae.config, "z_dim", image_latents.shape[1]))
if not hasattr(pipe.vae, "_latents_mean_tensor"):
pipe.vae._latents_mean_tensor = torch.tensor(pipe.vae.config.latents_mean).view(1, z_dim, 1, 1, 1)
pipe.vae._latents_std_tensor = torch.tensor(pipe.vae.config.latents_std).view(1, z_dim, 1, 1, 1)
latents_mean = pipe.vae._latents_mean_tensor.to(device=image_latents.device, dtype=image_latents.dtype)
latents_std = pipe.vae._latents_std_tensor.to(device=image_latents.device, dtype=image_latents.dtype)
image_latents = (image_latents - latents_mean) / latents_std

@SamitHuang

Copy link
Copy Markdown
Collaborator

Thanks for the PR. Can you attach the training results for reference?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants