Skip to content

Commit

Permalink
fix usage of Cache in different interpolation methods
Browse files Browse the repository at this point in the history
  • Loading branch information
sbrdar committed Feb 13, 2025
1 parent 2dcbf88 commit 8def729
Show file tree
Hide file tree
Showing 8 changed files with 84 additions and 45 deletions.
6 changes: 3 additions & 3 deletions src/atlas/interpolation/method/knn/GridBoxMethod.cc
Original file line number Diff line number Diff line change
Expand Up @@ -184,16 +184,16 @@ void GridBoxMethod::do_setup(const Grid& source, const Grid& target, const Cache
void GridBoxMethod::do_setup(const FunctionSpace& source, const FunctionSpace& target, const Cache& cache) {
ATLAS_TRACE("GridBoxMethod::setup()");

source_ = source;
target_ = target;
if (not matrixFree_ && interpolation::MatrixCache(cache)) {
setMatrix(cache);
source_ = source;
target_ = target;
ATLAS_ASSERT(matrix().rows() == target.size());
ATLAS_ASSERT(matrix().cols() == source.size());
return;
}

// setup only with cache
Log::warning() << "Can not create GridBoxMethod from (FunctionSpace, FunctionSpace, Cache). Use (Grid, Grid, Cache)";
ATLAS_NOTIMPLEMENTED;
}

Expand Down
13 changes: 11 additions & 2 deletions src/atlas/interpolation/method/knn/NearestNeighbour.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,17 @@ MethodBuilder<NearestNeighbour> __builder("nearest-neighbour");
} // namespace

void NearestNeighbour::do_setup(const FunctionSpace& source, const FunctionSpace& target, const Cache& cache) {
setMatrix(cache);
do_setup(source, target);
if (interpolation::MatrixCache(cache)) {
setMatrix(cache);
source_ = source;
target_ = target;
return;
}
if (functionspace::NodeColumns(source) && functionspace::NodeColumns(target)) {
do_setup(source, target);
return;
}
ATLAS_NOTIMPLEMENTED;
}

