Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
5c6be3d
Factored out forward map
LucaMantani Feb 16, 2026
83bea26
Added tests
LucaMantani Feb 16, 2026
458636d
removed pred_and_pdf from everywhere
LucaMantani Feb 16, 2026
7a8c95a
New forward map class
LucaMantani Feb 16, 2026
2b898bc
Refined implementation
LucaMantani Feb 16, 2026
0d8a241
Changed conftest
LucaMantani Feb 16, 2026
744b275
Fixed tests
LucaMantani Feb 16, 2026
29dfd7e
Make sure we write pdfs with the first parameters
LucaMantani Feb 16, 2026
0fd84f3
Restored doc
LucaMantani Feb 16, 2026
5dcb3c3
Fixed bug in tests
LucaMantani Feb 16, 2026
1e57a68
use check_pdf_model_is_linear as function rather than decorator
comane Mar 23, 2026
1afebf6
Apply suggestion from @comane
comane Mar 23, 2026
8f7cffd
added tests for forward map
comane Mar 24, 2026
97ceb1b
merge commit
comane Mar 24, 2026
a4b01b4
fixed tests from merge
comane Mar 24, 2026
cbf5ff6
upgraded local black and formatted forward map tests
comane Mar 24, 2026
74d0f4e
added line for raise not implemented in forward map
comane Mar 24, 2026
7da4f04
forward model initialised with pdf parameter names
comane Mar 26, 2026
054c97f
pass forward model to bayesian prior for total model params
comane Mar 26, 2026
22207f3
pass pdf_model object to forward map
comane Mar 26, 2026
18dc793
Update colibri/forward_map.py
LucaMantani Apr 28, 2026
089110c
Update colibri/forward_map.py
LucaMantani Apr 28, 2026
cf02826
Update colibri/forward_map.py
LucaMantani Apr 28, 2026
0cdce57
Added test
LucaMantani Apr 28, 2026
20941aa
removed grid_func passing to forward map
LucaMantani Apr 28, 2026
8581a03
black
LucaMantani Apr 28, 2026
11641ba
Adapted tests
LucaMantani Apr 28, 2026
f134f71
Merge branch 'main' into separate-forward-map
LucaMantani Apr 28, 2026
79c1e73
Merge branch 'main' into separate-forward-map
LucaMantani Apr 28, 2026
28daeab
Adapted after merging main
LucaMantani Apr 28, 2026
a5cddec
Fixed bug
LucaMantani Apr 28, 2026
bface54
typo
vschutze-alt Apr 29, 2026
177898a
typo
vschutze-alt Apr 29, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions colibri/analytic_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,15 +82,15 @@ def analytic_evidence_uniform_prior(sol_covmat, sol_mean, max_logl, a_vec, b_vec
return log_evidence, log_occam_factor


@check_pdf_model_is_linear
def analytic_fit(
central_covmat_index,
_pred_data,
forward_map,
pdf_model,
analytic_settings,
prior_settings,
FIT_XGRID,
fast_kernel_arrays,
data,
):
"""
Analytic fits, for any *linear* PDF model.
Expand All @@ -106,8 +106,8 @@ def analytic_fit(
central_covmat_index: commondata_utils.CentralCovmatIndex
dataclass containing central values and covariance matrix.

_pred_data: @jax.jit CompiledFunction
Prediction function for the fit.
forward_map: @jax.jit CompiledFunction
Forward map function for the fit.

pdf_model: pdf_model.PDFModel
PDF model to fit.
Expand All @@ -124,22 +124,27 @@ def analytic_fit(

fast_kernel_arrays: tuple
Tuple containing the fast kernel arrays.

data: validphys.core.DataGroupSpec
The data group specification for the fit.
"""
# Ensure that the PDF model is linear before running the fit.
log.info("Checking that the PDF model is linear...")
check_pdf_model_is_linear(pdf_model, forward_map, FIT_XGRID, data)

log.warning("The prior is assumed to be flat in the parameters.")
log.warning(
"Assuming that the prior is wide enough to fully cover the gaussian likelihood."
)

parameters = pdf_model.param_names
pred_and_pdf = pdf_model.pred_and_pdf_func(FIT_XGRID, forward_map=_pred_data)

# Precompute predictions for the basis of the model
bases = jnp.identity(len(parameters))
predictions = jnp.array(
[pred_and_pdf(basis, fast_kernel_arrays)[0] for basis in bases]
[forward_map(fast_kernel_arrays, basis)[0] for basis in bases]
)
intercept = pred_and_pdf(jnp.zeros(len(parameters)), fast_kernel_arrays)[0]
intercept = forward_map(fast_kernel_arrays, jnp.zeros(len(parameters)))[0]

# Construct the analytic solution
central_values = central_covmat_index.central_values
Expand Down
1 change: 1 addition & 0 deletions colibri/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
"colibri.param_initialisation",
"colibri.export_results",
"colibri.closure_test",
"colibri.forward_map",
"reportengine.report",
]

