-
-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Reduce JAX post-processing memory usage #7311
Conversation
] |
Looks like test failures are due to an older version of jax and a recent blackJAX PR to fix an argument deprecation to jnp.clip: blackjax-devs/blackjax#664 |
Also added in a bugfix for pm.sample not respecting compute_convergence_checks with numpyro/blackjax sampler |
Thanks @andrewdipper! |
Probably need to wait for #7317 to merge. |
Let's start deprecating it! Can you add a FutureWarning about this argument/ value being removed in the future when the user manually specifies it? |
Rebasing from main should fix the test failures and allow us to confirm nothing got broken |
I added the fix for reducing memory from the blackjax window_adaptation. But that depends on blackjax-devs/blackjax#674 which was just merged earlier today. So blackjax would have to be up to date - not sure how that's best handled. |
Unfortunately we can't test newer versions of blackjax because of conda-forge/jaxlib-feedstock#249 |
I made some more changes for my own experiments to offload on lighter hardware as follows:
I planned on later putting it up as a proposal if I liked it but since this change is stalled I figured it might be worth doing all together if you think the modifications are worthwhile. It's been helpful so far for me - let me know |
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #7311 +/- ##
==========================================
- Coverage 92.18% 92.18% -0.01%
==========================================
Files 103 103
Lines 17261 17257 -4
==========================================
- Hits 15912 15908 -4
Misses 1349 1349
|
Just removed the blackjax window adaptation memory fix that is stalled by no conda jaxlib updates. It's now covered by #7407. The remaining changes help with the sampling memory and should avoid repeat investigations into it |
Thanks @andrewdipper looks like a great improvement |
Co-authored-by: andrewdipper <[email protected]>
-Change blackjax sampling to only retain the relevant sampling info - reduces sampling memory requirements
-Change
_postprocess_samples
to reuse the input arrays - reduces postprocessing memory requirementsDescription
By default the current pymc blackjax sampler accumulates all the info provided by blackjax only to subsequently delete it. This results in memory usage several times what is expected (some info is
num_samples * num_vars
in size). The change only stores what is used resulting in memory scaling similar to that of the numpyro jax sampler. It's worth noting thatblackjax.window_adaptation
also has excessive memory usage but that needs to be fixed from blackjax blackjax-devs/blackjax#667. As such iftune
is not set sufficiently small memory usage will still be excessive.Changes the
"vmap"
mode of_postprocess_samples
to donate the input device arrays resulting in (for my rough tests / models) constant additional memory usage. This should make the"scan"
mode unnecessary. However, I left it in to not break anything.Related Issue
This should resolve by using
"vmap"
mode: #6744This should be unnecessary given the reduction in
"vmap"
memory usage: #7116Checklist
Type of change
📚 Documentation preview 📚: https://pymc--7311.org.readthedocs.build/en/7311/