Bug Report
Describe the bug
When training with --compile enabled, the training crashes at the epoch where GAN loss is first activated. The error occurs inside calculate_adaptive_weight(), which calls torch.autograd.grad(..., retain_graph=True) twice to compute gradient norms for adaptive GAN weight balancing.
The compiled backward function (via AOT Autograd) enables donated buffer optimization by default in PyTorch ≥ 2.6, which is incompatible with retain_graph=True.
Environment
PyTorch: 2.8.0+cu129
Python: 3.10
Training script: src/train_stage1.py
To Reproduce
Enable --compile in training config
Set gan_start_epoch to any epoch > 0 (so GAN loss is delayed)
Training runs fine until GAN loss activates, then crashes immediately
Error Message
RuntimeError: This backward function was compiled with non-empty donated buffers
which requires create_graph=False and retain_graph=False. Please keep
backward(create_graph=False, retain_graph=False) across all backward() function
calls, or set torch._functorch.config.donated_buffer=False to disable donated buffer.
Full traceback points to:
File "src/train_stage1.py", line 74, in calculate_adaptive_weight
recon_grads = torch.autograd.grad(recon_loss, layer, retain_graph=True)[0]
Bug Report
Describe the bug
When training with --compile enabled, the training crashes at the epoch where GAN loss is first activated. The error occurs inside calculate_adaptive_weight(), which calls torch.autograd.grad(..., retain_graph=True) twice to compute gradient norms for adaptive GAN weight balancing.
The compiled backward function (via AOT Autograd) enables donated buffer optimization by default in PyTorch ≥ 2.6, which is incompatible with retain_graph=True.
Environment
PyTorch: 2.8.0+cu129
Python: 3.10
Training script: src/train_stage1.py
To Reproduce
Enable --compile in training config
Set gan_start_epoch to any epoch > 0 (so GAN loss is delayed)
Training runs fine until GAN loss activates, then crashes immediately
Error Message
RuntimeError: This backward function was compiled with non-empty donated buffers
which requires create_graph=False and retain_graph=False. Please keep
backward(create_graph=False, retain_graph=False) across all backward() function
calls, or set torch._functorch.config.donated_buffer=False to disable donated buffer.
Full traceback points to:
File "src/train_stage1.py", line 74, in calculate_adaptive_weight
recon_grads = torch.autograd.grad(recon_loss, layer, retain_graph=True)[0]