Skip to content

Commit

Permalink
Merge pull request #156 from gauteh/sel-seltime
Browse files Browse the repository at this point in the history
traj: change seltime to sel, more generic and works on all coordinates.
  • Loading branch information
gauteh authored Dec 12, 2024
2 parents 485458a + 9d8edc9 commit e95fa4c
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 18 deletions.
1 change: 1 addition & 0 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ Methods
Dataset.traj.convex_hull_contains_point
Dataset.traj.get_area_convex_hull
Dataset.traj.gridtime
Dataset.traj.sel
Dataset.traj.seltime
Dataset.traj.iseltime
Dataset.traj.crop
Expand Down
4 changes: 3 additions & 1 deletion tests/test_seltime.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,15 @@ def test_seltime(barents):
print(barents.time.max(skipna=True))

assert barents.time.min(skipna=True) < pd.to_datetime('2022-10-20')
assert barents.time.max(
skipna=True) > pd.to_datetime('2022-11-01T23:59:59')

ds = barents.traj.seltime('2022-10-20', '2022-11-01')
print(ds.time.min(skipna=True))
print(ds.time.max(skipna=True))

assert ds.time.min(skipna=True) >= pd.to_datetime('2022-10-20')
assert ds.time.max(skipna=True) <= pd.to_datetime('2022-11-01')
assert ds.time.max(skipna=True) <= pd.to_datetime('2022-11-01T23:59:59')


def test_iseltime(barents):
Expand Down
24 changes: 24 additions & 0 deletions trajan/traj.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ def detect_tx_variable(ds):
else:
raise ValueError("Could not determine x / lon variable")

def ensure_time_dim(ds, time_dim):
if not time_dim in ds.dims:
return ds.expand_dims(time_dim)
else:
return ds

class Traj:
ds: xr.Dataset
Expand Down Expand Up @@ -694,6 +699,25 @@ def gridtime(self, times, time_varname=None) -> xr.Dataset:
A new dataset interpolated to the target times. The dataset will be 1D (i.e. gridded) and the time dimension will be named `time`.
"""

@abstractmethod
def sel(self, *args, **kwargs) -> xr.Dataset:
"""Select on each trajectory. On 1D datasets this is just a shortcut for `Dataset.sel`.
Parameters
----------
Anything accepted by `Dataset.sel`.
Returns
-------
ds : Dataset
A dataset with the selected range in each trajectory.
See also
--------
iseltime, sel, isel
"""

@abstractmethod
def seltime(self, t0=None, t1=None) -> xr.Dataset:
"""Select observations in time window between `t0` and `t1` (inclusive). For 1D datasets prefer to use `xarray.Dataset.sel`.
Expand Down
5 changes: 4 additions & 1 deletion trajan/traj1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,11 @@ def skill(self, other, method='liu-weissberg', **kwargs):
coords={self.trajectory_dim: self.ds.trajectory},
attrs={'method': method})

def sel(self, *args, **kwargs):
return self.ds.sel(*args, **kwargs)

def seltime(self, t0=None, t1=None):
return self.ds.sel({self.time_varname: slice(t0, t1)})
self.ds.sel({self.time_varname: slice(t0, t1)})

def iseltime(self, i):
return self.ds.isel({self.time_varname: i})
Expand Down
25 changes: 9 additions & 16 deletions trajan/traj2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pandas as pd
import logging

from .traj import Traj
from .traj import Traj, ensure_time_dim

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -194,20 +194,13 @@ def condense_obs(self) -> xr.Dataset:

return ds

@__require_obs_dim__
def seltime(self, t0=None, t1=None):
if t0 is None:
t0 = np.nanmin(self.ds[self.time_varname].values.ravel())
if t1 is None:
t1 = np.nanmax(self.ds[self.time_varname].values.ravel())
def sel(self, *args, **kwargs):
return self.ds.groupby(
self.trajectory_dim).map(lambda d: ensure_time_dim(d.traj.to_1d(
).sel(*args, **kwargs), self.time_varname).traj.to_2d(self.obs_dim))

t0 = pd.to_datetime(t0)
t1 = pd.to_datetime(t1)

return self.ds.where(
np.logical_and(self.ds[self.time_varname] >= t0,
self.ds[self.time_varname]
<= t1)).dropna(self.obs_dim, how='all')
def seltime(self, t0=None, t1=None):
return self.sel({self.time_varname: slice(t0, t1)})

@__require_obs_dim__
def iseltime(self, i):
Expand Down Expand Up @@ -239,9 +232,9 @@ def to_1d(self):

ds[self.time_varname] = ds[self.time_varname].squeeze(
self.trajectory_dim)
ds = ds.loc[{self.time_varname : ~pd.isna(ds[self.time_varname])}]
ds = ds.loc[{self.time_varname: ~pd.isna(ds[self.time_varname])}]
_, ui = np.unique(ds[self.time_varname], return_index=True)
ds = ds.isel({self.time_varname : ui})
ds = ds.isel({self.time_varname: ui})

return ds

Expand Down

0 comments on commit e95fa4c

Please sign in to comment.