Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions examples/bumphunt_example/run_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,11 +123,11 @@ def main():
model = DiphotonSoftmax(
n_cats=n_cats, temperature=1.0, mass_sigma=args.mass_sigma
)
optimizer = tf.keras.optimizers.RMSprop(0.1)
optimizer = tf.keras.optimizers.RMSprop(0.05)
lr_scheduler = LearningRateScheduler(
optimizer,
lr_initial=0.1,
lr_final=0.002,
lr_initial=0.05,
lr_final=0.001,
total_epochs=args.epochs,
mode="cosine",
)
Expand Down Expand Up @@ -202,6 +202,7 @@ def main():
list(range(n_cats)),
boundary_fname,
resolution=600,
annotation=f"Epoch {epoch}",
)
boundary_frames.append(boundary_fname)
print(
Expand Down
5 changes: 3 additions & 2 deletions examples/three_class_softmax_example/run_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def train_step(model, tdata, opt, lamY, lamU, thrY, thrU):
optimizer = tf.keras.optimizers.RMSprop(0.1)
lr_scheduler = LearningRateScheduler(
optimizer,
lr_initial=0.1,
lr_initial=0.05,
lr_final=0.001,
total_epochs=args.epochs,
mode="cosine",
Expand Down Expand Up @@ -264,7 +264,8 @@ def train_step(model, tdata, opt, lamY, lamU, thrY, thrU):
model,
[i for i in range(n_cats)],
boundary_fname,
resolution=500
resolution=500,
annotation=f"Epoch {ep}",
)
boundary_frames.append(boundary_fname)
loss_history.append(loss.numpy())
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ dependencies = [
"ml_dtypes>=0.4.1",
"mplhep",
"matplotlib>=3.7",
"pillow>=10.0",
"scipy",
"pandas",
"hist",
Expand Down Expand Up @@ -66,4 +67,4 @@ gpu = [
Homepage = "https://github.com/FloMau/gato-hep"
Repository = "https://github.com/FloMau/gato-hep"
Documentation = "https://gato-hep.readthedocs.io/en/latest/"
Issues = "https://github.com/FloMau/gato-hep/issues"
Issues = "https://github.com/FloMau/gato-hep/issues"
62 changes: 44 additions & 18 deletions src/gatohep/plotting_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
import os
import tensorflow as tf
import tensorflow_probability as tfp
import matplotlib.animation as animation
from matplotlib.colors import BoundaryNorm, ListedColormap
from matplotlib.patches import Ellipse
from PIL import Image

from gatohep.utils import build_mass_histograms

Expand Down Expand Up @@ -647,7 +647,7 @@ def plot_bin_boundaries_2D(
path_plot,
*,
resolution: int = 1000,
reduce: bool = False, # kept for API compatibility, ignored for dim==2
annotation: str | None = None,
):
"""
Plot hard-bin regions of a *2-D* GMM on the 2-simplex face
Expand All @@ -664,6 +664,9 @@ def plot_bin_boundaries_2D(
Grid resolution per axis. Default 500.
reduce : bool, optional
Ignored (for backward compatibility with older callers).
annotation : str, optional
Text rendered in the upper-left corner of the axes (useful for
tagging epochs in GIF frames). Default is ``None``.
"""
if model.dim != 2:
raise ValueError("This helper expects a 2D model.")
Expand Down Expand Up @@ -738,8 +741,20 @@ def plot_bin_boundaries_2D(
ax.set_xlabel("Discriminant dim. 0", fontsize=24)
ax.set_ylabel("Discriminant dim. 1", fontsize=24)

if annotation:
ax.text(
0.1,
0.96,
annotation,
transform=ax.transAxes,
ha="left",
va="top",
fontsize=18,
weight="bold",
)

plt.tight_layout()
plt.savefig(path_plot)
plt.savefig(path_plot, bbox_inches="tight", pad_inches=0.05)
plt.close(fig)


Expand Down Expand Up @@ -917,21 +932,32 @@ def plot_gmm_1d(model, output_filename, x_range=(0.0, 1.0), n_points=10_000):


def make_gif(frame_files, out_name, interval=800):
fig = plt.figure(figsize=(6, 4))
plt.axis("off")
"""
Assemble a set of pre-rendered PNGs into an animated GIF without extra margins.

Parameters
----------
frame_files : Sequence[str]
Ordered list of frame file paths.
out_name : str
Destination GIF path.
interval : int, optional
Frame duration in milliseconds. Default is 800 ms.
"""
if not frame_files:
raise ValueError("No frames provided for GIF creation.")

ims = []
images: list[Image.Image] = []
for fname in frame_files:
img = plt.imread(fname)
im = plt.imshow(img, animated=True)
ims.append([im])

ani = animation.ArtistAnimation(
fig, ims,
interval=interval,
blit=True,
repeat_delay=1000
with Image.open(fname) as img:
images.append(img.copy())

first, *rest = images
first.save(
out_name,
save_all=True,
append_images=rest,
duration=interval,
loop=0,
disposal=2,
)
# This requires that pillow is available (it's a dependency of matplotlib)
ani.save(out_name, writer="pillow")
plt.close(fig)