Skip to content

Commit

Permalink
Merge pull request #142 from knutfrode/dev
Browse files Browse the repository at this point in the history
Making attribute names more precise: obsdim->obs_dimname, timedim->time_varname
  • Loading branch information
knutfrode authored Nov 13, 2024
2 parents dc6f354 + b42c6c4 commit 391f3e2
Show file tree
Hide file tree
Showing 9 changed files with 162 additions and 170 deletions.
2 changes: 1 addition & 1 deletion examples/example_find_positions_at_obs_times.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
def gridwaves(tds):
t = tds[['lat', 'lon',
'time']].traj.gridtime(tds['time_waves_imu'].squeeze())
return t.traj.to_2d(obsdim='obs_waves_imu')
return t.traj.to_2d(obs_dim='obs_waves_imu')


dsw = ds.groupby('trajectory').map(gridwaves)
Expand Down
6 changes: 5 additions & 1 deletion examples/example_opendrift.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,11 @@
ds = ds.where(ds.status>=0) # only active particles

#%%
# Displaying a basic plot of trajectories
# Displaying some basic information about this dataset
print(ds.traj)

#%%
# Making a basic plot of trajectories
ds.traj.plot()
plt.title('Basic trajectory plot')
plt.show()
Expand Down
4 changes: 2 additions & 2 deletions tests/test_read_sfy.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ def test_interpret_sfy(test_data):
ds = xr.open_dataset(test_data / 'bug32.nc')
print(ds)

assert ds.traj.obsdim == 'package'
assert ds.traj.timedim == 'position_time'
assert ds.traj.obs_dim == 'package'
assert ds.traj.time_varname == 'position_time'

assert ds.traj.is_2d()

Expand Down
13 changes: 13 additions & 0 deletions tests/test_repr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import xarray as xr
import trajan as _

def test_repr_1d(opendrift_sim):
repr = str(opendrift_sim.traj)
assert '2015-11-16T00:00' in repr
assert 'Timestep: 1:00:00' in repr
assert "67 timesteps time['time'] (1D)" in repr

def test_repr_2d(test_data):
ds = xr.open_dataset(test_data / 'bug32.nc')
repr = str(ds.traj)
assert '2023-10-19T15:46:53.514499520' in repr
68 changes: 28 additions & 40 deletions trajan/accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,32 +7,20 @@

logger = logging.getLogger(__name__)

from .traj import Traj
from .traj import Traj, detect_tx_variable
from .traj1d import Traj1d
from .traj2d import Traj2d
from .ragged import ContiguousRagged


def detect_tx_dim(ds):
if 'lon' in ds:
return ds.lon
elif 'longitude' in ds:
return ds.longitude
elif 'x' in ds:
return ds.x
elif 'X' in ds:
return ds.X
else:
raise ValueError("Could not determine x / lon variable")


def detect_time_dim(ds, obsdim):
logger.debug(f'Detecting time-dimension for "{obsdim}"..')
def detect_time_variable(ds, obs_dim):
logger.debug(f'Detecting time-variable for "{obs_dim}"..')
# TODO: should use cf-xarray here
for v in ds.variables:
if obsdim in ds[v].dims and 'time' in v:
if obs_dim in ds[v].dims and 'time' in v:
return v

raise ValueError("no time dimension detected")
raise ValueError("No time variable detected")


@xr.register_dataset_accessor("traj")
Expand All @@ -47,10 +35,10 @@ def __new__(cls, ds):
ds = ds.expand_dims({'trajectory': 1})
ds['trajectory'].attrs['cf_role'] = 'trajectory_id'

obsdim = None
timedim = None
obs_dim = None
time_varname = None

tx = detect_tx_dim(ds)
tx = detect_tx_variable(ds)

# if we have a 1D dims, this is most likely some contiguous data
# there may be a few exceptions though, so be ready to default to the classical 2D parser below
Expand All @@ -62,7 +50,7 @@ def __new__(cls, ds):
# NOTE: this is probably not standard; something to point to the CF conventions?
# NOTE: for now, there is no discovery of the "index" dim, this is hardcorded; any way to do better?
if "index" in tx.dims:
obsdim = "index"
obs_dim = "index"

# discover the timecoord variable name #######################
# find all variables with standard_name "time"
Expand Down Expand Up @@ -90,63 +78,63 @@ def __new__(cls, ds):
else:
raise ValueError(f"cannot deduce rowsizevar; we have the following candidates: {with_dim_trajectory = }")
# sanity check
if not np.sum(ds[rowsizevar].to_numpy()) == len(ds[obsdim]):
if not np.sum(ds[rowsizevar].to_numpy()) == len(ds[obs_dim]):
raise ValueError("mismatch between the index length and the sum of the deduced trajectory lengths")

logger.debug(
f"1D storage dataset; detected: {obsdim = }, {timecoord = }, {trajectorycoord = }, {rowsizevar}"
f"1D storage dataset; detected: {obs_dim = }, {timecoord = }, {trajectorycoord = }, {rowsizevar}"
)

return ocls(ds, obsdim, timecoord, trajectorycoord, rowsizevar)
return ocls(ds, obs_dim, timecoord, trajectorycoord, rowsizevar)

else:
logging.warning(f"{ds} has {tx.dims = } which is of dimension 1 but is not index; this is a bit unusual; try to parse with Traj1d or Traj2d")

