Skip to content

Commit fc2b415

Browse files
authored
Add new MPI_Op -> MPI_LOR and Add wrappers for Logical datatype for MPI_Allreduce (#134)
1 parent b9fd50e commit fc2b415

File tree

4 files changed

+71
-6
lines changed

4 files changed

+71
-6
lines changed

src/mpi.f90

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ module mpi
2121
real(8), parameter :: MPI_IN_PLACE = -1002
2222
integer, parameter :: MPI_SUM = -2300
2323
integer, parameter :: MPI_MAX = -2301
24+
integer, parameter :: MPI_LOR = -2302
2425
integer, parameter :: MPI_INFO_NULL = -2000
2526
integer, parameter :: MPI_STATUS_SIZE = 5
2627
integer :: MPI_STATUS_IGNORE = 0
@@ -99,6 +100,7 @@ module mpi
99100
module procedure MPI_Allreduce_1D_recv_proc
100101
module procedure MPI_Allreduce_1D_real_proc
101102
module procedure MPI_Allreduce_1D_int_proc
103+
module procedure MPI_Allreduce_scalar_logical_proc
102104
end interface
103105

104106
interface MPI_Gatherv
@@ -172,14 +174,16 @@ module mpi
172174
contains
173175

174176
integer(kind=MPI_HANDLE_KIND) function handle_mpi_op_f2c(op_f) result(c_op)
175-
use mpi_c_bindings, only: c_mpi_op_f2c, c_mpi_sum, c_mpi_max
177+
use mpi_c_bindings, only: c_mpi_op_f2c, c_mpi_sum, c_mpi_max, c_mpi_lor
176178
integer, intent(in) :: op_f
177179
if (op_f == MPI_SUM) then
178180
c_op = c_mpi_sum
179181
else if (op_f == MPI_MAX) then
180182
c_op = c_MPI_MAX
183+
else if (op_f == MPI_LOR) then
184+
c_op = c_mpi_lor
181185
else
182-
c_op = c_mpi_op_f2c(op_f)
186+
c_op = c_mpi_op_f2c(op_f) ! For other operations, use the C binding
183187
end if
184188
end function
185189

@@ -841,6 +845,35 @@ subroutine MPI_Allreduce_1D_int_proc(sendbuf, recvbuf, count, datatype, op, comm
841845
end if
842846
end subroutine MPI_Allreduce_1D_int_proc
843847

848+
subroutine MPI_Allreduce_scalar_logical_proc(sendbuf, recvbuf, count, datatype, op, comm, ierror)
849+
use iso_c_binding, only: c_int, c_ptr, c_loc
850+
use mpi_c_bindings, only: c_mpi_allreduce, c_mpi_comm_f2c
851+
logical, intent(in), target :: sendbuf
852+
logical, intent(out), target :: recvbuf
853+
integer, intent(in) :: count, datatype, op, comm
854+
integer, intent(out), optional :: ierror
855+
type(c_ptr) :: sendbuf_ptr, recvbuf_ptr
856+
integer(kind=MPI_HANDLE_KIND) :: c_datatype, c_op, c_comm
857+
integer(c_int) :: local_ierr
858+
859+
sendbuf_ptr = c_loc(sendbuf)
860+
recvbuf_ptr = c_loc(recvbuf)
861+
c_datatype = handle_mpi_datatype_f2c(datatype)
862+
c_op = handle_mpi_op_f2c(op)
863+
864+
c_comm = handle_mpi_comm_f2c(comm)
865+
866+
local_ierr = c_mpi_allreduce(sendbuf_ptr, recvbuf_ptr, count, c_datatype, c_op, c_comm)
867+
868+
if (present(ierror)) then
869+
ierror = local_ierr
870+
else
871+
if (local_ierr /= MPI_SUCCESS) then
872+
print *, "MPI_Allreduce_1D_recv_proc failed with error code: ", local_ierr
873+
end if
874+
end if
875+
end subroutine MPI_Allreduce_scalar_logical_proc
876+
844877
function MPI_Wtime_proc() result(time)
845878
use mpi_c_bindings, only: c_mpi_wtime
846879
real(8) :: time

src/mpi_c_bindings.f90

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,21 @@ module mpi_c_bindings
1212
type(c_ptr), bind(C, name="c_MPI_STATUSES_IGNORE") :: c_mpi_statuses_ignore
1313
type(c_ptr), bind(C, name="c_MPI_IN_PLACE") :: c_mpi_in_place
1414
integer(kind=MPI_HANDLE_KIND), bind(C, name="c_MPI_INFO_NULL") :: c_mpi_info_null
15+
1516
integer(kind=MPI_HANDLE_KIND), bind(C, name="c_MPI_DOUBLE") :: c_mpi_double
1617
integer(kind=MPI_HANDLE_KIND), bind(C, name="c_MPI_FLOAT") :: c_mpi_float
1718
integer(kind=MPI_HANDLE_KIND), bind(C, name="c_MPI_REAL") :: c_mpi_real
1819
integer(kind=MPI_HANDLE_KIND), bind(C, name="c_MPI_INT") :: c_mpi_int
19-
integer(kind=MPI_HANDLE_KIND), bind(C, name="c_MPI_COMM_WORLD") :: c_mpi_comm_world
20-
integer(kind=MPI_HANDLE_KIND), bind(C, name="c_MPI_COMM_NULL") :: c_mpi_comm_null
21-
integer(kind=MPI_HANDLE_KIND), bind(C, name="c_MPI_SUM") :: c_mpi_sum
22-
integer(kind=MPI_HANDLE_KIND), bind(C, name="c_MPI_MAX") :: c_mpi_max
2320
integer(kind=MPI_HANDLE_KIND), bind(C, name="c_MPI_LOGICAL") :: c_mpi_logical
2421
integer(kind=MPI_HANDLE_KIND), bind(C, name="c_MPI_CHARACTER") :: c_mpi_character
2522

23+
integer(kind=MPI_HANDLE_KIND), bind(C, name="c_MPI_SUM") :: c_mpi_sum
24+
integer(kind=MPI_HANDLE_KIND), bind(C, name="c_MPI_MAX") :: c_mpi_max
25+
integer(kind=MPI_HANDLE_KIND), bind(C, name="c_MPI_LOR") :: c_mpi_lor
26+
27+
integer(kind=MPI_HANDLE_KIND), bind(C, name="c_MPI_COMM_WORLD") :: c_mpi_comm_world
28+
integer(kind=MPI_HANDLE_KIND), bind(C, name="c_MPI_COMM_NULL") :: c_mpi_comm_null
29+
2630
interface
2731

2832
function c_mpi_comm_f2c(comm_f) bind(C, name="MPI_Comm_f2c")

src/mpi_constants.c

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ MPI_Op c_MPI_SUM = MPI_SUM;
2626

2727
MPI_Op c_MPI_MAX = MPI_MAX;
2828

29+
MPI_Op c_MPI_LOR = MPI_LOR;
30+
2931
// Communicators Declarations
3032

3133
MPI_Comm c_MPI_COMM_NULL = MPI_COMM_NULL;

tests/allreduce_lor.f90

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
program mre_mpi_lor_allreduce
2+
use mpi
3+
implicit none
4+
5+
integer :: ierr, rank, size
6+
logical :: local_flag, global_flag
7+
8+
call MPI_INIT(ierr)
9+
if (ierr /= MPI_SUCCESS) error stop "MPI_INIT failed"
10+
11+
call MPI_COMM_RANK(MPI_COMM_WORLD, rank, ierr)
12+
call MPI_COMM_SIZE(MPI_COMM_WORLD, size, ierr)
13+
14+
! Initialize the local flag: True if this is the 0th rank, False otherwise
15+
local_flag = (rank == 0)
16+
17+
! Perform logical OR reduction across all processes
18+
call MPI_ALLREDUCE(local_flag, global_flag, 1, MPI_LOGICAL, MPI_LOR, MPI_COMM_WORLD, ierr)
19+
if (global_flag .neqv. .true.) error stop "MPI_ALLREDUCE failed"
20+
21+
print *, 'Rank', rank, ': global_flag =', global_flag
22+
23+
call MPI_FINALIZE(ierr)
24+
if (ierr /= MPI_SUCCESS) error stop "MPI_FINALIZE failed"
25+
26+
end program mre_mpi_lor_allreduce

0 commit comments

Comments
 (0)