Skip to content

Commit

Permalink
adding message type to asynchronous send/recv (faasm#28)
Browse files Browse the repository at this point in the history
  • Loading branch information
csegarragonz authored Nov 20, 2020
1 parent c83395c commit 521fe47
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 16 deletions.
12 changes: 9 additions & 3 deletions include/faabric/scheduler/MpiWorld.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,9 @@ class MpiWorld
int recvRank,
const uint8_t* buffer,
faabric_datatype_t* dataType,
int count);
int count,
faabric::MPIMessage::MPIMessageType messageType =
faabric::MPIMessage::NORMAL);

void broadcast(int sendRank,
const uint8_t* buffer,
Expand All @@ -82,7 +84,9 @@ class MpiWorld
int recvRank,
uint8_t* buffer,
faabric_datatype_t* dataType,
int count);
int count,
faabric::MPIMessage::MPIMessageType messageType =
faabric::MPIMessage::NORMAL);

void awaitAsyncRequest(int requestId);

Expand Down Expand Up @@ -220,7 +224,9 @@ class MpiWorld
const uint8_t* sendBuffer,
uint8_t* recvBuffer,
faabric_datatype_t* dataType,
int count);
int count,
faabric::MPIMessage::MPIMessageType messageType =
faabric::MPIMessage::NORMAL);

void pushToState();
};
Expand Down
47 changes: 34 additions & 13 deletions src/scheduler/MpiWorld.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -227,31 +227,47 @@ int MpiWorld::isend(int sendRank,
int recvRank,
const uint8_t* buffer,
faabric_datatype_t* dataType,
int count)
int count,
faabric::MPIMessage::MPIMessageType messageType)
{
return doISendRecv(sendRank, recvRank, buffer, nullptr, dataType, count);
return doISendRecv(
sendRank, recvRank, buffer, nullptr, dataType, count, messageType);
}

int MpiWorld::doISendRecv(int sendRank,
int recvRank,
const uint8_t* sendBuffer,
uint8_t* recvBuffer,
faabric_datatype_t* dataType,
int count)
int count,
faabric::MPIMessage::MPIMessageType messageType)
{

int requestId = (int)faabric::util::generateGid();

// Spawn a thread to do the work
asyncThreadMap.insert(std::pair<int, std::thread>(
requestId,
[this, sendRank, recvRank, sendBuffer, recvBuffer, dataType, count] {
[this,
sendRank,
recvRank,
sendBuffer,
recvBuffer,
dataType,
count,
messageType] {
// Do the operation (i.e. the underlying synchronous send/ receive)
if (recvBuffer == nullptr) {
this->send(sendRank, recvRank, sendBuffer, dataType, count);
this->send(
sendRank, recvRank, sendBuffer, dataType, count, messageType);
} else {
this->recv(
sendRank, recvRank, recvBuffer, dataType, count, nullptr);
this->recv(sendRank,
recvRank,
recvBuffer,
dataType,
count,
nullptr,
messageType);
}
}));

Expand Down Expand Up @@ -335,16 +351,19 @@ void MpiWorld::sendRecv(uint8_t* sendBuffer,
}

// Post async recv
int recvId = irecv(recvRank, sendRank, recvBuffer, recvDataType, recvCount);
int recvId = irecv(recvRank,
sendRank,
recvBuffer,
recvDataType,
recvCount,
faabric::MPIMessage::SENDRECV);
// Then send the message
// TODO change MPIMessage to MPIMessage::SENDRECV. This requires a change
// in the signature of doISendRecv.
send(sendRank,
recvRank,
sendBuffer,
sendDataType,
sendCount,
faabric::MPIMessage::NORMAL);
faabric::MPIMessage::SENDRECV);
// And wait
awaitAsyncRequest(recvId);
}
Expand Down Expand Up @@ -553,9 +572,11 @@ int MpiWorld::irecv(int sendRank,
int recvRank,
uint8_t* buffer,
faabric_datatype_t* dataType,
int count)
int count,
faabric::MPIMessage::MPIMessageType messageType)
{
return doISendRecv(sendRank, recvRank, nullptr, buffer, dataType, count);
return doISendRecv(
sendRank, recvRank, nullptr, buffer, dataType, count, messageType);
}

void MpiWorld::recv(int sendRank,
Expand Down
13 changes: 13 additions & 0 deletions tests/test/scheduler/test_mpi_world.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,19 @@ TEST_CASE("Test send and recv on same host", "[mpi]")
REQUIRE(status.MPI_SOURCE == rankA1);
REQUIRE(status.bytesSize == messageData.size() * sizeof(int));
}

SECTION("Test recv with type missmatch")
{
// Receive a message from a different type
auto buffer = new int[messageData.size()];
REQUIRE_THROWS(world.recv(rankA1,
rankA2,
BYTES(buffer),
MPI_INT,
messageData.size(),
nullptr,
faabric::MPIMessage::SENDRECV));
}
}

TEST_CASE("Test sendrecv", "[mpi]")
Expand Down

0 comments on commit 521fe47

Please sign in to comment.