Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
11f5df7
Partition MVP
stephenworsley Dec 17, 2025
8ad6983
lint fixes
stephenworsley Dec 17, 2025
bc29614
lint fixes
stephenworsley Dec 17, 2025
e9fa31a
lint fixes and name changes
stephenworsley Dec 18, 2025
e625856
lint fixes
stephenworsley Dec 18, 2025
60fc675
fix import order
stephenworsley Dec 18, 2025
decc42f
fix import order
stephenworsley Dec 18, 2025
237cef8
add tests, docstrings
stephenworsley Dec 18, 2025
7a8affa
add multidimensional test
stephenworsley Dec 19, 2025
8abfed2
add multidimensional cube support
stephenworsley Dec 19, 2025
9b3fc78
fix test
stephenworsley Dec 19, 2025
78e4479
fix test
stephenworsley Dec 19, 2025
147abbf
fix test
stephenworsley Dec 19, 2025
36a9b05
fix ESMFNearest behaviour
stephenworsley Dec 19, 2025
1dd92e6
add test, improve docstrings
stephenworsley Jan 14, 2026
45d449c
fix test
stephenworsley Jan 15, 2026
9bc0dc9
add test
stephenworsley Jan 16, 2026
004e895
fix test
stephenworsley Jan 16, 2026
f7ce8ca
ruff format
stephenworsley Jan 16, 2026
6f33251
add documentation
stephenworsley Jan 20, 2026
c367325
add to docstrings
stephenworsley Jan 21, 2026
9cbb76e
add docstrings and repr to PartialRegridder
stephenworsley Jan 21, 2026
79b56a1
repr testing
stephenworsley Jan 21, 2026
90968b6
ruff fix
stephenworsley Jan 21, 2026
15c1f5a
fix docs
stephenworsley Jan 21, 2026
84c9482
docs grammar
stephenworsley Jan 21, 2026
b6aca52
attempt benchmarks slowdown fix
stephenworsley Jan 22, 2026
2f9ba4e
tidy unused code
stephenworsley Jan 22, 2026
750bc30
Merge branch 'main' into partition_mvp
stephenworsley Jan 26, 2026
ae5b638
ruff fix
stephenworsley Jan 26, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 88 additions & 0 deletions docs/src/userguide/examples.rst
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,94 @@ certain regridders. We can do this as follows::
# Use loaded regridder.
result = loaded_regridder(source_mesh_cube)

Partitioning a Regridder
------------------------

If a regridder would be too large to handle in memory, it can be broken down
into smaller regridders which can collectively do the job of the larger regridder.
This is done using a `Partition` object.

.. note:: Currently, it is only possible to partition regridding when the source is
a large grid and the target is small enough to fit in memory.

A `Partition` is made by specifying a source, a target, a list of files, and a way
to divide the source grid into blocks whose regridders are saved to those files::

from iris.util import make_gridcube

from esmf_regrid import ESMFAreaWeighted
from esmf_regrid.experimental.partition import Partition

# Create a large source cube.
source_cube = make_gridcube(nx=800, ny=800)

# Create a small target cube.
target_cube = make_gridcube(nx=100, ny=100)

# Set the regridding scheme.
scheme = AreaWeighted()

# List a collection of file names/paths to save partial regridders to.
files = ["file_1", "file_2", "file_3", "file_4"]

# Set the size of each block of the partition. For the keyword `src_chunks`
# this follows the dask chunking API.
src_chunks = (400, 400)

# Initialise the partition.
partition = Partition(
source_cube,
target_cube,
scheme,
files,
src_chunks=src_chunks
)

Initialising the `Partition` will not generate the files automatically unless
the `auto_generate` keyword is set to `True`. In order for this `Partition` to
function, the regridder files must be generated by calling the `generate_files`
method.

.. note:: Not all files need to be generated at once, if you have a grid which
needs to be split into very many files, it is possible to generate only
a portion of those files within a given session by passing the number
of files to generate as an argument to the regridder. It is possible to
generate the remaining files in a different python session.
::

# Generate partial regridders and save them to the list of files.
partition.generate_files()

# Once the files have been generated, they can be used for regridding.
result = partition.apply_regridders(source_cube)

