Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
25b2bc8
Save parameters of the PDF model to disk
achiefa Dec 8, 2025
5926a45
First implementation of the ntk routine
achiefa Dec 8, 2025
232ae72
Add example runcard
achiefa Dec 9, 2025
ac19ea0
restore example runcard
vschutze-alt Dec 9, 2025
658da29
Store parameters in GradientDescentResult
achiefa Dec 9, 2025
959123f
Add new example runcard
achiefa Dec 9, 2025
d12ec10
enable ntk computation from a separate dedicated runcard
vschutze-alt Dec 9, 2025
042f1b1
Apply black formatting
achiefa Dec 9, 2025
168cd71
add example runcard to compute ntk
vschutze-alt Dec 9, 2025
75df249
add more options to the runcard, add plotting function
vschutze-alt Dec 10, 2025
0cd0214
Documentation
ecole41 Dec 10, 2025
7123ce0
Documentation tree
ecole41 Dec 10, 2025
7ad5e46
populate docs
vschutze-alt Dec 12, 2025
e007004
remove runcards from example, as they are in the docs
vschutze-alt Dec 12, 2025
6aedfb3
populate docs, plots in pdf format
vschutze-alt Dec 12, 2025
8327586
add ntk analysis plots and docs
vschutze-alt Dec 12, 2025
a5a9a58
NTK analysis + plot utilities
achiefa Jan 23, 2026
6cfdc5e
Refactoring ntk analysis
achiefa Jan 26, 2026
8c0226c
mark max_epochs as Optional
achiefa Jan 26, 2026
e92df47
Implementing plot utilities for NTK + eigenvectors + refactoring
achiefa Jan 29, 2026
d6c6a77
Adding example notebook; shows how to use the NTK funcionalities thro…
achiefa Feb 6, 2026
00b7cc6
Workaround to use nnpdf model oob from n3fit + correct misspelling
achiefa Feb 24, 2026
72e2635
Allow additional kwargs (for layer selector in n3fit see https://gith…
achiefa Mar 1, 2026
7c6161b
Add temporary fix for compatibility with nnpdf fits
achiefa Mar 23, 2026
3138708
Use frozensets instead of plain dictionaries
achiefa Mar 24, 2026
1a551d1
Convert frozenset to dict inside function
achiefa Mar 24, 2026
8ca4175
Add getters and make flavor ordering more explicit + allow NTKutils t…
achiefa Apr 1, 2026
df4a23a
Remove unused imports
achiefa Apr 1, 2026
20b16d5
Update __rmatmul__ to match with __matmul__ in NTKStats
achiefa Apr 1, 2026
bcac320
Correct spectrum indexing + serialization NTKGrid
achiefa Apr 8, 2026
c5edcff
Add reshape method to NTKStats
achiefa Apr 8, 2026
6d8011f
Add ntk module
achiefa Apr 8, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions colibri/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@
"colibri.param_initialisation",
"colibri.export_results",
"colibri.closure_test",
"colibri.ntk.eigenvalues",
"colibri.ntk.eigenvector",
"colibri.ntk.plotntk",
"colibri.ntk.ntk",
"reportengine.report",
]

Expand Down
23 changes: 21 additions & 2 deletions colibri/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import logging
import os
import shutil
from pathlib import Path

import jax
import jax.numpy as jnp
Expand All @@ -18,10 +19,11 @@
from colibri.constants import FLAVOUR_TO_ID_MAPPING
from colibri.core import IntegrabilitySettings, PriorSettings
from mpi4py import MPI
from reportengine.configparser import ConfigError, explicit_node
from reportengine.configparser import ConfigError, explicit_node, element_of
from validphys import covmats
from validphys.config import Config, Environment
from validphys.config import Config, Environment, _id_with_label
from validphys.fkparser import load_fktable
from validphys.loader import LoadFailedError

