Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RFC: Update-based training loop refactor, scale loss and token count calculation rather than gradients #2543

Closed
wants to merge 2 commits into from

Conversation

nathan-az
Copy link
Contributor

@nathan-az nathan-az commented Apr 1, 2025

Context

What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (enhancement)

#2515

Changelog

This motivation behind this PR was to remove gradient scaling. num_tokens is corrected to calculate the actual number of tokens in unique training examples for the update step, and loss is scaled accordingly (and adjusted for over-counting from non-dp workers). This required a larger refactor of the training loop to count up to update_steps rather than iterating through the dataloader, then grouping the prefetch of gradient accumulation batches.

Upsides

  • Less nesting
  • More logical grouping of gradient accumulation / forward/backward steps before update
  • Variables now hold the correct value (e.g. num_tokens now contains the actual number of tokens across the update step, on each rank)
  • Implicitly corrects calculation of TPS per GPU
  • Use of new context setters for methods like HSDP (e.g. is_last_backwards_step) can be localised to the forward/backward loop which should be quite clean
  • Pattern can relatively painlessly be reverted
  • Added a num_tokens metric - nice metric for gauging how effective packing a given dataset is

Downsides

  • Currently does not account for future use of loss parallel, context parallel, etc.
  • Profiler wait/warmup steps are now relative to update steps rather than forward steps. This may actually be desirable (I assumed this was the case until this refactor, and makes use of the term steps more consistent)
  • Prefetching examples technically uses slightly more memory (pedantic note, but still)
  • Likely to cause merge conflicts with in-flight work which includes changes to the training loop for this recipe

I recognise that this is beyond the scope of the initial motivation, but I think this format may be preferable in terms of readability and flexibility.

Losses are near identical (seeded with cudnn determinism mode enabled). There is very minor TPS benefit to avoiding the grad scaling. This is more substantial without gradient accumulation since the parameter read/write occurs more frequently. It may be more notable with much larger models.

image

Test plan

Please make sure to do each of the following if applicable to your PR. If you're unsure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.

  • run pre-commit hooks and linters (make sure you've first installed via pre-commit install)
  • add unit tests for any new functionality
  • update docstrings for any new or updated methods or classes
  • run unit tests via pytest tests
  • run recipe tests via pytest tests -m integration_test
  • manually run any new or modified recipes with sufficient proof of correctness
  • include relevant commands and any other artifacts in this summary (pastes of loss curves, eval results, etc.)

UX

If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Here is a docstring example
and a tutorial example

  • I did not change any public API
  • I have added an example to docs or docstrings

Copy link

pytorch-bot bot commented Apr 1, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2543

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 1, 2025
@nathan-az
Copy link
Contributor Author

FYI @joecummings this should allow HSDP and other collective-based optimisation strategies to. be quite clean :)

@nathan-az nathan-az changed the title RFC: Scaling loss and token count calculation rather than gradients in update step RFC: Scale loss and token count calculation rather than gradients, update-based training loop refactor Apr 1, 2025
@pbontrager
Copy link
Contributor

I know you mentioned that the memory increase is a minor thing, but I wonder if it'll always be minor if you're dealing with multimodal data. Another concern is that prefetching puts more pressure on the dataloader as training is blocked until it prepares n batches instead of 1 batch that it can prepare async to the model forward/backward. Once again this is likely a minor thing in a text only regime but becomes a bigger issue if data processing ever gets expensive.

I do think our gradient accumulation needs to be cleaned up though. I mocked up utility for it here. "is_last_backwards_step" could be included in the context manager as well.

@nathan-az
Copy link
Contributor Author

nathan-az commented Apr 2, 2025

Thanks for the comment @pbontrager! I like the idea of separating the gradient accumulation logic per your branch (will leave that discussion for a separate thread). I have some thoughts on your points:

memory increase is a minor thing, but I wonder if it'll always be minor if you're dealing with multimodal data

Agreed - in extreme cases inputs might consume a non-negligible amount of VRAM.

I've delayed the move to GPU to only include the example in the forward pass. I expect this is fine since RAM is more plentiful. I was concerned about the sum op on CPU, but testing on my very unimpressive home machine (Ryzen 5600x), performing this operation on batches of size (16, 2**17) indices required less than 2ms.

Another concern is that prefetching puts more pressure on the dataloader as training is blocked until it prepares n batches

Do we not pay this either way (is the total time of blocked operations not equal whether we front-load the examples for an update step or do it between forward/backward passes?) I'm pretty sure the changes in this PR don't actually make this any slower since the total blocking retrieval time should be the same, but please correct me.

Suppose this were an issue in a multi-modal setting. Could it be solved with a combination of prefetch_factor (I assume this is async), and num_workers? I think we would want this whether we followed the pattern in this PR or not.

Let me know what you think on the above, or if you have other thoughts!

@nathan-az nathan-az changed the title RFC: Scale loss and token count calculation rather than gradients, update-based training loop refactor RFC: Update-based training loop refactor, scale loss and token count calculation rather than gradients Apr 3, 2025
@ebsmothers
Copy link
Contributor

Hey @nathan-az sorry for the delay in getting to this one. Tbh I am a bit on the fence here -- I took a look at HF's prefetching approach when first implementing our version and I didn't like certain aspects (e.g. they tied the total number of tokens into the loss). It seems like your approach avoids that, which is nice (and it gets us out of the usage of scale_grads, which can cause its own perf hit). Overall though I am a bit wary -- maybe let's have a broader design discussion about what the right UX is prior to moving forward with this PR. Personally I think the prefetching logic makes the training loop harder to parse (I acknowledge the current token normalization logic is also tricky though), but maybe there is a way we can sequence things here. E.g. what would a good training utility for prefetching look like? If we can agree on that, then maybe we could see what the experience is integrating that utility into the recipe. Lmk if that makes sense to you. And thanks for opening this RFC! I think it's a worthwhile discussion for us to have.

@joecummings
Copy link
Contributor

Hey @nathan-az sorry for the delay in getting to this one. Tbh I am a bit on the fence here -- I took a look at HF's prefetching approach when first implementing our version and I didn't like certain aspects (e.g. they tied the total number of tokens into the loss). It seems like your approach avoids that, which is nice (and it gets us out of the usage of scale_grads, which can cause its own perf hit). Overall though I am a bit wary -- maybe let's have a broader design discussion about what the right UX is prior to moving forward with this PR. Personally I think the prefetching logic makes the training loop harder to parse (I acknowledge the current token normalization logic is also tricky though), but maybe there is a way we can sequence things here. E.g. what would a good training utility for prefetching look like? If we can agree on that, then maybe we could see what the experience is integrating that utility into the recipe. Lmk if that makes sense to you. And thanks for opening this RFC! I think it's a worthwhile discussion for us to have.

Also, please feel free to join our Discord if you haven't already: https://discord.gg/N4aVcgHS! Might be faster to communicate there.

@nathan-az
Copy link
Contributor Author

Thanks @joecummings (have joined the Discord server now!) and @ebsmothers for the comments!

I'll close this PR for now so it doesn't clog the list, but hopefully it can be used as a reference for a future Discussion thread (and if one is made I'll happily weigh in!)

To summarise my thoughts - I agree the prefetch is not the most intuitive, although I personally find the logical grouping of the forward/backward steps then update step more interpretable. In addition I think this loss scaling is a bit more interpretable than the grad scaling done currently (and it is more performant).

However I also agree that there is likely a better pattern that achieves all of this. And I understand it's likely not the highest priority with open PRs like LLaMA-4 support (including tp2ep, woo!) and fp8 training support in-flight!

@nathan-az nathan-az closed this Apr 9, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants