diff --git a/doc/changes/devel/13080.apichange.rst b/doc/changes/devel/13080.apichange.rst new file mode 100644 index 00000000000..2c6f5b575b2 --- /dev/null +++ b/doc/changes/devel/13080.apichange.rst @@ -0,0 +1 @@ +The backward-compatible defaults in :func:`mne.preprocessing.maxwell_filter` of ``st_overlap=False`` and ```mc_interp=None`` will change to their smooth variants ``True`` and ``"hann"``, respectively, in 1.11, by `Eric Larson`_. diff --git a/doc/changes/devel/13080.newfeature.rst b/doc/changes/devel/13080.newfeature.rst new file mode 100644 index 00000000000..2e3e4c3cd77 --- /dev/null +++ b/doc/changes/devel/13080.newfeature.rst @@ -0,0 +1 @@ +Add smooth processing of tSSS windows (using overlap-add) and movement compensation (using smooth interpolation of head positions) in :func:`mne.preprocessing.maxwell_filter` via ``st_overlap`` and ```mc_interp`` options, respectively, by `Eric Larson`_. \ No newline at end of file diff --git a/mne/_ola.py b/mne/_ola.py index 135ff835da3..e43e7cd3d31 100644 --- a/mne/_ola.py +++ b/mne/_ola.py @@ -5,7 +5,7 @@ import numpy as np from scipy.signal import get_window -from .utils import _ensure_int, logger, verbose +from .utils import _ensure_int, _validate_type, logger, verbose ############################################################################### # Class for interpolation between adjacent points @@ -42,7 +42,7 @@ class _Interp2: """ - def __init__(self, control_points, values, interp="hann"): + def __init__(self, control_points, values, interp="hann", *, name="Interp2"): # set up interpolation self.control_points = np.array(control_points, int).ravel() if not np.array_equal(np.unique(self.control_points), self.control_points): @@ -79,6 +79,7 @@ def val(pt): self._position = 0 # start at zero self._left_idx = 0 self._left = self._right = self._use_interp = None + self.name = name known_types = ("cos2", "linear", "zero", "hann") if interp not in known_types: raise ValueError(f'interp must be one of {known_types}, got "{interp}"') @@ -90,10 +91,10 @@ def feed_generator(self, n_pts): n_pts = _ensure_int(n_pts, "n_pts") original_position = self._position stop = self._position + n_pts - logger.debug(f"Feed {n_pts} ({self._position}-{stop})") + logger.debug(f" ~ {self.name} Feed {n_pts} ({self._position}-{stop})") used = np.zeros(n_pts, bool) if self._left is None: # first one - logger.debug(f" Eval @ 0 ({self.control_points[0]})") + logger.debug(f" ~ {self.name} Eval @ 0 ({self.control_points[0]})") self._left = self.values(self.control_points[0]) if len(self.control_points) == 1: self._right = self._left @@ -102,7 +103,7 @@ def feed_generator(self, n_pts): # Left zero-order hold condition if self._position < self.control_points[self._left_idx]: n_use = min(self.control_points[self._left_idx] - self._position, n_pts) - logger.debug(f" Left ZOH {n_use}") + logger.debug(f" ~ {self.name} Left ZOH {n_use}") this_sl = slice(None, n_use) assert used[this_sl].size == n_use assert not used[this_sl].any() @@ -127,7 +128,9 @@ def feed_generator(self, n_pts): self._left_idx += 1 self._use_interp = None # need to recreate it eval_pt = self.control_points[self._left_idx + 1] - logger.debug(f" Eval @ {self._left_idx + 1} ({eval_pt})") + logger.debug( + f" ~ {self.name} Eval @ {self._left_idx + 1} ({eval_pt})" + ) self._right = self.values(eval_pt) assert self._right is not None left_point = self.control_points[self._left_idx] @@ -148,7 +151,8 @@ def feed_generator(self, n_pts): n_use = min(stop, right_point) - self._position if n_use > 0: logger.debug( - f" Interp {self._interp} {n_use} ({left_point}-{right_point})" + f" ~ {self.name} Interp {self._interp} {n_use} " + f"({left_point}-{right_point})" ) interp_start = self._position - left_point assert interp_start >= 0 @@ -169,7 +173,7 @@ def feed_generator(self, n_pts): if self.control_points[self._left_idx] <= self._position: n_use = stop - self._position if n_use > 0: - logger.debug(f" Right ZOH {n_use}") + logger.debug(f" ~ {self.name} Right ZOH %s" % n_use) this_sl = slice(n_pts - n_use, None) assert not used[this_sl].any() used[this_sl] = True @@ -210,14 +214,13 @@ def feed(self, n_pts): def _check_store(store): + _validate_type(store, (np.ndarray, list, tuple, _Storer), "store") if isinstance(store, np.ndarray): store = [store] - if isinstance(store, list | tuple) and all( - isinstance(s, np.ndarray) for s in store - ): + if not isinstance(store, _Storer): + if not all(isinstance(s, np.ndarray) for s in store): + raise TypeError("All instances must be ndarrays") store = _Storer(*store) - if not callable(store): - raise TypeError(f"store must be callable, got type {type(store)}") return store @@ -229,10 +232,8 @@ class _COLA: process : callable A function that takes a chunk of input data with shape ``(n_channels, n_samples)`` and processes it. - store : callable | ndarray - A function that takes a completed chunk of output data. - Can also be an ``ndarray``, in which case it is treated as the - output data in which to store the results. + store : ndarray | list of ndarray | _Storer + The output data in which to store the results. n_total : int The total number of samples. n_samples : int @@ -276,6 +277,7 @@ def __init__( window="hann", tol=1e-10, *, + name="COLA", verbose=None, ): n_samples = _ensure_int(n_samples, "n_samples") @@ -302,6 +304,7 @@ def __init__( self._store = _check_store(store) self._idx = 0 self._in_buffers = self._out_buffers = None + self.name = name # Create our window boundaries window_name = window if isinstance(window, str) else "custom" @@ -343,6 +346,7 @@ def feed(self, *datas, verbose=None, **kwargs): raise ValueError( f"Got {len(datas)} array(s), needed {len(self._in_buffers)}" ) + current_offset = 0 # should be updated below for di, data in enumerate(datas): if not isinstance(data, np.ndarray) or data.ndim < 1: raise TypeError( @@ -363,9 +367,12 @@ def feed(self, *datas, verbose=None, **kwargs): f"shape[:-1]=={self._in_buffers[di].shape[:-1]}, got dtype " f"{data.dtype} shape[:-1]={data.shape[:-1]}" ) + # This gets updated on first iteration, so store it before it updates + if di == 0: + current_offset = self._in_offset logger.debug( - f" + Appending {self._in_offset:d}->" - f"{self._in_offset + data.shape[-1]:d}" + f" + {self.name}[{di}] Appending " + f"{current_offset}:{current_offset + data.shape[-1]}" ) self._in_buffers[di] = np.concatenate([self._in_buffers[di], data], -1) if self._in_offset > self.stops[-1]: @@ -388,13 +395,18 @@ def feed(self, *datas, verbose=None, **kwargs): if self._idx == 0: for offset in range(self._n_samples - self._step, 0, -self._step): this_window[:offset] += self._window[-offset:] - logger.debug(f" * Processing {start}->{stop}") this_proc = [in_[..., :this_len].copy() for in_ in self._in_buffers] + logger.debug( + f" * {self.name}[:] Processing {start}:{stop} " + f"(e.g., {this_proc[0].flat[[0, -1]]})" + ) if not all( proc.shape[-1] == this_len == this_window.size for proc in this_proc ): raise RuntimeError("internal indexing error") - outs = self._process(*this_proc, **kwargs) + start = self._store.idx + stop = self._store.idx + this_len + outs = self._process(*this_proc, start=start, stop=stop, **kwargs) if self._out_buffers is None: max_len = np.max(self.stops - self.starts) self._out_buffers = [ @@ -409,9 +421,12 @@ def feed(self, *datas, verbose=None, **kwargs): else: next_start = self.stops[-1] delta = next_start - self.starts[self._idx - 1] + logger.debug( + f" + {self.name}[:] Shifting input and output buffers by " + f"{delta} samples (storing {start}:{stop})" + ) for di in range(len(self._in_buffers)): self._in_buffers[di] = self._in_buffers[di][..., delta:] - logger.debug(f" - Shifting input/output buffers by {delta:d} samples") self._store(*[o[..., :delta] for o in self._out_buffers]) for ob in self._out_buffers: ob[..., :-delta] = ob[..., delta:] @@ -430,8 +445,8 @@ def _check_cola(win, nperseg, step, window_name, tol=1e-10): deviation = np.max(np.abs(binsums - const)) if deviation > tol: raise ValueError( - f"segment length {nperseg:d} with step {step:d} for {window_name} window " - "type does not provide a constant output " + f"segment length {nperseg} with step {step} for {window_name} " + "window type does not provide a constant output " f"({100 * deviation / const:g}% deviation)" ) return const diff --git a/mne/epochs.py b/mne/epochs.py index ee8921d3990..96f247875d9 100644 --- a/mne/epochs.py +++ b/mne/epochs.py @@ -4824,7 +4824,7 @@ def average_movements( del head_pos _check_usable(epochs, ignore_ref) origin = _check_origin(origin, epochs.info, "head") - recon_trans = _check_destination(destination, epochs.info, True) + recon_trans = _check_destination(destination, epochs.info, "head") logger.info(f"Aligning and averaging up to {len(epochs.events)} epochs") if not np.array_equal(epochs.events[:, 0], np.unique(epochs.events[:, 0])): diff --git a/mne/filter.py b/mne/filter.py index a7d7c883e2f..acdf867c63a 100644 --- a/mne/filter.py +++ b/mne/filter.py @@ -1664,24 +1664,16 @@ def _mt_spectrum_remove_win( n_overlap = (n_samples + 1) // 2 x_out = np.zeros_like(x) rm_freqs = list() - idx = [0] # Define how to process a chunk of data - def process(x_): + def process(x_, *, start, stop): out = _mt_spectrum_remove( x_, sfreq, line_freqs, notch_widths, window_fun, threshold, get_thresh ) rm_freqs.append(out[1]) return (out[0],) # must return a tuple - # Define how to store a chunk of fully processed data (it's trivial) - def store(x_): - stop = idx[0] + x_.shape[-1] - x_out[..., idx[0] : stop] += x_ - idx[0] = stop - - _COLA(process, store, n_times, n_samples, n_overlap, sfreq, verbose=False).feed(x) - assert idx[0] == n_times + _COLA(process, x_out, n_times, n_samples, n_overlap, sfreq, verbose=False).feed(x) return x_out, rm_freqs diff --git a/mne/preprocessing/maxwell.py b/mne/preprocessing/maxwell.py index 8c9c0a93957..2e71219c127 100644 --- a/mne/preprocessing/maxwell.py +++ b/mne/preprocessing/maxwell.py @@ -21,6 +21,7 @@ from .._fiff.proj import Projection from .._fiff.tag import _coil_trans_to_loc, _loc_to_coil_trans from .._fiff.write import DATE_NONE, _generate_meas_id +from .._ola import _COLA, _Interp2, _Storer from ..annotations import _annotations_starts_stops from ..bem import _check_origin from ..channels.channels import _get_T1T2_mag_inds, fix_mag_coil_types @@ -52,6 +53,7 @@ _pl, _time_mask, _validate_type, + _verbose_safe_false, logger, use_log_level, verbose, @@ -224,6 +226,8 @@ def maxwell_filter( mag_scale=100.0, skip_by_annotation=("edge", "bad_acq_skip"), extended_proj=(), + st_overlap=None, + mc_interp=None, verbose=None, ): """Maxwell filter data using multipole moments. @@ -271,6 +275,13 @@ def maxwell_filter( .. versionadded:: 0.17 %(extended_proj_maxwell)s + st_overlap : bool + If True (default in 1.11), tSSS processing will use a constant + overlap-add method. If False (default in 1.10), then + non-overlapping windows will be used. + + .. versionadded:: 1.10 + %(maxwell_mc_interp)s %(verbose)s Returns @@ -299,7 +310,7 @@ def maxwell_filter( .. warning:: Maxwell filtering in MNE is not designed or certified for clinical use. - Compared to the MEGIN MaxFilter™ software, the MNE Maxwell filtering + Compared to the MEGIN MaxFilter™ 2.2.11 software, the MNE Maxwell filtering routines currently provide the following features: .. table:: @@ -338,6 +349,10 @@ def maxwell_filter( +-----------------------------------------------------------------------------+-----+-----------+ | Head position estimation (:func:`~mne.chpi.compute_head_pos`) | ✓ | ✓ | +-----------------------------------------------------------------------------+-----+-----------+ + | Overlap-add processing for spatio-temporal projections | ✓ | | + +-----------------------------------------------------------------------------+-----+-----------+ + | Smooth interpolation in movement compensation | ✓ | | + +-----------------------------------------------------------------------------+-----+-----------+ | Certified for clinical use | | ✓ | +-----------------------------------------------------------------------------+-----+-----------+ | Extended external basis (eSSS) | ✓ | | @@ -399,6 +414,8 @@ def maxwell_filter( mag_scale=mag_scale, skip_by_annotation=skip_by_annotation, extended_proj=extended_proj, + st_overlap=st_overlap, + mc_interp=mc_interp, ) raw_sss = _run_maxwell_filter(raw, **params) # Update info @@ -429,6 +446,8 @@ def _prep_maxwell_filter( skip_by_annotation=("edge", "bad_acq_skip"), extended_proj=(), reconstruct="in", + st_overlap=False, + mc_interp="zero", verbose=None, ): # There are an absurd number of different possible notations for spherical @@ -446,8 +465,7 @@ def _prep_maxwell_filter( if st_correlation <= 0.0 or st_correlation > 1.0: raise ValueError(f"Need 0 < st_correlation <= 1., got {st_correlation}") _check_option("coord_frame", coord_frame, ["head", "meg"]) - head_frame = True if coord_frame == "head" else False - recon_trans = _check_destination(destination, raw.info, head_frame) + recon_trans = _check_destination(destination, raw.info, coord_frame) if st_duration is not None: st_duration = float(st_duration) st_correlation = float(st_correlation) @@ -466,7 +484,29 @@ def _prep_maxwell_filter( ) if st_only and st_duration is None: raise ValueError("st_duration must not be None if st_only is True") - head_pos = _check_pos(head_pos, head_frame, raw, st_fixed, raw.info["sfreq"]) + if st_overlap is None: + if st_duration is not None: + # TODO VERSION 1.10/1.11 deprecation + warn( + "st_overlap defaults to False in 1.10 but will change to " + "True in 1.11. Set it explicitly to avoid this warning.", + DeprecationWarning, + ) + st_overlap = False + add_channels = head_pos is not None and not st_only + if mc_interp is None: + if head_pos is not None: + # TODO VERSION 1.10/1.11 deprecation + warn( + 'mc_interp defaults to "zero" in 1.10 but will change ' + 'to "hann" in 1.11, set it explicitly to avoid this ' + "message.", + DeprecationWarning, + ) + mc_interp = "zero" + add_channels = (head_pos is not None) and (not st_only) + head_pos = _check_pos(head_pos, coord_frame, raw, st_fixed) + mc = _MoveComp(head_pos, coord_frame, raw, mc_interp, reconstruct) _check_info( raw.info, sss=not st_only, @@ -590,11 +630,10 @@ def _prep_maxwell_filter( st_correlation = None st_when = "never" update_kwargs["max_st"] = max_st - del st_fixed, max_st + del max_st # Figure out which transforms we need for each tSSS block # (and transform pos[1] to times) - head_pos[1] = raw.time_as_index(head_pos[1], use_rounding=True) # Compute the first bit of pos_data for cHPI reporting if info["dev_head_t"] is not None and head_pos[0] is not None: this_pos_quat = np.concatenate( @@ -650,6 +689,10 @@ def _prep_maxwell_filter( S_recon=S_recon, update_kwargs=update_kwargs, ignore_ref=ignore_ref, + add_channels=add_channels, + st_fixed=st_fixed, + st_overlap=st_overlap, + mc=mc, ) return params @@ -676,26 +719,33 @@ def _run_maxwell_filter( ignore_ref=False, reconstruct="in", copy=True, + add_channels, + st_fixed, + st_overlap, + mc, ): # Eventually find_bad_channels_maxwell could be sped up by moving this # outside the loop (e.g., in the prep function) but regularization depends # on which channels are being used, so easier just to include it here. # The time it takes to recompute S and pS themselves is roughly on par # with the np.dot with the data, so not a huge gain to be made there. - S_decomp, S_decomp_full, pS_decomp, reg_moments, n_use_in = _get_this_decomp_trans( - info["dev_head_t"], t=0.0 - ) - update_kwargs.update(reg_moments=reg_moments.copy()) if ctc is not None: ctc = ctc[good_mask][:, good_mask] - add_channels = (head_pos[0] is not None) and (not st_only) and copy + add_channels = add_channels and copy raw_sss, pos_picks = _copy_preload_add_channels(raw, add_channels, copy, info) sfreq = info["sfreq"] del raw if not st_only: # remove MEG projectors, they won't apply now _remove_meg_projs_comps(raw_sss, ignore_ref) + + # Figure out smooth overlap-add and interp params + if st_fixed and not st_only: + these_picks = meg_picks[good_mask] + else: + these_picks = meg_picks + # Figure out which segments of data we can use onsets, ends = _annotations_starts_stops(raw_sss, skip_by_annotation, invert=True) max_samps = (ends - onsets).max() @@ -705,171 +755,233 @@ def _run_maxwell_filter( "longest contiguous duration of the data " "({max_samps / sfreq:0.1f}s)." ) - # Generate time points to break up data into equal-length windows - starts, stops = list(), list() + + # This must be initialized inside _run_maxwell_filter because + # find_bad_channels_maxwell modifies good_mask + mc.initialize(_get_this_decomp_trans, info["dev_head_t"], S_recon) + update_kwargs.update(reg_moments=mc.reg_moments_0) + + # Process each valid block of data separately for onset, end in zip(onsets, ends): - read_lims = np.arange(onset, end + 1, st_duration) - if len(read_lims) == 1: - read_lims = np.concatenate([read_lims, [end]]) - if read_lims[-1] != end: - read_lims[-1] = end - # fold it into the previous buffer - n_last_buf = read_lims[-1] - read_lims[-2] - if st_correlation is not None and len(read_lims) > 2: - if n_last_buf >= st_duration: - logger.info( - " Spatiotemporal window did not fit evenly into" - "contiguous data segment. " - f"{(n_last_buf - st_duration) / sfreq:0.2f} seconds " - "were lumped into the previous window." - ) - else: - logger.info( - f" Contiguous data segment of duration " - f"{n_last_buf / sfreq:0.2f} " - "seconds is too short to be processed with tSSS " - f"using duration {st_duration / sfreq:0.2f}" - ) + n = end - onset + assert n > 0 + tsss_valid = n >= st_duration + if st_overlap and tsss_valid: + n_overlap = st_duration // 2 + window = "hann" + else: + n_overlap = 0 + window = "boxcar" + if st_fixed and st_correlation is not None: + fun = partial(_do_tSSS_on_avg_trans, mc=mc) + else: + fun = _do_tSSS + tsss = _COLA( + partial( + fun, + st_correlation=st_correlation, + tsss_valid=tsss_valid, + sfreq=sfreq, + ), + _Storer(raw_sss._data[:, onset:end], picks=these_picks), + n, + min(st_duration, n), + n_overlap, + sfreq, + window, + name="tSSS-COLA", + ) + + # Generate time points to break up data into equal-length windows + use_n = int(round(raw_sss.buffer_size_sec * raw_sss.info["sfreq"])) + read_lims = list(range(onset, end, use_n)) + [end] assert len(read_lims) >= 2 assert read_lims[0] == onset and read_lims[-1] == end - starts.extend(read_lims[:-1]) - stops.extend(read_lims[1:]) - del read_lims - st_duration = min(max_samps, st_duration) - - # Loop through buffer windows of data - n_sig = int(np.floor(np.log10(max(len(starts), 0)))) + 1 - logger.info(f" Processing {len(starts)} data chunk{_pl(starts)}") - for ii, (start, stop) in enumerate(zip(starts, stops)): - if start == stop: - continue # Skip zero-length annotations - tsss_valid = (stop - start) >= st_duration - rel_times = raw_sss.times[start:stop] - t_str = f"{rel_times[[0, -1]][0]:8.3f} - {rel_times[[0, -1]][1]:8.3f} s" - t_str += (f"(#{ii + 1}/{len(starts)})").rjust(2 * n_sig + 5) - - # Get original data - orig_data = raw_sss._data[meg_picks[good_mask], start:stop] - # This could just be np.empty if not st_only, but shouldn't be slow - # this way so might as well just always take the original data - out_meg_data = raw_sss._data[meg_picks, start:stop] - # Apply cross-talk correction - if ctc is not None: - orig_data = ctc.dot(orig_data) - out_pos_data = np.empty((len(pos_picks), stop - start)) - - # Figure out which positions to use - t_s_s_q_a = _trans_starts_stops_quats(head_pos, start, stop, this_pos_quat) - n_positions = len(t_s_s_q_a[0]) - - # Set up post-tSSS or do pre-tSSS - if st_correlation is not None: - # If doing tSSS before movecomp... - resid = orig_data.copy() # to be safe let's operate on a copy - if st_when == "after": - orig_in_data = np.empty((len(meg_picks), stop - start)) - else: # 'before' - avg_trans = t_s_s_q_a[-1] - if avg_trans is not None: - # if doing movecomp - ( - S_decomp_st, - _, - pS_decomp_st, - _, - n_use_in_st, - ) = _get_this_decomp_trans(avg_trans, t=rel_times[0]) + + # First pass: cross_talk, st_fixed=True + for start, stop in zip(read_lims[:-1], read_lims[1:]): + if start == stop: + continue # Skip zero-length annotations + + # Get original data and apply cross-talk correction + ctc_data = raw_sss._data[meg_picks[good_mask], start:stop] + if ctc is not None: + ctc_data = ctc.dot(ctc_data) + + # Apply the average transform and feed data to the tSSS pre-mc + # operator, which will pass its results to raw_sss._data + if st_fixed and st_correlation is not None: + if st_only: + proc = raw_sss._data[meg_picks, start:stop] else: - S_decomp_st, pS_decomp_st = S_decomp, pS_decomp - n_use_in_st = n_use_in - orig_in_data = np.dot( - np.dot(S_decomp_st[:, :n_use_in_st], pS_decomp_st[:n_use_in_st]), - resid, - ) - resid -= np.dot( - np.dot(S_decomp_st[:, n_use_in_st:], pS_decomp_st[n_use_in_st:]), - resid, - ) - resid -= orig_in_data - # Here we operate on our actual data - proc = out_meg_data if st_only else orig_data - _do_tSSS( + proc = ctc_data + tsss.feed( proc, + ctc_data, + sfreq=info["sfreq"], + ) + else: + raw_sss._data[meg_picks[good_mask], start:stop] = ctc_data + + # Second pass: movement compensation, st_fixed=False + for start, stop in zip(read_lims[:-1], read_lims[1:]): + data, orig_in_data, resid, pos_data, n_positions = mc.feed( + raw_sss._data[meg_picks, start:stop], good_mask, st_only + ) + raw_sss._data[meg_picks, start:stop] = data + if len(pos_picks) > 0: + raw_sss._data[pos_picks, start:stop] = pos_data + if not st_fixed and st_correlation is not None: + tsss.feed( + raw_sss._data[meg_picks, start:stop], orig_in_data, resid, - st_correlation, - n_positions, - t_str, - tsss_valid, + n_positions=n_positions, + sfreq=info["sfreq"], ) - if not st_only or st_when == "after": - # Do movement compensation on the data - for trans, rel_start, rel_stop, this_pos_quat in zip(*t_s_s_q_a[:4]): - # Recalculate bases if necessary (trans will be None iff the - # first position in this interval is the same as last of the - # previous interval) - if trans is not None: - ( - S_decomp, - S_decomp_full, - pS_decomp, - reg_moments, - n_use_in, - ) = _get_this_decomp_trans(trans, t=rel_times[rel_start]) - - # Determine multipole moments for this interval - mm_in = np.dot(pS_decomp[:n_use_in], orig_data[:, rel_start:rel_stop]) - - # Our output data - if not st_only: - if reconstruct == "in": - proj = S_recon.take(reg_moments[:n_use_in], axis=1) - mult = mm_in - else: - assert reconstruct == "orig" - proj = S_decomp_full # already picked reg - mm_out = np.dot( - pS_decomp[n_use_in:], orig_data[:, rel_start:rel_stop] - ) - mult = np.concatenate((mm_in, mm_out)) - out_meg_data[:, rel_start:rel_stop] = np.dot(proj, mult) - if len(pos_picks) > 0: - out_pos_data[:, rel_start:rel_stop] = this_pos_quat[:, np.newaxis] - - # Transform orig_data to store just the residual - if st_when == "after": - # Reconstruct data using original location from external - # and internal spaces and compute residual - rel_resid_data = resid[:, rel_start:rel_stop] - orig_in_data[:, rel_start:rel_stop] = np.dot( - S_decomp[:, :n_use_in], mm_in - ) - rel_resid_data -= np.dot( - np.dot(S_decomp[:, n_use_in:], pS_decomp[n_use_in:]), - rel_resid_data, - ) - rel_resid_data -= orig_in_data[:, rel_start:rel_stop] - - # If doing tSSS at the end - if st_when == "after": - _do_tSSS( - out_meg_data, - orig_in_data, - resid, - st_correlation, - n_positions, - t_str, - tsss_valid, + return raw_sss + + +class _MoveComp: + """Perform movement compensation.""" + + def __init__(self, pos, head_frame, raw, interp, reconstruct): + self.pos = pos + self.sfreq = raw.info["sfreq"] + self.interp = interp + assert reconstruct in ("orig", "in") + self.reconstruct = reconstruct + + def get_decomp_by_offset(self, offset): + idx = np.where(self.pos[1] == offset)[0][0] + dev_head_t = self.pos[0][idx] + t = offset / self.sfreq + S_decomp, S_decomp_full, pS_decomp, reg_moments, n_use_in = self.get_decomp( + dev_head_t, t=t + ) + S_recon_reg = self.S_recon.take(reg_moments[:n_use_in], axis=1) + if self.reconstruct == "orig": + op_sss = np.dot(S_decomp_full, pS_decomp) + else: + assert self.reconstruct == "in" + op_sss = np.dot(S_recon_reg, pS_decomp[:n_use_in]) + assert op_sss.shape[1] == self.n_good + op_in = np.dot(S_decomp[:, :n_use_in], pS_decomp[:n_use_in]) + op_resid = np.eye(S_decomp.shape[0]) - op_in + op_resid -= np.dot(S_decomp[:, n_use_in:], pS_decomp[n_use_in:]) + return op_sss, op_in, op_resid + + def initialize(self, get_decomp, dev_head_t, S_recon): + """Secondary initialization.""" + self.smooth = _Interp2( + self.pos[1], + self.get_decomp_by_offset, + interp=self.interp, + name="MC", + ) + _, _, pS_decomp, self.reg_moments_0, _ = get_decomp(dev_head_t, t=0.0) + self.n_good = pS_decomp.shape[1] + self.S_recon = S_recon + self.offset = 0 + self.get_decomp = get_decomp + # For the average passes + self.last_avg_quat = np.nan * np.ones(6) + + def get_avg_op(self, *, start, stop): + """Apply an average transformation over the next interval.""" + n_positions, avg_quat = _trans_lims(self.pos, start, stop)[1:] + if not np.allclose(avg_quat, self.last_avg_quat, atol=1e-7): + self.last_avg_quat = avg_quat + avg_trans = np.vstack( + [ + np.hstack([quat_to_rot(avg_quat[:3]), avg_quat[3:][:, np.newaxis]]), + [[0.0, 0.0, 0.0, 1.0]], + ] ) - elif st_when == "never" and head_pos[0] is not None: - logger.info( - f" Used {n_positions: 2d} head position{_pl(n_positions)} " - f"for {t_str}", + S_decomp_st, _, pS_decomp_st, _, n_use_in_st = self.get_decomp( + avg_trans, t=start / self.sfreq ) - raw_sss._data[meg_picks, start:stop] = out_meg_data - raw_sss._data[pos_picks, start:stop] = out_pos_data - return raw_sss + self.op_in_avg = np.dot( + S_decomp_st[:, :n_use_in_st], pS_decomp_st[:n_use_in_st] + ) + self.op_resid_avg = ( + np.eye(len(self.op_in_avg)) + - self.op_in_avg + - np.dot(S_decomp_st[:, n_use_in_st:], pS_decomp_st[n_use_in_st:]) + ) + return self.op_in_avg, self.op_resid_avg, n_positions + + def feed(self, data, good_mask, st_only): + n_samp = data.shape[1] + pos_data, n_pos = _trans_lims( + self.pos, self.offset, self.offset + data.shape[-1] + )[:2] + self.offset += data.shape[-1] + + # Do movement compensation on the data, with optional smoothing + in_data = resid_data = None + for sl, left, right, l_interp in self.smooth.feed_generator(n_samp): + good_data = data[good_mask, sl] + l_sss, l_in, l_resid = left + assert l_sss.shape[1] == good_data.shape[0] + if in_data is None: + in_data = np.empty((l_in.shape[0], data.shape[1])) + resid_data = np.empty((l_resid.shape[0], data.shape[1])) + r_interp = 1.0 - l_interp if l_interp is not None else None + if not st_only: + data[:, sl] = np.dot(l_sss, good_data) + if l_interp is not None: + data[:, sl] *= l_interp + data[:, sl] += r_interp * np.dot(right[0], good_data) + + # Reconstruct data using original location from external + # and internal spaces and compute residual + in_data[:, sl] = np.dot(l_in, good_data) + resid_data[:, sl] = np.dot(l_resid, good_data) + if l_interp is not None: + in_data[:, sl] *= l_interp + resid_data[:, sl] *= l_interp + in_data[:, sl] += r_interp * np.dot(right[1], good_data) + resid_data[:, sl] += r_interp * np.dot(right[2], good_data) + return data, in_data, resid_data, pos_data, n_pos + + +def _trans_lims(pos, start, stop): + """Get all trans and limits we need.""" + pos_idx = np.arange(*np.searchsorted(pos[1], [start, stop])) + used = np.zeros(stop - start, bool) + quats = np.empty((9, stop - start)) + n_positions = len(pos_idx) + for ti in range(-1, len(pos_idx)): + # first iteration for this block of data + if ti < 0: + rel_start = 0 + rel_stop = pos[1][pos_idx[0]] if len(pos_idx) > 0 else stop + rel_stop = rel_stop - start + if rel_start == rel_stop: + continue # our first pos occurs on first time sample + this_quat = pos[2][max(pos_idx[0] - 1 if len(pos_idx) else 0, 0)] + n_positions += 1 + else: + rel_start = pos[1][pos_idx[ti]] - start + if ti == len(pos_idx) - 1: + rel_stop = stop - start + else: + rel_stop = pos[1][pos_idx[ti + 1]] - start + this_quat = pos[2][pos_idx[ti]] + quats[:, rel_start:rel_stop] = this_quat[:, np.newaxis] + assert 0 <= rel_start + assert rel_start < rel_stop + assert rel_stop <= stop - start + assert not used[rel_start:rel_stop].any() + used[rel_start:rel_stop] = True + assert used.all() + quats = np.array(quats) + avg_quat = _average_quats(quats[:3].T) + avg_t = np.mean(quats[3:6], axis=1) + avg_quat = np.concatenate([avg_quat, avg_t]) + return quats, n_positions, avg_quat def _get_coil_scale(meg_picks, mag_picks, grad_picks, mag_scale, info): @@ -935,11 +1047,11 @@ def _remove_meg_projs_comps(inst, ignore_ref): inst.info["comps"] = [] -def _check_destination(destination, info, head_frame): +def _check_destination(destination, info, coord_frame): """Triage our reconstruction trans.""" if destination is None: return info["dev_head_t"] - if not head_frame: + if coord_frame != "head": raise RuntimeError( "destination can only be set if using the head coordinate frame" ) @@ -987,63 +1099,45 @@ def _prep_mf_coils(info, ignore_ref=True, *, accuracy="accurate", verbose=None): return rmags, cosmags, bins, n_coils, mag_mask, slice_map -def _trans_starts_stops_quats(pos, start, stop, this_pos_data): - """Get all trans and limits we need.""" - pos_idx = np.arange(*np.searchsorted(pos[1], [start, stop])) - used = np.zeros(stop - start, bool) - trans = list() - rel_starts = list() - rel_stops = list() - quats = list() - weights = list() - for ti in range(-1, len(pos_idx)): - # first iteration for this block of data - if ti < 0: - rel_start = 0 - rel_stop = pos[1][pos_idx[0]] if len(pos_idx) > 0 else stop - rel_stop = rel_stop - start - if rel_start == rel_stop: - continue # our first pos occurs on first time sample - # Don't calculate S_decomp here, use the last one - trans.append(None) # meaning: use previous - quats.append(this_pos_data) - else: - rel_start = pos[1][pos_idx[ti]] - start - if ti == len(pos_idx) - 1: - rel_stop = stop - start - else: - rel_stop = pos[1][pos_idx[ti + 1]] - start - trans.append(pos[0][pos_idx[ti]]) - quats.append(pos[2][pos_idx[ti]]) - assert 0 <= rel_start - assert rel_start < rel_stop - assert rel_stop <= stop - start - assert not used[rel_start:rel_stop].any() - used[rel_start:rel_stop] = True - rel_starts.append(rel_start) - rel_stops.append(rel_stop) - weights.append(rel_stop - rel_start) - assert used.all() - # Use weighted average for average trans over the window - if this_pos_data is None: - avg_trans = None - else: - weights = np.array(weights) - quats = np.array(quats) - weights = weights / weights.sum().astype(float) # int -> float - avg_quat = _average_quats(quats[:, :3], weights) - avg_t = np.dot(weights, quats[:, 3:6]) - avg_trans = np.vstack( - [ - np.hstack([quat_to_rot(avg_quat), avg_t[:, np.newaxis]]), - [[0.0, 0.0, 0.0, 1.0]], - ] - ) - return trans, rel_starts, rel_stops, quats, avg_trans +def _do_tSSS_on_avg_trans( + clean_data, + orig_data, + *, + st_correlation, + tsss_valid, + mc, + start, + stop, + sfreq, +): + # Get the average transformation over the start, stop interval and split data + op_in, op_resid, n_positions = mc.get_avg_op(start=start, stop=stop) + orig_in_data = op_in @ orig_data + resid = op_resid @ orig_data + return _do_tSSS( + clean_data, + orig_in_data, + resid, + st_correlation=st_correlation, + n_positions=n_positions, + tsss_valid=tsss_valid, + start=start, + stop=stop, + sfreq=sfreq, + ) def _do_tSSS( - clean_data, orig_in_data, resid, st_correlation, n_positions, t_str, tsss_valid + clean_data, + orig_in_data, + resid, + st_correlation, + n_positions, + tsss_valid, + *, + start, + stop, + sfreq, ): """Compute and apply SSP-like projection vectors based on min corr.""" if not tsss_valid: @@ -1052,6 +1146,8 @@ def _do_tSSS( np.asarray_chkfinite(resid) t_proj = _overlap_projector(orig_in_data, resid, st_correlation) # Apply projector according to Eq. 12 in :footcite:`TauluSimola2006` + start, stop = start / sfreq, (stop - 1) / sfreq + t_str = f"{start:8.3f} - {stop:8.3f} s" msg = ( f" Projecting {t_proj.shape[1]:2d} intersecting tSSS " f"component{_pl(t_proj.shape[1], ' ')} for {t_str}" @@ -1059,7 +1155,7 @@ def _do_tSSS( if n_positions > 1: msg += f" (across {n_positions:2d} position{_pl(n_positions, ' ')})" logger.info(msg) - clean_data -= np.dot(np.dot(clean_data, t_proj), t_proj.T) + return (clean_data - np.dot(np.dot(clean_data, t_proj), t_proj.T),) def _copy_preload_add_channels(raw, add_channels, copy, info): @@ -1089,7 +1185,7 @@ def _copy_preload_add_channels(raw, add_channels, copy, info): raw._data = out_data else: logger.info(msg + "loading raw data from disk") - with use_log_level(False): + with use_log_level(_verbose_safe_false()): raw._preload_data(out_data[: len(raw.ch_names)]) raw._data = out_data assert raw.preload is True @@ -1127,17 +1223,16 @@ def _copy_preload_add_channels(raw, add_channels, copy, info): return raw, np.array([], int) -def _check_pos(pos, head_frame, raw, st_fixed, sfreq): +def _check_pos(pos, coord_frame, raw, st_fixed): """Check for a valid pos array and transform it to a more usable form.""" _validate_type(pos, (np.ndarray, None), "head_pos") if pos is None: - return [None, np.array([-1])] - if not head_frame: + pos = np.empty((0, 10)) + elif coord_frame != "head": raise ValueError('positions can only be used if coord_frame="head"') if not st_fixed: warn("st_fixed=False is untested, use with caution!") - if not isinstance(pos, np.ndarray): - raise TypeError("pos must be an ndarray") + _validate_type(pos, np.ndarray, "head_pos") if pos.ndim != 2 or pos.shape[1] != 10: raise ValueError("pos must be an array of shape (N, 10)") t = pos[:, 0] @@ -1145,11 +1240,24 @@ def _check_pos(pos, head_frame, raw, st_fixed, sfreq): raise ValueError("Time points must unique and in ascending order") # We need an extra 1e-3 (1 ms) here because MaxFilter outputs values # only out to 3 decimal places - if not _time_mask(t, tmin=raw._first_time - 1e-3, tmax=None, sfreq=sfreq).all(): - raise ValueError( - "Head position time points must be greater than " - f"first sample offset, but found {t[0]:0.4f} < {raw._first_time:0.4f}" + if len(pos) > 0: + if not _time_mask( + t, tmin=raw._first_time - 1e-3, tmax=None, sfreq=raw.info["sfreq"] + ).all(): + raise ValueError( + "Head position time points must be greater than " + f"first sample offset, but found {t[0]:0.4f} < {raw._first_time:0.4f}" + ) + t = t - raw._first_time + if len(t) == 0 or t[0] > 0: + # Prepend the existing dev_head_t to make movecomp easier + t = np.concatenate([[0.0], t]) + trans = raw.info["dev_head_t"] + trans = np.eye(4) if trans is None else trans["trans"] + dev_head_pos = np.concatenate( + [t[[0]], rot_to_quat(trans[:3, :3]), trans[:3, 3], [0, 0, 0]] ) + pos = np.concatenate([dev_head_pos[np.newaxis], pos]) max_dist = np.sqrt(np.sum(pos[:, 4:7] ** 2, axis=1)).max() if max_dist > 1.0: warn( @@ -1157,11 +1265,14 @@ def _check_pos(pos, head_frame, raw, st_fixed, sfreq): "origin, positions may be invalid and Maxwell filtering could " "fail" ) + t[0] = 0 dev_head_ts = np.zeros((len(t), 4, 4)) dev_head_ts[:, 3, 3] = 1.0 dev_head_ts[:, :3, 3] = pos[:, 4:7] dev_head_ts[:, :3, :3] = quat_to_rot(pos[:, 1:4]) - pos = [dev_head_ts, t - raw._first_time, pos[:, 1:]] + t = raw.time_as_index(t, use_rounding=True) + pos = [dev_head_ts, t, pos[:, 1:]] + assert all(len(p) == len(pos[0]) for p in pos) return pos @@ -1260,6 +1371,7 @@ def _get_decomp( pS_decomp *= coil_scale[good_mask].T S_decomp /= coil_scale[good_mask] S_decomp_full /= coil_scale + assert pS_decomp.shape[1] == S_decomp.shape[0] == good_mask.sum() return S_decomp, S_decomp_full, pS_decomp, reg_moments, n_use_in @@ -2409,7 +2521,7 @@ def _trans_sss_basis(exp, all_coils, trans=None, coil_scale=100.0): # intentionally omitted: st_duration, st_correlation, destination, st_fixed, -# st_only +# st_only, st_overlap @verbose def find_bad_channels_maxwell( raw, @@ -2431,6 +2543,7 @@ def find_bad_channels_maxwell( skip_by_annotation=("edge", "bad_acq_skip"), h_freq=40.0, extended_proj=(), + mc_interp=None, verbose=None, ): r"""Find bad channels using Maxwell filtering. @@ -2484,6 +2597,7 @@ def find_bad_channels_maxwell( should provide similar results to MaxFilter. If you do not wish to apply a filter, set this to ``None``. %(extended_proj_maxwell)s + %(maxwell_mc_interp)s %(verbose)s Returns @@ -2626,6 +2740,7 @@ def find_bad_channels_maxwell( head_pos=head_pos, mag_scale=mag_scale, extended_proj=extended_proj, + reconstruct="orig", ) del origin, int_order, ext_order, calibration, cross_talk, coord_frame del regularize, ignore_ref, bad_condition, head_pos, mag_scale @@ -2719,8 +2834,8 @@ def find_bad_channels_maxwell( ] chunk_raw._data[:] = orig_data delta = chunk_raw.get_data(these_picks) - with use_log_level(False): - _run_maxwell_filter(chunk_raw, reconstruct="orig", copy=False, **params) + with use_log_level(_verbose_safe_false()): + _run_maxwell_filter(chunk_raw, copy=False, **params) if n_iter == 1 and len(chunk_flats): logger.info( diff --git a/mne/preprocessing/otp.py b/mne/preprocessing/otp.py index f5e6277a7b3..5b3e25d9953 100644 --- a/mne/preprocessing/otp.py +++ b/mne/preprocessing/otp.py @@ -109,7 +109,7 @@ def oversampled_temporal_projection(raw, duration=10.0, picks=None, verbose=None return raw_otp -def _otp(data, picks_good, picks_bad): +def _otp(data, picks_good, picks_bad, *, start=0, stop=None): """Perform OTP on one segment of data.""" if not np.isfinite(data).all(): raise RuntimeError("non-finite data (inf or nan) found in raw instance") diff --git a/mne/preprocessing/tests/test_maxwell.py b/mne/preprocessing/tests/test_maxwell.py index 002d4555ff8..2f279440d1c 100644 --- a/mne/preprocessing/tests/test_maxwell.py +++ b/mne/preprocessing/tests/test_maxwell.py @@ -5,6 +5,7 @@ import pathlib import re from contextlib import contextmanager +from functools import partial from pathlib import Path import numpy as np @@ -33,9 +34,11 @@ annotate_movement, compute_maxwell_basis, find_bad_channels_maxwell, - maxwell_filter, maxwell_filter_prepare_emptyroom, ) +from mne.preprocessing import ( + maxwell_filter as _maxwell_filter_ola, +) from mne.preprocessing.maxwell import ( _bases_complex_to_real, _bases_real_to_complex, @@ -187,6 +190,12 @@ def read_crop(fname, lims=(0, None)): return read_raw_fif(fname, allow_maxshield="yes").crop(*lims) +# For backward compat and to be most like MaxFilter, we make "maxwell_filter" +# the one that behaves like MaxFilter. _maxwell_filter is left to +# be the advanced/better one. +maxwell_filter = partial(_maxwell_filter_ola, st_overlap=False, mc_interp="zero") + + @pytest.mark.slowtest @testing.requires_testing_data def test_movement_compensation(tmp_path): @@ -296,6 +305,52 @@ def test_movement_compensation(tmp_path): ) +@pytest.mark.slowtest +@testing.requires_testing_data +def test_movement_compensation_smooth(): + """Test movement compensation with smooth interpolation.""" + lims = (0, 10) + raw = read_crop(raw_fname, lims).load_data() + mag_picks = pick_types(raw.info, meg="mag", exclude=()) + power = np.sqrt(np.sum(raw[mag_picks][0] ** 2)) + head_pos = read_head_pos(pos_fname) + kwargs = dict( + head_pos=head_pos, + origin=mf_head_origin, + regularize=None, + bad_condition="ignore", + ) + # Naive MC increases noise relative to raw + raw_sss = maxwell_filter(raw, **kwargs) + _assert_shielding(raw_sss, power, 0.258, max_factor=0.259) + # OLA MC decreases noise relative to raw + raw_sss_smooth = _maxwell_filter_ola(raw, mc_interp="hann", **kwargs) + _assert_shielding(raw_sss_smooth, raw_sss, 1.01, max_factor=1.02) + + # now with time-varying regularization + kwargs["regularize"] = "in" + raw_sss = maxwell_filter(raw, **kwargs) + _assert_shielding(raw_sss, power, 0.84, max_factor=0.85) + raw_sss_smooth = _maxwell_filter_ola(raw, mc_interp="hann", **kwargs) + _assert_shielding(raw_sss_smooth, raw_sss, 1.008, max_factor=1.012) + + # now with tSSS + kwargs["st_duration"] = 10 + with catch_logging() as log: + raw_tsss = maxwell_filter(raw, verbose=True, **kwargs) + log = log.getvalue() + want_re = re.compile(".*Projecting 25 intersecting.*across 24 pos.*", re.DOTALL) + assert want_re.match(log) is not None, log + _assert_shielding(raw_tsss, power, 31.2, max_factor=31.3) + with catch_logging() as log: + raw_tsss_smooth = _maxwell_filter_ola( + raw, mc_interp="hann", st_overlap=True, verbose=True, **kwargs + ) + log = log.getvalue() + assert want_re.match(log) is not None, log + _assert_shielding(raw_tsss_smooth, power, 31.5, max_factor=31.7) + + @pytest.mark.slowtest def test_other_systems(): """Test Maxwell filtering on KIT, BTI, and CTF files.""" @@ -655,6 +710,8 @@ def test_spatiotemporal(): """Test Maxwell filter (tSSS) spatiotemporal processing.""" # Load raw testing data raw = read_crop(raw_fname) + mag_picks = pick_types(raw.info, meg="mag", exclude=()) + power = np.sqrt(np.sum(raw[mag_picks][0] ** 2)) # Test that window is less than length of data with pytest.raises(ValueError, match="must be"): @@ -686,12 +743,31 @@ def test_spatiotemporal(): assert len(py_st) > 0 assert py_st["buflen"] == st_duration assert py_st["subspcorr"] == 0.98 + _assert_shielding(raw_tsss, power, 20.8) # Degenerate cases with pytest.raises(ValueError, match="Need 0 < st_correlation"): maxwell_filter(raw, st_duration=10.0, st_correlation=0.0) +@buggy_mkl_svd +@testing.requires_testing_data +def test_st_overlap(): + """Test st_overlap.""" + raw = read_crop(raw_fname).crop(0, 1.0) + mag_picks = pick_types(raw.info, meg="mag", exclude=()) + power = np.sqrt(np.sum(raw[mag_picks][0] ** 2)) + kwargs = dict( + origin=mf_head_origin, regularize=None, bad_condition="ignore", st_duration=0.5 + ) + raw_tsss = maxwell_filter(raw, **kwargs) + assert _compute_rank_int(raw_tsss, proj=False) == 140 + _assert_shielding(raw_tsss, power, 35.8, max_factor=35.9) + raw_tsss = _maxwell_filter_ola(raw, st_overlap=True, **kwargs) + assert _compute_rank_int(raw_tsss, proj=False) == 140 + _assert_shielding(raw_tsss, power, 35.6, max_factor=35.7) + + @pytest.mark.slowtest @testing.requires_testing_data def test_spatiotemporal_only(): @@ -707,20 +783,31 @@ def test_spatiotemporal_only(): raw_tsss = maxwell_filter(raw, st_duration=tmax / 2.0, st_only=True) assert len(raw.info["projs"]) == len(raw_tsss.info["projs"]) assert _compute_rank_int(raw_tsss, proj=False) == len(picks) - _assert_shielding(raw_tsss, power, 9) + _assert_shielding(raw_tsss, power, 9.2) # with movement head_pos = read_head_pos(pos_fname) raw_tsss = maxwell_filter( raw, st_duration=tmax / 2.0, st_only=True, head_pos=head_pos ) assert _compute_rank_int(raw_tsss, proj=False) == len(picks) - _assert_shielding(raw_tsss, power, 9) + _assert_shielding(raw_tsss, power, 9.2) with pytest.warns(RuntimeWarning, match="st_fixed"): raw_tsss = maxwell_filter( raw, st_duration=tmax / 2.0, st_only=True, head_pos=head_pos, st_fixed=False ) assert _compute_rank_int(raw_tsss, proj=False) == len(picks) - _assert_shielding(raw_tsss, power, 9) + _assert_shielding(raw_tsss, power, 9.2, max_factor=9.4) + # COLA + raw_tsss = maxwell_filter( + raw, + st_duration=tmax / 2.0, + st_only=True, + head_pos=head_pos, + st_overlap=True, + mc_interp="hann", + ) + assert _compute_rank_int(raw_tsss, proj=False) == len(picks) + _assert_shielding(raw_tsss, power, 9.5, max_factor=9.6) # should do nothing raw_tsss = maxwell_filter(raw, st_duration=tmax, st_correlation=1.0, st_only=True) assert_allclose(raw[:][0], raw_tsss[:][0]) @@ -1768,7 +1855,7 @@ def test_compute_maxwell_basis(regularize, n, int_order): assert n_use_in == len(reg_moments) - 15 # no externals removed xform = S[:, :n_use_in] @ pS[:n_use_in] got = xform @ raw.pick(picks="meg", exclude="bads").get_data() - assert_allclose(got, want) + assert_allclose(got, want, atol=1e-16) @testing.requires_testing_data @@ -1886,3 +1973,96 @@ def test_prepare_emptyroom_annot_first_samp( raw_er_prepared.get_data([0], reject_by_annotation="nan") ).mean() assert_allclose(prop_bad, prop_bad_er) + + +@pytest.mark.slowtest +@testing.requires_testing_data +@pytest.mark.parametrize("mc_interp", ("zero", "hann", False)) +@pytest.mark.parametrize("st_fixed", (False, True, False)) +@pytest.mark.parametrize("st_only", (True, False)) +@pytest.mark.filterwarnings("ignore:st_fixed=False is untested.*:RuntimeWarning") +def test_feed_avg(st_fixed, st_only, mc_interp): + """Test that feed_avg gives the correct data for tSSS.""" + if mc_interp is False: + movecomp = False + mc_interp = "zero" + else: + movecomp = True + raw = read_crop(raw_fname, (0, 3.0)).load_data() # 0-1, 0.5-1.5, ... + # Use every third mag just for speed + raw.pick("mag") + raw.pick(raw.ch_names[::3]) + if movecomp: + head_pos = read_head_pos(pos_fname) + # Trim just to make debugging easier + head_pos = head_pos[head_pos[:, 0] < head_pos[0, 0] + 5] + else: + head_pos = None + kwargs = dict( + int_order=3, st_duration=1, st_fixed=st_fixed, st_only=st_only, verbose="debug" + ) + # These were empirically determined -- the importart thing is that they + # only change under specific (expected) circumstances, e.g., not dependent + # on st_only at all + n = 8 if (movecomp and mc_interp == "hann" and not st_fixed) else 4 + st_0_1 = f"Projecting {n} intersecting tSSS components for 0.000 - 0.999 s" + if st_fixed: + n = 4 + else: + if movecomp: + n = 12 if mc_interp == "hann" else 8 + else: + n = 4 + st_0p5_1p5 = ( + f"Projecting {n:2d} intersecting tSSS components for 0.000 - 0.999 s" + ) + if movecomp and st_fixed: + st_0p5_1p5 += " (across 2 positions)\n" + n = 8 if (movecomp and mc_interp == "hann" and not st_fixed) else 4 + log_1_2 = ( + f"Projecting {n} intersecting tSSS components for 1.000 - 1.999 s\n" + ) + with catch_logging() as log: + _maxwell_filter_ola( + raw, head_pos=head_pos, st_overlap=False, mc_interp=mc_interp, **kwargs + ) + log = log.getvalue() + # Leave these print statements in because they'll be captured by pytest but + # are valuable during failures + print(log) + assert st_0_1 in log + assert log_1_2 in log + assert "Eval @ 0 (0)" in log + if movecomp: + assert raw.first_time == 9.0 + this_head_pos = head_pos[np.where(head_pos[:, 0] >= 9.5)[0][0] - 1 :].copy() + this_head_pos[0, 0] = 9.5 + assert this_head_pos[1, 0] > this_head_pos[0, 0] + else: + this_head_pos = None + with catch_logging() as log_crop: + _maxwell_filter_ola( + raw.copy().crop(0.5, None), + head_pos=this_head_pos, + st_overlap=False, + mc_interp=mc_interp, + **kwargs, + ) + log_crop = log_crop.getvalue() + print(log_crop) + assert st_0p5_1p5 in log_crop + # The full / OLA version of this will reflect the actual offset + st_0p5_1p5 = st_0p5_1p5.replace("0.000", "0.500").replace("0.999", "1.499") + with catch_logging() as log_ola: + _maxwell_filter_ola( + raw, + st_overlap=True, + mc_interp=mc_interp, + head_pos=head_pos, + **kwargs, + ) + log_ola = log_ola.getvalue() + print(log_ola) + assert st_0_1 in log_ola + assert log_1_2 in log_ola + assert st_0p5_1p5 in log_ola diff --git a/mne/tests/test_epochs.py b/mne/tests/test_epochs.py index aa11082238f..88f2d9cdc13 100644 --- a/mne/tests/test_epochs.py +++ b/mne/tests/test_epochs.py @@ -441,7 +441,7 @@ def test_average_movements(): assert_meg_snr(evoked_sss, evoked_move_non, 0.02, 2.6) assert_meg_snr(evoked_sss, evoked_stat_all, 0.05, 3.2) # these should be close to numerical precision - assert_allclose(evoked_sss_stat.data, evoked_stat_all.data, atol=1e-20) + assert_allclose(evoked_sss_stat.data, evoked_stat_all.data, atol=1e-14) # pos[0] > epochs.events[0] uses dev_head_t, so make it equivalent destination = deepcopy(epochs.info["dev_head_t"]) diff --git a/mne/tests/test_ola.py b/mne/tests/test_ola.py index 26ddafd8475..2a1eac13ec0 100644 --- a/mne/tests/test_ola.py +++ b/mne/tests/test_ola.py @@ -90,7 +90,7 @@ def test_cola(ndim): sfreq = 1000.0 rng = np.random.RandomState(0) - def processor(x): + def processor(x, *, start, stop): return (x / 2.0,) # halve the signal for n_total in (999, 1000, 1001): diff --git a/mne/utils/docs.py b/mne/utils/docs.py index 54cc6845e58..c1d4fb4595a 100644 --- a/mne/utils/docs.py +++ b/mne/utils/docs.py @@ -2551,6 +2551,15 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): :func:`mne.stats.combine_adjacency`). """ +docdict["maxwell_mc_interp"] = """ +mc_interp : str + Interpolation to use between adjacent time points in movement + compensation. Can be "zero" (default in 1.10; used by MaxFilter), + "linear", or "hann" (default in 1.11). + + .. versionadded:: 1.10 +""" + docdict["measure"] = """ measure : 'zscore' | 'correlation' Which method to use for finding outliers among the components: diff --git a/mne/viz/_3d.py b/mne/viz/_3d.py index 9c558d32a51..4a893b7c017 100644 --- a/mne/viz/_3d.py +++ b/mne/viz/_3d.py @@ -190,7 +190,7 @@ def plot_head_positions( _check_option("mode", mode, ["traces", "field"]) _validate_type(totals, bool, "totals") dest_info = dict(dev_head_t=None) if info is None else info - destination = _check_destination(destination, dest_info, head_frame=True) + destination = _check_destination(destination, dest_info, "head") if destination is not None: destination = _ensure_trans(destination, "head", "meg") # probably inv destination = destination["trans"]