void NearestNeighbour::do_setup(const Grid& source, const Grid& target, const Cache&) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -265,9 +265,21 @@ void StructuredInterpolation2D<Kernel>::do_setup( const Grid& source, const Grid


template <typename Kernel>
void StructuredInterpolation2D<Kernel>::do_setup( const FunctionSpace& source, const FunctionSpace& target, const Cache& ) {
void StructuredInterpolation2D<Kernel>::do_setup( const FunctionSpace& source, const FunctionSpace& target, const Cache& cache) {
ATLAS_TRACE( "StructuredInterpolation2D<" + Kernel::className() + ">::do_setup(FunctionSpace source, FunctionSpace target)" );
do_setup( source, target );
if (interpolation::MatrixCache(cache)) {
setMatrix(cache);
source_ = source;
target_ = target;
ATLAS_ASSERT(matrix().rows() == target.size());
ATLAS_ASSERT(matrix().cols() == source.size());
return;
}
if (functionspace::NodeColumns(source) && functionspace::PointCloud(target)) {
do_setup(source, target);
return;
}
ATLAS_NOTIMPLEMENTED;
}

template <typename Kernel>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,21 @@ void StructuredInterpolation3D<Kernel>::do_setup( const Grid& source, const Grid
}

template <typename Kernel>
void StructuredInterpolation3D<Kernel>::do_setup( const FunctionSpace& source, const FunctionSpace& target, const Cache& ) {
if ( mpi::size() > 1 ) {
ATLAS_NOTIMPLEMENTED;
void StructuredInterpolation3D<Kernel>::do_setup( const FunctionSpace& source, const FunctionSpace& target, const Cache& cache) {
ATLAS_TRACE( "StructuredInterpolation3D<" + Kernel::className() + ">::do_setup(FunctionSpace source, FunctionSpace target, const Cache)" );
if (interpolation::MatrixCache(cache)) {
setMatrix(cache);
source_ = source;
target_ = target;
ATLAS_ASSERT(matrix().rows() == target.size());
ATLAS_ASSERT(matrix().cols() == source.size());
return;
}

do_setup( source, target );
if (functionspace::StructuredColumns(source) && functionspace::PointCloud(target)) {
do_setup( source, target );
return;
}
ATLAS_NOTIMPLEMENTED;
}

template <typename Kernel>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -693,6 +693,7 @@ void ConservativeSphericalPolygonInterpolation::do_setup(const FunctionSpace& so
return;
}
}

do_setup(source, target);
}

Expand Down
12 changes: 8 additions & 4 deletions src/atlas/interpolation/method/unstructured/FiniteElement.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,16 +83,20 @@ void FiniteElement::do_setup(const Grid& source, const Grid& target, const Cache
do_setup(make_nodecolumns(source), functionspace::PointCloud{target});
}

void FiniteElement::do_setup(const FunctionSpace& source, const FunctionSpace& target, const Cache& cache) {
source_ = source;
target_ = target;
void FiniteElement::do_setup(const FunctionSpace& source, const FunctionSpace& target, const Cache& cache) {
if (interpolation::MatrixCache(cache)) {
setMatrix(cache);
source_ = source;
target_ = target;
ATLAS_ASSERT(matrix().rows() == target.size());
ATLAS_ASSERT(matrix().cols() == source.size());
return;
}
do_setup(source, target);
if (functionspace::NodeColumns(source) && functionspace::PointCloud(target)) {
do_setup(source, target);
return;
}
ATLAS_NOTIMPLEMENTED;
}

void FiniteElement::do_setup(const FunctionSpace& source, const FunctionSpace& target) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,13 +83,18 @@ void UnstructuredBilinearLonLat::do_setup(const Grid& source, const Grid& target
void UnstructuredBilinearLonLat::do_setup(const FunctionSpace& source, const FunctionSpace& target, const Cache& cache) {
allow_halo_exchange_ = false;
// no halo_exchange because we don't have any halo with delaunay or 3d structured meshgenerator

if (interpolation::MatrixCache(cache)) {
setMatrix(cache);
source_ = source;
target_ = target;
ATLAS_ASSERT(matrix().rows() == target.size());
ATLAS_ASSERT(matrix().cols() == source.size());
return;
}
if (functionspace::NodeColumns(source) && functionspace::PointCloud(target)) {
do_setup(source, target);
return;
}
ATLAS_NOTIMPLEMENTED;
}

Expand Down
55 changes: 27 additions & 28 deletions src/tests/interpolation/test_interpolation_global_matrix.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

#include "tests/AtlasTestEnvironment.h"

using atlas::functionspace::PointCloud;
using atlas::functionspace::NodeColumns;
using atlas::functionspace::StructuredColumns;
using atlas::util::Config;
Expand All @@ -40,31 +41,22 @@ namespace atlas::test {
Config config_scheme(std::string scheme_str) {
Config scheme;
scheme.set("matrix_free", false);
if (scheme_str == "linear") {
scheme.set("type", "structured-linear2D");
// The stencil does not require any halo, but we set it to 1 for pole treatment!
scheme.set("halo", 1);
}
if (scheme_str == "cubic") {
scheme.set("type", "structured-cubic2D");
scheme.set("type", scheme_str);
scheme.set("halo", 1);

if (scheme_str.find("cubic") != std::string::npos) {
scheme.set("halo", 2);
}
if (scheme_str == "quasicubic") {
scheme.set("type", "structured-quasicubic2D");
if (scheme_str == "k-nearest-neighbours") {
scheme.set("k-nearest-neighbours", 4);
scheme.set("halo", 2);
}
if (scheme_str == "conservative") {
scheme.set("type", "conservative-spherical-polygon");
if (scheme_str == "conservative-spherical-polygon") {
scheme.set("halo", 2);
scheme.set("order", 2);
scheme.set("src_cell_data", false);
scheme.set("tgt_cell_data", false);
}
if (scheme_str == "finite-element") {
scheme.set("type", "finite-element");
scheme.set("halo", 1);
}

scheme.set("name", scheme_str);
return scheme;
}

Expand All @@ -79,6 +71,11 @@ Config create_fspaces(const std::string& scheme_str, const Grid& input_grid, con
fs_in = NodeColumns(inmesh, scheme);
fs_out = NodeColumns(outmesh);
}
else if (scheme_type == "unstructured-bilinear-lonlat") {
auto inmesh = Mesh(input_grid);
fs_in = NodeColumns(input_grid, scheme);
fs_out = PointCloud(output_grid, grid::MatchingPartitioner(inmesh));
}
else if (scheme_type == "conservative-spherical-polygon") {
bool src_cell_data = scheme.getBool("src_cell_data");
bool tgt_cell_data = scheme.getBool("tgt_cell_data");
Expand Down Expand Up @@ -185,7 +182,7 @@ using SparseMatrixStorage = atlas::linalg::SparseMatrixStorage;


auto do_assemble_distribute_matrix = [&](const std::string scheme_str, const Grid& input_grid, const Grid& output_grid, const int mpi_root) {
Log::info() << "\tassemble / distribute from " << scheme_str << ", " << input_grid.name() << " - " << output_grid.name() << std::endl;
Log::info() << "[TEST] assemble / distribute from " << scheme_str << ", " << input_grid.name() << " - " << output_grid.name() << std::endl;
FunctionSpace fs_in;
FunctionSpace fs_out;
SparseMatrixStorage gmatrix;
Expand All @@ -201,7 +198,7 @@ using SparseMatrixStorage = atlas::linalg::SparseMatrixStorage;
}

std::vector<double> tgt_data(output_grid.size());
ATLAS_TRACE_SCOPE("Compute the global field from the global matrix") {
ATLAS_TRACE_SCOPE("TEST Compute the global field from the global matrix") {
if (mpi::comm().rank() == mpi_root) {
std::vector<double> src_data(input_grid.size());
auto src = eckit::linalg::Vector(src_data.data(), src_data.size());
Expand All @@ -217,7 +214,7 @@ using SparseMatrixStorage = atlas::linalg::SparseMatrixStorage;

auto tgt_field_global = interpolator.target().createField<double>(option::global(mpi_root));

ATLAS_TRACE_SCOPE("Compute the global field from the distributed interpolation") {
ATLAS_TRACE_SCOPE("TEST Compute the global field from the distributed interpolation") {
auto tgt_field = interpolator.target().createField<double>();
auto field_in = interpolator.source().createField<double>();
auto lonlat_in = array::make_view<double, 2>(interpolator.source().lonlat());
Expand All @@ -231,7 +228,7 @@ using SparseMatrixStorage = atlas::linalg::SparseMatrixStorage;
interpolator.target().gather(tgt_field, tgt_field_global);
}

ATLAS_TRACE_SCOPE("Compare the two global fields") {
ATLAS_TRACE_SCOPE("TEST Compare the two global fields") {
if (mpi::comm().rank() == mpi_root) {
auto tfield_global_v = array::make_view<double, 1>(tgt_field_global);
for (gidx_t i = 0; i < tgt_data.size(); ++i) {
Expand All @@ -256,17 +253,19 @@ using SparseMatrixStorage = atlas::linalg::SparseMatrixStorage;

auto test_matrix_assemble_distribute = [&](const Grid& input_grid, const Grid& output_grid) {
int mpi_root = 0;
// do_assemble_distribute_matrix("linear", input_grid, output_grid, mpi_root);
do_assemble_distribute_matrix("structured-linear2D", input_grid, output_grid, mpi_root);

// mpi_root = mpi::size() - 1;
// do_assemble_distribute_matrix("linear", input_grid, output_grid, mpi_root);
// do_assemble_distribute_matrix("cubic", input_grid, output_grid, mpi_root);
// do_assemble_distribute_matrix("quasicubic", input_grid, output_grid, mpi_root);
mpi_root = mpi::size() - 1;
do_assemble_distribute_matrix("structured-linear2D", input_grid, output_grid, mpi_root);
do_assemble_distribute_matrix("structured-cubic2D", input_grid, output_grid, mpi_root);
do_assemble_distribute_matrix("structured-quasicubic2D", input_grid, output_grid, mpi_root);

do_assemble_distribute_matrix("conservative", input_grid, output_grid, mpi_root);
do_assemble_distribute_matrix("nearest-neighbour", input_grid, output_grid, mpi_root);
do_assemble_distribute_matrix("k-nearest-neighbours", input_grid, output_grid, mpi_root);

mpi_root = mpi::comm().size() - 1;
do_assemble_distribute_matrix("conservative-spherical-polygon", input_grid, output_grid, mpi_root);
do_assemble_distribute_matrix("finite-element", input_grid, output_grid, mpi_root);
do_assemble_distribute_matrix("unstructured-bilinear-lonlat", input_grid, output_grid, mpi_root);
};

test_matrix_assemble_distribute(Grid("O128"), Grid("F128"));
Expand Down

0 comments on commit 8def729

Please sign in to comment.