diff --git a/osl_dynamics/analysis/post_hoc.py b/osl_dynamics/analysis/post_hoc.py index 53be2aa9f..c757f0c3e 100644 --- a/osl_dynamics/analysis/post_hoc.py +++ b/osl_dynamics/analysis/post_hoc.py @@ -778,7 +778,14 @@ def partial_covariances(data, alpha): return np.squeeze(pcovs) -def hmm_dual_estimation(data, alpha, zero_mean=False, eps=1e-5, n_jobs=1): +def hmm_dual_estimation( + data, + alpha, + zero_mean=False, + diagonal_covariances=False, + eps=1e-5, + n_jobs=1, +): """HMM dual estimation of observation model parameters. Parameters @@ -791,6 +798,9 @@ def hmm_dual_estimation(data, alpha, zero_mean=False, eps=1e-5, n_jobs=1): or (n_subjects, n_samples, n_states). zero_mean : bool, optional Should we force the state means to be zero? + diagonal_covariances : bool, optional + If True, estimate diagonal covariance matrices (variances only) + and return them as full matrices with zeros off-diagonal. eps : float, optional Small value to add to the diagonal of each state covariance. n_jobs : int, optional @@ -805,6 +815,8 @@ def hmm_dual_estimation(data, alpha, zero_mean=False, eps=1e-5, n_jobs=1): covariances : np.ndarray or list of np.ndarray State covariances. Shape is (n_states, n_channels, n_channels) or (n_subjects, n_states, n_channels, n_channels). + When ``diagonal_covariances=True``, the returned matrices are diagonal + (zeros off-diagonal) and encode per-channel variances only. """ # Validation @@ -856,23 +868,45 @@ def _calc(a, x): c = np.zeros([n_states, n_channels, n_channels]) for i in range(n_states): - if x.size <= memory_threshold: - d = x - m[i] - c[i] = ( - np.sum(d[:, :, None] * d[:, None, :] * a[:, i, None, None], axis=0) - / sum_a[i] - ) + if diagonal_covariances: + # Diagonal-only (variance) case + if x.size <= memory_threshold: + d = x - m[i] + # Weighted second moment per channel + diag_vals = np.sum((d**2) * a[:, i, None], axis=0) / sum_a[i] + else: + # Chunked version to avoid memory overflow + diag_vals = np.zeros(n_channels) + for start in range(0, n_samples, seq_length): + end = min(start + seq_length, n_samples) + d = x[start:end] - m[i] + diag_vals += np.sum((d**2) * a[start:end, i, None], axis=0) + diag_vals /= sum_a[i] + + # Add epsilon only to the diagonal + diag_vals = diag_vals + eps + c[i] = np.diag(diag_vals) + else: - # If the data is too large, calculate in chunks to avoid memory overflow. - for start in range(0, n_samples, seq_length): - end = min(start + seq_length, n_samples) - d = x[start:end] - m[i] - c[i] += np.sum( - d[:, :, None] * d[:, None, :] * a[start:end, i, None, None], - axis=0, + if x.size <= memory_threshold: + d = x - m[i] + c[i] = ( + np.sum( + d[:, :, None] * d[:, None, :] * a[:, i, None, None], axis=0 + ) + / sum_a[i] ) - c[i] /= sum_a[i] - c[i] += eps * np.eye(n_channels) + else: + # If the data is too large, calculate in chunks to avoid memory overflow. + for start in range(0, n_samples, seq_length): + end = min(start + seq_length, n_samples) + d = x[start:end] - m[i] + c[i] += np.sum( + d[:, :, None] * d[:, None, :] * a[start:end, i, None, None], + axis=0, + ) + c[i] /= sum_a[i] + c[i] += eps * np.eye(n_channels) return m, c @@ -905,6 +939,7 @@ def hmm_features( alpha, sampling_frequency=None, zero_mean=False, + diagonal_covariances=False, eps=1e-5, use_partial=False, n_jobs=1, @@ -924,11 +959,15 @@ def hmm_features( are unitless. zero_mean : bool, optional Should we force the state means to be zero? + diagonal_covariances : bool, optional + If True, estimate diagonal covariance matrices (variances only) + and return them as full matrices with zeros off-diagonal. eps : float, optional Small value to add to the diagonal of each state covariance. use_partial : bool, optional Should we use the partial state correlation matrix rather than - the full state covariance matrix? + the full state covariance matrix? For diagonal covariances, this + reduces to per-channel variance terms on the diagonal. n_jobs : int, optional Number of jobs to run in parallel. @@ -977,11 +1016,26 @@ def _calc(a, x): trans_prob = trans_prob.flatten() # Observation model parameters - m, c = hmm_dual_estimation(x, a, zero_mean=zero_mean, eps=eps, n_jobs=None) - i, j = np.triu_indices(c.shape[-1]) + m, c = hmm_dual_estimation( + x, + a, + zero_mean=zero_mean, + eps=eps, + n_jobs=None, + diagonal_covariances=diagonal_covariances, + ) + if use_partial: c = array_ops.cov2partialcorr(c) - c = c[..., i, j] + + if diagonal_covariances: + # Diagonal covariance: use the diagonal elements only + c = np.diagonal(c, axis1=-2, axis2=-1) + else: + # Full covariance: use upper-triangular elements + i, j = np.triu_indices(c.shape[-1]) + c = c[..., i, j] + if zero_mean: obs_mod = c else: diff --git a/osl_dynamics/models/hmm.py b/osl_dynamics/models/hmm.py index 69dc9b713..3026bf9cb 100644 --- a/osl_dynamics/models/hmm.py +++ b/osl_dynamics/models/hmm.py @@ -412,6 +412,8 @@ def dual_estimation(self, training_data, alpha=None, concatenate=False, n_jobs=1 covariances : np.ndarray Session-specific covariances. Shape is (n_sessions, n_states, n_channels, n_channels). + When ``config.diagonal_covariances=True``, the matrices are + diagonal (zeros off-diagonal) and encode per-channel variances only. """ if alpha is None: # Get the posterior @@ -442,6 +444,7 @@ def dual_estimation(self, training_data, alpha=None, concatenate=False, n_jobs=1 data, alpha, zero_mean=(not self.config.learn_means), + diagonal_covariances=self.config.diagonal_covariances, eps=self.config.covariances_epsilon, n_jobs=n_jobs, )