From 3c6a054093d305a98757a97398e5e34988a3aced Mon Sep 17 00:00:00 2001 From: Qian Chu <97355086+qian-chu@users.noreply.github.com> Date: Tue, 21 Jan 2025 19:30:23 +0100 Subject: [PATCH 1/5] [BUG] Correct annotation onset for exportation to EDF and EEGLAB (#12656) --- doc/changes/devel/12656.bugfix.rst | 1 + mne/export/_brainvision.py | 7 +++ mne/export/_edf.py | 5 +- mne/export/_eeglab.py | 16 ++++-- mne/export/_export.py | 8 ++- mne/export/tests/test_export.py | 89 +++++++++++++++++++++++++++--- 6 files changed, 112 insertions(+), 14 deletions(-) create mode 100644 doc/changes/devel/12656.bugfix.rst diff --git a/doc/changes/devel/12656.bugfix.rst b/doc/changes/devel/12656.bugfix.rst new file mode 100644 index 00000000000..b3a0c62539a --- /dev/null +++ b/doc/changes/devel/12656.bugfix.rst @@ -0,0 +1 @@ +Fix bug where :func:`mne.export.export_raw` does not correct for recording start time (`raw.first_time`) when exporting Raw instances to EDF or EEGLAB formats, by `Qian Chu`_. \ No newline at end of file diff --git a/mne/export/_brainvision.py b/mne/export/_brainvision.py index ba64ba010ce..6503c540f41 100644 --- a/mne/export/_brainvision.py +++ b/mne/export/_brainvision.py @@ -107,6 +107,13 @@ def _export_mne_raw(*, raw, fname, events=None, overwrite=False): def _mne_annots2pybv_events(raw): """Convert mne Annotations to pybv events.""" + # check that raw.annotations.orig_time is the same as raw.info["meas_date"] + # so that onsets are relative to the first sample + # (after further correction for first_time) + if raw.annotations and raw.info["meas_date"] != raw.annotations.orig_time: + raise ValueError( + "Annotations must have the same orig_time as raw.info['meas_date']" + ) events = [] for annot in raw.annotations: # handle onset and duration: seconds to sample, relative to diff --git a/mne/export/_edf.py b/mne/export/_edf.py index ef870692014..e50b05f7056 100644 --- a/mne/export/_edf.py +++ b/mne/export/_edf.py @@ -7,6 +7,7 @@ import numpy as np +from ..annotations import _sync_onset from ..utils import _check_edfio_installed, warn _check_edfio_installed() @@ -204,7 +205,9 @@ def _export_raw(fname, raw, physical_range, add_ch_type): for desc, onset, duration, ch_names in zip( raw.annotations.description, - raw.annotations.onset, + # subtract raw.first_time because EDF marks events starting from the first + # available data point and ignores raw.first_time + _sync_onset(raw, raw.annotations.onset, inverse=False), raw.annotations.duration, raw.annotations.ch_names, ): diff --git a/mne/export/_eeglab.py b/mne/export/_eeglab.py index 3c8f896164a..459207f0616 100644 --- a/mne/export/_eeglab.py +++ b/mne/export/_eeglab.py @@ -4,6 +4,7 @@ import numpy as np +from ..annotations import _sync_onset from ..utils import _check_eeglabio_installed _check_eeglabio_installed() @@ -24,11 +25,16 @@ def _export_raw(fname, raw): ch_names = [ch for ch in raw.ch_names if ch not in drop_chs] cart_coords = _get_als_coords_from_chs(raw.info["chs"], drop_chs) - annotations = [ - raw.annotations.description, - raw.annotations.onset, - raw.annotations.duration, - ] + if raw.annotations: + annotations = [ + raw.annotations.description, + # subtract raw.first_time because EEGLAB marks events starting from + # the first available data point and ignores raw.first_time + _sync_onset(raw, raw.annotations.onset, inverse=False), + raw.annotations.duration, + ] + else: + annotations = None eeglabio.raw.export_set( fname, data=raw.get_data(picks=ch_names), diff --git a/mne/export/_export.py b/mne/export/_export.py index 6e63064bf7c..835eb9f8513 100644 --- a/mne/export/_export.py +++ b/mne/export/_export.py @@ -22,9 +22,15 @@ def export_raw( """Export Raw to external formats. %(export_fmt_support_raw)s - %(export_warning)s + .. warning:: + When exporting ``Raw`` with annotations, ``raw.info["meas_date"]`` must be the + same as ``raw.annotations.orig_time``. This guarantees that the annotations are + in the same reference frame as the samples. + When `Raw.first_time` is not zero (e.g., after cropping), the onsets are + automatically corrected so that onsets are always relative to the first sample. + Parameters ---------- %(fname_export_params)s diff --git a/mne/export/tests/test_export.py b/mne/export/tests/test_export.py index 191e91b1eed..6f712923c7d 100644 --- a/mne/export/tests/test_export.py +++ b/mne/export/tests/test_export.py @@ -122,6 +122,49 @@ def test_export_raw_eeglab(tmp_path): raw.export(temp_fname, overwrite=True) +@pytest.mark.parametrize("tmin", (0, 1, 5, 10)) +def test_export_raw_eeglab_annotations(tmp_path, tmin): + """Test annotations in the exported EEGLAB file. + + All annotations should be preserved and onset corrected. + """ + pytest.importorskip("eeglabio") + raw = read_raw_fif(fname_raw, preload=True) + raw.apply_proj() + annotations = Annotations( + onset=[0.01, 0.05, 0.90, 1.05], + duration=[0, 1, 0, 0], + description=["test1", "test2", "test3", "test4"], + ch_names=[["MEG 0113"], ["MEG 0113", "MEG 0132"], [], ["MEG 0143"]], + ) + raw.set_annotations(annotations) + raw.crop(tmin) + + # export + temp_fname = tmp_path / "test.set" + raw.export(temp_fname) + + # read in the file + with pytest.warns(RuntimeWarning, match="is above the 99th percentile"): + raw_read = read_raw_eeglab(temp_fname, preload=True, montage_units="m") + assert raw_read.first_time == 0 # exportation resets first_time + valid_annot = ( + raw.annotations.onset >= tmin + ) # only annotations in the cropped range gets exported + + # compare annotations before and after export + assert_array_almost_equal( + raw.annotations.onset[valid_annot] - raw.first_time, + raw_read.annotations.onset, + ) + assert_array_equal( + raw.annotations.duration[valid_annot], raw_read.annotations.duration + ) + assert_array_equal( + raw.annotations.description[valid_annot], raw_read.annotations.description + ) + + def _create_raw_for_edf_tests(stim_channel_index=None): rng = np.random.RandomState(12345) ch_types = [ @@ -154,6 +197,7 @@ def test_double_export_edf(tmp_path): """Test exporting an EDF file multiple times.""" raw = _create_raw_for_edf_tests(stim_channel_index=2) raw.info.set_meas_date("2023-09-04 14:53:09.000") + raw.set_annotations(Annotations(onset=[1], duration=[0], description=["test"])) # include subject info and measurement date raw.info["subject_info"] = dict( @@ -258,8 +302,12 @@ def test_edf_padding(tmp_path, pad_width): @edfio_mark() -def test_export_edf_annotations(tmp_path): - """Test that exporting EDF preserves annotations.""" +@pytest.mark.parametrize("tmin", (0, 0.005, 0.03, 1)) +def test_export_edf_annotations(tmp_path, tmin): + """Test annotations in the exported EDF file. + + All annotations should be preserved and onset corrected. + """ raw = _create_raw_for_edf_tests() annotations = Annotations( onset=[0.01, 0.05, 0.90, 1.05], @@ -268,17 +316,44 @@ def test_export_edf_annotations(tmp_path): ch_names=[["0"], ["0", "1"], [], ["1"]], ) raw.set_annotations(annotations) + raw.crop(tmin) + assert raw.first_time == tmin + + if raw.n_times % raw.info["sfreq"] == 0: + expectation = nullcontext() + else: + expectation = pytest.warns( + RuntimeWarning, match="EDF format requires equal-length data blocks" + ) # export temp_fname = tmp_path / "test.edf" - raw.export(temp_fname) + with expectation: + raw.export(temp_fname) # read in the file raw_read = read_raw_edf(temp_fname, preload=True) - assert_array_equal(raw.annotations.onset, raw_read.annotations.onset) - assert_array_equal(raw.annotations.duration, raw_read.annotations.duration) - assert_array_equal(raw.annotations.description, raw_read.annotations.description) - assert_array_equal(raw.annotations.ch_names, raw_read.annotations.ch_names) + assert raw_read.first_time == 0 # exportation resets first_time + bad_annot = raw_read.annotations.description == "BAD_ACQ_SKIP" + if bad_annot.any(): + raw_read.annotations.delete(bad_annot) + valid_annot = ( + raw.annotations.onset >= tmin + ) # only annotations in the cropped range gets exported + + # compare annotations before and after export + assert_array_almost_equal( + raw.annotations.onset[valid_annot] - raw.first_time, raw_read.annotations.onset + ) + assert_array_equal( + raw.annotations.duration[valid_annot], raw_read.annotations.duration + ) + assert_array_equal( + raw.annotations.description[valid_annot], raw_read.annotations.description + ) + assert_array_equal( + raw.annotations.ch_names[valid_annot], raw_read.annotations.ch_names + ) @edfio_mark() From 27386d7bc8240500efcfc618e2fa57f0bcea1ace Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 21 Jan 2025 20:23:39 +0000 Subject: [PATCH 2/5] Bump autofix-ci/action from ff86a557419858bb967097bfc916833f5647fa8c to 551dded8c6cc8a1054039c8bc0b8b48c51dfc6ef in the actions group (#13071) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/autofix.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/autofix.yml b/.github/workflows/autofix.yml index d8a99200783..18543b854d0 100644 --- a/.github/workflows/autofix.yml +++ b/.github/workflows/autofix.yml @@ -21,4 +21,4 @@ jobs: - run: pip install --upgrade towncrier pygithub gitpython numpy - run: python ./.github/actions/rename_towncrier/rename_towncrier.py - run: python ./tools/dev/ensure_headers.py - - uses: autofix-ci/action@ff86a557419858bb967097bfc916833f5647fa8c + - uses: autofix-ci/action@551dded8c6cc8a1054039c8bc0b8b48c51dfc6ef From f97a916bc79942df1cc5578ed98cddbcf1aef907 Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Tue, 21 Jan 2025 17:18:35 -0500 Subject: [PATCH 3/5] MAINT: Add Numba to 3.13 test (#13075) --- doc/changes/devel/12656.bugfix.rst | 2 +- mne/export/_export.py | 8 +++++--- mne/forward/tests/test_make_forward.py | 2 +- mne/preprocessing/tests/test_fine_cal.py | 2 +- mne/utils/docs.py | 13 ++++++++----- tools/github_actions_dependencies.sh | 2 +- tools/github_actions_env_vars.sh | 2 +- 7 files changed, 18 insertions(+), 13 deletions(-) diff --git a/doc/changes/devel/12656.bugfix.rst b/doc/changes/devel/12656.bugfix.rst index b3a0c62539a..3f32dbd23e5 100644 --- a/doc/changes/devel/12656.bugfix.rst +++ b/doc/changes/devel/12656.bugfix.rst @@ -1 +1 @@ -Fix bug where :func:`mne.export.export_raw` does not correct for recording start time (`raw.first_time`) when exporting Raw instances to EDF or EEGLAB formats, by `Qian Chu`_. \ No newline at end of file +Fix bug where :func:`mne.export.export_raw` does not correct for recording start time (:attr:`raw.first_time `) when exporting Raw instances to EDF or EEGLAB formats, by `Qian Chu`_. \ No newline at end of file diff --git a/mne/export/_export.py b/mne/export/_export.py index 835eb9f8513..4b93fda917e 100644 --- a/mne/export/_export.py +++ b/mne/export/_export.py @@ -22,14 +22,16 @@ def export_raw( """Export Raw to external formats. %(export_fmt_support_raw)s + %(export_warning)s .. warning:: When exporting ``Raw`` with annotations, ``raw.info["meas_date"]`` must be the same as ``raw.annotations.orig_time``. This guarantees that the annotations are - in the same reference frame as the samples. - When `Raw.first_time` is not zero (e.g., after cropping), the onsets are - automatically corrected so that onsets are always relative to the first sample. + in the same reference frame as the samples. When + :attr:`Raw.first_time ` is not zero (e.g., after + cropping), the onsets are automatically corrected so that onsets are always + relative to the first sample. Parameters ---------- diff --git a/mne/forward/tests/test_make_forward.py b/mne/forward/tests/test_make_forward.py index 37ec6e041b5..a357c5779c9 100644 --- a/mne/forward/tests/test_make_forward.py +++ b/mne/forward/tests/test_make_forward.py @@ -482,7 +482,7 @@ def test_make_forward_solution_openmeeg(n_layers): eeg_atol=100, meg_corr_tol=0.98, eeg_corr_tol=0.98, - meg_rdm_tol=0.1, + meg_rdm_tol=0.11, eeg_rdm_tol=0.2, ) diff --git a/mne/preprocessing/tests/test_fine_cal.py b/mne/preprocessing/tests/test_fine_cal.py index 45971620db5..8b45208e848 100644 --- a/mne/preprocessing/tests/test_fine_cal.py +++ b/mne/preprocessing/tests/test_fine_cal.py @@ -231,7 +231,7 @@ def test_fine_cal_systems(system, tmp_path): err_limit = 6000 n_ref = 28 corrs = (0.19, 0.41, 0.49) - sfs = [0.5, 0.7, 0.9, 1.5] + sfs = [0.5, 0.7, 0.9, 1.55] corr_tol = 0.55 elif system == "fil": raw = read_raw_fil(fil_fname, verbose="error") diff --git a/mne/utils/docs.py b/mne/utils/docs.py index 683704c4bc6..54cc6845e58 100644 --- a/mne/utils/docs.py +++ b/mne/utils/docs.py @@ -1494,19 +1494,22 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): docdict["export_fmt_support_epochs"] = """\ Supported formats: - - EEGLAB (``.set``, uses :mod:`eeglabio`) + +- EEGLAB (``.set``, uses :mod:`eeglabio`) """ docdict["export_fmt_support_evoked"] = """\ Supported formats: - - MFF (``.mff``, uses :func:`mne.export.export_evokeds_mff`) + +- MFF (``.mff``, uses :func:`mne.export.export_evokeds_mff`) """ docdict["export_fmt_support_raw"] = """\ Supported formats: - - BrainVision (``.vhdr``, ``.vmrk``, ``.eeg``, uses `pybv `_) - - EEGLAB (``.set``, uses :mod:`eeglabio`) - - EDF (``.edf``, uses `edfio `_) + +- BrainVision (``.vhdr``, ``.vmrk``, ``.eeg``, uses `pybv `_) +- EEGLAB (``.set``, uses :mod:`eeglabio`) +- EDF (``.edf``, uses `edfio `_) """ # noqa: E501 docdict["export_warning"] = """\ diff --git a/tools/github_actions_dependencies.sh b/tools/github_actions_dependencies.sh index cebd2caefa7..d47d9070f8b 100755 --- a/tools/github_actions_dependencies.sh +++ b/tools/github_actions_dependencies.sh @@ -23,7 +23,7 @@ if [ ! -z "$CONDA_ENV" ]; then elif [[ "${MNE_CI_KIND}" == "pip" ]]; then # Only used for 3.13 at the moment, just get test deps plus a few extras # that we know are available - INSTALL_ARGS="nibabel scikit-learn numpydoc PySide6 mne-qt-browser pandas h5io mffpy defusedxml" + INSTALL_ARGS="nibabel scikit-learn numpydoc PySide6 mne-qt-browser pandas h5io mffpy defusedxml numba" INSTALL_KIND="test" else test "${MNE_CI_KIND}" == "pip-pre" diff --git a/tools/github_actions_env_vars.sh b/tools/github_actions_env_vars.sh index 8accf72a11a..9f424ae5f48 100755 --- a/tools/github_actions_env_vars.sh +++ b/tools/github_actions_env_vars.sh @@ -28,7 +28,7 @@ else # conda-like echo "MNE_LOGGING_LEVEL=warning" | tee -a $GITHUB_ENV echo "MNE_QT_BACKEND=PySide6" | tee -a $GITHUB_ENV # TODO: Also need "|unreliable on GitHub Actions conda" on macOS, but omit for now to make sure the failure actually shows up - echo "MNE_TEST_ALLOW_SKIP=.*(Requires (spm|brainstorm) dataset|CUDA not|PySide6 causes segfaults).*" | tee -a $GITHUB_ENV + echo "MNE_TEST_ALLOW_SKIP=.*(Requires (spm|brainstorm) dataset|CUDA not|PySide6 causes segfaults|Accelerate|Flakey verbose behavior).*" | tee -a $GITHUB_ENV fi fi set +x From 99e985845759005c2d809c705241918589aa2a0e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 22 Jan 2025 01:49:51 +0000 Subject: [PATCH 4/5] [pre-commit.ci] pre-commit autoupdate (#13073) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Eric Larson --- .github/workflows/check_changelog.yml | 3 +++ .github/workflows/circle_artifacts.yml | 3 +++ .pre-commit-config.yaml | 4 ++-- 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/.github/workflows/check_changelog.yml b/.github/workflows/check_changelog.yml index cc85b591977..6995c399b34 100644 --- a/.github/workflows/check_changelog.yml +++ b/.github/workflows/check_changelog.yml @@ -5,6 +5,9 @@ on: # yamllint disable-line rule:truthy types: [opened, synchronize, labeled, unlabeled] branches: ["main"] +permissions: + contents: read + jobs: changelog_checker: name: Check towncrier entry in doc/changes/devel/ diff --git a/.github/workflows/circle_artifacts.yml b/.github/workflows/circle_artifacts.yml index fa32e1ce80c..301c6234eb5 100644 --- a/.github/workflows/circle_artifacts.yml +++ b/.github/workflows/circle_artifacts.yml @@ -1,4 +1,7 @@ on: [status] # yamllint disable-line rule:truthy +permissions: + contents: read + statuses: write jobs: circleci_artifacts_redirector_job: if: "${{ startsWith(github.event.context, 'ci/circleci: build_docs') }}" diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index cb769988655..34b3ce9b130 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,7 +1,7 @@ repos: # Ruff mne - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.9.1 + rev: v0.9.2 hooks: - id: ruff name: ruff lint mne @@ -82,7 +82,7 @@ repos: # zizmor - repo: https://github.com/woodruffw/zizmor-pre-commit - rev: v1.1.1 + rev: v1.2.2 hooks: - id: zizmor From 5f2b7f1a33d42c5a110f67e098f9efcf92be7fff Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Wed, 22 Jan 2025 12:22:24 -0500 Subject: [PATCH 5/5] BUG: Improve sklearn compliance (#13065) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: Daniel McCloy --- doc/changes/devel/13065.bugfix.rst | 7 + examples/decoding/linear_model_patterns.py | 2 +- mne/cov.py | 4 +- mne/decoding/base.py | 11 +- mne/decoding/csp.py | 99 ++++----- mne/decoding/ems.py | 25 ++- mne/decoding/search_light.py | 81 ++++--- mne/decoding/ssd.py | 138 ++++++------ mne/decoding/tests/test_base.py | 14 +- mne/decoding/tests/test_csp.py | 46 ++-- mne/decoding/tests/test_ems.py | 7 + mne/decoding/tests/test_search_light.py | 30 +-- mne/decoding/tests/test_ssd.py | 54 ++++- mne/decoding/tests/test_time_frequency.py | 23 +- mne/decoding/tests/test_transformer.py | 88 ++++++-- mne/decoding/time_frequency.py | 32 ++- mne/decoding/transformer.py | 240 +++++++++++---------- mne/time_frequency/multitaper.py | 9 +- mne/time_frequency/tfr.py | 3 +- mne/utils/numerics.py | 5 +- tools/vulture_allowlist.py | 2 + 21 files changed, 571 insertions(+), 349 deletions(-) create mode 100644 doc/changes/devel/13065.bugfix.rst diff --git a/doc/changes/devel/13065.bugfix.rst b/doc/changes/devel/13065.bugfix.rst new file mode 100644 index 00000000000..bbaa07ae127 --- /dev/null +++ b/doc/changes/devel/13065.bugfix.rst @@ -0,0 +1,7 @@ +Improved sklearn class compatibility and compliance, which resulted in some parameters of classes having an underscore appended to their name during ``fit``, such as: + +- :class:`mne.decoding.FilterEstimator` parameter ``picks`` passed to the initializer is set as ``est.picks_`` +- :class:`mne.decoding.UnsupervisedSpatialFilter` parameter ``estimator`` passed to the initializer is set as ``est.estimator_`` + +Unused ``verbose`` class parameters (that had no effect) were removed from :class:`~mne.decoding.PSDEstimator`, :class:`~mne.decoding.TemporalFilter`, and :class:`~mne.decoding.FilterEstimator` as well. +Changes by `Eric Larson`_. diff --git a/examples/decoding/linear_model_patterns.py b/examples/decoding/linear_model_patterns.py index c1390cbb0d3..7373c0a18b3 100644 --- a/examples/decoding/linear_model_patterns.py +++ b/examples/decoding/linear_model_patterns.py @@ -79,7 +79,7 @@ # Extract and plot spatial filters and spatial patterns for name, coef in (("patterns", model.patterns_), ("filters", model.filters_)): - # We fitted the linear model onto Z-scored data. To make the filters + # We fit the linear model on Z-scored data. To make the filters # interpretable, we must reverse this normalization step coef = scaler.inverse_transform([coef])[0] diff --git a/mne/cov.py b/mne/cov.py index 94239472fa2..694c836d0cd 100644 --- a/mne/cov.py +++ b/mne/cov.py @@ -1226,7 +1226,7 @@ def _compute_rank_raw_array( from .io import RawArray return _compute_rank( - RawArray(data, info, copy=None, verbose=_verbose_safe_false()), + RawArray(data, info, copy="auto", verbose=_verbose_safe_false()), rank, scalings, info, @@ -1405,7 +1405,7 @@ def _compute_covariance_auto( # project back cov = np.dot(eigvec.T, np.dot(cov, eigvec)) # undo bias - cov *= data.shape[0] / (data.shape[0] - 1) + cov *= data.shape[0] / max(data.shape[0] - 1, 1) # undo scaling _undo_scaling_cov(cov, picks_list, scalings) method_ = method[ei] diff --git a/mne/decoding/base.py b/mne/decoding/base.py index a291416bb17..f73cd976fe3 100644 --- a/mne/decoding/base.py +++ b/mne/decoding/base.py @@ -19,7 +19,7 @@ from sklearn.linear_model import LogisticRegression from sklearn.metrics import check_scoring from sklearn.model_selection import KFold, StratifiedKFold, check_cv -from sklearn.utils import check_array, indexable +from sklearn.utils import check_array, check_X_y, indexable from ..parallel import parallel_func from ..utils import _pl, logger, verbose, warn @@ -76,9 +76,9 @@ class LinearModel(MetaEstimatorMixin, BaseEstimator): ) def __init__(self, model=None): + # TODO: We need to set this to get our tag checking to work properly if model is None: model = LogisticRegression(solver="liblinear") - self.model = model def __sklearn_tags__(self): @@ -122,7 +122,11 @@ def fit(self, X, y, **fit_params): self : instance of LinearModel Returns the modified instance. """ - X = check_array(X, input_name="X") + if y is not None: + X = check_array(X) + else: + X, y = check_X_y(X, y) + self.n_features_in_ = X.shape[1] if y is not None: y = check_array(y, dtype=None, ensure_2d=False, input_name="y") if y.ndim > 2: @@ -133,6 +137,7 @@ def fit(self, X, y, **fit_params): # fit the Model self.model.fit(X, y, **fit_params) + self.model_ = self.model # for better sklearn compat # Computes patterns using Haufe's trick: A = Cov_X . W . Precision_Y diff --git a/mne/decoding/csp.py b/mne/decoding/csp.py index 9e12335cdbe..ea38fd58ca3 100644 --- a/mne/decoding/csp.py +++ b/mne/decoding/csp.py @@ -6,7 +6,8 @@ import numpy as np from scipy.linalg import eigh -from sklearn.base import BaseEstimator, TransformerMixin +from sklearn.base import BaseEstimator +from sklearn.utils.validation import check_is_fitted from .._fiff.meas_info import create_info from ..cov import _compute_rank_raw_array, _regularized_covariance, _smart_eigh @@ -19,10 +20,11 @@ fill_doc, pinv, ) +from .transformer import MNETransformerMixin @fill_doc -class CSP(TransformerMixin, BaseEstimator): +class CSP(MNETransformerMixin, BaseEstimator): """M/EEG signal decomposition using the Common Spatial Patterns (CSP). This class can be used as a supervised decomposition to estimate spatial @@ -112,49 +114,44 @@ def __init__( component_order="mutual_info", ): # Init default CSP - if not isinstance(n_components, int): - raise ValueError("n_components must be an integer.") self.n_components = n_components self.rank = rank self.reg = reg - - # Init default cov_est - if not (cov_est == "concat" or cov_est == "epoch"): - raise ValueError("unknown covariance estimation method") self.cov_est = cov_est - - # Init default transform_into - self.transform_into = _check_option( - "transform_into", transform_into, ["average_power", "csp_space"] - ) - - # Init default log - if transform_into == "average_power": - if log is not None and not isinstance(log, bool): - raise ValueError( - 'log must be a boolean if transform_into == "average_power".' - ) - else: - if log is not None: - raise ValueError('log must be a None if transform_into == "csp_space".') + self.transform_into = transform_into self.log = log - - _validate_type(norm_trace, bool, "norm_trace") self.norm_trace = norm_trace self.cov_method_params = cov_method_params - self.component_order = _check_option( - "component_order", component_order, ("mutual_info", "alternate") + self.component_order = component_order + + def _validate_params(self, *, y): + _validate_type(self.n_components, int, "n_components") + if hasattr(self, "cov_est"): + _validate_type(self.cov_est, str, "cov_est") + _check_option("cov_est", self.cov_est, ("concat", "epoch")) + if hasattr(self, "norm_trace"): + _validate_type(self.norm_trace, bool, "norm_trace") + _check_option( + "transform_into", self.transform_into, ["average_power", "csp_space"] ) - - def _check_Xy(self, X, y=None): - """Check input data.""" - if not isinstance(X, np.ndarray): - raise ValueError(f"X should be of type ndarray (got {type(X)}).") - if y is not None: - if len(X) != len(y) or len(y) < 1: - raise ValueError("X and y must have the same length.") - if X.ndim < 3: - raise ValueError("X must have at least 3 dimensions.") + if self.transform_into == "average_power": + _validate_type( + self.log, + (bool, None), + "log", + extra="when transform_into is 'average_power'", + ) + else: + _validate_type( + self.log, None, "log", extra="when transform_into is 'csp_space'" + ) + _check_option( + "component_order", self.component_order, ("mutual_info", "alternate") + ) + self.classes_ = np.unique(y) + n_classes = len(self.classes_) + if n_classes < 2: + raise ValueError(f"n_classes must be >= 2, but got {n_classes} class") def fit(self, X, y): """Estimate the CSP decomposition on epochs. @@ -171,12 +168,9 @@ def fit(self, X, y): self : instance of CSP Returns the modified instance. """ - self._check_Xy(X, y) - - self._classes = np.unique(y) - n_classes = len(self._classes) - if n_classes < 2: - raise ValueError("n_classes must be >= 2.") + X, y = self._check_data(X, y=y, fit=True, return_y=True) + self._validate_params(y=y) + n_classes = len(self.classes_) if n_classes > 2 and self.component_order == "alternate": raise ValueError( "component_order='alternate' requires two classes, but data contains " @@ -225,13 +219,8 @@ def transform(self, X): If self.transform_into == 'csp_space' then returns the data in CSP space and shape is (n_epochs, n_components, n_times). """ - if not isinstance(X, np.ndarray): - raise ValueError(f"X should be of type ndarray (got {type(X)}).") - if self.filters_ is None: - raise RuntimeError( - "No filters available. Please first fit CSP decomposition." - ) - + check_is_fitted(self, "filters_") + X = self._check_data(X) pick_filters = self.filters_[: self.n_components] X = np.asarray([np.dot(pick_filters, epoch) for epoch in X]) @@ -577,7 +566,7 @@ def _compute_covariance_matrices(self, X, y): covs = [] sample_weights = [] - for ci, this_class in enumerate(self._classes): + for ci, this_class in enumerate(self.classes_): cov, weight = cov_estimator( X[y == this_class], cov_kind=f"class={this_class}", @@ -689,7 +678,7 @@ def _normalize_eigenvectors(self, eigen_vectors, covs, sample_weights): def _order_components( self, covs, sample_weights, eigen_vectors, eigen_values, component_order ): - n_classes = len(self._classes) + n_classes = len(self.classes_) if component_order == "mutual_info" and n_classes > 2: mutual_info = self._compute_mutual_info(covs, sample_weights, eigen_vectors) ix = np.argsort(mutual_info)[::-1] @@ -889,10 +878,8 @@ def fit(self, X, y): self : instance of SPoC Returns the modified instance. """ - self._check_Xy(X, y) - - if len(np.unique(y)) < 2: - raise ValueError("y must have at least two distinct values.") + X, y = self._check_data(X, y=y, fit=True, return_y=True) + self._validate_params(y=y) # The following code is directly copied from pyRiemann diff --git a/mne/decoding/ems.py b/mne/decoding/ems.py index b3e72a30e21..5c7557798ef 100644 --- a/mne/decoding/ems.py +++ b/mne/decoding/ems.py @@ -5,15 +5,16 @@ from collections import Counter import numpy as np -from sklearn.base import BaseEstimator, TransformerMixin +from sklearn.base import BaseEstimator from .._fiff.pick import _picks_to_idx, pick_info, pick_types from ..parallel import parallel_func from ..utils import logger, verbose from .base import _set_cv +from .transformer import MNETransformerMixin -class EMS(TransformerMixin, BaseEstimator): +class EMS(MNETransformerMixin, BaseEstimator): """Transformer to compute event-matched spatial filters. This version of EMS :footcite:`SchurgerEtAl2013` operates on the entire @@ -37,6 +38,16 @@ class EMS(TransformerMixin, BaseEstimator): .. footbibliography:: """ + def __sklearn_tags__(self): + """Return sklearn tags.""" + from sklearn.utils import ClassifierTags + + tags = super().__sklearn_tags__() + if tags.classifier_tags is None: + tags.classifier_tags = ClassifierTags() + tags.classifier_tags.multi_class = False + return tags + def __repr__(self): # noqa: D105 if hasattr(self, "filters_"): return ( @@ -64,11 +75,12 @@ def fit(self, X, y): self : instance of EMS Returns self. """ - classes = np.unique(y) - if len(classes) != 2: + X, y = self._check_data(X, y=y, fit=True, return_y=True) + classes, y = np.unique(y, return_inverse=True) + if len(classes) > 2: raise ValueError("EMS only works for binary classification.") self.classes_ = classes - filters = X[y == classes[0]].mean(0) - X[y == classes[1]].mean(0) + filters = X[y == 0].mean(0) - X[y == 1].mean(0) filters /= np.linalg.norm(filters, axis=0)[None, :] self.filters_ = filters return self @@ -86,13 +98,14 @@ def transform(self, X): X : array, shape (n_epochs, n_times) The input data transformed by the spatial filters. """ + X = self._check_data(X) Xt = np.sum(X * self.filters_, axis=1) return Xt @verbose def compute_ems( - epochs, conditions=None, picks=None, n_jobs=None, cv=None, verbose=None + epochs, conditions=None, picks=None, n_jobs=None, cv=None, *, verbose=None ): """Compute event-matched spatial filter on epochs. diff --git a/mne/decoding/search_light.py b/mne/decoding/search_light.py index e3059a3e959..8bd96781185 100644 --- a/mne/decoding/search_light.py +++ b/mne/decoding/search_light.py @@ -5,18 +5,25 @@ import logging import numpy as np -from sklearn.base import BaseEstimator, MetaEstimatorMixin, TransformerMixin, clone +from sklearn.base import BaseEstimator, MetaEstimatorMixin, clone from sklearn.metrics import check_scoring from sklearn.preprocessing import LabelEncoder -from sklearn.utils import check_array +from sklearn.utils.validation import check_is_fitted from ..parallel import parallel_func -from ..utils import ProgressBar, _parse_verbose, array_split_idx, fill_doc, verbose +from ..utils import ( + ProgressBar, + _parse_verbose, + _verbose_safe_false, + array_split_idx, + fill_doc, +) from .base import _check_estimator +from .transformer import MNETransformerMixin @fill_doc -class SlidingEstimator(MetaEstimatorMixin, TransformerMixin, BaseEstimator): +class SlidingEstimator(MetaEstimatorMixin, MNETransformerMixin, BaseEstimator): """Search Light. Fit, predict and score a series of models to each subset of the dataset @@ -38,7 +45,6 @@ class SlidingEstimator(MetaEstimatorMixin, TransformerMixin, BaseEstimator): List of fitted scikit-learn estimators (one per task). """ - @verbose def __init__( self, base_estimator, @@ -49,7 +55,6 @@ def __init__( allow_2d=False, verbose=None, ): - _check_estimator(base_estimator) self.base_estimator = base_estimator self.n_jobs = n_jobs self.scoring = scoring @@ -102,9 +107,13 @@ def fit(self, X, y, **fit_params): self : object Return self. """ - X = self._check_Xy(X, y) + _check_estimator(self.base_estimator) + X, _ = self._check_Xy(X, y, fit=True) parallel, p_func, n_jobs = parallel_func( - _sl_fit, self.n_jobs, max_jobs=X.shape[-1], verbose=False + _sl_fit, + self.n_jobs, + max_jobs=X.shape[-1], + verbose=_verbose_safe_false(), ) self.estimators_ = list() self.fit_params_ = fit_params @@ -153,14 +162,19 @@ def fit_transform(self, X, y, **fit_params): def _transform(self, X, method): """Aux. function to make parallel predictions/transformation.""" - X = self._check_Xy(X) + X, is_nd = self._check_Xy(X) + orig_method = method + check_is_fitted(self) method = _check_method(self.base_estimator, method) if X.shape[-1] != len(self.estimators_): raise ValueError("The number of estimators does not match X.shape[-1]") # For predictions/transforms the parallelization is across the data and # not across the estimators to avoid memory load. parallel, p_func, n_jobs = parallel_func( - _sl_transform, self.n_jobs, max_jobs=X.shape[-1], verbose=False + _sl_transform, + self.n_jobs, + max_jobs=X.shape[-1], + verbose=_verbose_safe_false(), ) X_splits = np.array_split(X, n_jobs, axis=-1) @@ -174,6 +188,10 @@ def _transform(self, X, method): ) y_pred = np.concatenate(y_pred, axis=1) + if orig_method == "transform": + y_pred = y_pred.astype(X.dtype) + if orig_method == "predict_proba" and not is_nd: + y_pred = y_pred[:, 0, :] return y_pred def transform(self, X): @@ -196,7 +214,7 @@ def transform(self, X): Xt : array, shape (n_samples, n_estimators) The transformed values generated by each estimator. """ # noqa: E501 - return self._transform(X, "transform").astype(X.dtype) + return self._transform(X, "transform") def predict(self, X): """Predict each data slice/task with a series of independent estimators. @@ -265,15 +283,12 @@ def decision_function(self, X): """ # noqa: E501 return self._transform(X, "decision_function") - def _check_Xy(self, X, y=None): + def _check_Xy(self, X, y=None, fit=False): """Aux. function to check input data.""" # Once we require sklearn 1.1+ we should do something like: - X = check_array(X, ensure_2d=False, allow_nd=True, input_name="X") - if y is not None: - y = check_array(y, dtype=None, ensure_2d=False, input_name="y") - if len(X) != len(y) or len(y) < 1: - raise ValueError("X and y must have the same length.") - if X.ndim < 3: + X = self._check_data(X, y=y, atleast_3d=False, fit=fit) + is_nd = X.ndim >= 3 + if not is_nd: err = None if not self.allow_2d: err = 3 @@ -282,7 +297,7 @@ def _check_Xy(self, X, y=None): if err: raise ValueError(f"X must have at least {err} dimensions.") X = X[..., np.newaxis] - return X + return X, is_nd def score(self, X, y): """Score each estimator on each task. @@ -307,7 +322,7 @@ def score(self, X, y): score : array, shape (n_samples, n_estimators) Score for each estimator/task. """ # noqa: E501 - X = self._check_Xy(X, y) + X, _ = self._check_Xy(X, y) if X.shape[-1] != len(self.estimators_): raise ValueError("The number of estimators does not match X.shape[-1]") @@ -317,7 +332,10 @@ def score(self, X, y): # For predictions/transforms the parallelization is across the data and # not across the estimators to avoid memory load. parallel, p_func, n_jobs = parallel_func( - _sl_score, self.n_jobs, max_jobs=X.shape[-1], verbose=False + _sl_score, + self.n_jobs, + max_jobs=X.shape[-1], + verbose=_verbose_safe_false(), ) X_splits = np.array_split(X, n_jobs, axis=-1) est_splits = np.array_split(self.estimators_, n_jobs) @@ -483,11 +501,16 @@ def __repr__(self): # noqa: D105 def _transform(self, X, method): """Aux. function to make parallel predictions/transformation.""" - X = self._check_Xy(X) + X, is_nd = self._check_Xy(X) + check_is_fitted(self) + orig_method = method method = _check_method(self.base_estimator, method) parallel, p_func, n_jobs = parallel_func( - _gl_transform, self.n_jobs, max_jobs=X.shape[-1], verbose=False + _gl_transform, + self.n_jobs, + max_jobs=X.shape[-1], + verbose=_verbose_safe_false(), ) context = _create_progressbar_context(self, X, "Transforming") @@ -500,6 +523,10 @@ def _transform(self, X, method): ) y_pred = np.concatenate(y_pred, axis=2) + if orig_method == "transform": + y_pred = y_pred.astype(X.dtype) + if orig_method == "predict_proba" and not is_nd: + y_pred = y_pred[:, 0, 0, :] return y_pred def transform(self, X): @@ -518,6 +545,7 @@ def transform(self, X): Xt : array, shape (n_samples, n_estimators, n_slices) The transformed values generated by each estimator. """ + check_is_fitted(self) return self._transform(X, "transform") def predict(self, X): @@ -603,11 +631,14 @@ def score(self, X, y): score : array, shape (n_samples, n_estimators, n_slices) Score for each estimator / data slice couple. """ # noqa: E501 - X = self._check_Xy(X, y) + X, _ = self._check_Xy(X, y) # For predictions/transforms the parallelization is across the data and # not across the estimators to avoid memory load. parallel, p_func, n_jobs = parallel_func( - _gl_score, self.n_jobs, max_jobs=X.shape[-1], verbose=False + _gl_score, + self.n_jobs, + max_jobs=X.shape[-1], + verbose=_verbose_safe_false(), ) scoring = check_scoring(self.base_estimator, self.scoring) y = _fix_auc(scoring, y) diff --git a/mne/decoding/ssd.py b/mne/decoding/ssd.py index 8bc0036d315..111ded9f274 100644 --- a/mne/decoding/ssd.py +++ b/mne/decoding/ssd.py @@ -4,8 +4,10 @@ import numpy as np from scipy.linalg import eigh -from sklearn.base import BaseEstimator, TransformerMixin +from sklearn.base import BaseEstimator +from sklearn.utils.validation import check_is_fitted +from .._fiff.meas_info import Info, create_info from .._fiff.pick import _picks_to_idx from ..cov import Covariance, _regularized_covariance from ..defaults import _handle_default @@ -13,17 +15,17 @@ from ..rank import compute_rank from ..time_frequency import psd_array_welch from ..utils import ( - _check_option, _time_mask, _validate_type, _verbose_safe_false, fill_doc, logger, ) +from .transformer import MNETransformerMixin @fill_doc -class SSD(TransformerMixin, BaseEstimator): +class SSD(MNETransformerMixin, BaseEstimator): """ Signal decomposition using the Spatio-Spectral Decomposition (SSD). @@ -64,7 +66,7 @@ class SSD(TransformerMixin, BaseEstimator): If sort_by_spectral_ratio is set to True, then the SSD sources will be sorted according to their spectral ratio which is calculated based on :func:`mne.time_frequency.psd_array_welch`. The n_fft parameter sets the - length of FFT used. + length of FFT used. The default (None) will use 1 second of data. See :func:`mne.time_frequency.psd_array_welch` for more information. cov_method_params : dict | None (default None) As in :class:`mne.decoding.SPoC` @@ -104,7 +106,25 @@ def __init__( rank=None, ): """Initialize instance.""" - dicts = {"signal": filt_params_signal, "noise": filt_params_noise} + self.info = info + self.filt_params_signal = filt_params_signal + self.filt_params_noise = filt_params_noise + self.reg = reg + self.n_components = n_components + self.picks = picks + self.sort_by_spectral_ratio = sort_by_spectral_ratio + self.return_filtered = return_filtered + self.n_fft = n_fft + self.cov_method_params = cov_method_params + self.rank = rank + + def _validate_params(self, X): + if isinstance(self.info, float): # special case, mostly for testing + self.sfreq_ = self.info + else: + _validate_type(self.info, Info, "info") + self.sfreq_ = self.info["sfreq"] + dicts = {"signal": self.filt_params_signal, "noise": self.filt_params_noise} for param, dd in [("l", 0), ("h", 0), ("l", 1), ("h", 1)]: key = ("signal", "noise")[dd] if param + "_freq" not in dicts[key]: @@ -116,48 +136,47 @@ def __init__( _validate_type(val, ("numeric",), f"{key} {param}_freq") # check freq bands if ( - filt_params_noise["l_freq"] > filt_params_signal["l_freq"] - or filt_params_signal["h_freq"] > filt_params_noise["h_freq"] + self.filt_params_noise["l_freq"] > self.filt_params_signal["l_freq"] + or self.filt_params_signal["h_freq"] > self.filt_params_noise["h_freq"] ): raise ValueError( "Wrongly specified frequency bands!\n" "The signal band-pass must be within the noise " "band-pass!" ) - self.picks = picks - del picks - self.info = info - self.freqs_signal = (filt_params_signal["l_freq"], filt_params_signal["h_freq"]) - self.freqs_noise = (filt_params_noise["l_freq"], filt_params_noise["h_freq"]) - self.filt_params_signal = filt_params_signal - self.filt_params_noise = filt_params_noise - # check if boolean - if not isinstance(sort_by_spectral_ratio, (bool)): - raise ValueError("sort_by_spectral_ratio must be boolean") - self.sort_by_spectral_ratio = sort_by_spectral_ratio - if n_fft is None: - self.n_fft = int(self.info["sfreq"]) - else: - self.n_fft = int(n_fft) - # check if boolean - if not isinstance(return_filtered, (bool)): - raise ValueError("return_filtered must be boolean") - self.return_filtered = return_filtered - self.reg = reg - self.n_components = n_components - self.rank = rank - self.cov_method_params = cov_method_params + self.freqs_signal_ = ( + self.filt_params_signal["l_freq"], + self.filt_params_signal["h_freq"], + ) + self.freqs_noise_ = ( + self.filt_params_noise["l_freq"], + self.filt_params_noise["h_freq"], + ) + _validate_type(self.sort_by_spectral_ratio, (bool,), "sort_by_spectral_ratio") + _validate_type(self.n_fft, ("numeric", None), "n_fft") + self.n_fft_ = min( + int(self.n_fft if self.n_fft is not None else self.sfreq_), + X.shape[-1], + ) + _validate_type(self.return_filtered, (bool,), "return_filtered") + if isinstance(self.info, Info): + ch_types = self.info.get_channel_types(picks=self.picks, unique=True) + if len(ch_types) > 1: + raise ValueError( + "At this point SSD only supports fitting " + f"single channel types. Your info has {len(ch_types)} types." + ) - def _check_X(self, X): + def _check_X(self, X, *, y=None, fit=False): """Check input data.""" - _validate_type(X, np.ndarray, "X") - _check_option("X.ndim", X.ndim, (2, 3)) + X = self._check_data(X, y=y, fit=fit, atleast_3d=False) n_chan = X.shape[-2] - if n_chan != self.info["nchan"]: + if isinstance(self.info, Info) and n_chan != self.info["nchan"]: raise ValueError( "Info must match the input data." f"Found {n_chan} channels but expected {self.info['nchan']}." ) + return X def fit(self, X, y=None): """Estimate the SSD decomposition on raw or epoched data. @@ -176,18 +195,17 @@ def fit(self, X, y=None): self : instance of SSD Returns the modified instance. """ - ch_types = self.info.get_channel_types(picks=self.picks, unique=True) - if len(ch_types) > 1: - raise ValueError( - "At this point SSD only supports fitting " - f"single channel types. Your info has {len(ch_types)} types." - ) - self.picks_ = _picks_to_idx(self.info, self.picks, none="data", exclude="bads") - self._check_X(X) + X = self._check_X(X, y=y, fit=True) + self._validate_params(X) + if isinstance(self.info, Info): + info = self.info + else: + info = create_info(X.shape[-2], self.sfreq_, ch_types="eeg") + self.picks_ = _picks_to_idx(info, self.picks, none="data", exclude="bads") X_aux = X[..., self.picks_, :] - X_signal = filter_data(X_aux, self.info["sfreq"], **self.filt_params_signal) - X_noise = filter_data(X_aux, self.info["sfreq"], **self.filt_params_noise) + X_signal = filter_data(X_aux, self.sfreq_, **self.filt_params_signal) + X_noise = filter_data(X_aux, self.sfreq_, **self.filt_params_noise) X_noise -= X_signal if X.ndim == 3: X_signal = np.hstack(X_signal) @@ -199,19 +217,19 @@ def fit(self, X, y=None): reg=self.reg, method_params=self.cov_method_params, rank="full", - info=self.info, + info=info, ) cov_noise = _regularized_covariance( X_noise, reg=self.reg, method_params=self.cov_method_params, rank="full", - info=self.info, + info=info, ) # project cov to rank subspace cov_signal, cov_noise, rank_proj = _dimensionality_reduction( - cov_signal, cov_noise, self.info, self.rank + cov_signal, cov_noise, info, self.rank ) eigvals_, eigvects_ = eigh(cov_signal, cov_noise) @@ -226,10 +244,10 @@ def fit(self, X, y=None): # than the initial ordering. This ordering should be also learned when # fitting. X_ssd = self.filters_.T @ X[..., self.picks_, :] - sorter_spec = Ellipsis + sorter_spec = slice(None) if self.sort_by_spectral_ratio: _, sorter_spec = self.get_spectral_ratio(ssd_sources=X_ssd) - self.sorter_spec = sorter_spec + self.sorter_spec_ = sorter_spec logger.info("Done.") return self @@ -248,17 +266,13 @@ def transform(self, X): X_ssd : array, shape ([n_epochs, ]n_components, n_times) The processed data. """ - self._check_X(X) - if self.filters_ is None: - raise RuntimeError("No filters available. Please first call fit") + check_is_fitted(self, "filters_") + X = self._check_X(X) if self.return_filtered: X_aux = X[..., self.picks_, :] - X = filter_data(X_aux, self.info["sfreq"], **self.filt_params_signal) + X = filter_data(X_aux, self.sfreq_, **self.filt_params_signal) X_ssd = self.filters_.T @ X[..., self.picks_, :] - if X.ndim == 2: - X_ssd = X_ssd[self.sorter_spec][: self.n_components] - else: - X_ssd = X_ssd[:, self.sorter_spec, :][:, : self.n_components, :] + X_ssd = X_ssd[..., self.sorter_spec_, :][..., : self.n_components, :] return X_ssd def fit_transform(self, X, y=None, **fit_params): @@ -308,11 +322,9 @@ def get_spectral_ratio(self, ssd_sources): ---------- .. footbibliography:: """ - psd, freqs = psd_array_welch( - ssd_sources, sfreq=self.info["sfreq"], n_fft=self.n_fft - ) - sig_idx = _time_mask(freqs, *self.freqs_signal) - noise_idx = _time_mask(freqs, *self.freqs_noise) + psd, freqs = psd_array_welch(ssd_sources, sfreq=self.sfreq_, n_fft=self.n_fft_) + sig_idx = _time_mask(freqs, *self.freqs_signal_) + noise_idx = _time_mask(freqs, *self.freqs_noise_) if psd.ndim == 3: mean_sig = psd[:, :, sig_idx].mean(axis=2).mean(axis=0) mean_noise = psd[:, :, noise_idx].mean(axis=2).mean(axis=0) @@ -352,7 +364,7 @@ def apply(self, X): The processed data. """ X_ssd = self.transform(X) - pick_patterns = self.patterns_[self.sorter_spec][: self.n_components].T + pick_patterns = self.patterns_[self.sorter_spec_][: self.n_components].T X = pick_patterns @ X_ssd return X diff --git a/mne/decoding/tests/test_base.py b/mne/decoding/tests/test_base.py index 6d915dd24f9..504e309d53c 100644 --- a/mne/decoding/tests/test_base.py +++ b/mne/decoding/tests/test_base.py @@ -86,6 +86,8 @@ def _make_data(n_samples=1000, n_features=5, n_targets=3): X = Y.dot(A.T) X += np.random.randn(n_samples, n_features) # add noise X += np.random.rand(n_features) # Put an offset + if n_targets == 1: + Y = Y[:, 0] return X, Y, A @@ -95,7 +97,7 @@ def test_get_coef(): """Test getting linear coefficients (filters/patterns) from estimators.""" lm_classification = LinearModel() assert hasattr(lm_classification, "__sklearn_tags__") - print(lm_classification.__sklearn_tags__) + print(lm_classification.__sklearn_tags__()) assert is_classifier(lm_classification.model) assert is_classifier(lm_classification) assert not is_regressor(lm_classification.model) @@ -273,7 +275,12 @@ def test_get_coef_multiclass(n_features, n_targets): """Test get_coef on multiclass problems.""" # Check patterns with more than 1 regressor X, Y, A = _make_data(n_samples=30000, n_features=n_features, n_targets=n_targets) - lm = LinearModel(LinearRegression()).fit(X, Y) + lm = LinearModel(LinearRegression()) + assert not hasattr(lm, "model_") + lm.fit(X, Y) + # TODO: modifying non-underscored `model` is a sklearn no-no, maybe should be a + # metaestimator? + assert lm.model is lm.model_ assert_array_equal(lm.filters_.shape, lm.patterns_.shape) if n_targets == 1: want_shape = (n_features,) @@ -473,9 +480,8 @@ def test_cross_val_multiscore(): def test_sklearn_compliance(estimator, check): """Test LinearModel compliance with sklearn.""" ignores = ( - "check_n_features_in", # maybe we should add this someday? - "check_estimator_sparse_data", # we densify "check_estimators_overwrite_params", # self.model changes! + "check_dont_overwrite_parameters", "check_parameters_default_constructible", ) if any(ignore in str(check) for ignore in ignores): diff --git a/mne/decoding/tests/test_csp.py b/mne/decoding/tests/test_csp.py index 7a1a83feeaf..e754b6952f9 100644 --- a/mne/decoding/tests/test_csp.py +++ b/mne/decoding/tests/test_csp.py @@ -19,6 +19,7 @@ from sklearn.model_selection import StratifiedKFold, cross_val_score from sklearn.pipeline import Pipeline, make_pipeline from sklearn.svm import SVC +from sklearn.utils.estimator_checks import parametrize_with_checks from mne import Epochs, compute_proj_raw, io, pick_types, read_events from mne.decoding import CSP, LinearModel, Scaler, SPoC, get_coef @@ -139,18 +140,22 @@ def test_csp(): y = epochs.events[:, -1] # Init - pytest.raises(ValueError, CSP, n_components="foo", norm_trace=False) + csp = CSP(n_components="foo") + with pytest.raises(TypeError, match="must be an instance"): + csp.fit(epochs_data, y) for reg in ["foo", -0.1, 1.1]: csp = CSP(reg=reg, norm_trace=False) pytest.raises(ValueError, csp.fit, epochs_data, epochs.events[:, -1]) for reg in ["oas", "ledoit_wolf", 0, 0.5, 1.0]: CSP(reg=reg, norm_trace=False) - for cov_est in ["foo", None]: - pytest.raises(ValueError, CSP, cov_est=cov_est, norm_trace=False) + csp = CSP(cov_est="foo", norm_trace=False) + with pytest.raises(ValueError, match="Invalid value"): + csp.fit(epochs_data, y) + csp = CSP(norm_trace="foo") with pytest.raises(TypeError, match="instance of bool"): - CSP(norm_trace="foo") + csp.fit(epochs_data, y) for cov_est in ["concat", "epoch"]: - CSP(cov_est=cov_est, norm_trace=False) + CSP(cov_est=cov_est, norm_trace=False).fit(epochs_data, y) n_components = 3 # Fit @@ -171,8 +176,8 @@ def test_csp(): # Test data exception pytest.raises(ValueError, csp.fit, epochs_data, np.zeros_like(epochs.events)) - pytest.raises(ValueError, csp.fit, epochs, y) - pytest.raises(ValueError, csp.transform, epochs) + pytest.raises(ValueError, csp.fit, "foo", y) + pytest.raises(ValueError, csp.transform, "foo") # Test plots epochs.pick(picks="mag") @@ -200,7 +205,7 @@ def test_csp(): for cov_est in ["concat", "epoch"]: csp = CSP(n_components=n_components, cov_est=cov_est, norm_trace=False) csp.fit(epochs_data, epochs.events[:, 2]).transform(epochs_data) - assert_equal(len(csp._classes), 4) + assert_equal(len(csp.classes_), 4) assert_array_equal(csp.filters_.shape, [n_channels, n_channels]) assert_array_equal(csp.patterns_.shape, [n_channels, n_channels]) @@ -220,15 +225,17 @@ def test_csp(): # Different normalization return different transform assert np.sum((X_trans["True"] - X_trans["False"]) ** 2) > 1.0 # Check wrong inputs - pytest.raises(ValueError, CSP, transform_into="average_power", log="foo") + csp = CSP(transform_into="average_power", log="foo") + with pytest.raises(TypeError, match="must be an instance of bool"): + csp.fit(epochs_data, epochs.events[:, 2]) # Test csp space transform csp = CSP(transform_into="csp_space", norm_trace=False) assert csp.transform_into == "csp_space" for log in ("foo", True, False): - pytest.raises( - ValueError, CSP, transform_into="csp_space", log=log, norm_trace=False - ) + csp = CSP(transform_into="csp_space", log=log, norm_trace=False) + with pytest.raises(TypeError, match="must be an instance"): + csp.fit(epochs_data, epochs.events[:, 2]) n_components = 2 csp = CSP(n_components=n_components, transform_into="csp_space", norm_trace=False) Xt = csp.fit(epochs_data, epochs.events[:, 2]).transform(epochs_data) @@ -343,8 +350,8 @@ def test_regularized_csp(ch_type, rank, reg): # test init exception pytest.raises(ValueError, csp.fit, epochs_data, np.zeros_like(epochs.events)) - pytest.raises(ValueError, csp.fit, epochs, y) - pytest.raises(ValueError, csp.transform, epochs) + pytest.raises(ValueError, csp.fit, "foo", y) + pytest.raises(ValueError, csp.transform, "foo") csp.n_components = n_components sources = csp.transform(epochs_data) @@ -465,7 +472,9 @@ def test_csp_component_ordering(): """Test that CSP component ordering works as expected.""" x, y = deterministic_toy_data(["class_a", "class_b"]) - pytest.raises(ValueError, CSP, component_order="invalid") + csp = CSP(component_order="invalid") + with pytest.raises(ValueError, match="Invalid value"): + csp.fit(x, y) # component_order='alternate' only works with two classes csp = CSP(component_order="alternate") @@ -480,3 +489,10 @@ def test_csp_component_ordering(): # p_alt arranges them to [0.8, 0.06, 0.5, 0.1] # p_mut arranges them to [0.06, 0.1, 0.8, 0.5] assert_array_almost_equal(p_alt, p_mut[[2, 0, 3, 1]]) + + +@pytest.mark.filterwarnings("ignore:.*Only one sample available.*") +@parametrize_with_checks([CSP(), SPoC()]) +def test_sklearn_compliance(estimator, check): + """Test compliance with sklearn.""" + check(estimator) diff --git a/mne/decoding/tests/test_ems.py b/mne/decoding/tests/test_ems.py index 10774c0681a..dc54303a541 100644 --- a/mne/decoding/tests/test_ems.py +++ b/mne/decoding/tests/test_ems.py @@ -11,6 +11,7 @@ pytest.importorskip("sklearn") from sklearn.model_selection import StratifiedKFold +from sklearn.utils.estimator_checks import parametrize_with_checks from mne import Epochs, io, pick_types, read_events from mne.decoding import EMS, compute_ems @@ -91,3 +92,9 @@ def test_ems(): assert_equal(ems.__repr__(), "") assert_array_almost_equal(filters, np.mean(coefs, axis=0)) assert_array_almost_equal(surrogates, np.vstack(Xt)) + + +@parametrize_with_checks([EMS()]) +def test_sklearn_compliance(estimator, check): + """Test compliance with sklearn.""" + check(estimator) diff --git a/mne/decoding/tests/test_search_light.py b/mne/decoding/tests/test_search_light.py index 7cb3a66dd81..e7abfd9209e 100644 --- a/mne/decoding/tests/test_search_light.py +++ b/mne/decoding/tests/test_search_light.py @@ -41,7 +41,7 @@ def make_data(): return X, y -def test_search_light(): +def test_search_light_basic(): """Test SlidingEstimator.""" # https://github.com/scikit-learn/scikit-learn/issues/27711 if platform.system() == "Windows" and check_version("numpy", "2.0.0.dev0"): @@ -52,7 +52,9 @@ def test_search_light(): X, y = make_data() n_epochs, _, n_time = X.shape # init - pytest.raises(ValueError, SlidingEstimator, "foo") + sl = SlidingEstimator("foo") + with pytest.raises(ValueError, match="must be"): + sl.fit(X, y) sl = SlidingEstimator(Ridge()) assert not is_classifier(sl) sl = SlidingEstimator(LogisticRegression(solver="liblinear")) @@ -69,7 +71,8 @@ def test_search_light(): # transforms pytest.raises(ValueError, sl.predict, X[:, :, :2]) y_trans = sl.transform(X) - assert X.dtype == y_trans.dtype == np.dtype(float) + assert X.dtype == float + assert y_trans.dtype == float y_pred = sl.predict(X) assert y_pred.dtype == np.dtype(int) assert_array_equal(y_pred.shape, [n_epochs, n_time]) @@ -344,22 +347,19 @@ def predict_proba(self, X): @pytest.mark.slowtest -@parametrize_with_checks([SlidingEstimator(LogisticRegression(), allow_2d=True)]) +@parametrize_with_checks( + [ + SlidingEstimator(LogisticRegression(), allow_2d=True), + GeneralizingEstimator(LogisticRegression(), allow_2d=True), + ] +) def test_sklearn_compliance(estimator, check): """Test LinearModel compliance with sklearn.""" ignores = ( - "check_estimator_sparse_data", # we densify - "check_classifiers_one_label_sample_weights", # don't handle singleton - "check_classifiers_classes", # dim mismatch + # TODO: we don't handle singleton right (probably) + "check_classifiers_one_label_sample_weights", + "check_classifiers_classes", "check_classifiers_train", - "check_decision_proba_consistency", - "check_parameters_default_constructible", - # Should probably fix these? - "check_estimators_unfitted", - "check_transformer_data_not_an_array", - "check_n_features_in", - "check_fit2d_predict1d", - "check_do_not_raise_errors_in_init_or_set_params", ) if any(ignore in str(check) for ignore in ignores): return diff --git a/mne/decoding/tests/test_ssd.py b/mne/decoding/tests/test_ssd.py index 8f4d2472803..b6cdfc472c3 100644 --- a/mne/decoding/tests/test_ssd.py +++ b/mne/decoding/tests/test_ssd.py @@ -11,6 +11,7 @@ pytest.importorskip("sklearn") from sklearn.pipeline import Pipeline +from sklearn.utils.estimator_checks import parametrize_with_checks from mne import create_info, io from mne.decoding import CSP @@ -101,8 +102,9 @@ def test_ssd(): l_trans_bandwidth=1, h_trans_bandwidth=1, ) + ssd = SSD(info, filt_params_signal, filt_params_noise) with pytest.raises(TypeError, match="must be an instance "): - ssd = SSD(info, filt_params_signal, filt_params_noise) + ssd.fit(X) # Wrongly specified noise band freq = 2 @@ -115,14 +117,16 @@ def test_ssd(): l_trans_bandwidth=1, h_trans_bandwidth=1, ) + ssd = SSD(info, filt_params_signal, filt_params_noise) with pytest.raises(ValueError, match="Wrongly specified "): - ssd = SSD(info, filt_params_signal, filt_params_noise) + ssd.fit(X) # filt param no dict filt_params_signal = freqs_sig filt_params_noise = freqs_noise + ssd = SSD(info, filt_params_signal, filt_params_noise) with pytest.raises(ValueError, match="must be defined"): - ssd = SSD(info, filt_params_signal, filt_params_noise) + ssd.fit(X) # Data type filt_params_signal = dict( @@ -140,15 +144,18 @@ def test_ssd(): ssd = SSD(info, filt_params_signal, filt_params_noise) raw = io.RawArray(X, info) - pytest.raises(TypeError, ssd.fit, raw) + with pytest.raises(ValueError): + ssd.fit(raw) # check non-boolean return_filtered - with pytest.raises(ValueError, match="return_filtered"): - ssd = SSD(info, filt_params_signal, filt_params_noise, return_filtered=0) + ssd = SSD(info, filt_params_signal, filt_params_noise, return_filtered=0) + with pytest.raises(TypeError, match="return_filtered"): + ssd.fit(X) # check non-boolean sort_by_spectral_ratio - with pytest.raises(ValueError, match="sort_by_spectral_ratio"): - ssd = SSD(info, filt_params_signal, filt_params_noise, sort_by_spectral_ratio=0) + ssd = SSD(info, filt_params_signal, filt_params_noise, sort_by_spectral_ratio=0) + with pytest.raises(TypeError, match="sort_by_spectral_ratio"): + ssd.fit(X) # More than 1 channel type ch_types = np.reshape([["mag"] * 10, ["eeg"] * 10], n_channels) @@ -161,7 +168,8 @@ def test_ssd(): # Number of channels info_3 = create_info(ch_names=n_channels + 1, sfreq=sf, ch_types="eeg") ssd = SSD(info_3, filt_params_signal, filt_params_noise) - pytest.raises(ValueError, ssd.fit, X) + with pytest.raises(ValueError, match="channels but expected"): + ssd.fit(X) # Fit n_components = 10 @@ -381,7 +389,7 @@ def test_sorting(): ssd.fit(Xtr) # check sorters - sorter_in = ssd.sorter_spec + sorter_in = ssd.sorter_spec_ ssd = SSD( info, filt_params_signal, @@ -476,3 +484,29 @@ def test_non_full_rank_data(): if sys.platform == "darwin": pytest.xfail("Unknown linalg bug (Accelerate?)") ssd.fit(X) + + +@pytest.mark.filterwarnings("ignore:.*invalid value encountered in divide.*") +@pytest.mark.filterwarnings("ignore:.*is longer than.*") +@parametrize_with_checks( + [ + SSD( + 100.0, + dict(l_freq=0.0, h_freq=30.0), + dict(l_freq=0.0, h_freq=40.0), + ) + ] +) +def test_sklearn_compliance(estimator, check): + """Test LinearModel compliance with sklearn.""" + ignores = ( + "check_methods_sample_order_invariance", + # Shape stuff + "check_fit_idempotent", + "check_methods_subset_invariance", + "check_transformer_general", + "check_transformer_data_not_an_array", + ) + if any(ignore in str(check) for ignore in ignores): + return + check(estimator) diff --git a/mne/decoding/tests/test_time_frequency.py b/mne/decoding/tests/test_time_frequency.py index 37e7d7d8dc2..638cebda21e 100644 --- a/mne/decoding/tests/test_time_frequency.py +++ b/mne/decoding/tests/test_time_frequency.py @@ -10,18 +10,23 @@ pytest.importorskip("sklearn") from sklearn.base import clone +from sklearn.utils.estimator_checks import parametrize_with_checks from mne.decoding.time_frequency import TimeFrequency -def test_timefrequency(): +def test_timefrequency_basic(): """Test TimeFrequency.""" # Init n_freqs = 3 freqs = [20, 21, 22] tf = TimeFrequency(freqs, sfreq=100) + n_epochs, n_chans, n_times = 10, 2, 100 + X = np.random.rand(n_epochs, n_chans, n_times) for output in ["avg_power", "foo", None]: - pytest.raises(ValueError, TimeFrequency, freqs, output=output) + tf = TimeFrequency(freqs, output=output) + with pytest.raises(ValueError, match="Invalid value"): + tf.fit(X) tf = clone(tf) # Clone estimator @@ -30,9 +35,9 @@ def test_timefrequency(): clone(tf) # Fit - n_epochs, n_chans, n_times = 10, 2, 100 - X = np.random.rand(n_epochs, n_chans, n_times) + assert not hasattr(tf, "fitted_") tf.fit(X, None) + assert tf.fitted_ # Transform tf = TimeFrequency(freqs, sfreq=100) @@ -41,9 +46,15 @@ def test_timefrequency(): Xt = tf.transform(X) assert_array_equal(Xt.shape, [n_epochs, n_chans, n_freqs, n_times]) # 2-D X - Xt = tf.transform(X[:, 0, :]) + Xt = tf.fit_transform(X[:, 0, :]) assert_array_equal(Xt.shape, [n_epochs, n_freqs, n_times]) # 3-D with decim tf = TimeFrequency(freqs, sfreq=100, decim=2) - Xt = tf.transform(X) + Xt = tf.fit_transform(X) assert_array_equal(Xt.shape, [n_epochs, n_chans, n_freqs, n_times // 2]) + + +@parametrize_with_checks([TimeFrequency([300, 400], 1000.0, n_cycles=0.25)]) +def test_sklearn_compliance(estimator, check): + """Test LinearModel compliance with sklearn.""" + check(estimator) diff --git a/mne/decoding/tests/test_transformer.py b/mne/decoding/tests/test_transformer.py index 8dcc3ad74c7..a8afe209d96 100644 --- a/mne/decoding/tests/test_transformer.py +++ b/mne/decoding/tests/test_transformer.py @@ -17,10 +17,14 @@ from sklearn.decomposition import PCA from sklearn.kernel_ridge import KernelRidge +from sklearn.pipeline import make_pipeline +from sklearn.preprocessing import StandardScaler +from sklearn.utils.estimator_checks import parametrize_with_checks -from mne import Epochs, io, pick_types, read_events +from mne import Epochs, EpochsArray, create_info, io, pick_types, read_events from mne.decoding import ( FilterEstimator, + LinearModel, PSDEstimator, Scaler, TemporalFilter, @@ -36,6 +40,7 @@ data_dir = Path(__file__).parents[2] / "io" / "tests" / "data" raw_fname = data_dir / "test_raw.fif" event_name = data_dir / "test-eve.fif" +info = create_info(2, 1000.0, "eeg") @pytest.mark.parametrize( @@ -101,9 +106,11 @@ def test_scaler(info, method): assert_array_almost_equal(epochs_data, Xi) # Test init exception - pytest.raises(ValueError, Scaler, None, None) - pytest.raises(TypeError, scaler.fit, epochs, y) - pytest.raises(TypeError, scaler.transform, epochs) + x = Scaler(None, None) + with pytest.raises(ValueError): + x.fit(epochs_data, y) + pytest.raises(ValueError, scaler.fit, "foo", y) + pytest.raises(ValueError, scaler.transform, "foo") epochs_bad = Epochs( raw, events, @@ -164,8 +171,8 @@ def test_filterestimator(): X = filt.fit_transform(epochs_data, y) # Test init exception - pytest.raises(ValueError, filt.fit, epochs, y) - pytest.raises(ValueError, filt.transform, epochs) + pytest.raises(ValueError, filt.fit, "foo", y) + pytest.raises(ValueError, filt.transform, "foo") def test_psdestimator(): @@ -182,14 +189,18 @@ def test_psdestimator(): epochs_data = epochs.get_data(copy=False) psd = PSDEstimator(2 * np.pi, 0, np.inf) y = epochs.events[:, -1] + assert not hasattr(psd, "fitted_") X = psd.fit_transform(epochs_data, y) + assert psd.fitted_ assert X.shape[0] == epochs_data.shape[0] assert_array_equal(psd.fit(epochs_data, y).transform(epochs_data), X) # Test init exception - pytest.raises(ValueError, psd.fit, epochs, y) - pytest.raises(ValueError, psd.transform, epochs) + with pytest.raises(ValueError): + psd.fit("foo", y) + with pytest.raises(ValueError): + psd.transform("foo") def test_vectorizer(): @@ -210,9 +221,16 @@ def test_vectorizer(): assert_equal(vect.fit_transform(data[1:]).shape, (149, 108)) # check if raised errors are working correctly - vect.fit(np.random.rand(105, 12, 3)) - pytest.raises(ValueError, vect.transform, np.random.rand(105, 12, 3, 1)) - pytest.raises(ValueError, vect.inverse_transform, np.random.rand(102, 12, 12)) + X = np.random.default_rng(0).standard_normal((105, 12, 3)) + y = np.arange(X.shape[0]) % 2 + pytest.raises(ValueError, vect.transform, X[..., np.newaxis]) + pytest.raises(ValueError, vect.inverse_transform, X[:, :-1]) + + # And that pipelines work properly + X_arr = EpochsArray(X, create_info(12, 1000.0, "eeg")) + vect.fit(X_arr) + clf = make_pipeline(Vectorizer(), StandardScaler(), LinearModel()) + clf.fit(X_arr, y) def test_unsupervised_spatial_filter(): @@ -235,11 +253,13 @@ def test_unsupervised_spatial_filter(): verbose=False, ) - # Test estimator - pytest.raises(ValueError, UnsupervisedSpatialFilter, KernelRidge(2)) + # Test estimator (must be a transformer) + X = epochs.get_data(copy=False) + usf = UnsupervisedSpatialFilter(KernelRidge(2)) + with pytest.raises(ValueError, match="transform"): + usf.fit(X) # Test fit - X = epochs.get_data(copy=False) n_components = 4 usf = UnsupervisedSpatialFilter(PCA(n_components)) usf.fit(X) @@ -255,7 +275,9 @@ def test_unsupervised_spatial_filter(): # Test with average param usf = UnsupervisedSpatialFilter(PCA(4), average=True) usf.fit_transform(X) - pytest.raises(ValueError, UnsupervisedSpatialFilter, PCA(4), 2) + usf = UnsupervisedSpatialFilter(PCA(4), 2) + with pytest.raises(TypeError, match="average must be"): + usf.fit(X) def test_temporal_filter(): @@ -281,8 +303,8 @@ def test_temporal_filter(): assert X.shape == Xt.shape # Test fit and transform numpy type check - with pytest.raises(ValueError, match="Data to be filtered must be"): - filt.transform([1, 2]) + with pytest.raises(ValueError): + filt.transform("foo") # Test with 2 dimensional data array X = np.random.rand(101, 500) @@ -298,4 +320,36 @@ def test_bad_triage(): filt = TemporalFilter(l_freq=8, h_freq=60, sfreq=160.0) # Used to fail with "ValueError: Effective band-stop frequency (135.0) is # too high (maximum based on Nyquist is 80.0)" + assert not hasattr(filt, "fitted_") filt.fit_transform(np.zeros((1, 1, 481))) + assert filt.fitted_ + + +@pytest.mark.filterwarnings("ignore:.*filter_length.*") +@parametrize_with_checks( + [ + FilterEstimator(info, l_freq=1, h_freq=10), + PSDEstimator(), + Scaler(scalings="mean"), + # Not easy to test Scaler(info) b/c number of channels must match + TemporalFilter(), + UnsupervisedSpatialFilter(PCA()), + Vectorizer(), + ] +) +def test_sklearn_compliance(estimator, check): + """Test LinearModel compliance with sklearn.""" + ignores = [] + if estimator.__class__.__name__ == "FilterEstimator": + ignores += [ + "check_estimators_overwrite_params", # we modify self.info + "check_methods_sample_order_invariance", + ] + if estimator.__class__.__name__.startswith(("PSD", "Temporal")): + ignores += [ + "check_transformers_unfitted", # allow unfitted transform + "check_methods_sample_order_invariance", + ] + if any(ignore in str(check) for ignore in ignores): + return + check(estimator) diff --git a/mne/decoding/time_frequency.py b/mne/decoding/time_frequency.py index de6ec52155b..29232aaeb9f 100644 --- a/mne/decoding/time_frequency.py +++ b/mne/decoding/time_frequency.py @@ -3,14 +3,16 @@ # Copyright the MNE-Python contributors. import numpy as np -from sklearn.base import BaseEstimator, TransformerMixin +from sklearn.base import BaseEstimator +from sklearn.utils.validation import check_is_fitted from ..time_frequency.tfr import _compute_tfr -from ..utils import _check_option, fill_doc, verbose +from ..utils import _check_option, fill_doc +from .transformer import MNETransformerMixin @fill_doc -class TimeFrequency(TransformerMixin, BaseEstimator): +class TimeFrequency(MNETransformerMixin, BaseEstimator): """Time frequency transformer. Time-frequency transform of times series along the last axis. @@ -59,7 +61,6 @@ class TimeFrequency(TransformerMixin, BaseEstimator): mne.time_frequency.tfr_multitaper """ - @verbose def __init__( self, freqs, @@ -74,9 +75,6 @@ def __init__( verbose=None, ): """Init TimeFrequency transformer.""" - # Check non-average output - output = _check_option("output", output, ["complex", "power", "phase"]) - self.freqs = freqs self.sfreq = sfreq self.method = method @@ -89,6 +87,16 @@ def __init__( self.n_jobs = n_jobs self.verbose = verbose + def __sklearn_tags__(self): + """Return sklearn tags.""" + out = super().__sklearn_tags__() + from sklearn.utils import TransformerTags + + if out.transformer_tags is None: + out.transformer_tags = TransformerTags() + out.transformer_tags.preserves_dtype = [] # real->complex + return out + def fit_transform(self, X, y=None): """Time-frequency transform of times series along the last axis. @@ -123,6 +131,10 @@ def fit(self, X, y=None): # noqa: D401 self : object Return self. """ + # Check non-average output + _check_option("output", self.output, ["complex", "power", "phase"]) + self._check_data(X, y=y, fit=True) + self.fitted_ = True return self def transform(self, X): @@ -130,16 +142,18 @@ def transform(self, X): Parameters ---------- - X : array, shape (n_samples, n_channels, n_times) + X : array, shape (n_samples, [n_channels, ]n_times) The training data samples. The channel dimension can be zero- or 1-dimensional. Returns ------- - Xt : array, shape (n_samples, n_channels, n_freqs, n_times) + Xt : array, shape (n_samples, [n_channels, ]n_freqs, n_times) The time-frequency transform of the data, where n_channels can be zero- or 1-dimensional. """ + X = self._check_data(X, atleast_3d=False) + check_is_fitted(self, "fitted_") # Ensure 3-dimensional X shape = X.shape[1:-1] if not shape: diff --git a/mne/decoding/transformer.py b/mne/decoding/transformer.py index e475cd22161..6d0c83f42ab 100644 --- a/mne/decoding/transformer.py +++ b/mne/decoding/transformer.py @@ -3,19 +3,72 @@ # Copyright the MNE-Python contributors. import numpy as np -from sklearn.base import BaseEstimator, TransformerMixin +from sklearn.base import BaseEstimator, TransformerMixin, check_array, clone +from sklearn.preprocessing import RobustScaler, StandardScaler +from sklearn.utils import check_X_y +from sklearn.utils.validation import check_is_fitted, validate_data from .._fiff.pick import ( _pick_data_channels, _picks_by_type, _picks_to_idx, pick_info, - pick_types, ) from ..cov import _check_scalings_user +from ..epochs import BaseEpochs from ..filter import filter_data from ..time_frequency import psd_array_multitaper -from ..utils import _check_option, _validate_type, fill_doc, verbose +from ..utils import _check_option, _validate_type, fill_doc + + +class MNETransformerMixin(TransformerMixin): + """TransformerMixin plus some helpers.""" + + def _check_data( + self, + epochs_data, + *, + y=None, + atleast_3d=True, + fit=False, + return_y=False, + multi_output=False, + check_n_features=True, + ): + # Sklearn calls asarray under the hood which works, but elsewhere they check for + # __len__ then look at the size of obj[0]... which is an epoch of shape (1, ...) + # rather than what they expect (shape (...)). So we explicitly get the NumPy + # array to make everyone happy. + if isinstance(epochs_data, BaseEpochs): + epochs_data = epochs_data.get_data(copy=False) + kwargs = dict(dtype=np.float64, allow_nd=True, order="C", force_writeable=True) + if hasattr(self, "n_features_in_") and check_n_features: + if y is None: + epochs_data = validate_data( + self, + epochs_data, + **kwargs, + reset=fit, + ) + else: + epochs_data, y = validate_data( + self, + epochs_data, + y, + **kwargs, + reset=fit, + ) + elif y is None: + epochs_data = check_array(epochs_data, **kwargs) + else: + epochs_data, y = check_X_y( + X=epochs_data, y=y, multi_output=multi_output, **kwargs + ) + if fit: + self.n_features_in_ = epochs_data.shape[1] + if atleast_3d: + epochs_data = np.atleast_3d(epochs_data) + return (epochs_data, y) if return_y else epochs_data class _ConstantScaler: @@ -55,8 +108,9 @@ def fit_transform(self, X, y=None): def _sklearn_reshape_apply(func, return_result, X, *args, **kwargs): """Reshape epochs and apply function.""" - if not isinstance(X, np.ndarray): - raise ValueError(f"data should be an np.ndarray, got {type(X)}.") + _validate_type(X, np.ndarray, "X") + if X.size == 0: + return X.copy() if return_result else None orig_shape = X.shape X = np.reshape(X.transpose(0, 2, 1), (-1, orig_shape[1])) X = func(X, *args, **kwargs) @@ -67,7 +121,7 @@ def _sklearn_reshape_apply(func, return_result, X, *args, **kwargs): @fill_doc -class Scaler(TransformerMixin, BaseEstimator): +class Scaler(MNETransformerMixin, BaseEstimator): """Standardize channel data. This class scales data for each channel. It differs from scikit-learn @@ -109,31 +163,6 @@ def __init__(self, info=None, scalings=None, with_mean=True, with_std=True): self.with_std = with_std self.scalings = scalings - if not (scalings is None or isinstance(scalings, dict | str)): - raise ValueError( - f"scalings type should be dict, str, or None, got {type(scalings)}" - ) - if isinstance(scalings, str): - _check_option("scalings", scalings, ["mean", "median"]) - if scalings is None or isinstance(scalings, dict): - if info is None: - raise ValueError( - f'Need to specify "info" if scalings is {type(scalings)}' - ) - self._scaler = _ConstantScaler(info, scalings, self.with_std) - elif scalings == "mean": - from sklearn.preprocessing import StandardScaler - - self._scaler = StandardScaler( - with_mean=self.with_mean, with_std=self.with_std - ) - else: # scalings == 'median': - from sklearn.preprocessing import RobustScaler - - self._scaler = RobustScaler( - with_centering=self.with_mean, with_scaling=self.with_std - ) - def fit(self, epochs_data, y=None): """Standardize data across channels. @@ -149,11 +178,30 @@ def fit(self, epochs_data, y=None): self : instance of Scaler The modified instance. """ - _validate_type(epochs_data, np.ndarray, "epochs_data") - if epochs_data.ndim == 2: - epochs_data = epochs_data[..., np.newaxis] + epochs_data = self._check_data(epochs_data, y=y, fit=True, multi_output=True) assert epochs_data.ndim == 3, epochs_data.shape - _sklearn_reshape_apply(self._scaler.fit, False, epochs_data, y=y) + + _validate_type(self.scalings, (dict, str, type(None)), "scalings") + if isinstance(self.scalings, str): + _check_option( + "scalings", self.scalings, ["mean", "median"], extra="when str" + ) + if self.scalings is None or isinstance(self.scalings, dict): + if self.info is None: + raise ValueError( + f'Need to specify "info" if scalings is {type(self.scalings)}' + ) + self.scaler_ = _ConstantScaler(self.info, self.scalings, self.with_std) + elif self.scalings == "mean": + self.scaler_ = StandardScaler( + with_mean=self.with_mean, with_std=self.with_std + ) + else: # scalings == 'median': + self.scaler_ = RobustScaler( + with_centering=self.with_mean, with_scaling=self.with_std + ) + + _sklearn_reshape_apply(self.scaler_.fit, False, epochs_data, y=y) return self def transform(self, epochs_data): @@ -174,13 +222,14 @@ def transform(self, epochs_data): This function makes a copy of the data before the operations and the memory usage may be large with big data. """ - _validate_type(epochs_data, np.ndarray, "epochs_data") + check_is_fitted(self, "scaler_") + epochs_data = self._check_data(epochs_data, atleast_3d=False) if epochs_data.ndim == 2: # can happen with SlidingEstimator if self.info is not None: assert len(self.info["ch_names"]) == epochs_data.shape[1] epochs_data = epochs_data[..., np.newaxis] assert epochs_data.ndim == 3, epochs_data.shape - return _sklearn_reshape_apply(self._scaler.transform, True, epochs_data) + return _sklearn_reshape_apply(self.scaler_.transform, True, epochs_data) def fit_transform(self, epochs_data, y=None): """Fit to data, then transform it. @@ -226,19 +275,20 @@ def inverse_transform(self, epochs_data): This function makes a copy of the data before the operations and the memory usage may be large with big data. """ + epochs_data = self._check_data(epochs_data, atleast_3d=False) squeeze = False # Can happen with CSP if epochs_data.ndim == 2: squeeze = True epochs_data = epochs_data[..., np.newaxis] assert epochs_data.ndim == 3, epochs_data.shape - out = _sklearn_reshape_apply(self._scaler.inverse_transform, True, epochs_data) + out = _sklearn_reshape_apply(self.scaler_.inverse_transform, True, epochs_data) if squeeze: out = out[..., 0] return out -class Vectorizer(TransformerMixin, BaseEstimator): +class Vectorizer(MNETransformerMixin, BaseEstimator): """Transform n-dimensional array into 2D array of n_samples by n_features. This class reshapes an n-dimensional array into an n_samples * n_features @@ -275,7 +325,7 @@ def fit(self, X, y=None): self : instance of Vectorizer Return the modified instance. """ - X = np.asarray(X) + X = self._check_data(X, y=y, atleast_3d=False, fit=True, check_n_features=False) self.features_shape_ = X.shape[1:] return self @@ -295,7 +345,7 @@ def transform(self, X): X : array, shape (n_samples, n_features) The transformed data. """ - X = np.asarray(X) + X = self._check_data(X, atleast_3d=False) if X.shape[1:] != self.features_shape_: raise ValueError("Shape of X used in fit and transform must be same") return X.reshape(len(X), -1) @@ -334,7 +384,7 @@ def inverse_transform(self, X): The data transformed into shape as used in fit. The first dimension is of length n_samples. """ - X = np.asarray(X) + X = self._check_data(X, atleast_3d=False, check_n_features=False) if X.ndim not in (2, 3): raise ValueError( f"X should be of 2 or 3 dimensions but has shape {X.shape}" @@ -343,7 +393,7 @@ def inverse_transform(self, X): @fill_doc -class PSDEstimator(TransformerMixin, BaseEstimator): +class PSDEstimator(MNETransformerMixin, BaseEstimator): """Compute power spectral density (PSD) using a multi-taper method. Parameters @@ -365,7 +415,6 @@ class PSDEstimator(TransformerMixin, BaseEstimator): n_jobs : int Number of parallel jobs to use (only used if adaptive=True). %(normalization)s - %(verbose)s See Also -------- @@ -375,7 +424,6 @@ class PSDEstimator(TransformerMixin, BaseEstimator): mne.Evoked.compute_psd """ - @verbose def __init__( self, sfreq=2 * np.pi, @@ -386,8 +434,6 @@ def __init__( low_bias=True, n_jobs=None, normalization="length", - *, - verbose=None, ): self.sfreq = sfreq self.fmin = fmin @@ -398,7 +444,7 @@ def __init__( self.n_jobs = n_jobs self.normalization = normalization - def fit(self, epochs_data, y): + def fit(self, epochs_data, y=None): """Compute power spectral density (PSD) using a multi-taper method. Parameters @@ -413,11 +459,8 @@ def fit(self, epochs_data, y): self : instance of PSDEstimator The modified instance. """ - if not isinstance(epochs_data, np.ndarray): - raise ValueError( - f"epochs_data should be of type ndarray (got {type(epochs_data)})." - ) - + self._check_data(epochs_data, y=y, fit=True) + self.fitted_ = True # sklearn compliance return self def transform(self, epochs_data): @@ -433,10 +476,7 @@ def transform(self, epochs_data): psd : array, shape (n_signals, n_freqs) or (n_freqs,) The computed PSD. """ - if not isinstance(epochs_data, np.ndarray): - raise ValueError( - f"epochs_data should be of type ndarray (got {type(epochs_data)})." - ) + epochs_data = self._check_data(epochs_data) psd, _ = psd_array_multitaper( epochs_data, sfreq=self.sfreq, @@ -452,7 +492,7 @@ def transform(self, epochs_data): @fill_doc -class FilterEstimator(TransformerMixin, BaseEstimator): +class FilterEstimator(MNETransformerMixin, BaseEstimator): """Estimator to filter RtEpochs. Applies a zero-phase low-pass, high-pass, band-pass, or band-stop @@ -488,7 +528,6 @@ class FilterEstimator(TransformerMixin, BaseEstimator): See mne.filter.construct_iir_filter for details. If iir_params is None and method="iir", 4th order Butterworth will be used. %(fir_design)s - %(verbose)s See Also -------- @@ -514,13 +553,11 @@ def __init__( method="fir", iir_params=None, fir_design="firwin", - *, - verbose=None, ): self.info = info self.l_freq = l_freq self.h_freq = h_freq - self.picks = _picks_to_idx(info, picks) + self.picks = picks self.filter_length = filter_length self.l_trans_bandwidth = l_trans_bandwidth self.h_trans_bandwidth = h_trans_bandwidth @@ -544,24 +581,11 @@ def fit(self, epochs_data, y): self : instance of FilterEstimator The modified instance. """ - if not isinstance(epochs_data, np.ndarray): - raise ValueError( - f"epochs_data should be of type ndarray (got {type(epochs_data)})." - ) - - if self.picks is None: - self.picks = pick_types( - self.info, meg=True, eeg=True, ref_meg=False, exclude=[] - ) + self.picks_ = _picks_to_idx(self.info, self.picks) + self._check_data(epochs_data, y=y, fit=True) if self.l_freq == 0: self.l_freq = None - if self.h_freq is not None and self.h_freq > (self.info["sfreq"] / 2.0): - self.h_freq = None - if self.l_freq is not None and not isinstance(self.l_freq, float): - self.l_freq = float(self.l_freq) - if self.h_freq is not None and not isinstance(self.h_freq, float): - self.h_freq = float(self.h_freq) if self.info["lowpass"] is None or ( self.h_freq is not None @@ -594,17 +618,12 @@ def transform(self, epochs_data): X : array, shape (n_epochs, n_channels, n_times) The data after filtering. """ - if not isinstance(epochs_data, np.ndarray): - raise ValueError( - f"epochs_data should be of type ndarray (got {type(epochs_data)})." - ) - epochs_data = np.atleast_3d(epochs_data) return filter_data( - epochs_data, + self._check_data(epochs_data), self.info["sfreq"], self.l_freq, self.h_freq, - self.picks, + self.picks_, self.filter_length, self.l_trans_bandwidth, self.h_trans_bandwidth, @@ -617,7 +636,7 @@ def transform(self, epochs_data): ) -class UnsupervisedSpatialFilter(TransformerMixin, BaseEstimator): +class UnsupervisedSpatialFilter(MNETransformerMixin, BaseEstimator): """Use unsupervised spatial filtering across time and samples. Parameters @@ -630,19 +649,6 @@ class UnsupervisedSpatialFilter(TransformerMixin, BaseEstimator): """ def __init__(self, estimator, average=False): - # XXX: Use _check_estimator #3381 - for attr in ("fit", "transform", "fit_transform"): - if not hasattr(estimator, attr): - raise ValueError( - "estimator must be a scikit-learn " - f"transformer, missing {attr} method" - ) - - if not isinstance(average, bool): - raise ValueError( - f"average parameter must be of bool type, got {type(bool)} instead" - ) - self.estimator = estimator self.average = average @@ -661,13 +667,25 @@ def fit(self, X, y=None): self : instance of UnsupervisedSpatialFilter Return the modified instance. """ + # sklearn.utils.estimator_checks.check_estimator(self.estimator) is probably + # too strict for us, given that we don't fully adhere yet, so just check attrs + for attr in ("fit", "transform", "fit_transform"): + if not hasattr(self.estimator, attr): + raise ValueError( + "estimator must be a scikit-learn " + f"transformer, missing {attr} method" + ) + _validate_type(self.average, bool, "average") + X = self._check_data(X, y=y, fit=True) if self.average: X = np.mean(X, axis=0).T else: n_epochs, n_channels, n_times = X.shape # trial as time samples X = np.transpose(X, (1, 0, 2)).reshape((n_channels, n_epochs * n_times)).T - self.estimator.fit(X) + + self.estimator_ = clone(self.estimator) + self.estimator_.fit(X) return self def fit_transform(self, X, y=None): @@ -700,6 +718,8 @@ def transform(self, X): X : array, shape (n_epochs, n_channels, n_times) The transformed data. """ + check_is_fitted(self.estimator_) + X = self._check_data(X) return self._apply_method(X, "transform") def inverse_transform(self, X): @@ -735,7 +755,7 @@ def _apply_method(self, X, method): X = np.transpose(X, [1, 0, 2]) X = np.reshape(X, [n_channels, n_epochs * n_times]).T # apply method - method = getattr(self.estimator, method) + method = getattr(self.estimator_, method) X = method(X) # put it back to n_epochs, n_dimensions X = np.reshape(X.T, [-1, n_epochs, n_times]).transpose([1, 0, 2]) @@ -743,7 +763,7 @@ def _apply_method(self, X, method): @fill_doc -class TemporalFilter(TransformerMixin, BaseEstimator): +class TemporalFilter(MNETransformerMixin, BaseEstimator): """Estimator to filter data array along the last dimension. Applies a zero-phase low-pass, high-pass, band-pass, or band-stop @@ -817,7 +837,6 @@ class TemporalFilter(TransformerMixin, BaseEstimator): attenuation using fewer samples than "firwin2". .. versionadded:: 0.15 - %(verbose)s See Also -------- @@ -826,7 +845,6 @@ class TemporalFilter(TransformerMixin, BaseEstimator): mne.filter.filter_data """ - @verbose def __init__( self, l_freq=None, @@ -840,8 +858,6 @@ def __init__( iir_params=None, fir_window="hamming", fir_design="firwin", - *, - verbose=None, ): self.l_freq = l_freq self.h_freq = h_freq @@ -855,17 +871,12 @@ def __init__( self.fir_window = fir_window self.fir_design = fir_design - if not isinstance(self.n_jobs, int) and self.n_jobs == "cuda": - raise ValueError( - f'n_jobs must be int or "cuda", got {type(self.n_jobs)} instead.' - ) - def fit(self, X, y=None): """Do nothing (for scikit-learn compatibility purposes). Parameters ---------- - X : array, shape (n_epochs, n_channels, n_times) or or shape (n_channels, n_times) + X : array, shape ([n_epochs, ]n_channels, n_times) The data to be filtered over the last dimension. The channels dimension can be zero when passing a 2D array. y : None @@ -875,7 +886,9 @@ def fit(self, X, y=None): ------- self : instance of TemporalFilter The modified instance. - """ # noqa: E501 + """ + self.fitted_ = True # sklearn compliance + self._check_data(X, y=y, atleast_3d=False, fit=True) return self def transform(self, X): @@ -883,7 +896,7 @@ def transform(self, X): Parameters ---------- - X : array, shape (n_epochs, n_channels, n_times) or shape (n_channels, n_times) + X : array, shape ([n_epochs, ]n_channels, n_times) The data to be filtered over the last dimension. The channels dimension can be zero when passing a 2D array. @@ -892,6 +905,7 @@ def transform(self, X): X : array The data after filtering. """ # noqa: E501 + X = self._check_data(X, atleast_3d=False) X = np.atleast_2d(X) if X.ndim > 3: diff --git a/mne/time_frequency/multitaper.py b/mne/time_frequency/multitaper.py index 98705e838c2..1c1a3baf238 100644 --- a/mne/time_frequency/multitaper.py +++ b/mne/time_frequency/multitaper.py @@ -63,7 +63,14 @@ def dpss_windows(N, half_nbw, Kmax, *, sym=True, norm=None, low_bias=True): ---------- .. footbibliography:: """ - dpss, eigvals = sp_dpss(N, half_nbw, Kmax, sym=sym, norm=norm, return_ratios=True) + # TODO VERSION can be removed with SciPy 1.16 is min, + # workaround for https://github.com/scipy/scipy/pull/22344 + if N <= 1: + dpss, eigvals = np.ones((1, 1)), np.ones(1) + else: + dpss, eigvals = sp_dpss( + N, half_nbw, Kmax, sym=sym, norm=norm, return_ratios=True + ) if low_bias: idx = eigvals > 0.9 if not idx.any(): diff --git a/mne/time_frequency/tfr.py b/mne/time_frequency/tfr.py index fc60802f61b..f4a01e87895 100644 --- a/mne/time_frequency/tfr.py +++ b/mne/time_frequency/tfr.py @@ -562,7 +562,8 @@ def _compute_tfr( if len(Ws[0][0]) > epoch_data.shape[2]: raise ValueError( "At least one of the wavelets is longer than the " - "signal. Use a longer signal or shorter wavelets." + f"signal ({len(Ws[0][0])} > {epoch_data.shape[2]} samples). " + "Use a longer signal or shorter wavelets." ) # Initialize output diff --git a/mne/utils/numerics.py b/mne/utils/numerics.py index 5029e8fbeca..11ba0ecb487 100644 --- a/mne/utils/numerics.py +++ b/mne/utils/numerics.py @@ -35,7 +35,7 @@ check_random_state, ) from .docs import fill_doc -from .misc import _empty_hash +from .misc import _empty_hash, _pl def split_list(v, n, idx=False): @@ -479,7 +479,8 @@ def _time_mask( extra = "" if include_tmax else "when include_tmax=False " raise ValueError( f"No samples remain when using tmin={orig_tmin} and tmax={orig_tmax} " - f"{extra}(original time bounds are [{times[0]}, {times[-1]}])" + f"{extra}(original time bounds are [{times[0]}, {times[-1]}] containing " + f"{len(times)} sample{_pl(times)})" ) return mask diff --git a/tools/vulture_allowlist.py b/tools/vulture_allowlist.py index 24bcd9af64a..9d0e215ee80 100644 --- a/tools/vulture_allowlist.py +++ b/tools/vulture_allowlist.py @@ -41,6 +41,8 @@ # Decoding _._more_tags +_.multi_class +_.preserves_dtype deep # Backward compat or rarely used