Skip to content

Commit 848b244

Browse files
committed
Add taper weighting
1 parent 087779c commit 848b244

File tree

2 files changed

+34
-25
lines changed

2 files changed

+34
-25
lines changed

mne/time_frequency/tests/test_tfr.py

+12-7
Original file line numberDiff line numberDiff line change
@@ -255,20 +255,25 @@ def test_tfr_morlet():
255255
# computed within the method.
256256
assert_allclose(epochs_amplitude_2.data**2, epochs_power_picks.data)
257257

258-
# test that averaging power across tapers when multitaper with
258+
# test that aggregating power across tapers when multitaper with
259259
# output='complex' gives the same as output='power'
260260
epoch_data = epochs.get_data()
261261
multitaper_power = tfr_array_multitaper(
262262
epoch_data, epochs.info["sfreq"], freqs, n_cycles, output="power"
263263
)
264-
multitaper_complex = tfr_array_multitaper(
265-
epoch_data, epochs.info["sfreq"], freqs, n_cycles, output="complex"
264+
multitaper_complex, weights = tfr_array_multitaper(
265+
epoch_data,
266+
epochs.info["sfreq"],
267+
freqs,
268+
n_cycles,
269+
output="complex",
270+
return_weights=True,
266271
)
267272

268-
taper_dim = 2
269-
power_from_complex = (multitaper_complex * multitaper_complex.conj()).real.mean(
270-
axis=taper_dim
271-
)
273+
weights = np.expand_dims(weights, axis=(0, 1, -1)) # match shape of complex data
274+
tfr = weights * multitaper_complex
275+
tfr = (tfr * tfr.conj()).real.sum(axis=2)
276+
power_from_complex = tfr * (2 / (weights * weights.conj()).real.sum(axis=2))
272277
assert_allclose(power_from_complex, multitaper_power)
273278

274279
print(itc) # test repr

mne/time_frequency/tfr.py

+22-18
Original file line numberDiff line numberDiff line change
@@ -545,20 +545,18 @@ def _compute_tfr(
545545
if method == "morlet":
546546
W = morlet(sfreq, freqs, n_cycles=n_cycles, zero_mean=zero_mean)
547547
Ws = [W] # to have same dimensionality as the 'multitaper' case
548+
weights = None # no tapers for Morlet estimates
548549

549550
elif method == "multitaper":
550-
out = _make_dpss(
551+
Ws, weights = _make_dpss(
551552
sfreq,
552553
freqs,
553554
n_cycles=n_cycles,
554555
time_bandwidth=time_bandwidth,
555556
zero_mean=zero_mean,
556-
return_weights=return_weights,
557+
return_weights=True, # required for converting complex → power
557558
)
558-
if return_weights:
559-
Ws, weights = out
560-
else:
561-
Ws = out
559+
weights = np.asarray(weights)
562560

563561
# Check wavelets
564562
if len(Ws[0][0]) > epoch_data.shape[2]:
@@ -582,8 +580,6 @@ def _compute_tfr(
582580
out = np.empty((n_chans, n_freqs, n_times), dtype)
583581
elif output in ["complex", "phase"] and method == "multitaper":
584582
out = np.empty((n_chans, n_tapers, n_epochs, n_freqs, n_times), dtype)
585-
if return_weights:
586-
weights = np.array(weights)
587583
else:
588584
out = np.empty((n_chans, n_epochs, n_freqs, n_times), dtype)
589585

@@ -594,7 +590,7 @@ def _compute_tfr(
594590

595591
# Parallelization is applied across channels.
596592
tfrs = parallel(
597-
my_cwt(channel, Ws, output, use_fft, "same", decim, method)
593+
my_cwt(channel, Ws, output, use_fft, "same", decim, weights)
598594
for channel in epoch_data.transpose(1, 0, 2)
599595
)
600596

@@ -683,7 +679,7 @@ def _check_tfr_param(
683679
return freqs, sfreq, zero_mean, n_cycles, time_bandwidth, decim
684680

685681

686-
def _time_frequency_loop(X, Ws, output, use_fft, mode, decim, method=None):
682+
def _time_frequency_loop(X, Ws, output, use_fft, mode, decim, weights=None):
687683
"""Aux. function to _compute_tfr.
688684
689685
Loops time-frequency transform across wavelets and epochs.
@@ -710,9 +706,8 @@ def _time_frequency_loop(X, Ws, output, use_fft, mode, decim, method=None):
710706
See numpy.convolve.
711707
decim : slice
712708
The decimation slice: e.g. power[:, decim]
713-
method : str | None
714-
Used only for multitapering to create tapers dimension in the output
715-
if ``output in ['complex', 'phase']``.
709+
weights : array, shape (n_tapers, n_wavelets) | None
710+
Concentration weights for each taper in the wavelets, if present.
716711
"""
717712
# Set output type
718713
dtype = np.float64
@@ -726,10 +721,12 @@ def _time_frequency_loop(X, Ws, output, use_fft, mode, decim, method=None):
726721
n_freqs = len(Ws[0])
727722
if ("avg_" in output) or ("itc" in output):
728723
tfrs = np.zeros((n_freqs, n_times), dtype=dtype)
729-
elif output in ["complex", "phase"] and method == "multitaper":
724+
elif output in ["complex", "phase"] and weights is not None:
730725
tfrs = np.zeros((n_tapers, n_epochs, n_freqs, n_times), dtype=dtype)
731726
else:
732727
tfrs = np.zeros((n_epochs, n_freqs, n_times), dtype=dtype)
728+
if weights is not None:
729+
weights = np.expand_dims(weights, axis=-1) # add singleton time dimension
733730

734731
# Loops across tapers.
735732
for taper_idx, W in enumerate(Ws):
@@ -744,6 +741,8 @@ def _time_frequency_loop(X, Ws, output, use_fft, mode, decim, method=None):
744741
# Loop across epochs
745742
for epoch_idx, tfr in enumerate(coefs):
746743
# Transform complex values
744+
if output not in ["complex", "phase"] and weights is not None:
745+
tfr = weights[taper_idx] * tfr # weight each taper estimate
747746
if output in ["power", "avg_power"]:
748747
tfr = (tfr * tfr.conj()).real # power
749748
elif output == "phase":
@@ -759,7 +758,7 @@ def _time_frequency_loop(X, Ws, output, use_fft, mode, decim, method=None):
759758
# Stack or add
760759
if ("avg_" in output) or ("itc" in output):
761760
tfrs += tfr
762-
elif output in ["complex", "phase"] and method == "multitaper":
761+
elif output in ["complex", "phase"] and weights is not None:
763762
tfrs[taper_idx, epoch_idx] += tfr
764763
else:
765764
tfrs[epoch_idx] += tfr
@@ -774,9 +773,14 @@ def _time_frequency_loop(X, Ws, output, use_fft, mode, decim, method=None):
774773
if ("avg_" in output) or ("itc" in output):
775774
tfrs /= n_epochs
776775

777-
# Normalization by number of taper
778-
if n_tapers > 1 and output not in ["complex", "phase"]:
779-
tfrs /= n_tapers
776+
# Normalization by taper weights
777+
if n_tapers > 1 and output not in ["complex", "phase", "itc"]:
778+
if "avg_" not in output: # add singleton epochs dimension to weights
779+
weights = np.expand_dims(weights, axis=0)
780+
tfrs.real *= 2 / (weights * weights.conj()).real.sum(axis=-3)
781+
if output == "avg_power_itc": # weight itc by the number of tapers
782+
tfrs.imag = tfrs.imag / n_tapers
783+
780784
return tfrs
781785

782786

0 commit comments

Comments
 (0)