diff --git a/src/bitsea/commons/mask.py b/src/bitsea/commons/mask.py index 1da1fd03..b7c4b220 100644 --- a/src/bitsea/commons/mask.py +++ b/src/bitsea/commons/mask.py @@ -10,6 +10,7 @@ import numpy as np import xarray as xr from numpy.typing import ArrayLike +from scipy.sparse import dok_array from bitsea.commons.bathymetry import Bathymetry from bitsea.commons.geodistances import extend_from_average @@ -185,7 +186,7 @@ def convert_lon_lat_wetpoint_indices( # Remove all the points whose distance is greater than the max_radius cut_mask[distances > max_radius] = False - # If there is no water in the slice, we return (jp, ip) but we + # If there is no water in the slice, we return (jp, ip), but we # also raise a warning if not np.any(cut_mask): warn( @@ -200,7 +201,7 @@ def convert_lon_lat_wetpoint_indices( distances[~cut_mask] = max_radius * max_radius + 1 # We get the index of the minimum value. We need to unravel because - # argmin works on the flatten array + # argmin works on the flattened array local_min = np.unravel_index( np.argmin(distances, axis=None), distances.shape ) @@ -236,7 +237,7 @@ def mask_at_level(self, z: float) -> np.ndarray: def bathymetry_in_cells(self) -> np.ndarray: """ - Returns a 2d map that for each columns associates the number of water + Returns a 2d map that associates for each column the number of water cells that are present on that column. Returns: @@ -247,7 +248,7 @@ def bathymetry_in_cells(self) -> np.ndarray: def rough_bathymetry(self) -> np.ndarray: """ Calculates the bathymetry used by the model - It does not takes in account e3t + It does not take into account e3t Returns: np.ndarray: a 2d numpy array of floats @@ -259,7 +260,7 @@ def rough_bathymetry(self) -> np.ndarray: def bathymetry(self) -> np.ndarray: """ Calculates the bathymetry used by the model - Best evaluation, since it takes in account e3t. + Best evaluation, since it takes into account e3t. Returns: a 2d numpy array of floats @@ -670,7 +671,7 @@ def save_as_netcdf(self, file_path: Union[PathLike, str]): class MaskBathymetry(Bathymetry): """ - This class is a bathymetry, generated starting from a mask, i.e., it + This class is a bathymetry generated starting from a mask, i.e., it returns the z-coordinate of the bottom face of the deepest cell of the column that contains the point (lon, lat). """ @@ -680,7 +681,7 @@ def __init__(self, mask: Mask): self._bathymetry_data = mask.bathymetry() # Fix the bathymetry of the land cells to 0 (to be coherent with the - # behaviour of the bathymetry classes). Otherwise, if we let the land + # behavior of the bathymetry classes). Otherwise, if we let the land # points to have bathymetry = 1e20, they will be in every # BathymetricBasin self._bathymetry_data[np.logical_not(self._mask[0, :, :])] = 0 @@ -712,3 +713,140 @@ def __call__(self, lon, lat): return output return float(output.squeeze().item()) + + +class MaskWithRivers(Mask): + def __init__( + self, + grid: Grid, + zlevels: ArrayLike, + mask_array: ArrayLike, + river_positions: ArrayLike, + allow_broadcast: bool = False, + e3t: Optional[np.ndarray] = None, + ): + rivers = np.asarray(river_positions) + + mask_dim = len(zlevels), grid.shape[0], grid.shape[1] + + try: + mask_data = np.asarray(mask_array, dtype=bool, copy=False) + input_copied = False + except TypeError: + # Old versions of Numpy do not support the "copy=False" syntax; + # In this case, we call the function without specifying the value + # of the "copy" argument, and we must assume that the data has not + # been copied (for safety) + mask_data = np.asarray(mask_array, dtype=bool) + input_copied = False + except ValueError: + # In this case, it has been impossible to avoid the copy. We run + # again the same command, but this time we force the copy + mask_data = np.asarray(mask_array, dtype=bool) + input_copied = True + + # If mask_data is too small to cover all the mask, we try to broadcast + # it. This enforces a copy + if allow_broadcast and mask_data.shape != mask_dim: + mask_data = np.copy(np.broadcast_to(mask_array, mask_dim)) + input_copied = True + + # Ensure that mask_array is writeable + if not input_copied: + mask_data = np.copy(mask_data) + + # Here we save the cells that belong to each river into a list of + # sparse arrays. We must use a list of 2D array because there are no + # 3D sparse arrays implemented inside scipy. + self._river_cells = [] + rivers = dok_array(rivers) + for depth_index in range(mask_data.shape[0]): + found_cells = False + current_river_cells = dok_array(mask_data.shape[1:], dtype=int) + for (i, j), value in rivers.items(): + if not mask_data[depth_index, i, j]: + continue + current_river_cells[i, j] = value + found_cells = True + + if not found_cells: + break + self._river_cells.append(current_river_cells) + + # Remove rivers from mask_array + for i, j in zip(*rivers.nonzero()): + mask_data[:, i, j] = False + + super().__init__( + grid=grid, + zlevels=zlevels, + mask_array=mask_data, + allow_broadcast=allow_broadcast, + e3t=e3t, + ) + + def get_water_cells(self) -> np.ndarray: + output_data = np.copy(self[:]) + for depth_index, river_cells in enumerate(self._river_cells): + lat_indices, lon_indices = river_cells.nonzero() + output_data[depth_index, lat_indices, lon_indices] = True + return output_data + + @classmethod + def from_file_pointer( + cls, + file_pointer: netCDF4.Dataset, + *, + zlevels_var_name: str = "nav_lev", + ylevels_var_name: str = "nav_lat", + xlevels_var_name: str = "nav_lon", + e3t_var_name: Optional[str] = None, + mask_var_name: str = "tmask", + rivers_var_name: str = "rivers", + read_e3t: bool = True, + ): + raw_mask = Mask.from_file_pointer( + file_pointer, + zlevels_var_name=zlevels_var_name, + ylevels_var_name=ylevels_var_name, + xlevels_var_name=xlevels_var_name, + e3t_var_name=e3t_var_name, + mask_var_name=mask_var_name, + read_e3t=read_e3t, + ) + rivers = np.asarray( + file_pointer.variables[rivers_var_name][:], dtype=int + ) + return MaskWithRivers( + grid=raw_mask.grid, + zlevels=raw_mask.zlevels, + mask_array=raw_mask.as_mutable_array(), + river_positions=rivers, + allow_broadcast=False, + e3t=raw_mask.e3t, + ) + + @classmethod + def from_file( + cls, + file_path: PathLike, + *, + zlevels_var_name: str = "nav_lev", + ylevels_var_name: str = "nav_lat", + xlevels_var_name: str = "nav_lon", + e3t_var_name: Optional[str] = None, + mask_var_name: str = "tmask", + rivers_var_name: str = "rivers", + read_e3t: bool = True, + ): + with netCDF4.Dataset(file_path, "r") as f: + return cls.from_file_pointer( + f, + zlevels_var_name=zlevels_var_name, + ylevels_var_name=ylevels_var_name, + xlevels_var_name=xlevels_var_name, + e3t_var_name=e3t_var_name, + mask_var_name=mask_var_name, + rivers_var_name=rivers_var_name, + read_e3t=read_e3t, + ) diff --git a/tests/commons/test_mask.py b/tests/commons/test_mask.py index eca50d59..fb6f1d56 100644 --- a/tests/commons/test_mask.py +++ b/tests/commons/test_mask.py @@ -1,5 +1,6 @@ from itertools import product as cart_prod +import netCDF4 import numpy as np import pytest @@ -7,6 +8,7 @@ from bitsea.commons.mask import FILL_VALUE from bitsea.commons.mask import Mask from bitsea.commons.mask import MaskBathymetry +from bitsea.commons.mask import MaskWithRivers from bitsea.commons.mesh import Mesh @@ -409,3 +411,28 @@ def test_mask_to_xarray(mask): assert np.allclose(xarray_mask.longitude, mask.xlevels) assert np.allclose(xarray_mask.depth, mask.zlevels) assert np.allclose(xarray_mask.tmask, mask) + + +def test_mask_with_rivers_from_file(test_data_dir): + mask_dir = test_data_dir / "masks" + mask_file = mask_dir / "mask_with_rivers.nc" + + rivers_mask = MaskWithRivers.from_file(mask_file) + water_cells = rivers_mask.get_water_cells() + + with netCDF4.Dataset(mask_file, "r") as ds: + original_mask = np.asarray(ds.variables["tmask"]) > 0.5 + rivers = np.asarray(ds.variables["rivers"]) + + assert np.all(original_mask == water_cells) + + # Check that in the current mask every cell assigned to a river is set + # to "False" + for i, j in zip(*np.where(rivers)): + assert not np.any(rivers_mask[:, i, j]) + + # Assert that if a value inside water_cells is True, then it is also True + # inside rivers_mask (rivers_mask is contained inside water_cells) + assert np.all( + np.logical_or(np.logical_and(rivers_mask, water_cells), ~rivers_mask) + )