diff --git a/src/smefit/analyze/fisher.py b/src/smefit/analyze/fisher.py index 2d813ece..7fdc9dc7 100644 --- a/src/smefit/analyze/fisher.py +++ b/src/smefit/analyze/fisher.py @@ -59,13 +59,15 @@ class FisherCalculator: coefficient manager datasets: smefit.loader.DataTuple DataTuple object with all the data information - + best_fit_point: pandas.DataFrame + best fit point of the coefficients """ - def __init__(self, coefficients, datasets, compute_quad): + def __init__(self, coefficients, datasets, best_fit_point, compute_quad): self.coefficients = coefficients self.free_parameters = self.coefficients.free_parameters.index self.datasets = datasets + self.best_fit_point = best_fit_point # update eft corrections with the constraints if compute_quad: @@ -97,72 +99,39 @@ def compute_linear(self): fisher_tab, index=self.datasets.ExpNames, columns=self.free_parameters ) - def compute_quadratic(self, posterior_df, smeft_predictions): + def compute_quadratic(self): """Compute quadratic Fisher information.""" - quad_fisher = [] - # compute some average values over the replicas - # delta exp - th (n_dat) - delta_th = self.datasets.Commondata - np.mean(smeft_predictions, axis=0) - # c, c**2 mean (n_free_op) - posterior_df = posterior_df[self.free_parameters] - c_mean = np.mean(posterior_df.values, axis=0) - c2_mean = np.mean(posterior_df.values**2, axis=0) - - # squared quad corr - diag_corr = np.diagonal(self.new_QuadraticCorrections, axis1=0, axis2=1) - off_diag_corr = self.new_QuadraticCorrections - diag_index = np.diag_indices(self.free_parameters.size) - off_diag_corr[diag_index[0], diag_index[1], :] = 0 - - # additional tensors - tmp = np.einsum("ri,ijk->rjk", posterior_df, off_diag_corr, optimize="optimal") - A_all = np.mean(tmp, axis=0) # (n_free_op, n_dat) - B_all = ( - np.einsum("rj,rjk->jk", posterior_df, tmp, optimize="optimal") - / posterior_df.shape[0] - ) # (n_free_op, n_dat) - D_all = ( - np.einsum("rjk,rjl->jkl", tmp, tmp, optimize="optimal") - / posterior_df.shape[0] - ) # (n_free_op, n_dat, n_dat) + best_fit_point = self.best_fit_point[self.free_parameters].values.flatten() + + # symmeterise the quadratic corrections s.t. each off diagonal component carries half ot the total + quad_symmetrised = 0.5 * ( + np.einsum("ij...->ij...", self.new_QuadraticCorrections) + + np.einsum("ij...->ji...", self.new_QuadraticCorrections) + ) + covmat = self.datasets.CovMat + deltaT = self.new_LinearCorrections + 2 * np.einsum( + "l, ilm -> im", best_fit_point, quad_symmetrised + ) + + quad_fisher = [] cnt = 0 - for ndat in track( - self.datasets.NdataExp, - description="[green]Computing quadratic Fisher information per dataset...", - ): - # slice the big matrices + + # this neglects correlations across datasets + for ndat in self.datasets.NdataExp: idxs = slice(cnt, cnt + ndat) - quad_corr = diag_corr[idxs, :].T - lin_corr = self.new_LinearCorrections[:, idxs] - inv_corr = self.datasets.InvCovMat[idxs, idxs] - delta = delta_th[idxs] - A = A_all[:, idxs] - B = B_all[:, idxs] - D = D_all[:, idxs, idxs] - - # (n_free_op) - fisher_row = ( - -quad_corr @ inv_corr @ delta.T - - delta @ inv_corr @ quad_corr.T - + lin_corr @ inv_corr @ A.T - + A @ inv_corr @ lin_corr.T - + 2 - * c_mean - @ ( - lin_corr @ inv_corr @ quad_corr.T - + quad_corr @ inv_corr @ lin_corr.T - ) - + 2 * (B @ inv_corr @ quad_corr.T + quad_corr @ inv_corr @ B.T) - + 4 * c2_mean @ quad_corr @ inv_corr @ quad_corr.T - + np.einsum("ikl,kl -> i", D, inv_corr, optimize="optimal") + invcovmat_dataset = np.linalg.inv(covmat[idxs, idxs]) + fisher_dataset = np.einsum( + "im, mn, jn", deltaT[:, idxs], invcovmat_dataset, deltaT[:, idxs] ) - quad_fisher.append(np.diag(fisher_row)) + quad_fisher.append(np.diag(fisher_dataset)) cnt += ndat + # the full fisher is instead given by + # fisher_quad_all = np.einsum("im, mn, jn", A, covmat, A) self.quad_fisher = pd.DataFrame( - quad_fisher + self.lin_fisher.values, + quad_fisher, index=self.datasets.ExpNames, columns=self.free_parameters, ) @@ -485,7 +454,7 @@ def plot_values(ax, dfs, cmap, norm, labels=None): label=labels[0], ), mpatches.Polygon( - [[0.5, -0.5], [0.5, 0.5], [0.5, 0.5]], + [[0.5, -0.5], [0.5, 0.5], [-0.5, 0.5]], closed=True, fc="none", edgecolor="black", @@ -575,6 +544,11 @@ def plot_heatmap( if column_names is not None: custom_ordering = [list(column.keys())[0] for column in column_names] fisher_dfs = [fisher_df.loc[custom_ordering] for fisher_df in fisher_dfs] + if quad_fisher_df is not None: + quad_fisher_dfs = [ + quad_fisher_df.loc[custom_ordering] + for quad_fisher_df in quad_fisher_dfs + ] x_labels = [list(column.values())[0] for column in column_names] else: x_labels = [ @@ -607,6 +581,7 @@ def plot_heatmap( ax.set_title(r"\rm Linear", fontsize=20, y=-0.08) cax1 = make_axes_locatable(ax).append_axes("right", size="5%", pad=0.5) colour_bar = fig.colorbar(mpl.cm.ScalarMappable(norm=norm, cmap=cmap), cax=cax1) + colour_bar.ax.tick_params(labelsize=22) if quad_fisher_df is not None: ax = fig.add_subplot(122) @@ -624,6 +599,7 @@ def plot_heatmap( colour_bar = fig.colorbar( mpl.cm.ScalarMappable(norm=norm, cmap=cmap), cax=cax1 ) + colour_bar.ax.tick_params(labelsize=22) fig.subplots_adjust(top=0.9) diff --git a/src/smefit/analyze/report.py b/src/smefit/analyze/report.py index 4bf8342b..ead33a68 100644 --- a/src/smefit/analyze/report.py +++ b/src/smefit/analyze/report.py @@ -523,7 +523,12 @@ def fisher( fishers = {} for fit in fit_list: compute_quad = fit.config["use_quad"] - fisher_cal = FisherCalculator(fit.coefficients, fit.datasets, compute_quad) + fisher_cal = FisherCalculator( + fit.coefficients, + fit.datasets, + fit.results["best_fit_point"], + compute_quad, + ) fisher_cal.compute_linear() fisher_cal.lin_fisher = fisher_cal.normalize( fisher_cal.lin_fisher, norm=norm, log=log @@ -535,9 +540,7 @@ def fisher( # if necessary compute the quadratic Fisher if compute_quad: - fisher_cal.compute_quadratic( - fit.results["samples"], fit.smeft_predictions - ) + fisher_cal.compute_quadratic() fisher_cal.quad_fisher = fisher_cal.normalize( fisher_cal.quad_fisher, norm=norm, log=log ) diff --git a/src/smefit/loader.py b/src/smefit/loader.py index e1430d6c..10f95bb8 100644 --- a/src/smefit/loader.py +++ b/src/smefit/loader.py @@ -26,6 +26,7 @@ "ExpNames", "NdataExp", "InvCovMat", + "CovMat", "ThCovMat", "Luminosity", "Replica", @@ -737,6 +738,7 @@ def load_datasets( np.array(exp_name), np.array(n_data_exp), np.linalg.inv(fit_covmat), + fit_covmat, theory_covariance, np.array(lumi_exp), replica,