Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] Reduce compute time for multivariate coherency methods #184

Merged
merged 10 commits into from
May 30, 2024
124 changes: 62 additions & 62 deletions mne_connectivity/spectral/epochs_multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from typing import Optional

import numpy as np
import scipy as sp
from mne.epochs import BaseEpochs
from mne.parallel import parallel_func
from mne.utils import ProgressBar, logger
Expand Down Expand Up @@ -287,31 +286,71 @@ def _compute_con_daughter(
"""

def _compute_t(self, C_r, n_seeds):
"""Compute transformation matrix, T, for frequencies (in parallel).
"""Compute transformation matrix, T, for frequencies (& times).

Eq. 3 of Ewald et al.; part of Eq. 9 of Vidaurre et al.
"""
parallel, parallel_invsqrtm, _ = parallel_func(
_invsqrtm, self.n_jobs, verbose=False
)

# imag. part of T filled when data is rank-deficient
T = np.zeros(C_r.shape, dtype=np.complex128)
for block_i in ProgressBar(range(self.n_steps), mesg="frequency blocks"):
freqs = self._get_block_indices(block_i, self.n_freqs)
T[:, freqs] = np.array(
parallel(parallel_invsqrtm(C_r[:, f], T[:, f], n_seeds) for f in freqs)
).transpose(1, 0, 2, 3)

if not np.isreal(T).all() or not np.isfinite(T).all():
try:
return self._invsqrtm(C_r, n_seeds)
except np.linalg.LinAlgError as error:
raise RuntimeError(
"the transformation matrix of the data must be real-valued "
"and contain no NaN or infinity values; check that you are "
"using full rank data or specify an appropriate rank for the "
"seeds and targets that is less than or equal to their ranks"
)

return np.real(T) # make T real if check passes
"the transformation matrix of the data could not be computed "
"from the cross-spectral density; check that you are using "
"full rank data or specify an appropriate rank for the seeds "
"and targets that is less than or equal to their ranks"
) from error

def _invsqrtm(self, C_r, n_seeds):
"""Compute inverse sqrt of CSD over frequencies and times.

Parameters
----------
C_r : np.ndarray, shape=(n_freqs, n_times, n_channels, n_channels)
Real part of the CSD. Expected to be symmetric and non-singular.
n_seeds : int
Number of seed channels for the connection.

Returns
-------
T : np.ndarray, shape=(n_freqs, n_times, n_channels, n_channels)
Inverse square root of the real-valued CSD. Name comes from Ewald
et al. (2012).

Notes
-----
This approach is a workaround for computing the inverse square root of
an ND array. SciPy has dedicated functions for this purpose, e.g.
`sp.linalg.fractional_matrix_power(A, -0.5)` or `sp.linalg.inv(
sp.linalg.sqrtm(A))`, however these only work with 2D arrays, meaning
frequencies and times must be looped over which is very slow. There are
no equivalent functions in NumPy for working with ND arrays (as of
v1.26).

The data array is expected to be symmetric and non-singular, otherwise
a LinAlgError is raised.

See Eq. 3 of Ewald et al. (2012). NeuroImage. DOI:
10.1016/j.neuroimage.2011.11.084.
"""
T = np.zeros_like(C_r, dtype=np.float64)

# seeds
eigvals, eigvects = np.linalg.eigh(C_r[:, :, :n_seeds, :n_seeds])
if (eigvals == 0).any(): # sign of non-full rank data
raise np.linalg.LinAlgError()
T[:, :, :n_seeds, :n_seeds] = (
eigvects * np.expand_dims(1.0 / np.sqrt(eigvals), (2))
) @ eigvects.transpose(0, 1, 3, 2)

# targets
eigvals, eigvects = np.linalg.eigh(C_r[:, :, n_seeds:, n_seeds:])
if (eigvals == 0).any(): # sign of non-full rank data
raise np.linalg.LinAlgError()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (eigvals == 0).any(): # sign of non-full rank data
raise np.linalg.LinAlgError()
n_zero = (eigvals == 0).sum()
if n_zero: # sign of non-full rank data
raise np.linalg.LinAlgError(
"Cannot compute inverse square root of rank-deficient matrix "
f"with {n_zero}/{len(eigvals)} zero eigenvalue(s)"
)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, added for both the seed and target eigvals.

T[:, :, n_seeds:, n_seeds:] = (
eigvects * np.expand_dims(1.0 / np.sqrt(eigvals), (2))
) @ eigvects.transpose(0, 1, 3, 2)

return T

def reshape_results(self):
"""Remove time dimension from results, if necessary."""
Expand All @@ -321,43 +360,6 @@ def reshape_results(self):
self.patterns = self.patterns[..., 0]


def _invsqrtm(C, T, n_seeds):
"""Compute inverse sqrt of CSD over times (used for CaCoh, MIC, & MIM).

Parameters
----------
C : np.ndarray, shape=(n_times, n_channels, n_channels)
CSD for a single frequency and all times (n_times=1 if the mode is not
time-frequency resolved, e.g. multitaper).
T : np.ndarray, shape=(n_times, n_channels, n_channels)
Empty array to store the inverse square root of the CSD in.
n_seeds : int
Number of seed channels for the connection.

Returns
-------
T : np.ndarray, shape=(n_times, n_channels, n_channels)
Inverse square root of the CSD. Name comes from Ewald et al. (2012).

Notes
-----
Kept as a standalone function to allow for parallelisation over CSD
frequencies.

See Eq. 3 of Ewald et al. (2012). NeuroImage. DOI:
10.1016/j.neuroimage.2011.11.084.
"""
for time_i in range(C.shape[0]):
T[time_i, :n_seeds, :n_seeds] = sp.linalg.fractional_matrix_power(
C[time_i, :n_seeds, :n_seeds], -0.5
)
T[time_i, n_seeds:, n_seeds:] = sp.linalg.fractional_matrix_power(
C[time_i, n_seeds:, n_seeds:], -0.5
)

