From 1d2635f84a55785c3531cfe4027eda3820a7fb31 Mon Sep 17 00:00:00 2001 From: "Thomas S. Binns" Date: Mon, 13 Jan 2025 20:00:00 +0000 Subject: [PATCH 01/24] [ENH] Add option to store and return TFR taper weights (#12910) Co-authored-by: Daniel McCloy Co-authored-by: Eric Larson --- doc/changes/devel/12910.newfeature.rst | 1 + mne/time_frequency/multitaper.py | 10 + mne/time_frequency/tests/test_tfr.py | 221 +++++++++++++-- mne/time_frequency/tfr.py | 362 +++++++++++++++++-------- mne/utils/docs.py | 12 + mne/utils/numerics.py | 3 + mne/viz/tests/test_topomap.py | 25 +- mne/viz/topomap.py | 14 +- 8 files changed, 507 insertions(+), 141 deletions(-) create mode 100644 doc/changes/devel/12910.newfeature.rst diff --git a/doc/changes/devel/12910.newfeature.rst b/doc/changes/devel/12910.newfeature.rst new file mode 100644 index 00000000000..95605c11017 --- /dev/null +++ b/doc/changes/devel/12910.newfeature.rst @@ -0,0 +1 @@ +Added the option to return taper weights from :func:`mne.time_frequency.tfr_array_multitaper`, and taper weights are now stored in the :class:`mne.time_frequency.BaseTFR` objects, by `Thomas Binns`_. \ No newline at end of file diff --git a/mne/time_frequency/multitaper.py b/mne/time_frequency/multitaper.py index 73a3308685d..98705e838c2 100644 --- a/mne/time_frequency/multitaper.py +++ b/mne/time_frequency/multitaper.py @@ -471,6 +471,7 @@ def tfr_array_multitaper( output="complex", n_jobs=None, *, + return_weights=False, verbose=None, ): """Compute Time-Frequency Representation (TFR) using DPSS tapers. @@ -504,6 +505,11 @@ def tfr_array_multitaper( coherence across trials. %(n_jobs)s The parallelization is implemented across channels. + return_weights : bool, default False + If True, return the taper weights. Only applies if ``output='complex'`` or + ``'phase'``. + + .. versionadded:: 1.10.0 %(verbose)s Returns @@ -520,6 +526,9 @@ def tfr_array_multitaper( If ``output`` is ``'avg_power_itc'``, the real values in ``out`` contain the average power and the imaginary values contain the inter-trial coherence: :math:`out = power_{avg} + i * ITC`. + weights : array of shape (n_tapers, n_freqs) + The taper weights. Only returned if ``output='complex'`` or ``'phase'`` and + ``return_weights=True``. See Also -------- @@ -550,6 +559,7 @@ def tfr_array_multitaper( use_fft=use_fft, decim=decim, output=output, + return_weights=return_weights, n_jobs=n_jobs, verbose=verbose, ) diff --git a/mne/time_frequency/tests/test_tfr.py b/mne/time_frequency/tests/test_tfr.py index e68ea9e6e18..6fa3a833be2 100644 --- a/mne/time_frequency/tests/test_tfr.py +++ b/mne/time_frequency/tests/test_tfr.py @@ -432,17 +432,21 @@ def test_tfr_morlet(): def test_dpsswavelet(): """Test DPSS tapers.""" freqs = np.arange(5, 25, 3) - Ws = _make_dpss( - 1000, freqs=freqs, n_cycles=freqs / 2.0, time_bandwidth=4.0, zero_mean=True + Ws, weights = _make_dpss( + 1000, + freqs=freqs, + n_cycles=freqs / 2.0, + time_bandwidth=4.0, + zero_mean=True, + return_weights=True, ) - assert len(Ws) == 3 # 3 tapers expected + assert np.shape(Ws)[:2] == (3, len(freqs)) # 3 tapers expected + assert np.shape(Ws)[:2] == np.shape(weights) # weights of shape (tapers, freqs) # Check that zero mean is true assert np.abs(np.mean(np.real(Ws[0][0]))) < 1e-5 - assert len(Ws[0]) == len(freqs) # As many wavelets as asked for - @pytest.mark.slowtest def test_tfr_multitaper(): @@ -664,6 +668,17 @@ def test_tfr_io(inst, average_tfr, request, tmp_path): with tfr.info._unlock(): tfr.info["meas_date"] = want assert tfr_loaded == tfr + # test with taper dimension and weights + n_tapers = 3 # anything >= 1 should do + weights = np.ones((n_tapers, tfr.shape[2])) # tapers x freqs + state = tfr.__getstate__() + state["data"] = np.repeat(np.expand_dims(tfr.data, 2), n_tapers, axis=2) # add dim + state["weights"] = weights # add weights + state["dims"] = ("epoch", "channel", "taper", "freq", "time") # update dims + tfr = EpochsTFR(inst=state) + tfr.save(fname, overwrite=True) + tfr_loaded = read_tfrs(fname) + assert tfr_loaded == tfr # test overwrite with pytest.raises(OSError, match="Destination file exists."): tfr.save(fname, overwrite=False) @@ -722,17 +737,31 @@ def test_average_tfr_init(full_evoked): AverageTFR(inst=full_evoked, method="stockwell", freqs=freqs_linspace) -def test_epochstfr_init_errors(epochs_tfr): - """Test __init__ for EpochsTFR.""" - state = epochs_tfr.__getstate__() - with pytest.raises(ValueError, match="EpochsTFR data should be 4D, got 3"): - EpochsTFR(inst=state | dict(data=epochs_tfr.data[..., 0])) +@pytest.mark.parametrize("inst", ("raw_tfr", "epochs_tfr", "average_tfr")) +def test_tfr_init_errors(inst, request, average_tfr): + """Test __init__ for {Raw,Epochs,Average}TFR.""" + # Load data + inst = _get_inst(inst, request, average_tfr=average_tfr) + state = inst.__getstate__() + # Prepare for TFRArray object instantiation + inst_name = inst.__class__.__name__ + class_mapping = dict(RawTFR=RawTFR, EpochsTFR=EpochsTFR, AverageTFR=AverageTFR) + ndims_mapping = dict( + RawTFR=("3D or 4D"), EpochsTFR=("4D or 5D"), AverageTFR=("3D or 4D") + ) + TFR = class_mapping[inst_name] + allowed_ndims = ndims_mapping[inst_name] + # Check errors caught + with pytest.raises(ValueError, match=f".*TFR data should be {allowed_ndims}"): + TFR(inst=state | dict(data=inst.data[..., 0])) + with pytest.raises(ValueError, match=f".*TFR data should be {allowed_ndims}"): + TFR(inst=state | dict(data=np.expand_dims(inst.data, axis=(0, 1)))) with pytest.raises(ValueError, match="Channel axis of data .* doesn't match info"): - EpochsTFR(inst=state | dict(data=epochs_tfr.data[:, :-1])) + TFR(inst=state | dict(data=inst.data[..., :-1, :, :])) with pytest.raises(ValueError, match="Time axis of data.*doesn't match times attr"): - EpochsTFR(inst=state | dict(times=epochs_tfr.times[:-1])) + TFR(inst=state | dict(times=inst.times[:-1])) with pytest.raises(ValueError, match="Frequency axis of.*doesn't match freqs attr"): - EpochsTFR(inst=state | dict(freqs=epochs_tfr.freqs[:-1])) + TFR(inst=state | dict(freqs=inst.freqs[:-1])) @pytest.mark.parametrize( @@ -830,6 +859,25 @@ def test_plot(): plt.close("all") +@pytest.mark.parametrize("output", ("complex", "phase")) +def test_plot_multitaper_complex_phase(output): + """Test TFR plotting of data with a taper dimension.""" + # Create example data with a taper dimension + n_chans, n_tapers, n_freqs, n_times = (3, 4, 2, 3) + data = np.random.rand(n_chans, n_tapers, n_freqs, n_times) + if output == "complex": + data = data + np.random.rand(*data.shape) * 1j # add imaginary data + times = np.arange(n_times) + freqs = np.arange(n_freqs) + weights = np.random.rand(n_tapers, n_freqs) + info = mne.create_info(n_chans, 1000.0, "eeg") + tfr = AverageTFRArray( + info=info, data=data, times=times, freqs=freqs, weights=weights + ) + # Check that plotting works + tfr.plot() + + @pytest.mark.parametrize( "timefreqs,title,combine", ( @@ -1154,6 +1202,15 @@ def test_averaging_epochsTFR(): ): power.average(method=np.mean) + # Check it doesn't run for taper spectra + tapered = epochs.compute_tfr( + method="multitaper", freqs=freqs, n_cycles=n_cycles, output="complex" + ) + with pytest.raises( + NotImplementedError, match=r"Averaging multitaper tapers .* is not supported." + ): + tapered.average() + def test_averaging_freqsandtimes_epochsTFR(): """Test that EpochsTFR averaging freqs methods work.""" @@ -1258,12 +1315,15 @@ def test_to_data_frame(): ch_names = ["EEG 001", "EEG 002", "EEG 003", "EEG 004"] n_picks = len(ch_names) ch_types = ["eeg"] * n_picks + n_tapers = 2 n_freqs = 5 n_times = 6 - data = np.random.rand(n_epos, n_picks, n_freqs, n_times) - times = np.arange(6) + data = np.random.rand(n_epos, n_picks, n_tapers, n_freqs, n_times) + times = np.arange(n_times) srate = 1000.0 - freqs = np.arange(5) + freqs = np.arange(n_freqs) + tapers = np.arange(n_tapers) + weights = np.ones((n_tapers, n_freqs)) events = np.zeros((n_epos, 3), dtype=int) events[:, 0] = np.arange(n_epos) events[:, 2] = np.arange(5, 5 + n_epos) @@ -1276,6 +1336,7 @@ def test_to_data_frame(): freqs=freqs, events=events, event_id=event_id, + weights=weights, ) # test index checking with pytest.raises(ValueError, match="options. Valid index options are"): @@ -1287,10 +1348,21 @@ def test_to_data_frame(): # test wide format df_wide = tfr.to_data_frame() assert all(np.isin(tfr.ch_names, df_wide.columns)) - assert all(np.isin(["time", "condition", "freq", "epoch"], df_wide.columns)) + assert all( + np.isin(["time", "condition", "freq", "epoch", "taper"], df_wide.columns) + ) # test long format df_long = tfr.to_data_frame(long_format=True) - expected = ("condition", "epoch", "freq", "time", "channel", "ch_type", "value") + expected = ( + "condition", + "epoch", + "freq", + "time", + "channel", + "ch_type", + "value", + "taper", + ) assert set(expected) == set(df_long.columns) assert set(tfr.ch_names) == set(df_long["channel"]) assert len(df_long) == tfr.data.size @@ -1298,21 +1370,29 @@ def test_to_data_frame(): df_long = tfr.to_data_frame(long_format=True, index=["freq"]) del df_wide, df_long # test whether data is in correct shape - df = tfr.to_data_frame(index=["condition", "epoch", "freq", "time"]) + df = tfr.to_data_frame(index=["condition", "epoch", "taper", "freq", "time"]) data = tfr.data assert_array_equal(df.values[:, 0], data[:, 0, :, :].reshape(1, -1).squeeze()) # compare arbitrary observation: assert ( - df.loc[("he", slice(None), freqs[1], times[2]), ch_names[3]].iat[0] - == data[1, 3, 1, 2] + df.loc[("he", slice(None), tapers[1], freqs[1], times[2]), ch_names[3]].iat[0] + == data[1, 3, 1, 1, 2] ) # Check also for AverageTFR: + # (remove taper dimension before averaging) + state = tfr.__getstate__() + state["data"] = state["data"][:, :, 0] + state["dims"] = ("epoch", "channel", "freq", "time") + state["weights"] = None + tfr = EpochsTFR(inst=state) tfr = tfr.average() with pytest.raises(ValueError, match="options. Valid index options are"): tfr.to_data_frame(index=["epoch", "condition"]) with pytest.raises(ValueError, match='"epoch" is not a valid option'): tfr.to_data_frame(index="epoch") + with pytest.raises(ValueError, match='"taper" is not a valid option'): + tfr.to_data_frame(index="taper") with pytest.raises(TypeError, match="index must be `None` or a string "): tfr.to_data_frame(index=np.arange(400)) # test wide format @@ -1348,11 +1428,13 @@ def test_to_data_frame_index(index): ch_names = ["EEG 001", "EEG 002", "EEG 003", "EEG 004"] n_picks = len(ch_names) ch_types = ["eeg"] * n_picks + n_tapers = 2 n_freqs = 5 n_times = 6 - data = np.random.rand(n_epos, n_picks, n_freqs, n_times) - times = np.arange(6) - freqs = np.arange(5) + data = np.random.rand(n_epos, n_picks, n_tapers, n_freqs, n_times) + times = np.arange(n_times) + freqs = np.arange(n_freqs) + weights = np.ones((n_tapers, n_freqs)) events = np.zeros((n_epos, 3), dtype=int) events[:, 0] = np.arange(n_epos) events[:, 2] = np.arange(5, 8) @@ -1365,6 +1447,7 @@ def test_to_data_frame_index(index): freqs=freqs, events=events, event_id=event_id, + weights=weights, ) df = tfr.to_data_frame(picks=[0, 2, 3], index=index) # test index order/hierarchy preservation @@ -1372,7 +1455,7 @@ def test_to_data_frame_index(index): index = [index] assert list(df.index.names) == index # test that non-indexed data were present as columns - non_index = list(set(["condition", "time", "freq", "epoch"]) - set(index)) + non_index = list(set(["condition", "time", "freq", "taper", "epoch"]) - set(index)) if len(non_index): assert all(np.isin(non_index, df.columns)) @@ -1538,7 +1621,8 @@ def test_epochs_compute_tfr_stockwell(epochs, freqs, return_itc): def test_epochs_compute_tfr_multitaper_complex_phase(epochs, output): """Test Epochs.compute_tfr(output="complex"/"phase").""" tfr = epochs.compute_tfr("multitaper", freqs_linspace, output=output) - assert len(tfr.shape) == 5 + assert len(tfr.shape) == 5 # epoch x channel x taper x freq x time + assert tfr.weights.shape == tfr.shape[2:4] # check weights and coeffs shapes match @pytest.mark.parametrize("copy", (False, True)) @@ -1550,6 +1634,42 @@ def test_epochstfr_iter_evoked(epochs_tfr, copy): assert avgs[0].comment == str(epochs_tfr.events[0, -1]) +@pytest.mark.parametrize("obj_type", ("raw", "epochs", "evoked")) +def test_tfrarray_tapered_spectra(obj_type): + """Test {Raw,Epochs,Average}TFRArray instantiation with tapered spectra.""" + # Create example data with a taper dimension + n_epochs, n_chans, n_tapers, n_freqs, n_times = (5, 3, 4, 2, 6) + data_shape = (n_chans, n_tapers, n_freqs, n_times) + if obj_type == "epochs": + data_shape = (n_epochs,) + data_shape + data = np.random.rand(*data_shape) + times = np.arange(n_times) + freqs = np.arange(n_freqs) + weights = np.random.rand(n_tapers, n_freqs) + info = mne.create_info(n_chans, 1000.0, "eeg") + # Prepare for TFRArray object instantiation + defaults = dict(info=info, data=data, times=times, freqs=freqs) + class_mapping = dict(raw=RawTFRArray, epochs=EpochsTFRArray, evoked=AverageTFRArray) + TFRArray = class_mapping[obj_type] + # Check TFRArray instantiation runs with good data + TFRArray(**defaults, weights=weights) + # Check taper dimension but no weights caught + with pytest.raises( + ValueError, match="Taper dimension in data, but no weights found." + ): + TFRArray(**defaults) + # Check mismatching n_taper in weights caught + with pytest.raises( + ValueError, match=r"Taper axis .* doesn't match weights attribute" + ): + TFRArray(**defaults, weights=weights[:-1]) + # Check mismatching n_freq in weights caught + with pytest.raises( + ValueError, match=r"Frequency axis .* doesn't match weights attribute" + ): + TFRArray(**defaults, weights=weights[:, :-1]) + + def test_tfr_proj(epochs): """Test `compute_tfr(proj=True)`.""" epochs.compute_tfr(method="morlet", freqs=freqs_linspace, proj=True) @@ -1731,3 +1851,52 @@ def test_tfr_plot_topomap(inst, ch_type, full_average_tfr, request): assert re.match( rf"Average over \d{{1,3}} {ch_type} channels\.", popup_fig.axes[0].get_title() ) + + +@pytest.mark.parametrize("output", ("complex", "phase")) +def test_tfr_topo_plotting_multitaper_complex_phase(output, evoked): + """Test plot_joint/topo/topomap() for data with a taper dimension.""" + # Compute TFR with taper dimension + tfr = evoked.compute_tfr( + method="multitaper", freqs=freqs_linspace, n_cycles=4, output=output + ) + # Check that plotting works + tfr.plot_joint(topomap_args=dict(res=8, contours=0, sensors=False)) # for speed + tfr.plot_topo() + tfr.plot_topomap() + + +def test_combine_tfr_error_catch(average_tfr): + """Test combine_tfr() catches errors.""" + # check unrecognised weights string caught + with pytest.raises(ValueError, match='Weights must be .* "nave" or "equal"'): + combine_tfr([average_tfr, average_tfr], weights="foo") + # check bad weights size caught + with pytest.raises(ValueError, match="Weights must be the same size as all_tfr"): + combine_tfr([average_tfr, average_tfr], weights=[1, 1, 1]) + # check different channel names caught + state = average_tfr.__getstate__() + new_info = average_tfr.info.copy() + average_tfr_bad = AverageTFR( + inst=state | dict(info=new_info.rename_channels({new_info.ch_names[0]: "foo"})) + ) + with pytest.raises(AssertionError, match=".* do not contain the same channels"): + combine_tfr([average_tfr, average_tfr_bad]) + # check different times caught + average_tfr_bad = AverageTFR(inst=state | dict(times=average_tfr.times + 1)) + with pytest.raises( + AssertionError, match=".* do not contain the same time instants" + ): + combine_tfr([average_tfr, average_tfr_bad]) + # check taper dim caught + n_tapers = 3 # anything >= 1 should do + weights = np.ones((n_tapers, average_tfr.shape[1])) # tapers x freqs + state["data"] = np.repeat(np.expand_dims(average_tfr.data, 1), n_tapers, axis=1) + state["weights"] = weights + state["dims"] = ("channel", "taper", "freq", "time") + average_tfr_taper = AverageTFR(inst=state) + with pytest.raises( + NotImplementedError, + match="Aggregating multitaper tapers across TFR datasets is not supported.", + ): + combine_tfr([average_tfr_taper, average_tfr_taper]) diff --git a/mne/time_frequency/tfr.py b/mne/time_frequency/tfr.py index 12d45d5d572..918fea1a33f 100644 --- a/mne/time_frequency/tfr.py +++ b/mne/time_frequency/tfr.py @@ -264,8 +264,11 @@ def _make_dpss( ------- Ws : list of array The wavelets time series. + Cs : list of array + The concentration weights. Only returned if return_weights=True. """ Ws = list() + Cs = list() freqs = np.array(freqs) if np.any(freqs <= 0): @@ -281,6 +284,7 @@ def _make_dpss( for m in range(n_taps): Wm = list() + Cm = list() for k, f in enumerate(freqs): if len(n_cycles) != 1: this_n_cycles = n_cycles[k] @@ -302,12 +306,15 @@ def _make_dpss( real_offset = Wk.mean() Wk -= real_offset Wk /= np.sqrt(0.5) * np.linalg.norm(Wk.ravel()) + Ck = np.sqrt(conc[m]) Wm.append(Wk) + Cm.append(Ck) Ws.append(Wm) + Cs.append(Cm) if return_weights: - return Ws, conc + return Ws, Cs return Ws @@ -428,6 +435,7 @@ def _compute_tfr( use_fft=True, decim=1, output="complex", + return_weights=False, n_jobs=None, *, verbose=None, @@ -478,7 +486,9 @@ def _compute_tfr( * 'itc' : inter-trial coherence. * 'avg_power_itc' : average of single trial power and inter-trial coherence across trials. - + return_weights : bool, default False + Whether to return the taper weights. Only applies if method='multitaper' and + output='complex' or 'phase'. %(n_jobs)s The number of epochs to process at the same time. The parallelization is implemented across channels. @@ -495,6 +505,9 @@ def _compute_tfr( n_tapers, n_freqs, n_times)``. If output is ``'avg_power_itc'``, the real values in the ``output`` contain average power' and the imaginary values contain the ITC: ``out = avg_power + i * itc``. + weights : array of shape (n_tapers, n_freqs) + The taper weights. Only returned if method='multitaper', output='complex' or + 'phase', and return_weights=True. """ # Check data epoch_data = np.asarray(epoch_data) @@ -516,6 +529,9 @@ def _compute_tfr( decim, output, ) + return_weights = ( + return_weights and method == "multitaper" and output in ["complex", "phase"] + ) decim = _ensure_slice(decim) if (freqs > sfreq / 2.0).any(): @@ -531,13 +547,18 @@ def _compute_tfr( Ws = [W] # to have same dimensionality as the 'multitaper' case elif method == "multitaper": - Ws = _make_dpss( + out = _make_dpss( sfreq, freqs, n_cycles=n_cycles, time_bandwidth=time_bandwidth, zero_mean=zero_mean, + return_weights=return_weights, ) + if return_weights: + Ws, weights = out + else: + Ws = out # Check wavelets if len(Ws[0][0]) > epoch_data.shape[2]: @@ -561,6 +582,8 @@ def _compute_tfr( out = np.empty((n_chans, n_freqs, n_times), dtype) elif output in ["complex", "phase"] and method == "multitaper": out = np.empty((n_chans, n_tapers, n_epochs, n_freqs, n_times), dtype) + if return_weights: + weights = np.array(weights) else: out = np.empty((n_chans, n_epochs, n_freqs, n_times), dtype) @@ -585,6 +608,9 @@ def _compute_tfr( out = out.transpose(2, 0, 1, 3, 4) else: out = out.transpose(1, 0, 2, 3) + + if return_weights: + return out, weights return out @@ -1200,6 +1226,9 @@ def __init__( method_kw.setdefault("output", "power") self._freqs = np.asarray(freqs, dtype=np.float64) del freqs + # always store weights for per-taper outputs + if method == "multitaper" and method_kw.get("output") in ["complex", "phase"]: + method_kw["return_weights"] = True # check validity of kwargs manually to save compute time if any are invalid tfr_funcs = dict( morlet=tfr_array_morlet, @@ -1221,6 +1250,7 @@ def __init__( self._method = method self._inst_type = type(inst) self._baseline = None + self._weights = None self.preload = True # needed for __getitem__, never False for TFRs # self._dims may also get updated by child classes self._dims = ["channel", "freq", "time"] @@ -1379,6 +1409,7 @@ def __getstate__(self): info=self.info, baseline=self._baseline, decim=self._decim, + weights=self._weights, ) def __setstate__(self, state): @@ -1389,7 +1420,6 @@ def __setstate__(self, state): defaults = dict( method="unknown", - dims=("epoch", "channel", "freq", "time")[-state["data"].ndim :], baseline=None, decim=1, data_type="TFR", @@ -1407,12 +1437,13 @@ def __setstate__(self, state): self._decim = defaults["decim"] self.preload = True self._set_times(self._raw_times) + self._weights = state.get("weights") # objs saved before #12910 won't have # Handle instance type. Prior to gh-11282, Raw was not a possibility so if # `inst_type_str` is missing it must be Epochs or Evoked unknown_class = Epochs if "epoch" in self._dims else Evoked inst_types = dict(Raw=Raw, Epochs=Epochs, Evoked=Evoked, Unknown=unknown_class) self._inst_type = inst_types[defaults["inst_type_str"]] - # sanity check data/freqs/times/info agreement + # sanity check data/freqs/times/info/weights agreement self._check_state() def __repr__(self): @@ -1465,18 +1496,29 @@ def _check_compatibility(self, other): raise RuntimeError(msg.format(problem, extra)) def _check_state(self): - """Check data/freqs/times/info agreement during __setstate__.""" + """Check data/freqs/times/info/weights agreement during __setstate__.""" msg = "{} axis of data ({}) doesn't match {} attribute ({})" n_chan_info = len(self.info["chs"]) n_chan = self._data.shape[self._dims.index("channel")] n_freq = self._data.shape[self._dims.index("freq")] n_time = self._data.shape[self._dims.index("time")] + n_taper = ( + self._data.shape[self._dims.index("taper")] + if "taper" in self._dims + else None + ) + if n_taper is not None and self._weights is None: + raise ValueError("Taper dimension in data, but no weights found.") if n_chan_info != n_chan: msg = msg.format("Channel", n_chan, "info", n_chan_info) elif n_freq != len(self.freqs): msg = msg.format("Frequency", n_freq, "freqs", self.freqs.size) elif n_time != len(self.times): msg = msg.format("Time", n_time, "times", self.times.size) + elif n_taper is not None and n_taper != self._weights.shape[0]: + msg = msg.format("Taper", n_taper, "weights", self._weights.shape[0]) + elif n_taper is not None and n_freq != self._weights.shape[1]: + msg = msg.format("Frequency", n_freq, "weights", self._weights.shape[1]) else: return raise ValueError(msg) @@ -1513,6 +1555,10 @@ def _compute_tfr(self, data, n_jobs, verbose): if self.method == "stockwell": self._data, self._itc, freqs = result assert np.array_equal(self._freqs, freqs) + elif self.method == "multitaper" and self._tfr_func.keywords.get( + "output", "" + ) in ["complex", "phase"]: + self._data, self._weights = result elif self._tfr_func.keywords.get("output", "").endswith("_itc"): self._data, self._itc = result.real, result.imag else: @@ -1613,6 +1659,7 @@ def _onselect( fmax=fmax, baseline=baseline, mode=mode, + taper_weights=self.weights, verbose=verbose, ) # average over times and freqs @@ -1691,6 +1738,11 @@ def times(self): """The time points present in the data (in seconds).""" return self._times_readonly + @property + def weights(self): + """The weights used for each taper in the time-frequency estimates.""" + return self._weights + @fill_doc def crop(self, tmin=None, tmax=None, fmin=None, fmax=None, include_tmax=True): """Crop data to a given time interval in place. @@ -1785,6 +1837,7 @@ def get_data( tmax=None, return_times=False, return_freqs=False, + return_tapers=False, ): """Get time-frequency data in NumPy array format. @@ -1800,6 +1853,10 @@ def get_data( return_freqs : bool Whether to return the frequency bin values for the requested frequency range. Default is ``False``. + return_tapers : bool + Whether to return the taper numbers. Default is ``False``. + + .. versionadded:: 1.10.0 Returns ------- @@ -1811,6 +1868,9 @@ def get_data( freqs : array The frequency values for the requested data range. Only returned if ``return_freqs`` is ``True``. + tapers : array | None + The taper numbers. Only returned if ``return_tapers`` is ``True``. Will be + ``None`` if a taper dimension is not present in the data. Notes ----- @@ -1848,7 +1908,13 @@ def get_data( if return_freqs: freqs = self._freqs[fmin_idx:fmax_idx] out.append(freqs) - if not return_times and not return_freqs: + if return_tapers: + if "taper" in self._dims: + tapers = np.arange(self.shape[self._dims.index("taper")]) + else: + tapers = None + out.append(tapers) + if not return_times and not return_freqs and not return_tapers: return out[0] return tuple(out) @@ -1960,6 +2026,7 @@ def plot( baseline=baseline, mode=mode, dB=dB, + taper_weights=self.weights, verbose=verbose, ) # shape @@ -1970,6 +2037,9 @@ def plot( want_shape[ch_axis] = len(idx_picks) if combine is None else 1 want_shape[freq_axis] = len(freqs) # in case there was fmin/fmax cropping want_shape[time_axis] = len(times) # in case there was tmin/tmax cropping + want_shape = [ + n for dim, n in zip(self._dims, want_shape) if dim != "taper" + ] # tapers must be aggregated over by now want_shape = tuple(want_shape) # combine combine_was_none = combine is None @@ -2313,6 +2383,7 @@ def plot_joint( fmax=_fmax, baseline=baseline, mode=mode, + taper_weights=self.weights, verbose=verbose, ) _data = _data.mean(axis=(-1, -2)) # avg over times and freqs @@ -2461,23 +2532,23 @@ def plot_topo( info, data = _prepare_picks(info, data, picks, axis=0) del picks - # TODO this is the only remaining call to _preproc_tfr; should be refactored - # (to use _prep_data_for_plot?) - data, times, freqs, vmin, vmax = _preproc_tfr( + # baseline, crop, convert complex to power, aggregate tapers, and dB scaling + data, times, freqs = _prep_data_for_plot( data, times, freqs, - tmin, - tmax, - fmin, - fmax, - mode, - baseline, - vmin, - vmax, - dB, - info["sfreq"], + tmin=tmin, + tmax=tmax, + fmin=fmin, + fmax=fmax, + baseline=baseline, + mode=mode, + dB=dB, + taper_weights=self.weights, + verbose=verbose, ) + # get vlims + vmin, vmax = _setup_vmin_vmax(data, vmin, vmax) if layout is None: from mne import find_layout @@ -2624,21 +2695,21 @@ def to_data_frame( ): """Export data in tabular structure as a pandas DataFrame. - Channels are converted to columns in the DataFrame. By default, - additional columns ``'time'``, ``'freq'``, ``'epoch'``, and - ``'condition'`` (epoch event description) are added, unless ``index`` - is not ``None`` (in which case the columns specified in ``index`` will - be used to form the DataFrame's index instead). ``'epoch'``, and - ``'condition'`` are not supported for ``AverageTFR``. + Channels are converted to columns in the DataFrame. By default, additional + columns ``'time'``, ``'freq'``, ``'taper'``, ``'epoch'``, and ``'condition'`` + (epoch event description) are added, unless ``index`` is not ``None`` (in which + case the columns specified in ``index`` will be used to form the DataFrame's + index instead). ``'epoch'``, and ``'condition'`` are not supported for + ``AverageTFR``. ``'taper'`` is only supported when a taper dimensions is + present, such as for complex or phase multitaper data. Parameters ---------- %(picks_all)s %(index_df_epo)s - Valid string values are ``'time'``, ``'freq'``, ``'epoch'``, and - ``'condition'`` for ``EpochsTFR`` and ``'time'`` and ``'freq'`` - for ``AverageTFR``. - Defaults to ``None``. + Valid string values are ``'time'``, ``'freq'``, ``'taper'``, ``'epoch'``, + and ``'condition'`` for ``EpochsTFR`` and ``'time'``, ``'freq'``, and + ``'taper'`` for ``AverageTFR``. Defaults to ``None``. %(long_format_df_epo)s %(time_format_df)s @@ -2651,42 +2722,58 @@ def to_data_frame( """ # check pandas once here, instead of in each private utils function pd = _check_pandas_installed() # noqa + # triage for Epoch-derived or unaggregated spectra + from_epo = isinstance(self, EpochsTFR) + unagg_mt = "taper" in self._dims # arg checking valid_index_args = ["time", "freq"] - if isinstance(self, EpochsTFR): + if from_epo: valid_index_args.extend(["epoch", "condition"]) + if unagg_mt: + valid_index_args.append("taper") valid_time_formats = ["ms", "timedelta"] index = _check_pandas_index_arguments(index, valid_index_args) time_format = _check_time_format(time_format, valid_time_formats) # get data picks = _picks_to_idx(self.info, picks, "all", exclude=()) - data, times, freqs = self.get_data(picks, return_times=True, return_freqs=True) - axis = self._dims.index("channel") - if not isinstance(self, EpochsTFR): + data, times, freqs, tapers = self.get_data( + picks, return_times=True, return_freqs=True, return_tapers=True + ) + ch_axis = self._dims.index("channel") + if not from_epo: data = data[np.newaxis] # add singleton "epochs" axis - axis += 1 - n_epochs, n_picks, n_freqs, n_times = data.shape - # reshape to (epochs*freqs*times) x signals - data = np.moveaxis(data, axis, -1) - data = data.reshape(n_epochs * n_freqs * n_times, n_picks) + ch_axis += 1 + if not unagg_mt: + data = np.expand_dims(data, -3) # add singleton "tapers" axis + n_epochs, n_picks, n_tapers, n_freqs, n_times = data.shape + # reshape to (epochs*tapers*freqs*times) x signals + data = np.moveaxis(data, ch_axis, -1) + data = data.reshape(n_epochs * n_tapers * n_freqs * n_times, n_picks) # prepare extra columns / multiindex mindex = list() + default_index = list() times = _convert_times(times, time_format, self.info["meas_date"]) - times = np.tile(times, n_epochs * n_freqs) - freqs = np.tile(np.repeat(freqs, n_times), n_epochs) + times = np.tile(times, n_epochs * n_freqs * n_tapers) + freqs = np.tile(np.repeat(freqs, n_times), n_epochs * n_tapers) mindex.append(("time", times)) mindex.append(("freq", freqs)) - if isinstance(self, EpochsTFR): - mindex.append(("epoch", np.repeat(self.selection, n_times * n_freqs))) + if from_epo: + mindex.append( + ("epoch", np.repeat(self.selection, n_times * n_freqs * n_tapers)) + ) rev_event_id = {v: k for k, v in self.event_id.items()} conditions = [rev_event_id[k] for k in self.events[:, 2]] - mindex.append(("condition", np.repeat(conditions, n_times * n_freqs))) + mindex.append( + ("condition", np.repeat(conditions, n_times * n_freqs * n_tapers)) + ) + default_index.extend(["condition", "epoch"]) + if unagg_mt: + tapers = np.repeat(np.tile(tapers, n_epochs), n_freqs * n_times) + mindex.append(("taper", tapers)) + default_index.append("taper") + default_index.extend(["freq", "time"]) assert all(len(mdx) == len(mindex[0]) for mdx in mindex[1:]) # build DataFrame - if isinstance(self, EpochsTFR): - default_index = ["condition", "epoch", "freq", "time"] - else: - default_index = ["freq", "time"] df = _build_data_frame( self, data, picks, long_format, mindex, index, default_index=default_index ) @@ -2733,6 +2820,7 @@ class AverageTFR(BaseTFR): %(nave_tfr_attr)s %(sfreq_tfr_attr)s %(shape_tfr_attr)s + %(weights_tfr_attr)s See Also -------- @@ -2849,6 +2937,15 @@ def __getstate__(self): def __setstate__(self, state): """Unpack AverageTFR from serialized format.""" + if state["data"].ndim not in [3, 4]: + raise ValueError( + f"RawTFR data should be 3D or 4D, got {state['data'].ndim}." + ) + # Set dims now since optional tapers makes it difficult to disentangle later + state["dims"] = ("channel",) + if state["data"].ndim == 4: + state["dims"] += ("taper",) + state["dims"] += ("freq", "time") super().__setstate__(state) self._comment = state.get("comment", "") self._nave = state.get("nave", 1) @@ -2892,6 +2989,7 @@ class AverageTFRArray(AverageTFR): The number of averaged TFRs. %(comment_averagetfr_attr)s %(method_tfr_array)s + %(weights_tfr_array)s Attributes ---------- @@ -2904,6 +3002,7 @@ class AverageTFRArray(AverageTFR): %(nave_tfr_attr)s %(sfreq_tfr_attr)s %(shape_tfr_attr)s + %(weights_tfr_attr)s See Also -------- @@ -2914,12 +3013,22 @@ class AverageTFRArray(AverageTFR): """ def __init__( - self, info, data, times, freqs, *, nave=None, comment=None, method=None + self, + info, + data, + times, + freqs, + *, + nave=None, + comment=None, + method=None, + weights=None, ): state = dict(info=info, data=data, times=times, freqs=freqs) - for name, optional in dict(nave=nave, comment=comment, method=method).items(): - if optional is not None: - state[name] = optional + optional = dict(nave=nave, comment=comment, method=method, weights=weights) + for name, value in optional.items(): + if value is not None: + state[name] = value self.__setstate__(state) @@ -2962,6 +3071,7 @@ class EpochsTFR(BaseTFR, GetEpochsMixin): %(selection_attr)s %(sfreq_tfr_attr)s %(shape_tfr_attr)s + %(weights_tfr_attr)s See Also -------- @@ -3041,8 +3151,15 @@ def __getstate__(self): def __setstate__(self, state): """Unpack EpochsTFR from serialized format.""" - if state["data"].ndim != 4: - raise ValueError(f"EpochsTFR data should be 4D, got {state['data'].ndim}.") + if state["data"].ndim not in [4, 5]: + raise ValueError( + f"EpochsTFR data should be 4D or 5D, got {state['data'].ndim}." + ) + # Set dims now since optional tapers makes it difficult to disentangle later + state["dims"] = ("epoch", "channel") + if state["data"].ndim == 5: + state["dims"] += ("taper",) + state["dims"] += ("freq", "time") super().__setstate__(state) self._metadata = state.get("metadata", None) n_epochs = self.shape[0] @@ -3152,7 +3269,16 @@ def average(self, method="mean", *, dim="epochs", copy=False): See discussion here: https://github.com/scipy/scipy/pull/12676#issuecomment-783370228 + + Averaging is not supported for data containing a taper dimension. """ + if "taper" in self._dims: + raise NotImplementedError( + "Averaging multitaper tapers across epochs, frequencies, or times is " + "not supported. If averaging across epochs, consider averaging the " + "epochs before computing the complex/phase spectrum." + ) + _check_option("dim", dim, ("epochs", "freqs", "times")) axis = self._dims.index(dim[:-1]) # self._dims entries aren't plural @@ -3524,6 +3650,7 @@ class EpochsTFRArray(EpochsTFR): %(selection)s %(drop_log)s %(metadata_epochstfr)s + %(weights_tfr_array)s Attributes ---------- @@ -3540,6 +3667,7 @@ class EpochsTFRArray(EpochsTFR): %(selection_attr)s %(sfreq_tfr_attr)s %(shape_tfr_attr)s + %(weights_tfr_attr)s See Also -------- @@ -3562,6 +3690,7 @@ def __init__( selection=None, drop_log=None, metadata=None, + weights=None, ): state = dict(info=info, data=data, times=times, freqs=freqs) optional = dict( @@ -3572,6 +3701,7 @@ def __init__( selection=selection, drop_log=drop_log, metadata=metadata, + weights=weights, ) for name, value in optional.items(): if value is not None: @@ -3614,6 +3744,7 @@ class RawTFR(BaseTFR): method : str The method used to compute the spectra (``'morlet'``, ``'multitaper'`` or ``'stockwell'``). + %(weights_tfr_attr)s See Also -------- @@ -3663,6 +3794,19 @@ def __init__( **method_kw, ) + def __setstate__(self, state): + """Unpack RawTFR from serialized format.""" + if state["data"].ndim not in [3, 4]: + raise ValueError( + f"RawTFR data should be 3D or 4D, got {state['data'].ndim}." + ) + # Set dims now since optional tapers makes it difficult to disentangle later + state["dims"] = ("channel",) + if state["data"].ndim == 4: + state["dims"] += ("taper",) + state["dims"] += ("freq", "time") + super().__setstate__(state) + def __getitem__(self, item): """Get RawTFR data. @@ -3728,6 +3872,7 @@ class RawTFRArray(RawTFR): %(times)s %(freqs_tfr_array)s %(method_tfr_array)s + %(weights_tfr_array)s Attributes ---------- @@ -3738,6 +3883,7 @@ class RawTFRArray(RawTFR): %(method_tfr_attr)s %(sfreq_tfr_attr)s %(shape_tfr_attr)s + %(weights_tfr_attr)s See Also -------- @@ -3755,10 +3901,13 @@ def __init__( freqs, *, method=None, + weights=None, ): state = dict(info=info, data=data, times=times, freqs=freqs) - if method is not None: - state["method"] = method + optional = dict(method=method, weights=weights) + for name, value in optional.items(): + if value is not None: + state[name] = value self.__setstate__(state) @@ -3786,8 +3935,16 @@ def combine_tfr(all_tfr, weights="nave"): Notes ----- + Aggregating multitaper TFR datasets with a taper dimension such as for complex or + phase data is not supported. + .. versionadded:: 0.11.0 """ + if any("taper" in tfr._dims for tfr in all_tfr): + raise NotImplementedError( + "Aggregating multitaper tapers across TFR datasets is not supported." + ) + tfr = all_tfr[0].copy() if isinstance(weights, str): if weights not in ("nave", "equal"): @@ -3861,62 +4018,6 @@ def _centered(arr, newsize): return arr[tuple(myslice)] -def _preproc_tfr( - data, - times, - freqs, - tmin, - tmax, - fmin, - fmax, - mode, - baseline, - vmin, - vmax, - dB, - sfreq, - copy=None, -): - """Aux Function to prepare tfr computation.""" - if copy is None: - copy = baseline is not None - data = rescale(data, times, baseline, mode, copy=copy) - - if np.iscomplexobj(data): - # complex amplitude → real power (for plotting); if data are - # real-valued they should already be power - data = (data * data.conj()).real - - # crop time - itmin, itmax = None, None - idx = np.where(_time_mask(times, tmin, tmax, sfreq=sfreq))[0] - if tmin is not None: - itmin = idx[0] - if tmax is not None: - itmax = idx[-1] + 1 - - times = times[itmin:itmax] - - # crop freqs - ifmin, ifmax = None, None - idx = np.where(_time_mask(freqs, fmin, fmax, sfreq=sfreq))[0] - if fmin is not None: - ifmin = idx[0] - if fmax is not None: - ifmax = idx[-1] + 1 - - freqs = freqs[ifmin:ifmax] - - # crop data - data = data[:, ifmin:ifmax, itmin:itmax] - - if dB: - data = 10 * np.log10(data) - - vmin, vmax = _setup_vmin_vmax(data, vmin, vmax) - return data, times, freqs, vmin, vmax - - def _ensure_slice(decim): """Aux function checking the decim parameter.""" _validate_type(decim, ("int-like", slice), "decim") @@ -4151,6 +4252,7 @@ def _prep_data_for_plot( baseline=None, mode=None, dB=False, + taper_weights=None, verbose=None, ): # baseline @@ -4164,9 +4266,43 @@ def _prep_data_for_plot( freqs = freqs[freq_mask] # crop data data = data[..., freq_mask, :][..., time_mask] - # complex amplitude → real power; real-valued data is already power (or ITC) + # handle unaggregated multitaper (complex or phase multitaper data) + if taper_weights is not None: # assumes a taper dimension + logger.info("Aggregating multitaper estimates before plotting...") + if np.iscomplexobj(data): # complex coefficients → power + data = _tfr_from_mt(data, taper_weights) + else: # tapered phase data → weighted phase data + # channels, tapers, freqs, time + assert data.ndim == 4 + # weights as a function of (tapers, freqs) + assert taper_weights.ndim == 2 + data = (data * taper_weights[np.newaxis, :, :, np.newaxis]).mean(axis=1) + # handle remaining complex amplitude → real power if np.iscomplexobj(data): data = (data * data.conj()).real if dB: data = 10 * np.log10(data) return data, times, freqs + + +def _tfr_from_mt(x_mt, weights): + """Aggregate complex multitaper coefficients over tapers and convert to power. + + Parameters + ---------- + x_mt : array, shape (n_channels, n_tapers, n_freqs, n_times) + The complex-valued multitaper coefficients. + weights : array, shape (n_tapers, n_freqs) + The weights to use to combine the tapered estimates. + + Returns + ------- + tfr : array, shape (n_channels, n_freqs, n_times) + The time-frequency power estimates. + """ + weights = weights[np.newaxis, :, :, np.newaxis] # add singleton channel & time dims + tfr = weights * x_mt + tfr *= tfr.conj() + tfr = tfr.real.sum(axis=1) + tfr *= 2 / (weights * weights.conj()).real.sum(axis=1) + return tfr diff --git a/mne/utils/docs.py b/mne/utils/docs.py index aea0a17fd32..683704c4bc6 100644 --- a/mne/utils/docs.py +++ b/mne/utils/docs.py @@ -5014,6 +5014,18 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): solution. """ +docdict["weights_tfr_array"] = """ +weights : array, shape (n_tapers, n_freqs) | None + The weights for each taper. Must be provided if ``data`` has a taper dimension, such + as for complex or phase multitaper data. + + .. versionadded:: 1.10.0 +""" +docdict["weights_tfr_attr"] = """ +weights : array, shape (n_tapers, n_freqs) | None + The weights used for each taper in the time-frequency estimates. +""" + docdict["window_psd"] = """\ window : str | float | tuple Windowing function to use. See :func:`scipy.signal.get_window`. diff --git a/mne/utils/numerics.py b/mne/utils/numerics.py index c287fb42305..4bf8d094f81 100644 --- a/mne/utils/numerics.py +++ b/mne/utils/numerics.py @@ -550,6 +550,9 @@ def grand_average(all_inst, interpolate_bads=True, drop_bads=True): Notes ----- + Aggregating multitaper TFR datasets with a taper dimension such as for complex or + phase data is not supported. + .. versionadded:: 0.11.0 """ # check if all elements in the given list are evoked data diff --git a/mne/viz/tests/test_topomap.py b/mne/viz/tests/test_topomap.py index afa9341c00e..b87d0d39f89 100644 --- a/mne/viz/tests/test_topomap.py +++ b/mne/viz/tests/test_topomap.py @@ -44,7 +44,7 @@ compute_bridged_electrodes, compute_current_source_density, ) -from mne.time_frequency.tfr import AverageTFRArray +from mne.time_frequency.tfr import AverageTFR, AverageTFRArray from mne.viz import plot_evoked_topomap, plot_projs_topomap, topomap from mne.viz.tests.test_raw import _proj_status from mne.viz.topomap import ( @@ -610,6 +610,29 @@ def test_plot_tfr_topomap(): ch_type="mag", tmin=0.05, tmax=0.150, fmin=0, fmax=10, res=res, contours=0 ) + # test data with taper dimension (real) + data = np.expand_dims(data, axis=1) + weights = np.random.rand(1, n_freqs) + tfr = AverageTFRArray( + info=info, + data=data, + times=times, + freqs=np.arange(n_freqs), + nave=nave, + weights=weights, + ) + tfr.plot_topomap( + ch_type="mag", tmin=0.05, tmax=0.150, fmin=0, fmax=10, res=res, contours=0 + ) + # test data with taper dimension (complex) + state = tfr.__getstate__() + tfr = AverageTFR(inst=state | dict(data=data * (1 + 1j))) + tfr.plot_topomap( + ch_type="mag", tmin=0.05, tmax=0.150, fmin=0, fmax=10, res=res, contours=0 + ) + # remove taper dim before proceeding + data = data[:, 0] + # test real numbers tfr = AverageTFRArray( info=info, data=data, times=times, freqs=np.arange(n_freqs), nave=nave diff --git a/mne/viz/topomap.py b/mne/viz/topomap.py index dd63a626683..d83698acbb1 100644 --- a/mne/viz/topomap.py +++ b/mne/viz/topomap.py @@ -1882,7 +1882,7 @@ def plot_tfr_topomap( tfr, ch_type, sphere=sphere ) outlines = _make_head_outlines(sphere, pos, outlines, clip_origin) - data = tfr.data[picks, :, :] + data = tfr.data[picks] # merging grads before rescaling makes ERDs visible if merge_channels: @@ -1890,6 +1890,18 @@ def plot_tfr_topomap( data = rescale(data, tfr.times, baseline, mode, copy=True) + # handle unaggregated multitaper (complex or phase multitaper data) + if tfr.weights is not None: # assumes a taper dimension + logger.info("Aggregating multitaper estimates before plotting...") + weights = tfr.weights[np.newaxis, :, :, np.newaxis] # add channel & time dims + data = weights * data + if np.iscomplexobj(data): # complex coefficients → power + data *= data.conj() + data = data.real.sum(axis=1) + data *= 2 / (weights * weights.conj()).real.sum(axis=1) + else: # tapered phase data → weighted phase data + data = data.mean(axis=1) + # handle remaining complex amplitude → real power if np.iscomplexobj(data): data = np.sqrt((data * data.conj()).real) From 2abb7b220ed2580e141158499919300cfa1f6a3b Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Mon, 13 Jan 2025 17:37:42 -0500 Subject: [PATCH 02/24] BUG: Fix bug with helium anon (#13056) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- doc/changes/devel/13056.bugfix.rst | 1 + mne/_fiff/meas_info.py | 15 +++++++--- mne/_fiff/tests/test_meas_info.py | 48 +++++++++++++++++++----------- mne/_fiff/write.py | 9 +++--- mne/utils/_testing.py | 4 +-- 5 files changed, 49 insertions(+), 28 deletions(-) create mode 100644 doc/changes/devel/13056.bugfix.rst diff --git a/doc/changes/devel/13056.bugfix.rst b/doc/changes/devel/13056.bugfix.rst new file mode 100644 index 00000000000..2a7919de289 --- /dev/null +++ b/doc/changes/devel/13056.bugfix.rst @@ -0,0 +1 @@ +Fix bug with saving of anonymized data when helium info is present in measurement info, by `Eric Larson`_. diff --git a/mne/_fiff/meas_info.py b/mne/_fiff/meas_info.py index 629d9a4b0ce..ecc93591a05 100644 --- a/mne/_fiff/meas_info.py +++ b/mne/_fiff/meas_info.py @@ -2493,6 +2493,8 @@ def read_meas_info(fid, tree, clean_bads=False, verbose=None): hi["meas_date"] = _ensure_meas_date_none_or_dt( tuple(int(t) for t in tag.data), ) + if "meas_date" not in hi: + hi["meas_date"] = None info["helium_info"] = hi del hi @@ -2879,7 +2881,8 @@ def write_meas_info(fid, info, data_type=None, reset_range=True): write_float(fid, FIFF.FIFF_HELIUM_LEVEL, hi["helium_level"]) if hi.get("orig_file_guid") is not None: write_string(fid, FIFF.FIFF_ORIG_FILE_GUID, hi["orig_file_guid"]) - write_int(fid, FIFF.FIFF_MEAS_DATE, _dt_to_stamp(hi["meas_date"])) + if hi["meas_date"] is not None: + write_int(fid, FIFF.FIFF_MEAS_DATE, _dt_to_stamp(hi["meas_date"])) end_block(fid, FIFF.FIFFB_HELIUM) del hi @@ -2916,8 +2919,10 @@ def write_meas_info(fid, info, data_type=None, reset_range=True): _write_proc_history(fid, info) -@fill_doc -def write_info(fname, info, data_type=None, reset_range=True): +@verbose +def write_info( + fname, info, *, data_type=None, reset_range=True, overwrite=False, verbose=None +): """Write measurement info in fif file. Parameters @@ -2931,8 +2936,10 @@ def write_info(fname, info, data_type=None, reset_range=True): raw data. reset_range : bool If True, info['chs'][k]['range'] will be set to unity. + %(overwrite)s + %(verbose)s """ - with start_and_end_file(fname) as fid: + with start_and_end_file(fname, overwrite=overwrite) as fid: start_block(fid, FIFF.FIFFB_MEAS) write_meas_info(fid, info, data_type, reset_range) end_block(fid, FIFF.FIFFB_MEAS) diff --git a/mne/_fiff/tests/test_meas_info.py b/mne/_fiff/tests/test_meas_info.py index 3e3c150573f..d088da2a4a2 100644 --- a/mne/_fiff/tests/test_meas_info.py +++ b/mne/_fiff/tests/test_meas_info.py @@ -306,7 +306,9 @@ def test_read_write_info(tmp_path): gantry_angle = info["gantry_angle"] meas_id = info["meas_id"] - write_info(temp_file, info) + with pytest.raises(FileExistsError, match="Destination file exists"): + write_info(temp_file, info) + write_info(temp_file, info, overwrite=True) info = read_info(temp_file) assert info["proc_history"][0]["creator"] == creator assert info["hpi_meas"][0]["creator"] == creator @@ -348,7 +350,7 @@ def test_read_write_info(tmp_path): info["meas_date"] = datetime(1800, 1, 1, 0, 0, 0, tzinfo=timezone.utc) fname = tmp_path / "test.fif" with pytest.raises(RuntimeError, match="must be between "): - write_info(fname, info) + write_info(fname, info, overwrite=True) @testing.requires_testing_data @@ -377,7 +379,7 @@ def test_io_coord_frame(tmp_path): for ch_type in ("eeg", "seeg", "ecog", "dbs", "hbo", "hbr"): info = create_info(ch_names=["Test Ch"], sfreq=1000.0, ch_types=[ch_type]) info["chs"][0]["loc"][:3] = [0.05, 0.01, -0.03] - write_info(fname, info) + write_info(fname, info, overwrite=True) info2 = read_info(fname) assert info2["chs"][0]["coord_frame"] == FIFF.FIFFV_COORD_HEAD @@ -585,7 +587,7 @@ def test_check_consistency(): info2["subject_info"] = {"height": "bad"} -def _test_anonymize_info(base_info): +def _test_anonymize_info(base_info, tmp_path): """Test that sensitive information can be anonymized.""" pytest.raises(TypeError, anonymize_info, "foo") assert isinstance(base_info, Info) @@ -692,14 +694,25 @@ def _adjust_back(e_i, dt): # exp 4 tests is a supplied daysback delta_t_3 = timedelta(days=223 + 364 * 500) + def _check_equiv(got, want, err_msg): + __tracebackhide__ = True + fname_temp = tmp_path / "test.fif" + assert_object_equal(got, want, err_msg=err_msg) + write_info(fname_temp, got, reset_range=False, overwrite=True) + got = read_info(fname_temp) + # this gets changed on write but that's expected + with got._unlock(): + got["file_id"] = want["file_id"] + assert_object_equal(got, want, err_msg=f"{err_msg} (on I/O round trip)") + new_info = anonymize_info(base_info.copy()) - assert_object_equal(new_info, exp_info, err_msg="anon mismatch") + _check_equiv(new_info, exp_info, err_msg="anon mismatch") new_info = anonymize_info(base_info.copy(), keep_his=True) - assert_object_equal(new_info, exp_info_2, err_msg="anon keep_his mismatch") + _check_equiv(new_info, exp_info_2, err_msg="anon keep_his mismatch") new_info = anonymize_info(base_info.copy(), daysback=delta_t_2.days) - assert_object_equal(new_info, exp_info_3, err_msg="anon daysback mismatch") + _check_equiv(new_info, exp_info_3, err_msg="anon daysback mismatch") with pytest.raises(RuntimeError, match="anonymize_info generated"): anonymize_info(base_info.copy(), daysback=delta_t_3.days) @@ -726,7 +739,7 @@ def _adjust_back(e_i, dt): new_info = anonymize_info(base_info.copy(), daysback=delta_t_2.days) else: new_info = anonymize_info(base_info.copy(), daysback=delta_t_2.days) - assert_object_equal( + _check_equiv( new_info, exp_info_3, err_msg="meas_date=None daysback mismatch", @@ -734,7 +747,7 @@ def _adjust_back(e_i, dt): with _record_warnings(): # meas_date is None new_info = anonymize_info(base_info.copy()) - assert_object_equal(new_info, exp_info_3, err_msg="meas_date=None mismatch") + _check_equiv(new_info, exp_info_3, err_msg="meas_date=None mismatch") @pytest.mark.parametrize( @@ -777,8 +790,8 @@ def _complete_info(info): height=2.0, ) info["helium_info"] = dict( - he_level_raw=12.34, - helium_level=45.67, + he_level_raw=np.float32(12.34), + helium_level=np.float32(45.67), meas_date=datetime(2024, 11, 14, 14, 8, 2, tzinfo=timezone.utc), orig_file_guid="e", ) @@ -796,14 +809,13 @@ def _complete_info(info): machid=np.ones(2, int), secs=d[0], usecs=d[1], - date=d, ), experimenter="j", max_info=dict( - max_st=[], - sss_ctc=[], - sss_cal=[], - sss_info=dict(head_pos=None, in_order=8), + max_st=dict(), + sss_ctc=dict(), + sss_cal=dict(), + sss_info=dict(in_order=8), ), date=d, ), @@ -830,8 +842,8 @@ def test_anonymize(tmp_path): # test mne.anonymize_info() events = read_events(event_name) epochs = Epochs(raw, events[:1], 2, 0.0, 0.1, baseline=None) - _test_anonymize_info(raw.info) - _test_anonymize_info(epochs.info) + _test_anonymize_info(raw.info, tmp_path) + _test_anonymize_info(epochs.info, tmp_path) # test instance methods & I/O roundtrip for inst, keep_his in zip((raw, epochs), (True, False)): diff --git a/mne/_fiff/write.py b/mne/_fiff/write.py index 1fc32f0163e..8486ca13121 100644 --- a/mne/_fiff/write.py +++ b/mne/_fiff/write.py @@ -13,7 +13,7 @@ import numpy as np from scipy.sparse import csc_array, csr_array -from ..utils import _file_like, _validate_type, logger +from ..utils import _check_fname, _file_like, _validate_type, logger from ..utils.numerics import _date_to_julian from .constants import FIFF @@ -277,7 +277,7 @@ def end_block(fid, kind): write_int(fid, FIFF.FIFF_BLOCK_END, kind) -def start_file(fname, id_=None): +def start_file(fname, id_=None, *, overwrite=True): """Open a fif file for writing and writes the compulsory header tags. Parameters @@ -294,6 +294,7 @@ def start_file(fname, id_=None): fid = fname fid.seek(0) else: + fname = _check_fname(fname, overwrite=overwrite) fname = str(fname) if op.splitext(fname)[1].lower() == ".gz": logger.debug("Writing using gzip") @@ -311,9 +312,9 @@ def start_file(fname, id_=None): @contextmanager -def start_and_end_file(fname, id_=None): +def start_and_end_file(fname, id_=None, *, overwrite=True): """Start and (if successfully written) close the file.""" - with start_file(fname, id_=id_) as fid: + with start_file(fname, id_=id_, overwrite=overwrite) as fid: yield fid end_file(fid) # we only hit this line if the yield does not err diff --git a/mne/utils/_testing.py b/mne/utils/_testing.py index 323b530a641..63e0d1036b9 100644 --- a/mne/utils/_testing.py +++ b/mne/utils/_testing.py @@ -179,9 +179,9 @@ def assert_and_remove_boundary_annot(annotations, n=1): annotations.delete(idx) -def assert_object_equal(a, b, *, err_msg="Object mismatch"): +def assert_object_equal(a, b, *, err_msg="Object mismatch", allclose=False): """Assert two objects are equal.""" - d = object_diff(a, b) + d = object_diff(a, b, allclose=allclose) assert d == "", f"{err_msg}\n{d}" From f82d3993617d2a34744eb955385448c67672d6ec Mon Sep 17 00:00:00 2001 From: "Thomas S. Binns" Date: Mon, 13 Jan 2025 22:43:44 +0000 Subject: [PATCH 03/24] Add `combine_spectrum()` function and allow `grand_average()` to support `Spectrum` data (#13058) Co-authored-by: Daniel McCloy --- doc/api/time_frequency.rst | 1 + doc/changes/devel/13058.newfeature.rst | 1 + mne/time_frequency/__init__.pyi | 2 + mne/time_frequency/spectrum.py | 68 +++++++++++++++++++++++ mne/time_frequency/tests/test_spectrum.py | 55 +++++++++++++++++- mne/time_frequency/tfr.py | 12 ++-- mne/utils/numerics.py | 57 +++++++++++-------- 7 files changed, 165 insertions(+), 31 deletions(-) create mode 100644 doc/changes/devel/13058.newfeature.rst diff --git a/doc/api/time_frequency.rst b/doc/api/time_frequency.rst index 8923920bdba..b66b1b6ca64 100644 --- a/doc/api/time_frequency.rst +++ b/doc/api/time_frequency.rst @@ -31,6 +31,7 @@ Functions that operate on mne-python objects: .. autosummary:: :toctree: ../generated/ + combine_spectrum csd_tfr csd_fourier csd_multitaper diff --git a/doc/changes/devel/13058.newfeature.rst b/doc/changes/devel/13058.newfeature.rst new file mode 100644 index 00000000000..bbd01fa4552 --- /dev/null +++ b/doc/changes/devel/13058.newfeature.rst @@ -0,0 +1 @@ +Add the function :func:`mne.time_frequency.combine_spectrum` for combining data across :class:`mne.time_frequency.Spectrum` objects, and allow :func:`mne.grand_average` to operate on :class:`mne.time_frequency.Spectrum` objects, by `Thomas Binns`_. \ No newline at end of file diff --git a/mne/time_frequency/__init__.pyi b/mne/time_frequency/__init__.pyi index 0faeb7263d8..a612c2a850a 100644 --- a/mne/time_frequency/__init__.pyi +++ b/mne/time_frequency/__init__.pyi @@ -11,6 +11,7 @@ __all__ = [ "RawTFRArray", "Spectrum", "SpectrumArray", + "combine_spectrum", "csd_array_fourier", "csd_array_morlet", "csd_array_multitaper", @@ -61,6 +62,7 @@ from .spectrum import ( EpochsSpectrumArray, Spectrum, SpectrumArray, + combine_spectrum, read_spectrum, ) from .tfr import ( diff --git a/mne/time_frequency/spectrum.py b/mne/time_frequency/spectrum.py index a70697fd57c..b1de7f11c0f 100644 --- a/mne/time_frequency/spectrum.py +++ b/mne/time_frequency/spectrum.py @@ -1643,6 +1643,74 @@ def __init__( ) +def combine_spectrum(all_spectrum, weights="nave"): + """Merge spectral data by weighted addition. + + Create a new :class:`mne.time_frequency.Spectrum` instance, using a combination of + the supplied instances as its data. By default, the mean (weighted by trials) is + used. Subtraction can be performed by passing negative weights (e.g., ``[1, -1]``). + Data must have the same channels and the same frequencies. + + Parameters + ---------- + all_spectrum : list of Spectrum + The Spectrum objects. + weights : list of float | str + The weights to apply to the data of each :class:`~mne.time_frequency.Spectrum` + instance, or a string describing the weighting strategy to apply: 'nave' + computes sum-to-one weights proportional to each object’s nave attribute; + 'equal' weights each :class:`~mne.time_frequency.Spectrum` by + ``1 / len(all_spectrum)``. + + Returns + ------- + spectrum : Spectrum + The new spectral data. + + Notes + ----- + .. versionadded:: 1.10.0 + """ + spectrum = all_spectrum[0].copy() + if isinstance(weights, str): + if weights not in ("nave", "equal"): + raise ValueError('Weights must be a list of float, or "nave" or "equal"') + if weights == "nave": + for s_ in all_spectrum: + if s_.nave is None: + raise ValueError(f"The 'nave' attribute is not specified for {s_}") + weights = np.array([e.nave for e in all_spectrum], float) + weights /= weights.sum() + else: # == 'equal' + weights = [1.0 / len(all_spectrum)] * len(all_spectrum) + weights = np.array(weights, float) + if weights.ndim != 1 or weights.size != len(all_spectrum): + raise ValueError("Weights must be the same size as all_spectrum") + + ch_names = spectrum.ch_names + for s_ in all_spectrum[1:]: + assert ( + s_.ch_names == ch_names + ), f"{spectrum} and {s_} do not contain the same channels" + assert ( + np.max(np.abs(s_.freqs - spectrum.freqs)) < 1e-7 + ), f"{spectrum} and {s_} do not contain the same frequencies" + + # use union of bad channels + bads = list( + set(spectrum.info["bads"]).union(*(s_.info["bads"] for s_ in all_spectrum[1:])) + ) + spectrum.info["bads"] = bads + + # combine spectral data + spectrum._data = sum(w * s_.data for w, s_ in zip(weights, all_spectrum)) + if spectrum.nave is not None: + spectrum._nave = max( + int(1.0 / sum(w**2 / s_.nave for w, s_ in zip(weights, all_spectrum))), 1 + ) + return spectrum + + def read_spectrum(fname): """Load a :class:`mne.time_frequency.Spectrum` object from disk. diff --git a/mne/time_frequency/tests/test_spectrum.py b/mne/time_frequency/tests/test_spectrum.py index 162d89b1c25..927c22360c5 100644 --- a/mne/time_frequency/tests/test_spectrum.py +++ b/mne/time_frequency/tests/test_spectrum.py @@ -14,7 +14,11 @@ from mne.io import RawArray from mne.time_frequency import read_spectrum from mne.time_frequency.multitaper import _psd_from_mt -from mne.time_frequency.spectrum import EpochsSpectrumArray, SpectrumArray +from mne.time_frequency.spectrum import ( + EpochsSpectrumArray, + SpectrumArray, + combine_spectrum, +) from mne.utils import _record_warnings @@ -190,6 +194,55 @@ def test_spectrum_copy(raw_spectrum): assert raw_spectrum.freqs is not None +@pytest.mark.parametrize("weights", ["nave", "equal", [1, -1]]) +def test_combine_spectrum(raw_spectrum, weights): + """Test `combine_spectrum()` works.""" + spectrum1 = raw_spectrum.copy() + spectrum2 = raw_spectrum.copy() + if weights == "nave": + spectrum1._nave = 1 + spectrum2._nave = 2 + spectrum2._data *= 2 + new_spectrum = combine_spectrum([spectrum1, spectrum2], weights=weights) + assert_allclose(new_spectrum.data, spectrum1.data * (5 / 3)) + elif weights == "equal": + spectrum2._data *= 2 + new_spectrum = combine_spectrum([spectrum1, spectrum2], weights=weights) + assert_allclose(new_spectrum.data, spectrum1.data * 1.5) + else: + new_spectrum = combine_spectrum([spectrum1, spectrum2], weights=weights) + assert_allclose(new_spectrum.data, 0) + + +def test_combine_spectrum_error_catch(raw_spectrum): + """Test `combine_spectrum()` catches errors.""" + # Test bad weights + with pytest.raises( + ValueError, match='Weights must be a list of float, or "nave" or "equal"' + ): + combine_spectrum([raw_spectrum, raw_spectrum], weights="foo") + with pytest.raises( + ValueError, match="Weights must be the same size as all_spectrum" + ): + combine_spectrum([raw_spectrum, raw_spectrum], weights=[1, 1, 1]) + + # Test bad nave + with pytest.raises(ValueError, match="The 'nave' attribute is not specified"): + combine_spectrum([raw_spectrum, raw_spectrum], weights="nave") + + # Test inconsistent channels + raw_spectrum2 = raw_spectrum.copy() + raw_spectrum2.drop_channels(raw_spectrum2.ch_names[0]) + with pytest.raises(AssertionError, match=".* do not contain the same channels"): + combine_spectrum([raw_spectrum, raw_spectrum2], weights="equal") + + # Test inconsistent frequencies + raw_spectrum2 = raw_spectrum.copy() + raw_spectrum2._freqs = raw_spectrum2._freqs + 1 + with pytest.raises(AssertionError, match=".* do not contain the same frequencies"): + combine_spectrum([raw_spectrum, raw_spectrum2], weights="equal") + + def test_spectrum_reject_by_annot(raw): """Test rejecting by annotation. diff --git a/mne/time_frequency/tfr.py b/mne/time_frequency/tfr.py index 918fea1a33f..b1736f151d2 100644 --- a/mne/time_frequency/tfr.py +++ b/mne/time_frequency/tfr.py @@ -3960,12 +3960,12 @@ def combine_tfr(all_tfr, weights="nave"): ch_names = tfr.ch_names for t_ in all_tfr[1:]: - assert t_.ch_names == ch_names, ValueError( - f"{tfr} and {t_} do not contain the same channels" - ) - assert np.max(np.abs(t_.times - tfr.times)) < 1e-7, ValueError( - f"{tfr} and {t_} do not contain the same time instants" - ) + assert ( + t_.ch_names == ch_names + ), f"{tfr} and {t_} do not contain the same channels" + assert ( + np.max(np.abs(t_.times - tfr.times)) < 1e-7 + ), f"{tfr} and {t_} do not contain the same time instants" # use union of bad channels bads = list(set(tfr.info["bads"]).union(*(t_.info["bads"] for t_ in all_tfr[1:]))) diff --git a/mne/utils/numerics.py b/mne/utils/numerics.py index 4bf8d094f81..eed23998774 100644 --- a/mne/utils/numerics.py +++ b/mne/utils/numerics.py @@ -515,37 +515,42 @@ def _freq_mask(freqs, sfreq, fmin=None, fmax=None, raise_error=True): def grand_average(all_inst, interpolate_bads=True, drop_bads=True): - """Make grand average of a list of Evoked or AverageTFR data. + """Make grand average of a list of Evoked, AverageTFR, or Spectrum data. - For :class:`mne.Evoked` data, the function interpolates bad channels based - on the ``interpolate_bads`` parameter. If ``interpolate_bads`` is True, - the grand average file will contain good channels and the bad channels - interpolated from the good MEG/EEG channels. - For :class:`mne.time_frequency.AverageTFR` data, the function takes the - subset of channels not marked as bad in any of the instances. + For :class:`mne.Evoked` data, the function interpolates bad channels based on the + ``interpolate_bads`` parameter. If ``interpolate_bads`` is True, the grand average + file will contain good channels and the bad channels interpolated from the good + MEG/EEG channels. + For :class:`mne.time_frequency.AverageTFR` and :class:`mne.time_frequency.Spectrum` + data, the function takes the subset of channels not marked as bad in any of the + instances. - The ``grand_average.nave`` attribute will be equal to the number - of evoked datasets used to calculate the grand average. + The ``grand_average.nave`` attribute will be equal to the number of datasets used to + calculate the grand average. - .. note:: A grand average evoked should not be used for source - localization. + .. note:: A grand average evoked should not be used for source localization. Parameters ---------- - all_inst : list of Evoked or AverageTFR - The evoked datasets. + all_inst : list of Evoked, AverageTFR or Spectrum + The datasets. + + .. versionchanged:: 1.10.0 + Added support for :class:`~mne.time_frequency.Spectrum` objects. + interpolate_bads : bool If True, bad MEG and EEG channels are interpolated. Ignored for - AverageTFR. + :class:`~mne.time_frequency.AverageTFR` and + :class:`~mne.time_frequency.Spectrum` data. drop_bads : bool - If True, drop all bad channels marked as bad in any data set. - If neither interpolate_bads nor drop_bads is True, in the output file, - every channel marked as bad in at least one of the input files will be - marked as bad, but no interpolation or dropping will be performed. + If True, drop all bad channels marked as bad in any data set. If neither + ``interpolate_bads`` nor ``drop_bads`` is `True`, in the output file, every + channel marked as bad in at least one of the input files will be marked as bad, + but no interpolation or dropping will be performed. Returns ------- - grand_average : Evoked | AverageTFR + grand_average : Evoked | AverageTFR | Spectrum The grand average data. Same type as input. Notes @@ -558,15 +563,17 @@ def grand_average(all_inst, interpolate_bads=True, drop_bads=True): # check if all elements in the given list are evoked data from ..channels.channels import equalize_channels from ..evoked import Evoked - from ..time_frequency import AverageTFR + from ..time_frequency import AverageTFR, Spectrum if not all_inst: - raise ValueError("Please pass a list of Evoked or AverageTFR objects.") + raise ValueError( + "Please pass a list of Evoked, AverageTFR, or Spectrum objects." + ) elif len(all_inst) == 1: warn("Only a single dataset was passed to mne.grand_average().") inst_type = type(all_inst[0]) - _validate_type(all_inst[0], (Evoked, AverageTFR), "All elements") + _validate_type(all_inst[0], (Evoked, AverageTFR, Spectrum), "All elements") for inst in all_inst: _validate_type(inst, inst_type, "All elements", "of the same type") @@ -581,6 +588,8 @@ def grand_average(all_inst, interpolate_bads=True, drop_bads=True): for inst in all_inst ] from ..evoked import combine_evoked as combine + elif isinstance(all_inst[0], Spectrum): + from ..time_frequency.spectrum import combine_spectrum as combine else: # isinstance(all_inst[0], AverageTFR): from ..time_frequency.tfr import combine_tfr as combine @@ -591,9 +600,9 @@ def grand_average(all_inst, interpolate_bads=True, drop_bads=True): inst.drop_channels(bads) equalize_channels(all_inst, copy=False) - # make grand_average object using combine_[evoked/tfr] + # make grand_average object using combine_[evoked/tfr/spectrum] grand_average = combine(all_inst, weights="equal") - # change the grand_average.nave to the number of Evokeds + # change the grand_average.nave to the number of datasets grand_average.nave = len(all_inst) # change comment field grand_average.comment = f"Grand average (n = {grand_average.nave})" From 2ae61edccb2af5b5f9f3a89a3131499b5c229c27 Mon Sep 17 00:00:00 2001 From: "Thomas S. Binns" Date: Tue, 14 Jan 2025 00:08:54 +0000 Subject: [PATCH 04/24] Add `combine_tfr` to API (#13054) --- doc/api/time_frequency.rst | 1 + doc/changes/devel/13054.newfeature.rst | 1 + mne/time_frequency/__init__.pyi | 2 ++ mne/time_frequency/tfr.py | 8 ++++---- 4 files changed, 8 insertions(+), 4 deletions(-) create mode 100644 doc/changes/devel/13054.newfeature.rst diff --git a/doc/api/time_frequency.rst b/doc/api/time_frequency.rst index b66b1b6ca64..a9ab2c34268 100644 --- a/doc/api/time_frequency.rst +++ b/doc/api/time_frequency.rst @@ -32,6 +32,7 @@ Functions that operate on mne-python objects: :toctree: ../generated/ combine_spectrum + combine_tfr csd_tfr csd_fourier csd_multitaper diff --git a/doc/changes/devel/13054.newfeature.rst b/doc/changes/devel/13054.newfeature.rst new file mode 100644 index 00000000000..3c89290e7fe --- /dev/null +++ b/doc/changes/devel/13054.newfeature.rst @@ -0,0 +1 @@ +Added :func:`mne.time_frequency.combine_tfr` to allow combining TFRs across tapers, by `Thomas Binns`_. \ No newline at end of file diff --git a/mne/time_frequency/__init__.pyi b/mne/time_frequency/__init__.pyi index a612c2a850a..6b53c39a98b 100644 --- a/mne/time_frequency/__init__.pyi +++ b/mne/time_frequency/__init__.pyi @@ -12,6 +12,7 @@ __all__ = [ "Spectrum", "SpectrumArray", "combine_spectrum", + "combine_tfr", "csd_array_fourier", "csd_array_morlet", "csd_array_multitaper", @@ -73,6 +74,7 @@ from .tfr import ( EpochsTFRArray, RawTFR, RawTFRArray, + combine_tfr, fwhm, morlet, read_tfrs, diff --git a/mne/time_frequency/tfr.py b/mne/time_frequency/tfr.py index b1736f151d2..71dabce6d31 100644 --- a/mne/time_frequency/tfr.py +++ b/mne/time_frequency/tfr.py @@ -3914,10 +3914,10 @@ def __init__( def combine_tfr(all_tfr, weights="nave"): """Merge AverageTFR data by weighted addition. - Create a new AverageTFR instance, using a combination of the supplied - instances as its data. By default, the mean (weighted by trials) is used. - Subtraction can be performed by passing negative weights (e.g., [1, -1]). - Data must have the same channels and the same time instants. + Create a new :class:`mne.time_frequency.AverageTFR` instance, using a combination of + the supplied instances as its data. By default, the mean (weighted by trials) is + used. Subtraction can be performed by passing negative weights (e.g., [1, -1]). Data + must have the same channels and the same time instants. Parameters ---------- From 5fec4e024a963c3f628693ab172d5b77cbafe6db Mon Sep 17 00:00:00 2001 From: Simon Kern <14980558+skjerns@users.noreply.github.com> Date: Tue, 14 Jan 2025 12:46:03 +0100 Subject: [PATCH 05/24] [DOC] extend documentation for add_channels (#13051) --- mne/channels/channels.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/mne/channels/channels.py b/mne/channels/channels.py index 8fbff33c13e..ed6dd8508cc 100644 --- a/mne/channels/channels.py +++ b/mne/channels/channels.py @@ -661,17 +661,21 @@ def _pick_projs(self): return self def add_channels(self, add_list, force_update_info=False): - """Append new channels to the instance. + """Append new channels from other MNE objects to the instance. Parameters ---------- add_list : list - A list of objects to append to self. Must contain all the same - type as the current object. + A list of MNE objects to append to the current instance. + The channels contained in the other instances are appended to the + channels of the current instance. Therefore, all other instances + must be of the same type as the current object. + See notes on how to add data coming from an array. force_update_info : bool If True, force the info for objects to be appended to match the - values in ``self``. This should generally only be used when adding - stim channels for which important metadata won't be overwritten. + values of the current instance. This should generally only be + used when adding stim channels for which important metadata won't + be overwritten. .. versionadded:: 0.12 @@ -688,6 +692,12 @@ def add_channels(self, add_list, force_update_info=False): ----- If ``self`` is a Raw instance that has been preloaded into a :obj:`numpy.memmap` instance, the memmap will be resized. + + This function expects an MNE object to be appended (e.g. :class:`~mne.io.Raw`, + :class:`~mne.Epochs`, :class:`~mne.Evoked`). If you simply want to add a + channel based on values of an np.ndarray, you need to create a + :class:`~mne.io.RawArray`. + See `_ """ # avoid circular imports from ..epochs import BaseEpochs From c0da91db9098b92ec3c20d8c0e237d0e02683865 Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Tue, 14 Jan 2025 14:19:02 -0500 Subject: [PATCH 06/24] BUG: Fix bug with interval calculation (#13062) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- azure-pipelines.yml | 2 +- doc/changes/devel/13062.bugfix.rst | 1 + mne/preprocessing/_fine_cal.py | 15 +++++++++------ mne/preprocessing/tests/test_fine_cal.py | 14 +++++++++++++- 4 files changed, 24 insertions(+), 8 deletions(-) create mode 100644 doc/changes/devel/13062.bugfix.rst diff --git a/azure-pipelines.yml b/azure-pipelines.yml index 2964ac1e07d..d0aa9ea9e76 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -243,7 +243,7 @@ stages: PYTHONIOENCODING: 'utf-8' AZURE_CI_WINDOWS: 'true' PYTHON_ARCH: 'x64' - timeoutInMinutes: 75 + timeoutInMinutes: 80 strategy: maxParallel: 4 matrix: diff --git a/doc/changes/devel/13062.bugfix.rst b/doc/changes/devel/13062.bugfix.rst new file mode 100644 index 00000000000..9e01fc4c835 --- /dev/null +++ b/doc/changes/devel/13062.bugfix.rst @@ -0,0 +1 @@ +Fix computation of time intervals in :func:`mne.preprocessing.compute_fine_calibration` by `Eric Larson`_. diff --git a/mne/preprocessing/_fine_cal.py b/mne/preprocessing/_fine_cal.py index 41d20539ce0..b43983a87eb 100644 --- a/mne/preprocessing/_fine_cal.py +++ b/mne/preprocessing/_fine_cal.py @@ -156,11 +156,12 @@ def compute_fine_calibration( # 1. Rotate surface normals using magnetometer information (if present) # cals = np.ones(len(info["ch_names"])) - time_idxs = raw.time_as_index(np.arange(0.0, raw.times[-1], t_window)) - if len(time_idxs) <= 1: - time_idxs = np.array([0, len(raw.times)], int) - else: - time_idxs[-1] = len(raw.times) + end = len(raw.times) + 1 + time_idxs = np.arange(0, end, int(round(t_window * raw.info["sfreq"]))) + if len(time_idxs) == 1: + time_idxs = np.concatenate([time_idxs, [end]]) + if time_idxs[-1] != end: + time_idxs[-1] = end count = 0 locs = np.array([ch["loc"] for ch in info["chs"]]) zs = locs[mag_picks, -3:].copy() @@ -388,9 +389,11 @@ def _adjust_mag_normals(info, data, origin, ext_order, *, angle_limit, err_limit each_err = _data_err(data, S_tot, cals, axis=-1)[picks_mag] n_bad = (each_err > err_limit).sum() if n_bad: + bad_max = np.argmax(each_err) reason.append( f"{n_bad} residual{_pl(n_bad)} > {err_limit:0.1f}% " - f"(max: {each_err.max():0.2f}%)" + f"(max: {each_err[bad_max]:0.2f}% @ " + f"{info['ch_names'][picks_mag[bad_max]]})" ) reason = ", ".join(reason) if reason: diff --git a/mne/preprocessing/tests/test_fine_cal.py b/mne/preprocessing/tests/test_fine_cal.py index 02c596bf4bc..45971620db5 100644 --- a/mne/preprocessing/tests/test_fine_cal.py +++ b/mne/preprocessing/tests/test_fine_cal.py @@ -20,7 +20,7 @@ ) from mne.preprocessing.tests.test_maxwell import _assert_shielding from mne.transforms import _angle_dist_between_rigid -from mne.utils import object_diff +from mne.utils import catch_logging, object_diff # Define fine calibration filepaths data_path = testing.data_path(download=False) @@ -289,3 +289,15 @@ def test_fine_cal_systems(system, tmp_path): got_corrs = np.corrcoef([raw_data, raw_sss_data, raw_sss_cal_data]) got_corrs = got_corrs[np.triu_indices(3, 1)] assert_allclose(got_corrs, corrs, atol=corr_tol) + if system == "fil": + with catch_logging(verbose=True) as log: + compute_fine_calibration( + raw.copy().crop(0, 0.12).pick(raw.ch_names[:12]), + t_window=0.06, # 2 segments + angle_limit=angle_limit, + err_limit=err_limit, + ext_order=2, + verbose=True, + ) + log = log.getvalue() + assert "(averaging over 2 time intervals)" in log, log From d472c268cb39fb6e4bf0dad24c802b17efdd4a33 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 14 Jan 2025 20:39:05 +0000 Subject: [PATCH 07/24] [pre-commit.ci] pre-commit autoupdate (#13060) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Eric Larson --- .pre-commit-config.yaml | 4 ++-- azure-pipelines.yml | 2 +- doc/conf.py | 2 +- doc/sphinxext/credit_tools.py | 8 +++---- doc/sphinxext/related_software.py | 6 ++--- doc/sphinxext/unit_role.py | 3 +-- examples/inverse/vector_mne_solution.py | 2 +- examples/visualization/evoked_whitening.py | 2 +- mne/_fiff/_digitization.py | 3 +-- mne/_fiff/meas_info.py | 7 +++--- mne/_fiff/proj.py | 16 ++++++------- mne/_fiff/reference.py | 6 ++--- mne/_fiff/tag.py | 3 +-- mne/_fiff/tests/test_meas_info.py | 2 +- mne/_fiff/tests/test_pick.py | 2 +- mne/beamformer/_compute_beamformer.py | 10 ++++---- mne/beamformer/tests/test_lcmv.py | 2 +- mne/bem.py | 14 +++++------ mne/channels/channels.py | 2 +- mne/channels/montage.py | 2 +- mne/channels/tests/test_channels.py | 3 +-- mne/channels/tests/test_montage.py | 7 +----- mne/commands/mne_make_scalp_surfaces.py | 3 +-- mne/commands/mne_setup_source_space.py | 3 +-- mne/coreg.py | 9 +++----- mne/cov.py | 10 ++++---- mne/datasets/_fetch.py | 3 +-- mne/datasets/config.py | 9 ++++---- mne/datasets/sleep_physionet/age.py | 5 +--- mne/epochs.py | 16 +++++-------- mne/event.py | 2 +- mne/evoked.py | 12 ++++------ mne/export/_egimff.py | 2 +- mne/export/_export.py | 3 +-- mne/export/tests/test_export.py | 4 ++-- mne/filter.py | 5 ++-- mne/forward/_field_interpolation.py | 5 ++-- mne/forward/_make_forward.py | 11 ++++----- mne/forward/forward.py | 11 ++++----- mne/gui/_coreg.py | 6 ++--- mne/html_templates/_templates.py | 2 +- mne/io/array/__init__.py | 2 +- mne/io/array/{array.py => _array.py} | 0 mne/io/artemis123/tests/test_artemis123.py | 6 ++--- mne/io/base.py | 5 ++-- mne/io/ctf/ctf.py | 2 +- mne/io/ctf/info.py | 3 +-- mne/io/ctf/tests/test_ctf.py | 6 ++--- mne/io/edf/edf.py | 6 ++--- mne/io/egi/egimff.py | 2 +- mne/io/fieldtrip/fieldtrip.py | 2 +- mne/io/fil/tests/test_fil.py | 18 +++++++-------- mne/io/neuralynx/tests/test_neuralynx.py | 12 +++++----- mne/io/nirx/nirx.py | 2 +- mne/io/tests/test_raw.py | 2 +- mne/label.py | 14 ++++------- mne/minimum_norm/inverse.py | 4 ++-- mne/minimum_norm/tests/test_inverse.py | 3 +-- mne/morph.py | 9 +++----- mne/preprocessing/_fine_cal.py | 2 +- mne/preprocessing/artifact_detection.py | 8 +++---- mne/preprocessing/eog.py | 4 ++-- mne/preprocessing/hfc.py | 3 +-- mne/preprocessing/ica.py | 18 +++++++-------- mne/preprocessing/ieeg/_volume.py | 2 +- mne/preprocessing/infomax_.py | 3 +-- mne/preprocessing/maxwell.py | 8 +++---- mne/preprocessing/nirs/_beer_lambert_law.py | 2 +- mne/preprocessing/tests/test_maxwell.py | 6 ++--- mne/preprocessing/xdawn.py | 3 +-- mne/report/report.py | 17 ++++++-------- mne/source_estimate.py | 9 ++++---- mne/source_space/_source_space.py | 10 ++++---- mne/surface.py | 17 +++++++------- mne/tests/test_annotations.py | 3 +-- mne/tests/test_dipole.py | 6 ++--- mne/tests/test_docstring_parameters.py | 3 +-- mne/tests/test_epochs.py | 12 +++++----- mne/tests/test_filter.py | 6 ++--- mne/time_frequency/_stft.py | 3 +-- mne/time_frequency/csd.py | 5 ++-- mne/time_frequency/spectrum.py | 18 +++++++-------- mne/time_frequency/tfr.py | 23 +++++++++---------- mne/utils/_logging.py | 2 +- mne/utils/check.py | 13 ++++------- mne/utils/config.py | 14 +++++------ mne/utils/misc.py | 2 +- mne/viz/_brain/_brain.py | 12 +++++----- mne/viz/_brain/tests/test_brain.py | 12 +++++----- mne/viz/_proj.py | 3 +-- mne/viz/backends/_utils.py | 3 +-- mne/viz/misc.py | 5 ++-- mne/viz/tests/test_3d.py | 2 +- mne/viz/topomap.py | 14 ++++------- mne/viz/utils.py | 2 +- tools/dev/ensure_headers.py | 12 +++++----- tools/hooks/update_environment_file.py | 2 +- tutorials/forward/20_source_alignment.py | 4 ++-- tutorials/forward/30_forward.py | 2 +- tutorials/intro/15_inplace.py | 4 ++-- .../40_artifact_correction_ica.py | 3 +-- .../50_artifact_correction_ssp.py | 2 +- 102 files changed, 276 insertions(+), 355 deletions(-) rename mne/io/array/{array.py => _array.py} (100%) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 2ba7f294c66..cb769988655 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,7 +1,7 @@ repos: # Ruff mne - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.8.6 + rev: v0.9.1 hooks: - id: ruff name: ruff lint mne @@ -82,7 +82,7 @@ repos: # zizmor - repo: https://github.com/woodruffw/zizmor-pre-commit - rev: v1.0.0 + rev: v1.1.1 hooks: - id: zizmor diff --git a/azure-pipelines.yml b/azure-pipelines.yml index d0aa9ea9e76..7149edac50b 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -112,7 +112,7 @@ stages: - bash: | set -e python -m pip install --progress-bar off --upgrade pip - python -m pip install --progress-bar off "mne-qt-browser[opengl] @ git+https://github.com/mne-tools/mne-qt-browser.git@main" pyvista scikit-learn python-picard qtpy nibabel sphinx-gallery "PySide6!=6.8.0,!=6.8.0.1" pandas neo pymatreader antio defusedxml + python -m pip install --progress-bar off "mne-qt-browser[opengl] @ git+https://github.com/mne-tools/mne-qt-browser.git@main" pyvista scikit-learn python-picard qtpy nibabel sphinx-gallery "PySide6!=6.8.0,!=6.8.0.1,!=6.8.1.1" pandas neo pymatreader antio defusedxml python -m pip uninstall -yq mne python -m pip install --progress-bar off --upgrade -e .[test] displayName: 'Install dependencies with pip' diff --git a/doc/conf.py b/doc/conf.py index 96028fb9045..74f66d8f6ae 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -1289,7 +1289,7 @@ def fix_sklearn_inherited_docstrings(app, what, name, obj, options, lines): rst_prolog += f""" .. |{icon}| raw:: html - + """ rst_prolog += """ diff --git a/doc/sphinxext/credit_tools.py b/doc/sphinxext/credit_tools.py index 708dcf00ce8..e22bd0b5530 100644 --- a/doc/sphinxext/credit_tools.py +++ b/doc/sphinxext/credit_tools.py @@ -169,7 +169,7 @@ def generate_credit_rst(app=None, *, verbose=False): if author["e"] is not None: if author["e"] not in name_map: unknown_emails.add( - f'{author["e"].ljust(29)} ' + f"{author['e'].ljust(29)} " "https://github.com/mne-tools/mne-python/pull/" f"{commit}/files" ) @@ -178,9 +178,9 @@ def generate_credit_rst(app=None, *, verbose=False): else: name = author["n"] if name in manual_renames: - assert _good_name( - manual_renames[name] - ), f"Bad manual rename: {name}" + assert _good_name(manual_renames[name]), ( + f"Bad manual rename: {name}" + ) name = manual_renames[name] if " " in name: first, last = name.rsplit(" ", maxsplit=1) diff --git a/doc/sphinxext/related_software.py b/doc/sphinxext/related_software.py index ac1b741b9af..ab159b0fcb4 100644 --- a/doc/sphinxext/related_software.py +++ b/doc/sphinxext/related_software.py @@ -163,9 +163,9 @@ def _get_packages() -> dict[str, str]: assert not dups, f"Duplicates in MANUAL_PACKAGES and PYPI_PACKAGES: {sorted(dups)}" # And the installer and PyPI-only should be disjoint: dups = set(PYPI_PACKAGES) & set(packages) - assert ( - not dups - ), f"Duplicates in PYPI_PACKAGES and installer packages: {sorted(dups)}" + assert not dups, ( + f"Duplicates in PYPI_PACKAGES and installer packages: {sorted(dups)}" + ) for name in PYPI_PACKAGES | set(MANUAL_PACKAGES): if name not in packages: packages.append(name) diff --git a/doc/sphinxext/unit_role.py b/doc/sphinxext/unit_role.py index b52665e8321..bf31ddf76c4 100644 --- a/doc/sphinxext/unit_role.py +++ b/doc/sphinxext/unit_role.py @@ -10,8 +10,7 @@ def unit_role(name, rawtext, text, lineno, inliner, options={}, content=[]): # def pass_error_to_sphinx(rawtext, text, lineno, inliner): msg = inliner.reporter.error( - "The :unit: role requires a space-separated number and unit; " - f"got {text}", + f"The :unit: role requires a space-separated number and unit; got {text}", line=lineno, ) prb = inliner.problematic(rawtext, rawtext, msg) diff --git a/examples/inverse/vector_mne_solution.py b/examples/inverse/vector_mne_solution.py index ca953cd2f24..f6ae788c145 100644 --- a/examples/inverse/vector_mne_solution.py +++ b/examples/inverse/vector_mne_solution.py @@ -79,7 +79,7 @@ # inverse was computed with loose=0.2 print( "Absolute cosine similarity between source normals and directions: " - f'{np.abs(np.sum(directions * inv["source_nn"][2::3], axis=-1)).mean()}' + f"{np.abs(np.sum(directions * inv['source_nn'][2::3], axis=-1)).mean()}" ) brain_max = stc_max.plot( initial_time=peak_time, diff --git a/examples/visualization/evoked_whitening.py b/examples/visualization/evoked_whitening.py index ed05ae3ba11..4bcb4bc8c04 100644 --- a/examples/visualization/evoked_whitening.py +++ b/examples/visualization/evoked_whitening.py @@ -85,7 +85,7 @@ print("Covariance estimates sorted from best to worst") for c in noise_covs: - print(f'{c["method"]} : {c["loglik"]}') + print(f"{c['method']} : {c['loglik']}") # %% # Show the evoked data: diff --git a/mne/_fiff/_digitization.py b/mne/_fiff/_digitization.py index e55fd5d2dae..eb8b6bc396a 100644 --- a/mne/_fiff/_digitization.py +++ b/mne/_fiff/_digitization.py @@ -328,8 +328,7 @@ def _get_data_as_dict_from_dig(dig, exclude_ref_channel=True): dig_coord_frames = set([FIFF.FIFFV_COORD_HEAD]) if len(dig_coord_frames) != 1: raise RuntimeError( - "Only single coordinate frame in dig is supported, " - f"got {dig_coord_frames}" + f"Only single coordinate frame in dig is supported, got {dig_coord_frames}" ) dig_ch_pos_location = np.array(dig_ch_pos_location) dig_ch_pos_location.shape = (-1, 3) # empty will be (0, 3) diff --git a/mne/_fiff/meas_info.py b/mne/_fiff/meas_info.py index ecc93591a05..51612824a6a 100644 --- a/mne/_fiff/meas_info.py +++ b/mne/_fiff/meas_info.py @@ -455,7 +455,7 @@ def _check_set(ch, projs, ch_type): for proj in projs: if ch["ch_name"] in proj["data"]["col_names"]: raise RuntimeError( - f'Cannot change channel type for channel {ch["ch_name"]} in ' + f"Cannot change channel type for channel {ch['ch_name']} in " f'projector "{proj["desc"]}"' ) ch["kind"] = new_kind @@ -1867,7 +1867,7 @@ def _check_consistency(self, prepend_error=""): ): raise RuntimeError( f'{prepend_error}info["meas_date"] must be a datetime object in UTC' - f' or None, got {repr(self["meas_date"])!r}' + f" or None, got {repr(self['meas_date'])!r}" ) chs = [ch["ch_name"] for ch in self["chs"]] @@ -3680,8 +3680,7 @@ def _write_ch_infos(fid, chs, reset_range, ch_names_mapping): # only write new-style channel information if necessary if len(ch_names_mapping): logger.info( - " Writing channel names to FIF truncated to 15 characters " - "with remapping" + " Writing channel names to FIF truncated to 15 characters with remapping" ) for ch in chs: start_block(fid, FIFF.FIFFB_CH_INFO) diff --git a/mne/_fiff/proj.py b/mne/_fiff/proj.py index 0376826138a..d6ec108e34d 100644 --- a/mne/_fiff/proj.py +++ b/mne/_fiff/proj.py @@ -76,7 +76,7 @@ def __repr__(self): # noqa: D105 s += f", active : {self['active']}" s += f", n_channels : {len(self['data']['col_names'])}" if self["explained_var"] is not None: - s += f', exp. var : {self["explained_var"] * 100:0.2f}%' + s += f", exp. var : {self['explained_var'] * 100:0.2f}%" return f"" # speed up info copy by taking advantage of mutability @@ -324,8 +324,7 @@ def apply_proj(self, verbose=None): if all(p["active"] for p in self.info["projs"]): logger.info( - "Projections have already been applied. " - "Setting proj attribute to True." + "Projections have already been applied. Setting proj attribute to True." ) return self @@ -663,9 +662,9 @@ def _read_proj(fid, node, *, ch_names_mapping=None, verbose=None): for proj in projs: misc = "active" if proj["active"] else " idle" logger.info( - f' {proj["desc"]} ' - f'({proj["data"]["nrow"]} x ' - f'{len(proj["data"]["col_names"])}) {misc}' + f" {proj['desc']} " + f"({proj['data']['nrow']} x " + f"{len(proj['data']['col_names'])}) {misc}" ) return projs @@ -795,8 +794,7 @@ def _make_projector(projs, ch_names, bads=(), include_active=True, inplace=False if not p["active"] or include_active: if len(p["data"]["col_names"]) != len(np.unique(p["data"]["col_names"])): raise ValueError( - f"Channel name list in projection item {k}" - " contains duplicate items" + f"Channel name list in projection item {k} contains duplicate items" ) # Get the two selection vectors to pick correct elements from @@ -832,7 +830,7 @@ def _make_projector(projs, ch_names, bads=(), include_active=True, inplace=False ) ): warn( - f'Projection vector {repr(p["desc"])} has been ' + f"Projection vector {repr(p['desc'])} has been " f"reduced to {100 * psize:0.2f}% of its " "original magnitude by subselecting " f"{len(vecsel)}/{orig_n} of the original " diff --git a/mne/_fiff/reference.py b/mne/_fiff/reference.py index e70bf5e36c1..b4c050c096d 100644 --- a/mne/_fiff/reference.py +++ b/mne/_fiff/reference.py @@ -102,7 +102,7 @@ def _check_before_dict_reference(inst, ref_dict): raise TypeError( f"{elem_name.capitalize()}s in the ref_channels dict must be strings. " f"Your dict has {elem_name}s of type " - f'{", ".join(map(lambda x: x.__name__, bad_elem))}.' + f"{', '.join(map(lambda x: x.__name__, bad_elem))}." ) # Check that keys are valid channels and values are lists-of-valid-channels @@ -113,8 +113,8 @@ def _check_before_dict_reference(inst, ref_dict): for elem_name, elem in dict(key=keys, value=values).items(): if bad_elem := elem - ch_set: raise ValueError( - f'ref_channels dict contains invalid {elem_name}(s) ' - f'({", ".join(bad_elem)}) ' + f"ref_channels dict contains invalid {elem_name}(s) " + f"({', '.join(bad_elem)}) " "that are not names of channels in the instance." ) # Check that values are not bad channels diff --git a/mne/_fiff/tag.py b/mne/_fiff/tag.py index abc7d32036b..3fd36454d58 100644 --- a/mne/_fiff/tag.py +++ b/mne/_fiff/tag.py @@ -70,8 +70,7 @@ def _frombuffer_rows(fid, tag_size, dtype=None, shape=None, rlims=None): have_shape = tag_size // item_size if want_shape != have_shape: raise ValueError( - f"Wrong shape specified, requested {want_shape} but got " - f"{have_shape}" + f"Wrong shape specified, requested {want_shape} but got {have_shape}" ) if not len(rlims) == 2: raise ValueError("rlims must have two elements") diff --git a/mne/_fiff/tests/test_meas_info.py b/mne/_fiff/tests/test_meas_info.py index d088da2a4a2..a38ecaade50 100644 --- a/mne/_fiff/tests/test_meas_info.py +++ b/mne/_fiff/tests/test_meas_info.py @@ -1118,7 +1118,7 @@ def test_channel_name_limit(tmp_path, monkeypatch, fname): meas_info, "_read_extended_ch_info", _read_extended_ch_info ) short_proj_names = [ - f"{name[:13 - bool(len(ref_names))]}-{ni}" + f"{name[: 13 - bool(len(ref_names))]}-{ni}" for ni, name in enumerate(long_proj_names) ] assert raw_read.info["projs"][0]["data"]["col_names"] == short_proj_names diff --git a/mne/_fiff/tests/test_pick.py b/mne/_fiff/tests/test_pick.py index 90830e1d5e5..5d1b24247ab 100644 --- a/mne/_fiff/tests/test_pick.py +++ b/mne/_fiff/tests/test_pick.py @@ -136,7 +136,7 @@ def _channel_type_old(info, idx): else: return t - raise ValueError(f'Unknown channel type for {ch["ch_name"]}') + raise ValueError(f"Unknown channel type for {ch['ch_name']}") def _assert_channel_types(info): diff --git a/mne/beamformer/_compute_beamformer.py b/mne/beamformer/_compute_beamformer.py index bb947cdd757..16bedc2c317 100644 --- a/mne/beamformer/_compute_beamformer.py +++ b/mne/beamformer/_compute_beamformer.py @@ -507,13 +507,13 @@ def __repr__(self): # noqa: D105 n_channels, ) if self["pick_ori"] is not None: - out += f', {self["pick_ori"]} ori' + out += f", {self['pick_ori']} ori" if self["weight_norm"] is not None: - out += f', {self["weight_norm"]} norm' + out += f", {self['weight_norm']} norm" if self.get("inversion") is not None: - out += f', {self["inversion"]} inversion' + out += f", {self['inversion']} inversion" if "rank" in self: - out += f', rank {self["rank"]}' + out += f", rank {self['rank']}" out += ">" return out @@ -531,7 +531,7 @@ def save(self, fname, overwrite=False, verbose=None): """ _, write_hdf5 = _import_h5io_funcs() - ending = f'-{self["kind"].lower()}.h5' + ending = f"-{self['kind'].lower()}.h5" check_fname(fname, self["kind"], (ending,)) csd_orig = None try: diff --git a/mne/beamformer/tests/test_lcmv.py b/mne/beamformer/tests/test_lcmv.py index 957dbaf5284..9ae5473e190 100644 --- a/mne/beamformer/tests/test_lcmv.py +++ b/mne/beamformer/tests/test_lcmv.py @@ -380,7 +380,7 @@ def test_make_lcmv_bem(tmp_path, reg, proj, kind): rank = 17 if proj else 20 assert "LCMV" in repr(filters) assert "unknown subject" not in repr(filters) - assert f'{fwd["nsource"]} vert' in repr(filters) + assert f"{fwd['nsource']} vert" in repr(filters) assert "20 ch" in repr(filters) assert f"rank {rank}" in repr(filters) diff --git a/mne/bem.py b/mne/bem.py index d361272fd49..22aa02d2a0d 100644 --- a/mne/bem.py +++ b/mne/bem.py @@ -91,7 +91,7 @@ class ConductorModel(dict): def __repr__(self): # noqa: D105 if self["is_sphere"]: - center = ", ".join(f"{x * 1000.:.1f}" for x in self["r0"]) + center = ", ".join(f"{x * 1000.0:.1f}" for x in self["r0"]) rad = self.radius if rad is None: # no radius / MEG only extra = f"Sphere (no layers): r0=[{center}] mm" @@ -538,7 +538,7 @@ def _assert_complete_surface(surf, incomplete="raise"): prop = tot_angle / (2 * np.pi) if np.abs(prop - 1.0) > 1e-5: msg = ( - f'Surface {_bem_surf_name[surf["id"]]} is not complete (sum of ' + f"Surface {_bem_surf_name[surf['id']]} is not complete (sum of " f"solid angles yielded {prop}, should be 1.)" ) _on_missing(incomplete, msg, name="incomplete", error_klass=RuntimeError) @@ -571,7 +571,7 @@ def _check_surface_size(surf): sizes = surf["rr"].max(axis=0) - surf["rr"].min(axis=0) if (sizes < 0.05).any(): raise RuntimeError( - f'Dimensions of the surface {_bem_surf_name[surf["id"]]} seem too ' + f"Dimensions of the surface {_bem_surf_name[surf['id']]} seem too " f"small ({1000 * sizes.min():9.5f}). Maybe the unit of measure" " is meters instead of mm" ) @@ -599,8 +599,7 @@ def _surfaces_to_bem( # surfs can be strings (filenames) or surface dicts if len(surfs) not in (1, 3) or not (len(surfs) == len(ids) == len(sigmas)): raise ValueError( - "surfs, ids, and sigmas must all have the same " - "number of elements (1 or 3)" + "surfs, ids, and sigmas must all have the same number of elements (1 or 3)" ) for si, surf in enumerate(surfs): if isinstance(surf, str | Path | os.PathLike): @@ -1260,8 +1259,7 @@ def make_watershed_bem( if op.isdir(ws_dir): if not overwrite: raise RuntimeError( - f"{ws_dir} already exists. Use the --overwrite option" - " to recreate it." + f"{ws_dir} already exists. Use the --overwrite option to recreate it." ) else: shutil.rmtree(ws_dir) @@ -2460,7 +2458,7 @@ def check_seghead(surf_path=subj_path / "surf"): logger.info(f"{ii}. Creating {level} tessellation...") logger.info( f"{ii}.1 Decimating the dense tessellation " - f'({len(surf["tris"])} -> {n_tri} triangles)...' + f"({len(surf['tris'])} -> {n_tri} triangles)..." ) points, tris = decimate_surface( points=surf["rr"], triangles=surf["tris"], n_triangles=n_tri diff --git a/mne/channels/channels.py b/mne/channels/channels.py index ed6dd8508cc..bf9e58f2819 100644 --- a/mne/channels/channels.py +++ b/mne/channels/channels.py @@ -1382,7 +1382,7 @@ def read_ch_adjacency(fname, picks=None): raise ValueError( f"No built-in channel adjacency matrix found with name: " f"{ch_adj_name}. Valid names are: " - f'{", ".join(get_builtin_ch_adjacencies())}' + f"{', '.join(get_builtin_ch_adjacencies())}" ) ch_adj = [a for a in _BUILTIN_CHANNEL_ADJACENCIES if a.name == ch_adj_name][0] diff --git a/mne/channels/montage.py b/mne/channels/montage.py index b22b9220e14..15cef38dec7 100644 --- a/mne/channels/montage.py +++ b/mne/channels/montage.py @@ -1287,7 +1287,7 @@ def _backcompat_value(pos, ref_pos): f"Not setting position{_pl(extra)} of {len(extra)} {types} " f"channel{_pl(extra)} found in montage:\n{names}\n" "Consider setting the channel types to be of " - f'{docdict["montage_types"]} ' + f"{docdict['montage_types']} " "using inst.set_channel_types before calling inst.set_montage, " "or omit these channels when creating your montage." ) diff --git a/mne/channels/tests/test_channels.py b/mne/channels/tests/test_channels.py index f51b551a1c8..bb886c51a96 100644 --- a/mne/channels/tests/test_channels.py +++ b/mne/channels/tests/test_channels.py @@ -404,8 +404,7 @@ def test_adjacency_matches_ft(tmp_path): if hash_mne.hexdigest() != hash_ft.hexdigest(): raise ValueError( - f"Hash mismatch between built-in and FieldTrip neighbors " - f"for {fname}" + f"Hash mismatch between built-in and FieldTrip neighbors for {fname}" ) diff --git a/mne/channels/tests/test_montage.py b/mne/channels/tests/test_montage.py index 8add1398409..d9306b5e1bd 100644 --- a/mne/channels/tests/test_montage.py +++ b/mne/channels/tests/test_montage.py @@ -420,12 +420,7 @@ def test_documented(): ), pytest.param( partial(read_dig_hpts, unit="m"), - ( - "eeg Fp1 -95.0 -3. -3.\n" - "eeg AF7 -1 -1 -3\n" - "eeg A3 -2 -2 2\n" - "eeg A 0 0 0" - ), + ("eeg Fp1 -95.0 -3. -3.\neeg AF7 -1 -1 -3\neeg A3 -2 -2 2\neeg A 0 0 0"), make_dig_montage( ch_pos={ "A": [0.0, 0.0, 0.0], diff --git a/mne/commands/mne_make_scalp_surfaces.py b/mne/commands/mne_make_scalp_surfaces.py index 5b7d020b98d..894ede7fa1a 100644 --- a/mne/commands/mne_make_scalp_surfaces.py +++ b/mne/commands/mne_make_scalp_surfaces.py @@ -49,8 +49,7 @@ def run(): "--force", dest="force", action="store_true", - help="Force creation of the surface even if it has " - "some topological defects.", + help="Force creation of the surface even if it has some topological defects.", ) parser.add_option( "-t", diff --git a/mne/commands/mne_setup_source_space.py b/mne/commands/mne_setup_source_space.py index e536a59f90b..273e833b31c 100644 --- a/mne/commands/mne_setup_source_space.py +++ b/mne/commands/mne_setup_source_space.py @@ -62,8 +62,7 @@ def run(): parser.add_option( "--ico", dest="ico", - help="use the recursively subdivided icosahedron " - "to create the source space.", + help="use the recursively subdivided icosahedron to create the source space.", default=None, type="int", ) diff --git a/mne/coreg.py b/mne/coreg.py index f28c6142c96..c7549ee028a 100644 --- a/mne/coreg.py +++ b/mne/coreg.py @@ -876,8 +876,7 @@ def _scale_params(subject_to, subject_from, scale, subjects_dir): subjects_dir = get_subjects_dir(subjects_dir, raise_error=True) if (subject_from is None) != (scale is None): raise TypeError( - "Need to provide either both subject_from and scale " - "parameters, or neither." + "Need to provide either both subject_from and scale parameters, or neither." ) if subject_from is None: @@ -1402,8 +1401,7 @@ def _read_surface(filename, *, on_defects): complete_surface_info(bem, copy=False) except Exception: raise ValueError( - f"Error loading surface from {filename} (see " - "Terminal for details)." + f"Error loading surface from {filename} (see Terminal for details)." ) return bem @@ -2145,8 +2143,7 @@ def omit_head_shape_points(self, distance): mask = self._orig_hsp_point_distance <= distance n_excluded = np.sum(~mask) logger.info( - "Coregistration: Excluding %i head shape points with " - "distance >= %.3f m.", + "Coregistration: Excluding %i head shape points with distance >= %.3f m.", n_excluded, distance, ) diff --git a/mne/cov.py b/mne/cov.py index 8b86119c1d1..94239472fa2 100644 --- a/mne/cov.py +++ b/mne/cov.py @@ -1293,7 +1293,7 @@ def _compute_covariance_auto( data_ = data.copy() name = method_.__name__ if callable(method_) else method_ logger.info( - f'Estimating {cov_kind + (" " if cov_kind else "")}' + f"Estimating {cov_kind + (' ' if cov_kind else '')}" f"covariance using {name.upper()}" ) mp = method_params[method_] @@ -1712,7 +1712,7 @@ def _get_ch_whitener(A, pca, ch_type, rank): logger.info( f" Setting small {ch_type} eigenvalues to zero " - f'({"using" if pca else "without"} PCA)' + f"({'using' if pca else 'without'} PCA)" ) if pca: # No PCA case. # This line will reduce the actual number of variables in data @@ -2400,7 +2400,7 @@ def _read_cov(fid, node, cov_kind, limited=False, verbose=None): data = tag.data diag = True logger.info( - " %d x %d diagonal covariance (kind = " "%d) found.", + " %d x %d diagonal covariance (kind = %d) found.", dim, dim, cov_kind, @@ -2416,7 +2416,7 @@ def _read_cov(fid, node, cov_kind, limited=False, verbose=None): data.flat[:: dim + 1] /= 2.0 diag = False logger.info( - " %d x %d full covariance (kind = %d) " "found.", + " %d x %d full covariance (kind = %d) found.", dim, dim, cov_kind, @@ -2425,7 +2425,7 @@ def _read_cov(fid, node, cov_kind, limited=False, verbose=None): diag = False data = tag.data logger.info( - " %d x %d sparse covariance (kind = %d)" " found.", + " %d x %d sparse covariance (kind = %d) found.", dim, dim, cov_kind, diff --git a/mne/datasets/_fetch.py b/mne/datasets/_fetch.py index 1e38606f908..8f44459ad97 100644 --- a/mne/datasets/_fetch.py +++ b/mne/datasets/_fetch.py @@ -143,8 +143,7 @@ def fetch_dataset( if auth is not None: if len(auth) != 2: raise RuntimeError( - "auth should be a 2-tuple consisting " - "of a username and password/token." + "auth should be a 2-tuple consisting of a username and password/token." ) # processor to uncompress files diff --git a/mne/datasets/config.py b/mne/datasets/config.py index ccd4babacd9..75eff184cd1 100644 --- a/mne/datasets/config.py +++ b/mne/datasets/config.py @@ -92,8 +92,8 @@ phantom_kit="0.2", ucl_opm_auditory="0.2", ) -TESTING_VERSIONED = f'mne-testing-data-{RELEASES["testing"]}' -MISC_VERSIONED = f'mne-misc-data-{RELEASES["misc"]}' +TESTING_VERSIONED = f"mne-testing-data-{RELEASES['testing']}" +MISC_VERSIONED = f"mne-misc-data-{RELEASES['misc']}" # To update any other dataset besides `testing` or `misc`, upload the new # version of the data archive itself (e.g., to https://osf.io or wherever) and @@ -118,7 +118,7 @@ hash="md5:d94fe9f3abe949a507eaeb865fb84a3f", url=( "https://codeload.github.com/mne-tools/mne-testing-data/" - f'tar.gz/{RELEASES["testing"]}' + f"tar.gz/{RELEASES['testing']}" ), # In case we ever have to resort to osf.io again... # archive_name='mne-testing-data.tar.gz', @@ -131,8 +131,7 @@ archive_name=f"{MISC_VERSIONED}.tar.gz", # 'mne-misc-data', hash="md5:e343d3a00cb49f8a2f719d14f4758afe", url=( - "https://codeload.github.com/mne-tools/mne-misc-data/tar.gz/" - f'{RELEASES["misc"]}' + f"https://codeload.github.com/mne-tools/mne-misc-data/tar.gz/{RELEASES['misc']}" ), folder_name="MNE-misc-data", config_key="MNE_DATASETS_MISC_PATH", diff --git a/mne/datasets/sleep_physionet/age.py b/mne/datasets/sleep_physionet/age.py index c14282ed202..b5ea1764946 100644 --- a/mne/datasets/sleep_physionet/age.py +++ b/mne/datasets/sleep_physionet/age.py @@ -122,10 +122,7 @@ def fetch_data( ) _on_missing(on_missing, msg) if 13 in subjects and 2 in recording: - msg = ( - "Requested recording 2 for subject 13, but it is not available " - "in corpus." - ) + msg = "Requested recording 2 for subject 13, but it is not available in corpus." _on_missing(on_missing, msg) fnames = [] diff --git a/mne/epochs.py b/mne/epochs.py index 04b1a288bfe..679643ab969 100644 --- a/mne/epochs.py +++ b/mne/epochs.py @@ -1671,8 +1671,7 @@ def _get_data( # we start out with an empty array, allocate only if necessary data = np.empty((0, len(self.info["ch_names"]), len(self.times))) msg = ( - f"for {n_events} events and {len(self._raw_times)} " - "original time points" + f"for {n_events} events and {len(self._raw_times)} original time points" ) if self._decim > 1: msg += " (prior to decimation)" @@ -2301,8 +2300,7 @@ def save( logger.info(f"Splitting into {n_parts} parts") if n_parts > 100: # This must be an error raise ValueError( - f"Split size {split_size} would result in writing " - f"{n_parts} files" + f"Split size {split_size} would result in writing {n_parts} files" ) if len(self.drop_log) > 100000: @@ -3143,7 +3141,7 @@ def _ensure_list(x): raise ValueError( f"The event names in keep_first and keep_last must " f"be mutually exclusive. Specified in both: " - f'{", ".join(sorted(keep_first_and_last))}' + f"{', '.join(sorted(keep_first_and_last))}" ) del keep_first_and_last @@ -3163,7 +3161,7 @@ def _diff_input_strings_vs_event_id(input_strings, input_name, event_id): if event_name_diff: raise ValueError( f"Present in {input_name}, but missing from event_id: " - f'{", ".join(event_name_diff)}' + f"{', '.join(event_name_diff)}" ) _diff_input_strings_vs_event_id( @@ -3556,8 +3554,7 @@ def __init__( if not isinstance(raw, BaseRaw): raise ValueError( - "The first argument to `Epochs` must be an " - "instance of mne.io.BaseRaw" + "The first argument to `Epochs` must be an instance of mne.io.BaseRaw" ) info = deepcopy(raw.info) annotations = raw.annotations.copy() @@ -4441,8 +4438,7 @@ def _get_epoch_from_raw(self, idx, verbose=None): else: # read the correct subset of the data raise RuntimeError( - "Correct epoch could not be found, please " - "contact mne-python developers" + "Correct epoch could not be found, please contact mne-python developers" ) # the following is equivalent to this, but faster: # diff --git a/mne/event.py b/mne/event.py index 723615ea56a..a19270db1e6 100644 --- a/mne/event.py +++ b/mne/event.py @@ -1649,7 +1649,7 @@ def match_event_names(event_names, keys, *, on_missing="raise"): _on_missing( on_missing=on_missing, msg=f'Event name "{key}" could not be found. The following events ' - f'are present in the data: {", ".join(event_names)}', + f"are present in the data: {', '.join(event_names)}", error_klass=KeyError, ) diff --git a/mne/evoked.py b/mne/evoked.py index 5fb09db9d1b..c04f83531e3 100644 --- a/mne/evoked.py +++ b/mne/evoked.py @@ -962,7 +962,7 @@ def __neg__(self): if out.comment is not None and " + " in out.comment: out.comment = f"({out.comment})" # multiple conditions in evoked - out.comment = f'- {out.comment or "unknown"}' + out.comment = f"- {out.comment or 'unknown'}" return out def get_peak( @@ -1053,8 +1053,7 @@ def get_peak( raise ValueError('Channel type must be "grad" for merge_grads') elif mode == "neg": raise ValueError( - "Negative mode (mode=neg) does not make " - "sense with merge_grads=True" + "Negative mode (mode=neg) does not make sense with merge_grads=True" ) meg = eeg = misc = seeg = dbs = ecog = fnirs = False @@ -1650,12 +1649,12 @@ def combine_evoked(all_evoked, weights): if e.comment is not None and " + " in e.comment: # multiple conditions this_comment = f"({e.comment})" else: - this_comment = f'{e.comment or "unknown"}' + this_comment = f"{e.comment or 'unknown'}" # assemble everything if idx == 0: comment += f"{sign}{weight}{multiplier}{this_comment}" else: - comment += f' {sign or "+"} {weight}{multiplier}{this_comment}' + comment += f" {sign or '+'} {weight}{multiplier}{this_comment}" # special-case: combine_evoked([e1, -e2], [1, -1]) evoked.comment = comment.replace(" - - ", " + ") return evoked @@ -1872,8 +1871,7 @@ def _read_evoked(fname, condition=None, kind="average", allow_maxshield=False): if len(chs) != nchan: raise ValueError( - "Number of channels and number of " - "channel definitions are different" + "Number of channels and number of channel definitions are different" ) ch_names_mapping = _read_extended_ch_info(chs, my_evoked, fid) diff --git a/mne/export/_egimff.py b/mne/export/_egimff.py index 3792ea4a6a5..185afb5f558 100644 --- a/mne/export/_egimff.py +++ b/mne/export/_egimff.py @@ -53,7 +53,7 @@ def export_evokeds_mff(fname, evoked, history=None, *, overwrite=False, verbose= info = evoked[0].info if np.round(info["sfreq"]) != info["sfreq"]: raise ValueError( - f'Sampling frequency must be a whole number. sfreq: {info["sfreq"]}' + f"Sampling frequency must be a whole number. sfreq: {info['sfreq']}" ) sampling_rate = int(info["sfreq"]) diff --git a/mne/export/_export.py b/mne/export/_export.py index 490bf986895..6e63064bf7c 100644 --- a/mne/export/_export.py +++ b/mne/export/_export.py @@ -216,7 +216,6 @@ def _infer_check_export_fmt(fmt, fname, supported_formats): supported_str = ", ".join(supported) raise ValueError( - f"Format '{fmt}' is not supported. " - f"Supported formats are {supported_str}." + f"Format '{fmt}' is not supported. Supported formats are {supported_str}." ) return fmt diff --git a/mne/export/tests/test_export.py b/mne/export/tests/test_export.py index ca0853837fc..191e91b1eed 100644 --- a/mne/export/tests/test_export.py +++ b/mne/export/tests/test_export.py @@ -235,7 +235,7 @@ def test_edf_padding(tmp_path, pad_width): RuntimeWarning, match=( "EDF format requires equal-length data blocks.*" - f"{pad_width/1000:.3g} seconds of edge values were appended.*" + f"{pad_width / 1000:.3g} seconds of edge values were appended.*" ), ): raw.export(temp_fname) @@ -580,7 +580,7 @@ def test_export_to_mff_incompatible_sfreq(): """Test non-whole number sampling frequency throws ValueError.""" pytest.importorskip("mffpy", "0.5.7") evoked = read_evokeds(fname_evoked) - with pytest.raises(ValueError, match=f'sfreq: {evoked[0].info["sfreq"]}'): + with pytest.raises(ValueError, match=f"sfreq: {evoked[0].info['sfreq']}"): export_evokeds("output.mff", evoked) diff --git a/mne/filter.py b/mne/filter.py index ee5b34cd657..a7d7c883e2f 100644 --- a/mne/filter.py +++ b/mne/filter.py @@ -411,8 +411,7 @@ def _prep_for_filtering(x, copy, picks=None): picks = np.tile(picks, n_epochs) + offset elif len(orig_shape) > 3: raise ValueError( - "picks argument is not supported for data with more" - " than three dimensions" + "picks argument is not supported for data with more than three dimensions" ) assert all(0 <= pick < x.shape[0] for pick in picks) # guaranteed by above @@ -2873,7 +2872,7 @@ def design_mne_c_filter( h_width = (int(((n_freqs - 1) * h_trans_bandwidth) / (0.5 * sfreq)) + 1) // 2 h_start = int(((n_freqs - 1) * h_freq) / (0.5 * sfreq)) logger.info( - "filter : %7.3f ... %6.1f Hz bins : %d ... %d of %d " "hpw : %d lpw : %d", + "filter : %7.3f ... %6.1f Hz bins : %d ... %d of %d hpw : %d lpw : %d", l_freq, h_freq, l_start, diff --git a/mne/forward/_field_interpolation.py b/mne/forward/_field_interpolation.py index b505b5e45df..e98a147b560 100644 --- a/mne/forward/_field_interpolation.py +++ b/mne/forward/_field_interpolation.py @@ -96,7 +96,7 @@ def _pinv_trunc(x, miss): varexp /= varexp[-1] n = np.where(varexp >= (1.0 - miss))[0][0] + 1 logger.info( - " Truncating at %d/%d components to omit less than %g " "(%0.2g)", + " Truncating at %d/%d components to omit less than %g (%0.2g)", n, len(s), miss, @@ -111,8 +111,7 @@ def _pinv_tikhonov(x, reg): # _reg_pinv requires square Hermitian, which we have here inv, _, n = _reg_pinv(x, reg=reg, rank=None) logger.info( - f" Truncating at {n}/{len(x)} components and regularizing " - f"with α={reg:0.1e}" + f" Truncating at {n}/{len(x)} components and regularizing with α={reg:0.1e}" ) return inv, n diff --git a/mne/forward/_make_forward.py b/mne/forward/_make_forward.py index 64aadf69fec..6c77f47e312 100644 --- a/mne/forward/_make_forward.py +++ b/mne/forward/_make_forward.py @@ -160,8 +160,7 @@ def _create_meg_coil(coilset, ch, acc, do_es): break else: raise RuntimeError( - "Desired coil definition not found " - f"(type = {ch['coil_type']} acc = {acc})" + f"Desired coil definition not found (type = {ch['coil_type']} acc = {acc})" ) # Apply a coordinate transformation if so desired @@ -295,8 +294,8 @@ def _setup_bem(bem, bem_extra, neeg, mri_head_t, allow_none=False, verbose=None) else: if bem["surfs"][0]["coord_frame"] != FIFF.FIFFV_COORD_MRI: raise RuntimeError( - f'BEM is in {_coord_frame_name(bem["surfs"][0]["coord_frame"])} ' - 'coordinates, should be in MRI' + f"BEM is in {_coord_frame_name(bem['surfs'][0]['coord_frame'])} " + "coordinates, should be in MRI" ) if neeg > 0 and len(bem["surfs"]) == 1: raise RuntimeError( @@ -335,7 +334,7 @@ def _prep_meg_channels( del picks # Get channel info and names for MEG channels - logger.info(f'Read {len(info_meg["chs"])} MEG channels from info') + logger.info(f"Read {len(info_meg['chs'])} MEG channels from info") # Get MEG compensation channels compensator = post_picks = None @@ -352,7 +351,7 @@ def _prep_meg_channels( 'channels. Consider using "ignore_ref=True" in ' "calculation" ) - logger.info(f'{len(info["comps"])} compensation data sets in info') + logger.info(f"{len(info['comps'])} compensation data sets in info") # Compose a compensation data set if necessary # adapted from mne_make_ctf_comp() from mne_ctf_comp.c logger.info("Setting up compensation data...") diff --git a/mne/forward/forward.py b/mne/forward/forward.py index e3e5c08d2f8..f1c2c2d11d7 100644 --- a/mne/forward/forward.py +++ b/mne/forward/forward.py @@ -512,7 +512,7 @@ def _merge_fwds(fwds, *, verbose=None): a[k]["row_names"] = a[k]["row_names"] + b[k]["row_names"] a["nchan"] = a["nchan"] + b["nchan"] if len(fwds) > 1: - logger.info(f' Forward solutions combined: {", ".join(combined)}') + logger.info(f" Forward solutions combined: {', '.join(combined)}") return fwd @@ -677,8 +677,7 @@ def read_forward_solution(fname, include=(), exclude=(), *, ordered=True, verbos # Make sure forward solution is in either the MRI or HEAD coordinate frame if fwd["coord_frame"] not in (FIFF.FIFFV_COORD_MRI, FIFF.FIFFV_COORD_HEAD): raise ValueError( - "Only forward solutions computed in MRI or head " - "coordinates are acceptable" + "Only forward solutions computed in MRI or head coordinates are acceptable" ) # Transform each source space to the HEAD or MRI coordinate frame, @@ -1205,8 +1204,7 @@ def _triage_loose(src, loose, fixed="auto"): if fixed is True: if not all(v == 0.0 for v in loose.values()): raise ValueError( - 'When using fixed=True, loose must be 0. or "auto", ' - f"got {orig_loose}" + f'When using fixed=True, loose must be 0. or "auto", got {orig_loose}' ) elif fixed is False: if any(v == 0.0 for v in loose.values()): @@ -1666,8 +1664,7 @@ def apply_forward( for ch_name in fwd["sol"]["row_names"]: if ch_name not in info["ch_names"]: raise ValueError( - f"Channel {ch_name} of forward operator not present in " - "evoked_template." + f"Channel {ch_name} of forward operator not present in evoked_template." ) # project the source estimate to the sensor space diff --git a/mne/gui/_coreg.py b/mne/gui/_coreg.py index 98e3fbfc0b3..b365a2eed5a 100644 --- a/mne/gui/_coreg.py +++ b/mne/gui/_coreg.py @@ -1611,8 +1611,7 @@ def _configure_dock(self): func=self._set_subjects_dir, is_directory=True, icon=True, - tooltip="Load the path to the directory containing the " - "FreeSurfer subjects", + tooltip="Load the path to the directory containing the FreeSurfer subjects", layout=subjects_dir_layout, ) self._renderer._layout_add_widget( @@ -1741,8 +1740,7 @@ def _configure_dock(self): self._widgets["omit"] = self._renderer._dock_add_button( name="Omit", callback=self._omit_hsp, - tooltip="Exclude the head shape points that are far away from " - "the MRI head", + tooltip="Exclude the head shape points that are far away from the MRI head", layout=omit_hsp_layout_2, ) self._widgets["reset_omit"] = self._renderer._dock_add_button( diff --git a/mne/html_templates/_templates.py b/mne/html_templates/_templates.py index 9427f2d6a25..1f68303a51e 100644 --- a/mne/html_templates/_templates.py +++ b/mne/html_templates/_templates.py @@ -66,7 +66,7 @@ def _format_time_range(inst) -> str: def _format_projs(info) -> list[str]: """Format projectors.""" - projs = [f'{p["desc"]} ({"on" if p["active"] else "off"})' for p in info["projs"]] + projs = [f"{p['desc']} ({'on' if p['active'] else 'off'})" for p in info["projs"]] return projs diff --git a/mne/io/array/__init__.py b/mne/io/array/__init__.py index aea21ef42ce..ad53f7c817f 100644 --- a/mne/io/array/__init__.py +++ b/mne/io/array/__init__.py @@ -4,4 +4,4 @@ # License: BSD-3-Clause # Copyright the MNE-Python contributors. -from .array import RawArray +from ._array import RawArray diff --git a/mne/io/array/array.py b/mne/io/array/_array.py similarity index 100% rename from mne/io/array/array.py rename to mne/io/array/_array.py diff --git a/mne/io/artemis123/tests/test_artemis123.py b/mne/io/artemis123/tests/test_artemis123.py index 039108eb915..610f32ba5da 100644 --- a/mne/io/artemis123/tests/test_artemis123.py +++ b/mne/io/artemis123/tests/test_artemis123.py @@ -35,9 +35,9 @@ def _assert_trans(actual, desired, dist_tol=0.017, angle_tol=5.0): angle = np.rad2deg(_angle_between_quats(quat_est, quat)) dist = np.linalg.norm(trans - trans_est) - assert ( - dist <= dist_tol - ), f"{1000 * dist:0.3f} > {1000 * dist_tol:0.3f} mm translation" + assert dist <= dist_tol, ( + f"{1000 * dist:0.3f} > {1000 * dist_tol:0.3f} mm translation" + ) assert angle <= angle_tol, f"{angle:0.3f} > {angle_tol:0.3f}° rotation" diff --git a/mne/io/base.py b/mne/io/base.py index 4f5f2436bd7..280330367f7 100644 --- a/mne/io/base.py +++ b/mne/io/base.py @@ -1013,8 +1013,7 @@ def get_data( if n_rejected > 0: if reject_by_annotation == "omit": msg = ( - "Omitting {} of {} ({:.2%}) samples, retaining {}" - " ({:.2%}) samples." + "Omitting {} of {} ({:.2%}) samples, retaining {} ({:.2%}) samples." ) logger.info( msg.format( @@ -2157,7 +2156,7 @@ def append(self, raws, preload=None): for edge_samp in edge_samps: onset = _sync_onset(self, edge_samp / self.info["sfreq"], True) logger.debug( - f"Marking edge at {edge_samp} samples " f"(maps to {onset:0.3f} sec)" + f"Marking edge at {edge_samp} samples (maps to {onset:0.3f} sec)" ) self.annotations.append(onset, 0.0, "BAD boundary") self.annotations.append(onset, 0.0, "EDGE boundary") diff --git a/mne/io/ctf/ctf.py b/mne/io/ctf/ctf.py index 44a4e39adf6..971ac51c2f6 100644 --- a/mne/io/ctf/ctf.py +++ b/mne/io/ctf/ctf.py @@ -267,7 +267,7 @@ def _get_sample_info(fname, res4, system_clock): fid.seek(offset, 0) this_data = np.fromfile(fid, ">i4", res4["nsamp"]) if len(this_data) != res4["nsamp"]: - raise RuntimeError(f"Cannot read data for trial {t+1}.") + raise RuntimeError(f"Cannot read data for trial {t + 1}.") end = np.where(this_data == 0)[0] if len(end) > 0: n_samp = samp_offset + end[0] diff --git a/mne/io/ctf/info.py b/mne/io/ctf/info.py index 1b96d8bd88f..685a20792d3 100644 --- a/mne/io/ctf/info.py +++ b/mne/io/ctf/info.py @@ -50,8 +50,7 @@ def _pick_isotrak_and_hpi_coils(res4, coils, t): if p["coord_frame"] == FIFF.FIFFV_MNE_COORD_CTF_DEVICE: if t is None or t["t_ctf_dev_dev"] is None: raise RuntimeError( - "No coordinate transformation " - "available for HPI coil locations" + "No coordinate transformation available for HPI coil locations" ) d = dict( kind=kind, diff --git a/mne/io/ctf/tests/test_ctf.py b/mne/io/ctf/tests/test_ctf.py index 4a5dd846655..448ea90baba 100644 --- a/mne/io/ctf/tests/test_ctf.py +++ b/mne/io/ctf/tests/test_ctf.py @@ -243,9 +243,9 @@ def test_read_ctf(tmp_path): # Make sure all digitization points are in the MNE head coord frame for p in raw.info["dig"]: - assert ( - p["coord_frame"] == FIFF.FIFFV_COORD_HEAD - ), "dig points must be in FIFF.FIFFV_COORD_HEAD" + assert p["coord_frame"] == FIFF.FIFFV_COORD_HEAD, ( + "dig points must be in FIFF.FIFFV_COORD_HEAD" + ) if fname.endswith("catch-alp-good-f.ds"): # omit points from .pos file with raw.info._unlock(): diff --git a/mne/io/edf/edf.py b/mne/io/edf/edf.py index bb79c46f24a..fadd1b83857 100644 --- a/mne/io/edf/edf.py +++ b/mne/io/edf/edf.py @@ -628,7 +628,7 @@ def _get_info( if len(chs_without_types): msg = ( "Could not determine channel type of the following channels, " - f'they will be set as EEG:\n{", ".join(chs_without_types)}' + f"they will be set as EEG:\n{', '.join(chs_without_types)}" ) logger.info(msg) @@ -712,8 +712,8 @@ def _get_info( if info["highpass"] > info["lowpass"]: warn( - f'Highpass cutoff frequency {info["highpass"]} is greater ' - f'than lowpass cutoff frequency {info["lowpass"]}, ' + f"Highpass cutoff frequency {info['highpass']} is greater " + f"than lowpass cutoff frequency {info['lowpass']}, " "setting values to 0 and Nyquist." ) info["highpass"] = 0.0 diff --git a/mne/io/egi/egimff.py b/mne/io/egi/egimff.py index b2f08020e15..c3a10fb72cd 100644 --- a/mne/io/egi/egimff.py +++ b/mne/io/egi/egimff.py @@ -106,7 +106,7 @@ def _read_mff_header(filepath): if bad: raise RuntimeError( "EGI epoch first/last samps could not be parsed:\n" - f'{list(epochs["first_samps"])}\n{list(epochs["last_samps"])}' + f"{list(epochs['first_samps'])}\n{list(epochs['last_samps'])}" ) summaryinfo.update(epochs) # index which samples in raw are actually readable from disk (i.e., not diff --git a/mne/io/fieldtrip/fieldtrip.py b/mne/io/fieldtrip/fieldtrip.py index 5d94d3e0a80..c8521722003 100644 --- a/mne/io/fieldtrip/fieldtrip.py +++ b/mne/io/fieldtrip/fieldtrip.py @@ -7,7 +7,7 @@ from ...epochs import EpochsArray from ...evoked import EvokedArray from ...utils import _check_fname, _import_pymatreader_funcs -from ..array.array import RawArray +from ..array._array import RawArray from .utils import ( _create_event_metadata, _create_events, diff --git a/mne/io/fil/tests/test_fil.py b/mne/io/fil/tests/test_fil.py index 06d3d924319..df15dd13353 100644 --- a/mne/io/fil/tests/test_fil.py +++ b/mne/io/fil/tests/test_fil.py @@ -87,9 +87,9 @@ def _fil_megmag(raw_test, raw_mat): mat_list = raw_mat["label"] mat_inds = _match_str(test_list, mat_list) - assert len(mat_inds) == len( - test_inds - ), "Number of magnetometer channels in RAW does not match .mat file!" + assert len(mat_inds) == len(test_inds), ( + "Number of magnetometer channels in RAW does not match .mat file!" + ) a = raw_test._data[test_inds, :] b = raw_mat["trial"][mat_inds, :] * 1e-15 # fT to T @@ -106,9 +106,9 @@ def _fil_stim(raw_test, raw_mat): mat_list = raw_mat["label"] mat_inds = _match_str(test_list, mat_list) - assert len(mat_inds) == len( - test_inds - ), "Number of stim channels in RAW does not match .mat file!" + assert len(mat_inds) == len(test_inds), ( + "Number of stim channels in RAW does not match .mat file!" + ) a = raw_test._data[test_inds, :] b = raw_mat["trial"][mat_inds, :] # fT to T @@ -122,9 +122,9 @@ def _fil_sensorpos(raw_test, raw_mat): grad_list = raw_mat["coil_label"] grad_inds = _match_str(test_list, grad_list) - assert len(grad_inds) == len( - test_inds - ), "Number of channels with position data in RAW does not match .mat file!" + assert len(grad_inds) == len(test_inds), ( + "Number of channels with position data in RAW does not match .mat file!" + ) mat_pos = raw_mat["coil_pos"][grad_inds, :] mat_ori = raw_mat["coil_ori"][grad_inds, :] diff --git a/mne/io/neuralynx/tests/test_neuralynx.py b/mne/io/neuralynx/tests/test_neuralynx.py index ea5cdbccdfb..18578ef4ab7 100644 --- a/mne/io/neuralynx/tests/test_neuralynx.py +++ b/mne/io/neuralynx/tests/test_neuralynx.py @@ -143,9 +143,9 @@ def test_neuralynx(): assert raw.info["meas_date"] == meas_date_utc, "meas_date not set correctly" # test that channel selection worked - assert ( - raw.ch_names == expected_chan_names - ), "labels in raw.ch_names don't match expected channel names" + assert raw.ch_names == expected_chan_names, ( + "labels in raw.ch_names don't match expected channel names" + ) mne_y = raw.get_data() # in V @@ -216,9 +216,9 @@ def test_neuralynx_gaps(): n_expected_gaps = 3 n_expected_missing_samples = 130 assert len(raw.annotations) == n_expected_gaps, "Wrong number of gaps detected" - assert ( - (mne_y[0, :] == 0).sum() == n_expected_missing_samples - ), "Number of true and inferred missing samples differ" + assert (mne_y[0, :] == 0).sum() == n_expected_missing_samples, ( + "Number of true and inferred missing samples differ" + ) # read in .mat files containing original gaps matchans = ["LAHC1_3_gaps.mat", "LAHC2_3_gaps.mat"] diff --git a/mne/io/nirx/nirx.py b/mne/io/nirx/nirx.py index 53a812e7a21..5d9b79b57cc 100644 --- a/mne/io/nirx/nirx.py +++ b/mne/io/nirx/nirx.py @@ -210,7 +210,7 @@ def __init__(self, fname, saturated, *, preload=False, encoding=None, verbose=No ): warn( "Only import of data from NIRScout devices have been " - f'thoroughly tested. You are using a {hdr["GeneralInfo"]["Device"]}' + f"thoroughly tested. You are using a {hdr['GeneralInfo']['Device']}" " device." ) diff --git a/mne/io/tests/test_raw.py b/mne/io/tests/test_raw.py index b559ce07068..8f773533ae4 100644 --- a/mne/io/tests/test_raw.py +++ b/mne/io/tests/test_raw.py @@ -533,7 +533,7 @@ def _test_raw_crop(reader, t_prop, kwargs): n_samp = 50 # crop to this number of samples (per instance) crop_t = n_samp / raw_1.info["sfreq"] t_start = t_prop * crop_t # also crop to some fraction into the first inst - extra = f' t_start={t_start}, preload={kwargs.get("preload", False)}' + extra = f" t_start={t_start}, preload={kwargs.get('preload', False)}" stop = (n_samp - 1) / raw_1.info["sfreq"] raw_1.crop(0, stop) assert len(raw_1.times) == 50 diff --git a/mne/label.py b/mne/label.py index f68144106c3..02bf9dc09c0 100644 --- a/mne/label.py +++ b/mne/label.py @@ -264,8 +264,7 @@ def __init__( if not (len(vertices) == len(values) == len(pos)): raise ValueError( - "vertices, values and pos need to have same " - "length (number of vertices)" + "vertices, values and pos need to have same length (number of vertices)" ) # name @@ -416,7 +415,7 @@ def __sub__(self, other): else: keep = np.arange(len(self.vertices)) - name = f'{self.name or "unnamed"} - {other.name or "unnamed"}' + name = f"{self.name or 'unnamed'} - {other.name or 'unnamed'}" return Label( self.vertices[keep], self.pos[keep], @@ -976,8 +975,7 @@ def _get_label_src(label, src): src = _ensure_src(src) if src.kind != "surface": raise RuntimeError( - "Cannot operate on SourceSpaces that are not " - f"surface type, got {src.kind}" + f"Cannot operate on SourceSpaces that are not surface type, got {src.kind}" ) if label.hemi == "lh": hemi_src = src[0] @@ -1585,8 +1583,7 @@ def stc_to_label( vertno = np.where(src[hemi_idx]["inuse"])[0] if not len(np.setdiff1d(this_vertno, vertno)) == 0: raise RuntimeError( - "stc contains vertices not present " - "in source space, did you morph?" + "stc contains vertices not present in source space, did you morph?" ) tmp = np.zeros((len(vertno), this_data.shape[1])) this_vertno_idx = np.searchsorted(vertno, this_vertno) @@ -2151,8 +2148,7 @@ def _read_annot(fname): cands = _read_annot_cands(dir_name) if len(cands) == 0: raise OSError( - f"No such file {fname}, no candidate parcellations " - "found in directory" + f"No such file {fname}, no candidate parcellations found in directory" ) else: raise OSError( diff --git a/mne/minimum_norm/inverse.py b/mne/minimum_norm/inverse.py index e5129a4822f..7c789503ac1 100644 --- a/mne/minimum_norm/inverse.py +++ b/mne/minimum_norm/inverse.py @@ -673,7 +673,7 @@ def prepare_inverse_operator( inv["eigen_leads"]["data"] = sqrt(scale) * inv["eigen_leads"]["data"] logger.info( - " Scaled noise and source covariance from nave = %d to" " nave = %d", + " Scaled noise and source covariance from nave = %d to nave = %d", inv["nave"], nave, ) @@ -2011,7 +2011,7 @@ def make_inverse_operator( logger.info( f" scaling factor to adjust the trace = {trace_GRGT:g} " f"(nchan = {eigen_fields.shape[0]} " - f'nzero = {(noise_cov["eig"] <= 0).sum()})' + f"nzero = {(noise_cov['eig'] <= 0).sum()})" ) # MNE-ify everything for output diff --git a/mne/minimum_norm/tests/test_inverse.py b/mne/minimum_norm/tests/test_inverse.py index aa3f8294027..5b5c941a9ac 100644 --- a/mne/minimum_norm/tests/test_inverse.py +++ b/mne/minimum_norm/tests/test_inverse.py @@ -130,8 +130,7 @@ def _compare(a, b): for k, v in a.items(): if k not in b and k not in skip_types: raise ValueError( - "First one had one second one didn't:\n" - f"{k} not in {b.keys()}" + f"First one had one second one didn't:\n{k} not in {b.keys()}" ) if k not in skip_types: last_keys.pop() diff --git a/mne/morph.py b/mne/morph.py index 9c475bff1e9..a8278731f3c 100644 --- a/mne/morph.py +++ b/mne/morph.py @@ -200,8 +200,7 @@ def compute_source_morph( if kind not in "surface" and xhemi: raise ValueError( - "Inter-hemispheric morphing can only be used " - "with surface source estimates." + "Inter-hemispheric morphing can only be used with surface source estimates." ) if sparse and kind != "surface": raise ValueError("Only surface source estimates can compute a sparse morph.") @@ -1301,8 +1300,7 @@ def grade_to_vertices(subject, grade, subjects_dir=None, n_jobs=None, verbose=No if isinstance(grade, list): if not len(grade) == 2: raise ValueError( - "grade as a list must have two elements " - "(arrays of output vertices)" + "grade as a list must have two elements (arrays of output vertices)" ) vertices = grade else: @@ -1385,8 +1383,7 @@ def _surf_upsampling_mat(idx_from, e, smooth): smooth = _ensure_int(smooth, "smoothing steps") if smooth <= 0: # == 0 is handled in a shortcut above raise ValueError( - "The number of smoothing operations has to be at least 0, got " - f"{smooth}" + f"The number of smoothing operations has to be at least 0, got {smooth}" ) smooth = smooth - 1 # idx will gradually expand from idx_from -> np.arange(n_tot) diff --git a/mne/preprocessing/_fine_cal.py b/mne/preprocessing/_fine_cal.py index b43983a87eb..06041cd7f8e 100644 --- a/mne/preprocessing/_fine_cal.py +++ b/mne/preprocessing/_fine_cal.py @@ -401,7 +401,7 @@ def _adjust_mag_normals(info, data, origin, ext_order, *, angle_limit, err_limit good = not bool(reason) assert np.allclose(np.linalg.norm(zs, axis=1), 1.0) logger.info(f" Fit mismatch {first_err:0.2f}→{last_err:0.2f}%") - logger.info(f' Data segment {"" if good else "un"}usable{reason}') + logger.info(f" Data segment {'' if good else 'un'}usable{reason}") # Reformat zs and cals to be the n_mags (including bads) assert zs.shape == (len(data), 3) assert cals.shape == (len(data), 1) diff --git a/mne/preprocessing/artifact_detection.py b/mne/preprocessing/artifact_detection.py index 0a4c8b6a24d..8674d6e22b3 100644 --- a/mne/preprocessing/artifact_detection.py +++ b/mne/preprocessing/artifact_detection.py @@ -213,7 +213,7 @@ def annotate_movement( onsets, offsets = hp_ts[onsets], hp_ts[offsets] bad_pct = 100 * (offsets - onsets).sum() / t_tot logger.info( - "Omitting %5.1f%% (%3d segments): " "ω >= %5.1f°/s (max: %0.1f°/s)", + "Omitting %5.1f%% (%3d segments): ω >= %5.1f°/s (max: %0.1f°/s)", bad_pct, len(onsets), rotation_velocity_limit, @@ -233,7 +233,7 @@ def annotate_movement( onsets, offsets = hp_ts[onsets], hp_ts[offsets] bad_pct = 100 * (offsets - onsets).sum() / t_tot logger.info( - "Omitting %5.1f%% (%3d segments): " "v >= %5.4fm/s (max: %5.4fm/s)", + "Omitting %5.1f%% (%3d segments): v >= %5.4fm/s (max: %5.4fm/s)", bad_pct, len(onsets), translation_velocity_limit, @@ -286,7 +286,7 @@ def annotate_movement( onsets, offsets = hp_ts[onsets], hp_ts[offsets] bad_pct = 100 * (offsets - onsets).sum() / t_tot logger.info( - "Omitting %5.1f%% (%3d segments): " "disp >= %5.4fm (max: %5.4fm)", + "Omitting %5.1f%% (%3d segments): disp >= %5.4fm (max: %5.4fm)", bad_pct, len(onsets), mean_distance_limit, @@ -539,7 +539,7 @@ def annotate_break( if ignore: logger.info( f"Ignoring annotations with descriptions starting " - f'with: {", ".join(ignore)}' + f"with: {', '.join(ignore)}" ) else: annotations = annotations_from_events( diff --git a/mne/preprocessing/eog.py b/mne/preprocessing/eog.py index 20e5481f89c..13b6f2ef672 100644 --- a/mne/preprocessing/eog.py +++ b/mne/preprocessing/eog.py @@ -213,12 +213,12 @@ def _get_eog_channel_index(ch_name, inst): if not_found: raise ValueError( f"The specified EOG channel{_pl(not_found)} " - f'cannot be found: {", ".join(not_found)}' + f"cannot be found: {', '.join(not_found)}" ) eog_inds = pick_channels(inst.ch_names, include=ch_names) - logger.info(f'Using EOG channel{_pl(ch_names)}: {", ".join(ch_names)}') + logger.info(f"Using EOG channel{_pl(ch_names)}: {', '.join(ch_names)}") return eog_inds diff --git a/mne/preprocessing/hfc.py b/mne/preprocessing/hfc.py index f8a65510a9a..41bf6bbd232 100644 --- a/mne/preprocessing/hfc.py +++ b/mne/preprocessing/hfc.py @@ -68,8 +68,7 @@ def compute_proj_hfc( n_chs = len(coils[5]) if n_chs != info["nchan"]: raise ValueError( - f'Only {n_chs}/{info["nchan"]} picks could be interpreted ' - "as MEG channels." + f"Only {n_chs}/{info['nchan']} picks could be interpreted as MEG channels." ) S = _sss_basis(exp, coils) del coils diff --git a/mne/preprocessing/ica.py b/mne/preprocessing/ica.py index 3ea11e0531e..f35fe24c1ee 100644 --- a/mne/preprocessing/ica.py +++ b/mne/preprocessing/ica.py @@ -560,7 +560,7 @@ def __repr__(self): """ICA fit information.""" infos = self._get_infos_for_repr() - s = f'{infos.fit_on or "no"} decomposition, method: {infos.fit_method}' + s = f"{infos.fit_on or 'no'} decomposition, method: {infos.fit_method}" if infos.fit_on is not None: s += ( @@ -568,8 +568,8 @@ def __repr__(self): f"{infos.fit_n_samples} samples), " f"{infos.fit_n_components} ICA components " f"({infos.fit_n_pca_components} PCA components available), " - f'channel types: {", ".join(infos.ch_types)}, ' - f'{len(infos.excludes) or "no"} sources marked for exclusion' + f"channel types: {', '.join(infos.ch_types)}, " + f"{len(infos.excludes) or 'no'} sources marked for exclusion" ) return f"" @@ -698,7 +698,7 @@ def fit( warn( f"The following parameters passed to ICA.fit() will be " f"ignored, as they only affect raw data (and it appears " - f'you passed epochs): {", ".join(ignored_params)}' + f"you passed epochs): {', '.join(ignored_params)}" ) picks = _picks_to_idx( @@ -875,7 +875,7 @@ def _do_proj(self, data, log_suffix=""): logger.info( f" Applying projection operator with {nproj} " f"vector{_pl(nproj)}" - f'{" " if log_suffix else ""}{log_suffix}' + f"{' ' if log_suffix else ''}{log_suffix}" ) if self.noise_cov is None: # otherwise it's in pre_whitener_ data = proj @ data @@ -1162,7 +1162,7 @@ def get_explained_variance_ratio(self, inst, *, components=None, ch_type=None): raise ValueError( f"You requested operation on the channel type " f'"{ch_type}", but only the following channel types are ' - f'supported: {", ".join(allowed_ch_types)}' + f"supported: {', '.join(allowed_ch_types)}" ) del ch_type @@ -2393,8 +2393,7 @@ def _pick_sources(self, data, include, exclude, n_pca_components): unmixing = np.dot(unmixing, pca_components) logger.info( - f" Projecting back using {_n_pca_comp} " - f"PCA component{_pl(_n_pca_comp)}" + f" Projecting back using {_n_pca_comp} PCA component{_pl(_n_pca_comp)}" ) mixing = np.eye(_n_pca_comp) mixing[: self.n_components_, : self.n_components_] = self.mixing_matrix_ @@ -3368,8 +3367,7 @@ def corrmap( is_subject = False else: raise ValueError( - "`template` must be a length-2 tuple or an array the " - "size of the ICA maps." + "`template` must be a length-2 tuple or an array the size of the ICA maps." ) template_fig, labelled_ics = None, None diff --git a/mne/preprocessing/ieeg/_volume.py b/mne/preprocessing/ieeg/_volume.py index b4997b2e3f8..af2dcf4328b 100644 --- a/mne/preprocessing/ieeg/_volume.py +++ b/mne/preprocessing/ieeg/_volume.py @@ -109,7 +109,7 @@ def _warn_missing_chs(info, dig_image, after_warp=False): if missing_ch: warn( f"Channel{_pl(missing_ch)} " - f'{", ".join(repr(ch) for ch in missing_ch)} not assigned ' + f"{', '.join(repr(ch) for ch in missing_ch)} not assigned " "voxels " + (f" after applying {after_warp}" if after_warp else "") ) diff --git a/mne/preprocessing/infomax_.py b/mne/preprocessing/infomax_.py index f0722ce5267..b445ac7116c 100644 --- a/mne/preprocessing/infomax_.py +++ b/mne/preprocessing/infomax_.py @@ -320,8 +320,7 @@ def infomax( if l_rate > min_l_rate: if verbose: logger.info( - f"... lowering learning rate to {l_rate:g}" - "\n... re-starting..." + f"... lowering learning rate to {l_rate:g}\n... re-starting..." ) else: raise ValueError( diff --git a/mne/preprocessing/maxwell.py b/mne/preprocessing/maxwell.py index 789c8520f05..8c9c0a93957 100644 --- a/mne/preprocessing/maxwell.py +++ b/mne/preprocessing/maxwell.py @@ -507,7 +507,7 @@ def _prep_maxwell_filter( extended_proj_.append(proj["data"]["data"][:, idx]) extended_proj = np.concatenate(extended_proj_) logger.info( - " Extending external SSS basis using %d projection " "vectors", + " Extending external SSS basis using %d projection vectors", len(extended_proj), ) @@ -566,8 +566,8 @@ def _prep_maxwell_filter( dist = np.sqrt(np.sum(_sq(diff))) if dist > 25.0: warn( - f'Head position change is over 25 mm ' - f'({", ".join(f"{x:0.1f}" for x in diff)}) = {dist:0.1f} mm' + f"Head position change is over 25 mm " + f"({', '.join(f'{x:0.1f}' for x in diff)}) = {dist:0.1f} mm" ) # Reconstruct raw file object with spatiotemporal processed data @@ -2579,7 +2579,7 @@ def find_bad_channels_maxwell( freq_loc = "below" if raw.info["lowpass"] < h_freq else "equal to" msg = ( f"The input data has already been low-pass filtered with a " - f'{raw.info["lowpass"]} Hz cutoff frequency, which is ' + f"{raw.info['lowpass']} Hz cutoff frequency, which is " f"{freq_loc} the requested cutoff of {h_freq} Hz. Not " f"applying low-pass filter." ) diff --git a/mne/preprocessing/nirs/_beer_lambert_law.py b/mne/preprocessing/nirs/_beer_lambert_law.py index 92a2e55b9fb..c17cf31110c 100644 --- a/mne/preprocessing/nirs/_beer_lambert_law.py +++ b/mne/preprocessing/nirs/_beer_lambert_law.py @@ -76,7 +76,7 @@ def beer_lambert_law(raw, ppf=6.0): for ki, kind in zip((ii, jj), ("hbo", "hbr")): ch = raw.info["chs"][ki] ch.update(coil_type=coil_dict[kind], unit=FIFF.FIFF_UNIT_MOL) - new_name = f'{ch["ch_name"].split(" ")[0]} {kind}' + new_name = f"{ch['ch_name'].split(' ')[0]} {kind}" rename[ch["ch_name"]] = new_name raw.rename_channels(rename) diff --git a/mne/preprocessing/tests/test_maxwell.py b/mne/preprocessing/tests/test_maxwell.py index f5e816258f8..002d4555ff8 100644 --- a/mne/preprocessing/tests/test_maxwell.py +++ b/mne/preprocessing/tests/test_maxwell.py @@ -980,9 +980,9 @@ def _assert_shielding(raw_sss, erm_power, min_factor, max_factor=np.inf, meg="ma sss_power = raw_sss[picks][0].ravel() sss_power = np.sqrt(np.sum(sss_power * sss_power)) factor = erm_power / sss_power - assert ( - min_factor <= factor < max_factor - ), f"Shielding factor not {min_factor:0.3f} <= {factor:0.3f} < {max_factor:0.3f}" + assert min_factor <= factor < max_factor, ( + f"Shielding factor not {min_factor:0.3f} <= {factor:0.3f} < {max_factor:0.3f}" + ) @buggy_mkl_svd diff --git a/mne/preprocessing/xdawn.py b/mne/preprocessing/xdawn.py index 0b1132761b1..606b49370df 100644 --- a/mne/preprocessing/xdawn.py +++ b/mne/preprocessing/xdawn.py @@ -198,8 +198,7 @@ def _fit_xdawn( evals, evecs = linalg.eigh(evo_cov, signal_cov) except np.linalg.LinAlgError as exp: raise ValueError( - "Could not compute eigenvalues, ensure " - f"proper regularization ({exp})" + f"Could not compute eigenvalues, ensure proper regularization ({exp})" ) evecs = evecs[:, np.argsort(evals)[::-1]] # sort eigenvectors evecs /= np.apply_along_axis(np.linalg.norm, 0, evecs) diff --git a/mne/report/report.py b/mne/report/report.py index 732c1a5c8b3..852feebc638 100644 --- a/mne/report/report.py +++ b/mne/report/report.py @@ -324,7 +324,7 @@ def _check_tags(tags) -> tuple[str]: raise TypeError( f"All tags must be strings without spaces or special characters, " f"but got the following instead: " - f'{", ".join([str(tag) for tag in bad_tags])}' + f"{', '.join([str(tag) for tag in bad_tags])}" ) # Check for invalid characters @@ -338,7 +338,7 @@ def _check_tags(tags) -> tuple[str]: if bad_tags: raise ValueError( f"The following tags contained invalid characters: " - f'{", ".join(repr(tag) for tag in bad_tags)}' + f"{', '.join(repr(tag) for tag in bad_tags)}" ) return tags @@ -429,8 +429,7 @@ def _fig_to_img( output = BytesIO() dpi = fig.get_dpi() logger.debug( - f"Saving figure with dimension {fig.get_size_inches()} inches with " - f"{dpi} dpi" + f"Saving figure with dimension {fig.get_size_inches()} inches with {dpi} dpi" ) # https://pillow.readthedocs.io/en/stable/handbook/image-file-formats.html @@ -913,7 +912,7 @@ def __repr__(self): if len(titles) > 0: titles = [f" {t}" for t in titles] # indent tr = max(len(s), 50) # trim to larger of opening str and 50 - titles = [f"{t[:tr - 2]} …" if len(t) > tr else t for t in titles] + titles = [f"{t[: tr - 2]} …" if len(t) > tr else t for t in titles] # then trim to the max length of all of these tr = max(len(title) for title in titles) tr = max(tr, len(s)) @@ -2761,9 +2760,7 @@ def _init_render(self, verbose=None): if inc_fname.endswith(".js"): include.append( - f'" + f'' ) elif inc_fname.endswith(".css"): include.append(f'') @@ -3649,7 +3646,7 @@ def _add_evoked_joint( ) ) - title = f'Time course ({_handle_default("titles")[ch_type]})' + title = f"Time course ({_handle_default('titles')[ch_type]})" self._add_figure( fig=fig, title=title, @@ -4121,7 +4118,7 @@ def _add_epochs( assert "eeg" in ch_type title_start = "ERP image" - title = f'{title_start} ({_handle_default("titles")[ch_type]})' + title = f"{title_start} ({_handle_default('titles')[ch_type]})" self._add_figure( fig=fig, diff --git a/mne/source_estimate.py b/mne/source_estimate.py index 024d630535c..deeb3a43ede 100644 --- a/mne/source_estimate.py +++ b/mne/source_estimate.py @@ -1388,8 +1388,7 @@ def transform(self, func, idx=None, tmin=None, tmax=None, copy=False): ] else: raise ValueError( - "copy must be True if transformed data has " - "more than 2 dimensions" + "copy must be True if transformed data has more than 2 dimensions" ) else: # return new or overwritten stc @@ -3633,7 +3632,7 @@ def _volume_labels(src, labels, mri_resolution): ] nnz = sum(len(v) != 0 for v in vertices) logger.info( - "%d/%d atlas regions had at least one vertex " "in the source space", + "%d/%d atlas regions had at least one vertex in the source space", nnz, len(out_labels), ) @@ -4006,7 +4005,7 @@ def stc_near_sensors( min_dist = pdist(pos).min() * 1000 logger.info( - f' Minimum {"projected " if project else ""}intra-sensor distance: ' + f" Minimum {'projected ' if project else ''}intra-sensor distance: " f"{min_dist:0.1f} mm" ) @@ -4034,7 +4033,7 @@ def stc_near_sensors( if len(missing): warn( f"Channel{_pl(missing)} missing in STC: " - f'{", ".join(evoked.ch_names[mi] for mi in missing)}' + f"{', '.join(evoked.ch_names[mi] for mi in missing)}" ) nz_data = w @ evoked.data diff --git a/mne/source_space/_source_space.py b/mne/source_space/_source_space.py index f5e8b76a1fa..d64989961cf 100644 --- a/mne/source_space/_source_space.py +++ b/mne/source_space/_source_space.py @@ -743,7 +743,7 @@ def export_volume( # generate use warnings for clipping if n_diff > 0: warn( - f'{n_diff} {src["type"]} vertices lay outside of volume ' + f"{n_diff} {src['type']} vertices lay outside of volume " f"space. Consider using a larger volume space." ) # get surface id or use default value @@ -1546,7 +1546,7 @@ def setup_source_space( # pre-load ico/oct surf (once) for speed, if necessary if stype not in ("spacing", "all"): logger.info( - f'Doing the {dict(ico="icosa", oct="octa")[stype]}hedral vertex picking...' + f"Doing the {dict(ico='icosa', oct='octa')[stype]}hedral vertex picking..." ) for hemi, surf in zip(["lh", "rh"], surfs): logger.info(f"Loading {surf}...") @@ -2916,8 +2916,7 @@ def _get_vertex_map_nn( raise RuntimeError(f"vertex {one} would be used multiple times.") one = one[0] logger.info( - "Source space vertex moved from %d to %d because of " - "double occupation.", + "Source space vertex moved from %d to %d because of double occupation.", was, one, ) @@ -3167,8 +3166,7 @@ def _compare_source_spaces(src0, src1, mode="exact", nearest=True, dist_tol=1.5e assert_array_equal( s["vertno"], np.where(s["inuse"])[0], - f'src{ii}[{si}]["vertno"] != ' - f'np.where(src{ii}[{si}]["inuse"])[0]', + f'src{ii}[{si}]["vertno"] != np.where(src{ii}[{si}]["inuse"])[0]', ) assert_equal(len(s0["vertno"]), len(s1["vertno"])) agreement = np.mean(s0["inuse"] == s1["inuse"]) diff --git a/mne/surface.py b/mne/surface.py index 21432e7edfd..9e24147a080 100644 --- a/mne/surface.py +++ b/mne/surface.py @@ -214,7 +214,7 @@ def get_meg_helmet_surf(info, trans=None, *, verbose=None): ] ) logger.info( - "Getting helmet for system %s (derived from %d MEG " "channel locations)", + "Getting helmet for system %s (derived from %d MEG channel locations)", system, len(rr), ) @@ -733,7 +733,7 @@ def __init__(self, surf, *, mode="old", verbose=None): else: self._init_old() logger.debug( - f'Setting up {mode} interior check for {len(self.surf["rr"])} ' + f"Setting up {mode} interior check for {len(self.surf['rr'])} " f"points took {(time.time() - t0) * 1000:0.1f} ms" ) @@ -761,8 +761,7 @@ def _init_pyvista(self): def __call__(self, rr, n_jobs=None, verbose=None): n_orig = len(rr) logger.info( - f"Checking surface interior status for " - f'{n_orig} point{_pl(n_orig, " ")}...' + f"Checking surface interior status for {n_orig} point{_pl(n_orig, ' ')}..." ) t0 = time.time() if self.mode == "pyvista": @@ -770,7 +769,7 @@ def __call__(self, rr, n_jobs=None, verbose=None): else: inside = self._call_old(rr, n_jobs) n = inside.sum() - logger.info(f' Total {n}/{n_orig} point{_pl(n, " ")} inside the surface') + logger.info(f" Total {n}/{n_orig} point{_pl(n, ' ')} inside the surface") logger.info(f"Interior check completed in {(time.time() - t0) * 1000:0.1f} ms") return inside @@ -792,7 +791,7 @@ def _call_old(self, rr, n_jobs): n = (in_mask).sum() n_pad = str(n).rjust(prec) logger.info( - f' Found {n_pad}/{n_orig} point{_pl(n, " ")} ' + f" Found {n_pad}/{n_orig} point{_pl(n, ' ')} " f"inside an interior sphere of radius " f"{1000 * self.inner_r:6.1f} mm" ) @@ -801,7 +800,7 @@ def _call_old(self, rr, n_jobs): n = (out_mask).sum() n_pad = str(n).rjust(prec) logger.info( - f' Found {n_pad}/{n_orig} point{_pl(n, " ")} ' + f" Found {n_pad}/{n_orig} point{_pl(n, ' ')} " f"outside an exterior sphere of radius " f"{1000 * self.outer_r:6.1f} mm" ) @@ -818,7 +817,7 @@ def _call_old(self, rr, n_jobs): n_pad = str(n).rjust(prec) check_pad = str(len(del_outside)).rjust(prec) logger.info( - f' Found {n_pad}/{check_pad} point{_pl(n, " ")} outside using ' + f" Found {n_pad}/{check_pad} point{_pl(n, ' ')} outside using " "surface Qhull" ) @@ -828,7 +827,7 @@ def _call_old(self, rr, n_jobs): n_pad = str(n).rjust(prec) check_pad = str(len(solid_outside)).rjust(prec) logger.info( - f' Found {n_pad}/{check_pad} point{_pl(n, " ")} outside using ' + f" Found {n_pad}/{check_pad} point{_pl(n, ' ')} outside using " "solid angles" ) inside[idx[solid_outside]] = False diff --git a/mne/tests/test_annotations.py b/mne/tests/test_annotations.py index 6b1356ae107..4d0db170e2a 100644 --- a/mne/tests/test_annotations.py +++ b/mne/tests/test_annotations.py @@ -1450,8 +1450,7 @@ def test_repr(): # long annotation repr (> 79 characters, will be shortened) r = repr(Annotations(range(14), [0] * 14, list("abcdefghijklmn"))) assert r == ( - "" + "" ) # empty Annotations diff --git a/mne/tests/test_dipole.py b/mne/tests/test_dipole.py index e93d4031646..f230eaa4256 100644 --- a/mne/tests/test_dipole.py +++ b/mne/tests/test_dipole.py @@ -214,9 +214,9 @@ def test_dipole_fitting(tmp_path): # Sanity check: do our residuals have less power than orig data? data_rms = np.sqrt(np.sum(evoked.data**2, axis=0)) resi_rms = np.sqrt(np.sum(residual.data**2, axis=0)) - assert ( - data_rms > resi_rms * 0.95 - ).all(), f"{(data_rms / resi_rms).min()} (factor: {0.95})" + assert (data_rms > resi_rms * 0.95).all(), ( + f"{(data_rms / resi_rms).min()} (factor: {0.95})" + ) # Compare to original points transform_surface_to(fwd["src"][0], "head", fwd["mri_head_t"]) diff --git a/mne/tests/test_docstring_parameters.py b/mne/tests/test_docstring_parameters.py index c94da5e5ab8..64f80f50b74 100644 --- a/mne/tests/test_docstring_parameters.py +++ b/mne/tests/test_docstring_parameters.py @@ -222,8 +222,7 @@ def test_tabs(): continue source = inspect.getsource(mod) assert "\t" not in source, ( - f'"{modname}" has tabs, please remove them ' - "or add it to the ignore list" + f'"{modname}" has tabs, please remove them or add it to the ignore list' ) diff --git a/mne/tests/test_epochs.py b/mne/tests/test_epochs.py index 079a2b53ec9..aa11082238f 100644 --- a/mne/tests/test_epochs.py +++ b/mne/tests/test_epochs.py @@ -479,12 +479,12 @@ def test_average_movements(): def _assert_drop_log_types(drop_log): __tracebackhide__ = True assert isinstance(drop_log, tuple), "drop_log should be tuple" - assert all( - isinstance(log, tuple) for log in drop_log - ), "drop_log[ii] should be tuple" - assert all( - isinstance(s, str) for log in drop_log for s in log - ), "drop_log[ii][jj] should be str" + assert all(isinstance(log, tuple) for log in drop_log), ( + "drop_log[ii] should be tuple" + ) + assert all(isinstance(s, str) for log in drop_log for s in log), ( + "drop_log[ii][jj] should be str" + ) def test_reject(): diff --git a/mne/tests/test_filter.py b/mne/tests/test_filter.py index e259ececbce..537f1930f45 100644 --- a/mne/tests/test_filter.py +++ b/mne/tests/test_filter.py @@ -90,9 +90,9 @@ def test_estimate_ringing(): (0.0001, (30000, 60000)), ): # 37993 n_ring = estimate_ringing_samples(butter(3, thresh, output=kind)) - assert ( - lims[0] <= n_ring <= lims[1] - ), f"{kind} {thresh}: {lims[0]} <= {n_ring} <= {lims[1]}" + assert lims[0] <= n_ring <= lims[1], ( + f"{kind} {thresh}: {lims[0]} <= {n_ring} <= {lims[1]}" + ) with pytest.warns(RuntimeWarning, match="properly estimate"): assert estimate_ringing_samples(butter(4, 0.00001)) == 100000 diff --git a/mne/time_frequency/_stft.py b/mne/time_frequency/_stft.py index 8fb80b43fcc..a6b6f23fff7 100644 --- a/mne/time_frequency/_stft.py +++ b/mne/time_frequency/_stft.py @@ -59,8 +59,7 @@ def stft(x, wsize, tstep=None, verbose=None): if (wsize % tstep) or (tstep % 2): raise ValueError( - "The step size must be a multiple of 2 and a " - "divider of the window length." + "The step size must be a multiple of 2 and a divider of the window length." ) if tstep > wsize / 2: diff --git a/mne/time_frequency/csd.py b/mne/time_frequency/csd.py index c858dd52e57..4ddaa0ac6a3 100644 --- a/mne/time_frequency/csd.py +++ b/mne/time_frequency/csd.py @@ -224,8 +224,7 @@ def sum(self, fmin=None, fmax=None): """ if self._is_sum: raise RuntimeError( - "This CSD matrix already represents a mean or " - "sum across frequencies." + "This CSD matrix already represents a mean or sum across frequencies." ) # Deal with the various ways in which fmin and fmax can be specified @@ -1372,7 +1371,7 @@ def _execute_csd_function( logger.info("[done]") if ch_names is None: - ch_names = [f"SERIES{i+1:03}" for i in range(n_channels)] + ch_names = [f"SERIES{i + 1:03}" for i in range(n_channels)] return CrossSpectralDensity( csds_mean, diff --git a/mne/time_frequency/spectrum.py b/mne/time_frequency/spectrum.py index b1de7f11c0f..03a57010061 100644 --- a/mne/time_frequency/spectrum.py +++ b/mne/time_frequency/spectrum.py @@ -311,7 +311,7 @@ def __init__( if np.isfinite(fmax) and (fmax > self.sfreq / 2): raise ValueError( f"Requested fmax ({fmax} Hz) must not exceed ½ the sampling " - f'frequency of the data ({0.5 * inst.info["sfreq"]} Hz).' + f"frequency of the data ({0.5 * inst.info['sfreq']} Hz)." ) # method self._inst_type = type(inst) @@ -442,7 +442,7 @@ def _check_values(self): if bad_value.any(): chs = np.array(self.ch_names)[bad_value].tolist() s = _pl(bad_value.sum()) - warn(f'Zero value in spectrum for channel{s} {", ".join(chs)}', UserWarning) + warn(f"Zero value in spectrum for channel{s} {', '.join(chs)}", UserWarning) def _returns_complex_tapers(self, **method_kw): return self.method == "multitaper" and method_kw.get("output") == "complex" @@ -1536,7 +1536,7 @@ def average(self, method="mean"): state["nave"] = state["data"].shape[0] state["data"] = method(state["data"]) state["dims"] = state["dims"][1:] - state["data_type"] = f'Averaged {state["data_type"]}' + state["data_type"] = f"Averaged {state['data_type']}" defaults = dict( method=None, fmin=None, @@ -1689,12 +1689,12 @@ def combine_spectrum(all_spectrum, weights="nave"): ch_names = spectrum.ch_names for s_ in all_spectrum[1:]: - assert ( - s_.ch_names == ch_names - ), f"{spectrum} and {s_} do not contain the same channels" - assert ( - np.max(np.abs(s_.freqs - spectrum.freqs)) < 1e-7 - ), f"{spectrum} and {s_} do not contain the same frequencies" + assert s_.ch_names == ch_names, ( + f"{spectrum} and {s_} do not contain the same channels" + ) + assert np.max(np.abs(s_.freqs - spectrum.freqs)) < 1e-7, ( + f"{spectrum} and {s_} do not contain the same frequencies" + ) # use union of bad channels bads = list( diff --git a/mne/time_frequency/tfr.py b/mne/time_frequency/tfr.py index 71dabce6d31..42e4075cc22 100644 --- a/mne/time_frequency/tfr.py +++ b/mne/time_frequency/tfr.py @@ -624,8 +624,7 @@ def _check_tfr_param( freqs = np.asarray(freqs, dtype=float) if freqs.ndim != 1: raise ValueError( - f"freqs must be of shape (n_freqs,), got {np.array(freqs.shape)} " - "instead." + f"freqs must be of shape (n_freqs,), got {np.array(freqs.shape)} instead." ) # Check sfreq @@ -1210,8 +1209,8 @@ def __init__( classname = "EpochsTFR" # end TODO raise ValueError( - f'{classname} got unsupported parameter value{_pl(problem)} ' - f'{" and ".join(problem)}.' + f"{classname} got unsupported parameter value{_pl(problem)} " + f"{' and '.join(problem)}." ) # check method valid_methods = ["morlet", "multitaper"] @@ -1538,7 +1537,7 @@ def _check_values(self, negative_ok=False): s = _pl(negative_values.sum()) warn( f"Negative value in time-frequency decomposition for channel{s} " - f'{", ".join(chs)}', + f"{', '.join(chs)}", UserWarning, ) @@ -3960,12 +3959,12 @@ def combine_tfr(all_tfr, weights="nave"): ch_names = tfr.ch_names for t_ in all_tfr[1:]: - assert ( - t_.ch_names == ch_names - ), f"{tfr} and {t_} do not contain the same channels" - assert ( - np.max(np.abs(t_.times - tfr.times)) < 1e-7 - ), f"{tfr} and {t_} do not contain the same time instants" + assert t_.ch_names == ch_names, ( + f"{tfr} and {t_} do not contain the same channels" + ) + assert np.max(np.abs(t_.times - tfr.times)) < 1e-7, ( + f"{tfr} and {t_} do not contain the same time instants" + ) # use union of bad channels bads = list(set(tfr.info["bads"]).union(*(t_.info["bads"] for t_ in all_tfr[1:]))) @@ -4162,7 +4161,7 @@ def _read_multiple_tfrs(tfr_data, condition=None, *, verbose=None): if len(out) == 0: raise ValueError( f'Cannot find condition "{condition}" in this file. ' - f'The file contains conditions {", ".join(keys)}' + f"The file contains conditions {', '.join(keys)}" ) if len(out) == 1: out = out[0] diff --git a/mne/utils/_logging.py b/mne/utils/_logging.py index 68963feaf61..f4d19655bbf 100644 --- a/mne/utils/_logging.py +++ b/mne/utils/_logging.py @@ -511,7 +511,7 @@ def _frame_info(n): except KeyError: # in our verbose dec pass else: - infos.append(f'{name.lstrip("mne.")}:{frame.f_lineno}') + infos.append(f"{name.lstrip('mne.')}:{frame.f_lineno}") frame = frame.f_back if frame is None: break diff --git a/mne/utils/check.py b/mne/utils/check.py index 21360df9c83..085c51b6996 100644 --- a/mne/utils/check.py +++ b/mne/utils/check.py @@ -317,8 +317,7 @@ def _check_subject( _validate_type(second, "str", "subject input") if first is not None and first != second: raise ValueError( - f"{first_kind} ({repr(first)}) did not match " - f"{second_kind} ({second})" + f"{first_kind} ({repr(first)}) did not match {second_kind} ({second})" ) return second elif first is not None: @@ -1071,8 +1070,7 @@ def _check_sphere(sphere, info=None, sphere_units="m"): del ch_pos["FPz"] elif "Fpz" not in ch_pos and "Oz" in ch_pos: logger.info( - "Approximating Fpz location by mirroring Oz along " - "the X and Y axes." + "Approximating Fpz location by mirroring Oz along the X and Y axes." ) # This assumes Fpz and Oz have the same Z coordinate ch_pos["Fpz"] = ch_pos["Oz"] * [-1, -1, 1] @@ -1082,7 +1080,7 @@ def _check_sphere(sphere, info=None, sphere_units="m"): msg = ( f'sphere="eeglab" requires digitization points of ' f"the following electrode locations in the data: " - f'{", ".join(horizon_ch_names)}, but could not find: ' + f"{', '.join(horizon_ch_names)}, but could not find: " f"{ch_name}" ) if ch_name == "Fpz": @@ -1263,8 +1261,7 @@ def _to_rgb(*args, name="color", alpha=False): except ValueError: args = args[0] if len(args) == 1 else args raise ValueError( - f'Invalid RGB{"A" if alpha else ""} argument(s) for {name}: ' - f"{repr(args)}" + f"Invalid RGB{'A' if alpha else ''} argument(s) for {name}: {repr(args)}" ) from None @@ -1288,5 +1285,5 @@ def _check_method_kwargs(func, kwargs, msg=None): if msg is None: msg = f'function "{func}"' raise TypeError( - f'Got unexpected keyword argument{s} {", ".join(invalid_kw)} for {msg}.' + f"Got unexpected keyword argument{s} {', '.join(invalid_kw)} for {msg}." ) diff --git a/mne/utils/config.py b/mne/utils/config.py index a817886c3f0..c28373fcb93 100644 --- a/mne/utils/config.py +++ b/mne/utils/config.py @@ -185,8 +185,7 @@ def set_memmap_min_size(memmap_min_size): "triggers automated memory mapping, e.g., 1M or 0.5G" ), "MNE_REPR_HTML": ( - "bool, represent some of our objects with rich HTML in a notebook " - "environment" + "bool, represent some of our objects with rich HTML in a notebook environment" ), "MNE_SKIP_NETWORK_TESTS": ( "bool, used in a test decorator (@requires_good_network) to skip " @@ -203,8 +202,7 @@ def set_memmap_min_size(memmap_min_size): ), "MNE_USE_CUDA": "bool, use GPU for filtering/resampling", "MNE_USE_NUMBA": ( - "bool, use Numba just-in-time compiler for some of our intensive " - "computations" + "bool, use Numba just-in-time compiler for some of our intensive computations" ), "SUBJECTS_DIR": "path-like, directory of freesurfer MRI files for each subject", } @@ -583,9 +581,9 @@ def _get_numpy_libs(): for pool in pools: if pool["internal_api"] in ("openblas", "mkl"): return ( - f'{rename[pool["internal_api"]]} ' - f'{pool["version"]} with ' - f'{pool["num_threads"]} thread{_pl(pool["num_threads"])}' + f"{rename[pool['internal_api']]} " + f"{pool['version']} with " + f"{pool['num_threads']} thread{_pl(pool['num_threads'])}" ) return bad_lib @@ -874,7 +872,7 @@ def sys_info( pre = "│ " else: pre = " | " - out(f'\n{pre}{" " * ljust}{op.dirname(mod.__file__)}') + out(f"\n{pre}{' ' * ljust}{op.dirname(mod.__file__)}") out("\n") if not mne_version_good: diff --git a/mne/utils/misc.py b/mne/utils/misc.py index bb3e3ee5cab..343761aee24 100644 --- a/mne/utils/misc.py +++ b/mne/utils/misc.py @@ -379,7 +379,7 @@ def _assert_no_instances(cls, when=""): check = False if check: if cls.__name__ == "Brain": - ref.append(f'Brain._cleaned = {getattr(obj, "_cleaned", None)}') + ref.append(f"Brain._cleaned = {getattr(obj, '_cleaned', None)}") rr = gc.get_referrers(obj) count = 0 for r in rr: diff --git a/mne/viz/_brain/_brain.py b/mne/viz/_brain/_brain.py index 247c0840858..778700c99a7 100644 --- a/mne/viz/_brain/_brain.py +++ b/mne/viz/_brain/_brain.py @@ -4072,28 +4072,28 @@ def _update_monotonic(lims, fmin, fmid, fmax): if fmin is not None: lims["fmin"] = fmin if lims["fmax"] < fmin: - logger.debug(f' Bumping fmax = {lims["fmax"]} to {fmin}') + logger.debug(f" Bumping fmax = {lims['fmax']} to {fmin}") lims["fmax"] = fmin if lims["fmid"] < fmin: - logger.debug(f' Bumping fmid = {lims["fmid"]} to {fmin}') + logger.debug(f" Bumping fmid = {lims['fmid']} to {fmin}") lims["fmid"] = fmin assert lims["fmin"] <= lims["fmid"] <= lims["fmax"] if fmid is not None: lims["fmid"] = fmid if lims["fmin"] > fmid: - logger.debug(f' Bumping fmin = {lims["fmin"]} to {fmid}') + logger.debug(f" Bumping fmin = {lims['fmin']} to {fmid}") lims["fmin"] = fmid if lims["fmax"] < fmid: - logger.debug(f' Bumping fmax = {lims["fmax"]} to {fmid}') + logger.debug(f" Bumping fmax = {lims['fmax']} to {fmid}") lims["fmax"] = fmid assert lims["fmin"] <= lims["fmid"] <= lims["fmax"] if fmax is not None: lims["fmax"] = fmax if lims["fmin"] > fmax: - logger.debug(f' Bumping fmin = {lims["fmin"]} to {fmax}') + logger.debug(f" Bumping fmin = {lims['fmin']} to {fmax}") lims["fmin"] = fmax if lims["fmid"] > fmax: - logger.debug(f' Bumping fmid = {lims["fmid"]} to {fmax}') + logger.debug(f" Bumping fmid = {lims['fmid']} to {fmax}") lims["fmid"] = fmax assert lims["fmin"] <= lims["fmid"] <= lims["fmax"] diff --git a/mne/viz/_brain/tests/test_brain.py b/mne/viz/_brain/tests/test_brain.py index fd2ff96579e..5d092c21713 100644 --- a/mne/viz/_brain/tests/test_brain.py +++ b/mne/viz/_brain/tests/test_brain.py @@ -867,9 +867,9 @@ def _assert_brain_range(brain, rng): for key, mesh in layerer._overlays.items(): if key == "curv": continue - assert ( - mesh._rng == rng - ), f"_layered_meshes[{repr(hemi)}][{repr(key)}]._rng != {rng}" + assert mesh._rng == rng, ( + f"_layered_meshes[{repr(hemi)}][{repr(key)}]._rng != {rng}" + ) @testing.requires_testing_data @@ -1237,9 +1237,9 @@ def test_brain_scraper(renderer_interactive_pyvistaqt, brain_gc, tmp_path): w = img.shape[1] w0 = size[0] # On Linux+conda we get a width of 624, similar tweak in test_brain_init above - assert np.isclose(w, w0, atol=30) or np.isclose( - w, w0 * 2, atol=30 - ), f"w ∉ {{{w0}, {2 * w0}}}" # HiDPI + assert np.isclose(w, w0, atol=30) or np.isclose(w, w0 * 2, atol=30), ( + f"w ∉ {{{w0}, {2 * w0}}}" + ) # HiDPI @testing.requires_testing_data diff --git a/mne/viz/_proj.py b/mne/viz/_proj.py index 5d21afb0594..6e0cb9a4143 100644 --- a/mne/viz/_proj.py +++ b/mne/viz/_proj.py @@ -90,8 +90,7 @@ def plot_projs_joint( missing = (~used.astype(bool)).sum() if missing: warn( - f"{missing} projector{_pl(missing)} had no channel names " - "present in epochs" + f"{missing} projector{_pl(missing)} had no channel names present in epochs" ) del projs ch_types = list(proj_by_type) # reduce to number we actually need diff --git a/mne/viz/backends/_utils.py b/mne/viz/backends/_utils.py index c415d83e456..467f5cb15e7 100644 --- a/mne/viz/backends/_utils.py +++ b/mne/viz/backends/_utils.py @@ -317,8 +317,7 @@ def _qt_get_stylesheet(theme): file = open(theme) except OSError: warn( - "Requested theme file not found, will use light instead: " - f"{repr(theme)}" + f"Requested theme file not found, will use light instead: {repr(theme)}" ) else: with file as fid: diff --git a/mne/viz/misc.py b/mne/viz/misc.py index ed2636d3961..c83a4dfe717 100644 --- a/mne/viz/misc.py +++ b/mne/viz/misc.py @@ -443,7 +443,7 @@ def _plot_mri_contours( if src[0]["coord_frame"] != FIFF.FIFFV_COORD_MRI: raise ValueError( "Source space must be in MRI coordinates, got " - f'{_frame_to_str[src[0]["coord_frame"]]}' + f"{_frame_to_str[src[0]['coord_frame']]}" ) for src_ in src: points = src_["rr"][src_["inuse"].astype(bool)] @@ -708,8 +708,7 @@ def plot_bem( src = read_source_spaces(src) elif src is not None and not isinstance(src, SourceSpaces): raise TypeError( - "src needs to be None, path-like or SourceSpaces instance, " - f"not {repr(src)}" + f"src needs to be None, path-like or SourceSpaces instance, not {repr(src)}" ) if len(surfaces) == 0: diff --git a/mne/viz/tests/test_3d.py b/mne/viz/tests/test_3d.py index 6f109b9490b..34022d59768 100644 --- a/mne/viz/tests/test_3d.py +++ b/mne/viz/tests/test_3d.py @@ -893,7 +893,7 @@ def test_plot_alignment_fnirs(renderer, tmp_path): with catch_logging() as log: fig = plot_alignment(info, **kwargs) log = log.getvalue() - assert f'fnirs_cw_amplitude: {info["nchan"]}' in log + assert f"fnirs_cw_amplitude: {info['nchan']}" in log _assert_n_actors(fig, renderer, info["nchan"]) fig = plot_alignment(info, fnirs=["channels", "sources", "detectors"], **kwargs) diff --git a/mne/viz/topomap.py b/mne/viz/topomap.py index d83698acbb1..0b0d6953a66 100644 --- a/mne/viz/topomap.py +++ b/mne/viz/topomap.py @@ -910,8 +910,7 @@ def _get_pos_outlines(info, picks, sphere, to_sphere=True): orig_sphere = sphere sphere, clip_origin = _adjust_meg_sphere(sphere, info, ch_type) logger.debug( - "Generating pos outlines with sphere " - f"{sphere} from {orig_sphere} for {ch_type}" + f"Generating pos outlines with sphere {sphere} from {orig_sphere} for {ch_type}" ) pos = _find_topomap_coords( info, picks, ignore_overlap=True, to_sphere=to_sphere, sphere=sphere @@ -1262,7 +1261,7 @@ def _plot_topomap( if len(data) != len(pos): raise ValueError( "Data and pos need to be of same length. Got data of " - f"length {len(data)}, pos of length { len(pos)}" + f"length {len(data)}, pos of length {len(pos)}" ) norm = min(data) >= 0 @@ -1409,8 +1408,7 @@ def _plot_ica_topomap( sphere = _check_sphere(sphere, ica.info) if not isinstance(axes, Axes): raise ValueError( - "axis has to be an instance of matplotlib Axes, " - f"got {type(axes)} instead." + f"axis has to be an instance of matplotlib Axes, got {type(axes)} instead." ) ch_type = _get_plot_ch_type(ica, ch_type, allow_ref_meg=ica.allow_ref_meg) if ch_type == "ref_meg": @@ -2191,8 +2189,7 @@ def plot_evoked_topomap( space = 1 / (2.0 * evoked.info["sfreq"]) if max(times) > max(evoked.times) + space or min(times) < min(evoked.times) - space: raise ValueError( - f"Times should be between {evoked.times[0]:0.3} and " - f"{evoked.times[-1]:0.3}." + f"Times should be between {evoked.times[0]:0.3} and {evoked.times[-1]:0.3}." ) # create axes want_axes = n_times + int(colorbar) @@ -2791,8 +2788,7 @@ def plot_psds_topomap( # convert legacy list-of-tuple input to a dict bands = {band[-1]: band[:-1] for band in bands} logger.info( - "converting legacy list-of-tuples input to a dict for the " - "`bands` parameter" + "converting legacy list-of-tuples input to a dict for the `bands` parameter" ) # upconvert single freqs to band upper/lower edges as needed bin_spacing = np.diff(freqs)[0] diff --git a/mne/viz/utils.py b/mne/viz/utils.py index 00458bf3908..a09da17de7d 100644 --- a/mne/viz/utils.py +++ b/mne/viz/utils.py @@ -2356,7 +2356,7 @@ def _gfp(data): except KeyError: raise ValueError( f'"combine" must be None, a callable, or one of "{", ".join(valid)}"; ' - f'got {combine}' + f"got {combine}" ) return combine diff --git a/tools/dev/ensure_headers.py b/tools/dev/ensure_headers.py index b5b425b5900..a4095d82b42 100644 --- a/tools/dev/ensure_headers.py +++ b/tools/dev/ensure_headers.py @@ -156,15 +156,15 @@ def _ensure_copyright(lines, path): lines[insert] = COPYRIGHT_LINE else: lines.insert(insert, COPYRIGHT_LINE) - assert ( - lines.count(COPYRIGHT_LINE) == 1 - ), f"{lines.count(COPYRIGHT_LINE)=} for {path=}" + assert lines.count(COPYRIGHT_LINE) == 1, ( + f"{lines.count(COPYRIGHT_LINE)=} for {path=}" + ) def _ensure_blank(lines, path): - assert ( - lines.count(COPYRIGHT_LINE) == 1 - ), f"{lines.count(COPYRIGHT_LINE)=} for {path=}" + assert lines.count(COPYRIGHT_LINE) == 1, ( + f"{lines.count(COPYRIGHT_LINE)=} for {path=}" + ) insert = lines.index(COPYRIGHT_LINE) + 1 if lines[insert].strip(): # actually has content lines.insert(insert, "") diff --git a/tools/hooks/update_environment_file.py b/tools/hooks/update_environment_file.py index f5e6bb335b0..0b5380a16b5 100755 --- a/tools/hooks/update_environment_file.py +++ b/tools/hooks/update_environment_file.py @@ -80,7 +80,7 @@ def split_dep(dep): pip_section = pip_section if len(pip_deps) else "" # prepare the env file env = f"""\ -# THIS FILE IS AUTO-GENERATED BY {'/'.join(Path(__file__).parts[-3:])} AND WILL BE OVERWRITTEN +# THIS FILE IS AUTO-GENERATED BY {"/".join(Path(__file__).parts[-3:])} AND WILL BE OVERWRITTEN name: mne channels: - conda-forge diff --git a/tutorials/forward/20_source_alignment.py b/tutorials/forward/20_source_alignment.py index dd26f610907..c8cf981dce9 100644 --- a/tutorials/forward/20_source_alignment.py +++ b/tutorials/forward/20_source_alignment.py @@ -115,11 +115,11 @@ mne.viz.set_3d_view(fig, 45, 90, distance=0.6, focalpoint=(0.0, 0.0, 0.0)) print( "Distance from head origin to MEG origin: " - f"{1000 * np.linalg.norm(raw.info["dev_head_t"]["trans"][:3, 3]):.1f} mm" + f"{1000 * np.linalg.norm(raw.info['dev_head_t']['trans'][:3, 3]):.1f} mm" ) print( "Distance from head origin to MRI origin: " - f"{1000 * np.linalg.norm(trans["trans"][:3, 3]):.1f} mm" + f"{1000 * np.linalg.norm(trans['trans'][:3, 3]):.1f} mm" ) dists = mne.dig_mri_distances(raw.info, trans, "sample", subjects_dir=subjects_dir) print( diff --git a/tutorials/forward/30_forward.py b/tutorials/forward/30_forward.py index 6c55d0bfe3c..72731982962 100644 --- a/tutorials/forward/30_forward.py +++ b/tutorials/forward/30_forward.py @@ -255,7 +255,7 @@ # or ``inv['src']`` so that this removal is adequately accounted for. print(f"Before: {src}") -print(f'After: {fwd["src"]}') +print(f"After: {fwd['src']}") # %% # We can explore the content of ``fwd`` to access the numpy array that contains diff --git a/tutorials/intro/15_inplace.py b/tutorials/intro/15_inplace.py index 0c68843d4c8..01e8c1f7eb0 100644 --- a/tutorials/intro/15_inplace.py +++ b/tutorials/intro/15_inplace.py @@ -60,9 +60,9 @@ # Another group of methods where data is modified in-place are the # channel-picking methods. For example: -print(f'original data had {original_raw.info["nchan"]} channels.') +print(f"original data had {original_raw.info['nchan']} channels.") original_raw.pick("eeg") # selects only the EEG channels -print(f'after picking, it has {original_raw.info["nchan"]} channels.') +print(f"after picking, it has {original_raw.info['nchan']} channels.") # %% diff --git a/tutorials/preprocessing/40_artifact_correction_ica.py b/tutorials/preprocessing/40_artifact_correction_ica.py index 5eeb7b79d64..257b1f85051 100644 --- a/tutorials/preprocessing/40_artifact_correction_ica.py +++ b/tutorials/preprocessing/40_artifact_correction_ica.py @@ -291,8 +291,7 @@ # This time, print as percentage. ratio_percent = round(100 * explained_var_ratio["eeg"]) print( - f"Fraction of variance in EEG signal explained by first component: " - f"{ratio_percent}%" + f"Fraction of variance in EEG signal explained by first component: {ratio_percent}%" ) # %% diff --git a/tutorials/preprocessing/50_artifact_correction_ssp.py b/tutorials/preprocessing/50_artifact_correction_ssp.py index 57be25803d5..530e6fd39d8 100644 --- a/tutorials/preprocessing/50_artifact_correction_ssp.py +++ b/tutorials/preprocessing/50_artifact_correction_ssp.py @@ -520,7 +520,7 @@ evoked_eeg.plot(proj=proj, axes=ax, spatial_colors=True) parts = ax.get_title().split("(") ylabel = ( - f'{parts[0]} ({ax.get_ylabel()})\n{parts[1].replace(")", "")}' + f"{parts[0]} ({ax.get_ylabel()})\n{parts[1].replace(')', '')}" if pi == 0 else "" ) From 087779c3bd5ba84dbcef7f3689a7d70f0b045da7 Mon Sep 17 00:00:00 2001 From: Santeri Ruuskanen <66060772+ruuskas@users.noreply.github.com> Date: Thu, 16 Jan 2025 19:07:45 +0200 Subject: [PATCH 08/24] Fix evoked topomap colorbars, closes #13050 (#13063) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Eric Larson --- doc/changes/devel/13063.bugfix.rst | 1 + examples/visualization/evoked_topomap.py | 4 ++-- mne/viz/evoked.py | 13 ++++++++++-- mne/viz/topomap.py | 25 +++++++++++++++++++++++- 4 files changed, 38 insertions(+), 5 deletions(-) create mode 100644 doc/changes/devel/13063.bugfix.rst diff --git a/doc/changes/devel/13063.bugfix.rst b/doc/changes/devel/13063.bugfix.rst new file mode 100644 index 00000000000..76eba2032a1 --- /dev/null +++ b/doc/changes/devel/13063.bugfix.rst @@ -0,0 +1 @@ +Fix bug in the colorbars created by :func:`mne.viz.plot_evoked_topomap` by `Santeri Ruuskanen`_. \ No newline at end of file diff --git a/examples/visualization/evoked_topomap.py b/examples/visualization/evoked_topomap.py index 83d1916c6f9..53b7a60dbba 100644 --- a/examples/visualization/evoked_topomap.py +++ b/examples/visualization/evoked_topomap.py @@ -5,8 +5,8 @@ Plotting topographic maps of evoked data ======================================== -Load evoked data and plot topomaps for selected time points using multiple -additional options. +Load evoked data and plot topomaps for selected time points using +multiple additional options. """ # Authors: Christian Brodbeck # Tal Linzen diff --git a/mne/viz/evoked.py b/mne/viz/evoked.py index 10ec5459e02..b047de4ea32 100644 --- a/mne/viz/evoked.py +++ b/mne/viz/evoked.py @@ -27,6 +27,7 @@ _clean_names, _is_numeric, _pl, + _time_mask, _to_rgb, _validate_type, fill_doc, @@ -1988,10 +1989,18 @@ def plot_evoked_joint( contours = topomap_args.get("contours", 6) ch_type = ch_types.pop() # set should only contain one element # Since the data has all the ch_types, we get the limits from the plot. - vmin, vmax = ts_ax.get_ylim() + vmin, vmax = (None, None) norm = ch_type == "grad" vmin = 0 if norm else vmin - vmin, vmax = _setup_vmin_vmax(evoked.data, vmin, vmax, norm) + time_idx = [ + np.where( + _time_mask(evoked.times, tmin=t, tmax=None, sfreq=evoked.info["sfreq"]) + )[0][0] + for t in times_sec + ] + scalings = topomap_args["scalings"] if "scalings" in topomap_args else None + scaling = _handle_default("scalings", scalings)[ch_type] + vmin, vmax = _setup_vmin_vmax(evoked.data[:, time_idx] * scaling, vmin, vmax, norm) if not isinstance(contours, list | np.ndarray): locator, contours = _set_contour_locator(vmin, vmax, contours) else: diff --git a/mne/viz/topomap.py b/mne/viz/topomap.py index 0b0d6953a66..bb180a3f299 100644 --- a/mne/viz/topomap.py +++ b/mne/viz/topomap.py @@ -2114,6 +2114,22 @@ def plot_evoked_topomap( :ref:`gridspec ` interface to adjust the colorbar size yourself. + The defaults for ``contours`` and ``vlim`` are handled as follows: + + * When neither ``vlim`` nor a list of ``contours`` is passed, MNE sets + ``vlim`` at ± the maximum absolute value of the data and then chooses + contours within those bounds. + + * When ``vlim`` but not a list of ``contours`` is passed, MNE chooses + contours to be within the ``vlim``. + + * When a list of ``contours`` but not ``vlim`` is passed, MNE chooses + ``vlim`` to encompass the ``contours`` and the maximum absolute value of the + data. + + * When both a list of ``contours`` and ``vlim`` are passed, MNE uses them + as-is. + When ``time=="interactive"``, the figure will publish and subscribe to the following UI events: @@ -2296,11 +2312,17 @@ def plot_evoked_topomap( _vlim = [ _setup_vmin_vmax(data[:, i], *vlim, norm=merge_channels) for i in range(n_times) ] - _vlim = (np.min(_vlim), np.max(_vlim)) + _vlim = [np.min(_vlim), np.max(_vlim)] cmap = _setup_cmap(cmap, n_axes=n_times, norm=_vlim[0] >= 0) # set up contours if not isinstance(contours, list | np.ndarray): _, contours = _set_contour_locator(*_vlim, contours) + else: + if vlim[0] is None and np.any(contours < _vlim[0]): + _vlim[0] = contours[0] + if vlim[1] is None and np.any(contours > _vlim[1]): + _vlim[1] = contours[-1] + # prepare for main loop over times kwargs = dict( sensors=sensors, @@ -3348,6 +3370,7 @@ def _set_contour_locator(vmin, vmax, contours): # correct number of bins is equal to contours + 1. locator = ticker.MaxNLocator(nbins=contours + 1) contours = locator.tick_values(vmin, vmax) + contours = contours[1:-1] return locator, contours From bd8c318537ffcabf4c5fadd4347ec5068bb91b67 Mon Sep 17 00:00:00 2001 From: Simon Kern <14980558+skjerns@users.noreply.github.com> Date: Fri, 17 Jan 2025 22:38:59 +0100 Subject: [PATCH 09/24] [FIX] Reading an EDF with preload=False and mixed frequency (#13069) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Eric Larson --- doc/changes/devel/13069.bugfix.rst | 1 + mne/conftest.py | 1 + mne/io/edf/edf.py | 9 ++++++--- mne/io/edf/tests/test_edf.py | 18 ++++++++++++++++++ 4 files changed, 26 insertions(+), 3 deletions(-) create mode 100644 doc/changes/devel/13069.bugfix.rst diff --git a/doc/changes/devel/13069.bugfix.rst b/doc/changes/devel/13069.bugfix.rst new file mode 100644 index 00000000000..7c23221c8df --- /dev/null +++ b/doc/changes/devel/13069.bugfix.rst @@ -0,0 +1 @@ +Fix bug cause by unnecessary assertion when loading mixed frequency EDFs without preloading :func:`mne.io.read_raw_edf` by `Simon Kern`_. \ No newline at end of file diff --git a/mne/conftest.py b/mne/conftest.py index 85e3f9d255b..8a4586067b3 100644 --- a/mne/conftest.py +++ b/mne/conftest.py @@ -186,6 +186,7 @@ def pytest_configure(config: pytest.Config): ignore:.*builtin type swigvarlink has no.*:DeprecationWarning # eeglabio ignore:numpy\.core\.records is deprecated.*:DeprecationWarning + ignore:Starting field name with a underscore.*: # joblib ignore:process .* is multi-threaded, use of fork/exec.*:DeprecationWarning """ # noqa: E501 diff --git a/mne/io/edf/edf.py b/mne/io/edf/edf.py index fadd1b83857..09ac24f753e 100644 --- a/mne/io/edf/edf.py +++ b/mne/io/edf/edf.py @@ -436,21 +436,24 @@ def _read_segment_file(data, idx, fi, start, stop, raw_extras, filenames, cals, ones[orig_idx, smp_read : smp_read + len(one_i)] = one_i n_smp_read[orig_idx] += len(one_i) + # resample channels with lower sample frequency # skip if no data was requested, ie. only annotations were read - if sum(n_smp_read) > 0: + if any(n_smp_read) > 0: # expected number of samples, equals maximum sfreq smp_exp = data.shape[-1] - assert max(n_smp_read) == smp_exp # resample data after loading all chunks to prevent edge artifacts resampled = False + for i, smp_read in enumerate(n_smp_read): # nothing read, nothing to resample if smp_read == 0: continue # upsample if n_samples is lower than from highest sfreq if smp_read != smp_exp: - assert (ones[i, smp_read:] == 0).all() # sanity check + # sanity check that we read exactly how much we expected + assert (ones[i, smp_read:] == 0).all() + ones[i, :] = resample( ones[i, :smp_read].astype(np.float64), smp_exp, diff --git a/mne/io/edf/tests/test_edf.py b/mne/io/edf/tests/test_edf.py index b4f0ab33fa5..ce671ca7e81 100644 --- a/mne/io/edf/tests/test_edf.py +++ b/mne/io/edf/tests/test_edf.py @@ -259,6 +259,24 @@ def test_edf_different_sfreqs(stim_channel): assert_allclose(times1, times2) +@testing.requires_testing_data +@pytest.mark.parametrize("stim_channel", (None, False, "auto")) +def test_edf_different_sfreqs_nopreload(stim_channel): + """Test loading smaller sfreq channels without preloading.""" + # load without preloading, then load a channel that has smaller sfreq + # as other channels, produced an error, see mne-python/issues/12897 + + for i in range(1, 13): + raw = read_raw_edf(input_fname=edf_reduced, verbose="error", preload=False) + + # this should work for channels of all sfreq, even if larger sfreqs + # are present in the file + x1 = raw.get_data(picks=[f"A{i}"], return_times=False) + # load next ch, this is sometimes with a higher sometimes a lower sfreq + x2 = raw.get_data([f"A{i + 1}"], return_times=False) + assert x1.shape == x2.shape + + def test_edf_data_broken(tmp_path): """Test edf files.""" raw = _test_raw_reader( From 8b9fc973e0bdaca9a5ba0c9333637722ed323633 Mon Sep 17 00:00:00 2001 From: "Thomas S. Binns" Date: Sun, 19 Jan 2025 22:05:20 +0000 Subject: [PATCH 10/24] [BUG] Fix taper weighting in computation of TFR multitaper power (#13067) Co-authored-by: Eric Larson --- doc/changes/devel/13067.bugfix.rst | 1 + mne/time_frequency/tests/test_tfr.py | 19 +++++++---- mne/time_frequency/tfr.py | 51 ++++++++++++++-------------- 3 files changed, 39 insertions(+), 32 deletions(-) create mode 100644 doc/changes/devel/13067.bugfix.rst diff --git a/doc/changes/devel/13067.bugfix.rst b/doc/changes/devel/13067.bugfix.rst new file mode 100644 index 00000000000..237df7623d5 --- /dev/null +++ b/doc/changes/devel/13067.bugfix.rst @@ -0,0 +1 @@ +Fix bug where taper weights were not correctly applied when computing multitaper power with :meth:`mne.Epochs.compute_tfr` and :func:`mne.time_frequency.tfr_array_multitaper`, by `Thomas Binns`_. \ No newline at end of file diff --git a/mne/time_frequency/tests/test_tfr.py b/mne/time_frequency/tests/test_tfr.py index 6fa3a833be2..6adb4e361e1 100644 --- a/mne/time_frequency/tests/test_tfr.py +++ b/mne/time_frequency/tests/test_tfr.py @@ -255,20 +255,25 @@ def test_tfr_morlet(): # computed within the method. assert_allclose(epochs_amplitude_2.data**2, epochs_power_picks.data) - # test that averaging power across tapers when multitaper with + # test that aggregating power across tapers when multitaper with # output='complex' gives the same as output='power' epoch_data = epochs.get_data() multitaper_power = tfr_array_multitaper( epoch_data, epochs.info["sfreq"], freqs, n_cycles, output="power" ) - multitaper_complex = tfr_array_multitaper( - epoch_data, epochs.info["sfreq"], freqs, n_cycles, output="complex" + multitaper_complex, weights = tfr_array_multitaper( + epoch_data, + epochs.info["sfreq"], + freqs, + n_cycles, + output="complex", + return_weights=True, ) - taper_dim = 2 - power_from_complex = (multitaper_complex * multitaper_complex.conj()).real.mean( - axis=taper_dim - ) + weights = np.expand_dims(weights, axis=(0, 1, -1)) # match shape of complex data + tfr = weights * multitaper_complex + tfr = (tfr * tfr.conj()).real.sum(axis=2) + power_from_complex = tfr * (2 / (weights * weights.conj()).real.sum(axis=2)) assert_allclose(power_from_complex, multitaper_power) print(itc) # test repr diff --git a/mne/time_frequency/tfr.py b/mne/time_frequency/tfr.py index 42e4075cc22..fc60802f61b 100644 --- a/mne/time_frequency/tfr.py +++ b/mne/time_frequency/tfr.py @@ -545,20 +545,18 @@ def _compute_tfr( if method == "morlet": W = morlet(sfreq, freqs, n_cycles=n_cycles, zero_mean=zero_mean) Ws = [W] # to have same dimensionality as the 'multitaper' case + weights = None # no tapers for Morlet estimates elif method == "multitaper": - out = _make_dpss( + Ws, weights = _make_dpss( sfreq, freqs, n_cycles=n_cycles, time_bandwidth=time_bandwidth, zero_mean=zero_mean, - return_weights=return_weights, + return_weights=True, # required for converting complex → power ) - if return_weights: - Ws, weights = out - else: - Ws = out + weights = np.asarray(weights) # Check wavelets if len(Ws[0][0]) > epoch_data.shape[2]: @@ -581,9 +579,7 @@ def _compute_tfr( if ("avg_" in output) or ("itc" in output): out = np.empty((n_chans, n_freqs, n_times), dtype) elif output in ["complex", "phase"] and method == "multitaper": - out = np.empty((n_chans, n_tapers, n_epochs, n_freqs, n_times), dtype) - if return_weights: - weights = np.array(weights) + out = np.empty((n_chans, n_epochs, n_tapers, n_freqs, n_times), dtype) else: out = np.empty((n_chans, n_epochs, n_freqs, n_times), dtype) @@ -594,7 +590,7 @@ def _compute_tfr( # Parallelization is applied across channels. tfrs = parallel( - my_cwt(channel, Ws, output, use_fft, "same", decim, method) + my_cwt(channel, Ws, output, use_fft, "same", decim, weights) for channel in epoch_data.transpose(1, 0, 2) ) @@ -604,10 +600,7 @@ def _compute_tfr( if ("avg_" not in output) and ("itc" not in output): # This is to enforce that the first dimension is for epochs - if output in ["complex", "phase"] and method == "multitaper": - out = out.transpose(2, 0, 1, 3, 4) - else: - out = out.transpose(1, 0, 2, 3) + out = np.moveaxis(out, 1, 0) if return_weights: return out, weights @@ -683,7 +676,7 @@ def _check_tfr_param( return freqs, sfreq, zero_mean, n_cycles, time_bandwidth, decim -def _time_frequency_loop(X, Ws, output, use_fft, mode, decim, method=None): +def _time_frequency_loop(X, Ws, output, use_fft, mode, decim, weights=None): """Aux. function to _compute_tfr. Loops time-frequency transform across wavelets and epochs. @@ -710,9 +703,8 @@ def _time_frequency_loop(X, Ws, output, use_fft, mode, decim, method=None): See numpy.convolve. decim : slice The decimation slice: e.g. power[:, decim] - method : str | None - Used only for multitapering to create tapers dimension in the output - if ``output in ['complex', 'phase']``. + weights : array, shape (n_tapers, n_wavelets) | None + Concentration weights for each taper in the wavelets, if present. """ # Set output type dtype = np.float64 @@ -726,10 +718,12 @@ def _time_frequency_loop(X, Ws, output, use_fft, mode, decim, method=None): n_freqs = len(Ws[0]) if ("avg_" in output) or ("itc" in output): tfrs = np.zeros((n_freqs, n_times), dtype=dtype) - elif output in ["complex", "phase"] and method == "multitaper": - tfrs = np.zeros((n_tapers, n_epochs, n_freqs, n_times), dtype=dtype) + elif output in ["complex", "phase"] and weights is not None: + tfrs = np.zeros((n_epochs, n_tapers, n_freqs, n_times), dtype=dtype) else: tfrs = np.zeros((n_epochs, n_freqs, n_times), dtype=dtype) + if weights is not None: + weights = np.expand_dims(weights, axis=-1) # add singleton time dimension # Loops across tapers. for taper_idx, W in enumerate(Ws): @@ -744,6 +738,8 @@ def _time_frequency_loop(X, Ws, output, use_fft, mode, decim, method=None): # Loop across epochs for epoch_idx, tfr in enumerate(coefs): # Transform complex values + if output not in ["complex", "phase"] and weights is not None: + tfr = weights[taper_idx] * tfr # weight each taper estimate if output in ["power", "avg_power"]: tfr = (tfr * tfr.conj()).real # power elif output == "phase": @@ -759,8 +755,8 @@ def _time_frequency_loop(X, Ws, output, use_fft, mode, decim, method=None): # Stack or add if ("avg_" in output) or ("itc" in output): tfrs += tfr - elif output in ["complex", "phase"] and method == "multitaper": - tfrs[taper_idx, epoch_idx] += tfr + elif output in ["complex", "phase"] and weights is not None: + tfrs[epoch_idx, taper_idx] += tfr else: tfrs[epoch_idx] += tfr @@ -774,9 +770,14 @@ def _time_frequency_loop(X, Ws, output, use_fft, mode, decim, method=None): if ("avg_" in output) or ("itc" in output): tfrs /= n_epochs - # Normalization by number of taper - if n_tapers > 1 and output not in ["complex", "phase"]: - tfrs /= n_tapers + # Normalization by taper weights + if n_tapers > 1 and output not in ["complex", "phase", "itc"]: + if "avg_" not in output: # add singleton epochs dimension to weights + weights = np.expand_dims(weights, axis=0) + tfrs.real *= 2 / (weights * weights.conj()).real.sum(axis=-3) + if output == "avg_power_itc": # weight itc by the number of tapers + tfrs.imag = tfrs.imag / n_tapers + return tfrs From 4f53a3732917dd1dbc91d4725ae79fc1c7ad4661 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Steinn=20Hauser=20Magn=C3=BAsson?= <42544980+steinnhauser@users.noreply.github.com> Date: Sun, 19 Jan 2025 23:12:02 +0100 Subject: [PATCH 11/24] New feature for removing heart artifacts from EEG or ESG data using a Principal Component Analysis - Optimal Basis Sets (PCA-OBS) algorithm (#13037) Co-authored-by: Emma Bailey Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Steinn Magnusson Co-authored-by: Eric Larson Co-authored-by: emma-bailey <93327939+emma-bailey@users.noreply.github.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: Daniel McCloy --- .circleci/config.yml | 9 +- doc/api/datasets.rst | 1 + doc/api/preprocessing.rst | 1 + doc/changes/devel/13037.newfeature.rst | 1 + doc/changes/names.inc | 2 + doc/conf.py | 1 + doc/references.bib | 10 + .../esg_rm_heart_artefact_pcaobs.py | 196 +++++++++++ mne/datasets/__init__.pyi | 2 + mne/datasets/utils.py | 30 +- mne/preprocessing/__init__.pyi | 2 + mne/preprocessing/_pca_obs.py | 333 ++++++++++++++++++ mne/preprocessing/tests/test_pca_obs.py | 107 ++++++ mne/utils/numerics.py | 3 + pyproject.toml | 1 + tools/circleci_dependencies.sh | 2 +- .../50_artifact_correction_ssp.py | 8 + 17 files changed, 706 insertions(+), 3 deletions(-) create mode 100644 doc/changes/devel/13037.newfeature.rst create mode 100755 examples/preprocessing/esg_rm_heart_artefact_pcaobs.py create mode 100755 mne/preprocessing/_pca_obs.py create mode 100644 mne/preprocessing/tests/test_pca_obs.py diff --git a/.circleci/config.yml b/.circleci/config.yml index 4297dc5fedf..26b9f600e3c 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -218,6 +218,9 @@ jobs: - restore_cache: keys: - data-cache-phantom-kit + - restore_cache: + keys: + - data-cache-ds004388 - run: name: Get data # This limit could be increased, but this is helpful for finding slow ones @@ -252,7 +255,7 @@ jobs: name: Check sphinx log for warnings (which are treated as errors) when: always command: | - ! grep "^.* (WARNING|ERROR): .*$" sphinx_log.txt + ! grep "^.*\(WARNING\|ERROR\): " sphinx_log.txt - run: name: Show profiling output when: always @@ -393,6 +396,10 @@ jobs: key: data-cache-phantom-kit paths: - ~/mne_data/MNE-phantom-KIT-data # (1 G) + - save_cache: + key: data-cache-ds004388 + paths: + - ~/mne_data/ds004388 # (1.8 G) linkcheck: diff --git a/doc/api/datasets.rst b/doc/api/datasets.rst index 2b2c92c8654..87730fbd717 100644 --- a/doc/api/datasets.rst +++ b/doc/api/datasets.rst @@ -18,6 +18,7 @@ Datasets brainstorm.bst_auditory.data_path brainstorm.bst_resting.data_path brainstorm.bst_raw.data_path + default_path eegbci.load_data eegbci.standardize fetch_aparc_sub_parcellation diff --git a/doc/api/preprocessing.rst b/doc/api/preprocessing.rst index 86ad3aca910..9fe3f995cc4 100644 --- a/doc/api/preprocessing.rst +++ b/doc/api/preprocessing.rst @@ -116,6 +116,7 @@ Projections: read_ica_eeglab read_fine_calibration write_fine_calibration + apply_pca_obs :py:mod:`mne.preprocessing.nirs`: diff --git a/doc/changes/devel/13037.newfeature.rst b/doc/changes/devel/13037.newfeature.rst new file mode 100644 index 00000000000..3b28e2294ab --- /dev/null +++ b/doc/changes/devel/13037.newfeature.rst @@ -0,0 +1 @@ +Add PCA-OBS preprocessing for the removal of heart-artefacts from EEG or ESG datasets via :func:`mne.preprocessing.apply_pca_obs`, by :newcontrib:`Emma Bailey` and :newcontrib:`Steinn Hauser Magnusson`. diff --git a/doc/changes/names.inc b/doc/changes/names.inc index 3ac0b1cd9c9..eb444c5e594 100644 --- a/doc/changes/names.inc +++ b/doc/changes/names.inc @@ -73,6 +73,7 @@ .. _Eberhard Eich: https://github.com/ebeich .. _Eduard Ort: https://github.com/eort .. _Emily Stephen: https://github.com/emilyps14 +.. _Emma Bailey: https://www.cbs.mpg.de/employees/bailey .. _Enrico Varano: https://github.com/enricovara/ .. _Enzo Altamiranda: https://www.linkedin.com/in/enzoalt .. _Eric Larson: https://larsoner.com @@ -284,6 +285,7 @@ .. _Stanislas Chambon: https://github.com/Slasnista .. _Stefan Appelhoff: https://stefanappelhoff.com .. _Stefan Repplinger: https://github.com/stfnrpplngr +.. _Steinn Hauser Magnusson: https://github.com/steinnhauser .. _Steven Bethard: https://github.com/bethard .. _Steven Bierer: https://github.com/neurolaunch .. _Steven Gutstein: https://github.com/smgutstein diff --git a/doc/conf.py b/doc/conf.py index 74f66d8f6ae..f1b771571d6 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -355,6 +355,7 @@ "n_frequencies", "n_tests", "n_samples", + "n_peaks", "n_permutations", "nchan", "n_points", diff --git a/doc/references.bib b/doc/references.bib index a129d2f46a2..e2578ed18f2 100644 --- a/doc/references.bib +++ b/doc/references.bib @@ -1335,6 +1335,16 @@ @inproceedings{NdiayeEtAl2016 year = {2016} } +@article{NiazyEtAl2005, + author = {Niazy, R. K. and Beckmann, C.F. and Iannetti, G.D. and Brady, J. M. and Smith, S. M.}, + title = {Removal of FMRI environment artifacts from EEG data using optimal basis sets}, + journal = {NeuroImage}, + year = {2005}, + volume = {28}, + pages = {720-737}, + doi = {10.1016/j.neuroimage.2005.06.067.} +} + @article{NicholsHolmes2002, author = {Nichols, Thomas E. and Holmes, Andrew P.}, doi = {10.1002/hbm.1058}, diff --git a/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py b/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py new file mode 100755 index 00000000000..a6c6bb3c2ba --- /dev/null +++ b/examples/preprocessing/esg_rm_heart_artefact_pcaobs.py @@ -0,0 +1,196 @@ +""" +.. _ex-pcaobs: + +===================================================================================== +Principal Component Analysis - Optimal Basis Sets (PCA-OBS) removing cardiac artefact +===================================================================================== + +This script shows an example of how to use an adaptation of PCA-OBS +:footcite:`NiazyEtAl2005`. PCA-OBS was originally designed to remove +the ballistocardiographic artefact in simultaneous EEG-fMRI. Here, it +has been adapted to remove the delay between the detected R-peak and the +ballistocardiographic artefact such that the algorithm can be applied to +remove the cardiac artefact in EEG (electroencephalography) and ESG +(electrospinography) data. We will illustrate how it works by applying the +algorithm to ESG data, where the effect of removal is most pronounced. + +See: https://www.biorxiv.org/content/10.1101/2024.09.05.611423v1 +for more details on the dataset and application for ESG data. + +""" + +# Authors: Emma Bailey , +# Steinn Hauser Magnusson +# License: BSD-3-Clause +# Copyright the MNE-Python contributors. + +import glob + +import numpy as np + +# %% +# Download sample subject data from OpenNeuro if you haven't already. +# This will download simultaneous EEG and ESG data from a single run of a +# single participant after median nerve stimulation of the left wrist. +import openneuro +from matplotlib import pyplot as plt + +import mne +from mne import Epochs, events_from_annotations +from mne.io import read_raw_eeglab +from mne.preprocessing import find_ecg_events, fix_stim_artifact + +# add the path where you want the OpenNeuro data downloaded. Each run is ~2GB of data +ds = "ds004388" +target_dir = mne.datasets.default_path() / ds +run_name = "sub-001/eeg/*median_run-03_eeg*.set" +if not glob.glob(str(target_dir / run_name)): + target_dir.mkdir(exist_ok=True) + openneuro.download(dataset=ds, target_dir=target_dir, include=run_name[:-4]) +block_files = glob.glob(str(target_dir / run_name)) +assert len(block_files) == 1 + +# %% +# Define the esg channels (arranged in two patches over the neck and lower back). + +esg_chans = [ + "S35", + "S24", + "S36", + "Iz", + "S17", + "S15", + "S32", + "S22", + "S19", + "S26", + "S28", + "S9", + "S13", + "S11", + "S7", + "SC1", + "S4", + "S18", + "S8", + "S31", + "SC6", + "S12", + "S16", + "S5", + "S30", + "S20", + "S34", + "S21", + "S25", + "L1", + "S29", + "S14", + "S33", + "S3", + "L4", + "S6", + "S23", +] + +# Interpolation window in seconds for ESG data to remove stimulation artefact +tstart_esg = -7e-3 +tmax_esg = 7e-3 + +# Define timing of heartbeat epochs in seconds relative to R-peaks +iv_baseline = [-400e-3, -300e-3] +iv_epoch = [-400e-3, 600e-3] + +# %% +# Next, we perform minimal preprocessing including removing the +# stimulation artefact, downsampling and filtering. + +raw = read_raw_eeglab(block_files[0], verbose="error") +raw.set_channel_types(dict(ECG="ecg")) +# Isolate the ESG channels (include the ECG channel for R-peak detection) +raw.pick(esg_chans + ["ECG"]) +# Trim duration and downsample (from 10kHz) to improve example speed +raw.crop(0, 60).load_data().resample(2000) + +# Find trigger timings to remove the stimulation artefact +events, event_dict = events_from_annotations(raw) +trigger_name = "Median - Stimulation" + +fix_stim_artifact( + raw, + events=events, + event_id=event_dict[trigger_name], + tmin=tstart_esg, + tmax=tmax_esg, + mode="linear", + stim_channel=None, +) + +# %% +# Find ECG events and add to the raw structure as event annotations. + +ecg_events, ch_ecg, average_pulse = find_ecg_events(raw, ch_name="ECG") +ecg_event_samples = np.asarray( + [[ecg_event[0] for ecg_event in ecg_events]] +) # Samples only + +qrs_event_time = [ + x / raw.info["sfreq"] for x in ecg_event_samples.reshape(-1) +] # Divide by sampling rate to make times +duration = np.repeat(0.0, len(ecg_event_samples)) +description = ["qrs"] * len(ecg_event_samples) + +raw.annotations.append( + qrs_event_time, duration, description, ch_names=[esg_chans] * len(qrs_event_time) +) + +# %% +# Create evoked response around the detected R-peaks +# before and after cardiac artefact correction. + +events, event_ids = events_from_annotations(raw) +event_id_dict = {key: value for key, value in event_ids.items() if key == "qrs"} +epochs = Epochs( + raw, + events, + event_id=event_id_dict, + tmin=iv_epoch[0], + tmax=iv_epoch[1], + baseline=tuple(iv_baseline), +) +evoked_before = epochs.average() + +# Apply function - modifies the data in place. Optionally high-pass filter +# the data before applying PCA-OBS to remove low frequency drifts +raw = mne.preprocessing.apply_pca_obs( + raw, picks=esg_chans, n_jobs=5, qrs_times=raw.times[ecg_event_samples.reshape(-1)] +) + +epochs = Epochs( + raw, + events, + event_id=event_id_dict, + tmin=iv_epoch[0], + tmax=iv_epoch[1], + baseline=tuple(iv_baseline), +) +evoked_after = epochs.average() + +# %% +# Compare evoked responses to assess completeness of artefact removal. + +fig, axes = plt.subplots(1, 1, layout="constrained") +data_before = evoked_before.get_data(units=dict(eeg="uV")).T +data_after = evoked_after.get_data(units=dict(eeg="uV")).T +hs = list() +hs.append(axes.plot(epochs.times, data_before, color="k")[0]) +hs.append(axes.plot(epochs.times, data_after, color="green", label="after")[0]) +axes.set(ylim=[-500, 1000], ylabel="Amplitude (µV)", xlabel="Time (s)") +axes.set(title="ECG artefact removal using PCA-OBS") +axes.legend(hs, ["before", "after"]) +plt.show() + +# %% +# References +# ---------- +# .. footbibliography:: diff --git a/mne/datasets/__init__.pyi b/mne/datasets/__init__.pyi index 44cee84fe7f..2f69a1027e5 100644 --- a/mne/datasets/__init__.pyi +++ b/mne/datasets/__init__.pyi @@ -6,6 +6,7 @@ __all__ = [ "epilepsy_ecog", "erp_core", "eyelink", + "default_path", "fetch_aparc_sub_parcellation", "fetch_dataset", "fetch_fsaverage", @@ -70,6 +71,7 @@ from ._infant import fetch_infant_template from ._phantom.base import fetch_phantom from .utils import ( _download_all_example_data, + default_path, fetch_aparc_sub_parcellation, fetch_hcp_mmp_parcellation, has_dataset, diff --git a/mne/datasets/utils.py b/mne/datasets/utils.py index 452e42cffc7..93aabc0841a 100644 --- a/mne/datasets/utils.py +++ b/mne/datasets/utils.py @@ -2,6 +2,7 @@ # License: BSD-3-Clause # Copyright the MNE-Python contributors. +import glob import importlib import inspect import logging @@ -92,6 +93,22 @@ def _dataset_version(path, name): return version +@verbose +def default_path(*, verbose=None): + """Get the default MNE_DATA path. + + Parameters + ---------- + %(verbose)s + + Returns + ------- + data_path : instance of Path + Path to the default MNE_DATA directory. + """ + return _get_path(None, None, None) + + def _get_path(path, key, name): """Get a dataset path.""" # 1. Input @@ -113,7 +130,8 @@ def _get_path(path, key, name): return path # 4. ~/mne_data (but use a fake home during testing so we don't # unnecessarily create ~/mne_data) - logger.info(f"Using default location ~/mne_data for {name}...") + extra = f" for {name}" if name else "" + logger.info(f"Using default location ~/mne_data{extra}...") path = Path(os.getenv("_MNE_FAKE_HOME_DIR", "~")).expanduser() / "mne_data" if not path.is_dir(): logger.info(f"Creating {path}") @@ -319,6 +337,8 @@ def _download_all_example_data(verbose=True): # # verbose=True by default so we get nice status messages. # Consider adding datasets from here to CircleCI for PR-auto-build + import openneuro + paths = dict() for kind in ( "sample testing misc spm_face somato hf_sef multimodal " @@ -375,6 +395,14 @@ def _download_all_example_data(verbose=True): limo.load_data(subject=1, update_path=True) logger.info("[done limo]") + # for ESG + ds = "ds004388" + target_dir = default_path() / ds + run_name = "sub-001/eeg/*median_run-03_eeg*.set" + if not glob.glob(str(target_dir / run_name)): + target_dir.mkdir(exist_ok=True) + openneuro.download(dataset=ds, target_dir=target_dir, include=run_name[:-4]) + @verbose def fetch_aparc_sub_parcellation(subjects_dir=None, verbose=None): diff --git a/mne/preprocessing/__init__.pyi b/mne/preprocessing/__init__.pyi index 54f1c825c13..c54685dba34 100644 --- a/mne/preprocessing/__init__.pyi +++ b/mne/preprocessing/__init__.pyi @@ -44,6 +44,7 @@ __all__ = [ "realign_raw", "regress_artifact", "write_fine_calibration", + "apply_pca_obs", ] from . import eyetracking, ieeg, nirs from ._annotate_amplitude import annotate_amplitude @@ -56,6 +57,7 @@ from ._fine_cal import ( write_fine_calibration, ) from ._lof import find_bad_channels_lof +from ._pca_obs import apply_pca_obs from ._peak_finder import peak_finder from ._regress import EOGRegression, read_eog_regression, regress_artifact from .artifact_detection import ( diff --git a/mne/preprocessing/_pca_obs.py b/mne/preprocessing/_pca_obs.py new file mode 100755 index 00000000000..be226a73889 --- /dev/null +++ b/mne/preprocessing/_pca_obs.py @@ -0,0 +1,333 @@ +"""Principle Component Analysis Optimal Basis Sets (PCA-OBS).""" + +# Authors: The MNE-Python contributors. +# License: BSD-3-Clause +# Copyright the MNE-Python contributors. + +import math + +import numpy as np +from scipy.interpolate import PchipInterpolator as pchip +from scipy.signal import detrend + +from ..io.fiff.raw import Raw +from ..utils import _PCA, _validate_type, logger, verbose + + +@verbose +def apply_pca_obs( + raw: Raw, + picks: list[str], + *, + qrs_times: np.ndarray, + n_components: int = 4, + n_jobs: int | None = None, + copy: bool = True, + verbose: bool | str | int | None = None, +) -> Raw: + """ + Apply the PCA-OBS algorithm to picks of a Raw object. + + Uses the optimal basis set (OBS) algorithm from :footcite:`NiazyEtAl2005`. + + Parameters + ---------- + raw : instance of Raw + The raw data to process. + %(picks_all_data_noref)s + qrs_times : ndarray, shape (n_peaks,) + Array of times in the Raw data of detected R-peaks in ECG channel. + n_components : int + Number of PCA components to use to form the OBS (default 4). + %(n_jobs)s + copy : bool + If False, modify the Raw instance in-place. + If True (default), copy the raw instance before processing. + %(verbose)s + + Returns + ------- + raw : instance of Raw + The modified raw instance. + + Notes + ----- + .. versionadded:: 1.10 + + References + ---------- + .. footbibliography:: + """ + # sanity checks + _validate_type(qrs_times, np.ndarray, "qrs_times") + if len(qrs_times.shape) > 1: + raise ValueError("qrs_times must be a 1d array") + if qrs_times.dtype not in [int, float]: + raise ValueError("qrs_times must be an array of either integers or floats") + if np.any(qrs_times < 0): + raise ValueError("qrs_times must be strictly positive") + if np.any(qrs_times >= raw.times[-1]): + logger.warning("some out of bound qrs_times will be ignored..") + + if copy: + raw = raw.copy() + + raw.apply_function( + _pca_obs, + picks=picks, + n_jobs=n_jobs, + # args sent to PCA_OBS, convert times to indices + qrs=raw.time_as_index(qrs_times), + n_components=n_components, + ) + + return raw + + +def _pca_obs( + data: np.ndarray, + qrs: np.ndarray, + n_components: int, +) -> np.ndarray: + """Algorithm to remove heart artefact from EEG data (array of length n_times).""" + # set to baseline + data = data - np.mean(data) + + # Allocate memory for artifact which will be subtracted from the data + fitted_art = np.zeros(data.shape) + + # Extract QRS event indexes which are within out data timeframe + peak_idx = qrs[qrs < len(data)] + peak_count = len(peak_idx) + + ################################################################## + # Preparatory work - reserving memory, configure sizes, de-trend # + ################################################################## + # define peak range based on RR + mRR = np.median(np.diff(peak_idx)) + peak_range = round(mRR / 2) # Rounds to an integer + mid_p = peak_range + 1 + n_samples_fit = round( + peak_range / 8 + ) # sample fit for interpolation between fitted artifact windows + + # make sure array is long enough for PArange (if not cut off last ECG peak) + # NOTE: Here we previously checked for the last part of the window to be big enough. + while peak_idx[peak_count - 1] + peak_range > len(data): + peak_count = peak_count - 1 # reduce number of QRS complexes detected + + # build PCA matrix(heart-beat-epochs x window-length) + pcamat = np.zeros((peak_count - 1, 2 * peak_range + 1)) # [epoch x time] + # picking out heartbeat epochs + for p in range(1, peak_count): + pcamat[p - 1, :] = data[peak_idx[p] - peak_range : peak_idx[p] + peak_range + 1] + + # detrending matrix(twice) + pcamat = detrend( + pcamat, type="constant", axis=1 + ) # [epoch x time] - detrended along the epoch + mean_effect: np.ndarray = np.mean( + pcamat, axis=0 + ) # [1 x time], contains the mean over all epochs + dpcamat = detrend(pcamat, type="constant", axis=1) # [time x epoch] + + ############################ + # Perform PCA with sklearn # + ############################ + # run PCA, perform singular value decomposition (SVD) + pca = _PCA() + pca.fit(dpcamat) + factor_loadings = pca.components_.T * np.sqrt(pca.explained_variance_) + + # define selected number of components using profile likelihood + + ##################################### + # Make template of the ECG artefact # + ##################################### + mean_effect = mean_effect.reshape(-1, 1) + pca_template = np.c_[mean_effect, factor_loadings[:, :n_components]] + + ################ + # Data Fitting # + ################ + window_start_idx = [] + window_end_idx = [] + post_idx_next_peak = None + + for p in range(peak_count): + # if the current peak doesn't have enough data in the + # start of the peak_range, skip fitting the artifact + if peak_idx[p] - peak_range < 0: + continue + + # Deals with start portion of data + if p == 0: + pre_range = peak_range + post_range = math.floor((peak_idx[p + 1] - peak_idx[p]) / 2) + if post_range > peak_range: + post_range = peak_range + + fitted_art, post_idx_next_peak = _fit_ecg_template( + data=data, + pca_template=pca_template, + a_peak_idx=peak_idx[p], + peak_range=peak_range, + pre_range=pre_range, + post_range=post_range, + mid_p=mid_p, + fitted_art=fitted_art, + post_idx_previous_peak=post_idx_next_peak, + n_samples_fit=n_samples_fit, + ) + # Appending to list instead of using counter + window_start_idx.append(peak_idx[p] - peak_range) + window_end_idx.append(peak_idx[p] + peak_range) + + # Deals with last edge of data + elif p == peak_count - 1: + pre_range = math.floor((peak_idx[p] - peak_idx[p - 1]) / 2) + post_range = peak_range + if pre_range > peak_range: + pre_range = peak_range + fitted_art, _ = _fit_ecg_template( + data=data, + pca_template=pca_template, + a_peak_idx=peak_idx[p], + peak_range=peak_range, + pre_range=pre_range, + post_range=post_range, + mid_p=mid_p, + fitted_art=fitted_art, + post_idx_previous_peak=post_idx_next_peak, + n_samples_fit=n_samples_fit, + ) + window_start_idx.append(peak_idx[p] - peak_range) + window_end_idx.append(peak_idx[p] + peak_range) + + # Deals with middle portion of data + else: + # ---------------- Processing of central data - -------------------- + # cycle through peak artifacts identified by peakplot + pre_range = math.floor((peak_idx[p] - peak_idx[p - 1]) / 2) + post_range = math.floor((peak_idx[p + 1] - peak_idx[p]) / 2) + if pre_range >= peak_range: + pre_range = peak_range + if post_range > peak_range: + post_range = peak_range + + a_template = pca_template[ + mid_p - peak_range - 1 : mid_p + peak_range + 1, : + ] + fitted_art, post_idx_next_peak = _fit_ecg_template( + data=data, + pca_template=a_template, + a_peak_idx=peak_idx[p], + peak_range=peak_range, + pre_range=pre_range, + post_range=post_range, + mid_p=mid_p, + fitted_art=fitted_art, + post_idx_previous_peak=post_idx_next_peak, + n_samples_fit=n_samples_fit, + ) + window_start_idx.append(peak_idx[p] - peak_range) + window_end_idx.append(peak_idx[p] + peak_range) + + # Actually subtract the artefact, return needs to be the same shape as input data + data -= fitted_art + return data + + +def _fit_ecg_template( + data: np.ndarray, + pca_template: np.ndarray, + a_peak_idx: int, + peak_range: int, + pre_range: int, + post_range: int, + mid_p: float, + fitted_art: np.ndarray, + post_idx_previous_peak: int | None, + n_samples_fit: int, +) -> tuple[np.ndarray, int]: + """ + Fits the heartbeat artefact found in the data. + + Returns the fitted artefact and the index of the next peak. + + Parameters + ---------- + data (ndarray): Data from the raw signal (n_channels, n_times) + pca_template (ndarray): Mean heartbeat and first N (default 4) + principal components of the heartbeat matrix + a_peak_idx (int): Sample index of current R-peak + peak_range (int): Half the median RR-interval + pre_range (int): Number of samples to fit before the R-peak + post_range (int): Number of samples to fit after the R-peak + mid_p (float): Sample index marking middle of the median RR interval + in the signal. Used to extract relevant part of PCA_template. + fitted_art (ndarray): The computed heartbeat artefact computed to + remove from the data + post_idx_previous_peak (optional int): Sample index of previous R-peak + n_samples_fit (int): Sample fit for interpolation in fitted artifact + windows. Helps reduce sharp edges at end of fitted heartbeat events + + Returns + ------- + tuple[np.ndarray, int]: the fitted artifact and the next peak index + """ + # post_idx_next_peak is passed in in PCA_OBS, used here as post_idx_previous_peak + # Then next_peak is returned at the end and the process repeats + # select window of template + template = pca_template[mid_p - peak_range - 1 : mid_p + peak_range + 1, :] + + # select window of data and detrend it + slice_ = data[a_peak_idx - peak_range : a_peak_idx + peak_range + 1] + + detrended_data = detrend(slice_, type="constant") + + # maps data on template and then maps it again back to the sensor space + least_square = np.linalg.lstsq(template, detrended_data, rcond=None) + pad_fit = np.dot(template, least_square[0]) + + # fit artifact + fitted_art[a_peak_idx - pre_range - 1 : a_peak_idx + post_range] = pad_fit[ + mid_p - pre_range - 1 : mid_p + post_range + ].T + + # if last peak, return + if post_idx_previous_peak is None: + return fitted_art, a_peak_idx + post_range + + # interpolate time between peaks + intpol_window = np.ceil([post_idx_previous_peak, a_peak_idx - pre_range]).astype( + int + ) # interpolation window + + if intpol_window[0] < intpol_window[1]: + # Piecewise Cubic Hermite Interpolating Polynomial(PCHIP) + replace EEG data + + # You have x_fit which is two slices on either side of the interpolation window + # endpoints + # You have y_fit which is the y vals corresponding to x values above + # You have x_interpol which is the time points between the two slices in x_fit + # that you want to interpolate + # You have y_interpol which is values from pchip at the time points specified in + # x_interpol + # points to be interpolated in pt - the gap between the endpoints of the window + x_interpol = np.arange(intpol_window[0], intpol_window[1] + 1, 1) + # Entire range of x values in this step (taking some + # number of samples before and after the window) + x_fit = np.concatenate( + [ + np.arange(intpol_window[0] - n_samples_fit, intpol_window[0] + 1, 1), + np.arange(intpol_window[1], intpol_window[1] + n_samples_fit + 1, 1), + ] + ) + y_fit = fitted_art[x_fit] + y_interpol = pchip(x_fit, y_fit)(x_interpol) # perform interpolation + + # make fitted artefact in the desired range equal to the completed fit above + fitted_art[post_idx_previous_peak : a_peak_idx - pre_range + 1] = y_interpol + + return fitted_art, a_peak_idx + post_range diff --git a/mne/preprocessing/tests/test_pca_obs.py b/mne/preprocessing/tests/test_pca_obs.py new file mode 100644 index 00000000000..ee2568a2080 --- /dev/null +++ b/mne/preprocessing/tests/test_pca_obs.py @@ -0,0 +1,107 @@ +# Authors: The MNE-Python contributors. +# License: BSD-3-Clause +# Copyright the MNE-Python contributors. + +from pathlib import Path + +import numpy as np +import pytest + +from mne.io import read_raw_fif +from mne.io.fiff.raw import Raw +from mne.preprocessing import apply_pca_obs + +data_path = Path(__file__).parents[2] / "io" / "tests" / "data" +raw_fname = data_path / "test_raw.fif" + + +@pytest.fixture() +def short_raw_data(): + """Create a short, picked raw instance.""" + return read_raw_fif(raw_fname, preload=True) + + +def test_heart_artifact_removal(short_raw_data: Raw): + """Test PCA-OBS analysis and heart artifact removal of ECG datasets.""" + pd = pytest.importorskip("pandas") + + # copy the original raw. heart artifact is removed in-place + orig_df: pd.DataFrame = short_raw_data.to_data_frame().copy(deep=True) + + # fake some random qrs events in the window of the raw data + # remove first and last samples and cast to integer for indexing + ecg_event_times = np.linspace(0, orig_df["time"].iloc[-1], 20)[1:-1] + + # perform heart artifact removal + short_raw_data = apply_pca_obs( + raw=short_raw_data, picks=["eeg"], qrs_times=ecg_event_times, n_jobs=1 + ) + + # compare processed df to original df + removed_heart_artifact_df: pd.DataFrame = short_raw_data.to_data_frame() + + # ensure all column names remain the same + pd.testing.assert_index_equal( + orig_df.columns, + removed_heart_artifact_df.columns, + ) + + # ensure every column starting with EEG has been altered + altered_cols = [c for c in orig_df.columns if c.startswith("EEG")] + for col in altered_cols: + with pytest.raises( + AssertionError + ): # make sure that error is raised when we check equal + pd.testing.assert_series_equal( + orig_df[col], + removed_heart_artifact_df[col], + ) + + # ensure every column not starting with EEG has not been altered + unaltered_cols = [c for c in orig_df.columns if not c.startswith("EEG")] + pd.testing.assert_frame_equal( + orig_df[unaltered_cols], + removed_heart_artifact_df[unaltered_cols], + ) + + +# test that various nonsensical inputs raise the proper errors +@pytest.mark.parametrize( + ("picks", "qrs_times", "error", "exception"), + [ + ( + ["eeg"], + np.array([[0, 1], [2, 3]]), + "qrs_times must be a 1d array", + ValueError, + ), + ( + ["eeg"], + [2, 3, 4], + "qrs_times must be an instance of ndarray, got instead.", + TypeError, + ), + ( + ["eeg"], + np.array([None, "foo", 2]), + "qrs_times must be an array of either integers or floats", + ValueError, + ), + ( + ["eeg"], + np.array([-1, 0, 3]), + "qrs_times must be strictly positive", + ValueError, + ), + ], +) +def test_pca_obs_bad_input( + short_raw_data: Raw, + picks: list[str], + qrs_times: np.ndarray, + error: str, + exception: type[Exception], +): + """Test if bad input data raises the proper errors in the function sanity checks.""" + with pytest.raises(exception, match=error): + apply_pca_obs(raw=short_raw_data, picks=picks, qrs_times=qrs_times) diff --git a/mne/utils/numerics.py b/mne/utils/numerics.py index eed23998774..5029e8fbeca 100644 --- a/mne/utils/numerics.py +++ b/mne/utils/numerics.py @@ -871,6 +871,9 @@ def fit_transform(self, X, y=None): return U + def fit(self, X): + self._fit(X) + def _fit(self, X): if self.n_components is None: n_components = min(X.shape) diff --git a/pyproject.toml b/pyproject.toml index bb56126bc07..f20c495a2bc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -63,6 +63,7 @@ doc = [ "mne-gui-addons", "neo", "numpydoc", + "openneuro-py", "psutil", "pydata_sphinx_theme >= 0.15.2", "pygments >= 2.13", diff --git a/tools/circleci_dependencies.sh b/tools/circleci_dependencies.sh index 2ecc9718ab2..dd3216ebf06 100755 --- a/tools/circleci_dependencies.sh +++ b/tools/circleci_dependencies.sh @@ -11,6 +11,6 @@ python -m pip install --upgrade --progress-bar off \ alphaCSC autoreject bycycle conpy emd fooof meggie \ mne-ari mne-bids-pipeline mne-faster mne-features \ mne-icalabel mne-lsl mne-microstates mne-nirs mne-rsa \ - neurodsp neurokit2 niseq nitime openneuro-py pactools \ + neurodsp neurokit2 niseq nitime pactools \ plotly pycrostates pyprep pyriemann python-picard sesameeg \ sleepecg tensorpac yasa meegkit eeg_positions diff --git a/tutorials/preprocessing/50_artifact_correction_ssp.py b/tutorials/preprocessing/50_artifact_correction_ssp.py index 530e6fd39d8..28dee357f9a 100644 --- a/tutorials/preprocessing/50_artifact_correction_ssp.py +++ b/tutorials/preprocessing/50_artifact_correction_ssp.py @@ -390,6 +390,13 @@ # # See the documentation of each function for further details. # +# .. note:: +# In situations only limited electrodes are available for analysis, removing the +# cardiac artefact using techniques which rely on the availability of spatial +# information (such as SSP) may not be possible. In these instances, it may be of +# use to consider algorithms which require information only regarding heartbeat +# instances in the time domain, such as :func:`mne.preprocessing.apply_pca_obs`. +# # # Repairing EOG artifacts with SSP # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -535,6 +542,7 @@ # reduced the amplitude of our signals in sensor space, but that it should not # bias the amplitudes in source space. # +# # References # ^^^^^^^^^^ # From 3c6a054093d305a98757a97398e5e34988a3aced Mon Sep 17 00:00:00 2001 From: Qian Chu <97355086+qian-chu@users.noreply.github.com> Date: Tue, 21 Jan 2025 19:30:23 +0100 Subject: [PATCH 12/24] [BUG] Correct annotation onset for exportation to EDF and EEGLAB (#12656) --- doc/changes/devel/12656.bugfix.rst | 1 + mne/export/_brainvision.py | 7 +++ mne/export/_edf.py | 5 +- mne/export/_eeglab.py | 16 ++++-- mne/export/_export.py | 8 ++- mne/export/tests/test_export.py | 89 +++++++++++++++++++++++++++--- 6 files changed, 112 insertions(+), 14 deletions(-) create mode 100644 doc/changes/devel/12656.bugfix.rst diff --git a/doc/changes/devel/12656.bugfix.rst b/doc/changes/devel/12656.bugfix.rst new file mode 100644 index 00000000000..b3a0c62539a --- /dev/null +++ b/doc/changes/devel/12656.bugfix.rst @@ -0,0 +1 @@ +Fix bug where :func:`mne.export.export_raw` does not correct for recording start time (`raw.first_time`) when exporting Raw instances to EDF or EEGLAB formats, by `Qian Chu`_. \ No newline at end of file diff --git a/mne/export/_brainvision.py b/mne/export/_brainvision.py index ba64ba010ce..6503c540f41 100644 --- a/mne/export/_brainvision.py +++ b/mne/export/_brainvision.py @@ -107,6 +107,13 @@ def _export_mne_raw(*, raw, fname, events=None, overwrite=False): def _mne_annots2pybv_events(raw): """Convert mne Annotations to pybv events.""" + # check that raw.annotations.orig_time is the same as raw.info["meas_date"] + # so that onsets are relative to the first sample + # (after further correction for first_time) + if raw.annotations and raw.info["meas_date"] != raw.annotations.orig_time: + raise ValueError( + "Annotations must have the same orig_time as raw.info['meas_date']" + ) events = [] for annot in raw.annotations: # handle onset and duration: seconds to sample, relative to diff --git a/mne/export/_edf.py b/mne/export/_edf.py index ef870692014..e50b05f7056 100644 --- a/mne/export/_edf.py +++ b/mne/export/_edf.py @@ -7,6 +7,7 @@ import numpy as np +from ..annotations import _sync_onset from ..utils import _check_edfio_installed, warn _check_edfio_installed() @@ -204,7 +205,9 @@ def _export_raw(fname, raw, physical_range, add_ch_type): for desc, onset, duration, ch_names in zip( raw.annotations.description, - raw.annotations.onset, + # subtract raw.first_time because EDF marks events starting from the first + # available data point and ignores raw.first_time + _sync_onset(raw, raw.annotations.onset, inverse=False), raw.annotations.duration, raw.annotations.ch_names, ): diff --git a/mne/export/_eeglab.py b/mne/export/_eeglab.py index 3c8f896164a..459207f0616 100644 --- a/mne/export/_eeglab.py +++ b/mne/export/_eeglab.py @@ -4,6 +4,7 @@ import numpy as np +from ..annotations import _sync_onset from ..utils import _check_eeglabio_installed _check_eeglabio_installed() @@ -24,11 +25,16 @@ def _export_raw(fname, raw): ch_names = [ch for ch in raw.ch_names if ch not in drop_chs] cart_coords = _get_als_coords_from_chs(raw.info["chs"], drop_chs) - annotations = [ - raw.annotations.description, - raw.annotations.onset, - raw.annotations.duration, - ] + if raw.annotations: + annotations = [ + raw.annotations.description, + # subtract raw.first_time because EEGLAB marks events starting from + # the first available data point and ignores raw.first_time + _sync_onset(raw, raw.annotations.onset, inverse=False), + raw.annotations.duration, + ] + else: + annotations = None eeglabio.raw.export_set( fname, data=raw.get_data(picks=ch_names), diff --git a/mne/export/_export.py b/mne/export/_export.py index 6e63064bf7c..835eb9f8513 100644 --- a/mne/export/_export.py +++ b/mne/export/_export.py @@ -22,9 +22,15 @@ def export_raw( """Export Raw to external formats. %(export_fmt_support_raw)s - %(export_warning)s + .. warning:: + When exporting ``Raw`` with annotations, ``raw.info["meas_date"]`` must be the + same as ``raw.annotations.orig_time``. This guarantees that the annotations are + in the same reference frame as the samples. + When `Raw.first_time` is not zero (e.g., after cropping), the onsets are + automatically corrected so that onsets are always relative to the first sample. + Parameters ---------- %(fname_export_params)s diff --git a/mne/export/tests/test_export.py b/mne/export/tests/test_export.py index 191e91b1eed..6f712923c7d 100644 --- a/mne/export/tests/test_export.py +++ b/mne/export/tests/test_export.py @@ -122,6 +122,49 @@ def test_export_raw_eeglab(tmp_path): raw.export(temp_fname, overwrite=True) +@pytest.mark.parametrize("tmin", (0, 1, 5, 10)) +def test_export_raw_eeglab_annotations(tmp_path, tmin): + """Test annotations in the exported EEGLAB file. + + All annotations should be preserved and onset corrected. + """ + pytest.importorskip("eeglabio") + raw = read_raw_fif(fname_raw, preload=True) + raw.apply_proj() + annotations = Annotations( + onset=[0.01, 0.05, 0.90, 1.05], + duration=[0, 1, 0, 0], + description=["test1", "test2", "test3", "test4"], + ch_names=[["MEG 0113"], ["MEG 0113", "MEG 0132"], [], ["MEG 0143"]], + ) + raw.set_annotations(annotations) + raw.crop(tmin) + + # export + temp_fname = tmp_path / "test.set" + raw.export(temp_fname) + + # read in the file + with pytest.warns(RuntimeWarning, match="is above the 99th percentile"): + raw_read = read_raw_eeglab(temp_fname, preload=True, montage_units="m") + assert raw_read.first_time == 0 # exportation resets first_time + valid_annot = ( + raw.annotations.onset >= tmin + ) # only annotations in the cropped range gets exported + + # compare annotations before and after export + assert_array_almost_equal( + raw.annotations.onset[valid_annot] - raw.first_time, + raw_read.annotations.onset, + ) + assert_array_equal( + raw.annotations.duration[valid_annot], raw_read.annotations.duration + ) + assert_array_equal( + raw.annotations.description[valid_annot], raw_read.annotations.description + ) + + def _create_raw_for_edf_tests(stim_channel_index=None): rng = np.random.RandomState(12345) ch_types = [ @@ -154,6 +197,7 @@ def test_double_export_edf(tmp_path): """Test exporting an EDF file multiple times.""" raw = _create_raw_for_edf_tests(stim_channel_index=2) raw.info.set_meas_date("2023-09-04 14:53:09.000") + raw.set_annotations(Annotations(onset=[1], duration=[0], description=["test"])) # include subject info and measurement date raw.info["subject_info"] = dict( @@ -258,8 +302,12 @@ def test_edf_padding(tmp_path, pad_width): @edfio_mark() -def test_export_edf_annotations(tmp_path): - """Test that exporting EDF preserves annotations.""" +@pytest.mark.parametrize("tmin", (0, 0.005, 0.03, 1)) +def test_export_edf_annotations(tmp_path, tmin): + """Test annotations in the exported EDF file. + + All annotations should be preserved and onset corrected. + """ raw = _create_raw_for_edf_tests() annotations = Annotations( onset=[0.01, 0.05, 0.90, 1.05], @@ -268,17 +316,44 @@ def test_export_edf_annotations(tmp_path): ch_names=[["0"], ["0", "1"], [], ["1"]], ) raw.set_annotations(annotations) + raw.crop(tmin) + assert raw.first_time == tmin + + if raw.n_times % raw.info["sfreq"] == 0: + expectation = nullcontext() + else: + expectation = pytest.warns( + RuntimeWarning, match="EDF format requires equal-length data blocks" + ) # export temp_fname = tmp_path / "test.edf" - raw.export(temp_fname) + with expectation: + raw.export(temp_fname) # read in the file raw_read = read_raw_edf(temp_fname, preload=True) - assert_array_equal(raw.annotations.onset, raw_read.annotations.onset) - assert_array_equal(raw.annotations.duration, raw_read.annotations.duration) - assert_array_equal(raw.annotations.description, raw_read.annotations.description) - assert_array_equal(raw.annotations.ch_names, raw_read.annotations.ch_names) + assert raw_read.first_time == 0 # exportation resets first_time + bad_annot = raw_read.annotations.description == "BAD_ACQ_SKIP" + if bad_annot.any(): + raw_read.annotations.delete(bad_annot) + valid_annot = ( + raw.annotations.onset >= tmin + ) # only annotations in the cropped range gets exported + + # compare annotations before and after export + assert_array_almost_equal( + raw.annotations.onset[valid_annot] - raw.first_time, raw_read.annotations.onset + ) + assert_array_equal( + raw.annotations.duration[valid_annot], raw_read.annotations.duration + ) + assert_array_equal( + raw.annotations.description[valid_annot], raw_read.annotations.description + ) + assert_array_equal( + raw.annotations.ch_names[valid_annot], raw_read.annotations.ch_names + ) @edfio_mark() From 27386d7bc8240500efcfc618e2fa57f0bcea1ace Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 21 Jan 2025 20:23:39 +0000 Subject: [PATCH 13/24] Bump autofix-ci/action from ff86a557419858bb967097bfc916833f5647fa8c to 551dded8c6cc8a1054039c8bc0b8b48c51dfc6ef in the actions group (#13071) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/autofix.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/autofix.yml b/.github/workflows/autofix.yml index d8a99200783..18543b854d0 100644 --- a/.github/workflows/autofix.yml +++ b/.github/workflows/autofix.yml @@ -21,4 +21,4 @@ jobs: - run: pip install --upgrade towncrier pygithub gitpython numpy - run: python ./.github/actions/rename_towncrier/rename_towncrier.py - run: python ./tools/dev/ensure_headers.py - - uses: autofix-ci/action@ff86a557419858bb967097bfc916833f5647fa8c + - uses: autofix-ci/action@551dded8c6cc8a1054039c8bc0b8b48c51dfc6ef From f97a916bc79942df1cc5578ed98cddbcf1aef907 Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Tue, 21 Jan 2025 17:18:35 -0500 Subject: [PATCH 14/24] MAINT: Add Numba to 3.13 test (#13075) --- doc/changes/devel/12656.bugfix.rst | 2 +- mne/export/_export.py | 8 +++++--- mne/forward/tests/test_make_forward.py | 2 +- mne/preprocessing/tests/test_fine_cal.py | 2 +- mne/utils/docs.py | 13 ++++++++----- tools/github_actions_dependencies.sh | 2 +- tools/github_actions_env_vars.sh | 2 +- 7 files changed, 18 insertions(+), 13 deletions(-) diff --git a/doc/changes/devel/12656.bugfix.rst b/doc/changes/devel/12656.bugfix.rst index b3a0c62539a..3f32dbd23e5 100644 --- a/doc/changes/devel/12656.bugfix.rst +++ b/doc/changes/devel/12656.bugfix.rst @@ -1 +1 @@ -Fix bug where :func:`mne.export.export_raw` does not correct for recording start time (`raw.first_time`) when exporting Raw instances to EDF or EEGLAB formats, by `Qian Chu`_. \ No newline at end of file +Fix bug where :func:`mne.export.export_raw` does not correct for recording start time (:attr:`raw.first_time `) when exporting Raw instances to EDF or EEGLAB formats, by `Qian Chu`_. \ No newline at end of file diff --git a/mne/export/_export.py b/mne/export/_export.py index 835eb9f8513..4b93fda917e 100644 --- a/mne/export/_export.py +++ b/mne/export/_export.py @@ -22,14 +22,16 @@ def export_raw( """Export Raw to external formats. %(export_fmt_support_raw)s + %(export_warning)s .. warning:: When exporting ``Raw`` with annotations, ``raw.info["meas_date"]`` must be the same as ``raw.annotations.orig_time``. This guarantees that the annotations are - in the same reference frame as the samples. - When `Raw.first_time` is not zero (e.g., after cropping), the onsets are - automatically corrected so that onsets are always relative to the first sample. + in the same reference frame as the samples. When + :attr:`Raw.first_time ` is not zero (e.g., after + cropping), the onsets are automatically corrected so that onsets are always + relative to the first sample. Parameters ---------- diff --git a/mne/forward/tests/test_make_forward.py b/mne/forward/tests/test_make_forward.py index 37ec6e041b5..a357c5779c9 100644 --- a/mne/forward/tests/test_make_forward.py +++ b/mne/forward/tests/test_make_forward.py @@ -482,7 +482,7 @@ def test_make_forward_solution_openmeeg(n_layers): eeg_atol=100, meg_corr_tol=0.98, eeg_corr_tol=0.98, - meg_rdm_tol=0.1, + meg_rdm_tol=0.11, eeg_rdm_tol=0.2, ) diff --git a/mne/preprocessing/tests/test_fine_cal.py b/mne/preprocessing/tests/test_fine_cal.py index 45971620db5..8b45208e848 100644 --- a/mne/preprocessing/tests/test_fine_cal.py +++ b/mne/preprocessing/tests/test_fine_cal.py @@ -231,7 +231,7 @@ def test_fine_cal_systems(system, tmp_path): err_limit = 6000 n_ref = 28 corrs = (0.19, 0.41, 0.49) - sfs = [0.5, 0.7, 0.9, 1.5] + sfs = [0.5, 0.7, 0.9, 1.55] corr_tol = 0.55 elif system == "fil": raw = read_raw_fil(fil_fname, verbose="error") diff --git a/mne/utils/docs.py b/mne/utils/docs.py index 683704c4bc6..54cc6845e58 100644 --- a/mne/utils/docs.py +++ b/mne/utils/docs.py @@ -1494,19 +1494,22 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): docdict["export_fmt_support_epochs"] = """\ Supported formats: - - EEGLAB (``.set``, uses :mod:`eeglabio`) + +- EEGLAB (``.set``, uses :mod:`eeglabio`) """ docdict["export_fmt_support_evoked"] = """\ Supported formats: - - MFF (``.mff``, uses :func:`mne.export.export_evokeds_mff`) + +- MFF (``.mff``, uses :func:`mne.export.export_evokeds_mff`) """ docdict["export_fmt_support_raw"] = """\ Supported formats: - - BrainVision (``.vhdr``, ``.vmrk``, ``.eeg``, uses `pybv `_) - - EEGLAB (``.set``, uses :mod:`eeglabio`) - - EDF (``.edf``, uses `edfio `_) + +- BrainVision (``.vhdr``, ``.vmrk``, ``.eeg``, uses `pybv `_) +- EEGLAB (``.set``, uses :mod:`eeglabio`) +- EDF (``.edf``, uses `edfio `_) """ # noqa: E501 docdict["export_warning"] = """\ diff --git a/tools/github_actions_dependencies.sh b/tools/github_actions_dependencies.sh index cebd2caefa7..d47d9070f8b 100755 --- a/tools/github_actions_dependencies.sh +++ b/tools/github_actions_dependencies.sh @@ -23,7 +23,7 @@ if [ ! -z "$CONDA_ENV" ]; then elif [[ "${MNE_CI_KIND}" == "pip" ]]; then # Only used for 3.13 at the moment, just get test deps plus a few extras # that we know are available - INSTALL_ARGS="nibabel scikit-learn numpydoc PySide6 mne-qt-browser pandas h5io mffpy defusedxml" + INSTALL_ARGS="nibabel scikit-learn numpydoc PySide6 mne-qt-browser pandas h5io mffpy defusedxml numba" INSTALL_KIND="test" else test "${MNE_CI_KIND}" == "pip-pre" diff --git a/tools/github_actions_env_vars.sh b/tools/github_actions_env_vars.sh index 8accf72a11a..9f424ae5f48 100755 --- a/tools/github_actions_env_vars.sh +++ b/tools/github_actions_env_vars.sh @@ -28,7 +28,7 @@ else # conda-like echo "MNE_LOGGING_LEVEL=warning" | tee -a $GITHUB_ENV echo "MNE_QT_BACKEND=PySide6" | tee -a $GITHUB_ENV # TODO: Also need "|unreliable on GitHub Actions conda" on macOS, but omit for now to make sure the failure actually shows up - echo "MNE_TEST_ALLOW_SKIP=.*(Requires (spm|brainstorm) dataset|CUDA not|PySide6 causes segfaults).*" | tee -a $GITHUB_ENV + echo "MNE_TEST_ALLOW_SKIP=.*(Requires (spm|brainstorm) dataset|CUDA not|PySide6 causes segfaults|Accelerate|Flakey verbose behavior).*" | tee -a $GITHUB_ENV fi fi set +x From 99e985845759005c2d809c705241918589aa2a0e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 22 Jan 2025 01:49:51 +0000 Subject: [PATCH 15/24] [pre-commit.ci] pre-commit autoupdate (#13073) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Eric Larson --- .github/workflows/check_changelog.yml | 3 +++ .github/workflows/circle_artifacts.yml | 3 +++ .pre-commit-config.yaml | 4 ++-- 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/.github/workflows/check_changelog.yml b/.github/workflows/check_changelog.yml index cc85b591977..6995c399b34 100644 --- a/.github/workflows/check_changelog.yml +++ b/.github/workflows/check_changelog.yml @@ -5,6 +5,9 @@ on: # yamllint disable-line rule:truthy types: [opened, synchronize, labeled, unlabeled] branches: ["main"] +permissions: + contents: read + jobs: changelog_checker: name: Check towncrier entry in doc/changes/devel/ diff --git a/.github/workflows/circle_artifacts.yml b/.github/workflows/circle_artifacts.yml index fa32e1ce80c..301c6234eb5 100644 --- a/.github/workflows/circle_artifacts.yml +++ b/.github/workflows/circle_artifacts.yml @@ -1,4 +1,7 @@ on: [status] # yamllint disable-line rule:truthy +permissions: + contents: read + statuses: write jobs: circleci_artifacts_redirector_job: if: "${{ startsWith(github.event.context, 'ci/circleci: build_docs') }}" diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index cb769988655..34b3ce9b130 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,7 +1,7 @@ repos: # Ruff mne - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.9.1 + rev: v0.9.2 hooks: - id: ruff name: ruff lint mne @@ -82,7 +82,7 @@ repos: # zizmor - repo: https://github.com/woodruffw/zizmor-pre-commit - rev: v1.1.1 + rev: v1.2.2 hooks: - id: zizmor From 5f2b7f1a33d42c5a110f67e098f9efcf92be7fff Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Wed, 22 Jan 2025 12:22:24 -0500 Subject: [PATCH 16/24] BUG: Improve sklearn compliance (#13065) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: Daniel McCloy --- doc/changes/devel/13065.bugfix.rst | 7 + examples/decoding/linear_model_patterns.py | 2 +- mne/cov.py | 4 +- mne/decoding/base.py | 11 +- mne/decoding/csp.py | 99 ++++----- mne/decoding/ems.py | 25 ++- mne/decoding/search_light.py | 81 ++++--- mne/decoding/ssd.py | 138 ++++++------ mne/decoding/tests/test_base.py | 14 +- mne/decoding/tests/test_csp.py | 46 ++-- mne/decoding/tests/test_ems.py | 7 + mne/decoding/tests/test_search_light.py | 30 +-- mne/decoding/tests/test_ssd.py | 54 ++++- mne/decoding/tests/test_time_frequency.py | 23 +- mne/decoding/tests/test_transformer.py | 88 ++++++-- mne/decoding/time_frequency.py | 32 ++- mne/decoding/transformer.py | 240 +++++++++++---------- mne/time_frequency/multitaper.py | 9 +- mne/time_frequency/tfr.py | 3 +- mne/utils/numerics.py | 5 +- tools/vulture_allowlist.py | 2 + 21 files changed, 571 insertions(+), 349 deletions(-) create mode 100644 doc/changes/devel/13065.bugfix.rst diff --git a/doc/changes/devel/13065.bugfix.rst b/doc/changes/devel/13065.bugfix.rst new file mode 100644 index 00000000000..bbaa07ae127 --- /dev/null +++ b/doc/changes/devel/13065.bugfix.rst @@ -0,0 +1,7 @@ +Improved sklearn class compatibility and compliance, which resulted in some parameters of classes having an underscore appended to their name during ``fit``, such as: + +- :class:`mne.decoding.FilterEstimator` parameter ``picks`` passed to the initializer is set as ``est.picks_`` +- :class:`mne.decoding.UnsupervisedSpatialFilter` parameter ``estimator`` passed to the initializer is set as ``est.estimator_`` + +Unused ``verbose`` class parameters (that had no effect) were removed from :class:`~mne.decoding.PSDEstimator`, :class:`~mne.decoding.TemporalFilter`, and :class:`~mne.decoding.FilterEstimator` as well. +Changes by `Eric Larson`_. diff --git a/examples/decoding/linear_model_patterns.py b/examples/decoding/linear_model_patterns.py index c1390cbb0d3..7373c0a18b3 100644 --- a/examples/decoding/linear_model_patterns.py +++ b/examples/decoding/linear_model_patterns.py @@ -79,7 +79,7 @@ # Extract and plot spatial filters and spatial patterns for name, coef in (("patterns", model.patterns_), ("filters", model.filters_)): - # We fitted the linear model onto Z-scored data. To make the filters + # We fit the linear model on Z-scored data. To make the filters # interpretable, we must reverse this normalization step coef = scaler.inverse_transform([coef])[0] diff --git a/mne/cov.py b/mne/cov.py index 94239472fa2..694c836d0cd 100644 --- a/mne/cov.py +++ b/mne/cov.py @@ -1226,7 +1226,7 @@ def _compute_rank_raw_array( from .io import RawArray return _compute_rank( - RawArray(data, info, copy=None, verbose=_verbose_safe_false()), + RawArray(data, info, copy="auto", verbose=_verbose_safe_false()), rank, scalings, info, @@ -1405,7 +1405,7 @@ def _compute_covariance_auto( # project back cov = np.dot(eigvec.T, np.dot(cov, eigvec)) # undo bias - cov *= data.shape[0] / (data.shape[0] - 1) + cov *= data.shape[0] / max(data.shape[0] - 1, 1) # undo scaling _undo_scaling_cov(cov, picks_list, scalings) method_ = method[ei] diff --git a/mne/decoding/base.py b/mne/decoding/base.py index a291416bb17..f73cd976fe3 100644 --- a/mne/decoding/base.py +++ b/mne/decoding/base.py @@ -19,7 +19,7 @@ from sklearn.linear_model import LogisticRegression from sklearn.metrics import check_scoring from sklearn.model_selection import KFold, StratifiedKFold, check_cv -from sklearn.utils import check_array, indexable +from sklearn.utils import check_array, check_X_y, indexable from ..parallel import parallel_func from ..utils import _pl, logger, verbose, warn @@ -76,9 +76,9 @@ class LinearModel(MetaEstimatorMixin, BaseEstimator): ) def __init__(self, model=None): + # TODO: We need to set this to get our tag checking to work properly if model is None: model = LogisticRegression(solver="liblinear") - self.model = model def __sklearn_tags__(self): @@ -122,7 +122,11 @@ def fit(self, X, y, **fit_params): self : instance of LinearModel Returns the modified instance. """ - X = check_array(X, input_name="X") + if y is not None: + X = check_array(X) + else: + X, y = check_X_y(X, y) + self.n_features_in_ = X.shape[1] if y is not None: y = check_array(y, dtype=None, ensure_2d=False, input_name="y") if y.ndim > 2: @@ -133,6 +137,7 @@ def fit(self, X, y, **fit_params): # fit the Model self.model.fit(X, y, **fit_params) + self.model_ = self.model # for better sklearn compat # Computes patterns using Haufe's trick: A = Cov_X . W . Precision_Y diff --git a/mne/decoding/csp.py b/mne/decoding/csp.py index 9e12335cdbe..ea38fd58ca3 100644 --- a/mne/decoding/csp.py +++ b/mne/decoding/csp.py @@ -6,7 +6,8 @@ import numpy as np from scipy.linalg import eigh -from sklearn.base import BaseEstimator, TransformerMixin +from sklearn.base import BaseEstimator +from sklearn.utils.validation import check_is_fitted from .._fiff.meas_info import create_info from ..cov import _compute_rank_raw_array, _regularized_covariance, _smart_eigh @@ -19,10 +20,11 @@ fill_doc, pinv, ) +from .transformer import MNETransformerMixin @fill_doc -class CSP(TransformerMixin, BaseEstimator): +class CSP(MNETransformerMixin, BaseEstimator): """M/EEG signal decomposition using the Common Spatial Patterns (CSP). This class can be used as a supervised decomposition to estimate spatial @@ -112,49 +114,44 @@ def __init__( component_order="mutual_info", ): # Init default CSP - if not isinstance(n_components, int): - raise ValueError("n_components must be an integer.") self.n_components = n_components self.rank = rank self.reg = reg - - # Init default cov_est - if not (cov_est == "concat" or cov_est == "epoch"): - raise ValueError("unknown covariance estimation method") self.cov_est = cov_est - - # Init default transform_into - self.transform_into = _check_option( - "transform_into", transform_into, ["average_power", "csp_space"] - ) - - # Init default log - if transform_into == "average_power": - if log is not None and not isinstance(log, bool): - raise ValueError( - 'log must be a boolean if transform_into == "average_power".' - ) - else: - if log is not None: - raise ValueError('log must be a None if transform_into == "csp_space".') + self.transform_into = transform_into self.log = log - - _validate_type(norm_trace, bool, "norm_trace") self.norm_trace = norm_trace self.cov_method_params = cov_method_params - self.component_order = _check_option( - "component_order", component_order, ("mutual_info", "alternate") + self.component_order = component_order + + def _validate_params(self, *, y): + _validate_type(self.n_components, int, "n_components") + if hasattr(self, "cov_est"): + _validate_type(self.cov_est, str, "cov_est") + _check_option("cov_est", self.cov_est, ("concat", "epoch")) + if hasattr(self, "norm_trace"): + _validate_type(self.norm_trace, bool, "norm_trace") + _check_option( + "transform_into", self.transform_into, ["average_power", "csp_space"] ) - - def _check_Xy(self, X, y=None): - """Check input data.""" - if not isinstance(X, np.ndarray): - raise ValueError(f"X should be of type ndarray (got {type(X)}).") - if y is not None: - if len(X) != len(y) or len(y) < 1: - raise ValueError("X and y must have the same length.") - if X.ndim < 3: - raise ValueError("X must have at least 3 dimensions.") + if self.transform_into == "average_power": + _validate_type( + self.log, + (bool, None), + "log", + extra="when transform_into is 'average_power'", + ) + else: + _validate_type( + self.log, None, "log", extra="when transform_into is 'csp_space'" + ) + _check_option( + "component_order", self.component_order, ("mutual_info", "alternate") + ) + self.classes_ = np.unique(y) + n_classes = len(self.classes_) + if n_classes < 2: + raise ValueError(f"n_classes must be >= 2, but got {n_classes} class") def fit(self, X, y): """Estimate the CSP decomposition on epochs. @@ -171,12 +168,9 @@ def fit(self, X, y): self : instance of CSP Returns the modified instance. """ - self._check_Xy(X, y) - - self._classes = np.unique(y) - n_classes = len(self._classes) - if n_classes < 2: - raise ValueError("n_classes must be >= 2.") + X, y = self._check_data(X, y=y, fit=True, return_y=True) + self._validate_params(y=y) + n_classes = len(self.classes_) if n_classes > 2 and self.component_order == "alternate": raise ValueError( "component_order='alternate' requires two classes, but data contains " @@ -225,13 +219,8 @@ def transform(self, X): If self.transform_into == 'csp_space' then returns the data in CSP space and shape is (n_epochs, n_components, n_times). """ - if not isinstance(X, np.ndarray): - raise ValueError(f"X should be of type ndarray (got {type(X)}).") - if self.filters_ is None: - raise RuntimeError( - "No filters available. Please first fit CSP decomposition." - ) - + check_is_fitted(self, "filters_") + X = self._check_data(X) pick_filters = self.filters_[: self.n_components] X = np.asarray([np.dot(pick_filters, epoch) for epoch in X]) @@ -577,7 +566,7 @@ def _compute_covariance_matrices(self, X, y): covs = [] sample_weights = [] - for ci, this_class in enumerate(self._classes): + for ci, this_class in enumerate(self.classes_): cov, weight = cov_estimator( X[y == this_class], cov_kind=f"class={this_class}", @@ -689,7 +678,7 @@ def _normalize_eigenvectors(self, eigen_vectors, covs, sample_weights): def _order_components( self, covs, sample_weights, eigen_vectors, eigen_values, component_order ): - n_classes = len(self._classes) + n_classes = len(self.classes_) if component_order == "mutual_info" and n_classes > 2: mutual_info = self._compute_mutual_info(covs, sample_weights, eigen_vectors) ix = np.argsort(mutual_info)[::-1] @@ -889,10 +878,8 @@ def fit(self, X, y): self : instance of SPoC Returns the modified instance. """ - self._check_Xy(X, y) - - if len(np.unique(y)) < 2: - raise ValueError("y must have at least two distinct values.") + X, y = self._check_data(X, y=y, fit=True, return_y=True) + self._validate_params(y=y) # The following code is directly copied from pyRiemann diff --git a/mne/decoding/ems.py b/mne/decoding/ems.py index b3e72a30e21..5c7557798ef 100644 --- a/mne/decoding/ems.py +++ b/mne/decoding/ems.py @@ -5,15 +5,16 @@ from collections import Counter import numpy as np -from sklearn.base import BaseEstimator, TransformerMixin +from sklearn.base import BaseEstimator from .._fiff.pick import _picks_to_idx, pick_info, pick_types from ..parallel import parallel_func from ..utils import logger, verbose from .base import _set_cv +from .transformer import MNETransformerMixin -class EMS(TransformerMixin, BaseEstimator): +class EMS(MNETransformerMixin, BaseEstimator): """Transformer to compute event-matched spatial filters. This version of EMS :footcite:`SchurgerEtAl2013` operates on the entire @@ -37,6 +38,16 @@ class EMS(TransformerMixin, BaseEstimator): .. footbibliography:: """ + def __sklearn_tags__(self): + """Return sklearn tags.""" + from sklearn.utils import ClassifierTags + + tags = super().__sklearn_tags__() + if tags.classifier_tags is None: + tags.classifier_tags = ClassifierTags() + tags.classifier_tags.multi_class = False + return tags + def __repr__(self): # noqa: D105 if hasattr(self, "filters_"): return ( @@ -64,11 +75,12 @@ def fit(self, X, y): self : instance of EMS Returns self. """ - classes = np.unique(y) - if len(classes) != 2: + X, y = self._check_data(X, y=y, fit=True, return_y=True) + classes, y = np.unique(y, return_inverse=True) + if len(classes) > 2: raise ValueError("EMS only works for binary classification.") self.classes_ = classes - filters = X[y == classes[0]].mean(0) - X[y == classes[1]].mean(0) + filters = X[y == 0].mean(0) - X[y == 1].mean(0) filters /= np.linalg.norm(filters, axis=0)[None, :] self.filters_ = filters return self @@ -86,13 +98,14 @@ def transform(self, X): X : array, shape (n_epochs, n_times) The input data transformed by the spatial filters. """ + X = self._check_data(X) Xt = np.sum(X * self.filters_, axis=1) return Xt @verbose def compute_ems( - epochs, conditions=None, picks=None, n_jobs=None, cv=None, verbose=None + epochs, conditions=None, picks=None, n_jobs=None, cv=None, *, verbose=None ): """Compute event-matched spatial filter on epochs. diff --git a/mne/decoding/search_light.py b/mne/decoding/search_light.py index e3059a3e959..8bd96781185 100644 --- a/mne/decoding/search_light.py +++ b/mne/decoding/search_light.py @@ -5,18 +5,25 @@ import logging import numpy as np -from sklearn.base import BaseEstimator, MetaEstimatorMixin, TransformerMixin, clone +from sklearn.base import BaseEstimator, MetaEstimatorMixin, clone from sklearn.metrics import check_scoring from sklearn.preprocessing import LabelEncoder -from sklearn.utils import check_array +from sklearn.utils.validation import check_is_fitted from ..parallel import parallel_func -from ..utils import ProgressBar, _parse_verbose, array_split_idx, fill_doc, verbose +from ..utils import ( + ProgressBar, + _parse_verbose, + _verbose_safe_false, + array_split_idx, + fill_doc, +) from .base import _check_estimator +from .transformer import MNETransformerMixin @fill_doc -class SlidingEstimator(MetaEstimatorMixin, TransformerMixin, BaseEstimator): +class SlidingEstimator(MetaEstimatorMixin, MNETransformerMixin, BaseEstimator): """Search Light. Fit, predict and score a series of models to each subset of the dataset @@ -38,7 +45,6 @@ class SlidingEstimator(MetaEstimatorMixin, TransformerMixin, BaseEstimator): List of fitted scikit-learn estimators (one per task). """ - @verbose def __init__( self, base_estimator, @@ -49,7 +55,6 @@ def __init__( allow_2d=False, verbose=None, ): - _check_estimator(base_estimator) self.base_estimator = base_estimator self.n_jobs = n_jobs self.scoring = scoring @@ -102,9 +107,13 @@ def fit(self, X, y, **fit_params): self : object Return self. """ - X = self._check_Xy(X, y) + _check_estimator(self.base_estimator) + X, _ = self._check_Xy(X, y, fit=True) parallel, p_func, n_jobs = parallel_func( - _sl_fit, self.n_jobs, max_jobs=X.shape[-1], verbose=False + _sl_fit, + self.n_jobs, + max_jobs=X.shape[-1], + verbose=_verbose_safe_false(), ) self.estimators_ = list() self.fit_params_ = fit_params @@ -153,14 +162,19 @@ def fit_transform(self, X, y, **fit_params): def _transform(self, X, method): """Aux. function to make parallel predictions/transformation.""" - X = self._check_Xy(X) + X, is_nd = self._check_Xy(X) + orig_method = method + check_is_fitted(self) method = _check_method(self.base_estimator, method) if X.shape[-1] != len(self.estimators_): raise ValueError("The number of estimators does not match X.shape[-1]") # For predictions/transforms the parallelization is across the data and # not across the estimators to avoid memory load. parallel, p_func, n_jobs = parallel_func( - _sl_transform, self.n_jobs, max_jobs=X.shape[-1], verbose=False + _sl_transform, + self.n_jobs, + max_jobs=X.shape[-1], + verbose=_verbose_safe_false(), ) X_splits = np.array_split(X, n_jobs, axis=-1) @@ -174,6 +188,10 @@ def _transform(self, X, method): ) y_pred = np.concatenate(y_pred, axis=1) + if orig_method == "transform": + y_pred = y_pred.astype(X.dtype) + if orig_method == "predict_proba" and not is_nd: + y_pred = y_pred[:, 0, :] return y_pred def transform(self, X): @@ -196,7 +214,7 @@ def transform(self, X): Xt : array, shape (n_samples, n_estimators) The transformed values generated by each estimator. """ # noqa: E501 - return self._transform(X, "transform").astype(X.dtype) + return self._transform(X, "transform") def predict(self, X): """Predict each data slice/task with a series of independent estimators. @@ -265,15 +283,12 @@ def decision_function(self, X): """ # noqa: E501 return self._transform(X, "decision_function") - def _check_Xy(self, X, y=None): + def _check_Xy(self, X, y=None, fit=False): """Aux. function to check input data.""" # Once we require sklearn 1.1+ we should do something like: - X = check_array(X, ensure_2d=False, allow_nd=True, input_name="X") - if y is not None: - y = check_array(y, dtype=None, ensure_2d=False, input_name="y") - if len(X) != len(y) or len(y) < 1: - raise ValueError("X and y must have the same length.") - if X.ndim < 3: + X = self._check_data(X, y=y, atleast_3d=False, fit=fit) + is_nd = X.ndim >= 3 + if not is_nd: err = None if not self.allow_2d: err = 3 @@ -282,7 +297,7 @@ def _check_Xy(self, X, y=None): if err: raise ValueError(f"X must have at least {err} dimensions.") X = X[..., np.newaxis] - return X + return X, is_nd def score(self, X, y): """Score each estimator on each task. @@ -307,7 +322,7 @@ def score(self, X, y): score : array, shape (n_samples, n_estimators) Score for each estimator/task. """ # noqa: E501 - X = self._check_Xy(X, y) + X, _ = self._check_Xy(X, y) if X.shape[-1] != len(self.estimators_): raise ValueError("The number of estimators does not match X.shape[-1]") @@ -317,7 +332,10 @@ def score(self, X, y): # For predictions/transforms the parallelization is across the data and # not across the estimators to avoid memory load. parallel, p_func, n_jobs = parallel_func( - _sl_score, self.n_jobs, max_jobs=X.shape[-1], verbose=False + _sl_score, + self.n_jobs, + max_jobs=X.shape[-1], + verbose=_verbose_safe_false(), ) X_splits = np.array_split(X, n_jobs, axis=-1) est_splits = np.array_split(self.estimators_, n_jobs) @@ -483,11 +501,16 @@ def __repr__(self): # noqa: D105 def _transform(self, X, method): """Aux. function to make parallel predictions/transformation.""" - X = self._check_Xy(X) + X, is_nd = self._check_Xy(X) + check_is_fitted(self) + orig_method = method method = _check_method(self.base_estimator, method) parallel, p_func, n_jobs = parallel_func( - _gl_transform, self.n_jobs, max_jobs=X.shape[-1], verbose=False + _gl_transform, + self.n_jobs, + max_jobs=X.shape[-1], + verbose=_verbose_safe_false(), ) context = _create_progressbar_context(self, X, "Transforming") @@ -500,6 +523,10 @@ def _transform(self, X, method): ) y_pred = np.concatenate(y_pred, axis=2) + if orig_method == "transform": + y_pred = y_pred.astype(X.dtype) + if orig_method == "predict_proba" and not is_nd: + y_pred = y_pred[:, 0, 0, :] return y_pred def transform(self, X): @@ -518,6 +545,7 @@ def transform(self, X): Xt : array, shape (n_samples, n_estimators, n_slices) The transformed values generated by each estimator. """ + check_is_fitted(self) return self._transform(X, "transform") def predict(self, X): @@ -603,11 +631,14 @@ def score(self, X, y): score : array, shape (n_samples, n_estimators, n_slices) Score for each estimator / data slice couple. """ # noqa: E501 - X = self._check_Xy(X, y) + X, _ = self._check_Xy(X, y) # For predictions/transforms the parallelization is across the data and # not across the estimators to avoid memory load. parallel, p_func, n_jobs = parallel_func( - _gl_score, self.n_jobs, max_jobs=X.shape[-1], verbose=False + _gl_score, + self.n_jobs, + max_jobs=X.shape[-1], + verbose=_verbose_safe_false(), ) scoring = check_scoring(self.base_estimator, self.scoring) y = _fix_auc(scoring, y) diff --git a/mne/decoding/ssd.py b/mne/decoding/ssd.py index 8bc0036d315..111ded9f274 100644 --- a/mne/decoding/ssd.py +++ b/mne/decoding/ssd.py @@ -4,8 +4,10 @@ import numpy as np from scipy.linalg import eigh -from sklearn.base import BaseEstimator, TransformerMixin +from sklearn.base import BaseEstimator +from sklearn.utils.validation import check_is_fitted +from .._fiff.meas_info import Info, create_info from .._fiff.pick import _picks_to_idx from ..cov import Covariance, _regularized_covariance from ..defaults import _handle_default @@ -13,17 +15,17 @@ from ..rank import compute_rank from ..time_frequency import psd_array_welch from ..utils import ( - _check_option, _time_mask, _validate_type, _verbose_safe_false, fill_doc, logger, ) +from .transformer import MNETransformerMixin @fill_doc -class SSD(TransformerMixin, BaseEstimator): +class SSD(MNETransformerMixin, BaseEstimator): """ Signal decomposition using the Spatio-Spectral Decomposition (SSD). @@ -64,7 +66,7 @@ class SSD(TransformerMixin, BaseEstimator): If sort_by_spectral_ratio is set to True, then the SSD sources will be sorted according to their spectral ratio which is calculated based on :func:`mne.time_frequency.psd_array_welch`. The n_fft parameter sets the - length of FFT used. + length of FFT used. The default (None) will use 1 second of data. See :func:`mne.time_frequency.psd_array_welch` for more information. cov_method_params : dict | None (default None) As in :class:`mne.decoding.SPoC` @@ -104,7 +106,25 @@ def __init__( rank=None, ): """Initialize instance.""" - dicts = {"signal": filt_params_signal, "noise": filt_params_noise} + self.info = info + self.filt_params_signal = filt_params_signal + self.filt_params_noise = filt_params_noise + self.reg = reg + self.n_components = n_components + self.picks = picks + self.sort_by_spectral_ratio = sort_by_spectral_ratio + self.return_filtered = return_filtered + self.n_fft = n_fft + self.cov_method_params = cov_method_params + self.rank = rank + + def _validate_params(self, X): + if isinstance(self.info, float): # special case, mostly for testing + self.sfreq_ = self.info + else: + _validate_type(self.info, Info, "info") + self.sfreq_ = self.info["sfreq"] + dicts = {"signal": self.filt_params_signal, "noise": self.filt_params_noise} for param, dd in [("l", 0), ("h", 0), ("l", 1), ("h", 1)]: key = ("signal", "noise")[dd] if param + "_freq" not in dicts[key]: @@ -116,48 +136,47 @@ def __init__( _validate_type(val, ("numeric",), f"{key} {param}_freq") # check freq bands if ( - filt_params_noise["l_freq"] > filt_params_signal["l_freq"] - or filt_params_signal["h_freq"] > filt_params_noise["h_freq"] + self.filt_params_noise["l_freq"] > self.filt_params_signal["l_freq"] + or self.filt_params_signal["h_freq"] > self.filt_params_noise["h_freq"] ): raise ValueError( "Wrongly specified frequency bands!\n" "The signal band-pass must be within the noise " "band-pass!" ) - self.picks = picks - del picks - self.info = info - self.freqs_signal = (filt_params_signal["l_freq"], filt_params_signal["h_freq"]) - self.freqs_noise = (filt_params_noise["l_freq"], filt_params_noise["h_freq"]) - self.filt_params_signal = filt_params_signal - self.filt_params_noise = filt_params_noise - # check if boolean - if not isinstance(sort_by_spectral_ratio, (bool)): - raise ValueError("sort_by_spectral_ratio must be boolean") - self.sort_by_spectral_ratio = sort_by_spectral_ratio - if n_fft is None: - self.n_fft = int(self.info["sfreq"]) - else: - self.n_fft = int(n_fft) - # check if boolean - if not isinstance(return_filtered, (bool)): - raise ValueError("return_filtered must be boolean") - self.return_filtered = return_filtered - self.reg = reg - self.n_components = n_components - self.rank = rank - self.cov_method_params = cov_method_params + self.freqs_signal_ = ( + self.filt_params_signal["l_freq"], + self.filt_params_signal["h_freq"], + ) + self.freqs_noise_ = ( + self.filt_params_noise["l_freq"], + self.filt_params_noise["h_freq"], + ) + _validate_type(self.sort_by_spectral_ratio, (bool,), "sort_by_spectral_ratio") + _validate_type(self.n_fft, ("numeric", None), "n_fft") + self.n_fft_ = min( + int(self.n_fft if self.n_fft is not None else self.sfreq_), + X.shape[-1], + ) + _validate_type(self.return_filtered, (bool,), "return_filtered") + if isinstance(self.info, Info): + ch_types = self.info.get_channel_types(picks=self.picks, unique=True) + if len(ch_types) > 1: + raise ValueError( + "At this point SSD only supports fitting " + f"single channel types. Your info has {len(ch_types)} types." + ) - def _check_X(self, X): + def _check_X(self, X, *, y=None, fit=False): """Check input data.""" - _validate_type(X, np.ndarray, "X") - _check_option("X.ndim", X.ndim, (2, 3)) + X = self._check_data(X, y=y, fit=fit, atleast_3d=False) n_chan = X.shape[-2] - if n_chan != self.info["nchan"]: + if isinstance(self.info, Info) and n_chan != self.info["nchan"]: raise ValueError( "Info must match the input data." f"Found {n_chan} channels but expected {self.info['nchan']}." ) + return X def fit(self, X, y=None): """Estimate the SSD decomposition on raw or epoched data. @@ -176,18 +195,17 @@ def fit(self, X, y=None): self : instance of SSD Returns the modified instance. """ - ch_types = self.info.get_channel_types(picks=self.picks, unique=True) - if len(ch_types) > 1: - raise ValueError( - "At this point SSD only supports fitting " - f"single channel types. Your info has {len(ch_types)} types." - ) - self.picks_ = _picks_to_idx(self.info, self.picks, none="data", exclude="bads") - self._check_X(X) + X = self._check_X(X, y=y, fit=True) + self._validate_params(X) + if isinstance(self.info, Info): + info = self.info + else: + info = create_info(X.shape[-2], self.sfreq_, ch_types="eeg") + self.picks_ = _picks_to_idx(info, self.picks, none="data", exclude="bads") X_aux = X[..., self.picks_, :] - X_signal = filter_data(X_aux, self.info["sfreq"], **self.filt_params_signal) - X_noise = filter_data(X_aux, self.info["sfreq"], **self.filt_params_noise) + X_signal = filter_data(X_aux, self.sfreq_, **self.filt_params_signal) + X_noise = filter_data(X_aux, self.sfreq_, **self.filt_params_noise) X_noise -= X_signal if X.ndim == 3: X_signal = np.hstack(X_signal) @@ -199,19 +217,19 @@ def fit(self, X, y=None): reg=self.reg, method_params=self.cov_method_params, rank="full", - info=self.info, + info=info, ) cov_noise = _regularized_covariance( X_noise, reg=self.reg, method_params=self.cov_method_params, rank="full", - info=self.info, + info=info, ) # project cov to rank subspace cov_signal, cov_noise, rank_proj = _dimensionality_reduction( - cov_signal, cov_noise, self.info, self.rank + cov_signal, cov_noise, info, self.rank ) eigvals_, eigvects_ = eigh(cov_signal, cov_noise) @@ -226,10 +244,10 @@ def fit(self, X, y=None): # than the initial ordering. This ordering should be also learned when # fitting. X_ssd = self.filters_.T @ X[..., self.picks_, :] - sorter_spec = Ellipsis + sorter_spec = slice(None) if self.sort_by_spectral_ratio: _, sorter_spec = self.get_spectral_ratio(ssd_sources=X_ssd) - self.sorter_spec = sorter_spec + self.sorter_spec_ = sorter_spec logger.info("Done.") return self @@ -248,17 +266,13 @@ def transform(self, X): X_ssd : array, shape ([n_epochs, ]n_components, n_times) The processed data. """ - self._check_X(X) - if self.filters_ is None: - raise RuntimeError("No filters available. Please first call fit") + check_is_fitted(self, "filters_") + X = self._check_X(X) if self.return_filtered: X_aux = X[..., self.picks_, :] - X = filter_data(X_aux, self.info["sfreq"], **self.filt_params_signal) + X = filter_data(X_aux, self.sfreq_, **self.filt_params_signal) X_ssd = self.filters_.T @ X[..., self.picks_, :] - if X.ndim == 2: - X_ssd = X_ssd[self.sorter_spec][: self.n_components] - else: - X_ssd = X_ssd[:, self.sorter_spec, :][:, : self.n_components, :] + X_ssd = X_ssd[..., self.sorter_spec_, :][..., : self.n_components, :] return X_ssd def fit_transform(self, X, y=None, **fit_params): @@ -308,11 +322,9 @@ def get_spectral_ratio(self, ssd_sources): ---------- .. footbibliography:: """ - psd, freqs = psd_array_welch( - ssd_sources, sfreq=self.info["sfreq"], n_fft=self.n_fft - ) - sig_idx = _time_mask(freqs, *self.freqs_signal) - noise_idx = _time_mask(freqs, *self.freqs_noise) + psd, freqs = psd_array_welch(ssd_sources, sfreq=self.sfreq_, n_fft=self.n_fft_) + sig_idx = _time_mask(freqs, *self.freqs_signal_) + noise_idx = _time_mask(freqs, *self.freqs_noise_) if psd.ndim == 3: mean_sig = psd[:, :, sig_idx].mean(axis=2).mean(axis=0) mean_noise = psd[:, :, noise_idx].mean(axis=2).mean(axis=0) @@ -352,7 +364,7 @@ def apply(self, X): The processed data. """ X_ssd = self.transform(X) - pick_patterns = self.patterns_[self.sorter_spec][: self.n_components].T + pick_patterns = self.patterns_[self.sorter_spec_][: self.n_components].T X = pick_patterns @ X_ssd return X diff --git a/mne/decoding/tests/test_base.py b/mne/decoding/tests/test_base.py index 6d915dd24f9..504e309d53c 100644 --- a/mne/decoding/tests/test_base.py +++ b/mne/decoding/tests/test_base.py @@ -86,6 +86,8 @@ def _make_data(n_samples=1000, n_features=5, n_targets=3): X = Y.dot(A.T) X += np.random.randn(n_samples, n_features) # add noise X += np.random.rand(n_features) # Put an offset + if n_targets == 1: + Y = Y[:, 0] return X, Y, A @@ -95,7 +97,7 @@ def test_get_coef(): """Test getting linear coefficients (filters/patterns) from estimators.""" lm_classification = LinearModel() assert hasattr(lm_classification, "__sklearn_tags__") - print(lm_classification.__sklearn_tags__) + print(lm_classification.__sklearn_tags__()) assert is_classifier(lm_classification.model) assert is_classifier(lm_classification) assert not is_regressor(lm_classification.model) @@ -273,7 +275,12 @@ def test_get_coef_multiclass(n_features, n_targets): """Test get_coef on multiclass problems.""" # Check patterns with more than 1 regressor X, Y, A = _make_data(n_samples=30000, n_features=n_features, n_targets=n_targets) - lm = LinearModel(LinearRegression()).fit(X, Y) + lm = LinearModel(LinearRegression()) + assert not hasattr(lm, "model_") + lm.fit(X, Y) + # TODO: modifying non-underscored `model` is a sklearn no-no, maybe should be a + # metaestimator? + assert lm.model is lm.model_ assert_array_equal(lm.filters_.shape, lm.patterns_.shape) if n_targets == 1: want_shape = (n_features,) @@ -473,9 +480,8 @@ def test_cross_val_multiscore(): def test_sklearn_compliance(estimator, check): """Test LinearModel compliance with sklearn.""" ignores = ( - "check_n_features_in", # maybe we should add this someday? - "check_estimator_sparse_data", # we densify "check_estimators_overwrite_params", # self.model changes! + "check_dont_overwrite_parameters", "check_parameters_default_constructible", ) if any(ignore in str(check) for ignore in ignores): diff --git a/mne/decoding/tests/test_csp.py b/mne/decoding/tests/test_csp.py index 7a1a83feeaf..e754b6952f9 100644 --- a/mne/decoding/tests/test_csp.py +++ b/mne/decoding/tests/test_csp.py @@ -19,6 +19,7 @@ from sklearn.model_selection import StratifiedKFold, cross_val_score from sklearn.pipeline import Pipeline, make_pipeline from sklearn.svm import SVC +from sklearn.utils.estimator_checks import parametrize_with_checks from mne import Epochs, compute_proj_raw, io, pick_types, read_events from mne.decoding import CSP, LinearModel, Scaler, SPoC, get_coef @@ -139,18 +140,22 @@ def test_csp(): y = epochs.events[:, -1] # Init - pytest.raises(ValueError, CSP, n_components="foo", norm_trace=False) + csp = CSP(n_components="foo") + with pytest.raises(TypeError, match="must be an instance"): + csp.fit(epochs_data, y) for reg in ["foo", -0.1, 1.1]: csp = CSP(reg=reg, norm_trace=False) pytest.raises(ValueError, csp.fit, epochs_data, epochs.events[:, -1]) for reg in ["oas", "ledoit_wolf", 0, 0.5, 1.0]: CSP(reg=reg, norm_trace=False) - for cov_est in ["foo", None]: - pytest.raises(ValueError, CSP, cov_est=cov_est, norm_trace=False) + csp = CSP(cov_est="foo", norm_trace=False) + with pytest.raises(ValueError, match="Invalid value"): + csp.fit(epochs_data, y) + csp = CSP(norm_trace="foo") with pytest.raises(TypeError, match="instance of bool"): - CSP(norm_trace="foo") + csp.fit(epochs_data, y) for cov_est in ["concat", "epoch"]: - CSP(cov_est=cov_est, norm_trace=False) + CSP(cov_est=cov_est, norm_trace=False).fit(epochs_data, y) n_components = 3 # Fit @@ -171,8 +176,8 @@ def test_csp(): # Test data exception pytest.raises(ValueError, csp.fit, epochs_data, np.zeros_like(epochs.events)) - pytest.raises(ValueError, csp.fit, epochs, y) - pytest.raises(ValueError, csp.transform, epochs) + pytest.raises(ValueError, csp.fit, "foo", y) + pytest.raises(ValueError, csp.transform, "foo") # Test plots epochs.pick(picks="mag") @@ -200,7 +205,7 @@ def test_csp(): for cov_est in ["concat", "epoch"]: csp = CSP(n_components=n_components, cov_est=cov_est, norm_trace=False) csp.fit(epochs_data, epochs.events[:, 2]).transform(epochs_data) - assert_equal(len(csp._classes), 4) + assert_equal(len(csp.classes_), 4) assert_array_equal(csp.filters_.shape, [n_channels, n_channels]) assert_array_equal(csp.patterns_.shape, [n_channels, n_channels]) @@ -220,15 +225,17 @@ def test_csp(): # Different normalization return different transform assert np.sum((X_trans["True"] - X_trans["False"]) ** 2) > 1.0 # Check wrong inputs - pytest.raises(ValueError, CSP, transform_into="average_power", log="foo") + csp = CSP(transform_into="average_power", log="foo") + with pytest.raises(TypeError, match="must be an instance of bool"): + csp.fit(epochs_data, epochs.events[:, 2]) # Test csp space transform csp = CSP(transform_into="csp_space", norm_trace=False) assert csp.transform_into == "csp_space" for log in ("foo", True, False): - pytest.raises( - ValueError, CSP, transform_into="csp_space", log=log, norm_trace=False - ) + csp = CSP(transform_into="csp_space", log=log, norm_trace=False) + with pytest.raises(TypeError, match="must be an instance"): + csp.fit(epochs_data, epochs.events[:, 2]) n_components = 2 csp = CSP(n_components=n_components, transform_into="csp_space", norm_trace=False) Xt = csp.fit(epochs_data, epochs.events[:, 2]).transform(epochs_data) @@ -343,8 +350,8 @@ def test_regularized_csp(ch_type, rank, reg): # test init exception pytest.raises(ValueError, csp.fit, epochs_data, np.zeros_like(epochs.events)) - pytest.raises(ValueError, csp.fit, epochs, y) - pytest.raises(ValueError, csp.transform, epochs) + pytest.raises(ValueError, csp.fit, "foo", y) + pytest.raises(ValueError, csp.transform, "foo") csp.n_components = n_components sources = csp.transform(epochs_data) @@ -465,7 +472,9 @@ def test_csp_component_ordering(): """Test that CSP component ordering works as expected.""" x, y = deterministic_toy_data(["class_a", "class_b"]) - pytest.raises(ValueError, CSP, component_order="invalid") + csp = CSP(component_order="invalid") + with pytest.raises(ValueError, match="Invalid value"): + csp.fit(x, y) # component_order='alternate' only works with two classes csp = CSP(component_order="alternate") @@ -480,3 +489,10 @@ def test_csp_component_ordering(): # p_alt arranges them to [0.8, 0.06, 0.5, 0.1] # p_mut arranges them to [0.06, 0.1, 0.8, 0.5] assert_array_almost_equal(p_alt, p_mut[[2, 0, 3, 1]]) + + +@pytest.mark.filterwarnings("ignore:.*Only one sample available.*") +@parametrize_with_checks([CSP(), SPoC()]) +def test_sklearn_compliance(estimator, check): + """Test compliance with sklearn.""" + check(estimator) diff --git a/mne/decoding/tests/test_ems.py b/mne/decoding/tests/test_ems.py index 10774c0681a..dc54303a541 100644 --- a/mne/decoding/tests/test_ems.py +++ b/mne/decoding/tests/test_ems.py @@ -11,6 +11,7 @@ pytest.importorskip("sklearn") from sklearn.model_selection import StratifiedKFold +from sklearn.utils.estimator_checks import parametrize_with_checks from mne import Epochs, io, pick_types, read_events from mne.decoding import EMS, compute_ems @@ -91,3 +92,9 @@ def test_ems(): assert_equal(ems.__repr__(), "") assert_array_almost_equal(filters, np.mean(coefs, axis=0)) assert_array_almost_equal(surrogates, np.vstack(Xt)) + + +@parametrize_with_checks([EMS()]) +def test_sklearn_compliance(estimator, check): + """Test compliance with sklearn.""" + check(estimator) diff --git a/mne/decoding/tests/test_search_light.py b/mne/decoding/tests/test_search_light.py index 7cb3a66dd81..e7abfd9209e 100644 --- a/mne/decoding/tests/test_search_light.py +++ b/mne/decoding/tests/test_search_light.py @@ -41,7 +41,7 @@ def make_data(): return X, y -def test_search_light(): +def test_search_light_basic(): """Test SlidingEstimator.""" # https://github.com/scikit-learn/scikit-learn/issues/27711 if platform.system() == "Windows" and check_version("numpy", "2.0.0.dev0"): @@ -52,7 +52,9 @@ def test_search_light(): X, y = make_data() n_epochs, _, n_time = X.shape # init - pytest.raises(ValueError, SlidingEstimator, "foo") + sl = SlidingEstimator("foo") + with pytest.raises(ValueError, match="must be"): + sl.fit(X, y) sl = SlidingEstimator(Ridge()) assert not is_classifier(sl) sl = SlidingEstimator(LogisticRegression(solver="liblinear")) @@ -69,7 +71,8 @@ def test_search_light(): # transforms pytest.raises(ValueError, sl.predict, X[:, :, :2]) y_trans = sl.transform(X) - assert X.dtype == y_trans.dtype == np.dtype(float) + assert X.dtype == float + assert y_trans.dtype == float y_pred = sl.predict(X) assert y_pred.dtype == np.dtype(int) assert_array_equal(y_pred.shape, [n_epochs, n_time]) @@ -344,22 +347,19 @@ def predict_proba(self, X): @pytest.mark.slowtest -@parametrize_with_checks([SlidingEstimator(LogisticRegression(), allow_2d=True)]) +@parametrize_with_checks( + [ + SlidingEstimator(LogisticRegression(), allow_2d=True), + GeneralizingEstimator(LogisticRegression(), allow_2d=True), + ] +) def test_sklearn_compliance(estimator, check): """Test LinearModel compliance with sklearn.""" ignores = ( - "check_estimator_sparse_data", # we densify - "check_classifiers_one_label_sample_weights", # don't handle singleton - "check_classifiers_classes", # dim mismatch + # TODO: we don't handle singleton right (probably) + "check_classifiers_one_label_sample_weights", + "check_classifiers_classes", "check_classifiers_train", - "check_decision_proba_consistency", - "check_parameters_default_constructible", - # Should probably fix these? - "check_estimators_unfitted", - "check_transformer_data_not_an_array", - "check_n_features_in", - "check_fit2d_predict1d", - "check_do_not_raise_errors_in_init_or_set_params", ) if any(ignore in str(check) for ignore in ignores): return diff --git a/mne/decoding/tests/test_ssd.py b/mne/decoding/tests/test_ssd.py index 8f4d2472803..b6cdfc472c3 100644 --- a/mne/decoding/tests/test_ssd.py +++ b/mne/decoding/tests/test_ssd.py @@ -11,6 +11,7 @@ pytest.importorskip("sklearn") from sklearn.pipeline import Pipeline +from sklearn.utils.estimator_checks import parametrize_with_checks from mne import create_info, io from mne.decoding import CSP @@ -101,8 +102,9 @@ def test_ssd(): l_trans_bandwidth=1, h_trans_bandwidth=1, ) + ssd = SSD(info, filt_params_signal, filt_params_noise) with pytest.raises(TypeError, match="must be an instance "): - ssd = SSD(info, filt_params_signal, filt_params_noise) + ssd.fit(X) # Wrongly specified noise band freq = 2 @@ -115,14 +117,16 @@ def test_ssd(): l_trans_bandwidth=1, h_trans_bandwidth=1, ) + ssd = SSD(info, filt_params_signal, filt_params_noise) with pytest.raises(ValueError, match="Wrongly specified "): - ssd = SSD(info, filt_params_signal, filt_params_noise) + ssd.fit(X) # filt param no dict filt_params_signal = freqs_sig filt_params_noise = freqs_noise + ssd = SSD(info, filt_params_signal, filt_params_noise) with pytest.raises(ValueError, match="must be defined"): - ssd = SSD(info, filt_params_signal, filt_params_noise) + ssd.fit(X) # Data type filt_params_signal = dict( @@ -140,15 +144,18 @@ def test_ssd(): ssd = SSD(info, filt_params_signal, filt_params_noise) raw = io.RawArray(X, info) - pytest.raises(TypeError, ssd.fit, raw) + with pytest.raises(ValueError): + ssd.fit(raw) # check non-boolean return_filtered - with pytest.raises(ValueError, match="return_filtered"): - ssd = SSD(info, filt_params_signal, filt_params_noise, return_filtered=0) + ssd = SSD(info, filt_params_signal, filt_params_noise, return_filtered=0) + with pytest.raises(TypeError, match="return_filtered"): + ssd.fit(X) # check non-boolean sort_by_spectral_ratio - with pytest.raises(ValueError, match="sort_by_spectral_ratio"): - ssd = SSD(info, filt_params_signal, filt_params_noise, sort_by_spectral_ratio=0) + ssd = SSD(info, filt_params_signal, filt_params_noise, sort_by_spectral_ratio=0) + with pytest.raises(TypeError, match="sort_by_spectral_ratio"): + ssd.fit(X) # More than 1 channel type ch_types = np.reshape([["mag"] * 10, ["eeg"] * 10], n_channels) @@ -161,7 +168,8 @@ def test_ssd(): # Number of channels info_3 = create_info(ch_names=n_channels + 1, sfreq=sf, ch_types="eeg") ssd = SSD(info_3, filt_params_signal, filt_params_noise) - pytest.raises(ValueError, ssd.fit, X) + with pytest.raises(ValueError, match="channels but expected"): + ssd.fit(X) # Fit n_components = 10 @@ -381,7 +389,7 @@ def test_sorting(): ssd.fit(Xtr) # check sorters - sorter_in = ssd.sorter_spec + sorter_in = ssd.sorter_spec_ ssd = SSD( info, filt_params_signal, @@ -476,3 +484,29 @@ def test_non_full_rank_data(): if sys.platform == "darwin": pytest.xfail("Unknown linalg bug (Accelerate?)") ssd.fit(X) + + +@pytest.mark.filterwarnings("ignore:.*invalid value encountered in divide.*") +@pytest.mark.filterwarnings("ignore:.*is longer than.*") +@parametrize_with_checks( + [ + SSD( + 100.0, + dict(l_freq=0.0, h_freq=30.0), + dict(l_freq=0.0, h_freq=40.0), + ) + ] +) +def test_sklearn_compliance(estimator, check): + """Test LinearModel compliance with sklearn.""" + ignores = ( + "check_methods_sample_order_invariance", + # Shape stuff + "check_fit_idempotent", + "check_methods_subset_invariance", + "check_transformer_general", + "check_transformer_data_not_an_array", + ) + if any(ignore in str(check) for ignore in ignores): + return + check(estimator) diff --git a/mne/decoding/tests/test_time_frequency.py b/mne/decoding/tests/test_time_frequency.py index 37e7d7d8dc2..638cebda21e 100644 --- a/mne/decoding/tests/test_time_frequency.py +++ b/mne/decoding/tests/test_time_frequency.py @@ -10,18 +10,23 @@ pytest.importorskip("sklearn") from sklearn.base import clone +from sklearn.utils.estimator_checks import parametrize_with_checks from mne.decoding.time_frequency import TimeFrequency -def test_timefrequency(): +def test_timefrequency_basic(): """Test TimeFrequency.""" # Init n_freqs = 3 freqs = [20, 21, 22] tf = TimeFrequency(freqs, sfreq=100) + n_epochs, n_chans, n_times = 10, 2, 100 + X = np.random.rand(n_epochs, n_chans, n_times) for output in ["avg_power", "foo", None]: - pytest.raises(ValueError, TimeFrequency, freqs, output=output) + tf = TimeFrequency(freqs, output=output) + with pytest.raises(ValueError, match="Invalid value"): + tf.fit(X) tf = clone(tf) # Clone estimator @@ -30,9 +35,9 @@ def test_timefrequency(): clone(tf) # Fit - n_epochs, n_chans, n_times = 10, 2, 100 - X = np.random.rand(n_epochs, n_chans, n_times) + assert not hasattr(tf, "fitted_") tf.fit(X, None) + assert tf.fitted_ # Transform tf = TimeFrequency(freqs, sfreq=100) @@ -41,9 +46,15 @@ def test_timefrequency(): Xt = tf.transform(X) assert_array_equal(Xt.shape, [n_epochs, n_chans, n_freqs, n_times]) # 2-D X - Xt = tf.transform(X[:, 0, :]) + Xt = tf.fit_transform(X[:, 0, :]) assert_array_equal(Xt.shape, [n_epochs, n_freqs, n_times]) # 3-D with decim tf = TimeFrequency(freqs, sfreq=100, decim=2) - Xt = tf.transform(X) + Xt = tf.fit_transform(X) assert_array_equal(Xt.shape, [n_epochs, n_chans, n_freqs, n_times // 2]) + + +@parametrize_with_checks([TimeFrequency([300, 400], 1000.0, n_cycles=0.25)]) +def test_sklearn_compliance(estimator, check): + """Test LinearModel compliance with sklearn.""" + check(estimator) diff --git a/mne/decoding/tests/test_transformer.py b/mne/decoding/tests/test_transformer.py index 8dcc3ad74c7..a8afe209d96 100644 --- a/mne/decoding/tests/test_transformer.py +++ b/mne/decoding/tests/test_transformer.py @@ -17,10 +17,14 @@ from sklearn.decomposition import PCA from sklearn.kernel_ridge import KernelRidge +from sklearn.pipeline import make_pipeline +from sklearn.preprocessing import StandardScaler +from sklearn.utils.estimator_checks import parametrize_with_checks -from mne import Epochs, io, pick_types, read_events +from mne import Epochs, EpochsArray, create_info, io, pick_types, read_events from mne.decoding import ( FilterEstimator, + LinearModel, PSDEstimator, Scaler, TemporalFilter, @@ -36,6 +40,7 @@ data_dir = Path(__file__).parents[2] / "io" / "tests" / "data" raw_fname = data_dir / "test_raw.fif" event_name = data_dir / "test-eve.fif" +info = create_info(2, 1000.0, "eeg") @pytest.mark.parametrize( @@ -101,9 +106,11 @@ def test_scaler(info, method): assert_array_almost_equal(epochs_data, Xi) # Test init exception - pytest.raises(ValueError, Scaler, None, None) - pytest.raises(TypeError, scaler.fit, epochs, y) - pytest.raises(TypeError, scaler.transform, epochs) + x = Scaler(None, None) + with pytest.raises(ValueError): + x.fit(epochs_data, y) + pytest.raises(ValueError, scaler.fit, "foo", y) + pytest.raises(ValueError, scaler.transform, "foo") epochs_bad = Epochs( raw, events, @@ -164,8 +171,8 @@ def test_filterestimator(): X = filt.fit_transform(epochs_data, y) # Test init exception - pytest.raises(ValueError, filt.fit, epochs, y) - pytest.raises(ValueError, filt.transform, epochs) + pytest.raises(ValueError, filt.fit, "foo", y) + pytest.raises(ValueError, filt.transform, "foo") def test_psdestimator(): @@ -182,14 +189,18 @@ def test_psdestimator(): epochs_data = epochs.get_data(copy=False) psd = PSDEstimator(2 * np.pi, 0, np.inf) y = epochs.events[:, -1] + assert not hasattr(psd, "fitted_") X = psd.fit_transform(epochs_data, y) + assert psd.fitted_ assert X.shape[0] == epochs_data.shape[0] assert_array_equal(psd.fit(epochs_data, y).transform(epochs_data), X) # Test init exception - pytest.raises(ValueError, psd.fit, epochs, y) - pytest.raises(ValueError, psd.transform, epochs) + with pytest.raises(ValueError): + psd.fit("foo", y) + with pytest.raises(ValueError): + psd.transform("foo") def test_vectorizer(): @@ -210,9 +221,16 @@ def test_vectorizer(): assert_equal(vect.fit_transform(data[1:]).shape, (149, 108)) # check if raised errors are working correctly - vect.fit(np.random.rand(105, 12, 3)) - pytest.raises(ValueError, vect.transform, np.random.rand(105, 12, 3, 1)) - pytest.raises(ValueError, vect.inverse_transform, np.random.rand(102, 12, 12)) + X = np.random.default_rng(0).standard_normal((105, 12, 3)) + y = np.arange(X.shape[0]) % 2 + pytest.raises(ValueError, vect.transform, X[..., np.newaxis]) + pytest.raises(ValueError, vect.inverse_transform, X[:, :-1]) + + # And that pipelines work properly + X_arr = EpochsArray(X, create_info(12, 1000.0, "eeg")) + vect.fit(X_arr) + clf = make_pipeline(Vectorizer(), StandardScaler(), LinearModel()) + clf.fit(X_arr, y) def test_unsupervised_spatial_filter(): @@ -235,11 +253,13 @@ def test_unsupervised_spatial_filter(): verbose=False, ) - # Test estimator - pytest.raises(ValueError, UnsupervisedSpatialFilter, KernelRidge(2)) + # Test estimator (must be a transformer) + X = epochs.get_data(copy=False) + usf = UnsupervisedSpatialFilter(KernelRidge(2)) + with pytest.raises(ValueError, match="transform"): + usf.fit(X) # Test fit - X = epochs.get_data(copy=False) n_components = 4 usf = UnsupervisedSpatialFilter(PCA(n_components)) usf.fit(X) @@ -255,7 +275,9 @@ def test_unsupervised_spatial_filter(): # Test with average param usf = UnsupervisedSpatialFilter(PCA(4), average=True) usf.fit_transform(X) - pytest.raises(ValueError, UnsupervisedSpatialFilter, PCA(4), 2) + usf = UnsupervisedSpatialFilter(PCA(4), 2) + with pytest.raises(TypeError, match="average must be"): + usf.fit(X) def test_temporal_filter(): @@ -281,8 +303,8 @@ def test_temporal_filter(): assert X.shape == Xt.shape # Test fit and transform numpy type check - with pytest.raises(ValueError, match="Data to be filtered must be"): - filt.transform([1, 2]) + with pytest.raises(ValueError): + filt.transform("foo") # Test with 2 dimensional data array X = np.random.rand(101, 500) @@ -298,4 +320,36 @@ def test_bad_triage(): filt = TemporalFilter(l_freq=8, h_freq=60, sfreq=160.0) # Used to fail with "ValueError: Effective band-stop frequency (135.0) is # too high (maximum based on Nyquist is 80.0)" + assert not hasattr(filt, "fitted_") filt.fit_transform(np.zeros((1, 1, 481))) + assert filt.fitted_ + + +@pytest.mark.filterwarnings("ignore:.*filter_length.*") +@parametrize_with_checks( + [ + FilterEstimator(info, l_freq=1, h_freq=10), + PSDEstimator(), + Scaler(scalings="mean"), + # Not easy to test Scaler(info) b/c number of channels must match + TemporalFilter(), + UnsupervisedSpatialFilter(PCA()), + Vectorizer(), + ] +) +def test_sklearn_compliance(estimator, check): + """Test LinearModel compliance with sklearn.""" + ignores = [] + if estimator.__class__.__name__ == "FilterEstimator": + ignores += [ + "check_estimators_overwrite_params", # we modify self.info + "check_methods_sample_order_invariance", + ] + if estimator.__class__.__name__.startswith(("PSD", "Temporal")): + ignores += [ + "check_transformers_unfitted", # allow unfitted transform + "check_methods_sample_order_invariance", + ] + if any(ignore in str(check) for ignore in ignores): + return + check(estimator) diff --git a/mne/decoding/time_frequency.py b/mne/decoding/time_frequency.py index de6ec52155b..29232aaeb9f 100644 --- a/mne/decoding/time_frequency.py +++ b/mne/decoding/time_frequency.py @@ -3,14 +3,16 @@ # Copyright the MNE-Python contributors. import numpy as np -from sklearn.base import BaseEstimator, TransformerMixin +from sklearn.base import BaseEstimator +from sklearn.utils.validation import check_is_fitted from ..time_frequency.tfr import _compute_tfr -from ..utils import _check_option, fill_doc, verbose +from ..utils import _check_option, fill_doc +from .transformer import MNETransformerMixin @fill_doc -class TimeFrequency(TransformerMixin, BaseEstimator): +class TimeFrequency(MNETransformerMixin, BaseEstimator): """Time frequency transformer. Time-frequency transform of times series along the last axis. @@ -59,7 +61,6 @@ class TimeFrequency(TransformerMixin, BaseEstimator): mne.time_frequency.tfr_multitaper """ - @verbose def __init__( self, freqs, @@ -74,9 +75,6 @@ def __init__( verbose=None, ): """Init TimeFrequency transformer.""" - # Check non-average output - output = _check_option("output", output, ["complex", "power", "phase"]) - self.freqs = freqs self.sfreq = sfreq self.method = method @@ -89,6 +87,16 @@ def __init__( self.n_jobs = n_jobs self.verbose = verbose + def __sklearn_tags__(self): + """Return sklearn tags.""" + out = super().__sklearn_tags__() + from sklearn.utils import TransformerTags + + if out.transformer_tags is None: + out.transformer_tags = TransformerTags() + out.transformer_tags.preserves_dtype = [] # real->complex + return out + def fit_transform(self, X, y=None): """Time-frequency transform of times series along the last axis. @@ -123,6 +131,10 @@ def fit(self, X, y=None): # noqa: D401 self : object Return self. """ + # Check non-average output + _check_option("output", self.output, ["complex", "power", "phase"]) + self._check_data(X, y=y, fit=True) + self.fitted_ = True return self def transform(self, X): @@ -130,16 +142,18 @@ def transform(self, X): Parameters ---------- - X : array, shape (n_samples, n_channels, n_times) + X : array, shape (n_samples, [n_channels, ]n_times) The training data samples. The channel dimension can be zero- or 1-dimensional. Returns ------- - Xt : array, shape (n_samples, n_channels, n_freqs, n_times) + Xt : array, shape (n_samples, [n_channels, ]n_freqs, n_times) The time-frequency transform of the data, where n_channels can be zero- or 1-dimensional. """ + X = self._check_data(X, atleast_3d=False) + check_is_fitted(self, "fitted_") # Ensure 3-dimensional X shape = X.shape[1:-1] if not shape: diff --git a/mne/decoding/transformer.py b/mne/decoding/transformer.py index e475cd22161..6d0c83f42ab 100644 --- a/mne/decoding/transformer.py +++ b/mne/decoding/transformer.py @@ -3,19 +3,72 @@ # Copyright the MNE-Python contributors. import numpy as np -from sklearn.base import BaseEstimator, TransformerMixin +from sklearn.base import BaseEstimator, TransformerMixin, check_array, clone +from sklearn.preprocessing import RobustScaler, StandardScaler +from sklearn.utils import check_X_y +from sklearn.utils.validation import check_is_fitted, validate_data from .._fiff.pick import ( _pick_data_channels, _picks_by_type, _picks_to_idx, pick_info, - pick_types, ) from ..cov import _check_scalings_user +from ..epochs import BaseEpochs from ..filter import filter_data from ..time_frequency import psd_array_multitaper -from ..utils import _check_option, _validate_type, fill_doc, verbose +from ..utils import _check_option, _validate_type, fill_doc + + +class MNETransformerMixin(TransformerMixin): + """TransformerMixin plus some helpers.""" + + def _check_data( + self, + epochs_data, + *, + y=None, + atleast_3d=True, + fit=False, + return_y=False, + multi_output=False, + check_n_features=True, + ): + # Sklearn calls asarray under the hood which works, but elsewhere they check for + # __len__ then look at the size of obj[0]... which is an epoch of shape (1, ...) + # rather than what they expect (shape (...)). So we explicitly get the NumPy + # array to make everyone happy. + if isinstance(epochs_data, BaseEpochs): + epochs_data = epochs_data.get_data(copy=False) + kwargs = dict(dtype=np.float64, allow_nd=True, order="C", force_writeable=True) + if hasattr(self, "n_features_in_") and check_n_features: + if y is None: + epochs_data = validate_data( + self, + epochs_data, + **kwargs, + reset=fit, + ) + else: + epochs_data, y = validate_data( + self, + epochs_data, + y, + **kwargs, + reset=fit, + ) + elif y is None: + epochs_data = check_array(epochs_data, **kwargs) + else: + epochs_data, y = check_X_y( + X=epochs_data, y=y, multi_output=multi_output, **kwargs + ) + if fit: + self.n_features_in_ = epochs_data.shape[1] + if atleast_3d: + epochs_data = np.atleast_3d(epochs_data) + return (epochs_data, y) if return_y else epochs_data class _ConstantScaler: @@ -55,8 +108,9 @@ def fit_transform(self, X, y=None): def _sklearn_reshape_apply(func, return_result, X, *args, **kwargs): """Reshape epochs and apply function.""" - if not isinstance(X, np.ndarray): - raise ValueError(f"data should be an np.ndarray, got {type(X)}.") + _validate_type(X, np.ndarray, "X") + if X.size == 0: + return X.copy() if return_result else None orig_shape = X.shape X = np.reshape(X.transpose(0, 2, 1), (-1, orig_shape[1])) X = func(X, *args, **kwargs) @@ -67,7 +121,7 @@ def _sklearn_reshape_apply(func, return_result, X, *args, **kwargs): @fill_doc -class Scaler(TransformerMixin, BaseEstimator): +class Scaler(MNETransformerMixin, BaseEstimator): """Standardize channel data. This class scales data for each channel. It differs from scikit-learn @@ -109,31 +163,6 @@ def __init__(self, info=None, scalings=None, with_mean=True, with_std=True): self.with_std = with_std self.scalings = scalings - if not (scalings is None or isinstance(scalings, dict | str)): - raise ValueError( - f"scalings type should be dict, str, or None, got {type(scalings)}" - ) - if isinstance(scalings, str): - _check_option("scalings", scalings, ["mean", "median"]) - if scalings is None or isinstance(scalings, dict): - if info is None: - raise ValueError( - f'Need to specify "info" if scalings is {type(scalings)}' - ) - self._scaler = _ConstantScaler(info, scalings, self.with_std) - elif scalings == "mean": - from sklearn.preprocessing import StandardScaler - - self._scaler = StandardScaler( - with_mean=self.with_mean, with_std=self.with_std - ) - else: # scalings == 'median': - from sklearn.preprocessing import RobustScaler - - self._scaler = RobustScaler( - with_centering=self.with_mean, with_scaling=self.with_std - ) - def fit(self, epochs_data, y=None): """Standardize data across channels. @@ -149,11 +178,30 @@ def fit(self, epochs_data, y=None): self : instance of Scaler The modified instance. """ - _validate_type(epochs_data, np.ndarray, "epochs_data") - if epochs_data.ndim == 2: - epochs_data = epochs_data[..., np.newaxis] + epochs_data = self._check_data(epochs_data, y=y, fit=True, multi_output=True) assert epochs_data.ndim == 3, epochs_data.shape - _sklearn_reshape_apply(self._scaler.fit, False, epochs_data, y=y) + + _validate_type(self.scalings, (dict, str, type(None)), "scalings") + if isinstance(self.scalings, str): + _check_option( + "scalings", self.scalings, ["mean", "median"], extra="when str" + ) + if self.scalings is None or isinstance(self.scalings, dict): + if self.info is None: + raise ValueError( + f'Need to specify "info" if scalings is {type(self.scalings)}' + ) + self.scaler_ = _ConstantScaler(self.info, self.scalings, self.with_std) + elif self.scalings == "mean": + self.scaler_ = StandardScaler( + with_mean=self.with_mean, with_std=self.with_std + ) + else: # scalings == 'median': + self.scaler_ = RobustScaler( + with_centering=self.with_mean, with_scaling=self.with_std + ) + + _sklearn_reshape_apply(self.scaler_.fit, False, epochs_data, y=y) return self def transform(self, epochs_data): @@ -174,13 +222,14 @@ def transform(self, epochs_data): This function makes a copy of the data before the operations and the memory usage may be large with big data. """ - _validate_type(epochs_data, np.ndarray, "epochs_data") + check_is_fitted(self, "scaler_") + epochs_data = self._check_data(epochs_data, atleast_3d=False) if epochs_data.ndim == 2: # can happen with SlidingEstimator if self.info is not None: assert len(self.info["ch_names"]) == epochs_data.shape[1] epochs_data = epochs_data[..., np.newaxis] assert epochs_data.ndim == 3, epochs_data.shape - return _sklearn_reshape_apply(self._scaler.transform, True, epochs_data) + return _sklearn_reshape_apply(self.scaler_.transform, True, epochs_data) def fit_transform(self, epochs_data, y=None): """Fit to data, then transform it. @@ -226,19 +275,20 @@ def inverse_transform(self, epochs_data): This function makes a copy of the data before the operations and the memory usage may be large with big data. """ + epochs_data = self._check_data(epochs_data, atleast_3d=False) squeeze = False # Can happen with CSP if epochs_data.ndim == 2: squeeze = True epochs_data = epochs_data[..., np.newaxis] assert epochs_data.ndim == 3, epochs_data.shape - out = _sklearn_reshape_apply(self._scaler.inverse_transform, True, epochs_data) + out = _sklearn_reshape_apply(self.scaler_.inverse_transform, True, epochs_data) if squeeze: out = out[..., 0] return out -class Vectorizer(TransformerMixin, BaseEstimator): +class Vectorizer(MNETransformerMixin, BaseEstimator): """Transform n-dimensional array into 2D array of n_samples by n_features. This class reshapes an n-dimensional array into an n_samples * n_features @@ -275,7 +325,7 @@ def fit(self, X, y=None): self : instance of Vectorizer Return the modified instance. """ - X = np.asarray(X) + X = self._check_data(X, y=y, atleast_3d=False, fit=True, check_n_features=False) self.features_shape_ = X.shape[1:] return self @@ -295,7 +345,7 @@ def transform(self, X): X : array, shape (n_samples, n_features) The transformed data. """ - X = np.asarray(X) + X = self._check_data(X, atleast_3d=False) if X.shape[1:] != self.features_shape_: raise ValueError("Shape of X used in fit and transform must be same") return X.reshape(len(X), -1) @@ -334,7 +384,7 @@ def inverse_transform(self, X): The data transformed into shape as used in fit. The first dimension is of length n_samples. """ - X = np.asarray(X) + X = self._check_data(X, atleast_3d=False, check_n_features=False) if X.ndim not in (2, 3): raise ValueError( f"X should be of 2 or 3 dimensions but has shape {X.shape}" @@ -343,7 +393,7 @@ def inverse_transform(self, X): @fill_doc -class PSDEstimator(TransformerMixin, BaseEstimator): +class PSDEstimator(MNETransformerMixin, BaseEstimator): """Compute power spectral density (PSD) using a multi-taper method. Parameters @@ -365,7 +415,6 @@ class PSDEstimator(TransformerMixin, BaseEstimator): n_jobs : int Number of parallel jobs to use (only used if adaptive=True). %(normalization)s - %(verbose)s See Also -------- @@ -375,7 +424,6 @@ class PSDEstimator(TransformerMixin, BaseEstimator): mne.Evoked.compute_psd """ - @verbose def __init__( self, sfreq=2 * np.pi, @@ -386,8 +434,6 @@ def __init__( low_bias=True, n_jobs=None, normalization="length", - *, - verbose=None, ): self.sfreq = sfreq self.fmin = fmin @@ -398,7 +444,7 @@ def __init__( self.n_jobs = n_jobs self.normalization = normalization - def fit(self, epochs_data, y): + def fit(self, epochs_data, y=None): """Compute power spectral density (PSD) using a multi-taper method. Parameters @@ -413,11 +459,8 @@ def fit(self, epochs_data, y): self : instance of PSDEstimator The modified instance. """ - if not isinstance(epochs_data, np.ndarray): - raise ValueError( - f"epochs_data should be of type ndarray (got {type(epochs_data)})." - ) - + self._check_data(epochs_data, y=y, fit=True) + self.fitted_ = True # sklearn compliance return self def transform(self, epochs_data): @@ -433,10 +476,7 @@ def transform(self, epochs_data): psd : array, shape (n_signals, n_freqs) or (n_freqs,) The computed PSD. """ - if not isinstance(epochs_data, np.ndarray): - raise ValueError( - f"epochs_data should be of type ndarray (got {type(epochs_data)})." - ) + epochs_data = self._check_data(epochs_data) psd, _ = psd_array_multitaper( epochs_data, sfreq=self.sfreq, @@ -452,7 +492,7 @@ def transform(self, epochs_data): @fill_doc -class FilterEstimator(TransformerMixin, BaseEstimator): +class FilterEstimator(MNETransformerMixin, BaseEstimator): """Estimator to filter RtEpochs. Applies a zero-phase low-pass, high-pass, band-pass, or band-stop @@ -488,7 +528,6 @@ class FilterEstimator(TransformerMixin, BaseEstimator): See mne.filter.construct_iir_filter for details. If iir_params is None and method="iir", 4th order Butterworth will be used. %(fir_design)s - %(verbose)s See Also -------- @@ -514,13 +553,11 @@ def __init__( method="fir", iir_params=None, fir_design="firwin", - *, - verbose=None, ): self.info = info self.l_freq = l_freq self.h_freq = h_freq - self.picks = _picks_to_idx(info, picks) + self.picks = picks self.filter_length = filter_length self.l_trans_bandwidth = l_trans_bandwidth self.h_trans_bandwidth = h_trans_bandwidth @@ -544,24 +581,11 @@ def fit(self, epochs_data, y): self : instance of FilterEstimator The modified instance. """ - if not isinstance(epochs_data, np.ndarray): - raise ValueError( - f"epochs_data should be of type ndarray (got {type(epochs_data)})." - ) - - if self.picks is None: - self.picks = pick_types( - self.info, meg=True, eeg=True, ref_meg=False, exclude=[] - ) + self.picks_ = _picks_to_idx(self.info, self.picks) + self._check_data(epochs_data, y=y, fit=True) if self.l_freq == 0: self.l_freq = None - if self.h_freq is not None and self.h_freq > (self.info["sfreq"] / 2.0): - self.h_freq = None - if self.l_freq is not None and not isinstance(self.l_freq, float): - self.l_freq = float(self.l_freq) - if self.h_freq is not None and not isinstance(self.h_freq, float): - self.h_freq = float(self.h_freq) if self.info["lowpass"] is None or ( self.h_freq is not None @@ -594,17 +618,12 @@ def transform(self, epochs_data): X : array, shape (n_epochs, n_channels, n_times) The data after filtering. """ - if not isinstance(epochs_data, np.ndarray): - raise ValueError( - f"epochs_data should be of type ndarray (got {type(epochs_data)})." - ) - epochs_data = np.atleast_3d(epochs_data) return filter_data( - epochs_data, + self._check_data(epochs_data), self.info["sfreq"], self.l_freq, self.h_freq, - self.picks, + self.picks_, self.filter_length, self.l_trans_bandwidth, self.h_trans_bandwidth, @@ -617,7 +636,7 @@ def transform(self, epochs_data): ) -class UnsupervisedSpatialFilter(TransformerMixin, BaseEstimator): +class UnsupervisedSpatialFilter(MNETransformerMixin, BaseEstimator): """Use unsupervised spatial filtering across time and samples. Parameters @@ -630,19 +649,6 @@ class UnsupervisedSpatialFilter(TransformerMixin, BaseEstimator): """ def __init__(self, estimator, average=False): - # XXX: Use _check_estimator #3381 - for attr in ("fit", "transform", "fit_transform"): - if not hasattr(estimator, attr): - raise ValueError( - "estimator must be a scikit-learn " - f"transformer, missing {attr} method" - ) - - if not isinstance(average, bool): - raise ValueError( - f"average parameter must be of bool type, got {type(bool)} instead" - ) - self.estimator = estimator self.average = average @@ -661,13 +667,25 @@ def fit(self, X, y=None): self : instance of UnsupervisedSpatialFilter Return the modified instance. """ + # sklearn.utils.estimator_checks.check_estimator(self.estimator) is probably + # too strict for us, given that we don't fully adhere yet, so just check attrs + for attr in ("fit", "transform", "fit_transform"): + if not hasattr(self.estimator, attr): + raise ValueError( + "estimator must be a scikit-learn " + f"transformer, missing {attr} method" + ) + _validate_type(self.average, bool, "average") + X = self._check_data(X, y=y, fit=True) if self.average: X = np.mean(X, axis=0).T else: n_epochs, n_channels, n_times = X.shape # trial as time samples X = np.transpose(X, (1, 0, 2)).reshape((n_channels, n_epochs * n_times)).T - self.estimator.fit(X) + + self.estimator_ = clone(self.estimator) + self.estimator_.fit(X) return self def fit_transform(self, X, y=None): @@ -700,6 +718,8 @@ def transform(self, X): X : array, shape (n_epochs, n_channels, n_times) The transformed data. """ + check_is_fitted(self.estimator_) + X = self._check_data(X) return self._apply_method(X, "transform") def inverse_transform(self, X): @@ -735,7 +755,7 @@ def _apply_method(self, X, method): X = np.transpose(X, [1, 0, 2]) X = np.reshape(X, [n_channels, n_epochs * n_times]).T # apply method - method = getattr(self.estimator, method) + method = getattr(self.estimator_, method) X = method(X) # put it back to n_epochs, n_dimensions X = np.reshape(X.T, [-1, n_epochs, n_times]).transpose([1, 0, 2]) @@ -743,7 +763,7 @@ def _apply_method(self, X, method): @fill_doc -class TemporalFilter(TransformerMixin, BaseEstimator): +class TemporalFilter(MNETransformerMixin, BaseEstimator): """Estimator to filter data array along the last dimension. Applies a zero-phase low-pass, high-pass, band-pass, or band-stop @@ -817,7 +837,6 @@ class TemporalFilter(TransformerMixin, BaseEstimator): attenuation using fewer samples than "firwin2". .. versionadded:: 0.15 - %(verbose)s See Also -------- @@ -826,7 +845,6 @@ class TemporalFilter(TransformerMixin, BaseEstimator): mne.filter.filter_data """ - @verbose def __init__( self, l_freq=None, @@ -840,8 +858,6 @@ def __init__( iir_params=None, fir_window="hamming", fir_design="firwin", - *, - verbose=None, ): self.l_freq = l_freq self.h_freq = h_freq @@ -855,17 +871,12 @@ def __init__( self.fir_window = fir_window self.fir_design = fir_design - if not isinstance(self.n_jobs, int) and self.n_jobs == "cuda": - raise ValueError( - f'n_jobs must be int or "cuda", got {type(self.n_jobs)} instead.' - ) - def fit(self, X, y=None): """Do nothing (for scikit-learn compatibility purposes). Parameters ---------- - X : array, shape (n_epochs, n_channels, n_times) or or shape (n_channels, n_times) + X : array, shape ([n_epochs, ]n_channels, n_times) The data to be filtered over the last dimension. The channels dimension can be zero when passing a 2D array. y : None @@ -875,7 +886,9 @@ def fit(self, X, y=None): ------- self : instance of TemporalFilter The modified instance. - """ # noqa: E501 + """ + self.fitted_ = True # sklearn compliance + self._check_data(X, y=y, atleast_3d=False, fit=True) return self def transform(self, X): @@ -883,7 +896,7 @@ def transform(self, X): Parameters ---------- - X : array, shape (n_epochs, n_channels, n_times) or shape (n_channels, n_times) + X : array, shape ([n_epochs, ]n_channels, n_times) The data to be filtered over the last dimension. The channels dimension can be zero when passing a 2D array. @@ -892,6 +905,7 @@ def transform(self, X): X : array The data after filtering. """ # noqa: E501 + X = self._check_data(X, atleast_3d=False) X = np.atleast_2d(X) if X.ndim > 3: diff --git a/mne/time_frequency/multitaper.py b/mne/time_frequency/multitaper.py index 98705e838c2..1c1a3baf238 100644 --- a/mne/time_frequency/multitaper.py +++ b/mne/time_frequency/multitaper.py @@ -63,7 +63,14 @@ def dpss_windows(N, half_nbw, Kmax, *, sym=True, norm=None, low_bias=True): ---------- .. footbibliography:: """ - dpss, eigvals = sp_dpss(N, half_nbw, Kmax, sym=sym, norm=norm, return_ratios=True) + # TODO VERSION can be removed with SciPy 1.16 is min, + # workaround for https://github.com/scipy/scipy/pull/22344 + if N <= 1: + dpss, eigvals = np.ones((1, 1)), np.ones(1) + else: + dpss, eigvals = sp_dpss( + N, half_nbw, Kmax, sym=sym, norm=norm, return_ratios=True + ) if low_bias: idx = eigvals > 0.9 if not idx.any(): diff --git a/mne/time_frequency/tfr.py b/mne/time_frequency/tfr.py index fc60802f61b..f4a01e87895 100644 --- a/mne/time_frequency/tfr.py +++ b/mne/time_frequency/tfr.py @@ -562,7 +562,8 @@ def _compute_tfr( if len(Ws[0][0]) > epoch_data.shape[2]: raise ValueError( "At least one of the wavelets is longer than the " - "signal. Use a longer signal or shorter wavelets." + f"signal ({len(Ws[0][0])} > {epoch_data.shape[2]} samples). " + "Use a longer signal or shorter wavelets." ) # Initialize output diff --git a/mne/utils/numerics.py b/mne/utils/numerics.py index 5029e8fbeca..11ba0ecb487 100644 --- a/mne/utils/numerics.py +++ b/mne/utils/numerics.py @@ -35,7 +35,7 @@ check_random_state, ) from .docs import fill_doc -from .misc import _empty_hash +from .misc import _empty_hash, _pl def split_list(v, n, idx=False): @@ -479,7 +479,8 @@ def _time_mask( extra = "" if include_tmax else "when include_tmax=False " raise ValueError( f"No samples remain when using tmin={orig_tmin} and tmax={orig_tmax} " - f"{extra}(original time bounds are [{times[0]}, {times[-1]}])" + f"{extra}(original time bounds are [{times[0]}, {times[-1]}] containing " + f"{len(times)} sample{_pl(times)})" ) return mask diff --git a/tools/vulture_allowlist.py b/tools/vulture_allowlist.py index 24bcd9af64a..9d0e215ee80 100644 --- a/tools/vulture_allowlist.py +++ b/tools/vulture_allowlist.py @@ -41,6 +41,8 @@ # Decoding _._more_tags +_.multi_class +_.preserves_dtype deep # Backward compat or rarely used From 6028982a3e34bf843d4694f60565a0fbb821ed2e Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Wed, 22 Jan 2025 13:21:23 -0500 Subject: [PATCH 17/24] MAINT: Fix doc build (#13076) --- doc/sphinxext/mne_doc_utils.py | 2 ++ doc/sphinxext/related_software.py | 24 ++++++++++++++++-------- mne/viz/backends/_pyvista.py | 1 - tools/circleci_dependencies.sh | 2 +- tutorials/intro/70_report.py | 10 +++++----- tutorials/inverse/20_dipole_fit.py | 1 + 6 files changed, 25 insertions(+), 15 deletions(-) diff --git a/doc/sphinxext/mne_doc_utils.py b/doc/sphinxext/mne_doc_utils.py index 7df361e4af1..e626838f251 100644 --- a/doc/sphinxext/mne_doc_utils.py +++ b/doc/sphinxext/mne_doc_utils.py @@ -97,6 +97,8 @@ def reset_warnings(gallery_conf, fname): r"numpy\.core is deprecated and has been renamed to numpy\._core", # matplotlib "__array_wrap__ must accept context and return_scalar.*", + # nibabel + "__array__ implementation doesn't accept.*", ): warnings.filterwarnings( # deal with other modules having bad imports "ignore", message=f".*{key}.*", category=DeprecationWarning diff --git a/doc/sphinxext/related_software.py b/doc/sphinxext/related_software.py index ab159b0fcb4..2548725390a 100644 --- a/doc/sphinxext/related_software.py +++ b/doc/sphinxext/related_software.py @@ -173,6 +173,7 @@ def _get_packages() -> dict[str, str]: packages = sorted(packages, key=lambda x: x.lower()) packages = [RENAMES.get(package, package) for package in packages] out = dict() + reasons = [] for package in status_iterator( packages, f"Adding {len(packages)} related software packages: " ): @@ -183,12 +184,17 @@ def _get_packages() -> dict[str, str]: else: md = importlib.metadata.metadata(package) except importlib.metadata.PackageNotFoundError: - pass # raise a complete error later + reasons.append(f"{package}: not found, needs to be installed") + continue # raise a complete error later else: # Every project should really have this + do_continue = False for key in ("Summary",): if key not in md: - raise ExtensionError(f"Missing {repr(key)} for {package}") + reasons.extend(f"{package}: missing {repr(key)}") + do_continue = True + if do_continue: + continue # It is annoying to find the home page url = None if "Home-page" in md: @@ -204,15 +210,17 @@ def _get_packages() -> dict[str, str]: if url is not None: break else: - raise RuntimeError( - f"Could not find Home-page for {package} in:\n" - f"{sorted(set(md))}\nwith Summary:\n{md['Summary']}" + reasons.append( + f"{package}: could not find Home-page in {sorted(md)}" ) + continue out[package]["url"] = url out[package]["description"] = md["Summary"].replace("\n", "") - bad = [package for package in packages if not out[package]] - if bad and REQUIRE_METADATA: - raise ExtensionError(f"Could not find metadata for:\n{' '.join(bad)}") + reason_str = "\n".join(reasons) + if reason_str and REQUIRE_METADATA: + raise ExtensionError( + f"Could not find suitable metadata for related software:\n{reason_str}" + ) return out diff --git a/mne/viz/backends/_pyvista.py b/mne/viz/backends/_pyvista.py index 8fe4b4bf1d8..ee5b62404d3 100644 --- a/mne/viz/backends/_pyvista.py +++ b/mne/viz/backends/_pyvista.py @@ -1331,7 +1331,6 @@ def _is_osmesa(plotter): ) gpu_info = " ".join(gpu_info).lower() is_osmesa = "mesa" in gpu_info.split() - print(is_osmesa) if is_osmesa: # Try to warn if it's ancient version = re.findall("mesa ([0-9.]+)[ -].*", gpu_info) or re.findall( diff --git a/tools/circleci_dependencies.sh b/tools/circleci_dependencies.sh index dd3216ebf06..b306bb528f4 100755 --- a/tools/circleci_dependencies.sh +++ b/tools/circleci_dependencies.sh @@ -13,4 +13,4 @@ python -m pip install --upgrade --progress-bar off \ mne-icalabel mne-lsl mne-microstates mne-nirs mne-rsa \ neurodsp neurokit2 niseq nitime pactools \ plotly pycrostates pyprep pyriemann python-picard sesameeg \ - sleepecg tensorpac yasa meegkit eeg_positions + sleepecg tensorpac yasa meegkit eeg_positions wfdb diff --git a/tutorials/intro/70_report.py b/tutorials/intro/70_report.py index cc32d02679b..fe87c0f3a44 100644 --- a/tutorials/intro/70_report.py +++ b/tutorials/intro/70_report.py @@ -12,11 +12,11 @@ and after each preprocessing step, epoch rejection statistics, MRI slices with overlaid BEM shells, all the way up to plots of estimated cortical activity. -Compared to a Jupyter notebook, :class:`mne.Report` is easier to deploy, as the -HTML pages it generates are self-contained and do not require a running Python -environment. However, it is less flexible as you can't change code and re-run -something directly within the browser. This tutorial covers the basics of -building a report. As usual, we will start by importing the modules and data we need: +Compared to a Jupyter notebook, :class:`mne.Report` is easier to deploy, as the HTML +pages it generates are self-contained and do not require a running Python environment. +However, it is less flexible as you can't change code and re-run something directly +within the browser. This tutorial covers the basics of building a report. As usual, +we will start by importing the modules and data we need: """ # Authors: The MNE-Python contributors. diff --git a/tutorials/inverse/20_dipole_fit.py b/tutorials/inverse/20_dipole_fit.py index 2b640aa8fc2..e72e76dd0fd 100644 --- a/tutorials/inverse/20_dipole_fit.py +++ b/tutorials/inverse/20_dipole_fit.py @@ -87,6 +87,7 @@ # %% # Calculate and visualise magnetic field predicted by dipole with maximum GOF # and compare to the measured data, highlighting the ipsilateral (right) source + fwd, stc = make_forward_dipole(dip, fname_bem, evoked.info, fname_trans) pred_evoked = simulate_evoked(fwd, stc, evoked.info, cov=None, nave=np.inf) From aca49655b10fc17679142e07c5d46659be1099da Mon Sep 17 00:00:00 2001 From: Marijn van Vliet Date: Thu, 23 Jan 2025 17:36:29 +0200 Subject: [PATCH 18/24] Allow lasso selection sensors in a plot_evoked_topo (#12071) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Daniel McCloy --- doc/changes/devel/12071.newfeature.rst | 1 + mne/epochs.py | 2 + mne/evoked.py | 2 + mne/viz/_figure.py | 6 +- mne/viz/_mpl_figure.py | 2 +- mne/viz/evoked.py | 13 +- mne/viz/tests/test_raw.py | 39 ++---- mne/viz/tests/test_topo.py | 36 +++++- mne/viz/tests/test_utils.py | 69 ++++++++++ mne/viz/topo.py | 96 +++++++++++--- mne/viz/ui_events.py | 20 +++ mne/viz/utils.py | 171 ++++++++++++++++--------- 12 files changed, 348 insertions(+), 109 deletions(-) create mode 100644 doc/changes/devel/12071.newfeature.rst diff --git a/doc/changes/devel/12071.newfeature.rst b/doc/changes/devel/12071.newfeature.rst new file mode 100644 index 00000000000..4e7995e3beb --- /dev/null +++ b/doc/changes/devel/12071.newfeature.rst @@ -0,0 +1 @@ +Add new ``select`` parameter to :func:`mne.viz.plot_evoked_topo` and :meth:`mne.Evoked.plot_topo` to toggle lasso selection of sensors, by `Marijn van Vliet`_. diff --git a/mne/epochs.py b/mne/epochs.py index 679643ab969..ee8921d3990 100644 --- a/mne/epochs.py +++ b/mne/epochs.py @@ -1353,6 +1353,7 @@ def plot_topo_image( fig_facecolor="k", fig_background=None, font_color="w", + select=False, show=True, ): return plot_topo_image_epochs( @@ -1371,6 +1372,7 @@ def plot_topo_image( fig_facecolor=fig_facecolor, fig_background=fig_background, font_color=font_color, + select=select, show=show, ) diff --git a/mne/evoked.py b/mne/evoked.py index c04f83531e3..7bd2355e4ee 100644 --- a/mne/evoked.py +++ b/mne/evoked.py @@ -613,6 +613,7 @@ def plot_topo( background_color="w", noise_cov=None, exclude="bads", + select=False, show=True, ): """. @@ -639,6 +640,7 @@ def plot_topo( background_color=background_color, noise_cov=noise_cov, exclude=exclude, + select=select, show=show, ) diff --git a/mne/viz/_figure.py b/mne/viz/_figure.py index b63d2a395e2..f492c4b7fde 100644 --- a/mne/viz/_figure.py +++ b/mne/viz/_figure.py @@ -500,11 +500,11 @@ def _create_ch_location_fig(self, pick): show=False, ) # highlight desired channel & disable interactivity - inds = np.isin(fig.lasso.ch_names, [ch_name]) + fig.lasso.selection_inds = np.isin(fig.lasso.names, [ch_name]) fig.lasso.disconnect() - fig.lasso.alpha_other = 0.3 + fig.lasso.alpha_nonselected = 0.3 fig.lasso.linewidth_selected = 3 - fig.lasso.style_sensors(inds) + fig.lasso.style_objects() return fig diff --git a/mne/viz/_mpl_figure.py b/mne/viz/_mpl_figure.py index 2e552bd4012..3987b641dff 100644 --- a/mne/viz/_mpl_figure.py +++ b/mne/viz/_mpl_figure.py @@ -1536,7 +1536,7 @@ def _update_selection(self): def _update_highlighted_sensors(self): """Update the sensor plot to show what is selected.""" inds = np.isin( - self.mne.fig_selection.lasso.ch_names, self.mne.ch_names[self.mne.picks] + self.mne.fig_selection.lasso.names, self.mne.ch_names[self.mne.picks] ).nonzero()[0] self.mne.fig_selection.lasso.select_many(inds) diff --git a/mne/viz/evoked.py b/mne/viz/evoked.py index b047de4ea32..96ee0684e6e 100644 --- a/mne/viz/evoked.py +++ b/mne/viz/evoked.py @@ -1153,6 +1153,7 @@ def plot_evoked_topo( background_color="w", noise_cov=None, exclude="bads", + select=False, show=True, ): """Plot 2D topography of evoked responses. @@ -1218,6 +1219,15 @@ def plot_evoked_topo( exclude : list of str | ``'bads'`` Channels names to exclude from the plot. If ``'bads'``, the bad channels are excluded. By default, exclude is set to ``'bads'``. + select : bool + Whether to enable the lasso-selection tool to enable the user to select + channels. The selected channels will be available in + ``fig.lasso.selection``. + + .. versionadded:: 1.10.0 + exclude : list of str | ``'bads'`` + Channels names to exclude from the plot. If ``'bads'``, the + bad channels are excluded. By default, exclude is set to ``'bads'``. show : bool Show figure if True. @@ -1274,10 +1284,11 @@ def plot_evoked_topo( font_color=font_color, merge_channels=merge_grads, legend=legend, + noise_cov=noise_cov, axes=axes, exclude=exclude, + select=select, show=show, - noise_cov=noise_cov, ) diff --git a/mne/viz/tests/test_raw.py b/mne/viz/tests/test_raw.py index 89e0a7c543d..caa09ae4d07 100644 --- a/mne/viz/tests/test_raw.py +++ b/mne/viz/tests/test_raw.py @@ -1088,36 +1088,25 @@ def test_plot_sensors(raw): pytest.raises(TypeError, plot_sensors, raw) # needs to be info pytest.raises(ValueError, plot_sensors, raw.info, kind="sasaasd") plt.close("all") + + # Test lasso selection. fig, sels = raw.plot_sensors("select", show_names=True) ax = fig.axes[0] - - # Click with no sensors - _fake_click(fig, ax, (0.0, 0.0), xform="data") - _fake_click(fig, ax, (0, 0.0), xform="data", kind="release") - assert fig.lasso.selection == [] - - # Lasso with 1 sensor (upper left) - _fake_click(fig, ax, (0, 1), xform="ax") - fig.canvas.draw() - assert fig.lasso.selection == [] - _fake_click(fig, ax, (0.65, 1), xform="ax", kind="motion") - _fake_click(fig, ax, (0.65, 0.7), xform="ax", kind="motion") - _fake_keypress(fig, "control") - _fake_click(fig, ax, (0, 0.7), xform="ax", kind="release", key="control") + # Lasso a single sensor. + _fake_click(fig, ax, (-0.13, 0.13), xform="data") + _fake_click(fig, ax, (-0.11, 0.13), xform="data", kind="motion") + _fake_click(fig, ax, (-0.11, 0.06), xform="data", kind="motion") + _fake_click(fig, ax, (-0.13, 0.06), xform="data", kind="motion") + _fake_click(fig, ax, (-0.13, 0.13), xform="data", kind="motion") + _fake_click(fig, ax, (-0.13, 0.13), xform="data", kind="release") assert fig.lasso.selection == ["MEG 0121"] - # check that point appearance changes - fc = fig.lasso.collection.get_facecolors() - ec = fig.lasso.collection.get_edgecolors() - assert (fc[:, -1] == [0.5, 1.0, 0.5]).all() - assert (ec[:, -1] == [0.25, 1.0, 0.25]).all() - - _fake_click(fig, ax, (0.7, 1), xform="ax", kind="motion", key="control") - xy = ax.collections[0].get_offsets() - _fake_click(fig, ax, xy[2], xform="data", key="control") # single sel + # Add another sensor with a single click. + _fake_keypress(fig, "control") + _fake_click(fig, ax, (-0.1278, 0.0318), xform="data") + _fake_click(fig, ax, (-0.1278, 0.0318), xform="data", kind="release") + _fake_keypress(fig, "control", kind="release") assert fig.lasso.selection == ["MEG 0121", "MEG 0131"] - _fake_click(fig, ax, xy[2], xform="data", key="control") # deselect - assert fig.lasso.selection == ["MEG 0121"] plt.close("all") raw.info["dev_head_t"] = None # like empty room diff --git a/mne/viz/tests/test_topo.py b/mne/viz/tests/test_topo.py index 85b4b43dcf8..48d031739b9 100644 --- a/mne/viz/tests/test_topo.py +++ b/mne/viz/tests/test_topo.py @@ -23,7 +23,7 @@ ) from mne.viz.evoked import _line_plot_onselect from mne.viz.topo import _imshow_tfr, _plot_update_evoked_topo_proj, iter_topography -from mne.viz.utils import _fake_click +from mne.viz.utils import _fake_click, _fake_keypress base_dir = Path(__file__).parents[2] / "io" / "tests" / "data" evoked_fname = base_dir / "test-ave.fif" @@ -231,6 +231,16 @@ def test_plot_topo(): break plt.close("all") + # Test plot_topo with selection of channels enabled. + fig = evoked.plot_topo(select=True) + ax = fig.axes[0] + _fake_click(fig, ax, (0.05, 0.62), xform="data") + _fake_click(fig, ax, (0.2, 0.62), xform="data", kind="motion") + _fake_click(fig, ax, (0.2, 0.7), xform="data", kind="motion") + _fake_click(fig, ax, (0.05, 0.7), xform="data", kind="motion") + _fake_click(fig, ax, (0.05, 0.7), xform="data", kind="release") + assert fig.lasso.selection == ["MEG 0113", "MEG 0112", "MEG 0111"] + def test_plot_topo_nirs(fnirs_evoked): """Test plotting of ERP topography for nirs data.""" @@ -296,6 +306,30 @@ def test_plot_topo_image_epochs(): assert qm_cmap[0] is cmap +def test_plot_topo_select(): + """Test selecting sensors in an ERP topography plot.""" + # Show topography + evoked = _get_epochs().average() + fig = plot_evoked_topo(evoked, select=True) + ax = fig.axes[0] + + # Lasso select 3 out of the 6 sensors. + _fake_click(fig, ax, (0.05, 0.5), xform="data") + _fake_click(fig, ax, (0.2, 0.5), xform="data", kind="motion") + _fake_click(fig, ax, (0.2, 0.6), xform="data", kind="motion") + _fake_click(fig, ax, (0.05, 0.6), xform="data", kind="motion") + _fake_click(fig, ax, (0.05, 0.5), xform="data", kind="motion") + _fake_click(fig, ax, (0.05, 0.5), xform="data", kind="release") + assert fig.lasso.selection == ["MEG 0132", "MEG 0133", "MEG 0131"] + + # Add another sensor with a single click. + _fake_keypress(fig, "control") + _fake_click(fig, ax, (0.11, 0.65), xform="data") + _fake_click(fig, ax, (0.21, 0.65), xform="data", kind="release") + _fake_keypress(fig, "control", kind="release") + assert fig.lasso.selection == ["MEG 0111", "MEG 0132", "MEG 0133", "MEG 0131"] + + def test_plot_tfr_topo(): """Test plotting of TFR data.""" epochs = _get_epochs() diff --git a/mne/viz/tests/test_utils.py b/mne/viz/tests/test_utils.py index 59e2976e464..55dc0f1e65c 100644 --- a/mne/viz/tests/test_utils.py +++ b/mne/viz/tests/test_utils.py @@ -16,6 +16,7 @@ from mne.viz import ClickableImage, add_background_image, mne_analyze_colormap from mne.viz.ui_events import ColormapRange, link, subscribe from mne.viz.utils import ( + SelectFromCollection, _compute_scalings, _fake_click, _fake_keypress, @@ -274,3 +275,71 @@ def callback(event): cmap_new1 = fig.axes[0].CB.mappable.get_cmap().name cmap_new2 = fig2.axes[0].CB.mappable.get_cmap().name assert cmap_new1 == cmap_new2 == cmap_want != cmap_old + + +def test_select_from_collection(): + """Test the lasso selector for matplotlib figures.""" + fig, ax = plt.subplots() + collection = ax.scatter([1, 2, 2, 1], [1, 1, 0, 0], color="black", edgecolor="red") + ax.set_xlim(-1, 4) + ax.set_ylim(-1, 2) + lasso = SelectFromCollection(ax, collection, names=["A", "B", "C", "D"]) + assert lasso.selection == [] + + # Make a selection with no patches inside of it. + _fake_click(fig, ax, (0, 0), xform="data") + _fake_click(fig, ax, (0.5, 0), xform="data", kind="motion") + _fake_click(fig, ax, (0.5, 1), xform="data", kind="motion") + _fake_click(fig, ax, (0.5, 1), xform="data", kind="release") + assert lasso.selection == [] + + # Doing a single click on a patch should not select it. + _fake_click(fig, ax, (1, 1), xform="data") + assert lasso.selection == [] + + # Make a selection with two patches in it. + _fake_click(fig, ax, (0, 0.5), xform="data") + _fake_click(fig, ax, (3, 0.5), xform="data", kind="motion") + _fake_click(fig, ax, (3, 1.5), xform="data", kind="motion") + _fake_click(fig, ax, (0, 1.5), xform="data", kind="motion") + _fake_click(fig, ax, (0, 0.5), xform="data", kind="motion") + _fake_click(fig, ax, (0, 0.5), xform="data", kind="release") + assert lasso.selection == ["A", "B"] + + # Use Control key to lasso an additional patch. + _fake_keypress(fig, "control") + _fake_click(fig, ax, (0.5, -0.5), xform="data") + _fake_click(fig, ax, (1.5, -0.5), xform="data", kind="motion") + _fake_click(fig, ax, (1.5, 0.5), xform="data", kind="motion") + _fake_click(fig, ax, (0.5, 0.5), xform="data", kind="motion") + _fake_click(fig, ax, (0.5, 0.5), xform="data", kind="release") + _fake_keypress(fig, "control", kind="release") + assert lasso.selection == ["A", "B", "D"] + + # Use CTRL+SHIFT to remove a patch. + _fake_keypress(fig, "ctrl+shift") + _fake_click(fig, ax, (0.5, 0.5), xform="data") + _fake_click(fig, ax, (1.5, 0.5), xform="data", kind="motion") + _fake_click(fig, ax, (1.5, 1.5), xform="data", kind="motion") + _fake_click(fig, ax, (0.5, 1.5), xform="data", kind="motion") + _fake_click(fig, ax, (0.5, 1.5), xform="data", kind="release") + _fake_keypress(fig, "ctrl+shift", kind="release") + assert lasso.selection == ["B", "D"] + + # Check that the two selected patches have a different appearance. + fc = lasso.collection.get_facecolors() + ec = lasso.collection.get_edgecolors() + assert (fc[:, -1] == [0.5, 1.0, 0.5, 1.0]).all() + assert (ec[:, -1] == [0.25, 1.0, 0.25, 1.0]).all() + + # Test adding and removing single channels. + lasso.select_one(2) # should not do anything without modifier keys + assert lasso.selection == ["B", "D"] + _fake_keypress(fig, "control") + lasso.select_one(2) # add to selection + _fake_keypress(fig, "control", kind="release") + assert lasso.selection == ["B", "C", "D"] + _fake_keypress(fig, "ctrl+shift") + lasso.select_one(1) # remove from selection + assert lasso.selection == ["C", "D"] + _fake_keypress(fig, "ctrl+shift", kind="release") diff --git a/mne/viz/topo.py b/mne/viz/topo.py index 3364a455aed..5c43d4de48e 100644 --- a/mne/viz/topo.py +++ b/mne/viz/topo.py @@ -13,8 +13,10 @@ from .._fiff.pick import _picks_to_idx, channel_type, pick_types from ..defaults import _handle_default from ..utils import Bunch, _check_option, _clean_names, _is_numeric, _to_rgb, fill_doc +from .ui_events import ChannelsSelect, publish, subscribe from .utils import ( DraggableColorbar, + SelectFromCollection, _check_cov, _check_delayed_ssp, _draw_proj_checkbox, @@ -37,6 +39,7 @@ def iter_topography( axis_spinecolor="k", layout_scale=None, legend=False, + select=False, ): """Create iterator over channel positions. @@ -72,6 +75,12 @@ def iter_topography( If True, an additional axis is created in the bottom right corner that can be used to, e.g., construct a legend. The index of this axis will be -1. + select : bool + Whether to enable the lasso-selection tool to enable the user to select + channels. The selected channels will be available in + ``fig.lasso.selection``. + + .. versionadded:: 1.10.0 Returns ------- @@ -93,6 +102,7 @@ def iter_topography( axis_spinecolor, layout_scale, legend=legend, + select=select, ) @@ -128,6 +138,7 @@ def _iter_topography( img=False, axes=None, legend=False, + select=False, ): """Iterate over topography. @@ -193,8 +204,11 @@ def format_coord_multiaxis(x, y, ch_name=None): under_ax.set(xlim=[0, 1], ylim=[0, 1]) axs = list() + + shown_ch_names = [] for idx, name in iter_ch: ch_idx = ch_names.index(name) + shown_ch_names.append(name) if not unified: # old, slow way ax = plt.axes(pos[idx]) ax.patch.set_facecolor(axis_facecolor) @@ -226,24 +240,48 @@ def format_coord_multiaxis(x, y, ch_name=None): if unified: under_ax._mne_axs = axs # Create a PolyCollection for the axis backgrounds + sel_pos = pos[[i[0] for i in iter_ch]] verts = np.transpose( [ - pos[:, :2], - pos[:, :2] + pos[:, 2:] * [1, 0], - pos[:, :2] + pos[:, 2:], - pos[:, :2] + pos[:, 2:] * [0, 1], + sel_pos[:, :2], + sel_pos[:, :2] + sel_pos[:, 2:] * [1, 0], + sel_pos[:, :2] + sel_pos[:, 2:], + sel_pos[:, :2] + sel_pos[:, 2:] * [0, 1], ], [1, 0, 2], ) - if not img: - under_ax.add_collection( - collections.PolyCollection( - verts, - facecolor=axis_facecolor, - edgecolor=axis_spinecolor, - linewidth=1.0, + if not img: # Not needed for image plots. + collection = collections.PolyCollection( + verts, + facecolor=axis_facecolor, + edgecolor=axis_spinecolor, + linewidth=1.0, + ) + under_ax.add_collection(collection) + + if select: + # Configure the lasso-selection tool + fig.lasso = SelectFromCollection( + ax=under_ax, + collection=collection, + names=shown_ch_names, + alpha_nonselected=0, + alpha_selected=1, + linewidth_nonselected=0, + linewidth_selected=0.7, ) - ) # Not needed for image plots. + + def on_select(): + publish(fig, ChannelsSelect(ch_names=fig.lasso.selection)) + + def on_channels_select(event): + selection_inds = np.flatnonzero( + np.isin(shown_ch_names, event.ch_names) + ) + fig.lasso.select_many(selection_inds) + + fig.lasso.callbacks.append(on_select) + subscribe(fig, "channels_select", on_channels_select) for ax in axs: yield ax, ax._mne_ch_idx @@ -270,6 +308,7 @@ def _plot_topo( unified=False, img=False, axes=None, + select=False, ): """Plot on sensor layout.""" import matplotlib.pyplot as plt @@ -322,6 +361,7 @@ def _plot_topo( unified=unified, img=img, axes=axes, + select=select, ) for ax, ch_idx in my_topo_plot: @@ -340,8 +380,17 @@ def _plot_topo( def _plot_topo_onpick(event, show_func): """Onpick callback that shows a single channel in a new figure.""" - # make sure that the swipe gesture in OS-X doesn't open many figures orig_ax = event.inaxes + fig = orig_ax.figure + + # If we are doing lasso select, allow it to handle the click instead. + if hasattr(fig, "lasso") and event.key in ["control", "ctrl+shift"]: + return + + # make sure that the swipe gesture in OS-X doesn't open many figures + if fig.canvas._key in ["shift", "alt"]: + return + import matplotlib.pyplot as plt try: @@ -838,9 +887,10 @@ def _plot_evoked_topo( merge_channels=False, legend=True, axes=None, + noise_cov=None, exclude="bads", + select=False, show=True, - noise_cov=None, ): """Plot 2D topography of evoked responses. @@ -912,6 +962,10 @@ def _plot_evoked_topo( exclude : list of str | 'bads' Channels names to exclude from being shown. If 'bads', the bad channels are excluded. By default, exclude is set to 'bads'. + select : bool + Whether to enable the lasso-selection tool to enable the user to select + channels. The selected channels will be available in + ``fig.lasso.selection``. show : bool Show figure if True. @@ -1091,6 +1145,7 @@ def _plot_evoked_topo( y_label=y_label, unified=True, axes=axes, + select=select, ) add_background_image(fig, fig_background) @@ -1098,7 +1153,10 @@ def _plot_evoked_topo( if legend is not False: legend_loc = 0 if legend is True else legend labels = [e.comment if e.comment else "Unknown" for e in evoked] - handles = fig.axes[0].lines[: len(evoked)] + if select: + handles = fig.axes[0].lines[1 : len(evoked) + 1] + else: + handles = fig.axes[0].lines[: len(evoked)] legend = plt.legend( labels=labels, handles=handles, loc=legend_loc, prop={"size": 10} ) @@ -1157,6 +1215,7 @@ def plot_topo_image_epochs( fig_facecolor="k", fig_background=None, font_color="w", + select=False, show=True, ): """Plot Event Related Potential / Fields image on topographies. @@ -1204,6 +1263,12 @@ def plot_topo_image_epochs( :func:`matplotlib.pyplot.imshow`. Defaults to ``None``. font_color : color The color of tick labels in the colorbar. Defaults to white. + select : bool + Whether to enable the lasso-selection tool to enable the user to select + channels. The selected channels will be available in + ``fig.lasso.selection``. + + .. versionadded:: 1.10.0 show : bool Whether to show the figure. Defaults to ``True``. @@ -1293,6 +1358,7 @@ def plot_topo_image_epochs( y_label="Epoch", unified=True, img=True, + select=select, ) add_background_image(fig, fig_background) plt_show(show) diff --git a/mne/viz/ui_events.py b/mne/viz/ui_events.py index 256d5741ad3..b8b3fe29a4d 100644 --- a/mne/viz/ui_events.py +++ b/mne/viz/ui_events.py @@ -212,6 +212,26 @@ class Contours(UIEvent): contours: list[str] +@dataclass +@fill_doc +class ChannelsSelect(UIEvent): + """Indicates that the user has selected one or more channels. + + Parameters + ---------- + ch_names : list of str + The names of the channels that were selected. + + Attributes + ---------- + %(ui_event_name_source)s + ch_names : list of str + The names of the channels that were selected. + """ + + ch_names: list[str] + + def _get_event_channel(fig): """Get the event channel associated with a figure. diff --git a/mne/viz/utils.py b/mne/viz/utils.py index a09da17de7d..f9d64c49ec8 100644 --- a/mne/viz/utils.py +++ b/mne/viz/utils.py @@ -58,7 +58,7 @@ warn, ) from ..utils.misc import _identity_function -from .ui_events import ColormapRange, publish, subscribe +from .ui_events import ChannelsSelect, ColormapRange, publish, subscribe _channel_type_prettyprint = { "eeg": "EEG channel", @@ -807,12 +807,12 @@ def _fake_click(fig, ax, point, xform="ax", button=1, kind="press", key=None): ) -def _fake_keypress(fig, key): +def _fake_keypress(fig, key, kind="press"): from matplotlib import backend_bases fig.canvas.callbacks.process( - "key_press_event", - backend_bases.KeyEvent(name="key_press_event", canvas=fig.canvas, key=key), + f"key_{kind}_event", + backend_bases.KeyEvent(name=f"key_{kind}_event", canvas=fig.canvas, key=key), ) @@ -952,7 +952,7 @@ def plot_sensors( Whether to plot the sensors as 3d, topomap or as an interactive sensor selection dialog. Available options ``'topomap'``, ``'3d'``, ``'select'``. If ``'select'``, a set of channels can be selected - interactively by using lasso selector or clicking while holding control + interactively by using lasso selector or clicking while holding the control key. The selected channels are returned along with the figure instance. Defaults to ``'topomap'``. ch_type : None | str @@ -1163,10 +1163,10 @@ def _onpick_sensor(event, fig, ax, pos, ch_names, show_names): if event.mouseevent.inaxes != ax: return - if event.mouseevent.key == "control" and fig.lasso is not None: + if fig.lasso is not None and event.mouseevent.key in ["control", "ctrl+shift"]: + # Add the sensor to the selection instead of showing its name. for ind in event.ind: fig.lasso.select_one(ind) - return if show_names: return # channel names already visible @@ -1272,7 +1272,17 @@ def _plot_sensors_2d( lw=linewidth, ) if kind == "select": - fig.lasso = SelectFromCollection(ax, pts, ch_names) + fig.lasso = SelectFromCollection(ax, pts, names=ch_names) + + def on_select(): + publish(fig, ChannelsSelect(ch_names=fig.lasso.selection)) + + def on_channels_select(event): + selection_inds = np.flatnonzero(np.isin(ch_names, event.ch_names)) + fig.lasso.select_many(selection_inds) + + fig.lasso.callbacks.append(on_select) + subscribe(fig, "channels_select", on_channels_select) else: fig.lasso = None @@ -1595,11 +1605,14 @@ def _update(self): class SelectFromCollection: - """Select channels from a matplotlib collection using ``LassoSelector``. + """Select objects from a matplotlib collection using ``LassoSelector``. - Selected channels are saved in the ``selection`` attribute. This tool - highlights selected points by fading other points out (i.e., reducing their - alpha values). + The names of the selected objects are saved in the ``selection`` attribute. + This tool highlights selected objects by fading other objects out (i.e., + reducing their alpha values). + + Holding down the Control key will add to the current selection, and holding down + Control+Shift will remove from the current selection. Parameters ---------- @@ -1607,112 +1620,144 @@ class SelectFromCollection: Axes to interact with. collection : instance of matplotlib collection Collection you want to select from. - alpha_other : 0 <= float <= 1 - To highlight a selection, this tool sets all selected points to an - alpha value of 1 and non-selected points to ``alpha_other``. - Defaults to 0.3. - linewidth_other : float - Linewidth to use for non-selected sensors. Default is 1. + names : list of str + The names of the object. The selection is returned as a subset of these names. + alpha_selected : float + Alpha for selected objects (0=tranparant, 1=opaque). + alpha_nonselected : float + Alpha for non-selected objects (0=tranparant, 1=opaque). + linewidth_selected : float + Linewidth for the borders of selected objects. + linewidth_nonselected : float + Linewidth for the borders of non-selected objects. Notes ----- - This tool selects collection objects based on their *origins* - (i.e., ``offsets``). Calls all callbacks in self.callbacks when selection - is ready. + This tool selects collection objects which bounding boxes intersect with a lasso + path. Calls all callbacks in self.callbacks when selection is ready. """ def __init__( self, ax, collection, - ch_names, - alpha_other=0.5, - linewidth_other=0.5, + *, + names, alpha_selected=1, + alpha_nonselected=0.5, linewidth_selected=1, + linewidth_nonselected=0.5, + verbose=None, ): from matplotlib.widgets import LassoSelector + self.fig = ax.figure self.canvas = ax.figure.canvas self.collection = collection - self.ch_names = ch_names - self.alpha_other = alpha_other - self.linewidth_other = linewidth_other + self.names = names self.alpha_selected = alpha_selected + self.alpha_nonselected = alpha_nonselected self.linewidth_selected = linewidth_selected + self.linewidth_nonselected = linewidth_nonselected - self.xys = collection.get_offsets() - self.Npts = len(self.xys) + from matplotlib.collections import PolyCollection + from matplotlib.path import Path - # Ensure that we have separate colors for each object + if isinstance(collection, PolyCollection): + self.paths = collection.get_paths() + else: + self.paths = [Path([point]) for point in collection.get_offsets()] + self.Npts = len(self.paths) + if self.Npts != len(names): + raise ValueError( + f"Number of names ({len(names)}) does not match the number of objects " + f"in the collection ({self.Npts})." + ) + + # Ensure that we have colors for each object. self.fc = collection.get_facecolors() self.ec = collection.get_edgecolors() - self.lw = collection.get_linewidths() if len(self.fc) == 0: raise ValueError("Collection must have a facecolor") elif len(self.fc) == 1: self.fc = np.tile(self.fc, self.Npts).reshape(self.Npts, -1) + if len(self.ec) == 0: + self.ec = np.zeros((self.Npts, 4)) # all black + elif len(self.ec) == 1: self.ec = np.tile(self.ec, self.Npts).reshape(self.Npts, -1) - self.fc[:, -1] = self.alpha_other # deselect in the beginning - self.ec[:, -1] = self.alpha_other - self.lw = np.full(self.Npts, self.linewidth_other) + self.lw = np.full(self.Npts, float(self.linewidth_nonselected)) + # Initialize the lasso selector self.lasso = LassoSelector( ax, onselect=self.on_select, props=dict(color="red", linewidth=0.5) ) self.selection = list() + self.selection_inds = np.array([], dtype="int") self.callbacks = list() + # Deselect everything in the beginning. + self.style_objects() + + # For backwards compatibility + @property + def ch_names(self): + return self.names + + def notify(self): + """Notify listeners that a selection has been made.""" + logger.info(f"Selected channels: {self.selection}") + for callback in self.callbacks: + callback() + def on_select(self, verts): """Select a subset from the collection.""" from matplotlib.path import Path - if len(verts) <= 3: # Seems to be a good way to exclude single clicks. + # Don't respond to single clicks without extra keys being hold down. + # Figures like plot_evoked_topo want to do something else with them. + if len(verts) <= 3 and self.canvas._key not in ["control", "ctrl+shift"]: return path = Path(verts) - inds = np.nonzero([path.contains_point(xy) for xy in self.xys])[0] + inds = np.nonzero([path.intersects_path(p) for p in self.paths])[0] if self.canvas._key == "control": # Appending selection. - sels = [np.where(self.ch_names == c)[0][0] for c in self.selection] - inters = set(inds) - set(sels) - inds = list(inters.union(set(sels) - set(inds))) - - self.selection[:] = np.array(self.ch_names)[inds].tolist() - self.style_sensors(inds) + self.selection_inds = np.union1d(self.selection_inds, inds).astype("int") + elif self.canvas._key == "ctrl+shift": + self.selection_inds = np.setdiff1d(self.selection_inds, inds).astype("int") + else: + self.selection_inds = inds + self.selection = [self.names[i] for i in self.selection_inds] + self.style_objects() self.notify() def select_one(self, ind): """Select or deselect one sensor.""" - ch_name = self.ch_names[ind] - if ch_name in self.selection: - sel_ind = self.selection.index(ch_name) - self.selection.pop(sel_ind) + if self.canvas._key == "control": + self.selection_inds = np.union1d(self.selection_inds, [ind]) + elif self.canvas._key == "ctrl+shift": + self.selection_inds = np.setdiff1d(self.selection_inds, [ind]) else: - self.selection.append(ch_name) - inds = np.isin(self.ch_names, self.selection).nonzero()[0] - self.style_sensors(inds) + return # don't notify() + self.selection = [self.names[i] for i in self.selection_inds] + self.style_objects() self.notify() - def notify(self): - """Notify listeners that a selection has been made.""" - for callback in self.callbacks: - callback() - def select_many(self, inds): """Select many sensors using indices (for predefined selections).""" - self.selection[:] = np.array(self.ch_names)[inds].tolist() - self.style_sensors(inds) + self.selection_inds = inds + self.selection = [self.names[i] for i in self.selection_inds] + self.style_objects() - def style_sensors(self, inds): + def style_objects(self): """Style selected sensors as "active".""" # reset - self.fc[:, -1] = self.alpha_other - self.ec[:, -1] = self.alpha_other / 2 - self.lw[:] = self.linewidth_other + self.fc[:, -1] = self.alpha_nonselected + self.ec[:, -1] = self.alpha_nonselected / 2 + self.lw[:] = self.linewidth_nonselected # style sensors at `inds` - self.fc[inds, -1] = self.alpha_selected - self.ec[inds, -1] = self.alpha_selected - self.lw[inds] = self.linewidth_selected + self.fc[self.selection_inds, -1] = self.alpha_selected + self.ec[self.selection_inds, -1] = self.alpha_selected + self.lw[self.selection_inds] = self.linewidth_selected self.collection.set_facecolors(self.fc) self.collection.set_edgecolors(self.ec) self.collection.set_linewidths(self.lw) From 7daeceef4a2b80f4d849ec55a72a6450020c8c0c Mon Sep 17 00:00:00 2001 From: Roy Eric <139973278+Randomidous@users.noreply.github.com> Date: Fri, 24 Jan 2025 17:05:25 +0100 Subject: [PATCH 19/24] BUGFIX: return events if provided when current = desired sfreq (#13070) Co-authored-by: Eric Larson Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Daniel McCloy --- doc/changes/devel/13070.bugfix.rst | 1 + doc/changes/names.inc | 1 + mne/io/base.py | 5 ++++- mne/io/fiff/tests/test_raw_fiff.py | 10 ++++++++++ 4 files changed, 16 insertions(+), 1 deletion(-) create mode 100644 doc/changes/devel/13070.bugfix.rst diff --git a/doc/changes/devel/13070.bugfix.rst b/doc/changes/devel/13070.bugfix.rst new file mode 100644 index 00000000000..3c6a3c25082 --- /dev/null +++ b/doc/changes/devel/13070.bugfix.rst @@ -0,0 +1 @@ +Return events when requested even when current matches the desired sfreq in :meth:`mne.io.Raw.resample` by :newcontrib:`Roy Eric Wieske`. \ No newline at end of file diff --git a/doc/changes/names.inc b/doc/changes/names.inc index eb444c5e594..5a58ac0fa34 100644 --- a/doc/changes/names.inc +++ b/doc/changes/names.inc @@ -258,6 +258,7 @@ .. _Roman Goj: https://romanmne.blogspot.co.uk .. _Ross Maddox: https://www.urmc.rochester.edu/labs/maddox-lab.aspx .. _Rotem Falach: https://github.com/Falach +.. _Roy Eric Wieske: https://github.com/Randomidous .. _Sammi Chekroud: https://github.com/schekroud .. _Samu Taulu: https://phys.washington.edu/people/samu-taulu .. _Samuel Deslauriers-Gauthier: https://github.com/sdeslauriers diff --git a/mne/io/base.py b/mne/io/base.py index 280330367f7..b3052b80aff 100644 --- a/mne/io/base.py +++ b/mne/io/base.py @@ -1386,7 +1386,10 @@ def resample( sfreq = float(sfreq) o_sfreq = float(self.info["sfreq"]) if _check_resamp_noop(sfreq, o_sfreq): - return self + if events is not None: + return self, events.copy() + else: + return self # When no event object is supplied, some basic detection of dropped # events is performed to generate a warning. Finding events can fail diff --git a/mne/io/fiff/tests/test_raw_fiff.py b/mne/io/fiff/tests/test_raw_fiff.py index 1ae0cc52901..3ae49189161 100644 --- a/mne/io/fiff/tests/test_raw_fiff.py +++ b/mne/io/fiff/tests/test_raw_fiff.py @@ -23,6 +23,7 @@ concatenate_events, create_info, equalize_channels, + events_from_annotations, find_events, make_fixed_length_epochs, pick_channels, @@ -1318,6 +1319,15 @@ def test_crop(): assert raw.n_times - 1 == raw3.n_times +@testing.requires_testing_data +def test_resample_with_events(): + """Test resampling raws with events.""" + raw = read_raw_fif(fif_fname) + raw.resample(250) # pretend raw is recorded at 250 Hz + events, _ = events_from_annotations(raw) + raw, events = raw.resample(250, events=events) + + @testing.requires_testing_data def test_resample_equiv(): """Test resample (with I/O and multiple files).""" From 3db12ff5357d4d6666f3d2257e91cee877e83234 Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Fri, 24 Jan 2025 14:16:21 -0500 Subject: [PATCH 20/24] BUG: Fix bug with Mesa 3D detection (#13082) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- doc/changes/devel/13082.bugfix.rst | 1 + examples/preprocessing/movement_detection.py | 1 + mne/conftest.py | 1 + mne/viz/_brain/tests/test_brain.py | 3 --- mne/viz/backends/_pyvista.py | 2 +- 5 files changed, 4 insertions(+), 4 deletions(-) create mode 100644 doc/changes/devel/13082.bugfix.rst diff --git a/doc/changes/devel/13082.bugfix.rst b/doc/changes/devel/13082.bugfix.rst new file mode 100644 index 00000000000..0f5cad3d0af --- /dev/null +++ b/doc/changes/devel/13082.bugfix.rst @@ -0,0 +1 @@ +Fix bug with automated Mesa 3D detection for proper 3D option setting on systems with software rendering, by `Eric Larson`_. \ No newline at end of file diff --git a/examples/preprocessing/movement_detection.py b/examples/preprocessing/movement_detection.py index 9bcac562588..dd468feb464 100644 --- a/examples/preprocessing/movement_detection.py +++ b/examples/preprocessing/movement_detection.py @@ -81,6 +81,7 @@ ############################################################################## # After checking the annotated movement artifacts, calculate the new transform # and plot it: + new_dev_head_t = compute_average_dev_head_t(raw, head_pos) raw.info["dev_head_t"] = new_dev_head_t fig = mne.viz.plot_alignment( diff --git a/mne/conftest.py b/mne/conftest.py index 8a4586067b3..2a73c7a1b8e 100644 --- a/mne/conftest.py +++ b/mne/conftest.py @@ -174,6 +174,7 @@ def pytest_configure(config: pytest.Config): # pandas ignore:\n*Pyarrow will become a required dependency of pandas.*:DeprecationWarning ignore:np\.find_common_type is deprecated.*:DeprecationWarning + ignore:Python binding for RankQuantileOptions.*: # pyvista <-> NumPy 2.0 ignore:__array_wrap__ must accept context and return_scalar arguments.*:DeprecationWarning # nibabel <-> NumPy 2.0 diff --git a/mne/viz/_brain/tests/test_brain.py b/mne/viz/_brain/tests/test_brain.py index 5d092c21713..46406542b5c 100644 --- a/mne/viz/_brain/tests/test_brain.py +++ b/mne/viz/_brain/tests/test_brain.py @@ -773,9 +773,6 @@ def test_single_hemi(hemi, renderer_interactive_pyvistaqt, brain_gc): def test_brain_save_movie(tmp_path, renderer, brain_gc, interactive_state): """Test saving a movie of a Brain instance.""" imageio_ffmpeg = pytest.importorskip("imageio_ffmpeg") - # TODO: Figure out why this fails -- some imageio_ffmpeg error - if os.getenv("MNE_CI_KIND", "") == "conda" and platform.system() == "Linux": - pytest.skip("Test broken for unknown reason on conda linux") brain = _create_testing_brain( hemi="lh", time_viewer=False, cortex=["r", "b"] diff --git a/mne/viz/backends/_pyvista.py b/mne/viz/backends/_pyvista.py index ee5b62404d3..0bd1ae1d3ca 100644 --- a/mne/viz/backends/_pyvista.py +++ b/mne/viz/backends/_pyvista.py @@ -1344,7 +1344,7 @@ def _is_osmesa(plotter): "surface rendering, consider upgrading to 18.3.6 or " "later." ) - is_osmesa = "via llvmpipe" in gpu_info + is_osmesa = "llvmpipe" in gpu_info return is_osmesa From 631ddb3e9da67456947e23c6a070aa869d853a26 Mon Sep 17 00:00:00 2001 From: Marijn van Vliet Date: Mon, 27 Jan 2025 16:34:16 +0200 Subject: [PATCH 21/24] Fix _close() on MNEAnnotationsFigure and MNESelectionFigure [circle deploy] (#13086) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Richard Höchenberger --- mne/viz/_mpl_figure.py | 4 ++-- tutorials/epochs/60_make_fixed_length_epochs.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/mne/viz/_mpl_figure.py b/mne/viz/_mpl_figure.py index 3987b641dff..f3563b454f0 100644 --- a/mne/viz/_mpl_figure.py +++ b/mne/viz/_mpl_figure.py @@ -186,7 +186,7 @@ def _inch_to_rel(self, dim_inches, horiz=True): class MNEAnnotationFigure(MNEFigure): """Interactive dialog figure for annotations.""" - def _close(self, event): + def _close(self, event=None): """Handle close events (via keypress or window [x]).""" parent = self.mne.parent_fig # disable span selector @@ -275,7 +275,7 @@ def _set_active_button(self, idx, *, draw=True): class MNESelectionFigure(MNEFigure): """Interactive dialog figure for channel selections.""" - def _close(self, event): + def _close(self, event=None): """Handle close events.""" self.mne.parent_fig.mne.child_figs.remove(self) self.mne.fig_selection = None diff --git a/tutorials/epochs/60_make_fixed_length_epochs.py b/tutorials/epochs/60_make_fixed_length_epochs.py index 04a4ec87c7d..0920c14d457 100644 --- a/tutorials/epochs/60_make_fixed_length_epochs.py +++ b/tutorials/epochs/60_make_fixed_length_epochs.py @@ -11,7 +11,7 @@ We will also briefly demonstrate how to use these epochs in connectivity analysis. -First, we import necessary modules and read in a sample raw data set. +First, we import the necessary modules and read in a sample raw data set. This data set contains brain activity that is event-related, i.e., synchronized to the onset of auditory stimuli. However, rather than creating epochs by segmenting the data around the onset of each stimulus, we will From 4037ead8fe9ec27d7342263c574b99a7bc537104 Mon Sep 17 00:00:00 2001 From: Marijn van Vliet Date: Tue, 28 Jan 2025 01:08:56 +0200 Subject: [PATCH 22/24] Fix signature of some more _close() methods [circle deploy] (#13087) --- mne/viz/_figure.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mne/viz/_figure.py b/mne/viz/_figure.py index f492c4b7fde..090c661f633 100644 --- a/mne/viz/_figure.py +++ b/mne/viz/_figure.py @@ -424,7 +424,7 @@ def _redraw(self, update_data=True, annotations=False): if annotations and not self.mne.is_epochs: self._draw_annotations() - def _close(self, event): + def _close(self, event=None): """Handle close events (via keypress or window [x]).""" from matplotlib.pyplot import close From 45fb777fbc53c88888032d40a14940c985079a93 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 27 Jan 2025 23:56:13 +0000 Subject: [PATCH 23/24] [pre-commit.ci] pre-commit autoupdate (#13088) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Daniel McCloy --- .pre-commit-config.yaml | 4 ++-- mne/_fiff/proj.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 34b3ce9b130..fb5a6bd4247 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,7 +1,7 @@ repos: # Ruff mne - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.9.2 + rev: v0.9.3 hooks: - id: ruff name: ruff lint mne @@ -23,7 +23,7 @@ repos: # Codespell - repo: https://github.com/codespell-project/codespell - rev: v2.3.0 + rev: v2.4.0 hooks: - id: codespell additional_dependencies: diff --git a/mne/_fiff/proj.py b/mne/_fiff/proj.py index d6ec108e34d..aa010085904 100644 --- a/mne/_fiff/proj.py +++ b/mne/_fiff/proj.py @@ -1100,7 +1100,7 @@ def _has_eeg_average_ref_proj( def _needs_eeg_average_ref_proj(info): - """Determine if the EEG needs an averge EEG reference. + """Determine if the EEG needs an average EEG reference. This returns True if no custom reference has been applied and no average reference projection is present in the list of projections. From 715540a823ae5dec335bee0b2499f1f7183c19c4 Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Tue, 28 Jan 2025 13:44:50 -0500 Subject: [PATCH 24/24] MAINT: Fix CircleCI [circle deploy] (#13089) --- mne/viz/utils.py | 2 +- tutorials/epochs/60_make_fixed_length_epochs.py | 9 ++++----- tutorials/evoked/10_evoked_overview.py | 11 +++++------ 3 files changed, 10 insertions(+), 12 deletions(-) diff --git a/mne/viz/utils.py b/mne/viz/utils.py index f9d64c49ec8..b9b844b321a 100644 --- a/mne/viz/utils.py +++ b/mne/viz/utils.py @@ -1185,7 +1185,7 @@ def _onpick_sensor(event, fig, ax, pos, ch_names, show_names): fig.canvas.draw() -def _close_event(event, fig): +def _close_event(event=None, fig=None): """Listen for sensor plotter close event.""" if getattr(fig, "lasso", None) is not None: fig.lasso.disconnect() diff --git a/tutorials/epochs/60_make_fixed_length_epochs.py b/tutorials/epochs/60_make_fixed_length_epochs.py index 0920c14d457..10b8c12ea19 100644 --- a/tutorials/epochs/60_make_fixed_length_epochs.py +++ b/tutorials/epochs/60_make_fixed_length_epochs.py @@ -5,11 +5,10 @@ ================================================= This tutorial shows how to segment continuous data into a set of epochs spaced -equidistantly in time. The epochs will not be created based on experimental -events; instead, the continuous data will be "chunked" into consecutive epochs -(which may be temporally overlapping, adjacent, or separated). -We will also briefly demonstrate how to use these epochs in connectivity -analysis. +equidistantly in time. The epochs will not be created based on experimental events; +instead, the continuous data will be "chunked" into consecutive epochs (which may be +temporally overlapping, adjacent, or separated). We will also briefly demonstrate how +to use these epochs in connectivity analysis. First, we import the necessary modules and read in a sample raw data set. This data set contains brain activity that is event-related, i.e., diff --git a/tutorials/evoked/10_evoked_overview.py b/tutorials/evoked/10_evoked_overview.py index 75e63692bd2..b251a1f8239 100644 --- a/tutorials/evoked/10_evoked_overview.py +++ b/tutorials/evoked/10_evoked_overview.py @@ -5,12 +5,11 @@ The Evoked data structure: evoked/averaged data =============================================== -This tutorial covers the basics of creating and working with :term:`evoked` -data. It introduces the :class:`~mne.Evoked` data structure in detail, -including how to load, query, subset, export, and plot data from an -:class:`~mne.Evoked` object. For details on creating an :class:`~mne.Evoked` -object from (possibly simulated) data in a :class:`NumPy array -`, see :ref:`tut-creating-data-structures`. +This tutorial covers the basics of creating and working with :term:`evoked` data. It +introduces the :class:`~mne.Evoked` data structure in detail, including how to load, +query, subset, export, and plot data from an :class:`~mne.Evoked` object. For details +on creating an :class:`~mne.Evoked` object from (possibly simulated) data in a +:class:`NumPy array `, see :ref:`tut-creating-data-structures`. As usual, we start by importing the modules we need: """