comm = MPI.COMM_WORLD
rank = comm.Get_rank()
Expand Down Expand Up @@ -622,3 +624,20 @@ def produce_pdf_model(self):
Returns None as the pdf_model is not used in the colibri module.
"""
return None

def produce_replicas_path(self, fit):
"""
Produces the replicas folder where the fit replicas are stored.
"""
replicas_path = fit.path / "fit_replicas"
replicas_path.mkdir(parents=True, exist_ok=True)
return replicas_path

@element_of("fits")
@_id_with_label
def parse_fit(self, fit: str):
"""A fit in the results folder, containing at least a valid filter result."""
try:
return self.loader.check_fit(fit)
except LoadFailedError as e:
raise ConfigError(str(e), fit, self.loader.available_fits)
6 changes: 6 additions & 0 deletions colibri/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,12 +167,15 @@ class MonteCarloFit:
Array containing the validation loss.
optimized_parameters: jnp.array
Array containing the optimized parameters.
parameters_by_epoch: jnp.array
Array containing the parameters of the model recorded during training.
"""

monte_carlo_specs: dict
training_loss: jnp.array
validation_loss: jnp.array
optimized_parameters: jnp.array
parameters_by_epoch: jnp.array


@dataclass(frozen=True)
Expand All @@ -189,12 +192,15 @@ class GradientDescentResult:
Recorded (epoch) validation losses (sampled according to record_every).
specs: dict
Dictionary of settings used for the run (epochs, batch size, etc.).
parameters_by_epoch: jnp.array
Array containing the parameters of the model recorded during training.
"""

optimized_parameters: Any
training_loss: jnp.array
validation_loss: jnp.array
specs: Dict[str, Any]
parameters_by_epoch: jnp.array


@dataclass(frozen=True)
Expand Down
2 changes: 2 additions & 0 deletions colibri/doc/sphinx/source/theory/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,5 @@ This section discusses some relevant theoretical background to Colibri.
./prior_distributions.rst

./inference_methods.rst

./ntk_theory.rst
6 changes: 6 additions & 0 deletions colibri/doc/sphinx/source/theory/ntk_theory.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
.. _ntk_theory:

===========================
Neural Tangent Kernel (NTK)
===========================

2 changes: 2 additions & 0 deletions colibri/doc/sphinx/source/tutorials/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,5 @@ such as implementing a custom PDF model or using it to fit data.
closure_tests/index

scripts/index

ntk/computing_ntk
79 changes: 79 additions & 0 deletions colibri/doc/sphinx/source/tutorials/ntk/computing_ntk.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
.. _computing_ntk:

=====================================
Computing Neural Tangent Kernel (NTK)
=====================================

This section describes how to compute and analyse this Neural Tangent Kernel (NTK)
during the training of a given fit. For information on what the NTK is, please
refer to the :ref:`NTK theory section <ntk_theory>`.

The NTK can be computed on any fit where the parameter values have been saved during
training. To produce these saved parameter values during a fit set the
`record_parameters` argument to `True` in the fit runcard.

.. code-block:: bash

record_parameters: True
record_every: 100

This above example will save the parameter values every 100 epochs.

Once the parameter values have been saved, the NTK can be computed using the `compute_ntk`
action. An example runcard to compute the NTK is shown below.

.. code-block:: bash

meta:
title: Example NTK determination for a Les Houches fit
author: Colibri User
keywords: [example, PDF plots, NTK]

pdf: {id: "test_ntk_fit"} # generated by the colibri fit

ntk_plots_settings:
ntk_plots: True
n_top_eigenvalues: 10 # Number of top eigenvalues for which evolution is plotted
y_scale: log # Sets the scale for eigenvalue evolution plot.
plot_n_epochs: 5 # The number of epochs for which the eigenvector plots will be plotted.
# These will be equally spaced among all epochs. Default 6.
plot_n_eigenvectors: [1, 2] # Eigenvector plots will be produced for these. Default 1.

actions_:
- compute_ntk

This runcard can be run with the command
``colibri_model_exe compute_ntk.yaml -r replica_n``, where the ``colibri_model_exe``
is the Colibri executable, and ``n`` is the replica number for which the NTK
should be computed. This will produce a folder called ``compute_ntk``. The NTK values
are stored in npz file format in ``compute_ntk/ntk_replicas/replica_n``.

``ntk_plots_settings``
^^^^^^^^^^^^^^^^^^^^^^
* ``ntk_plots``:
Whether analytic plots are produced with the ``compute_ntk`` command. For details on
the plots produced, see the :ref:`analytic NTK plots section below <analytic_ntk_plots>`.
* ``n_top_eigenvalues``:
This is the number of top eigenvalues that will be plotted when ``ntk_plots`` is
``True``. If a number is not specified, a default of 1 eigenvalue will be plotted.
* ``x_scale`` / ``y_scale``:
If set to ``log``, the eigenvalue evolution plot will be produced with a logarithmic
scale on the respective axis. Otherwise a linear scale will be used.
* ``plot_n_epochs``:
The number of epochs for which eigenvector plots will be produced.
These will be equally spaced among all epochs. Default 6.
* ``plot_n_eigenvectors``:
The eigenvectors for which eigenvector plots will be produced. Default 1.

.. _analytic_ntk_plots:

Analytic NTK plots
==================
If the flag ``ntk_plots`` is set to ``True``, the following analytic plots will be
produced and stored in .pdf format in ``compute_ntk/ntk_plots/``.

* **NTK eigenvalue evolution:** Value of the NTK eigenvalues for increasing number of recorded epochs.

* **NTK eigenvector evolution:** Eigenvector evolution for a given number of recorded epochs, for all flavours.

* **Eigenvector heatmaps:** Density of NTK eigenvectors across PDF flavours and x-regions.
Original file line number Diff line number Diff line change
Expand Up @@ -98,4 +98,4 @@ ultranest_settings:


actions_:
- run_ultranest_fit # Choose from ultranest_fit, monte_carlo_fit, analytic_fit
- run_ultranest_fit # Choose from ultranest_fit, monte_carlo_fit, analytic_fit
15 changes: 14 additions & 1 deletion colibri/gradient_descent.py
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @achiefa , thanks for starting this. Just from a quick look, may I suggest to not have this write stuff directly, but rather add things in the GradientDescentResult? This way the writing is delegated to other dedicated functions and you don't need to modify much here. Similarly for the MonteCarloFit class.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @LucaMantani, thanks for your comment. Indeed, we have considered this option, which I agree has a more solid design principle. However, I was worried that storing the parameters for all recorded epochs could yield memory issues during training. If instead we use a buffer that is saved on disk and freed at the end of each epoch, then we avoid any potential memory issue. Maybe this is not a problem at all, and we can simply store all parameters in a big array and then add it to GradientDescentResult.

Just to quantify the problem: For a neural network with 763 parameters (float64), a single array is about 0.01 MB. This is then multiplied by the number of epochs for which we want to save the parameters. For instance, if we have 100 epochs, this adds up to ~1MB for one replica. Again, probably we can afford this in favour of a better code design. What do you think?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think 1 MB is nothing, we load in memory several Gb due to the data and FK tables. Even if one had a model with 1000 parameters, saving it 1000 times would be 8 Mb. So I would say memory is far from being an issue?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. Let's put it in GradientDescentResult then.

Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def run_gradient_descent(
max_epochs: int,
data_batch: colibri.DataBatches = None,
record_every: int = 50,
record_parameters: bool = False,
) -> GradientDescentResult:
"""Generic gradient descent loop.