return T


class _MultivariateImCohEstBase(_MultivariateCohEstBase):
"""Base estimator for multivariate imag. part of coherency methods.

Expand Down Expand Up @@ -844,9 +846,7 @@ def _whittle_lwr_recursion(self, G):
) # forward autocov
G_b = np.reshape(
np.flip(G[:, 1:, :, :], 1).transpose(0, 3, 2, 1), (t, n, qn), order="F"
).transpose(
0, 2, 1
) # backward autocov
).transpose(0, 2, 1) # backward autocov

A_f = np.zeros((t, n, qn)) # forward coefficients
A_b = np.zeros((t, n, qn)) # backward coefficients
Expand Down
17 changes: 12 additions & 5 deletions mne_connectivity/spectral/tests/test_spectral.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,12 +467,19 @@ def test_spectral_connectivity(method, mode):
assert out_lens[0] == 10


_coh_marks = []
_gc_marks = []
if platform.system() == "Darwin" and platform.processor() == "arm":
_coh_marks.extend([
pytest.mark.filterwarnings("ignore:invalid value encountered in sqrt:")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this be removed now that there is a raise LinAlgError in there? Or does it come from another code path?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, the filters for the cohy methods should now be redundant. Have removed and will wait to see if macOS tests pass to be sure.

])
_gc_marks.extend([
pytest.mark.filterwarnings("ignore:divide by zero encountered in det:"),
pytest.mark.filterwarnings("ignore:invalid value encountered in det:"),
])
_cacoh = pytest.param("cacoh", marks=_coh_marks, id="cacoh")
_mic = pytest.param("mic", marks=_coh_marks, id="mic")
_mim = pytest.param("mim", marks=_coh_marks, id="mim")
_gc = pytest.param("gc", marks=_gc_marks, id="gc")
_gc_tr = pytest.param("gc_tr", marks=_gc_marks, id="gc_tr")

Expand Down Expand Up @@ -732,7 +739,7 @@ def test_multivariate_spectral_connectivity_epochs_regression():

@pytest.mark.parametrize(
"method",
["cacoh", "mic", "mim", _gc, _gc_tr, ["cacoh", "mic", "mim", "gc", "gc_tr"]],
[_cacoh, _mic, _mim, _gc, _gc_tr, ["cacoh", "mic", "mim", "gc", "gc_tr"]],
)
@pytest.mark.parametrize("mode", ["multitaper", "fourier", "cwt_morlet"])
def test_multivar_spectral_connectivity_epochs_error_catch(method, mode):
Expand Down Expand Up @@ -913,7 +920,7 @@ def test_multivar_spectral_connectivity_epochs_error_catch(method, mode):
)


@pytest.mark.parametrize("method", ["cacoh", "mic", "mim", _gc, _gc_tr])
@pytest.mark.parametrize("method", [_cacoh, _mic, _mim, _gc, _gc_tr])
def test_multivar_spectral_connectivity_parallel(method):
"""Test multivar. freq.-domain connectivity methods run in parallel."""
data = make_signals_in_freq_bands(
Expand Down Expand Up @@ -1434,7 +1441,7 @@ def test_spectral_connectivity_time_padding(method, mode, padding):
)


@pytest.mark.parametrize("method", ["cacoh", "mic", "mim", _gc, _gc_tr])
@pytest.mark.parametrize("method", [_cacoh, _mic, _mim, _gc, _gc_tr])
@pytest.mark.parametrize("average", [True, False])
@pytest.mark.parametrize("faverage", [True, False])
def test_multivar_spectral_connectivity_time_shapes(method, average, faverage):
Expand Down Expand Up @@ -1511,7 +1518,7 @@ def test_multivar_spectral_connectivity_time_shapes(method, average, faverage):
assert np.all(np.array(con.indices) == np.array(([[0, 1]], [[2, -1]])))


@pytest.mark.parametrize("method", ["cacoh", "mic", "mim", _gc, _gc_tr])
@pytest.mark.parametrize("method", [_cacoh, _mic, _mim, _gc, _gc_tr])
@pytest.mark.parametrize("mode", ["multitaper", "cwt_morlet"])
def test_multivar_spectral_connectivity_time_error_catch(method, mode):
"""Test error catching for time-resolved multivar. connectivity methods."""
Expand Down Expand Up @@ -1722,7 +1729,7 @@ def test_spectral_connectivity_indices_roundtrip_io(tmp_path, method, indices):
assert con.indices is None and read_con.indices is None


@pytest.mark.parametrize("method", ["cacoh", "mic", "mim", _gc, _gc_tr])
@pytest.mark.parametrize("method", [_cacoh, _mic, _mim, _gc, _gc_tr])
@pytest.mark.parametrize("indices", [None, ([[0, 1]], [[2, 3]])])
def test_multivar_spectral_connectivity_indices_roundtrip_io(tmp_path, method, indices):
"""Test that indices values and type is maintained after saving.
Expand Down
Loading