Skip to content

Commit

Permalink
add framework for TWA head pos as MF destination (#1043)
Browse files Browse the repository at this point in the history
Co-authored-by: Eric Larson <[email protected]>
  • Loading branch information
drammock and larsoner authored Feb 19, 2025
1 parent fe63ba2 commit 3564641
Show file tree
Hide file tree
Showing 6 changed files with 181 additions and 10 deletions.
5 changes: 4 additions & 1 deletion mne_bids_pipeline/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,7 +644,7 @@
```
"""

mf_destination: Literal["reference_run"] | FloatArrayLike = "reference_run"
mf_destination: Literal["reference_run", "twa"] | FloatArrayLike = "reference_run"
"""
Despite all possible care to avoid movements in the MEG, the participant
will likely slowly drift down from the Dewar or slightly shift the head
Expand All @@ -666,6 +666,9 @@
from mne.transforms import translation
mf_destination = translation(z=0.04)
```
3. Compute the time-weighted average head position across all runs in a session,
and use that as the destination coordinates for each run. This will result in a
device-to-head transformation that differs between sessions within each subject.
"""

mf_int_order: int = 8
Expand Down
12 changes: 12 additions & 0 deletions mne_bids_pipeline/_config_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,11 +278,23 @@ def _check_config(config: SimpleNamespace, config_path: PathLike | None) -> None
"head_pos",
"extended_proj",
)
# check `mf_extra_kws` for things that shouldn't be in there
if duplicates := (set(config.mf_extra_kws) & set(mf_reserved_kwargs)):
raise ConfigError(
f"`mf_extra_kws` contains keys {', '.join(sorted(duplicates))} that are "
"handled by dedicated config keys. Please remove them from `mf_extra_kws`."
)
# if `destination="twa"` make sure `mf_mc=True`
if (
isinstance(config.mf_destination, str)
and config.mf_destination == "twa"
and not config.mf_mc
):
raise ConfigError(
"cannot compute time-weighted average head position (mf_destination='twa') "
"without movement compensation. Please set `mf_mc=True` in your config."
)

reject = config.reject
ica_reject = config.ica_reject
if config.spatial_filter == "ica":
Expand Down
137 changes: 136 additions & 1 deletion mne_bids_pipeline/steps/preprocessing/_02_head_pos.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,14 @@
from types import SimpleNamespace

import mne
from mne_bids import BIDSPath, find_matching_paths

from mne_bids_pipeline._config_utils import get_runs_tasks, get_subjects_sessions
from mne_bids_pipeline._import_data import (
_get_bids_path_in,
_get_run_rest_noise_path,
_import_data_kwargs,
_path_dict,
import_experimental_data,
)
from mne_bids_pipeline._logging import gen_log_kwargs, logger
Expand Down Expand Up @@ -140,6 +143,121 @@ def run_head_pos(
return _prep_out_files(exec_params=exec_params, out_files=out_files)


def get_input_fnames_twa_head_pos(
*,
cfg: SimpleNamespace,
subject: str,
session: str | None,
run: str | None,
task: str | None,
) -> dict[str, BIDSPath | list[BIDSPath]]:
"""Get paths of files required by compute_twa_head_pos function."""
in_files: dict[str, BIDSPath] = dict()
# can't use `_get_run_path()` here because we don't loop over runs/tasks.
# But any run will do, as long as the file exists:
run, _ = get_runs_tasks(
config=cfg, subject=subject, session=session, which=("runs",)
)[0]
bids_path_in = _get_bids_path_in(
cfg=cfg,
subject=subject,
session=session,
run=run,
task=task,
kind="orig",
)
in_files[f"raw_task-{task}"] = _path_dict(
cfg=cfg,
subject=subject,
session=session,
bids_path_in=bids_path_in,
add_bads=False,
allow_missing=False,
kind="orig",
)[f"raw_task-{task}_run-{run}"]
# ideally we'd do the path-finding for `all_runs_raw_bidspaths` and
# `all_runs_headpos_bidspaths` here, but we can't because MBP is strict about only
# returning paths, not lists of paths :(
return in_files


@failsafe_run(
get_input_fnames=get_input_fnames_twa_head_pos,
)
def compute_twa_head_pos(
*,
cfg: SimpleNamespace,
exec_params: SimpleNamespace,
subject: str,
session: str | list[str] | None,
run: str | None,
task: str | None,
in_files: InFilesT,
) -> OutFilesT:
"""Compute time-weighted average head position."""
# logging
want_mc = cfg.mf_mc
dest_is_twa = isinstance(cfg.mf_destination, str) and cfg.mf_destination == "twa"
msg = "Skipping computation of time-weighted average head position"
if not want_mc:
msg += " (no movement compensation requested)"
kwargs = dict(emoji="skip")
elif not dest_is_twa:
msg += ' (mf_destination is not "twa")'
kwargs = dict(emoji="skip")
else:
msg = "Computing time-weighted average head position"
kwargs = dict()
logger.info(**gen_log_kwargs(message=msg, **kwargs))
# maybe bail early
if not want_mc and not dest_is_twa:
return _prep_out_files(exec_params=exec_params, out_files=dict())

# path to (subject+session)-level `destination.fif` in derivatives folder
bids_path_in = in_files.pop(f"raw_task-{task}")
dest_path = bids_path_in.copy().update(
check=False,
description="twa",
extension=".fif",
root=cfg.deriv_root,
run=None,
suffix="destination",
)
# need raw files from all runs
all_runs_raw_bidspaths = find_matching_paths(
root=cfg.bids_root,
subjects=subject,
sessions=session,
tasks=task,
suffixes="meg",
ignore_json=True,
ignore_nosub=True,
check=True,
)
raw_fnames = [bp.fpath for bp in all_runs_raw_bidspaths]
raws = [
mne.io.read_raw_fif(fname, allow_maxshield=True, verbose="ERROR", preload=False)
for fname in raw_fnames
]
# also need headpos files from all runs
all_runs_headpos_bidspaths = find_matching_paths(
root=cfg.deriv_root,
subjects=subject,
sessions=session,
tasks=task,
suffixes="headpos",
extensions=".txt",
check=False,
)
head_poses = [mne.chpi.read_head_pos(bp.fpath) for bp in all_runs_headpos_bidspaths]
# compute time-weighted average head position and save it to disk
destination = mne.preprocessing.compute_average_dev_head_t(raws, head_poses)
mne.write_trans(fname=dest_path.fpath, trans=destination, overwrite=True)
# output
out_files = dict(destination_head_pos=dest_path)
return _prep_out_files(exec_params=exec_params, out_files=out_files)


def get_config(
*,
config: SimpleNamespace,
Expand Down Expand Up @@ -183,5 +301,22 @@ def main(*, config: SimpleNamespace) -> None:
which=("runs", "rest"),
)
)
# compute time-weighted average head position
# within subject+session+task, across runs
parallel, run_func = parallel_func(
compute_twa_head_pos, exec_params=config.exec_params
)
more_logs = parallel(
run_func(
cfg=config,
exec_params=config.exec_params,
subject=subject,
session=session,
task=config.task or None, # default task is ""
run=None,
)
for subject, sessions in get_subjects_sessions(config).items()
for session in sessions
)

save_logs(config=config, logs=logs)
save_logs(config=config, logs=logs + more_logs)
31 changes: 23 additions & 8 deletions mne_bids_pipeline/steps/preprocessing/_03_maxfilter.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,14 +220,24 @@ def get_input_fnames_maxwell_filter(
subject=subject,
session=session,
)[f"raw_task-{pos_task}_run-{pos_run}"]
in_files[f"{in_key}-pos"] = path.update(
in_files[f"{in_key}-pos"] = path.copy().update(
suffix="headpos",
extension=".txt",
root=cfg.deriv_root,
check=False,
task=pos_task,
run=pos_run,
)
if isinstance(cfg.mf_destination, str) and cfg.mf_destination == "twa":
in_files[f"{in_key}-twa"] = path.update(
description="twa",
suffix="destination",
extension=".fif",
root=cfg.deriv_root,
check=False,
task=pos_task,
run=None,
)

if cfg.mf_esss:
in_files["esss_basis"] = (
Expand Down Expand Up @@ -299,7 +309,7 @@ def run_maxwell_filter(
)
if isinstance(cfg.mf_destination, str):
destination = cfg.mf_destination
assert destination == "reference_run"
assert destination in ("reference_run", "twa")
else:
destination_array = np.array(cfg.mf_destination, float)
assert destination_array.shape == (4, 4)
Expand Down Expand Up @@ -340,9 +350,17 @@ def run_maxwell_filter(
verbose=cfg.read_raw_bids_verbose,
)
bids_path_ref_bads_in = in_files.pop("raw_ref_run-bads", None)
# load head pos
if cfg.mf_mc:
head_pos = mne.chpi.read_head_pos(in_files.pop(f"{in_key}-pos"))
else:
head_pos = None
# triage string-valued destinations
if isinstance(destination, str):
assert destination == "reference_run"
destination = raw.info["dev_head_t"]
if destination == "reference_run":
destination = raw.info["dev_head_t"]
elif destination == "twa":
destination = mne.read_trans(in_files.pop(f"{in_key}-twa"))
del raw
assert isinstance(destination, mne.transforms.Transform), destination

Expand All @@ -354,10 +372,7 @@ def run_maxwell_filter(
else:
apply_msg += "SSS"
if cfg.mf_mc:
extra.append("MC")
head_pos = mne.chpi.read_head_pos(in_files.pop(f"{in_key}-pos"))
else:
head_pos = None
extra.append("MC") # head_pos already loaded above
if cfg.mf_esss:
extra.append("eSSS")
extended_proj = mne.read_proj(in_files.pop("esss_basis"))
Expand Down
1 change: 1 addition & 0 deletions mne_bids_pipeline/tests/configs/config_ds004229.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
z=0.055,
) @ mne.transforms.rotation(x=np.deg2rad(-15))
mf_mc = True
mf_destination = "twa"
mf_st_duration = 10
mf_int_order = 6 # lower for smaller heads
mf_mc_t_step_min = 0.5 # just for speed!
Expand Down
5 changes: 5 additions & 0 deletions mne_bids_pipeline/tests/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@ def test_validation(tmp_path: Path, capsys: pytest.CaptureFixture[str]) -> None:
_import_config(config_path=config_path)
msg, err = capsys.readouterr()
assert msg == err == "" # no new message
# TWA headpos without movement compensation
bad_text = working_text + "mf_destination = 'twa'\n"
config_path.write_text(bad_text)
with pytest.raises(ConfigError, match="cannot compute time-weighted average head"):
_import_config(config_path=config_path)
# maxfilter extra kwargs
bad_text = working_text + "mf_extra_kws = {'calibration': 'x', 'head_pos': False}\n"
config_path.write_text(bad_text)
Expand Down

0 comments on commit 3564641

Please sign in to comment.