Once the files for a regridder have been generated, they can be used to reconstruct
the partition object in a later session. This is done by passing in the list of
files which have already been generated::

# Use the same arguments which constructed the original partition.
source_cube = make_gridcube(nx=800, ny=800)
target_cube = make_gridcube(nx=100, ny=100)
scheme = AreaWeighted()
files = ["file_1", "file_2", "file_3", "file_4"]
src_chunks = (400, 400)

# List the files which have already been generated.
saved_files = ["file_1", "file_2", "file_3", "file_4"]

# Reconstruct Partition from pre-generated files.
partition = Partition(
source_cube,
target_cube,
scheme,
files,
src_chunks=src_chunks
saved_files=saved_files # Pass in the list of saved files.
)

# The new Partition can now be used without the need for generating files.
result = partition.apply_regridders(source_cube)

.. todo:
Add more examples.

Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ ignore = [
"ANN202",
"ANN204",

"B905", # Zip strictness should be explicit
"D104", # Misssing docstring
"E501", # Line too long
"ERA001", # Commented out code
Expand Down Expand Up @@ -298,5 +299,5 @@ convention = "numpy"
[tool.ruff.lint.pylint]
# TODO: refactor to reduce complexity, if possible
max-args = 10
max-branches = 21
max-branches = 22
max-statements = 110
70 changes: 44 additions & 26 deletions src/esmf_regrid/esmf_regridder.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,46 @@ def _out_dtype(self, in_dtype):
).dtype
return out_dtype

def _gen_weights_and_data(self, src_array):
extra_shape = src_array.shape[: -self.src.dims]

if self.method == Constants.Method.NEAREST:
weight_matrix = self.weight_matrix.astype(src_array.dtype)
else:
weight_matrix = self.weight_matrix

flat_src = self.src._array_to_matrix(ma.filled(src_array, 0.0))
flat_tgt = weight_matrix @ flat_src

src_inverted_mask = self.src._array_to_matrix(~ma.getmaskarray(src_array))
weight_sums = weight_matrix @ src_inverted_mask
return weight_sums, flat_tgt, extra_shape

def _regrid_from_weights_and_data(
self,
tgt_weights,
tgt_data,
extra,
norm_type=Constants.NormType.FRACAREA,
mdtol=1,
):
# Set the minimum mdtol to be slightly higher than 0 to account for rounding
# errors.
mdtol = max(mdtol, 1e-8)
tgt_mask = tgt_weights > 1 - mdtol
normalisations = np.ones_like(tgt_data)
if self.method != Constants.Method.NEAREST:
masked_weight_sums = tgt_weights * tgt_mask
if norm_type == Constants.NormType.FRACAREA:
normalisations[tgt_mask] /= masked_weight_sums[tgt_mask]
elif norm_type == Constants.NormType.DSTAREA:
pass
normalisations = ma.array(normalisations, mask=np.logical_not(tgt_mask))

tgt_array = tgt_data * normalisations
tgt_array = self.tgt._matrix_to_array(tgt_array, extra)
return tgt_array

