Skip to content

Commit d4f57b2

Browse files
tsbinnslarsoner
andcommitted
Backport PR #13067 on branch maint/1.9 ([BUG] Fix taper weighting in computation of TFR multitaper power)
Co-Authored-By: Eric Larson <[email protected]>
1 parent 672bdf4 commit d4f57b2

File tree

3 files changed

+34
-35
lines changed

3 files changed

+34
-35
lines changed

Diff for: doc/changes/devel/13067.bugfix.rst

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fix bug where taper weights were not correctly applied when computing multitaper power with :meth:`mne.Epochs.compute_tfr` and :func:`mne.time_frequency.tfr_array_multitaper`, by `Thomas Binns`_.

Diff for: mne/time_frequency/tests/test_tfr.py

-16
Original file line numberDiff line numberDiff line change
@@ -255,22 +255,6 @@ def test_tfr_morlet():
255255
# computed within the method.
256256
assert_allclose(epochs_amplitude_2.data**2, epochs_power_picks.data)
257257

258-
# test that averaging power across tapers when multitaper with
259-
# output='complex' gives the same as output='power'
260-
epoch_data = epochs.get_data()
261-
multitaper_power = tfr_array_multitaper(
262-
epoch_data, epochs.info["sfreq"], freqs, n_cycles, output="power"
263-
)
264-
multitaper_complex = tfr_array_multitaper(
265-
epoch_data, epochs.info["sfreq"], freqs, n_cycles, output="complex"
266-
)
267-
268-
taper_dim = 2
269-
power_from_complex = (multitaper_complex * multitaper_complex.conj()).real.mean(
270-
axis=taper_dim
271-
)
272-
assert_allclose(power_from_complex, multitaper_power)
273-
274258
print(itc) # test repr
275259
print(itc.ch_names) # test property
276260
itc += power # test add

Diff for: mne/time_frequency/tfr.py

