Skip to content
Draft
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
ae88a4d
Add dummy tests
mo-joshuacolclough Oct 13, 2025
28774d7
Remove MPL_SET_DEFAULT_COMM and set default prior to MPL_INIT call
mo-joshuacolclough Oct 13, 2025
281b875
Implement set_mpi_comm. Add transi example with split communicator
mo-joshuacolclough Oct 14, 2025
ae81b10
Add mpi includes in CMake
mo-joshuacolclough Oct 14, 2025
2dc9ebb
Fix test include
mo-joshuacolclough Oct 14, 2025
1947405
Check that if set_mpi_comm is called again, that it is with the same …
mo-joshuacolclough Oct 14, 2025
4a0d487
Ensure logic is correct - only do check if MPL_INIT is called
mo-joshuacolclough Oct 14, 2025
277053a
add MPI to libs
mo-joshuacolclough Oct 15, 2025
306292b
Rename test example
mo-joshuacolclough Oct 17, 2025
9fa3834
Comment on global grid point field
mo-joshuacolclough Oct 17, 2025
793b622
update cmake
mo-joshuacolclough Oct 17, 2025
4ecb953
Tidying
mo-joshuacolclough Oct 21, 2025
9ef8c2c
Remove include <mpi.h>
mo-joshuacolclough Oct 21, 2025
900cbbe
Only import what is used from mpl_data_module
mo-joshuacolclough Oct 21, 2025
8cbdbc4
Comments
mo-joshuacolclough Oct 22, 2025
1fb1f42
Revert transi CMakeLists
mo-joshuacolclough Oct 22, 2025
3e18627
Link mpi to cmake test.
mo-joshuacolclough Nov 6, 2025
5bf579f
Merge develop
mo-joshuacolclough Nov 6, 2025
63111d0
Remove multiple_res test (saved in a separate branch). Get world size
mo-joshuacolclough Nov 6, 2025
892a1cd
Add simple transi split comm test
mo-joshuacolclough Nov 6, 2025
06499d4
Remove transi test_example (was copied from another test)
mo-joshuacolclough Nov 6, 2025
b9f0200
Remove unecessary imports in transi_module.F90
mo-joshuacolclough Nov 7, 2025
e917db1
Use MPL gather routines
mo-joshuacolclough Nov 7, 2025
d99f9ac
Tidy imports in test
mo-joshuacolclough Nov 7, 2025
0e68acf
Use MPL for splitting of communicator following Sam's comments.
mo-joshuacolclough Nov 11, 2025
c189a68
Temporarily disable transi test
mo-joshuacolclough Jan 5, 2026
a99282d
Fix cmake
mo-joshuacolclough Jan 9, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ ecbuild_find_package( NAME fiat REQUIRED )

ecbuild_add_option( FEATURE MPI
DESCRIPTION "Support for MPI distributed memory parallelism"
REQUIRED_PACKAGES "MPI COMPONENTS Fortran CXX"
REQUIRED_PACKAGES "MPI COMPONENTS Fortran C CXX"
CONDITION fiat_HAVE_MPI )

ecbuild_add_option( FEATURE OMP
Expand Down Expand Up @@ -197,7 +197,7 @@ ectrans_find_lapack()
ecbuild_add_option( FEATURE TESTS
DEFAULT ON
DESCRIPTION "Enable unit testing"
REQUIRED_PACKAGES "MPI COMPONENTS Fortran" )
REQUIRED_PACKAGES "MPI COMPONENTS Fortran C CXX" )

### Add sources
include( ectrans_compile_options )
Expand Down
9 changes: 9 additions & 0 deletions src/transi/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,24 @@ endif()

configure_file( version.c.in version.c )

if( HAVE_MPI )
set(transi_mpi_defs ${MPI_C_DEFINITIONS} )
set(transi_mpi_incs ${MPI_C_INCLUDE_PATH} )
set(transi_mpi_libs ${MPI_C_LIBRARIES} )
endif()

