diff --git a/CMakeLists.txt b/CMakeLists.txt index fa50151f8..b03bfc409 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -43,7 +43,7 @@ ecbuild_find_package( NAME fiat REQUIRED ) ecbuild_add_option( FEATURE MPI DESCRIPTION "Support for MPI distributed memory parallelism" - REQUIRED_PACKAGES "MPI COMPONENTS Fortran CXX" + REQUIRED_PACKAGES "MPI COMPONENTS C Fortran CXX" CONDITION fiat_HAVE_MPI ) ecbuild_add_option( FEATURE OMP diff --git a/src/transi/transi.h b/src/transi/transi.h index d4640212f..0f53a4c5b 100644 --- a/src/transi/transi.h +++ b/src/transi/transi.h @@ -189,6 +189,8 @@ int trans_use_mpi(_bool); */ int trans_init(void); +int trans_set_mpi_comm(const int mpi_user_comm); + int trans_set_read(struct Trans_t*, const char* filepath); int trans_set_write(struct Trans_t*, const char* filepath); int trans_set_cache(struct Trans_t*, const void*, size_t); diff --git a/src/transi/transi_module.F90 b/src/transi/transi_module.F90 index 1e152eb54..b39e7e0d2 100644 --- a/src/transi/transi_module.F90 +++ b/src/transi/transi_module.F90 @@ -39,6 +39,11 @@ module trans_module MPL_END, & MPL_NPROC, & MPL_MYRANK + +use MPL_DATA_MODULE, only: & + MPLUSERCOMM, & + LMPLUSERCOMM + implicit none private :: c_ptr @@ -640,6 +645,32 @@ function trans_init() bind(C,name="trans_init") result(iret) end function trans_init +function trans_set_mpi_comm(mpi_user_comm) bind(C,name="trans_set_mpi_comm") result(iret) + use, intrinsic :: iso_c_binding + integer(c_int) :: iret + integer(c_int), value, intent(in) :: mpi_user_comm + + integer(c_int), save :: last_comm_set = -1 ! -1 indicates no comm has been set + + iret = TRANS_SUCCESS + if (.not. USE_MPI) return + ! If MPL_INIT already setup, and the comm coming in is the same, then skip. + if (is_init .and. mpi_user_comm == last_comm_set) return + + ! Confirm that this is called prior to MPL_INIT, to ensure correct setting of global vars. + if (is_init) then + write(error_unit,'(A)') "trans_set_mpi_comm: ERROR: Must be called prior to trans_init." + iret = TRANS_ERROR + return + end if + + LMPLUSERCOMM = .true. + MPLUSERCOMM = mpi_user_comm + last_comm_set = mpi_user_comm + +end function trans_set_mpi_comm + + function trans_setup(trans) bind(C,name="trans_setup") result(iret) use, intrinsic :: iso_c_binding integer(c_int) :: iret diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index bdbb32408..f7ca8a580 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -409,6 +409,16 @@ if( HAVE_ETRANS ) endif() +if( HAVE_CPU AND HAVE_MPI ) + ecbuild_add_test(TARGET ectrans_test_split_mpi_comm + SOURCES trans/test_split_mpi_comm.F90 + LIBS trans_sp parkind_sp + LINKER_LANGUAGE Fortran + MPI 4 + OMP 1 + ) +endif() + # -------------------------------------------------------------------------------------------------- # Add tests for transi # -------------------------------------------------------------------------------------------------- @@ -555,4 +565,14 @@ if( HAVE_TRANSI ) LINKER_LANGUAGE C CONDITION HAVE_ETRANS ENVIRONMENT TRANS_USE_MPI=0 ) + + if( HAVE_MPI ) + ecbuild_add_test( TARGET ectrans_test_transi_split_comm + SOURCES transi/transi_test_split_comm.c + LIBS ectrans_test MPI::MPI_C + MPI 2 + LINKER_LANGUAGE C + ENVIRONMENT TRANS_USE_MPI=1 ) + endif() + endif() diff --git a/tests/trans/test_split_mpi_comm.F90 b/tests/trans/test_split_mpi_comm.F90 new file mode 100644 index 000000000..a80f85f7f --- /dev/null +++ b/tests/trans/test_split_mpi_comm.F90 @@ -0,0 +1,186 @@ +! (C) Crown Copyright 2025- Met Office. +! +! This software is licensed under the terms of the Apache Licence Version 2.0 +! which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +! + +program test_split_mpi_comm + +USE PARKIND1, ONLY : JPRM, JPIM +USE MPL_MODULE, ONLY : MPL_INIT, MPL_END, MPL_MYRANK, MPL_NPROC, MPL_GATHERV, & + MPL_COMM_SPLIT, MPL_SETDFLT_COMM, MPL_COMM +USE ABORT_TRANS_MOD, ONLY : ABORT_TRANS +USE mpl_data_module, ONLY : MPLUSERCOMM, LMPLUSERCOMM +USE mpl_mpif + +implicit none + +#include "setup_trans0.h" +#include "setup_trans.h" +#include "inv_trans.h" +#include "dir_trans.h" +#include "trans_inq.h" + +integer(kind=JPIM), parameter, dimension(2) :: truncations = [79, 188] + +! MODE +integer(kind=JPIM), parameter, dimension(2) :: Ms = [1, 3] +integer(kind=JPIM), parameter, dimension(2) :: Ns = [2, 4] + + +integer(kind=JPIM) :: num_spectral_elements, num_grid_points +integer(kind=JPIM) :: g_num_spectral_elements, g_num_grid_points ! global +integer(kind=JPIM) :: local_spectral_coefficient_index +integer(kind=JPIM) :: ierror +integer(kind=JPIM) :: i +integer(kind=JPIM) :: split_num_ranks, split_rank +integer(kind=JPIM) :: world_num_ranks, world_rank +integer(kind=JPIM) :: split_colour, split_key +integer(kind=JPIM) :: split_comm, dummy_comm +integer(kind=JPIM) :: truncation +integer(kind=JPIM) :: M, N +integer(kind=JPIM) :: num_latitudes, num_longitudes + +! Book-keeping for MPI gather. +integer(kind=JPIM), allocatable :: displs(:) +integer(kind=JPIM), allocatable :: grid_partition_sizes(:) + +integer(kind=JPIM), allocatable :: spectral_indices(:) + +! Fields +real(kind=JPRM), allocatable :: spectral_field(:,:) +real(kind=JPRM), allocatable :: grid_point_field(:,:,:) + +! NOTE: 1 Dimensional global field used to write to file output. +real(kind=JPRM), allocatable :: g_grid_point_field(:) + +character(len=1024) :: filename + +call MPL_INIT() +world_num_ranks = MPL_NPROC() +world_rank = MPL_MYRANK() + +split_colour = get_split_group(world_rank, world_num_ranks) +split_key = world_rank +call MPL_COMM_SPLIT(MPL_COMM, split_colour, split_key, split_comm, ierror) +call MPL_SETDFLT_COMM(split_comm, dummy_comm) + +split_rank = MPL_MYRANK() +split_num_ranks = MPL_NPROC() + +print*,"=== Rank ", world_rank, ", Setup on group", split_colour, "num ranks = ", split_num_ranks, "===" + +! Assert that the split comm is smaller than WORLD. +if (split_num_ranks >= world_num_ranks) then + print*, "SPLIT = ", split_num_ranks, "TOTAL = ", world_num_ranks + call ABORT_TRANS("ERROR: Split communicator not smaller than MPI_COMM_WORLD.") +end if + +print*,"=== Local rank ", split_rank, ", on group", split_colour, "size", split_num_ranks, "===" + +call setup_trans0(KPRINTLEV=0, LDMPOFF=.false., & +! Split grid NS Split spectral + KPRGPNS=split_num_ranks, KPRTRW=split_num_ranks) + +! Different transform based on the colour. +truncation = truncations(split_colour + 1) + +num_latitudes = 2*(truncation + 1) +num_longitudes = num_latitudes*2 + +if (split_rank == 1) print*,"Colour ", split_colour, & + " GLOBAL NUM LON =>", num_longitudes, "| LAT =>", num_latitudes + +call setup_trans(KSMAX=truncation, KDGL=num_latitudes) + +! Get function space sizes +call trans_inq(KSPEC2=num_spectral_elements, KGPTOT=num_grid_points) +call trans_inq(KSPEC2G=g_num_spectral_elements, KGPTOTG=g_num_grid_points) + +allocate(spectral_field(1, num_spectral_elements)) +allocate(grid_point_field(num_grid_points, 1, 1)) + +! Get spectral indices +allocate(spectral_indices(0:truncation)) +call trans_inq(KASM0=spectral_indices) + +! select mode +M = Ms(split_colour + 1) +N = Ns(split_colour + 1) +local_spectral_coefficient_index = spectral_indices(M) + 2*(N - M) + 1 + +spectral_field(:,:) = 0.0 + +if (local_spectral_coefficient_index > 0) then + spectral_field(1,local_spectral_coefficient_index) = 1.0 +end if + +call inv_trans(PSPSCALAR=spectral_field, PGP=grid_point_field) + +! -------------- Gather the result on the root (0) and write to file ----------------- + +! Get counts from each PE. +allocate(grid_partition_sizes(split_num_ranks)) +grid_partition_sizes = 0 + +call MPL_GATHERV(num_grid_points, KRECVBUF=grid_partition_sizes, KCOMM=split_comm) + +if (split_rank == 1) then + ! Allocate a global field. + allocate(g_grid_point_field(g_num_grid_points)) + + ! Make displacement arrays + allocate(displs(split_num_ranks)) + displs = 0 + do i=2, split_num_ranks + displs(i) = displs(i - 1) + grid_partition_sizes(i - 1) + end do +end if + +call MPL_GATHERV(grid_point_field(:,1,1), PRECVBUF=g_grid_point_field, & + KCOMM=split_comm, & + KRECVCOUNTS=grid_partition_sizes, & + KRECVDISPL=displs) + +if (split_rank == 1) then + ! Write to file. Can then be plotted using a python script + ! such as in the docs: https://sites.ecmwf.int/docs/ectrans/page/usage.html. + + write(filename, "(A22,I0,A4)") "grid_point_field_trunc_", truncation, ".dat" + open(7, file=filename, form="unformatted") + write(7) g_grid_point_field(:) + close(7) + + print*,"Colour", split_colour, "finished and written to file: "//trim(filename) +end if + +call MPL_END() + +CONTAINS + +! Get the colour of comm for this rank. +function get_split_group(rank, world_size) result(group) + implicit none + + integer(kind=JPIM), intent(in) :: rank + integer(kind=JPIM), intent(in) :: world_size + ! return + integer(kind=JPIM) :: group + + real(kind=JPRM) :: rank_ratio + + ! ---------------------------------------------- + ! Uneven splitting based on a ratio 1:3. + ! ---------------------------------------------- + rank_ratio = real(rank, kind=JPRM) / real(world_size, kind=JPRM) + + ! Split X% + if (rank_ratio <= 0.25_jprm) then + group = 0 + else + group = 1 + end if + +end function get_split_group + +end program diff --git a/tests/transi/transi_test.h b/tests/transi/transi_test.h index 5bef5f6bc..2c1747b82 100644 --- a/tests/transi/transi_test.h +++ b/tests/transi/transi_test.h @@ -40,6 +40,15 @@ }\ } while(0) +#define ASSERT_MSG( assertion, msg ) do {\ + if( !(assertion) ) {\ + printf("ERROR: Assertion `%s' failed @%s:%d => %s\n",#assertion,__FILE__,__LINE__,#msg);\ + TRANS_CHECK( trans_delete(&trans) );\ + exit(1);\ + }\ +} while(0) + + #define TRANS_ERROR -1 #define TRANS_NOTIMPL -2 #define TRANS_MISSING_ARG -3 diff --git a/tests/transi/transi_test_split_comm.c b/tests/transi/transi_test_split_comm.c new file mode 100644 index 000000000..f45aa91e9 --- /dev/null +++ b/tests/transi/transi_test_split_comm.c @@ -0,0 +1,80 @@ +/* + * (C) Crown Copyright 2025- Met Office. + * + * This software is licensed under the terms of the Apache Licence Version 2.0 + * which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. + */ + +#include +#include +#include + +#include "ectrans/transi.h" + +#include "transi_test.h" + +#include + +// ---------------------------------------------------------------------------- + +int getColour(const int world_rank) { return world_rank % 2; } + +// ---------------------------------------------------------------------------- + +int main ( int arc, char **argv ) { + MPI_Init(&arc, &argv); + trans_use_mpi(true); + + setbuf(stdout,NULL); // unbuffered stdout + + int world_size; + MPI_Comm_size(MPI_COMM_WORLD, &world_size); + + int world_rank; + MPI_Comm_rank(MPI_COMM_WORLD, &world_rank); + + // Split world communicator. + const int colour = getColour(world_rank); + MPI_Comm split_comm; + MPI_Comm_split(MPI_COMM_WORLD, colour, world_rank, &split_comm); + + int split_size; + MPI_Comm_size(split_comm, &split_size); + + // Set default fiat MPL comm. + const MPI_Fint split_comm_int = MPI_Comm_c2f(split_comm); + TRANS_CHECK( trans_set_mpi_comm(split_comm_int) ); + + // Initialise trans (+ MPL as a result) with split communicator. + struct Trans_t trans; + TRANS_CHECK( trans_new(&trans) ); + + const int nlon = 320; + const int nlat = 161; + const int nsmax = 159; + TRANS_CHECK( trans_set_resol_lonlat(&trans,nlon,nlat) ); + TRANS_CHECK( trans_set_trunc(&trans,nsmax) ); + TRANS_CHECK( trans_setup(&trans) ); + + printf("World size => %d :: Split size => %d :: Trans size => %d\n", + world_size, split_size, trans.nproc); + + ASSERT_MSG(world_size >= 2, + "ERROR: Number of MPI processes for this test must be greater than or equal to 2."); + ASSERT(trans.nproc == split_size); + ASSERT(trans.nproc < world_size); + ASSERT(trans.nproc <= world_size / 2); + + // Attempt to set up trans on WORLD - should fail, since MPL has already + // been initialised on the split_comm + const MPI_Fint world_comm_int = MPI_Comm_c2f(MPI_COMM_WORLD); + const int ret_code = trans_set_mpi_comm(world_comm_int); + ASSERT_MSG(ret_code != 0, "ERROR: Expected `trans_set_mpi_comm(MPI_COMM_WORLD)` " + "to fail on second setup."); + + TRANS_CHECK( trans_finalize() ); + + MPI_Finalize(); + + return 0; +}