diff --git a/src/parcels/_core/utils/sgrid.py b/src/parcels/_core/utils/sgrid.py index 6ff698ad1..ff87d5081 100644 --- a/src/parcels/_core/utils/sgrid.py +++ b/src/parcels/_core/utils/sgrid.py @@ -54,6 +54,7 @@ def __init__( topology_dimension: Literal[2], node_dimensions: tuple[Dim, Dim], face_dimensions: tuple[DimDimPadding, DimDimPadding], + node_coordinates: None | tuple[Dim, Dim] = None, vertical_dimensions: None | tuple[DimDimPadding] = None, ): if cf_role != "grid_topology": @@ -76,6 +77,14 @@ def __init__( ): raise ValueError("face_dimensions must be a tuple of 2 DimDimPadding for a 2D grid") + if node_coordinates is not None: + if not ( + isinstance(node_coordinates, tuple) + and len(node_coordinates) == 2 + and all(isinstance(nd, str) for nd in node_coordinates) + ): + raise ValueError("node_coordinates must be a tuple of 2 dimensions for a 2D grid") + if vertical_dimensions is not None: if not ( isinstance(vertical_dimensions, tuple) @@ -90,21 +99,21 @@ def __init__( self.node_dimensions = node_dimensions self.face_dimensions = face_dimensions - #! Optional attributes aren't really important to Parcels, can be added later if needed + # Optional attributes + self.node_coordinates = node_coordinates + self.vertical_dimensions = vertical_dimensions + + #! Some optional attributes aren't really important to Parcels, can be added later if needed # Optional attributes # # With defaults (set in init) # edge1_dimensions: tuple[Dim, DimDimPadding] # edge2_dimensions: tuple[DimDimPadding, Dim] # # Without defaults - # node_coordinates: None | Any = None # edge1_coordinates: None | Any = None # edge2_coordinates: None | Any = None # face_coordinate: None | Any = None - #! Important optional attribute for 2D grids with vertical layering - self.vertical_dimensions = vertical_dimensions - def __repr__(self) -> str: return repr_from_dunder_dict(self) @@ -121,6 +130,7 @@ def from_attrs(cls, attrs): topology_dimension=attrs["topology_dimension"], node_dimensions=load_mappings(attrs["node_dimensions"]), face_dimensions=load_mappings(attrs["face_dimensions"]), + node_coordinates=maybe_load_mappings(attrs.get("node_coordinates")), vertical_dimensions=maybe_load_mappings(attrs.get("vertical_dimensions")), ) except Exception as e: @@ -133,6 +143,8 @@ def to_attrs(self) -> dict[str, str | int]: node_dimensions=dump_mappings(self.node_dimensions), face_dimensions=dump_mappings(self.face_dimensions), ) + if self.node_coordinates is not None: + d["node_coordinates"] = dump_mappings(self.node_coordinates) if self.vertical_dimensions is not None: d["vertical_dimensions"] = dump_mappings(self.vertical_dimensions) return d @@ -148,6 +160,7 @@ def __init__( topology_dimension: Literal[3], node_dimensions: tuple[Dim, Dim, Dim], volume_dimensions: tuple[DimDimPadding, DimDimPadding, DimDimPadding], + node_coordinates: None | tuple[Dim, Dim, Dim] = None, ): if cf_role != "grid_topology": raise ValueError(f"cf_role must be 'grid_topology', got {cf_role!r}") @@ -169,13 +182,24 @@ def __init__( ): raise ValueError("face_dimensions must be a tuple of 2 DimDimPadding for a 2D grid") + if node_coordinates is not None: + if not ( + isinstance(node_coordinates, tuple) + and len(node_coordinates) == 3 + and all(isinstance(nd, str) for nd in node_coordinates) + ): + raise ValueError("node_coordinates must be a tuple of 3 dimensions for a 3D grid") + # Required attributes self.cf_role = cf_role self.topology_dimension = topology_dimension self.node_dimensions = node_dimensions self.volume_dimensions = volume_dimensions - # ! Optional attributes aren't really important to Parcels, can be added later if needed + # Optional attributes + self.node_coordinates = node_coordinates + + # ! Some optional attributes aren't really important to Parcels, can be added later if needed # Optional attributes # # With defaults (set in init) # edge1_dimensions: tuple[DimDimPadding, Dim, Dim] @@ -186,7 +210,6 @@ def __init__( # face3_dimensions: tuple[DimDimPadding, DimDimPadding, Dim] # # Without defaults - # node_coordinates # edge *i_coordinates* # face *i_coordinates* # volume_coordinates @@ -207,17 +230,21 @@ def from_attrs(cls, attrs): topology_dimension=attrs["topology_dimension"], node_dimensions=load_mappings(attrs["node_dimensions"]), volume_dimensions=load_mappings(attrs["volume_dimensions"]), + node_coordinates=maybe_load_mappings(attrs.get("node_coordinates")), ) except Exception as e: raise SGridParsingException(f"Failed to parse Grid3DMetadata from {attrs=!r}") from e def to_attrs(self) -> dict[str, str | int]: - return dict( + d = dict( cf_role=self.cf_role, topology_dimension=self.topology_dimension, node_dimensions=dump_mappings(self.node_dimensions), volume_dimensions=dump_mappings(self.volume_dimensions), ) + if self.node_coordinates is not None: + d["node_coordinates"] = dump_mappings(self.node_coordinates) + return d def rename_dims(self, dims_dict: dict[str, str]) -> Self: return _metadata_rename_dims(self, dims_dict) diff --git a/tests/strategies/sgrid.py b/tests/strategies/sgrid.py index eed1d7583..562b242e1 100644 --- a/tests/strategies/sgrid.py +++ b/tests/strategies/sgrid.py @@ -28,7 +28,7 @@ @st.composite def grid2Dmetadata(draw) -> sgrid.Grid2DMetadata: - N = 6 + N = 8 names = draw(st.lists(dimension_name, min_size=N, max_size=N, unique=True)) node_dimension1 = names[0] node_dimension2 = names[1] @@ -37,11 +37,20 @@ def grid2Dmetadata(draw) -> sgrid.Grid2DMetadata: padding_type1 = draw(padding) padding_type2 = draw(padding) - vertical_dimensions_dim1 = names[4] - vertical_dimensions_dim2 = names[5] + node_coordinates_var1 = names[4] + node_coordinates_var2 = names[5] + has_node_coordinates = draw(st.booleans()) + + vertical_dimensions_dim1 = names[6] + vertical_dimensions_dim2 = names[7] vertical_dimensions_padding = draw(padding) has_vertical_dimensions = draw(st.booleans()) + if has_node_coordinates: + node_coordinates = (node_coordinates_var1, node_coordinates_var2) + else: + node_coordinates = None + if has_vertical_dimensions: vertical_dimensions = ( sgrid.DimDimPadding(vertical_dimensions_dim1, vertical_dimensions_dim2, vertical_dimensions_padding), @@ -57,13 +66,14 @@ def grid2Dmetadata(draw) -> sgrid.Grid2DMetadata: sgrid.DimDimPadding(face_dimension1, node_dimension1, padding_type1), sgrid.DimDimPadding(face_dimension2, node_dimension2, padding_type2), ), + node_coordinates=node_coordinates, vertical_dimensions=vertical_dimensions, ) @st.composite def grid3Dmetadata(draw) -> sgrid.Grid3DMetadata: - N = 6 + N = 9 names = draw(st.lists(dimension_name, min_size=N, max_size=N, unique=True)) node_dimension1 = names[0] node_dimension2 = names[1] @@ -75,6 +85,16 @@ def grid3Dmetadata(draw) -> sgrid.Grid3DMetadata: padding_type2 = draw(padding) padding_type3 = draw(padding) + node_coordinates_var1 = names[6] + node_coordinates_var2 = names[7] + node_coordinates_dim3 = names[8] + has_node_coordinates = draw(st.booleans()) + + if has_node_coordinates: + node_coordinates = (node_coordinates_var1, node_coordinates_var2, node_coordinates_dim3) + else: + node_coordinates = None + return sgrid.Grid3DMetadata( cf_role="grid_topology", topology_dimension=3, @@ -84,6 +104,7 @@ def grid3Dmetadata(draw) -> sgrid.Grid3DMetadata: sgrid.DimDimPadding(face_dimension2, node_dimension2, padding_type2), sgrid.DimDimPadding(face_dimension3, node_dimension3, padding_type3), ), + node_coordinates=node_coordinates, ) diff --git a/tests/utils/test_sgrid.py b/tests/utils/test_sgrid.py index 27f1630a8..e73c97865 100644 --- a/tests/utils/test_sgrid.py +++ b/tests/utils/test_sgrid.py @@ -1,3 +1,5 @@ +import itertools + import numpy as np import pytest import xarray as xr @@ -7,29 +9,47 @@ from parcels._core.utils import sgrid from tests.strategies import sgrid as sgrid_strategies -grid2dmetadata = sgrid.Grid2DMetadata( - cf_role="grid_topology", - topology_dimension=2, - node_dimensions=("node_dimension1", "node_dimension2"), - face_dimensions=( - sgrid.DimDimPadding("face_dimension1", "node_dimension1", sgrid.Padding.LOW), - sgrid.DimDimPadding("face_dimension2", "node_dimension2", sgrid.Padding.LOW), - ), - vertical_dimensions=( - sgrid.DimDimPadding("vertical_dimensions_dim1", "vertical_dimensions_dim2", sgrid.Padding.LOW), - ), -) -grid3dmetadata = sgrid.Grid3DMetadata( - cf_role="grid_topology", - topology_dimension=3, - node_dimensions=("node_dimension1", "node_dimension2", "node_dimension3"), - volume_dimensions=( - sgrid.DimDimPadding("face_dimension1", "node_dimension1", sgrid.Padding.LOW), - sgrid.DimDimPadding("face_dimension2", "node_dimension2", sgrid.Padding.LOW), - sgrid.DimDimPadding("face_dimension3", "node_dimension3", sgrid.Padding.LOW), - ), -) +def create_example_grid2dmetadata(with_vertical_dimensions: bool, with_node_coordinates: bool): + vertical_dimensions = ( + (sgrid.DimDimPadding("vertical_dimensions_dim1", "vertical_dimensions_dim2", sgrid.Padding.LOW),) + if with_vertical_dimensions + else None + ) + node_coordinates = ("node_coordinates_var1", "node_coordinates_var2") if with_node_coordinates else None + + return sgrid.Grid2DMetadata( + cf_role="grid_topology", + topology_dimension=2, + node_dimensions=("node_dimension1", "node_dimension2"), + face_dimensions=( + sgrid.DimDimPadding("face_dimension1", "node_dimension1", sgrid.Padding.LOW), + sgrid.DimDimPadding("face_dimension2", "node_dimension2", sgrid.Padding.LOW), + ), + node_coordinates=node_coordinates, + vertical_dimensions=vertical_dimensions, + ) + + +def create_example_grid3dmetadata(with_node_coordinates: bool): + node_coordinates = ( + ("node_coordinates_var1", "node_coordinates_var2", "node_coordinates_dim3") if with_node_coordinates else None + ) + return sgrid.Grid3DMetadata( + cf_role="grid_topology", + topology_dimension=3, + node_dimensions=("node_dimension1", "node_dimension2", "node_dimension3"), + volume_dimensions=( + sgrid.DimDimPadding("face_dimension1", "node_dimension1", sgrid.Padding.LOW), + sgrid.DimDimPadding("face_dimension2", "node_dimension2", sgrid.Padding.LOW), + sgrid.DimDimPadding("face_dimension3", "node_dimension3", sgrid.Padding.LOW), + ), + node_coordinates=node_coordinates, + ) + + +grid2dmetadata = create_example_grid2dmetadata(with_vertical_dimensions=True, with_node_coordinates=True) +grid3dmetadata = create_example_grid3dmetadata(with_node_coordinates=True) def dummy_sgrid_ds(grid: sgrid.Grid2DMetadata | sgrid.Grid3DMetadata) -> xr.Dataset: @@ -225,45 +245,10 @@ def test_parse_sgrid_3d(grid_metadata: sgrid.Grid3DMetadata): @pytest.mark.parametrize( "grid", [ - ( - sgrid.Grid2DMetadata( - cf_role="grid_topology", - topology_dimension=2, - node_dimensions=("node_dimension1", "node_dimension2"), - face_dimensions=( - sgrid.DimDimPadding("face_dimension1", "node_dimension1", sgrid.Padding.LOW), - sgrid.DimDimPadding("face_dimension2", "node_dimension2", sgrid.Padding.LOW), - ), - vertical_dimensions=( - sgrid.DimDimPadding("vertical_dimensions_dim1", "vertical_dimensions_dim2", sgrid.Padding.LOW), - ), - ) - ), - ( - sgrid.Grid2DMetadata( - cf_role="grid_topology", - topology_dimension=2, - node_dimensions=("node_dimension1", "node_dimension2"), - face_dimensions=( - sgrid.DimDimPadding("face_dimension1", "node_dimension1", sgrid.Padding.LOW), - sgrid.DimDimPadding("face_dimension2", "node_dimension2", sgrid.Padding.LOW), - ), - vertical_dimensions=None, - ) - ), - ( - sgrid.Grid3DMetadata( - cf_role="grid_topology", - topology_dimension=3, - node_dimensions=("node_dimension1", "node_dimension2", "node_dimension3"), - volume_dimensions=( - sgrid.DimDimPadding("face_dimension1", "node_dimension1", sgrid.Padding.LOW), - sgrid.DimDimPadding("face_dimension2", "node_dimension2", sgrid.Padding.LOW), - sgrid.DimDimPadding("face_dimension3", "node_dimension3", sgrid.Padding.LOW), - ), - ) - ), - ], + create_example_grid2dmetadata(with_node_coordinates=i, with_vertical_dimensions=j) + for i, j in itertools.product([False, True], [False, True]) + ] + + [create_example_grid3dmetadata(with_node_coordinates=i) for i in [False, True]], ) def test_rename_dims(grid): dims = sgrid.get_unique_dim_names(grid)