Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 9 additions & 13 deletions colibri/analytic_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def analytic_evidence_uniform_prior(sol_covmat, sol_mean, max_logl, a_vec, b_vec

@check_pdf_model_is_linear
def analytic_fit(
central_covmat_index,
central_sqrt_covmat_index,
_pred_data,
pdf_model,
analytic_settings,
Expand All @@ -103,8 +103,8 @@ def analytic_fit(

Parameters
----------
central_covmat_index: commondata_utils.CentralCovmatIndex
dataclass containing central values and covariance matrix.
central_sqrt_covmat_index: commondata_utils.CentralSqrtCovmatIndex
dataclass containing central values and square root of the covariance matrix.

_pred_data: @jax.jit CompiledFunction
Prediction function for the fit.
Expand Down Expand Up @@ -142,22 +142,18 @@ def analytic_fit(
intercept = pred_and_pdf(jnp.zeros(len(parameters)), fast_kernel_arrays)[0]

# Construct the analytic solution
central_values = central_covmat_index.central_values
covmat = central_covmat_index.covmat
central_values = central_sqrt_covmat_index.central_values
sqrt_covmat = central_sqrt_covmat_index.sqrt_covmat

# Solve chi2 analytically for the mean
Y = central_values - intercept
X = predictions.T - intercept[:, None]

t0 = time.time()

# Cholesky factorization: S = L L^T
# upper False means that we want the lower triangular matrix L
L = jla.cholesky(covmat, upper=False)

# Whiten the problem: Y' = L^-1 Y, X' = L^-1 X
Y_tilde = jlinalg.triangular_solve(L, Y, left_side=True, lower=True)
X_tilde = jlinalg.triangular_solve(L, X, left_side=True, lower=True)
# Whiten the problem: Y' = sqrt_covmat^-1 Y, X' = sqrt_covmat^-1 X
Y_tilde = jlinalg.triangular_solve(sqrt_covmat, Y, left_side=True, lower=True)
X_tilde = jlinalg.triangular_solve(sqrt_covmat, X, left_side=True, lower=True)

if jnp.any(jla.eigh(X_tilde.T @ X_tilde)[0] <= 0.0):
raise ValueError(
Expand Down Expand Up @@ -257,7 +253,7 @@ def analytic_fit(
min_chi2 = -2 * max_logl
log.info(f"Minimum chi2 = {min_chi2}")

BIC = min_chi2 + sol_covmat.shape[0] * np.log(covmat.shape[0])
BIC = min_chi2 + sol_covmat.shape[0] * np.log(sqrt_covmat.shape[0])
AIC = min_chi2 + 2 * sol_covmat.shape[0]

# Compute average chi2 (in whitened basis)
Expand Down
40 changes: 15 additions & 25 deletions colibri/commondata_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import jax.numpy as jnp

from colibri.theory_predictions import make_pred_dataset
from colibri.core import CentralCovmatIndex
from colibri.core import CentralSqrtCovmatIndex


def experimental_commondata_tuple(data):
Expand Down Expand Up @@ -101,21 +101,21 @@ def level_0_commondata_tuple(

def level_1_commondata_tuple(
level_0_commondata_tuple,
data_generation_covariance_matrix,
general_covariance_matrix,
level_1_seed=123456,
):
"""
Returns a tuple (validphys nodes should be immutable)
of level 1 commondata instances.
Noise is added to the level_0_commondata_tuple central values
according to a multivariate Gaussian with covariance data_generation_covariance_matrix
according to a multivariate Gaussian with covariance general_covariance_matrix

Parameters
----------
level_0_commondata_tuple: tuple of nnpdf_data.coredata.CommonData instances
A tuple of level_0 closure test data.

data_generation_covariance_matrix: jnp.ndarray
general_covariance_matrix: jnp.ndarray
The covariance matrix used for data generation.

level_1_seed: int
Expand All @@ -133,11 +133,11 @@ def level_1_commondata_tuple(
)

# Now, sample from the multivariate Gaussian with central values central_values
# and covariance matrix data_generation_covariance_matrix. This produces the
# and general_covariance_matrix. This produces the
# level_1 data.
rng = jax.random.PRNGKey(level_1_seed)
sample = jax.random.multivariate_normal(
rng, central_values, data_generation_covariance_matrix
rng, central_values, general_covariance_matrix
)

# Now, reconstruct the commondata tuple, by modifying the original commondata
Expand All @@ -150,11 +150,11 @@ def level_1_commondata_tuple(
return tuple(sample_list)


def central_covmat_index(commondata_tuple, fit_covariance_matrix):
def central_sqrt_covmat_index(commondata_tuple, general_sqrt_covariance_matrix):
"""
Given a commondata_tuple and a covariance_matrix, generated
Given a commondata_tuple and a general_sqrt_covariance_matrix, generated
according to respective explicit node in config.py, store
relevant data into CentralCovmatIndex dataclass.
relevant data into CentralSqrtCovmatIndex dataclass.

Parameters
----------
Expand All @@ -163,15 +163,14 @@ def central_covmat_index(commondata_tuple, fit_covariance_matrix):
(see config.produce_commondata_tuple) and accordingly to the
specified options.

fit_covariance_matrix: jnp.ndarray
covariance matrix, is generated as explicit node
(see config.fit_covariance_matrix) can be either experimental
general_sqrt_covariance_matrix: jnp.ndarray
square root of the covariance matrix which can be either experimental
or t0 covariance matrix depending on whether `use_fit_t0` is
True or False
True or False.

Returns
-------
CentralCovmatIndex
CentralSqrtCovmatIndex
Dataclass containing central values, covariance matrix and
index of central values.
"""
Expand All @@ -180,17 +179,8 @@ def central_covmat_index(commondata_tuple, fit_covariance_matrix):
)
central_values_idx = jnp.arange(central_values.shape[0])

return CentralCovmatIndex(
return CentralSqrtCovmatIndex(
central_values=central_values,
central_values_idx=central_values_idx,
covmat=fit_covariance_matrix,
sqrt_covmat=general_sqrt_covariance_matrix,
)


def pseudodata_central_covmat_index(
commondata_tuple, data_generation_covariance_matrix
):
"""Same as central_covmat_index, but with the pseudodata generation
covariance matrix for a Monte Carlo fit.
"""
return central_covmat_index(commondata_tuple, data_generation_covariance_matrix)
15 changes: 1 addition & 14 deletions colibri/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,7 +607,7 @@ def produce_commondata_tuple(self, closure_test_level=False):
)

@explicit_node
def produce_fit_covariance_matrix(self, use_fit_t0: bool = True):
def produce_general_covariance_matrix(self, use_fit_t0: bool = True):
"""
Produces the covariance matrix used in the fit.
This covariance matrix is used in:
Expand All @@ -619,19 +619,6 @@ def produce_fit_covariance_matrix(self, use_fit_t0: bool = True):
else:
return colibri_covmats.dataset_inputs_covmat_from_systematics

@explicit_node
def produce_data_generation_covariance_matrix(self, use_gen_t0: bool = False):
"""Produces the covariance matrix used in:
- level 1 closure test data construction (fluctuating around the level
0 data)
- Monte Carlo pseudodata (fluctuating either around the level 0 data or
level 1 data)
"""
if use_gen_t0:
return colibri_covmats.dataset_inputs_t0_covmat_from_systematics
else:
return colibri_covmats.dataset_inputs_covmat_from_systematics

def parse_closure_test_pdf(self, name):
"""PDF set used to generate fakedata"""
if name == "colibri_model":
Expand Down
4 changes: 2 additions & 2 deletions colibri/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,9 +256,9 @@ class PosdataTrainValidationSplit(TrainValidationSplit):


@dataclass(frozen=True)
class CentralCovmatIndex:
class CentralSqrtCovmatIndex:
central_values: jnp.array
covmat: jnp.array
sqrt_covmat: jnp.array
central_values_idx: jnp.array

def to_dict(self):
Expand Down
17 changes: 11 additions & 6 deletions colibri/covmats.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,19 @@
from validphys import covmats


def sqrt_covmat_jax(covariance_matrix):
def general_sqrt_covariance_matrix(general_covariance_matrix):
"""
Same as `validphys.covmats.sqrt_covmat` but
for jax.numpy arrays

Parameters
----------
covariance_matrix : jnp.ndarray
general_covariance_matrix : jnp.ndarray
A positive definite covariance matrix, which is N_dat x N_dat (where
N_dat is the number of data points after cuts) containing uncertainty
and correlation information.
NOTE: for more details on what covariance matrix is used, see the production rule in `config.py`
for the options of covariance matrix.

Returns
-------
Expand All @@ -35,9 +37,9 @@ def sqrt_covmat_jax(covariance_matrix):
``jnp.allclose(sqrt_covmat @ sqrt_covmat.T, covariance_matrix)``.
"""

dimensions = covariance_matrix.shape
dimensions = general_covariance_matrix.shape

if covariance_matrix.size == 0:
if general_covariance_matrix.size == 0:
raise ValueError("Attempting the decomposition of an empty matrix.")
elif dimensions[0] != dimensions[1]:
raise ValueError(
Expand All @@ -46,8 +48,11 @@ def sqrt_covmat_jax(covariance_matrix):
f"{dimensions[1]}"
)

sqrt_diags = jnp.sqrt(jnp.diag(covariance_matrix))
correlation_matrix = covariance_matrix / sqrt_diags[:, jnp.newaxis] / sqrt_diags
sqrt_diags = jnp.sqrt(jnp.diag(general_covariance_matrix))
correlation_matrix = (
general_covariance_matrix / sqrt_diags[:, jnp.newaxis] / sqrt_diags
)
# NOTE: scipy.linalg.cholesky returns the upper triangular decomposition by default
decomp = jla.cholesky(correlation_matrix)
sqrt_matrix = (decomp * sqrt_diags).T
return sqrt_matrix
Expand Down
10 changes: 6 additions & 4 deletions colibri/data_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def data_batches(
training_indices,
batch_size=None,
batch_seed=1,
fit_covariance_matrix=None,
general_covariance_matrix=None,
shuffle_each_epoch=False,
) -> DataBatches:
"""
Expand All @@ -31,7 +31,7 @@ def data_batches(

batch_seed: int, default is 1

fit_covariance_matrix: jax.Array, optional
general_covariance_matrix: jax.Array, optional
If provided together with shuffle_each_epoch=False, fixed batches are
precomputed once and the corresponding inverse covariance submatrices
are cached for reuse. This avoids inverting within the likelihood at
Expand Down Expand Up @@ -79,8 +79,10 @@ def _slice_batches_from_perm(perm: jax.Array) -> List[jax.Array]:
perm0 = _make_perm(key)
fixed_batches = _slice_batches_from_perm(perm0)

if fit_covariance_matrix is not None:
train_covmat = fit_covariance_matrix[training_indices][:, training_indices]
if general_covariance_matrix is not None:
train_covmat = general_covariance_matrix[training_indices][
:, training_indices
]
fixed_batches_specs = [
BatchSpec(
idx=b,
Expand Down
Loading
Loading