From 5c6be3dce13a6db8fcd255ceeb1a0057dd26a21d Mon Sep 17 00:00:00 2001 From: Luca Mantani Date: Mon, 16 Feb 2026 12:11:55 +0100 Subject: [PATCH 01/30] Factored out forward map --- colibri/app.py | 1 + colibri/likelihood.py | 18 +++++++++--------- colibri/pdf_model.py | 37 ------------------------------------- 3 files changed, 10 insertions(+), 46 deletions(-) diff --git a/colibri/app.py b/colibri/app.py index 435f852b9..74000d5be 100644 --- a/colibri/app.py +++ b/colibri/app.py @@ -31,6 +31,7 @@ "colibri.param_initialisation", "colibri.export_results", "colibri.closure_test", + "colibri.forward_map", "reportengine.report", ] diff --git a/colibri/likelihood.py b/colibri/likelihood.py index 586a8c878..d9ee7fe41 100644 --- a/colibri/likelihood.py +++ b/colibri/likelihood.py @@ -62,9 +62,8 @@ def __init__( self.positivity_penalty_settings = positivity_penalty_settings self.integrability_penalty = integrability_penalty - self.pred_and_pdf = pdf_model.pred_and_pdf_func( - fit_xgrid, forward_map=forward_map - ) + self.pdf_grid = pdf_model.grid_values_func(fit_xgrid) + self.forward_map = forward_map self.fast_kernel_arrays = fast_kernel_arrays self.positivity_fast_kernel_arrays = positivity_fast_kernel_arrays @@ -126,7 +125,8 @@ def log_likelihood( jnp.ndarray jax array with the value of the log-likelihood. """ - predictions, pdf = self.pred_and_pdf(params, fast_kernel_arrays) + pdf = self.pdf_grid(params) + predictions = self.forward_map(pdf, fast_kernel_arrays) # Select only the data relevant for this likelihood # Especially important when using a training/validation split predictions = predictions[self.central_values_idx] @@ -169,7 +169,7 @@ def log_likelihood( central_covmat_index, pdf_model, FIT_XGRID, - _pred_data, + forward_map, fast_kernel_arrays, positivity_fast_kernel_arrays, _penalty_posdata, @@ -186,7 +186,7 @@ def log_likelihood( central_covmat_index, pdf_model, FIT_XGRID, - _pred_data, + forward_map, fast_kernel_arrays, positivity_fast_kernel_arrays, _penalty_posdata, @@ -200,7 +200,7 @@ def mc_log_likelihood( fit_covariance_matrix, pdf_model, FIT_XGRID, - _pred_data, + forward_map, fast_kernel_arrays, positivity_fast_kernel_arrays, _penalty_posdata, @@ -228,7 +228,7 @@ def mc_log_likelihood( central_covmat_index_train, pdf_model, FIT_XGRID, - _pred_data, + forward_map, fast_kernel_arrays, positivity_fast_kernel_arrays, _penalty_posdata, @@ -254,7 +254,7 @@ def mc_log_likelihood( central_covmat_index_val, pdf_model, FIT_XGRID, - _pred_data, + forward_map, fast_kernel_arrays, positivity_fast_kernel_arrays, _penalty_posdata, diff --git a/colibri/pdf_model.py b/colibri/pdf_model.py index 215ea2ff1..11fbf71b5 100644 --- a/colibri/pdf_model.py +++ b/colibri/pdf_model.py @@ -53,40 +53,3 @@ def func(params): return func """ pass - - def pred_and_pdf_func( - self, - xgrid: ArrayLike, - forward_map: Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray], - ) -> Callable[[jnp.ndarray, jnp.ndarray], Tuple[jnp.ndarray, jnp.ndarray]]: - """Creates a function that returns a tuple of two arrays, given the model parameters and the fast kernel arrays as input. - - The returned function produces: - - The first array: 1D vector of theory predictions for the data. - - The second array: PDF values evaluated on the x-grid, using `self.grid_values_func`, with shape (Nfl, Nx). - - The `forward_map` is used to map the PDF values defined on the x-grid and the fast kernel arrays into the corresponding theory prediction vector. - """ - pdf_func = self.grid_values_func(xgrid) - - def pred_and_pdf(params, fast_kernel_arrays): - """ - Parameters - ---------- - params: jnp.array - The model parameters. - - fast_kernel_arrays: tuple - tuple of tuples of jnp.arrays - The FK tables to use. - - Returns - ------- - tuple - The predictions and the PDF values. - """ - pdf = pdf_func(params) - predictions = forward_map(pdf, fast_kernel_arrays) - return predictions, pdf - - return pred_and_pdf From 83bea26ec9db9f3785561dcb37a3f7bb0d6defed Mon Sep 17 00:00:00 2001 From: Luca Mantani Date: Mon, 16 Feb 2026 12:26:55 +0100 Subject: [PATCH 02/30] Added tests --- colibri/forward_map.py | 13 +++++++++++++ colibri/tests/test_likelihood.py | 33 ++++++++++++++++++-------------- colibri/tests/test_pdf_model.py | 14 -------------- 3 files changed, 32 insertions(+), 28 deletions(-) create mode 100644 colibri/forward_map.py diff --git a/colibri/forward_map.py b/colibri/forward_map.py new file mode 100644 index 000000000..c34295042 --- /dev/null +++ b/colibri/forward_map.py @@ -0,0 +1,13 @@ +""" +colibri.forward_map.py + +This module implements the forward map, i.e. the map from the PDF grid to the theory predictions for each dataset, using the FK tables. + +""" + + +def forward_map(make_pred_data): + """ + Internal alias function for make_pred_data. + """ + return make_pred_data diff --git a/colibri/tests/test_likelihood.py b/colibri/tests/test_likelihood.py index 4d69644cc..18ca2475a 100644 --- a/colibri/tests/test_likelihood.py +++ b/colibri/tests/test_likelihood.py @@ -66,8 +66,9 @@ def test_LogLikelihood_class(pos_penalty): ] ) # Compute expected value using actual prediction and covariance - predictions, pdf = log_likelihood_class.pred_and_pdf( - params, log_likelihood_class.fast_kernel_arrays + pdf = log_likelihood_class.pdf_grid(params) + predictions = log_likelihood_class.forward_map( + pdf, log_likelihood_class.fast_kernel_arrays ) predictions = predictions[log_likelihood_class.central_values_idx] diff = predictions - log_likelihood_class.central_values @@ -165,8 +166,9 @@ def test_log_likelihood_with_and_without_pos_penalty(): ) # Compute expectation directly: -0.5 * (chi2 + pos_pen + integ_pen) - predictions, pdf = log_likelihood_class.pred_and_pdf( - params, log_likelihood_class.fast_kernel_arrays + pdf = log_likelihood_class.pdf_grid(params) + predictions = log_likelihood_class.forward_map( + pdf, log_likelihood_class.fast_kernel_arrays ) predictions = predictions[log_likelihood_class.central_values_idx] diff = predictions - log_likelihood_class.central_values @@ -213,8 +215,9 @@ def test_log_likelihood_with_and_without_pos_penalty(): ) # Expectation: Only chi2 value (penalties zeroed) - predictions, pdf = log_likelihood_class.pred_and_pdf( - params, log_likelihood_class.fast_kernel_arrays + pdf = log_likelihood_class.pdf_grid(params) + predictions = log_likelihood_class.forward_map( + pdf, log_likelihood_class.fast_kernel_arrays ) predictions = predictions[log_likelihood_class.central_values_idx] diff = predictions - log_likelihood_class.central_values @@ -275,7 +278,8 @@ def test_mc_log_likelihood_with_split(pos_penalty): # Compute expected for train and validation independently def compute_expected(ll_obj): - preds, pdf = ll_obj.pred_and_pdf(params, ll_obj.fast_kernel_arrays) + pdf = ll_obj.pdf_grid(params) + preds = ll_obj.forward_map(pdf, ll_obj.fast_kernel_arrays) preds = preds[ll_obj.central_values_idx] diff = preds - ll_obj.central_values inv = ll_obj.inv_covmat @@ -348,9 +352,8 @@ def test_mc_log_likelihood_without_split_returns_nan_for_validation(pos_penalty) params = jnp.array([0.3, 0.4]) train_val = train_loglike(params) # Compute expected train value - predictions, pdf = train_loglike.pred_and_pdf( - params, train_loglike.fast_kernel_arrays - ) + pdf = train_loglike.pdf_grid(params) + predictions = train_loglike.forward_map(pdf, train_loglike.fast_kernel_arrays) predictions = predictions[train_loglike.central_values_idx] diff = predictions - train_loglike.central_values chi2_val = jnp.einsum("i,ij,j", diff, train_loglike.inv_covmat, diff) @@ -409,8 +412,9 @@ def test_LogLikelihood_call_with_batch_idx(pos_penalty): ll_value_batched = log_likelihood_class(params, batch=batch) # Compute expected on the batch index: recompute inv_covmat on the sub-covmat - predictions, pdf = log_likelihood_class.pred_and_pdf( - params, log_likelihood_class.fast_kernel_arrays + pdf = log_likelihood_class.pdf_grid(params) + predictions = log_likelihood_class.forward_map( + pdf, log_likelihood_class.fast_kernel_arrays ) predictions = predictions[log_likelihood_class.central_values_idx] predictions_b = predictions[batch.idx] @@ -476,8 +480,9 @@ def test_LogLikelihood_call_with_batch_with_inv_cov(pos_penalty): ll_value_batched = log_likelihood_class(params, batch=batch) # Compute expected value using the provided inv_b (should be identical) - predictions, pdf = log_likelihood_class.pred_and_pdf( - params, log_likelihood_class.fast_kernel_arrays + pdf = log_likelihood_class.pdf_grid(params) + predictions = log_likelihood_class.forward_map( + pdf, log_likelihood_class.fast_kernel_arrays ) predictions = predictions[log_likelihood_class.central_values_idx] predictions_b = predictions[batch.idx] diff --git a/colibri/tests/test_pdf_model.py b/colibri/tests/test_pdf_model.py index 11786b2f0..c3759cfd2 100644 --- a/colibri/tests/test_pdf_model.py +++ b/colibri/tests/test_pdf_model.py @@ -35,17 +35,3 @@ def test_grid_values_func(): expected_output = sum([param * TEST_PDF_GRID for param in params]) assert_array_equal(func(params), expected_output) - - -def test_pred_and_pdf_func(): - """ - Tests that the pred_and_pdf_func returns the correct values. - """ - pred_and_pdf = model.pred_and_pdf_func(TEST_XGRID, TEST_FORWARD_MAP_DIS) - - params = jnp.array([2, 3]) - predictions, pdf = pred_and_pdf(params, TEST_FK_ARRAYS) - - expected_predictions = jnp.einsum("ijk,jk->i", TEST_FK_ARRAYS[0], pdf) - - assert jnp.allclose(predictions, expected_predictions) From 458636d1406ce5bbcdddb9a605ef398db1698e97 Mon Sep 17 00:00:00 2001 From: Luca Mantani Date: Mon, 16 Feb 2026 13:12:02 +0100 Subject: [PATCH 03/30] removed pred_and_pdf from everywhere --- colibri/analytic_fit.py | 12 +++---- colibri/checks.py | 14 ++++---- .../doc/sphinx/source/theory/pdf_model.rst | 10 ++---- colibri/tests/conftest.py | 9 ------ colibri/tests/test_analytic_fit.py | 32 +++++++------------ colibri/tests/test_checks.py | 29 ++++++++++------- colibri/utils.py | 8 ++--- 7 files changed, 48 insertions(+), 66 deletions(-) diff --git a/colibri/analytic_fit.py b/colibri/analytic_fit.py index 3e75160c5..f1147e6f4 100644 --- a/colibri/analytic_fit.py +++ b/colibri/analytic_fit.py @@ -84,7 +84,7 @@ def analytic_evidence_uniform_prior(sol_covmat, sol_mean, max_logl, a_vec, b_vec @check_pdf_model_is_linear def analytic_fit( central_inv_covmat_index, - _pred_data, + forward_map, pdf_model, analytic_settings, prior_settings, @@ -105,8 +105,8 @@ def analytic_fit( central_inv_covmat_index: commondata_utils.CentralInvCovmatIndex dataclass containing central values and inverse covmat. - _pred_data: @jax.jit CompiledFunction - Prediction function for the fit. + forward_map: @jax.jit CompiledFunction + Forward map function for the fit. pdf_model: pdf_model.PDFModel PDF model to fit. @@ -131,14 +131,14 @@ def analytic_fit( ) parameters = pdf_model.param_names - pred_and_pdf = pdf_model.pred_and_pdf_func(FIT_XGRID, forward_map=_pred_data) # Precompute predictions for the basis of the model bases = jnp.identity(len(parameters)) + pdf_grid = pdf_model.grid_values_func(FIT_XGRID) predictions = jnp.array( - [pred_and_pdf(basis, fast_kernel_arrays)[0] for basis in bases] + [forward_map(pdf_grid(basis), fast_kernel_arrays) for basis in bases] ) - intercept = pred_and_pdf(jnp.zeros(len(parameters)), fast_kernel_arrays)[0] + intercept = forward_map(pdf_grid(jnp.zeros(len(parameters))), fast_kernel_arrays) # Construct the analytic solution central_values = central_inv_covmat_index.central_values diff --git a/colibri/checks.py b/colibri/checks.py index 1ad341879..20d7e7947 100644 --- a/colibri/checks.py +++ b/colibri/checks.py @@ -42,7 +42,7 @@ def check_pdf_models_equal(prior_settings, pdf_model, theoryid): @make_argcheck -def check_pdf_model_is_linear(pdf_model, FIT_XGRID, data): +def check_pdf_model_is_linear(pdf_model, forward_map, FIT_XGRID, data): """ Decorator that can be added to functions to check that the PDF model is linear. @@ -52,8 +52,8 @@ def check_pdf_model_is_linear(pdf_model, FIT_XGRID, data): fk = fast_kernel_arrays(data, FIT_XGRID) parameters = pdf_model.param_names - pred_and_pdf = pdf_model.pred_and_pdf_func(FIT_XGRID, forward_map=pred_data) - intercept = pred_and_pdf(jnp.zeros(len(parameters)), fk)[0] + pdf_grid = pdf_model.grid_values_func(FIT_XGRID) + intercept = forward_map(pdf_grid(jnp.zeros(len(parameters))), fk)[0] # Run the check for 10 random points in the parameter space for i in range(10): @@ -65,16 +65,16 @@ def check_pdf_model_is_linear(pdf_model, FIT_XGRID, data): # Test additivity add_check = jnp.isclose( - pred_and_pdf(x1, fk)[0] + pred_and_pdf(x2, fk)[0], - pred_and_pdf(x1 + x2, fk)[0] + intercept, + forward_map(pdf_grid(x1), fk)[0] + forward_map(pdf_grid(x2), fk)[0], + forward_map(pdf_grid(x1 + x2), fk)[0] + intercept, ) # Test homogeneity c = jax.random.uniform(key, (1,)) homogeneity_check = jnp.isclose( - c * (pred_and_pdf(x1, fk)[0] - intercept), - pred_and_pdf(c * x1, fk)[0] - intercept, + c * (forward_map(pdf_grid(x1), fk)[0] - intercept), + forward_map(pdf_grid(c * x1), fk)[0] - intercept, ) if not add_check.all() or not homogeneity_check.all(): diff --git a/colibri/doc/sphinx/source/theory/pdf_model.rst b/colibri/doc/sphinx/source/theory/pdf_model.rst index 3f59b595f..9a7e29a1c 100644 --- a/colibri/doc/sphinx/source/theory/pdf_model.rst +++ b/colibri/doc/sphinx/source/theory/pdf_model.rst @@ -69,16 +69,10 @@ Prediction Construction To compute physical observables (structure functions, cross sections, etc.), PDFs must be convolved with perturbative coefficient functions. In Colibri, this is handled via -the ``pred_and_pdf_func`` method, which takes the :math:`x`-grid and a forward map from -the PDF to the physical observable, and produces a function taking as input the PDF -parameters and a tuple of fast-kernel arrays: +the ``forward_map`` function, which takes the PDF values on the grid and maps them to predictions for the physical observables. .. math:: - (\boldsymbol{\theta}, FK) \to (\text{predictions}, f_{\rm grid}(\boldsymbol{\theta})) - -This function evaluates the PDF on the grid via ``grid_values_func``, and feeds the -resulting :math:`N_{\rm fl} \times N_{\rm x}` array into the supplied ``forward_map``, -to yield a 1D vector of theory predictions for all data points. + (f_{\rm grid}(\boldsymbol{\theta}), FK) \to \text{predictions} .. note:: Although the prediction function is already implemented, the user is allowed to override diff --git a/colibri/tests/conftest.py b/colibri/tests/conftest.py index 61433515e..c7874c671 100644 --- a/colibri/tests/conftest.py +++ b/colibri/tests/conftest.py @@ -292,16 +292,7 @@ def wmin_param(params): MOCK_PDF_MODEL.grid_values_func = lambda xgrid: lambda params: jnp.sum( jnp.array([param * TEST_PDF_GRID for param in params]), axis=0 ) -""" -Mock PDF model with 2 parameters and grid_values_func simple mult add operation on np.ones grid. -""" -MOCK_PDF_MODEL.pred_and_pdf_func = ( - lambda xgrid, forward_map: lambda params, fast_kernel_arrays: ( - forward_map(MOCK_PDF_MODEL.grid_values_func(xgrid)(params), fast_kernel_arrays), - MOCK_PDF_MODEL.grid_values_func(xgrid)(params), - ) -) """ Mock prediction function of PDF model. """ diff --git a/colibri/tests/test_analytic_fit.py b/colibri/tests/test_analytic_fit.py index 35af4280c..8e6b0d8e0 100644 --- a/colibri/tests/test_analytic_fit.py +++ b/colibri/tests/test_analytic_fit.py @@ -33,21 +33,17 @@ def test_analytic_fit_flat_direction(): """ Tests that the analytic fit raises a ValueError when the - pred_and_pdf_func returns a flat direction in the parameter space. + forward_map returns a flat direction in the parameter space. """ - # override the pred_and_pdf_func to return a flat direction - # in the parameter space - MOCK_PDF_MODEL.pred_and_pdf_func = lambda xgrid, forward_map: ( - lambda params, fkarrs: (jnp.ones_like(params), TEST_PDF_GRID) - ) + n_params = len(MOCK_PDF_MODEL.param_names) - _pred_data = TEST_FORWARD_MAP_DIS + forward_map = lambda pdf, fkarrs: jnp.ones(n_params) with pytest.raises(ValueError): # Run the analytic fit and make sure that the Value Error is raised analytic_fit( MOCK_CENTRAL_INV_COVMAT_INDEX, - _pred_data, + forward_map, MOCK_PDF_MODEL, analytic_settings, TEST_PRIOR_SETTINGS_UNIFORM, @@ -61,16 +57,14 @@ def test_analytic_fit(caplog): Tests basic functionality of the analytic fit function. """ - MOCK_PDF_MODEL.pred_and_pdf_func = lambda xgrid, forward_map: ( - lambda params, fkarrs: (params, TEST_PDF_GRID) - ) + MOCK_PDF_MODEL.grid_values_func = lambda xgrid: lambda params: params - _pred_data = TEST_FORWARD_MAP_DIS + forward_map = lambda pdf, fkarrs: pdf # Run the analytic fit result = analytic_fit( MOCK_CENTRAL_INV_COVMAT_INDEX, - _pred_data, + forward_map, MOCK_PDF_MODEL, analytic_settings, TEST_PRIOR_SETTINGS_UNIFORM, @@ -92,7 +86,7 @@ def test_analytic_fit(caplog): with caplog.at_level(logging.ERROR): # Set the log level to ERROR result_2 = analytic_fit( MOCK_CENTRAL_INV_COVMAT_INDEX, - _pred_data, + forward_map, MOCK_PDF_MODEL, analytic_settings, TEST_PRIOR_SETTINGS_UNIFORM, @@ -121,16 +115,14 @@ def test_analytic_fit_different_priors(caplog): } ) - MOCK_PDF_MODEL.pred_and_pdf_func = lambda xgrid, forward_map: ( - lambda params, fkarrs: (params, TEST_PDF_GRID) - ) + MOCK_PDF_MODEL.grid_values_func = lambda xgrid: lambda params: params - _pred_data = None + forward_map = lambda pdf, fkarrs: pdf # Run the analytic fit result = analytic_fit( MOCK_CENTRAL_INV_COVMAT_INDEX, - _pred_data, + forward_map, MOCK_PDF_MODEL, analytic_settings, PRIOR_SETTINGS1, @@ -156,7 +148,7 @@ def test_analytic_fit_different_priors(caplog): # Run the analytic fit with custom uniform prior result = analytic_fit( MOCK_CENTRAL_INV_COVMAT_INDEX, - _pred_data, + forward_map, MOCK_PDF_MODEL, analytic_settings, PRIOR_SETTINGS2, diff --git a/colibri/tests/test_checks.py b/colibri/tests/test_checks.py index 12495d1da..e25b74b04 100644 --- a/colibri/tests/test_checks.py +++ b/colibri/tests/test_checks.py @@ -125,24 +125,29 @@ def test_check_pdf_model_is_linear(mock_fast_kernel_arrays, mock_make_pred_data) mock_pdf_model = MagicMock() mock_pdf_model.param_names = ["a", "b", "c"] - # Mock the behavior of pred_and_pdf_func to return a linear model - def linear_model(params, fk): - # Simulating a simple linear model: f(x) = a*x + b*y + c*z + 3.0, where params = [a, b, c] - return (jnp.dot(params, fk) + 3.0, params) + # Mock the behavior of pdf_grid to return a linear model + def pdf_linear_model(params): + return params - # Set the mock's pred_and_pdf_func to return the linear_model function - mock_pdf_model.pred_and_pdf_func.return_value = linear_model + def forward_map_lin(pdf, fk): + # Simulating a simple linear model: f(x) = a*x + b*y + c*z + 3.0, where pdf = [a, b, c] + return (jnp.dot(pdf, fk) + 3.0, pdf) + + # Set the mock's grid_values_func to return the linear_model function + mock_pdf_model.grid_values_func.return_value = pdf_linear_model # Test for linear model (should not raise an exception) - check_pdf_model_is_linear.__wrapped__(mock_pdf_model, FIT_XGRID, data) + check_pdf_model_is_linear.__wrapped__( + mock_pdf_model, forward_map_lin, FIT_XGRID, data + ) # Now mock a non-linear model to ensure the ValueError is raised - def non_linear_model(params, fk): + def non_linear_model(pdf, fk): # Introduce some non-linearity - return (jnp.dot(params**2, FIT_XGRID) + fk, params) - - mock_pdf_model.pred_and_pdf_func.return_value = non_linear_model + return (jnp.dot(pdf**2, FIT_XGRID) + fk, pdf) # Ensure ValueError is raised for non-linear model with pytest.raises(ValueError): - check_pdf_model_is_linear.__wrapped__(mock_pdf_model, FIT_XGRID, data) + check_pdf_model_is_linear.__wrapped__( + mock_pdf_model, non_linear_model, FIT_XGRID, data + ) diff --git a/colibri/utils.py b/colibri/utils.py index 296d8aca6..29f3181ee 100644 --- a/colibri/utils.py +++ b/colibri/utils.py @@ -291,7 +291,7 @@ def wrapper(*args, **kwargs): def likelihood_float_type( - _pred_data, + forward_map, pdf_model, FIT_XGRID, bayesian_prior, @@ -308,11 +308,11 @@ def likelihood_float_type( central_values = central_inv_covmat_index.central_values inv_covmat = central_inv_covmat_index.inv_covmat - - pred_and_pdf = pdf_model.pred_and_pdf_func(FIT_XGRID, forward_map=_pred_data) + pdf_grid = pdf_model.grid_values_func(FIT_XGRID) def log_likelihood(params, central_values, inv_covmat, fast_kernel_arrays): - predictions, _ = pred_and_pdf(params, fast_kernel_arrays) + pdf = pdf_grid(params) + predictions, _ = forward_map(pdf, fast_kernel_arrays) return -0.5 * loss_function(central_values, predictions, inv_covmat) params = bayesian_prior( From 7a8c95a05f5fc44baccc15ab09500fadb0076302 Mon Sep 17 00:00:00 2001 From: Luca Mantani Date: Mon, 16 Feb 2026 16:25:25 +0100 Subject: [PATCH 04/30] New forward map class --- colibri/forward_map.py | 147 +++++++++++++++++++++++++++++++++++++++-- colibri/pdf_model.py | 2 +- 2 files changed, 144 insertions(+), 5 deletions(-) diff --git a/colibri/forward_map.py b/colibri/forward_map.py index c34295042..340834947 100644 --- a/colibri/forward_map.py +++ b/colibri/forward_map.py @@ -1,13 +1,152 @@ """ colibri.forward_map.py -This module implements the forward map, i.e. the map from the PDF grid to the theory predictions for each dataset, using the FK tables. +Forward maps: parameters → theory predictions. +A ``ForwardMap`` implements the final stage of the fit pipeline, turning the +fit parameter vector into theory predictions that can be compared with +data in the likelihood. + +Design choice: fixed call signature +----------------------------------- +The log-likelihood calls every forward map with the same fixed signature:: + + (pdf_grid_func, fk_tables, params) -> predictions + +Parameter convention +-------------------- +``params`` is a 1-D array containing *all* fit parameters. In colibri we allow +for "extra" fit parameters beyond the PDF model parameters (e.g. nuisance-like factors, +or parameters of a custom prediction function). + +By convention: + +``params[:self.n_pdf_params]`` are PDF parameters consumed by ``pdf_grid_func``; +any remaining entries are "extra" parameters interpreted by the chosen + ``ForwardMap`` implementation. + +Example - fitting a normalisation factor on top of the PDF +---------------------------------------------------------- +:: + + class NormForwardMap(ForwardMap): + def __init__(self, pred_func, n_pdf_params: int): + super().__init__(n_pdf_params) + self._pred_func = pred_func + + def __call__(self, pdf_grid_func, fk_tables, params): + pdf = pdf_grid_func(params[: self.n_pdf_params]) + norm = params[self.n_pdf_params] # first extra parameter + return norm * self._pred_func(pdf, fk_tables) + +Example - fixed PDF, fitting only extra parameters +--------------------------------------------------- +:: + + class FixedPDFForwardMap(ForwardMap): + def __init__(self, pred_func, fixed_pdf, fk_tables, n_pdf_params: int = 0): + super().__init__(n_pdf_params) + self._pred_func = pred_func + self._fixed_pred = self._pred_func(fixed_pdf, fk_tables) + + def __call__(self, pdf_grid_func, fk_tables, params): + scale = params[0] + return scale * self._fixed_pred """ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Any, Callable + +import jax.numpy as jnp + + +class ForwardMap(ABC): + """Abstract base class for forward maps. -def forward_map(make_pred_data): + A forward map turns fit parameters into theory predictions that can be + compared with experimental data inside the likelihood. + + All forward maps share the same call signature: + + ``(pdf_grid_func, fk_tables, params) -> predictions`` + + Notes + ----- + The split point between PDF parameters and "extra" parameters is owned + by the forward map via ``self.n_pdf_params``. """ - Internal alias function for make_pred_data. + + def __init__(self, n_pdf_params: int): + + self.n_pdf_params = n_pdf_params + + @abstractmethod + def __call__( + self, + pdf_grid_func: Callable[[jnp.ndarray], jnp.ndarray], + fk_tables: Any, + params: jnp.ndarray, + ) -> jnp.ndarray: + """Compute theory predictions from fit parameters. + + Parameters + ---------- + pdf_grid_func : callable + Callable that evaluates PDF values on the fit x-grid from the PDF + parameters. + + Expected call signature: + ``pdf = pdf_grid_func(pdf_params)`` + with ``pdf`` shaped ``(N_fl, N_x)``. + + fk_tables : jnp.ndarray + Fast-kernel tables needed by the prediction function. + + params : jnp.ndarray + 1-D array containing all fit parameters. By convention: + * ``params[:self.n_pdf_params]`` are PDF parameters + * the remaining entries are extra parameters interpreted by the + specific ``ForwardMap`` implementation. + + Returns + ------- + jnp.ndarray + Theory predictions (1-D array with one entry per data point). + """ + raise NotImplementedError + + +class FKTableForwardMap(ForwardMap): + """Default forward map: params → PDF → FK-table convolution. + + This is the standard pipeline used in colibri PDF fits. """ - return make_pred_data + + def __init__( + self, pred_func: Callable[[jnp.ndarray, Any], jnp.ndarray], n_pdf_params: int + ): + super().__init__(n_pdf_params) + self._pred_func = pred_func + + def __call__(self, pdf_grid_func, fk_tables, params): + pdf_params = params[: self.n_pdf_params] + pdf = pdf_grid_func(pdf_params) + return self._pred_func(pdf, fk_tables) + + +def forward_map(_pred_data, pdf_model): + """Reportengine provider that builds the default FK-table forward map. + + Parameters + ---------- + _pred_data : callable + Prediction function of the form ``pred_func(pdf, fk_tables) -> predictions``. + pdf_model : optional + Used to infer ``n_pdf_params`` from ``len(pdf_model.param_names)``. + + """ + + n_pdf_params = len(pdf_model.param_names) + return FKTableForwardMap(_pred_data, n_pdf_params=n_pdf_params) diff --git a/colibri/pdf_model.py b/colibri/pdf_model.py index 11fbf71b5..bdec05866 100644 --- a/colibri/pdf_model.py +++ b/colibri/pdf_model.py @@ -6,7 +6,7 @@ """ from abc import ABC, abstractmethod -from typing import Callable, Tuple +from typing import Callable import jax.numpy as jnp from jax.typing import ArrayLike From 2b898bcfd87146db0540dfb756445bff0c4135b4 Mon Sep 17 00:00:00 2001 From: Luca Mantani Date: Mon, 16 Feb 2026 16:50:47 +0100 Subject: [PATCH 05/30] Refined implementation --- colibri/analytic_fit.py | 4 ++-- colibri/checks.py | 10 +++++----- colibri/forward_map.py | 16 +++++++++++----- colibri/likelihood.py | 3 +-- colibri/utils.py | 3 +-- 5 files changed, 20 insertions(+), 16 deletions(-) diff --git a/colibri/analytic_fit.py b/colibri/analytic_fit.py index f1147e6f4..87ee56fd5 100644 --- a/colibri/analytic_fit.py +++ b/colibri/analytic_fit.py @@ -136,9 +136,9 @@ def analytic_fit( bases = jnp.identity(len(parameters)) pdf_grid = pdf_model.grid_values_func(FIT_XGRID) predictions = jnp.array( - [forward_map(pdf_grid(basis), fast_kernel_arrays) for basis in bases] + [forward_map(pdf_grid, fast_kernel_arrays, basis)[0] for basis in bases] ) - intercept = forward_map(pdf_grid(jnp.zeros(len(parameters))), fast_kernel_arrays) + intercept = forward_map(pdf_grid, fast_kernel_arrays, jnp.zeros(len(parameters)))[0] # Construct the analytic solution central_values = central_inv_covmat_index.central_values diff --git a/colibri/checks.py b/colibri/checks.py index 20d7e7947..d014cb678 100644 --- a/colibri/checks.py +++ b/colibri/checks.py @@ -53,7 +53,7 @@ def check_pdf_model_is_linear(pdf_model, forward_map, FIT_XGRID, data): parameters = pdf_model.param_names pdf_grid = pdf_model.grid_values_func(FIT_XGRID) - intercept = forward_map(pdf_grid(jnp.zeros(len(parameters))), fk)[0] + intercept, _ = forward_map(pdf_grid, fk, jnp.zeros(len(parameters))) # Run the check for 10 random points in the parameter space for i in range(10): @@ -65,16 +65,16 @@ def check_pdf_model_is_linear(pdf_model, forward_map, FIT_XGRID, data): # Test additivity add_check = jnp.isclose( - forward_map(pdf_grid(x1), fk)[0] + forward_map(pdf_grid(x2), fk)[0], - forward_map(pdf_grid(x1 + x2), fk)[0] + intercept, + forward_map(pdf_grid, fk, x1)[0] + forward_map(pdf_grid, fk, x2)[0], + forward_map(pdf_grid, fk, x1 + x2)[0] + intercept, ) # Test homogeneity c = jax.random.uniform(key, (1,)) homogeneity_check = jnp.isclose( - c * (forward_map(pdf_grid(x1), fk)[0] - intercept), - forward_map(pdf_grid(c * x1), fk)[0] - intercept, + c * (forward_map(pdf_grid, fk, x1)[0] - intercept), + forward_map(pdf_grid, fk, c * x1)[0] - intercept, ) if not add_check.all() or not homogeneity_check.all(): diff --git a/colibri/forward_map.py b/colibri/forward_map.py index 340834947..60fde8162 100644 --- a/colibri/forward_map.py +++ b/colibri/forward_map.py @@ -5,13 +5,15 @@ A ``ForwardMap`` implements the final stage of the fit pipeline, turning the fit parameter vector into theory predictions that can be compared with -data in the likelihood. +data in the likelihood. It will also also return the PDF values on the fit x-grid, +which is sometimes needed for computing penalties. + Design choice: fixed call signature ----------------------------------- The log-likelihood calls every forward map with the same fixed signature:: - (pdf_grid_func, fk_tables, params) -> predictions + (pdf_grid_func, fk_tables, params) -> predictions, pdf Parameter convention -------------------- @@ -37,7 +39,7 @@ def __init__(self, pred_func, n_pdf_params: int): def __call__(self, pdf_grid_func, fk_tables, params): pdf = pdf_grid_func(params[: self.n_pdf_params]) norm = params[self.n_pdf_params] # first extra parameter - return norm * self._pred_func(pdf, fk_tables) + return norm * self._pred_func(pdf, fk_tables), pdf Example - fixed PDF, fitting only extra parameters --------------------------------------------------- @@ -47,11 +49,12 @@ class FixedPDFForwardMap(ForwardMap): def __init__(self, pred_func, fixed_pdf, fk_tables, n_pdf_params: int = 0): super().__init__(n_pdf_params) self._pred_func = pred_func + self.fixed_pdf = fixed_pdf self._fixed_pred = self._pred_func(fixed_pdf, fk_tables) def __call__(self, pdf_grid_func, fk_tables, params): scale = params[0] - return scale * self._fixed_pred + return scale * self._fixed_pred, self.fixed_pdf """ from __future__ import annotations @@ -114,6 +117,9 @@ def __call__( ------- jnp.ndarray Theory predictions (1-D array with one entry per data point). + jnp.ndarray + The PDF values (2-D array with shape (N_fl, N_x)). + """ raise NotImplementedError @@ -133,7 +139,7 @@ def __init__( def __call__(self, pdf_grid_func, fk_tables, params): pdf_params = params[: self.n_pdf_params] pdf = pdf_grid_func(pdf_params) - return self._pred_func(pdf, fk_tables) + return self._pred_func(pdf, fk_tables), pdf def forward_map(_pred_data, pdf_model): diff --git a/colibri/likelihood.py b/colibri/likelihood.py index d9ee7fe41..d53d4a211 100644 --- a/colibri/likelihood.py +++ b/colibri/likelihood.py @@ -125,8 +125,7 @@ def log_likelihood( jnp.ndarray jax array with the value of the log-likelihood. """ - pdf = self.pdf_grid(params) - predictions = self.forward_map(pdf, fast_kernel_arrays) + predictions, pdf = self.forward_map(self.pdf_grid, fast_kernel_arrays, params) # Select only the data relevant for this likelihood # Especially important when using a training/validation split predictions = predictions[self.central_values_idx] diff --git a/colibri/utils.py b/colibri/utils.py index 29f3181ee..3d059d14e 100644 --- a/colibri/utils.py +++ b/colibri/utils.py @@ -311,8 +311,7 @@ def likelihood_float_type( pdf_grid = pdf_model.grid_values_func(FIT_XGRID) def log_likelihood(params, central_values, inv_covmat, fast_kernel_arrays): - pdf = pdf_grid(params) - predictions, _ = forward_map(pdf, fast_kernel_arrays) + predictions, pdf = forward_map(pdf_grid, fast_kernel_arrays, params) return -0.5 * loss_function(central_values, predictions, inv_covmat) params = bayesian_prior( From 0d8a24172b12cd9082a40f258a9ef9386a5cbceb Mon Sep 17 00:00:00 2001 From: Luca Mantani Date: Mon, 16 Feb 2026 16:51:20 +0100 Subject: [PATCH 06/30] Changed conftest --- colibri/tests/conftest.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/colibri/tests/conftest.py b/colibri/tests/conftest.py index c7874c671..e231ef504 100644 --- a/colibri/tests/conftest.py +++ b/colibri/tests/conftest.py @@ -11,6 +11,7 @@ from colibri.pdf_model import PDFModel from colibri.core import PriorSettings +from colibri.forward_map import FKTableForwardMap CONFIG_YML_PATH = "test_runcards/test_config.yaml" @@ -324,7 +325,10 @@ def wmin_param(params): """ -TEST_FORWARD_MAP_DIS = lambda pdf, fk_arrays: jnp.einsum("ijk,jk->i", fk_arrays[0], pdf) +TEST_FORWARD_MAP_DIS = FKTableForwardMap( + lambda pdf, fk_arrays: jnp.einsum("ijk,jk->i", fk_arrays[0], pdf), + n_pdf_params=2, +) """ Mock DIS forward map function for testing purposes. Function expects a tuple of DIS-like fast kernel array of shape (N_data, TEST_N_FL, TEST_N_XGRID) and a PDF of shape (TEST_N_FL, TEST_N_XGRID). From 744b27558c233191363144132a3a3688a4e31c48 Mon Sep 17 00:00:00 2001 From: Luca Mantani Date: Mon, 16 Feb 2026 17:13:15 +0100 Subject: [PATCH 07/30] Fixed tests --- colibri/tests/test_analytic_fit.py | 13 ++++++++--- colibri/tests/test_checks.py | 14 ++++++++---- colibri/tests/test_likelihood.py | 35 +++++++++++++---------------- colibri/tests/test_ultranest_fit.py | 26 ++++++++++++++++----- colibri/tests/test_utils.py | 6 ++++- 5 files changed, 62 insertions(+), 32 deletions(-) diff --git a/colibri/tests/test_analytic_fit.py b/colibri/tests/test_analytic_fit.py index 8e6b0d8e0..6b4178e50 100644 --- a/colibri/tests/test_analytic_fit.py +++ b/colibri/tests/test_analytic_fit.py @@ -13,6 +13,7 @@ from colibri.analytic_fit import AnalyticFit, analytic_fit, run_analytic_fit from colibri.core import PriorSettings +from colibri.forward_map import FKTableForwardMap from colibri.tests.conftest import ( MOCK_CENTRAL_INV_COVMAT_INDEX, MOCK_PDF_MODEL, @@ -37,7 +38,9 @@ def test_analytic_fit_flat_direction(): """ n_params = len(MOCK_PDF_MODEL.param_names) - forward_map = lambda pdf, fkarrs: jnp.ones(n_params) + forward_map = FKTableForwardMap( + lambda pdf, fkarrs: jnp.ones(n_params), n_pdf_params=n_params + ) with pytest.raises(ValueError): # Run the analytic fit and make sure that the Value Error is raised @@ -59,7 +62,9 @@ def test_analytic_fit(caplog): MOCK_PDF_MODEL.grid_values_func = lambda xgrid: lambda params: params - forward_map = lambda pdf, fkarrs: pdf + forward_map = FKTableForwardMap( + lambda pdf, fkarrs: pdf, n_pdf_params=len(MOCK_PDF_MODEL.param_names) + ) # Run the analytic fit result = analytic_fit( @@ -117,7 +122,9 @@ def test_analytic_fit_different_priors(caplog): MOCK_PDF_MODEL.grid_values_func = lambda xgrid: lambda params: params - forward_map = lambda pdf, fkarrs: pdf + forward_map = FKTableForwardMap( + lambda pdf, fkarrs: pdf, n_pdf_params=len(MOCK_PDF_MODEL.param_names) + ) # Run the analytic fit result = analytic_fit( diff --git a/colibri/tests/test_checks.py b/colibri/tests/test_checks.py index e25b74b04..5b02e458f 100644 --- a/colibri/tests/test_checks.py +++ b/colibri/tests/test_checks.py @@ -9,6 +9,8 @@ import jax.numpy as jnp import pytest +from colibri.forward_map import FKTableForwardMap + from colibri.checks import check_pdf_model_is_linear, check_pdf_models_equal from colibri.core import PriorSettings @@ -129,9 +131,11 @@ def test_check_pdf_model_is_linear(mock_fast_kernel_arrays, mock_make_pred_data) def pdf_linear_model(params): return params - def forward_map_lin(pdf, fk): + forward_map_lin = FKTableForwardMap( # Simulating a simple linear model: f(x) = a*x + b*y + c*z + 3.0, where pdf = [a, b, c] - return (jnp.dot(pdf, fk) + 3.0, pdf) + lambda pdf, fk: jnp.dot(pdf, fk) + 3.0, + n_pdf_params=3, + ) # Set the mock's grid_values_func to return the linear_model function mock_pdf_model.grid_values_func.return_value = pdf_linear_model @@ -142,9 +146,11 @@ def forward_map_lin(pdf, fk): ) # Now mock a non-linear model to ensure the ValueError is raised - def non_linear_model(pdf, fk): + non_linear_model = FKTableForwardMap( # Introduce some non-linearity - return (jnp.dot(pdf**2, FIT_XGRID) + fk, pdf) + lambda pdf, fk: jnp.dot(pdf**2, FIT_XGRID) + fk, + n_pdf_params=3, + ) # Ensure ValueError is raised for non-linear model with pytest.raises(ValueError): diff --git a/colibri/tests/test_likelihood.py b/colibri/tests/test_likelihood.py index 18ca2475a..d218cae60 100644 --- a/colibri/tests/test_likelihood.py +++ b/colibri/tests/test_likelihood.py @@ -66,9 +66,8 @@ def test_LogLikelihood_class(pos_penalty): ] ) # Compute expected value using actual prediction and covariance - pdf = log_likelihood_class.pdf_grid(params) - predictions = log_likelihood_class.forward_map( - pdf, log_likelihood_class.fast_kernel_arrays + predictions, pdf = log_likelihood_class.forward_map( + log_likelihood_class.pdf_grid, log_likelihood_class.fast_kernel_arrays, params ) predictions = predictions[log_likelihood_class.central_values_idx] diff = predictions - log_likelihood_class.central_values @@ -166,9 +165,8 @@ def test_log_likelihood_with_and_without_pos_penalty(): ) # Compute expectation directly: -0.5 * (chi2 + pos_pen + integ_pen) - pdf = log_likelihood_class.pdf_grid(params) - predictions = log_likelihood_class.forward_map( - pdf, log_likelihood_class.fast_kernel_arrays + predictions, pdf = log_likelihood_class.forward_map( + log_likelihood_class.pdf_grid, log_likelihood_class.fast_kernel_arrays, params ) predictions = predictions[log_likelihood_class.central_values_idx] diff = predictions - log_likelihood_class.central_values @@ -215,9 +213,8 @@ def test_log_likelihood_with_and_without_pos_penalty(): ) # Expectation: Only chi2 value (penalties zeroed) - pdf = log_likelihood_class.pdf_grid(params) - predictions = log_likelihood_class.forward_map( - pdf, log_likelihood_class.fast_kernel_arrays + predictions, pdf = log_likelihood_class.forward_map( + log_likelihood_class.pdf_grid, log_likelihood_class.fast_kernel_arrays, params ) predictions = predictions[log_likelihood_class.central_values_idx] diff = predictions - log_likelihood_class.central_values @@ -278,8 +275,9 @@ def test_mc_log_likelihood_with_split(pos_penalty): # Compute expected for train and validation independently def compute_expected(ll_obj): - pdf = ll_obj.pdf_grid(params) - preds = ll_obj.forward_map(pdf, ll_obj.fast_kernel_arrays) + preds, pdf = ll_obj.forward_map( + ll_obj.pdf_grid, ll_obj.fast_kernel_arrays, params + ) preds = preds[ll_obj.central_values_idx] diff = preds - ll_obj.central_values inv = ll_obj.inv_covmat @@ -352,8 +350,9 @@ def test_mc_log_likelihood_without_split_returns_nan_for_validation(pos_penalty) params = jnp.array([0.3, 0.4]) train_val = train_loglike(params) # Compute expected train value - pdf = train_loglike.pdf_grid(params) - predictions = train_loglike.forward_map(pdf, train_loglike.fast_kernel_arrays) + predictions, pdf = train_loglike.forward_map( + train_loglike.pdf_grid, train_loglike.fast_kernel_arrays, params + ) predictions = predictions[train_loglike.central_values_idx] diff = predictions - train_loglike.central_values chi2_val = jnp.einsum("i,ij,j", diff, train_loglike.inv_covmat, diff) @@ -412,9 +411,8 @@ def test_LogLikelihood_call_with_batch_idx(pos_penalty): ll_value_batched = log_likelihood_class(params, batch=batch) # Compute expected on the batch index: recompute inv_covmat on the sub-covmat - pdf = log_likelihood_class.pdf_grid(params) - predictions = log_likelihood_class.forward_map( - pdf, log_likelihood_class.fast_kernel_arrays + predictions, pdf = log_likelihood_class.forward_map( + log_likelihood_class.pdf_grid, log_likelihood_class.fast_kernel_arrays, params ) predictions = predictions[log_likelihood_class.central_values_idx] predictions_b = predictions[batch.idx] @@ -480,9 +478,8 @@ def test_LogLikelihood_call_with_batch_with_inv_cov(pos_penalty): ll_value_batched = log_likelihood_class(params, batch=batch) # Compute expected value using the provided inv_b (should be identical) - pdf = log_likelihood_class.pdf_grid(params) - predictions = log_likelihood_class.forward_map( - pdf, log_likelihood_class.fast_kernel_arrays + predictions, pdf = log_likelihood_class.forward_map( + log_likelihood_class.pdf_grid, log_likelihood_class.fast_kernel_arrays, params ) predictions = predictions[log_likelihood_class.central_values_idx] predictions_b = predictions[batch.idx] diff --git a/colibri/tests/test_ultranest_fit.py b/colibri/tests/test_ultranest_fit.py index 323a1c776..19f222136 100644 --- a/colibri/tests/test_ultranest_fit.py +++ b/colibri/tests/test_ultranest_fit.py @@ -21,6 +21,7 @@ ) from colibri.ultranest_fit import UltranestFit, run_ultranest_fit, ultranest_fit from colibri.likelihood import LogLikelihood +from colibri.forward_map import FKTableForwardMap jax.config.update("jax_enable_x64", True) @@ -47,11 +48,14 @@ def test_ultranest_fit(pos_penalty): _pred_data = lambda *args: jnp.array([0.0]) + forward_map = FKTableForwardMap( + _pred_data, n_pdf_params=len(MOCK_PDF_MODEL.param_names) + ) mock_log_likelihood = LogLikelihood( MOCK_CENTRAL_COVMAT_INDEX, MOCK_PDF_MODEL, TEST_XGRID, - _pred_data, + forward_map, TEST_FK_ARRAYS, TEST_POS_FK_ARRAYS, MOCK_PENALTY_POSDATA, @@ -84,13 +88,16 @@ def test_ultranest_fit(pos_penalty): def test_ultranest_fit_vectorized(pos_penalty): _pred_data = lambda *args: jnp.array([0.0]) + forward_map = FKTableForwardMap( + _pred_data, n_pdf_params=len(MOCK_PDF_MODEL.param_names) + ) ultranest_settings["ReactiveNS_settings"]["vectorized"] = True mock_log_likelihood = LogLikelihood( MOCK_CENTRAL_COVMAT_INDEX, MOCK_PDF_MODEL, TEST_XGRID, - _pred_data, + forward_map, TEST_FK_ARRAYS, TEST_POS_FK_ARRAYS, MOCK_PENALTY_POSDATA, @@ -133,12 +140,15 @@ def test_ultranest_fit_with_SliceSampler(pos_penalty): } _pred_data = lambda *args: jnp.array([0.0]) + forward_map = FKTableForwardMap( + _pred_data, n_pdf_params=len(MOCK_PDF_MODEL.param_names) + ) mock_log_likelihood = LogLikelihood( MOCK_CENTRAL_COVMAT_INDEX, MOCK_PDF_MODEL, TEST_XGRID, - _pred_data, + forward_map, TEST_FK_ARRAYS, TEST_POS_FK_ARRAYS, MOCK_PENALTY_POSDATA, @@ -181,12 +191,15 @@ def test_ultranest_fit_with_popSliceSampler(pos_penalty): } _pred_data = lambda *args: jnp.array([0.0]) + forward_map = FKTableForwardMap( + _pred_data, n_pdf_params=len(MOCK_PDF_MODEL.param_names) + ) mock_log_likelihood = LogLikelihood( MOCK_CENTRAL_COVMAT_INDEX, MOCK_PDF_MODEL, TEST_XGRID, - _pred_data, + forward_map, TEST_FK_ARRAYS, TEST_POS_FK_ARRAYS, MOCK_PENALTY_POSDATA, @@ -233,12 +246,15 @@ def test_ultranest_fit_with_sampler_plot(mock_sampler_class, pos_penalty): } _pred_data = lambda *args: jnp.array([0.0]) + forward_map = FKTableForwardMap( + _pred_data, n_pdf_params=len(MOCK_PDF_MODEL.param_names) + ) mock_log_likelihood = LogLikelihood( MOCK_CENTRAL_COVMAT_INDEX, MOCK_PDF_MODEL, TEST_XGRID, - _pred_data, + forward_map, TEST_FK_ARRAYS, TEST_POS_FK_ARRAYS, MOCK_PENALTY_POSDATA, diff --git a/colibri/tests/test_utils.py b/colibri/tests/test_utils.py index 32a2dc360..9affe7bbd 100644 --- a/colibri/tests/test_utils.py +++ b/colibri/tests/test_utils.py @@ -31,6 +31,7 @@ TEST_DATASET, TEST_DATASET_HAD, ) +from colibri.forward_map import FKTableForwardMap from colibri.utils import ( cast_to_numpy, closest_indices, @@ -338,6 +339,9 @@ def test_likelihood_float_type( _pred_data = lambda x, fks: jnp.ones( len(MOCK_CENTRAL_INV_COVMAT_INDEX.central_values) ) # Mock _pred_data + forward_map = FKTableForwardMap( + _pred_data, n_pdf_params=len(MOCK_PDF_MODEL.param_names) + ) # Mock forward_map FIT_XGRID = jnp.linspace(0, 1, 10) # Mock FIT_XGRID output_path = tmp_path @@ -347,7 +351,7 @@ def test_likelihood_float_type( # Call the function under test likelihood_float_type( - _pred_data=_pred_data, + forward_map=forward_map, pdf_model=MOCK_PDF_MODEL, FIT_XGRID=FIT_XGRID, bayesian_prior=mock_bayesian_prior, From 29dfd7ec22a244f35517fda7da95a51fd278f99c Mon Sep 17 00:00:00 2001 From: Luca Mantani Date: Mon, 16 Feb 2026 17:20:13 +0100 Subject: [PATCH 08/30] Make sure we write pdfs with the first parameters --- colibri/export_results.py | 3 ++- colibri/mc_utils.py | 3 ++- colibri/utils.py | 3 ++- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/colibri/export_results.py b/colibri/export_results.py index d294874d9..8df0bfce6 100644 --- a/colibri/export_results.py +++ b/colibri/export_results.py @@ -246,10 +246,11 @@ def write_replicas( # Create the exportgrid lhapdf_interpolator = pdf_model.grid_values_func(xgrid) + n_pdf_params = len(pdf_model.param_names) # Finish by writing the replicas to export grids, ready for evolution for i in indices_per_process: - parameters = jnp.array(bayes_fit.resampled_posterior[i, :]) + parameters = jnp.array(bayes_fit.resampled_posterior[i, :n_pdf_params]) grid_for_writing = np.array(lhapdf_interpolator(parameters)) replica_index = i + 1 diff --git a/colibri/mc_utils.py b/colibri/mc_utils.py index 5dc9301bb..b1a4cfad6 100644 --- a/colibri/mc_utils.py +++ b/colibri/mc_utils.py @@ -109,9 +109,10 @@ def write_exportgrid_mc( # Create the exportgrid lhapdf_interpolator = pdf_model.grid_values_func(LHAPDF_XGRID) + n_pdf_params = len(pdf_model.param_names) # Rotate the grid from the evolution basis into the export grid basis - grid_for_writing = np.array(lhapdf_interpolator(parameters)) + grid_for_writing = np.array(lhapdf_interpolator(parameters[:n_pdf_params])) write_exportgrid( grid_for_writing=grid_for_writing, diff --git a/colibri/utils.py b/colibri/utils.py index 3d059d14e..ab419c0e7 100644 --- a/colibri/utils.py +++ b/colibri/utils.py @@ -453,6 +453,7 @@ def write_resampled_bayesian_fit( # overwrite old ns_result.csv with resampled posterior parameters = pdf_model.param_names + n_pdf_params = len(pdf_model.param_names) df = pd.DataFrame(resampled_posterior, columns=parameters) df.to_csv(str(resampled_fit_path) + f"/{csv_results_name}.csv", float_format="%.5e") @@ -465,7 +466,7 @@ def write_resampled_bayesian_fit( for i, parameters in enumerate(resampled_posterior): # Get the PDF grid in the evolution basis lhapdf_interpolator = pdf_model.grid_values_func(LHAPDF_XGRID) - grid_for_writing = np.array(lhapdf_interpolator(parameters)) + grid_for_writing = np.array(lhapdf_interpolator(parameters[:n_pdf_params])) replica_index = i + 1 From 0fd84f31a76462e3d017bac73c563f4738bb6bf3 Mon Sep 17 00:00:00 2001 From: Luca Mantani Date: Mon, 16 Feb 2026 17:29:54 +0100 Subject: [PATCH 09/30] Restored doc --- colibri/doc/sphinx/source/theory/pdf_model.rst | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/colibri/doc/sphinx/source/theory/pdf_model.rst b/colibri/doc/sphinx/source/theory/pdf_model.rst index 9a7e29a1c..3f59b595f 100644 --- a/colibri/doc/sphinx/source/theory/pdf_model.rst +++ b/colibri/doc/sphinx/source/theory/pdf_model.rst @@ -69,10 +69,16 @@ Prediction Construction To compute physical observables (structure functions, cross sections, etc.), PDFs must be convolved with perturbative coefficient functions. In Colibri, this is handled via -the ``forward_map`` function, which takes the PDF values on the grid and maps them to predictions for the physical observables. +the ``pred_and_pdf_func`` method, which takes the :math:`x`-grid and a forward map from +the PDF to the physical observable, and produces a function taking as input the PDF +parameters and a tuple of fast-kernel arrays: .. math:: - (f_{\rm grid}(\boldsymbol{\theta}), FK) \to \text{predictions} + (\boldsymbol{\theta}, FK) \to (\text{predictions}, f_{\rm grid}(\boldsymbol{\theta})) + +This function evaluates the PDF on the grid via ``grid_values_func``, and feeds the +resulting :math:`N_{\rm fl} \times N_{\rm x}` array into the supplied ``forward_map``, +to yield a 1D vector of theory predictions for all data points. .. note:: Although the prediction function is already implemented, the user is allowed to override From 5dcb3c3ce5718f72f150b3c1b4bc930f2bf21c67 Mon Sep 17 00:00:00 2001 From: Luca Mantani Date: Mon, 16 Feb 2026 17:56:48 +0100 Subject: [PATCH 10/30] Fixed bug in tests --- colibri/tests/test_analytic_fit.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/colibri/tests/test_analytic_fit.py b/colibri/tests/test_analytic_fit.py index 6b4178e50..b48aa8ba4 100644 --- a/colibri/tests/test_analytic_fit.py +++ b/colibri/tests/test_analytic_fit.py @@ -55,12 +55,15 @@ def test_analytic_fit_flat_direction(): ) -def test_analytic_fit(caplog): +def test_analytic_fit(caplog, monkeypatch): """ Tests basic functionality of the analytic fit function. """ - MOCK_PDF_MODEL.grid_values_func = lambda xgrid: lambda params: params + # Mock the grid_values_func of the PDF model within the test to return the input parameters as the PDF grid values + monkeypatch.setattr( + MOCK_PDF_MODEL, "grid_values_func", lambda xgrid: lambda params: params + ) forward_map = FKTableForwardMap( lambda pdf, fkarrs: pdf, n_pdf_params=len(MOCK_PDF_MODEL.param_names) @@ -111,7 +114,7 @@ def test_analytic_fit(caplog): assert len(result_2.param_names) == len(MOCK_PDF_MODEL.param_names) -def test_analytic_fit_different_priors(caplog): +def test_analytic_fit_different_priors(caplog, monkeypatch): PRIOR_SETTINGS1 = PriorSettings( **{ @@ -120,7 +123,10 @@ def test_analytic_fit_different_priors(caplog): } ) - MOCK_PDF_MODEL.grid_values_func = lambda xgrid: lambda params: params + # Mock the grid_values_func of the PDF model within the test to return the input parameters as the PDF grid values + monkeypatch.setattr( + MOCK_PDF_MODEL, "grid_values_func", lambda xgrid: lambda params: params + ) forward_map = FKTableForwardMap( lambda pdf, fkarrs: pdf, n_pdf_params=len(MOCK_PDF_MODEL.param_names) From 1e57a68977ed007f5ead991e66a5a9b8dde940e6 Mon Sep 17 00:00:00 2001 From: Mark Nestor Costantini Date: Mon, 23 Mar 2026 11:48:52 +0000 Subject: [PATCH 11/30] use check_pdf_model_is_linear as function rather than decorator --- colibri/analytic_fit.py | 8 +++++++- colibri/checks.py | 1 - colibri/tests/conftest.py | 2 +- colibri/tests/test_analytic_fit.py | 11 +++++++++-- colibri/tests/test_checks.py | 8 ++------ 5 files changed, 19 insertions(+), 11 deletions(-) diff --git a/colibri/analytic_fit.py b/colibri/analytic_fit.py index 87ee56fd5..452f33de0 100644 --- a/colibri/analytic_fit.py +++ b/colibri/analytic_fit.py @@ -81,7 +81,6 @@ def analytic_evidence_uniform_prior(sol_covmat, sol_mean, max_logl, a_vec, b_vec return log_evidence, log_occam_factor -@check_pdf_model_is_linear def analytic_fit( central_inv_covmat_index, forward_map, @@ -90,6 +89,7 @@ def analytic_fit( prior_settings, FIT_XGRID, fast_kernel_arrays, + data, ): """ Analytic fits, for any *linear* PDF model. @@ -123,7 +123,13 @@ def analytic_fit( fast_kernel_arrays: tuple Tuple containing the fast kernel arrays. + + data: validphys.core.DataGroupSpec + The data group specification for the fit. """ + # Ensure that the PDF model is linear before running the fit. + log.info("Checking that the PDF model is linear...") + check_pdf_model_is_linear(pdf_model, forward_map, FIT_XGRID, data) log.warning("The prior is assumed to be flat in the parameters.") log.warning( diff --git a/colibri/checks.py b/colibri/checks.py index d014cb678..0c7e0571e 100644 --- a/colibri/checks.py +++ b/colibri/checks.py @@ -41,7 +41,6 @@ def check_pdf_models_equal(prior_settings, pdf_model, theoryid): ) -@make_argcheck def check_pdf_model_is_linear(pdf_model, forward_map, FIT_XGRID, data): """ Decorator that can be added to functions to check that the diff --git a/colibri/tests/conftest.py b/colibri/tests/conftest.py index e231ef504..d1ce8f59c 100644 --- a/colibri/tests/conftest.py +++ b/colibri/tests/conftest.py @@ -149,7 +149,7 @@ "dataset_inputs": [ # Hadronic {"dataset": "DYE866_Z0_800GEV_DW_RATIO_PDXSECRATIO", "variant": "legacy"}, - {"dataset": "DYE866_Z0_800GEV_PXSEC", "variant": "legacy"}, + {"dataset": "DYE866_Z0_800GEV_PXSEC", "varian t": "legacy"}, {"dataset": "DYE605_Z0_38P8GEV_DW_PXSEC", "variant": "legacy"}, { "dataset": "DYE906_Z0_120GEV_DW_PDXSECRATIO", diff --git a/colibri/tests/test_analytic_fit.py b/colibri/tests/test_analytic_fit.py index b48aa8ba4..ffe96d599 100644 --- a/colibri/tests/test_analytic_fit.py +++ b/colibri/tests/test_analytic_fit.py @@ -11,6 +11,7 @@ import jax.random import pytest +from colibri.api import API as colibriAPI from colibri.analytic_fit import AnalyticFit, analytic_fit, run_analytic_fit from colibri.core import PriorSettings from colibri.forward_map import FKTableForwardMap @@ -18,10 +19,9 @@ MOCK_CENTRAL_INV_COVMAT_INDEX, MOCK_PDF_MODEL, TEST_FK_ARRAYS, - TEST_FORWARD_MAP_DIS, - TEST_PDF_GRID, TEST_PRIOR_SETTINGS_UNIFORM, TEST_XGRID, + TEST_DATASETS, ) analytic_settings = { @@ -30,6 +30,8 @@ "n_posterior_samples": 10, } +TEST_DATA = colibriAPI.data(**TEST_DATASETS) + def test_analytic_fit_flat_direction(): """ @@ -52,6 +54,7 @@ def test_analytic_fit_flat_direction(): TEST_PRIOR_SETTINGS_UNIFORM, TEST_XGRID, TEST_FK_ARRAYS, + TEST_DATA, ) @@ -78,6 +81,7 @@ def test_analytic_fit(caplog, monkeypatch): TEST_PRIOR_SETTINGS_UNIFORM, TEST_XGRID, TEST_FK_ARRAYS, + TEST_DATA, ) assert isinstance(result, AnalyticFit) @@ -100,6 +104,7 @@ def test_analytic_fit(caplog, monkeypatch): TEST_PRIOR_SETTINGS_UNIFORM, TEST_XGRID, TEST_FK_ARRAYS, + TEST_DATA, ) # Check that an error message was logged, because the prior was not wide enough @@ -141,6 +146,7 @@ def test_analytic_fit_different_priors(caplog, monkeypatch): PRIOR_SETTINGS1, TEST_XGRID, TEST_FK_ARRAYS, + TEST_DATA, ) assert isinstance(result, AnalyticFit) @@ -167,6 +173,7 @@ def test_analytic_fit_different_priors(caplog, monkeypatch): PRIOR_SETTINGS2, TEST_XGRID, TEST_FK_ARRAYS, + TEST_DATA, ) diff --git a/colibri/tests/test_checks.py b/colibri/tests/test_checks.py index 5b02e458f..f1aaf4c9f 100644 --- a/colibri/tests/test_checks.py +++ b/colibri/tests/test_checks.py @@ -141,9 +141,7 @@ def pdf_linear_model(params): mock_pdf_model.grid_values_func.return_value = pdf_linear_model # Test for linear model (should not raise an exception) - check_pdf_model_is_linear.__wrapped__( - mock_pdf_model, forward_map_lin, FIT_XGRID, data - ) + check_pdf_model_is_linear(mock_pdf_model, forward_map_lin, FIT_XGRID, data) # Now mock a non-linear model to ensure the ValueError is raised non_linear_model = FKTableForwardMap( @@ -154,6 +152,4 @@ def pdf_linear_model(params): # Ensure ValueError is raised for non-linear model with pytest.raises(ValueError): - check_pdf_model_is_linear.__wrapped__( - mock_pdf_model, non_linear_model, FIT_XGRID, data - ) + check_pdf_model_is_linear(mock_pdf_model, non_linear_model, FIT_XGRID, data) From 1afebf6d1c36a99c169f6f3b1b834654824d58ef Mon Sep 17 00:00:00 2001 From: Mark Nestor Costantini <85164495+comane@users.noreply.github.com> Date: Mon, 23 Mar 2026 17:21:30 +0000 Subject: [PATCH 12/30] Apply suggestion from @comane --- colibri/forward_map.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colibri/forward_map.py b/colibri/forward_map.py index 60fde8162..4e28efe54 100644 --- a/colibri/forward_map.py +++ b/colibri/forward_map.py @@ -155,4 +155,4 @@ def forward_map(_pred_data, pdf_model): """ n_pdf_params = len(pdf_model.param_names) - return FKTableForwardMap(_pred_data, n_pdf_params=n_pdf_params) + return FKTableForwardMap(pred_func=_pred_data, n_pdf_params=n_pdf_params) From 8f7cffd2bbc6c255e738920db0f36c8bb9755199 Mon Sep 17 00:00:00 2001 From: Mark Nestor Costantini Date: Tue, 24 Mar 2026 10:34:08 +0000 Subject: [PATCH 13/30] added tests for forward map --- colibri/tests/conftest.py | 2 +- colibri/tests/test_forward_map.py | 240 ++++++++++++++++++++++++++++++ 2 files changed, 241 insertions(+), 1 deletion(-) create mode 100644 colibri/tests/test_forward_map.py diff --git a/colibri/tests/conftest.py b/colibri/tests/conftest.py index d1ce8f59c..e231ef504 100644 --- a/colibri/tests/conftest.py +++ b/colibri/tests/conftest.py @@ -149,7 +149,7 @@ "dataset_inputs": [ # Hadronic {"dataset": "DYE866_Z0_800GEV_DW_RATIO_PDXSECRATIO", "variant": "legacy"}, - {"dataset": "DYE866_Z0_800GEV_PXSEC", "varian t": "legacy"}, + {"dataset": "DYE866_Z0_800GEV_PXSEC", "variant": "legacy"}, {"dataset": "DYE605_Z0_38P8GEV_DW_PXSEC", "variant": "legacy"}, { "dataset": "DYE906_Z0_120GEV_DW_PDXSECRATIO", diff --git a/colibri/tests/test_forward_map.py b/colibri/tests/test_forward_map.py new file mode 100644 index 000000000..c1bcb7fab --- /dev/null +++ b/colibri/tests/test_forward_map.py @@ -0,0 +1,240 @@ +""" +colibri.tests.test_forward_map + +Tests for the ForwardMap abstract base class, FKTableForwardMap, and the +forward_map provider function. +""" + +import pytest +import numpy as np +import jax.numpy as jnp +from unittest.mock import Mock +from numpy.testing import assert_array_almost_equal + +from colibri.forward_map import ForwardMap, FKTableForwardMap, forward_map +from colibri.tests.conftest import ( + TEST_FK_ARRAYS, + TEST_PDF_GRID, + TEST_N_DATA, + TEST_N_FL, + TEST_N_XGRID, + MOCK_PDF_MODEL, +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _simple_pred_func(pdf, fk_tables): + """DIS-like prediction: einsum over the first FK table.""" + return jnp.einsum("ijk,jk->i", fk_tables[0], pdf) + + +def _make_pdf_grid_func(pdf_grid): + """Return a callable that ignores params and always returns pdf_grid.""" + return lambda params: pdf_grid + + +# --------------------------------------------------------------------------- +# ForwardMap (abstract base class) +# --------------------------------------------------------------------------- + + +def test_forward_map_cannot_be_instantiated(): + """ForwardMap is abstract; direct instantiation must raise TypeError.""" + with pytest.raises(TypeError): + ForwardMap(n_pdf_params=2) + + +def test_forward_map_subclass_without_call_cannot_be_instantiated(): + """A subclass that does not implement __call__ must also raise TypeError.""" + + class NoCallSubclass(ForwardMap): + pass + + with pytest.raises(TypeError): + NoCallSubclass(n_pdf_params=2) + + +def test_forward_map_subclass_stores_n_pdf_params(): + """n_pdf_params passed to super().__init__ must be stored on the instance.""" + + class MinimalForwardMap(ForwardMap): + def __call__(self, pdf_grid_func, fk_tables, params): + pdf = pdf_grid_func(params[: self.n_pdf_params]) + return _simple_pred_func(pdf, fk_tables), pdf + + fm = MinimalForwardMap(n_pdf_params=5) + assert fm.n_pdf_params == 5 + + +# --------------------------------------------------------------------------- +# FKTableForwardMap.__init__ +# --------------------------------------------------------------------------- + + +def test_fktable_forward_map_stores_n_pdf_params(): + """FKTableForwardMap.__init__ must store n_pdf_params via the base class.""" + fm = FKTableForwardMap(pred_func=_simple_pred_func, n_pdf_params=3) + assert fm.n_pdf_params == 3 + + +def test_fktable_forward_map_stores_pred_func(): + """FKTableForwardMap.__init__ must store the pred_func.""" + fm = FKTableForwardMap(pred_func=_simple_pred_func, n_pdf_params=3) + assert fm._pred_func is _simple_pred_func + + +# --------------------------------------------------------------------------- +# FKTableForwardMap.__call__ +# --------------------------------------------------------------------------- + + +def test_fktable_forward_map_returns_tuple(): + """__call__ must return a 2-tuple (predictions, pdf).""" + fm = FKTableForwardMap(pred_func=_simple_pred_func, n_pdf_params=2) + pdf_grid_func = _make_pdf_grid_func(TEST_PDF_GRID) + params = jnp.array([1.0, 2.0]) + + result = fm(pdf_grid_func, TEST_FK_ARRAYS, params) + + assert isinstance(result, tuple) + assert len(result) == 2 + + +def test_fktable_forward_map_predictions_shape(): + """Predictions returned by __call__ must have shape (N_data,).""" + fm = FKTableForwardMap(pred_func=_simple_pred_func, n_pdf_params=2) + pdf_grid_func = _make_pdf_grid_func(TEST_PDF_GRID) + params = jnp.array([1.0, 2.0]) + + predictions, _ = fm(pdf_grid_func, TEST_FK_ARRAYS, params) + + assert predictions.shape == (TEST_N_DATA,) + + +def test_fktable_forward_map_pdf_shape(): + """PDF returned by __call__ must have shape (N_fl, N_x).""" + fm = FKTableForwardMap(pred_func=_simple_pred_func, n_pdf_params=2) + pdf_grid_func = _make_pdf_grid_func(TEST_PDF_GRID) + params = jnp.array([1.0, 2.0]) + + _, pdf = fm(pdf_grid_func, TEST_FK_ARRAYS, params) + + assert pdf.shape == (TEST_N_FL, TEST_N_XGRID) + + +def test_fktable_forward_map_slices_pdf_params(): + """ + __call__ must pass only params[:n_pdf_params] to pdf_grid_func; extra + parameters appended to params must not affect the PDF or predictions. + """ + n_pdf = 2 + fm = FKTableForwardMap(pred_func=_simple_pred_func, n_pdf_params=n_pdf) + pdf_grid_func = _make_pdf_grid_func(TEST_PDF_GRID) + + pdf_params = jnp.array([1.0, 2.0]) + extra_params = jnp.array([99.0, -99.0]) # should be ignored + + params_no_extra = pdf_params + params_with_extra = jnp.concatenate([pdf_params, extra_params]) + + preds_no_extra, pdf_no_extra = fm(pdf_grid_func, TEST_FK_ARRAYS, params_no_extra) + preds_with_extra, pdf_with_extra = fm( + pdf_grid_func, TEST_FK_ARRAYS, params_with_extra + ) + + assert_array_almost_equal(preds_no_extra, preds_with_extra) + assert_array_almost_equal(pdf_no_extra, pdf_with_extra) + + +def test_fktable_forward_map_uses_pdf_grid_func(): + """ + __call__ must feed the pdf returned by pdf_grid_func into pred_func. + We verify this by using a pdf_grid_func that scales by a known factor. + """ + scale = 3.0 + n_pdf = 2 + fm = FKTableForwardMap(pred_func=_simple_pred_func, n_pdf_params=n_pdf) + + params = jnp.array([1.0, 2.0]) + base_pdf_grid_func = _make_pdf_grid_func(TEST_PDF_GRID) + scaled_pdf_grid_func = lambda p: scale * base_pdf_grid_func(p) # noqa: E731 + + preds_base, _ = fm(base_pdf_grid_func, TEST_FK_ARRAYS, params) + preds_scaled, _ = fm(scaled_pdf_grid_func, TEST_FK_ARRAYS, params) + + np.testing.assert_allclose(preds_scaled, scale * preds_base, rtol=1e-5) + + +def test_fktable_forward_map_correct_values(): + """ + __call__ must produce predictions equal to pred_func(pdf_grid_func(params), fk). + """ + n_pdf = 2 + fm = FKTableForwardMap(pred_func=_simple_pred_func, n_pdf_params=n_pdf) + + params = jnp.array([1.0, 2.0]) + pdf_grid_func = _make_pdf_grid_func(TEST_PDF_GRID) + + predictions, pdf = fm(pdf_grid_func, TEST_FK_ARRAYS, params) + + expected_pdf = pdf_grid_func(params) + expected_preds = _simple_pred_func(expected_pdf, TEST_FK_ARRAYS) + + assert_array_almost_equal(predictions, expected_preds) + assert_array_almost_equal(pdf, expected_pdf) + + +# --------------------------------------------------------------------------- +# forward_map provider function +# --------------------------------------------------------------------------- + + +def test_forward_map_provider_returns_fktable_forward_map(): + """forward_map() must return an FKTableForwardMap instance.""" + result = forward_map(_pred_data=_simple_pred_func, pdf_model=MOCK_PDF_MODEL) + assert isinstance(result, FKTableForwardMap) + + +def test_forward_map_provider_infers_n_pdf_params(): + """ + forward_map() must set n_pdf_params equal to len(pdf_model.param_names). + """ + result = forward_map(_pred_data=_simple_pred_func, pdf_model=MOCK_PDF_MODEL) + assert result.n_pdf_params == len(MOCK_PDF_MODEL.param_names) + + +def test_forward_map_provider_stores_pred_func(): + """forward_map() must wire _pred_data into the FKTableForwardMap.""" + result = forward_map(_pred_data=_simple_pred_func, pdf_model=MOCK_PDF_MODEL) + assert result._pred_func is _simple_pred_func + + +def test_forward_map_provider_functional(): + """ + The FKTableForwardMap built by forward_map() must produce correct results + when called. + """ + fm = forward_map(_pred_data=_simple_pred_func, pdf_model=MOCK_PDF_MODEL) + pdf_grid_func = _make_pdf_grid_func(TEST_PDF_GRID) + params = jnp.array([1.0, 2.0]) + + predictions, pdf = fm(pdf_grid_func, TEST_FK_ARRAYS, params) + + assert predictions.shape == (TEST_N_DATA,) + assert pdf.shape == (TEST_N_FL, TEST_N_XGRID) + + +def test_forward_map_provider_with_different_param_counts(): + """ + forward_map() must correctly handle pdf_models with different numbers of + parameters. + """ + for n in [1, 3, 7]: + mock_model = Mock() + mock_model.param_names = [f"p_{i}" for i in range(n)] + fm = forward_map(_pred_data=_simple_pred_func, pdf_model=mock_model) + assert fm.n_pdf_params == n From a4b01b4ec416cc10043582ac53923979518b75c2 Mon Sep 17 00:00:00 2001 From: Mark Nestor Costantini Date: Tue, 24 Mar 2026 10:54:06 +0000 Subject: [PATCH 14/30] fixed tests from merge --- colibri/tests/test_blackjax_fit.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/colibri/tests/test_blackjax_fit.py b/colibri/tests/test_blackjax_fit.py index c8845666e..f84e8df23 100644 --- a/colibri/tests/test_blackjax_fit.py +++ b/colibri/tests/test_blackjax_fit.py @@ -21,6 +21,7 @@ from colibri.core import BlackJAXFit, BayesianPrior from colibri.blackjax_fit import blackjax_fit, run_blackjax_fit +from colibri.forward_map import FKTableForwardMap from colibri.likelihood import LogLikelihood jax.config.update("jax_enable_x64", True) @@ -61,12 +62,15 @@ def mock_sample(rng_key, n_samples): @pytest.mark.parametrize("pos_penalty", [True, False]) def test_blackjax_fit(pos_penalty): - _pred_data = lambda *args: jnp.array([0.0]) + forward_map = FKTableForwardMap( + lambda pdf, fk: jnp.zeros(len(MOCK_PDF_MODEL.param_names)), + n_pdf_params=len(MOCK_PDF_MODEL.param_names), + ) mock_log_likelihood = LogLikelihood( MOCK_CENTRAL_COVMAT_INDEX, MOCK_PDF_MODEL, TEST_XGRID, - _pred_data, + forward_map, TEST_FK_ARRAYS, TEST_POS_FK_ARRAYS, MOCK_PENALTY_POSDATA, From cbf5ff68675aba5b011a59a034510709199aacd6 Mon Sep 17 00:00:00 2001 From: Mark Nestor Costantini Date: Tue, 24 Mar 2026 10:56:46 +0000 Subject: [PATCH 15/30] upgraded local black and formatted forward map tests --- colibri/tests/test_forward_map.py | 1 - 1 file changed, 1 deletion(-) diff --git a/colibri/tests/test_forward_map.py b/colibri/tests/test_forward_map.py index c1bcb7fab..55668c439 100644 --- a/colibri/tests/test_forward_map.py +++ b/colibri/tests/test_forward_map.py @@ -21,7 +21,6 @@ MOCK_PDF_MODEL, ) - # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- From 74d0f4e2b41efa695239f6fe2737a53e91ff1a25 Mon Sep 17 00:00:00 2001 From: Mark Nestor Costantini Date: Tue, 24 Mar 2026 11:12:46 +0000 Subject: [PATCH 16/30] added line for raise not implemented in forward map --- colibri/tests/test_forward_map.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/colibri/tests/test_forward_map.py b/colibri/tests/test_forward_map.py index 55668c439..b6540aaa4 100644 --- a/colibri/tests/test_forward_map.py +++ b/colibri/tests/test_forward_map.py @@ -57,6 +57,18 @@ class NoCallSubclass(ForwardMap): NoCallSubclass(n_pdf_params=2) +def test_forward_map_abstract_call_raises_not_implemented(): + """Calling super().__call__() must hit the raise NotImplementedError body.""" + + class SuperCallingForwardMap(ForwardMap): + def __call__(self, pdf_grid_func, fk_tables, params): + return super().__call__(pdf_grid_func, fk_tables, params) + + fm = SuperCallingForwardMap(n_pdf_params=2) + with pytest.raises(NotImplementedError): + fm(_make_pdf_grid_func(TEST_PDF_GRID), TEST_FK_ARRAYS, jnp.array([1.0, 2.0])) + + def test_forward_map_subclass_stores_n_pdf_params(): """n_pdf_params passed to super().__init__ must be stored on the instance.""" From 7da4f04e06a1b2b5f1169e193416af9d17271190 Mon Sep 17 00:00:00 2001 From: Mark Nestor Costantini Date: Thu, 26 Mar 2026 09:09:19 +0000 Subject: [PATCH 17/30] forward model initialised with pdf parameter names --- colibri/forward_map.py | 30 +++++++++++++-------- colibri/tests/conftest.py | 2 +- colibri/tests/test_analytic_fit.py | 7 ++--- colibri/tests/test_blackjax_fit.py | 2 +- colibri/tests/test_checks.py | 4 +-- colibri/tests/test_forward_map.py | 41 ++++++++++++++++------------- colibri/tests/test_ultranest_fit.py | 10 +++---- colibri/tests/test_utils.py | 2 +- 8 files changed, 56 insertions(+), 42 deletions(-) diff --git a/colibri/forward_map.py b/colibri/forward_map.py index 4e28efe54..ccbccf03a 100644 --- a/colibri/forward_map.py +++ b/colibri/forward_map.py @@ -32,8 +32,8 @@ :: class NormForwardMap(ForwardMap): - def __init__(self, pred_func, n_pdf_params: int): - super().__init__(n_pdf_params) + def __init__(self, pred_func, pdf_param_names: list[str]): + super().__init__(pdf_param_names) self._pred_func = pred_func def __call__(self, pdf_grid_func, fk_tables, params): @@ -46,8 +46,8 @@ def __call__(self, pdf_grid_func, fk_tables, params): :: class FixedPDFForwardMap(ForwardMap): - def __init__(self, pred_func, fixed_pdf, fk_tables, n_pdf_params: int = 0): - super().__init__(n_pdf_params) + def __init__(self, pred_func, fixed_pdf, fk_tables, pdf_param_names: list[str] | None = None): + super().__init__(pdf_param_names if pdf_param_names is not None else []) self._pred_func = pred_func self.fixed_pdf = fixed_pdf self._fixed_pred = self._pred_func(fixed_pdf, fk_tables) @@ -81,9 +81,14 @@ class ForwardMap(ABC): by the forward map via ``self.n_pdf_params``. """ - def __init__(self, n_pdf_params: int): + def __init__(self, pdf_param_names: list[str]): - self.n_pdf_params = n_pdf_params + self.pdf_param_names = pdf_param_names + + @property + def n_pdf_params(self) -> int: + """Number of PDF parameters, derived from ``pdf_param_names``.""" + return len(self.pdf_param_names) @abstractmethod def __call__( @@ -131,9 +136,11 @@ class FKTableForwardMap(ForwardMap): """ def __init__( - self, pred_func: Callable[[jnp.ndarray, Any], jnp.ndarray], n_pdf_params: int + self, + pred_func: Callable[[jnp.ndarray, Any], jnp.ndarray], + pdf_param_names: list[str], ): - super().__init__(n_pdf_params) + super().__init__(pdf_param_names) self._pred_func = pred_func def __call__(self, pdf_grid_func, fk_tables, params): @@ -150,9 +157,10 @@ def forward_map(_pred_data, pdf_model): _pred_data : callable Prediction function of the form ``pred_func(pdf, fk_tables) -> predictions``. pdf_model : optional - Used to infer ``n_pdf_params`` from ``len(pdf_model.param_names)``. + Used to obtain ``pdf_param_names`` from ``pdf_model.param_names``. """ - n_pdf_params = len(pdf_model.param_names) - return FKTableForwardMap(pred_func=_pred_data, n_pdf_params=n_pdf_params) + return FKTableForwardMap( + pred_func=_pred_data, pdf_param_names=pdf_model.param_names + ) diff --git a/colibri/tests/conftest.py b/colibri/tests/conftest.py index 94cf798f4..650077a3b 100644 --- a/colibri/tests/conftest.py +++ b/colibri/tests/conftest.py @@ -327,7 +327,7 @@ def wmin_param(params): TEST_FORWARD_MAP_DIS = FKTableForwardMap( lambda pdf, fk_arrays: jnp.einsum("ijk,jk->i", fk_arrays[0], pdf), - n_pdf_params=2, + pdf_param_names=["param1", "param2"], ) """ Mock DIS forward map function for testing purposes. diff --git a/colibri/tests/test_analytic_fit.py b/colibri/tests/test_analytic_fit.py index 522a773e6..579e603a5 100644 --- a/colibri/tests/test_analytic_fit.py +++ b/colibri/tests/test_analytic_fit.py @@ -41,7 +41,8 @@ def test_analytic_fit_flat_direction(): n_params = len(MOCK_PDF_MODEL.param_names) forward_map = FKTableForwardMap( - lambda pdf, fkarrs: jnp.ones(n_params), n_pdf_params=n_params + lambda pdf, fkarrs: jnp.ones(n_params), + pdf_param_names=MOCK_PDF_MODEL.param_names, ) with pytest.raises(ValueError): @@ -69,7 +70,7 @@ def test_analytic_fit(caplog, monkeypatch): ) forward_map = FKTableForwardMap( - lambda pdf, fkarrs: pdf, n_pdf_params=len(MOCK_PDF_MODEL.param_names) + lambda pdf, fkarrs: pdf, pdf_param_names=MOCK_PDF_MODEL.param_names ) # Run the analytic fit @@ -134,7 +135,7 @@ def test_analytic_fit_different_priors(caplog, monkeypatch): ) forward_map = FKTableForwardMap( - lambda pdf, fkarrs: pdf, n_pdf_params=len(MOCK_PDF_MODEL.param_names) + lambda pdf, fkarrs: pdf, pdf_param_names=MOCK_PDF_MODEL.param_names ) # Run the analytic fit diff --git a/colibri/tests/test_blackjax_fit.py b/colibri/tests/test_blackjax_fit.py index f84e8df23..018df86e4 100644 --- a/colibri/tests/test_blackjax_fit.py +++ b/colibri/tests/test_blackjax_fit.py @@ -64,7 +64,7 @@ def mock_sample(rng_key, n_samples): def test_blackjax_fit(pos_penalty): forward_map = FKTableForwardMap( lambda pdf, fk: jnp.zeros(len(MOCK_PDF_MODEL.param_names)), - n_pdf_params=len(MOCK_PDF_MODEL.param_names), + pdf_param_names=MOCK_PDF_MODEL.param_names, ) mock_log_likelihood = LogLikelihood( MOCK_CENTRAL_COVMAT_INDEX, diff --git a/colibri/tests/test_checks.py b/colibri/tests/test_checks.py index f1aaf4c9f..35877cb51 100644 --- a/colibri/tests/test_checks.py +++ b/colibri/tests/test_checks.py @@ -134,7 +134,7 @@ def pdf_linear_model(params): forward_map_lin = FKTableForwardMap( # Simulating a simple linear model: f(x) = a*x + b*y + c*z + 3.0, where pdf = [a, b, c] lambda pdf, fk: jnp.dot(pdf, fk) + 3.0, - n_pdf_params=3, + pdf_param_names=["a", "b", "c"], ) # Set the mock's grid_values_func to return the linear_model function @@ -147,7 +147,7 @@ def pdf_linear_model(params): non_linear_model = FKTableForwardMap( # Introduce some non-linearity lambda pdf, fk: jnp.dot(pdf**2, FIT_XGRID) + fk, - n_pdf_params=3, + pdf_param_names=["a", "b", "c"], ) # Ensure ValueError is raised for non-linear model diff --git a/colibri/tests/test_forward_map.py b/colibri/tests/test_forward_map.py index b6540aaa4..c4a0abc1b 100644 --- a/colibri/tests/test_forward_map.py +++ b/colibri/tests/test_forward_map.py @@ -44,7 +44,7 @@ def _make_pdf_grid_func(pdf_grid): def test_forward_map_cannot_be_instantiated(): """ForwardMap is abstract; direct instantiation must raise TypeError.""" with pytest.raises(TypeError): - ForwardMap(n_pdf_params=2) + ForwardMap(pdf_param_names=["a", "b"]) def test_forward_map_subclass_without_call_cannot_be_instantiated(): @@ -54,7 +54,7 @@ class NoCallSubclass(ForwardMap): pass with pytest.raises(TypeError): - NoCallSubclass(n_pdf_params=2) + NoCallSubclass(pdf_param_names=["a", "b"]) def test_forward_map_abstract_call_raises_not_implemented(): @@ -64,20 +64,21 @@ class SuperCallingForwardMap(ForwardMap): def __call__(self, pdf_grid_func, fk_tables, params): return super().__call__(pdf_grid_func, fk_tables, params) - fm = SuperCallingForwardMap(n_pdf_params=2) + fm = SuperCallingForwardMap(pdf_param_names=["a", "b"]) with pytest.raises(NotImplementedError): fm(_make_pdf_grid_func(TEST_PDF_GRID), TEST_FK_ARRAYS, jnp.array([1.0, 2.0])) -def test_forward_map_subclass_stores_n_pdf_params(): - """n_pdf_params passed to super().__init__ must be stored on the instance.""" +def test_forward_map_subclass_stores_pdf_param_names(): + """pdf_param_names passed to super().__init__ must be stored on the instance.""" class MinimalForwardMap(ForwardMap): def __call__(self, pdf_grid_func, fk_tables, params): pdf = pdf_grid_func(params[: self.n_pdf_params]) return _simple_pred_func(pdf, fk_tables), pdf - fm = MinimalForwardMap(n_pdf_params=5) + fm = MinimalForwardMap(pdf_param_names=["p0", "p1", "p2", "p3", "p4"]) + assert fm.pdf_param_names == ["p0", "p1", "p2", "p3", "p4"] assert fm.n_pdf_params == 5 @@ -86,15 +87,16 @@ def __call__(self, pdf_grid_func, fk_tables, params): # --------------------------------------------------------------------------- -def test_fktable_forward_map_stores_n_pdf_params(): - """FKTableForwardMap.__init__ must store n_pdf_params via the base class.""" - fm = FKTableForwardMap(pred_func=_simple_pred_func, n_pdf_params=3) +def test_fktable_forward_map_stores_pdf_param_names(): + """FKTableForwardMap.__init__ must store pdf_param_names via the base class.""" + fm = FKTableForwardMap(pred_func=_simple_pred_func, pdf_param_names=["a", "b", "c"]) + assert fm.pdf_param_names == ["a", "b", "c"] assert fm.n_pdf_params == 3 def test_fktable_forward_map_stores_pred_func(): """FKTableForwardMap.__init__ must store the pred_func.""" - fm = FKTableForwardMap(pred_func=_simple_pred_func, n_pdf_params=3) + fm = FKTableForwardMap(pred_func=_simple_pred_func, pdf_param_names=["a", "b", "c"]) assert fm._pred_func is _simple_pred_func @@ -105,7 +107,7 @@ def test_fktable_forward_map_stores_pred_func(): def test_fktable_forward_map_returns_tuple(): """__call__ must return a 2-tuple (predictions, pdf).""" - fm = FKTableForwardMap(pred_func=_simple_pred_func, n_pdf_params=2) + fm = FKTableForwardMap(pred_func=_simple_pred_func, pdf_param_names=["a", "b"]) pdf_grid_func = _make_pdf_grid_func(TEST_PDF_GRID) params = jnp.array([1.0, 2.0]) @@ -117,7 +119,7 @@ def test_fktable_forward_map_returns_tuple(): def test_fktable_forward_map_predictions_shape(): """Predictions returned by __call__ must have shape (N_data,).""" - fm = FKTableForwardMap(pred_func=_simple_pred_func, n_pdf_params=2) + fm = FKTableForwardMap(pred_func=_simple_pred_func, pdf_param_names=["a", "b"]) pdf_grid_func = _make_pdf_grid_func(TEST_PDF_GRID) params = jnp.array([1.0, 2.0]) @@ -128,7 +130,7 @@ def test_fktable_forward_map_predictions_shape(): def test_fktable_forward_map_pdf_shape(): """PDF returned by __call__ must have shape (N_fl, N_x).""" - fm = FKTableForwardMap(pred_func=_simple_pred_func, n_pdf_params=2) + fm = FKTableForwardMap(pred_func=_simple_pred_func, pdf_param_names=["a", "b"]) pdf_grid_func = _make_pdf_grid_func(TEST_PDF_GRID) params = jnp.array([1.0, 2.0]) @@ -143,7 +145,7 @@ def test_fktable_forward_map_slices_pdf_params(): parameters appended to params must not affect the PDF or predictions. """ n_pdf = 2 - fm = FKTableForwardMap(pred_func=_simple_pred_func, n_pdf_params=n_pdf) + fm = FKTableForwardMap(pred_func=_simple_pred_func, pdf_param_names=["a", "b"]) pdf_grid_func = _make_pdf_grid_func(TEST_PDF_GRID) pdf_params = jnp.array([1.0, 2.0]) @@ -168,7 +170,7 @@ def test_fktable_forward_map_uses_pdf_grid_func(): """ scale = 3.0 n_pdf = 2 - fm = FKTableForwardMap(pred_func=_simple_pred_func, n_pdf_params=n_pdf) + fm = FKTableForwardMap(pred_func=_simple_pred_func, pdf_param_names=["a", "b"]) params = jnp.array([1.0, 2.0]) base_pdf_grid_func = _make_pdf_grid_func(TEST_PDF_GRID) @@ -185,7 +187,7 @@ def test_fktable_forward_map_correct_values(): __call__ must produce predictions equal to pred_func(pdf_grid_func(params), fk). """ n_pdf = 2 - fm = FKTableForwardMap(pred_func=_simple_pred_func, n_pdf_params=n_pdf) + fm = FKTableForwardMap(pred_func=_simple_pred_func, pdf_param_names=["a", "b"]) params = jnp.array([1.0, 2.0]) pdf_grid_func = _make_pdf_grid_func(TEST_PDF_GRID) @@ -210,11 +212,13 @@ def test_forward_map_provider_returns_fktable_forward_map(): assert isinstance(result, FKTableForwardMap) -def test_forward_map_provider_infers_n_pdf_params(): +def test_forward_map_provider_infers_pdf_param_names(): """ - forward_map() must set n_pdf_params equal to len(pdf_model.param_names). + forward_map() must set pdf_param_names equal to pdf_model.param_names, + and n_pdf_params must equal len(pdf_model.param_names). """ result = forward_map(_pred_data=_simple_pred_func, pdf_model=MOCK_PDF_MODEL) + assert result.pdf_param_names == MOCK_PDF_MODEL.param_names assert result.n_pdf_params == len(MOCK_PDF_MODEL.param_names) @@ -248,4 +252,5 @@ def test_forward_map_provider_with_different_param_counts(): mock_model = Mock() mock_model.param_names = [f"p_{i}" for i in range(n)] fm = forward_map(_pred_data=_simple_pred_func, pdf_model=mock_model) + assert fm.pdf_param_names == mock_model.param_names assert fm.n_pdf_params == n diff --git a/colibri/tests/test_ultranest_fit.py b/colibri/tests/test_ultranest_fit.py index dfda7429d..26e077348 100644 --- a/colibri/tests/test_ultranest_fit.py +++ b/colibri/tests/test_ultranest_fit.py @@ -68,7 +68,7 @@ def test_ultranest_fit(pos_penalty): _pred_data = lambda *args: jnp.array([0.0]) forward_map = FKTableForwardMap( - _pred_data, n_pdf_params=len(MOCK_PDF_MODEL.param_names) + _pred_data, pdf_param_names=MOCK_PDF_MODEL.param_names ) mock_log_likelihood = LogLikelihood( MOCK_CENTRAL_COVMAT_INDEX, @@ -108,7 +108,7 @@ def test_ultranest_fit_vectorized(pos_penalty): _pred_data = lambda *args: jnp.array([0.0]) forward_map = FKTableForwardMap( - _pred_data, n_pdf_params=len(MOCK_PDF_MODEL.param_names) + _pred_data, pdf_param_names=MOCK_PDF_MODEL.param_names ) ultranest_settings["ReactiveNS_settings"]["vectorized"] = True @@ -160,7 +160,7 @@ def test_ultranest_fit_with_SliceSampler(pos_penalty): _pred_data = lambda *args: jnp.array([0.0]) forward_map = FKTableForwardMap( - _pred_data, n_pdf_params=len(MOCK_PDF_MODEL.param_names) + _pred_data, pdf_param_names=MOCK_PDF_MODEL.param_names ) mock_log_likelihood = LogLikelihood( @@ -211,7 +211,7 @@ def test_ultranest_fit_with_popSliceSampler(pos_penalty): _pred_data = lambda *args: jnp.array([0.0]) forward_map = FKTableForwardMap( - _pred_data, n_pdf_params=len(MOCK_PDF_MODEL.param_names) + _pred_data, pdf_param_names=MOCK_PDF_MODEL.param_names ) mock_log_likelihood = LogLikelihood( @@ -266,7 +266,7 @@ def test_ultranest_fit_with_sampler_plot(mock_sampler_class, pos_penalty): _pred_data = lambda *args: jnp.array([0.0]) forward_map = FKTableForwardMap( - _pred_data, n_pdf_params=len(MOCK_PDF_MODEL.param_names) + _pred_data, pdf_param_names=MOCK_PDF_MODEL.param_names ) mock_log_likelihood = LogLikelihood( diff --git a/colibri/tests/test_utils.py b/colibri/tests/test_utils.py index b469e15a4..7fbae6bd5 100644 --- a/colibri/tests/test_utils.py +++ b/colibri/tests/test_utils.py @@ -343,7 +343,7 @@ def test_likelihood_float_type( len(MOCK_CENTRAL_COVMAT_INDEX.central_values) ) # Mock _pred_data forward_map = FKTableForwardMap( - _pred_data, n_pdf_params=len(MOCK_PDF_MODEL.param_names) + _pred_data, pdf_param_names=MOCK_PDF_MODEL.param_names ) # Mock forward_map FIT_XGRID = jnp.linspace(0, 1, 10) # Mock FIT_XGRID output_path = tmp_path From 054c97f0b85cca050074535a660767242b9e0b5d Mon Sep 17 00:00:00 2001 From: Mark Nestor Costantini Date: Thu, 26 Mar 2026 09:31:29 +0000 Subject: [PATCH 18/30] pass forward model to bayesian prior for total model params --- colibri/bayes_prior.py | 8 +-- colibri/blackjax_fit.py | 12 ++-- colibri/checks.py | 11 ++-- colibri/forward_map.py | 23 ++++++-- colibri/tests/test_bayes_prior.py | 18 ++++-- colibri/tests/test_blackjax_fit.py | 15 +++-- colibri/tests/test_checks.py | 86 ++++++++++++++--------------- colibri/tests/test_forward_map.py | 54 ++++++++++++++++++ colibri/tests/test_ultranest_fit.py | 10 ++-- colibri/ultranest_fit.py | 8 +-- 10 files changed, 156 insertions(+), 89 deletions(-) diff --git a/colibri/bayes_prior.py b/colibri/bayes_prior.py index 443e9b64f..a15a69c71 100644 --- a/colibri/bayes_prior.py +++ b/colibri/bayes_prior.py @@ -5,15 +5,13 @@ cast_to_numpy, get_full_posterior, ) -from colibri.checks import check_pdf_models_equal from colibri.core import BayesianPrior import tensorflow_probability.substrates.jax as tfp tfd = tfp.distributions -@check_pdf_models_equal -def bayesian_prior(prior_settings, pdf_model): +def bayesian_prior(prior_settings, forward_map): """ Produces a prior transform function. @@ -31,8 +29,8 @@ def bayesian_prior(prior_settings, pdf_model): prior_specs = prior_settings.prior_distribution_specs if "bounds" in prior_specs: - # Use param names from the model to order bounds correctly - param_names = pdf_model.param_names + # Use param names from the forward map to order bounds correctly + param_names = forward_map.param_names bounds_dict = prior_specs["bounds"] missing = [p for p in param_names if p not in bounds_dict] if missing: diff --git a/colibri/blackjax_fit.py b/colibri/blackjax_fit.py index 8f27c88bc..c35c3032e 100644 --- a/colibri/blackjax_fit.py +++ b/colibri/blackjax_fit.py @@ -37,7 +37,7 @@ def blackjax_fit( - pdf_model, + forward_map, bayesian_prior, blackjax_settings, log_likelihood, @@ -47,8 +47,8 @@ def blackjax_fit( Parameters ---------- - pdf_model: pdf_model.PDFModel - The PDF model to fit. + forward_map: ForwardMap + The forward map whose ``param_names`` enumerate all fit parameters. bayesian_prior: BayesianPrior, @jax.jit CompiledFunction The prior function for the model. @@ -70,7 +70,7 @@ def blackjax_fit( # set the BlackJAX seed rng_key = jax.random.PRNGKey(blackjax_settings["seed"]) log.info(f"BlackJAX initialisation seed: {rng_key}") - n_dims = pdf_model.n_parameters + n_dims = len(forward_map.param_names) n_live = blackjax_settings["n_live"] n_delete = int(blackjax_settings["delete_fraction"] * n_live) @@ -141,7 +141,7 @@ def one_step(carry, xs): data=final_states.particles, logL=final_states.loglikelihood, logL_birth=final_states.loglikelihood_birth, - columns=pdf_model.param_names, + columns=forward_map.param_names, ) # write nested_samples.csv to blackjax_logs log_dir = blackjax_settings["log_dir"] @@ -167,7 +167,7 @@ def one_step(carry, xs): "logZ_err": logzs.std(), "ess": ess_value, }, - param_names=pdf_model.param_names, + param_names=forward_map.param_names, resampled_posterior=resampled_posterior, full_posterior_samples=full_samples, bayesian_metrics={ diff --git a/colibri/checks.py b/colibri/checks.py index 0c7e0571e..e38ea6b4d 100644 --- a/colibri/checks.py +++ b/colibri/checks.py @@ -9,15 +9,15 @@ import jax from colibri.theory_predictions import make_pred_data, fast_kernel_arrays -from colibri.utils import get_fit_path, get_pdf_model, pdf_models_equal +from colibri.utils import get_fit_path, get_pdf_model @make_argcheck -def check_pdf_models_equal(prior_settings, pdf_model, theoryid): +def check_pdf_models_equal(prior_settings, forward_map, theoryid): """ Decorator that can be added to functions to check that the PDF model used as prior (eg when using prior_settings["type"] == "prior_from_gauss_posterior") - matches the PDF model used in the current fit (pdf_model). + matches the PDF model used in the current fit (via ``forward_map.pdf_param_names``). """ if prior_settings.prior_distribution == "prior_from_gauss_posterior": @@ -25,9 +25,10 @@ def check_pdf_models_equal(prior_settings, pdf_model, theoryid): prior_fit = prior_settings.prior_distribution_specs["prior_fit"] prior_pdf_model = get_pdf_model(prior_fit) - if not pdf_models_equal(prior_pdf_model, pdf_model): + if prior_pdf_model.param_names != list(forward_map.pdf_param_names): raise ValueError( - f"PDF model {pdf_model} does not match prior settings {prior_pdf_model}" + f"PDF param names from forward_map {list(forward_map.pdf_param_names)} " + f"do not match prior PDF model param names {prior_pdf_model.param_names}" ) # load filter.yml runcard of the prior fit diff --git a/colibri/forward_map.py b/colibri/forward_map.py index ccbccf03a..f5d1a509e 100644 --- a/colibri/forward_map.py +++ b/colibri/forward_map.py @@ -33,7 +33,7 @@ class NormForwardMap(ForwardMap): def __init__(self, pred_func, pdf_param_names: list[str]): - super().__init__(pdf_param_names) + super().__init__(pdf_param_names, extra_param_names=["norm"]) self._pred_func = pred_func def __call__(self, pdf_grid_func, fk_tables, params): @@ -81,15 +81,21 @@ class ForwardMap(ABC): by the forward map via ``self.n_pdf_params``. """ - def __init__(self, pdf_param_names: list[str]): + def __init__(self, pdf_param_names: list[str], extra_param_names: list[str] = ()): self.pdf_param_names = pdf_param_names + self.extra_param_names = extra_param_names @property def n_pdf_params(self) -> int: """Number of PDF parameters, derived from ``pdf_param_names``.""" return len(self.pdf_param_names) + @property + def param_names(self) -> list[str]: + """All fit parameter names: PDF parameters followed by extra parameters.""" + return list(self.pdf_param_names) + list(self.extra_param_names) + @abstractmethod def __call__( self, @@ -139,8 +145,9 @@ def __init__( self, pred_func: Callable[[jnp.ndarray, Any], jnp.ndarray], pdf_param_names: list[str], + extra_param_names: list[str] = (), ): - super().__init__(pdf_param_names) + super().__init__(pdf_param_names, extra_param_names=extra_param_names) self._pred_func = pred_func def __call__(self, pdf_grid_func, fk_tables, params): @@ -149,18 +156,22 @@ def __call__(self, pdf_grid_func, fk_tables, params): return self._pred_func(pdf, fk_tables), pdf -def forward_map(_pred_data, pdf_model): +def forward_map(_pred_data, pdf_model, extra_param_names=()): """Reportengine provider that builds the default FK-table forward map. Parameters ---------- _pred_data : callable Prediction function of the form ``pred_func(pdf, fk_tables) -> predictions``. - pdf_model : optional + pdf_model : object Used to obtain ``pdf_param_names`` from ``pdf_model.param_names``. + extra_param_names : list[str], optional + Names of any additional fit parameters beyond the PDF parameters. """ return FKTableForwardMap( - pred_func=_pred_data, pdf_param_names=pdf_model.param_names + pred_func=_pred_data, + pdf_param_names=pdf_model.param_names, + extra_param_names=extra_param_names, ) diff --git a/colibri/tests/test_bayes_prior.py b/colibri/tests/test_bayes_prior.py index 4343bc15a..a6b9ef845 100644 --- a/colibri/tests/test_bayes_prior.py +++ b/colibri/tests/test_bayes_prior.py @@ -17,13 +17,19 @@ from colibri.bayes_prior import bayesian_prior from colibri.core import PriorSettings from colibri.tests.conftest import MOCK_PDF_MODEL, TEST_PRIOR_SETTINGS_UNIFORM +from unittest.mock import Mock + +# Create a mock forward_map that exposes param_names matching MOCK_PDF_MODEL +MOCK_FORWARD_MAP = Mock() +MOCK_FORWARD_MAP.param_names = MOCK_PDF_MODEL.param_names +MOCK_FORWARD_MAP.pdf_param_names = MOCK_PDF_MODEL.param_names def test_uniform_prior(): """ Test the transformation of a uniform prior distribution. """ - prior_transform = bayesian_prior(TEST_PRIOR_SETTINGS_UNIFORM, MOCK_PDF_MODEL) + prior_transform = bayesian_prior(TEST_PRIOR_SETTINGS_UNIFORM, MOCK_FORWARD_MAP) key = random.PRNGKey(0) cube = random.uniform(key, shape=(10,)) @@ -54,7 +60,7 @@ def test_uniform_prior(): } ) - prior_transform_bounds = bayesian_prior(prior_settings_bounds, MOCK_PDF_MODEL) + prior_transform_bounds = bayesian_prior(prior_settings_bounds, MOCK_FORWARD_MAP) cube_bounds = random.uniform(key, shape=(2,)) expected_bounds = jnp.array( @@ -84,7 +90,7 @@ def test_uniform_prior(): ) with pytest.raises(ValueError, match="Missing bounds for parameters"): - bayesian_prior(prior_settings_missing_bounds, MOCK_PDF_MODEL) + bayesian_prior(prior_settings_missing_bounds, MOCK_FORWARD_MAP) # ---- Test missing min_val/max_val and bounds ---- prior_settings_invalid = PriorSettings( @@ -95,7 +101,7 @@ def test_uniform_prior(): ) with pytest.raises(ValueError, match="prior_distribution_specs must define either"): - bayesian_prior(prior_settings_invalid, MOCK_PDF_MODEL) + bayesian_prior(prior_settings_invalid, MOCK_FORWARD_MAP) @patch("colibri.bayes_prior.get_full_posterior") @@ -121,7 +127,7 @@ def cov(self): } ) - prior_transform = bayesian_prior(prior_settings, MOCK_PDF_MODEL) + prior_transform = bayesian_prior(prior_settings, MOCK_FORWARD_MAP) key = random.PRNGKey(0) cube = random.uniform(key, shape=(10, 2)) @@ -140,4 +146,4 @@ def test_invalid_prior_type(): ) with pytest.raises(ValueError) as e: - bayesian_prior(prior_settings, MOCK_PDF_MODEL) + bayesian_prior(prior_settings, MOCK_FORWARD_MAP) diff --git a/colibri/tests/test_blackjax_fit.py b/colibri/tests/test_blackjax_fit.py index 018df86e4..ae088bb56 100644 --- a/colibri/tests/test_blackjax_fit.py +++ b/colibri/tests/test_blackjax_fit.py @@ -43,7 +43,7 @@ def mock_sample(rng_key, n_samples): bayesian_prior = BayesianPrior( prior_transform=lambda x: x, log_prob=lambda x: -jnp.sum(x**2, axis=-1), - sample=lambda rng, n: jnp.zeros((n, MOCK_PDF_MODEL.n_parameters)), + sample=lambda rng, n: jnp.zeros((n, len(MOCK_PDF_MODEL.param_names))), ) integrability_penalty = lambda pdf: jnp.array([0.0]) @@ -82,12 +82,10 @@ def test_blackjax_fit(pos_penalty): integrability_penalty=integrability_penalty, ) - MOCK_PDF_MODEL.n_parameters = len(MOCK_PDF_MODEL.param_names) - with patch("colibri.blackjax_fit.anesthetic.NestedSamples"): fit_result = blackjax_fit( - MOCK_PDF_MODEL, + forward_map, bayesian_prior, blackjax_settings, mock_log_likelihood, @@ -97,13 +95,14 @@ def test_blackjax_fit(pos_penalty): def test_blackjax_fit_truncates_posterior_and_warns(caplog): - # --- ensure pdf_model is consistent --- - MOCK_PDF_MODEL.n_parameters = len(MOCK_PDF_MODEL.param_names) + # --- build a forward_map with the right param_names --- + mock_forward_map = Mock() + mock_forward_map.param_names = ["param1", "param2"] bayesian_prior = BayesianPrior( prior_transform=lambda x: x, log_prob=lambda x: -jnp.sum(x**2, axis=-1), - sample=lambda rng, n: jnp.zeros((n, MOCK_PDF_MODEL.n_parameters)), + sample=lambda rng, n: jnp.zeros((n, len(mock_forward_map.param_names))), ) blackjax_settings = { @@ -147,7 +146,7 @@ def test_blackjax_fit_truncates_posterior_and_warns(caplog): caplog.set_level("WARNING") fit_result = blackjax_fit( - MOCK_PDF_MODEL, + mock_forward_map, bayesian_prior, blackjax_settings, log_likelihood, diff --git a/colibri/tests/test_checks.py b/colibri/tests/test_checks.py index 35877cb51..d8da860e2 100644 --- a/colibri/tests/test_checks.py +++ b/colibri/tests/test_checks.py @@ -21,11 +21,8 @@ read_data="theoryid: 123\nt0pdfset: t0pdfset1", ) @patch("os.path.exists", return_value=True) -@patch("colibri.checks.get_pdf_model", return_value="model1") -@patch("colibri.checks.pdf_models_equal") -def test_check_pdf_models_equal_true( - mock_pdf_models_equal, mock_get_pdf_model, mock_exists, mock_open -): +@patch("colibri.checks.get_pdf_model") +def test_check_pdf_models_equal_true(mock_get_pdf_model, mock_exists, mock_open): # Setup prior_settings = PriorSettings( **{ @@ -33,16 +30,20 @@ def test_check_pdf_models_equal_true( "prior_distribution_specs": {"prior_fit": "fit1"}, } ) - pdf_model = "model1" + + # The prior model returned by get_pdf_model must have matching param_names + mock_prior_model = MagicMock() + mock_prior_model.param_names = ["param1", "param2"] + mock_get_pdf_model.return_value = mock_prior_model + + forward_map = MagicMock() + forward_map.pdf_param_names = ["param1", "param2"] theoryid = MagicMock() theoryid.id = 123 - # Configure mock behavior - mock_pdf_models_equal.side_effect = lambda x, y: x == y - - # Act - check_pdf_models_equal.__wrapped__(prior_settings, pdf_model, theoryid) + # Act — should not raise + check_pdf_models_equal.__wrapped__(prior_settings, forward_map, theoryid) @patch( @@ -51,10 +52,9 @@ def test_check_pdf_models_equal_true( read_data="theoryid: 456\nt0pdfset: t0pdfset1", ) @patch("os.path.exists", return_value=True) -@patch("colibri.checks.get_pdf_model", return_value="model1") -@patch("colibri.checks.pdf_models_equal") +@patch("colibri.checks.get_pdf_model") def test_check_pdf_models_equal_false_theoryid( - mock_pdf_models_equal, mock_get_pdf_model, mock_exists, mock_open + mock_get_pdf_model, mock_exists, mock_open ): # Setup prior_settings = PriorSettings( @@ -63,21 +63,20 @@ def test_check_pdf_models_equal_false_theoryid( "prior_distribution_specs": {"prior_fit": "fit1"}, } ) - pdf_model = "model1" - theoryid = MagicMock() - theoryid.id = 123 + mock_prior_model = MagicMock() + mock_prior_model.param_names = ["param1", "param2"] + mock_get_pdf_model.return_value = mock_prior_model - t0pdfset = MagicMock() - t0pdfset.name = "t0pdfset1" + forward_map = MagicMock() + forward_map.pdf_param_names = ["param1", "param2"] - # Configure mock behavior - mock_pdf_models_equal.side_effect = lambda x, y: x == y + theoryid = MagicMock() + theoryid.id = 123 + # Theory ID mismatch (file says 456, fit says 123) with pytest.raises(Exception): - check_pdf_models_equal.__wrapped__( - prior_settings, pdf_model, theoryid, t0pdfset - ) + check_pdf_models_equal.__wrapped__(prior_settings, forward_map, theoryid) @patch( @@ -86,31 +85,30 @@ def test_check_pdf_models_equal_false_theoryid( read_data="theoryid: 123\nt0pdfset: t0pdfset2", ) @patch("os.path.exists", return_value=True) -@patch("colibri.checks.get_pdf_model", return_value="model1") -@patch("colibri.checks.pdf_models_equal") -def test_check_pdf_models_equal_false_t0pdf( - mock_pdf_models_equal, mock_get_pdf_model, mock_exists, mock_open +@patch("colibri.checks.get_pdf_model") +def test_check_pdf_models_equal_false_param_names( + mock_get_pdf_model, mock_exists, mock_open ): - # Setup - prior_settings = { - "prior_distribution": "prior_from_gauss_posterior", - "prior_distribution_specs": {"prior_fit": "fit1"}, - } - pdf_model = "model1" + # Setup — param names mismatch between prior model and forward_map + prior_settings = PriorSettings( + **{ + "prior_distribution": "prior_from_gauss_posterior", + "prior_distribution_specs": {"prior_fit": "fit1"}, + } + ) - theoryid = MagicMock() - theoryid.id = 123 + mock_prior_model = MagicMock() + mock_prior_model.param_names = ["param1", "param2", "param3"] # different + mock_get_pdf_model.return_value = mock_prior_model - t0pdfset = MagicMock() - t0pdfset.name = "t0pdfset1" + forward_map = MagicMock() + forward_map.pdf_param_names = ["param1", "param2"] - # Configure mock behavior - mock_pdf_models_equal.side_effect = lambda x, y: x == y + theoryid = MagicMock() + theoryid.id = 123 - with pytest.raises(Exception): - check_pdf_models_equal.__wrapped__( - prior_settings, pdf_model, theoryid, t0pdfset - ) + with pytest.raises(ValueError): + check_pdf_models_equal.__wrapped__(prior_settings, forward_map, theoryid) @patch("colibri.checks.make_pred_data") diff --git a/colibri/tests/test_forward_map.py b/colibri/tests/test_forward_map.py index c4a0abc1b..4b7f2e8f0 100644 --- a/colibri/tests/test_forward_map.py +++ b/colibri/tests/test_forward_map.py @@ -82,6 +82,33 @@ def __call__(self, pdf_grid_func, fk_tables, params): assert fm.n_pdf_params == 5 +def test_forward_map_extra_param_names_default(): + """extra_param_names defaults to an empty tuple.""" + + class MinimalForwardMap(ForwardMap): + def __call__(self, pdf_grid_func, fk_tables, params): + return None + + fm = MinimalForwardMap(pdf_param_names=["a", "b"]) + assert list(fm.extra_param_names) == [] + assert fm.param_names == ["a", "b"] + + +def test_forward_map_extra_param_names(): + """extra_param_names are stored and appear in param_names after pdf_param_names.""" + + class MinimalForwardMap(ForwardMap): + def __call__(self, pdf_grid_func, fk_tables, params): + return None + + fm = MinimalForwardMap( + pdf_param_names=["a", "b"], extra_param_names=["norm", "scale"] + ) + assert list(fm.extra_param_names) == ["norm", "scale"] + assert fm.param_names == ["a", "b", "norm", "scale"] + assert fm.n_pdf_params == 2 + + # --------------------------------------------------------------------------- # FKTableForwardMap.__init__ # --------------------------------------------------------------------------- @@ -94,6 +121,18 @@ def test_fktable_forward_map_stores_pdf_param_names(): assert fm.n_pdf_params == 3 +def test_fktable_forward_map_extra_param_names(): + """FKTableForwardMap must accept and store extra_param_names.""" + fm = FKTableForwardMap( + pred_func=_simple_pred_func, + pdf_param_names=["a", "b"], + extra_param_names=["norm"], + ) + assert fm.param_names == ["a", "b", "norm"] + assert fm.n_pdf_params == 2 + assert list(fm.extra_param_names) == ["norm"] + + def test_fktable_forward_map_stores_pred_func(): """FKTableForwardMap.__init__ must store the pred_func.""" fm = FKTableForwardMap(pred_func=_simple_pred_func, pdf_param_names=["a", "b", "c"]) @@ -254,3 +293,18 @@ def test_forward_map_provider_with_different_param_counts(): fm = forward_map(_pred_data=_simple_pred_func, pdf_model=mock_model) assert fm.pdf_param_names == mock_model.param_names assert fm.n_pdf_params == n + + +def test_forward_map_provider_with_extra_param_names(): + """ + forward_map() must forward extra_param_names to FKTableForwardMap + and expose them via param_names. + """ + extra = ["norm", "scale"] + fm = forward_map( + _pred_data=_simple_pred_func, + pdf_model=MOCK_PDF_MODEL, + extra_param_names=extra, + ) + assert fm.param_names == MOCK_PDF_MODEL.param_names + extra + assert list(fm.extra_param_names) == extra diff --git a/colibri/tests/test_ultranest_fit.py b/colibri/tests/test_ultranest_fit.py index 26e077348..a2b91dec0 100644 --- a/colibri/tests/test_ultranest_fit.py +++ b/colibri/tests/test_ultranest_fit.py @@ -87,7 +87,7 @@ def test_ultranest_fit(pos_penalty): ) fit_result = ultranest_fit( - MOCK_PDF_MODEL, + forward_map, bayesian_prior, ultranest_settings, mock_log_likelihood, @@ -129,7 +129,7 @@ def test_ultranest_fit_vectorized(pos_penalty): ) fit_result = ultranest_fit( - MOCK_PDF_MODEL, + forward_map, bayesian_prior, ultranest_settings, mock_log_likelihood, @@ -180,7 +180,7 @@ def test_ultranest_fit_with_SliceSampler(pos_penalty): ) fit_result = ultranest_fit( - MOCK_PDF_MODEL, + forward_map, bayesian_prior, ultranest_settings, mock_log_likelihood, @@ -231,7 +231,7 @@ def test_ultranest_fit_with_popSliceSampler(pos_penalty): ) fit_result = ultranest_fit( - MOCK_PDF_MODEL, + forward_map, bayesian_prior, ultranest_settings, mock_log_likelihood, @@ -301,7 +301,7 @@ def test_ultranest_fit_with_sampler_plot(mock_sampler_class, pos_penalty): mock_sampler_instance.plot = Mock() fit_result = ultranest_fit( - MOCK_PDF_MODEL, + forward_map, bayesian_prior, ultranest_settings_with_plot, mock_log_likelihood, diff --git a/colibri/ultranest_fit.py b/colibri/ultranest_fit.py index 9896a2188..ac6e396c2 100644 --- a/colibri/ultranest_fit.py +++ b/colibri/ultranest_fit.py @@ -41,7 +41,7 @@ def ultranest_fit( - pdf_model, + forward_map, bayesian_prior, ultranest_settings, log_likelihood, @@ -51,8 +51,8 @@ def ultranest_fit( Parameters ---------- - pdf_model: pdf_model.PDFModel - The PDF model to fit. + forward_map: ForwardMap + The forward map whose ``param_names`` enumerate all fit parameters. bayesian_prior: BayesianPrior The prior object containing prior_transform, log_prob, and sample functions. @@ -74,7 +74,7 @@ def ultranest_fit( # set the ultranest seed np.random.seed(ultranest_settings["ultranest_seed"]) - parameters = pdf_model.param_names + parameters = forward_map.param_names if ultranest_settings["ReactiveNS_settings"]["vectorized"]: log.info("Vectorized likelihood for ultranest fit.") From 22207f3c73df76327d78e6b9a11980988bdc8764 Mon Sep 17 00:00:00 2001 From: Mark Nestor Costantini Date: Thu, 26 Mar 2026 10:23:09 +0000 Subject: [PATCH 19/30] pass pdf_model object to forward map --- colibri/forward_map.py | 24 +++++++------ colibri/tests/conftest.py | 2 +- colibri/tests/test_analytic_fit.py | 10 ++---- colibri/tests/test_blackjax_fit.py | 2 +- colibri/tests/test_checks.py | 4 +-- colibri/tests/test_forward_map.py | 55 +++++++++++++++++++++-------- colibri/tests/test_ultranest_fit.py | 20 +++-------- colibri/tests/test_utils.py | 2 +- 8 files changed, 67 insertions(+), 52 deletions(-) diff --git a/colibri/forward_map.py b/colibri/forward_map.py index f5d1a509e..d5cee0e10 100644 --- a/colibri/forward_map.py +++ b/colibri/forward_map.py @@ -32,8 +32,8 @@ :: class NormForwardMap(ForwardMap): - def __init__(self, pred_func, pdf_param_names: list[str]): - super().__init__(pdf_param_names, extra_param_names=["norm"]) + def __init__(self, pred_func, pdf_model): + super().__init__(pdf_model, extra_param_names=["norm"]) self._pred_func = pred_func def __call__(self, pdf_grid_func, fk_tables, params): @@ -46,8 +46,8 @@ def __call__(self, pdf_grid_func, fk_tables, params): :: class FixedPDFForwardMap(ForwardMap): - def __init__(self, pred_func, fixed_pdf, fk_tables, pdf_param_names: list[str] | None = None): - super().__init__(pdf_param_names if pdf_param_names is not None else []) + def __init__(self, pred_func, fixed_pdf, fk_tables, pdf_model=None): + super().__init__(pdf_model) self._pred_func = pred_func self.fixed_pdf = fixed_pdf self._fixed_pred = self._pred_func(fixed_pdf, fk_tables) @@ -81,9 +81,13 @@ class ForwardMap(ABC): by the forward map via ``self.n_pdf_params``. """ - def __init__(self, pdf_param_names: list[str], extra_param_names: list[str] = ()): + def __init__(self, pdf_model, extra_param_names: list[str] = ()): - self.pdf_param_names = pdf_param_names + self.pdf_model = pdf_model + if pdf_model is not None: + self.pdf_param_names = pdf_model.param_names + else: + self.pdf_param_names = [] self.extra_param_names = extra_param_names @property @@ -144,10 +148,10 @@ class FKTableForwardMap(ForwardMap): def __init__( self, pred_func: Callable[[jnp.ndarray, Any], jnp.ndarray], - pdf_param_names: list[str], + pdf_model, extra_param_names: list[str] = (), ): - super().__init__(pdf_param_names, extra_param_names=extra_param_names) + super().__init__(pdf_model, extra_param_names=extra_param_names) self._pred_func = pred_func def __call__(self, pdf_grid_func, fk_tables, params): @@ -164,7 +168,7 @@ def forward_map(_pred_data, pdf_model, extra_param_names=()): _pred_data : callable Prediction function of the form ``pred_func(pdf, fk_tables) -> predictions``. pdf_model : object - Used to obtain ``pdf_param_names`` from ``pdf_model.param_names``. + The PDF model object; must expose a ``param_names`` attribute. extra_param_names : list[str], optional Names of any additional fit parameters beyond the PDF parameters. @@ -172,6 +176,6 @@ def forward_map(_pred_data, pdf_model, extra_param_names=()): return FKTableForwardMap( pred_func=_pred_data, - pdf_param_names=pdf_model.param_names, + pdf_model=pdf_model, extra_param_names=extra_param_names, ) diff --git a/colibri/tests/conftest.py b/colibri/tests/conftest.py index 650077a3b..f76dc16e9 100644 --- a/colibri/tests/conftest.py +++ b/colibri/tests/conftest.py @@ -327,7 +327,7 @@ def wmin_param(params): TEST_FORWARD_MAP_DIS = FKTableForwardMap( lambda pdf, fk_arrays: jnp.einsum("ijk,jk->i", fk_arrays[0], pdf), - pdf_param_names=["param1", "param2"], + pdf_model=MOCK_PDF_MODEL, ) """ Mock DIS forward map function for testing purposes. diff --git a/colibri/tests/test_analytic_fit.py b/colibri/tests/test_analytic_fit.py index 579e603a5..a4fd6ec24 100644 --- a/colibri/tests/test_analytic_fit.py +++ b/colibri/tests/test_analytic_fit.py @@ -42,7 +42,7 @@ def test_analytic_fit_flat_direction(): forward_map = FKTableForwardMap( lambda pdf, fkarrs: jnp.ones(n_params), - pdf_param_names=MOCK_PDF_MODEL.param_names, + pdf_model=MOCK_PDF_MODEL, ) with pytest.raises(ValueError): @@ -69,9 +69,7 @@ def test_analytic_fit(caplog, monkeypatch): MOCK_PDF_MODEL, "grid_values_func", lambda xgrid: lambda params: params ) - forward_map = FKTableForwardMap( - lambda pdf, fkarrs: pdf, pdf_param_names=MOCK_PDF_MODEL.param_names - ) + forward_map = FKTableForwardMap(lambda pdf, fkarrs: pdf, pdf_model=MOCK_PDF_MODEL) # Run the analytic fit result = analytic_fit( @@ -134,9 +132,7 @@ def test_analytic_fit_different_priors(caplog, monkeypatch): MOCK_PDF_MODEL, "grid_values_func", lambda xgrid: lambda params: params ) - forward_map = FKTableForwardMap( - lambda pdf, fkarrs: pdf, pdf_param_names=MOCK_PDF_MODEL.param_names - ) + forward_map = FKTableForwardMap(lambda pdf, fkarrs: pdf, pdf_model=MOCK_PDF_MODEL) # Run the analytic fit result = analytic_fit( diff --git a/colibri/tests/test_blackjax_fit.py b/colibri/tests/test_blackjax_fit.py index ae088bb56..e4a41d722 100644 --- a/colibri/tests/test_blackjax_fit.py +++ b/colibri/tests/test_blackjax_fit.py @@ -64,7 +64,7 @@ def mock_sample(rng_key, n_samples): def test_blackjax_fit(pos_penalty): forward_map = FKTableForwardMap( lambda pdf, fk: jnp.zeros(len(MOCK_PDF_MODEL.param_names)), - pdf_param_names=MOCK_PDF_MODEL.param_names, + pdf_model=MOCK_PDF_MODEL, ) mock_log_likelihood = LogLikelihood( MOCK_CENTRAL_COVMAT_INDEX, diff --git a/colibri/tests/test_checks.py b/colibri/tests/test_checks.py index d8da860e2..256d3a428 100644 --- a/colibri/tests/test_checks.py +++ b/colibri/tests/test_checks.py @@ -132,7 +132,7 @@ def pdf_linear_model(params): forward_map_lin = FKTableForwardMap( # Simulating a simple linear model: f(x) = a*x + b*y + c*z + 3.0, where pdf = [a, b, c] lambda pdf, fk: jnp.dot(pdf, fk) + 3.0, - pdf_param_names=["a", "b", "c"], + pdf_model=mock_pdf_model, ) # Set the mock's grid_values_func to return the linear_model function @@ -145,7 +145,7 @@ def pdf_linear_model(params): non_linear_model = FKTableForwardMap( # Introduce some non-linearity lambda pdf, fk: jnp.dot(pdf**2, FIT_XGRID) + fk, - pdf_param_names=["a", "b", "c"], + pdf_model=mock_pdf_model, ) # Ensure ValueError is raised for non-linear model diff --git a/colibri/tests/test_forward_map.py b/colibri/tests/test_forward_map.py index 4b7f2e8f0..6eaed3f5f 100644 --- a/colibri/tests/test_forward_map.py +++ b/colibri/tests/test_forward_map.py @@ -41,10 +41,17 @@ def _make_pdf_grid_func(pdf_grid): # --------------------------------------------------------------------------- +def _mock_pdf_model(param_names): + """Create a mock pdf_model with the given param_names.""" + model = Mock() + model.param_names = param_names + return model + + def test_forward_map_cannot_be_instantiated(): """ForwardMap is abstract; direct instantiation must raise TypeError.""" with pytest.raises(TypeError): - ForwardMap(pdf_param_names=["a", "b"]) + ForwardMap(pdf_model=_mock_pdf_model(["a", "b"])) def test_forward_map_subclass_without_call_cannot_be_instantiated(): @@ -54,7 +61,7 @@ class NoCallSubclass(ForwardMap): pass with pytest.raises(TypeError): - NoCallSubclass(pdf_param_names=["a", "b"]) + NoCallSubclass(pdf_model=_mock_pdf_model(["a", "b"])) def test_forward_map_abstract_call_raises_not_implemented(): @@ -64,7 +71,7 @@ class SuperCallingForwardMap(ForwardMap): def __call__(self, pdf_grid_func, fk_tables, params): return super().__call__(pdf_grid_func, fk_tables, params) - fm = SuperCallingForwardMap(pdf_param_names=["a", "b"]) + fm = SuperCallingForwardMap(pdf_model=_mock_pdf_model(["a", "b"])) with pytest.raises(NotImplementedError): fm(_make_pdf_grid_func(TEST_PDF_GRID), TEST_FK_ARRAYS, jnp.array([1.0, 2.0])) @@ -77,8 +84,10 @@ def __call__(self, pdf_grid_func, fk_tables, params): pdf = pdf_grid_func(params[: self.n_pdf_params]) return _simple_pred_func(pdf, fk_tables), pdf - fm = MinimalForwardMap(pdf_param_names=["p0", "p1", "p2", "p3", "p4"]) + mock_model = _mock_pdf_model(["p0", "p1", "p2", "p3", "p4"]) + fm = MinimalForwardMap(pdf_model=mock_model) assert fm.pdf_param_names == ["p0", "p1", "p2", "p3", "p4"] + assert fm.pdf_model is mock_model assert fm.n_pdf_params == 5 @@ -89,7 +98,7 @@ class MinimalForwardMap(ForwardMap): def __call__(self, pdf_grid_func, fk_tables, params): return None - fm = MinimalForwardMap(pdf_param_names=["a", "b"]) + fm = MinimalForwardMap(pdf_model=_mock_pdf_model(["a", "b"])) assert list(fm.extra_param_names) == [] assert fm.param_names == ["a", "b"] @@ -102,7 +111,7 @@ def __call__(self, pdf_grid_func, fk_tables, params): return None fm = MinimalForwardMap( - pdf_param_names=["a", "b"], extra_param_names=["norm", "scale"] + pdf_model=_mock_pdf_model(["a", "b"]), extra_param_names=["norm", "scale"] ) assert list(fm.extra_param_names) == ["norm", "scale"] assert fm.param_names == ["a", "b", "norm", "scale"] @@ -116,7 +125,9 @@ def __call__(self, pdf_grid_func, fk_tables, params): def test_fktable_forward_map_stores_pdf_param_names(): """FKTableForwardMap.__init__ must store pdf_param_names via the base class.""" - fm = FKTableForwardMap(pred_func=_simple_pred_func, pdf_param_names=["a", "b", "c"]) + fm = FKTableForwardMap( + pred_func=_simple_pred_func, pdf_model=_mock_pdf_model(["a", "b", "c"]) + ) assert fm.pdf_param_names == ["a", "b", "c"] assert fm.n_pdf_params == 3 @@ -125,7 +136,7 @@ def test_fktable_forward_map_extra_param_names(): """FKTableForwardMap must accept and store extra_param_names.""" fm = FKTableForwardMap( pred_func=_simple_pred_func, - pdf_param_names=["a", "b"], + pdf_model=_mock_pdf_model(["a", "b"]), extra_param_names=["norm"], ) assert fm.param_names == ["a", "b", "norm"] @@ -135,7 +146,9 @@ def test_fktable_forward_map_extra_param_names(): def test_fktable_forward_map_stores_pred_func(): """FKTableForwardMap.__init__ must store the pred_func.""" - fm = FKTableForwardMap(pred_func=_simple_pred_func, pdf_param_names=["a", "b", "c"]) + fm = FKTableForwardMap( + pred_func=_simple_pred_func, pdf_model=_mock_pdf_model(["a", "b", "c"]) + ) assert fm._pred_func is _simple_pred_func @@ -146,7 +159,9 @@ def test_fktable_forward_map_stores_pred_func(): def test_fktable_forward_map_returns_tuple(): """__call__ must return a 2-tuple (predictions, pdf).""" - fm = FKTableForwardMap(pred_func=_simple_pred_func, pdf_param_names=["a", "b"]) + fm = FKTableForwardMap( + pred_func=_simple_pred_func, pdf_model=_mock_pdf_model(["a", "b"]) + ) pdf_grid_func = _make_pdf_grid_func(TEST_PDF_GRID) params = jnp.array([1.0, 2.0]) @@ -158,7 +173,9 @@ def test_fktable_forward_map_returns_tuple(): def test_fktable_forward_map_predictions_shape(): """Predictions returned by __call__ must have shape (N_data,).""" - fm = FKTableForwardMap(pred_func=_simple_pred_func, pdf_param_names=["a", "b"]) + fm = FKTableForwardMap( + pred_func=_simple_pred_func, pdf_model=_mock_pdf_model(["a", "b"]) + ) pdf_grid_func = _make_pdf_grid_func(TEST_PDF_GRID) params = jnp.array([1.0, 2.0]) @@ -169,7 +186,9 @@ def test_fktable_forward_map_predictions_shape(): def test_fktable_forward_map_pdf_shape(): """PDF returned by __call__ must have shape (N_fl, N_x).""" - fm = FKTableForwardMap(pred_func=_simple_pred_func, pdf_param_names=["a", "b"]) + fm = FKTableForwardMap( + pred_func=_simple_pred_func, pdf_model=_mock_pdf_model(["a", "b"]) + ) pdf_grid_func = _make_pdf_grid_func(TEST_PDF_GRID) params = jnp.array([1.0, 2.0]) @@ -184,7 +203,9 @@ def test_fktable_forward_map_slices_pdf_params(): parameters appended to params must not affect the PDF or predictions. """ n_pdf = 2 - fm = FKTableForwardMap(pred_func=_simple_pred_func, pdf_param_names=["a", "b"]) + fm = FKTableForwardMap( + pred_func=_simple_pred_func, pdf_model=_mock_pdf_model(["a", "b"]) + ) pdf_grid_func = _make_pdf_grid_func(TEST_PDF_GRID) pdf_params = jnp.array([1.0, 2.0]) @@ -209,7 +230,9 @@ def test_fktable_forward_map_uses_pdf_grid_func(): """ scale = 3.0 n_pdf = 2 - fm = FKTableForwardMap(pred_func=_simple_pred_func, pdf_param_names=["a", "b"]) + fm = FKTableForwardMap( + pred_func=_simple_pred_func, pdf_model=_mock_pdf_model(["a", "b"]) + ) params = jnp.array([1.0, 2.0]) base_pdf_grid_func = _make_pdf_grid_func(TEST_PDF_GRID) @@ -226,7 +249,9 @@ def test_fktable_forward_map_correct_values(): __call__ must produce predictions equal to pred_func(pdf_grid_func(params), fk). """ n_pdf = 2 - fm = FKTableForwardMap(pred_func=_simple_pred_func, pdf_param_names=["a", "b"]) + fm = FKTableForwardMap( + pred_func=_simple_pred_func, pdf_model=_mock_pdf_model(["a", "b"]) + ) params = jnp.array([1.0, 2.0]) pdf_grid_func = _make_pdf_grid_func(TEST_PDF_GRID) diff --git a/colibri/tests/test_ultranest_fit.py b/colibri/tests/test_ultranest_fit.py index a2b91dec0..122f557a1 100644 --- a/colibri/tests/test_ultranest_fit.py +++ b/colibri/tests/test_ultranest_fit.py @@ -67,9 +67,7 @@ def mock_sample(rng_key, n_samples): def test_ultranest_fit(pos_penalty): _pred_data = lambda *args: jnp.array([0.0]) - forward_map = FKTableForwardMap( - _pred_data, pdf_param_names=MOCK_PDF_MODEL.param_names - ) + forward_map = FKTableForwardMap(_pred_data, pdf_model=MOCK_PDF_MODEL) mock_log_likelihood = LogLikelihood( MOCK_CENTRAL_COVMAT_INDEX, MOCK_PDF_MODEL, @@ -107,9 +105,7 @@ def test_ultranest_fit(pos_penalty): def test_ultranest_fit_vectorized(pos_penalty): _pred_data = lambda *args: jnp.array([0.0]) - forward_map = FKTableForwardMap( - _pred_data, pdf_param_names=MOCK_PDF_MODEL.param_names - ) + forward_map = FKTableForwardMap(_pred_data, pdf_model=MOCK_PDF_MODEL) ultranest_settings["ReactiveNS_settings"]["vectorized"] = True mock_log_likelihood = LogLikelihood( @@ -159,9 +155,7 @@ def test_ultranest_fit_with_SliceSampler(pos_penalty): } _pred_data = lambda *args: jnp.array([0.0]) - forward_map = FKTableForwardMap( - _pred_data, pdf_param_names=MOCK_PDF_MODEL.param_names - ) + forward_map = FKTableForwardMap(_pred_data, pdf_model=MOCK_PDF_MODEL) mock_log_likelihood = LogLikelihood( MOCK_CENTRAL_COVMAT_INDEX, @@ -210,9 +204,7 @@ def test_ultranest_fit_with_popSliceSampler(pos_penalty): } _pred_data = lambda *args: jnp.array([0.0]) - forward_map = FKTableForwardMap( - _pred_data, pdf_param_names=MOCK_PDF_MODEL.param_names - ) + forward_map = FKTableForwardMap(_pred_data, pdf_model=MOCK_PDF_MODEL) mock_log_likelihood = LogLikelihood( MOCK_CENTRAL_COVMAT_INDEX, @@ -265,9 +257,7 @@ def test_ultranest_fit_with_sampler_plot(mock_sampler_class, pos_penalty): } _pred_data = lambda *args: jnp.array([0.0]) - forward_map = FKTableForwardMap( - _pred_data, pdf_param_names=MOCK_PDF_MODEL.param_names - ) + forward_map = FKTableForwardMap(_pred_data, pdf_model=MOCK_PDF_MODEL) mock_log_likelihood = LogLikelihood( MOCK_CENTRAL_COVMAT_INDEX, diff --git a/colibri/tests/test_utils.py b/colibri/tests/test_utils.py index 7fbae6bd5..57d6a8f46 100644 --- a/colibri/tests/test_utils.py +++ b/colibri/tests/test_utils.py @@ -343,7 +343,7 @@ def test_likelihood_float_type( len(MOCK_CENTRAL_COVMAT_INDEX.central_values) ) # Mock _pred_data forward_map = FKTableForwardMap( - _pred_data, pdf_param_names=MOCK_PDF_MODEL.param_names + _pred_data, pdf_model=MOCK_PDF_MODEL ) # Mock forward_map FIT_XGRID = jnp.linspace(0, 1, 10) # Mock FIT_XGRID output_path = tmp_path From 18dc793f9d60fa933bcddfbad632b0b2f58b72c6 Mon Sep 17 00:00:00 2001 From: Luca Mantani Date: Tue, 28 Apr 2026 11:28:48 +0200 Subject: [PATCH 20/30] Update colibri/forward_map.py --- colibri/forward_map.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colibri/forward_map.py b/colibri/forward_map.py index d5cee0e10..ad752ea0b 100644 --- a/colibri/forward_map.py +++ b/colibri/forward_map.py @@ -149,7 +149,7 @@ def __init__( self, pred_func: Callable[[jnp.ndarray, Any], jnp.ndarray], pdf_model, - extra_param_names: list[str] = (), + extra_param_names: list[str] = [], ): super().__init__(pdf_model, extra_param_names=extra_param_names) self._pred_func = pred_func From 089110c4abf3df08e3452a0a49b3af1eec62c7a7 Mon Sep 17 00:00:00 2001 From: Luca Mantani Date: Tue, 28 Apr 2026 11:28:59 +0200 Subject: [PATCH 21/30] Update colibri/forward_map.py --- colibri/forward_map.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colibri/forward_map.py b/colibri/forward_map.py index ad752ea0b..aa0bd2a09 100644 --- a/colibri/forward_map.py +++ b/colibri/forward_map.py @@ -81,7 +81,7 @@ class ForwardMap(ABC): by the forward map via ``self.n_pdf_params``. """ - def __init__(self, pdf_model, extra_param_names: list[str] = ()): + def __init__(self, pdf_model, extra_param_names: list[str] = []): self.pdf_model = pdf_model if pdf_model is not None: From cf02826661be73a248101f48fb7d74ae21cf427d Mon Sep 17 00:00:00 2001 From: Luca Mantani Date: Tue, 28 Apr 2026 11:29:07 +0200 Subject: [PATCH 22/30] Update colibri/forward_map.py --- colibri/forward_map.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colibri/forward_map.py b/colibri/forward_map.py index aa0bd2a09..855e467ab 100644 --- a/colibri/forward_map.py +++ b/colibri/forward_map.py @@ -160,7 +160,7 @@ def __call__(self, pdf_grid_func, fk_tables, params): return self._pred_func(pdf, fk_tables), pdf -def forward_map(_pred_data, pdf_model, extra_param_names=()): +def forward_map(_pred_data, pdf_model, extra_param_names=[]): """Reportengine provider that builds the default FK-table forward map. Parameters From 0cdce577e3b26ef27b936018f7ae8847f71b4184 Mon Sep 17 00:00:00 2001 From: Luca Mantani Date: Tue, 28 Apr 2026 11:32:45 +0200 Subject: [PATCH 23/30] Added test --- colibri/tests/test_forward_map.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/colibri/tests/test_forward_map.py b/colibri/tests/test_forward_map.py index 6eaed3f5f..20de39fb8 100644 --- a/colibri/tests/test_forward_map.py +++ b/colibri/tests/test_forward_map.py @@ -76,6 +76,19 @@ def __call__(self, pdf_grid_func, fk_tables, params): fm(_make_pdf_grid_func(TEST_PDF_GRID), TEST_FK_ARRAYS, jnp.array([1.0, 2.0])) +def test_forward_map_none_pdf_model_sets_empty_param_names(): + """When pdf_model is None, pdf_param_names must be set to an empty list.""" + + class MinimalForwardMap(ForwardMap): + def __call__(self, pdf_grid_func, fk_tables, params): + return None + + fm = MinimalForwardMap(pdf_model=None) + assert fm.pdf_param_names == [] + assert fm.n_pdf_params == 0 + assert fm.pdf_model is None + + def test_forward_map_subclass_stores_pdf_param_names(): """pdf_param_names passed to super().__init__ must be stored on the instance.""" From 20941aac29b40a6c8d72d900512453ab5c7b44fd Mon Sep 17 00:00:00 2001 From: Luca Mantani Date: Tue, 28 Apr 2026 12:32:34 +0200 Subject: [PATCH 24/30] removed grid_func passing to forward map --- colibri/analytic_fit.py | 5 +- colibri/checks.py | 11 +-- colibri/forward_map.py | 48 +++++----- colibri/likelihood.py | 11 +-- colibri/tests/conftest.py | 1 + colibri/tests/test_forward_map.py | 146 ++++++++++++++++++++---------- colibri/tests/test_likelihood.py | 25 ++--- colibri/utils.py | 3 +- 8 files changed, 140 insertions(+), 110 deletions(-) diff --git a/colibri/analytic_fit.py b/colibri/analytic_fit.py index 1db196c38..637ea0bd4 100644 --- a/colibri/analytic_fit.py +++ b/colibri/analytic_fit.py @@ -141,11 +141,10 @@ def analytic_fit( # Precompute predictions for the basis of the model bases = jnp.identity(len(parameters)) - pdf_grid = pdf_model.grid_values_func(FIT_XGRID) predictions = jnp.array( - [forward_map(pdf_grid, fast_kernel_arrays, basis)[0] for basis in bases] + [forward_map(fast_kernel_arrays, basis)[0] for basis in bases] ) - intercept = forward_map(pdf_grid, fast_kernel_arrays, jnp.zeros(len(parameters)))[0] + intercept = forward_map(fast_kernel_arrays, jnp.zeros(len(parameters)))[0] # Construct the analytic solution central_values = central_covmat_index.central_values diff --git a/colibri/checks.py b/colibri/checks.py index e38ea6b4d..fc4aab4e0 100644 --- a/colibri/checks.py +++ b/colibri/checks.py @@ -52,8 +52,7 @@ def check_pdf_model_is_linear(pdf_model, forward_map, FIT_XGRID, data): fk = fast_kernel_arrays(data, FIT_XGRID) parameters = pdf_model.param_names - pdf_grid = pdf_model.grid_values_func(FIT_XGRID) - intercept, _ = forward_map(pdf_grid, fk, jnp.zeros(len(parameters))) + intercept, _ = forward_map(fk, jnp.zeros(len(parameters))) # Run the check for 10 random points in the parameter space for i in range(10): @@ -65,16 +64,16 @@ def check_pdf_model_is_linear(pdf_model, forward_map, FIT_XGRID, data): # Test additivity add_check = jnp.isclose( - forward_map(pdf_grid, fk, x1)[0] + forward_map(pdf_grid, fk, x2)[0], - forward_map(pdf_grid, fk, x1 + x2)[0] + intercept, + forward_map(fk, x1)[0] + forward_map(fk, x2)[0], + forward_map(fk, x1 + x2)[0] + intercept, ) # Test homogeneity c = jax.random.uniform(key, (1,)) homogeneity_check = jnp.isclose( - c * (forward_map(pdf_grid, fk, x1)[0] - intercept), - forward_map(pdf_grid, fk, c * x1)[0] - intercept, + c * (forward_map(fk, x1)[0] - intercept), + forward_map(fk, c * x1)[0] - intercept, ) if not add_check.all() or not homogeneity_check.all(): diff --git a/colibri/forward_map.py b/colibri/forward_map.py index 855e467ab..f05fba48c 100644 --- a/colibri/forward_map.py +++ b/colibri/forward_map.py @@ -13,7 +13,12 @@ ----------------------------------- The log-likelihood calls every forward map with the same fixed signature:: - (pdf_grid_func, fk_tables, params) -> predictions, pdf + (fk_tables, params) -> predictions, pdf + +The PDF grid function is bound at construction time (via ``pdf_grid_func`` +stored on the instance), so ``fk_tables`` remain a dynamic argument that JAX +traces as an abstract input — keeping them out of the compiled binary and +avoiding expensive recompilation when values change. Parameter convention -------------------- @@ -23,21 +28,22 @@ By convention: -``params[:self.n_pdf_params]`` are PDF parameters consumed by ``pdf_grid_func``; -any remaining entries are "extra" parameters interpreted by the chosen - ``ForwardMap`` implementation. +``params[:self.n_pdf_params]`` are PDF parameters consumed by the bound +``pdf_grid_func``; any remaining entries are "extra" parameters interpreted by +the chosen ``ForwardMap`` implementation. Example - fitting a normalisation factor on top of the PDF ---------------------------------------------------------- :: class NormForwardMap(ForwardMap): - def __init__(self, pred_func, pdf_model): + def __init__(self, pred_func, pdf_model, pdf_grid_func): super().__init__(pdf_model, extra_param_names=["norm"]) self._pred_func = pred_func + self._pdf_grid_func = pdf_grid_func - def __call__(self, pdf_grid_func, fk_tables, params): - pdf = pdf_grid_func(params[: self.n_pdf_params]) + def __call__(self, fk_tables, params): + pdf = self._pdf_grid_func(params[: self.n_pdf_params]) norm = params[self.n_pdf_params] # first extra parameter return norm * self._pred_func(pdf, fk_tables), pdf @@ -52,7 +58,7 @@ def __init__(self, pred_func, fixed_pdf, fk_tables, pdf_model=None): self.fixed_pdf = fixed_pdf self._fixed_pred = self._pred_func(fixed_pdf, fk_tables) - def __call__(self, pdf_grid_func, fk_tables, params): + def __call__(self, fk_tables, params): scale = params[0] return scale * self._fixed_pred, self.fixed_pdf """ @@ -73,7 +79,7 @@ class ForwardMap(ABC): All forward maps share the same call signature: - ``(pdf_grid_func, fk_tables, params) -> predictions`` + ``(fk_tables, params) -> predictions`` Notes ----- @@ -103,7 +109,6 @@ def param_names(self) -> list[str]: @abstractmethod def __call__( self, - pdf_grid_func: Callable[[jnp.ndarray], jnp.ndarray], fk_tables: Any, params: jnp.ndarray, ) -> jnp.ndarray: @@ -111,14 +116,6 @@ def __call__( Parameters ---------- - pdf_grid_func : callable - Callable that evaluates PDF values on the fit x-grid from the PDF - parameters. - - Expected call signature: - ``pdf = pdf_grid_func(pdf_params)`` - with ``pdf`` shaped ``(N_fl, N_x)``. - fk_tables : jnp.ndarray Fast-kernel tables needed by the prediction function. @@ -149,18 +146,20 @@ def __init__( self, pred_func: Callable[[jnp.ndarray, Any], jnp.ndarray], pdf_model, + pdf_grid_func: Callable[[jnp.ndarray], jnp.ndarray], extra_param_names: list[str] = [], ): super().__init__(pdf_model, extra_param_names=extra_param_names) self._pred_func = pred_func + self._pdf_grid_func = pdf_grid_func - def __call__(self, pdf_grid_func, fk_tables, params): + def __call__(self, fk_tables, params): pdf_params = params[: self.n_pdf_params] - pdf = pdf_grid_func(pdf_params) + pdf = self._pdf_grid_func(pdf_params) return self._pred_func(pdf, fk_tables), pdf -def forward_map(_pred_data, pdf_model, extra_param_names=[]): +def forward_map(_pred_data, pdf_model, FIT_XGRID, extra_param_names=[]): """Reportengine provider that builds the default FK-table forward map. Parameters @@ -168,14 +167,17 @@ def forward_map(_pred_data, pdf_model, extra_param_names=[]): _pred_data : callable Prediction function of the form ``pred_func(pdf, fk_tables) -> predictions``. pdf_model : object - The PDF model object; must expose a ``param_names`` attribute. + The PDF model object; must expose ``param_names`` and ``grid_values_func``. + FIT_XGRID : array-like + The x-grid on which the PDF is evaluated. extra_param_names : list[str], optional Names of any additional fit parameters beyond the PDF parameters. """ - + pdf_grid_func = pdf_model.grid_values_func(FIT_XGRID) return FKTableForwardMap( pred_func=_pred_data, pdf_model=pdf_model, + pdf_grid_func=pdf_grid_func, extra_param_names=extra_param_names, ) diff --git a/colibri/likelihood.py b/colibri/likelihood.py index d53d4a211..43ed85797 100644 --- a/colibri/likelihood.py +++ b/colibri/likelihood.py @@ -23,7 +23,6 @@ def __init__( self, central_covmat_index, pdf_model, - fit_xgrid, forward_map, fast_kernel_arrays, positivity_fast_kernel_arrays, @@ -38,8 +37,6 @@ def __init__( pdf_model: pdf_model.PDFModel - fit_xgrid: np.ndarray - forward_map: Callable fast_kernel_arrays: tuple @@ -62,7 +59,6 @@ def __init__( self.positivity_penalty_settings = positivity_penalty_settings self.integrability_penalty = integrability_penalty - self.pdf_grid = pdf_model.grid_values_func(fit_xgrid) self.forward_map = forward_map self.fast_kernel_arrays = fast_kernel_arrays @@ -125,7 +121,7 @@ def log_likelihood( jnp.ndarray jax array with the value of the log-likelihood. """ - predictions, pdf = self.forward_map(self.pdf_grid, fast_kernel_arrays, params) + predictions, pdf = self.forward_map(fast_kernel_arrays, params) # Select only the data relevant for this likelihood # Especially important when using a training/validation split predictions = predictions[self.central_values_idx] @@ -167,7 +163,6 @@ def log_likelihood( def log_likelihood( central_covmat_index, pdf_model, - FIT_XGRID, forward_map, fast_kernel_arrays, positivity_fast_kernel_arrays, @@ -184,7 +179,6 @@ def log_likelihood( return LogLikelihood( central_covmat_index, pdf_model, - FIT_XGRID, forward_map, fast_kernel_arrays, positivity_fast_kernel_arrays, @@ -198,7 +192,6 @@ def mc_log_likelihood( mc_pseudodata, fit_covariance_matrix, pdf_model, - FIT_XGRID, forward_map, fast_kernel_arrays, positivity_fast_kernel_arrays, @@ -226,7 +219,6 @@ def mc_log_likelihood( train_loglike = LogLikelihood( central_covmat_index_train, pdf_model, - FIT_XGRID, forward_map, fast_kernel_arrays, positivity_fast_kernel_arrays, @@ -252,7 +244,6 @@ def mc_log_likelihood( val_loglike = LogLikelihood( central_covmat_index_val, pdf_model, - FIT_XGRID, forward_map, fast_kernel_arrays, positivity_fast_kernel_arrays, diff --git a/colibri/tests/conftest.py b/colibri/tests/conftest.py index f76dc16e9..cb4fedda3 100644 --- a/colibri/tests/conftest.py +++ b/colibri/tests/conftest.py @@ -328,6 +328,7 @@ def wmin_param(params): TEST_FORWARD_MAP_DIS = FKTableForwardMap( lambda pdf, fk_arrays: jnp.einsum("ijk,jk->i", fk_arrays[0], pdf), pdf_model=MOCK_PDF_MODEL, + pdf_grid_func=MOCK_PDF_MODEL.grid_values_func(TEST_XGRID), ) """ Mock DIS forward map function for testing purposes. diff --git a/colibri/tests/test_forward_map.py b/colibri/tests/test_forward_map.py index 20de39fb8..75601f5b8 100644 --- a/colibri/tests/test_forward_map.py +++ b/colibri/tests/test_forward_map.py @@ -18,6 +18,7 @@ TEST_N_DATA, TEST_N_FL, TEST_N_XGRID, + TEST_XGRID, MOCK_PDF_MODEL, ) @@ -68,19 +69,19 @@ def test_forward_map_abstract_call_raises_not_implemented(): """Calling super().__call__() must hit the raise NotImplementedError body.""" class SuperCallingForwardMap(ForwardMap): - def __call__(self, pdf_grid_func, fk_tables, params): - return super().__call__(pdf_grid_func, fk_tables, params) + def __call__(self, fk_tables, params): + return super().__call__(fk_tables, params) fm = SuperCallingForwardMap(pdf_model=_mock_pdf_model(["a", "b"])) with pytest.raises(NotImplementedError): - fm(_make_pdf_grid_func(TEST_PDF_GRID), TEST_FK_ARRAYS, jnp.array([1.0, 2.0])) + fm(TEST_FK_ARRAYS, jnp.array([1.0, 2.0])) def test_forward_map_none_pdf_model_sets_empty_param_names(): """When pdf_model is None, pdf_param_names must be set to an empty list.""" class MinimalForwardMap(ForwardMap): - def __call__(self, pdf_grid_func, fk_tables, params): + def __call__(self, fk_tables, params): return None fm = MinimalForwardMap(pdf_model=None) @@ -93,9 +94,8 @@ def test_forward_map_subclass_stores_pdf_param_names(): """pdf_param_names passed to super().__init__ must be stored on the instance.""" class MinimalForwardMap(ForwardMap): - def __call__(self, pdf_grid_func, fk_tables, params): - pdf = pdf_grid_func(params[: self.n_pdf_params]) - return _simple_pred_func(pdf, fk_tables), pdf + def __call__(self, fk_tables, params): + return None mock_model = _mock_pdf_model(["p0", "p1", "p2", "p3", "p4"]) fm = MinimalForwardMap(pdf_model=mock_model) @@ -108,7 +108,7 @@ def test_forward_map_extra_param_names_default(): """extra_param_names defaults to an empty tuple.""" class MinimalForwardMap(ForwardMap): - def __call__(self, pdf_grid_func, fk_tables, params): + def __call__(self, fk_tables, params): return None fm = MinimalForwardMap(pdf_model=_mock_pdf_model(["a", "b"])) @@ -120,7 +120,7 @@ def test_forward_map_extra_param_names(): """extra_param_names are stored and appear in param_names after pdf_param_names.""" class MinimalForwardMap(ForwardMap): - def __call__(self, pdf_grid_func, fk_tables, params): + def __call__(self, fk_tables, params): return None fm = MinimalForwardMap( @@ -138,8 +138,11 @@ def __call__(self, pdf_grid_func, fk_tables, params): def test_fktable_forward_map_stores_pdf_param_names(): """FKTableForwardMap.__init__ must store pdf_param_names via the base class.""" + pdf_grid_func = _make_pdf_grid_func(TEST_PDF_GRID) fm = FKTableForwardMap( - pred_func=_simple_pred_func, pdf_model=_mock_pdf_model(["a", "b", "c"]) + pred_func=_simple_pred_func, + pdf_model=_mock_pdf_model(["a", "b", "c"]), + pdf_grid_func=pdf_grid_func, ) assert fm.pdf_param_names == ["a", "b", "c"] assert fm.n_pdf_params == 3 @@ -147,9 +150,11 @@ def test_fktable_forward_map_stores_pdf_param_names(): def test_fktable_forward_map_extra_param_names(): """FKTableForwardMap must accept and store extra_param_names.""" + pdf_grid_func = _make_pdf_grid_func(TEST_PDF_GRID) fm = FKTableForwardMap( pred_func=_simple_pred_func, pdf_model=_mock_pdf_model(["a", "b"]), + pdf_grid_func=pdf_grid_func, extra_param_names=["norm"], ) assert fm.param_names == ["a", "b", "norm"] @@ -159,12 +164,26 @@ def test_fktable_forward_map_extra_param_names(): def test_fktable_forward_map_stores_pred_func(): """FKTableForwardMap.__init__ must store the pred_func.""" + pdf_grid_func = _make_pdf_grid_func(TEST_PDF_GRID) fm = FKTableForwardMap( - pred_func=_simple_pred_func, pdf_model=_mock_pdf_model(["a", "b", "c"]) + pred_func=_simple_pred_func, + pdf_model=_mock_pdf_model(["a", "b", "c"]), + pdf_grid_func=pdf_grid_func, ) assert fm._pred_func is _simple_pred_func +def test_fktable_forward_map_stores_pdf_grid_func(): + """FKTableForwardMap.__init__ must store the pdf_grid_func.""" + pdf_grid_func = _make_pdf_grid_func(TEST_PDF_GRID) + fm = FKTableForwardMap( + pred_func=_simple_pred_func, + pdf_model=_mock_pdf_model(["a", "b"]), + pdf_grid_func=pdf_grid_func, + ) + assert fm._pdf_grid_func is pdf_grid_func + + # --------------------------------------------------------------------------- # FKTableForwardMap.__call__ # --------------------------------------------------------------------------- @@ -172,13 +191,15 @@ def test_fktable_forward_map_stores_pred_func(): def test_fktable_forward_map_returns_tuple(): """__call__ must return a 2-tuple (predictions, pdf).""" + pdf_grid_func = _make_pdf_grid_func(TEST_PDF_GRID) fm = FKTableForwardMap( - pred_func=_simple_pred_func, pdf_model=_mock_pdf_model(["a", "b"]) + pred_func=_simple_pred_func, + pdf_model=_mock_pdf_model(["a", "b"]), + pdf_grid_func=pdf_grid_func, ) - pdf_grid_func = _make_pdf_grid_func(TEST_PDF_GRID) params = jnp.array([1.0, 2.0]) - result = fm(pdf_grid_func, TEST_FK_ARRAYS, params) + result = fm(TEST_FK_ARRAYS, params) assert isinstance(result, tuple) assert len(result) == 2 @@ -186,26 +207,30 @@ def test_fktable_forward_map_returns_tuple(): def test_fktable_forward_map_predictions_shape(): """Predictions returned by __call__ must have shape (N_data,).""" + pdf_grid_func = _make_pdf_grid_func(TEST_PDF_GRID) fm = FKTableForwardMap( - pred_func=_simple_pred_func, pdf_model=_mock_pdf_model(["a", "b"]) + pred_func=_simple_pred_func, + pdf_model=_mock_pdf_model(["a", "b"]), + pdf_grid_func=pdf_grid_func, ) - pdf_grid_func = _make_pdf_grid_func(TEST_PDF_GRID) params = jnp.array([1.0, 2.0]) - predictions, _ = fm(pdf_grid_func, TEST_FK_ARRAYS, params) + predictions, _ = fm(TEST_FK_ARRAYS, params) assert predictions.shape == (TEST_N_DATA,) def test_fktable_forward_map_pdf_shape(): """PDF returned by __call__ must have shape (N_fl, N_x).""" + pdf_grid_func = _make_pdf_grid_func(TEST_PDF_GRID) fm = FKTableForwardMap( - pred_func=_simple_pred_func, pdf_model=_mock_pdf_model(["a", "b"]) + pred_func=_simple_pred_func, + pdf_model=_mock_pdf_model(["a", "b"]), + pdf_grid_func=pdf_grid_func, ) - pdf_grid_func = _make_pdf_grid_func(TEST_PDF_GRID) params = jnp.array([1.0, 2.0]) - _, pdf = fm(pdf_grid_func, TEST_FK_ARRAYS, params) + _, pdf = fm(TEST_FK_ARRAYS, params) assert pdf.shape == (TEST_N_FL, TEST_N_XGRID) @@ -215,11 +240,12 @@ def test_fktable_forward_map_slices_pdf_params(): __call__ must pass only params[:n_pdf_params] to pdf_grid_func; extra parameters appended to params must not affect the PDF or predictions. """ - n_pdf = 2 + pdf_grid_func = _make_pdf_grid_func(TEST_PDF_GRID) fm = FKTableForwardMap( - pred_func=_simple_pred_func, pdf_model=_mock_pdf_model(["a", "b"]) + pred_func=_simple_pred_func, + pdf_model=_mock_pdf_model(["a", "b"]), + pdf_grid_func=pdf_grid_func, ) - pdf_grid_func = _make_pdf_grid_func(TEST_PDF_GRID) pdf_params = jnp.array([1.0, 2.0]) extra_params = jnp.array([99.0, -99.0]) # should be ignored @@ -227,10 +253,8 @@ def test_fktable_forward_map_slices_pdf_params(): params_no_extra = pdf_params params_with_extra = jnp.concatenate([pdf_params, extra_params]) - preds_no_extra, pdf_no_extra = fm(pdf_grid_func, TEST_FK_ARRAYS, params_no_extra) - preds_with_extra, pdf_with_extra = fm( - pdf_grid_func, TEST_FK_ARRAYS, params_with_extra - ) + preds_no_extra, pdf_no_extra = fm(TEST_FK_ARRAYS, params_no_extra) + preds_with_extra, pdf_with_extra = fm(TEST_FK_ARRAYS, params_with_extra) assert_array_almost_equal(preds_no_extra, preds_with_extra) assert_array_almost_equal(pdf_no_extra, pdf_with_extra) @@ -238,21 +262,27 @@ def test_fktable_forward_map_slices_pdf_params(): def test_fktable_forward_map_uses_pdf_grid_func(): """ - __call__ must feed the pdf returned by pdf_grid_func into pred_func. - We verify this by using a pdf_grid_func that scales by a known factor. + __call__ must feed the pdf returned by the bound pdf_grid_func into pred_func. + We verify this by constructing two forward maps with differently scaled pdf_grid_funcs. """ scale = 3.0 - n_pdf = 2 - fm = FKTableForwardMap( - pred_func=_simple_pred_func, pdf_model=_mock_pdf_model(["a", "b"]) - ) - - params = jnp.array([1.0, 2.0]) base_pdf_grid_func = _make_pdf_grid_func(TEST_PDF_GRID) scaled_pdf_grid_func = lambda p: scale * base_pdf_grid_func(p) # noqa: E731 - preds_base, _ = fm(base_pdf_grid_func, TEST_FK_ARRAYS, params) - preds_scaled, _ = fm(scaled_pdf_grid_func, TEST_FK_ARRAYS, params) + fm_base = FKTableForwardMap( + pred_func=_simple_pred_func, + pdf_model=_mock_pdf_model(["a", "b"]), + pdf_grid_func=base_pdf_grid_func, + ) + fm_scaled = FKTableForwardMap( + pred_func=_simple_pred_func, + pdf_model=_mock_pdf_model(["a", "b"]), + pdf_grid_func=scaled_pdf_grid_func, + ) + + params = jnp.array([1.0, 2.0]) + preds_base, _ = fm_base(TEST_FK_ARRAYS, params) + preds_scaled, _ = fm_scaled(TEST_FK_ARRAYS, params) np.testing.assert_allclose(preds_scaled, scale * preds_base, rtol=1e-5) @@ -261,15 +291,15 @@ def test_fktable_forward_map_correct_values(): """ __call__ must produce predictions equal to pred_func(pdf_grid_func(params), fk). """ - n_pdf = 2 + pdf_grid_func = _make_pdf_grid_func(TEST_PDF_GRID) fm = FKTableForwardMap( - pred_func=_simple_pred_func, pdf_model=_mock_pdf_model(["a", "b"]) + pred_func=_simple_pred_func, + pdf_model=_mock_pdf_model(["a", "b"]), + pdf_grid_func=pdf_grid_func, ) params = jnp.array([1.0, 2.0]) - pdf_grid_func = _make_pdf_grid_func(TEST_PDF_GRID) - - predictions, pdf = fm(pdf_grid_func, TEST_FK_ARRAYS, params) + predictions, pdf = fm(TEST_FK_ARRAYS, params) expected_pdf = pdf_grid_func(params) expected_preds = _simple_pred_func(expected_pdf, TEST_FK_ARRAYS) @@ -285,7 +315,9 @@ def test_fktable_forward_map_correct_values(): def test_forward_map_provider_returns_fktable_forward_map(): """forward_map() must return an FKTableForwardMap instance.""" - result = forward_map(_pred_data=_simple_pred_func, pdf_model=MOCK_PDF_MODEL) + result = forward_map( + _pred_data=_simple_pred_func, pdf_model=MOCK_PDF_MODEL, FIT_XGRID=TEST_XGRID + ) assert isinstance(result, FKTableForwardMap) @@ -294,27 +326,40 @@ def test_forward_map_provider_infers_pdf_param_names(): forward_map() must set pdf_param_names equal to pdf_model.param_names, and n_pdf_params must equal len(pdf_model.param_names). """ - result = forward_map(_pred_data=_simple_pred_func, pdf_model=MOCK_PDF_MODEL) + result = forward_map( + _pred_data=_simple_pred_func, pdf_model=MOCK_PDF_MODEL, FIT_XGRID=TEST_XGRID + ) assert result.pdf_param_names == MOCK_PDF_MODEL.param_names assert result.n_pdf_params == len(MOCK_PDF_MODEL.param_names) def test_forward_map_provider_stores_pred_func(): """forward_map() must wire _pred_data into the FKTableForwardMap.""" - result = forward_map(_pred_data=_simple_pred_func, pdf_model=MOCK_PDF_MODEL) + result = forward_map( + _pred_data=_simple_pred_func, pdf_model=MOCK_PDF_MODEL, FIT_XGRID=TEST_XGRID + ) assert result._pred_func is _simple_pred_func +def test_forward_map_provider_binds_pdf_grid_func(): + """forward_map() must call pdf_model.grid_values_func(FIT_XGRID) and store the result.""" + result = forward_map( + _pred_data=_simple_pred_func, pdf_model=MOCK_PDF_MODEL, FIT_XGRID=TEST_XGRID + ) + assert callable(result._pdf_grid_func) + + def test_forward_map_provider_functional(): """ The FKTableForwardMap built by forward_map() must produce correct results when called. """ - fm = forward_map(_pred_data=_simple_pred_func, pdf_model=MOCK_PDF_MODEL) - pdf_grid_func = _make_pdf_grid_func(TEST_PDF_GRID) + fm = forward_map( + _pred_data=_simple_pred_func, pdf_model=MOCK_PDF_MODEL, FIT_XGRID=TEST_XGRID + ) params = jnp.array([1.0, 2.0]) - predictions, pdf = fm(pdf_grid_func, TEST_FK_ARRAYS, params) + predictions, pdf = fm(TEST_FK_ARRAYS, params) assert predictions.shape == (TEST_N_DATA,) assert pdf.shape == (TEST_N_FL, TEST_N_XGRID) @@ -328,7 +373,9 @@ def test_forward_map_provider_with_different_param_counts(): for n in [1, 3, 7]: mock_model = Mock() mock_model.param_names = [f"p_{i}" for i in range(n)] - fm = forward_map(_pred_data=_simple_pred_func, pdf_model=mock_model) + fm = forward_map( + _pred_data=_simple_pred_func, pdf_model=mock_model, FIT_XGRID=TEST_XGRID + ) assert fm.pdf_param_names == mock_model.param_names assert fm.n_pdf_params == n @@ -342,6 +389,7 @@ def test_forward_map_provider_with_extra_param_names(): fm = forward_map( _pred_data=_simple_pred_func, pdf_model=MOCK_PDF_MODEL, + FIT_XGRID=TEST_XGRID, extra_param_names=extra, ) assert fm.param_names == MOCK_PDF_MODEL.param_names + extra diff --git a/colibri/tests/test_likelihood.py b/colibri/tests/test_likelihood.py index d218cae60..a8910008a 100644 --- a/colibri/tests/test_likelihood.py +++ b/colibri/tests/test_likelihood.py @@ -18,7 +18,7 @@ TEST_FK_ARRAYS, TEST_FORWARD_MAP_DIS, TEST_POS_FK_ARRAYS, - TEST_XGRID, + ) from colibri.data_batch import BatchSpec @@ -38,7 +38,6 @@ def test_LogLikelihood_class(pos_penalty): log_likelihood_class = LogLikelihood( central_covmat_index=MOCK_CENTRAL_COVMAT_INDEX, pdf_model=MOCK_PDF_MODEL, - fit_xgrid=TEST_XGRID, forward_map=TEST_FORWARD_MAP_DIS, fast_kernel_arrays=TEST_FK_ARRAYS, positivity_fast_kernel_arrays=TEST_POS_FK_ARRAYS, @@ -67,7 +66,7 @@ def test_LogLikelihood_class(pos_penalty): ) # Compute expected value using actual prediction and covariance predictions, pdf = log_likelihood_class.forward_map( - log_likelihood_class.pdf_grid, log_likelihood_class.fast_kernel_arrays, params + log_likelihood_class.fast_kernel_arrays, params ) predictions = predictions[log_likelihood_class.central_values_idx] diff = predictions - log_likelihood_class.central_values @@ -104,7 +103,6 @@ def test_log_likelihood(pos_penalty): log_likelihood_class = LogLikelihood( central_covmat_index=MOCK_CENTRAL_COVMAT_INDEX, pdf_model=MOCK_PDF_MODEL, - fit_xgrid=TEST_XGRID, forward_map=TEST_FORWARD_MAP_DIS, fast_kernel_arrays=TEST_FK_ARRAYS, positivity_fast_kernel_arrays=TEST_POS_FK_ARRAYS, @@ -115,7 +113,6 @@ def test_log_likelihood(pos_penalty): log_like = log_likelihood( MOCK_CENTRAL_COVMAT_INDEX, MOCK_PDF_MODEL, - TEST_XGRID, TEST_FORWARD_MAP_DIS, TEST_FK_ARRAYS, TEST_POS_FK_ARRAYS, @@ -143,7 +140,6 @@ def test_log_likelihood_with_and_without_pos_penalty(): log_likelihood_class = LogLikelihood( MOCK_CENTRAL_COVMAT_INDEX, MOCK_PDF_MODEL, - TEST_XGRID, TEST_FORWARD_MAP_DIS, TEST_FK_ARRAYS, TEST_POS_FK_ARRAYS, @@ -166,7 +162,7 @@ def test_log_likelihood_with_and_without_pos_penalty(): # Compute expectation directly: -0.5 * (chi2 + pos_pen + integ_pen) predictions, pdf = log_likelihood_class.forward_map( - log_likelihood_class.pdf_grid, log_likelihood_class.fast_kernel_arrays, params + log_likelihood_class.fast_kernel_arrays, params ) predictions = predictions[log_likelihood_class.central_values_idx] diff = predictions - log_likelihood_class.central_values @@ -195,7 +191,6 @@ def test_log_likelihood_with_and_without_pos_penalty(): log_likelihood_class = LogLikelihood( MOCK_CENTRAL_COVMAT_INDEX, MOCK_PDF_MODEL, - TEST_XGRID, TEST_FORWARD_MAP_DIS, TEST_FK_ARRAYS, TEST_POS_FK_ARRAYS, @@ -214,7 +209,7 @@ def test_log_likelihood_with_and_without_pos_penalty(): # Expectation: Only chi2 value (penalties zeroed) predictions, pdf = log_likelihood_class.forward_map( - log_likelihood_class.pdf_grid, log_likelihood_class.fast_kernel_arrays, params + log_likelihood_class.fast_kernel_arrays, params ) predictions = predictions[log_likelihood_class.central_values_idx] diff = predictions - log_likelihood_class.central_values @@ -255,7 +250,6 @@ def test_mc_log_likelihood_with_split(pos_penalty): mc_pd, fit_covariance_matrix, MOCK_PDF_MODEL, - TEST_XGRID, TEST_FORWARD_MAP_DIS, TEST_FK_ARRAYS, TEST_POS_FK_ARRAYS, @@ -276,7 +270,7 @@ def test_mc_log_likelihood_with_split(pos_penalty): # Compute expected for train and validation independently def compute_expected(ll_obj): preds, pdf = ll_obj.forward_map( - ll_obj.pdf_grid, ll_obj.fast_kernel_arrays, params + ll_obj.fast_kernel_arrays, params ) preds = preds[ll_obj.central_values_idx] diff = preds - ll_obj.central_values @@ -335,7 +329,6 @@ def test_mc_log_likelihood_without_split_returns_nan_for_validation(pos_penalty) mc_pd, fit_covariance_matrix, MOCK_PDF_MODEL, - TEST_XGRID, TEST_FORWARD_MAP_DIS, TEST_FK_ARRAYS, TEST_POS_FK_ARRAYS, @@ -351,7 +344,7 @@ def test_mc_log_likelihood_without_split_returns_nan_for_validation(pos_penalty) train_val = train_loglike(params) # Compute expected train value predictions, pdf = train_loglike.forward_map( - train_loglike.pdf_grid, train_loglike.fast_kernel_arrays, params + train_loglike.fast_kernel_arrays, params ) predictions = predictions[train_loglike.central_values_idx] diff = predictions - train_loglike.central_values @@ -394,7 +387,6 @@ def test_LogLikelihood_call_with_batch_idx(pos_penalty): log_likelihood_class = LogLikelihood( central_covmat_index=MOCK_CENTRAL_COVMAT_INDEX, pdf_model=MOCK_PDF_MODEL, - fit_xgrid=TEST_XGRID, forward_map=TEST_FORWARD_MAP_DIS, fast_kernel_arrays=TEST_FK_ARRAYS, positivity_fast_kernel_arrays=TEST_POS_FK_ARRAYS, @@ -412,7 +404,7 @@ def test_LogLikelihood_call_with_batch_idx(pos_penalty): # Compute expected on the batch index: recompute inv_covmat on the sub-covmat predictions, pdf = log_likelihood_class.forward_map( - log_likelihood_class.pdf_grid, log_likelihood_class.fast_kernel_arrays, params + log_likelihood_class.fast_kernel_arrays, params ) predictions = predictions[log_likelihood_class.central_values_idx] predictions_b = predictions[batch.idx] @@ -456,7 +448,6 @@ def test_LogLikelihood_call_with_batch_with_inv_cov(pos_penalty): log_likelihood_class = LogLikelihood( central_covmat_index=MOCK_CENTRAL_COVMAT_INDEX, pdf_model=MOCK_PDF_MODEL, - fit_xgrid=TEST_XGRID, forward_map=TEST_FORWARD_MAP_DIS, fast_kernel_arrays=TEST_FK_ARRAYS, positivity_fast_kernel_arrays=TEST_POS_FK_ARRAYS, @@ -479,7 +470,7 @@ def test_LogLikelihood_call_with_batch_with_inv_cov(pos_penalty): # Compute expected value using the provided inv_b (should be identical) predictions, pdf = log_likelihood_class.forward_map( - log_likelihood_class.pdf_grid, log_likelihood_class.fast_kernel_arrays, params + log_likelihood_class.fast_kernel_arrays, params ) predictions = predictions[log_likelihood_class.central_values_idx] predictions_b = predictions[batch.idx] diff --git a/colibri/utils.py b/colibri/utils.py index e66910b17..aee86d0da 100644 --- a/colibri/utils.py +++ b/colibri/utils.py @@ -308,10 +308,9 @@ def likelihood_float_type( central_values = central_covmat_index.central_values covmat = central_covmat_index.covmat - pdf_grid = pdf_model.grid_values_func(FIT_XGRID) def log_likelihood(params, central_values, inv_covmat, fast_kernel_arrays): - predictions, pdf = forward_map(pdf_grid, fast_kernel_arrays, params) + predictions, pdf = forward_map(fast_kernel_arrays, params) return -0.5 * loss_function(central_values, predictions, inv_covmat) params = bayesian_prior.prior_transform( From 8581a03083249b61276ff02479d64a0bbfcff1e1 Mon Sep 17 00:00:00 2001 From: Luca Mantani Date: Tue, 28 Apr 2026 12:34:28 +0200 Subject: [PATCH 25/30] black --- colibri/tests/test_likelihood.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/colibri/tests/test_likelihood.py b/colibri/tests/test_likelihood.py index a8910008a..3a6ee23bc 100644 --- a/colibri/tests/test_likelihood.py +++ b/colibri/tests/test_likelihood.py @@ -18,7 +18,6 @@ TEST_FK_ARRAYS, TEST_FORWARD_MAP_DIS, TEST_POS_FK_ARRAYS, - ) from colibri.data_batch import BatchSpec @@ -269,9 +268,7 @@ def test_mc_log_likelihood_with_split(pos_penalty): # Compute expected for train and validation independently def compute_expected(ll_obj): - preds, pdf = ll_obj.forward_map( - ll_obj.fast_kernel_arrays, params - ) + preds, pdf = ll_obj.forward_map(ll_obj.fast_kernel_arrays, params) preds = preds[ll_obj.central_values_idx] diff = preds - ll_obj.central_values inv = ll_obj.inv_covmat From 11641ba9faf6f1ee61e25a62ca4d61ce0cc83194 Mon Sep 17 00:00:00 2001 From: Luca Mantani Date: Tue, 28 Apr 2026 12:57:50 +0200 Subject: [PATCH 26/30] Adapted tests --- colibri/tests/test_analytic_fit.py | 13 +++++++++-- colibri/tests/test_blackjax_fit.py | 2 +- colibri/tests/test_checks.py | 8 ++++--- colibri/tests/test_ultranest_fit.py | 35 ++++++++++++++++++++--------- colibri/tests/test_utils.py | 6 +++-- 5 files changed, 46 insertions(+), 18 deletions(-) diff --git a/colibri/tests/test_analytic_fit.py b/colibri/tests/test_analytic_fit.py index a4fd6ec24..9120e8315 100644 --- a/colibri/tests/test_analytic_fit.py +++ b/colibri/tests/test_analytic_fit.py @@ -43,6 +43,7 @@ def test_analytic_fit_flat_direction(): forward_map = FKTableForwardMap( lambda pdf, fkarrs: jnp.ones(n_params), pdf_model=MOCK_PDF_MODEL, + pdf_grid_func=MOCK_PDF_MODEL.grid_values_func(TEST_XGRID), ) with pytest.raises(ValueError): @@ -69,7 +70,11 @@ def test_analytic_fit(caplog, monkeypatch): MOCK_PDF_MODEL, "grid_values_func", lambda xgrid: lambda params: params ) - forward_map = FKTableForwardMap(lambda pdf, fkarrs: pdf, pdf_model=MOCK_PDF_MODEL) + forward_map = FKTableForwardMap( + lambda pdf, fkarrs: pdf, + pdf_model=MOCK_PDF_MODEL, + pdf_grid_func=MOCK_PDF_MODEL.grid_values_func(TEST_XGRID), + ) # Run the analytic fit result = analytic_fit( @@ -132,7 +137,11 @@ def test_analytic_fit_different_priors(caplog, monkeypatch): MOCK_PDF_MODEL, "grid_values_func", lambda xgrid: lambda params: params ) - forward_map = FKTableForwardMap(lambda pdf, fkarrs: pdf, pdf_model=MOCK_PDF_MODEL) + forward_map = FKTableForwardMap( + lambda pdf, fkarrs: pdf, + pdf_model=MOCK_PDF_MODEL, + pdf_grid_func=MOCK_PDF_MODEL.grid_values_func(TEST_XGRID), + ) # Run the analytic fit result = analytic_fit( diff --git a/colibri/tests/test_blackjax_fit.py b/colibri/tests/test_blackjax_fit.py index e4a41d722..618314714 100644 --- a/colibri/tests/test_blackjax_fit.py +++ b/colibri/tests/test_blackjax_fit.py @@ -65,11 +65,11 @@ def test_blackjax_fit(pos_penalty): forward_map = FKTableForwardMap( lambda pdf, fk: jnp.zeros(len(MOCK_PDF_MODEL.param_names)), pdf_model=MOCK_PDF_MODEL, + pdf_grid_func=MOCK_PDF_MODEL.grid_values_func(TEST_XGRID), ) mock_log_likelihood = LogLikelihood( MOCK_CENTRAL_COVMAT_INDEX, MOCK_PDF_MODEL, - TEST_XGRID, forward_map, TEST_FK_ARRAYS, TEST_POS_FK_ARRAYS, diff --git a/colibri/tests/test_checks.py b/colibri/tests/test_checks.py index 256d3a428..a0145512a 100644 --- a/colibri/tests/test_checks.py +++ b/colibri/tests/test_checks.py @@ -129,15 +129,16 @@ def test_check_pdf_model_is_linear(mock_fast_kernel_arrays, mock_make_pred_data) def pdf_linear_model(params): return params + # Set the mock's grid_values_func to return the linear_model function + mock_pdf_model.grid_values_func.return_value = pdf_linear_model + forward_map_lin = FKTableForwardMap( # Simulating a simple linear model: f(x) = a*x + b*y + c*z + 3.0, where pdf = [a, b, c] lambda pdf, fk: jnp.dot(pdf, fk) + 3.0, pdf_model=mock_pdf_model, + pdf_grid_func=mock_pdf_model.grid_values_func(FIT_XGRID), ) - # Set the mock's grid_values_func to return the linear_model function - mock_pdf_model.grid_values_func.return_value = pdf_linear_model - # Test for linear model (should not raise an exception) check_pdf_model_is_linear(mock_pdf_model, forward_map_lin, FIT_XGRID, data) @@ -146,6 +147,7 @@ def pdf_linear_model(params): # Introduce some non-linearity lambda pdf, fk: jnp.dot(pdf**2, FIT_XGRID) + fk, pdf_model=mock_pdf_model, + pdf_grid_func=mock_pdf_model.grid_values_func(FIT_XGRID), ) # Ensure ValueError is raised for non-linear model diff --git a/colibri/tests/test_ultranest_fit.py b/colibri/tests/test_ultranest_fit.py index 122f557a1..912a1f169 100644 --- a/colibri/tests/test_ultranest_fit.py +++ b/colibri/tests/test_ultranest_fit.py @@ -67,11 +67,14 @@ def mock_sample(rng_key, n_samples): def test_ultranest_fit(pos_penalty): _pred_data = lambda *args: jnp.array([0.0]) - forward_map = FKTableForwardMap(_pred_data, pdf_model=MOCK_PDF_MODEL) + forward_map = FKTableForwardMap( + _pred_data, + pdf_model=MOCK_PDF_MODEL, + pdf_grid_func=MOCK_PDF_MODEL.grid_values_func(TEST_XGRID), + ) mock_log_likelihood = LogLikelihood( MOCK_CENTRAL_COVMAT_INDEX, MOCK_PDF_MODEL, - TEST_XGRID, forward_map, TEST_FK_ARRAYS, TEST_POS_FK_ARRAYS, @@ -105,13 +108,16 @@ def test_ultranest_fit(pos_penalty): def test_ultranest_fit_vectorized(pos_penalty): _pred_data = lambda *args: jnp.array([0.0]) - forward_map = FKTableForwardMap(_pred_data, pdf_model=MOCK_PDF_MODEL) + forward_map = FKTableForwardMap( + _pred_data, + pdf_model=MOCK_PDF_MODEL, + pdf_grid_func=MOCK_PDF_MODEL.grid_values_func(TEST_XGRID), + ) ultranest_settings["ReactiveNS_settings"]["vectorized"] = True mock_log_likelihood = LogLikelihood( MOCK_CENTRAL_COVMAT_INDEX, MOCK_PDF_MODEL, - TEST_XGRID, forward_map, TEST_FK_ARRAYS, TEST_POS_FK_ARRAYS, @@ -155,12 +161,15 @@ def test_ultranest_fit_with_SliceSampler(pos_penalty): } _pred_data = lambda *args: jnp.array([0.0]) - forward_map = FKTableForwardMap(_pred_data, pdf_model=MOCK_PDF_MODEL) + forward_map = FKTableForwardMap( + _pred_data, + pdf_model=MOCK_PDF_MODEL, + pdf_grid_func=MOCK_PDF_MODEL.grid_values_func(TEST_XGRID), + ) mock_log_likelihood = LogLikelihood( MOCK_CENTRAL_COVMAT_INDEX, MOCK_PDF_MODEL, - TEST_XGRID, forward_map, TEST_FK_ARRAYS, TEST_POS_FK_ARRAYS, @@ -204,12 +213,15 @@ def test_ultranest_fit_with_popSliceSampler(pos_penalty): } _pred_data = lambda *args: jnp.array([0.0]) - forward_map = FKTableForwardMap(_pred_data, pdf_model=MOCK_PDF_MODEL) + forward_map = FKTableForwardMap( + _pred_data, + pdf_model=MOCK_PDF_MODEL, + pdf_grid_func=MOCK_PDF_MODEL.grid_values_func(TEST_XGRID), + ) mock_log_likelihood = LogLikelihood( MOCK_CENTRAL_COVMAT_INDEX, MOCK_PDF_MODEL, - TEST_XGRID, forward_map, TEST_FK_ARRAYS, TEST_POS_FK_ARRAYS, @@ -257,12 +269,15 @@ def test_ultranest_fit_with_sampler_plot(mock_sampler_class, pos_penalty): } _pred_data = lambda *args: jnp.array([0.0]) - forward_map = FKTableForwardMap(_pred_data, pdf_model=MOCK_PDF_MODEL) + forward_map = FKTableForwardMap( + _pred_data, + pdf_model=MOCK_PDF_MODEL, + pdf_grid_func=MOCK_PDF_MODEL.grid_values_func(TEST_XGRID), + ) mock_log_likelihood = LogLikelihood( MOCK_CENTRAL_COVMAT_INDEX, MOCK_PDF_MODEL, - TEST_XGRID, forward_map, TEST_FK_ARRAYS, TEST_POS_FK_ARRAYS, diff --git a/colibri/tests/test_utils.py b/colibri/tests/test_utils.py index 57d6a8f46..d7a7459c6 100644 --- a/colibri/tests/test_utils.py +++ b/colibri/tests/test_utils.py @@ -342,10 +342,12 @@ def test_likelihood_float_type( _pred_data = lambda x, fks: jnp.ones( len(MOCK_CENTRAL_COVMAT_INDEX.central_values) ) # Mock _pred_data + FIT_XGRID = jnp.linspace(0, 1, 10) # Mock FIT_XGRID forward_map = FKTableForwardMap( - _pred_data, pdf_model=MOCK_PDF_MODEL + _pred_data, + pdf_model=MOCK_PDF_MODEL, + pdf_grid_func=MOCK_PDF_MODEL.grid_values_func(FIT_XGRID), ) # Mock forward_map - FIT_XGRID = jnp.linspace(0, 1, 10) # Mock FIT_XGRID output_path = tmp_path fast_kernel_arrays = jax.random.uniform( From 28daeabc225bfeae838fbfe9331a4420c63d400d Mon Sep 17 00:00:00 2001 From: Luca Mantani Date: Tue, 28 Apr 2026 15:02:35 +0200 Subject: [PATCH 27/30] Adapted after merging main --- colibri/bayes_prior.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/colibri/bayes_prior.py b/colibri/bayes_prior.py index bb8177c91..aa098ed59 100644 --- a/colibri/bayes_prior.py +++ b/colibri/bayes_prior.py @@ -43,8 +43,9 @@ def bayesian_prior(prior_settings, forward_map): elif "min_val" in prior_specs and "max_val" in prior_specs: # Global bounds for all parameters - mins = jnp.array([float(prior_specs["min_val"])] * pdf_model.n_parameters) - maxs = jnp.array([float(prior_specs["max_val"])] * pdf_model.n_parameters) + n_params = len(forward_map.param_names) + mins = jnp.array([float(prior_specs["min_val"])] * n_params) + maxs = jnp.array([float(prior_specs["max_val"])] * n_params) else: raise ValueError( From a5cddec57efa801fc4d53d1e0924e3b906731158 Mon Sep 17 00:00:00 2001 From: Luca Mantani Date: Tue, 28 Apr 2026 15:51:03 +0200 Subject: [PATCH 28/30] Fixed bug --- colibri/blackjax_fit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colibri/blackjax_fit.py b/colibri/blackjax_fit.py index fbe1b6f0b..bcb3bd498 100644 --- a/colibri/blackjax_fit.py +++ b/colibri/blackjax_fit.py @@ -150,7 +150,7 @@ def one_step(carry, xs): nested_samples.to_csv(log_dir + "/nested_samples.csv") # Export resampled posterior samples - posterior_df = pd.DataFrame(resampled_posterior, columns=pdf_model.param_names) + posterior_df = pd.DataFrame(resampled_posterior, columns=forward_map.param_names) posterior_df.to_csv(os.path.join(log_dir, "posterior_samples.csv"), index=False) # Compute bayesian metrics (similar to UltraNest) From bface548ce17be8904eab1a105fafdd16709c529 Mon Sep 17 00:00:00 2001 From: vschutze-alt Date: Wed, 29 Apr 2026 16:12:36 +0100 Subject: [PATCH 29/30] typo --- colibri/forward_map.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colibri/forward_map.py b/colibri/forward_map.py index f05fba48c..cfe229a57 100644 --- a/colibri/forward_map.py +++ b/colibri/forward_map.py @@ -5,7 +5,7 @@ A ``ForwardMap`` implements the final stage of the fit pipeline, turning the fit parameter vector into theory predictions that can be compared with -data in the likelihood. It will also also return the PDF values on the fit x-grid, +data in the likelihood. It will also return the PDF values on the fit x-grid, which is sometimes needed for computing penalties. From 177898a52b158c8a3579f14be1ac29d9d9627c12 Mon Sep 17 00:00:00 2001 From: vschutze-alt Date: Wed, 29 Apr 2026 16:24:34 +0100 Subject: [PATCH 30/30] typo --- colibri/forward_map.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colibri/forward_map.py b/colibri/forward_map.py index cfe229a57..e7e1e111c 100644 --- a/colibri/forward_map.py +++ b/colibri/forward_map.py @@ -48,7 +48,7 @@ def __call__(self, fk_tables, params): return norm * self._pred_func(pdf, fk_tables), pdf Example - fixed PDF, fitting only extra parameters ---------------------------------------------------- +-------------------------------------------------- :: class FixedPDFForwardMap(ForwardMap):