Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduce JAX post-processing memory usage #7311

Merged
merged 8 commits into from
Jul 11, 2024
103 changes: 48 additions & 55 deletions pymc/sampling/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,11 @@ def _get_log_likelihood(
elemwise_logp = model.logp(model.observed_RVs, sum=False)
jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=elemwise_logp)
result = _postprocess_samples(
jax_fn, samples, backend, postprocessing_vectorize=postprocessing_vectorize
jax_fn,
samples,
backend,
postprocessing_vectorize=postprocessing_vectorize,
donate_samples=False,
)
return {v.name: r for v, r in zip(model.observed_RVs, result)}

Expand All @@ -181,7 +185,8 @@ def _postprocess_samples(
jax_fn: Callable,
raw_mcmc_samples: list[TensorVariable],
postprocessing_backend: Literal["cpu", "gpu"] | None = None,
postprocessing_vectorize: Literal["vmap", "scan"] = "scan",
postprocessing_vectorize: Literal["vmap", "scan"] = "vmap",
donate_samples: bool = False,
) -> list[TensorVariable]:
if postprocessing_vectorize == "scan":
t_raw_mcmc_samples = [jnp.swapaxes(t, 0, 1) for t in raw_mcmc_samples]
Expand All @@ -193,7 +198,12 @@ def _postprocess_samples(
)
return [jnp.swapaxes(t, 0, 1) for t in outs]
elif postprocessing_vectorize == "vmap":
return jax.vmap(jax.vmap(jax_fn))(*_device_put(raw_mcmc_samples, postprocessing_backend))

def process_fn(x):
return jax.vmap(jax.vmap(jax_fn))(*_device_put(x, postprocessing_backend))

return jax.jit(process_fn, donate_argnums=0 if donate_samples else None)(raw_mcmc_samples)

else:
raise ValueError(f"Unrecognized postprocessing_vectorize: {postprocessing_vectorize}")

Expand Down Expand Up @@ -253,7 +263,16 @@ def _blackjax_inference_loop(
def _one_step(state, xs):
_, rng_key = xs
state, info = kernel(rng_key, state)
return state, (state, info)
position = state.position
stats = {
"diverging": info.is_divergent,
"energy": info.energy,
"tree_depth": info.num_trajectory_expansions,
"n_steps": info.num_integration_steps,
"acceptance_rate": info.acceptance_rate,
"lp": state.logdensity,
}
return state, (position, stats)

progress_bar = adaptation_kwargs.pop("progress_bar", False)
if progress_bar:
Expand All @@ -264,43 +283,9 @@ def _one_step(state, xs):
one_step = jax.jit(_one_step)

keys = jax.random.split(seed, draws)
_, (states, infos) = jax.lax.scan(one_step, last_state, (jnp.arange(draws), keys))

return states, infos


def _blackjax_stats_to_dict(sample_stats, potential_energy) -> dict:
"""Extract compatible stats from blackjax NUTS sampler
with PyMC/Arviz naming conventions.

Parameters
----------
sample_stats: NUTSInfo
Blackjax NUTSInfo object containing sampler statistics
potential_energy: ArrayLike
Potential energy values of sampled positions.
_, (samples, stats) = jax.lax.scan(one_step, last_state, (jnp.arange(draws), keys))

Returns
-------
Dict[str, ArrayLike]
Dictionary of sampler statistics.
"""
rename_key = {
"is_divergent": "diverging",
"energy": "energy",
"num_trajectory_expansions": "tree_depth",
"num_integration_steps": "n_steps",
"acceptance_rate": "acceptance_rate", # naming here is
"acceptance_probability": "acceptance_rate", # depending on blackjax version
}
converted_stats = {}
converted_stats["lp"] = potential_energy
for old_name, new_name in rename_key.items():
value = getattr(sample_stats, old_name, None)
if value is None:
continue
converted_stats[new_name] = value
return converted_stats
return samples, stats


def _sample_blackjax_nuts(
Expand Down Expand Up @@ -410,11 +395,7 @@ def _sample_blackjax_nuts(
**nuts_kwargs,
)

states, stats = map_fn(get_posterior_samples)(keys, initial_points)
raw_mcmc_samples = states.position
potential_energy = states.logdensity.block_until_ready()
sample_stats = _blackjax_stats_to_dict(stats, potential_energy)

raw_mcmc_samples, sample_stats = map_fn(get_posterior_samples)(keys, initial_points)
return raw_mcmc_samples, sample_stats, blackjax


Expand Down Expand Up @@ -515,7 +496,7 @@ def sample_jax_nuts(
keep_untransformed: bool = False,
chain_method: str = "parallel",
postprocessing_backend: Literal["cpu", "gpu"] | None = None,
postprocessing_vectorize: Literal["vmap", "scan"] = "scan",
postprocessing_vectorize: Literal["vmap", "scan"] | None = None,
postprocessing_chunks=None,
idata_kwargs: dict | None = None,
compute_convergence_checks: bool = True,
Expand Down Expand Up @@ -597,6 +578,16 @@ def sample_jax_nuts(
DeprecationWarning,
)

if postprocessing_vectorize is not None:
import warnings

warnings.warn(
'postprocessing_vectorize={"scan", "vmap"} will be removed in a future release.',
FutureWarning,
)
else:
postprocessing_vectorize = "vmap"

model = modelcontext(model)

if var_names is not None:
Expand Down Expand Up @@ -645,15 +636,6 @@ def sample_jax_nuts(
)
tic2 = datetime.now()

jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=vars_to_sample)
result = _postprocess_samples(
jax_fn,
raw_mcmc_samples,
postprocessing_backend=postprocessing_backend,
postprocessing_vectorize=postprocessing_vectorize,
)
mcmc_samples = {v.name: r for v, r in zip(vars_to_sample, result)}

if idata_kwargs is None:
idata_kwargs = {}
else:
Expand All @@ -669,6 +651,17 @@ def sample_jax_nuts(
else:
log_likelihood = None

jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=vars_to_sample)
result = _postprocess_samples(
jax_fn,
raw_mcmc_samples,
postprocessing_backend=postprocessing_backend,
postprocessing_vectorize=postprocessing_vectorize,
donate_samples=True,
)
del raw_mcmc_samples
mcmc_samples = {v.name: r for v, r in zip(vars_to_sample, result)}

attrs = {
"sampling_time": (tic2 - tic1).total_seconds(),
"tuning_steps": tune,
Expand Down
4 changes: 4 additions & 0 deletions pymc/sampling/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,7 @@ def _sample_external_nuts(
var_names: Sequence[str] | None,
progressbar: bool,
idata_kwargs: dict | None,
compute_convergence_checks: bool,
nuts_sampler_kwargs: dict | None,
**kwargs,
):
Expand Down Expand Up @@ -364,6 +365,7 @@ def _sample_external_nuts(
progressbar=progressbar,
nuts_sampler=sampler,
idata_kwargs=idata_kwargs,
compute_convergence_checks=compute_convergence_checks,
**nuts_sampler_kwargs,
)
return idata
Expand Down Expand Up @@ -718,6 +720,7 @@ def joined_blas_limiter():
raise ValueError(
"Model can not be sampled with NUTS alone. Your model is probably not continuous."
)

with joined_blas_limiter():
return _sample_external_nuts(
sampler=nuts_sampler,
Expand All @@ -731,6 +734,7 @@ def joined_blas_limiter():
var_names=var_names,
progressbar=progressbar,
idata_kwargs=idata_kwargs,
compute_convergence_checks=compute_convergence_checks,
nuts_sampler_kwargs=nuts_sampler_kwargs,
**kwargs,
)
Expand Down
Loading