diff --git a/examples/1D_example/run_gmm_example.py b/examples/1D_example/run_gmm_example.py index 8327ddf..ed17c47 100644 --- a/examples/1D_example/run_gmm_example.py +++ b/examples/1D_example/run_gmm_example.py @@ -41,6 +41,7 @@ def __init__(self, n_cats, temperature=1.0, mean_norm="sigmoid"): temperature=temperature, mean_norm=mean_norm, mean_range=(0.0, 1.0), + cov_offdiag_damping=0.1, ) # dummy NN output is already in (0,1) def call(self, data_dict): diff --git a/examples/bumphunt_example/run_example.py b/examples/bumphunt_example/run_example.py index c37bdbf..d3ae06b 100644 --- a/examples/bumphunt_example/run_example.py +++ b/examples/bumphunt_example/run_example.py @@ -36,6 +36,7 @@ def __init__(self, n_cats, temperature=0.5, mass_sigma=1.5): dim=2, temperature=temperature, mean_norm="softmax", + cov_offdiag_damping=0.1, name="gato_diphoton", ) self.mass_center = tf.constant(125.0, dtype=tf.float32) diff --git a/examples/three_class_softmax_example/run_example.py b/examples/three_class_softmax_example/run_example.py index 6a2c0a7..1f23734 100644 --- a/examples/three_class_softmax_example/run_example.py +++ b/examples/three_class_softmax_example/run_example.py @@ -50,6 +50,7 @@ def __init__(self, n_cats, temperature=0.3, name="gato_2D"): dim=2, temperature=temperature, mean_norm="softmax", + cov_offdiag_damping=0.1, name=name ) diff --git a/src/gatohep/models.py b/src/gatohep/models.py index 1e09743..b1de923 100644 --- a/src/gatohep/models.py +++ b/src/gatohep/models.py @@ -68,6 +68,7 @@ def __init__( temperature=1.0, mean_norm: str = "softmax", mean_range: tuple | list = (0.0, 1.0), + cov_offdiag_damping: float = 0.1, name="gato_gmm_model", ): """ @@ -81,6 +82,9 @@ def __init__( Dimensionality of the feature space. temperature : float, optional Temperature parameter for the softmax function. Default is 1.0. + cov_offdiag_damping : float, optional + Multiplicative damping applied to the off-diagonal entries of the + Cholesky factors to stabilise learned covariances. Default is 0.1. name : str, optional Name of the model. Default is "gato_gmm_model". """ @@ -88,6 +92,7 @@ def __init__( self.n_cats = n_cats self.dim = dim self.temperature = temperature + self.cov_offdiag_damping = float(cov_offdiag_damping) self.mixture_logits = tf.Variable( tf.random.normal([n_cats], stddev=0.1), @@ -141,7 +146,7 @@ def get_scale_tril(self): """ L_raw = tf.linalg.band_part(self.unconstrained_L, -1, 0) off = L_raw - tf.linalg.diag(tf.linalg.diag_part(L_raw)) - off = 0.1 * off + off = self.cov_offdiag_damping * off raw_diag = tf.linalg.diag_part(L_raw) sigma = self._sigma_base * tf.exp(raw_diag) return tf.linalg.set_diag(off, sigma)