Expand Down Expand Up @@ -64,7 +65,11 @@ def run_gradient_descent(
Defaults to None, in which case the loss is assumed to not have been batched.

record_every : int, default 50
Record losses every this many epochs.
Record losses every this many epochs. If `record_parameters` is True,
parameters of the model are also recorded.

record_parameters : bool, default False
Whether to record parameters every `record_every` epochs.
"""

params = initial_parameters
Expand All @@ -79,6 +84,7 @@ def _step(p, ostate, batch_idx):

train_losses = []
val_losses = []
parameters_by_epoch = [] if record_parameters else None

if data_batch is None:
# we simulate a fake batch iterator that just yields a dummy batch index
Expand All @@ -102,6 +108,7 @@ def batch_gen():
batch_idx = next(batches_iter)
params, opt_state, batch_loss = _step(params, opt_state, batch_idx)
epoch_train_loss += batch_loss

epoch_val_loss = validation_loss_fn(params)

early_stopper = early_stopper.update(epoch_val_loss)
Expand All @@ -114,6 +121,11 @@ def batch_gen():
train_losses.append(epoch_train_loss)
val_losses.append(epoch_val_loss)

if record_parameters:
if epoch % record_every == 0:
log.info(f"Recording parameters at epoch {epoch}")
parameters_by_epoch.append(params)

if early_stopper.should_stop:
log.info(f"Early stopping at epoch {epoch}")
break
Expand All @@ -127,4 +139,5 @@ def batch_gen():
"batch_size": batch_size,
"record_every": record_every,
},
parameters_by_epoch=jnp.array(parameters_by_epoch),
)
36 changes: 33 additions & 3 deletions colibri/monte_carlo_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ def monte_carlo_fit(
max_epochs,
batch_size=None,
batch_seed=1,
record_parameters=False,
record_every=50,
):
"""
This function performs a Monte Carlo fit.
Expand Down Expand Up @@ -61,14 +63,20 @@ def monte_carlo_fit(
batch_seed: int, optional
Seed used to construct the batches. Defaults to 1.

record_parameters: bool, default False
Whether to monitor the parameters during the Monte Carlo fit.

record_every: int, default 50
Frequency (in epochs) at which to record losses parameters of
model. If record_parameters is False, only losses are recorded.

Returns
-------
MonteCarloFit: The result of the fit with following attributes:
monte_carlo_specs: dict
training_loss: jnp.array
validation_loss: jnp.array
"""

len_tr_idx, len_val_idx = len_trval_data

@jax.jit
Expand Down Expand Up @@ -100,7 +108,8 @@ def loss_validation(parameters):
early_stopper=early_stopper,
max_epochs=max_epochs,
data_batch=data_batch,
record_every=50,
record_every=record_every,
record_parameters=record_parameters,
)

t1 = time.time()
Expand All @@ -111,14 +120,18 @@ def loss_validation(parameters):
"max_epochs": max_epochs,
"batch_size": data_batch.batch_size,
"batch_seed": batch_seed,
"record_every": record_every,
},
training_loss=gd_result.training_loss,
validation_loss=gd_result.validation_loss,
optimized_parameters=gd_result.optimized_parameters,
parameters_by_epoch=gd_result.parameters_by_epoch,
)


def run_monte_carlo_fit(monte_carlo_fit, pdf_model, output_path, replica_index, Q0):
def run_monte_carlo_fit(
monte_carlo_fit, pdf_model, output_path, replica_index, Q0, record_parameters=False
):
"""
Runs the Monte Carlo fit and writes the output to the output directory.

Expand Down Expand Up @@ -178,3 +191,20 @@ def run_monte_carlo_fit(monte_carlo_fit, pdf_model, output_path, replica_index,
index=False,
float_format="%.5e",
)

if record_parameters:
# Save the parameters by epoch if recorded
record_every = mc_fit.monte_carlo_specs["record_every"]
parameters_by_epoch = mc_fit.parameters_by_epoch
params_path = (
str(output_path) + f"/fit_replicas/replica_{replica_index}/parameters/"
)
if not os.path.exists(params_path):
os.makedirs(params_path, exist_ok=True)

for epoch_idx in range(parameters_by_epoch.shape[0]):
epoch = epoch_idx * record_every
jnp.savez(
params_path + f"params_{epoch}.npz",
params=parameters_by_epoch[epoch_idx],
)
34 changes: 34 additions & 0 deletions colibri/ntk/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
"""
colibri.ntk

NTK (Neural Tangent Kernel) analysis module for colibri.

This module provides tools for computing and analyzing the Neural Tangent Kernel
of PDF fits, including eigenvalue/eigenvector computation and plotting utilities.
"""

from colibri.ntk.ntkutils import NTKGrid, NTKStats
from colibri.ntk.eigenvalues import (
EigenvalueGrid,
eigenvalue_grid,
eigenvalues_ensemble,
)
from colibri.ntk.eigenvector import (
EigenvectorGrid,
eigenvector_grid,
eigenvectors_ensemble_at_epoch,
)

__all__ = [
# Abstract interface
"NTKGrid",
"NTKStats",
# Eigenvalue classes and functions
"EigenvalueGrid",
"eigenvalue_grid",
"eigenvalues_ensemble",
# Eigenvector classes and functions
"EigenvectorGrid",
"eigenvector_grid",
"eigenvectors_ensemble_at_epoch",
]
Loading
Loading