diff --git a/colibri/commondata_utils.py b/colibri/commondata_utils.py index cdbe3ba3..a3b71483 100644 --- a/colibri/commondata_utils.py +++ b/colibri/commondata_utils.py @@ -37,7 +37,6 @@ def level_0_commondata_tuple( FIT_XGRID, fast_kernel_arrays, flavour_indices=None, - fill_fk_xgrid_with_zeros=False, ): """ Returns a tuple (validphys nodes should be immutable) @@ -66,13 +65,6 @@ def level_0_commondata_tuple( flavour_indices: list, default is None Subset of flavour (evolution basis) indices to be used. - fill_fk_xgrid_with_zeros: bool, default is False - If True, then the missing xgrid points in the FK table - will be filled with zeros. This is useful when the FK table - is needed as tensor of shape (Ndat, Nfl, Nfk_xgrid) with Nfk_xgrid and Nfl fixed - for all datasets. - - Returns ------- tuple @@ -92,7 +84,6 @@ def level_0_commondata_tuple( ds, FIT_XGRID, flavour_indices=flavour_indices, - fill_fk_xgrid_with_zeros=fill_fk_xgrid_with_zeros, )(closure_test_central_pdf_grid, fk_dataset) ) ) diff --git a/colibri/tests/test_theory_predictions.py b/colibri/tests/test_theory_predictions.py index f280f665..b362e2f6 100644 --- a/colibri/tests/test_theory_predictions.py +++ b/colibri/tests/test_theory_predictions.py @@ -9,6 +9,8 @@ from numpy.testing import assert_allclose from validphys.fkparser import load_fktable +import numpy as np + from colibri.api import API as colibriAPI from colibri.tests.conftest import ( CLOSURE_TEST_PDFSET, @@ -30,26 +32,12 @@ def __init__(self, xgrid): self.xgrid = xgrid -def test_fktable_xgrid_indices_fill_with_zeros(): - # Case where fill_fk_xgrid_with_zeros is True - fktable = FKTableDataMock(xgrid=jnp.array([0.1, 0.2, 0.3])) - FIT_XGRID = jnp.array([0.05, 0.1, 0.15, 0.2, 0.25, 0.3]) - - expected_indices = jnp.arange( - len(FIT_XGRID) - ) # Should return indices for the entire FIT_XGRID - result = fktable_xgrid_indices(fktable, FIT_XGRID, fill_fk_xgrid_with_zeros=True) - - assert jnp.array_equal(result, expected_indices) - - -def test_fktable_xgrid_indices_no_fill(): - # Case where fill_fk_xgrid_with_zeros is False +def test_fktable_xgrid_indices(): fktable = FKTableDataMock(xgrid=jnp.array([0.1, 0.2, 0.3])) FIT_XGRID = jnp.array([0.05, 0.1, 0.15, 0.2, 0.25, 0.3]) expected_indices = jnp.array([1, 3, 5]) # Indices where fk_xgrid matches FIT_XGRID - result = fktable_xgrid_indices(fktable, FIT_XGRID, fill_fk_xgrid_with_zeros=False) + result = fktable_xgrid_indices(fktable, FIT_XGRID) assert jnp.array_equal(result, expected_indices) @@ -61,7 +49,7 @@ def test_fktable_xgrid_indices_with_tolerance(): # Due to tolerance, the indices should match as if they were the same expected_indices = jnp.array([1, 3, 5]) - result = fktable_xgrid_indices(fktable, FIT_XGRID, fill_fk_xgrid_with_zeros=False) + result = fktable_xgrid_indices(fktable, FIT_XGRID) assert jnp.array_equal(result, expected_indices) @@ -74,7 +62,7 @@ def test_fktable_xgrid_indices_no_matches(): expected_indices = jnp.array( [] ) # No matching indices, closest_indices returns empty array - result = fktable_xgrid_indices(fktable, FIT_XGRID, fill_fk_xgrid_with_zeros=False) + result = fktable_xgrid_indices(fktable, FIT_XGRID) assert jnp.array_equal(result, expected_indices) @@ -122,6 +110,49 @@ def test_fast_kernel_arrays(): assert jnp.any(fk_arrays_filled[0][0][:, :, non_zero_indices] != 0) +def test_fast_kernel_arrays_hadronic_fill_with_zeros(): + """ + Test that fast_kernel_arrays correctly fills the x-grid with zeros for hadronic FK tables. + This is a regression test for the bug where the 4D hadronic array was assigned + into a 3D zeros array. + """ + from colibri.utils import closest_indices + from validphys.fkparser import load_fktable + + dataset = colibriAPI.data(**TEST_DATASETS_HAD) + ds = dataset.datasets[0] + FIT_XGRID = colibriAPI.FIT_XGRID(**TEST_DATASETS_HAD) + + # This should not raise an error (regression check) + fk_arrays_filled = colibriAPI.fast_kernel_arrays( + **{**TEST_DATASETS_HAD, "fill_fk_xgrid_with_zeros": True} + ) + + fk_arr = fk_arrays_filled[0][0] + + # Hadronic FK array should be 4D: (Ndat, Nfl, Nfit_x, Nfit_x) + assert fk_arr.ndim == 4 + assert fk_arr.shape[2] == len(FIT_XGRID) + assert fk_arr.shape[3] == len(FIT_XGRID) + + # Check that non-zero values are placed at the correct x-grid positions + fk_xgrid = load_fktable(ds.fkspecs[0]).xgrid + non_zero_indices = closest_indices(FIT_XGRID, fk_xgrid, atol=1e-8) + non_zero_indices = np.array(non_zero_indices) + + # The non-zero block should contain non-zero values + assert jnp.any( + fk_arr[:, :, non_zero_indices[:, None], non_zero_indices[None, :]] != 0 + ) + + # Entries outside the non-zero block should be zero + all_indices = np.arange(len(FIT_XGRID)) + zero_indices = np.setdiff1d(all_indices, non_zero_indices) + if len(zero_indices) > 0: + assert jnp.all(fk_arr[:, :, zero_indices, :] == 0) + assert jnp.all(fk_arr[:, :, :, zero_indices] == 0) + + def test_make_dis_prediction(): """ Test make_dis_prediction function gives the same results diff --git a/colibri/theory_penalties.py b/colibri/theory_penalties.py index 474a59cd..7e8e4752 100644 --- a/colibri/theory_penalties.py +++ b/colibri/theory_penalties.py @@ -80,7 +80,9 @@ def make_penalty_posdataset(posdataset, FIT_XGRID, flavour_indices=None): """ pred_funcs = pred_funcs_from_dataset( - posdataset, FIT_XGRID, flavour_indices, fill_fk_xgrid_with_zeros=False + posdataset, + FIT_XGRID, + flavour_indices, ) def pos_penalty(pdf, alpha, lambda_positivity, fk_dataset): diff --git a/colibri/theory_predictions.py b/colibri/theory_predictions.py index 01a7d200..64bdbbe4 100644 --- a/colibri/theory_predictions.py +++ b/colibri/theory_predictions.py @@ -17,15 +17,11 @@ OP = {key: jax.jit(val) for key, val in convolution.OP.items()} -def fktable_xgrid_indices(fktable, FIT_XGRID, fill_fk_xgrid_with_zeros=False): +def fktable_xgrid_indices(fktable, FIT_XGRID): """ Given an FKTableData instance and the xgrid used in the fit returns the indices of the xgrid of the FK table in the xgrid of the fit. - If fill_fk_xgrid_with_zeros is True, then the all indices of the fit xgrid - are returned. This is useful when the FK table is needed as tensor - of shape (Ndat, Nfl, Nfk_xgrid) with Nfk_xgrid and Nfl fixed for all datasets. - Parameters ---------- fktable : validphys.coredata.FKTableData @@ -33,16 +29,11 @@ def fktable_xgrid_indices(fktable, FIT_XGRID, fill_fk_xgrid_with_zeros=False): FIT_XGRID: jnp.ndarray array of xgrid points of the theory entering the fit - fill_fk_xgrid_with_zeros: bool, default is False - Returns ------- jnp.ndarray Indices mapping FK x-grid into fit x-grid. """ - if fill_fk_xgrid_with_zeros: - return jnp.arange(len(FIT_XGRID)) - # Extract xgrid of the FK table and find the indices fk_xgrid = fktable.xgrid # atol is chosen to be 1e-8 as this is the order of magnitude of the difference between the smallest entries of the XGRID @@ -52,7 +43,7 @@ def fktable_xgrid_indices(fktable, FIT_XGRID, fill_fk_xgrid_with_zeros=False): def fast_kernel_arrays( - data, FIT_XGRID, flavour_indices=None, fill_fk_xgrid_with_zeros=False + data, FIT_XGRID, flavour_indices=None, fill_fk_xgrid_with_zeros=True ): """ Returns a tuple of tuples of jax.numpy arrays. @@ -96,10 +87,23 @@ def fast_kernel_arrays( # fill with zeros the Xgrid dimension of the FK table so as to have tensor of shape (Ndat, Nfl, Nfk_xgrid) fk_xgrid = fk.xgrid non_zero_indices = closest_indices(FIT_XGRID, fk_xgrid, atol=1e-8) - new_fk_arr = np.zeros( - (fk_arr.shape[0], fk_arr.shape[1], len(FIT_XGRID)) - ) - new_fk_arr[:, :, non_zero_indices] = fk_arr + if fk.hadronic: + new_fk_arr = np.zeros( + ( + fk_arr.shape[0], + fk_arr.shape[1], + len(FIT_XGRID), + len(FIT_XGRID), + ) + ) + new_fk_arr[ + :, :, non_zero_indices[:, None], non_zero_indices[None, :] + ] = fk_arr + else: + new_fk_arr = np.zeros( + (fk_arr.shape[0], fk_arr.shape[1], len(FIT_XGRID)) + ) + new_fk_arr[:, :, non_zero_indices] = fk_arr fk_arr = jnp.array(new_fk_arr) fk_dataset_arr.append(fk_arr) @@ -109,7 +113,9 @@ def fast_kernel_arrays( def make_dis_prediction( - fktable, FIT_XGRID, flavour_indices=None, fill_fk_xgrid_with_zeros=False + fktable, + FIT_XGRID, + flavour_indices=None, ): """ Closure to compute the theory prediction for a DIS observable. @@ -126,21 +132,13 @@ def make_dis_prediction( flavour_indices: list, default is None - fill_fk_xgrid_with_zeros: bool, default is False - If True, then the missing xgrid points in the FK table - will be filled with zeros. This is useful when the FK table - is needed as tensor of shape (Ndat, Nfl, Nfk_xgrid) with Nfk_xgrid and Nfl fixed - for all datasets. - Returns ------- Callable """ lumi_indices = mask_luminosity_mapping(fktable, flavour_indices) - fk_xgrid_indices = fktable_xgrid_indices( - fktable, FIT_XGRID, fill_fk_xgrid_with_zeros=fill_fk_xgrid_with_zeros - ) + fk_xgrid_indices = fktable_xgrid_indices(fktable, FIT_XGRID) def dis_prediction(pdf, fk_arr): """ @@ -166,15 +164,20 @@ def dis_prediction(pdf, fk_arr): jnp.ndarray theory prediction for a hadronic observable (shape is Ndata, ) """ + # NOTE: for computational efficiency, in the convolution, we only sum over xgrid points that are non-zero in the FK table. return jnp.einsum( - "ijk, jk ->i", fk_arr, pdf[lumi_indices, :][:, fk_xgrid_indices] + "ijk, jk ->i", + fk_arr[:, :, fk_xgrid_indices], + pdf[lumi_indices, :][:, fk_xgrid_indices], ) return dis_prediction def make_had_prediction( - fktable, FIT_XGRID, flavour_indices=None, fill_fk_xgrid_with_zeros=False + fktable, + FIT_XGRID, + flavour_indices=None, ): """ Closure to compute the theory prediction for a Hadronic observable. @@ -189,12 +192,6 @@ def make_had_prediction( flavour_indices: list, default is None - fill_fk_xgrid_with_zeros: bool, default is False - If True, then the missing xgrid points in the FK table - will be filled with zeros. This is useful when the FK table - is needed as tensor of shape (Ndat, Nfl, Nfk_xgrid) with Nfk_xgrid and Nfl fixed - for all datasets. - Returns ------- Callable @@ -203,9 +200,7 @@ def make_had_prediction( first_lumi_indices = lumi_indices[0::2] second_lumi_indices = lumi_indices[1::2] - fk_xgrid_indices = fktable_xgrid_indices( - fktable, FIT_XGRID, fill_fk_xgrid_with_zeros=fill_fk_xgrid_with_zeros - ) + fk_xgrid_indices = fktable_xgrid_indices(fktable, FIT_XGRID) def had_prediction(pdf, fk_arr): """ @@ -231,9 +226,10 @@ def had_prediction(pdf, fk_arr): jnp.ndarray theory prediction for a hadronic observable (shape is Ndata, ) """ + # NOTE: for computational efficiency, in the convolution, we only sum over xgrid points that are non-zero in the FK table. return jnp.einsum( "ijkl,jk,jl->i", - fk_arr, + fk_arr[:, :, fk_xgrid_indices[:, None], fk_xgrid_indices[None, :]], pdf[first_lumi_indices, :][:, fk_xgrid_indices], pdf[second_lumi_indices, :][:, fk_xgrid_indices], ) @@ -242,7 +238,9 @@ def had_prediction(pdf, fk_arr): def pred_funcs_from_dataset( - dataset, FIT_XGRID, flavour_indices, fill_fk_xgrid_with_zeros=False + dataset, + FIT_XGRID, + flavour_indices, ): """ Returns a list containing the forward maps associated with the fkspecs of a dataset. @@ -255,8 +253,6 @@ def pred_funcs_from_dataset( flavour_indices: list, default is None - fill_fk_xgrid_with_zeros: bool, default is False - Returns ------- list @@ -268,20 +264,18 @@ def pred_funcs_from_dataset( fk = load_fktable(fkspec).with_cuts(dataset.cuts) if fk.hadronic: - pred = make_had_prediction( - fk, FIT_XGRID, flavour_indices, fill_fk_xgrid_with_zeros - ) + pred = make_had_prediction(fk, FIT_XGRID, flavour_indices) else: - pred = make_dis_prediction( - fk, FIT_XGRID, flavour_indices, fill_fk_xgrid_with_zeros - ) + pred = make_dis_prediction(fk, FIT_XGRID, flavour_indices) pred_funcs.append(pred) return pred_funcs def make_pred_dataset( - dataset, FIT_XGRID, flavour_indices=None, fill_fk_xgrid_with_zeros=False + dataset, + FIT_XGRID, + flavour_indices=None, ): """ Compute theory prediction for a DataSetSpec @@ -296,16 +290,12 @@ def make_pred_dataset( flavour_indices: list, default is None - fill_fk_xgrid_with_zeros: bool, default is False - Returns ------- Callable """ - pred_funcs = pred_funcs_from_dataset( - dataset, FIT_XGRID, flavour_indices, fill_fk_xgrid_with_zeros - ) + pred_funcs = pred_funcs_from_dataset(dataset, FIT_XGRID, flavour_indices) def prediction(pdf, fk_dataset): return OP[dataset.op]( @@ -316,7 +306,9 @@ def prediction(pdf, fk_dataset): def make_pred_data( - data, FIT_XGRID, flavour_indices=None, fill_fk_xgrid_with_zeros=False + data, + FIT_XGRID, + flavour_indices=None, ): """ Compute theory prediction for entire DataGroupSpec @@ -331,8 +323,6 @@ def make_pred_data( flavour_indices: list, default is None - fill_fk_xgrid_with_zeros: bool, default is False - Returns ------- Callable @@ -346,7 +336,6 @@ def make_pred_data( ds, FIT_XGRID, flavour_indices, - fill_fk_xgrid_with_zeros=fill_fk_xgrid_with_zeros, ) ) @@ -363,7 +352,9 @@ def eval_preds(pdf, fast_kernel_arrays): def make_pred_t0data( - data, FIT_XGRID, flavour_indices=None, fill_fk_xgrid_with_zeros=False + data, + FIT_XGRID, + flavour_indices=None, ): """ Compute theory prediction for entire DataGroupSpec. @@ -380,8 +371,6 @@ def make_pred_t0data( flavour_indices: list, default is None - fill_fk_xgrid_with_zeros: bool, default is False - Returns ------- Callable @@ -395,7 +384,6 @@ def make_pred_t0data( ds, FIT_XGRID, flavour_indices=flavour_indices, - fill_fk_xgrid_with_zeros=fill_fk_xgrid_with_zeros, ) )