diff --git a/examples/bumphunt_example/run_example.py b/examples/bumphunt_example/run_example.py index f1eb4c1..21d912e 100644 --- a/examples/bumphunt_example/run_example.py +++ b/examples/bumphunt_example/run_example.py @@ -123,11 +123,11 @@ def main(): model = DiphotonSoftmax( n_cats=n_cats, temperature=1.0, mass_sigma=args.mass_sigma ) - optimizer = tf.keras.optimizers.RMSprop(0.1) + optimizer = tf.keras.optimizers.RMSprop(0.05) lr_scheduler = LearningRateScheduler( optimizer, - lr_initial=0.1, - lr_final=0.002, + lr_initial=0.05, + lr_final=0.001, total_epochs=args.epochs, mode="cosine", ) @@ -202,6 +202,7 @@ def main(): list(range(n_cats)), boundary_fname, resolution=600, + annotation=f"Epoch {epoch}", ) boundary_frames.append(boundary_fname) print( diff --git a/examples/three_class_softmax_example/run_example.py b/examples/three_class_softmax_example/run_example.py index 569bebf..6a2c0a7 100644 --- a/examples/three_class_softmax_example/run_example.py +++ b/examples/three_class_softmax_example/run_example.py @@ -190,7 +190,7 @@ def train_step(model, tdata, opt, lamY, lamU, thrY, thrU): optimizer = tf.keras.optimizers.RMSprop(0.1) lr_scheduler = LearningRateScheduler( optimizer, - lr_initial=0.1, + lr_initial=0.05, lr_final=0.001, total_epochs=args.epochs, mode="cosine", @@ -264,7 +264,8 @@ def train_step(model, tdata, opt, lamY, lamU, thrY, thrU): model, [i for i in range(n_cats)], boundary_fname, - resolution=500 + resolution=500, + annotation=f"Epoch {ep}", ) boundary_frames.append(boundary_fname) loss_history.append(loss.numpy()) diff --git a/pyproject.toml b/pyproject.toml index b487e8d..572dad8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,6 +20,7 @@ dependencies = [ "ml_dtypes>=0.4.1", "mplhep", "matplotlib>=3.7", + "pillow>=10.0", "scipy", "pandas", "hist", @@ -66,4 +67,4 @@ gpu = [ Homepage = "https://github.com/FloMau/gato-hep" Repository = "https://github.com/FloMau/gato-hep" Documentation = "https://gato-hep.readthedocs.io/en/latest/" -Issues = "https://github.com/FloMau/gato-hep/issues" \ No newline at end of file +Issues = "https://github.com/FloMau/gato-hep/issues" diff --git a/src/gatohep/plotting_utils.py b/src/gatohep/plotting_utils.py index 64308d4..d779c01 100644 --- a/src/gatohep/plotting_utils.py +++ b/src/gatohep/plotting_utils.py @@ -5,9 +5,9 @@ import os import tensorflow as tf import tensorflow_probability as tfp -import matplotlib.animation as animation from matplotlib.colors import BoundaryNorm, ListedColormap from matplotlib.patches import Ellipse +from PIL import Image from gatohep.utils import build_mass_histograms @@ -647,7 +647,7 @@ def plot_bin_boundaries_2D( path_plot, *, resolution: int = 1000, - reduce: bool = False, # kept for API compatibility, ignored for dim==2 + annotation: str | None = None, ): """ Plot hard-bin regions of a *2-D* GMM on the 2-simplex face @@ -664,6 +664,9 @@ def plot_bin_boundaries_2D( Grid resolution per axis. Default 500. reduce : bool, optional Ignored (for backward compatibility with older callers). + annotation : str, optional + Text rendered in the upper-left corner of the axes (useful for + tagging epochs in GIF frames). Default is ``None``. """ if model.dim != 2: raise ValueError("This helper expects a 2D model.") @@ -738,8 +741,20 @@ def plot_bin_boundaries_2D( ax.set_xlabel("Discriminant dim. 0", fontsize=24) ax.set_ylabel("Discriminant dim. 1", fontsize=24) + if annotation: + ax.text( + 0.1, + 0.96, + annotation, + transform=ax.transAxes, + ha="left", + va="top", + fontsize=18, + weight="bold", + ) + plt.tight_layout() - plt.savefig(path_plot) + plt.savefig(path_plot, bbox_inches="tight", pad_inches=0.05) plt.close(fig) @@ -917,21 +932,32 @@ def plot_gmm_1d(model, output_filename, x_range=(0.0, 1.0), n_points=10_000): def make_gif(frame_files, out_name, interval=800): - fig = plt.figure(figsize=(6, 4)) - plt.axis("off") + """ + Assemble a set of pre-rendered PNGs into an animated GIF without extra margins. + + Parameters + ---------- + frame_files : Sequence[str] + Ordered list of frame file paths. + out_name : str + Destination GIF path. + interval : int, optional + Frame duration in milliseconds. Default is 800 ms. + """ + if not frame_files: + raise ValueError("No frames provided for GIF creation.") - ims = [] + images: list[Image.Image] = [] for fname in frame_files: - img = plt.imread(fname) - im = plt.imshow(img, animated=True) - ims.append([im]) - - ani = animation.ArtistAnimation( - fig, ims, - interval=interval, - blit=True, - repeat_delay=1000 + with Image.open(fname) as img: + images.append(img.copy()) + + first, *rest = images + first.save( + out_name, + save_all=True, + append_images=rest, + duration=interval, + loop=0, + disposal=2, ) - # This requires that pillow is available (it's a dependency of matplotlib) - ani.save(out_name, writer="pillow") - plt.close(fig)