diff --git a/docs/src/userguide/examples.rst b/docs/src/userguide/examples.rst index e66a5a14..9eb15980 100644 --- a/docs/src/userguide/examples.rst +++ b/docs/src/userguide/examples.rst @@ -54,6 +54,94 @@ certain regridders. We can do this as follows:: # Use loaded regridder. result = loaded_regridder(source_mesh_cube) +Partitioning a Regridder +------------------------ + +If a regridder would be too large to handle in memory, it can be broken down +into smaller regridders which can collectively do the job of the larger regridder. +This is done using a `Partition` object. + +.. note:: Currently, it is only possible to partition regridding when the source is + a large grid and the target is small enough to fit in memory. + +A `Partition` is made by specifying a source, a target, a list of files, and a way +to divide the source grid into blocks whose regridders are saved to those files:: + + from iris.util import make_gridcube + + from esmf_regrid import ESMFAreaWeighted + from esmf_regrid.experimental.partition import Partition + + # Create a large source cube. + source_cube = make_gridcube(nx=800, ny=800) + + # Create a small target cube. + target_cube = make_gridcube(nx=100, ny=100) + + # Set the regridding scheme. + scheme = AreaWeighted() + + # List a collection of file names/paths to save partial regridders to. + files = ["file_1", "file_2", "file_3", "file_4"] + + # Set the size of each block of the partition. For the keyword `src_chunks` + # this follows the dask chunking API. + src_chunks = (400, 400) + + # Initialise the partition. + partition = Partition( + source_cube, + target_cube, + scheme, + files, + src_chunks=src_chunks + ) + +Initialising the `Partition` will not generate the files automatically unless +the `auto_generate` keyword is set to `True`. In order for this `Partition` to +function, the regridder files must be generated by calling the `generate_files` +method. + +.. note:: Not all files need to be generated at once, if you have a grid which + needs to be split into very many files, it is possible to generate only + a portion of those files within a given session by passing the number + of files to generate as an argument to the regridder. It is possible to + generate the remaining files in a different python session. +:: + + # Generate partial regridders and save them to the list of files. + partition.generate_files() + + # Once the files have been generated, they can be used for regridding. + result = partition.apply_regridders(source_cube) + +Once the files for a regridder have been generated, they can be used to reconstruct +the partition object in a later session. This is done by passing in the list of +files which have already been generated:: + + # Use the same arguments which constructed the original partition. + source_cube = make_gridcube(nx=800, ny=800) + target_cube = make_gridcube(nx=100, ny=100) + scheme = AreaWeighted() + files = ["file_1", "file_2", "file_3", "file_4"] + src_chunks = (400, 400) + + # List the files which have already been generated. + saved_files = ["file_1", "file_2", "file_3", "file_4"] + + # Reconstruct Partition from pre-generated files. + partition = Partition( + source_cube, + target_cube, + scheme, + files, + src_chunks=src_chunks + saved_files=saved_files # Pass in the list of saved files. + ) + + # The new Partition can now be used without the need for generating files. + result = partition.apply_regridders(source_cube) + .. todo: Add more examples. diff --git a/pyproject.toml b/pyproject.toml index e2cd86ef..e7b22a62 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -181,6 +181,7 @@ ignore = [ "ANN202", "ANN204", + "B905", # Zip strictness should be explicit "D104", # Misssing docstring "E501", # Line too long "ERA001", # Commented out code @@ -298,5 +299,5 @@ convention = "numpy" [tool.ruff.lint.pylint] # TODO: refactor to reduce complexity, if possible max-args = 10 -max-branches = 21 +max-branches = 22 max-statements = 110 diff --git a/src/esmf_regrid/esmf_regridder.py b/src/esmf_regrid/esmf_regridder.py index 800fdbd3..c335320c 100644 --- a/src/esmf_regrid/esmf_regridder.py +++ b/src/esmf_regrid/esmf_regridder.py @@ -175,6 +175,46 @@ def _out_dtype(self, in_dtype): ).dtype return out_dtype + def _gen_weights_and_data(self, src_array): + extra_shape = src_array.shape[: -self.src.dims] + + if self.method == Constants.Method.NEAREST: + weight_matrix = self.weight_matrix.astype(src_array.dtype) + else: + weight_matrix = self.weight_matrix + + flat_src = self.src._array_to_matrix(ma.filled(src_array, 0.0)) + flat_tgt = weight_matrix @ flat_src + + src_inverted_mask = self.src._array_to_matrix(~ma.getmaskarray(src_array)) + weight_sums = weight_matrix @ src_inverted_mask + return weight_sums, flat_tgt, extra_shape + + def _regrid_from_weights_and_data( + self, + tgt_weights, + tgt_data, + extra, + norm_type=Constants.NormType.FRACAREA, + mdtol=1, + ): + # Set the minimum mdtol to be slightly higher than 0 to account for rounding + # errors. + mdtol = max(mdtol, 1e-8) + tgt_mask = tgt_weights > 1 - mdtol + normalisations = np.ones_like(tgt_data) + if self.method != Constants.Method.NEAREST: + masked_weight_sums = tgt_weights * tgt_mask + if norm_type == Constants.NormType.FRACAREA: + normalisations[tgt_mask] /= masked_weight_sums[tgt_mask] + elif norm_type == Constants.NormType.DSTAREA: + pass + normalisations = ma.array(normalisations, mask=np.logical_not(tgt_mask)) + + tgt_array = tgt_data * normalisations + tgt_array = self.tgt._matrix_to_array(tgt_array, extra) + return tgt_array + def regrid(self, src_array, norm_type=Constants.NormType.FRACAREA, mdtol=1): """Perform regridding on an array of data. @@ -212,30 +252,8 @@ def regrid(self, src_array, norm_type=Constants.NormType.FRACAREA, mdtol=1): f"got an array with shape ending in {main_shape}." ) raise ValueError(e_msg) - extra_shape = array_shape[: -self.src.dims] - extra_size = max(1, np.prod(extra_shape)) - src_inverted_mask = self.src._array_to_matrix(~ma.getmaskarray(src_array)) - weight_matrix = self.weight_matrix - if self.method == Constants.Method.NEAREST: - # force out_dtype := in_dtype - weight_matrix = weight_matrix.astype(src_array.dtype) - weight_sums = weight_matrix @ src_inverted_mask - out_dtype = self._out_dtype(src_array.dtype) - # Set the minimum mdtol to be slightly higher than 0 to account for rounding - # errors. - mdtol = max(mdtol, 1e-8) - tgt_mask = weight_sums > 1 - mdtol - normalisations = np.ones([self.tgt.size, extra_size], dtype=out_dtype) - if self.method != Constants.Method.NEAREST: - masked_weight_sums = weight_sums * tgt_mask - if norm_type == Constants.NormType.FRACAREA: - normalisations[tgt_mask] /= masked_weight_sums[tgt_mask] - elif norm_type == Constants.NormType.DSTAREA: - pass - normalisations = ma.array(normalisations, mask=np.logical_not(tgt_mask)) - - flat_src = self.src._array_to_matrix(ma.filled(src_array, 0.0)) - flat_tgt = weight_matrix @ flat_src - flat_tgt = flat_tgt * normalisations - tgt_array = self.tgt._matrix_to_array(flat_tgt, extra_shape) + tgt_weights, tgt_data, extra = self._gen_weights_and_data(src_array) + tgt_array = self._regrid_from_weights_and_data( + tgt_weights, tgt_data, extra, norm_type=norm_type, mdtol=mdtol + ) return tgt_array diff --git a/src/esmf_regrid/experimental/_partial.py b/src/esmf_regrid/experimental/_partial.py new file mode 100644 index 00000000..11e5b90c --- /dev/null +++ b/src/esmf_regrid/experimental/_partial.py @@ -0,0 +1,88 @@ +"""Provides a regridder class compatible with Partition.""" + +import numpy as np + +from esmf_regrid.schemes import ( + _create_cube, + _ESMFRegridder, +) + + +class PartialRegridder(_ESMFRegridder): + """Regridder class designed for use in :class:`~esmf_regrid.experimental._partial.Partial`.""" + + def __init__(self, src, tgt, src_slice, tgt_slice, weights, scheme, **kwargs): + """Create a regridder instance for a block of :class:`~esmf_regrid.experimental._partial.Partial`. + + Parameters + ---------- + src : :class:`iris.cube.Cube` + The :class:`~iris.cube.Cube` providing the source. + tgt : :class:`iris.cube.Cube` or :class:`iris.mesh.MeshXY` + The :class:`~iris.cube.Cube` or :class:`~iris.mesh.MeshXY` providing the target. + src_slice : tuple + The upper and lower bounds of the block taken from the original source from which the + ``src`` was derived. + tgt_slice : tuple + The upper and lower bounds of the block taken from the original target from which the + ``tgt`` was derived. + weights : :class:`scipy.sparse.spmatrix` + The weights to use for regridding. + scheme : :class:`~esmf_regrid.schemes.ESMFAreaWeighted` or :class:`~esmf_regrid.schemes.ESMFBilinear` + The scheme used to construct the regridder. + """ + self.src_slice = src_slice # this will be tuple-like + self.tgt_slice = tgt_slice + self.scheme = scheme + + # Pop duplicate kwargs. + for arg in set(kwargs.keys()).intersection(vars(self.scheme)): + kwargs.pop(arg) + + self._regridder = scheme.regridder( + src, + tgt, + precomputed_weights=weights, + **kwargs, + ) + self.__dict__.update(self._regridder.__dict__) + + def __repr__(self): + """Return a representation of the class.""" + result = ( + f"PartialRegridder(" + f"src_slice={self.src_slice}, " + f"tgt_slice={self.tgt_slice}, " + f"scheme={self.scheme})" + ) + return result + + def partial_regrid(self, src): + """Perform the first half of regridding, generating weights and data.""" + dims = self._get_cube_dims(src) + num_dims = len(dims) + standard_in_dims = [-1, -2][:num_dims] + data = np.moveaxis(src.data, dims, standard_in_dims) + result = self.regridder._gen_weights_and_data(data) + return result + + def finish_regridding(self, src_cube, weights, data, extra): + """Perform the second half of regridding, combining weights and data.""" + dims = self._get_cube_dims(src_cube) + + result_data = self.regridder._regrid_from_weights_and_data(weights, data, extra) + + num_out_dims = self.regridder.tgt.dims + num_dims = len(dims) + standard_out_dims = [-1, -2][:num_out_dims] + if num_dims == 2 and num_out_dims == 1: + dims = [min(dims)] + if num_dims == 1 and num_out_dims == 2: + dims = [dims[0] + 1, dims[0]] + + result_data = np.moveaxis(result_data, standard_out_dims, dims) + + result_cube = _create_cube( + result_data, src_cube, dims, self._tgt, len(self._tgt) + ) + return result_cube diff --git a/src/esmf_regrid/experimental/io.py b/src/esmf_regrid/experimental/io.py index e71d7365..56c5958b 100644 --- a/src/esmf_regrid/experimental/io.py +++ b/src/esmf_regrid/experimental/io.py @@ -10,13 +10,17 @@ import esmf_regrid from esmf_regrid import Constants, _load_context, check_method, esmpy +from esmf_regrid.experimental._partial import PartialRegridder from esmf_regrid.experimental.unstructured_scheme import ( GridToMeshESMFRegridder, MeshToGridESMFRegridder, ) from esmf_regrid.schemes import ( + ESMFAreaWeighted, ESMFAreaWeightedRegridder, + ESMFBilinear, ESMFBilinearRegridder, + ESMFNearest, ESMFNearestRegridder, GridRecord, MeshRecord, @@ -28,6 +32,7 @@ ESMFNearestRegridder, GridToMeshESMFRegridder, MeshToGridESMFRegridder, + PartialRegridder, ] _REGRIDDER_NAME_MAP = {rg_class.__name__: rg_class for rg_class in SUPPORTED_REGRIDDERS} _SOURCE_NAME = "regridder_source_field" @@ -47,6 +52,8 @@ _SOURCE_RESOLUTION = "src_resolution" _TARGET_RESOLUTION = "tgt_resolution" _ESMF_ARGS = "esmf_args" +_SRC_SLICE_NAME = "src_slice" +_TGT_SLICE_NAME = "tgt_slice" _VALID_ESMF_KWARGS = [ "pole_method", "regrid_pole_npoints", @@ -118,54 +125,41 @@ def _clean_var_names(cube): con.var_name = None -def save_regridder(rg, filename): - """Save a regridder scheme instance. +def _standard_grid_cube(grid, name): + if grid[0].ndim == 1: + shape = [coord.points.size for coord in grid] + else: + shape = grid[0].shape + data = np.zeros(shape) + cube = Cube(data, var_name=name, long_name=name) + if grid[0].ndim == 1: + cube.add_dim_coord(grid[0], 0) + cube.add_dim_coord(grid[1], 1) + else: + cube.add_aux_coord(grid[0], [0, 1]) + cube.add_aux_coord(grid[1], [0, 1]) + return cube - Saves any of the regridder classes, i.e. - :class:`~esmf_regrid.experimental.unstructured_scheme.GridToMeshESMFRegridder`, - :class:`~esmf_regrid.experimental.unstructured_scheme.MeshToGridESMFRegridder`, - :class:`~esmf_regrid.schemes.ESMFAreaWeightedRegridder`, - :class:`~esmf_regrid.schemes.ESMFBilinearRegridder` or - :class:`~esmf_regrid.schemes.ESMFNearestRegridder`. - . - Parameters - ---------- - rg : :class:`~esmf_regrid.schemes._ESMFRegridder` - The regridder instance to save. - filename : str - The file name to save to. - """ - regridder_type = rg.__class__.__name__ +def _standard_mesh_cube(mesh, location, name): + mesh_coords = mesh.to_MeshCoords(location) + data = np.zeros(mesh_coords[0].points.shape[0]) + cube = Cube(data, var_name=name, long_name=name) + for coord in mesh_coords: + cube.add_aux_coord(coord, 0) + return cube - def _standard_grid_cube(grid, name): - if grid[0].ndim == 1: - shape = [coord.points.size for coord in grid] - else: - shape = grid[0].shape - data = np.zeros(shape) - cube = Cube(data, var_name=name, long_name=name) - if grid[0].ndim == 1: - cube.add_dim_coord(grid[0], 0) - cube.add_dim_coord(grid[1], 1) - else: - cube.add_aux_coord(grid[0], [0, 1]) - cube.add_aux_coord(grid[1], [0, 1]) - return cube - - def _standard_mesh_cube(mesh, location, name): - mesh_coords = mesh.to_MeshCoords(location) - data = np.zeros(mesh_coords[0].points.shape[0]) - cube = Cube(data, var_name=name, long_name=name) - for coord in mesh_coords: - cube.add_aux_coord(coord, 0) - return cube +def _generate_src_tgt(regridder_type, rg, allow_partial): if regridder_type in [ "ESMFAreaWeightedRegridder", "ESMFBilinearRegridder", "ESMFNearestRegridder", + "PartialRegridder", ]: + if regridder_type == "PartialRegridder" and not allow_partial: + e_msg = "To save a PartialRegridder, `allow_partial=True` must be set." + raise ValueError(e_msg) src_grid = rg._src if isinstance(src_grid, GridRecord): src_cube = _standard_grid_cube( @@ -210,12 +204,36 @@ def _standard_mesh_cube(mesh, location, name): tgt_grid = (rg.grid_y, rg.grid_x) tgt_cube = _standard_grid_cube(tgt_grid, _TARGET_NAME) _add_mask_to_cube(rg.tgt_mask, tgt_cube, _TARGET_MASK_NAME) + else: - e_msg = ( - f"Expected a regridder of type `GridToMeshESMFRegridder` or " - f"`MeshToGridESMFRegridder`, got type {regridder_type}." - ) + e_msg = f"Unexpected regridder type {regridder_type}." raise TypeError(e_msg) + return src_cube, tgt_cube + + +def save_regridder(rg, filename, allow_partial=False): + """Save a regridder scheme instance. + + Saves any of the regridder classes, i.e. + :class:`~esmf_regrid.experimental.unstructured_scheme.GridToMeshESMFRegridder`, + :class:`~esmf_regrid.experimental.unstructured_scheme.MeshToGridESMFRegridder`, + :class:`~esmf_regrid.schemes.ESMFAreaWeightedRegridder`, + :class:`~esmf_regrid.schemes.ESMFBilinearRegridder` or + :class:`~esmf_regrid.schemes.ESMFNearestRegridder`. + . + + Parameters + ---------- + rg : :class:`~esmf_regrid.schemes._ESMFRegridder` + The regridder instance to save. + filename : str + The file name to save to. + allow_partial : bool, default=False + If True, allow the saving of :class:`~esmf_regrid.experimental._partial.PartialRegridder` instances. + """ + regridder_type = rg.__class__.__name__ + + src_cube, tgt_cube = _generate_src_tgt(regridder_type, rg, allow_partial) method = str(check_method(rg.method).name) @@ -223,7 +241,7 @@ def _standard_mesh_cube(mesh, location, name): resolution = rg.resolution src_resolution = None tgt_resolution = None - elif regridder_type == "ESMFAreaWeightedRegridder": + elif method == "CONSERVATIVE": resolution = None src_resolution = rg.src_resolution tgt_resolution = rg.tgt_resolution @@ -264,6 +282,22 @@ def _standard_mesh_cube(mesh, location, name): if tgt_resolution is not None: attributes[_TARGET_RESOLUTION] = tgt_resolution + extra_cubes = [] + if regridder_type == "PartialRegridder": + src_slice = rg.src_slice # this slice is described by a tuple + if src_slice is None: + src_slice = [] + src_slice_cube = Cube( + src_slice, long_name=_SRC_SLICE_NAME, var_name=_SRC_SLICE_NAME + ) + tgt_slice = rg.tgt_slice # this slice is described by a tuple + if tgt_slice is None: + tgt_slice = [] + tgt_slice_cube = Cube( + src_slice, long_name=_TGT_SLICE_NAME, var_name=_TGT_SLICE_NAME + ) + extra_cubes = [src_slice_cube, tgt_slice_cube] + weights_cube = Cube(weight_data, var_name=_WEIGHTS_NAME, long_name=_WEIGHTS_NAME) row_coord = AuxCoord( weight_rows, var_name=_WEIGHTS_ROW_NAME, long_name=_WEIGHTS_ROW_NAME @@ -298,7 +332,9 @@ def _standard_mesh_cube(mesh, location, name): # Save cubes while ensuring var_names do not conflict for the sake of consistency. with _managed_var_name(src_cube, tgt_cube): - cube_list = CubeList([src_cube, tgt_cube, weights_cube, weight_shape_cube]) + cube_list = CubeList( + [src_cube, tgt_cube, weights_cube, weight_shape_cube, *extra_cubes] + ) for cube in cube_list: cube.attributes = attributes @@ -306,7 +342,7 @@ def _standard_mesh_cube(mesh, location, name): iris.fileformats.netcdf.save(cube_list, filename) -def load_regridder(filename): +def load_regridder(filename, allow_partial=False): """Load a regridder scheme instance. Loads any of the regridder classes, i.e. @@ -320,6 +356,8 @@ def load_regridder(filename): ---------- filename : str The file name to load from. + allow_partial : bool, default=False + If True, allow the loading of :class:`~esmf_regrid.experimental._partial.PartialRegridder` instances. Returns ------- @@ -343,6 +381,12 @@ def load_regridder(filename): raise TypeError(e_msg) scheme = _REGRIDDER_NAME_MAP[regridder_type] + if regridder_type == "PartialRegridder" and not allow_partial: + e_msg = ( + "PartialRegridder cannot be loaded without setting `allow_partial=True`." + ) + raise ValueError(e_msg) + # Determine the regridding method, allowing for files created when # conservative regridding was the only method. method_string = weights_cube.attributes.get(_METHOD, "CONSERVATIVE") @@ -396,26 +440,48 @@ def load_regridder(filename): elif scheme is MeshToGridESMFRegridder: resolution_keyword = _TARGET_RESOLUTION kwargs = {resolution_keyword: resolution, "method": method, "mdtol": mdtol} - elif scheme is ESMFAreaWeightedRegridder: + elif method is Constants.Method.CONSERVATIVE: kwargs = { _SOURCE_RESOLUTION: src_resolution, _TARGET_RESOLUTION: tgt_resolution, "mdtol": mdtol, } - elif scheme is ESMFBilinearRegridder: + elif method is Constants.Method.BILINEAR: kwargs = {"mdtol": mdtol} else: kwargs = {} - regridder = scheme( - src_cube, - tgt_cube, - precomputed_weights=weight_matrix, - use_src_mask=use_src_mask, - use_tgt_mask=use_tgt_mask, - esmf_args=esmf_args, - **kwargs, - ) + if scheme is PartialRegridder: + src_slice = cubes.extract_cube(_SRC_SLICE_NAME).data.tolist() + if src_slice == []: + src_slice = None + tgt_slice = cubes.extract_cube(_TGT_SLICE_NAME).data.tolist() + if tgt_slice == []: + tgt_slice = None + sub_scheme = { + Constants.Method.CONSERVATIVE: ESMFAreaWeighted, + Constants.Method.BILINEAR: ESMFBilinear, + Constants.Method.NEAREST: ESMFNearest, + }[method] + regridder = scheme( + src_cube, + tgt_cube, + src_slice, + tgt_slice, + weight_matrix, + sub_scheme(), + **kwargs, + ) + else: + regridder = scheme( + src_cube, + tgt_cube, + precomputed_weights=weight_matrix, + use_src_mask=use_src_mask, + use_tgt_mask=use_tgt_mask, + esmf_args=esmf_args, + **kwargs, + ) esmf_version = weights_cube.attributes[_VERSION_ESMF] regridder.regridder.esmf_version = esmf_version diff --git a/src/esmf_regrid/experimental/partition.py b/src/esmf_regrid/experimental/partition.py new file mode 100644 index 00000000..1c73a28e --- /dev/null +++ b/src/esmf_regrid/experimental/partition.py @@ -0,0 +1,260 @@ +"""Provides an interface for splitting up a large regridding task.""" + +import numpy as np + +from esmf_regrid.constants import Constants +from esmf_regrid.experimental._partial import PartialRegridder +from esmf_regrid.experimental.io import load_regridder, save_regridder +from esmf_regrid.schemes import _get_grid_dims + + +def _get_chunk(cube, sl): + if cube.mesh is None: + grid_dims = _get_grid_dims(cube) + else: + grid_dims = (cube.mesh_dim(),) + full_slice = [np.s_[:]] * len(cube.shape) + for s, d in zip(sl, grid_dims): + full_slice[d] = np.s_[s[0] : s[1]] + return cube[*full_slice] + + +def _determine_blocks(shape, chunks, num_chunks, explicit_chunks): + which_inputs = ( + chunks is not None, + num_chunks is not None, + explicit_chunks is not None, + ) + if sum(which_inputs) == 0: + msg = "Partition blocks must must be specified by either chunks, num_chunks, or explicit_chunks." + raise ValueError(msg) + if sum(which_inputs) > 1: + msg = "Potentially conflicting partition block definitions." + raise ValueError(msg) + if num_chunks is not None: + chunks = [s // n for s, n in zip(shape, num_chunks)] + for chunk in chunks: + if chunk == 0: + msg = "`num_chunks` cannot divide a dimension into more blocks than the size of that dimension." + raise ValueError(msg) + if chunks is not None: + if all(isinstance(x, int) for x in chunks): + proper_chunks = [] + for s, c in zip(shape, chunks): + proper_chunk = [c] * (s // c) + if s % c != 0: + proper_chunk += [s % c] + proper_chunks.append(proper_chunk) + chunks = proper_chunks + for s, chunk in zip(shape, chunks): + if sum(chunk) != s: + msg = "Chunks must sum to the size of their respective dimension." + raise ValueError(msg) + bounds = [np.cumsum([0, *chunk]) for chunk in chunks] + if len(bounds) == 1: + explicit_chunks = [ + [[int(lower), int(upper)]] + for lower, upper in zip(bounds[0][:-1], bounds[0][1:]) + ] + elif len(bounds) == 2: + explicit_chunks = [ + [[int(ly), int(uy)], [int(lx), int(ux)]] + for ly, uy in zip(bounds[0][:-1], bounds[0][1:]) + for lx, ux in zip(bounds[1][:-1], bounds[1][1:]) + ] + else: + msg = "Chunks must not exceed two dimensions." + raise ValueError(msg) + return explicit_chunks + + +class Partition: + """Class for breaking down regridding into manageable chunks.""" + + def __init__( + self, + src, + tgt, + scheme, + file_names, + use_dask_src_chunks=False, + src_chunks=None, + num_src_chunks=None, + explicit_src_blocks=None, + auto_generate=False, + saved_files=None, + ): + """Class for breaking down regridding into manageable chunks. + + Note + ---- + Currently, it is only possible to divide the source grid into chunks. + Meshes are not yet supported as a source. + + Parameters + ---------- + src : cube + Source cube. + tgt : cube + Target cube. + scheme : regridding scheme + Regridding scheme to generate regridders, either ESMFAreaWeighted or ESMFBilinear. + file_names : iterable of str + A list of file names to save/load parts of the regridder to/from. + use_dask_src_chunks : bool, default=False + If true, partition using the same chunks from the source cube. + src_chunks : numpy array, tuple of int or tuple of tuple of int, default=None + Specify the size of blocks to use to divide up the cube. Dimensions are specified + in y,x axis order. If `src_chunks` is a tuple of int, each integer describes + the maximum size of a block in that dimension. If `src_chunks` is a tuple of tuples, + each sub-tuple describes the size of each successive block in that dimension. These + block sizes should add up to the total size of that dimension or else an error + is raised. + num_src_chunks : tuple of int + Specify the number of blocks to use to divide up the cube. Dimensions are specified + in y,x axis order. Each integer describes the number of blocks that dimension will + be divided into. + explicit_src_blocks : arraylike NxMx2 + Explicitly specify the bounds of each block in the partition. + auto_generate : bool, default=False + When true, start generating files on initialisation. + saved_files : iterable of str + A list of paths to previously saved files. + """ + if scheme._method == Constants.Method.NEAREST: + msg = "The `Nearest` method is not implemented." + raise NotImplementedError(msg) + if src.mesh is not None: + msg = "Partition does not yet support source meshes." + raise NotImplementedError(msg) + # TODO: Extract a slice of the cube. + self.src = src + if src.mesh is None: + grid_dims = _get_grid_dims(src) + else: + grid_dims = (src.mesh_dim(),) + shape = tuple(src.shape[i] for i in grid_dims) + self.tgt = tgt + self.scheme = scheme + # TODO: consider abstracting away the idea of files + self.file_names = file_names + if use_dask_src_chunks: + if src_chunks is not None: + msg = "Potentially conflicting partition block definitions." + raise ValueError(msg) + if not src.has_lazy_data(): + msg = "If `use_dask_src_chunks=True`, the source cube must be lazy." + raise TypeError(msg) + src_chunks = src.slices(grid_dims).next().lazy_data().chunks + self.src_blocks = _determine_blocks( + shape, src_chunks, num_src_chunks, explicit_src_blocks + ) + if len(self.src_blocks) != len(file_names): + msg = "Number of source blocks does not match number of file names." + raise ValueError(msg) + # This will be controllable in future + tgt_blocks = None + self.tgt_blocks = tgt_blocks + if tgt_blocks is not None: + msg = "Target chunking not yet implemented." + raise NotImplementedError(msg) + + # Note: this may need to become more sophisticated when both src and tgt are large + self.file_block_dict = dict(zip(self.file_names, self.src_blocks)) + + if saved_files is None: + self.saved_files = [] + else: + self.saved_files = saved_files + if auto_generate: + self.generate_files(self.file_names) + + def __repr__(self): + """Return a representation of the class.""" + result = ( + f"Partition(" + f"src={self.src!r}, " + f"tgt={self.tgt!r}, " + f"scheme={self.scheme}, " + f"num file_names={len(self.file_names)}," + f"num saved_files={len(self.saved_files)})" + ) + return result + + @property + def unsaved_files(self): + """List of files not yet generated.""" + files = set(self.file_names) - set(self.saved_files) + return [file for file in self.file_names if file in files] + + def generate_files(self, files_to_generate=None): + """Generate files with regridding information. + + Parameters + ---------- + files_to_generate : int, default=None + Specify the number of files to generate, default behaviour is to generate all files. + """ + if files_to_generate is None: + files = self.unsaved_files + else: + if not isinstance(files_to_generate, int): + msg = "`files_to_generate` must be an integer." + raise ValueError(msg) + files = self.unsaved_files[:files_to_generate] + + for file in files: + src_block = self.file_block_dict[file] + src = _get_chunk(self.src, src_block) + tgt = self.tgt + regridder = self.scheme.regridder(src, tgt) + weights = regridder.regridder.weight_matrix + regridder = PartialRegridder( + src, tgt, src_block, None, weights, self.scheme + ) + save_regridder(regridder, file, allow_partial=True) + self.saved_files.append(file) + + def apply_regridders(self, cube, allow_incomplete=False): + """Apply the saved regridders to a cube. + + Parameters + ---------- + allow_incomplete : bool, default=False + If False, raise an error if not all files have been generated. If True, perform + regridding using the files which have been generated. + """ + # for each target chunk, iterate through each associated regridder + # for now, assume one target chunk + if len(self.saved_files) == 0: + msg = "No files have been generated." + raise OSError(msg) + if not allow_incomplete and len(self.unsaved_files) != 0: + msg = "Not all files have been generated." + raise OSError(msg) + current_result = None + current_weights = None + files = self.saved_files + + for file, chunk in zip(self.file_names, self.src_blocks): + if file in files: + next_regridder = load_regridder(file, allow_partial=True) + cube_chunk = _get_chunk(cube, chunk) + next_weights, next_result, extra = next_regridder.partial_regrid( + cube_chunk + ) + if current_weights is None: + current_weights = next_weights + else: + current_weights += next_weights + if current_result is None: + current_result = next_result + else: + current_result += next_result + + return next_regridder.finish_regridding( + cube_chunk, + current_weights, + current_result, + extra, + ) diff --git a/src/esmf_regrid/schemes.py b/src/esmf_regrid/schemes.py index adfe87b9..25f8c9e1 100644 --- a/src/esmf_regrid/schemes.py +++ b/src/esmf_regrid/schemes.py @@ -606,6 +606,18 @@ def _make_meshinfo(cube_or_mesh, method, mask, src_or_tgt, location=None): return _mesh_to_MeshInfo(mesh, location, mask=mask) +def _get_grid_dims(cube): + src_x = _get_coord(cube, "x") + src_y = _get_coord(cube, "y") + + if len(src_x.shape) == 1: + grid_x_dim = cube.coord_dims(src_x)[0] + grid_y_dim = cube.coord_dims(src_y)[0] + else: + grid_y_dim, grid_x_dim = cube.coord_dims(src_x) + return grid_y_dim, grid_x_dim + + def _regrid_rectilinear_to_rectilinear__prepare( src_grid_cube, tgt_grid_cube, @@ -625,14 +637,8 @@ def _regrid_rectilinear_to_rectilinear__prepare( """ tgt_x = _get_coord(tgt_grid_cube, "x") tgt_y = _get_coord(tgt_grid_cube, "y") - src_x = _get_coord(src_grid_cube, "x") - src_y = _get_coord(src_grid_cube, "y") - if len(src_x.shape) == 1: - grid_x_dim = src_grid_cube.coord_dims(src_x)[0] - grid_y_dim = src_grid_cube.coord_dims(src_y)[0] - else: - grid_y_dim, grid_x_dim = src_grid_cube.coord_dims(src_x) + grid_y_dim, grid_x_dim = _get_grid_dims(src_grid_cube) srcinfo = _make_gridinfo(src_grid_cube, method, src_resolution, src_mask) tgtinfo = _make_gridinfo(tgt_grid_cube, method, tgt_resolution, tgt_mask) @@ -805,8 +811,6 @@ def _regrid_rectilinear_to_unstructured__prepare( The 'regrid info' returned can be reused over many 2d slices. """ - grid_x = _get_coord(src_grid_cube, "x") - grid_y = _get_coord(src_grid_cube, "y") if isinstance(tgt_cube_or_mesh, MeshXY): mesh = tgt_cube_or_mesh location = tgt_location @@ -814,11 +818,7 @@ def _regrid_rectilinear_to_unstructured__prepare( mesh = tgt_cube_or_mesh.mesh location = tgt_cube_or_mesh.location - if grid_x.ndim == 1: - (grid_x_dim,) = src_grid_cube.coord_dims(grid_x) - (grid_y_dim,) = src_grid_cube.coord_dims(grid_y) - else: - grid_y_dim, grid_x_dim = src_grid_cube.coord_dims(grid_x) + grid_y_dim, grid_x_dim = _get_grid_dims(src_grid_cube) meshinfo = _make_meshinfo( tgt_cube_or_mesh, method, tgt_mask, "target", location=tgt_location @@ -1087,6 +1087,7 @@ def __init__( the regridder is saved . """ + self._method = Constants.Method.CONSERVATIVE if not (0 <= mdtol <= 1): msg = "Value for mdtol must be in range 0 - 1, got {}." raise ValueError(msg.format(mdtol)) @@ -1123,6 +1124,7 @@ def regridder( use_tgt_mask=None, tgt_location="face", esmf_args=None, + precomputed_weights=None, ): """Create regridder to perform regridding from ``src_grid`` to ``tgt_grid``. @@ -1191,6 +1193,7 @@ def regridder( use_tgt_mask=use_tgt_mask, tgt_location="face", esmf_args=esmf_args, + precomputed_weights=precomputed_weights, ) @@ -1240,6 +1243,7 @@ def __init__( the regridder is saved . """ + self._method = Constants.Method.BILINEAR if not (0 <= mdtol <= 1): msg = "Value for mdtol must be in range 0 - 1, got {}." raise ValueError(msg.format(mdtol)) @@ -1274,6 +1278,7 @@ def regridder( tgt_location=None, extrapolate_gaps=False, esmf_args=None, + precomputed_weights=None, ): """Create regridder to perform regridding from ``src_grid`` to ``tgt_grid``. @@ -1336,6 +1341,7 @@ def regridder( use_tgt_mask=use_tgt_mask, tgt_location=tgt_location, esmf_args=esmf_args, + precomputed_weights=precomputed_weights, ) @@ -1389,6 +1395,7 @@ def __init__( arguments are recorded as a property of this regridder and are stored when the regridder is saved . """ + self._method = Constants.Method.NEAREST self.use_src_mask = use_src_mask self.use_tgt_mask = use_tgt_mask self.tgt_location = tgt_location @@ -1415,6 +1422,7 @@ def regridder( use_tgt_mask=None, tgt_location=None, esmf_args=None, + precomputed_weights=None, ): """Create regridder to perform regridding from ``src_grid`` to ``tgt_grid``. @@ -1468,6 +1476,7 @@ def regridder( use_tgt_mask=use_tgt_mask, tgt_location=tgt_location, esmf_args=esmf_args, + precomputed_weights=precomputed_weights, ) @@ -1491,9 +1500,9 @@ def __init__( Parameters ---------- src : :class:`iris.cube.Cube` - The rectilinear :class:`~iris.cube.Cube` providing the source grid. + The :class:`~iris.cube.Cube` providing the source grid. tgt : :class:`iris.cube.Cube` or :class:`iris.mesh.MeshXY` - The rectilinear :class:`~iris.cube.Cube` providing the target grid. + The :class:`~iris.cube.Cube` providing the target grid. method : :class:`Constants.Method` The method to be used to calculate weights. mdtol : float, default=None @@ -1566,26 +1575,7 @@ def __init__( else: self._src = GridRecord(_get_coord(src, "x"), _get_coord(src, "y")) - def __call__(self, cube): - """Regrid this :class:`~iris.cube.Cube` onto the target grid of this regridder instance. - - The given :class:`~iris.cube.Cube` must be defined with the same grid as the source - :class:`~iris.cube.Cube` used to create this :class:`_ESMFRegridder` instance. - - Parameters - ---------- - cube : :class:`iris.cube.Cube` - A :class:`~iris.cube.Cube` instance to be regridded. - - Returns - ------- - :class:`iris.cube.Cube` - A :class:`~iris.cube.Cube` defined with the horizontal dimensions of the target - and the other dimensions from this :class:`~iris.cube.Cube`. The data values of - this :class:`~iris.cube.Cube` will be converted to values on the new grid using - regridding via :mod:`esmpy` generated weights. - - """ + def _get_cube_dims(self, cube): if cube.mesh is not None: # TODO: replace temporary hack when iris issues are sorted. # Ignore differences in var_name that might be caused by saving. @@ -1629,6 +1619,29 @@ def __call__(self, cube): else: # Due to structural reasons, the order here must be reversed. dims = cube.coord_dims(new_src_x)[::-1] + return dims + + def __call__(self, cube): + """Regrid this :class:`~iris.cube.Cube` onto the target grid of this regridder instance. + + The given :class:`~iris.cube.Cube` must be defined with the same grid as the source + :class:`~iris.cube.Cube` used to create this :class:`_ESMFRegridder` instance. + + Parameters + ---------- + cube : :class:`iris.cube.Cube` + A :class:`~iris.cube.Cube` instance to be regridded. + + Returns + ------- + :class:`iris.cube.Cube` + A :class:`~iris.cube.Cube` defined with the horizontal dimensions of the target + and the other dimensions from this :class:`~iris.cube.Cube`. The data values of + this :class:`~iris.cube.Cube` will be converted to values on the new grid using + regridding via :mod:`esmpy` generated weights. + + """ + dims = self._get_cube_dims(cube) regrid_info = RegridInfo( dims=dims, @@ -1673,11 +1686,11 @@ def __init__( Parameters ---------- src : :class:`iris.cube.Cube` - The rectilinear :class:`~iris.cube.Cube` providing the source. + The :class:`~iris.cube.Cube` providing the source. If this cube has a grid defined by latitude/longitude coordinates, those coordinates must have bounds. tgt : :class:`iris.cube.Cube` or :class:`iris.mesh.MeshXY` - The unstructured :class:`~iris.cube.Cube`or + The :class:`~iris.cube.Cube`or :class:`~iris.mesh.MeshXY` defining the target. If this cube has a grid defined by latitude/longitude coordinates, those coordinates must have bounds. @@ -1760,9 +1773,9 @@ def __init__( Parameters ---------- src : :class:`iris.cube.Cube` - The rectilinear :class:`~iris.cube.Cube` providing the source. + The :class:`~iris.cube.Cube` providing the source. tgt : :class:`iris.cube.Cube` or :class:`iris.mesh.MeshXY` - The unstructured :class:`~iris.cube.Cube`or + The :class:`~iris.cube.Cube`or :class:`~iris.mesh.MeshXY` defining the target. mdtol : float, default=0 Tolerance of missing data. The value returned in each element of @@ -1834,9 +1847,9 @@ def __init__( Parameters ---------- src : :class:`iris.cube.Cube` - The rectilinear :class:`~iris.cube.Cube` providing the source. + The :class:`~iris.cube.Cube` providing the source. tgt : :class:`iris.cube.Cube` or :class:`iris.mesh.MeshXY` - The unstructured :class:`~iris.cube.Cube`or + The :class:`~iris.cube.Cube`or :class:`~iris.mesh.MeshXY` defining the target. precomputed_weights : :class:`scipy.sparse.spmatrix`, optional If ``None``, :mod:`esmpy` will be used to diff --git a/src/esmf_regrid/tests/unit/experimental/partition/__init__.py b/src/esmf_regrid/tests/unit/experimental/partition/__init__.py new file mode 100644 index 00000000..656fc3a9 --- /dev/null +++ b/src/esmf_regrid/tests/unit/experimental/partition/__init__.py @@ -0,0 +1 @@ +"""Unit tests for :mod:`esmf_regrid.experimental.partition`.""" diff --git a/src/esmf_regrid/tests/unit/experimental/partition/test_PartialRegridder.py b/src/esmf_regrid/tests/unit/experimental/partition/test_PartialRegridder.py new file mode 100644 index 00000000..37ad2ecc --- /dev/null +++ b/src/esmf_regrid/tests/unit/experimental/partition/test_PartialRegridder.py @@ -0,0 +1,25 @@ +"""Unit tests for :mod:`esmf_regrid.experimental.partition`.""" + +from esmf_regrid import ESMFAreaWeighted +from esmf_regrid.experimental._partial import PartialRegridder +from esmf_regrid.tests.unit.schemes.test__cube_to_GridInfo import ( + _grid_cube, +) + + +def test_PartialRegridder_repr(): + """Test repr of PartialRegridder instance.""" + src = _grid_cube(10, 15, (-180, 180), (-90, 90), circular=True) + tgt = _grid_cube(5, 10, (-180, 180), (-90, 90), circular=True) + src_slice = ((10, 20), (15, 30)) + tgt_slice = ((0, 5), (0, 10)) + weights = None + scheme = ESMFAreaWeighted(mdtol=0.5) + + pr = PartialRegridder(src, tgt, src_slice, tgt_slice, weights, scheme) + + expected_repr = ( + "PartialRegridder(src_slice=((10, 20), (15, 30)), tgt_slice=((0, 5), (0, 10)), " + "scheme=ESMFAreaWeighted(mdtol=0.5, use_src_mask=False, use_tgt_mask=False, esmf_args={}))" + ) + assert repr(pr) == expected_repr diff --git a/src/esmf_regrid/tests/unit/experimental/partition/test_Partition.py b/src/esmf_regrid/tests/unit/experimental/partition/test_Partition.py new file mode 100644 index 00000000..0d72b457 --- /dev/null +++ b/src/esmf_regrid/tests/unit/experimental/partition/test_Partition.py @@ -0,0 +1,286 @@ +"""Unit tests for :mod:`esmf_regrid.experimental.partition`.""" + +import dask.array as da +import numpy as np +import pytest + +from esmf_regrid import ESMFAreaWeighted, ESMFNearest +from esmf_regrid.experimental.partition import Partition +from esmf_regrid.tests.unit.schemes.test__cube_to_GridInfo import ( + _curvilinear_cube, + _grid_cube, +) +from esmf_regrid.tests.unit.schemes.test__mesh_to_MeshInfo import ( + _gridlike_mesh_cube, +) +from esmf_regrid.tests.unit.schemes.test_regrid_rectilinear_to_rectilinear import ( + _make_full_cubes, +) + + +def test_Partition(tmp_path): + """Test basic implementation of Partition class.""" + src = _grid_cube(150, 500, (-180, 180), (-90, 90), circular=True) + src.data = np.arange(150 * 500).reshape([500, 150]) + tgt = _grid_cube(16, 36, (-180, 180), (-90, 90), circular=True) + + files = [tmp_path / f"partial_{x}.nc" for x in range(5)] + scheme = ESMFAreaWeighted(mdtol=1) + + blocks = [ + [[0, 100], [0, 150]], + [[100, 200], [0, 150]], + [[200, 300], [0, 150]], + [[300, 400], [0, 150]], + [[400, 500], [0, 150]], + ] + partition = Partition(src, tgt, scheme, files, explicit_src_blocks=blocks) + + partition.generate_files() + + result = partition.apply_regridders(src) + expected = src.regrid(tgt, scheme) + assert np.allclose(result.data, expected.data) + assert result == expected + + +def test_Partition_block_api(tmp_path): + """Test API for controlling block shape for Partition class.""" + src = _grid_cube(150, 500, (-180, 180), (-90, 90), circular=True) + tgt = _grid_cube(16, 36, (-180, 180), (-90, 90), circular=True) + + files = [tmp_path / f"partial_{x}.nc" for x in range(5)] + scheme = ESMFAreaWeighted(mdtol=1) + + num_src_chunks = (5, 1) + partition = Partition(src, tgt, scheme, files, num_src_chunks=num_src_chunks) + + expected_chunks = [ + [[0, 100], [0, 150]], + [[100, 200], [0, 150]], + [[200, 300], [0, 150]], + [[300, 400], [0, 150]], + [[400, 500], [0, 150]], + ] + assert partition.src_blocks == expected_chunks + + src_chunks = (100, 150) + partition = Partition(src, tgt, scheme, files, src_chunks=src_chunks) + + expected_chunks = [ + [[0, 100], [0, 150]], + [[100, 200], [0, 150]], + [[200, 300], [0, 150]], + [[300, 400], [0, 150]], + [[400, 500], [0, 150]], + ] + assert partition.src_blocks == expected_chunks + + src_chunks = ((100, 100, 100, 100, 100), (150,)) + partition = Partition(src, tgt, scheme, files, src_chunks=src_chunks) + + expected_chunks = [ + [[0, 100], [0, 150]], + [[100, 200], [0, 150]], + [[200, 300], [0, 150]], + [[300, 400], [0, 150]], + [[400, 500], [0, 150]], + ] + assert partition.src_blocks == expected_chunks + + src.data = da.from_array(src.data, chunks=src_chunks) + partition = Partition(src, tgt, scheme, files, use_dask_src_chunks=True) + + expected_chunks = [ + [[0, 100], [0, 150]], + [[100, 200], [0, 150]], + [[200, 300], [0, 150]], + [[300, 400], [0, 150]], + [[400, 500], [0, 150]], + ] + assert partition.src_blocks == expected_chunks + + +def test_Partition_mesh_src(tmp_path): + """Test Partition class when the source has a mesh.""" + src = _gridlike_mesh_cube(150, 500) + src.data = np.arange(150 * 500) + tgt = _grid_cube(16, 36, (-180, 180), (-90, 90), circular=True) + + files = [tmp_path / f"partial_{x}.nc" for x in range(5)] + scheme = ESMFAreaWeighted(mdtol=1) + + src_chunks = (15000,) + with pytest.raises(NotImplementedError): + _ = Partition(src, tgt, scheme, files, src_chunks=src_chunks) + + # TODO: when mesh partitioning becomes possible, uncomment. + # expected_src_chunks = [[[0, 15000]], [[15000, 30000]], [[30000, 45000]], [[45000, 60000]], [[60000, 75000]]] + # assert partition.src_blocks == expected_src_chunks + # + # partition.generate_files() + # + # result = partition.apply_regridders(src) + # expected = src.regrid(tgt, scheme) + # assert np.allclose(result.data, expected.data) + # assert result == expected + + +def test_Partition_curv_src(tmp_path): + """Test Partition class when the source has a curvilinear grid.""" + src = _curvilinear_cube(150, 500, (-180, 180), (-90, 90)) + src.data = np.arange(150 * 500).reshape([500, 150]) + tgt = _grid_cube(16, 36, (-180, 180), (-90, 90), circular=True) + + files = [tmp_path / f"partial_{x}.nc" for x in range(5)] + scheme = ESMFAreaWeighted(mdtol=1) + + src_chunks = (100, 150) + partition = Partition(src, tgt, scheme, files, src_chunks=src_chunks) + + expected_src_chunks = [ + [[0, 100], [0, 150]], + [[100, 200], [0, 150]], + [[200, 300], [0, 150]], + [[300, 400], [0, 150]], + [[400, 500], [0, 150]], + ] + assert partition.src_blocks == expected_src_chunks + + partition.generate_files() + + result = partition.apply_regridders(src) + expected = src.regrid(tgt, scheme) + assert np.allclose(result.data, expected.data) + assert result == expected + + +def test_conflicting_chunks(tmp_path): + """Test error handling of Partition class.""" + src = _grid_cube(150, 500, (-180, 180), (-90, 90), circular=True) + tgt = _grid_cube(16, 36, (-180, 180), (-90, 90), circular=True) + + files = [tmp_path / f"partial_{x}.nc" for x in range(5)] + scheme = ESMFAreaWeighted(mdtol=1) + num_src_chunks = (5, 1) + src_chunks = (100, 150) + blocks = [ + [[0, 100], [0, 150]], + [[100, 200], [0, 150]], + [[200, 300], [0, 150]], + [[300, 400], [0, 150]], + [[400, 500], [0, 150]], + ] + + with pytest.raises(ValueError): + _ = Partition( + src, + tgt, + scheme, + files, + src_chunks=src_chunks, + num_src_chunks=num_src_chunks, + ) + with pytest.raises(ValueError): + _ = Partition( + src, tgt, scheme, files, src_chunks=src_chunks, explicit_src_blocks=blocks + ) + with pytest.raises(ValueError): + _ = Partition(src, tgt, scheme, files) + with pytest.raises(TypeError): + _ = Partition(src, tgt, scheme, files, use_dask_src_chunks=True) + with pytest.raises(ValueError): + _ = Partition(src, tgt, scheme, files[:-1], src_chunks=src_chunks) + + +def test_multidimensional_cube(tmp_path): + """Test Partition class when the source has a multidimensional cube.""" + src_cube, tgt_grid, expected_cube = _make_full_cubes() + files = [tmp_path / f"partial_{x}.nc" for x in range(4)] + scheme = ESMFAreaWeighted(mdtol=1) + chunks = (2, 3) + + partition = Partition(src_cube, tgt_grid, scheme, files, src_chunks=chunks) + + partition.generate_files() + + result = partition.apply_regridders(src_cube) + + # Lenient check for data. + assert np.allclose(result.data, expected_cube.data) + + # Check metadata and coords. + result.data = expected_cube.data + assert result == expected_cube + + +def test_save_incomplete(tmp_path): + """Test Partition class when a limited number of files are saved.""" + src = _grid_cube(150, 500, (-180, 180), (-90, 90), circular=True) + tgt = _grid_cube(16, 36, (-180, 180), (-90, 90), circular=True) + + files = [tmp_path / f"partial_{x}.nc" for x in range(5)] + src_chunks = (100, 150) + scheme = ESMFAreaWeighted(mdtol=1) + num_initial_chunks = 3 + expected_files = files[:num_initial_chunks] + + partition = Partition(src, tgt, scheme, files, src_chunks=src_chunks) + with pytest.raises(OSError): + _ = partition.apply_regridders(src, allow_incomplete=True) + + partition.generate_files(files_to_generate=num_initial_chunks) + assert partition.saved_files == expected_files + + expected_array_partial = np.ma.zeros([36, 16]) + expected_array_partial[22:] = np.ma.masked + + with pytest.raises(OSError): + _ = partition.apply_regridders(src) + partial_result = partition.apply_regridders(src, allow_incomplete=True) + assert np.ma.allclose(partial_result.data, expected_array_partial) + + loaded_partition = Partition( + src, tgt, scheme, files, src_chunks=src_chunks, saved_files=expected_files + ) + + with pytest.raises(OSError): + _ = loaded_partition.apply_regridders(src) + partial_result_2 = partition.apply_regridders(src, allow_incomplete=True) + assert np.ma.allclose(partial_result_2.data, expected_array_partial) + + loaded_partition.generate_files() + + result = loaded_partition.apply_regridders(src) + expected_array = np.ma.zeros([36, 16]) + assert np.ma.allclose(result.data, expected_array) + + +def test_nearest_invalid(tmp_path): + """Test Partition class when initialised with an invalid scheme.""" + src_cube, tgt_grid, _ = _make_full_cubes() + files = [tmp_path / f"partial_{x}.nc" for x in range(4)] + scheme = ESMFNearest() + chunks = (2, 3) + + with pytest.raises(NotImplementedError): + _ = Partition(src_cube, tgt_grid, scheme, files, src_chunks=chunks) + + +def test_Partition_repr(tmp_path): + """Test repr of Partition instance.""" + src_cube, tgt_grid, _ = _make_full_cubes() + files = [tmp_path / f"partial_{x}.nc" for x in range(4)] + scheme = ESMFAreaWeighted() + chunks = (2, 3) + + partition = Partition(src_cube, tgt_grid, scheme, files, src_chunks=chunks) + + expected_repr = ( + "Partition(src=, " + "tgt=, " + "scheme=ESMFAreaWeighted(mdtol=0, use_src_mask=False, use_tgt_mask=False, esmf_args={}), " + "num file_names=4,num saved_files=0)" + ) + assert repr(partition) == expected_repr diff --git a/src/esmf_regrid/tests/unit/schemes/test_regrid_rectilinear_to_rectilinear.py b/src/esmf_regrid/tests/unit/schemes/test_regrid_rectilinear_to_rectilinear.py index 548cc1f4..1233b696 100644 --- a/src/esmf_regrid/tests/unit/schemes/test_regrid_rectilinear_to_rectilinear.py +++ b/src/esmf_regrid/tests/unit/schemes/test_regrid_rectilinear_to_rectilinear.py @@ -1,5 +1,7 @@ """Unit tests for :func:`esmf_regrid.schemes.regrid_rectilinear_to_rectilinear`.""" +from functools import partial + import dask.array as da from iris.coord_systems import RotatedGeogCS from iris.coords import AuxCoord, DimCoord @@ -68,12 +70,19 @@ def test_rotated_regridding(): assert np.allclose(expected_data, full_mdtol_result.data) -def test_extra_dims(): - """Test for :func:`esmf_regrid.schemes.regrid_rectilinear_to_rectilinear`. +def _add_metadata(cube): + result = cube.copy() + result.units = "K" + result.attributes = {"a": 1} + result.standard_name = "air_temperature" + scalar_height = AuxCoord([5], units="m", standard_name="height") + scalar_time = DimCoord([10], units="s", standard_name="time") + result.add_aux_coord(scalar_height) + result.add_aux_coord(scalar_time) + return result - Tests the handling of extra dimensions and metadata. Ensures that proper - coordinates, attributes, names and units are copied over. - """ + +def _make_full_cubes(src_rectilinear=True, tgt_rectilinear=True): h = 2 t = 4 e = 6 @@ -86,13 +95,22 @@ def test_extra_dims(): lon_bounds = (-180, 180) lat_bounds = (-90, 90) - src_grid = _grid_cube( + if src_rectilinear: + src_func = partial(_grid_cube, circular=True) + else: + src_func = _curvilinear_cube + if tgt_rectilinear: + tgt_func = partial(_grid_cube, circular=True) + else: + tgt_func = _curvilinear_cube + src_grid = src_func( src_lons, src_lats, lon_bounds, lat_bounds, ) - tgt_grid = _grid_cube( + + tgt_grid = tgt_func( tgt_lons, tgt_lats, lon_bounds, @@ -110,47 +128,56 @@ def test_extra_dims(): ] src_cube = Cube(src_data) + if src_rectilinear: + src_cube.add_dim_coord(src_grid.coord("latitude"), 1) + src_cube.add_dim_coord(src_grid.coord("longitude"), 3) + else: + src_cube.add_aux_coord(src_grid.coord("latitude"), (1, 3)) + src_cube.add_aux_coord(src_grid.coord("longitude"), (1, 3)) src_cube.add_dim_coord(height, 0) - src_cube.add_dim_coord(src_grid.coord("latitude"), 1) src_cube.add_dim_coord(time, 2) - src_cube.add_dim_coord(src_grid.coord("longitude"), 3) src_cube.add_aux_coord(extra, 4) src_cube.add_aux_coord(spanning, [0, 2, 4]) - def _add_metadata(cube): - result = cube.copy() - result.units = "K" - result.attributes = {"a": 1} - result.standard_name = "air_temperature" - scalar_height = AuxCoord([5], units="m", standard_name="height") - scalar_time = DimCoord([10], units="s", standard_name="time") - result.add_aux_coord(scalar_height) - result.add_aux_coord(scalar_time) - return result - src_cube = _add_metadata(src_cube) - result = regrid_rectilinear_to_rectilinear(src_cube, tgt_grid) - expected_data = np.empty([h, tgt_lats, t, tgt_lons, e]) expected_data[:] = np.arange(t * h * e).reshape([h, t, e])[ :, np.newaxis, :, np.newaxis, : ] expected_cube = Cube(expected_data) + if tgt_rectilinear: + expected_cube.add_dim_coord(tgt_grid.coord("latitude"), 1) + expected_cube.add_dim_coord(tgt_grid.coord("longitude"), 3) + else: + expected_cube.add_aux_coord(tgt_grid.coord("latitude"), (1, 3)) + expected_cube.add_aux_coord(tgt_grid.coord("longitude"), (1, 3)) expected_cube.add_dim_coord(height, 0) - expected_cube.add_dim_coord(tgt_grid.coord("latitude"), 1) expected_cube.add_dim_coord(time, 2) - expected_cube.add_dim_coord(tgt_grid.coord("longitude"), 3) expected_cube.add_aux_coord(extra, 4) expected_cube.add_aux_coord(spanning, [0, 2, 4]) expected_cube = _add_metadata(expected_cube) + return src_cube, tgt_grid, expected_cube + + +def test_extra_dims(): + """Test for :func:`esmf_regrid.schemes.regrid_rectilinear_to_rectilinear`. + + Tests the handling of extra dimensions and metadata. Ensures that proper + coordinates, attributes, names and units are copied over. + """ + src_cube, tgt_grid, expected_cube = _make_full_cubes( + src_rectilinear=True, tgt_rectilinear=True + ) + + result = regrid_rectilinear_to_rectilinear(src_cube, tgt_grid) # Lenient check for data. - assert np.allclose(expected_data, result.data) + assert np.allclose(expected_cube.data, result.data) # Check metadata and coords. - result.data = expected_data + result.data = expected_cube.data assert expected_cube == result @@ -266,83 +293,17 @@ def test_extra_dims_curvilinear(): Tests the handling of extra dimensions and metadata. Ensures that proper coordinates, attributes, names and units are copied over. """ - h = 2 - t = 4 - e = 6 - src_lats = 3 - src_lons = 5 - - tgt_lats = 5 - tgt_lons = 3 - - lon_bounds = (-180, 180) - lat_bounds = (-90, 90) - - src_grid = _curvilinear_cube( - src_lons, - src_lats, - lon_bounds, - lat_bounds, - ) - tgt_grid = _curvilinear_cube( - tgt_lons, - tgt_lats, - lon_bounds, - lat_bounds, + src_cube, tgt_grid, expected_cube = _make_full_cubes( + src_rectilinear=False, tgt_rectilinear=False ) - height = DimCoord(np.arange(h), standard_name="height") - time = DimCoord(np.arange(t), standard_name="time") - extra = AuxCoord(np.arange(e), long_name="extra dim") - spanning = AuxCoord(np.ones([h, t, e]), long_name="spanning dim") - - src_data = np.empty([h, src_lats, t, src_lons, e]) - src_data[:] = np.arange(t * h * e).reshape([h, t, e])[ - :, np.newaxis, :, np.newaxis, : - ] - - src_cube = Cube(src_data) - src_cube.add_dim_coord(height, 0) - src_cube.add_aux_coord(src_grid.coord("latitude"), (1, 3)) - src_cube.add_dim_coord(time, 2) - src_cube.add_aux_coord(src_grid.coord("longitude"), (1, 3)) - src_cube.add_aux_coord(extra, 4) - src_cube.add_aux_coord(spanning, [0, 2, 4]) - - def _add_metadata(cube): - result = cube.copy() - result.units = "K" - result.attributes = {"a": 1} - result.standard_name = "air_temperature" - scalar_height = AuxCoord([5], units="m", standard_name="height") - scalar_time = DimCoord([10], units="s", standard_name="time") - result.add_aux_coord(scalar_height) - result.add_aux_coord(scalar_time) - return result - - src_cube = _add_metadata(src_cube) - result = regrid_rectilinear_to_rectilinear(src_cube, tgt_grid) - expected_data = np.empty([h, tgt_lats, t, tgt_lons, e]) - expected_data[:] = np.arange(t * h * e).reshape([h, t, e])[ - :, np.newaxis, :, np.newaxis, : - ] - - expected_cube = Cube(expected_data) - expected_cube.add_dim_coord(height, 0) - expected_cube.add_aux_coord(tgt_grid.coord("latitude"), (1, 3)) - expected_cube.add_dim_coord(time, 2) - expected_cube.add_aux_coord(tgt_grid.coord("longitude"), (1, 3)) - expected_cube.add_aux_coord(extra, 4) - expected_cube.add_aux_coord(spanning, [0, 2, 4]) - expected_cube = _add_metadata(expected_cube) - # Lenient check for data. - assert np.allclose(expected_data, result.data) + assert np.allclose(expected_cube.data, result.data) # Check metadata and coords. - result.data = expected_data + result.data = expected_cube.data assert expected_cube == result @@ -352,83 +313,17 @@ def test_extra_dims_curvilinear_to_rectilinear(): Tests the handling of extra dimensions and metadata. Ensures that proper coordinates, attributes, names and units are copied over. """ - h = 2 - t = 4 - e = 6 - src_lats = 3 - src_lons = 5 - - tgt_lats = 5 - tgt_lons = 3 - - lon_bounds = (-180, 180) - lat_bounds = (-90, 90) - - src_grid = _curvilinear_cube( - src_lons, - src_lats, - lon_bounds, - lat_bounds, - ) - tgt_grid = _grid_cube( - tgt_lons, - tgt_lats, - lon_bounds, - lat_bounds, + src_cube, tgt_grid, expected_cube = _make_full_cubes( + src_rectilinear=False, tgt_rectilinear=True ) - height = DimCoord(np.arange(h), standard_name="height") - time = DimCoord(np.arange(t), standard_name="time") - extra = AuxCoord(np.arange(e), long_name="extra dim") - spanning = AuxCoord(np.ones([h, t, e]), long_name="spanning dim") - - src_data = np.empty([h, src_lats, t, src_lons, e]) - src_data[:] = np.arange(t * h * e).reshape([h, t, e])[ - :, np.newaxis, :, np.newaxis, : - ] - - src_cube = Cube(src_data) - src_cube.add_dim_coord(height, 0) - src_cube.add_aux_coord(src_grid.coord("latitude"), (1, 3)) - src_cube.add_dim_coord(time, 2) - src_cube.add_aux_coord(src_grid.coord("longitude"), (1, 3)) - src_cube.add_aux_coord(extra, 4) - src_cube.add_aux_coord(spanning, [0, 2, 4]) - - def _add_metadata(cube): - result = cube.copy() - result.units = "K" - result.attributes = {"a": 1} - result.standard_name = "air_temperature" - scalar_height = AuxCoord([5], units="m", standard_name="height") - scalar_time = DimCoord([10], units="s", standard_name="time") - result.add_aux_coord(scalar_height) - result.add_aux_coord(scalar_time) - return result - - src_cube = _add_metadata(src_cube) - result = regrid_rectilinear_to_rectilinear(src_cube, tgt_grid) - expected_data = np.empty([h, tgt_lats, t, tgt_lons, e]) - expected_data[:] = np.arange(t * h * e).reshape([h, t, e])[ - :, np.newaxis, :, np.newaxis, : - ] - - expected_cube = Cube(expected_data) - expected_cube.add_dim_coord(height, 0) - expected_cube.add_dim_coord(tgt_grid.coord("latitude"), 1) - expected_cube.add_dim_coord(time, 2) - expected_cube.add_dim_coord(tgt_grid.coord("longitude"), 3) - expected_cube.add_aux_coord(extra, 4) - expected_cube.add_aux_coord(spanning, [0, 2, 4]) - expected_cube = _add_metadata(expected_cube) - # Lenient check for data. - assert np.allclose(expected_data, result.data) + assert np.allclose(expected_cube.data, result.data) # Check metadata and coords. - result.data = expected_data + result.data = expected_cube.data assert expected_cube == result @@ -438,81 +333,15 @@ def test_extra_dims_rectilinear_to_curvilinear(): Tests the handling of extra dimensions and metadata. Ensures that proper coordinates, attributes, names and units are copied over. """ - h = 2 - t = 4 - e = 6 - src_lats = 3 - src_lons = 5 - - tgt_lats = 5 - tgt_lons = 3 - - lon_bounds = (-180, 180) - lat_bounds = (-90, 90) - - src_grid = _grid_cube( - src_lons, - src_lats, - lon_bounds, - lat_bounds, + src_cube, tgt_grid, expected_cube = _make_full_cubes( + src_rectilinear=True, tgt_rectilinear=False ) - tgt_grid = _curvilinear_cube( - tgt_lons, - tgt_lats, - lon_bounds, - lat_bounds, - ) - - height = DimCoord(np.arange(h), standard_name="height") - time = DimCoord(np.arange(t), standard_name="time") - extra = AuxCoord(np.arange(e), long_name="extra dim") - spanning = AuxCoord(np.ones([h, t, e]), long_name="spanning dim") - - src_data = np.empty([h, src_lats, t, src_lons, e]) - src_data[:] = np.arange(t * h * e).reshape([h, t, e])[ - :, np.newaxis, :, np.newaxis, : - ] - - src_cube = Cube(src_data) - src_cube.add_dim_coord(height, 0) - src_cube.add_dim_coord(src_grid.coord("latitude"), 1) - src_cube.add_dim_coord(time, 2) - src_cube.add_dim_coord(src_grid.coord("longitude"), 3) - src_cube.add_aux_coord(extra, 4) - src_cube.add_aux_coord(spanning, [0, 2, 4]) - - def _add_metadata(cube): - result = cube.copy() - result.units = "K" - result.attributes = {"a": 1} - result.standard_name = "air_temperature" - scalar_height = AuxCoord([5], units="m", standard_name="height") - scalar_time = DimCoord([10], units="s", standard_name="time") - result.add_aux_coord(scalar_height) - result.add_aux_coord(scalar_time) - return result - - src_cube = _add_metadata(src_cube) result = regrid_rectilinear_to_rectilinear(src_cube, tgt_grid) - expected_data = np.empty([h, tgt_lats, t, tgt_lons, e]) - expected_data[:] = np.arange(t * h * e).reshape([h, t, e])[ - :, np.newaxis, :, np.newaxis, : - ] - - expected_cube = Cube(expected_data) - expected_cube.add_dim_coord(height, 0) - expected_cube.add_aux_coord(tgt_grid.coord("latitude"), (1, 3)) - expected_cube.add_dim_coord(time, 2) - expected_cube.add_aux_coord(tgt_grid.coord("longitude"), (1, 3)) - expected_cube.add_aux_coord(extra, 4) - expected_cube.add_aux_coord(spanning, [0, 2, 4]) - expected_cube = _add_metadata(expected_cube) - # Lenient check for data. - assert np.allclose(expected_data, result.data) + assert np.allclose(expected_cube.data, result.data) # Check metadata and coords. - result.data = expected_data + result.data = expected_cube.data assert expected_cube == result