Expand Down
13 changes: 6 additions & 7 deletions colibri/bayes_prior.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,13 @@
cast_to_numpy,
get_full_posterior,
)
from colibri.checks import check_pdf_models_equal
from colibri.core import BayesianPrior
import tensorflow_probability.substrates.jax as tfp

tfd = tfp.distributions


@check_pdf_models_equal
def bayesian_prior(prior_settings, pdf_model):
def bayesian_prior(prior_settings, forward_map):
"""
Produces a prior transform function.

Expand All @@ -31,8 +29,8 @@ def bayesian_prior(prior_settings, pdf_model):
prior_specs = prior_settings.prior_distribution_specs

if "bounds" in prior_specs:
# Use param names from the model to order bounds correctly
param_names = pdf_model.param_names
# Use param names from the forward map to order bounds correctly
param_names = forward_map.param_names
bounds_dict = prior_specs["bounds"]
missing = [p for p in param_names if p not in bounds_dict]
if missing:
Expand All @@ -45,8 +43,9 @@ def bayesian_prior(prior_settings, pdf_model):

elif "min_val" in prior_specs and "max_val" in prior_specs:
# Global bounds for all parameters
mins = jnp.array([float(prior_specs["min_val"])] * pdf_model.n_parameters)
maxs = jnp.array([float(prior_specs["max_val"])] * pdf_model.n_parameters)
n_params = len(forward_map.param_names)
mins = jnp.array([float(prior_specs["min_val"])] * n_params)
maxs = jnp.array([float(prior_specs["max_val"])] * n_params)

else:
raise ValueError(
Expand Down
14 changes: 7 additions & 7 deletions colibri/blackjax_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@


def blackjax_fit(
pdf_model,
forward_map,
bayesian_prior,
blackjax_settings,
log_likelihood,
Expand All @@ -48,8 +48,8 @@ def blackjax_fit(

Parameters
----------
pdf_model: pdf_model.PDFModel
The PDF model to fit.
forward_map: ForwardMap
The forward map whose ``param_names`` enumerate all fit parameters.

bayesian_prior: BayesianPrior, @jax.jit CompiledFunction
The prior function for the model.
Expand All @@ -71,7 +71,7 @@ def blackjax_fit(
# set the BlackJAX seed
rng_key = jax.random.PRNGKey(blackjax_settings["seed"])
log.info(f"BlackJAX initialisation seed: {rng_key}")
n_dims = pdf_model.n_parameters
n_dims = len(forward_map.param_names)
n_live = blackjax_settings["n_live"]
n_delete = int(blackjax_settings["delete_fraction"] * n_live)

Expand Down Expand Up @@ -142,15 +142,15 @@ def one_step(carry, xs):
data=final_states.particles,
logL=final_states.loglikelihood,
logL_birth=final_states.loglikelihood_birth,
columns=pdf_model.param_names,
columns=forward_map.param_names,
)
# write nested_samples.csv to blackjax_logs
log_dir = blackjax_settings["log_dir"]
os.makedirs(log_dir, exist_ok=True) # Create directory if it doesn't exist
nested_samples.to_csv(log_dir + "/nested_samples.csv")

# Export resampled posterior samples
posterior_df = pd.DataFrame(resampled_posterior, columns=pdf_model.param_names)
posterior_df = pd.DataFrame(resampled_posterior, columns=forward_map.param_names)
posterior_df.to_csv(os.path.join(log_dir, "posterior_samples.csv"), index=False)

# Compute bayesian metrics (similar to UltraNest)
Expand All @@ -172,7 +172,7 @@ def one_step(carry, xs):
"logZ_err": logzs.std(),
"ess": ess_value,
},
param_names=pdf_model.param_names,
param_names=forward_map.param_names,
resampled_posterior=resampled_posterior,
full_posterior_samples=full_samples,
bayesian_metrics={
Expand Down
25 changes: 12 additions & 13 deletions colibri/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,25 +9,26 @@
import jax
from colibri.theory_predictions import make_pred_data, fast_kernel_arrays

from colibri.utils import get_fit_path, get_pdf_model, pdf_models_equal
from colibri.utils import get_fit_path, get_pdf_model


@make_argcheck
def check_pdf_models_equal(prior_settings, pdf_model, theoryid):
def check_pdf_models_equal(prior_settings, forward_map, theoryid):
"""
Decorator that can be added to functions to check that the
PDF model used as prior (eg when using prior_settings["type"] == "prior_from_gauss_posterior")
matches the PDF model used in the current fit (pdf_model).
matches the PDF model used in the current fit (via ``forward_map.pdf_param_names``).
"""

if prior_settings.prior_distribution == "prior_from_gauss_posterior":

prior_fit = prior_settings.prior_distribution_specs["prior_fit"]
prior_pdf_model = get_pdf_model(prior_fit)

if not pdf_models_equal(prior_pdf_model, pdf_model):
if prior_pdf_model.param_names != list(forward_map.pdf_param_names):
raise ValueError(
f"PDF model {pdf_model} does not match prior settings {prior_pdf_model}"
f"PDF param names from forward_map {list(forward_map.pdf_param_names)} "
f"do not match prior PDF model param names {prior_pdf_model.param_names}"
)

# load filter.yml runcard of the prior fit
Expand All @@ -41,8 +42,7 @@ def check_pdf_models_equal(prior_settings, pdf_model, theoryid):
)


@make_argcheck
def check_pdf_model_is_linear(pdf_model, FIT_XGRID, data):
def check_pdf_model_is_linear(pdf_model, forward_map, FIT_XGRID, data):
"""
Decorator that can be added to functions to check that the
PDF model is linear.
Expand All @@ -52,8 +52,7 @@ def check_pdf_model_is_linear(pdf_model, FIT_XGRID, data):
fk = fast_kernel_arrays(data, FIT_XGRID)

