diff --git a/colibri/analytic_fit.py b/colibri/analytic_fit.py index 895befac7..af04ca284 100644 --- a/colibri/analytic_fit.py +++ b/colibri/analytic_fit.py @@ -84,7 +84,7 @@ def analytic_evidence_uniform_prior(sol_covmat, sol_mean, max_logl, a_vec, b_vec @check_pdf_model_is_linear def analytic_fit( - central_covmat_index, + central_sqrt_covmat_index, _pred_data, pdf_model, analytic_settings, @@ -103,8 +103,8 @@ def analytic_fit( Parameters ---------- - central_covmat_index: commondata_utils.CentralCovmatIndex - dataclass containing central values and covariance matrix. + central_sqrt_covmat_index: commondata_utils.CentralSqrtCovmatIndex + dataclass containing central values and square root of the covariance matrix. _pred_data: @jax.jit CompiledFunction Prediction function for the fit. @@ -142,8 +142,8 @@ def analytic_fit( intercept = pred_and_pdf(jnp.zeros(len(parameters)), fast_kernel_arrays)[0] # Construct the analytic solution - central_values = central_covmat_index.central_values - covmat = central_covmat_index.covmat + central_values = central_sqrt_covmat_index.central_values + sqrt_covmat = central_sqrt_covmat_index.sqrt_covmat # Solve chi2 analytically for the mean Y = central_values - intercept @@ -151,13 +151,9 @@ def analytic_fit( t0 = time.time() - # Cholesky factorization: S = L L^T - # upper False means that we want the lower triangular matrix L - L = jla.cholesky(covmat, upper=False) - - # Whiten the problem: Y' = L^-1 Y, X' = L^-1 X - Y_tilde = jlinalg.triangular_solve(L, Y, left_side=True, lower=True) - X_tilde = jlinalg.triangular_solve(L, X, left_side=True, lower=True) + # Whiten the problem: Y' = sqrt_covmat^-1 Y, X' = sqrt_covmat^-1 X + Y_tilde = jlinalg.triangular_solve(sqrt_covmat, Y, left_side=True, lower=True) + X_tilde = jlinalg.triangular_solve(sqrt_covmat, X, left_side=True, lower=True) if jnp.any(jla.eigh(X_tilde.T @ X_tilde)[0] <= 0.0): raise ValueError( @@ -257,7 +253,7 @@ def analytic_fit( min_chi2 = -2 * max_logl log.info(f"Minimum chi2 = {min_chi2}") - BIC = min_chi2 + sol_covmat.shape[0] * np.log(covmat.shape[0]) + BIC = min_chi2 + sol_covmat.shape[0] * np.log(sqrt_covmat.shape[0]) AIC = min_chi2 + 2 * sol_covmat.shape[0] # Compute average chi2 (in whitened basis) diff --git a/colibri/commondata_utils.py b/colibri/commondata_utils.py index cdbe3ba37..8de7cd5af 100644 --- a/colibri/commondata_utils.py +++ b/colibri/commondata_utils.py @@ -10,7 +10,7 @@ import jax.numpy as jnp from colibri.theory_predictions import make_pred_dataset -from colibri.core import CentralCovmatIndex +from colibri.core import CentralSqrtCovmatIndex def experimental_commondata_tuple(data): @@ -101,21 +101,21 @@ def level_0_commondata_tuple( def level_1_commondata_tuple( level_0_commondata_tuple, - data_generation_covariance_matrix, + general_covariance_matrix, level_1_seed=123456, ): """ Returns a tuple (validphys nodes should be immutable) of level 1 commondata instances. Noise is added to the level_0_commondata_tuple central values - according to a multivariate Gaussian with covariance data_generation_covariance_matrix + according to a multivariate Gaussian with covariance general_covariance_matrix Parameters ---------- level_0_commondata_tuple: tuple of nnpdf_data.coredata.CommonData instances A tuple of level_0 closure test data. - data_generation_covariance_matrix: jnp.ndarray + general_covariance_matrix: jnp.ndarray The covariance matrix used for data generation. level_1_seed: int @@ -133,11 +133,11 @@ def level_1_commondata_tuple( ) # Now, sample from the multivariate Gaussian with central values central_values - # and covariance matrix data_generation_covariance_matrix. This produces the + # and general_covariance_matrix. This produces the # level_1 data. rng = jax.random.PRNGKey(level_1_seed) sample = jax.random.multivariate_normal( - rng, central_values, data_generation_covariance_matrix + rng, central_values, general_covariance_matrix ) # Now, reconstruct the commondata tuple, by modifying the original commondata @@ -150,11 +150,11 @@ def level_1_commondata_tuple( return tuple(sample_list) -def central_covmat_index(commondata_tuple, fit_covariance_matrix): +def central_sqrt_covmat_index(commondata_tuple, general_sqrt_covariance_matrix): """ - Given a commondata_tuple and a covariance_matrix, generated + Given a commondata_tuple and a general_sqrt_covariance_matrix, generated according to respective explicit node in config.py, store - relevant data into CentralCovmatIndex dataclass. + relevant data into CentralSqrtCovmatIndex dataclass. Parameters ---------- @@ -163,15 +163,14 @@ def central_covmat_index(commondata_tuple, fit_covariance_matrix): (see config.produce_commondata_tuple) and accordingly to the specified options. - fit_covariance_matrix: jnp.ndarray - covariance matrix, is generated as explicit node - (see config.fit_covariance_matrix) can be either experimental + general_sqrt_covariance_matrix: jnp.ndarray + square root of the covariance matrix which can be either experimental or t0 covariance matrix depending on whether `use_fit_t0` is - True or False + True or False. Returns ------- - CentralCovmatIndex + CentralSqrtCovmatIndex Dataclass containing central values, covariance matrix and index of central values. """ @@ -180,17 +179,8 @@ def central_covmat_index(commondata_tuple, fit_covariance_matrix): ) central_values_idx = jnp.arange(central_values.shape[0]) - return CentralCovmatIndex( + return CentralSqrtCovmatIndex( central_values=central_values, central_values_idx=central_values_idx, - covmat=fit_covariance_matrix, + sqrt_covmat=general_sqrt_covariance_matrix, ) - - -def pseudodata_central_covmat_index( - commondata_tuple, data_generation_covariance_matrix -): - """Same as central_covmat_index, but with the pseudodata generation - covariance matrix for a Monte Carlo fit. - """ - return central_covmat_index(commondata_tuple, data_generation_covariance_matrix) diff --git a/colibri/config.py b/colibri/config.py index 0802f2dfe..e27059c16 100644 --- a/colibri/config.py +++ b/colibri/config.py @@ -607,7 +607,7 @@ def produce_commondata_tuple(self, closure_test_level=False): ) @explicit_node - def produce_fit_covariance_matrix(self, use_fit_t0: bool = True): + def produce_general_covariance_matrix(self, use_fit_t0: bool = True): """ Produces the covariance matrix used in the fit. This covariance matrix is used in: @@ -619,19 +619,6 @@ def produce_fit_covariance_matrix(self, use_fit_t0: bool = True): else: return colibri_covmats.dataset_inputs_covmat_from_systematics - @explicit_node - def produce_data_generation_covariance_matrix(self, use_gen_t0: bool = False): - """Produces the covariance matrix used in: - - level 1 closure test data construction (fluctuating around the level - 0 data) - - Monte Carlo pseudodata (fluctuating either around the level 0 data or - level 1 data) - """ - if use_gen_t0: - return colibri_covmats.dataset_inputs_t0_covmat_from_systematics - else: - return colibri_covmats.dataset_inputs_covmat_from_systematics - def parse_closure_test_pdf(self, name): """PDF set used to generate fakedata""" if name == "colibri_model": diff --git a/colibri/core.py b/colibri/core.py index 238b947e2..5dd7c2f91 100644 --- a/colibri/core.py +++ b/colibri/core.py @@ -256,9 +256,9 @@ class PosdataTrainValidationSplit(TrainValidationSplit): @dataclass(frozen=True) -class CentralCovmatIndex: +class CentralSqrtCovmatIndex: central_values: jnp.array - covmat: jnp.array + sqrt_covmat: jnp.array central_values_idx: jnp.array def to_dict(self): diff --git a/colibri/covmats.py b/colibri/covmats.py index bacc1dba7..343fd4a6d 100644 --- a/colibri/covmats.py +++ b/colibri/covmats.py @@ -14,17 +14,19 @@ from validphys import covmats -def sqrt_covmat_jax(covariance_matrix): +def general_sqrt_covariance_matrix(general_covariance_matrix): """ Same as `validphys.covmats.sqrt_covmat` but for jax.numpy arrays Parameters ---------- - covariance_matrix : jnp.ndarray + general_covariance_matrix : jnp.ndarray A positive definite covariance matrix, which is N_dat x N_dat (where N_dat is the number of data points after cuts) containing uncertainty and correlation information. + NOTE: for more details on what covariance matrix is used, see the production rule in `config.py` + for the options of covariance matrix. Returns ------- @@ -35,9 +37,9 @@ def sqrt_covmat_jax(covariance_matrix): ``jnp.allclose(sqrt_covmat @ sqrt_covmat.T, covariance_matrix)``. """ - dimensions = covariance_matrix.shape + dimensions = general_covariance_matrix.shape - if covariance_matrix.size == 0: + if general_covariance_matrix.size == 0: raise ValueError("Attempting the decomposition of an empty matrix.") elif dimensions[0] != dimensions[1]: raise ValueError( @@ -46,8 +48,11 @@ def sqrt_covmat_jax(covariance_matrix): f"{dimensions[1]}" ) - sqrt_diags = jnp.sqrt(jnp.diag(covariance_matrix)) - correlation_matrix = covariance_matrix / sqrt_diags[:, jnp.newaxis] / sqrt_diags + sqrt_diags = jnp.sqrt(jnp.diag(general_covariance_matrix)) + correlation_matrix = ( + general_covariance_matrix / sqrt_diags[:, jnp.newaxis] / sqrt_diags + ) + # NOTE: scipy.linalg.cholesky returns the upper triangular decomposition by default decomp = jla.cholesky(correlation_matrix) sqrt_matrix = (decomp * sqrt_diags).T return sqrt_matrix diff --git a/colibri/data_batch.py b/colibri/data_batch.py index 6ba2e3816..903991286 100644 --- a/colibri/data_batch.py +++ b/colibri/data_batch.py @@ -18,7 +18,7 @@ def data_batches( training_indices, batch_size=None, batch_seed=1, - fit_covariance_matrix=None, + general_covariance_matrix=None, shuffle_each_epoch=False, ) -> DataBatches: """ @@ -31,7 +31,7 @@ def data_batches( batch_seed: int, default is 1 - fit_covariance_matrix: jax.Array, optional + general_covariance_matrix: jax.Array, optional If provided together with shuffle_each_epoch=False, fixed batches are precomputed once and the corresponding inverse covariance submatrices are cached for reuse. This avoids inverting within the likelihood at @@ -79,8 +79,10 @@ def _slice_batches_from_perm(perm: jax.Array) -> List[jax.Array]: perm0 = _make_perm(key) fixed_batches = _slice_batches_from_perm(perm0) - if fit_covariance_matrix is not None: - train_covmat = fit_covariance_matrix[training_indices][:, training_indices] + if general_covariance_matrix is not None: + train_covmat = general_covariance_matrix[training_indices][ + :, training_indices + ] fixed_batches_specs = [ BatchSpec( idx=b, diff --git a/colibri/likelihood.py b/colibri/likelihood.py index 586a8c878..722be4e3f 100644 --- a/colibri/likelihood.py +++ b/colibri/likelihood.py @@ -9,7 +9,7 @@ import jax import jax.numpy as jnp from colibri.loss_functions import chi2 -from colibri.commondata_utils import CentralCovmatIndex +from colibri.commondata_utils import CentralSqrtCovmatIndex from colibri.data_batch import BatchSpec @@ -21,7 +21,7 @@ class LogLikelihood(object): def __init__( self, - central_covmat_index, + central_sqrt_covmat_index, pdf_model, fit_xgrid, forward_map, @@ -34,7 +34,7 @@ def __init__( """ Parameters ---------- - central_covmat_index: commondata_utils.CentralCovmatIndex + central_sqrt_covmat_index: commondata_utils.CentralSqrtCovmatIndex pdf_model: pdf_model.PDFModel @@ -53,10 +53,13 @@ def __init__( integrability_penalty: Callable """ - self.central_values = central_covmat_index.central_values - self.covmat = central_covmat_index.covmat - self.inv_covmat = jnp.linalg.inv(self.covmat) - self.central_values_idx = central_covmat_index.central_values_idx + self.central_values = central_sqrt_covmat_index.central_values + self.sqrt_covmat = central_sqrt_covmat_index.sqrt_covmat + # NOTE: we zero the upper triangle so XLA sees only n²/2 nonzeros and can optimize accordingly. + self.inv_sqrt_covmat = jnp.tril( + jnp.linalg.inv(central_sqrt_covmat_index.sqrt_covmat) + ) + self.central_values_idx = central_sqrt_covmat_index.central_values_idx self.pdf_model = pdf_model self.penalty_posdata = penalty_posdata self.positivity_penalty_settings = positivity_penalty_settings @@ -93,7 +96,7 @@ def __call__(self, params, batch: BatchSpec | None = None): return self.log_likelihood( params, self.central_values, - self.inv_covmat, + self.inv_sqrt_covmat, self.fast_kernel_arrays, self.positivity_fast_kernel_arrays, batch=batch, @@ -104,7 +107,7 @@ def log_likelihood( self, params: jnp.ndarray, central_values: jnp.ndarray, - inv_covmat: jnp.ndarray, + inv_sqrt_covmat: jnp.ndarray, fast_kernel_arrays: tuple, positivity_fast_kernel_arrays: tuple, batch: BatchSpec | None = None, @@ -117,7 +120,7 @@ def log_likelihood( ---------- params: jnp.ndarray central_values: jnp.ndarray - inv_covmat: jnp.ndarray + inv_sqrt_covmat: jnp.ndarray fast_kernel_arrays: tuple positivity_fast_kernel_arrays: tuple @@ -131,6 +134,7 @@ def log_likelihood( # Especially important when using a training/validation split predictions = predictions[self.central_values_idx] + # TODO: the code inside this if condition needs to be changed since we now have only sqrt_covmat. if batch is not None: predictions = predictions[batch.idx] central_values = central_values[batch.idx] @@ -161,12 +165,14 @@ def log_likelihood( ) return -0.5 * ( - chi2(central_values, predictions, inv_covmat) + pos_penalty + integ_penalty + chi2(central_values, predictions, inv_sqrt_covmat) + + pos_penalty + + integ_penalty ) def log_likelihood( - central_covmat_index, + central_sqrt_covmat_index, pdf_model, FIT_XGRID, _pred_data, @@ -183,7 +189,7 @@ def log_likelihood( model specific applications by changing the log_likelihood method of the LogLikelihood class. """ return LogLikelihood( - central_covmat_index, + central_sqrt_covmat_index, pdf_model, FIT_XGRID, _pred_data, @@ -197,7 +203,7 @@ def log_likelihood( def mc_log_likelihood( mc_pseudodata, - fit_covariance_matrix, + general_covariance_matrix, pdf_model, FIT_XGRID, _pred_data, @@ -216,11 +222,11 @@ def mc_log_likelihood( tr_idx = mc_pseudodata.training_indices central_values_train = mc_pseudodata.pseudodata[tr_idx] - covmat_train = fit_covariance_matrix[tr_idx][:, tr_idx] + covmat_train = general_covariance_matrix[tr_idx][:, tr_idx] - central_covmat_index_train = CentralCovmatIndex( + central_covmat_index_train = CentralSqrtCovmatIndex( central_values=central_values_train, - covmat=covmat_train, + sqrt_covmat=covmat_train, central_values_idx=tr_idx, ) @@ -242,11 +248,11 @@ def mc_log_likelihood( else: val_idx = mc_pseudodata.validation_indices central_values_val = mc_pseudodata.pseudodata[val_idx] - covmat_val = fit_covariance_matrix[val_idx][:, val_idx] + covmat_val = general_covariance_matrix[val_idx][:, val_idx] - central_covmat_index_val = CentralCovmatIndex( + central_covmat_index_val = CentralSqrtCovmatIndex( central_values=central_values_val, - covmat=covmat_val, + sqrt_covmat=covmat_val, central_values_idx=val_idx, ) diff --git a/colibri/loss_functions.py b/colibri/loss_functions.py index 3b816c6c9..a5a04f438 100644 --- a/colibri/loss_functions.py +++ b/colibri/loss_functions.py @@ -7,7 +7,7 @@ import jax.numpy as jnp -def chi2(central_values, predictions, inv_covmat): +def chi2(central_values, predictions, inv_sqrt_covmat): """ Compute the chi2 loss. @@ -19,8 +19,8 @@ def chi2(central_values, predictions, inv_covmat): predictions: jnp.ndarray The predictions of the model. - inv_covmat: jnp.ndarray - The inverse of the covariance matrix. + inv_sqrt_covmat: jnp.ndarray + The inverse of the square root of the covariance matrix. Returns ------- @@ -29,6 +29,8 @@ def chi2(central_values, predictions, inv_covmat): """ diff = predictions - central_values - loss = jnp.einsum("i,ij,j", diff, inv_covmat, diff) + # whiten the diff + z = jnp.einsum("ij,j->i", inv_sqrt_covmat, diff) + loss = jnp.dot(z, z) return loss diff --git a/colibri/mc_utils.py b/colibri/mc_utils.py index 5dc9301bb..254769c91 100644 --- a/colibri/mc_utils.py +++ b/colibri/mc_utils.py @@ -22,7 +22,7 @@ def mc_pseudodata( - pseudodata_central_covmat_index, + central_covmat_index, replica_index, trval_seed, shuffle_indices=True, @@ -33,9 +33,9 @@ def mc_pseudodata( a fraction mc_validation_fraction of the data. """ - central_values = pseudodata_central_covmat_index.central_values - covmat = pseudodata_central_covmat_index.covmat - all_indices = pseudodata_central_covmat_index.central_values_idx + central_values = central_covmat_index.central_values + covmat = central_covmat_index.covmat + all_indices = central_covmat_index.central_values_idx # Generate pseudodata according to a multivariate Gaussian centred on # central_values and with covariance matrix covmat. diff --git a/colibri/tests/conftest.py b/colibri/tests/conftest.py index a75b094cc..90e6e36e7 100644 --- a/colibri/tests/conftest.py +++ b/colibri/tests/conftest.py @@ -340,12 +340,12 @@ def wmin_param(params): """ -MOCK_CENTRAL_COVMAT_INDEX = Mock() -MOCK_CENTRAL_COVMAT_INDEX.central_values = jnp.ones(TEST_N_DATA) -MOCK_CENTRAL_COVMAT_INDEX.covmat = jnp.eye(TEST_N_DATA) -MOCK_CENTRAL_COVMAT_INDEX.central_values_idx = jnp.arange(TEST_N_DATA) +MOCK_CENTRAL_SQRT_COVMAT_INDEX = Mock() +MOCK_CENTRAL_SQRT_COVMAT_INDEX.central_values = jnp.ones(TEST_N_DATA) +MOCK_CENTRAL_SQRT_COVMAT_INDEX.sqrt_covmat = jnp.eye(TEST_N_DATA) +MOCK_CENTRAL_SQRT_COVMAT_INDEX.central_values_idx = jnp.arange(TEST_N_DATA) """ -Mock instance of Central covmat index object. +Mock instance of Central sqrt covmat index object. """ diff --git a/colibri/tests/test_analytic_fit.py b/colibri/tests/test_analytic_fit.py index 4067cdadb..afa6c14e6 100644 --- a/colibri/tests/test_analytic_fit.py +++ b/colibri/tests/test_analytic_fit.py @@ -14,7 +14,7 @@ from colibri.analytic_fit import AnalyticFit, analytic_fit, run_analytic_fit from colibri.core import PriorSettings from colibri.tests.conftest import ( - MOCK_CENTRAL_COVMAT_INDEX, + MOCK_CENTRAL_SQRT_COVMAT_INDEX, MOCK_PDF_MODEL, TEST_FK_ARRAYS, TEST_FORWARD_MAP_DIS, @@ -46,7 +46,7 @@ def test_analytic_fit_flat_direction(): with pytest.raises(ValueError): # Run the analytic fit and make sure that the Value Error is raised analytic_fit( - MOCK_CENTRAL_COVMAT_INDEX, + MOCK_CENTRAL_SQRT_COVMAT_INDEX, _pred_data, MOCK_PDF_MODEL, analytic_settings, @@ -69,7 +69,7 @@ def test_analytic_fit(caplog): # Run the analytic fit result = analytic_fit( - MOCK_CENTRAL_COVMAT_INDEX, + MOCK_CENTRAL_SQRT_COVMAT_INDEX, _pred_data, MOCK_PDF_MODEL, analytic_settings, @@ -91,7 +91,7 @@ def test_analytic_fit(caplog): # Run the analytic fit with caplog.at_level(logging.ERROR): # Set the log level to ERROR result_2 = analytic_fit( - MOCK_CENTRAL_COVMAT_INDEX, + MOCK_CENTRAL_SQRT_COVMAT_INDEX, _pred_data, MOCK_PDF_MODEL, analytic_settings, @@ -129,7 +129,7 @@ def test_analytic_fit_different_priors(caplog): # Run the analytic fit result = analytic_fit( - MOCK_CENTRAL_COVMAT_INDEX, + MOCK_CENTRAL_SQRT_COVMAT_INDEX, _pred_data, MOCK_PDF_MODEL, analytic_settings, @@ -155,7 +155,7 @@ def test_analytic_fit_different_priors(caplog): # Run the analytic fit with custom uniform prior result = analytic_fit( - MOCK_CENTRAL_COVMAT_INDEX, + MOCK_CENTRAL_SQRT_COVMAT_INDEX, _pred_data, MOCK_PDF_MODEL, analytic_settings, diff --git a/colibri/tests/test_blackjax_fit.py b/colibri/tests/test_blackjax_fit.py index c8845666e..17f80d1eb 100644 --- a/colibri/tests/test_blackjax_fit.py +++ b/colibri/tests/test_blackjax_fit.py @@ -11,7 +11,7 @@ import types from colibri.tests.conftest import ( - MOCK_CENTRAL_COVMAT_INDEX, + MOCK_CENTRAL_SQRT_COVMAT_INDEX, MOCK_PDF_MODEL, MOCK_PENALTY_POSDATA, TEST_FK_ARRAYS, @@ -63,7 +63,7 @@ def mock_sample(rng_key, n_samples): def test_blackjax_fit(pos_penalty): _pred_data = lambda *args: jnp.array([0.0]) mock_log_likelihood = LogLikelihood( - MOCK_CENTRAL_COVMAT_INDEX, + MOCK_CENTRAL_SQRT_COVMAT_INDEX, MOCK_PDF_MODEL, TEST_XGRID, _pred_data, diff --git a/colibri/tests/test_commondata/NMC_NC_NOTFIXED_P_EM-SIGMARED_level1_central_values.csv b/colibri/tests/test_commondata/NMC_NC_NOTFIXED_P_EM-SIGMARED_level1_central_values.csv index 8446ce4df..7115b3b94 100644 --- a/colibri/tests/test_commondata/NMC_NC_NOTFIXED_P_EM-SIGMARED_level1_central_values.csv +++ b/colibri/tests/test_commondata/NMC_NC_NOTFIXED_P_EM-SIGMARED_level1_central_values.csv @@ -1,205 +1,205 @@ entry,data -17,0.3603048338606682 -22,0.35818729276965333 -23,0.33935994020184174 -28,0.3664990774711061 -29,0.3467304060499512 -30,0.40521255214950547 -35,0.3531961258192715 -36,0.3637907395768924 -37,0.38575738790569325 -41,0.35227873193906156 -42,0.36252003504518365 -43,0.3440042704263882 -47,0.35837420663358976 -48,0.32985308561991883 -49,0.345677289038046 -52,0.31232801049115083 -53,0.35024401439955416 -54,0.33244422895894415 -55,0.29895920861396597 -58,0.3189413244213146 -59,0.31466869994618485 -60,0.31245746144905556 -61,0.34120274075180873 -64,0.2915009416514593 -65,0.30404012268294733 -66,0.2578634296440007 -69,0.24234003419324424 -70,0.2653291979914224 -84,0.37105586729678547 -85,0.39779266373407957 -88,0.36442618090479545 -89,0.3794579788223175 -90,0.3834055260582568 -92,0.3601622676555683 -93,0.37724336262119607 -94,0.379606220737562 -95,0.3801245043834281 -96,0.398594519453656 -98,0.3854347864792688 -99,0.37811069703826 -100,0.3859782785265421 -101,0.376311096471139 -102,0.4101422662963671 -105,0.3553078549871987 -106,0.3860653491117977 -107,0.3709452617986862 -108,0.3821451769966814 -109,0.33793883376015216 -111,0.37667152717341795 -112,0.3531980619266058 -113,0.3783190183905295 -114,0.3732906235757716 -115,0.3790778930161888 -116,0.37180625172238063 -117,0.3590059668172086 -118,0.35883576869375594 -119,0.35592019948730574 -120,0.3411703865914043 -121,0.38335872159805195 -122,0.3502076227157032 -123,0.33956145499586293 -124,0.3381559578280785 -125,0.32612959027712113 -126,0.29924397651055984 -127,0.33087547629341174 -128,0.3194599233652932 -129,0.30889576032317523 -130,0.29020273997835255 -131,0.2879337429078681 -132,0.26498927043902704 -133,0.24411640007257404 -134,0.24328513403322571 -135,0.21715722915230992 -137,0.1594824407227202 -138,0.15041563263063173 -148,0.37011668893344774 -149,0.36363643550279656 -153,0.40860815283656154 -154,0.3920906510982488 -155,0.4188731196613895 -158,0.3870779495512743 -159,0.3860255076362997 -160,0.4212783017274639 -161,0.4199918457339618 -162,0.3871457782043634 -163,0.40765240508644424 -164,0.4134364935358898 -165,0.4062927856002634 -166,0.40132106829281183 -167,0.3971068692049978 -168,0.35811149951166643 -169,0.38475442674287613 -170,0.38501063915201744 -171,0.3905130591845987 -172,0.41942545683917043 -173,0.4091550238163816 -174,0.4378801548580854 -175,0.378708380183995 -176,0.4042357675637269 -177,0.38818983767404247 -178,0.3792996039280643 -179,0.39029471156339696 -180,0.3281958831978928 -181,0.3714018207046274 -182,0.38154728406159377 -183,0.38739250557671373 -184,0.37011456266059983 -185,0.3843527462654067 -186,0.3939117448781657 -187,0.3787215693465303 -188,0.38371378599487277 -189,0.392780264860788 -190,0.34767773359773746 -191,0.3705212827548895 -192,0.37764046616851993 -193,0.3456333433975945 -194,0.3438700613715861 -195,0.36410550008028053 -196,0.3524225627725131 -197,0.37618521495538676 -198,0.3265095184938219 -199,0.33647065768693 -200,0.33880807082529196 -201,0.3292349379980065 -202,0.27607971209490717 -203,0.31327439121612466 -204,0.3043928847081169 -205,0.3027995826851733 -206,0.3058202259116505 -207,0.2551672805192405 -208,0.2348927655490289 -209,0.22632490219770265 -210,0.23667847041375836 -211,0.17151953920279323 -212,0.13902771457956487 -213,0.13268370797781665 -222,0.390713954868101 -223,0.3821041998509772 -226,0.3935734978046931 -227,0.4016382324509361 -228,0.3930417837676036 -230,0.39755547437723215 -231,0.4116832091916647 -232,0.4106980667369833 -233,0.39084285916200673 -234,0.4089064590428189 -235,0.38118990953810505 -236,0.40880898330165977 -237,0.41328555657638594 -238,0.3973715677385094 -239,0.4113250903927858 -240,0.4093674677004484 -241,0.39654523963911703 -242,0.3960657365605373 -243,0.39711580439609484 -244,0.40156577017408956 -245,0.39900448133883376 -246,0.40840287528500313 -247,0.4001621683293462 -248,0.39320020042015535 -249,0.3920786567613664 -250,0.3952672365105562 -251,0.40171538794233913 -252,0.3838344418537248 -253,0.3768840662135969 -254,0.39266898926306143 -255,0.40224720369051237 -256,0.3851695152121074 -257,0.38519556370407526 -258,0.3764714341280915 -259,0.35801365226337 -260,0.37241134938336184 -261,0.3708695597777215 -262,0.34867520304404953 -263,0.3776469278004451 -264,0.37602595220086626 -265,0.36090646084058364 -266,0.3570185007291533 -267,0.3530623280168961 -268,0.3431850049193758 -269,0.35110071604546356 -270,0.34337070849238216 -271,0.33459631137242307 -272,0.32452501070590306 -273,0.33493479661503534 -274,0.3249847685770449 -275,0.3132368775780197 -276,0.31909018586752824 -277,0.2934865491337534 -278,0.3144645575364671 -279,0.2886736820629583 -280,0.27178077177396465 -281,0.2791723533750581 -282,0.2757782194732628 -283,0.2358048758993952 -284,0.22835886898124533 -285,0.22388113083928893 -286,0.21580843815256529 -287,0.22498582081791113 -288,0.18446723149987512 -289,0.13591121287614563 -290,0.1205371025063535 -291,0.122473865408316 -292,0.1040850466795483 +17,0.3602960851260646 +22,0.3582420397001891 +23,0.33943116413896546 +28,0.36659581532552343 +29,0.34661495424133304 +30,0.4051189440922148 +35,0.3531665850969312 +36,0.36343827564358666 +37,0.38606914980645135 +41,0.3521409131153247 +42,0.3621596977535262 +43,0.34392845504377206 +47,0.3582250034548195 +48,0.32955846750525225 +49,0.345748311614541 +52,0.31260246279430925 +53,0.34958219017296754 +54,0.3325759666587536 +55,0.29992721243151854 +58,0.318417569556407 +59,0.3149477559767299 +60,0.31206293709830024 +61,0.34042285566169816 +64,0.29132802748438946 +65,0.30447730214599095 +66,0.2584225833573018 +69,0.242918480367915 +70,0.26500168686840736 +84,0.37090292857596235 +85,0.39757890003424556 +88,0.3641547662865116 +89,0.3790435299434382 +90,0.3832890336994821 +92,0.3598378494332204 +93,0.3769218997285954 +94,0.37940312354065786 +95,0.3798707613334809 +96,0.3986885214477179 +98,0.3850573213523201 +99,0.37783525127128365 +100,0.38572862264332425 +101,0.37602972429441567 +102,0.40985803278671173 +105,0.3550159092876564 +106,0.3857602061237348 +107,0.37061043775217845 +108,0.38181334117712395 +109,0.3375492837700707 +111,0.3763720710338359 +112,0.3528483894442309 +113,0.37795520119152826 +114,0.3729226710939394 +115,0.3786763133439441 +116,0.3716288651980735 +117,0.35877058270549167 +118,0.35854626691196717 +119,0.3556052909788879 +120,0.34082309109881925 +121,0.3829866258528797 +122,0.34999798122791115 +123,0.33930968310688264 +124,0.3378614352626538 +125,0.32579174425095153 +126,0.2989323354601122 +127,0.3306221444645402 +128,0.3191856065595237 +129,0.30860894217757034 +130,0.28991366900932963 +131,0.28763274077039613 +132,0.26468952108150845 +133,0.2438341891121763 +134,0.24301580348954296 +135,0.21689498865563167 +137,0.15921031835435062 +138,0.15020698370764388 +148,0.37008922759564256 +149,0.36366895076629935 +153,0.4086277756217774 +154,0.3921288650985351 +155,0.41881689283285567 +158,0.38704222853096176 +159,0.3860391871414111 +160,0.42110025926872835 +161,0.4201032250128759 +162,0.3875546100117665 +163,0.4077163705784447 +164,0.41327081685181577 +165,0.406406966652214 +166,0.4014169147482617 +167,0.39685104951075073 +168,0.3584908733573333 +169,0.3849033801175411 +170,0.384666469664671 +171,0.3904487271941932 +172,0.41904513817489136 +173,0.4091837130435501 +174,0.4380748076120528 +175,0.37829972842745746 +176,0.40441922185575974 +177,0.3881800677833913 +178,0.37887883414016016 +179,0.39032443260900895 +180,0.32842670827365167 +181,0.3715356687575716 +182,0.3814750118997806 +183,0.3873176183944553 +184,0.36928214594771946 +185,0.3844633300007523 +186,0.39360651347341596 +187,0.37857526350558685 +188,0.38375915771693103 +189,0.3927436069782699 +190,0.34741238056468676 +191,0.3712612054969631 +192,0.3775916908046756 +193,0.34563207165423 +194,0.34376008369560107 +195,0.36390639422732174 +196,0.3519477128866498 +197,0.37618288802534927 +198,0.32631359607988636 +199,0.3365701702718855 +200,0.338708118392744 +201,0.32893628193854807 +202,0.2762636569159889 +203,0.3132057840676821 +204,0.3043535885918375 +205,0.3028238698088419 +206,0.3048960653367345 +207,0.25547352949341706 +208,0.2349849658073128 +209,0.22651926124023591 +210,0.23626225630204126 +211,0.17155218289293417 +212,0.13908509834393842 +213,0.13299306328832683 +222,0.3907536667895795 +223,0.38223538609755314 +226,0.39370329506748425 +227,0.4017646775518575 +228,0.39306083937720365 +230,0.39752257751365794 +231,0.41173447554232645 +232,0.4107189474862238 +233,0.39113761313472406 +234,0.4095670722668964 +235,0.3815849402347026 +236,0.408872020320091 +237,0.41333534355822993 +238,0.39764216295856875 +239,0.41174247053866175 +240,0.4099149621897652 +241,0.3965618640741384 +242,0.3961288090834424 +243,0.39720443248836906 +244,0.40202390248972686 +245,0.39942767236933324 +246,0.409247583553478 +247,0.3999742201575351 +248,0.39328677218494984 +249,0.39253660500044496 +250,0.3957186524051729 +251,0.4022781001470672 +252,0.38515998161327764 +253,0.3763764469082246 +254,0.3932458866542779 +255,0.40255212118691835 +256,0.385501055207715 +257,0.3855819669277833 +258,0.3764825874710068 +259,0.3583850959811114 +260,0.37266379681244305 +261,0.3709910972431487 +262,0.34972062975843154 +263,0.37772341119534214 +264,0.376466250039088 +265,0.3610969807339979 +266,0.357342045684518 +267,0.35323388524511873 +268,0.3430568933316093 +269,0.3511065691769053 +270,0.343486274558213 +271,0.3349987004426417 +272,0.325212609892056 +273,0.33507811568994117 +274,0.3249408770860502 +275,0.31343871086515823 +276,0.31916881785804396 +277,0.2938015547122429 +278,0.314993391082701 +279,0.2888644440810636 +280,0.2717597676716878 +281,0.27944853327028596 +282,0.27606734233838753 +283,0.23489631678081063 +284,0.2285871789951518 +285,0.22397228537030542 +286,0.21594442527795765 +287,0.22528461402526953 +288,0.18499516950155376 +289,0.13618537168370887 +290,0.1207120565747806 +291,0.12253291459260823 +292,0.1043397953904383 diff --git a/colibri/tests/test_commondata/NMC_level2_central_values.csv b/colibri/tests/test_commondata/NMC_level2_central_values.csv index 01b1f2c0f..112aa8f77 100644 --- a/colibri/tests/test_commondata/NMC_level2_central_values.csv +++ b/colibri/tests/test_commondata/NMC_level2_central_values.csv @@ -1,205 +1,205 @@ ,cv -0,0.33192269298907395 -1,0.3747305469911535 -2,0.37637134153620494 -3,0.38699236624058664 -4,0.34914377135484564 -5,0.3935204018600434 -6,0.3726200423009672 -7,0.3429039535841628 -8,0.3939042148903802 -9,0.35071020371245787 -10,0.33815639649667256 -11,0.34342782808380284 -12,0.36283973499228694 -13,0.36098517712087697 -14,0.38886790874822547 -15,0.36190903665049085 -16,0.33871156899504795 -17,0.3503197974335905 -18,0.43769773484445057 -19,0.32232321679425435 -20,0.3302518920612063 -21,0.29763073499848763 -22,0.33049175606082404 -23,0.3075954061413668 -24,0.34876348540514285 -25,0.3006334842455802 -26,0.29090055928564723 -27,0.23629848663748892 -28,0.37730754100969 -29,0.3940793229403061 -30,0.36535699625193474 -31,0.3737916906084377 -32,0.41022428767157265 -33,0.34009366729122203 -34,0.3910760251255194 -35,0.39967293306717494 -36,0.3860945191549688 -37,0.42774382868761046 -38,0.34916621917410917 -39,0.398521408989425 -40,0.3804318994289178 -41,0.36782932114681804 -42,0.4055049872514805 -43,0.3890043289836042 -44,0.37922609217557773 -45,0.37394954878687353 -46,0.4708371315672829 -47,0.33914260572297517 -48,0.34576763014972833 -49,0.3273935916803767 -50,0.3362257777590238 -51,0.39126959353267043 -52,0.3981271353934514 -53,0.29276281191263337 -54,0.3461954517474651 -55,0.34049197550185595 -56,0.3579512898222646 -57,0.3333449994809165 -58,0.3885542451262865 -59,0.32142881867397183 -60,0.3056797116232881 -61,0.3435453049595251 -62,0.3010067079605381 -63,0.38721747019713965 -64,0.3048725476388016 -65,0.291241880491206 -66,0.29405696276026 -67,0.27816100455715387 -68,0.3232802987703025 -69,0.2798769090546532 -70,0.25792185532765594 -71,0.21788214273431783 -72,0.23253948226128238 -73,0.19020839368401615 -74,0.16574386604094346 -75,0.3628546484181653 -76,0.37488739449239394 -77,0.38224701246619097 -78,0.3816033030914252 -79,0.40793332689772094 -80,0.37064081663939524 -81,0.36193327196860114 -82,0.37966168758649216 -83,0.38679899744638013 -84,0.33931858696833356 -85,0.3737387873948957 -86,0.39272299282344897 -87,0.38872346173659533 -88,0.40023931272831176 -89,0.38802712637458275 -90,0.2950286243204989 -91,0.3678344626136501 -92,0.4039493550145363 -93,0.3948652681736802 -94,0.4073847057583242 -95,0.40595486711779466 -96,0.31098843011267197 -97,0.40062274467423614 -98,0.3663445344393611 -99,0.3638813162597222 -100,0.4277446747275281 -101,0.3908198699237685 -102,0.37060749518373015 -103,0.33204341069949295 -104,0.36304186715665804 -105,0.39526211438782266 -106,0.4497738149048414 -107,0.3134167122954642 -108,0.38905283720378486 -109,0.371334356257774 -110,0.3416662560081737 -111,0.35254417974935004 -112,0.38812075429633475 -113,0.28272284621715815 -114,0.3565445716841722 -115,0.3535820846277791 -116,0.3336603009295655 -117,0.36544980934887855 -118,0.35531421404071123 -119,0.329471104078231 -120,0.3352193317049103 -121,0.2899851388241994 -122,0.32283803088100893 -123,0.34836900291393064 -124,0.27672487240820876 -125,0.28513897259750937 -126,0.31462034038869907 -127,0.2719721794000607 -128,0.31456641248750683 -129,0.21083544404340504 -130,0.21577764079882883 -131,0.18575809952540867 -132,0.22787248112338387 -133,0.157536372999196 -134,0.12091437565076325 -135,0.11492346842941377 -136,0.36980657741437695 -137,0.37704984535896857 -138,0.38836541814174574 -139,0.38798165500107185 -140,0.3910462293122809 -141,0.38246676946923774 -142,0.3972821468483997 -143,0.40107046600450463 -144,0.4195259853931709 -145,0.44781722723439105 -146,0.39764909457840314 -147,0.4019225567395396 -148,0.38853154331253853 -149,0.4137242191973009 -150,0.4191073716910497 -151,0.4259905215496432 -152,0.38282069402512886 -153,0.4017865626139235 -154,0.3923180792963412 -155,0.4153158007119324 -156,0.41182464839176464 -157,0.4284021171910421 -158,0.35254195967629515 -159,0.3838009746603819 -160,0.38782943793782015 -161,0.4055340087162593 -162,0.41505666012649 -163,0.4371357350996553 -164,0.39187125762916863 -165,0.39048006997732526 -166,0.3935717806090454 -167,0.39520236630783184 -168,0.37840747037135264 -169,0.3739722949021634 -170,0.401869699672214 -171,0.38114258666765616 -172,0.38074115848623824 -173,0.4068203658415078 -174,0.32394332101526774 -175,0.3826200441169252 -176,0.36294758115805104 -177,0.36280538472954016 -178,0.3302993859786094 -179,0.338708760339188 -180,0.331650927874605 -181,0.34373744221533764 -182,0.3396404303432598 -183,0.35059735492935973 -184,0.33176743242931633 -185,0.2996778119664703 -186,0.3069245332158003 -187,0.3054725020692893 -188,0.2992491474547441 -189,0.2969076022374312 -190,0.2856714237116741 -191,0.2602185668720255 -192,0.25060160116224384 -193,0.2581954729895674 -194,0.1970804794268892 -195,0.2306310847317516 -196,0.21422459252669232 -197,0.2182253723325444 -198,0.2155336391792624 -199,0.20497258394733725 -200,0.13924986116068555 -201,0.1201581843342225 -202,0.12013699049490029 -203,0.10352317184321143 +0,0.3320000456343507 +1,0.3748749978777025 +2,0.3764907348126766 +3,0.3870942866939022 +4,0.34921620262211434 +5,0.39357195014227014 +6,0.37267578708442545 +7,0.3429207550900085 +8,0.39396604655847334 +9,0.35069746747013913 +10,0.338015176946116 +11,0.3433862447508126 +12,0.3627769702449634 +13,0.3608729582060437 +14,0.3888782195232877 +15,0.36196673242184424 +16,0.3385488912313887 +17,0.35034433050068675 +18,0.43789284809308293 +19,0.3222130553732431 +20,0.33030360414258736 +21,0.29752317291920305 +22,0.3302692012548227 +23,0.3075519432115465 +24,0.34884467306277694 +25,0.300712848516291 +26,0.2909712828471878 +27,0.23617448434382596 +28,0.3774528663931931 +29,0.3942142953876416 +30,0.36547545190675285 +31,0.3739949507638011 +32,0.4102083566868722 +33,0.3401629759500299 +34,0.3911423404940657 +35,0.3995705726940441 +36,0.3860178102434032 +37,0.4270141681428556 +38,0.34943990175030154 +39,0.3986020561266927 +40,0.3803279327762195 +41,0.36781651185929254 +42,0.40518880682845654 +43,0.3890412538924381 +44,0.3792253945021168 +45,0.37409235333888985 +46,0.47032405098913616 +47,0.3395681289913764 +48,0.34613291560797643 +49,0.32791810366852514 +50,0.33643600596320195 +51,0.391196470696606 +52,0.397752801815056 +53,0.29350204247227474 +54,0.34642522321724883 +55,0.34091688412438537 +56,0.3579768683187782 +57,0.3335011870685683 +58,0.38820961537798 +59,0.3218277742082934 +60,0.3059889423867525 +61,0.34365559543123536 +62,0.30123686589311605 +63,0.38653567989042614 +64,0.30501688875816135 +65,0.29134626443538447 +66,0.29398663564397093 +67,0.2782371393451725 +68,0.32305901275951204 +69,0.2796975381370556 +70,0.25772092269252656 +71,0.21785338468048834 +72,0.232067500707852 +73,0.18982918570119497 +74,0.16539555384253452 +75,0.36332631222214024 +76,0.37555942745002246 +77,0.38274370244191885 +78,0.3821584244324852 +79,0.4086903107727046 +80,0.37107675406099866 +81,0.3621629127257864 +82,0.3804651031635101 +83,0.38702509960312265 +84,0.33890790674671534 +85,0.37411294265451545 +86,0.3934267058869639 +87,0.38922553116073627 +88,0.40090511152034985 +89,0.38896612491560106 +90,0.2951572376829248 +91,0.36814482392993003 +92,0.40477027587079006 +93,0.39558975099219523 +94,0.40836124905896337 +95,0.40672966952298234 +96,0.311230533079052 +97,0.40132833978864163 +98,0.36692505845343004 +99,0.3646210351087103 +100,0.42870478171668336 +101,0.39160165945952136 +102,0.3710120535477687 +103,0.33256838851490433 +104,0.3636935066851795 +105,0.3959860609600574 +106,0.45078270710440793 +107,0.3138300011721631 +108,0.38968193891412756 +109,0.3719741449123973 +110,0.3422774061248695 +111,0.3532279338947229 +112,0.38893692138983166 +113,0.28287387759319577 +114,0.3570028738723356 +115,0.35409751039560794 +116,0.33426300398969283 +117,0.3661272072652692 +118,0.35614397512185386 +119,0.3298893983309484 +120,0.3357653934549284 +121,0.29046264183040327 +122,0.323428747312865 +123,0.3490619512062637 +124,0.2771031724096577 +125,0.28563317011863987 +126,0.31513640033474377 +127,0.2724894921757355 +128,0.3153413383159387 +129,0.2112514398771773 +130,0.21621055377413664 +131,0.18610584433998442 +132,0.2285008965147044 +133,0.15794306285112028 +134,0.12125418765786956 +135,0.11518825897607175 +136,0.36985957799016517 +137,0.377371092160903 +138,0.388280178284051 +139,0.38825238744889035 +140,0.39125774506505306 +141,0.3823432483173795 +142,0.3971979509988071 +143,0.40132509808386935 +144,0.41990380434801966 +145,0.44831216069144586 +146,0.397692093177 +147,0.4018193538980447 +148,0.3884271328329649 +149,0.4139273703161559 +150,0.4194351251528375 +151,0.42638615665521273 +152,0.38272534309741707 +153,0.40171928092415043 +154,0.39239019461790414 +155,0.4156019333658019 +156,0.41212729076183663 +157,0.4287864762823745 +158,0.35245566136902484 +159,0.3837579246958282 +160,0.3880282102582463 +161,0.40585379026652896 +162,0.4154137970650791 +163,0.43770750556594057 +164,0.39172089223872286 +165,0.39074174046135124 +166,0.39398656633510426 +167,0.3956248340974036 +168,0.37880337166182104 +169,0.37413222562301596 +170,0.402238618600373 +171,0.3815582770922897 +172,0.3810806987758028 +173,0.4072801527166225 +174,0.3240507619026542 +175,0.3829587951710855 +176,0.36335902947454585 +177,0.36320448526288673 +178,0.3306635376587522 +179,0.3389586689359901 +180,0.3319696765822389 +181,0.3440974287468366 +182,0.34005162643700954 +183,0.351026503988845 +184,0.3320558640515136 +185,0.2999760434069472 +186,0.3073096125429107 +187,0.30581739395991986 +188,0.2996180924586271 +189,0.29742825414440527 +190,0.2860342843494195 +191,0.2605428414504416 +192,0.25102305421287974 +193,0.2586237594957355 +194,0.1968304047076662 +195,0.23117309229786412 +196,0.21461976743670633 +197,0.2186208065636584 +198,0.21598676228912225 +199,0.20538299635574722 +200,0.14022659856993683 +201,0.12076060549265144 +202,0.1205180735596712 +203,0.1038507373756925 diff --git a/colibri/tests/test_commondata_utils.py b/colibri/tests/test_commondata_utils.py index b685c1029..d5afc7465 100644 --- a/colibri/tests/test_commondata_utils.py +++ b/colibri/tests/test_commondata_utils.py @@ -10,7 +10,10 @@ from numpy.testing import assert_allclose from colibri.api import API as colibriAPI -from colibri.commondata_utils import CentralCovmatIndex, experimental_commondata_tuple +from colibri.commondata_utils import ( + CentralSqrtCovmatIndex, + experimental_commondata_tuple, +) from colibri.tests.conftest import ( CLOSURE_TEST_PDFSET, PSEUDODATA_SEED, @@ -49,18 +52,18 @@ def test_experimental_commondata_tuple(): def test_central_covmat_index(): """ - Test that CentralCovmatIndex object is produced correctly. + Test that CentralSqrtCovmatIndex object is produced correctly. """ - result = colibriAPI.central_covmat_index(**{**TEST_DATASETS, **T0_PDFSET}) - # Check that central_covmat_index produces a CentralCovmatIndex object - assert isinstance(result, CentralCovmatIndex) + result = colibriAPI.central_sqrt_covmat_index(**{**TEST_DATASETS, **T0_PDFSET}) + # Check that central_covmat_index produces a CentralSqrtCovmatIndex object + assert isinstance(result, CentralSqrtCovmatIndex) - # Check that CentralCovmatIndex has the required attributes, of the correct types + # Check that CentralSqrtCovmatIndex has the required attributes, of the correct types assert hasattr(result, "central_values") assert isinstance(result.central_values, jnp.ndarray) - assert hasattr(result, "covmat") - assert isinstance(result.covmat, jnp.ndarray) + assert hasattr(result, "sqrt_covmat") + assert isinstance(result.sqrt_covmat, jnp.ndarray) assert hasattr(result, "central_values_idx") assert isinstance(result.central_values_idx, jnp.ndarray) @@ -69,7 +72,7 @@ def test_central_covmat_index(): assert isinstance(result_dict, dict) # Check that dimensions of attributes are correct - assert result.central_values.shape[0] == result.covmat.shape[0] + assert result.central_values.shape[0] == result.sqrt_covmat.shape[0] assert result.central_values_idx.shape[0] == result.central_values.shape[0] @@ -107,7 +110,12 @@ def test_level1_commondata_tuple(): ) current_level1_central_values = colibriAPI.level_1_commondata_tuple( - **{**TEST_DATASETS, **CLOSURE_TEST_PDFSET, "level_1_seed": PSEUDODATA_SEED} + **{ + **TEST_DATASETS, + **CLOSURE_TEST_PDFSET, + "level_1_seed": PSEUDODATA_SEED, + **T0_PDFSET, + } ) assert_allclose( diff --git a/colibri/tests/test_covmats.py b/colibri/tests/test_covmats.py index 34918130b..becc2023d 100644 --- a/colibri/tests/test_covmats.py +++ b/colibri/tests/test_covmats.py @@ -12,21 +12,21 @@ from numpy.testing import assert_allclose from colibri.api import API as colibriAPI -from colibri.covmats import sqrt_covmat_jax +from colibri.covmats import general_sqrt_covariance_matrix from colibri.tests.conftest import T0_PDFSET, TEST_DATASETS TEST_COVMATS_FOLDER = pathlib.Path(__file__).with_name("test_covmats") -def test_sqrt_covmat_jax(): +def test_general_sqrt_covariance_matrix(): """ - Test that sqrt_covmat_jax actually computes the square root of a matrix. + Test that general_sqrt_covariance_matrix actually computes the square root of a matrix. """ test_matrix = jnp.array([[4, 12, -16], [12, 37, -43], [-16, -43, 98]]) # This matrix has square root [[2, 0, 0], [6, 1, 0], [-8, 5, 3]] - sqrt_matrix = sqrt_covmat_jax(test_matrix) + sqrt_matrix = general_sqrt_covariance_matrix(test_matrix) actual_sqrt = jnp.array([[2, 0, 0], [6, 1, 0], [-8, 5, 3]]) assert_allclose(sqrt_matrix, actual_sqrt, rtol=1e-5) diff --git a/colibri/tests/test_data_batch.py b/colibri/tests/test_data_batch.py index 74d58c34a..c6f79c6c4 100644 --- a/colibri/tests/test_data_batch.py +++ b/colibri/tests/test_data_batch.py @@ -40,7 +40,7 @@ def test_data_batches(): assert hasattr(next_batch, "idx") assert isinstance(next_batch.idx, jax.Array) assert len(next_batch.idx) == batch_size - # inv_cov is optional and for this call (no fit_covariance_matrix) should be None + # inv_cov is optional and for this call (no general_covariance_matrix) should be None assert getattr(next_batch, "inv_cov", None) is None # When shuffle_each_epoch=False (default) fixed_batches should be available @@ -69,7 +69,7 @@ def test_data_batches_with_covmat(): db = data_batches( training_indices, batch_size=batch_size, - fit_covariance_matrix=cov, + general_covariance_matrix=cov, batch_seed=42, ) diff --git a/colibri/tests/test_likelihood.py b/colibri/tests/test_likelihood.py index 4d69644cc..c639a2ccd 100644 --- a/colibri/tests/test_likelihood.py +++ b/colibri/tests/test_likelihood.py @@ -12,7 +12,7 @@ from colibri.likelihood import LogLikelihood, log_likelihood, mc_log_likelihood from colibri.mc_utils import MCPseudodata from colibri.tests.conftest import ( - MOCK_CENTRAL_COVMAT_INDEX, + MOCK_CENTRAL_SQRT_COVMAT_INDEX, MOCK_PDF_MODEL, MOCK_PENALTY_POSDATA, TEST_FK_ARRAYS, @@ -36,7 +36,7 @@ def test_LogLikelihood_class(pos_penalty): Tests the LogLikelihood class. """ log_likelihood_class = LogLikelihood( - central_covmat_index=MOCK_CENTRAL_COVMAT_INDEX, + central_sqrt_covmat_index=MOCK_CENTRAL_SQRT_COVMAT_INDEX, pdf_model=MOCK_PDF_MODEL, fit_xgrid=TEST_XGRID, forward_map=TEST_FORWARD_MAP_DIS, @@ -52,10 +52,12 @@ def test_LogLikelihood_class(pos_penalty): ) assert_allclose( - MOCK_CENTRAL_COVMAT_INDEX.central_values, + MOCK_CENTRAL_SQRT_COVMAT_INDEX.central_values, log_likelihood_class.central_values, ) - assert_allclose(MOCK_CENTRAL_COVMAT_INDEX.covmat, log_likelihood_class.covmat) + assert_allclose( + MOCK_CENTRAL_SQRT_COVMAT_INDEX.sqrt_covmat, log_likelihood_class.sqrt_covmat + ) assert MOCK_PDF_MODEL == log_likelihood_class.pdf_model assert MOCK_PENALTY_POSDATA == log_likelihood_class.penalty_posdata @@ -71,7 +73,8 @@ def test_LogLikelihood_class(pos_penalty): ) predictions = predictions[log_likelihood_class.central_values_idx] diff = predictions - log_likelihood_class.central_values - chi2_val = jnp.einsum("i,ij,j", diff, log_likelihood_class.inv_covmat, diff) + z = jnp.einsum("ij, j", log_likelihood_class.inv_sqrt_covmat, diff) + chi2_val = jnp.dot(z, z) pos_pen = ( jnp.sum( @@ -102,7 +105,7 @@ def test_log_likelihood(pos_penalty): {"positivity_penalty": pos_penalty, "alpha": 1e-7, "lambda_positivity": 1000}, ) log_likelihood_class = LogLikelihood( - central_covmat_index=MOCK_CENTRAL_COVMAT_INDEX, + central_sqrt_covmat_index=MOCK_CENTRAL_SQRT_COVMAT_INDEX, pdf_model=MOCK_PDF_MODEL, fit_xgrid=TEST_XGRID, forward_map=TEST_FORWARD_MAP_DIS, @@ -113,7 +116,7 @@ def test_log_likelihood(pos_penalty): integrability_penalty=integrability_penalty, ) log_like = log_likelihood( - MOCK_CENTRAL_COVMAT_INDEX, + MOCK_CENTRAL_SQRT_COVMAT_INDEX, MOCK_PDF_MODEL, TEST_XGRID, TEST_FORWARD_MAP_DIS, @@ -141,7 +144,7 @@ def test_log_likelihood_with_and_without_pos_penalty(): # Instantiate the class log_likelihood_class = LogLikelihood( - MOCK_CENTRAL_COVMAT_INDEX, + MOCK_CENTRAL_SQRT_COVMAT_INDEX, MOCK_PDF_MODEL, TEST_XGRID, TEST_FORWARD_MAP_DIS, @@ -159,7 +162,7 @@ def test_log_likelihood_with_and_without_pos_penalty(): ll_value_with_penalty = log_likelihood_class.log_likelihood( params, log_likelihood_class.central_values, - log_likelihood_class.inv_covmat, + log_likelihood_class.inv_sqrt_covmat, log_likelihood_class.fast_kernel_arrays, log_likelihood_class.positivity_fast_kernel_arrays, ) @@ -170,7 +173,9 @@ def test_log_likelihood_with_and_without_pos_penalty(): ) predictions = predictions[log_likelihood_class.central_values_idx] diff = predictions - log_likelihood_class.central_values - chi2_val = jnp.einsum("i,ij,j", diff, log_likelihood_class.inv_covmat, diff) + z = jnp.einsum("ij, j", log_likelihood_class.inv_sqrt_covmat, diff) + + chi2_val = jnp.dot(z, z) pos_pen = jnp.sum( log_likelihood_class.penalty_posdata( pdf, @@ -193,7 +198,7 @@ def test_log_likelihood_with_and_without_pos_penalty(): # Instantiate the class log_likelihood_class = LogLikelihood( - MOCK_CENTRAL_COVMAT_INDEX, + MOCK_CENTRAL_SQRT_COVMAT_INDEX, MOCK_PDF_MODEL, TEST_XGRID, TEST_FORWARD_MAP_DIS, @@ -207,7 +212,7 @@ def test_log_likelihood_with_and_without_pos_penalty(): ll_value_without_penalty = log_likelihood_class.log_likelihood( params, log_likelihood_class.central_values, - log_likelihood_class.inv_covmat, + log_likelihood_class.inv_sqrt_covmat, log_likelihood_class.fast_kernel_arrays, log_likelihood_class.positivity_fast_kernel_arrays, ) @@ -218,7 +223,8 @@ def test_log_likelihood_with_and_without_pos_penalty(): ) predictions = predictions[log_likelihood_class.central_values_idx] diff = predictions - log_likelihood_class.central_values - chi2_val = jnp.einsum("i,ij,j", diff, log_likelihood_class.inv_covmat, diff) + z = jnp.einsum("ij, j", log_likelihood_class.inv_sqrt_covmat, diff) + chi2_val = jnp.dot(z, z) expected_without_penalty = -0.5 * chi2_val assert float(ll_value_without_penalty) == pytest.approx( float(expected_without_penalty) @@ -234,7 +240,7 @@ def test_mc_log_likelihood_with_split(pos_penalty): # Create a tiny pseudodata setup consistent with TEST_N_DATA = 2 pseudodata = jnp.array([1.0, 2.0]) - fit_covariance_matrix = jnp.eye(2) + general_covariance_matrix = jnp.eye(2) training_indices = jnp.array([0]) validation_indices = jnp.array([1]) @@ -253,7 +259,7 @@ def test_mc_log_likelihood_with_split(pos_penalty): train_loglike, val_loglike = mc_log_likelihood( mc_pd, - fit_covariance_matrix, + general_covariance_matrix, MOCK_PDF_MODEL, TEST_XGRID, TEST_FORWARD_MAP_DIS, @@ -278,7 +284,7 @@ def compute_expected(ll_obj): preds, pdf = ll_obj.pred_and_pdf(params, ll_obj.fast_kernel_arrays) preds = preds[ll_obj.central_values_idx] diff = preds - ll_obj.central_values - inv = ll_obj.inv_covmat + inv = ll_obj.inv_sqrt_covmat.T @ ll_obj.inv_sqrt_covmat chi2_val = jnp.einsum("i,ij,j", diff, inv, diff) pos_pen = ( jnp.sum( @@ -312,7 +318,7 @@ def test_mc_log_likelihood_without_split_returns_nan_for_validation(pos_penalty) # Pseudodata across both points; training uses all when no split pseudodata = jnp.array([1.0, 2.0]) - fit_covariance_matrix = jnp.eye(2) + general_covariance_matrix = jnp.eye(2) training_indices = jnp.array([0, 1]) validation_indices = jnp.array([]) @@ -331,7 +337,7 @@ def test_mc_log_likelihood_without_split_returns_nan_for_validation(pos_penalty) train_loglike, val_loglike = mc_log_likelihood( mc_pd, - fit_covariance_matrix, + general_covariance_matrix, MOCK_PDF_MODEL, TEST_XGRID, TEST_FORWARD_MAP_DIS, @@ -353,7 +359,8 @@ def test_mc_log_likelihood_without_split_returns_nan_for_validation(pos_penalty) ) predictions = predictions[train_loglike.central_values_idx] diff = predictions - train_loglike.central_values - chi2_val = jnp.einsum("i,ij,j", diff, train_loglike.inv_covmat, diff) + z = jnp.einsum("ij, j", train_loglike.inv_sqrt_covmat, diff) + chi2_val = jnp.dot(z, z) pos_pen = ( jnp.sum( train_loglike.penalty_posdata( @@ -390,7 +397,7 @@ def test_LogLikelihood_call_with_batch_idx(pos_penalty): } log_likelihood_class = LogLikelihood( - central_covmat_index=MOCK_CENTRAL_COVMAT_INDEX, + central_sqrt_covmat_index=MOCK_CENTRAL_SQRT_COVMAT_INDEX, pdf_model=MOCK_PDF_MODEL, fit_xgrid=TEST_XGRID, forward_map=TEST_FORWARD_MAP_DIS, @@ -415,7 +422,7 @@ def test_LogLikelihood_call_with_batch_idx(pos_penalty): predictions = predictions[log_likelihood_class.central_values_idx] predictions_b = predictions[batch.idx] central_b = log_likelihood_class.central_values[batch.idx] - cov_b = log_likelihood_class.covmat[batch.idx][:, batch.idx] + cov_b = log_likelihood_class.sqrt_covmat[batch.idx][:, batch.idx] inv_b = jnp.linalg.inv(cov_b) diff_b = predictions_b - central_b chi2_b = jnp.einsum("i,ij,j", diff_b, inv_b, diff_b) @@ -452,7 +459,7 @@ def test_LogLikelihood_call_with_batch_with_inv_cov(pos_penalty): } log_likelihood_class = LogLikelihood( - central_covmat_index=MOCK_CENTRAL_COVMAT_INDEX, + central_sqrt_covmat_index=MOCK_CENTRAL_SQRT_COVMAT_INDEX, pdf_model=MOCK_PDF_MODEL, fit_xgrid=TEST_XGRID, forward_map=TEST_FORWARD_MAP_DIS, @@ -467,7 +474,7 @@ def test_LogLikelihood_call_with_batch_with_inv_cov(pos_penalty): # Select first two data points and precompute their inverse covariance batch_idx = jnp.array([0, 1]) - cov_b = log_likelihood_class.covmat[batch_idx][:, batch_idx] + cov_b = log_likelihood_class.sqrt_covmat[batch_idx][:, batch_idx] inv_b = jnp.linalg.inv(cov_b) # Provide the precomputed inverse covariance in the BatchSpec diff --git a/colibri/tests/test_mc_utils.py b/colibri/tests/test_mc_utils.py index a661b74fb..b67c42040 100644 --- a/colibri/tests/test_mc_utils.py +++ b/colibri/tests/test_mc_utils.py @@ -23,10 +23,12 @@ TEST_COMMONDATA_FOLDER, TEST_DATASETS, TRVAL_INDEX, + T0_PDFSET, ) MC_PSEUDODATA = { "level_1_seed": PSEUDODATA_SEED, + **T0_PDFSET, **CLOSURE_TEST_PDFSET, **TRVAL_INDEX, **REPLICA_INDEX, diff --git a/colibri/tests/test_ultranest_fit.py b/colibri/tests/test_ultranest_fit.py index 4b6df815f..4f6abb440 100644 --- a/colibri/tests/test_ultranest_fit.py +++ b/colibri/tests/test_ultranest_fit.py @@ -12,7 +12,7 @@ import pytest from colibri.tests.conftest import ( - MOCK_CENTRAL_COVMAT_INDEX, + MOCK_CENTRAL_SQRT_COVMAT_INDEX, MOCK_PDF_MODEL, MOCK_PENALTY_POSDATA, TEST_FK_ARRAYS, @@ -67,7 +67,7 @@ def test_ultranest_fit(pos_penalty): _pred_data = lambda *args: jnp.array([0.0]) mock_log_likelihood = LogLikelihood( - MOCK_CENTRAL_COVMAT_INDEX, + MOCK_CENTRAL_SQRT_COVMAT_INDEX, MOCK_PDF_MODEL, TEST_XGRID, _pred_data, @@ -106,7 +106,7 @@ def test_ultranest_fit_vectorized(pos_penalty): ultranest_settings["ReactiveNS_settings"]["vectorized"] = True mock_log_likelihood = LogLikelihood( - MOCK_CENTRAL_COVMAT_INDEX, + MOCK_CENTRAL_SQRT_COVMAT_INDEX, MOCK_PDF_MODEL, TEST_XGRID, _pred_data, @@ -154,7 +154,7 @@ def test_ultranest_fit_with_SliceSampler(pos_penalty): _pred_data = lambda *args: jnp.array([0.0]) mock_log_likelihood = LogLikelihood( - MOCK_CENTRAL_COVMAT_INDEX, + MOCK_CENTRAL_SQRT_COVMAT_INDEX, MOCK_PDF_MODEL, TEST_XGRID, _pred_data, @@ -202,7 +202,7 @@ def test_ultranest_fit_with_popSliceSampler(pos_penalty): _pred_data = lambda *args: jnp.array([0.0]) mock_log_likelihood = LogLikelihood( - MOCK_CENTRAL_COVMAT_INDEX, + MOCK_CENTRAL_SQRT_COVMAT_INDEX, MOCK_PDF_MODEL, TEST_XGRID, _pred_data, @@ -254,7 +254,7 @@ def test_ultranest_fit_with_sampler_plot(mock_sampler_class, pos_penalty): _pred_data = lambda *args: jnp.array([0.0]) mock_log_likelihood = LogLikelihood( - MOCK_CENTRAL_COVMAT_INDEX, + MOCK_CENTRAL_SQRT_COVMAT_INDEX, MOCK_PDF_MODEL, TEST_XGRID, _pred_data, diff --git a/colibri/tests/test_utils.py b/colibri/tests/test_utils.py index 9a4981ceb..8833bae52 100644 --- a/colibri/tests/test_utils.py +++ b/colibri/tests/test_utils.py @@ -26,7 +26,7 @@ from colibri.api import API as cAPI from colibri.tests.conftest import ( - MOCK_CENTRAL_COVMAT_INDEX, + MOCK_CENTRAL_SQRT_COVMAT_INDEX, MOCK_PDF_MODEL, TEST_DATASET, TEST_DATASET_HAD, @@ -339,7 +339,7 @@ def test_likelihood_float_type( ): _pred_data = lambda x, fks: jnp.ones( - len(MOCK_CENTRAL_COVMAT_INDEX.central_values) + len(MOCK_CENTRAL_SQRT_COVMAT_INDEX.central_values) ) # Mock _pred_data FIT_XGRID = jnp.linspace(0, 1, 10) # Mock FIT_XGRID output_path = tmp_path @@ -355,7 +355,7 @@ def test_likelihood_float_type( FIT_XGRID=FIT_XGRID, bayesian_prior=mock_bayesian_prior, output_path=output_path, - central_covmat_index=MOCK_CENTRAL_COVMAT_INDEX, + central_covmat_index=MOCK_CENTRAL_SQRT_COVMAT_INDEX, fast_kernel_arrays=fast_kernel_arrays, )