-
Notifications
You must be signed in to change notification settings - Fork 34
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ENH] Reduce compute time for multivariate coherency methods #184
Changes from 8 commits
a06353a
22a3275
1a89f25
c755517
7964e0b
e2234d5
434eac3
e746427
eeed5aa
9130e30
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -467,12 +467,19 @@ def test_spectral_connectivity(method, mode): | |
assert out_lens[0] == 10 | ||
|
||
|
||
_coh_marks = [] | ||
_gc_marks = [] | ||
if platform.system() == "Darwin" and platform.processor() == "arm": | ||
_coh_marks.extend([ | ||
pytest.mark.filterwarnings("ignore:invalid value encountered in sqrt:") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can this be removed now that there is a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, the filters for the cohy methods should now be redundant. Have removed and will wait to see if macOS tests pass to be sure. |
||
]) | ||
_gc_marks.extend([ | ||
pytest.mark.filterwarnings("ignore:divide by zero encountered in det:"), | ||
pytest.mark.filterwarnings("ignore:invalid value encountered in det:"), | ||
]) | ||
_cacoh = pytest.param("cacoh", marks=_coh_marks, id="cacoh") | ||
_mic = pytest.param("mic", marks=_coh_marks, id="mic") | ||
_mim = pytest.param("mim", marks=_coh_marks, id="mim") | ||
_gc = pytest.param("gc", marks=_gc_marks, id="gc") | ||
_gc_tr = pytest.param("gc_tr", marks=_gc_marks, id="gc_tr") | ||
|
||
|
@@ -732,7 +739,7 @@ def test_multivariate_spectral_connectivity_epochs_regression(): | |
|
||
@pytest.mark.parametrize( | ||
"method", | ||
["cacoh", "mic", "mim", _gc, _gc_tr, ["cacoh", "mic", "mim", "gc", "gc_tr"]], | ||
[_cacoh, _mic, _mim, _gc, _gc_tr, ["cacoh", "mic", "mim", "gc", "gc_tr"]], | ||
) | ||
@pytest.mark.parametrize("mode", ["multitaper", "fourier", "cwt_morlet"]) | ||
def test_multivar_spectral_connectivity_epochs_error_catch(method, mode): | ||
|
@@ -913,7 +920,7 @@ def test_multivar_spectral_connectivity_epochs_error_catch(method, mode): | |
) | ||
|
||
|
||
@pytest.mark.parametrize("method", ["cacoh", "mic", "mim", _gc, _gc_tr]) | ||
@pytest.mark.parametrize("method", [_cacoh, _mic, _mim, _gc, _gc_tr]) | ||
def test_multivar_spectral_connectivity_parallel(method): | ||
"""Test multivar. freq.-domain connectivity methods run in parallel.""" | ||
data = make_signals_in_freq_bands( | ||
|
@@ -1434,7 +1441,7 @@ def test_spectral_connectivity_time_padding(method, mode, padding): | |
) | ||
|
||
|
||
@pytest.mark.parametrize("method", ["cacoh", "mic", "mim", _gc, _gc_tr]) | ||
@pytest.mark.parametrize("method", [_cacoh, _mic, _mim, _gc, _gc_tr]) | ||
@pytest.mark.parametrize("average", [True, False]) | ||
@pytest.mark.parametrize("faverage", [True, False]) | ||
def test_multivar_spectral_connectivity_time_shapes(method, average, faverage): | ||
|
@@ -1511,7 +1518,7 @@ def test_multivar_spectral_connectivity_time_shapes(method, average, faverage): | |
assert np.all(np.array(con.indices) == np.array(([[0, 1]], [[2, -1]]))) | ||
|
||
|
||
@pytest.mark.parametrize("method", ["cacoh", "mic", "mim", _gc, _gc_tr]) | ||
@pytest.mark.parametrize("method", [_cacoh, _mic, _mim, _gc, _gc_tr]) | ||
@pytest.mark.parametrize("mode", ["multitaper", "cwt_morlet"]) | ||
def test_multivar_spectral_connectivity_time_error_catch(method, mode): | ||
"""Test error catching for time-resolved multivar. connectivity methods.""" | ||
|
@@ -1722,7 +1729,7 @@ def test_spectral_connectivity_indices_roundtrip_io(tmp_path, method, indices): | |
assert con.indices is None and read_con.indices is None | ||
|
||
|
||
@pytest.mark.parametrize("method", ["cacoh", "mic", "mim", _gc, _gc_tr]) | ||
@pytest.mark.parametrize("method", [_cacoh, _mic, _mim, _gc, _gc_tr]) | ||
@pytest.mark.parametrize("indices", [None, ([[0, 1]], [[2, 3]])]) | ||
def test_multivar_spectral_connectivity_indices_roundtrip_io(tmp_path, method, indices): | ||
"""Test that indices values and type is maintained after saving. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, added for both the seed and target eigvals.