# we have a ds where 2D arrays are used to store data, this is either Traj1d or Traj2d
# there may also be some slightly unusual cases where these Traj1d and Traj2d classes will be used on data with 1D arrays
if 'obs' in tx.dims:
obsdim = 'obs'
timedim = detect_time_dim(ds, obsdim)
obs_dim = 'obs'
time_varname = detect_time_variable(ds, obs_dim)

elif 'index' in tx.dims:
obsdim = 'obs'
timedim = detect_time_dim(ds, obsdim)
obs_dim = 'obs'
time_varname = detect_time_variable(ds, obs_dim)

elif 'time' in tx.dims:
obsdim = 'time'
timedim = 'time'
obs_dim = 'time'
time_varname = 'time'

else:
for d in tx.dims:
if not ds[d].attrs.get(
'cf_role',
None) == 'trajectory_id' and not 'traj' in d:

obsdim = d
timedim = detect_time_dim(ds, obsdim)
obs_dim = d
time_varname = detect_time_variable(ds, obs_dim)

break

if obsdim is None:
if obs_dim is None:
logger.warning('No time or obs dimension detected.')

logger.debug(
f"Detected obs-dim: {obsdim}, detected time-dim: {timedim}.")
f"Detected obs-dim: {obs_dim}, detected time-variable: {time_varname}.")

if obsdim is None:
if obs_dim is None:
ocls = Traj1d

elif len(ds[timedim].shape) <= 1:
elif len(ds[time_varname].shape) <= 1:
logger.debug('Detected structured (1D) trajectory dataset')
ocls = Traj1d

elif len(ds[timedim].shape) == 2:
elif len(ds[time_varname].shape) == 2:
logger.debug('Detected un-structured (2D) trajectory dataset')
ocls = Traj2d

else:
raise ValueError(
f'Time dimension has shape greater than 2: {ds["timedim"].shape}'
f'Time variable has more than two dimensions: {ds[time_varname].shape}'
)

return ocls(ds, obsdim, timedim)
return ocls(ds, obs_dim, time_varname)
22 changes: 11 additions & 11 deletions trajan/ragged.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@ class ContiguousRagged(Traj):
trajdim: str
rowvar: str

def __init__(self, ds, obsdim, timedim, trajectorycoord, rowsizevar):
def __init__(self, ds, obs_dim, time_varname, trajectorycoord, rowsizevar):
self.trajdim = trajectorycoord
self.rowvar = rowsizevar
super().__init__(ds, obsdim, timedim)
super().__init__(ds, obs_dim, time_varname)

def to_2d(self, obsdim='obs'):
def to_2d(self, obs_dim='obs'):
"""This actually converts a contiguous ragged xarray Dataset into an xarray Dataset that follows the Traj2d conventions."""
global_attrs = self.ds.attrs

Expand All @@ -45,7 +45,7 @@ def to_2d(self, obsdim='obs'):
self.ds[self.rowvar].to_numpy()):
end_index = start_index + crrt_rowsize
array_time[crrt_index, :crrt_rowsize] = self.ds[
self.timedim][start_index:end_index]
self.time_varname][start_index:end_index]
start_index = end_index

# it seems that we need to build the "backbone" of the Dataset independently first
Expand All @@ -70,7 +70,7 @@ def to_2d(self, obsdim='obs'):

# trajectory vars
'time':
xr.DataArray(dims=["trajectory", obsdim],
xr.DataArray(dims=["trajectory", obs_dim],
data=array_time,
attrs={
"standard_name": "time",
Expand All @@ -81,7 +81,7 @@ def to_2d(self, obsdim='obs'):

# now add all "normal" variables
# NOTE: for now, we only consider scalar vars; if we want to consider more complex vars (e.g., spectra), this will need updated
# NOTE: such an update would typically need to look at the dims of the variable, and if there are additional dims to obsdim, create a higer dim variable
# NOTE: such an update would typically need to look at the dims of the variable, and if there are additional dims to obs_dim, create a higer dim variable

for crrt_data_var in self.ds.data_vars:
attrs = self.ds[crrt_data_var].attrs
Expand All @@ -90,9 +90,9 @@ def to_2d(self, obsdim='obs'):
continue

if len(self.ds[crrt_data_var].dims
) != 1 or self.ds[crrt_data_var].dims[0] != self.obsdim:
) != 1 or self.ds[crrt_data_var].dims[0] != self.obs_dim:
raise ValueError(
f"data_vars element {crrt_data_var} has dims {self.ds[crrt_data_var].dims}, expected {(self.obsdim,)}"
f"data_vars element {crrt_data_var} has dims {self.ds[crrt_data_var].dims}, expected {(self.obs_dim,)}"
)

crrt_var = np.full((nbr_trajectories, longest_trajectory), np.nan)
Expand All @@ -119,7 +119,7 @@ def to_2d(self, obsdim='obs'):
crrt_data_var = "lat"

ds_converted_to_traj2d[crrt_data_var] = \
xr.DataArray(dims=["trajectory", obsdim],
xr.DataArray(dims=["trajectory", obs_dim],
data=crrt_var,
attrs=attrs)

Expand All @@ -140,6 +140,6 @@ def plot(self) -> Plot:
def timestep(self, average=np.median):
return self.to_2d().traj.timestep(average)

def gridtime(self, times, timedim=None, round=True):
return self.to_2d().traj.gridtime(times, timedim, round)
def gridtime(self, times, time_varname=None, round=True):
return self.to_2d().traj.gridtime(times, time_varname, round)

Loading

0 comments on commit 391f3e2

Please sign in to comment.