Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
184 changes: 165 additions & 19 deletions examples/bumphunt_example/run_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,20 @@
from gatohep.losses import high_bkg_uncertainty_penalty, low_bkg_penalty
from gatohep.models import gato_gmm_model
from gatohep.plotting_utils import (
assign_bins_and_order,
make_gif,
plot_bias_history,
plot_bin_boundaries_2D,
plot_category_mass_spectra,
plot_history,
plot_inclusive_mass,
plot_significance_comparison,
plot_yield_vs_uncertainty,
)
from gatohep.utils import (
LearningRateScheduler,
TemperatureScheduler,
asymptotic_significance,
build_category_mass_maps,
compute_mass_reweight_factors,
convert_mass_data_to_tensors,
Expand Down Expand Up @@ -71,16 +74,88 @@ def call(self, data_dict, reweight=None, reweight_processes=None):
return loss, bkg_yield, bkg_sum_w2, z1, z2


@tf.function
def train_step(model, data, opt, reweight, lamY, lamU, thrY, thrU):
with tf.GradientTape() as tape:
loss, B_sig, B_sig_w2, z1, z2 = model.call(data, reweight)
penalty_y = low_bkg_penalty(B_sig, threshold=thrY)
penalty_u = high_bkg_uncertainty_penalty(B_sig_w2, B_sig, rel_threshold=thrU)
total = loss + lamY * penalty_y + lamU * penalty_u
grads = tape.gradient(total, model.trainable_variables)
opt.apply_gradients(zip(grads, model.trainable_variables))
return total, loss, penalty_y, penalty_u, z1, z2
def compute_significances_from_assignments(
assignments, data_dict, n_bins, mass_low, mass_high
):
"""
Sum signal/background yields per bin using a provided assignment map.

Parameters
----------
assignments : dict[str, np.ndarray]
Hard bin indices per process (negative entries ignored).
data_dict : Mapping[str, pandas.DataFrame]
Event tables containing ``mass`` and ``weight`` columns.
n_bins : int
Number of categories / bins.
mass_low, mass_high : float
Higgs-window boundaries.

Returns
-------
tuple[float, float]
Significances for ``signal1`` and ``signal2``.
"""
s1 = np.zeros(n_bins, dtype=np.float64)
s2 = np.zeros_like(s1)
bkg = np.zeros_like(s1)

for proc, assign in assignments.items():
if assign.size == 0:
continue
df = data_dict[proc]
masses = df["mass"].values
weights = df["weight"].values
mask = (
(assign >= 0)
& (masses >= mass_low)
& (masses <= mass_high)
)
if not np.any(mask):
continue
bins = assign[mask]
w = weights[mask]
accum = np.zeros(n_bins, dtype=np.float64)
np.add.at(accum, bins, w)
if proc == "signal1":
s1 += accum
elif proc == "signal2":
s2 += accum
else:
bkg += accum

s1_tf = tf.constant(s1, dtype=tf.float32)
s2_tf = tf.constant(s2, dtype=tf.float32)
bkg_tf = tf.constant(bkg, dtype=tf.float32)

z1_bins = asymptotic_significance(s1_tf, bkg_tf + s2_tf)
z2_bins = asymptotic_significance(s2_tf, bkg_tf + s1_tf)
z1 = float(tf.sqrt(tf.reduce_sum(z1_bins**2)))
z2 = float(tf.sqrt(tf.reduce_sum(z2_bins**2)))
return z1, z2


def build_argmax_assignments(data_dict, nbins, sig_index):
"""
Produce equidistant bin indices based on a softmax component.

Only events whose argmax equals ``sig_index`` receive a valid bin,
reproducing the baseline used in the three-class example.
"""
edges = np.linspace(0.33, 1.0, nbins + 1, dtype=np.float32)
assignments = {}
for proc, df in data_dict.items():
if df.empty:
assignments[proc] = np.array([], dtype=np.int32)
continue
outputs = np.stack(df["NN_output"].values)
argmax = np.argmax(outputs, axis=1)
values = outputs[:, sig_index]
bins = np.clip(np.digitize(values, edges, right=False) - 1, 0, nbins - 1)
valid = argmax == sig_index
assign = np.where(valid, bins, -1).astype(np.int32)
assignments[proc] = assign
return assignments


def main():
Expand Down Expand Up @@ -113,10 +188,25 @@ def main():
data_2d = slice_to_2d_features(data_full)
tensor_data = convert_mass_data_to_tensors(data_2d)

plot_inclusive_mass(data_2d, path_plots, sig_scales=(50, 250))

sig_low = 125.0 - args.mass_sigma
sig_high = 125.0 + args.mass_sigma
plot_inclusive_mass(data_2d, path_plots, sig_scales=(50, 250))

baseline_bins = [2, 5, 10]
baseline_results = {"signal1": {}, "signal2": {}}
gato_results = {"signal1": {}, "signal2": {}}

for nbins in baseline_bins:
for sig_idx, sig_name in enumerate(("signal1", "signal2")):
assignments = build_argmax_assignments(data_2d, nbins, sig_idx)
z1, z2 = compute_significances_from_assignments(
assignments,
data_2d,
nbins,
sig_low,
sig_high,
)
baseline_results[sig_name][nbins] = z1 if sig_idx == 0 else z2

for n_cats in args.gato_bins:
print(f"\n--- Optimising {n_cats} bins ---")
Expand All @@ -139,10 +229,24 @@ def main():
mode="cosine",
)

