diff --git a/include/mpicpp-lite/impl/Communicator.h b/include/mpicpp-lite/impl/Communicator.h index 1b4059a..4283bc7 100644 --- a/include/mpicpp-lite/impl/Communicator.h +++ b/include/mpicpp-lite/impl/Communicator.h @@ -13,6 +13,7 @@ #include "Operation.h" #include "Error.h" #include "Group.h" +#include "Tag.h" namespace mpicpp_lite { @@ -45,6 +46,7 @@ class Communicator { /// Communicator create(const Group & group, int tag = 0) const; + Communicator create(const Group & group, Tag tag) const; /// Makes a new communicator to which topology information has been attached /// @@ -74,6 +76,8 @@ class Communicator { /// @param value Value to send template void send(int dest, int tag, const T & value) const; + template + void send(int dest, Tag tag, const T & value) const; /// Send data to another process /// @@ -84,6 +88,8 @@ class Communicator { /// @param n Number of values to send template void send(int dest, int tag, const T * values, int n) const; + template + void send(int dest, Tag tag, const T * values, int n) const; /// Send `std::vector` of data to another process /// @@ -93,12 +99,15 @@ class Communicator { /// @param value Vector of `T` to send template void send(int dest, int tag, const std::vector & value) const; + template + void send(int dest, Tag tag, const std::vector & value) const; /// Send a message to another process without any data /// /// @param dest Destination rank /// @param tag Message tag void send(int dest, int tag) const; + void send(int dest, Tag tag) const; /// Receive data from a remote process /// @@ -109,6 +118,8 @@ class Communicator { /// @return `Status` of the operation template Status recv(int source, int tag, T & value) const; + template + Status recv(int source, Tag tag, T & value) const; /// Receive data from a remote process /// @@ -120,6 +131,8 @@ class Communicator { /// @return `Status` of the operation template Status recv(int source, int tag, T * values, int n) const; + template + Status recv(int source, Tag tag, T * values, int n) const; /// Receive std::vector of data from a remote process /// @@ -130,6 +143,8 @@ class Communicator { /// @return `Status` of the operation template Status recv(int source, int tag, std::vector & value) const; + template + Status recv(int source, Tag tag, std::vector & value) const; /// Receive a message from a remote process without any data /// @@ -137,6 +152,7 @@ class Communicator { /// @param tag Message tag /// @return `Status` of the operation Status recv(int source, int tag) const; + Status recv(int source, Tag tag) const; /// Send a message to a remote process without blocking /// @@ -147,6 +163,8 @@ class Communicator { /// @return Communication `Request` template Request isend(int dest, int tag, const T & value) const; + template + Request isend(int dest, Tag tag, const T & value) const; /// Send a std::vector of values to a remote process without blocking /// @@ -157,6 +175,8 @@ class Communicator { /// @return Communication `Request` template Request isend(int dest, int tag, const std::vector & values) const; + template + Request isend(int dest, Tag tag, const std::vector & values) const; /// Send a message to a remote process without blocking /// @@ -168,6 +188,8 @@ class Communicator { /// @return Communication `Request` template Request isend(int dest, int tag, const T * values, int n) const; + template + Request isend(int dest, Tag tag, const T * values, int n) const; /// Receive a message from a remote process without blocking /// @@ -178,6 +200,8 @@ class Communicator { /// @return Communication `Request` template Request irecv(int source, int tag, T & value) const; + template + Request irecv(int source, Tag tag, T & value) const; /// Receive a message from a remote process without blocking /// @@ -188,6 +212,8 @@ class Communicator { /// @param n Number of values to receive template Request irecv(int source, int tag, T * values, int n) const; + template + Request irecv(int source, Tag tag, T * values, int n) const; /// Nonblocking test for a message /// @@ -196,6 +222,7 @@ class Communicator { /// @return `true` if a message with the specified source, and tag is available, `false` /// otherwise bool iprobe(int source, int tag) const; + bool iprobe(int source, Tag tag) const; /// Nonblocking test for a message /// @@ -205,6 +232,7 @@ class Communicator { /// @return `true` if a message with the specified source, and tag is available, `false` /// otherwise bool iprobe(int source, int tag, Status & status) const; + bool iprobe(int source, Tag tag, Status & status) const; /// Wait for all processes within a communicator to reach the barrier. void barrier() const; @@ -700,6 +728,14 @@ Communicator::create(const Group & group, int tag) const return new_comm; } +inline Communicator +Communicator::create(const Group & group, Tag tag) const +{ + Communicator new_comm; + MPI_CHECK_SELF(MPI_Comm_create_group(this->comm_, group.group_, tag.value(), &new_comm.comm_)); + return new_comm; +} + inline CartesianCommunicator Communicator::create_cartesian(const std::vector & dims, bool reorder) const { @@ -741,6 +777,13 @@ Communicator::send(int dest, int tag, const T & value) const send(dest, tag, &value, 1); } +template +inline void +Communicator::send(int dest, Tag tag, const T & value) const +{ + send(dest, tag, &value, 1); +} + template inline void Communicator::send(int dest, int tag, const T * values, int n) const @@ -748,6 +791,14 @@ Communicator::send(int dest, int tag, const T * values, int n) const MPI_CHECK_SELF(MPI_Send(const_cast(values), n, mpi_datatype(), dest, tag, this->comm_)); } +template +inline void +Communicator::send(int dest, Tag tag, const T * values, int n) const +{ + MPI_CHECK_SELF( + MPI_Send(const_cast(values), n, mpi_datatype(), dest, tag.value(), this->comm_)); +} + template inline void Communicator::send(int dest, int tag, const std::vector & value) const @@ -756,12 +807,26 @@ Communicator::send(int dest, int tag, const std::vector & value) const send(dest, tag, value.data(), size); } +template +inline void +Communicator::send(int dest, Tag tag, const std::vector & value) const +{ + typename std::vector::size_type size = value.size(); + send(dest, tag, value.data(), size); +} + inline void Communicator::send(int dest, int tag) const { MPI_CHECK_SELF(MPI_Send(MPI_BOTTOM, 0, MPI_PACKED, dest, tag, this->comm_)); } +inline void +Communicator::send(int dest, Tag tag) const +{ + MPI_CHECK_SELF(MPI_Send(MPI_BOTTOM, 0, MPI_PACKED, dest, tag.value(), this->comm_)); +} + template <> inline void Communicator::send(int dest, int tag, const std::string & value) const @@ -771,6 +836,15 @@ Communicator::send(int dest, int tag, const std::string & value) const send(dest, tag, value.data(), value.size()); } +template <> +inline void +Communicator::send(int dest, Tag tag, const std::string & value) const +{ + if (size() < 2) + return; + send(dest, tag, value.data(), value.size()); +} + // Recv template @@ -780,6 +854,13 @@ Communicator::recv(int source, int tag, T & value) const return recv(source, tag, &value, 1); } +template +inline Status +Communicator::recv(int source, Tag tag, T & value) const +{ + return recv(source, tag, &value, 1); +} + template inline Status Communicator::recv(int source, int tag, T * values, int n) const @@ -795,6 +876,21 @@ Communicator::recv(int source, int tag, T * values, int n) const return status; } +template +inline Status +Communicator::recv(int source, Tag tag, T * values, int n) const +{ + Status status; + MPI_CHECK_SELF(MPI_Recv(const_cast(values), + n, + mpi_datatype(), + source, + tag.value(), + this->comm_, + &status.native())); + return status; +} + template inline Status Communicator::recv(int source, int tag, std::vector & values) const @@ -806,6 +902,17 @@ Communicator::recv(int source, int tag, std::vector & values) const return recv(source, tag, values.data(), size); } +template +inline Status +Communicator::recv(int source, Tag tag, std::vector & values) const +{ + Status status; + MPI_CHECK_SELF(MPI_Probe(source, tag.value(), this->comm_, &status.native())); + auto size = status.count(); + values.resize(size); + return recv(source, tag, values.data(), size); +} + inline Status Communicator::recv(int source, int tag) const { @@ -814,6 +921,15 @@ Communicator::recv(int source, int tag) const return status; } +inline Status +Communicator::recv(int source, Tag tag) const +{ + Status status; + MPI_CHECK_SELF( + MPI_Recv(MPI_BOTTOM, 0, MPI_PACKED, source, tag.value(), this->comm_, &status.native())); + return status; +} + template <> inline Status Communicator::recv(int source, int tag, std::string & value) const @@ -824,6 +940,16 @@ Communicator::recv(int source, int tag, std::string & value) const return status; } +template <> +inline Status +Communicator::recv(int source, Tag tag, std::string & value) const +{ + std::vector str; + auto status = recv(source, tag, str); + value.assign(str.begin(), str.end()); + return status; +} + // Isend template @@ -833,6 +959,13 @@ Communicator::isend(int dest, int tag, const T & value) const return isend(dest, tag, &value, 1); } +template +inline Request +Communicator::isend(int dest, Tag tag, const T & value) const +{ + return isend(dest, tag, &value, 1); +} + template inline Request Communicator::isend(int dest, int tag, const std::vector & values) const @@ -840,6 +973,13 @@ Communicator::isend(int dest, int tag, const std::vector & values) const return isend(dest, tag, values.data(), values.size()); } +template +inline Request +Communicator::isend(int dest, Tag tag, const std::vector & values) const +{ + return isend(dest, tag, values.data(), values.size()); +} + template inline Request Communicator::isend(int dest, int tag, const T * values, int n) const @@ -856,6 +996,22 @@ Communicator::isend(int dest, int tag, const T * values, int n) const return request; } +template +inline Request +Communicator::isend(int dest, Tag tag, const T * values, int n) const +{ + assert(values != nullptr); + Request request; + MPI_CHECK_SELF(MPI_Isend(const_cast(values), + n, + mpi_datatype(), + dest, + tag.value(), + this->comm_, + &request.native())); + return request; +} + // Irecv template @@ -865,6 +1021,13 @@ Communicator::irecv(int source, int tag, T & value) const return irecv(source, tag, &value, 1); } +template +inline Request +Communicator::irecv(int source, Tag tag, T & value) const +{ + return irecv(source, tag, &value, 1); +} + template inline Request Communicator::irecv(int source, int tag, T * values, int n) const @@ -881,6 +1044,22 @@ Communicator::irecv(int source, int tag, T * values, int n) const return request; } +template +inline Request +Communicator::irecv(int source, Tag tag, T * values, int n) const +{ + assert(values != nullptr); + Request request; + MPI_CHECK_SELF(MPI_Irecv(const_cast(values), + n, + mpi_datatype(), + source, + tag.value(), + this->comm_, + &request.native())); + return request; +} + inline bool Communicator::iprobe(int source, int tag) const { @@ -889,6 +1068,14 @@ Communicator::iprobe(int source, int tag) const return flag != 0; } +inline bool +Communicator::iprobe(int source, Tag tag) const +{ + int flag; + MPI_CHECK_SELF(MPI_Iprobe(source, tag.value(), this->comm_, &flag, MPI_STATUS_IGNORE)); + return flag != 0; +} + inline bool Communicator::iprobe(int source, int tag, Status & status) const { @@ -897,6 +1084,14 @@ Communicator::iprobe(int source, int tag, Status & status) const return flag != 0; } +inline bool +Communicator::iprobe(int source, Tag tag, Status & status) const +{ + int flag; + MPI_CHECK_SELF(MPI_Iprobe(source, tag.value(), this->comm_, &flag, &status.native())); + return flag != 0; +} + // Barrier inline void diff --git a/include/mpicpp-lite/impl/Tag.h b/include/mpicpp-lite/impl/Tag.h new file mode 100644 index 0000000..8ec28f7 --- /dev/null +++ b/include/mpicpp-lite/impl/Tag.h @@ -0,0 +1,43 @@ +// SPDX-FileCopyrightText: 2026 David Andrs +// SPDX-License-Identifier: MIT + +#pragma once + +#include "mpi.h" + +namespace mpicpp_lite { + +class Tag { +public: + constexpr Tag() : value_(0) {} + explicit constexpr Tag(int tag) : value_(tag) {} + + constexpr int + value() const + { + return this->value_; + } + + constexpr bool + operator<(Tag other) const + { + return this->value_ < other.value_; + } + + constexpr bool + operator==(int other) const + { + return this->value_ == other; + } + + constexpr bool + operator!=(int other) const + { + return this->value_ != other; + } + +private: + int value_; +}; + +} // namespace mpicpp_lite diff --git a/tests/MPI_test.cpp b/tests/MPI_test.cpp index 37e02ae..34cd659 100644 --- a/tests/MPI_test.cpp +++ b/tests/MPI_test.cpp @@ -72,13 +72,13 @@ TEST(MPITest, send_recv_int) if (n_mpis == 1) return; - int tag = 1234; + Tag tag(1234); if (comm.rank() == 0) { for (int i = 1; i < n_mpis; i++) { int val; auto status = comm.recv(i, tag, val); EXPECT_EQ(val, i * 4); - EXPECT_EQ(status.tag(), tag); + EXPECT_EQ(status.tag(), tag.value()); EXPECT_EQ(status.source(), i); EXPECT_EQ(status.error(), 0); } @@ -96,13 +96,13 @@ TEST(MPITest, send_recv_bool) if (n_mpis == 1) return; - int tag = 1234; + Tag tag(1234); if (comm.rank() == 0) { for (int i = 1; i < n_mpis; i++) { bool val; auto status = comm.recv(i, tag, val); EXPECT_EQ(val, i % 2 == 0); - EXPECT_EQ(status.tag(), tag); + EXPECT_EQ(status.tag(), tag.value()); EXPECT_EQ(status.source(), i); EXPECT_EQ(status.error(), 0); } @@ -120,12 +120,12 @@ TEST(MPITest, send_recv) if (n_mpis == 1) return; - int tag = 0; + Tag tag; if (comm.rank() == 0) { for (int i = 1; i < n_mpis; i++) { auto status = comm.recv(i, tag); EXPECT_EQ(status.source(), i); - EXPECT_EQ(status.tag(), tag); + EXPECT_EQ(status.tag(), tag.value()); } } else @@ -139,7 +139,7 @@ TEST(MPITest, send_recv_arr_int) if (n_mpis == 1) return; - int tag = 0; + Tag tag; if (comm.rank() == 0) { for (int i = 1; i < n_mpis; i++) { std::vector arr; @@ -149,7 +149,7 @@ TEST(MPITest, send_recv_arr_int) for (int j = 0; j < sz; j++) EXPECT_EQ(arr[j], 2 * j); EXPECT_EQ(status.source(), i); - EXPECT_EQ(status.tag(), tag); + EXPECT_EQ(status.tag(), tag.value()); } } else { @@ -169,12 +169,12 @@ TEST(MPITest, send_recv_empty_std_vector) if (n_mpis == 1) return; - int tag = 1234; + Tag tag(1234); if (comm.rank() == 0) { for (int i = 1; i < n_mpis; i++) { std::vector arr; auto status = comm.recv(i, tag, arr); - EXPECT_EQ(status.tag(), tag); + EXPECT_EQ(status.tag(), tag.value()); EXPECT_EQ(status.source(), i); EXPECT_EQ(status.error(), 0); } @@ -192,7 +192,7 @@ TEST(MPITest, send_recv_std_str) if (n_mpis == 1) return; - int tag = 101; + Tag tag(101); std::string str; if (comm.rank() == 0) { str = "ahoi"; @@ -670,7 +670,7 @@ TEST(MPITest, iprobe) if (comm.size() == 1) return; - int tag = 1; + Tag tag(1); if (comm.rank() == 0) { for (int i = 1; i < comm.size(); i++) { int num = i * 5; @@ -700,7 +700,7 @@ TEST(MPITest, iprobe_w_status) if (comm.size() == 1) return; - int tag = 1; + Tag tag(1); if (comm.rank() == 0) { for (int i = 1; i < comm.size(); i++) { int num = i * 5; @@ -731,7 +731,7 @@ TEST(MPITest, isend_irecv_wait) if (comm.size() == 1) return; - int tag = 1; + Tag tag(1); if (comm.rank() == 0) { for (int i = 1; i < comm.size(); i++) { int num = i * 5; @@ -753,7 +753,7 @@ TEST(MPITest, isend_irecv_wait_w_status) if (comm.size() == 1) return; - int tag = 1; + Tag tag(1); if (comm.rank() == 0) { for (int i = 1; i < comm.size(); i++) { int num = i * 5; @@ -769,7 +769,7 @@ TEST(MPITest, isend_irecv_wait_w_status) wait(request, status); EXPECT_EQ(val, comm.rank() * 5); EXPECT_EQ(status.source(), 0); - EXPECT_EQ(status.tag(), tag); + EXPECT_EQ(status.tag(), tag.value()); } } @@ -779,7 +779,7 @@ TEST(MPITest, isend_irecv_waitall) if (comm.size() == 1) return; - int tag = 1; + Tag tag(1); if (comm.rank() == 0) { int n = comm.size() - 1; std::vector vals; @@ -805,7 +805,7 @@ TEST(MPITest, isend_irecv_waitany) if (comm.size() == 1) return; - int tag = 1; + Tag tag(1); if (comm.rank() == 0) { int n = comm.size() - 1; std::vector vals; @@ -835,7 +835,7 @@ TEST(MPITest, test_all) if (comm.size() < 2) return; - int tag = 1; + Tag tag(1); if (comm.rank() == 0) { int n = comm.size() - 1; std::vector vals(n); @@ -860,7 +860,7 @@ TEST(MPITest, test_any) if (comm.size() < 2) return; - int tag = 1; + Tag tag(1); if (comm.rank() == 0) { int n = comm.size() - 1; std::vector vals(n);