Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 35 additions & 8 deletions src/parcels/_core/utils/sgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def __init__(
topology_dimension: Literal[2],
node_dimensions: tuple[Dim, Dim],
face_dimensions: tuple[DimDimPadding, DimDimPadding],
node_coordinates: None | tuple[Dim, Dim] = None,
vertical_dimensions: None | tuple[DimDimPadding] = None,
):
if cf_role != "grid_topology":
Expand All @@ -76,6 +77,14 @@ def __init__(
):
raise ValueError("face_dimensions must be a tuple of 2 DimDimPadding for a 2D grid")

if node_coordinates is not None:
if not (
isinstance(node_coordinates, tuple)
and len(node_coordinates) == 2
and all(isinstance(nd, str) for nd in node_coordinates)
):
raise ValueError("node_coordinates must be a tuple of 2 dimensions for a 2D grid")

if vertical_dimensions is not None:
if not (
isinstance(vertical_dimensions, tuple)
Expand All @@ -90,21 +99,21 @@ def __init__(
self.node_dimensions = node_dimensions
self.face_dimensions = face_dimensions

#! Optional attributes aren't really important to Parcels, can be added later if needed
# Optional attributes
self.node_coordinates = node_coordinates
self.vertical_dimensions = vertical_dimensions

#! Some optional attributes aren't really important to Parcels, can be added later if needed
# Optional attributes
# # With defaults (set in init)
# edge1_dimensions: tuple[Dim, DimDimPadding]
# edge2_dimensions: tuple[DimDimPadding, Dim]

# # Without defaults
# node_coordinates: None | Any = None
# edge1_coordinates: None | Any = None
# edge2_coordinates: None | Any = None
# face_coordinate: None | Any = None

#! Important optional attribute for 2D grids with vertical layering
self.vertical_dimensions = vertical_dimensions

def __repr__(self) -> str:
return repr_from_dunder_dict(self)

Expand All @@ -121,6 +130,7 @@ def from_attrs(cls, attrs):
topology_dimension=attrs["topology_dimension"],
node_dimensions=load_mappings(attrs["node_dimensions"]),
face_dimensions=load_mappings(attrs["face_dimensions"]),
node_coordinates=maybe_load_mappings(attrs.get("node_coordinates")),
vertical_dimensions=maybe_load_mappings(attrs.get("vertical_dimensions")),
)
except Exception as e:
Expand All @@ -133,6 +143,8 @@ def to_attrs(self) -> dict[str, str | int]:
node_dimensions=dump_mappings(self.node_dimensions),
face_dimensions=dump_mappings(self.face_dimensions),
)
if self.node_coordinates is not None:
d["node_coordinates"] = dump_mappings(self.node_coordinates)
if self.vertical_dimensions is not None:
d["vertical_dimensions"] = dump_mappings(self.vertical_dimensions)
return d
Expand All @@ -148,6 +160,7 @@ def __init__(
topology_dimension: Literal[3],
node_dimensions: tuple[Dim, Dim, Dim],
volume_dimensions: tuple[DimDimPadding, DimDimPadding, DimDimPadding],
node_coordinates: None | tuple[Dim, Dim, Dim] = None,
):
if cf_role != "grid_topology":
raise ValueError(f"cf_role must be 'grid_topology', got {cf_role!r}")
Expand All @@ -169,13 +182,24 @@ def __init__(
):
raise ValueError("face_dimensions must be a tuple of 2 DimDimPadding for a 2D grid")

if node_coordinates is not None:
if not (
isinstance(node_coordinates, tuple)
and len(node_coordinates) == 3
and all(isinstance(nd, str) for nd in node_coordinates)
):
raise ValueError("node_coordinates must be a tuple of 3 dimensions for a 3D grid")

# Required attributes
self.cf_role = cf_role
self.topology_dimension = topology_dimension
self.node_dimensions = node_dimensions
self.volume_dimensions = volume_dimensions

# ! Optional attributes aren't really important to Parcels, can be added later if needed
# Optional attributes
self.node_coordinates = node_coordinates

# ! Some optional attributes aren't really important to Parcels, can be added later if needed
# Optional attributes
# # With defaults (set in init)
# edge1_dimensions: tuple[DimDimPadding, Dim, Dim]
Expand All @@ -186,7 +210,6 @@ def __init__(
# face3_dimensions: tuple[DimDimPadding, DimDimPadding, Dim]

# # Without defaults
# node_coordinates
# edge *i_coordinates*
# face *i_coordinates*
# volume_coordinates
Expand All @@ -207,17 +230,21 @@ def from_attrs(cls, attrs):
topology_dimension=attrs["topology_dimension"],
node_dimensions=load_mappings(attrs["node_dimensions"]),
volume_dimensions=load_mappings(attrs["volume_dimensions"]),
node_coordinates=maybe_load_mappings(attrs.get("node_coordinates")),
)
except Exception as e:
raise SGridParsingException(f"Failed to parse Grid3DMetadata from {attrs=!r}") from e

