diff --git a/src/mpi.f90 b/src/mpi.f90 index 73bf72e..1903afd 100644 --- a/src/mpi.f90 +++ b/src/mpi.f90 @@ -128,6 +128,10 @@ module mpi module procedure MPI_Recv_StatusIgnore_proc end interface + interface MPI_Sendrecv + module procedure MPI_Sendrecv_proc + end interface + interface MPI_Waitall module procedure MPI_Waitall_proc end interface @@ -668,7 +672,49 @@ subroutine MPI_Irecv_proc(buf, count, datatype, source, tag, comm, request, ierr print *, "MPI_Irecv failed with error code: ", local_ierr end if end if - end subroutine + end subroutine MPI_Irecv_proc + + subroutine MPI_Sendrecv_proc (sendbuf, sendcount, sendtype, dest, sendtag, & + recvbuf, recvcount, recvtype, source, recvtag, comm, status, ierror) + use iso_c_binding, only: c_int, c_ptr, c_loc + use mpi_c_bindings, only: c_mpi_sendrecv, c_mpi_status_c2f + real(8), dimension(:,:), target, intent(in) :: sendbuf + integer, intent(in) :: sendcount, dest, sendtag + real(8), dimension(:,:), target, intent(out) :: recvbuf + integer, intent(in) :: recvcount, source, recvtag + integer, intent(in) :: comm + integer, intent(in) :: sendtype, recvtype + integer(kind=MPI_HANDLE_KIND) :: c_comm + integer, intent(out) :: status(MPI_STATUS_SIZE) + integer, optional, intent(out) :: ierror + integer(c_int) :: local_ierr, status_ierr + integer(kind=MPI_HANDLE_KIND) :: c_sendtype, c_recvtype + type(c_ptr) :: sendbuf_ptr, recvbuf_ptr, c_status + integer(c_int), dimension(MPI_STATUS_SIZE), target :: tmp_status + + c_comm = handle_mpi_comm_f2c(comm) + + c_sendtype = handle_mpi_datatype_f2c(sendtype) + c_recvtype = handle_mpi_datatype_f2c(recvtype) + sendbuf_ptr = c_loc(sendbuf) + recvbuf_ptr = c_loc(recvbuf) + c_status = c_loc(tmp_status) + + local_ierr = c_mpi_sendrecv(sendbuf_ptr, sendcount, c_sendtype, dest, sendtag, & + recvbuf_ptr, recvcount, c_recvtype, source, recvtag, & + c_comm, c_status) + + if (local_ierr == MPI_SUCCESS) then + ! status_ierr = c_mpi_status_c2f(c_status, status) + end if + + if (local_ierr /= MPI_SUCCESS) then + print *, "MPI_Sendrecv failed with error code: ", local_ierr + if (present(ierror)) then + ierror = local_ierr + end if + end if + end subroutine MPI_Sendrecv_proc subroutine MPI_Allreduce_scalar(sendbuf, recvbuf, count, datatype, op, comm, ierror) use iso_c_binding, only: c_int, c_ptr, c_loc diff --git a/src/mpi_c_bindings.f90 b/src/mpi_c_bindings.f90 index 05b4f8e..a275034 100644 --- a/src/mpi_c_bindings.f90 +++ b/src/mpi_c_bindings.f90 @@ -219,6 +219,22 @@ function c_mpi_recv(buf, count, c_dtype, source, tag, c_comm, status) bind(C, na integer(c_int) :: c_mpi_recv end function c_mpi_recv + function c_mpi_sendrecv (sendbuf, sendcount, sendtype, dest, sendtag, & + recvbuf, recvcount, recvtype, source, recvtag, comm, status) bind(C, name="MPI_Sendrecv") + use iso_c_binding, only: c_int, c_ptr + type(c_ptr), value :: sendbuf + integer(c_int), value :: sendcount + integer(kind=MPI_HANDLE_KIND), value :: sendtype + integer(c_int), value :: dest, sendtag + type(c_ptr), value :: recvbuf + integer(c_int), value :: recvcount + integer(kind=MPI_HANDLE_KIND), value :: recvtype + integer(c_int), value :: source, recvtag + integer(kind=MPI_HANDLE_KIND), value :: comm + type(c_ptr), value :: status + integer(c_int) :: c_mpi_sendrecv + end function c_mpi_sendrecv + function c_mpi_waitall(count, requests, statuses) bind(C, name="MPI_Waitall") use iso_c_binding, only: c_int, c_ptr integer(c_int), value :: count diff --git a/tests/sendrecv_1.f90 b/tests/sendrecv_1.f90 new file mode 100644 index 0000000..de60b4e --- /dev/null +++ b/tests/sendrecv_1.f90 @@ -0,0 +1,52 @@ +program sendrecv_1 + use mpi + implicit none + integer :: ierr, rank, size, next, prev + real(8), allocatable :: sendbuf(:,:), recvbuf(:,:) + integer :: status(MPI_STATUS_SIZE) + logical :: error + integer :: i, j, n1, n2 + + n1 = 2 + n2 = 3 + + ! Initialize MPI + call MPI_Init(ierr) + call MPI_Comm_rank(MPI_COMM_WORLD, rank, ierr) + call MPI_Comm_size(MPI_COMM_WORLD, size, ierr) + + ! Set up ring communication + next = mod(rank + 1, size) ! Send to next process + prev = mod(rank - 1 + size, size) ! Receive from previous process + + ! Allocate and initialize send/recv buffers + allocate(sendbuf(n1, n2)) + allocate(recvbuf(n1, n2)) + sendbuf = rank + recvbuf = -1.0d0 + + ! Perform sendrecv + call MPI_Sendrecv(sendbuf, n1*n2, MPI_REAL8, next, 0, & + recvbuf, n1*n2, MPI_REAL8, prev, 0, & + MPI_COMM_WORLD, status, ierr) + + ! Verify result + error = .false. + do i = 1, n1 + do j = 1, n2 + if (recvbuf(i,j) /= real(prev,8)) then + print *, "Rank ", rank, ": Error at (",i,",",j,"): Expected ", prev, ", got ", recvbuf(i,j) + error = .true. + end if + end do + end do + + if (.not. error .and. rank == 0) then + print *, "MPI_Sendrecv test passed: rank ", rank, " received correct data" + end if + + ! Clean up + call MPI_Finalize(ierr) + + if (error) error stop 1 +end program sendrecv_1 \ No newline at end of file