diff --git a/src/smefit/external_chi2_smelli.py b/src/smefit/external_chi2_smelli.py new file mode 100644 index 00000000..482c654b --- /dev/null +++ b/src/smefit/external_chi2_smelli.py @@ -0,0 +1,81 @@ +"""External χ² interface for smelli.""" + +from __future__ import annotations + +from typing import Iterable, List + +import numpy as np + +from smefit.rge.rge import RGE + + +class SmelliChi2: + """Compute a smelli likelihood contribution in a given SMEFiT fit.""" + + def __init__( + self, + coefficients, + rge_dict=None, + include_likelihoods: Iterable[str] | None = None, + initial_scale: float | None = None, + ): + self.names = list(coefficients.name) + + # Build smelli objects lazily to avoid hard import side effects. + from smelli import GlobalLikelihood + from wilson import Wilson + + self._Wilson = Wilson + self.include_likelihoods = ( + ["likelihood_ewpt.yaml"] + if include_likelihoods is None + else list(include_likelihoods) + ) + + # Initial scale where SMEFT coefficients are defined. + if rge_dict is not None and "init_scale" in rge_dict: + self.scale = ( + float(initial_scale) + if initial_scale is not None + else float(rge_dict.get("init_scale", 1e3)) + ) + + # Translate SMEFiT basis to Warsaw basis. + self.translation_basis = RGE( + self.names, + self.scale, + rge_dict.get("smeft_accuracy", "integrate"), + rge_dict.get("adm_QCD", False), + rge_dict.get("yukawa", "top"), + ).RGEbasis + else: + self.scale = float(initial_scale) if initial_scale is not None else 1e3 + self.translation_basis = {op: {op: 1.0} for op in self.names} + + self.gl = GlobalLikelihood( + eft="SMEFT", basis="Warsaw", include_likelihoods=self.include_likelihoods + ) + + def _build_wilson_dict(self, coefficient_values): + coeff_values = np.asarray(coefficient_values, dtype=float) + if coeff_values.shape[0] != len(self.names): + raise ValueError( + "Coefficient vector length does not match fitted coefficients " + f"({coeff_values.shape[0]} != {len(self.names)})." + ) + + wc_dict = {} + for op, c in zip(self.names, coeff_values): + for key, val in self.translation_basis[op].items(): + wc_dict[key] = wc_dict.get(key, 0.0) + float(val) * c + return {k: v for k, v in wc_dict.items() if v != 0.0} + + def compute_chi2(self, coefficient_values): + wc_dict = self._build_wilson_dict(coefficient_values) + w = self._Wilson(wc_dict, self.scale, eft="SMEFT", basis="Warsaw") + pt = self.gl.parameter_point(w) + return -2.0 * float(pt.log_likelihood_global()) + + +# Backward-compatible class name for older runcards. +smelli_chi2 = SmelliChi2 diff --git a/src/smefit/optimize/ultranest.py b/src/smefit/optimize/ultranest.py index 7ce4ecf2..dbc069a0 100644 --- a/src/smefit/optimize/ultranest.py +++ b/src/smefit/optimize/ultranest.py @@ -340,7 +340,7 @@ def produce_all_params(self, params): return all_params @partial(jax.jit, static_argnames=["self"]) - def gaussian_loglikelihood(self, params): + def gaussian_loglikelihood_jit(self, params): """Multi gaussian log likelihood function. Parameters @@ -358,6 +358,38 @@ def gaussian_loglikelihood(self, params): return -0.5 * self.chi2_func_ns(all_params) + def gaussian_loglikelihood(self, params): + """Gaussian log likelihood wrapper. + + This keeps the JIT path for pure SMEFiT χ², while allowing external + likelihoods (e.g. smelli) to run in eager mode. + """ + + if self.chi2_ext is None: + return self.gaussian_loglikelihood_jit(params) + + all_params = self.produce_all_params(params) + return -0.5 * self.chi2_func_ext_ns(all_params) + + def chi2_func_ext_ns(self, params): + """Compatibility χ² path for non-JAX external likelihood contributions.""" + if self.loaded_datasets is not None: + chi2_tot = chi2.compute_chi2( + self.loaded_datasets, + params, + self.use_quad, + self.use_multiplicative_prescription, + use_replica=False, + ) + else: + chi2_tot = 0 + + for chi2_ext in self.chi2_ext: + chi2_ext_i = chi2_ext(params) + chi2_tot += chi2_ext_i + + return chi2_tot + @partial(jax.jit, static_argnames=["self"]) def flat_prior(self, hypercube): """Update the prior function. @@ -383,12 +415,13 @@ def run_sampling(self): if self.store_raw: log_dir = self.results_path / "ultranest_log" - if self.vectorized: + if self.vectorized and self.chi2_ext is None: loglikelihood = jax.vmap(self.gaussian_loglikelihood) flat_prior = jax.vmap(self.flat_prior) else: loglikelihood = self.gaussian_loglikelihood flat_prior = self.flat_prior + ns_vectorized = self.vectorized and self.chi2_ext is None _logger.info(f"Running fit with backend: {jbackend.get_backend().platform}") t1 = time.time() @@ -398,7 +431,7 @@ def run_sampling(self): flat_prior, log_dir=log_dir, resume=True, - vectorized=self.vectorized, + vectorized=ns_vectorized, ) if self.npar > 10: # set up step sampler. Here, we use a differential evolution slice sampler: