-
Notifications
You must be signed in to change notification settings - Fork 572
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
RFC: Update-based training loop refactor, scale loss and token count calculation rather than gradients #2543
Conversation
…ate grad scaling Signed-off-by: Nathan Azrak <[email protected]>
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2543
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
FYI @joecummings this should allow HSDP and other collective-based optimisation strategies to. be quite clean :) |
I know you mentioned that the memory increase is a minor thing, but I wonder if it'll always be minor if you're dealing with multimodal data. Another concern is that prefetching puts more pressure on the dataloader as training is blocked until it prepares n batches instead of 1 batch that it can prepare async to the model forward/backward. Once again this is likely a minor thing in a text only regime but becomes a bigger issue if data processing ever gets expensive. I do think our gradient accumulation needs to be cleaned up though. I mocked up utility for it here. "is_last_backwards_step" could be included in the context manager as well. |
Signed-off-by: Nathan Azrak <[email protected]>
Thanks for the comment @pbontrager! I like the idea of separating the gradient accumulation logic per your branch (will leave that discussion for a separate thread). I have some thoughts on your points:
Agreed - in extreme cases inputs might consume a non-negligible amount of VRAM. I've delayed the move to GPU to only include the example in the forward pass. I expect this is fine since RAM is more plentiful. I was concerned about the sum op on CPU, but testing on my very unimpressive home machine (Ryzen 5600x), performing this operation on batches of size (16, 2**17) indices required less than 2ms.
Do we not pay this either way (is the total time of blocked operations not equal whether we front-load the examples for an update step or do it between forward/backward passes?) I'm pretty sure the changes in this PR don't actually make this any slower since the total blocking retrieval time should be the same, but please correct me. Suppose this were an issue in a multi-modal setting. Could it be solved with a combination of Let me know what you think on the above, or if you have other thoughts! |
Hey @nathan-az sorry for the delay in getting to this one. Tbh I am a bit on the fence here -- I took a look at HF's prefetching approach when first implementing our version and I didn't like certain aspects (e.g. they tied the total number of tokens into the loss). It seems like your approach avoids that, which is nice (and it gets us out of the usage of scale_grads, which can cause its own perf hit). Overall though I am a bit wary -- maybe let's have a broader design discussion about what the right UX is prior to moving forward with this PR. Personally I think the prefetching logic makes the training loop harder to parse (I acknowledge the current token normalization logic is also tricky though), but maybe there is a way we can sequence things here. E.g. what would a good training utility for prefetching look like? If we can agree on that, then maybe we could see what the experience is integrating that utility into the recipe. Lmk if that makes sense to you. And thanks for opening this RFC! I think it's a worthwhile discussion for us to have. |
Also, please feel free to join our Discord if you haven't already: https://discord.gg/N4aVcgHS! Might be faster to communicate there. |
Thanks @joecummings (have joined the Discord server now!) and @ebsmothers for the comments! I'll close this PR for now so it doesn't clog the list, but hopefully it can be used as a reference for a future Discussion thread (and if one is made I'll happily weigh in!) To summarise my thoughts - I agree the prefetch is not the most intuitive, although I personally find the logical grouping of the forward/backward steps then update step more interpretable. In addition I think this loss scaling is a bit more interpretable than the grad scaling done currently (and it is more performant). However I also agree that there is likely a better pattern that achieves all of this. And I understand it's likely not the highest priority with open PRs like LLaMA-4 support (including tp2ep, woo!) and fp8 training support in-flight! |
Context
What is the purpose of this PR? Is it to
#2515
Changelog
This motivation behind this PR was to remove gradient scaling.
num_tokens
is corrected to calculate the actual number of tokens in unique training examples for the update step, andloss
is scaled accordingly (and adjusted for over-counting from non-dp workers). This required a larger refactor of the training loop to count up toupdate_steps
rather than iterating through the dataloader, then grouping the prefetch of gradient accumulation batches.Upsides
num_tokens
now contains the actual number of tokens across the update step, on each rank)is_last_backwards_step
) can be localised to the forward/backward loop which should be quite cleannum_tokens
metric - nice metric for gauging how effective packing a given dataset isDownsides
steps
more consistent)I recognise that this is beyond the scope of the initial motivation, but I think this format may be preferable in terms of readability and flexibility.
Losses are near identical (seeded with cudnn determinism mode enabled). There is very minor TPS benefit to avoiding the grad scaling. This is more substantial without gradient accumulation since the parameter read/write occurs more frequently. It may be more notable with much larger models.
Test plan
Please make sure to do each of the following if applicable to your PR. If you're unsure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.
pre-commit install
)pytest tests
pytest tests -m integration_test
UX
If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Here is a docstring example
and a tutorial example