Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DRAFT] Redesign/datasets ICA addition #56

Open
wants to merge 17 commits into
base: redesign/datasets
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions benchmarks/MOABB/dataio/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@

from torch.utils.data import Dataset

from .ica import ICAProcessor


class RawEEGSample(TypedDict, total=False):
"""Default dictionary keys provided by `~RawEEGDataset`.
Expand Down Expand Up @@ -94,10 +96,12 @@ def __init__(
data,
preload=False,
verbose=None,
ica_processor: Optional[ICAProcessor] = None,
dynamic_items=(),
output_keys=(),
):
self.verbose = verbose
self.ica_processor = ica_processor
dynamic_items = [self._make_load_raw_dynamic_item(preload)] + list(
dynamic_items
)
Expand Down Expand Up @@ -297,6 +301,9 @@ def _make_load_raw_dynamic_item(self, preload: bool):
@provides("info", "raw")
def _load_raw(fpath: str):
raw = self._read_raw_bids_cached(fpath, preload)

if self.ica_processor is not None:
raw = self.ica_processor.process(raw, fpath)

yield raw.info
yield raw
Expand Down
93 changes: 93 additions & 0 deletions benchmarks/MOABB/dataio/ica.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
from pathlib import Path
from typing import Union, Optional, Dict, Any

import mne
from mne.preprocessing import ICA
from mne_bids import get_bids_path_from_fname, BIDSPath


class ICAProcessor:
"""Handles ICA computation and application for EEG data.

Arguments
---------
n_components : int | float | None
Number of components to keep during ICA decomposition
method : str
The ICA method to use. Can be 'fastica', 'infomax' or 'picard'.
Defaults to 'fastica'.
random_state : int | None
Random state for reproducibility
fit_params : dict | None
Additional parameters to pass to the ICA fit method.
See mne.preprocessing.ICA for details.
filter_params : dict | None
Parameters for the high-pass filter applied before ICA.
Defaults to {'l_freq': 1.0, 'h_freq': None}
"""

def __init__(
self,
n_components=None,
method='fastica',
random_state=42,
fit_params: Optional[Dict[str, Any]] = None,
filter_params: Optional[Dict[str, Any]] = None,
):
self.n_components = n_components
self.method = method
self.random_state = random_state
self.fit_params = fit_params or {}
self.filter_params = filter_params or {'l_freq': 1.0, 'h_freq': None}

def get_ica_path(self, raw_path: Union[str, Path]) -> Path:
"""Generate path where ICA solution should be stored.

