@@ -415,10 +415,7 @@ def medfilt1(x, L=3):
415
415
>>> L = 103
416
416
>>> xout = medfilt1(x=x, L=L)
417
417
>>> ax = plt.subplot(212)
418
- >>> (
419
- ... l1,
420
- ... l2,
421
- ... ) = ax.plot(
418
+ >>> (l1, l2,) = ax.plot(
422
419
... x
423
420
... ), ax.plot(xout)
424
421
>>> ax.grid(True)
@@ -570,7 +567,7 @@ def md_trenberth(x):
570
567
return y
571
568
572
569
573
- def pl33tn (x , dt = 1.0 , T = 33.0 , mode = "valid" ):
570
+ def pl33tn (x , dt = 1.0 , T = 33.0 , mode = "valid" , t = None ):
574
571
"""
575
572
Computes low-passed series from `x` using pl33 filter, with optional
576
573
sample interval `dt` (hours) and filter half-amplitude period T (hours)
@@ -608,14 +605,25 @@ def pl33tn(x, dt=1.0, T=33.0, mode="valid"):
608
605
"""
609
606
610
607
import cf_xarray # noqa: F401
608
+ import pandas as pd
611
609
import xarray as xr
612
610
613
- if isinstance (x , xr .Dataset ):
614
- raise TypeError ("Input a DataArray not a Dataset." )
611
+ if isinstance (x , ( xr .Dataset , pd . DataFrame ) ):
612
+ raise TypeError ("Input a DataArray not a Dataset, or a Series not a DataFrame ." )
615
613
614
+ if isinstance (x , pd .Series ) and not isinstance (
615
+ x .index ,
616
+ pd .core .indexes .datetimes .DatetimeIndex ,
617
+ ):
618
+ raise TypeError ("Input Series needs to have parsed datetime indices." )
619
+
620
+ # find dt in units of hours
616
621
if isinstance (x , xr .DataArray ):
617
- # find dt in units of hours
618
- dt = (x .cf ["T" ][1 ] - x .cf ["T" ][0 ]) * 1e-9 / 3600
622
+ dt = (x .cf ["T" ][1 ] - x .cf ["T" ][0 ]) / np .timedelta64 (
623
+ 360_000_000_000 ,
624
+ )
625
+ elif isinstance (x , pd .Series ):
626
+ dt = (x .index [1 ] - x .index [0 ]) / pd .Timedelta ("1H" )
619
627
620
628
pl33 = np .array (
621
629
[
@@ -694,18 +702,20 @@ def pl33tn(x, dt=1.0, T=33.0, mode="valid"):
694
702
dt = float (dt ) * (33.0 / T )
695
703
696
704
filter_time = np .arange (0.0 , 33.0 , dt , dtype = "d" )
697
- # N = len(filter_time)
705
+ Nt = len (filter_time )
698
706
filter_time = np .hstack ((- filter_time [- 1 :0 :- 1 ], filter_time ))
699
707
700
708
pl33 = np .interp (filter_time , _dt , pl33 )
701
709
pl33 /= pl33 .sum ()
702
710
703
711
if isinstance (x , xr .DataArray ):
712
+ x = x .interpolate_na (dim = x .cf ["T" ].name )
713
+
704
714
weight = xr .DataArray (pl33 , dims = ["window" ])
705
715
xf = (
706
716
x .rolling ({x .cf ["T" ].name : len (pl33 )}, center = True )
707
717
.construct ({x .cf ["T" ].name : "window" })
708
- .dot (weight )
718
+ .dot (weight , dims = "window" )
709
719
)
710
720
# update attrs
711
721
attrs = {
@@ -715,7 +725,26 @@ def pl33tn(x, dt=1.0, T=33.0, mode="valid"):
715
725
}
716
726
xf .attrs = attrs
717
727
728
+ elif isinstance (x , pd .Series ):
729
+ xf = x .to_frame ().apply (np .convolve , v = pl33 , mode = mode )
730
+
731
+ # nan out edges which are not good values anyway
732
+ if mode == "same" :
733
+ xf [: Nt - 1 ] = np .nan
734
+ xf [- Nt + 2 :] = np .nan
735
+
718
736
else : # use numpy
719
737
xf = np .convolve (x , pl33 , mode = mode )
720
738
739
+ # times to match xf
740
+ if t is not None :
741
+ # Nt = len(filter_time)
742
+ tf = t [Nt - 1 : - Nt + 1 ]
743
+ return xf , tf
744
+
745
+ # nan out edges which are not good values anyway
746
+ if mode == "same" :
747
+ xf [: Nt - 1 ] = np .nan
748
+ xf [- Nt + 2 :] = np .nan
749
+
721
750
return xf
0 commit comments