ecbuild_add_library( TARGET transi_dp
SOURCES transi_module.F90
transi.h
transi.c
version.h
${CMAKE_CURRENT_BINARY_DIR}/version.c
DEFINITIONS "${transi_mpi_defs}"
HEADER_DESTINATION include/ectrans
PUBLIC_INCLUDES $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>
$<INSTALL_INTERFACE:include>
"${transi_mpi_incs}"
PUBLIC_LIBS "${transi_mpi_libs}"
PRIVATE_LIBS trans_dp
$<${ectrans_HAVE_ETRANS}:etrans_dp>
PRIVATE_DEFINITIONS ECTRANS_HAVE_MPI=${ectrans_HAVE_MPI}
Expand Down
3 changes: 3 additions & 0 deletions src/transi/transi.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
#ifndef ectrans_transi_h
#define ectrans_transi_h

#include <mpi.h>
#include <stddef.h> // size_t

typedef int _bool;
Expand Down Expand Up @@ -166,6 +167,8 @@ int trans_use_mpi(_bool);
*/
int trans_init(void);

int trans_set_mpi_comm(const MPI_Fint mpi_user_comm);

int trans_set_read(struct Trans_t*, const char* filepath);
int trans_set_write(struct Trans_t*, const char* filepath);
int trans_set_cache(struct Trans_t*, const void*, size_t);
Expand Down
31 changes: 31 additions & 0 deletions src/transi/transi_module.F90
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@ module trans_module
MPL_END, &
MPL_NPROC, &
MPL_MYRANK

use MPL_DATA_MODULE

use mpl_mpif

implicit none

private :: c_ptr
Expand Down Expand Up @@ -572,6 +577,32 @@ function trans_init() bind(C,name="trans_init") result(iret)
end function trans_init


function trans_set_mpi_comm(mpi_user_comm) bind(C,name="trans_set_mpi_comm") result(iret)
use, intrinsic :: iso_c_binding
integer(c_int) :: iret
integer(c_int), value, intent(in) :: mpi_user_comm

integer(c_int), save :: last_comm_set = -1 ! -1 indicates no comm has been set

iret = TRANS_SUCCESS
if (.not. USE_MPI) return
! If MPL_INIT already setup, and the comm coming in is the same, then skip.
if (is_init .and. mpi_user_comm == last_comm_set) return

! Confirm that this is called prior to MPL_INIT, to ensure correct setting of global vars.
if (is_init) then
write(error_unit,'(A)') "trans_set_mpi_comm: ERROR: Must be called prior to trans_init."
iret = TRANS_ERROR
return
end if

LMPLUSERCOMM = .true.
MPLUSERCOMM = mpi_user_comm
last_comm_set = mpi_user_comm

end function trans_set_mpi_comm


function trans_setup(trans) bind(C,name="trans_setup") result(iret)
use, intrinsic :: iso_c_binding
integer(c_int) :: iret
Expand Down
24 changes: 24 additions & 0 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,22 @@ if( HAVE_ETRANS )

endif()

# EDIT(JC)
ecbuild_add_test(TARGET ectrans_test_example
SOURCES trans/test_example.F90
LIBS trans_sp parkind_sp
LINKER_LANGUAGE Fortran
MPI 1
OMP 1
)
ecbuild_add_test(TARGET ectrans_test_multiple_res
SOURCES trans/test_multiple_res.F90
LIBS trans_sp parkind_sp
LINKER_LANGUAGE Fortran
MPI 1
OMP 1
)

# --------------------------------------------------------------------------------------------------
# Add tests for transi
# --------------------------------------------------------------------------------------------------
Expand Down Expand Up @@ -417,6 +433,14 @@ if( HAVE_TRANSI )
LINKER_LANGUAGE C
ENVIRONMENT TRANS_USE_MPI=0 )


ecbuild_add_test( TARGET ectrans_test_transi_example
SOURCES transi/transi_test_example.c
LIBS ectrans_test
MPI 1
LINKER_LANGUAGE C
ENVIRONMENT TRANS_USE_MPI=1 )