def to_attrs(self) -> dict[str, str | int]:
return dict(
d = dict(
cf_role=self.cf_role,
topology_dimension=self.topology_dimension,
node_dimensions=dump_mappings(self.node_dimensions),
volume_dimensions=dump_mappings(self.volume_dimensions),
)
if self.node_coordinates is not None:
d["node_coordinates"] = dump_mappings(self.node_coordinates)
return d

def rename_dims(self, dims_dict: dict[str, str]) -> Self:
return _metadata_rename_dims(self, dims_dict)
Expand Down
29 changes: 25 additions & 4 deletions tests/strategies/sgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

@st.composite
def grid2Dmetadata(draw) -> sgrid.Grid2DMetadata:
N = 6
N = 8
names = draw(st.lists(dimension_name, min_size=N, max_size=N, unique=True))
node_dimension1 = names[0]
node_dimension2 = names[1]
Expand All @@ -37,11 +37,20 @@ def grid2Dmetadata(draw) -> sgrid.Grid2DMetadata:
padding_type1 = draw(padding)
padding_type2 = draw(padding)

vertical_dimensions_dim1 = names[4]
vertical_dimensions_dim2 = names[5]
node_coordinates_var1 = names[4]
node_coordinates_var2 = names[5]
has_node_coordinates = draw(st.booleans())

vertical_dimensions_dim1 = names[6]
vertical_dimensions_dim2 = names[7]
vertical_dimensions_padding = draw(padding)
has_vertical_dimensions = draw(st.booleans())

if has_node_coordinates:
node_coordinates = (node_coordinates_var1, node_coordinates_var2)
else:
node_coordinates = None

if has_vertical_dimensions:
vertical_dimensions = (
sgrid.DimDimPadding(vertical_dimensions_dim1, vertical_dimensions_dim2, vertical_dimensions_padding),
Expand All @@ -57,13 +66,14 @@ def grid2Dmetadata(draw) -> sgrid.Grid2DMetadata:
sgrid.DimDimPadding(face_dimension1, node_dimension1, padding_type1),
sgrid.DimDimPadding(face_dimension2, node_dimension2, padding_type2),
),
node_coordinates=node_coordinates,
vertical_dimensions=vertical_dimensions,
)


@st.composite
def grid3Dmetadata(draw) -> sgrid.Grid3DMetadata:
N = 6
N = 9
names = draw(st.lists(dimension_name, min_size=N, max_size=N, unique=True))
node_dimension1 = names[0]
node_dimension2 = names[1]
Expand All @@ -75,6 +85,16 @@ def grid3Dmetadata(draw) -> sgrid.Grid3DMetadata:
padding_type2 = draw(padding)
padding_type3 = draw(padding)

node_coordinates_var1 = names[6]
node_coordinates_var2 = names[7]
node_coordinates_dim3 = names[8]
has_node_coordinates = draw(st.booleans())

if has_node_coordinates:
node_coordinates = (node_coordinates_var1, node_coordinates_var2, node_coordinates_dim3)
else:
node_coordinates = None

return sgrid.Grid3DMetadata(
cf_role="grid_topology",
topology_dimension=3,
Expand All @@ -84,6 +104,7 @@ def grid3Dmetadata(draw) -> sgrid.Grid3DMetadata:
sgrid.DimDimPadding(face_dimension2, node_dimension2, padding_type2),
sgrid.DimDimPadding(face_dimension3, node_dimension3, padding_type3),
),
node_coordinates=node_coordinates,
)


Expand Down
107 changes: 46 additions & 61 deletions tests/utils/test_sgrid.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import itertools

import numpy as np
import pytest
import xarray as xr
Expand All @@ -7,29 +9,47 @@
from parcels._core.utils import sgrid
from tests.strategies import sgrid as sgrid_strategies

grid2dmetadata = sgrid.Grid2DMetadata(
cf_role="grid_topology",
topology_dimension=2,
node_dimensions=("node_dimension1", "node_dimension2"),
face_dimensions=(
sgrid.DimDimPadding("face_dimension1", "node_dimension1", sgrid.Padding.LOW),
sgrid.DimDimPadding("face_dimension2", "node_dimension2", sgrid.Padding.LOW),
),
vertical_dimensions=(
sgrid.DimDimPadding("vertical_dimensions_dim1", "vertical_dimensions_dim2", sgrid.Padding.LOW),
),
)

