Skip to content

Commit f56a02c

Browse files
authored
Merge pull request #94 from kthyng/more_filters
fixes and improvements for tidal filter
2 parents 1805869 + a142637 commit f56a02c

File tree

1 file changed

+40
-11
lines changed

1 file changed

+40
-11
lines changed

oceans/filters.py

Lines changed: 40 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -415,10 +415,7 @@ def medfilt1(x, L=3):
415415
>>> L = 103
416416
>>> xout = medfilt1(x=x, L=L)
417417
>>> ax = plt.subplot(212)
418-
>>> (
419-
... l1,
420-
... l2,
421-
... ) = ax.plot(
418+
>>> (l1, l2,) = ax.plot(
422419
... x
423420
... ), ax.plot(xout)
424421
>>> ax.grid(True)
@@ -570,7 +567,7 @@ def md_trenberth(x):
570567
return y
571568

572569

573-
def pl33tn(x, dt=1.0, T=33.0, mode="valid"):
570+
def pl33tn(x, dt=1.0, T=33.0, mode="valid", t=None):
574571
"""
575572
Computes low-passed series from `x` using pl33 filter, with optional
576573
sample interval `dt` (hours) and filter half-amplitude period T (hours)
@@ -608,14 +605,25 @@ def pl33tn(x, dt=1.0, T=33.0, mode="valid"):
608605
"""
609606

610607
import cf_xarray # noqa: F401
608+
import pandas as pd
611609
import xarray as xr
612610

613-
if isinstance(x, xr.Dataset):
614-
raise TypeError("Input a DataArray not a Dataset.")
611+
if isinstance(x, (xr.Dataset, pd.DataFrame)):
612+
raise TypeError("Input a DataArray not a Dataset, or a Series not a DataFrame.")
615613

614+
if isinstance(x, pd.Series) and not isinstance(
615+
x.index,
616+
pd.core.indexes.datetimes.DatetimeIndex,
617+
):
618+
raise TypeError("Input Series needs to have parsed datetime indices.")
619+
620+
# find dt in units of hours
616621
if isinstance(x, xr.DataArray):
617-
# find dt in units of hours
618-
dt = (x.cf["T"][1] - x.cf["T"][0]) * 1e-9 / 3600
622+
dt = (x.cf["T"][1] - x.cf["T"][0]) / np.timedelta64(
623+
360_000_000_000,
624+
)
625+
elif isinstance(x, pd.Series):
626+
dt = (x.index[1] - x.index[0]) / pd.Timedelta("1H")
619627

620628
pl33 = np.array(
621629
[
@@ -694,18 +702,20 @@ def pl33tn(x, dt=1.0, T=33.0, mode="valid"):
694702
dt = float(dt) * (33.0 / T)
695703

696704
filter_time = np.arange(0.0, 33.0, dt, dtype="d")
697-
# N = len(filter_time)
705+
Nt = len(filter_time)
698706
filter_time = np.hstack((-filter_time[-1:0:-1], filter_time))
699707

700708
pl33 = np.interp(filter_time, _dt, pl33)
701709
pl33 /= pl33.sum()
702710

703711
if isinstance(x, xr.DataArray):
712+
x = x.interpolate_na(dim=x.cf["T"].name)
713+
704714
weight = xr.DataArray(pl33, dims=["window"])
705715
xf = (
706716
x.rolling({x.cf["T"].name: len(pl33)}, center=True)
707717
.construct({x.cf["T"].name: "window"})
708-
.dot(weight)
718+
.dot(weight, dims="window")
709719
)
710720
# update attrs
711721
attrs = {
@@ -715,7 +725,26 @@ def pl33tn(x, dt=1.0, T=33.0, mode="valid"):
715725
}
716726
xf.attrs = attrs
717727

728+
elif isinstance(x, pd.Series):
729+
xf = x.to_frame().apply(np.convolve, v=pl33, mode=mode)
730+
731+
# nan out edges which are not good values anyway
732+
if mode == "same":
733+
xf[: Nt - 1] = np.nan
734+
xf[-Nt + 2 :] = np.nan
735+
718736
else: # use numpy
719737
xf = np.convolve(x, pl33, mode=mode)
720738

739+
# times to match xf
740+
if t is not None:
741+
# Nt = len(filter_time)
742+
tf = t[Nt - 1 : -Nt + 1]
743+
return xf, tf
744+
745+
# nan out edges which are not good values anyway
746+
if mode == "same":
747+
xf[: Nt - 1] = np.nan
748+
xf[-Nt + 2 :] = np.nan
749+
721750
return xf

0 commit comments

Comments
 (0)