diff --git a/docs/api.rst b/docs/api.rst index df323f2f8..a68f91ba9 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -18,6 +18,7 @@ Top Level Functions open_grid open_dataset + open_multigrid open_mfdataset concat diff --git a/test/core/test_api.py b/test/core/test_api.py index bb79ef364..150c2877c 100644 --- a/test/core/test_api.py +++ b/test/core/test_api.py @@ -4,10 +4,13 @@ import pytest import tempfile import xarray as xr +from pathlib import Path from unittest.mock import patch from uxarray.core.utils import _open_dataset_with_fallback import os +TEST_MESHFILES = Path(__file__).resolve().parent.parent / "meshfiles" + def test_open_geoflow_dataset(gridpath, datasetpath): """Loads a single dataset with its grid topology file using uxarray's open_dataset call.""" @@ -163,3 +166,71 @@ def mock_open_dataset(*args, **kwargs): ds_fallback.close() if os.path.exists(tmp_path): os.unlink(tmp_path) + + +def test_list_grid_names_multigrid(gridpath): + """List grids from an OASIS-style multi-grid file.""" + grid_file = gridpath("scrip", "oasis", "grids.nc") + grid_names = ux.list_grid_names(grid_file) + + assert isinstance(grid_names, list) + assert set(grid_names) == {"ocn", "atm"} + + +def test_list_grid_names_single_scrip(): + """List grids from a standard single-grid SCRIP file.""" + grid_path = TEST_MESHFILES / "scrip" / "outCSne8" / "outCSne8.nc" + grid_names = ux.list_grid_names(grid_path) + + assert isinstance(grid_names, list) + assert grid_names == ["grid"] + + +def test_open_multigrid_all_grids(gridpath): + """Open all grids from a multi-grid file.""" + grid_file = gridpath("scrip", "oasis", "grids.nc") + grids = ux.open_multigrid(grid_file) + + assert isinstance(grids, dict) + assert set(grids.keys()) == {"ocn", "atm"} + assert grids["ocn"].n_face == 12 + assert grids["atm"].n_face == 20 + + +def test_open_multigrid_specific_grids(gridpath): + """Open a subset of grids from a multi-grid file.""" + grid_file = gridpath("scrip", "oasis", "grids.nc") + grids = ux.open_multigrid(grid_file, gridnames=["ocn"]) + + assert set(grids.keys()) == {"ocn"} + assert grids["ocn"].n_face == 12 + + +def test_open_multigrid_with_masks(gridpath): + """Open grids with a companion mask file.""" + grid_file = gridpath("scrip", "oasis", "grids.nc") + mask_file = gridpath("scrip", "oasis", "masks.nc") + + grids = ux.open_multigrid(grid_file, mask_filename=mask_file) + + assert grids["ocn"].n_face == 8 + assert grids["atm"].n_face == 20 + + +def test_open_multigrid_mask_zero_faces(gridpath): + """Applying masks that deactivate an entire grid should not fail.""" + grid_file = gridpath("scrip", "oasis", "grids.nc") + mask_file = gridpath("scrip", "oasis", "masks_no_atm.nc") + + grids = ux.open_multigrid(grid_file, mask_filename=mask_file) + + assert grids["ocn"].n_face == 8 + assert grids["atm"].n_face == 0 + + +def test_open_multigrid_missing_grid_error(gridpath): + """Requesting a missing grid should raise.""" + grid_file = gridpath("scrip", "oasis", "grids.nc") + + with pytest.raises(ValueError, match="Grid 'land' not found"): + ux.open_multigrid(grid_file, gridnames=["land"]) diff --git a/test/grid/grid/test_initialization.py b/test/grid/grid/test_initialization.py index 0ab753885..b59a836b3 100644 --- a/test/grid/grid/test_initialization.py +++ b/test/grid/grid/test_initialization.py @@ -1,6 +1,7 @@ import numpy as np import numpy.testing as nt import pytest +import xarray as xr import uxarray as ux from uxarray.constants import INT_FILL_VALUE, ERROR_TOLERANCE @@ -77,3 +78,23 @@ def test_from_topology(): face_node_connectivity=face_node_connectivity, fill_value=-1, ) + + +def test_grid_init_handles_empty_longitude_fields(): + """Ensure grids with empty longitude arrays don't error during initialization.""" + empty_lon = np.array([], dtype=np.float64) + ds = xr.Dataset( + { + "node_lon": (("n_node",), empty_lon), + "node_lat": (("n_node",), empty_lon), + "face_node_connectivity": ( + ("n_face", "n_max_face_nodes"), + np.empty((0, 0), dtype=np.int64), + ), + "face_lon": (("n_face",), empty_lon), + } + ) + + uxgrid = ux.Grid(ds, source_grid_spec="UGRID") + + assert uxgrid.n_face == 0 diff --git a/test/io/test_scrip.py b/test/io/test_scrip.py index aecc074c3..57b844802 100644 --- a/test/io/test_scrip.py +++ b/test/io/test_scrip.py @@ -1,11 +1,13 @@ import os import xarray as xr import warnings +import numpy as np import numpy.testing as nt import pytest import uxarray as ux from uxarray.constants import INT_DTYPE, INT_FILL_VALUE +from uxarray.io._scrip import _detect_multigrid def test_read_ugrid(gridpath, mesh_constants): @@ -50,3 +52,69 @@ def test_to_xarray_ugrid(gridpath): reloaded_grid._ds.close() del reloaded_grid os.remove("scrip_ugrid_csne8.nc") + + +def test_oasis_multigrid_format_detection(): + """Detect OASIS-style multi-grid naming.""" + ds = xr.Dataset() + ds["ocn.cla"] = xr.DataArray(np.random.rand(100, 4), dims=["nc_ocn", "nv_ocn"]) + ds["ocn.clo"] = xr.DataArray(np.random.rand(100, 4), dims=["nc_ocn", "nv_ocn"]) + ds["atm.cla"] = xr.DataArray(np.random.rand(200, 4), dims=["nc_atm", "nv_atm"]) + ds["atm.clo"] = xr.DataArray(np.random.rand(200, 4), dims=["nc_atm", "nv_atm"]) + + format_type, grids = _detect_multigrid(ds) + assert format_type == "multi_scrip" + assert set(grids.keys()) == {"ocn", "atm"} + + +def test_open_multigrid_with_masks(gridpath): + """Load OASIS multi-grids with masks applied.""" + grid_file = gridpath("scrip", "oasis", "grids.nc") + mask_file = gridpath("scrip", "oasis", "masks.nc") + + grids = ux.open_multigrid(grid_file, mask_filename=mask_file) + assert grids["ocn"].n_face == 8 + assert grids["atm"].n_face == 20 + + ocean_only = ux.open_multigrid( + grid_file, gridnames=["ocn"], mask_filename=mask_file + ) + assert set(ocean_only.keys()) == {"ocn"} + assert ocean_only["ocn"].n_face == 8 + + grid_names = ux.list_grid_names(grid_file) + assert set(grid_names) == {"ocn", "atm"} + + +def test_open_multigrid_mask_active_value_default(gridpath): + """Default mask semantics keep value==1 active for both grids.""" + grid_file = gridpath("scrip", "oasis", "grids.nc") + mask_file = gridpath("scrip", "oasis", "masks_no_atm.nc") + + grids = ux.open_multigrid(grid_file, mask_filename=mask_file) + + with xr.open_dataset(mask_file) as mask_ds: + expected_ocn = int(mask_ds["ocn.msk"].values.sum()) + expected_atm = int(mask_ds["atm.msk"].values.sum()) + + assert grids["ocn"].n_face == expected_ocn + assert grids["atm"].n_face == expected_atm + + +def test_open_multigrid_mask_active_value_per_grid_override(gridpath): + """Per-grid override supports masks with different active values.""" + grid_file = gridpath("scrip", "oasis", "grids.nc") + mask_file = gridpath("scrip", "oasis", "masks_no_atm.nc") + + grids = ux.open_multigrid( + grid_file, + mask_filename=mask_file, + mask_active_value={"atm": 0, "ocn": 1}, + ) + + with xr.open_dataset(mask_file) as mask_ds: + expected_ocn = int(mask_ds["ocn.msk"].values.sum()) + expected_atm = int((mask_ds["atm.msk"].values == 0).sum()) + + assert grids["ocn"].n_face == expected_ocn + assert grids["atm"].n_face == expected_atm diff --git a/test/meshfiles/scrip/oasis/README.md b/test/meshfiles/scrip/oasis/README.md new file mode 100644 index 000000000..13a8a23d9 --- /dev/null +++ b/test/meshfiles/scrip/oasis/README.md @@ -0,0 +1,44 @@ +# OASIS Multi-Grid SCRIP Test Files + +This directory contains small test files for OASIS/YAC multi-grid SCRIP format support in UXarray. + +## Files + +- `grids.nc`: Multi-grid file containing two grids + - `ocn`: Ocean grid with 12 cells (3x4 regular grid) + - `atm`: Atmosphere grid with 20 cells (4x5 regular grid) + +- `masks.nc`: Domain masks for the grids + - `ocn.msk`: Ocean mask (8 ocean cells, 4 land cells) + - `atm.msk`: Atmosphere mask (all 20 cells active) + +## OASIS Format + +OASIS uses a specific naming convention for multi-grid SCRIP files: +- Grid variables are prefixed with grid name: `.` +- Corner latitudes: `.cla` +- Corner longitudes: `.clo` +- Dimensions: `nc_` (cells), `nv_` (corners) + +## Usage in Tests + +```python +import uxarray as ux + +# List available grids +grid_names = ux.list_grid_names("grids.nc") +# ['ocn', 'atm'] + +# Load all grids +grids = ux.open_multigrid("grids.nc") + +# Load with masks +masked_grids = ux.open_multigrid("grids.nc", mask_filename="masks.nc") +# Ocean grid will have 8 cells, atmosphere grid will have 20 cells +``` + +## File Sizes + +These files are intentionally small for fast testing: +- `grids.nc`: ~3 KB +- `masks.nc`: ~1 KB diff --git a/test/meshfiles/scrip/oasis/create_oasis_test_files.py b/test/meshfiles/scrip/oasis/create_oasis_test_files.py new file mode 100644 index 000000000..2c6e9e210 --- /dev/null +++ b/test/meshfiles/scrip/oasis/create_oasis_test_files.py @@ -0,0 +1,147 @@ +#!/usr/bin/env python +"""Create small OASIS test files for UXarray tests.""" + +from __future__ import annotations + +from pathlib import Path + +import numpy as np +import xarray as xr + +BASE_DIR = Path(__file__).parent + + +def create_oasis_test_files() -> None: + """Generate deterministic OASIS-style multi-grid SCRIP files.""" + output_dir = BASE_DIR + output_dir.mkdir(parents=True, exist_ok=True) + + grid_path = output_dir / "grids.nc" + mask_path = output_dir / "masks.nc" + + grid_ds = xr.Dataset() + + # Small ocean grid (3 latitude bands x 4 longitude bands = 12 cells) + n_ocean = 12 + grid_ds.coords["nc_ocn"] = np.arange(n_ocean) + grid_ds.coords["nv_ocn"] = np.arange(4) + + ocean_lons = np.array([0, 120, 240, 360]) + ocean_lats = np.array([-60, -20, 20, 60]) + + ocean_clo = np.zeros((n_ocean, 4)) + ocean_cla = np.zeros((n_ocean, 4)) + + idx = 0 + for j in range(3): # latitude bands + for i in range(4): # longitude bands + ocean_clo[idx, 0] = ocean_lons[i] + ocean_clo[idx, 1] = ocean_lons[(i + 1) % 4] + ocean_clo[idx, 2] = ocean_lons[(i + 1) % 4] + ocean_clo[idx, 3] = ocean_lons[i] + + ocean_cla[idx, 0] = ocean_lats[j] + ocean_cla[idx, 1] = ocean_lats[j] + ocean_cla[idx, 2] = ocean_lats[j + 1] + ocean_cla[idx, 3] = ocean_lats[j + 1] + idx += 1 + + grid_ds["ocn.clo"] = xr.DataArray( + ocean_clo, + dims=["nc_ocn", "nv_ocn"], + attrs={ + "units": "degrees_east", + "long_name": "ocean grid corner longitude", + }, + ) + grid_ds["ocn.cla"] = xr.DataArray( + ocean_cla, + dims=["nc_ocn", "nv_ocn"], + attrs={ + "units": "degrees_north", + "long_name": "ocean grid corner latitude", + }, + ) + + # Small atmosphere grid (4 latitude bands x 5 longitude bands = 20 cells) + n_atmos = 20 + grid_ds.coords["nc_atm"] = np.arange(n_atmos) + grid_ds.coords["nv_atm"] = np.arange(4) + + atm_lons = np.array([0, 90, 180, 270, 360]) + atm_lats = np.array([-90, -45, 0, 45, 90]) + + atm_clo = np.zeros((n_atmos, 4)) + atm_cla = np.zeros((n_atmos, 4)) + + idx = 0 + for j in range(4): + for i in range(5): + atm_clo[idx, 0] = atm_lons[i] + atm_clo[idx, 1] = atm_lons[(i + 1) % 5] + atm_clo[idx, 2] = atm_lons[(i + 1) % 5] + atm_clo[idx, 3] = atm_lons[i] + + atm_cla[idx, 0] = atm_lats[j] + atm_cla[idx, 1] = atm_lats[j] + atm_cla[idx, 2] = atm_lats[j + 1] + atm_cla[idx, 3] = atm_lats[j + 1] + idx += 1 + + grid_ds["atm.clo"] = xr.DataArray( + atm_clo, + dims=["nc_atm", "nv_atm"], + attrs={ + "units": "degrees_east", + "long_name": "atmosphere grid corner longitude", + }, + ) + grid_ds["atm.cla"] = xr.DataArray( + atm_cla, + dims=["nc_atm", "nv_atm"], + attrs={ + "units": "degrees_north", + "long_name": "atmosphere grid corner latitude", + }, + ) + + grid_ds.attrs["title"] = "OASIS multi-grid test file" + grid_ds.attrs["description"] = "Small test grids for UXarray OASIS support" + grid_ds.attrs["conventions"] = "SCRIP" + grid_ds.attrs["grid_type"] = "curvilinear" + + grid_ds.to_netcdf(grid_path, engine="scipy") + + mask_ds = xr.Dataset() + + ocean_mask = np.ones(n_ocean, dtype=np.int32) + ocean_mask[8:] = 0 + mask_ds["ocn.msk"] = xr.DataArray( + ocean_mask, + dims=["nc_ocn"], + attrs={ + "long_name": "ocean domain mask", + "valid_values": "0: land, 1: ocean", + }, + ) + + atmos_mask = np.ones(n_atmos, dtype=np.int32) + mask_ds["atm.msk"] = xr.DataArray( + atmos_mask, + dims=["nc_atm"], + attrs={ + "long_name": "atmosphere domain mask", + "valid_values": "0: inactive, 1: active", + }, + ) + + mask_ds.attrs["title"] = "OASIS mask file" + mask_ds.attrs["description"] = "Domain masks for ocean and atmosphere grids" + + mask_ds.to_netcdf(mask_path, engine="scipy") + + print(f"Created {grid_path} and {mask_path}") + + +if __name__ == "__main__": + create_oasis_test_files() diff --git a/test/meshfiles/scrip/oasis/grids.nc b/test/meshfiles/scrip/oasis/grids.nc new file mode 100644 index 000000000..abd27ee8b Binary files /dev/null and b/test/meshfiles/scrip/oasis/grids.nc differ diff --git a/test/meshfiles/scrip/oasis/masks.nc b/test/meshfiles/scrip/oasis/masks.nc new file mode 100644 index 000000000..199f1e4af Binary files /dev/null and b/test/meshfiles/scrip/oasis/masks.nc differ diff --git a/test/meshfiles/scrip/oasis/masks_no_atm.nc b/test/meshfiles/scrip/oasis/masks_no_atm.nc new file mode 100644 index 000000000..294a3ea45 Binary files /dev/null and b/test/meshfiles/scrip/oasis/masks_no_atm.nc differ diff --git a/uxarray/__init__.py b/uxarray/__init__.py index faa08aff3..80b0dcc9e 100644 --- a/uxarray/__init__.py +++ b/uxarray/__init__.py @@ -1,5 +1,12 @@ from .constants import INT_DTYPE, INT_FILL_VALUE -from .core.api import concat, open_dataset, open_grid, open_mfdataset +from .core.api import ( + concat, + list_grid_names, + open_dataset, + open_grid, + open_mfdataset, + open_multigrid, +) from .core.dataarray import UxDataArray from .core.dataset import UxDataset from .grid import Grid @@ -20,6 +27,8 @@ "open_grid", "open_dataset", "open_mfdataset", + "open_multigrid", + "list_grid_names", "concat", "UxDataset", "UxDataArray", diff --git a/uxarray/core/api.py b/uxarray/core/api.py index f357a267e..59def0a52 100644 --- a/uxarray/core/api.py +++ b/uxarray/core/api.py @@ -1,7 +1,17 @@ from __future__ import annotations import os -from typing import TYPE_CHECKING, Any, Dict, Optional, Union +from pathlib import Path +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Mapping, + Optional, + Sequence, + TypeAlias, + Union, +) from warnings import warn import numpy as np @@ -13,10 +23,25 @@ match_chunks_to_ugrid, ) from uxarray.grid import Grid +from uxarray.io._scrip import ( + _detect_multigrid, + _extract_single_grid, + _read_scrip, + _resolve_cell_dims, + _stack_cell_dims, +) if TYPE_CHECKING: from xarray import Dataset +__all__ = [ + "open_grid", + "open_multigrid", + "list_grid_names", + "open_dataset", + "open_mfdataset", +] + def open_grid( grid_filename_or_obj: str | os.PathLike[Any] | dict | Dataset, @@ -123,14 +148,226 @@ def open_grid( return grid +MaskValue: TypeAlias = Any | Sequence[Any] +MaskActiveValue: TypeAlias = MaskValue | Mapping[str, MaskValue] | None + + +def open_multigrid( + grid_filename_or_obj: str | Path | "Dataset", + gridnames: list[str] | None = None, + mask_filename: str | Path | "Dataset" | None = None, + mask_active_value: MaskActiveValue = 1, + **kwargs: dict[str, Any], +) -> dict[str, Grid]: + """Open a multi-grid SCRIP file and construct ``Grid`` objects. + + Parameters + ---------- + grid_filename_or_obj : str, Path or xr.Dataset + Path to the multi-grid SCRIP file or an already opened dataset. + gridnames : list of str, optional + Specific grid names to load. If ``None``, all grids are loaded. + mask_filename : str, Path or xr.Dataset, optional + Optional path to a mask file containing ``.msk`` variables. + Defaults to retaining cells where mask value equals 1. + mask_active_value : scalar, sequence or mapping[str, scalar/sequence], optional + Mask value(s) treated as active. Provide a scalar or sequence to apply to + all grids, or a dict keyed by grid name for per-grid overrides. When a + mapping is provided and a grid name is not found, the fallback is the + mapping's ``"_default"``/``"default"`` entry if present, otherwise 1. + **kwargs : dict, optional + Extra keyword arguments forwarded to :func:`xarray.open_dataset` + when opening ``grid_filename_or_obj``. + + Returns + ------- + dict[str, Grid] + Dictionary mapping grid names to ``Grid`` objects. + """ + import xarray as xr + + grid_ds_opened = False + if isinstance(grid_filename_or_obj, xr.Dataset): + grid_ds = grid_filename_or_obj + else: + grid_ds = xr.open_dataset(grid_filename_or_obj, **kwargs) + grid_ds_opened = True + + mask_ds = None + mask_ds_opened = False + if mask_filename is not None: + if isinstance(mask_filename, xr.Dataset): + mask_ds = mask_filename + else: + mask_ds = xr.open_dataset(mask_filename) + mask_ds_opened = True + + try: + active_value_map: Mapping[str, MaskValue] | None = ( + mask_active_value if isinstance(mask_active_value, Mapping) else None + ) + default_active_value: MaskValue = ( + active_value_map.get("_default", active_value_map.get("default", 1)) + if active_value_map is not None + else (mask_active_value if mask_active_value is not None else 1) + ) + if default_active_value is None: + default_active_value = 1 + + active_value_cache: dict[str, np.ndarray] = {} + + def _normalize_active_values(value: MaskValue | None) -> np.ndarray: + """Normalize mask active values to a 1D numpy array.""" + if value is None: + value = default_active_value + + if isinstance(value, (str, bytes)) or not np.iterable(value): + return np.asarray([value]) + + return np.asarray(list(value)).ravel() + + def _active_mask_values_for_grid(grid_name: str) -> np.ndarray: + """Return the active mask value(s) for a grid as a 1D array.""" + cached = active_value_cache.get(grid_name) + if cached is not None: + return cached + + if active_value_map is not None: + value = active_value_map.get(grid_name, default_active_value) + else: + value = ( + mask_active_value + if mask_active_value is not None + else default_active_value + ) + + active_values = _normalize_active_values(value) + active_value_cache[grid_name] = active_values + return active_values + + format_type, grids_dict = _detect_multigrid(grid_ds) + + if format_type == "single_scrip": + if gridnames is not None and "grid" not in gridnames: + raise ValueError( + f"Requested grids {gridnames} not found. " + "This file contains a single grid named 'grid'." + ) + grid_ds_ugrid, source_dims_dict = _read_scrip(grid_ds) + return { + "grid": Grid( + grid_ds_ugrid, + source_grid_spec="Scrip", + source_dims_dict=source_dims_dict, + ) + } + + if not grids_dict: + raise ValueError(f"No grids detected in file: {grid_filename_or_obj}") + + available_grids = list(grids_dict.keys()) + + if gridnames is None: + grids_to_load = available_grids + else: + if isinstance(gridnames, str): + requested = [gridnames] + else: + requested = list(gridnames) + + grids_to_load = [] + for name in requested: + if name not in grids_dict: + raise ValueError( + f"Grid '{name}' not found. Available grids: {available_grids}" + ) + grids_to_load.append(name) + + loaded_grids: dict[str, Grid] = {} + for grid_name in grids_to_load: + metadata = grids_dict[grid_name] + scrip_ds = _extract_single_grid(grid_ds, grid_name, metadata) + grid_ds_ugrid, source_dims_dict = _read_scrip(scrip_ds) + + grid = Grid( + grid_ds_ugrid, + source_grid_spec="Scrip", + source_dims_dict=source_dims_dict, + ) + + if mask_ds is not None: + mask_var = f"{grid_name}.msk" + if mask_var in mask_ds: + mask_da = mask_ds[mask_var] + mask_cell_dims = _resolve_cell_dims(metadata, mask_da.dims) + mask_flat = _stack_cell_dims(mask_da, mask_cell_dims, "grid_size") + mask_values = np.asarray(mask_flat.values) + active_values = _active_mask_values_for_grid(grid_name) + active_mask = np.isin(mask_values, active_values) + active_indices = np.flatnonzero(active_mask) + grid = grid.isel(n_face=active_indices) + else: + warn( + f"Mask variable '{mask_var}' not found in mask file; " + f"grid '{grid_name}' will be returned without masking." + ) + + loaded_grids[grid_name] = grid + + return loaded_grids + finally: + if grid_ds_opened: + grid_ds.close() + if mask_ds is not None and mask_ds_opened: + mask_ds.close() + + +def list_grid_names( + grid_filename_or_obj: str | Path | "Dataset", **kwargs: dict[str, Any] +) -> list[str]: + """List all grid names available within a grid file. + + Parameters + ---------- + grid_filename_or_obj : str, Path or xr.Dataset + Path to the grid file or an already opened dataset. + **kwargs : dict, optional + Additional keyword arguments forwarded to :func:`xarray.open_dataset`. + + Returns + ------- + list[str] + ``['grid']`` for single-grid files or the detected grid names for + multi-grid files. + """ + import xarray as xr + + grid_ds_opened = False + if isinstance(grid_filename_or_obj, xr.Dataset): + grid_ds = grid_filename_or_obj + else: + grid_ds = xr.open_dataset(grid_filename_or_obj, **kwargs) + grid_ds_opened = True + + try: + format_type, grids_dict = _detect_multigrid(grid_ds) + names = list(grids_dict.keys()) + if format_type == "single_scrip": + return names or ["grid"] + return names + finally: + if grid_ds_opened: + grid_ds.close() + + def open_dataset( grid_filename_or_obj: str | os.PathLike[Any] | dict | Dataset, filename_or_obj: str | os.PathLike[Any], chunks=None, chunk_grid: bool = True, - use_dual: Optional[bool] = False, - grid_kwargs: Optional[Dict[str, Any]] = None, - **kwargs: Dict[str, Any], + use_dual: bool | None = False, + grid_kwargs: dict[str, Any] | None = None, + **kwargs: dict[str, Any], ) -> UxDataset: """Wraps ``xarray.open_dataset()`` for loading in a dataset paired with a grid file. diff --git a/uxarray/grid/coordinates.py b/uxarray/grid/coordinates.py index a00e2a002..1a5dbd476 100644 --- a/uxarray/grid/coordinates.py +++ b/uxarray/grid/coordinates.py @@ -695,8 +695,10 @@ def _set_desired_longitude_range(uxgrid): with xr.set_options(keep_attrs=True): for lon_name in ["node_lon", "edge_lon", "face_lon"]: if lon_name in uxgrid._ds: - if uxgrid._ds[lon_name].max() > 180: - da = uxgrid._ds[lon_name] + da = uxgrid._ds[lon_name] + if da.size == 0: + continue + if da.max() > 180: wrapped = (uxgrid._ds[lon_name] + 180) % 360 - 180 wrapped.name = da.name uxgrid._ds[lon_name] = wrapped diff --git a/uxarray/io/_scrip.py b/uxarray/io/_scrip.py index 7db3c5377..d166f1b13 100644 --- a/uxarray/io/_scrip.py +++ b/uxarray/io/_scrip.py @@ -1,3 +1,5 @@ +from typing import Any, Dict, List, Optional, Sequence, Tuple + import numpy as np import polars as pl import xarray as xr @@ -239,12 +241,12 @@ def grid_center_lat_lon(ds): Returns ------- - center_lon : :class:`numpy.ndarray` - The calculated center longitudes of the grid box based on the corner - points - center_lat : :class:`numpy.ndarray` + center_lat : array-like The calculated center latitudes of the grid box based on the corner - points + points. Preserves chunking when inputs are backed by Dask arrays. + center_lon : array-like + The calculated center longitudes of the grid box based on the corner + points. Preserves chunking when inputs are backed by Dask arrays. """ # Calculate and create grid center lat/lon @@ -266,7 +268,257 @@ def grid_center_lat_lon(ds): center_lon = np.rad2deg(np.arctan2(y, x)) center_lat = np.rad2deg(np.arctan2(z, np.sqrt(x**2 + y**2))) - # Make negative lons positive - center_lon[center_lon < 0] += 360 + # Normalize negative longitudes without forcing eager computation + center_lon = center_lon.where(center_lon >= 0, center_lon + 360) + + return center_lat.data, center_lon.data + + +def _detect_multigrid(ds: xr.Dataset) -> Tuple[str, Dict[str, Dict[str, Any]]]: + """Detect whether a dataset follows single-grid or multi-grid SCRIP format. + + Parameters + ---------- + ds : xr.Dataset + Dataset to inspect. + + Returns + ------- + tuple + A tuple of (format_type, grids_dict). ``format_type`` is either + ``\"single_scrip\"`` or ``\"multi_scrip\"``. ``grids_dict`` maps grid + names to their variable metadata when multi-grid files are detected. + """ + + # Quick exit for canonical single-grid SCRIP files + if {"grid_corner_lat", "grid_corner_lon"}.issubset(set(ds.variables)): + return "single_scrip", {"grid": {}} + + # Collect candidate grids from dimension names + grids: Dict[str, Dict[str, Any]] = {} + for dim_name in ds.dims: + if dim_name.startswith("nc_"): + info = grids.setdefault(dim_name[3:], {}) + info.setdefault("cell_dims", []).append(dim_name) + info.setdefault("cell_dim", dim_name) + elif dim_name.startswith("nv_"): + grids.setdefault(dim_name[3:], {})["corner_dim"] = dim_name + + def _infer_corner_dim_from_dims(dims: Sequence[str]) -> Optional[str]: + for dim in dims: + dim_lower = dim.lower() + if dim_lower.startswith(("nv", "nvertex", "corner", "corn", "crn")): + return dim + return dims[-1] if dims else None + + def _update_grid_dim_metadata(info: Dict[str, Any], dims: Sequence[str]) -> None: + if not dims: + return + + corner_dim = info.get("corner_dim") + inferred_corner_dim = _infer_corner_dim_from_dims(dims) + if corner_dim is None or corner_dim not in dims: + corner_dim = inferred_corner_dim + if corner_dim is not None: + info["corner_dim"] = corner_dim + + cell_dims = [dim for dim in dims if dim != corner_dim] + if not cell_dims: + cell_dims = list(dims) + + info["cell_dims"] = list(cell_dims) + info["cell_dim"] = cell_dims[0] + + corner_lat_suffixes = {"cla", "corner_lat", "cornlat"} + corner_lon_suffixes = {"clo", "corner_lon", "cornlon"} + center_lat_suffixes = {"center_lat", "cenlat", "gclat", "clat"} + center_lon_suffixes = {"center_lon", "cenlon", "gclon", "clon"} + + # Parse OASIS-style . names + for var_name in ds.data_vars: + if "." not in var_name: + continue + grid_name, suffix = var_name.split(".", 1) + info = grids.setdefault(grid_name, {}) + suffix_lower = suffix.lower() + + if suffix_lower in corner_lat_suffixes: + info["corner_lat"] = var_name + dims = ds[var_name].dims + if len(dims) >= 2: + _update_grid_dim_metadata(info, dims) + elif suffix_lower in corner_lon_suffixes: + info["corner_lon"] = var_name + dims = ds[var_name].dims + if len(dims) >= 2: + _update_grid_dim_metadata(info, dims) + elif suffix_lower in center_lat_suffixes: + info["center_lat"] = var_name + elif suffix_lower in center_lon_suffixes: + info["center_lon"] = var_name + + # Keep only grids that have the required corner variables + parsed_grids = { + name: meta + for name, meta in grids.items() + if "corner_lat" in meta and "corner_lon" in meta + } + + if parsed_grids: + return "multi_scrip", parsed_grids + + return "single_scrip", {} + + +def _resolve_cell_dims( + metadata: Dict[str, Any], + data_dims: Sequence[str], + corner_dim: Optional[str] = None, +) -> List[str]: + """Determine which dimensions describe cells for a grid variable.""" + + dims_from_meta = metadata.get("cell_dims") + cell_dims: List[str] = [] + if isinstance(dims_from_meta, (list, tuple)): + cell_dims = [dim for dim in dims_from_meta if dim in data_dims] + elif "cell_dim" in metadata and metadata["cell_dim"] in data_dims: + cell_dims = [metadata["cell_dim"]] + if not cell_dims: + cell_dims = [dim for dim in data_dims] + + if corner_dim is not None: + cell_dims = [dim for dim in cell_dims if dim != corner_dim] + + if not cell_dims: + cell_dims = [dim for dim in data_dims if dim != corner_dim] + + if not cell_dims: + raise ValueError("Unable to determine cell dimensions for grid variable.") + + return cell_dims + + +def _stack_cell_dims( + data_array: xr.DataArray, cell_dims: Sequence[str], new_dim: str +) -> xr.DataArray: + """Stack one or more cell dimensions into a single new dimension.""" + + dims_in_array = [dim for dim in cell_dims if dim in data_array.dims] + if not dims_in_array: + if new_dim in data_array.dims: + return data_array + raise ValueError( + f"Unable to stack dimensions {cell_dims}; none are present in {data_array.dims}" + ) + + if len(dims_in_array) == 1: + dim = dims_in_array[0] + if dim == new_dim: + return data_array + return data_array.rename({dim: new_dim}) + + stacked = data_array.stack({new_dim: dims_in_array}) + # Remove MultiIndex so grid_size behaves like a standard dimension + stacked = stacked.reset_index(new_dim, drop=True) + # Ensure the new dimension is the leading axis for consistency + remaining_dims = [dim for dim in stacked.dims if dim != new_dim] + return stacked.transpose(new_dim, *remaining_dims) + + +def _extract_single_grid( + ds: xr.Dataset, grid_name: str, metadata: Dict[str, Any] +) -> xr.Dataset: + """Extract a single grid from a multi-grid SCRIP dataset. + + Parameters + ---------- + ds : xr.Dataset + The source multi-grid dataset. + grid_name : str + Name of the grid to extract. + metadata : dict + Mapping that describes the variable and dimension names for the grid. + + Returns + ------- + xr.Dataset + Dataset encoded in standard SCRIP single-grid format. + """ + + if "corner_lat" not in metadata or "corner_lon" not in metadata: + raise ValueError(f"Grid '{grid_name}' is missing corner variables.") + + corner_lat = ds[metadata["corner_lat"]] + corner_lon = ds[metadata["corner_lon"]] + + dims = list(corner_lat.dims) + if len(dims) < 2: + raise ValueError(f"Corner variable for grid '{grid_name}' must be at least 2D.") + + corner_dim = metadata.get("corner_dim", dims[-1]) + cell_dims = _resolve_cell_dims(metadata, dims, corner_dim) + + grid_corner_lat = _stack_cell_dims(corner_lat, cell_dims, "grid_size") + grid_corner_lon = _stack_cell_dims(corner_lon, cell_dims, "grid_size") + + if corner_dim != "grid_corners": + grid_corner_lat = grid_corner_lat.rename({corner_dim: "grid_corners"}) + grid_corner_lon = grid_corner_lon.rename({corner_dim: "grid_corners"}) + + grid_corner_lat = grid_corner_lat.copy() + grid_corner_lon = grid_corner_lon.copy() + + result = xr.Dataset() + result["grid_corner_lat"] = grid_corner_lat + result["grid_corner_lon"] = grid_corner_lon + + n_cells = grid_corner_lat.sizes["grid_size"] + + # Center coordinates: use supplied variables if available, otherwise compute + center_lat = metadata.get("center_lat") + center_lon = metadata.get("center_lon") + + computed_lat_lon = None + + if center_lat and center_lat in ds: + center_lat_da = _stack_cell_dims( + ds[center_lat], + _resolve_cell_dims(metadata, ds[center_lat].dims), + "grid_size", + ).copy() + else: + if computed_lat_lon is None: + computed_lat_lon = grid_center_lat_lon(result) + center_lat_da = xr.DataArray(computed_lat_lon[0], dims=["grid_size"]) + + if center_lon and center_lon in ds: + center_lon_da = _stack_cell_dims( + ds[center_lon], + _resolve_cell_dims(metadata, ds[center_lon].dims), + "grid_size", + ).copy() + else: + if computed_lat_lon is None: + computed_lat_lon = grid_center_lat_lon(result) + center_lon_da = xr.DataArray(computed_lat_lon[1], dims=["grid_size"]) + + result["grid_center_lat"] = center_lat_da + result["grid_center_lon"] = center_lon_da + + # Provide minimal auxiliary variables required by the reader + result["grid_imask"] = xr.DataArray( + np.ones(n_cells, dtype=np.int32), dims=["grid_size"] + ) + result["grid_dims"] = xr.DataArray( + np.array([n_cells], dtype=np.int32), dims=["grid_rank"] + ) + + # Provide a placeholder grid area to satisfy downstream checks + result["grid_area"] = xr.DataArray( + np.ones(n_cells, dtype=np.float64), dims=["grid_size"] + ) + + result.attrs.update(ds.attrs) + result.attrs["grid_name"] = grid_name - return center_lat, center_lon + return result