Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ docs/_build/*
docs/_downloads
docs/jupyter_execute/*
docs/.jupyter_cache/*
docs/reference
output

*.log
Expand Down
64 changes: 32 additions & 32 deletions src/parcels/_core/utils/sgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,8 @@ def to_attrs(self) -> dict[str, str | int]:
d["vertical_dimensions"] = dump_mappings(self.vertical_dimensions)
return d

def rename_dims(self, dims_dict: dict[str, str]) -> Self:
return _metadata_rename_dims(self, dims_dict)
def rename(self, names_dict: dict[str, str]) -> Self:
return _metadata_rename(self, names_dict)


class Grid3DMetadata(AttrsSerializable):
Expand Down Expand Up @@ -248,8 +248,8 @@ def to_attrs(self) -> dict[str, str | int]:
d["node_coordinates"] = dump_mappings(self.node_coordinates)
return d

def rename_dims(self, dims_dict: dict[str, str]) -> Self:
return _metadata_rename_dims(self, dims_dict)
def rename(self, dims_dict: dict[str, str]) -> Self:
return _metadata_rename(self, dims_dict)


@dataclass
Expand Down Expand Up @@ -418,22 +418,22 @@ def parse_sgrid(ds: xr.Dataset):
return (ds, {"coords": xgcm_coords})


def rename_dims(ds: xr.Dataset, dims_dict: dict[str, str]) -> xr.Dataset:
def rename(ds: xr.Dataset, name_dict: dict[str, str]) -> xr.Dataset:
grid_da = get_grid_topology(ds)
if grid_da is None:
raise ValueError(
"No variable found in dataset with 'cf_role' attribute set to 'grid_topology'. This doesn't look to be an SGrid dataset - please make your dataset conforms to SGrid conventions."
)

ds = ds.rename_dims(dims_dict)
ds = ds.rename(name_dict)

# Update the metadata
grid = parse_grid_attrs(grid_da.attrs)
ds[grid_da.name].attrs = grid.rename_dims(dims_dict).to_attrs()
ds[grid_da.name].attrs = grid.rename(name_dict).to_attrs()
return ds


def get_unique_dim_names(grid: Grid2DMetadata | Grid3DMetadata) -> set[str]:
def get_unique_names(grid: Grid2DMetadata | Grid3DMetadata) -> set[str]:
dims = set()
dims.update(set(grid.node_dimensions))

Expand All @@ -453,14 +453,6 @@ def get_unique_dim_names(grid: Grid2DMetadata | Grid3DMetadata) -> set[str]:
return dims


@overload
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These were moved to be directly over the _metadata_rename_dims function (as this is how @overloads work).

Wasn't flagged before since we're not (yet) running mypy

def _metadata_rename_dims(grid: Grid2DMetadata, dims_dict: dict[str, str]) -> Grid2DMetadata: ...


@overload
def _metadata_rename_dims(grid: Grid3DMetadata, dims_dict: dict[str, str]) -> Grid3DMetadata: ...


def _attach_sgrid_metadata(ds, grid: Grid2DMetadata | Grid3DMetadata):
"""Copies the dataset and attaches the SGRID metadata in 'grid' variable. Modifies 'conventions' attribute."""
ds = ds.copy()
Expand All @@ -473,24 +465,32 @@ def _attach_sgrid_metadata(ds, grid: Grid2DMetadata | Grid3DMetadata):
return ds


def _metadata_rename_dims(grid, dims_dict):
@overload
def _metadata_rename(grid: Grid2DMetadata, names_dict: dict[str, str]) -> Grid2DMetadata: ...


@overload
def _metadata_rename(grid: Grid3DMetadata, names_dict: dict[str, str]) -> Grid3DMetadata: ...


def _metadata_rename(grid, names_dict):
"""
Renames dimensions in SGrid metadata.
Renames dimensions and coordinates in SGrid metadata.

Similar in API to xr.Dataset.rename_dims. Renames dimensions according to dims_dict mapping
Similar in API to xr.Dataset.rename . Renames dimensions according to names_dict mapping
of old dimension names to new dimension names.
"""
dims_dict = dims_dict.copy()
assert len(dims_dict) == len(set(dims_dict.values())), "dims_dict contains duplicate target dimension names"
names_dict = names_dict.copy()
assert len(names_dict) == len(set(names_dict.values())), "names_dict contains duplicate target dimension names"

existing_dims = get_unique_dim_names(grid)
for dim in dims_dict.keys():
if dim not in existing_dims:
raise ValueError(f"Dimension {dim!r} not found in SGrid metadata dimensions {existing_dims!r}")
existing_names = get_unique_names(grid)
for name in names_dict.keys():
if name not in existing_names:
raise ValueError(f"Name {name!r} not found in names defined in SGrid metadata {existing_names!r}")

for dim in existing_dims:
if dim not in dims_dict:
dims_dict[dim] = dim # identity mapping for dimensions not being renamed
for name in existing_names:
if name not in names_dict:
names_dict[name] = name # identity mapping for names not being renamed

kwargs = {}
for key, value in grid.__dict__.items():
Expand All @@ -499,14 +499,14 @@ def _metadata_rename_dims(grid, dims_dict):
for item in value:
if isinstance(item, DimDimPadding):
new_item = DimDimPadding(
dim1=dims_dict[item.dim1],
dim2=dims_dict[item.dim2],
dim1=names_dict[item.dim1],
dim2=names_dict[item.dim2],
padding=item.padding,
)
new_value.append(new_item)
else:
assert isinstance(item, str)
new_value.append(dims_dict[item])
new_value.append(names_dict[item])
kwargs[key] = tuple(new_value)
continue

Expand All @@ -515,7 +515,7 @@ def _metadata_rename_dims(grid, dims_dict):
continue

if isinstance(value, str):
kwargs[key] = dims_dict[value]
kwargs[key] = names_dict[value]
continue

raise ValueError(f"Unexpected attribute {key!r} on {grid!r}")
Expand Down
8 changes: 5 additions & 3 deletions src/parcels/_datasets/structured/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
_attach_sgrid_metadata,
)
from parcels._core.utils.sgrid import (
rename_dims as sgrid_rename_dims,
rename as sgrid_rename,
)
from parcels._datasets.utils import _attach_sgrid_metadata

Expand Down Expand Up @@ -258,11 +258,12 @@ def _unrolled_cone_curvilinear_grid():
DimDimPadding("XC", "XG", Padding.HIGH),
DimDimPadding("YC", "YG", Padding.HIGH),
),
node_coordinates=("lon", "lat"),
vertical_dimensions=(DimDimPadding("ZC", "ZG", Padding.HIGH),),
),
)
.pipe(
sgrid_rename_dims,
sgrid_rename,
_COMODO_TO_2D_SGRID,
)
),
Expand All @@ -278,11 +279,12 @@ def _unrolled_cone_curvilinear_grid():
DimDimPadding("XC", "XG", Padding.LOW),
DimDimPadding("YC", "YG", Padding.LOW),
),
node_coordinates=("lon", "lat"),
vertical_dimensions=(DimDimPadding("ZC", "ZG", Padding.LOW),),
),
)
.pipe(
sgrid_rename_dims,
sgrid_rename,
_COMODO_TO_2D_SGRID,
)
),
Expand Down
75 changes: 61 additions & 14 deletions tests/utils/test_sgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def dummy_sgrid_2d_ds(grid: sgrid.Grid2DMetadata) -> xr.Dataset:
ds = dummy_comodo_3d_ds()

# Can't rename dimensions that already exist in the dataset
assume(sgrid.get_unique_dim_names(grid) & set(ds.dims) == set())
assume(sgrid.get_unique_names(grid) & set(ds.dims) == set())

renamings = {}
if grid.vertical_dimensions is None:
Expand All @@ -90,7 +90,7 @@ def dummy_sgrid_3d_ds(grid: sgrid.Grid3DMetadata) -> xr.Dataset:
ds = dummy_comodo_3d_ds()

# Can't rename dimensions that already exist in the dataset
assume(sgrid.get_unique_dim_names(grid) & set(ds.dims) == set())
assume(sgrid.get_unique_names(grid) & set(ds.dims) == set())

renamings = {}
for old, new in zip(["XG", "YG", "ZG"], grid.node_dimensions, strict=True):
Expand Down Expand Up @@ -250,30 +250,77 @@ def test_parse_sgrid_3d(grid_metadata: sgrid.Grid3DMetadata):
]
+ [create_example_grid3dmetadata(with_node_coordinates=i) for i in [False, True]],
)
def test_rename_dims(grid):
dims = sgrid.get_unique_dim_names(grid)
def test_rename(grid):
dims = sgrid.get_unique_names(grid)
dims_dict = {dim: f"new_{dim}" for dim in dims}
dims_dict_inv = {v: k for k, v in dims_dict.items()}

grid_new = grid.rename_dims(dims_dict)
assert dims & set(sgrid.get_unique_dim_names(grid_new)) == set()
grid_new = grid.rename(dims_dict)
assert dims & set(sgrid.get_unique_names(grid_new)) == set()

assert grid == grid_new.rename_dims(dims_dict_inv)
assert grid == grid_new.rename(dims_dict_inv)


def test_rename_dims_errors():
def test_rename_errors():
# Test various error modes of rename_dims
grid = grid2dmetadata
# Non-unique target dimension names
dims_dict = {
names_dict = {
"node_dimension1": "new_node_dimension",
"node_dimension2": "new_node_dimension",
}
with pytest.raises(AssertionError, match="dims_dict contains duplicate target dimension names"):
grid.rename_dims(dims_dict)
with pytest.raises(AssertionError, match="names_dict contains duplicate target dimension names"):
grid.rename(names_dict)
# Unexpected attribute in dims_dict
dims_dict = {
names_dict = {
"unexpected_dimension": "new_unexpected_dimension",
}
with pytest.raises(ValueError, match="Dimension 'unexpected_dimension' not found in SGrid metadata dimensions"):
grid.rename_dims(dims_dict)
with pytest.raises(ValueError, match="Name 'unexpected_dimension' not found in names defined in SGrid metadata"):
grid.rename(names_dict)


@pytest.mark.parametrize(
"ds",
[
xr.Dataset(
{
"data_g": (["time", "ZG", "YG", "XG"], np.random.rand(10, 10, 10, 10)),
"data_c": (["time", "ZC", "YC", "XC"], np.random.rand(10, 10, 10, 10)),
"grid": (
[],
np.array(0),
sgrid.Grid2DMetadata(
cf_role="grid_topology",
topology_dimension=2,
node_dimensions=("XG", "YG"),
face_dimensions=(
sgrid.DimDimPadding("XC", "XG", sgrid.Padding.HIGH),
sgrid.DimDimPadding("YC", "YG", sgrid.Padding.HIGH),
),
vertical_dimensions=(sgrid.DimDimPadding("ZC", "ZG", sgrid.Padding.HIGH),),
node_coordinates=("lon", "lat"),
).to_attrs(),
),
},
coords={
"lon": (["XG"], 2 * np.pi / 10 * np.arange(0, 10)),
"lat": (["YG"], 2 * np.pi / (10) * np.arange(0, 10)),
"depth": (["ZG"], np.arange(10)),
"time": (["time"], xr.date_range("2000", "2001", 10), {"axis": "T"}),
},
),
],
)
def test_rename_dataset(ds):
# Check renaming works for coordinates
ds_new = sgrid.rename(ds, {"lon": "lon_updated"})
grid_new = sgrid.parse_grid_attrs(ds_new["grid"].attrs)
assert "lon_updated" in ds_new.coords
assert "lon_updated" == grid_new.node_coordinates[0]

# Check renaming works for dim
ds_new = sgrid.rename(ds, {"XC": "XC_updated"})
grid_new = sgrid.parse_grid_attrs(ds_new["grid"].attrs)
assert "XC_updated" in ds_new.dims
assert "XC" not in ds_new.dims
assert "XC_updated" == grid_new.face_dimensions[0].dim1
Loading