parameters = pdf_model.param_names
pred_and_pdf = pdf_model.pred_and_pdf_func(FIT_XGRID, forward_map=pred_data)
intercept = pred_and_pdf(jnp.zeros(len(parameters)), fk)[0]
intercept, _ = forward_map(fk, jnp.zeros(len(parameters)))

# Run the check for 10 random points in the parameter space
for i in range(10):
Expand All @@ -65,16 +64,16 @@ def check_pdf_model_is_linear(pdf_model, FIT_XGRID, data):

# Test additivity
add_check = jnp.isclose(
pred_and_pdf(x1, fk)[0] + pred_and_pdf(x2, fk)[0],
pred_and_pdf(x1 + x2, fk)[0] + intercept,
forward_map(fk, x1)[0] + forward_map(fk, x2)[0],
forward_map(fk, x1 + x2)[0] + intercept,
)

# Test homogeneity
c = jax.random.uniform(key, (1,))

homogeneity_check = jnp.isclose(
c * (pred_and_pdf(x1, fk)[0] - intercept),
pred_and_pdf(c * x1, fk)[0] - intercept,
c * (forward_map(fk, x1)[0] - intercept),
forward_map(fk, c * x1)[0] - intercept,
)

if not add_check.all() or not homogeneity_check.all():
Expand Down
3 changes: 2 additions & 1 deletion colibri/export_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,10 +246,11 @@ def write_replicas(

# Create the exportgrid
lhapdf_interpolator = pdf_model.grid_values_func(xgrid)
n_pdf_params = len(pdf_model.param_names)

# Finish by writing the replicas to export grids, ready for evolution
for i in indices_per_process:
parameters = jnp.array(bayes_fit.resampled_posterior[i, :])
parameters = jnp.array(bayes_fit.resampled_posterior[i, :n_pdf_params])
grid_for_writing = np.array(lhapdf_interpolator(parameters))

replica_index = i + 1
Expand Down
Loading
Loading