Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions colibri/analytic_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def analytic_fit(
"Assuming that the prior is wide enough to fully cover the gaussian likelihood."
)

parameters = pdf_model.param_names
parameters = pdf_model.full_param_names
pred_and_pdf = pdf_model.pred_and_pdf_func(FIT_XGRID, forward_map=_pred_data)

# Precompute predictions for the basis of the model
Expand Down Expand Up @@ -263,7 +263,7 @@ def analytic_fit(
return AnalyticFit(
analytic_specs=analytic_settings,
resampled_posterior=samples,
param_names=parameters,
full_param_names=parameters,
full_posterior_samples=full_samples,
bayesian_metrics={
"bayes_complexity": Cb,
Expand Down
6 changes: 3 additions & 3 deletions colibri/bayes_prior.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,15 @@ def bayesian_prior(prior_settings, pdf_model):

if "bounds" in prior_specs:
# Use param names from the model to order bounds correctly
param_names = pdf_model.param_names
full_param_names = pdf_model.full_param_names
bounds_dict = prior_specs["bounds"]

missing = [p for p in param_names if p not in bounds_dict]
missing = [p for p in full_param_names if p not in bounds_dict]
if missing:
raise ValueError(f"Missing bounds for parameters: {missing}")

# Per-parameter bounds
bounds = jnp.array([bounds_dict[param] for param in param_names])
bounds = jnp.array([bounds_dict[param] for param in full_param_names])
mins = bounds[:, 0]
maxs = bounds[:, 1]

Expand Down
2 changes: 1 addition & 1 deletion colibri/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def check_pdf_model_is_linear(pdf_model, FIT_XGRID, data):
pred_data = make_pred_data(data, FIT_XGRID)
fk = fast_kernel_arrays(data, FIT_XGRID)

parameters = pdf_model.param_names
parameters = pdf_model.full_param_names
pred_and_pdf = pdf_model.pred_and_pdf_func(FIT_XGRID, forward_map=pred_data)
intercept = pred_and_pdf(jnp.zeros(len(parameters)), fk)[0]

Expand Down
11 changes: 7 additions & 4 deletions colibri/closure_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,18 +73,21 @@ def closure_test_colibri_model_pdf(closure_test_model_settings, FIT_XGRID):

# Compute the pdf grid
pdf_grid_func = pdf_model.grid_values_func(FIT_XGRID)
# check that parameters keys are the same as pdf_model.param_names
# check that parameters keys are the same as pdf_model.full_param_names
if set(closure_test_model_settings["parameters"].keys()) != set(
pdf_model.param_names
pdf_model.full_param_names
):
raise ValueError(
"The provided parameters do not match the model's parameter names:\n"
f"Provided: {list(closure_test_model_settings['parameters'].keys())}\n"
f"Expected: {pdf_model.param_names}"
f"Expected: {pdf_model.full_param_names}"
)

params = jnp.array(
[closure_test_model_settings["parameters"][p] for p in pdf_model.param_names]
[
closure_test_model_settings["parameters"][p]
for p in pdf_model.full_param_names
]
)
pdf_grid = pdf_grid_func(params)

Expand Down
6 changes: 3 additions & 3 deletions colibri/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class BayesianFit:

Attributes
----------
param_names: list
full_param_names: list
List of the names of the parameters.
resampled_posterior: jnp.array
Array containing the resampled posterior samples.
Expand All @@ -77,7 +77,7 @@ class BayesianFit:
The log evidence of the model.
"""

param_names: list
full_param_names: list
resampled_posterior: jnp.array
full_posterior_samples: jnp.array
bayesian_metrics: dict
Expand Down Expand Up @@ -149,7 +149,7 @@ class HessianFit:
hessian: jnp.ndarray
cov_params: jnp.ndarray
resampled_posterior: jnp.ndarray
param_names: list
full_param_names: list


@dataclass(frozen=True)
Expand Down
8 changes: 5 additions & 3 deletions colibri/export_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,14 @@ def export_bayes_results(

# Write full sample to csv
full_samples_df = pd.DataFrame(
bayes_fit.full_posterior_samples, columns=bayes_fit.param_names
bayes_fit.full_posterior_samples, columns=bayes_fit.full_param_names
)
full_samples_df.to_csv(
str(output_path) + "/full_posterior_sample.csv", float_format="%.5e"
)

# Save the resampled results
df = pd.DataFrame(bayes_fit.resampled_posterior, columns=bayes_fit.param_names)
df = pd.DataFrame(bayes_fit.resampled_posterior, columns=bayes_fit.full_param_names)
df.to_csv(str(output_path) + f"/{results_name}.csv", float_format="%.5e")

# Save bayesian metrics to csv file
Expand Down Expand Up @@ -88,7 +88,9 @@ def export_hessian_results(
"""

# Save the resampled results
df = pd.DataFrame(hessian_fit.resampled_posterior, columns=hessian_fit.param_names)
df = pd.DataFrame(
hessian_fit.resampled_posterior, columns=hessian_fit.full_param_names
)
df.to_csv(str(output_path) + f"/{results_name}.csv", float_format="%.5e")

# Write the optimized parameters, the covmat and the min chi2, the training loss to a csv file
Expand Down
2 changes: 1 addition & 1 deletion colibri/hessian_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def valid_chi2(params):
hessian=hessian,
cov_params=cov_params,
resampled_posterior=hessian_param_set,
param_names=pdf_model.param_names,
full_param_names=pdf_model.full_param_names,
)


Expand Down
6 changes: 3 additions & 3 deletions colibri/likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,7 @@ def __init__(
self.positivity_penalty_settings = positivity_penalty_settings
self.integrability_penalty = integrability_penalty

self.pred_and_pdf = pdf_model.pred_and_pdf_func(
fit_xgrid, forward_map=forward_map
)
self.pred_and_pdf = pdf_model.predictions(fit_xgrid, forward_map=forward_map)

self.fast_kernel_arrays = fast_kernel_arrays
self.positivity_fast_kernel_arrays = positivity_fast_kernel_arrays
Expand Down Expand Up @@ -123,6 +121,8 @@ def log_likelihood(
jnp.ndarray
jax array with the value of the log-likelihood.
"""
# NOTE: here when passing params to pred_and_pdf (or more generally to predictions), we could first
# change it's data structure to separate pdf model parameters and extra parameters
predictions, pdf = self.pred_and_pdf(params, fast_kernel_arrays)
# Select only the data relevant for this likelihood
# Especially important when using a training/validation split
Expand Down
2 changes: 1 addition & 1 deletion colibri/monte_carlo_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def run_monte_carlo_fit(monte_carlo_fit, pdf_model, output_path, replica_index,
"""
mc_fit = monte_carlo_fit

df = pd.DataFrame(mc_fit.optimized_parameters, index=pdf_model.param_names).T
df = pd.DataFrame(mc_fit.optimized_parameters, index=pdf_model.full_param_names).T

# In a Monte Carlo fit, replicas are written to the fit_replicas
# directory, and mc_postfit must then be applied to select valid ones
Expand Down
22 changes: 12 additions & 10 deletions colibri/param_initialisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def pdf_initial_parameters(pdf_model, param_initialiser_settings, replica_index=
param_initialiser_settings["type"] = "zeros"

if param_initialiser_settings["type"] == "zeros":
return jnp.array([0.0] * len(pdf_model.param_names))
return jnp.array([0.0] * len(pdf_model.full_param_names))

if "random_seed" in param_initialiser_settings:
random_seed = jax.random.PRNGKey(
Expand All @@ -43,7 +43,7 @@ def pdf_initial_parameters(pdf_model, param_initialiser_settings, replica_index=
else:
random_seed = jax.random.PRNGKey(replica_index)

param_names = pdf_model.param_names
full_param_names = pdf_model.full_param_names

if param_initialiser_settings["type"] == "normal":
means_setting = param_initialiser_settings.get("means", 0.0)
Expand Down Expand Up @@ -79,35 +79,37 @@ def pdf_initial_parameters(pdf_model, param_initialiser_settings, replica_index=
def expand(setting, default, name):
# If dict → check consistency
if isinstance(setting, dict):
if len(setting) != len(param_names):
if len(setting) != len(full_param_names):
raise ValueError(
f"'{name}' dict must have one entry per parameter "
f"(got {len(setting)} for {len(param_names)} parameters)."
f"(got {len(setting)} for {len(full_param_names)} parameters)."
)
return jnp.array([setting.get(p, default) for p in param_names])
return jnp.array([setting.get(p, default) for p in full_param_names])
# If scalar → broadcast
elif isinstance(setting, (int, float)):
return jnp.full(len(param_names), setting)
return jnp.full(len(full_param_names), setting)
else:
raise TypeError(f"'{name}' must be dict or scalar, got {type(setting)}")

means = expand(means_setting, 0.0, "means")
stds = expand(stds_setting, 1.0, "stds")

normal_samples = jax.random.normal(key=random_seed, shape=(len(param_names),))
normal_samples = jax.random.normal(
key=random_seed, shape=(len(full_param_names),)
)
return means + stds * normal_samples

if param_initialiser_settings["type"] == "uniform":
if "bounds" in param_initialiser_settings:
# Use param names from the model to order bounds correctly
bounds_dict = param_initialiser_settings["bounds"]

missing = [p for p in param_names if p not in bounds_dict]
missing = [p for p in full_param_names if p not in bounds_dict]
if missing:
raise ValueError(f"Missing bounds for parameters: {missing}")

# Per-parameter bounds
bounds = jnp.array([bounds_dict[param] for param in param_names])
bounds = jnp.array([bounds_dict[param] for param in full_param_names])
min_val = bounds[:, 0]
max_val = bounds[:, 1]

Expand All @@ -127,7 +129,7 @@ def expand(setting, default, name):

initial_values = jax.random.uniform(
key=random_seed,
shape=(len(pdf_model.param_names),),
shape=(len(pdf_model.full_param_names),),
minval=min_val,
maxval=max_val,
)
Expand Down
68 changes: 67 additions & 1 deletion colibri/pdf_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,27 @@ class PDFModel(ABC):
@property
@abstractmethod
def param_names(self) -> list:
"""This should return a list of names for the fitted parameters of the model.
"""This should return a list of names for the fitted parameters of the PDF model.
The order of the names is important as it will be assumed to be the order of the parameters
fed to the model.
"""
pass

@property
def extra_params(self) -> list:
"""
This should return a list of names for parameters that are not used to parametrize the PDF model,
but are still relevant for the model (e.g. heavy quark masses, SMEFT parameters, etc.).

Default is an empty list.
"""
return []

@property
def full_param_names(self) -> list:
"""Returns a list of all parameter names, including both PDF model parameters and extra parameters."""
return self.param_names + self.extra_params

@abstractmethod
def grid_values_func(self, xgrid: ArrayLike) -> Callable[[jnp.array], jnp.ndarray]:
"""This function should produce a grid values function, which takes
Expand Down Expand Up @@ -90,3 +105,54 @@ def pred_and_pdf(params, fast_kernel_arrays):
return predictions, pdf

return pred_and_pdf

def extra_forward_map(self, predictions, extra_params):
"""
NOTE:
this function here is user specified and can potentially be anything.
The default forward_map always computes theory predictions from PDFs and FK tables.

This function can be completed to modify the predictions computed from the PDFs,
using extra parameters that are not part of the PDF parametrization.

e.g. (how the function could look like for SMEFT):
def extra_forward_map(self, predictions, extra_params):
# extra_params could contain SMEFT Wilson coefficients
# Modify predictions based on SMEFT contributions
modified_predictions = predictions + compute_smeft_contributions(predictions, extra_params)
return modified_predictions

Parameters
----------
predictions: jnp.ndarray
The theory predictions computed from the PDFs.
extra_params: list
The extra parameters to modify the predictions.
"""
raise NotImplementedError(
"extra_forward_map is not implemented for this PDFModel."
)

def predictions(self, xgrid, forward_map):
"""
The default simply returns self.pred_and_pdf_func when the extra_forward_map is NotImplemented.

TODO: ...
"""
pred_and_pdf = self.pred_and_pdf_func(xgrid, forward_map)

# Check if extra_forward_map is implemented
if self.extra_forward_map.__func__ is PDFModel.extra_forward_map:
return pred_and_pdf
else:

def modified_pred_and_pdf(params, fast_kernel_arrays):
predictions, pdf = pred_and_pdf(
params[: len(self.param_names)], fast_kernel_arrays
)
modified_predictions = self.extra_forward_map(
predictions, params[len(self.param_names) :]
)
return modified_predictions, pdf
Comment on lines +149 to +156
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thinking more about this, I am not sure how this can be done.
The call function in the likelihood.py module is passed directly to the sampler and requires an array in input.
Do you have some ideas @LucaMantani ?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it shouldn't be impossible to modify the call method so that params is not an array. I did that in the batching PR, modifying the batch_idx to a batch dataclass.


return modified_pred_and_pdf
Comment on lines +136 to +158
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function takes the same inputs of pred_and_pdf_func, so maybe all of this can be implemented directly inside of it?

2 changes: 1 addition & 1 deletion colibri/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ def wmin_param(params):


MOCK_PDF_MODEL = Mock()
MOCK_PDF_MODEL.param_names = ["param1", "param2"]
MOCK_PDF_MODEL.full_param_names = ["param1", "param2"]
MOCK_PDF_MODEL.grid_values_func = lambda xgrid: lambda params: jnp.sum(
jnp.array([param * TEST_PDF_GRID for param in params]), axis=0
)
Expand Down
8 changes: 4 additions & 4 deletions colibri/tests/test_analytic_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def test_analytic_fit(caplog):
assert (
result.resampled_posterior.shape[0] == analytic_settings["n_posterior_samples"]
)
assert len(result.param_names) == len(MOCK_PDF_MODEL.param_names)
assert len(result.full_param_names) == len(MOCK_PDF_MODEL.full_param_names)

# Check that it works if min_max_prior is False
analytic_settings["min_max_prior"] = False
Expand All @@ -109,7 +109,7 @@ def test_analytic_fit(caplog):
result_2.resampled_posterior.shape[0]
== analytic_settings["n_posterior_samples"]
)
assert len(result_2.param_names) == len(MOCK_PDF_MODEL.param_names)
assert len(result_2.full_param_names) == len(MOCK_PDF_MODEL.full_param_names)


def test_analytic_fit_different_priors(caplog):
Expand Down Expand Up @@ -144,7 +144,7 @@ def test_analytic_fit_different_priors(caplog):
assert (
result.resampled_posterior.shape[0] == analytic_settings["n_posterior_samples"]
)
assert len(result.param_names) == len(MOCK_PDF_MODEL.param_names)
assert len(result.full_param_names) == len(MOCK_PDF_MODEL.full_param_names)

PRIOR_SETTINGS2 = PriorSettings(
**{
Expand Down Expand Up @@ -174,7 +174,7 @@ def test_run_analytic_fit(mock_write_exportgrid, tmp_path):
mock_analytic_fit.resampled_posterior = jax.random.normal(
jax.random.PRNGKey(0), (10, 2)
)
mock_analytic_fit.param_names = ["param1", "param2"]
mock_analytic_fit.full_param_names = ["param1", "param2"]
mock_analytic_fit.full_posterior_samples = jax.random.normal(
jax.random.PRNGKey(0), (100, 2)
)
Expand Down
2 changes: 1 addition & 1 deletion colibri/tests/test_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def test_check_pdf_model_is_linear(mock_fast_kernel_arrays, mock_make_pred_data)

# Create a mock for the PDF model
mock_pdf_model = MagicMock()
mock_pdf_model.param_names = ["a", "b", "c"]
mock_pdf_model.full_param_names = ["a", "b", "c"]

# Mock the behavior of pred_and_pdf_func to return a linear model
def linear_model(params, fk):
Expand Down
Loading
Loading