@@ -545,20 +545,18 @@ def _compute_tfr(
545
545
if method == "morlet" :
546
546
W = morlet (sfreq , freqs , n_cycles = n_cycles , zero_mean = zero_mean )
547
547
Ws = [W ] # to have same dimensionality as the 'multitaper' case
548
+ weights = None # no tapers for Morlet estimates
548
549
549
550
elif method == "multitaper" :
550
- out = _make_dpss (
551
+ Ws , weights = _make_dpss (
551
552
sfreq ,
552
553
freqs ,
553
554
n_cycles = n_cycles ,
554
555
time_bandwidth = time_bandwidth ,
555
556
zero_mean = zero_mean ,
556
- return_weights = return_weights ,
557
+ return_weights = True , # required for converting complex → power
557
558
)
558
- if return_weights :
559
- Ws , weights = out
560
- else :
561
- Ws = out
559
+ weights = np .asarray (weights )
562
560
563
561
# Check wavelets
564
562
if len (Ws [0 ][0 ]) > epoch_data .shape [2 ]:
@@ -582,8 +580,6 @@ def _compute_tfr(
582
580
out = np .empty ((n_chans , n_freqs , n_times ), dtype )
583
581
elif output in ["complex" , "phase" ] and method == "multitaper" :
584
582
out = np .empty ((n_chans , n_tapers , n_epochs , n_freqs , n_times ), dtype )
585
- if return_weights :
586
- weights = np .array (weights )
587
583
else :
588
584
out = np .empty ((n_chans , n_epochs , n_freqs , n_times ), dtype )
589
585
@@ -594,7 +590,7 @@ def _compute_tfr(
594
590
595
591
# Parallelization is applied across channels.
596
592
tfrs = parallel (
597
- my_cwt (channel , Ws , output , use_fft , "same" , decim , method )
593
+ my_cwt (channel , Ws , output , use_fft , "same" , decim , weights )
598
594
for channel in epoch_data .transpose (1 , 0 , 2 )
599
595
)
600
596
@@ -683,7 +679,7 @@ def _check_tfr_param(
683
679
return freqs , sfreq , zero_mean , n_cycles , time_bandwidth , decim
684
680
685
681
686
- def _time_frequency_loop (X , Ws , output , use_fft , mode , decim , method = None ):
682
+ def _time_frequency_loop (X , Ws , output , use_fft , mode , decim , weights = None ):
687
683
"""Aux. function to _compute_tfr.
688
684
689
685
Loops time-frequency transform across wavelets and epochs.
@@ -710,9 +706,8 @@ def _time_frequency_loop(X, Ws, output, use_fft, mode, decim, method=None):
710
706
See numpy.convolve.
711
707
decim : slice
712
708
The decimation slice: e.g. power[:, decim]
713
- method : str | None
714
- Used only for multitapering to create tapers dimension in the output
715
- if ``output in ['complex', 'phase']``.
709
+ weights : array, shape (n_tapers, n_wavelets) | None
710
+ Concentration weights for each taper in the wavelets, if present.
716
711
"""
717
712
# Set output type
718
713
dtype = np .float64
@@ -726,10 +721,12 @@ def _time_frequency_loop(X, Ws, output, use_fft, mode, decim, method=None):
726
721
n_freqs = len (Ws [0 ])
727
722
if ("avg_" in output ) or ("itc" in output ):
728
723
tfrs = np .zeros ((n_freqs , n_times ), dtype = dtype )
729
- elif output in ["complex" , "phase" ] and method == "multitaper" :
724
+ elif output in ["complex" , "phase" ] and weights is not None :
730
725
tfrs = np .zeros ((n_tapers , n_epochs , n_freqs , n_times ), dtype = dtype )
731
726
else :
732
727
tfrs = np .zeros ((n_epochs , n_freqs , n_times ), dtype = dtype )
728
+ if weights is not None :
729
+ weights = np .expand_dims (weights , axis = - 1 ) # add singleton time dimension
733
730
734
731
# Loops across tapers.
735
732
for taper_idx , W in enumerate (Ws ):
@@ -744,6 +741,8 @@ def _time_frequency_loop(X, Ws, output, use_fft, mode, decim, method=None):
744
741
# Loop across epochs
745
742
for epoch_idx , tfr in enumerate (coefs ):
746
743
# Transform complex values
744
+ if output not in ["complex" , "phase" ] and weights is not None :
745
+ tfr = weights [taper_idx ] * tfr # weight each taper estimate
747
746
if output in ["power" , "avg_power" ]:
748
747
tfr = (tfr * tfr .conj ()).real # power
749
748
elif output == "phase" :
@@ -759,7 +758,7 @@ def _time_frequency_loop(X, Ws, output, use_fft, mode, decim, method=None):
759
758
# Stack or add
760
759
if ("avg_" in output ) or ("itc" in output ):
761
760
tfrs += tfr
762
- elif output in ["complex" , "phase" ] and method == "multitaper" :
761
+ elif output in ["complex" , "phase" ] and weights is not None :
763
762
tfrs [taper_idx , epoch_idx ] += tfr
764
763
else :
765
764
tfrs [epoch_idx ] += tfr
@@ -774,9 +773,14 @@ def _time_frequency_loop(X, Ws, output, use_fft, mode, decim, method=None):
774
773
if ("avg_" in output ) or ("itc" in output ):
775
774
tfrs /= n_epochs
776
775
777
- # Normalization by number of taper
778
- if n_tapers > 1 and output not in ["complex" , "phase" ]:
779
- tfrs /= n_tapers
776
+ # Normalization by taper weights
777
+ if n_tapers > 1 and output not in ["complex" , "phase" , "itc" ]:
778
+ if "avg_" not in output : # add singleton epochs dimension to weights
779
+ weights = np .expand_dims (weights , axis = 0 )
780
+ tfrs .real *= 2 / (weights * weights .conj ()).real .sum (axis = - 3 )
781
+ if output == "avg_power_itc" : # weight itc by the number of tapers
782
+ tfrs .imag = tfrs .imag / n_tapers
783
+
780
784
return tfrs
781
785
782
786
0 commit comments