Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 74 additions & 20 deletions osl_dynamics/analysis/post_hoc.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,7 +778,14 @@ def partial_covariances(data, alpha):
return np.squeeze(pcovs)


def hmm_dual_estimation(data, alpha, zero_mean=False, eps=1e-5, n_jobs=1):
def hmm_dual_estimation(
data,
alpha,
zero_mean=False,
diagonal_covariances=False,
eps=1e-5,
n_jobs=1,
):
"""HMM dual estimation of observation model parameters.

Parameters
Expand All @@ -791,6 +798,9 @@ def hmm_dual_estimation(data, alpha, zero_mean=False, eps=1e-5, n_jobs=1):
or (n_subjects, n_samples, n_states).
zero_mean : bool, optional
Should we force the state means to be zero?
diagonal_covariances : bool, optional
If True, estimate diagonal covariance matrices (variances only)
and return them as full matrices with zeros off-diagonal.
eps : float, optional
Small value to add to the diagonal of each state covariance.
n_jobs : int, optional
Expand All @@ -805,6 +815,8 @@ def hmm_dual_estimation(data, alpha, zero_mean=False, eps=1e-5, n_jobs=1):
covariances : np.ndarray or list of np.ndarray
State covariances. Shape is (n_states, n_channels, n_channels)
or (n_subjects, n_states, n_channels, n_channels).
When ``diagonal_covariances=True``, the returned matrices are diagonal
(zeros off-diagonal) and encode per-channel variances only.
"""

# Validation
Expand Down Expand Up @@ -856,23 +868,45 @@ def _calc(a, x):

c = np.zeros([n_states, n_channels, n_channels])
for i in range(n_states):
if x.size <= memory_threshold:
d = x - m[i]
c[i] = (
np.sum(d[:, :, None] * d[:, None, :] * a[:, i, None, None], axis=0)
/ sum_a[i]
)
if diagonal_covariances:
# Diagonal-only (variance) case
if x.size <= memory_threshold:
d = x - m[i]
# Weighted second moment per channel
diag_vals = np.sum((d**2) * a[:, i, None], axis=0) / sum_a[i]
else:
# Chunked version to avoid memory overflow
diag_vals = np.zeros(n_channels)
for start in range(0, n_samples, seq_length):
end = min(start + seq_length, n_samples)
d = x[start:end] - m[i]
diag_vals += np.sum((d**2) * a[start:end, i, None], axis=0)
diag_vals /= sum_a[i]

# Add epsilon only to the diagonal
diag_vals = diag_vals + eps
c[i] = np.diag(diag_vals)

else:
# If the data is too large, calculate in chunks to avoid memory overflow.
for start in range(0, n_samples, seq_length):
end = min(start + seq_length, n_samples)
d = x[start:end] - m[i]
c[i] += np.sum(
d[:, :, None] * d[:, None, :] * a[start:end, i, None, None],
axis=0,
if x.size <= memory_threshold:
d = x - m[i]
c[i] = (
np.sum(
d[:, :, None] * d[:, None, :] * a[:, i, None, None], axis=0
)
/ sum_a[i]
)
c[i] /= sum_a[i]
c[i] += eps * np.eye(n_channels)
else:
# If the data is too large, calculate in chunks to avoid memory overflow.
for start in range(0, n_samples, seq_length):
end = min(start + seq_length, n_samples)
d = x[start:end] - m[i]
c[i] += np.sum(
d[:, :, None] * d[:, None, :] * a[start:end, i, None, None],
axis=0,
)
c[i] /= sum_a[i]
c[i] += eps * np.eye(n_channels)

return m, c

Expand Down Expand Up @@ -905,6 +939,7 @@ def hmm_features(
alpha,
sampling_frequency=None,
zero_mean=False,
diagonal_covariances=False,
eps=1e-5,
use_partial=False,
n_jobs=1,
Expand All @@ -924,11 +959,15 @@ def hmm_features(
are unitless.
zero_mean : bool, optional
Should we force the state means to be zero?
diagonal_covariances : bool, optional
If True, estimate diagonal covariance matrices (variances only)
and return them as full matrices with zeros off-diagonal.
eps : float, optional
Small value to add to the diagonal of each state covariance.
use_partial : bool, optional
Should we use the partial state correlation matrix rather than
the full state covariance matrix?
the full state covariance matrix? For diagonal covariances, this
reduces to per-channel variance terms on the diagonal.
n_jobs : int, optional
Number of jobs to run in parallel.

Expand Down Expand Up @@ -977,11 +1016,26 @@ def _calc(a, x):
trans_prob = trans_prob.flatten()

# Observation model parameters
m, c = hmm_dual_estimation(x, a, zero_mean=zero_mean, eps=eps, n_jobs=None)
i, j = np.triu_indices(c.shape[-1])
m, c = hmm_dual_estimation(
x,
a,
zero_mean=zero_mean,
eps=eps,
n_jobs=None,
diagonal_covariances=diagonal_covariances,
)

if use_partial:
c = array_ops.cov2partialcorr(c)
c = c[..., i, j]

if diagonal_covariances:
# Diagonal covariance: use the diagonal elements only
c = np.diagonal(c, axis1=-2, axis2=-1)
else:
# Full covariance: use upper-triangular elements
i, j = np.triu_indices(c.shape[-1])
c = c[..., i, j]

if zero_mean:
obs_mod = c
else:
Expand Down
3 changes: 3 additions & 0 deletions osl_dynamics/models/hmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,8 @@ def dual_estimation(self, training_data, alpha=None, concatenate=False, n_jobs=1
covariances : np.ndarray
Session-specific covariances.
Shape is (n_sessions, n_states, n_channels, n_channels).
When ``config.diagonal_covariances=True``, the matrices are
diagonal (zeros off-diagonal) and encode per-channel variances only.
"""
if alpha is None:
# Get the posterior
Expand Down Expand Up @@ -442,6 +444,7 @@ def dual_estimation(self, training_data, alpha=None, concatenate=False, n_jobs=1
data,
alpha,
zero_mean=(not self.config.learn_means),
diagonal_covariances=self.config.diagonal_covariances,
eps=self.config.covariances_epsilon,
n_jobs=n_jobs,
)
Expand Down