def regrid(self, src_array, norm_type=Constants.NormType.FRACAREA, mdtol=1):
"""Perform regridding on an array of data.

Expand Down Expand Up @@ -212,30 +252,8 @@ def regrid(self, src_array, norm_type=Constants.NormType.FRACAREA, mdtol=1):
f"got an array with shape ending in {main_shape}."
)
raise ValueError(e_msg)
extra_shape = array_shape[: -self.src.dims]
extra_size = max(1, np.prod(extra_shape))
src_inverted_mask = self.src._array_to_matrix(~ma.getmaskarray(src_array))
weight_matrix = self.weight_matrix
if self.method == Constants.Method.NEAREST:
# force out_dtype := in_dtype
weight_matrix = weight_matrix.astype(src_array.dtype)
weight_sums = weight_matrix @ src_inverted_mask
out_dtype = self._out_dtype(src_array.dtype)
# Set the minimum mdtol to be slightly higher than 0 to account for rounding
# errors.
mdtol = max(mdtol, 1e-8)
tgt_mask = weight_sums > 1 - mdtol
normalisations = np.ones([self.tgt.size, extra_size], dtype=out_dtype)
if self.method != Constants.Method.NEAREST:
masked_weight_sums = weight_sums * tgt_mask
if norm_type == Constants.NormType.FRACAREA:
normalisations[tgt_mask] /= masked_weight_sums[tgt_mask]
elif norm_type == Constants.NormType.DSTAREA:
pass
normalisations = ma.array(normalisations, mask=np.logical_not(tgt_mask))

flat_src = self.src._array_to_matrix(ma.filled(src_array, 0.0))
flat_tgt = weight_matrix @ flat_src
flat_tgt = flat_tgt * normalisations
tgt_array = self.tgt._matrix_to_array(flat_tgt, extra_shape)
tgt_weights, tgt_data, extra = self._gen_weights_and_data(src_array)
tgt_array = self._regrid_from_weights_and_data(
tgt_weights, tgt_data, extra, norm_type=norm_type, mdtol=mdtol
)
return tgt_array
88 changes: 88 additions & 0 deletions src/esmf_regrid/experimental/_partial.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
"""Provides a regridder class compatible with Partition."""

import numpy as np

from esmf_regrid.schemes import (
_create_cube,
_ESMFRegridder,
)


class PartialRegridder(_ESMFRegridder):
"""Regridder class designed for use in :class:`~esmf_regrid.experimental._partial.Partial`."""

def __init__(self, src, tgt, src_slice, tgt_slice, weights, scheme, **kwargs):
"""Create a regridder instance for a block of :class:`~esmf_regrid.experimental._partial.Partial`.

Parameters
----------
src : :class:`iris.cube.Cube`
The :class:`~iris.cube.Cube` providing the source.
tgt : :class:`iris.cube.Cube` or :class:`iris.mesh.MeshXY`
The :class:`~iris.cube.Cube` or :class:`~iris.mesh.MeshXY` providing the target.
src_slice : tuple
The upper and lower bounds of the block taken from the original source from which the
``src`` was derived.
tgt_slice : tuple
The upper and lower bounds of the block taken from the original target from which the
``tgt`` was derived.
weights : :class:`scipy.sparse.spmatrix`
The weights to use for regridding.
scheme : :class:`~esmf_regrid.schemes.ESMFAreaWeighted` or :class:`~esmf_regrid.schemes.ESMFBilinear`
The scheme used to construct the regridder.
"""
self.src_slice = src_slice # this will be tuple-like
self.tgt_slice = tgt_slice
self.scheme = scheme

# Pop duplicate kwargs.
for arg in set(kwargs.keys()).intersection(vars(self.scheme)):
kwargs.pop(arg)

self._regridder = scheme.regridder(
src,
tgt,
precomputed_weights=weights,
**kwargs,
)
self.__dict__.update(self._regridder.__dict__)

def __repr__(self):
"""Return a representation of the class."""
result = (
f"PartialRegridder("
f"src_slice={self.src_slice}, "
f"tgt_slice={self.tgt_slice}, "
f"scheme={self.scheme})"
)
return result

def partial_regrid(self, src):
"""Perform the first half of regridding, generating weights and data."""
dims = self._get_cube_dims(src)
num_dims = len(dims)
standard_in_dims = [-1, -2][:num_dims]
data = np.moveaxis(src.data, dims, standard_in_dims)
result = self.regridder._gen_weights_and_data(data)
return result

def finish_regridding(self, src_cube, weights, data, extra):
"""Perform the second half of regridding, combining weights and data."""
dims = self._get_cube_dims(src_cube)

result_data = self.regridder._regrid_from_weights_and_data(weights, data, extra)

num_out_dims = self.regridder.tgt.dims
num_dims = len(dims)
standard_out_dims = [-1, -2][:num_out_dims]
if num_dims == 2 and num_out_dims == 1:
dims = [min(dims)]
if num_dims == 1 and num_out_dims == 2:
dims = [dims[0] + 1, dims[0]]

result_data = np.moveaxis(result_data, standard_out_dims, dims)

result_cube = _create_cube(
result_data, src_cube, dims, self._tgt, len(self._tgt)
)
return result_cube
Loading
Loading