Gradient sync not work on some model architectures with FSDP in TorchTitan #1014
Replies: 4 comments 2 replies
-
cc @mori360 |
Beta Was this translation helpful? Give feedback.
-
@JohanSchalkwyk1 Thanks for the issue. |
Beta Was this translation helpful? Give feedback.
-
Some update on things I've tried that works. The following change results in weights being updated [rank0]: (adapter): OptimizedModule( Note here every layer became an FSDP Layer. The following does not work [rank0]: (adapter): OptimizedModule( i.e FSDP is only on the outer layer. Looking at transformer FSDP it doesnt look like it wraps linear in an FSDPLinear |
Beta Was this translation helpful? Give feedback.
-
transformer structure would be as follows [rank0]: (0): FSDPOptimizedModule( FSDP is only on the top level |
Beta Was this translation helpful? Give feedback.
-
I have the following model architecture (Speech Language Model), that I parallelize with FSDP. The basic design looks like this
Only the weights of the FSDPAdapter (26M) are trained. I set param.requires_grad accordingly.
I've noticed issues that the model never trains. I track both grad_norm and per layer and the weight_norm of the varaibles, and none of them change. (for example the weight-norm of the RMSNorm layer stays constant throughout training. After some time debugging I tried the following composition
FSDPWhisperEncoder -> DDPAdapter -> FSDPLlamaLLM
Which does train correctly (on a smaller model). Losses converge, grad norm goes down,etc, but on larger models seem unstable (nccl timeouts). I.e this doesnt feel like a robust solution. It feels like the standard FSDP approach should work. Any advice on how to pinpoint why the model is not updating under this condition will be much appreciated.
Beta Was this translation helpful? Give feedback.
All reactions