Skip to content

Commit

Permalink
[ENH] Reduce compute time for multivariate coherency methods (#184)
Browse files Browse the repository at this point in the history
  • Loading branch information
tsbinns authored May 30, 2024
1 parent cc5bfcb commit 67df38b
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 65 deletions.
128 changes: 68 additions & 60 deletions mne_connectivity/spectral/epochs_multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from typing import Optional

import numpy as np
import scipy as sp
from mne.epochs import BaseEpochs
from mne.parallel import parallel_func
from mne.utils import ProgressBar, logger
Expand Down Expand Up @@ -287,31 +286,79 @@ def _compute_con_daughter(
"""

def _compute_t(self, C_r, n_seeds):
"""Compute transformation matrix, T, for frequencies (in parallel).
"""Compute transformation matrix, T, for frequencies (& times).
Eq. 3 of Ewald et al.; part of Eq. 9 of Vidaurre et al.
"""
parallel, parallel_invsqrtm, _ = parallel_func(
_invsqrtm, self.n_jobs, verbose=False
)

# imag. part of T filled when data is rank-deficient
T = np.zeros(C_r.shape, dtype=np.complex128)
for block_i in ProgressBar(range(self.n_steps), mesg="frequency blocks"):
freqs = self._get_block_indices(block_i, self.n_freqs)
T[:, freqs] = np.array(
parallel(parallel_invsqrtm(C_r[:, f], T[:, f], n_seeds) for f in freqs)
).transpose(1, 0, 2, 3)

if not np.isreal(T).all() or not np.isfinite(T).all():
try:
return self._invsqrtm(C_r, n_seeds)
except np.linalg.LinAlgError as error:
raise RuntimeError(
"the transformation matrix of the data must be real-valued "
"and contain no NaN or infinity values; check that you are "
"using full rank data or specify an appropriate rank for the "
"seeds and targets that is less than or equal to their ranks"
"the transformation matrix of the data could not be computed "
"from the cross-spectral density; check that you are using "
"full rank data or specify an appropriate rank for the seeds "
"and targets that is less than or equal to their ranks"
) from error

def _invsqrtm(self, C_r, n_seeds):
"""Compute inverse sqrt of CSD over frequencies and times.
Parameters
----------
C_r : np.ndarray, shape=(n_freqs, n_times, n_channels, n_channels)
Real part of the CSD. Expected to be symmetric and non-singular.
n_seeds : int
Number of seed channels for the connection.
Returns
-------
T : np.ndarray, shape=(n_freqs, n_times, n_channels, n_channels)
Inverse square root of the real-valued CSD. Name comes from Ewald
et al. (2012).
Notes
-----
This approach is a workaround for computing the inverse square root of
an ND array. SciPy has dedicated functions for this purpose, e.g.
`sp.linalg.fractional_matrix_power(A, -0.5)` or `sp.linalg.inv(
sp.linalg.sqrtm(A))`, however these only work with 2D arrays, meaning
frequencies and times must be looped over which is very slow. There are
no equivalent functions in NumPy for working with ND arrays (as of
v1.26).
The data array is expected to be symmetric and non-singular, otherwise
a LinAlgError is raised.
See Eq. 3 of Ewald et al. (2012). NeuroImage. DOI:
10.1016/j.neuroimage.2011.11.084.
"""
T = np.zeros_like(C_r, dtype=np.float64)

# seeds
eigvals, eigvects = np.linalg.eigh(C_r[:, :, :n_seeds, :n_seeds])
n_zero = (eigvals == 0).sum()
if n_zero: # sign of non-full rank data
raise np.linalg.LinAlgError(
"Cannot compute inverse square root of rank-deficient matrix "
f"with {n_zero}/{len(eigvals)} zero eigenvalue(s)"
)
T[:, :, :n_seeds, :n_seeds] = (
eigvects * np.expand_dims(1.0 / np.sqrt(eigvals), (2))
) @ eigvects.transpose(0, 1, 3, 2)

# targets
eigvals, eigvects = np.linalg.eigh(C_r[:, :, n_seeds:, n_seeds:])
n_zero = (eigvals == 0).sum()
if n_zero: # sign of non-full rank data
raise np.linalg.LinAlgError(
"Cannot compute inverse square root of rank-deficient matrix "
f"with {n_zero}/{len(eigvals)} zero eigenvalue(s)"
)
T[:, :, n_seeds:, n_seeds:] = (
eigvects * np.expand_dims(1.0 / np.sqrt(eigvals), (2))
) @ eigvects.transpose(0, 1, 3, 2)

return np.real(T) # make T real if check passes
return T

def reshape_results(self):
"""Remove time dimension from results, if necessary."""
Expand All @@ -321,43 +368,6 @@ def reshape_results(self):
self.patterns = self.patterns[..., 0]


def _invsqrtm(C, T, n_seeds):
"""Compute inverse sqrt of CSD over times (used for CaCoh, MIC, & MIM).
Parameters
----------
C : np.ndarray, shape=(n_times, n_channels, n_channels)
CSD for a single frequency and all times (n_times=1 if the mode is not
time-frequency resolved, e.g. multitaper).
T : np.ndarray, shape=(n_times, n_channels, n_channels)
Empty array to store the inverse square root of the CSD in.
n_seeds : int
Number of seed channels for the connection.
Returns
-------
T : np.ndarray, shape=(n_times, n_channels, n_channels)
Inverse square root of the CSD. Name comes from Ewald et al. (2012).
Notes
-----
Kept as a standalone function to allow for parallelisation over CSD
frequencies.
See Eq. 3 of Ewald et al. (2012). NeuroImage. DOI:
10.1016/j.neuroimage.2011.11.084.
"""
for time_i in range(C.shape[0]):
T[time_i, :n_seeds, :n_seeds] = sp.linalg.fractional_matrix_power(
C[time_i, :n_seeds, :n_seeds], -0.5
)
T[time_i, n_seeds:, n_seeds:] = sp.linalg.fractional_matrix_power(
C[time_i, n_seeds:, n_seeds:], -0.5
)

return T


class _MultivariateImCohEstBase(_MultivariateCohEstBase):
"""Base estimator for multivariate imag. part of coherency methods.
Expand Down Expand Up @@ -844,9 +854,7 @@ def _whittle_lwr_recursion(self, G):
) # forward autocov
G_b = np.reshape(
np.flip(G[:, 1:, :, :], 1).transpose(0, 3, 2, 1), (t, n, qn), order="F"
).transpose(
0, 2, 1
) # backward autocov
).transpose(0, 2, 1) # backward autocov

A_f = np.zeros((t, n, qn)) # forward coefficients
A_b = np.zeros((t, n, qn)) # backward coefficients
Expand Down
12 changes: 7 additions & 5 deletions mne_connectivity/spectral/tests/test_spectral.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ def test_spectral_connectivity(method, mode):
con.get_data(output="dense")[1, 0, gidx[0] : gidx[1]], lower_t
)
assert_array_less(con.get_data(output="dense")[1, 0, : bidx[0]], lower_t)
assert_array_less(con.get_data(output="dense")[1, 0, bidx[1] :], lower_t),
assert_array_less(con.get_data(output="dense")[1, 0, bidx[1] :], lower_t)
assert np.all(
con.get_data(output="dense")[1, 0, bidx[1] :] < lower_t
), con.get_data()[1, 0, bidx[1] :].max()
Expand Down Expand Up @@ -469,10 +469,12 @@ def test_spectral_connectivity(method, mode):

_gc_marks = []
if platform.system() == "Darwin" and platform.processor() == "arm":
_gc_marks.extend([
pytest.mark.filterwarnings("ignore:divide by zero encountered in det:"),
pytest.mark.filterwarnings("ignore:invalid value encountered in det:"),
])
_gc_marks.extend(
[
pytest.mark.filterwarnings("ignore:divide by zero encountered in det:"),
pytest.mark.filterwarnings("ignore:invalid value encountered in det:"),
]
)
_gc = pytest.param("gc", marks=_gc_marks, id="gc")
_gc_tr = pytest.param("gc_tr", marks=_gc_marks, id="gc_tr")

Expand Down

0 comments on commit 67df38b

Please sign in to comment.