Skip to content

Commit f865434

Browse files
adit4443yagxyd
andauthored
Implement MPI_Gatherv Wrapper for String Array as DataType and add bindings for MPI_LOGICAL and MPI_CHARACTER (#126)
* Add bindings for MPI_LOGICAL and MPI_CHARACTER datatype * Implement wrapper for MPI_Gatherv for data of type String Array * Add Test for MPI_Gatherv for case of string array as data type being gathered * tests: add another program to test MPI_GatherV --------- Co-authored-by: Gaurav Dhingra <[email protected]>
1 parent 2cdaf6f commit f865434

File tree

4 files changed

+110
-1
lines changed

4 files changed

+110
-1
lines changed

src/mpi.f90

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ module mpi
88
integer, parameter :: MPI_DOUBLE_PRECISION = -10004
99
integer, parameter :: MPI_REAL4 = -10013
1010
integer, parameter :: MPI_REAL8 = -10014
11+
integer, parameter :: MPI_CHARACTER = -10003
12+
integer, parameter :: MPI_LOGICAL = -10005
1113

1214
integer, parameter :: MPI_COMM_TYPE_SHARED = 1
1315
integer, parameter :: MPI_PROC_NULL = -1
@@ -75,6 +77,7 @@ module mpi
7577
interface MPI_Gatherv
7678
module procedure MPI_Gatherv_int
7779
module procedure MPI_Gatherv_real
80+
module procedure MPI_Gatherv_character
7881
end interface MPI_Gatherv
7982

8083
interface MPI_Wtime
@@ -170,14 +173,18 @@ integer(kind=MPI_HANDLE_KIND) function handle_mpi_info_f2c(info_f) result(c_info
170173
end function handle_mpi_info_f2c
171174

172175
integer(kind=MPI_HANDLE_KIND) function handle_mpi_datatype_f2c(datatype_f) result(c_datatype)
173-
use mpi_c_bindings, only: c_mpi_float, c_mpi_double, c_mpi_int
176+
use mpi_c_bindings, only: c_mpi_float, c_mpi_double, c_mpi_int, c_mpi_logical, c_mpi_character
174177
integer, intent(in) :: datatype_f
175178
if (datatype_f == MPI_REAL4) then
176179
c_datatype = c_mpi_float
177180
else if (datatype_f == MPI_REAL8 .OR. datatype_f == MPI_DOUBLE_PRECISION) then
178181
c_datatype = c_mpi_double
179182
else if (datatype_f == MPI_INTEGER) then
180183
c_datatype = c_mpi_int
184+
else if (datatype_f == MPI_CHARACTER) then
185+
c_datatype = c_mpi_character
186+
else if (datatype_f == MPI_LOGICAL) then
187+
c_datatype = c_mpi_logical
181188
end if
182189
end function
183190

@@ -852,6 +859,42 @@ subroutine MPI_Gatherv_real(sendbuf, sendcount, sendtype, recvbuf, recvcounts, &
852859
end if
853860
end subroutine MPI_Gatherv_real
854861

862+
subroutine MPI_Gatherv_character(sendbuf, sendcount, sendtype, recvbuf, recvcounts, &
863+
displs, recvtype, root, comm, ierror)
864+
use iso_c_binding, only: c_int, c_ptr, c_loc
865+
use mpi_c_bindings, only: c_mpi_gatherv
866+
character(len=*), intent(in), target :: sendbuf(*)
867+
integer, intent(in) :: sendcount
868+
integer, intent(in) :: sendtype
869+
character(len=*), intent(out), target :: recvbuf(*)
870+
integer, dimension(:), intent(in) :: recvcounts
871+
integer, dimension(:), intent(in) :: displs
872+
integer, intent(in) :: recvtype
873+
integer, intent(in) :: root
874+
integer, intent(in) :: comm
875+
integer, optional, intent(out) :: ierror
876+
integer(kind=MPI_HANDLE_KIND) :: c_sendtype, c_recvtype, c_comm
877+
type(c_ptr) :: c_sendbuf, c_recvbuf
878+
integer(c_int) :: local_ierr
879+
880+
c_sendbuf = c_loc(sendbuf)
881+
c_recvbuf = c_loc(recvbuf)
882+
c_sendtype = handle_mpi_datatype_f2c(sendtype)
883+
c_recvtype = handle_mpi_datatype_f2c(recvtype)
884+
c_comm = handle_mpi_comm_f2c(comm)
885+
886+
! Call C MPI_Gatherv
887+
local_ierr = c_mpi_gatherv(c_sendbuf, sendcount, c_sendtype, &
888+
c_recvbuf, recvcounts, displs, c_recvtype, &
889+
root, c_comm)
890+
891+
if (present(ierror)) then
892+
ierror = local_ierr
893+
else if (local_ierr /= MPI_SUCCESS) then
894+
print *, "MPI_Gatherv failed with error code: ", local_ierr
895+
end if
896+
end subroutine MPI_Gatherv_character
897+
855898
subroutine MPI_Waitall_proc(count, array_of_requests, array_of_statuses, ierror)
856899
use iso_c_binding, only: c_int, c_ptr
857900
use mpi_c_bindings, only: c_mpi_waitall, c_mpi_request_f2c, c_mpi_request_c2f, c_mpi_status_c2f, c_mpi_statuses_ignore

src/mpi_c_bindings.f90

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ module mpi_c_bindings
1818
integer(kind=MPI_HANDLE_KIND), bind(C, name="c_MPI_COMM_WORLD") :: c_mpi_comm_world
1919
integer(kind=MPI_HANDLE_KIND), bind(C, name="c_MPI_SUM") :: c_mpi_sum
2020
integer(kind=MPI_HANDLE_KIND), bind(C, name="c_MPI_MAX") :: c_mpi_max
21+
integer(kind=MPI_HANDLE_KIND), bind(C, name="c_MPI_LOGICAL") :: c_mpi_logical
22+
integer(kind=MPI_HANDLE_KIND), bind(C, name="c_MPI_CHARACTER") :: c_mpi_character
2123

2224
interface
2325

src/mpi_constants.c

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,7 @@ void* c_MPI_IN_PLACE = MPI_IN_PLACE;
1717
MPI_Op c_MPI_SUM = MPI_SUM;
1818

1919
MPI_Op c_MPI_MAX = MPI_MAX;
20+
21+
MPI_Datatype c_MPI_LOGICAL = MPI_LOGICAL;
22+
23+
MPI_Datatype c_MPI_CHARACTER = MPI_CHARACTER;

tests/gatherv_3.f90

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
program gatherv_3
2+
use mpi
3+
implicit none
4+
5+
integer :: rank, nprocs, ierr
6+
integer :: local_data(3)
7+
integer, allocatable :: gathered_data(:)
8+
integer, allocatable :: counts(:), displacements(:)
9+
10+
call MPI_INIT(ierr)
11+
call MPI_COMM_RANK(MPI_COMM_WORLD, rank, ierr)
12+
call MPI_COMM_SIZE(MPI_COMM_WORLD, nprocs, ierr)
13+
14+
local_data = (/ rank*3+1, rank*3+2, rank*3+3 /)
15+
16+
allocate(counts(nprocs))
17+
allocate(displacements(nprocs))
18+
19+
call gatherIntegers(local_data, gathered_data, counts, displacements, rank, nprocs, MPI_COMM_WORLD)
20+
21+
deallocate(counts, displacements)
22+
if (rank == 0) deallocate(gathered_data)
23+
24+
call MPI_FINALIZE(ierr)
25+
26+
contains
27+
28+
subroutine gatherIntegers(local_data, gathered_data, counts, displacements, rank, nprocs, comm)
29+
integer, intent(in) :: local_data(:)
30+
integer, allocatable, intent(out) :: gathered_data(:)
31+
integer, intent(out) :: counts(:)
32+
integer, intent(out) :: displacements(:)
33+
integer, intent(in) :: rank, nprocs, comm
34+
35+
integer :: i, total_elements, ierr
36+
integer :: local_size
37+
38+
local_size = size(local_data)
39+
counts = local_size
40+
41+
displacements(1) = 0
42+
do i = 2, nprocs
43+
displacements(i) = displacements(i-1) + counts(i-1)
44+
end do
45+
46+
total_elements = local_size * nprocs
47+
48+
if (rank == 0) then
49+
allocate(gathered_data(total_elements))
50+
else
51+
allocate(gathered_data(1))
52+
end if
53+
54+
call MPI_GatherV(local_data, local_size, MPI_INTEGER, &
55+
gathered_data, counts, displacements, MPI_INTEGER, &
56+
0, comm, ierr)
57+
58+
end subroutine gatherIntegers
59+
60+
end program gatherv_3

0 commit comments

Comments
 (0)