Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 36 additions & 60 deletions src/smefit/analyze/fisher.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from matplotlib.legend_handler import HandlerPatch
from matplotlib.patches import Polygon
from mpl_toolkits.axes_grid1 import make_axes_locatable
from rich.progress import track

Check notice on line 11 in src/smefit/analyze/fisher.py

View check run for this annotation

codefactor.io / CodeFactor

src/smefit/analyze/fisher.py#L11

Unused track imported from rich.progress (unused-import)

from .latex_tools import latex_packages
from .pca import impose_constrain
Expand Down Expand Up @@ -59,13 +59,15 @@
coefficient manager
datasets: smefit.loader.DataTuple
DataTuple object with all the data information

best_fit_point: pandas.DataFrame
best fit point of the coefficients
"""

def __init__(self, coefficients, datasets, compute_quad):
def __init__(self, coefficients, datasets, best_fit_point, compute_quad):
self.coefficients = coefficients
self.free_parameters = self.coefficients.free_parameters.index
self.datasets = datasets
self.best_fit_point = best_fit_point

# update eft corrections with the constraints
if compute_quad:
Expand Down Expand Up @@ -97,72 +99,39 @@
fisher_tab, index=self.datasets.ExpNames, columns=self.free_parameters
)

def compute_quadratic(self, posterior_df, smeft_predictions):
def compute_quadratic(self):
"""Compute quadratic Fisher information."""
quad_fisher = []

# compute some average values over the replicas
# delta exp - th (n_dat)
delta_th = self.datasets.Commondata - np.mean(smeft_predictions, axis=0)
# c, c**2 mean (n_free_op)
posterior_df = posterior_df[self.free_parameters]
c_mean = np.mean(posterior_df.values, axis=0)
c2_mean = np.mean(posterior_df.values**2, axis=0)

# squared quad corr
diag_corr = np.diagonal(self.new_QuadraticCorrections, axis1=0, axis2=1)
off_diag_corr = self.new_QuadraticCorrections
diag_index = np.diag_indices(self.free_parameters.size)
off_diag_corr[diag_index[0], diag_index[1], :] = 0

# additional tensors
tmp = np.einsum("ri,ijk->rjk", posterior_df, off_diag_corr, optimize="optimal")
A_all = np.mean(tmp, axis=0) # (n_free_op, n_dat)
B_all = (
np.einsum("rj,rjk->jk", posterior_df, tmp, optimize="optimal")
/ posterior_df.shape[0]
) # (n_free_op, n_dat)
D_all = (
np.einsum("rjk,rjl->jkl", tmp, tmp, optimize="optimal")
/ posterior_df.shape[0]
) # (n_free_op, n_dat, n_dat)
best_fit_point = self.best_fit_point[self.free_parameters].values.flatten()

# symmeterise the quadratic corrections s.t. each off diagonal component carries half ot the total

Check notice on line 107 in src/smefit/analyze/fisher.py

View check run for this annotation

codefactor.io / CodeFactor

src/smefit/analyze/fisher.py#L107

Line too long (106/100) (line-too-long)
quad_symmetrised = 0.5 * (
np.einsum("ij...->ij...", self.new_QuadraticCorrections)
+ np.einsum("ij...->ji...", self.new_QuadraticCorrections)
)
covmat = self.datasets.CovMat

deltaT = self.new_LinearCorrections + 2 * np.einsum(
"l, ilm -> im", best_fit_point, quad_symmetrised
)

quad_fisher = []
cnt = 0
for ndat in track(
self.datasets.NdataExp,
description="[green]Computing quadratic Fisher information per dataset...",
):
# slice the big matrices

# this neglects correlations across datasets
for ndat in self.datasets.NdataExp:
idxs = slice(cnt, cnt + ndat)
quad_corr = diag_corr[idxs, :].T
lin_corr = self.new_LinearCorrections[:, idxs]
inv_corr = self.datasets.InvCovMat[idxs, idxs]
delta = delta_th[idxs]
A = A_all[:, idxs]
B = B_all[:, idxs]
D = D_all[:, idxs, idxs]

# (n_free_op)
fisher_row = (
-quad_corr @ inv_corr @ delta.T
- delta @ inv_corr @ quad_corr.T
+ lin_corr @ inv_corr @ A.T
+ A @ inv_corr @ lin_corr.T
+ 2
* c_mean
@ (
lin_corr @ inv_corr @ quad_corr.T
+ quad_corr @ inv_corr @ lin_corr.T
)
+ 2 * (B @ inv_corr @ quad_corr.T + quad_corr @ inv_corr @ B.T)
+ 4 * c2_mean @ quad_corr @ inv_corr @ quad_corr.T
+ np.einsum("ikl,kl -> i", D, inv_corr, optimize="optimal")
invcovmat_dataset = np.linalg.inv(covmat[idxs, idxs])
fisher_dataset = np.einsum(
"im, mn, jn", deltaT[:, idxs], invcovmat_dataset, deltaT[:, idxs]
)
quad_fisher.append(np.diag(fisher_row))
quad_fisher.append(np.diag(fisher_dataset))
cnt += ndat
# the full fisher is instead given by
# fisher_quad_all = np.einsum("im, mn, jn", A, covmat, A)

self.quad_fisher = pd.DataFrame(
quad_fisher + self.lin_fisher.values,
quad_fisher,
index=self.datasets.ExpNames,
columns=self.free_parameters,
)
Expand Down Expand Up @@ -485,7 +454,7 @@
label=labels[0],
),
mpatches.Polygon(
[[0.5, -0.5], [0.5, 0.5], [0.5, 0.5]],
[[0.5, -0.5], [0.5, 0.5], [-0.5, 0.5]],
closed=True,
fc="none",
edgecolor="black",
Expand Down Expand Up @@ -575,6 +544,11 @@
if column_names is not None:
custom_ordering = [list(column.keys())[0] for column in column_names]
fisher_dfs = [fisher_df.loc[custom_ordering] for fisher_df in fisher_dfs]
if quad_fisher_df is not None:
quad_fisher_dfs = [
quad_fisher_df.loc[custom_ordering]
for quad_fisher_df in quad_fisher_dfs
]
x_labels = [list(column.values())[0] for column in column_names]
else:
x_labels = [
Expand Down Expand Up @@ -607,6 +581,7 @@
ax.set_title(r"\rm Linear", fontsize=20, y=-0.08)
cax1 = make_axes_locatable(ax).append_axes("right", size="5%", pad=0.5)
colour_bar = fig.colorbar(mpl.cm.ScalarMappable(norm=norm, cmap=cmap), cax=cax1)
colour_bar.ax.tick_params(labelsize=22)

if quad_fisher_df is not None:
ax = fig.add_subplot(122)
Expand All @@ -624,6 +599,7 @@
colour_bar = fig.colorbar(
mpl.cm.ScalarMappable(norm=norm, cmap=cmap), cax=cax1
)
colour_bar.ax.tick_params(labelsize=22)

fig.subplots_adjust(top=0.9)

Expand Down
11 changes: 7 additions & 4 deletions src/smefit/analyze/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,7 +523,12 @@ def fisher(
fishers = {}
for fit in fit_list:
compute_quad = fit.config["use_quad"]
fisher_cal = FisherCalculator(fit.coefficients, fit.datasets, compute_quad)
fisher_cal = FisherCalculator(
fit.coefficients,
fit.datasets,
fit.results["best_fit_point"],
compute_quad,
)
fisher_cal.compute_linear()
fisher_cal.lin_fisher = fisher_cal.normalize(
fisher_cal.lin_fisher, norm=norm, log=log
Expand All @@ -535,9 +540,7 @@ def fisher(

# if necessary compute the quadratic Fisher
if compute_quad:
fisher_cal.compute_quadratic(
fit.results["samples"], fit.smeft_predictions
)
fisher_cal.compute_quadratic()
fisher_cal.quad_fisher = fisher_cal.normalize(
fisher_cal.quad_fisher, norm=norm, log=log
)
Expand Down
2 changes: 2 additions & 0 deletions src/smefit/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
"ExpNames",
"NdataExp",
"InvCovMat",
"CovMat",
"ThCovMat",
"Luminosity",
"Replica",
Expand Down Expand Up @@ -737,6 +738,7 @@ def load_datasets(
np.array(exp_name),
np.array(n_data_exp),
np.linalg.inv(fit_covmat),
fit_covmat,
theory_covariance,
np.array(lumi_exp),
replica,
Expand Down
Loading