Skip to content

Commit

Permalink
BUG: Improve sklearn compliance (#13065)
Browse files Browse the repository at this point in the history
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Daniel McCloy <[email protected]>
  • Loading branch information
3 people authored Jan 22, 2025
1 parent 99e9858 commit 5f2b7f1
Show file tree
Hide file tree
Showing 21 changed files with 571 additions and 349 deletions.
7 changes: 7 additions & 0 deletions doc/changes/devel/13065.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Improved sklearn class compatibility and compliance, which resulted in some parameters of classes having an underscore appended to their name during ``fit``, such as:

- :class:`mne.decoding.FilterEstimator` parameter ``picks`` passed to the initializer is set as ``est.picks_``
- :class:`mne.decoding.UnsupervisedSpatialFilter` parameter ``estimator`` passed to the initializer is set as ``est.estimator_``

Unused ``verbose`` class parameters (that had no effect) were removed from :class:`~mne.decoding.PSDEstimator`, :class:`~mne.decoding.TemporalFilter`, and :class:`~mne.decoding.FilterEstimator` as well.
Changes by `Eric Larson`_.
2 changes: 1 addition & 1 deletion examples/decoding/linear_model_patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@

# Extract and plot spatial filters and spatial patterns
for name, coef in (("patterns", model.patterns_), ("filters", model.filters_)):
# We fitted the linear model onto Z-scored data. To make the filters
# We fit the linear model on Z-scored data. To make the filters
# interpretable, we must reverse this normalization step
coef = scaler.inverse_transform([coef])[0]

Expand Down
4 changes: 2 additions & 2 deletions mne/cov.py
Original file line number Diff line number Diff line change
Expand Up @@ -1226,7 +1226,7 @@ def _compute_rank_raw_array(
from .io import RawArray

return _compute_rank(
RawArray(data, info, copy=None, verbose=_verbose_safe_false()),
RawArray(data, info, copy="auto", verbose=_verbose_safe_false()),
rank,
scalings,
info,
Expand Down Expand Up @@ -1405,7 +1405,7 @@ def _compute_covariance_auto(
# project back
cov = np.dot(eigvec.T, np.dot(cov, eigvec))
# undo bias
cov *= data.shape[0] / (data.shape[0] - 1)
cov *= data.shape[0] / max(data.shape[0] - 1, 1)
# undo scaling
_undo_scaling_cov(cov, picks_list, scalings)
method_ = method[ei]
Expand Down
11 changes: 8 additions & 3 deletions mne/decoding/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import check_scoring
from sklearn.model_selection import KFold, StratifiedKFold, check_cv
from sklearn.utils import check_array, indexable
from sklearn.utils import check_array, check_X_y, indexable

from ..parallel import parallel_func
from ..utils import _pl, logger, verbose, warn
Expand Down Expand Up @@ -76,9 +76,9 @@ class LinearModel(MetaEstimatorMixin, BaseEstimator):
)

def __init__(self, model=None):
# TODO: We need to set this to get our tag checking to work properly
if model is None:
model = LogisticRegression(solver="liblinear")

self.model = model

def __sklearn_tags__(self):
Expand Down Expand Up @@ -122,7 +122,11 @@ def fit(self, X, y, **fit_params):
self : instance of LinearModel
Returns the modified instance.
"""
X = check_array(X, input_name="X")
if y is not None:
X = check_array(X)
else:
X, y = check_X_y(X, y)
self.n_features_in_ = X.shape[1]
if y is not None:
y = check_array(y, dtype=None, ensure_2d=False, input_name="y")
if y.ndim > 2:
Expand All @@ -133,6 +137,7 @@ def fit(self, X, y, **fit_params):

# fit the Model
self.model.fit(X, y, **fit_params)
self.model_ = self.model # for better sklearn compat

# Computes patterns using Haufe's trick: A = Cov_X . W . Precision_Y

Expand Down
99 changes: 43 additions & 56 deletions mne/decoding/csp.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@

import numpy as np
from scipy.linalg import eigh
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.base import BaseEstimator
from sklearn.utils.validation import check_is_fitted

from .._fiff.meas_info import create_info
from ..cov import _compute_rank_raw_array, _regularized_covariance, _smart_eigh
Expand All @@ -19,10 +20,11 @@
fill_doc,
pinv,
)
from .transformer import MNETransformerMixin


@fill_doc
class CSP(TransformerMixin, BaseEstimator):
class CSP(MNETransformerMixin, BaseEstimator):
"""M/EEG signal decomposition using the Common Spatial Patterns (CSP).
This class can be used as a supervised decomposition to estimate spatial
Expand Down Expand Up @@ -112,49 +114,44 @@ def __init__(
component_order="mutual_info",
):
# Init default CSP
if not isinstance(n_components, int):
raise ValueError("n_components must be an integer.")
self.n_components = n_components
self.rank = rank
self.reg = reg

# Init default cov_est
if not (cov_est == "concat" or cov_est == "epoch"):
raise ValueError("unknown covariance estimation method")
self.cov_est = cov_est

# Init default transform_into
self.transform_into = _check_option(
"transform_into", transform_into, ["average_power", "csp_space"]
)

# Init default log
if transform_into == "average_power":
if log is not None and not isinstance(log, bool):
raise ValueError(
'log must be a boolean if transform_into == "average_power".'
)
else:
if log is not None:
raise ValueError('log must be a None if transform_into == "csp_space".')
self.transform_into = transform_into
self.log = log

_validate_type(norm_trace, bool, "norm_trace")
self.norm_trace = norm_trace
self.cov_method_params = cov_method_params
self.component_order = _check_option(
"component_order", component_order, ("mutual_info", "alternate")
self.component_order = component_order

def _validate_params(self, *, y):
_validate_type(self.n_components, int, "n_components")
if hasattr(self, "cov_est"):
_validate_type(self.cov_est, str, "cov_est")
_check_option("cov_est", self.cov_est, ("concat", "epoch"))
if hasattr(self, "norm_trace"):
_validate_type(self.norm_trace, bool, "norm_trace")
_check_option(
"transform_into", self.transform_into, ["average_power", "csp_space"]
)

def _check_Xy(self, X, y=None):
"""Check input data."""
if not isinstance(X, np.ndarray):
raise ValueError(f"X should be of type ndarray (got {type(X)}).")
if y is not None:
if len(X) != len(y) or len(y) < 1:
raise ValueError("X and y must have the same length.")
if X.ndim < 3:
raise ValueError("X must have at least 3 dimensions.")
if self.transform_into == "average_power":
_validate_type(
self.log,
(bool, None),
"log",
extra="when transform_into is 'average_power'",
)
else:
_validate_type(
self.log, None, "log", extra="when transform_into is 'csp_space'"
)
_check_option(
"component_order", self.component_order, ("mutual_info", "alternate")
)
self.classes_ = np.unique(y)
n_classes = len(self.classes_)
if n_classes < 2:
raise ValueError(f"n_classes must be >= 2, but got {n_classes} class")

def fit(self, X, y):
"""Estimate the CSP decomposition on epochs.
Expand All @@ -171,12 +168,9 @@ def fit(self, X, y):
self : instance of CSP
Returns the modified instance.
"""
self._check_Xy(X, y)

self._classes = np.unique(y)
n_classes = len(self._classes)
if n_classes < 2:
raise ValueError("n_classes must be >= 2.")
X, y = self._check_data(X, y=y, fit=True, return_y=True)
self._validate_params(y=y)
n_classes = len(self.classes_)
if n_classes > 2 and self.component_order == "alternate":
raise ValueError(
"component_order='alternate' requires two classes, but data contains "
Expand Down Expand Up @@ -225,13 +219,8 @@ def transform(self, X):
If self.transform_into == 'csp_space' then returns the data in CSP
space and shape is (n_epochs, n_components, n_times).
"""
if not isinstance(X, np.ndarray):
raise ValueError(f"X should be of type ndarray (got {type(X)}).")
if self.filters_ is None:
raise RuntimeError(
"No filters available. Please first fit CSP decomposition."
)

check_is_fitted(self, "filters_")
X = self._check_data(X)
pick_filters = self.filters_[: self.n_components]
X = np.asarray([np.dot(pick_filters, epoch) for epoch in X])

Expand Down Expand Up @@ -577,7 +566,7 @@ def _compute_covariance_matrices(self, X, y):

covs = []
sample_weights = []
for ci, this_class in enumerate(self._classes):
for ci, this_class in enumerate(self.classes_):
cov, weight = cov_estimator(
X[y == this_class],
cov_kind=f"class={this_class}",
Expand Down Expand Up @@ -689,7 +678,7 @@ def _normalize_eigenvectors(self, eigen_vectors, covs, sample_weights):
def _order_components(
self, covs, sample_weights, eigen_vectors, eigen_values, component_order
):
n_classes = len(self._classes)
n_classes = len(self.classes_)
if component_order == "mutual_info" and n_classes > 2:
mutual_info = self._compute_mutual_info(covs, sample_weights, eigen_vectors)
ix = np.argsort(mutual_info)[::-1]
Expand Down Expand Up @@ -889,10 +878,8 @@ def fit(self, X, y):
self : instance of SPoC
Returns the modified instance.
"""
self._check_Xy(X, y)

if len(np.unique(y)) < 2:
raise ValueError("y must have at least two distinct values.")
X, y = self._check_data(X, y=y, fit=True, return_y=True)
self._validate_params(y=y)

# The following code is directly copied from pyRiemann

Expand Down
25 changes: 19 additions & 6 deletions mne/decoding/ems.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,16 @@
from collections import Counter

import numpy as np
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.base import BaseEstimator

from .._fiff.pick import _picks_to_idx, pick_info, pick_types
from ..parallel import parallel_func
from ..utils import logger, verbose
from .base import _set_cv
from .transformer import MNETransformerMixin


class EMS(TransformerMixin, BaseEstimator):
class EMS(MNETransformerMixin, BaseEstimator):
"""Transformer to compute event-matched spatial filters.
This version of EMS :footcite:`SchurgerEtAl2013` operates on the entire
Expand All @@ -37,6 +38,16 @@ class EMS(TransformerMixin, BaseEstimator):
.. footbibliography::
"""

def __sklearn_tags__(self):
"""Return sklearn tags."""
from sklearn.utils import ClassifierTags

tags = super().__sklearn_tags__()
if tags.classifier_tags is None:
tags.classifier_tags = ClassifierTags()
tags.classifier_tags.multi_class = False
return tags

def __repr__(self): # noqa: D105
if hasattr(self, "filters_"):
return (
Expand Down Expand Up @@ -64,11 +75,12 @@ def fit(self, X, y):
self : instance of EMS
Returns self.
"""
classes = np.unique(y)
if len(classes) != 2:
X, y = self._check_data(X, y=y, fit=True, return_y=True)
classes, y = np.unique(y, return_inverse=True)
if len(classes) > 2:
raise ValueError("EMS only works for binary classification.")
self.classes_ = classes
filters = X[y == classes[0]].mean(0) - X[y == classes[1]].mean(0)
filters = X[y == 0].mean(0) - X[y == 1].mean(0)
filters /= np.linalg.norm(filters, axis=0)[None, :]
self.filters_ = filters
return self
Expand All @@ -86,13 +98,14 @@ def transform(self, X):
X : array, shape (n_epochs, n_times)
The input data transformed by the spatial filters.
"""
X = self._check_data(X)
Xt = np.sum(X * self.filters_, axis=1)
return Xt


@verbose
def compute_ems(
epochs, conditions=None, picks=None, n_jobs=None, cv=None, verbose=None
epochs, conditions=None, picks=None, n_jobs=None, cv=None, *, verbose=None
):
"""Compute event-matched spatial filter on epochs.
Expand Down
Loading

0 comments on commit 5f2b7f1

Please sign in to comment.