Skip to content

Commit 88f7b63

Browse files
committed
FIX: More [circle full]
1 parent e6270f9 commit 88f7b63

File tree

5 files changed

+15
-2
lines changed

5 files changed

+15
-2
lines changed

mne/decoding/tests/test_base.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,12 @@ def test_get_coef_multiclass(n_features, n_targets):
275275
"""Test get_coef on multiclass problems."""
276276
# Check patterns with more than 1 regressor
277277
X, Y, A = _make_data(n_samples=30000, n_features=n_features, n_targets=n_targets)
278-
lm = LinearModel(LinearRegression()).fit(X, Y)
278+
lm = LinearModel(LinearRegression())
279+
assert not hasattr(lm, "model_")
280+
lm.fit(X, Y)
281+
# TODO: modifying non-underscored `model` is a sklearn no-no, maybe should be a
282+
# metaestimator?
283+
assert lm.model is lm.model_
279284
assert_array_equal(lm.filters_.shape, lm.patterns_.shape)
280285
if n_targets == 1:
281286
want_shape = (n_features,)

mne/decoding/tests/test_time_frequency.py

+2
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,9 @@ def test_timefrequency_basic():
3535
clone(tf)
3636

3737
# Fit
38+
assert not hasattr(tf, "fitted_")
3839
tf.fit(X, None)
40+
assert tf.fitted_
3941

4042
# Transform
4143
tf = TimeFrequency(freqs, sfreq=100)

mne/decoding/tests/test_transformer.py

+4
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,9 @@ def test_psdestimator():
186186
epochs_data = epochs.get_data(copy=False)
187187
psd = PSDEstimator(2 * np.pi, 0, np.inf)
188188
y = epochs.events[:, -1]
189+
assert not hasattr(psd, "fitted_")
189190
X = psd.fit_transform(epochs_data, y)
191+
assert psd.fitted_
190192

191193
assert X.shape[0] == epochs_data.shape[0]
192194
assert_array_equal(psd.fit(epochs_data, y).transform(epochs_data), X)
@@ -308,7 +310,9 @@ def test_bad_triage():
308310
filt = TemporalFilter(l_freq=8, h_freq=60, sfreq=160.0)
309311
# Used to fail with "ValueError: Effective band-stop frequency (135.0) is
310312
# too high (maximum based on Nyquist is 80.0)"
313+
assert not hasattr(filt, "fitted_")
311314
filt.fit_transform(np.zeros((1, 1, 481)))
315+
assert filt.fitted_
312316

313317

314318
@pytest.mark.filterwarnings("ignore:.*filter_length.*")

mne/decoding/transformer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -461,7 +461,7 @@ def fit(self, epochs_data, y=None):
461461
self.fitted_ = True # sklearn compliance
462462
return self
463463

464-
def transform(self, epochs_data, y=None):
464+
def transform(self, epochs_data):
465465
"""Compute power spectral density (PSD) using a multi-taper method.
466466
467467
Parameters

tools/vulture_allowlist.py

+2
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@
4141

4242
# Decoding
4343
_._more_tags
44+
_.multi_class
45+
_.preserves_dtype
4446
deep
4547

4648
# Backward compat or rarely used

0 commit comments

Comments
 (0)