Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduce JAX post-processing memory usage #7311

Merged
merged 8 commits into from
Jul 11, 2024
Merged

Conversation

andrewdipper
Copy link
Contributor

@andrewdipper andrewdipper commented May 13, 2024

-Change blackjax sampling to only retain the relevant sampling info - reduces sampling memory requirements
-Change _postprocess_samples to reuse the input arrays - reduces postprocessing memory requirements

Description

By default the current pymc blackjax sampler accumulates all the info provided by blackjax only to subsequently delete it. This results in memory usage several times what is expected (some info is num_samples * num_vars in size). The change only stores what is used resulting in memory scaling similar to that of the numpyro jax sampler. It's worth noting that blackjax.window_adaptation also has excessive memory usage but that needs to be fixed from blackjax blackjax-devs/blackjax#667. As such if tune is not set sufficiently small memory usage will still be excessive.

Changes the "vmap" mode of _postprocess_samples to donate the input device arrays resulting in (for my rough tests / models) constant additional memory usage. This should make the "scan" mode unnecessary. However, I left it in to not break anything.

Related Issue

This should resolve by using "vmap" mode: #6744
This should be unnecessary given the reduction in "vmap" memory usage: #7116

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

📚 Documentation preview 📚: https://pymc--7311.org.readthedocs.build/en/7311/

Copy link

welcome bot commented May 13, 2024

Thank You Banner]
💖 Thanks for opening this pull request! 💖 The PyMC community really appreciates your time and effort to contribute to the project. Please make sure you have read our Contributing Guidelines and filled in our pull request template to the best of your ability.

@andrewdipper
Copy link
Contributor Author

Looks like test failures are due to an older version of jax and a recent blackJAX PR to fix an argument deprecation to jnp.clip: blackjax-devs/blackjax#664
https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.clip.html

@andrewdipper
Copy link
Contributor Author

Also added in a bugfix for pm.sample not respecting compute_convergence_checks with numpyro/blackjax sampler

@twiecki
Copy link
Member

twiecki commented May 16, 2024

Thanks @andrewdipper!

@twiecki
Copy link
Member

twiecki commented May 16, 2024

Probably need to wait for #7317 to merge.

@ricardoV94
Copy link
Member

This should make the "scan" mode unnecessary. However, I left it in to not break anything.

Let's start deprecating it! Can you add a FutureWarning about this argument/ value being removed in the future when the user manually specifies it?

@ricardoV94 ricardoV94 changed the title reduce blackjax sampler and postprocessing memory Reduce JAX sampler and postprocessing memory May 16, 2024
@ricardoV94 ricardoV94 changed the title Reduce JAX sampler and postprocessing memory Reduce JAX sampler memory usage May 16, 2024
@ricardoV94
Copy link
Member

Rebasing from main should fix the test failures and allow us to confirm nothing got broken

pymc/sampling/jax.py Outdated Show resolved Hide resolved
pymc/sampling/jax.py Outdated Show resolved Hide resolved
@andrewdipper
Copy link
Contributor Author

I added the fix for reducing memory from the blackjax window_adaptation. But that depends on blackjax-devs/blackjax#674 which was just merged earlier today. So blackjax would have to be up to date - not sure how that's best handled.

@ricardoV94
Copy link
Member

ricardoV94 commented May 17, 2024

Unfortunately we can't test newer versions of blackjax because of conda-forge/jaxlib-feedstock#249

@andrewdipper
Copy link
Contributor Author

I made some more changes for my own experiments to offload on lighter hardware as follows:

  • Enable sampling in chunks to make memory independent of number of samples (if chunks > 1 then save on host memory)
  • Allow pmap of vmap with shard map - so sampling of 4 chains can be done on 2 devices with 2 chains per device etc.
  • Postprocessing / likelihood is done on a per chunk basis instead of all at the end (should eliminate the need for choosing the device for postprocessing)

I planned on later putting it up as a proposal if I liked it but since this change is stalled I figured it might be worth doing all together if you think the modifications are worthwhile. It's been helpful so far for me - let me know

Copy link

codecov bot commented Jul 11, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 92.18%. Comparing base (a4ea9fc) to head (f43f002).
Report is 96 commits behind head on main.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #7311      +/-   ##
==========================================
- Coverage   92.18%   92.18%   -0.01%     
==========================================
  Files         103      103              
  Lines       17261    17257       -4     
==========================================
- Hits        15912    15908       -4     
  Misses       1349     1349              
Files with missing lines Coverage Δ
pymc/sampling/jax.py 94.00% <100.00%> (-0.11%) ⬇️
pymc/sampling/mcmc.py 87.97% <ø> (ø)

@andrewdipper
Copy link
Contributor Author

Just removed the blackjax window adaptation memory fix that is stalled by no conda jaxlib updates. It's now covered by #7407. The remaining changes help with the sampling memory and should avoid repeat investigations into it

@andrewdipper andrewdipper requested a review from ricardoV94 July 11, 2024 03:34
@ricardoV94 ricardoV94 changed the title Reduce JAX sampler memory usage Reduce JAX post-processing memory usage Jul 11, 2024
@ricardoV94 ricardoV94 merged commit 2216b59 into pymc-devs:main Jul 11, 2024
22 checks passed
@ricardoV94
Copy link
Member

Thanks @andrewdipper looks like a great improvement

@andrewdipper andrewdipper deleted the memfix branch July 11, 2024 14:41
mkusnetsov pushed a commit to mkusnetsov/pymc that referenced this pull request Oct 26, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants