-
Notifications
You must be signed in to change notification settings - Fork 108
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
window_adaptation excessive memory usage #667
Comments
Agree. I think numpryo also has a flag to control what get exposed. I think we will need to add a kwarg to blackjax/blackjax/adaptation/window_adaptation.py Lines 244 to 252 in 40efb6c
def return_all_adapt_info(state, info, adaptation_state):
return AdaptationInfo(state, info, adaptation_state)
def window_adaptation(
algorithm,
logdensity_fn: Callable,
is_mass_matrix_diagonal: bool = True,
initial_step_size: float = 1.0,
target_acceptance_rate: float = 0.80,
progress_bar: bool = False,
adaptation_info_fn: Callable = return_all_adapt_info
**extra_parameters,
) And then add some utility function for filtering common info (e.g., |
Really good point and thank you for the deep dive!! Feel free to send a PR. |
Describe the issue as clearly as possible:
The scan in window_adaptation by default saves the
AdaptationInfo
for every sample along the way. This results in memory usage many times in excess of (num_samples)*(num_variables) and leads to out of memory issues. However, it looks like the last states are the only information necessary to performing the window adaptation.As such it'd be helpful to disable / select what info to store along the way such that the auxiliary info doesn't cause out of memory issues. Removing it altogether also doesn't seem ideal. I'd be happy to give a PR if you have an idea of what/how to best include/exclude the extra info.
I believe #529 is the result of the same thing: The extra buffers are likely for storing the sample by sample info - I get similar outputs.
Thanks
Steps/code to reproduce the bug:
Expected result:
Error message:
Blackjax/JAX/jaxlib/Python version information:
Context for the issue:
I found this trying to reduce memory consumption for a pymc model sampled with blackjax - there's a similar issue there with storing extra info during the actual sampling process. With both fixes memory consumption and performance are initially looking more than comparable with pymc's numpyro sampler.
The text was updated successfully, but these errors were encountered: