Skip to content

Commit 1918de3

Browse files
committed
FIX: Argh
1 parent 88f7b63 commit 1918de3

File tree

3 files changed

+22
-5
lines changed

3 files changed

+22
-5
lines changed

examples/decoding/linear_model_patterns.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@
7979

8080
# Extract and plot spatial filters and spatial patterns
8181
for name, coef in (("patterns", model.patterns_), ("filters", model.filters_)):
82-
# We fitted the linear model onto Z-scored data. To make the filters
82+
# We fit the linear model on Z-scored data. To make the filters
8383
# interpretable, we must reverse this normalization step
8484
coef = scaler.inverse_transform([coef])[0]
8585

mne/decoding/tests/test_transformer.py

+14-4
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,14 @@
1717

1818
from sklearn.decomposition import PCA
1919
from sklearn.kernel_ridge import KernelRidge
20+
from sklearn.pipeline import make_pipeline
21+
from sklearn.preprocessing import StandardScaler
2022
from sklearn.utils.estimator_checks import parametrize_with_checks
2123

22-
from mne import Epochs, create_info, io, pick_types, read_events
24+
from mne import Epochs, EpochsArray, create_info, io, pick_types, read_events
2325
from mne.decoding import (
2426
FilterEstimator,
27+
LinearModel,
2528
PSDEstimator,
2629
Scaler,
2730
TemporalFilter,
@@ -218,9 +221,16 @@ def test_vectorizer():
218221
assert_equal(vect.fit_transform(data[1:]).shape, (149, 108))
219222

220223
# check if raised errors are working correctly
221-
vect.fit(np.random.rand(105, 12, 3))
222-
pytest.raises(ValueError, vect.transform, np.random.rand(105, 12, 3, 1))
223-
pytest.raises(ValueError, vect.inverse_transform, np.random.rand(102, 12, 12))
224+
X = np.random.default_rng(0).standard_normal((105, 12, 3))
225+
y = np.arange(X.shape[0]) % 2
226+
pytest.raises(ValueError, vect.transform, X[..., np.newaxis])
227+
pytest.raises(ValueError, vect.inverse_transform, X[:, :-1])
228+
229+
# And that pipelines work properly
230+
X_arr = EpochsArray(X, create_info(12, 1000.0, "eeg"))
231+
vect.fit(X_arr)
232+
clf = make_pipeline(Vectorizer(), StandardScaler(), LinearModel())
233+
clf.fit(X_arr, y)
224234

225235

226236
def test_unsupervised_spatial_filter():

mne/decoding/transformer.py

+7
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
pick_info,
1616
)
1717
from ..cov import _check_scalings_user
18+
from ..epochs import BaseEpochs
1819
from ..filter import filter_data
1920
from ..time_frequency import psd_array_multitaper
2021
from ..utils import _check_option, _validate_type, fill_doc
@@ -34,6 +35,12 @@ def _check_data(
3435
multi_output=False,
3536
check_n_features=True,
3637
):
38+
# Sklearn calls asarray under the hood which works, but elsewhere they check for
39+
# __len__ then look at the size of obj[0]... which is an epoch of shape (1, ...)
40+
# rather than what they expect (shape (...)). So we explicitly get the NumPy
41+
# array to make everyone happy.
42+
if isinstance(epochs_data, BaseEpochs):
43+
epochs_data = epochs_data.get_data(copy=False)
3744
kwargs = dict(dtype=np.float64, allow_nd=True, order="C", force_writeable=True)
3845
if hasattr(self, "n_features_in_") and check_n_features:
3946
if y is None:

0 commit comments

Comments
 (0)