Skip to content

Commit

Permalink
Nested rhat
Browse files Browse the repository at this point in the history
  • Loading branch information
gil2rok authored Oct 30, 2024
1 parent 65ae00e commit a103796
Showing 1 changed file with 61 additions and 1 deletion.
62 changes: 61 additions & 1 deletion blackjax/diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from blackjax.types import Array, ArrayLike

__all__ = ["potential_scale_reduction", "effective_sample_size"]
__all__ = ["potential_scale_reduction", "nested_rhat", "effective_sample_size"]


def potential_scale_reduction(
Expand Down Expand Up @@ -75,6 +75,66 @@ def potential_scale_reduction(
return rhat_value.squeeze()


def nested_rhat(
input_array: ArrayLike,
superchain_axis: int = 0,
chain_axis: int = 1,
sample_axis: int = 2,
) -> Array:
"""Margossian et al. (2024)'s nested R-hat for computing multiple MCMC superchain convergence.
Parameters
----------
input_array
An array representing multiple superchains of MCMC smaples. The array must
contain a superchain dimension, chain dimension, and sample dimension.
superchain_axis
The axis indicating the multiple superchains. Default to 0.
chain_axis
The axis indicating the multiple chains. Default to 1.
sample_axis
The axis indicating a single chain of MCMC samples. Default to 2.
Returns
-------
NDArray of the resulting statistics (r-hat), with the chain and sample dimensions squeezed.
"""
assert input_array.ndim == 4, "The input array must have 4 dimensions."
num_chains = input_array.shape[chain_axis]
num_samples = input_array.shape[sample_axis]
param_axis = 3 - (chain_axis + sample_axis + superchain_axis)
num_params = input_array.shape[param_axis]
assert (
num_chains > 1 or num_samples > 1
), "num_chains or num_samples must be greater than 1 for valid nested R-hat."

chain_means = jnp.mean(input_array, axis=sample_axis)
super_means = jnp.mean(chain_means, axis=chain_axis)
total_mean = jnp.mean(super_means, axis=superchain_axis)

between_var = jnp.mean(jnp.square(super_means - total_mean), axis=superchain_axis)

if num_chains > 1:
within_chain_var = jnp.mean(
jnp.square(chain_means - super_means), axis=chain_axis
)
else:
within_chain_var = jnp.zeros(num_params)

if num_samples > 1:
within_super_var = jnp.mean(
jnp.square(input_array - chain_means), axis=(chain_axis, sample_axis)
)
else:
within_super_var = jnp.zeros(num_params)

within_var = jnp.mean(within_chain_var + within_super_var, axis=superchain_axis)

nested_rhat_value = jnp.sqrt(1 + between_var / within_var)
return nested_rhat_value.squeeze()


def effective_sample_size(
input_array: ArrayLike, chain_axis: int = 0, sample_axis: int = 1
) -> Array:
Expand Down

0 comments on commit a103796

Please sign in to comment.