@tf.function
def train_step(tdata, reweight_tensor, lamY, lamU, thrY, thrU):
with tf.GradientTape() as tape:
loss, B_sig, B_sig_w2, z1, z2 = model.call(tdata, reweight_tensor)
penalty_y = low_bkg_penalty(B_sig, threshold=thrY)
penalty_u = high_bkg_uncertainty_penalty(
B_sig_w2, B_sig, rel_threshold=thrU
)
total = loss + lamY * penalty_y + lamU * penalty_u
grads = tape.gradient(total, model.trainable_variables)
optimizer.apply_gradients(zip(grads, model.trainable_variables))
return total, loss, penalty_y, penalty_u, z1, z2, B_sig

reweight = tf.ones(n_cats, dtype=tf.float32)
loss_history = []
penalty_y_hist = []
penalty_u_hist = []
continuum_history = []
bias_history = []
bias_epochs = []
temp_history = []
Expand All @@ -164,10 +268,8 @@ def main():
reweight = tf.constant(factors, dtype=tf.float32)
print(f"Updated reweight factors: {factors}")

_, loss, penY, penU, z1, z2 = train_step(
model,
_, loss, penY, penU, z1, z2, B_bins = train_step(
tensor_data,
optimizer,
reweight,
args.lam_yield,
args.lam_unc,
Expand All @@ -180,6 +282,11 @@ def main():
loss_history.append(float(loss.numpy()))
penalty_y_hist.append(float(penY.numpy()))
penalty_u_hist.append(float(penU.numpy()))
reweight_np = reweight.numpy()
B_np = B_bins.numpy()
continuum_history.append(
B_np / np.maximum(reweight_np, 1e-6)
)

if epoch % 10 == 0 or epoch == args.epochs - 1:
lr_value = getattr(optimizer, "learning_rate", getattr(optimizer, "lr", None))
Expand Down Expand Up @@ -240,6 +347,23 @@ def main():
y_label="High-uncertainty penalty",
x_label="Epoch",
)
plot_history(
np.array(continuum_history),
os.path.join(path_bins, f"continuum_background_{n_cats}.pdf"),
y_label="Continuum background (100-180 GeV)",
x_label="Epoch",
boundaries=True,
boundary_labels=[f"Cat. {i}" for i in range(n_cats)],
)
plot_history(
np.array(continuum_history),
os.path.join(path_bins, f"continuum_background_{n_cats}_log.pdf"),
y_label="Continuum background (100-180 GeV)",
x_label="Epoch",
boundaries=True,
log_scale=True,
boundary_labels=[f"Cat. {i}" for i in range(n_cats)],
)
plot_bias_history(
bias_history,
os.path.join(path_bins, f"bias_history_{n_cats}.pdf"),
Expand All @@ -248,11 +372,24 @@ def main():
temp_label="Temperature",
)

assignments = model.get_bin_indices(
raw_assign = model.get_bin_indices(
{p: {"NN_output": tensor_data[p]["NN_output"]} for p in tensor_data}
)
assign_np = {k: v.numpy() for k, v in assignments.items()}
per_cat_hists = build_category_mass_maps(assign_np, data_2d, n_cats)
raw_assign_np = {k: v.numpy() for k, v in raw_assign.items()}

assign_dict, order, _, inv = assign_bins_and_order(model, data_2d, reduce=False)
assign_np = {k: np.asarray(v) for k, v in assign_dict.items()}

z1_opt, z2_opt = compute_significances_from_assignments(
raw_assign_np,
data_2d,
n_cats,
sig_low,
sig_high,
)
gato_results["signal1"][n_cats] = z1_opt
gato_results["signal2"][n_cats] = z2_opt
per_cat_hists = build_category_mass_maps(raw_assign_np, data_2d, n_cats)
plot_category_mass_spectra(
per_cat_hists,
os.path.join(path_bins, "mass_spectra"),
Expand All @@ -261,7 +398,7 @@ def main():

plot_bin_boundaries_2D(
model,
list(range(n_cats)),
order,
os.path.join(path_bins, f"bin_boundaries_{n_cats}_bins.pdf"),
)

Expand All @@ -284,6 +421,15 @@ def main():
log=True,
)

remapped_baseline = {
sig: {2 * n + 1: baseline_results[sig][n] for n in baseline_results[sig]}
for sig in baseline_results
}
plot_significance_comparison(
remapped_baseline,
gato_results,
os.path.join(path_plots, "significance_comparison.pdf"),
)

if __name__ == "__main__":
main()
24 changes: 18 additions & 6 deletions src/gatohep/plotting_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,7 @@ def plot_history(
y_label="Value",
x_label="Epoch",
boundaries=False,
boundary_labels=None,
title=None,
log_scale=False,
):
Expand Down Expand Up @@ -299,23 +300,34 @@ def plot_history(
-------
None
"""
epochs = np.arange(len(history_data))
values = np.asarray(history_data, dtype=float)
epochs = np.arange(values.shape[0])
fig, ax = plt.subplots(figsize=(8, 6))

if not boundaries: # scalar history
ax.plot(epochs, history_data, marker="o")
else: # matrix (epochs, n_tracks)
values = np.asarray(history_data, dtype=float)
multi_series = boundaries or (
values.ndim > 1 and (values.shape[1] if values.ndim > 1 else 1) > 1
)

if not multi_series:
ax.plot(epochs, values, marker="o")
else:
if values.ndim == 1:
values = values[:, None]
n_trk = values.shape[1]
cmap = plt.get_cmap("tab20", n_trk)
labels = (
boundary_labels
if boundary_labels is not None and len(boundary_labels) == n_trk
else [f"Series {t + 1}" for t in range(n_trk)]
)
for t in range(n_trk):
ax.plot(
epochs,
values[:, t],
marker="o",
markersize=3,
color=cmap(t),
label=f"Boundary {t + 1}",
label=labels[t],
)

ax.set_xlabel(x_label, fontsize=22)
Expand Down
27 changes: 13 additions & 14 deletions src/gatohep/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,13 +370,14 @@ def compute_mass_reweight_factors(
mass_sb_high=180.0,
mass_sig_low=123.5,
mass_sig_high=126.5,
nbins=60,
nbins=10,
):
"""
Fit an exponential to each category's diphoton-mass spectrum and
return per-bin factors that map the continuum yield in the full
sideband (100-180 GeV by default) to the yield expected in the
signal window (125 +/- 1 sigma).

"""

def is_signal(name: str) -> bool:
Expand Down Expand Up @@ -438,16 +439,13 @@ def integral_exp(A, B, x1, x2):

edges = hist_obj.axes[0].edges
centers = 0.5 * (edges[:-1] + edges[1:])
errs = np.sqrt(np.maximum(hist_obj.variances(), 1e-12))
p0 = [max(vals[0], 1e-6), -0.03]
try:
(A, B), _ = curve_fit(
exp_func,
centers,
vals,
p0=p0,
sigma=errs,
absolute_sigma=True,
maxfev=2000,
)
pred_sig = integral_exp(A, B, mass_sig_low, mass_sig_high) / bin_width
Expand Down Expand Up @@ -689,15 +687,16 @@ def build_category_mass_maps(
if cat_ids is None:
continue
mask = cat_ids == k
if not np.any(mask):
continue
proc_hists[proc] = create_hist(
df["mass"].values[mask],
weights=df["weight"].values[mask],
bins=bins,
low=mass_range[0],
high=mass_range[1],
name=axis_name,
)
masses = df["mass"].values[mask]
weights = df["weight"].values[mask]
h = hist.Hist.new.Reg(
bins,
mass_range[0],
mass_range[1],
name=axis_name
).Weight()
if masses.size:
h.fill(masses, weight=weights)
proc_hists[proc] = h
per_cat.append(proc_hists)
return per_cat