Skip to content

Commit

Permalink
Update examples to use simulation function
Browse files Browse the repository at this point in the history
  • Loading branch information
tsbinns committed Mar 18, 2024
1 parent fdd532d commit 95cd343
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 170 deletions.
99 changes: 18 additions & 81 deletions examples/cacoh.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,19 @@
spatial patterns of the connectivity.
"""

# Authors: Mohammad Orabe <orabe.mhd@gmail.com>
# Thomas S. Binns <t.s.binns@outlook.com>
# Authors: Thomas S. Binns <t.s.binns@outlook.com>
# Mohammad Orabe <orabe.mhd@gmail.com>
# License: BSD (3-clause)

# %%
import numpy as np
from matplotlib import pyplot as plt

import mne
from mne_connectivity import seed_target_indices, spectral_connectivity_epochs
from mne_connectivity import (
make_signals_in_freq_bands,
seed_target_indices,
spectral_connectivity_epochs,
)

###############################################################################
# Background
Expand Down Expand Up @@ -64,95 +67,29 @@
#
# We can consider the seeds and targets to be signals of different modalities,
# e.g. cortical EEG signals and subcortical LFP signals, cortical EEG signals
# and muscular EMG signals, etc.... We use the function below to simulate these
# signals.

# %%


def simulate_connectivity(freq_band: tuple[int, int], rng_seed: int) -> np.ndarray:
"""Simulates signals interacting in a given frequency band.
Parameters
----------
freq_band : tuple of int, int
Frequency band where the connectivity should be simulated, where the
first entry corresponds to the lower frequency, and the second entry to
the higher frequency.
rng_seed : int
Seed to use for the random number generator.
Returns
-------
data : numpy.ndarray
The simulated data stored in an array. The channels are arranged
according to seeds, then targets.
"""
# Define fixed simulation parameters
n_seeds = 5
n_targets = 3
n_epochs = 10
n_times = 200 # samples
sfreq = 100 # Hz
snr = 0.7
trans_bandwidth = 1 # Hz
connection_delay = 1 # sample

np.random.seed(rng_seed)

n_channels = n_seeds + n_targets

# simulate signal source at desired frequency band
signal = np.random.randn(1, n_epochs * n_times + connection_delay)
signal = mne.filter.filter_data(
data=signal,
sfreq=sfreq,
l_freq=freq_band[0],
h_freq=freq_band[1],
l_trans_bandwidth=trans_bandwidth,
h_trans_bandwidth=trans_bandwidth,
fir_design="firwin2",
verbose=False,
)

# simulate noise for each channel
noise = np.random.randn(n_channels, n_epochs * n_times + connection_delay)

# create data by projecting signal into noise
data = (signal * snr) + (noise * (1 - snr))

# shift target data by desired delay
if connection_delay > 0:
# shift target data
data[n_seeds:, connection_delay:] = data[n_seeds:, : n_epochs * n_times]
# remove extra time
data = data[:, : n_epochs * n_times]

# reshape data into epochs
data = data.reshape(n_channels, n_epochs, n_times)
data = data.transpose((1, 0, 2)) # (epochs x channels x times)

return data


###############################################################################
# and muscular EMG signals, etc.... We use the
# :func:`~mne_connectivity.make_signals_in_freq_bands` function to simulate
# these signals.

# %%

# Generate simulated data
data_10_12 = simulate_connectivity(
data_10_12 = make_signals_in_freq_bands(
n_seeds=5,
n_targets=3,
freq_band=(10, 12), # 10-12 Hz interaction
rng_seed=42,
)

data_23_25 = simulate_connectivity(
data_23_25 = make_signals_in_freq_bands(
n_seeds=5,
n_targets=3,
freq_band=(23, 25), # 23-25 Hz interaction
rng_seed=44,
)

# Combine data into a single array
data = np.concatenate((data_10_12, data_23_25), axis=1)
# Combine data into a single object
data = data_10_12.add_channels([data_23_25])

###############################################################################
# Computing CaCoh
Expand Down
114 changes: 25 additions & 89 deletions examples/compare_coherency_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,11 @@
import numpy as np
from matplotlib import pyplot as plt

import mne
from mne_connectivity import seed_target_indices, spectral_connectivity_epochs
from mne_connectivity import (
make_signals_in_freq_bands,
seed_target_indices,
spectral_connectivity_epochs,
)

###############################################################################
# An introduction to coherency-based connectivity methods
Expand Down Expand Up @@ -86,96 +89,25 @@

# %%


def simulate_connectivity(
freq_band: tuple[int, int], connection_delay: int, rng_seed: int
) -> np.ndarray:
"""Simulates signals interacting in a given frequency band.
Parameters
----------
freq_band : tuple of int, int
Frequency band where the connectivity should be simulated, where the
first entry corresponds to the lower frequency, and the second entry to
the higher frequency.
connection_delay :
Number of timepoints for the delay of connectivity between the seeds
and targets. If > 0, the target data is a delayed form of the seed data
by this many timepoints.
rng_seed : int
Seed to use for the random number generator.
Returns
-------
data : numpy.ndarray
The simulated data stored in an array. The channels are arranged
according to seeds, then targets.
"""
# Define fixed simulation parameters
n_seeds = 3
n_targets = 3
n_epochs = 10
n_times = 200 # samples
sfreq = 100 # Hz
snr = 0.7
trans_bandwidth = 1 # Hz

np.random.seed(rng_seed)

n_channels = n_seeds + n_targets

# simulate signal source at desired frequency band
signal = np.random.randn(1, n_epochs * n_times + connection_delay)
signal = mne.filter.filter_data(
data=signal,
sfreq=sfreq,
l_freq=freq_band[0],
h_freq=freq_band[1],
l_trans_bandwidth=trans_bandwidth,
h_trans_bandwidth=trans_bandwidth,
fir_design="firwin2",
verbose=False,
)

# simulate noise for each channel
noise = np.random.randn(n_channels, n_epochs * n_times + connection_delay)

# create data by projecting signal into noise
data = (signal * snr) + (noise * (1 - snr))

# shift target data by desired delay
if connection_delay > 0:
# shift target data
data[n_seeds:, connection_delay:] = data[n_seeds:, : n_epochs * n_times]
# remove extra time
data = data[:, : n_epochs * n_times]

# reshape data into epochs
data = data.reshape(n_channels, n_epochs, n_times)
data = data.transpose((1, 0, 2)) # (epochs x channels x times)

return data


# %%

# Generate simulated data
data_delay = simulate_connectivity(
data_delay = make_signals_in_freq_bands(
n_seeds=3,
n_targets=3,
freq_band=(10, 12), # 10-12 Hz interaction
connection_delay=2, # samples; non-zero time-lag
rng_seed=42,
)

data_no_delay = simulate_connectivity(
data_no_delay = make_signals_in_freq_bands(
n_seeds=3,
n_targets=3,
freq_band=(23, 25), # 23-25 Hz interaction
connection_delay=0, # samples; zero time-lag
rng_seed=44,
)

# Combine data into a single array
data = np.concatenate((data_delay, data_no_delay), axis=1)
# Combine data into a single object
data = data_delay.add_channels([data_no_delay])

###############################################################################
# We compute the connectivity of these simulated signals using CaCoh (a
Expand Down Expand Up @@ -374,7 +306,7 @@ def plot_connectivity_circle():
#
# **In situations where non-physiological zero time-lag interactions are not
# assumed, methods based on real and imaginary parts of coherency (Cohy, Coh,
# CaCoh) should be used.** Examples of situations include:
# CaCoh) should be used.** An example includes:
#
# - Connectivity between channels of different modalities where different
# references are used.
Expand Down Expand Up @@ -411,20 +343,24 @@ def plot_connectivity_circle():
# %%

# Generate simulated data
data_10_12 = simulate_connectivity(
data_10_12 = make_signals_in_freq_bands(
n_seeds=3,
n_targets=3,
freq_band=(10, 12), # 10-12 Hz interaction
connection_delay=1, # samples
rng_seed=42,
rng_seed=40,
)

data_23_25 = simulate_connectivity(
freq_band=(23, 25), # 10-12 Hz interaction
data_23_25 = make_signals_in_freq_bands(
n_seeds=3,
n_targets=3,
freq_band=(23, 25), # 23-25 Hz interaction
connection_delay=1, # samples
rng_seed=44,
rng_seed=42,
)

# Combine data into a single array
data = np.concatenate((data_10_12, data_23_25), axis=1)
data = data_10_12.add_channels([data_23_25])

# Compute CaCoh & MIC
(cacoh, mic) = spectral_connectivity_epochs(
Expand Down Expand Up @@ -519,7 +455,7 @@ def plot_connectivity_circle():
axis.plot(imcoh.freqs, imcoh_mean_subbed, linewidth=2, label="ImCoh", linestyle="--")
axis.set_xlabel("Frequency (Hz)")
axis.set_ylabel("Mean-corrected connectivity (A.U.)")
axis.annotate("$\pm$45°\ninteraction", xy=(12, 0.25))
axis.annotate("$\pm$45°\ninteraction", xy=(13, 0.25))
axis.annotate("$\pm$90°\ninteraction", xy=(26.5, 0.25))
axis.legend(loc="upper left")
fig.suptitle("Coh vs. ImCoh\n$\pm$45° & $\pm$90° interactions")
Expand Down

0 comments on commit 95cd343

Please sign in to comment.