grid3dmetadata = sgrid.Grid3DMetadata(
cf_role="grid_topology",
topology_dimension=3,
node_dimensions=("node_dimension1", "node_dimension2", "node_dimension3"),
volume_dimensions=(
sgrid.DimDimPadding("face_dimension1", "node_dimension1", sgrid.Padding.LOW),
sgrid.DimDimPadding("face_dimension2", "node_dimension2", sgrid.Padding.LOW),
sgrid.DimDimPadding("face_dimension3", "node_dimension3", sgrid.Padding.LOW),
),
)
def create_example_grid2dmetadata(with_vertical_dimensions: bool, with_node_coordinates: bool):
vertical_dimensions = (
(sgrid.DimDimPadding("vertical_dimensions_dim1", "vertical_dimensions_dim2", sgrid.Padding.LOW),)
if with_vertical_dimensions
else None
)
node_coordinates = ("node_coordinates_var1", "node_coordinates_var2") if with_node_coordinates else None

return sgrid.Grid2DMetadata(
cf_role="grid_topology",
topology_dimension=2,
node_dimensions=("node_dimension1", "node_dimension2"),
face_dimensions=(
sgrid.DimDimPadding("face_dimension1", "node_dimension1", sgrid.Padding.LOW),
sgrid.DimDimPadding("face_dimension2", "node_dimension2", sgrid.Padding.LOW),
),
node_coordinates=node_coordinates,
vertical_dimensions=vertical_dimensions,
)


def create_example_grid3dmetadata(with_node_coordinates: bool):
node_coordinates = (
("node_coordinates_var1", "node_coordinates_var2", "node_coordinates_dim3") if with_node_coordinates else None
)
return sgrid.Grid3DMetadata(
cf_role="grid_topology",
topology_dimension=3,
node_dimensions=("node_dimension1", "node_dimension2", "node_dimension3"),
volume_dimensions=(
sgrid.DimDimPadding("face_dimension1", "node_dimension1", sgrid.Padding.LOW),
sgrid.DimDimPadding("face_dimension2", "node_dimension2", sgrid.Padding.LOW),
sgrid.DimDimPadding("face_dimension3", "node_dimension3", sgrid.Padding.LOW),
),
node_coordinates=node_coordinates,
)


grid2dmetadata = create_example_grid2dmetadata(with_vertical_dimensions=True, with_node_coordinates=True)
grid3dmetadata = create_example_grid3dmetadata(with_node_coordinates=True)


def dummy_sgrid_ds(grid: sgrid.Grid2DMetadata | sgrid.Grid3DMetadata) -> xr.Dataset:
Expand Down Expand Up @@ -225,45 +245,10 @@ def test_parse_sgrid_3d(grid_metadata: sgrid.Grid3DMetadata):
@pytest.mark.parametrize(
"grid",
[
(
sgrid.Grid2DMetadata(
cf_role="grid_topology",
topology_dimension=2,
node_dimensions=("node_dimension1", "node_dimension2"),
face_dimensions=(
sgrid.DimDimPadding("face_dimension1", "node_dimension1", sgrid.Padding.LOW),
sgrid.DimDimPadding("face_dimension2", "node_dimension2", sgrid.Padding.LOW),
),
vertical_dimensions=(
sgrid.DimDimPadding("vertical_dimensions_dim1", "vertical_dimensions_dim2", sgrid.Padding.LOW),
),
)
),
(
sgrid.Grid2DMetadata(
cf_role="grid_topology",
topology_dimension=2,
node_dimensions=("node_dimension1", "node_dimension2"),
face_dimensions=(
sgrid.DimDimPadding("face_dimension1", "node_dimension1", sgrid.Padding.LOW),
sgrid.DimDimPadding("face_dimension2", "node_dimension2", sgrid.Padding.LOW),
),
vertical_dimensions=None,
)
),
(
sgrid.Grid3DMetadata(
cf_role="grid_topology",
topology_dimension=3,
node_dimensions=("node_dimension1", "node_dimension2", "node_dimension3"),
volume_dimensions=(
sgrid.DimDimPadding("face_dimension1", "node_dimension1", sgrid.Padding.LOW),
sgrid.DimDimPadding("face_dimension2", "node_dimension2", sgrid.Padding.LOW),
sgrid.DimDimPadding("face_dimension3", "node_dimension3", sgrid.Padding.LOW),
),
)
),
],
create_example_grid2dmetadata(with_node_coordinates=i, with_vertical_dimensions=j)
for i, j in itertools.product([False, True], [False, True])
]
+ [create_example_grid3dmetadata(with_node_coordinates=i) for i in [False, True]],
Comment on lines +248 to +251
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a quick refactor so that we test all combinations

)
def test_rename_dims(grid):
dims = sgrid.get_unique_dim_names(grid)
Expand Down
Loading