@@ -266,6 +266,7 @@ def _make_dpss(
266
266
The wavelets time series.
267
267
"""
268
268
Ws = list ()
269
+ Cs = list ()
269
270
270
271
freqs = np .array (freqs )
271
272
if np .any (freqs <= 0 ):
@@ -281,6 +282,7 @@ def _make_dpss(
281
282
282
283
for m in range (n_taps ):
283
284
Wm = list ()
285
+ Cm = list ()
284
286
for k , f in enumerate (freqs ):
285
287
if len (n_cycles ) != 1 :
286
288
this_n_cycles = n_cycles [k ]
@@ -302,12 +304,15 @@ def _make_dpss(
302
304
real_offset = Wk .mean ()
303
305
Wk -= real_offset
304
306
Wk /= np .sqrt (0.5 ) * np .linalg .norm (Wk .ravel ())
307
+ Ck = np .sqrt (conc [m ])
305
308
306
309
Wm .append (Wk )
310
+ Cm .append (Ck )
307
311
308
312
Ws .append (Wm )
313
+ Cs .append (Cm )
309
314
if return_weights :
310
- return Ws , conc
315
+ return Ws , Cs
311
316
return Ws
312
317
313
318
@@ -529,15 +534,18 @@ def _compute_tfr(
529
534
if method == "morlet" :
530
535
W = morlet (sfreq , freqs , n_cycles = n_cycles , zero_mean = zero_mean )
531
536
Ws = [W ] # to have same dimensionality as the 'multitaper' case
537
+ weights = None # no tapers for Morlet estimates
532
538
533
539
elif method == "multitaper" :
534
- Ws = _make_dpss (
540
+ Ws , weights = _make_dpss (
535
541
sfreq ,
536
542
freqs ,
537
543
n_cycles = n_cycles ,
538
544
time_bandwidth = time_bandwidth ,
539
545
zero_mean = zero_mean ,
546
+ return_weights = True , # required for converting complex → power
540
547
)
548
+ weights = np .asarray (weights )
541
549
542
550
# Check wavelets
543
551
if len (Ws [0 ][0 ]) > epoch_data .shape [2 ]:
@@ -560,7 +568,7 @@ def _compute_tfr(
560
568
if ("avg_" in output ) or ("itc" in output ):
561
569
out = np .empty ((n_chans , n_freqs , n_times ), dtype )
562
570
elif output in ["complex" , "phase" ] and method == "multitaper" :
563
- out = np .empty ((n_chans , n_tapers , n_epochs , n_freqs , n_times ), dtype )
571
+ out = np .empty ((n_chans , n_epochs , n_tapers , n_freqs , n_times ), dtype )
564
572
else :
565
573
out = np .empty ((n_chans , n_epochs , n_freqs , n_times ), dtype )
566
574
@@ -571,7 +579,7 @@ def _compute_tfr(
571
579
572
580
# Parallelization is applied across channels.
573
581
tfrs = parallel (
574
- my_cwt (channel , Ws , output , use_fft , "same" , decim , method )
582
+ my_cwt (channel , Ws , output , use_fft , "same" , decim , weights )
575
583
for channel in epoch_data .transpose (1 , 0 , 2 )
576
584
)
577
585
@@ -581,10 +589,8 @@ def _compute_tfr(
581
589
582
590
if ("avg_" not in output ) and ("itc" not in output ):
583
591
# This is to enforce that the first dimension is for epochs
584
- if output in ["complex" , "phase" ] and method == "multitaper" :
585
- out = out .transpose (2 , 0 , 1 , 3 , 4 )
586
- else :
587
- out = out .transpose (1 , 0 , 2 , 3 )
592
+ out = np .moveaxis (out , 1 , 0 )
593
+
588
594
return out
589
595
590
596
@@ -658,7 +664,7 @@ def _check_tfr_param(
658
664
return freqs , sfreq , zero_mean , n_cycles , time_bandwidth , decim
659
665
660
666
661
- def _time_frequency_loop (X , Ws , output , use_fft , mode , decim , method = None ):
667
+ def _time_frequency_loop (X , Ws , output , use_fft , mode , decim , weights = None ):
662
668
"""Aux. function to _compute_tfr.
663
669
664
670
Loops time-frequency transform across wavelets and epochs.
@@ -685,9 +691,8 @@ def _time_frequency_loop(X, Ws, output, use_fft, mode, decim, method=None):
685
691
See numpy.convolve.
686
692
decim : slice
687
693
The decimation slice: e.g. power[:, decim]
688
- method : str | None
689
- Used only for multitapering to create tapers dimension in the output
690
- if ``output in ['complex', 'phase']``.
694
+ weights : array, shape (n_tapers, n_wavelets) | None
695
+ Concentration weights for each taper in the wavelets, if present.
691
696
"""
692
697
# Set output type
693
698
dtype = np .float64
@@ -701,10 +706,12 @@ def _time_frequency_loop(X, Ws, output, use_fft, mode, decim, method=None):
701
706
n_freqs = len (Ws [0 ])
702
707
if ("avg_" in output ) or ("itc" in output ):
703
708
tfrs = np .zeros ((n_freqs , n_times ), dtype = dtype )
704
- elif output in ["complex" , "phase" ] and method == "multitaper" :
705
- tfrs = np .zeros ((n_tapers , n_epochs , n_freqs , n_times ), dtype = dtype )
709
+ elif output in ["complex" , "phase" ] and weights is not None :
710
+ tfrs = np .zeros ((n_epochs , n_tapers , n_freqs , n_times ), dtype = dtype )
706
711
else :
707
712
tfrs = np .zeros ((n_epochs , n_freqs , n_times ), dtype = dtype )
713
+ if weights is not None :
714
+ weights = np .expand_dims (weights , axis = - 1 ) # add singleton time dimension
708
715
709
716
# Loops across tapers.
710
717
for taper_idx , W in enumerate (Ws ):
@@ -719,6 +726,8 @@ def _time_frequency_loop(X, Ws, output, use_fft, mode, decim, method=None):
719
726
# Loop across epochs
720
727
for epoch_idx , tfr in enumerate (coefs ):
721
728
# Transform complex values
729
+ if output not in ["complex" , "phase" ] and weights is not None :
730
+ tfr = weights [taper_idx ] * tfr # weight each taper estimate
722
731
if output in ["power" , "avg_power" ]:
723
732
tfr = (tfr * tfr .conj ()).real # power
724
733
elif output == "phase" :
@@ -734,8 +743,8 @@ def _time_frequency_loop(X, Ws, output, use_fft, mode, decim, method=None):
734
743
# Stack or add
735
744
if ("avg_" in output ) or ("itc" in output ):
736
745
tfrs += tfr
737
- elif output in ["complex" , "phase" ] and method == "multitaper" :
738
- tfrs [taper_idx , epoch_idx ] += tfr
746
+ elif output in ["complex" , "phase" ] and weights is not None :
747
+ tfrs [epoch_idx , taper_idx ] += tfr
739
748
else :
740
749
tfrs [epoch_idx ] += tfr
741
750
@@ -749,9 +758,14 @@ def _time_frequency_loop(X, Ws, output, use_fft, mode, decim, method=None):
749
758
if ("avg_" in output ) or ("itc" in output ):
750
759
tfrs /= n_epochs
751
760
752
- # Normalization by number of taper
753
- if n_tapers > 1 and output not in ["complex" , "phase" ]:
754
- tfrs /= n_tapers
761
+ # Normalization by taper weights
762
+ if n_tapers > 1 and output not in ["complex" , "phase" , "itc" ]:
763
+ if "avg_" not in output : # add singleton epochs dimension to weights
764
+ weights = np .expand_dims (weights , axis = 0 )
765
+ tfrs .real *= 2 / (weights * weights .conj ()).real .sum (axis = - 3 )
766
+ if output == "avg_power_itc" : # weight itc by the number of tapers
767
+ tfrs .imag = tfrs .imag / n_tapers
768
+
755
769
return tfrs
756
770
757
771
0 commit comments