Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] Reduce compute time for multivariate coherency methods #184

Merged
merged 10 commits into from
May 30, 2024
116 changes: 60 additions & 56 deletions mne_connectivity/spectral/epochs_multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from typing import Optional

import numpy as np
import scipy as sp
from mne.epochs import BaseEpochs
from mne.parallel import parallel_func
from mne.utils import ProgressBar, logger
Expand Down Expand Up @@ -287,31 +286,73 @@ def _compute_con_daughter(
"""

def _compute_t(self, C_r, n_seeds):
"""Compute transformation matrix, T, for frequencies (in parallel).
"""Compute transformation matrix, T, for frequencies (& times).

Eq. 3 of Ewald et al.; part of Eq. 9 of Vidaurre et al.
"""
parallel, parallel_invsqrtm, _ = parallel_func(
_invsqrtm, self.n_jobs, verbose=False
)
try:
return self._invsqrtm(C_r, n_seeds)
except np.linalg.LinAlgError as error:
raise RuntimeError(
"the transformation matrix of the data could not be computed "
"from the cross-spectral density; check that you are using "
"full rank data or specify an appropriate rank for the seeds "
"and targets that is less than or equal to their ranks"
) from error

def _invsqrtm(self, C_r, n_seeds):
"""Compute inverse sqrt of CSD over frequencies and times.

Parameters
----------
C_r : np.ndarray, shape=(n_freqs, n_times, n_channels, n_channels)
Real part of the CSD. Expected to be symmetric and non-singular.
n_seeds : int
Number of seed channels for the connection.

Returns
-------
T : np.ndarray, shape=(n_freqs, n_times, n_channels, n_channels)
Inverse square root of the real-valued CSD. Name comes from Ewald
et al. (2012).

Notes
-----
This approach is a workaround for computing the inverse square root of
an ND array. SciPy has dedicated functions for this purpose, e.g.
`sp.linalg.fractional_matrix_power(A, -0.5)` or `sp.linalg.inv(
sp.linalg.sqrtm(A))`, however these only work with 2D arrays, meaning
frequencies and times must be looped over which is very slow. There are
no equivalent functions in NumPy for working with ND arrays (as of
v1.26).

The data array is expected to be symmetric and non-singular, otherwise
a LinAlgError is raised.

See Eq. 3 of Ewald et al. (2012). NeuroImage. DOI:
10.1016/j.neuroimage.2011.11.084.
"""
T = np.zeros_like(C_r, dtype=np.float64)

# imag. part of T filled when data is rank-deficient
T = np.zeros(C_r.shape, dtype=np.complex128)
for block_i in ProgressBar(range(self.n_steps), mesg="frequency blocks"):
freqs = self._get_block_indices(block_i, self.n_freqs)
T[:, freqs] = np.array(
parallel(parallel_invsqrtm(C_r[:, f], T[:, f], n_seeds) for f in freqs)
).transpose(1, 0, 2, 3)
# seeds
eigvals, eigvects = np.linalg.eigh(C_r[:, :, :n_seeds, :n_seeds])
T[:, :, :n_seeds, :n_seeds] = np.linalg.inv(
np.matmul(
(eigvects * np.expand_dims(np.sqrt(eigvals), (2))),
eigvects.transpose(0, 1, 3, 2),
)
)

if not np.isreal(T).all() or not np.isfinite(T).all():
raise RuntimeError(
"the transformation matrix of the data must be real-valued "
"and contain no NaN or infinity values; check that you are "
"using full rank data or specify an appropriate rank for the "
"seeds and targets that is less than or equal to their ranks"
# targets
eigvals, eigvects = np.linalg.eigh(C_r[:, :, n_seeds:, n_seeds:])
T[:, :, n_seeds:, n_seeds:] = np.linalg.inv(
np.matmul(
(eigvects * np.expand_dims(np.sqrt(eigvals), (2))),
eigvects.transpose(0, 1, 3, 2),
)
)

return np.real(T) # make T real if check passes
return T

def reshape_results(self):
"""Remove time dimension from results, if necessary."""
Expand All @@ -321,43 +362,6 @@ def reshape_results(self):
self.patterns = self.patterns[..., 0]


def _invsqrtm(C, T, n_seeds):
"""Compute inverse sqrt of CSD over times (used for CaCoh, MIC, & MIM).

Parameters
----------
C : np.ndarray, shape=(n_times, n_channels, n_channels)
CSD for a single frequency and all times (n_times=1 if the mode is not
time-frequency resolved, e.g. multitaper).
T : np.ndarray, shape=(n_times, n_channels, n_channels)
Empty array to store the inverse square root of the CSD in.
n_seeds : int
Number of seed channels for the connection.

Returns
-------
T : np.ndarray, shape=(n_times, n_channels, n_channels)
Inverse square root of the CSD. Name comes from Ewald et al. (2012).

Notes
-----
Kept as a standalone function to allow for parallelisation over CSD
frequencies.

See Eq. 3 of Ewald et al. (2012). NeuroImage. DOI:
10.1016/j.neuroimage.2011.11.084.
"""
for time_i in range(C.shape[0]):
T[time_i, :n_seeds, :n_seeds] = sp.linalg.fractional_matrix_power(
C[time_i, :n_seeds, :n_seeds], -0.5
)
T[time_i, n_seeds:, n_seeds:] = sp.linalg.fractional_matrix_power(
C[time_i, n_seeds:, n_seeds:], -0.5
)

return T


class _MultivariateImCohEstBase(_MultivariateCohEstBase):
"""Base estimator for multivariate imag. part of coherency methods.

Expand Down
Loading