diff --git a/colibri/blackjax_fit.py b/colibri/blackjax_fit.py index 22ea86cc..bfcba4d3 100644 --- a/colibri/blackjax_fit.py +++ b/colibri/blackjax_fit.py @@ -69,7 +69,7 @@ def blackjax_fit( log.info(f"Running fit with backend: {jax.default_backend()}") # set the BlackJAX seed - rng_key = jax.random.PRNGKey(blackjax_settings["seed"]) + rng_key = jax.random.PRNGKey(blackjax_settings["blackjax_seed"]) log.info(f"BlackJAX initialisation seed: {rng_key}") n_dims = pdf_model.n_parameters n_live = blackjax_settings["n_live"] diff --git a/colibri/doc/sphinx/source/tutorials/running_fits/bayesian/blackjax.rst b/colibri/doc/sphinx/source/tutorials/running_fits/bayesian/blackjax.rst index 3142af79..b3201443 100644 --- a/colibri/doc/sphinx/source/tutorials/running_fits/bayesian/blackjax.rst +++ b/colibri/doc/sphinx/source/tutorials/running_fits/bayesian/blackjax.rst @@ -111,7 +111,7 @@ Runcard delete_fraction: 0.5 log_precision: -3 posterior_resampling_seed: 52 - seed: 0 + blackjax_seed: 0 diff --git a/colibri/tests/test_blackjax_fit.py b/colibri/tests/test_blackjax_fit.py index c8845666..971a0033 100644 --- a/colibri/tests/test_blackjax_fit.py +++ b/colibri/tests/test_blackjax_fit.py @@ -48,7 +48,7 @@ def mock_sample(rng_key, n_samples): integrability_penalty = lambda pdf: jnp.array([0.0]) blackjax_settings = { - "seed": 42, + "blackjax_seed": 42, "n_live": 50, "delete_fraction": 0.5, "repeats": 2, @@ -103,7 +103,7 @@ def test_blackjax_fit_truncates_posterior_and_warns(caplog): ) blackjax_settings = { - "seed": 0, + "blackjax_seed": 0, "n_live": 4, "delete_fraction": 0.5, "repeats": 1,