Skip to content

Commit 2cdaf6f

Browse files
authored
feat: Implement wrappers for MPI_Gatherv (#123)
* Add Test for MPI_Gatherv * Add wrappers for MPI_Gatherv
1 parent cd1d4bd commit 2cdaf6f

File tree

3 files changed

+171
-0
lines changed

3 files changed

+171
-0
lines changed

src/mpi.f90

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,11 @@ module mpi
7272
module procedure MPI_Allreduce_1D_int_proc
7373
end interface
7474

75+
interface MPI_Gatherv
76+
module procedure MPI_Gatherv_int
77+
module procedure MPI_Gatherv_real
78+
end interface MPI_Gatherv
79+
7580
interface MPI_Wtime
7681
module procedure MPI_Wtime_proc
7782
end interface
@@ -765,6 +770,88 @@ subroutine MPI_Recv_StatusIgnore_proc(buf, count, datatype, source, tag, comm, s
765770

766771
end subroutine MPI_Recv_StatusIgnore_proc
767772

773+
subroutine MPI_Gatherv_int(sendbuf, sendcount, sendtype, recvbuf, recvcounts, &
774+
displs, recvtype, root, comm, ierror)
775+
use iso_c_binding, only: c_int, c_ptr, c_loc
776+
use mpi_c_bindings, only: c_mpi_gatherv, c_mpi_in_place
777+
integer, dimension(:), intent(in), target :: sendbuf
778+
integer, intent(in) :: sendcount
779+
integer, intent(in) :: sendtype
780+
integer, dimension(:), intent(out), target :: recvbuf
781+
integer, dimension(:), intent(in) :: recvcounts
782+
integer, dimension(:), intent(in) :: displs
783+
integer, intent(in) :: recvtype
784+
integer, intent(in) :: root
785+
integer, intent(in) :: comm
786+
integer, optional, intent(out) :: ierror
787+
integer(kind=MPI_HANDLE_KIND) :: c_sendtype, c_recvtype, c_comm
788+
type(c_ptr) :: c_sendbuf, c_recvbuf
789+
integer(c_int) :: local_ierr
790+
791+
if (sendbuf(1) == MPI_IN_PLACE) then
792+
c_sendbuf = c_MPI_IN_PLACE
793+
else
794+
c_sendbuf = c_loc(sendbuf)
795+
end if
796+
797+
c_recvbuf = c_loc(recvbuf)
798+
c_sendtype = handle_mpi_datatype_f2c(sendtype)
799+
c_recvtype = handle_mpi_datatype_f2c(recvtype)
800+
c_comm = handle_mpi_comm_f2c(comm)
801+
802+
! Call C MPI_Gatherv
803+
local_ierr = c_mpi_gatherv(c_sendbuf, sendcount, c_sendtype, &
804+
c_recvbuf, recvcounts, displs, c_recvtype, &
805+
root, c_comm)
806+
807+
if (present(ierror)) then
808+
ierror = local_ierr
809+
else if (local_ierr /= MPI_SUCCESS) then
810+
print *, "MPI_Gatherv failed with error code: ", local_ierr
811+
end if
812+
end subroutine MPI_Gatherv_int
813+
814+
subroutine MPI_Gatherv_real(sendbuf, sendcount, sendtype, recvbuf, recvcounts, &
815+
displs, recvtype, root, comm, ierror)
816+
use iso_c_binding, only: c_int, c_ptr, c_loc
817+
use mpi_c_bindings, only: c_mpi_gatherv, c_mpi_in_place
818+
real(8), dimension(:), intent(in), target :: sendbuf
819+
integer, intent(in) :: sendcount
820+
integer, intent(in) :: sendtype
821+
real(8), dimension(:), intent(out), target :: recvbuf
822+
integer, dimension(:), intent(in) :: recvcounts
823+
integer, dimension(:), intent(in) :: displs
824+
integer, intent(in) :: recvtype
825+
integer, intent(in) :: root
826+
integer, intent(in) :: comm
827+
integer, optional, intent(out) :: ierror
828+
integer(kind=MPI_HANDLE_KIND) :: c_sendtype, c_recvtype, c_comm
829+
type(c_ptr) :: c_sendbuf, c_recvbuf
830+
integer(c_int) :: local_ierr
831+
832+
if (sendbuf(1) == MPI_IN_PLACE) then
833+
c_sendbuf = c_MPI_IN_PLACE
834+
else
835+
c_sendbuf = c_loc(sendbuf)
836+
end if
837+
838+
c_recvbuf = c_loc(recvbuf)
839+
c_sendtype = handle_mpi_datatype_f2c(sendtype)
840+
c_recvtype = handle_mpi_datatype_f2c(recvtype)
841+
c_comm = handle_mpi_comm_f2c(comm)
842+
843+
! Call C MPI_Gatherv
844+
local_ierr = c_mpi_gatherv(c_sendbuf, sendcount, c_sendtype, &
845+
c_recvbuf, recvcounts, displs, c_recvtype, &
846+
root, c_comm)
847+
848+
if (present(ierror)) then
849+
ierror = local_ierr
850+
else if (local_ierr /= MPI_SUCCESS) then
851+
print *, "MPI_Gatherv failed with error code: ", local_ierr
852+
end if
853+
end subroutine MPI_Gatherv_real
854+
768855
subroutine MPI_Waitall_proc(count, array_of_requests, array_of_statuses, ierror)
769856
use iso_c_binding, only: c_int, c_ptr
770857
use mpi_c_bindings, only: c_mpi_waitall, c_mpi_request_f2c, c_mpi_request_c2f, c_mpi_status_c2f, c_mpi_statuses_ignore

src/mpi_c_bindings.f90

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,21 @@ function c_mpi_ssend(buf, count, datatype, dest, tag, comm) bind(C, name="MPI_Ss
220220
integer(c_int) :: c_mpi_ssend
221221
end function
222222

223+
function c_mpi_gatherv(sendbuf, sendcount, sendtype, recvbuf, recvcounts, &
224+
displs, recvtype, root, comm) bind(C, name="MPI_Gatherv")
225+
use iso_c_binding, only: c_int, c_ptr
226+
type(c_ptr), value :: sendbuf
227+
integer(c_int), value :: sendcount
228+
integer(kind=MPI_HANDLE_KIND), value :: sendtype
229+
type(c_ptr), value :: recvbuf
230+
integer(c_int), dimension(*), intent(in) :: recvcounts
231+
integer(c_int), dimension(*), intent(in) :: displs
232+
integer(kind=MPI_HANDLE_KIND), value :: recvtype
233+
integer(c_int), value :: root
234+
integer(kind=MPI_HANDLE_KIND), value :: comm
235+
integer(c_int) :: c_mpi_gatherv
236+
end function c_mpi_gatherv
237+
223238
function c_mpi_cart_create(comm_old, ndims, dims, periods, reorder, comm_cart) bind(C, name="MPI_Cart_create")
224239
use iso_c_binding, only: c_int, c_ptr
225240
integer(kind=MPI_HANDLE_KIND), value :: comm_old

tests/gatherv_1.f90

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
program gatherv_1
2+
use mpi
3+
implicit none
4+
integer :: ierr, rank, size, root
5+
integer, allocatable :: sendbuf(:), recvbuf(:)
6+
integer, allocatable :: recvcounts(:), displs(:)
7+
integer :: sendcount, i, total
8+
logical :: error
9+
10+
! Initialize MPI
11+
call MPI_Init(ierr)
12+
call MPI_Comm_rank(MPI_COMM_WORLD, rank, ierr)
13+
call MPI_Comm_size(MPI_COMM_WORLD, size, ierr)
14+
15+
! Root process
16+
root = 0
17+
18+
! Each process sends 'rank + 1' integers
19+
sendcount = rank + 1
20+
allocate(sendbuf(sendcount))
21+
do i = 1, sendcount
22+
sendbuf(i) = rank * 100 + i ! Unique values per process
23+
end do
24+
25+
! Allocate receive buffers on root
26+
if (rank == root) then
27+
allocate(recvcounts(size))
28+
allocate(displs(size))
29+
total = 0
30+
do i = 1, size
31+
recvcounts(i) = i ! Process i-1 sends i elements
32+
displs(i) = total ! Displacement in recvbuf
33+
total = total + recvcounts(i)
34+
end do
35+
allocate(recvbuf(total))
36+
recvbuf = 0
37+
else
38+
allocate(recvcounts(1), displs(1), recvbuf(1)) ! Dummy allocations for non-root
39+
end if
40+
41+
! Perform gather
42+
call MPI_Gatherv(sendbuf, sendcount, MPI_INTEGER, recvbuf, recvcounts, &
43+
displs, MPI_INTEGER, root, MPI_COMM_WORLD, ierr)
44+
45+
! Verify results on root
46+
error = .false.
47+
if (rank == root) then
48+
do i = 1, size
49+
do sendcount = 1, i
50+
if (recvbuf(displs(i) + sendcount) /= (i-1)*100 + sendcount) then
51+
print *, "Error at rank ", i-1, " index ", sendcount, &
52+
": expected ", (i-1)*100 + sendcount, &
53+
", got ", recvbuf(displs(i) + sendcount)
54+
error = .true.
55+
error stop
56+
end if
57+
end do
58+
end do
59+
if (.not. error) then
60+
print *, "MPI_Gatherv test passed on root"
61+
end if
62+
end if
63+
64+
! Clean up
65+
deallocate(sendbuf, recvbuf, recvcounts, displs)
66+
call MPI_Finalize(ierr)
67+
68+
if (error) stop 1
69+
end program gatherv_1

0 commit comments

Comments
 (0)