[diffusion, tests, doc] feat: add Qwen-Image SFT example#217
Conversation
Add a standalone Qwen-Image SFT example for CoRT-style t2i/edit data, with LoRA/FSDP/NPU launch scripts and CPU unit tests. Co-authored-by: GitHub Copilot Signed-off-by: zhiyida2004 <zyd2004@mail.ustc.edu.cn>
There was a problem hiding this comment.
Code Review
This pull request introduces a standalone SFT trainer for Qwen-Image using a diffusion flow-matching MSE objective, along with supporting datasets, training scripts for GPU/NPU, and unit tests. The review feedback identifies several critical issues: a mismatch in global_step tracking when gradient accumulation is enabled, incorrect saving and loading of optimizer states under FSDP (which requires using FSDP-specific state dict utilities), and a performance bottleneck in latent encoding caused by recreating normalization tensors on every batch instead of caching them.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| optimizer.zero_grad(set_to_none=True) | ||
| for epoch in range(args.total_epochs): | ||
| train_sampler.set_epoch(epoch) | ||
| iterator = tqdm(train_loader, disable=rank != 0, desc=f"epoch {epoch + 1}/{args.total_epochs}") | ||
| for batch in iterator: | ||
| grad_norm = torch.tensor(0.0, device=device) | ||
| loss = compute_batch_loss(pipe, batch, args, device, dtype) | ||
| scaled_loss = loss / args.gradient_accumulation_steps | ||
| scaled_loss.backward() | ||
|
|
||
| if (global_step + 1) % args.gradient_accumulation_steps == 0: | ||
| if args.max_grad_norm > 0: | ||
| if use_fsdp: | ||
| grad_norm = pipe.transformer.clip_grad_norm_(args.max_grad_norm) | ||
| else: | ||
| grad_norm = torch.nn.utils.clip_grad_norm_(pipe.transformer.parameters(), args.max_grad_norm) | ||
| else: | ||
| grad_norm = torch.tensor(0.0, device=device) | ||
| optimizer.step() | ||
| scheduler.step() | ||
| optimizer.zero_grad(set_to_none=True) | ||
|
|
||
| global_step += 1 | ||
| if rank == 0 and global_step % args.log_freq == 0: | ||
| lr = scheduler.get_last_lr()[0] | ||
| iterator.set_postfix(loss=f"{loss.detach().item():.4f}", lr=f"{lr:.2e}") | ||
| LOGGER.info( | ||
| "step=%d loss=%.6f lr=%.6e grad_norm=%.4f", | ||
| global_step, | ||
| loss.detach().item(), | ||
| lr, | ||
| float(grad_norm.detach().item()) if torch.is_tensor(grad_norm) else float(grad_norm), | ||
| ) | ||
|
|
||
| if args.test_freq > 0 and val_loader is not None and global_step % args.test_freq == 0: | ||
| val_loss = validate(pipe, val_loader, args, device, dtype) | ||
| if rank == 0: | ||
| LOGGER.info("step=%d val/loss=%.6f", global_step, val_loss) | ||
|
|
||
| if args.save_freq > 0 and global_step % args.save_freq == 0: | ||
| save_checkpoint(pipe, optimizer, scheduler, args, global_step, rank, use_fsdp) | ||
|
|
||
| if global_step >= args.total_training_steps: | ||
| save_checkpoint(pipe, optimizer, scheduler, args, global_step, rank, use_fsdp) | ||
| cleanup_distributed() | ||
| return |
There was a problem hiding this comment.
There is a mismatch between global_step (which is incremented on every batch) and scheduler.step() (which is called only on optimizer steps, i.e., once every gradient_accumulation_steps batches). Since command-line arguments like total_training_steps and warmup_steps are compared against global_step in some places but used as optimizer steps in the scheduler, the learning rate schedule will not decay properly and training will terminate prematurely. Tracking accumulation steps separately and incrementing global_step only on optimizer updates resolves this mismatch.
optimizer.zero_grad(set_to_none=True)
batch_accumulator = 0
for epoch in range(args.total_epochs):
train_sampler.set_epoch(epoch)
iterator = tqdm(train_loader, disable=rank != 0, desc=f"epoch {epoch + 1}/{args.total_epochs}")
for batch in iterator:
grad_norm = torch.tensor(0.0, device=device)
loss = compute_batch_loss(pipe, batch, args, device, dtype)
scaled_loss = loss / args.gradient_accumulation_steps
scaled_loss.backward()
batch_accumulator += 1
if batch_accumulator % args.gradient_accumulation_steps == 0:
if args.max_grad_norm > 0:
if use_fsdp:
grad_norm = pipe.transformer.clip_grad_norm_(args.max_grad_norm)
else:
grad_norm = torch.nn.utils.clip_grad_norm_(pipe.transformer.parameters(), args.max_grad_norm)
else:
grad_norm = torch.tensor(0.0, device=device)
optimizer.step()
scheduler.step()
optimizer.zero_grad(set_to_none=True)
global_step += 1
if rank == 0 and global_step % args.log_freq == 0:
lr = scheduler.get_last_lr()[0]
iterator.set_postfix(loss=f"{loss.detach().item():.4f}", lr=f"{lr:.2e}")
LOGGER.info(
"step=%d loss=%.6f lr=%.6e grad_norm=%.4f",
global_step,
loss.detach().item(),
lr,
float(grad_norm.detach().item()) if torch.is_tensor(grad_norm) else float(grad_norm),
)
if args.test_freq > 0 and val_loader is not None and global_step % args.test_freq == 0:
val_loss = validate(pipe, val_loader, args, device, dtype)
if rank == 0:
LOGGER.info(
"step=%d val/loss=%.6f",
global_step,
val_loss,
)
if args.save_freq > 0 and global_step % args.save_freq == 0:
save_checkpoint(pipe, optimizer, scheduler, args, global_step, rank, use_fsdp)
if global_step >= args.total_training_steps:
save_checkpoint(pipe, optimizer, scheduler, args, global_step, rank, use_fsdp)
cleanup_distributed()
return| transformer = pipe.transformer | ||
| state_dict = None | ||
| if use_fsdp: | ||
| save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) | ||
| with FSDP.state_dict_type(transformer, StateDictType.FULL_STATE_DICT, save_policy): | ||
| state_dict = transformer.state_dict() | ||
| else: | ||
| state_dict = transformer.state_dict() | ||
|
|
||
| if rank == 0: | ||
| unwrapped = unwrap_transformer(transformer) | ||
| unwrapped.save_pretrained(ckpt_dir / "transformer", state_dict=state_dict, safe_serialization=True) | ||
| torch.save( | ||
| { | ||
| "optimizer": optimizer.state_dict(), | ||
| "scheduler": scheduler.state_dict(), | ||
| "step": step, | ||
| }, | ||
| ckpt_dir / "trainer_state.pt", | ||
| ) |
There was a problem hiding this comment.
Under FSDP, calling optimizer.state_dict() directly on rank 0 only saves the sharded optimizer state for that rank. To correctly save the full optimizer state dict, you must use FSDP.optim_state_dict(transformer, optimizer) on all ranks within the FSDP state dict type context.
transformer = pipe.transformer
state_dict = None
optim_state = None
if use_fsdp:
save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
with FSDP.state_dict_type(transformer, StateDictType.FULL_STATE_DICT, save_policy):
state_dict = transformer.state_dict()
optim_state = FSDP.optim_state_dict(transformer, optimizer)
else:
state_dict = transformer.state_dict()
optim_state = optimizer.state_dict()
if rank == 0:
unwrapped = unwrap_transformer(transformer)
unwrapped.save_pretrained(ckpt_dir / "transformer", state_dict=state_dict, safe_serialization=True)
torch.save(
{
"optimizer": optim_state,
"scheduler": scheduler.state_dict(),
"step": step,
},
ckpt_dir / "trainer_state.pt",
)| def load_training_state(path: str, optimizer, scheduler, rank: int) -> int: | ||
| state_path = Path(path) / "trainer_state.pt" | ||
| if not state_path.exists(): | ||
| LOGGER.warning("No trainer_state.pt found under %s; model resume must be handled via model_name_or_path.", path) | ||
| return 0 | ||
| state = torch.load(state_path, map_location="cpu") | ||
| optimizer.load_state_dict(state["optimizer"]) | ||
| scheduler.load_state_dict(state["scheduler"]) | ||
| step = int(state.get("step", 0)) | ||
| if rank == 0: | ||
| LOGGER.info("Loaded trainer state from %s at step %d", state_path, step) | ||
| return step |
There was a problem hiding this comment.
When resuming training under FSDP, loading a full optimizer state dict directly into a sharded optimizer will fail or be incorrect. You must use FSDP.optim_state_dict_to_load to translate the full optimizer state dict into the sharded/flattened form expected by FSDP.
| def load_training_state(path: str, optimizer, scheduler, rank: int) -> int: | |
| state_path = Path(path) / "trainer_state.pt" | |
| if not state_path.exists(): | |
| LOGGER.warning("No trainer_state.pt found under %s; model resume must be handled via model_name_or_path.", path) | |
| return 0 | |
| state = torch.load(state_path, map_location="cpu") | |
| optimizer.load_state_dict(state["optimizer"]) | |
| scheduler.load_state_dict(state["scheduler"]) | |
| step = int(state.get("step", 0)) | |
| if rank == 0: | |
| LOGGER.info("Loaded trainer state from %s at step %d", state_path, step) | |
| return step | |
| def load_training_state(transformer, optimizer, scheduler, path: str, rank: int, use_fsdp: bool) -> int: | |
| state_path = Path(path) / "trainer_state.pt" | |
| if not state_path.exists(): | |
| LOGGER.warning("No trainer_state.pt found under %s; model resume must be handled via model_name_or_path.", path) | |
| return 0 | |
| state = torch.load(state_path, map_location="cpu") | |
| if use_fsdp: | |
| sharded_osd = FSDP.optim_state_dict_to_load(transformer, optimizer, state["optimizer"]) | |
| optimizer.load_state_dict(sharded_osd) | |
| else: | |
| optimizer.load_state_dict(state["optimizer"]) | |
| scheduler.load_state_dict(state["scheduler"]) | |
| step = int(state.get("step", 0)) | |
| if rank == 0: | |
| LOGGER.info("Loaded trainer state from %s at step %d", state_path, step) | |
| return step |
|
|
||
| optimizer = create_optimizer(pipe, args) | ||
| scheduler = create_scheduler(optimizer, args) | ||
| global_step = load_training_state(args.resume_from, optimizer, scheduler, rank) if args.resume_from else 0 |
There was a problem hiding this comment.
Update the call site of load_training_state to pass the transformer and the use_fsdp flag.
| global_step = load_training_state(args.resume_from, optimizer, scheduler, rank) if args.resume_from else 0 | |
| global_step = load_training_state(pipe.transformer, optimizer, scheduler, args.resume_from, rank, use_fsdp) if args.resume_from else 0 |
| z_dim = int(getattr(pipe.vae.config, "z_dim", image_latents.shape[1])) | ||
| latents_mean = ( | ||
| torch.tensor(pipe.vae.config.latents_mean).view(1, z_dim, 1, 1, 1).to(image_latents.device, image_latents.dtype) | ||
| ) | ||
| latents_std = ( | ||
| torch.tensor(pipe.vae.config.latents_std).view(1, z_dim, 1, 1, 1).to(image_latents.device, image_latents.dtype) | ||
| ) | ||
| image_latents = (image_latents - latents_mean) / latents_std |
There was a problem hiding this comment.
Recreating latents_mean and latents_std tensors from configuration lists on CPU and copying them to GPU on every single batch is inefficient and introduces host-device synchronization overhead. Caching these tensors on pipe.vae avoids this overhead.
| z_dim = int(getattr(pipe.vae.config, "z_dim", image_latents.shape[1])) | |
| latents_mean = ( | |
| torch.tensor(pipe.vae.config.latents_mean).view(1, z_dim, 1, 1, 1).to(image_latents.device, image_latents.dtype) | |
| ) | |
| latents_std = ( | |
| torch.tensor(pipe.vae.config.latents_std).view(1, z_dim, 1, 1, 1).to(image_latents.device, image_latents.dtype) | |
| ) | |
| image_latents = (image_latents - latents_mean) / latents_std | |
| z_dim = int(getattr(pipe.vae.config, "z_dim", image_latents.shape[1])) | |
| if not hasattr(pipe.vae, "_latents_mean_tensor"): | |
| pipe.vae._latents_mean_tensor = torch.tensor(pipe.vae.config.latents_mean).view(1, z_dim, 1, 1, 1) | |
| pipe.vae._latents_std_tensor = torch.tensor(pipe.vae.config.latents_std).view(1, z_dim, 1, 1, 1) | |
| latents_mean = pipe.vae._latents_mean_tensor.to(device=image_latents.device, dtype=image_latents.dtype) | |
| latents_std = pipe.vae._latents_std_tensor.to(device=image_latents.device, dtype=image_latents.dtype) | |
| image_latents = (image_latents - latents_mean) / latents_std |
|
Thanks for the PR. Can you attach the training results for reference? |
What does this PR do?
This PR adds a standalone Qwen-Image SFT example for CoRT-style
t2i/editimage data.Main changes:
examples/qwen_image_sft_trainer/qwen_image_sft.py, a PyTorch/diffusers SFT entrypoint with LoRA, FSDP, CUDA/NPU/CPU device selection, checkpoint save/resume, and CoRT-style dataset expansion.run_qwen_image_sft_lora.shrun_qwen_image_sft_lora_npu.shThis is a standalone recipe-style trainer, not a Ray/Hydra worker integration.
AI assistance was used to draft and prepare this change. The submitter has reviewed the changed files.
Checklist Before Starting
[{modules}] {type}: {description}.Test
PATH=/Users/zyd/miniconda3/bin:$PATH pre-commit run --files \ examples/qwen_image_sft_trainer/README.md \ examples/qwen_image_sft_trainer/qwen_image_sft.py \ examples/qwen_image_sft_trainer/run_qwen_image_sft_lora.sh \ examples/qwen_image_sft_trainer/run_qwen_image_sft_lora_npu.sh \ tests/trainer/diffusion/test_qwen_image_sft_example_on_cpu.py