Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Overlap-add processing for maxwell filter #13080

Open
wants to merge 31 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 29 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
78b83d9
MAINT: Refactor internals
larsoner Apr 21, 2020
b0c7fae
ENH: Add smooth options to maxwell_filter
larsoner Apr 21, 2020
f66265a
Merge remote-tracking branch 'upstream/main' into cola4
larsoner Nov 28, 2022
9d6b31d
FIX: Fix test
larsoner Nov 28, 2022
7c44d72
Merge remote-tracking branch 'upstream/main' into cola5
larsoner Apr 25, 2023
5997a93
TST: Improve test
larsoner Apr 25, 2023
9d143ae
WIP: Black
larsoner May 18, 2023
954a990
Merge remote-tracking branch 'upstream/main' into cola5
larsoner May 18, 2023
5bebc83
WIP: Fix time logging with start, stop
larsoner May 18, 2023
062e101
WIP: Test
larsoner May 18, 2023
b252a86
FIX: n_positions
larsoner May 18, 2023
c57e5bf
FIX: Str
larsoner May 24, 2023
e28f34a
WIP: TDD
larsoner Jun 22, 2023
1fc9871
FIX: Closer
larsoner Jun 22, 2023
119dac7
FIX: Very close
larsoner Jun 22, 2023
70be6f8
FIX: WORKING
larsoner Jun 23, 2023
24790d4
Merge remote-tracking branch 'upstream/main' into cola5
larsoner Jun 23, 2023
b6581a8
FIX: Uncomment
larsoner Jun 23, 2023
22145f2
FIX: Uncomment
larsoner Jun 23, 2023
48394dd
Merge remote-tracking branch 'origin/cola5' into cola5
larsoner Jul 3, 2023
57701e5
FIX: OTP
larsoner Jul 3, 2023
2725a52
Merge remote-tracking branch 'upstream/main' into cola5
larsoner Jul 10, 2023
0adefdf
Merge remote-tracking branch 'upstream/main' into cola5
larsoner Dec 6, 2024
46c193d
Merge remote-tracking branch 'upstream/main' into cola5
larsoner Jan 24, 2025
67f8d33
FIX: Ver
larsoner Jan 24, 2025
4b390ce
DOC: Document it
larsoner Jan 24, 2025
ed40542
FIX: Compat
larsoner Jan 24, 2025
58500ed
FIX: Filt
larsoner Jan 24, 2025
8db31da
FIX: Maybe
larsoner Jan 24, 2025
6d65b33
Apply suggestions from code review
larsoner Jan 31, 2025
8cfd708
Merge branch 'main' into cola5
larsoner Jan 31, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/changes/devel/13080.apichange.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
The backward-compatible defaults in :func:`mne.preprocessing.maxwell_filter` of ``st_overlap=False`` and ```mc_interp=None`` will change to their smooth variants ``True`` and ``"hann"``, respectively, in 1.11, by `Eric Larson`_.
1 change: 1 addition & 0 deletions doc/changes/devel/13080.newfeature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add smooth processing of tSSS windows (using overlap-add) and movement compensation (using smooth interpolation of head positions) in :func:`mne.preprocessing.maxwell_filter` via ``st_overlap`` and ```mc_interp`` options, respectively, by `Eric Larson`_.
63 changes: 39 additions & 24 deletions mne/_ola.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np
from scipy.signal import get_window

from .utils import _ensure_int, logger, verbose
from .utils import _ensure_int, _validate_type, logger, verbose

###############################################################################
# Class for interpolation between adjacent points
Expand Down Expand Up @@ -42,7 +42,7 @@ class _Interp2:

