diff --git a/blackjax/__init__.py b/blackjax/__init__.py index 5858c34aa..8b0b7be11 100644 --- a/blackjax/__init__.py +++ b/blackjax/__init__.py @@ -10,6 +10,7 @@ from .adaptation.window_adaptation import window_adaptation from .base import SamplingAlgorithm, VIAlgorithm from .diagnostics import effective_sample_size as ess +from .diagnostics import nested_rhat as nested_rhat from .diagnostics import potential_scale_reduction as rhat from .mcmc import barker from .mcmc import dynamic_hmc as _dynamic_hmc