Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
5c6be3d
Factored out forward map
LucaMantani Feb 16, 2026
83bea26
Added tests
LucaMantani Feb 16, 2026
458636d
removed pred_and_pdf from everywhere
LucaMantani Feb 16, 2026
7a8c95a
New forward map class
LucaMantani Feb 16, 2026
2b898bc
Refined implementation
LucaMantani Feb 16, 2026
0d8a241
Changed conftest
LucaMantani Feb 16, 2026
744b275
Fixed tests
LucaMantani Feb 16, 2026
29dfd7e
Make sure we write pdfs with the first parameters
LucaMantani Feb 16, 2026
0fd84f3
Restored doc
LucaMantani Feb 16, 2026
5dcb3c3
Fixed bug in tests
LucaMantani Feb 16, 2026
1e57a68
use check_pdf_model_is_linear as function rather than decorator
comane Mar 23, 2026
1afebf6
Apply suggestion from @comane
comane Mar 23, 2026
8f7cffd
added tests for forward map
comane Mar 24, 2026
97ceb1b
merge commit
comane Mar 24, 2026
a4b01b4
fixed tests from merge
comane Mar 24, 2026
cbf5ff6
upgraded local black and formatted forward map tests
comane Mar 24, 2026
74d0f4e
added line for raise not implemented in forward map
comane Mar 24, 2026
7da4f04
forward model initialised with pdf parameter names
comane Mar 26, 2026
054c97f
pass forward model to bayesian prior for total model params
comane Mar 26, 2026
22207f3
pass pdf_model object to forward map
comane Mar 26, 2026
18dc793
Update colibri/forward_map.py
LucaMantani Apr 28, 2026
089110c
Update colibri/forward_map.py
LucaMantani Apr 28, 2026
cf02826
Update colibri/forward_map.py
LucaMantani Apr 28, 2026
0cdce57
Added test
LucaMantani Apr 28, 2026
20941aa
removed grid_func passing to forward map
LucaMantani Apr 28, 2026
8581a03
black
LucaMantani Apr 28, 2026
11641ba
Adapted tests
LucaMantani Apr 28, 2026
f134f71
Merge branch 'main' into separate-forward-map
LucaMantani Apr 28, 2026
79c1e73
Merge branch 'main' into separate-forward-map
LucaMantani Apr 28, 2026
28daeab
Adapted after merging main
LucaMantani Apr 28, 2026
a5cddec
Fixed bug
LucaMantani Apr 28, 2026
bface54
typo
vschutze-alt Apr 29, 2026
177898a
typo
vschutze-alt Apr 29, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions colibri/analytic_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def analytic_evidence_uniform_prior(sol_covmat, sol_mean, max_logl, a_vec, b_vec
@check_pdf_model_is_linear
def analytic_fit(
central_inv_covmat_index,
_pred_data,
forward_map,
pdf_model,
analytic_settings,
prior_settings,
Expand All @@ -105,8 +105,8 @@ def analytic_fit(
central_inv_covmat_index: commondata_utils.CentralInvCovmatIndex
dataclass containing central values and inverse covmat.

_pred_data: @jax.jit CompiledFunction
Prediction function for the fit.
forward_map: @jax.jit CompiledFunction
Forward map function for the fit.

pdf_model: pdf_model.PDFModel
PDF model to fit.
Expand All @@ -131,14 +131,14 @@ def analytic_fit(
)

parameters = pdf_model.param_names
pred_and_pdf = pdf_model.pred_and_pdf_func(FIT_XGRID, forward_map=_pred_data)

# Precompute predictions for the basis of the model
bases = jnp.identity(len(parameters))
pdf_grid = pdf_model.grid_values_func(FIT_XGRID)
predictions = jnp.array(
[pred_and_pdf(basis, fast_kernel_arrays)[0] for basis in bases]
[forward_map(pdf_grid, fast_kernel_arrays, basis)[0] for basis in bases]
)
intercept = pred_and_pdf(jnp.zeros(len(parameters)), fast_kernel_arrays)[0]
intercept = forward_map(pdf_grid, fast_kernel_arrays, jnp.zeros(len(parameters)))[0]

# Construct the analytic solution
central_values = central_inv_covmat_index.central_values
Expand Down
1 change: 1 addition & 0 deletions colibri/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
"colibri.param_initialisation",
"colibri.export_results",
"colibri.closure_test",
"colibri.forward_map",
"reportengine.report",
]

Expand Down
14 changes: 7 additions & 7 deletions colibri/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def check_pdf_models_equal(prior_settings, pdf_model, theoryid):


@make_argcheck
def check_pdf_model_is_linear(pdf_model, FIT_XGRID, data):
def check_pdf_model_is_linear(pdf_model, forward_map, FIT_XGRID, data):
"""
Decorator that can be added to functions to check that the
Comment thread
comane marked this conversation as resolved.
Outdated
PDF model is linear.
Expand All @@ -52,8 +52,8 @@ def check_pdf_model_is_linear(pdf_model, FIT_XGRID, data):
fk = fast_kernel_arrays(data, FIT_XGRID)

parameters = pdf_model.param_names
pred_and_pdf = pdf_model.pred_and_pdf_func(FIT_XGRID, forward_map=pred_data)
intercept = pred_and_pdf(jnp.zeros(len(parameters)), fk)[0]
pdf_grid = pdf_model.grid_values_func(FIT_XGRID)
intercept, _ = forward_map(pdf_grid, fk, jnp.zeros(len(parameters)))

# Run the check for 10 random points in the parameter space
for i in range(10):
Expand All @@ -65,16 +65,16 @@ def check_pdf_model_is_linear(pdf_model, FIT_XGRID, data):

# Test additivity
add_check = jnp.isclose(
pred_and_pdf(x1, fk)[0] + pred_and_pdf(x2, fk)[0],
pred_and_pdf(x1 + x2, fk)[0] + intercept,
forward_map(pdf_grid, fk, x1)[0] + forward_map(pdf_grid, fk, x2)[0],
forward_map(pdf_grid, fk, x1 + x2)[0] + intercept,
)

# Test homogeneity
c = jax.random.uniform(key, (1,))

homogeneity_check = jnp.isclose(
c * (pred_and_pdf(x1, fk)[0] - intercept),
pred_and_pdf(c * x1, fk)[0] - intercept,
c * (forward_map(pdf_grid, fk, x1)[0] - intercept),
forward_map(pdf_grid, fk, c * x1)[0] - intercept,
)

if not add_check.all() or not homogeneity_check.all():
Expand Down
3 changes: 2 additions & 1 deletion colibri/export_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,10 +246,11 @@ def write_replicas(

# Create the exportgrid
lhapdf_interpolator = pdf_model.grid_values_func(xgrid)
n_pdf_params = len(pdf_model.param_names)

# Finish by writing the replicas to export grids, ready for evolution
for i in indices_per_process:
parameters = jnp.array(bayes_fit.resampled_posterior[i, :])
parameters = jnp.array(bayes_fit.resampled_posterior[i, :n_pdf_params])
grid_for_writing = np.array(lhapdf_interpolator(parameters))

replica_index = i + 1
Expand Down
158 changes: 158 additions & 0 deletions colibri/forward_map.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
"""
colibri.forward_map.py

Forward maps: parameters → theory predictions.

A ``ForwardMap`` implements the final stage of the fit pipeline, turning the
fit parameter vector into theory predictions that can be compared with
data in the likelihood. It will also also return the PDF values on the fit x-grid,
which is sometimes needed for computing penalties.


Design choice: fixed call signature
-----------------------------------
The log-likelihood calls every forward map with the same fixed signature::

(pdf_grid_func, fk_tables, params) -> predictions, pdf

Parameter convention
--------------------
``params`` is a 1-D array containing *all* fit parameters. In colibri we allow
for "extra" fit parameters beyond the PDF model parameters (e.g. nuisance-like factors,
or parameters of a custom prediction function).

By convention:

``params[:self.n_pdf_params]`` are PDF parameters consumed by ``pdf_grid_func``;
any remaining entries are "extra" parameters interpreted by the chosen
``ForwardMap`` implementation.

Example - fitting a normalisation factor on top of the PDF
----------------------------------------------------------
::

class NormForwardMap(ForwardMap):
def __init__(self, pred_func, n_pdf_params: int):
super().__init__(n_pdf_params)
self._pred_func = pred_func

def __call__(self, pdf_grid_func, fk_tables, params):
pdf = pdf_grid_func(params[: self.n_pdf_params])
norm = params[self.n_pdf_params] # first extra parameter
return norm * self._pred_func(pdf, fk_tables), pdf

Example - fixed PDF, fitting only extra parameters
---------------------------------------------------
::

class FixedPDFForwardMap(ForwardMap):
def __init__(self, pred_func, fixed_pdf, fk_tables, n_pdf_params: int = 0):
super().__init__(n_pdf_params)
self._pred_func = pred_func
self.fixed_pdf = fixed_pdf
self._fixed_pred = self._pred_func(fixed_pdf, fk_tables)

def __call__(self, pdf_grid_func, fk_tables, params):
scale = params[0]
return scale * self._fixed_pred, self.fixed_pdf
"""

from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Any, Callable

import jax.numpy as jnp


class ForwardMap(ABC):
"""Abstract base class for forward maps.

A forward map turns fit parameters into theory predictions that can be
compared with experimental data inside the likelihood.

All forward maps share the same call signature:

``(pdf_grid_func, fk_tables, params) -> predictions``

Notes
-----
The split point between PDF parameters and "extra" parameters is owned
by the forward map via ``self.n_pdf_params``.
"""

def __init__(self, n_pdf_params: int):

self.n_pdf_params = n_pdf_params

@abstractmethod
def __call__(
self,
pdf_grid_func: Callable[[jnp.ndarray], jnp.ndarray],
fk_tables: Any,
params: jnp.ndarray,
) -> jnp.ndarray:
"""Compute theory predictions from fit parameters.

Parameters
----------
pdf_grid_func : callable
Callable that evaluates PDF values on the fit x-grid from the PDF
parameters.

Expected call signature:
``pdf = pdf_grid_func(pdf_params)``
with ``pdf`` shaped ``(N_fl, N_x)``.

fk_tables : jnp.ndarray
Fast-kernel tables needed by the prediction function.

params : jnp.ndarray
1-D array containing all fit parameters. By convention:
* ``params[:self.n_pdf_params]`` are PDF parameters
* the remaining entries are extra parameters interpreted by the
specific ``ForwardMap`` implementation.

Returns
-------
jnp.ndarray
Theory predictions (1-D array with one entry per data point).
jnp.ndarray
The PDF values (2-D array with shape (N_fl, N_x)).

"""
raise NotImplementedError


class FKTableForwardMap(ForwardMap):
"""Default forward map: params → PDF → FK-table convolution.

This is the standard pipeline used in colibri PDF fits.
"""

def __init__(
self, pred_func: Callable[[jnp.ndarray, Any], jnp.ndarray], n_pdf_params: int
):
super().__init__(n_pdf_params)
self._pred_func = pred_func

def __call__(self, pdf_grid_func, fk_tables, params):
pdf_params = params[: self.n_pdf_params]
pdf = pdf_grid_func(pdf_params)
return self._pred_func(pdf, fk_tables), pdf


def forward_map(_pred_data, pdf_model):
"""Reportengine provider that builds the default FK-table forward map.

Parameters
----------
_pred_data : callable
Prediction function of the form ``pred_func(pdf, fk_tables) -> predictions``.
pdf_model : optional
Used to infer ``n_pdf_params`` from ``len(pdf_model.param_names)``.

"""

n_pdf_params = len(pdf_model.param_names)
return FKTableForwardMap(_pred_data, n_pdf_params=n_pdf_params)
Comment thread
comane marked this conversation as resolved.
Outdated
17 changes: 8 additions & 9 deletions colibri/likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,8 @@ def __init__(
self.positivity_penalty_settings = positivity_penalty_settings
self.integrability_penalty = integrability_penalty

self.pred_and_pdf = pdf_model.pred_and_pdf_func(
fit_xgrid, forward_map=forward_map
)
self.pdf_grid = pdf_model.grid_values_func(fit_xgrid)
self.forward_map = forward_map

self.fast_kernel_arrays = fast_kernel_arrays
self.positivity_fast_kernel_arrays = positivity_fast_kernel_arrays
Expand Down Expand Up @@ -126,7 +125,7 @@ def log_likelihood(
jnp.ndarray
jax array with the value of the log-likelihood.
"""
predictions, pdf = self.pred_and_pdf(params, fast_kernel_arrays)
predictions, pdf = self.forward_map(self.pdf_grid, fast_kernel_arrays, params)
# Select only the data relevant for this likelihood
# Especially important when using a training/validation split
predictions = predictions[self.central_values_idx]
Expand Down Expand Up @@ -169,7 +168,7 @@ def log_likelihood(
central_covmat_index,
pdf_model,
FIT_XGRID,
_pred_data,
forward_map,
fast_kernel_arrays,
positivity_fast_kernel_arrays,
_penalty_posdata,
Expand All @@ -186,7 +185,7 @@ def log_likelihood(
central_covmat_index,
pdf_model,
FIT_XGRID,
_pred_data,
forward_map,
fast_kernel_arrays,
positivity_fast_kernel_arrays,
_penalty_posdata,
Expand All @@ -200,7 +199,7 @@ def mc_log_likelihood(
fit_covariance_matrix,
pdf_model,
FIT_XGRID,
_pred_data,
forward_map,
fast_kernel_arrays,
positivity_fast_kernel_arrays,
_penalty_posdata,
Expand Down Expand Up @@ -228,7 +227,7 @@ def mc_log_likelihood(
central_covmat_index_train,
pdf_model,
FIT_XGRID,
_pred_data,
forward_map,
fast_kernel_arrays,
positivity_fast_kernel_arrays,
_penalty_posdata,
Expand All @@ -254,7 +253,7 @@ def mc_log_likelihood(
central_covmat_index_val,
pdf_model,
FIT_XGRID,
_pred_data,
forward_map,
fast_kernel_arrays,
positivity_fast_kernel_arrays,
_penalty_posdata,
Expand Down
3 changes: 2 additions & 1 deletion colibri/mc_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,10 @@ def write_exportgrid_mc(

# Create the exportgrid
lhapdf_interpolator = pdf_model.grid_values_func(LHAPDF_XGRID)
n_pdf_params = len(pdf_model.param_names)

# Rotate the grid from the evolution basis into the export grid basis
grid_for_writing = np.array(lhapdf_interpolator(parameters))
grid_for_writing = np.array(lhapdf_interpolator(parameters[:n_pdf_params]))

write_exportgrid(
grid_for_writing=grid_for_writing,
Expand Down
39 changes: 1 addition & 38 deletions colibri/pdf_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"""

from abc import ABC, abstractmethod
from typing import Callable, Tuple
from typing import Callable

import jax.numpy as jnp
from jax.typing import ArrayLike
Expand Down Expand Up @@ -53,40 +53,3 @@ def func(params):
return func
"""
pass

def pred_and_pdf_func(
self,
xgrid: ArrayLike,
forward_map: Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray],
) -> Callable[[jnp.ndarray, jnp.ndarray], Tuple[jnp.ndarray, jnp.ndarray]]:
"""Creates a function that returns a tuple of two arrays, given the model parameters and the fast kernel arrays as input.

The returned function produces:
- The first array: 1D vector of theory predictions for the data.
- The second array: PDF values evaluated on the x-grid, using `self.grid_values_func`, with shape (Nfl, Nx).

The `forward_map` is used to map the PDF values defined on the x-grid and the fast kernel arrays into the corresponding theory prediction vector.
"""
pdf_func = self.grid_values_func(xgrid)

def pred_and_pdf(params, fast_kernel_arrays):
"""
Parameters
----------
params: jnp.array
The model parameters.

fast_kernel_arrays: tuple
tuple of tuples of jnp.arrays
The FK tables to use.

Returns
-------
tuple
The predictions and the PDF values.
"""
pdf = pdf_func(params)
predictions = forward_map(pdf, fast_kernel_arrays)
return predictions, pdf

return pred_and_pdf
Loading
Loading