"""

def __init__(self, control_points, values, interp="hann"):
def __init__(self, control_points, values, interp="hann", *, name="Interp2"):
# set up interpolation
self.control_points = np.array(control_points, int).ravel()
if not np.array_equal(np.unique(self.control_points), self.control_points):
Expand Down Expand Up @@ -79,6 +79,7 @@ def val(pt):
self._position = 0 # start at zero
self._left_idx = 0
self._left = self._right = self._use_interp = None
self.name = name
known_types = ("cos2", "linear", "zero", "hann")
if interp not in known_types:
raise ValueError(f'interp must be one of {known_types}, got "{interp}"')
Expand All @@ -90,10 +91,10 @@ def feed_generator(self, n_pts):
n_pts = _ensure_int(n_pts, "n_pts")
original_position = self._position
stop = self._position + n_pts
logger.debug(f"Feed {n_pts} ({self._position}-{stop})")
logger.debug(f" ~ {self.name} Feed {n_pts} ({self._position}-{stop})")
used = np.zeros(n_pts, bool)
if self._left is None: # first one
logger.debug(f" Eval @ 0 ({self.control_points[0]})")
logger.debug(f" ~ {self.name} Eval @ 0 ({self.control_points[0]})")
self._left = self.values(self.control_points[0])
if len(self.control_points) == 1:
self._right = self._left
Expand All @@ -102,7 +103,7 @@ def feed_generator(self, n_pts):
# Left zero-order hold condition
if self._position < self.control_points[self._left_idx]:
n_use = min(self.control_points[self._left_idx] - self._position, n_pts)
logger.debug(f" Left ZOH {n_use}")
logger.debug(f" ~ {self.name} Left ZOH {n_use}")
this_sl = slice(None, n_use)
assert used[this_sl].size == n_use
assert not used[this_sl].any()
Expand All @@ -127,7 +128,9 @@ def feed_generator(self, n_pts):
self._left_idx += 1
self._use_interp = None # need to recreate it
eval_pt = self.control_points[self._left_idx + 1]
logger.debug(f" Eval @ {self._left_idx + 1} ({eval_pt})")
logger.debug(
f" ~ {self.name} Eval @ {self._left_idx + 1} ({eval_pt})"
)
self._right = self.values(eval_pt)
assert self._right is not None
left_point = self.control_points[self._left_idx]
Expand All @@ -148,7 +151,8 @@ def feed_generator(self, n_pts):
n_use = min(stop, right_point) - self._position
if n_use > 0:
logger.debug(
f" Interp {self._interp} {n_use} ({left_point}-{right_point})"
f" ~ {self.name} Interp {self._interp} {n_use} "
f"({left_point}-{right_point})"
)
interp_start = self._position - left_point
assert interp_start >= 0
Expand All @@ -169,7 +173,7 @@ def feed_generator(self, n_pts):
if self.control_points[self._left_idx] <= self._position:
n_use = stop - self._position
if n_use > 0:
logger.debug(f" Right ZOH {n_use}")
logger.debug(f" ~ {self.name} Right ZOH %s" % n_use)
this_sl = slice(n_pts - n_use, None)
assert not used[this_sl].any()
used[this_sl] = True
Expand Down Expand Up @@ -210,14 +214,13 @@ def feed(self, n_pts):


def _check_store(store):
_validate_type(store, (np.ndarray, list, tuple, _Storer), "store")
if isinstance(store, np.ndarray):
store = [store]
if isinstance(store, list | tuple) and all(
isinstance(s, np.ndarray) for s in store
):
if not isinstance(store, _Storer):
if not all(isinstance(s, np.ndarray) for s in store):
raise TypeError("All instances must be ndarrays")
store = _Storer(*store)
if not callable(store):
raise TypeError(f"store must be callable, got type {type(store)}")
return store


Expand All @@ -229,10 +232,8 @@ class _COLA:
process : callable
A function that takes a chunk of input data with shape
``(n_channels, n_samples)`` and processes it.
store : callable | ndarray
A function that takes a completed chunk of output data.
Can also be an ``ndarray``, in which case it is treated as the
output data in which to store the results.
store : ndarray | list of ndarray | _Storer
The output data in which to store the results.
n_total : int
The total number of samples.
n_samples : int
Expand Down Expand Up @@ -276,6 +277,7 @@ def __init__(
window="hann",
tol=1e-10,
*,
name="COLA",
verbose=None,
):
n_samples = _ensure_int(n_samples, "n_samples")
Expand All @@ -302,6 +304,7 @@ def __init__(
self._store = _check_store(store)
self._idx = 0
self._in_buffers = self._out_buffers = None
self.name = name

# Create our window boundaries
window_name = window if isinstance(window, str) else "custom"
Expand Down Expand Up @@ -343,6 +346,7 @@ def feed(self, *datas, verbose=None, **kwargs):
raise ValueError(
f"Got {len(datas)} array(s), needed {len(self._in_buffers)}"
)
current_offset = 0 # should be updated below
for di, data in enumerate(datas):
if not isinstance(data, np.ndarray) or data.ndim < 1:
raise TypeError(
Expand All @@ -363,9 +367,12 @@ def feed(self, *datas, verbose=None, **kwargs):
f"shape[:-1]=={self._in_buffers[di].shape[:-1]}, got dtype "
f"{data.dtype} shape[:-1]={data.shape[:-1]}"
)
# This gets updated on first iteration, so store it before it updates
if di == 0:
current_offset = self._in_offset
logger.debug(
f" + Appending {self._in_offset:d}->"
f"{self._in_offset + data.shape[-1]:d}"
f" + {self.name}[{di}] Appending "
f"{current_offset}:{current_offset + data.shape[-1]}"
)
self._in_buffers[di] = np.concatenate([self._in_buffers[di], data], -1)
if self._in_offset > self.stops[-1]:
Expand All @@ -388,13 +395,18 @@ def feed(self, *datas, verbose=None, **kwargs):
if self._idx == 0:
for offset in range(self._n_samples - self._step, 0, -self._step):
this_window[:offset] += self._window[-offset:]
logger.debug(f" * Processing {start}->{stop}")
this_proc = [in_[..., :this_len].copy() for in_ in self._in_buffers]
logger.debug(
f" * {self.name}[:] Processing {start}:{stop} "
f"(e.g., {this_proc[0].flat[[0, -1]]})"
)
if not all(
proc.shape[-1] == this_len == this_window.size for proc in this_proc
):
raise RuntimeError("internal indexing error")
outs = self._process(*this_proc, **kwargs)
start = self._store.idx
stop = self._store.idx + this_len
outs = self._process(*this_proc, start=start, stop=stop, **kwargs)
if self._out_buffers is None:
max_len = np.max(self.stops - self.starts)
self._out_buffers = [
Expand All @@ -409,9 +421,12 @@ def feed(self, *datas, verbose=None, **kwargs):
else:
next_start = self.stops[-1]
delta = next_start - self.starts[self._idx - 1]
logger.debug(
f" + {self.name}[:] Shifting input and output buffers by "
f"{delta} samples (storing {start}:{stop})"
)
for di in range(len(self._in_buffers)):
self._in_buffers[di] = self._in_buffers[di][..., delta:]
logger.debug(f" - Shifting input/output buffers by {delta:d} samples")
self._store(*[o[..., :delta] for o in self._out_buffers])
for ob in self._out_buffers:
ob[..., :-delta] = ob[..., delta:]
Expand All @@ -430,8 +445,8 @@ def _check_cola(win, nperseg, step, window_name, tol=1e-10):
deviation = np.max(np.abs(binsums - const))
if deviation > tol:
raise ValueError(
f"segment length {nperseg:d} with step {step:d} for {window_name} window "
"type does not provide a constant output "
f"segment length {nperseg} with step {step} for {window_name} "
"window type does not provide a constant output "
f"({100 * deviation / const:g}% deviation)"
)
return const
Expand Down
2 changes: 1 addition & 1 deletion mne/epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4824,7 +4824,7 @@ def average_movements(
del head_pos
_check_usable(epochs, ignore_ref)
origin = _check_origin(origin, epochs.info, "head")
recon_trans = _check_destination(destination, epochs.info, True)
recon_trans = _check_destination(destination, epochs.info, "head")

logger.info(f"Aligning and averaging up to {len(epochs.events)} epochs")
if not np.array_equal(epochs.events[:, 0], np.unique(epochs.events[:, 0])):
Expand Down
12 changes: 2 additions & 10 deletions mne/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1664,24 +1664,16 @@ def _mt_spectrum_remove_win(
n_overlap = (n_samples + 1) // 2
x_out = np.zeros_like(x)
rm_freqs = list()
idx = [0]

# Define how to process a chunk of data
def process(x_):
def process(x_, *, start, stop):
out = _mt_spectrum_remove(
x_, sfreq, line_freqs, notch_widths, window_fun, threshold, get_thresh
)
rm_freqs.append(out[1])
return (out[0],) # must return a tuple

# Define how to store a chunk of fully processed data (it's trivial)
def store(x_):
stop = idx[0] + x_.shape[-1]
x_out[..., idx[0] : stop] += x_
idx[0] = stop

_COLA(process, store, n_times, n_samples, n_overlap, sfreq, verbose=False).feed(x)
assert idx[0] == n_times
_COLA(process, x_out, n_times, n_samples, n_overlap, sfreq, verbose=False).feed(x)
return x_out, rm_freqs


Expand Down
Loading
Loading