ecbuild_add_test( TARGET ectrans_test_transi_lonlat_diff_incr
SOURCES transi/transi_test_lonlat_diff_incr.c
LIBS ectrans_test
Expand Down
190 changes: 190 additions & 0 deletions tests/trans/test_example.F90
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
! (C) Crown Copyright 2025- Met Office.
!
! This software is licensed under the terms of the Apache Licence Version 2.0
! which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
!

program test_example

USE PARKIND1, ONLY: JPRM, JPIM
USE MPL_MODULE ,ONLY : MPL_INIT, MPL_END, MPL_BARRIER, MPL_MYRANK, MPL_NPROC, &
MPL_COMM_SPLIT
USE ABORT_TRANS_MOD, ONLY : ABORT_TRANS
use mpl_data_module
use mpl_mpif

implicit none

#include "setup_trans0.h"
#include "setup_trans.h"
#include "inv_trans.h"
#include "dir_trans.h"
#include "trans_inq.h"

integer(kind=JPIM), parameter, dimension(2) :: truncations = [79, 188]

! MODE
integer(kind=JPIM), parameter, dimension(2) :: Ms = [1, 3]
integer(kind=JPIM), parameter, dimension(2) :: Ns = [2, 4]


integer(kind=JPIM) :: num_spectral_elements, num_grid_points
integer(kind=JPIM) :: g_num_spectral_elements, g_num_grid_points ! global
integer(kind=JPIM) :: mode_index
integer(kind=JPIM) :: ierror
integer(kind=JPIM) :: i
integer(kind=JPIM) :: num_ranks, rank
integer(kind=JPIM) :: split_colour, split_key
integer(kind=JPIM) :: split_comm
integer(kind=JPIM) :: truncation
integer(kind=JPIM) :: M, N
integer(kind=JPIM) :: num_latitudes, num_longitudes

! Number of grid points on each rank
integer(kind=JPIM) :: grid_partition_size_local(1)
integer(kind=JPIM), allocatable :: displs(:)
integer(kind=JPIM), allocatable :: grid_partition_sizes(:)

integer(kind=JPIM), allocatable :: spectral_indices(:)

real(kind=JPRM), allocatable :: spectral_field(:,:)
real(kind=JPRM), allocatable :: grid_point_field(:,:,:)

real(kind=JPRM), allocatable :: g_grid_point_field(:)

character(len=1024) :: filename


call MPI_Init(ierror)
call MPI_Comm_rank(MPI_COMM_WORLD, rank, ierror)

split_colour = get_split_group()
split_key = rank
call MPI_Comm_split(MPI_COMM_WORLD, split_colour, split_key, split_comm, ierror)

print*, "=== Rank ", rank, ", Setup on group", split_colour, "==="

! Set MPL comm
LMPLUSERCOMM = .true.
MPLUSERCOMM = split_comm
call MPL_INIT()


rank = MPL_MYRANK()
num_ranks = MPL_NPROC()
print*, "=== Local rank ", rank, ", on group", split_colour, "size", num_ranks, "==="

! Split grid NS Split spectral
call setup_trans0(KPRINTLEV=0, LDMPOFF=.false., KPRGPNS=num_ranks, KPRTRW=num_ranks)

! DIFFERENT TRANSFORM BASED ON COMM GROUP
truncation = truncations(split_colour + 1)
print*, "TRUNCATION = ", truncation

num_latitudes = 2*(truncation + 1)
num_longitudes = num_latitudes*2
print*, ">>> GLOBAL NUM LON =>", num_longitudes, "| LAT =>", num_latitudes
call setup_trans(KSMAX=truncation, KDGL=num_latitudes)

! Get function space sizes
call trans_inq(KSPEC2=num_spectral_elements, KGPTOT=num_grid_points)
call trans_inq(KSPEC2G=g_num_spectral_elements, KGPTOTG=g_num_grid_points)
print*,"Num spec = ", num_spectral_elements, "| Num grid points = ", num_grid_points, g_num_grid_points

allocate(spectral_field(1, num_spectral_elements))
allocate(grid_point_field(num_grid_points, 1, 1))

! Get spectral indices
allocate(spectral_indices(0:truncation))
call trans_inq(KASM0=spectral_indices)

! select mode
M = Ms(split_colour + 1)
N = Ns(split_colour + 1)
mode_index = spectral_indices(M) + 2*(N - M) + 1

spectral_field(:,:) = 0.0

if (mode_index > 0) then
spectral_field(1,mode_index) = 1.0
end if

call inv_trans(PSPSCALAR=spectral_field, PGP=grid_point_field)

! -------------- Gather the result on the root (0) and write to file -----------------

! Get counts from each PE.
grid_partition_size_local(1) = num_grid_points
allocate(grid_partition_sizes(num_ranks))
grid_partition_sizes = 0

call MPI_Gather(grid_partition_size_local, 1, MPI_INT, &
grid_partition_sizes, 1, MPI_INT, &
0, split_comm, ierror)
if (ierror /= 0) then
print*,"MPI ERROR => ", ierror
call ABORT_TRANS("MPI ERROR")
end if

print*, "SIZES => ", grid_partition_sizes(:)

! Allocate a global field.
allocate(g_grid_point_field(g_num_grid_points))

! Make displacement arrays
allocate(displs(num_ranks))
displs = 0
do i=2, num_ranks
displs(i) = displs(i - 1) + grid_partition_sizes(i - 1)
end do
print*,"displs => ", displs(:)

call MPI_Gatherv(grid_point_field(:,1,1), num_grid_points, MPI_FLOAT, &
g_grid_point_field, grid_partition_sizes, displs, MPI_FLOAT, &
0, split_comm, ierror)
if (ierror /= 0) then
print*,"Gatherv MPI ERROR => ", ierror
call ABORT_TRANS("MPI ERROR")
end if

if (rank == 1) then
write(filename, "(A22,I0,A4)") "grid_point_field_group", split_colour, ".dat"
open(7, file=filename, form="unformatted")
write(7) g_grid_point_field(:)
close(7)
end if

call MPL_END()
call MPI_Finalize()

CONTAINS

! Get the colour of comm for this rank.
function get_split_group() result(group)
implicit none
! return
integer(kind=JPIM) :: group

integer(kind=JPIM) :: rank, world_size, ierror
real(kind=JPRM) :: rank_ratio

call MPI_Comm_size(MPI_COMM_WORLD, world_size, ierror)
call MPI_Comm_rank(MPI_COMM_WORLD, rank, ierror)

!group = mod(rank, 2) ! Split comm in half, alternating.

! ----------------------------------------------
! Uneven splitting.
! ----------------------------------------------
rank_ratio = real(rank, kind=JPRM) / real(world_size, kind=JPRM)

! Split X%
if (rank_ratio < 0.3_jprm) then
group = 0
else
group = 1
end if

end function get_split_group

end program
Loading