+33-19
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,7 @@ def _make_dpss(
266266
The wavelets time series.
267267
"""
268268
Ws = list()
269+
Cs = list()
269270

270271
freqs = np.array(freqs)
271272
if np.any(freqs <= 0):
@@ -281,6 +282,7 @@ def _make_dpss(
281282

282283
for m in range(n_taps):
283284
Wm = list()
285+
Cm = list()
284286
for k, f in enumerate(freqs):
285287
if len(n_cycles) != 1:
286288
this_n_cycles = n_cycles[k]
@@ -302,12 +304,15 @@ def _make_dpss(
302304
real_offset = Wk.mean()
303305
Wk -= real_offset
304306
Wk /= np.sqrt(0.5) * np.linalg.norm(Wk.ravel())
307+
Ck = np.sqrt(conc[m])
305308

306309
Wm.append(Wk)
310+
Cm.append(Ck)
307311

308312
Ws.append(Wm)
313+
Cs.append(Cm)
309314
if return_weights:
310-
return Ws, conc
315+
return Ws, Cs
311316
return Ws
312317

313318

@@ -529,15 +534,18 @@ def _compute_tfr(
529534
if method == "morlet":
530535
W = morlet(sfreq, freqs, n_cycles=n_cycles, zero_mean=zero_mean)
531536
Ws = [W] # to have same dimensionality as the 'multitaper' case
537+
weights = None # no tapers for Morlet estimates
532538

533539
elif method == "multitaper":
534-
Ws = _make_dpss(
540+
Ws, weights = _make_dpss(
535541
sfreq,
536542
freqs,
537543
n_cycles=n_cycles,
538544
time_bandwidth=time_bandwidth,
539545
zero_mean=zero_mean,
546+
return_weights=True, # required for converting complex → power
540547
)
548+
weights = np.asarray(weights)
541549

542550
# Check wavelets
543551
if len(Ws[0][0]) > epoch_data.shape[2]:
@@ -560,7 +568,7 @@ def _compute_tfr(
560568
if ("avg_" in output) or ("itc" in output):
561569
out = np.empty((n_chans, n_freqs, n_times), dtype)
562570
elif output in ["complex", "phase"] and method == "multitaper":
563-
out = np.empty((n_chans, n_tapers, n_epochs, n_freqs, n_times), dtype)
571+
out = np.empty((n_chans, n_epochs, n_tapers, n_freqs, n_times), dtype)
564572
else:
565573
out = np.empty((n_chans, n_epochs, n_freqs, n_times), dtype)
566574

@@ -571,7 +579,7 @@ def _compute_tfr(
571579

572580
# Parallelization is applied across channels.
573581
tfrs = parallel(
574-
my_cwt(channel, Ws, output, use_fft, "same", decim, method)
582+
my_cwt(channel, Ws, output, use_fft, "same", decim, weights)
575583
for channel in epoch_data.transpose(1, 0, 2)
576584
)
577585

@@ -581,10 +589,8 @@ def _compute_tfr(
581589

582590
if ("avg_" not in output) and ("itc" not in output):
583591
# This is to enforce that the first dimension is for epochs
584-
if output in ["complex", "phase"] and method == "multitaper":
585-
out = out.transpose(2, 0, 1, 3, 4)
586-
else:
587-
out = out.transpose(1, 0, 2, 3)
592+
out = np.moveaxis(out, 1, 0)
593+
588594
return out
589595

590596

@@ -658,7 +664,7 @@ def _check_tfr_param(
658664
return freqs, sfreq, zero_mean, n_cycles, time_bandwidth, decim
659665

660666

661-
def _time_frequency_loop(X, Ws, output, use_fft, mode, decim, method=None):
667+
def _time_frequency_loop(X, Ws, output, use_fft, mode, decim, weights=None):
662668
"""Aux. function to _compute_tfr.
663669
664670
Loops time-frequency transform across wavelets and epochs.
@@ -685,9 +691,8 @@ def _time_frequency_loop(X, Ws, output, use_fft, mode, decim, method=None):
685691
See numpy.convolve.
686692
decim : slice
687693
The decimation slice: e.g. power[:, decim]
688-
method : str | None
689-
Used only for multitapering to create tapers dimension in the output
690-
if ``output in ['complex', 'phase']``.
694+
weights : array, shape (n_tapers, n_wavelets) | None
695+
Concentration weights for each taper in the wavelets, if present.
691696
"""
692697
# Set output type
693698
dtype = np.float64
@@ -701,10 +706,12 @@ def _time_frequency_loop(X, Ws, output, use_fft, mode, decim, method=None):
701706
n_freqs = len(Ws[0])
702707
if ("avg_" in output) or ("itc" in output):
703708
tfrs = np.zeros((n_freqs, n_times), dtype=dtype)
704-
elif output in ["complex", "phase"] and method == "multitaper":
705-
tfrs = np.zeros((n_tapers, n_epochs, n_freqs, n_times), dtype=dtype)
709+
elif output in ["complex", "phase"] and weights is not None:
710+
tfrs = np.zeros((n_epochs, n_tapers, n_freqs, n_times), dtype=dtype)
706711
else:
707712
tfrs = np.zeros((n_epochs, n_freqs, n_times), dtype=dtype)
713+
if weights is not None:
714+
weights = np.expand_dims(weights, axis=-1) # add singleton time dimension
708715

709716
# Loops across tapers.
710717
for taper_idx, W in enumerate(Ws):
@@ -719,6 +726,8 @@ def _time_frequency_loop(X, Ws, output, use_fft, mode, decim, method=None):
719726
# Loop across epochs
720727
for epoch_idx, tfr in enumerate(coefs):
721728
# Transform complex values
729+
if output not in ["complex", "phase"] and weights is not None:
730+
tfr = weights[taper_idx] * tfr # weight each taper estimate
722731
if output in ["power", "avg_power"]:
723732
tfr = (tfr * tfr.conj()).real # power
724733
elif output == "phase":
@@ -734,8 +743,8 @@ def _time_frequency_loop(X, Ws, output, use_fft, mode, decim, method=None):
734743
# Stack or add
735744
if ("avg_" in output) or ("itc" in output):
736745
tfrs += tfr
737-
elif output in ["complex", "phase"] and method == "multitaper":
738-
tfrs[taper_idx, epoch_idx] += tfr
746+
elif output in ["complex", "phase"] and weights is not None:
747+
tfrs[epoch_idx, taper_idx] += tfr
739748
else:
740749
tfrs[epoch_idx] += tfr
741750

@@ -749,9 +758,14 @@ def _time_frequency_loop(X, Ws, output, use_fft, mode, decim, method=None):
749758
if ("avg_" in output) or ("itc" in output):
750759
tfrs /= n_epochs
751760

752-
# Normalization by number of taper
753-
if n_tapers > 1 and output not in ["complex", "phase"]:
754-
tfrs /= n_tapers
761+
# Normalization by taper weights
762+
if n_tapers > 1 and output not in ["complex", "phase", "itc"]:
763+
if "avg_" not in output: # add singleton epochs dimension to weights
764+
weights = np.expand_dims(weights, axis=0)
765+
tfrs.real *= 2 / (weights * weights.conj()).real.sum(axis=-3)
766+
if output == "avg_power_itc": # weight itc by the number of tapers
767+
tfrs.imag = tfrs.imag / n_tapers
768+
755769
return tfrs
756770

757771

0 commit comments

Comments
 (0)