Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 81 additions & 0 deletions src/smefit/external_chi2_smelli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
"""External χ² interface for smelli."""

from __future__ import annotations

from typing import Iterable, List

import numpy as np

from smefit.rge.rge import RGE


class SmelliChi2:
"""Compute a smelli likelihood contribution in a given SMEFiT fit."""

def __init__(
self,
coefficients,
rge_dict=None,
include_likelihoods: Iterable[str] | None = None,
initial_scale: float | None = None,
):
self.names = list(coefficients.name)

# Build smelli objects lazily to avoid hard import side effects.
from smelli import GlobalLikelihood
from wilson import Wilson

self._Wilson = Wilson
self.include_likelihoods = (
["likelihood_ewpt.yaml"]
if include_likelihoods is None
else list(include_likelihoods)
)

# Initial scale where SMEFT coefficients are defined.
if rge_dict is not None and "init_scale" in rge_dict:
self.scale = (
float(initial_scale)
if initial_scale is not None
else float(rge_dict.get("init_scale", 1e3))
)

# Translate SMEFiT basis to Warsaw basis.
self.translation_basis = RGE(
self.names,
self.scale,
rge_dict.get("smeft_accuracy", "integrate"),
rge_dict.get("adm_QCD", False),
rge_dict.get("yukawa", "top"),
).RGEbasis
else:
self.scale = float(initial_scale) if initial_scale is not None else 1e3
self.translation_basis = {op: {op: 1.0} for op in self.names}

self.gl = GlobalLikelihood(
eft="SMEFT", basis="Warsaw", include_likelihoods=self.include_likelihoods
)

def _build_wilson_dict(self, coefficient_values):
coeff_values = np.asarray(coefficient_values, dtype=float)
if coeff_values.shape[0] != len(self.names):
raise ValueError(
"Coefficient vector length does not match fitted coefficients "
f"({coeff_values.shape[0]} != {len(self.names)})."
)

wc_dict = {}
for op, c in zip(self.names, coeff_values):
for key, val in self.translation_basis[op].items():
wc_dict[key] = wc_dict.get(key, 0.0) + float(val) * c
return {k: v for k, v in wc_dict.items() if v != 0.0}

def compute_chi2(self, coefficient_values):
wc_dict = self._build_wilson_dict(coefficient_values)
w = self._Wilson(wc_dict, self.scale, eft="SMEFT", basis="Warsaw")
pt = self.gl.parameter_point(w)
return -2.0 * float(pt.log_likelihood_global())


# Backward-compatible class name for older runcards.
smelli_chi2 = SmelliChi2
39 changes: 36 additions & 3 deletions src/smefit/optimize/ultranest.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ def produce_all_params(self, params):
return all_params

@partial(jax.jit, static_argnames=["self"])
def gaussian_loglikelihood(self, params):
def gaussian_loglikelihood_jit(self, params):
"""Multi gaussian log likelihood function.

Parameters
Expand All @@ -358,6 +358,38 @@ def gaussian_loglikelihood(self, params):

return -0.5 * self.chi2_func_ns(all_params)

def gaussian_loglikelihood(self, params):
"""Gaussian log likelihood wrapper.

This keeps the JIT path for pure SMEFiT χ², while allowing external
likelihoods (e.g. smelli) to run in eager mode.
"""

if self.chi2_ext is None:
return self.gaussian_loglikelihood_jit(params)

all_params = self.produce_all_params(params)
return -0.5 * self.chi2_func_ext_ns(all_params)

def chi2_func_ext_ns(self, params):
"""Compatibility χ² path for non-JAX external likelihood contributions."""
if self.loaded_datasets is not None:
chi2_tot = chi2.compute_chi2(
self.loaded_datasets,
params,
self.use_quad,
self.use_multiplicative_prescription,
use_replica=False,
)
else:
chi2_tot = 0

for chi2_ext in self.chi2_ext:
chi2_ext_i = chi2_ext(params)
chi2_tot += chi2_ext_i

return chi2_tot

@partial(jax.jit, static_argnames=["self"])
def flat_prior(self, hypercube):
"""Update the prior function.
Expand All @@ -383,12 +415,13 @@ def run_sampling(self):
if self.store_raw:
log_dir = self.results_path / "ultranest_log"

if self.vectorized:
if self.vectorized and self.chi2_ext is None:
loglikelihood = jax.vmap(self.gaussian_loglikelihood)
flat_prior = jax.vmap(self.flat_prior)
else:
loglikelihood = self.gaussian_loglikelihood
flat_prior = self.flat_prior
ns_vectorized = self.vectorized and self.chi2_ext is None

_logger.info(f"Running fit with backend: {jbackend.get_backend().platform}")
t1 = time.time()
Expand All @@ -398,7 +431,7 @@ def run_sampling(self):
flat_prior,
log_dir=log_dir,
resume=True,
vectorized=self.vectorized,
vectorized=ns_vectorized,
)
if self.npar > 10:
# set up step sampler. Here, we use a differential evolution slice sampler:
Expand Down
Loading