Creates a derivatives folder to store ICA solutions, following BIDS conventions.
"""
bids_path = get_bids_path_from_fname(raw_path)
# For derivatives, you can put them in a derivatives folder:
bids_path.root = (bids_path.root / ".." / "derivatives" / f"ica-{self.method}")
# Keep the same base entities:
bids_path.update(
suffix='eeg', # override or confirm suffix
extension='.fif',
description='ica', # <-- This sets a desc=ica entity
check=True, # If you do not want BIDSPath to fail on derivative checks
)
# Make sure the folder is created
bids_path.fpath.parent.mkdir(parents=True, exist_ok=True)

return bids_path.fpath

def compute_ica(self, raw: mne.io.RawArray, ica_path: Path) -> ICA:
"""Compute ICA solution and save to disk."""
# High-pass filter for ICA
raw_filtered = raw.copy()
raw_filtered.filter(**self.filter_params)

ica = ICA(
n_components=self.n_components,
method=self.method,
random_state=self.random_state,
**self.fit_params
)
ica.fit(raw_filtered)
ica.save(ica_path)
return ica

def process(self, raw: mne.io.RawArray, raw_path: Union[str, Path]) -> mne.io.RawArray:
"""Process raw data with ICA, computing or loading from cache."""

ica_path = self.get_ica_path(raw_path)

if not ica_path.exists():
ica = self.compute_ica(raw, ica_path)
else:
ica = mne.preprocessing.read_ica(ica_path, verbose='ERROR')

# Create a copy of the raw data before applying ICA
raw_ica = raw.copy()
ica.apply(raw_ica)

return raw_ica
145 changes: 145 additions & 0 deletions benchmarks/MOABB/validate_ica.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
import logging
import os
from pathlib import Path
import time
import mne
import moabb
from moabb.datasets import BNCI2014_001
from memory_profiler import profile

from dataio.datasets import EpochedEEGDataset, RawEEGDataset, InMemoryDataset
from dataio.ica import ICAProcessor

# Set up logging
mne.set_log_level(verbose=False)
moabb.set_log_level(level="ERROR")

def test_ica_method(method: str, n_components: int = 15, **kwargs):
"""Test a specific ICA method and return timing results."""
print(f"\nTesting ICA method: {method}")
ica_processor = ICAProcessor(
n_components=n_components,
method=method,
**kwargs
)

dataset = EpochedEEGDataset.from_moabb(
BNCI2014_001(),
f"data/MNE-BIDS-bnci2014-001-epoched-{method}.json",
save_path="data",
tmin=0,
tmax=4.0,
preload=True,
output_keys=["label", "subject", "session", "epoch"],
ica_processor=ica_processor
)

# First run - ICA computation
print("First run (computing ICA):")
start = time.time()
for _ in dataset:
pass
computation_time = time.time() - start
print(f"Time with {method} ICA (first run): {computation_time:.2f}s")

# Second run - using cached ICA
print("\nSecond run (using cached ICA):")
start = time.time()
for _ in dataset:
pass
cached_time = time.time() - start
print(f"Time with {method} ICA (cached): {cached_time:.2f}s")

# Memory-cached version
print("\nTesting with InMemoryDataset wrapper:")
dataset_cached = InMemoryDataset(dataset)
start = time.time()
for _ in dataset_cached:
pass
memory_cached_time = time.time() - start
print(f"Time with {method} ICA (in-memory cache): {memory_cached_time:.2f}s")

return {
'method': method,
'computation_time': computation_time,
'cached_time': cached_time,
'memory_cached_time': memory_cached_time
}

def compare_ica_methods():
# Test without ICA first as baseline
print("\nTesting without ICA (baseline):")
dataset_no_ica = EpochedEEGDataset.from_moabb(
BNCI2014_001(),
"data/MNE-BIDS-bnci2014-001-epoched.json",
save_path="data",
tmin=0,
tmax=4.0,
output_keys=["label", "subject", "session", "epoch"],
)

start = time.time()
for _ in dataset_no_ica:
pass
baseline_time = time.time() - start
print(f"Time without ICA: {baseline_time:.2f}s")

# Test different ICA methods
results = []

# Test Picard
results.append(test_ica_method(
'picard',
n_components=15,
fit_params={'max_iter': 500}
))

# Test Infomax
results.append(test_ica_method(
'infomax',
n_components=15,
fit_params={'max_iter': 1000}
))

# Print comparison
print("\nComparison Summary:")
print("-" * 50)
print(f"Baseline (no ICA): {baseline_time:.2f}s")
print("-" * 50)
for result in results:
print(f"Method: {result['method']}")
print(f" Computation time: {result['computation_time']:.2f}s")
print(f" Cached access time: {result['cached_time']:.2f}s")
print(f" In-memory cached time: {result['memory_cached_time']:.2f}s")
print("-" * 50)

@profile
def profile_memory_usage():
# Profile memory usage for both methods
for method in ['picard', 'infomax']:
print(f"\nProfiling {method} ICA:")
ica_processor = ICAProcessor(
n_components=15,
method=method,
fit_params={'max_iter': 500} if method == 'picard' else {'iteration': 1000}
)
dataset = EpochedEEGDataset.from_moabb(
BNCI2014_001(),
f"data/MNE-BIDS-bnci2014-001-epoched-{method}.json",
save_path="data",
tmin=0,
tmax=4.0,
preload=True,
output_keys=["label", "subject", "session", "epoch"],
ica_processor=ica_processor
)

for _ in dataset:
pass

if __name__ == "__main__":
print("Running ICA method comparison...")
compare_ica_methods()

print("\nRunning memory profile...")
profile_memory_usage()