diff --git a/colibri/app.py b/colibri/app.py index 7da2809d..9292a970 100644 --- a/colibri/app.py +++ b/colibri/app.py @@ -32,6 +32,10 @@ "colibri.param_initialisation", "colibri.export_results", "colibri.closure_test", + "colibri.ntk.eigenvalues", + "colibri.ntk.eigenvector", + "colibri.ntk.plotntk", + "colibri.ntk.ntk", "reportengine.report", ] diff --git a/colibri/config.py b/colibri/config.py index 7e6d9334..920f356a 100644 --- a/colibri/config.py +++ b/colibri/config.py @@ -10,6 +10,7 @@ import logging import os import shutil +from pathlib import Path import jax import jax.numpy as jnp @@ -18,10 +19,11 @@ from colibri.constants import FLAVOUR_TO_ID_MAPPING from colibri.core import IntegrabilitySettings, PriorSettings from mpi4py import MPI -from reportengine.configparser import ConfigError, explicit_node +from reportengine.configparser import ConfigError, explicit_node, element_of from validphys import covmats -from validphys.config import Config, Environment +from validphys.config import Config, Environment, _id_with_label from validphys.fkparser import load_fktable +from validphys.loader import LoadFailedError comm = MPI.COMM_WORLD rank = comm.Get_rank() @@ -622,3 +624,20 @@ def produce_pdf_model(self): Returns None as the pdf_model is not used in the colibri module. """ return None + + def produce_replicas_path(self, fit): + """ + Produces the replicas folder where the fit replicas are stored. + """ + replicas_path = fit.path / "fit_replicas" + replicas_path.mkdir(parents=True, exist_ok=True) + return replicas_path + + @element_of("fits") + @_id_with_label + def parse_fit(self, fit: str): + """A fit in the results folder, containing at least a valid filter result.""" + try: + return self.loader.check_fit(fit) + except LoadFailedError as e: + raise ConfigError(str(e), fit, self.loader.available_fits) diff --git a/colibri/core.py b/colibri/core.py index bd23dbd7..7eeb8efc 100644 --- a/colibri/core.py +++ b/colibri/core.py @@ -167,12 +167,15 @@ class MonteCarloFit: Array containing the validation loss. optimized_parameters: jnp.array Array containing the optimized parameters. + parameters_by_epoch: jnp.array + Array containing the parameters of the model recorded during training. """ monte_carlo_specs: dict training_loss: jnp.array validation_loss: jnp.array optimized_parameters: jnp.array + parameters_by_epoch: jnp.array @dataclass(frozen=True) @@ -189,12 +192,15 @@ class GradientDescentResult: Recorded (epoch) validation losses (sampled according to record_every). specs: dict Dictionary of settings used for the run (epochs, batch size, etc.). + parameters_by_epoch: jnp.array + Array containing the parameters of the model recorded during training. """ optimized_parameters: Any training_loss: jnp.array validation_loss: jnp.array specs: Dict[str, Any] + parameters_by_epoch: jnp.array @dataclass(frozen=True) diff --git a/colibri/doc/sphinx/source/theory/index.rst b/colibri/doc/sphinx/source/theory/index.rst index 1b8bb4d0..d428269a 100644 --- a/colibri/doc/sphinx/source/theory/index.rst +++ b/colibri/doc/sphinx/source/theory/index.rst @@ -18,3 +18,5 @@ This section discusses some relevant theoretical background to Colibri. ./prior_distributions.rst ./inference_methods.rst + + ./ntk_theory.rst diff --git a/colibri/doc/sphinx/source/theory/ntk_theory.rst b/colibri/doc/sphinx/source/theory/ntk_theory.rst new file mode 100644 index 00000000..36b8a875 --- /dev/null +++ b/colibri/doc/sphinx/source/theory/ntk_theory.rst @@ -0,0 +1,6 @@ +.. _ntk_theory: + +=========================== +Neural Tangent Kernel (NTK) +=========================== + diff --git a/colibri/doc/sphinx/source/tutorials/index.rst b/colibri/doc/sphinx/source/tutorials/index.rst index 92f4dece..5b66d6e7 100644 --- a/colibri/doc/sphinx/source/tutorials/index.rst +++ b/colibri/doc/sphinx/source/tutorials/index.rst @@ -17,3 +17,5 @@ such as implementing a custom PDF model or using it to fit data. closure_tests/index scripts/index + + ntk/computing_ntk \ No newline at end of file diff --git a/colibri/doc/sphinx/source/tutorials/ntk/computing_ntk.rst b/colibri/doc/sphinx/source/tutorials/ntk/computing_ntk.rst new file mode 100644 index 00000000..1e31072d --- /dev/null +++ b/colibri/doc/sphinx/source/tutorials/ntk/computing_ntk.rst @@ -0,0 +1,79 @@ +.. _computing_ntk: + +===================================== +Computing Neural Tangent Kernel (NTK) +===================================== + +This section describes how to compute and analyse this Neural Tangent Kernel (NTK) +during the training of a given fit. For information on what the NTK is, please +refer to the :ref:`NTK theory section `. + +The NTK can be computed on any fit where the parameter values have been saved during +training. To produce these saved parameter values during a fit set the +`record_parameters` argument to `True` in the fit runcard. + +.. code-block:: bash + + record_parameters: True + record_every: 100 + +This above example will save the parameter values every 100 epochs. + +Once the parameter values have been saved, the NTK can be computed using the `compute_ntk` +action. An example runcard to compute the NTK is shown below. + +.. code-block:: bash + + meta: + title: Example NTK determination for a Les Houches fit + author: Colibri User + keywords: [example, PDF plots, NTK] + + pdf: {id: "test_ntk_fit"} # generated by the colibri fit + + ntk_plots_settings: + ntk_plots: True + n_top_eigenvalues: 10 # Number of top eigenvalues for which evolution is plotted + y_scale: log # Sets the scale for eigenvalue evolution plot. + plot_n_epochs: 5 # The number of epochs for which the eigenvector plots will be plotted. + # These will be equally spaced among all epochs. Default 6. + plot_n_eigenvectors: [1, 2] # Eigenvector plots will be produced for these. Default 1. + + actions_: + - compute_ntk + +This runcard can be run with the command +``colibri_model_exe compute_ntk.yaml -r replica_n``, where the ``colibri_model_exe`` +is the Colibri executable, and ``n`` is the replica number for which the NTK +should be computed. This will produce a folder called ``compute_ntk``. The NTK values +are stored in npz file format in ``compute_ntk/ntk_replicas/replica_n``. + +``ntk_plots_settings`` +^^^^^^^^^^^^^^^^^^^^^^ +* ``ntk_plots``: + Whether analytic plots are produced with the ``compute_ntk`` command. For details on + the plots produced, see the :ref:`analytic NTK plots section below `. +* ``n_top_eigenvalues``: + This is the number of top eigenvalues that will be plotted when ``ntk_plots`` is + ``True``. If a number is not specified, a default of 1 eigenvalue will be plotted. +* ``x_scale`` / ``y_scale``: + If set to ``log``, the eigenvalue evolution plot will be produced with a logarithmic + scale on the respective axis. Otherwise a linear scale will be used. +* ``plot_n_epochs``: + The number of epochs for which eigenvector plots will be produced. + These will be equally spaced among all epochs. Default 6. +* ``plot_n_eigenvectors``: + The eigenvectors for which eigenvector plots will be produced. Default 1. + +.. _analytic_ntk_plots: + +Analytic NTK plots +================== +If the flag ``ntk_plots`` is set to ``True``, the following analytic plots will be +produced and stored in .pdf format in ``compute_ntk/ntk_plots/``. + +* **NTK eigenvalue evolution:** Value of the NTK eigenvalues for increasing number of recorded epochs. + +* **NTK eigenvector evolution:** Eigenvector evolution for a given number of recorded epochs, for all flavours. + +* **Eigenvector heatmaps:** Density of NTK eigenvectors across PDF flavours and x-regions. \ No newline at end of file diff --git a/colibri/examples/les_houches_example/runcards/lh_fit_closure_test.yaml b/colibri/examples/les_houches_example/runcards/lh_fit_closure_test.yaml index e3b5426e..533b25c8 100644 --- a/colibri/examples/les_houches_example/runcards/lh_fit_closure_test.yaml +++ b/colibri/examples/les_houches_example/runcards/lh_fit_closure_test.yaml @@ -98,4 +98,4 @@ ultranest_settings: actions_: -- run_ultranest_fit # Choose from ultranest_fit, monte_carlo_fit, analytic_fit \ No newline at end of file +- run_ultranest_fit # Choose from ultranest_fit, monte_carlo_fit, analytic_fit diff --git a/colibri/gradient_descent.py b/colibri/gradient_descent.py index 293cd30e..203a536d 100644 --- a/colibri/gradient_descent.py +++ b/colibri/gradient_descent.py @@ -35,6 +35,7 @@ def run_gradient_descent( max_epochs: int, data_batch: colibri.DataBatches = None, record_every: int = 50, + record_parameters: bool = False, ) -> GradientDescentResult: """Generic gradient descent loop. @@ -64,7 +65,11 @@ def run_gradient_descent( Defaults to None, in which case the loss is assumed to not have been batched. record_every : int, default 50 - Record losses every this many epochs. + Record losses every this many epochs. If `record_parameters` is True, + parameters of the model are also recorded. + + record_parameters : bool, default False + Whether to record parameters every `record_every` epochs. """ params = initial_parameters @@ -79,6 +84,7 @@ def _step(p, ostate, batch_idx): train_losses = [] val_losses = [] + parameters_by_epoch = [] if record_parameters else None if data_batch is None: # we simulate a fake batch iterator that just yields a dummy batch index @@ -102,6 +108,7 @@ def batch_gen(): batch_idx = next(batches_iter) params, opt_state, batch_loss = _step(params, opt_state, batch_idx) epoch_train_loss += batch_loss + epoch_val_loss = validation_loss_fn(params) early_stopper = early_stopper.update(epoch_val_loss) @@ -114,6 +121,11 @@ def batch_gen(): train_losses.append(epoch_train_loss) val_losses.append(epoch_val_loss) + if record_parameters: + if epoch % record_every == 0: + log.info(f"Recording parameters at epoch {epoch}") + parameters_by_epoch.append(params) + if early_stopper.should_stop: log.info(f"Early stopping at epoch {epoch}") break @@ -127,4 +139,5 @@ def batch_gen(): "batch_size": batch_size, "record_every": record_every, }, + parameters_by_epoch=jnp.array(parameters_by_epoch), ) diff --git a/colibri/monte_carlo_fit.py b/colibri/monte_carlo_fit.py index 50ee6085..fb741ed7 100644 --- a/colibri/monte_carlo_fit.py +++ b/colibri/monte_carlo_fit.py @@ -30,6 +30,8 @@ def monte_carlo_fit( max_epochs, batch_size=None, batch_seed=1, + record_parameters=False, + record_every=50, ): """ This function performs a Monte Carlo fit. @@ -61,6 +63,13 @@ def monte_carlo_fit( batch_seed: int, optional Seed used to construct the batches. Defaults to 1. + record_parameters: bool, default False + Whether to monitor the parameters during the Monte Carlo fit. + + record_every: int, default 50 + Frequency (in epochs) at which to record losses parameters of + model. If record_parameters is False, only losses are recorded. + Returns ------- MonteCarloFit: The result of the fit with following attributes: @@ -68,7 +77,6 @@ def monte_carlo_fit( training_loss: jnp.array validation_loss: jnp.array """ - len_tr_idx, len_val_idx = len_trval_data @jax.jit @@ -100,7 +108,8 @@ def loss_validation(parameters): early_stopper=early_stopper, max_epochs=max_epochs, data_batch=data_batch, - record_every=50, + record_every=record_every, + record_parameters=record_parameters, ) t1 = time.time() @@ -111,14 +120,18 @@ def loss_validation(parameters): "max_epochs": max_epochs, "batch_size": data_batch.batch_size, "batch_seed": batch_seed, + "record_every": record_every, }, training_loss=gd_result.training_loss, validation_loss=gd_result.validation_loss, optimized_parameters=gd_result.optimized_parameters, + parameters_by_epoch=gd_result.parameters_by_epoch, ) -def run_monte_carlo_fit(monte_carlo_fit, pdf_model, output_path, replica_index, Q0): +def run_monte_carlo_fit( + monte_carlo_fit, pdf_model, output_path, replica_index, Q0, record_parameters=False +): """ Runs the Monte Carlo fit and writes the output to the output directory. @@ -178,3 +191,20 @@ def run_monte_carlo_fit(monte_carlo_fit, pdf_model, output_path, replica_index, index=False, float_format="%.5e", ) + + if record_parameters: + # Save the parameters by epoch if recorded + record_every = mc_fit.monte_carlo_specs["record_every"] + parameters_by_epoch = mc_fit.parameters_by_epoch + params_path = ( + str(output_path) + f"/fit_replicas/replica_{replica_index}/parameters/" + ) + if not os.path.exists(params_path): + os.makedirs(params_path, exist_ok=True) + + for epoch_idx in range(parameters_by_epoch.shape[0]): + epoch = epoch_idx * record_every + jnp.savez( + params_path + f"params_{epoch}.npz", + params=parameters_by_epoch[epoch_idx], + ) diff --git a/colibri/ntk/__init__.py b/colibri/ntk/__init__.py new file mode 100644 index 00000000..82f06a6a --- /dev/null +++ b/colibri/ntk/__init__.py @@ -0,0 +1,34 @@ +""" +colibri.ntk + +NTK (Neural Tangent Kernel) analysis module for colibri. + +This module provides tools for computing and analyzing the Neural Tangent Kernel +of PDF fits, including eigenvalue/eigenvector computation and plotting utilities. +""" + +from colibri.ntk.ntkutils import NTKGrid, NTKStats +from colibri.ntk.eigenvalues import ( + EigenvalueGrid, + eigenvalue_grid, + eigenvalues_ensemble, +) +from colibri.ntk.eigenvector import ( + EigenvectorGrid, + eigenvector_grid, + eigenvectors_ensemble_at_epoch, +) + +__all__ = [ + # Abstract interface + "NTKGrid", + "NTKStats", + # Eigenvalue classes and functions + "EigenvalueGrid", + "eigenvalue_grid", + "eigenvalues_ensemble", + # Eigenvector classes and functions + "EigenvectorGrid", + "eigenvector_grid", + "eigenvectors_ensemble_at_epoch", +] diff --git a/colibri/ntk/api_example_ntk.ipynb b/colibri/ntk/api_example_ntk.ipynb new file mode 100644 index 00000000..5267f546 --- /dev/null +++ b/colibri/ntk/api_example_ntk.ipynb @@ -0,0 +1,366 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "8fac088b", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "os.environ['KERAS_BACKEND'] = 'jax'\n", + "\n", + "import matplotlib.pyplot as plt\n", + "%matplotlib inline\n", + "\n", + "from colibri_n3fit.api import API" + ] + }, + { + "cell_type": "markdown", + "id": "f0cfd57d", + "metadata": {}, + "source": [ + "# API examples for NTK utilities\n", + "--------------------------------" + ] + }, + { + "cell_type": "markdown", + "id": "2782cca8", + "metadata": {}, + "source": [ + "## Compute the ensemble of eigenvectors of the NTK at a given epoch\n", + "\n", + "The provider function `eigenvectors_ensemble_at_epoch` allows to compute the ensemble of eigenvectors of the NTK at a given epoch for a specified Monte Carlo fit. The minimal required arguments are the fit identifier and the epoch number (see below). It returns a dictionary with the following items:\n", + "- `eigenvectors_data`: ndarray (n_replicas, n_eigenvectors, n_flav * n_xgrid)\n", + "- `epoch`: the epoch number\n", + "- `ntk_shape`: shape of the NTK matrix before flattening\n", + "- `replica_indices`: list of replica indices included" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2396944f", + "metadata": {}, + "outputs": [], + "source": [ + "result_dict = API.eigenvectors_ensemble_at_epoch(fit=\"260123-ac-nnpdf40-dis-ntk\", epoch=100, max_workers=5)\n", + "print(f\"Shape of raw data {result_dict['eigenvectors_data'].shape}\")\n", + "print(f\"Epoch: {result_dict['epoch']}\")\n", + "print(f\"Shape of the NTK {result_dict['shape']}\")\n", + "print(f\"Number of replicas: {len(result_dict['replica_indices'])}\")" + ] + }, + { + "cell_type": "markdown", + "id": "f94bd4af", + "metadata": {}, + "source": [ + "In addition, the user can also specify:\n", + "- `replica_index_list`: tuple of replica indices to include in the computation. If not provided, all replicas available for the fit will be used.\n", + "- `max_workers`: number of parallel workers to use for the computation. Default is `min(10, n_replicas).\n", + "\n", + "Note that the `replica_index_list` is mainly meant for testing purposes with a reduced number of replicas. The example below shows how to compute the eigenvectors for a subset of the replica ensemble at epoch 100 using 5 parallel workers:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9bf8c3a6", + "metadata": {}, + "outputs": [], + "source": [ + "result_dict = API.eigenvectors_ensemble_at_epoch(fit=\"260123-ac-nnpdf40-dis-ntk\", epoch=100, replica_index_list=(10,1,75,3,23), max_workers=5)\n", + "print(f\"Shape of raw data {result_dict['eigenvectors_data'].shape}\")\n", + "print(f\"Epoch: {result_dict['epoch']}\")\n", + "print(f\"Shape of the NTK {result_dict['shape']}\")\n", + "print(f\"Number of replicas: {len(result_dict['replica_indices'])}\")" + ] + }, + { + "cell_type": "markdown", + "id": "ca8fab46", + "metadata": {}, + "source": [ + "### Plot the eigenvectors by fit and flavour or by rank and flavour" + ] + }, + { + "cell_type": "markdown", + "id": "be55a2fd", + "metadata": {}, + "source": [ + "The eigenvectors of the NTK can be visualised using two different plotting functions: `plot_eigenvectors_by_fit_and_flavour` and `plot_eigenvectors_by_rank_and_flavour`. The first function\n", + "allows to compare a fixed set of eigenvector ranks across different fits for specified flavours. The second function allows to compare different eigenvector ranks within the same fit for specified flavours. Both functions require the following arguments:\n", + "- `fits`: a list of dictionaries, each containing the fit identifier and an optional label for the legend.\n", + "- `rank_indices`: a list of eigenvector ranks to plot.\n", + "- `flavour_mapping`: a list of flavour codes to include in the plots.\n", + "- `epoch`: the epoch number at which to compute the eigenvectors.\n", + "\n", + "In addition, on top of the additional arguments `max_workers` and `replica_index_list` described above, the user can also specify:\n", + "- `error_type`: type of error to display. Available options are:\n", + " - `mean`, that shows the mean and standard deviation across replicas.\n", + " - `median`, that shows the median and 68% confidence interval across replicas. Default is `mean`.\n", + "- `xscale`, `yscale`: scale of the x- and y-axis. Default is `linear.\n", + "- `ymin`, `ymax`: minimum and maximum values for the y-axis. Default is `None`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5ea9480d", + "metadata": {}, + "outputs": [], + "source": [ + "res = API.plot_eigenvectors_by_rank_and_flavour(fits=[{\"id\": \"260123-ac-nnpdf40-dis-ntk\", \"label\": r\"$\\textrm{Custom label for fit}$\"}], \n", + " rank_indices=[0, 1, 2], \n", + " flavour_mapping=['g', 'T3'], \n", + " epoch=100,\n", + " error_type='mean',)\n", + "results = {r[1].name : (r[0], r[1]) for r in res}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f1528f4a", + "metadata": {}, + "outputs": [], + "source": [ + "results['eigvec_1_g'][1].ax.set_title(\"Eigenvector 1, gluon\")\n", + "results['eigvec_1_g'][0]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c030e187", + "metadata": {}, + "outputs": [], + "source": [ + "results['eigvec_2_T3'][1].ax.set_title(\"Eigenvector 1, T3\")\n", + "results['eigvec_2_T3'][0]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e2d08d6c", + "metadata": {}, + "outputs": [], + "source": [ + "gen_eigvec_by_fit = API.plot_eigenvectors_by_fit_and_flavour(fits=[{\"id\": \"260123-ac-nnpdf40-dis-ntk\", \"label\": r\"$\\textrm{Custom label for fit}$\"}], \n", + " rank_indices=[0, 1, 2], \n", + " flavour_mapping=['g', 'T3'], \n", + " epoch=100,\n", + " error_type='mean',)\n", + "results_eigvec_by_fit = {r[1].name : (r[0], r[1]) for r in gen_eigvec_by_fit}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "62c11bb0", + "metadata": {}, + "outputs": [], + "source": [ + "results_eigvec_by_fit['eigvecs_T3'][0]" + ] + }, + { + "cell_type": "markdown", + "id": "614f56f5", + "metadata": {}, + "source": [ + "# Compute the eigenvalues of the NTK\n", + "\n", + "Similarly to the eigenvectors, the eigenvalue of the NTK can be computed and visualised accordingly. For instance, the user can use the provider function `eigenvalues_ensemble` to compute the ensemble of eigenvalues of the NTK for a specified Monte Carlo fit across **all** specified epochs. The function can be called with the following arguments:\n", + "- `fit`: fit identifier.\n", + "- `replica_index_list`: tuple of replica indices to include in the computation. If not provided, all replicas available for the fit will be used. It is mainly meant for testing purposes with a reduced number of replicas.\n", + "- `max_epoch`: maximum epoch to consider. Replicas that do not reach this epoch will be excluded from the computation. Default is `None`, meaning that the epochs considered will be those in the intersection of all replicas.\n", + "- `max_workers`: number of parallel workers to use for the computation. Default is `min(10, n_replicas).\n", + "- `force_recompute`: whether to force recomputation of the eigenvalues even if cached data is available. Default is `False`.\n", + "\n", + "Note that, contrary to the case of the eigenvectors, the eigenvalue are stored in the folders of each replicas. If the users wants to recompute them using, for instance, a different common epochs rule, the `force_recompute` flag should be set to `True`.\n", + "\n", + "The function returns a dictionary with the following items:\n", + "- `eigenvalues_by_epoch`: dict mapping epoch -> ndarray (n_replicas, n_eigenvalues)\n", + "- `epochs`: list of epochs\n", + "- `ntk_shape`: shape of NTK matrix\n", + "- `replica_indices`: list of replica indices included" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c81020d8", + "metadata": {}, + "outputs": [], + "source": [ + "eigvals_dict = API.eigenvalues_ensemble(fit=\"260123-ac-nnpdf40-dis-ntk\", force_recompute=False)\n", + "print(f\"Shape of raw data {eigvals_dict['eigenvalues_by_epoch'][0].shape}\")\n", + "print(f\"Number of epochs: {len(eigvals_dict['epochs'])} <----\")\n", + "print(f\"Maximum epoch: {max(eigvals_dict['epochs'])}\")\n", + "print(f\"Shape of the NTK {eigvals_dict['ntk_shape']}\")\n", + "print(f\"Number of replicas: {len(eigvals_dict['replica_indices'])}\")" + ] + }, + { + "cell_type": "markdown", + "id": "84853751", + "metadata": {}, + "source": [ + "Note that in this case, we have not specified `max_epoch`, so all epochs in the intersection of all replicas are considered. This results in just 21 epochs considered, with a maximum epoch of 1000.\n", + "\n", + "The eigenvalues can then be visualised using the plotting function `plot_eigvals_by_fit` or `plot_eigvals_by_rank`, which work similarly to the eigenvector plotting functions described above." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c442a0af", + "metadata": {}, + "outputs": [], + "source": [ + "gen_eigval = API.plot_eigvals_by_rank(fits=[{\"id\": \"260123-ac-nnpdf40-dis-ntk\", \"label\": r\"$\\textrm{Custom label for fit}$\"}], rank_indices=[0, 1, 2])\n", + "results_eigval = {r[1].name : (r[0], r[1]) for r in gen_eigval}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "13df3d10", + "metadata": {}, + "outputs": [], + "source": [ + "results_eigval['lambda_1'][0]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e82e1b74", + "metadata": {}, + "outputs": [], + "source": [ + "gen_eigval_by_fit = API.plot_eigvals_by_fit(fits=[{\"id\": \"260123-ac-nnpdf40-dis-ntk\", \"label\": r\"$\\textrm{Custom label for fit}$\"}], rank_indices=[0, 1, 2], error_type='median', yscale='log')\n", + "results_eigval_by_fit = {r[1].name : (r[0], r[1]) for r in gen_eigval_by_fit}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "705970df", + "metadata": {}, + "outputs": [], + "source": [ + "results_eigval_by_fit['eigvals_$\\\\textrm{Custom label for fit}$'][0]" + ] + }, + { + "cell_type": "markdown", + "id": "5a8cb987", + "metadata": {}, + "source": [ + "In addition, there is an extra utility function, `plot_eigvals_replicas_by_rank`, that shows each replica's eigenvalues for a given fit and rank index across epochs." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8f441aed", + "metadata": {}, + "outputs": [], + "source": [ + "generator_eigval_replicas = API.plot_eigvals_replicas_by_rank(fits=[{\"id\": \"260123-ac-nnpdf40-dis-ntk\", \"label\": r\"$\\textrm{Custom label for fit}$\"}], rank_indices=[0])\n", + "results_eigval_replicas = {r[1].name : (r[0], r[1]) for r in generator_eigval_replicas}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d9b61c1e", + "metadata": {}, + "outputs": [], + "source": [ + "results_eigval_replicas['lambda_replicas_1'][1].ax.set_yscale('linear')\n", + "results_eigval_replicas['lambda_replicas_1'][0]" + ] + }, + { + "cell_type": "markdown", + "id": "6a22dada", + "metadata": {}, + "source": [ + "Finally, we can request a higher maximum epoch. This will filter out replicas that do not reach that epoch." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7bd647ea", + "metadata": {}, + "outputs": [], + "source": [ + "eigvals_dict = API.eigenvalues_ensemble(fit=\"260123-ac-nnpdf40-dis-ntk\", force_recompute=False, max_epoch=25000)\n", + "print(f\"Shape of raw data {eigvals_dict['eigenvalues_by_epoch'][0].shape}\")\n", + "print(f\"Number of epochs: {len(eigvals_dict['epochs'])} <----\")\n", + "print(f\"Maximum epoch: {max(eigvals_dict['epochs'])}\")\n", + "print(f\"Shape of the NTK {eigvals_dict['ntk_shape']}\")\n", + "print(f\"Number of replicas: {len(eigvals_dict['replica_indices'])}\")" + ] + }, + { + "cell_type": "markdown", + "id": "a96e1fca", + "metadata": {}, + "source": [ + "Note that now we have 501 epochs considered, with a maximum epoch of 25000. This however reduces the number of replicas included in the computation, as some replicas do not reach that epoch." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6ea25157", + "metadata": {}, + "outputs": [], + "source": [ + "gen_eigval = API.plot_eigvals_by_rank(fits=[{\"id\": \"260123-ac-nnpdf40-dis-ntk\", \"label\": r\"$\\textrm{Custom label for fit}$\"}], rank_indices=[0, 1, 2], max_epoch=25000)\n", + "results_eigval = {r[1].name : (r[0], r[1]) for r in gen_eigval}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b49a1f2f", + "metadata": {}, + "outputs": [], + "source": [ + "results_eigval['lambda_1'][0]" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "colibri-dev", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/colibri/ntk/eigenvalues.py b/colibri/ntk/eigenvalues.py new file mode 100644 index 00000000..0cbb6f78 --- /dev/null +++ b/colibri/ntk/eigenvalues.py @@ -0,0 +1,324 @@ +""" +colibri.ntk.ntk.py + +This module contains the routine that computes the Neural Tangent Kernel (NTK) +for a given PDF model and provides statistical analysis tools for NTK ensembles. + +""" + +import logging +from concurrent.futures import ThreadPoolExecutor, as_completed +from pathlib import Path +from typing import Dict, List, Optional + +import numpy as np +from tqdm import tqdm + +from reportengine import collect +from validphys.core import FitSpec + +from colibri.ntk.ntkutils import ( + NTKGrid, + NTKStats, + compute_eigenvalues_for_replica, + get_completed_replicas, + get_replica_idx_list, + load_eigenvalues_ensemble, +) + +log = logging.getLogger(__name__) + + +class EigenvalueGrid(NTKGrid): + """ + Container for eigenvalue data from a single fit. + + This class holds eigenvalue statistics across epochs and provides methods + to extract eigenvalue trajectories for plotting. Implements the NTKGrid + interface for use with generic plotting utilities. + + Parameters + ---------- + label : str + Human-readable label for the fit (e.g., "L0", "L1", "NNPDF4.0") + epochs : list of int + List of epoch numbers + eigenvalues_stats : dict + Dictionary mapping epoch -> NTKStats object containing eigenvalues. + Each NTKStats has shape (nreplicas, n_eigenvalues). + + Attributes + ---------- + label : str + Label for this fit + epochs : list of int + Available epochs + n_eigenvalues : int + Number of eigenvalues per replica + nreplicas : int + Number of replicas in the ensemble + """ + + def __init__( + self, + label: str, + epochs: List[int], + eigenvalues_stats: Dict[int, NTKStats], + ): + if not epochs: + raise ValueError("epochs cannot be empty") + if not eigenvalues_stats: + raise ValueError("eigenvalues_stats cannot be empty") + + self._label = label + self.epochs = sorted(epochs) + self._eigenvalues_stats = eigenvalues_stats + + first_epoch = self.epochs[0] + first_stats = self._eigenvalues_stats[first_epoch] + self.nreplicas = first_stats.data.shape[0] + self.n_eigenvalues = first_stats.data.shape[1] + + @property + def label(self) -> str: + """Human-readable label for this grid (e.g., fit name).""" + return self._label + + @property + def n_ranks(self) -> int: + """Number of eigenvalue ranks available.""" + return self.n_eigenvalues + + @property + def xgrid(self) -> np.ndarray: + """X-axis grid for plotting (epochs).""" + return np.array(self.epochs) + + @property + def xlabel(self) -> str: + """Label for x-axis.""" + return r"$\rm Epochs$" + + def get_stat_by_epoch(self, epoch: int) -> NTKStats: + """Get NTKStats for a specific epoch.""" + if epoch not in self._eigenvalues_stats: + raise ValueError(f"Epoch {epoch} not found in eigenvalues_stats") + return self._eigenvalues_stats[epoch] + + def get_plotting_data(self, rank_index: int, **kwargs) -> NTKStats: + """ + Get plotting data for a specific eigenvalue rank. + + Parameters + ---------- + rank_index : int + Index of the eigenvalue (1 = largest eigenvalue) + **kwargs + Ignored for eigenvalues (no additional selection needed) + + Returns + ------- + NTKStats + Statistics for the eigenvalue trajectory, shape (nreplicas, n_epochs) + """ + if rank_index <= 0 or rank_index > self.n_eigenvalues: + raise ValueError( + f"rank_index {rank_index} out of range [1, {self.n_eigenvalues}]" + ) + data_by_epoch = [ + self._eigenvalues_stats[epoch].data[:, rank_index-1] for epoch in self.epochs + ] + + # Stack into (nreplicas, n_epochs) array + combined_data = np.stack(data_by_epoch, axis=1) + return NTKStats(combined_data) + + def get_plotting_label(self, rank_index: int, **kwargs) -> str: + """ + Get legend label for a specific eigenvalue rank. + + Parameters + ---------- + rank_index : int + Index of the eigenvalue (1 = largest eigenvalue) + **kwargs + Ignored for eigenvalues + + Returns + ------- + str + LaTeX-formatted label (e.g., r"$\\lambda^{(1)}$") + """ + return rf"$\lambda^{{({rank_index})}}$" + + def save(self, path: Path): + """Serialize this EigenvalueGrid to disk.""" + epochs = np.array(self.epochs) + # Stack all epoch data into shape (n_epochs, n_replicas, n_eigenvalues) + data = np.stack( + [self._eigenvalues_stats[ep].data for ep in self.epochs], axis=0 + ) + np.savez_compressed( + path, + label=self._label, + epochs=epochs, + eigenvalues=data, + ) + + @classmethod + def load(cls, path: Path) -> "EigenvalueGrid": + """Deserialize an EigenvalueGrid from a .npz file.""" + f = np.load(path, allow_pickle=False) + label = str(f["label"]) + epochs = f["epochs"].tolist() + eigenvalues = f["eigenvalues"] # (n_epochs, n_replicas, n_eigenvalues) + eigenvalues_stats = { + epoch: NTKStats(eigenvalues[i]) for i, epoch in enumerate(epochs) + } + return cls(label=label, epochs=epochs, eigenvalues_stats=eigenvalues_stats) + + + + +def eigenvalues_ensemble( + fit: FitSpec, + replicas_path: Path, + replica_index_list: Optional[tuple] = None, + max_epoch: Optional[int] = None, + force_recompute: bool = False, + max_workers: Optional[int] = None, + name: Optional[str] = None, + kwargs: frozenset = frozenset({}), +): + """ + Compute NTK eigenvalues for all replicas across all specified epochs. + + This function computes eigenvalues immediately after each NTK and discards + the NTK matrix. Each replica is saved to disk as soon as it completes. + + Parameters + ---------- + fit : FitSpec + The fit object containing fit information + replicas_path : Path + Path to the replicas directory + replica_index_list : tuple, optional + Specific replica indices to compute. If None, computes all. Mainly for + testing. + max_epoch : int, optional + Maximum number of epochs to consider per replica. If None, uses all + available epochs. Mainly for testing. + force_recompute : bool, optional + If True, recompute all replicas even if cached. Default is False. + max_workers : int, optional + Maximum number of parallel workers. If None, defaults to min(10, + n_replicas). + name: str, optional + Optional name to include in the filename for clarity when saving results. + kwargs : dict, optional + Additional kwargs to pass to compute_eigenvalues_at_epoch_for_replica + + Returns + ------- + dict + Dictionary with keys: + - `eigenvalues_by_epoch`: dict mapping epoch -> ndarray (n_replicas, n_eigenvalues) + - `epochs`: list of epochs - 'ntk_shape': shape of NTK matrix + - `ntk_shape`: shape of the NTK matrix before flattening + - `replica_indices`: list of replica indices included + """ + + # Determine which replicas to compute + if replica_index_list is None: + replica_index_list = get_replica_idx_list(replicas_path) + + # Check for already completed replicas + if force_recompute: + completed = [] + log.info("Force recompute enabled: ignoring cached replicas") + else: + completed = get_completed_replicas(replicas_path, name) + + pending = sorted([r for r in replica_index_list if r not in completed]) + + if not pending: + log.info(f"All {len(completed)} replicas already computed. Loading from cache.") + return load_eigenvalues_ensemble(replicas_path, max_epoch, name) + + log.info( + f"Computing eigenvalues: {len(pending)} pending, " + f"{len(completed)} already done" + ) + + n_pending = len(pending) + if max_workers is None: + max_workers = min(10, n_pending) + log.info(f"Using max_workers={max_workers} for parallel computation") + + # Compute pending replicas in parallel + with ThreadPoolExecutor(max_workers=max_workers) as executor: + future_to_replica = { + executor.submit( + compute_eigenvalues_for_replica, + fit.name, + replicas_path, + replica_idx, + max_epoch, + name, + **dict(kwargs), + ): replica_idx + for replica_idx in pending + } + + # TODO: Do we want to keep that status bar with tqdm? + # Track progress + futures_iter = as_completed(future_to_replica) + for future in tqdm(futures_iter, total=n_pending, desc="Computing eigenvalues"): + replica_idx = future_to_replica[future] + try: + result = future.result() + if result is None: + log.warning(f"Replica {replica_idx} failed") + except Exception as e: + log.warning(f"Error computing replica {replica_idx}: {e}") + + # Load all results (completed + newly computed) + return load_eigenvalues_ensemble(replicas_path, max_epoch, name) + + +def eigenvalue_grid(fit: FitSpec, eigenvalues_ensemble) -> EigenvalueGrid: + """ + Create an EigenvalueGrid from NTK eigenvalues ensemble data. + + Parameters + ---------- + fit : FitSpec + The fit object containing fit information + eigenvalues_ensemble : dict + Output from eigenvalues_ensemble function + + Returns + ------- + EigenvalueGrid + The constructed EigenvalueGrid object + """ + epochs = eigenvalues_ensemble["epochs"] + eigvals_by_epoch = eigenvalues_ensemble["eigenvalues_by_epoch"] + label = fit.label + + # Wrap numpy arrays in NTKStats objects + eigenvalues_stats = { + epoch: NTKStats(data) for epoch, data in eigvals_by_epoch.items() + } + + return EigenvalueGrid( + label=label, + epochs=epochs, + eigenvalues_stats=eigenvalues_stats, + ) + +def eigenvalues_at_epoch(eigenvalue_grid, epoch: int): + return eigenvalue_grid.get_stat_by_epoch(epoch) + +# Collect eigenvalue grids across fits +eigval_grids_by_fit = collect("eigenvalue_grid", ("fits",)) diff --git a/colibri/ntk/eigenvector.py b/colibri/ntk/eigenvector.py new file mode 100644 index 00000000..1d9ff3c4 --- /dev/null +++ b/colibri/ntk/eigenvector.py @@ -0,0 +1,336 @@ +""" +colibri.ntk.eigenvector.py +""" + +import functools +import logging +import threading +from concurrent.futures import ThreadPoolExecutor, as_completed +from pathlib import Path +from typing import Optional + +import pandas as pd +import numpy as np +from tqdm import tqdm + +from reportengine import collect +from validphys.core import FitSpec + +from colibri.constants import XGRID, FLAVOUR_TO_ID_MAPPING +from colibri.ntk.ntkutils import ( + NTKGrid, + NTKStats, + NTK_ORDERING, + compute_eigenvectors_at_epoch_for_replica, + get_replica_idx_list, +) + +log = logging.getLogger(__name__) + +class EigenvectorGrid(NTKGrid): + """ + Container for eigenvector data from a single fit at a specific epoch. + + This class holds eigenvector statistics and provides methods to extract + eigenvector components for plotting. Implements the NTKGrid interface + for use with generic plotting utilities. + + The eigenvector data has shape (nreplicas, n_eigenvectors, nflavors * n_xgrid). + + Parameters + ---------- + label : str + Human-readable label for the fit (e.g., "L0", "L1") + epoch : int + The epoch at which eigenvectors were computed + shape : tuple[int] + The shape of the NTK matrix before flattening (nflavors, n_xgrid, nflavors, n_xgrid) + eigenvectors_stat : NTKStats + Statistics object with data of shape (nreplicas, n_eigenvectors, nflavors * n_xgrid) + + Attributes + ---------- + label : str + Label for this fit + epoch : int + Epoch number + nflavors : int + Number of flavours + n_eigenvectors : int + Number of eigenvectors + n_xgrid : int + Number of x-grid points + nreplicas : int + Number of replicas + """ + + def __init__( + self, + label: str, + epoch: int, + shape: tuple[int], + eigenvectors_stat: NTKStats, + ): + # Check input dimension + if len(eigenvectors_stat.data.shape) != 3: + raise ValueError( + f"eigenvectors_stat data must be 3D (nreplicas, n_eigenvalues, n_parameters), " + f"got shape {eigenvectors_stat.data.shape}" + ) + + self._label = label + self.epoch = epoch + self.nflavors = shape[0] + self._eigenvectors_stat = eigenvectors_stat + + self.nreplicas = len(eigenvectors_stat.error_members()) + self.n_eigenvectors = shape[0] * shape[1] + self.n_xgrid = shape[1] + + # NTKGrid interface implementation + @property + def label(self) -> str: + """Human-readable label for this grid (e.g., fit name).""" + return self._label + + @property + def n_ranks(self) -> int: + """Number of eigenvector ranks available.""" + return self.n_eigenvectors + + @property + def xgrid(self) -> np.ndarray: + """X-axis grid for plotting (x values).""" + return np.array(XGRID) + + @property + def xlabel(self) -> str: + """Label for x-axis.""" + return r"$x$" + + def get_stat(self) -> NTKStats: + """Get the full eigenvector statistics object.""" + return self._eigenvectors_stat + + def get_plotting_data( + self, rank_index: int, flavour_index: int = 0, **kwargs + ) -> NTKStats: + """ + Get eigenvector component for plotting. + + This method reshapes (nreplicas, nflavors*n_xgrid) -> (nreplicas, + nflavors, n_xgrid) and selects the specified flavour. + + Parameters + ---------- + rank_index : int + Index of the eigenvector (1 = largest eigenvalue's eigenvector) + flavour_index : int, optional + The ID that represents a specific flavour (see + `FLAVOUR_TO_ID_MAPPING` in `colibri.constants`) + **kwargs + Additional kwargs (ignored) + + Returns + ------- + NTKStats + Statistics for the eigenvector component, shape (nreplicas, n_xgrid) + """ + if rank_index <= 0 or rank_index > self.n_eigenvectors: + raise ValueError( + f"rank_index {rank_index} out of range [0, {self.n_eigenvectors}]" + ) + if flavour_index < 0 or flavour_index >= self.nflavors: + raise ValueError( + f"flavour_index {flavour_index} out of range [0, {self.nflavors})" + ) + + # Get eigenvector data for the specified rank: (nreplicas, n_flaovors * n_xgrid) + eigvec_data = self._eigenvectors_stat.data[:, :, rank_index-1] + + # Reshape to (nreplicas, nflavors, n_xgrid) + reshaped = eigvec_data.reshape(self.nreplicas, self.nflavors, self.n_xgrid, order=NTK_ORDERING) + + # Select the specified flavour: (nreplicas, n_xgrid) + flavour_data = reshaped[:, flavour_index, :] + + return NTKStats(flavour_data) + + def get_plotting_label( + self, rank_index: int, flavour_index: int = 0, **kwargs + ) -> str: + """ + Get legend label for a specific eigenvector component. + + Parameters + ---------- + rank_index : int + Index of the eigenvector (1 = largest eigenvalue's eigenvector) + flavour_index : int, optional + Index of the flavour (see `FLAVOUR_TO_ID_MAPPING` in `colibri.constants`) + **kwargs + Additional kwargs (ignored) + + Returns + ------- + str + LaTeX-formatted label (e.g., r"$v^{(1)}_{\rm GLUON}$") + """ + # TODO: choose if to include flavour in label + return rf"$z^{{({rank_index})}}$" + + +@functools.cache +def eigenvectors_ensemble_at_epoch( + fit: FitSpec, + replicas_path: Path, + epoch: int, + replica_index_list: Optional[tuple] = None, + max_workers: Optional[int] = None, + kwargs: frozenset = frozenset({}), +): + """ + Compute NTK eigenvalues for all replicas for a specified epoch. + + Parameters + ---------- + fit : FitSpec + The fit object containing fit information + replicas_path : Path + Path to the replicas directory + epoch : int + The epoch at which to compute eigenvectors + replica_index_list : tuple, optional + Specific replica indices to compute. If None, computes all. Mainly for + testing. + max_workers : int, optional + Maximum number of parallel workers. If None, defaults to min(10, + n_replicas). + kwargs : dict, optional + Additional kwargs to pass to compute_eigenvectors_at_epoch_for_replica + + Returns + ------- + dict + Dictionary with keys: + - `eigenvectors_data`: ndarray (n_replicas, n_flav * n_xgrid, n_eigenvectors) (flavor-major) + - `epoch`: the epoch number + - `ntk_shape`: shape of the NTK matrix before flattening + - `replica_indices`: list of replica indices included + """ + + # Determine which replicas to compute + if replica_index_list is None: + replica_index_list = get_replica_idx_list(replicas_path) + n_replicas = len(replica_index_list) + + if max_workers is None: + max_workers = min(10, n_replicas) + log.info(f"Using max_workers={max_workers} for parallel computation") + + eigenvectors_data = {} + shape = None + + # Compute pending replicas in parallel + lock = threading.Lock() + with ThreadPoolExecutor(max_workers=max_workers) as executor: + future_to_replica = { + executor.submit( + compute_eigenvectors_at_epoch_for_replica, + fit.name, + replicas_path, + replica_idx, + epoch, + **dict(kwargs), + ): replica_idx + for replica_idx in replica_index_list + } + + futures_iter = as_completed(future_to_replica) + for future in tqdm( + futures_iter, total=n_replicas, desc="Computing eigenvectors" + ): + replica_idx = future_to_replica[future] + try: + result = future.result() + if result is None: + log.warning(f"Replica {replica_idx} failed") + eigenvectors_data[replica_idx] = result[2] + with lock: + if shape is None: + shape = result[3] + except Exception as e: + log.warning(f"Error computing replica {replica_idx}: {e}") + + # Check we have at least one replicas + if not eigenvectors_data: + raise RuntimeError("No replicas were successfully computed.") + + # Convert the final dict into a ndarray + eigenvectors_data = np.array( + [ + eigenvectors_data[replica_idx] + for replica_idx in sorted(eigenvectors_data.keys()) + ] + ) # Shape: (nreplicas, n_eigenvectors, n_eigenvectors) + + return { + "eigenvectors_data": eigenvectors_data, + "epoch": epoch, + "shape": shape, + "replica_indices": replica_index_list, + } + + +def eigenvector_grid(fit: FitSpec, eigenvectors_ensemble_at_epoch) -> EigenvectorGrid: + """ + Create an EigenvectorGrid from NTK eigenvectors ensemble data. + + Parameters + ---------- + fit : FitSpec + The fit object containing fit information + eigenvectors_ensemble_at_epoch : dict + Output from eigenvectors_ensemble_at_epoch function + + Returns + ------- + EigenvectorGrid + The constructed EigenvectorGrid object + """ + eigenvectors_array = eigenvectors_ensemble_at_epoch["eigenvectors_data"] + eigenvectors_stat = NTKStats(data=eigenvectors_array) + shape = eigenvectors_ensemble_at_epoch["shape"] + + return EigenvectorGrid( + label=fit.label, + epoch=eigenvectors_ensemble_at_epoch["epoch"], + shape=shape, + eigenvectors_stat=eigenvectors_stat, + ) + +def eigenvectors_at_epoch(eigenvector_grid: EigenvectorGrid, + flavours: list = list(FLAVOUR_TO_ID_MAPPING.keys())) -> NTKStats: + """Returns DataFrame with eigenvector components for specified flavours.""" + eigvec_data = eigenvector_grid.get_stat().data # Shape (nreplicas, nflavors * n_xgrid, n_eigenvectors) + + cl_index = pd.Index(range(eigenvector_grid.n_eigenvectors), name="rank") + + # Index follows NTK_ORDERING: for each flavour, all x-points in sequence + index = pd.MultiIndex.from_tuples( + [(fl, i + 1) for fl in flavours for i in range(len(XGRID))], + names=["flavour", "x"], + ) + if len(flavours) > eigenvector_grid.nflavors: + raise ValueError( + f"flavour_indices {flavours} out of range [0, {eigenvector_grid.nflavors})" + ) + + dfs = [pd.DataFrame( + data=eigvec_data[k], index=index, columns=cl_index) + for k in range(eigenvector_grid.nreplicas) + ] + + return NTKStats(dfs) + +eigvecs_grids_by_fit = collect("eigenvector_grid", ("fits",)) diff --git a/colibri/ntk/ntk.py b/colibri/ntk/ntk.py new file mode 100644 index 00000000..aeb1efe6 --- /dev/null +++ b/colibri/ntk/ntk.py @@ -0,0 +1,147 @@ +from functools import lru_cache +import logging +import threading +from concurrent.futures import ThreadPoolExecutor, as_completed +from pathlib import Path +from typing import Optional + +import jax.numpy as jnp +import numpy as np +from tqdm import tqdm +import pandas as pd + +from validphys.core import FitSpec + +from colibri.constants import XGRID +from colibri.utils import get_pdf_model +from colibri.ntk.ntkutils import ( + get_replica_idx_list, + get_parameters_all_epochs, + compute_ntk, + NTKStats +) + +log = logging.getLogger(__name__) + +@lru_cache(maxsize=None) +def ntk_ensemble_at_epoch( + fit: FitSpec, + replicas_path: Path, + epoch: int, + replica_index_list: Optional[tuple] = None, + max_workers: Optional[int] = None, + kwargs: frozenset = frozenset({}), +): + """ + Compute the NTK ensemble at a specific epoch across multiple replicas. + + Parameters + ---------- + fit : FitSpec + The fit object containing fit information + replicas_path : Path + Path to the replicas directory + epoch : int + The epoch at which to compute eigenvectors + replica_index_list : tuple, optional + Specific replica indices to compute. If None, computes all. Mainly for + testing. + max_workers : int, optional + Maximum number of parallel workers. If None, defaults to min(10, + n_replicas). + kwargs : dict, optional + Additional kwargs to pass to compute_eigenvectors_at_epoch_for_replica + """ + + # Determine which replicas to compute + if replica_index_list is None: + replica_index_list = get_replica_idx_list(replicas_path) + n_replicas = len(replica_index_list) + + if max_workers is None: + max_workers = min(10, n_replicas) + log.info(f"Using max_workers={max_workers} for parallel computation") + + ntk_data = {} + shape = None + + def compute_ntk_at_epoch_for_replica(replica_idx): + pdf_model = get_pdf_model(fit.name, replica_idx=replica_idx) + param_files = get_parameters_all_epochs(replicas_path, replica_idx) + params_at_epoch = param_files.get(epoch, None) + if params_at_epoch is None: + raise ValueError( + f"Epoch {epoch} not found for replica {replica_idx} in fit {fit.name}" + ) + + params = jnp.load(params_at_epoch)["params"] + ntk, shape = compute_ntk(pdf_model, params, **dict(kwargs)) + return (ntk, shape) + + lock = threading.Lock() + with ThreadPoolExecutor(max_workers=max_workers) as executor: + future_to_replica = { + executor.submit( + compute_ntk_at_epoch_for_replica, + replica_idx + ): replica_idx + for replica_idx in replica_index_list + } + + futures_iter = as_completed(future_to_replica) + for future in tqdm( + futures_iter, total=n_replicas, desc="Computing NTKs" + ): + replica_idx = future_to_replica[future] + try: + result = future.result() + if result is None: + log.warning(f"Replica {replica_idx} failed") + ntk_data[replica_idx], _shape = result + with lock: + if shape is None: + shape = _shape + except Exception as e: + log.warning(f"Error computing the NTK for replica {replica_idx}: {e}") + + if not ntk_data: + raise ValueError("No NTK data computed successfully") + + # Convert the final dict into a ndarray + ntk_data_array = np.array( + [ + ntk_data[replica_idx] for replica_idx in sorted(ntk_data.keys()) + ] + ) # Shape: (n_repl, nflav*ngrid, nflav*ngrid) + + return { + "ntk_data": ntk_data_array, + "epoch": epoch, + "shape": shape, + "replica_indices": replica_index_list, + } + +def ntk_at_epoch(ntk_ensemble_at_epoch, flavours): + """ + Returns dataframe with the NTK at a specific epoch. + """ + ntk_ensemble = ntk_ensemble_at_epoch["ntk_data"] + shape = ntk_ensemble_at_epoch["shape"] + n_replicas = len(ntk_ensemble_at_epoch["replica_indices"]) + + # Index follows NTK_ORDERING: for each flavour, all x-points in sequence + index = pd.MultiIndex.from_tuples( + [(fl, i + 1) for fl in flavours for i in range(len(XGRID))], + names=["flavour", "x"], + ) + if len(flavours) > shape[0]: + raise ValueError( + f"flavour_indices {flavours} out of range [0, {shape[0]})" + ) + + dfs = [pd.DataFrame( + data=ntk_ensemble[k], index=index, columns=index) + for k in range(n_replicas) + ] + + return NTKStats(dfs) \ No newline at end of file diff --git a/colibri/ntk/ntkutils.py b/colibri/ntk/ntkutils.py new file mode 100644 index 00000000..0e8dfd01 --- /dev/null +++ b/colibri/ntk/ntkutils.py @@ -0,0 +1,853 @@ +""" +colibri.ntkutils.py + +Module containing several utils for the analysis of the NTK. + +""" + +from __future__ import annotations + +import abc +import functools +import logging +from functools import lru_cache +from pathlib import Path + +import jax +import jax.numpy as jnp +import numpy as np +import pandas as pd + +from validphys.core import MCStats + +from colibri.constants import XGRID +from colibri.utils import get_pdf_model + +log = logging.getLogger(__name__) + +NTK_EIGVAL_TOKEN = "ntk_eigenvalues" +NTK_ORDERING = "C" # flavour-major C-order: all x-points per flavour, then next flavour + +def _check_replicas(self, other): + if isinstance(other, NTKStats) and self.data.shape[0] != other.data.shape[0]: + raise ValueError( + f"NTKStats replica count mismatch: {self.data.shape[0]} vs {other.data.shape[0]}" + ) + + +def _check_indices(self, other, op_name): + other_index, other_columns = self._get_index(other) + if op_name == "__matmul__": + if self._df_columns is not None and other_index is not None and not self._df_columns.equals(other_index): + raise ValueError( + f"Index mismatch in matmul: left columns {self._df_columns.tolist()} " + f"do not match right index {other_index.tolist()}" + ) + elif op_name == "__rmatmul__": + if other_columns is not None and self._df_index is not None and not other_columns.equals(self._df_index): + raise ValueError( + f"Index mismatch in matmul: left columns {other_columns.tolist()} " + f"do not match right index {self._df_index.tolist()}" + ) + else: + if self._df_index is not None and other_index is not None and not self._df_index.equals(other_index): + raise ValueError( + f"Index mismatch in {op_name}: left index {self._df_index.tolist()} " + f"does not match right index {other_index.tolist()}" + ) + if self._df_columns is not None and other_columns is not None and not self._df_columns.equals(other_columns): + raise ValueError( + f"Column index mismatch in {op_name}: left columns {self._df_columns.tolist()} " + f"do not match right columns {other_columns.tolist()}" + ) + + +def _checks_ntkstats_compat(method): + """Decorator that applies replica and index compatibility checks.""" + + @functools.wraps(method) + def wrapper(self, other): + _check_replicas(self, other) + _check_indices(self, other, method.__name__) + return method(self, other) + + return wrapper + + + +class NTKStats(MCStats): + """ + Container for NTK statistics across replicas at a single epoch. + + When constructed with a list of DataFrames, the index is preserved and + accessible via the ``frames`` property, while ``data`` remains a numpy + array for all statistical operations. + """ + + # Tell numpy's ufunc dispatch (which backs the @ operator since numpy ≥ 1.16) + # to return NotImplemented rather than coercing this object, so that Python + # can fall through to NTKStats.__rmatmul__ when numpy arrays appear on the left. + __array_ufunc__ = None + + def __init__(self, data): + if isinstance(data, list) and data and isinstance(data[0], pd.DataFrame): + self._df_index = data[0].index + self._df_columns = data[0].columns + super().__init__(np.stack([df.values for df in data])) + else: + self._df_index = None + self._df_columns = None + super().__init__(data) + + self.shape = self.data.shape[1:] + self.ndim = len(self.data.shape) - 1 # Number of dimensions of the observable (e.g., 0 for scalar, 1 for vector, 2 for matrix) + self.nreplica = self.data.shape[0] # Number of replicas (first dimension) + + @property + def frames(self): + """Return data as a list of DataFrames (preserving the original index), or None.""" + if self._df_index is None: + return self.data + return [ + pd.DataFrame(self.data[k], index=self._df_index, columns=self._df_columns) + for k in range(len(self.data)) + ] + + def _with_index(self, data: np.ndarray) -> NTKStats: + """Wrap a numpy array in NTKStats, preserving this instance's index metadata.""" + result = NTKStats(data) + result._df_index = self._df_index + result._df_columns = self._df_columns + return result + + def _other_data(self, other): + return other.data if isinstance(other, NTKStats) else other + + def central_value(self): + cv = self.data.mean(axis=0) + if self._df_index is not None and self._df_columns is not None: + return pd.DataFrame(cv, index=self._df_index, columns=self._df_columns) + return cv + + def error_members(self): + return self.data + + def median(self): + med = np.median(self.data, axis=0) + if self._df_index is not None and self._df_columns is not None: + return pd.DataFrame(med, index=self._df_index, columns=self._df_columns) + return med + + def std_error(self): + std = np.std(self.data, axis=0) + if self._df_index is not None and self._df_columns is not None: + return pd.DataFrame(std, index=self._df_index, columns=self._df_columns) + return std + + + @_checks_ntkstats_compat + def __add__(self, other): + return self._with_index(self.data + self._other_data(other)) + + @_checks_ntkstats_compat + def __radd__(self, other): + return self._with_index(self._other_data(other) + self.data) + + @_checks_ntkstats_compat + def __sub__(self, other): + return self._with_index(self.data - self._other_data(other)) + + @_checks_ntkstats_compat + def __rsub__(self, other): + return self._with_index(self._other_data(other) - self.data) + + @_checks_ntkstats_compat + def __mul__(self, other): + return self._with_index(self.data * self._other_data(other)) + + @_checks_ntkstats_compat + def __rmul__(self, other): + return self._with_index(self._other_data(other) * self.data) + + @_checks_ntkstats_compat + def __truediv__(self, other): + return self._with_index(self.data / self._other_data(other)) + + @_checks_ntkstats_compat + def __rtruediv__(self, other): + return self._with_index(self._other_data(other) / self.data) + + @property + def T(self) -> NTKStats: + """Transpose each replica's matrix; requires 3D data (Nrep, d1, d2).""" + if self.data.ndim != 3: + raise ValueError( + f"Transpose requires matrix observables (Nrep, d1, d2), got {self.data.shape}" + ) + result = NTKStats(self.data.transpose(0, 2, 1)) + result._df_index = self._df_columns + result._df_columns = self._df_index + return result + + def _get_index(self, other): + """Return (row_index, col_index) for other, supporting DataFrame and NTKStats.""" + if isinstance(other, pd.DataFrame): + return other.index, other.columns + if isinstance(other, NTKStats): + return other._df_index, other._df_columns + return None, None + + def set_index(self, index, columns): + """Return a new NTKStats with the given index and columns.""" + self._df_index = index + self._df_columns = columns + + @_checks_ntkstats_compat + def __matmul__(self, other) -> NTKStats: + _, other_cols = self._get_index(other) + + other_data = other.values if isinstance(other, pd.DataFrame) else self._other_data(other) + + if isinstance(other, NTKStats) and other_data.ndim == 2: + # Vector per replica (Nrep, n): treat as (Nrep, n, 1), multiply, then squeeze. + result_data = (self.data @ other_data[:, :, None]).squeeze(-1) + else: + # Plain 2D matrix (no replica dim): add batch dim so numpy broadcasts over replicas. + if not isinstance(other, NTKStats) and isinstance(other_data, np.ndarray) and other_data.ndim == 2: + other_data = other_data[None] + result_data = self.data @ other_data + result = NTKStats(result_data) + result._df_index = self._df_index + result._df_columns = other_cols + return result + + @_checks_ntkstats_compat + def __rmatmul__(self, other) -> NTKStats: + # treat `other` as the left operand: other @ self + other_index, _ = self._get_index(other) + + other_data = other.values if isinstance(other, pd.DataFrame) else self._other_data(other) + + if self.data.ndim == 2: + # Vector per replica (Nrep, n): treat as (Nrep, n, 1), multiply, then squeeze. + if not isinstance(other, NTKStats) and isinstance(other_data, np.ndarray) and other_data.ndim == 2: + other_data = other_data[None] + result_data = (other_data @ self.data[:, :, None]).squeeze(-1) + else: + # Plain 2D matrix (no replica dim): add batch dim so numpy broadcasts over replicas. + if not isinstance(other, NTKStats) and isinstance(other_data, np.ndarray) and other_data.ndim == 2: + other_data = other_data[None] + result_data = other_data @ self.data + + result = NTKStats(result_data) + result._df_index = other_index + result._df_columns = self._df_columns + return result + + def _assert_eigenvalues(self): + if self.data.ndim != 2: + raise ValueError( + f"Operation requires 1D observables per replica (shape (Nrep, n)), " + f"got {self.data.shape}" + ) + + def _as_diag_matrices(self, vals: np.ndarray) -> NTKStats: + """Build (Nrep, n, n) diagonal matrices from (Nrep, n) values.""" + n = vals.shape[1] + return NTKStats(vals[:, :, None] * np.eye(n)) + + def as_diag(self) -> NTKStats: + """Convert (Nrep, n) eigenvalues to (Nrep, n, n) diagonal matrices.""" + self._assert_eigenvalues() + return self._as_diag_matrices(self.data) + + def exp_kernel(self, t: float) -> NTKStats: + """ + Compute ``diag(1 - exp(-t * λ))`` for each replica. + + Parameters + ---------- + t : float + Time parameter controlling the decay rate. + + Returns + ------- + NTKStats + Shape ``(Nrep, n, n)`` — diagonal matrices per replica. + """ + self._assert_eigenvalues() + return self._as_diag_matrices(1.0 - np.exp(-t * self.data)) + + def exp_kernel_decay(self, t: float) -> NTKStats: + """ + Compute ``diag(exp(-t * λ))`` for each replica. + + Parameters + ---------- + t : float + Time parameter controlling the decay rate. + + Returns + ------- + NTKStats + Shape ``(Nrep, n, n)`` — diagonal matrices per replica. + """ + self._assert_eigenvalues() + return self._as_diag_matrices(np.exp(-t * self.data)) + + def reshape(self, new_shape) -> NTKStats: + """Return a new NTKStats with data reshaped to new_shape. + + If the reshape is from data with ndim = 1 to ndim = 2 (vector to + matrix), the original index (if any) is split into equal parts and + assigned to rows and columns of the new shape. For other reshapes, the + original index is discarded and set to None, since it may not be + meaningful after reshaping. + """ + if (self.ndim == 2 and self.shape[-1] == 1) and len(new_shape) == 2: + # Reshaping from vector to matrix: split index if possible + if self._df_index is not None: + n_rows, n_cols = new_shape + if len(self._df_index) != n_rows * n_cols: + raise ValueError( + f"Cannot reshape with index: original length {len(self._df_index)} " + f"does not match new shape {new_shape}" + ) + row_index = self._df_index.droplevel(-1)[::n_cols] + col_index = self._df_index.droplevel(0)[:n_cols] + else: + row_index = None + col_index = None + else: + # For other reshapes, discard index since it may not be meaningful + row_index = None + col_index = None + + + result = NTKStats(self.data.reshape((self.nreplica, *new_shape))) + result._df_index = row_index + result._df_columns = col_index + return result + + +class NTKGrid(abc.ABC): + """ + Abstract base class for NTK data containers that can be plotted. + + This interface allows plotting utilities to work uniformly with both + eigenvalue and eigenvector data. Each implementation must provide: + - A label identifying the data source (e.g., fit name) + - The x-axis grid for plotting (e.g., epochs or XGRID) + - Methods to extract plotting data for specific ranks + """ + + @property + @abc.abstractmethod + def label(self) -> str: + """Human-readable label for this grid (e.g., fit name).""" + pass + + @property + @abc.abstractmethod + def n_ranks(self) -> int: + """Number of eigenvalue/eigenvector ranks available.""" + pass + + @property + @abc.abstractmethod + def xgrid(self) -> np.ndarray: + """X-axis grid for plotting.""" + pass + + @property + @abc.abstractmethod + def xlabel(self) -> str: + """Label for x-axis.""" + pass + + @abc.abstractmethod + def get_plotting_data(self, rank_index: int, **kwargs) -> NTKStats: + """ + Get plotting data (y-values) for a specific rank. + + Parameters + ---------- + rank_index : int + Index of the eigenvalue/eigenvector rank (0 = largest) + **kwargs + Additional selection parameters as needed by different + implementations (e.g., flavour_index for eigenvectors) + + Returns + ------- + NTKStats + Statistics object containing data of shape (nreplicas, n_xgrid) + """ + pass + + @abc.abstractmethod + def get_plotting_label(self, rank_index: int, **kwargs) -> str: + """ + Get legend label for a specific rank. + + Parameters + ---------- + rank_index : int + Index of the eigenvalue/eigenvector rank + **kwargs + Additional selection parameters + + Returns + ------- + str + LaTeX-formatted label for the legend + """ + pass + +def generate_filename(replica_idx: int, name: str = None) -> str: + """ + Generate a filename for saving NTK eigenvalues based on replica index and an optional name. + + Parameters + ---------- + replica_idx : int + Index of the replica + name : str, optional + Optional name to include in the filename for clarity + + Returns + ------- + str + Generated filename string + """ + if name is None: + return f"{NTK_EIGVAL_TOKEN}_{replica_idx}.npz" + else: + return f"{NTK_EIGVAL_TOKEN}_{name}_{replica_idx}.npz" + + +@lru_cache +def get_parameters_all_epochs(replicas_path, replica_index): + """ + Get paths to model parameters files at all epochs for a given replica. + + Parameters + ---------- + replicas_path : Path + Path to the replicas directory + replica_index : int + Index of the replica to retrieve + + Returns + ------- + dict + Dictionary mapping epoch number to parameter file Path + """ + params_folder = replicas_path / f"replica_{replica_index}/parameters" + param_files = list(params_folder.glob("*.npz")) + param_files.sort(key=lambda f: int(f.stem.split("_")[-1])) + + param_epochs_dict = {} + + for param_file in param_files: + epoch = int(param_file.stem.split("_")[-1]) + param_epochs_dict[epoch] = param_file + + return param_epochs_dict + + +def get_replica_idx_list(replicas_path): + """ + Determine the available replica indices by counting + the replica directories. + + Parameters + ---------- + replicas_path : Path + Path to the replicas directory + + Returns + ------- + list + List of replica indices found + """ + replicas_path = Path(replicas_path) + if not replicas_path.exists(): + raise FileNotFoundError(f"Replicas path does not exist: {replicas_path}") + + # Count directories named "replica_*" + replica_dirs = sorted(replicas_path.glob("replica_*")) + rep_list = [int(d.name.split("_")[1]) for d in replica_dirs] + return rep_list + + +def compute_ntk(pdf_model, params, **kwargs): + """ + Compute the NTK matrix given model parameters. + + Parameters + ---------- + pdf_model : PDFModel + The PDF model instance + params : dict + Model parameters + **kwargs + Additional arguments for the pdf_model.grid_values_func (e.g., exclude_layers + for the n3fit model) + + Returns + ------- + ntk : jnp.ndarray + The NTK matrix + ntk_shape : tuple + Shape of the NTK matrix + """ + pdf_func = pdf_model.grid_values_func(XGRID, **kwargs) + jacobian_func = jax.jacfwd(pdf_func) + jacobian = jacobian_func(params) + + # Compute NTK (nf,ng,nf,ng) -> assumes shape from jacobian + ntk = jnp.einsum("ijk,lmk->ijlm", jacobian, jacobian) + + # Flatten to (nflavors * n_xgrid) × (nflavors * n_xgrid) + d1, d2, d3, d4 = ntk.shape # d1=nf, d2=ng, d3=nf, d4=ng + ntk = ntk.reshape(d1 * d2, d3 * d4, order=NTK_ORDERING) + + return ntk, (d1, d2, d3, d4) + + +def compute_eigendecomposition(ntk_matrix, hermitian=True): + """ + Compute eigendecomposition of an NTK matrix. + + Parameters + ---------- + ntk_matrix : ndarray + The NTK matrix to decompose + hermitian : bool, optional + Whether to use hermitian eigendecomposition (default: True) + + Returns + ------- + eigenvalues : ndarray + Eigenvalues in descending order + eigenvectors : ndarray + Corresponding eigenvectors (columns) + """ + if hermitian: + # For symmetric/hermitian matrices, use eigh for better numerical stability + eigenvalues, eigenvectors = np.linalg.eigh(ntk_matrix) + # Sort in descending order + idx = eigenvalues.argsort()[::-1] + eigenvalues = eigenvalues[idx] + eigenvectors = eigenvectors[:, idx] + else: + # Use SVD for general matrices + eigenvectors, eigenvalues, _ = np.linalg.svd(ntk_matrix) + + return eigenvalues, eigenvectors + + +def compute_eigenvalues_for_replica( + fit_name: str, replicas_path: Path, replica_idx: int, max_epoch=None, name: str = None, **kwargs +): + """ + Compute the NTK eigenvalues for a given replica across all epochs. + + Parameters + ---------- + fit_name : str + Name of the fit (used to load pdf_model) + replicas_path : Path + Path to the replicas directory + replica_idx : int + Replica index to compute + max_epoch : int, optional + Maximum epoch number to consider. + name : str, optional + Optional name to include in the filename for clarity when saving results. + + Returns + ------- + tuple or None + (replica_idx, epochs, ntk_shape) on success, None on failure + """ + try: + # Create fresh pdf_model to avoid JAX tracer leaks + pdf_model = get_pdf_model(fit_name, replica_idx=replica_idx) + param_files = get_parameters_all_epochs(replicas_path, replica_idx) + + eigenvalues_list = [] + epochs = [] + ntk_shape = None + + for epoch, param_file in param_files.items(): + if max_epoch is not None and epoch > max_epoch: + continue + params = jnp.load(param_file)["params"] + + ntk, shape = compute_ntk(pdf_model, params, **kwargs) + if ntk_shape is None: + ntk_shape = shape + + eigvals, _ = compute_eigendecomposition(ntk, hermitian=True) + eigenvalues_list.append(eigvals) + epochs.append(epoch) + + # Stack eigenvalues: (n_epochs, n_eigenvalues) + eigenvalues = np.stack(eigenvalues_list, axis=0) + + # Save immediately to disk + save_replica_eigenvalues( + eigenvalues=eigenvalues, + epochs=epochs, + replica_idx=replica_idx, + replicas_path=replicas_path, + ntk_shape=ntk_shape, + name=name + ) + + return (replica_idx, epochs, ntk_shape) + + except FileNotFoundError as e: + log.warning(f"Skipping replica {replica_idx}: {e}") + return None + + +def get_completed_replicas(replicas_path: Path, name: str = None) -> list: + """ + Utility function to get list of replica indices for which + the NTK eigenvalues have already been computed. + + Parameters + ---------- + replicas_path : Path + Directory containing replica folders. + + Returns + ------- + list + List of completed replica indices + """ + replicas_path = Path(replicas_path) + completed = [] + + for replica_folder in replicas_path.glob("replica_*"): + try: + idx = int(replica_folder.stem.split("_")[1]) + filename = generate_filename(idx, name) + replica_file = replica_folder / f"{filename}" + if replica_file.exists(): + completed.append(idx) + except (ValueError, IndexError) as e: + log.debug(f"Skipping folder {replica_folder.name}: {e}") + continue + + return sorted(completed) + + +def save_replica_eigenvalues( + eigenvalues: np.ndarray, + epochs: list, + replica_idx: int, + replicas_path: Path, + ntk_shape: tuple = None, + name: str = None +) -> None: + """ + Save eigenvalues for a single replica to disk. + + Parameters + ---------- + eigenvalues : np.ndarray + Eigenvalues array of shape (n_epochs, n_eigenvalues) + epochs : list + List of epoch numbers + replica_idx : int + Replica index + replicas_path : Path + Directory to save results + ntk_shape : tuple, optional + Shape of the NTK matrix (saved in metadata) + name: str, optional + Optional name to include in the filename for clarity + """ + filename = generate_filename(replica_idx, name) + replica_file = ( + replicas_path / f"replica_{replica_idx}/{filename}" + ) + np.savez_compressed( + replica_file, + eigenvalues=eigenvalues, + epochs=np.array(epochs), + ntk_shape=ntk_shape, + ) + log.debug(f"Saved eigenvalues for replica {replica_idx} to {replica_file}") + + +def load_replica_eigenvalues(replica_idx: int, cache_dir: Path, name: str = None) -> dict: + """ + Load eigenvalues for a single replica from disk. + + Parameters + ---------- + replica_idx : int + Replica index + cache_dir : Path + Directory containing saved results + name: str, optional + Optional name to include in the filename to specify the set of eigenvalues. + + Returns + ------- + dict + Dictionary with 'eigenvalues' (n_epochs, n_eigenvalues) and 'epochs' + """ + filename = generate_filename(replica_idx, name) + replica_file = ( + cache_dir / f"replica_{replica_idx}/{filename}" + ) + + if not replica_file.exists(): + raise FileNotFoundError(f"Replica {replica_idx} not found at {replica_file}") + + data = np.load(replica_file) + return { + "eigenvalues": data["eigenvalues"], + "epochs": data["epochs"].tolist(), + "ntk_shape": data["ntk_shape"], + } + + +def load_eigenvalues_ensemble( + replicas_path: Path, max_epoch=None, name: str = None +) -> dict: + """ + Load all replica eigenvalues into an ensemble format. + + Parameters + ---------- + replicas_path : Path + Path to replica folders. + max_epoch : int, optional + Maximum epoch to consider. It filters out replicas that do not have + data up to this epoch. + name: str, optional + Optional name to include in the filename to specify the set of eigenvalues. + + Returns + ------- + dict + Dictionary with keys: + - 'eigenvalues_by_epoch': dict mapping epoch -> ndarray (n_replicas, n_eigenvalues) + - 'epochs': list of epochs + - 'ntk_shape': shape of NTK matrix + - 'replica_indices': list of replica indices included + """ + completed_replicas = get_completed_replicas(replicas_path, name) + + if not completed_replicas: + raise ValueError(f"No completed replicas found in {replicas_path}") + + # Load all replicas + all_eigenvalues = [] + included_replicas = [] + ntk_shape = None + for replica_idx in completed_replicas: + data = load_replica_eigenvalues(replica_idx, replicas_path, name) + epochs = np.array(data["epochs"]) + eigenvalues = data["eigenvalues"] # (n_epochs, n_eigenvalues) + + if ntk_shape is None: + ntk_shape = data["ntk_shape"] + + # If max_epoch is set, filter epochs and eigenvalues + if max_epoch is not None: + if max_epoch not in epochs: + log.warning( + f"Replica {replica_idx} does not contain epoch {max_epoch}. " + f"Last epoch is {epochs[-1]}. Excluded from ensemble." + ) + continue + mask = [e <= max_epoch for e in epochs] + epochs = epochs[mask] + eigenvalues = eigenvalues[mask] + + all_eigenvalues.append((replica_idx, epochs, eigenvalues)) + included_replicas.append(replica_idx) + + if not all_eigenvalues: + raise ValueError("No replicas have epochs up to the specified max_epoch.") + + # Determine common epochs across included replicas + all_epoch_sets = [set(epochs) for _, epochs, _ in all_eigenvalues] + common_epochs = sorted(set.intersection(*all_epoch_sets)) + if not common_epochs: + raise ValueError("No common epochs found across replicas.") + + eigenvalues_by_epoch = {epoch: [] for epoch in common_epochs} + for replica_idx, epochs, eigenvalues in all_eigenvalues: + epoch_to_idx = {e: i for i, e in enumerate(epochs)} + for epoch in common_epochs: + idx = epoch_to_idx[epoch] + eigenvalues_by_epoch[epoch].append(eigenvalues[idx]) + + # Stack into arrays + for epoch in common_epochs: + eigenvalues_by_epoch[epoch] = np.stack(eigenvalues_by_epoch[epoch], axis=0) + + log.info( + f"Loaded eigenvalues ensemble: {len(included_replicas)} replicas, " + f"{len(common_epochs)} epochs" + ) + + return { + "eigenvalues_by_epoch": eigenvalues_by_epoch, + "epochs": common_epochs, + "ntk_shape": ntk_shape, + "replica_indices": included_replicas, + } + + +def compute_eigenvectors_at_epoch_for_replica( + fit_name: str, + replicas_path: Path, + replica_idx: int, + epoch: int, + **kwargs +): + """ + Compute the eigenvectors of the NTK at a given epoch for a specific replica. + + Parameters + ---------- + fit_name : str + Name of the fit (used to load pdf_model) + replicas_path : Path + Path to the replicas directory + replica_idx : int + Replica index to compute + epoch : int + Epoch number at which to compute eigenvectors + **kwargs + Additional arguments for the pdf_model.grid_values_func (e.g., exclude_layers + for the n3fit model) + """ + try: + pdf_model = get_pdf_model(fit_name, replica_idx=replica_idx) + param_files = get_parameters_all_epochs(replicas_path, replica_idx) + + params_at_epoch = param_files.get(epoch, None) + if params_at_epoch is None: + raise ValueError( + f"Epoch {epoch} not found for replica {replica_idx} in fit {fit_name}" + ) + + params = jnp.load(params_at_epoch)["params"] + ntk, shape = compute_ntk(pdf_model, params, **kwargs) + _, eigvecs = compute_eigendecomposition(ntk, hermitian=True) + + return (replica_idx, epoch, eigvecs, shape) + except Exception as e: + log.warning(f"Skipping replica {replica_idx}: {e}") + return None diff --git a/colibri/ntk/plotntk.py b/colibri/ntk/plotntk.py new file mode 100644 index 00000000..09a20242 --- /dev/null +++ b/colibri/ntk/plotntk.py @@ -0,0 +1,436 @@ +""" +colibri.ntk.plotntk.py + +Plotting utilities for NTK eigenvalues and eigenvectors. + +Design: +- Single `ntk_plot_provider` function handles all cases +- Draw styles: "band" (uncertainty bands) or "replicas" (individual lines) +- Iteration modes: "by_rank" or "by_fit" +""" + +import warnings +from collections import namedtuple +from dataclasses import dataclass, field +from functools import partial +from typing import Any, Iterator, List, Optional, Tuple + +import matplotlib.patches as mpatches +import numpy as np +from matplotlib import rc + +from validphys import plotutils + +from colibri.constants import FLAVOURS_ID_MAPPINGS +from colibri.ntk.ntkutils import NTKGrid, NTKStats + +rc("font", **{"family": "sans-serif", "sans-serif": ["Helvetica"]}) +rc("text", usetex=True) +rc("text.latex", preamble=r"\usepackage{amsmath,amssymb}") + + +HandlerSpec = namedtuple("HandlerSpec", ["color", "alpha"]) + + +@dataclass +class PlotResult: + """Result from plotting a single figure.""" + + fig: Any + ax: Any + name: str + title: str + handles: list = field(default_factory=list) + labels: list = field(default_factory=list) + + +class ComposedHandler: + """Legend handler for plots with uncertainty bands.""" + + def legend_artist(self, legend, orig_handle, fontsize, handlebox): + x0, y0 = handlebox.xdescent, handlebox.ydescent + width, height = handlebox.width, handlebox.height + + patch = mpatches.Rectangle( + [x0, y0], + width, + height, + facecolor=orig_handle.color, + alpha=orig_handle.alpha, + edgecolor="none", + transform=handlebox.get_transform(), + ) + line = mpatches.Rectangle( + [x0, y0 + height / 2 - height * 0.05], + width, + height * 0.1, + facecolor=orig_handle.color, + alpha=1, + edgecolor="none", + transform=handlebox.get_transform(), + ) + handlebox.add_artist(patch) + handlebox.add_artist(line) + return [patch, line] + + +# ============================================================================= +# Drawing functions +# ============================================================================= + + +def draw_band( + ax, + xgrid: np.ndarray, + stats: NTKStats, + label: str, + error_type: str = "mean", + handles=None, + labels=None, +): + """ + Draw data with uncertainty band. + + Returns array of plotted values for axis scaling. + """ + color = ax._get_lines.get_next_color() + alpha = 0.3 + + if error_type == "median": + ax.plot(xgrid, stats.median(), color=color, linewidth=1.5) + with warnings.catch_warnings(): + warnings.simplefilter("ignore", RuntimeWarning) + err68down, err68up = stats.errorbar68() + lower, upper = err68up, err68down + elif error_type == "mean": + ax.plot(xgrid, stats.central_value(), color=color, linewidth=1.5) + lower, upper = stats.errorbarstd() + else: + raise ValueError(f"Unknown error_type '{error_type}'") + + ax.fill_between(xgrid, lower, upper, color=color, alpha=alpha, zorder=1) + + if handles is not None and labels is not None: + handles.append(HandlerSpec(color=color, alpha=alpha)) + labels.append(label) + + return np.array([lower, upper]) + + +def draw_replicas(ax, xgrid, stats, label, **kwargs): + """ + Draw individual replica lines with mean overlay. + + Returns array of plotted values for axis scaling. + """ + color = ax._get_lines.get_next_color() + data = stats.data + ax.plot(xgrid, data.T, alpha=0.2, linewidth=0.5, color=color, zorder=1) + ax.plot(xgrid, stats.central_value(), color=color, linewidth=2, label=label) + return data + + +# ============================================================================= +# Iteration utilities +# ============================================================================= + + +def iter_by_rank( + grids: List[NTKGrid], rank_indices: List[int], extra_kwargs: Optional[dict] = None +): + """ + Yield (rank_index, items) where items is list of (stats, label, xgrid) per grid. + """ + extra_kwargs = extra_kwargs or {} + for rank_index in rank_indices: + items = [] + for grid in grids: + stats = grid.get_plotting_data(rank_index, **extra_kwargs) + items.append((stats, grid.label, grid.xgrid)) + yield rank_index, items + + +def iter_by_fit( + grids: List[NTKGrid], rank_indices: List[int], extra_kwargs: Optional[dict] = None +): + """ + Yield (grid, items) where items is list of (stats, label, xgrid) per rank. + """ + extra_kwargs = extra_kwargs or {} + for grid in grids: + items = [] + for rank_index in rank_indices: + stats = grid.get_plotting_data(rank_index, **extra_kwargs) + label = grid.get_plotting_label(rank_index, **extra_kwargs) + items.append((stats, label, grid.xgrid)) + yield grid, items + + +# ============================================================================= +# Main plotting function +# ============================================================================= + + +def ntk_plot_provider( + grids: List[NTKGrid], + rank_indices: Optional[List[int]] = None, + iterator_fn=iter_by_rank, + draw_fn=draw_band, + custom_handler=ComposedHandler, + xscale: Optional[str] = None, + yscale: Optional[str] = None, + ymin: Optional[float] = None, + ymax: Optional[float] = None, + title_fn=None, + name_fn=None, + ylabel_fn=None, +) -> Iterator[Tuple[Any, PlotResult]]: + """ + Unified NTK plotting function for eigenvalues and eigenvectors. + + Parameters + ---------- + grids : list of NTKGrid + Data containers (EigenvalueGrid or EigenvectorGrid) + rank_indices : list of int, optional + Which ranks to plot. Default: first 5. + iterator_fn : callable + Function to iterate over grids and ranks (e.g., iter_by_rank or + iter_by_fit) + draw_fn : callable + Function to draw data (e.g., draw_band or draw_replicas) + custom_handler : callable, optional + Custom legend handler class. If None, uses default legend. + xscale, yscale : str + Axis scales ("linear" or "log") + ymin, ymax : float, optional + Y-axis limits + title_fn : callable, optional + Custom title function: (rank_index, grid) -> str + name_fn : callable, optional + Custom name function: (rank_index, grid) -> str + ylabel_fn : callable, optional + Custom ylabel function: (rank_index, grid) -> str + + Yields + ------ + tuple + (fig, PlotResult) pairs + """ + if not grids: + return + + # Determine rank indices + if rank_indices is None: + max_ranks = min(grid.n_ranks for grid in grids) + rank_indices = list(range(min(5, max_ranks))) + + # Get common xgrid + common_xgrid = grids[0].xgrid + xlabel = grids[0].xlabel + + iterator = iterator_fn(grids, rank_indices) + for grid, items in iterator: + fig, ax = plotutils.subplots(figsize=(8, 6)) + handles, labels_list = [], [] + all_vals = [] + + title = title_fn(grid) + name = name_fn(grid) + ylabel = ylabel_fn(grid) + ax.set_title(title) + + # Draw each item + for stats, label, xgrid in items: + vals = draw_fn( + ax, + xgrid, + stats, + label, + handles=handles, + labels=labels_list, + ) + if vals is not None: + all_vals.append(np.atleast_2d(vals)) + + # Configure axes + if xscale and xscale != "linear": + ax.set_xscale(xscale) + if yscale and yscale != "linear": + ax.set_yscale(yscale) + + if all_vals: + plotutils.frame_center(ax, common_xgrid, np.concatenate(all_vals)) + if ymin is not None: + ax.set_ylim(bottom=ymin) + if ymax is not None: + ax.set_ylim(top=ymax) + + ax.set_xlabel(xlabel) + ax.set_ylabel(ylabel) + ax.set_xlim(common_xgrid[0], common_xgrid[-1]) + ax.set_axisbelow(True) + ax.grid(True, alpha=0.3) + + # Legend + if custom_handler is None: + ax.legend() + else: + ax.legend(handles, labels_list, handler_map={HandlerSpec: custom_handler()}) + + result = PlotResult( + fig=fig, ax=ax, name=name, title=title, handles=handles, labels=labels_list + ) + yield fig, result + + +# ============================================================================= +# Convenience functions +# ============================================================================= + + +def plot_eigvals_by_rank( + eigval_grids_by_fit, + rank_indices: Optional[list] = None, + error_type: str = "mean", + xscale: Optional[str] = None, + yscale: Optional[str] = None, + ymin: Optional[float] = None, + ymax: Optional[float] = None, +): + """Plot eigenvalues, one figure per rank showing all fits.""" + yield from ntk_plot_provider( + eigval_grids_by_fit, + rank_indices, + draw_fn=partial(draw_band, error_type=error_type), + iterator_fn=iter_by_rank, + title_fn=lambda rank_index: rf"$\lambda^{{({rank_index})}}$", + name_fn=lambda rank_index: f"lambda_{rank_index}", + ylabel_fn=lambda rank_index: rf"$\lambda^{{({rank_index})}}$", + xscale=xscale, + yscale=yscale, + ymin=ymin, + ymax=ymax, + ) + + +def plot_eigvals_by_fit( + eigval_grids_by_fit, + rank_indices: Optional[list] = None, + error_type: str = "mean", + xscale: Optional[str] = None, + yscale: Optional[str] = None, + ymin: Optional[float] = None, + ymax: Optional[float] = None, +): + """Plot eigenvalues, one figure per fit showing multiple ranks.""" + yield from ntk_plot_provider( + eigval_grids_by_fit, + rank_indices, + draw_fn=partial(draw_band, error_type=error_type), + iterator_fn=iter_by_fit, + title_fn=lambda grid: grid.label, + name_fn=lambda grid: f"eigvals_{grid.label}", + ylabel_fn=lambda _: r"$\textrm{NTK eigenvalues}$", + xscale=xscale, + yscale=yscale, + ymin=ymin, + ymax=ymax, + ) + + +def plot_eigvals_replicas_by_rank( + eigval_grids_by_fit, + rank_indices: Optional[list] = None, + xscale: Optional[str] = None, + yscale: Optional[str] = None, + ymin: Optional[float] = None, + ymax: Optional[float] = None, +): + """Plot eigenvalue replicas, one figure per rank.""" + yield from ntk_plot_provider( + eigval_grids_by_fit, + rank_indices, + draw_fn=draw_replicas, + iterator_fn=iter_by_rank, + title_fn=lambda rank_index: rf"$\lambda^{{({rank_index})}}$", + name_fn=lambda rank_index: f"lambda_replicas_{rank_index}", + ylabel_fn=lambda rank_index: rf"$\lambda^{{({rank_index})}}$", + xscale=xscale, + yscale=yscale, + ymin=ymin, + ymax=ymax, + custom_handler=None, + ) + + +def plot_eigenvectors_by_rank_and_flavour( + eigvecs_grids_by_fit, + flavour_indices: list, + error_type: str = "mean", + rank_indices: Optional[list] = None, + xscale: Optional[str] = None, + yscale: Optional[str] = None, + ymin: Optional[float] = None, + ymax: Optional[float] = None, +): + """ + Plot eigenvector components, one figure per (rank, flavour). + """ + if rank_indices is None: + max_ranks = min(grid.n_ranks for grid in eigvecs_grids_by_fit) + rank_indices = list(range(min(5, max_ranks))) + + for flavour_index in flavour_indices: + flavour_name = FLAVOURS_ID_MAPPINGS[flavour_index] + yield from ntk_plot_provider( + eigvecs_grids_by_fit, + rank_indices=rank_indices, + iterator_fn=partial( + iter_by_rank, extra_kwargs={"flavour_index": flavour_index} + ), + draw_fn=partial(draw_band, error_type=error_type), + title_fn=lambda _: f"${flavour_name}$", + name_fn=lambda rank_index: f"eigvec_{rank_index + 1}_{flavour_name}", + ylabel_fn=lambda _: f"${flavour_name}$", + xscale=xscale, + yscale=yscale, + ymin=ymin, + ymax=ymax, + ) + + +def plot_eigenvectors_by_fit_and_flavour( + eigvecs_grids_by_fit, + flavour_indices: list, + error_type: str = "mean", + rank_indices: Optional[list] = None, + xscale: Optional[str] = None, + yscale: Optional[str] = None, + ymin: Optional[float] = None, + ymax: Optional[float] = None, +): + """ + Plot eigenvector components, one figure per (fit, flavour). + """ + if rank_indices is None: + max_ranks = min(grid.n_ranks for grid in eigvecs_grids_by_fit) + rank_indices = list(range(min(5, max_ranks))) + + for flavour_index in flavour_indices: + flavour_name = FLAVOURS_ID_MAPPINGS[flavour_index] + yield from ntk_plot_provider( + eigvecs_grids_by_fit, + rank_indices=rank_indices, + iterator_fn=partial( + iter_by_fit, extra_kwargs={"flavour_index": flavour_index} + ), + draw_fn=partial(draw_band, error_type=error_type), + title_fn=lambda grid: f"{grid.label} - ${flavour_name}$", + name_fn=lambda _: f"eigvecs_{flavour_name}", + ylabel_fn=lambda _: f"${flavour_name}$", + xscale=xscale, + yscale=yscale, + ymin=ymin, + ymax=ymax, + ) diff --git a/colibri/utils.py b/colibri/utils.py index 1e8154ad..105d990c 100644 --- a/colibri/utils.py +++ b/colibri/utils.py @@ -172,10 +172,10 @@ def resample_from_ns_posterior( def get_fit_path(fit): - fit_path = pathlib.Path(sys.prefix) / "share/colibri/results" / fit + fit_path = pathlib.Path(sys.prefix) / "share/NNPDF/results" / fit if not os.path.exists(fit_path): raise FileNotFoundError( - "Could not find a fit " + fit + " in the colibri/results directory." + "Could not find a fit " + fit + " in the NNPDF/results directory." ) return pathlib.Path(fit_path) @@ -210,7 +210,7 @@ def get_full_posterior(colibri_fit): return df -def get_pdf_model(colibri_fit): +def get_pdf_model(colibri_fit, replica_idx=None): """ Given a colibri fit, returns the PDF model. @@ -218,6 +218,10 @@ def get_pdf_model(colibri_fit): ---------- colibri_fit : str The name of the fit to read. + replica_idx : int, optional + Temporary workaround to specify the preprocessing factors for the n3fit model. + If pdf_model is a dictionary, it is automatically converted to an N3FitPDFModel. If + this is not the case, the replica_idx is ignored. Returns @@ -237,6 +241,31 @@ def get_pdf_model(colibri_fit): with open(pdf_model_path, "rb") as file: pdf_model = dill.load(file) + if isinstance(pdf_model, dict) and "_init_args" in pdf_model: + from colibri_n3fit.model import N3FitPDFModel + if pdf_model.get("nnseed", None) is None: + # Fetch nnseed from the runcard + import yaml + with open(fit_path / "filter.yml", "r") as runcard_file: + runcard = yaml.safe_load(runcard_file) + pdf_model['_init_args'].update({"nnseed": runcard.get("nnseed", None)}) + + pdf_model = N3FitPDFModel(**(pdf_model['_init_args'] | {"replica_index": 1})) + + if replica_idx is not None: + # Load preprocessing factors + import json + from n3fit.backends.keras_backend.MetaModel import PREPROCESSING_LAYER_ALL_REPLICAS + prepr_weights = [] + path = fit_path / f"nnfit/replica_{replica_idx}" + with open(path / f"{colibri_fit}.json", "r", encoding="utf-8") as f: + jf = json.load(f) + prep = jf['preprocessing'] + for fl in prep: + prepr_weights.append(np.array([[fl["smallx"]]])) + prepr_weights.append(np.array([[fl["largex"]]])) + pdf_model.n3fit_model.get_layer(PREPROCESSING_LAYER_ALL_REPLICAS).set_weights(prepr_weights) + return pdf_model