Skip to content

Commit 294c100

Browse files
committed
mpi: inline small messages
In this PR we add support to include the body of small (for a definition of small) messages inside the message body. This increases the size of _all_ messages being moved around, but hopefully spares the need to malloc/free small messages.
1 parent c3dbe3b commit 294c100

File tree

7 files changed

+146
-47
lines changed

7 files changed

+146
-47
lines changed

include/faabric/mpi/MpiMessage.h

+18-2
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,11 @@
33
#include <cstdint>
44
#include <vector>
55

6+
// Constant copied from OpenMPI's SM implementation. It indicates the maximum
7+
// number of Bytes that we may inline in a message (rather than malloc-ing)
8+
// https://github.com/open-mpi/ompi/blob/main/opal/mca/btl/sm/btl_sm_component.c#L153
9+
#define MPI_MAX_INLINE_SEND 256
10+
611
namespace faabric::mpi {
712

813
enum MpiMessageType : int32_t
@@ -49,7 +54,11 @@ struct MpiMessage
4954
// struct 8-aligned
5055
int32_t requestId;
5156
MpiMessageType messageType;
52-
void* buffer;
57+
union
58+
{
59+
void* buffer;
60+
uint8_t inlineMsg[MPI_MAX_INLINE_SEND];
61+
};
5362
};
5463
static_assert((sizeof(MpiMessage) % 8) == 0, "MPI message must be 8-aligned!");
5564

@@ -60,7 +69,14 @@ inline size_t payloadSize(const MpiMessage& msg)
6069

6170
inline size_t msgSize(const MpiMessage& msg)
6271
{
63-
return sizeof(MpiMessage) + payloadSize(msg);
72+
size_t payloadSz = payloadSize(msg);
73+
74+
// If we can inline the message, we do not need to add anything else
75+
if (payloadSz < MPI_MAX_INLINE_SEND) {
76+
return sizeof(MpiMessage);
77+
}
78+
79+
return sizeof(MpiMessage) + payloadSz;
6480
}
6581

6682
void serializeMpiMsg(std::vector<uint8_t>& buffer, const MpiMessage& msg);

src/mpi/MpiMessage.cpp

+7-6
Original file line numberDiff line numberDiff line change
@@ -12,24 +12,25 @@ void parseMpiMsg(const std::vector<uint8_t>& bytes, MpiMessage* msg)
1212
assert(msg != nullptr);
1313
assert(bytes.size() >= sizeof(MpiMessage));
1414
std::memcpy(msg, bytes.data(), sizeof(MpiMessage));
15-
size_t thisPayloadSize = bytes.size() - sizeof(MpiMessage);
16-
assert(thisPayloadSize == payloadSize(*msg));
15+
size_t thisPayloadSize = payloadSize(*msg);
1716

1817
if (thisPayloadSize == 0) {
1918
msg->buffer = nullptr;
2019
return;
2120
}
2221

23-
msg->buffer = faabric::util::malloc(thisPayloadSize);
24-
std::memcpy(
25-
msg->buffer, bytes.data() + sizeof(MpiMessage), thisPayloadSize);
22+
if (thisPayloadSize > MPI_MAX_INLINE_SEND) {
23+
msg->buffer = faabric::util::malloc(thisPayloadSize);
24+
std::memcpy(
25+
msg->buffer, bytes.data() + sizeof(MpiMessage), thisPayloadSize);
26+
}
2627
}
2728

2829
void serializeMpiMsg(std::vector<uint8_t>& buffer, const MpiMessage& msg)
2930
{
3031
std::memcpy(buffer.data(), &msg, sizeof(MpiMessage));
3132
size_t payloadSz = payloadSize(msg);
32-
if (payloadSz > 0 && msg.buffer != nullptr) {
33+
if (payloadSz > MPI_MAX_INLINE_SEND && msg.buffer != nullptr) {
3334
std::memcpy(buffer.data() + sizeof(MpiMessage), msg.buffer, payloadSz);
3435
}
3536
}

src/mpi/MpiWorld.cpp

+52-18
Original file line numberDiff line numberDiff line change
@@ -589,6 +589,8 @@ void MpiWorld::send(int sendRank,
589589
MpiMessageType messageType)
590590
{
591591
// Sanity-check input parameters
592+
// TODO: should we just make this assertions and wait for something else
593+
// to seg-fault down the line?
592594
checkRanksRange(sendRank, recvRank);
593595
if (getHostForRank(sendRank) != thisHost) {
594596
SPDLOG_ERROR("Trying to send message from a non-local rank: {}",
@@ -609,34 +611,45 @@ void MpiWorld::send(int sendRank,
609611
.recvRank = recvRank,
610612
.typeSize = dataType->size,
611613
.count = count,
612-
.messageType = messageType,
613-
.buffer = nullptr };
614+
.messageType = messageType };
614615

615616
// Mock the message sending in tests
617+
// TODO: can we get rid of this atomic in the hot path?
616618
if (faabric::util::isMockMode()) {
617619
mpiMockedMessages[sendRank].push_back(msg);
618620
return;
619621
}
620622

621-
bool mustSendData = count > 0 && buffer != nullptr;
623+
size_t dataSize = count * dataType->size;
624+
bool mustSendData = dataSize > 0 && buffer != nullptr;
622625

623626
// Dispatch the message locally or globally
624627
if (isLocal) {
625628
// Take control over the buffer data if we are gonna move it to
626629
// the in-memory queues for local messaging
627630
if (mustSendData) {
628-
void* bufferPtr = faabric::util::malloc(count * dataType->size);
629-
std::memcpy(bufferPtr, buffer, count * dataType->size);
631+
if (dataSize < MPI_MAX_INLINE_SEND) {
632+
std::memcpy(msg.inlineMsg, buffer, count * dataType->size);
633+
} else {
634+
void* bufferPtr = faabric::util::malloc(count * dataType->size);
635+
std::memcpy(bufferPtr, buffer, count * dataType->size);
630636

631-
msg.buffer = bufferPtr;
637+
msg.buffer = bufferPtr;
638+
}
639+
} else {
640+
msg.buffer = nullptr;
632641
}
633642

634643
SPDLOG_TRACE(
635644
"MPI - send {} -> {} ({})", sendRank, recvRank, messageType);
636645
getLocalQueue(sendRank, recvRank)->enqueue(msg);
637646
} else {
638647
if (mustSendData) {
639-
msg.buffer = (void*)buffer;
648+
if (dataSize < MPI_MAX_INLINE_SEND) {
649+
std::memcpy(msg.inlineMsg, buffer, count * dataType->size);
650+
} else {
651+
msg.buffer = (void*)buffer;
652+
}
640653
}
641654

642655
SPDLOG_TRACE(
@@ -704,17 +717,25 @@ void MpiWorld::doRecv(const MpiMessage& m,
704717
}
705718
assert(m.messageType == messageType);
706719
assert(m.count <= count);
720+
size_t dataSize = m.count * dataType->size;
707721

708722
// We must copy the data into the application-provided buffer
709-
if (m.count > 0 && m.buffer != nullptr) {
723+
if (dataSize > 0) {
710724
// Make sure we do not overflow the recepient buffer
711725
auto bytesToCopy =
712726
std::min<size_t>(m.count * dataType->size, count * dataType->size);
713-
std::memcpy(buffer, m.buffer, bytesToCopy);
714727

715-
// This buffer has been malloc-ed either as part of a local `send`
716-
// or as part of a remote `parseMpiMsg`
717-
faabric::util::free((void*)m.buffer);
728+
if (dataSize > MPI_MAX_INLINE_SEND) {
729+
assert(m.buffer != nullptr);
730+
731+
std::memcpy(buffer, m.buffer, bytesToCopy);
732+
733+
// This buffer has been malloc-ed either as part of a local `send`
734+
// or as part of a remote `parseMpiMsg`
735+
faabric::util::free((void*)m.buffer);
736+
} else {
737+
std::memcpy(buffer, m.inlineMsg, bytesToCopy);
738+
}
718739
}
719740

720741
// Set status values if required
@@ -1886,21 +1907,34 @@ MpiMessage MpiWorld::recvBatchReturnLast(int sendRank,
18861907
// Copy the request id so that it is not overwritten
18871908
int tmpRequestId = itr->requestId;
18881909

1889-
// Copy into current slot in the list, but keep a copy to the
1890-
// app-provided buffer to read data into
1910+
// Copy the app-provided buffer to recv data into so that it is
1911+
// not overwritten too. Note that, irrespective of wether the
1912+
// message is inlined or not, we always use the buffer pointer to
1913+
// point to the app-provided recv-buffer
18911914
void* providedBuffer = itr->buffer;
1915+
1916+
// Copy into current slot in the list
18921917
*itr = getLocalQueue(sendRank, recvRank)->dequeue();
18931918
itr->requestId = tmpRequestId;
18941919

1895-
if (itr->buffer != nullptr) {
1920+
// If we have send a non-inlined message, copy the data into the
1921+
// provided buffer, free the one in the queue,
1922+
size_t dataSize = itr->count * itr->typeSize;
1923+
if (dataSize > MPI_MAX_INLINE_SEND) {
1924+
assert(itr->buffer != nullptr);
18961925
assert(providedBuffer != nullptr);
1897-
// If buffers are not null, we must have a non-zero size
1898-
assert((itr->count * itr->typeSize) > 0);
18991926
std::memcpy(
19001927
providedBuffer, itr->buffer, itr->count * itr->typeSize);
1928+
19011929
faabric::util::free(itr->buffer);
1930+
1931+
itr->buffer = providedBuffer;
1932+
} else if (dataSize > 0) {
1933+
std::memcpy(
1934+
providedBuffer, itr->inlineMsg, itr->count * itr->typeSize);
1935+
} else {
1936+
itr->buffer = providedBuffer;
19021937
}
1903-
itr->buffer = providedBuffer;
19041938
}
19051939
assert(itr->messageType != MpiMessageType::UNACKED_MPI_MESSAGE);
19061940

tests/dist/mpi/examples/mpi_isendrecv.cpp

-3
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,6 @@ int iSendRecv()
4141
}
4242
printf("Rank %i - async working properly\n", rank);
4343

44-
delete sendRequest;
45-
delete recvRequest;
46-
4744
MPI_Finalize();
4845

4946
return 0;

tests/dist/mpi/examples/mpi_send_sync_async.cpp

-3
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ int sendSyncAsync()
2222
MPI_Send(&r, 1, MPI_INT, r, 0, MPI_COMM_WORLD);
2323
MPI_Wait(&sendRequest, MPI_STATUS_IGNORE);
2424
}
25-
delete sendRequest;
2625
} else {
2726
// Asynchronously receive twice from rank 0
2827
int recvValue1 = -1;
@@ -47,8 +46,6 @@ int sendSyncAsync()
4746
rank);
4847
return 1;
4948
}
50-
delete recvRequest1;
51-
delete recvRequest2;
5249
}
5350
printf("Rank %i - send sync and async working properly\n", rank);
5451

tests/test/mpi/test_mpi_message.cpp

+41-12
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@ bool areMpiMsgEqual(const MpiMessage& msgA, const MpiMessage& msgB)
1818
return false;
1919
}
2020

21-
// First, compare the message body (excluding the pointer, which we
22-
// know is at the end)
23-
if (std::memcmp(&msgA, &msgB, sizeof(MpiMessage) - sizeof(void*)) != 0) {
21+
// First, compare the message body (excluding the union at the end)
22+
size_t unionSize = sizeof(uint8_t) * MPI_MAX_INLINE_SEND;
23+
if (std::memcmp(&msgA, &msgB, sizeof(MpiMessage) - unionSize) != 0) {
2424
return false;
2525
}
2626

@@ -35,7 +35,11 @@ bool areMpiMsgEqual(const MpiMessage& msgA, const MpiMessage& msgB)
3535
// Assert, as this should pass given the previous comparisons
3636
assert(payloadSizeA == payloadSizeB);
3737

38-
return std::memcmp(msgA.buffer, msgB.buffer, payloadSizeA) == 0;
38+
if (payloadSizeA > MPI_MAX_INLINE_SEND) {
39+
return std::memcmp(msgA.buffer, msgB.buffer, payloadSizeA) == 0;
40+
}
41+
42+
return std::memcmp(msgA.inlineMsg, msgB.inlineMsg, payloadSizeA) == 0;
3943
}
4044

4145
TEST_CASE("Test getting a message size", "[mpi]")
@@ -59,11 +63,23 @@ TEST_CASE("Test getting a message size", "[mpi]")
5963
expectedPayloadSize = 0;
6064
}
6165

62-
SECTION("Non-empty message")
66+
SECTION("Non-empty (small) message")
6367
{
6468
std::vector<int> nums = { 1, 2, 3, 4, 5, 6, 6 };
6569
msg.count = nums.size();
6670
msg.typeSize = sizeof(int);
71+
std::memcpy(msg.inlineMsg, nums.data(), nums.size() * sizeof(int));
72+
73+
expectedPayloadSize = sizeof(int) * nums.size();
74+
expectedMsgSize = sizeof(MpiMessage);
75+
}
76+
77+
SECTION("Non-empty (large) message")
78+
{
79+
int32_t maxNumInts = MPI_MAX_INLINE_SEND / sizeof(int32_t);
80+
std::vector<int32_t> nums(maxNumInts + 3, 3);
81+
msg.count = nums.size();
82+
msg.typeSize = sizeof(int);
6783
msg.buffer = faabric::util::malloc(msg.count * msg.typeSize);
6884
std::memcpy(msg.buffer, nums.data(), nums.size() * sizeof(int));
6985

@@ -74,7 +90,7 @@ TEST_CASE("Test getting a message size", "[mpi]")
7490
REQUIRE(expectedMsgSize == msgSize(msg));
7591
REQUIRE(expectedPayloadSize == payloadSize(msg));
7692

77-
if (msg.buffer != nullptr) {
93+
if (expectedPayloadSize > MPI_MAX_INLINE_SEND && msg.buffer != nullptr) {
7894
faabric::util::free(msg.buffer);
7995
}
8096
}
@@ -95,11 +111,22 @@ TEST_CASE("Test (de)serialising an MPI message", "[mpi]")
95111
msg.buffer = nullptr;
96112
}
97113

98-
SECTION("Non-empty message")
114+
SECTION("Non-empty (small) message")
99115
{
100116
std::vector<int> nums = { 1, 2, 3, 4, 5, 6, 6 };
101117
msg.count = nums.size();
102118
msg.typeSize = sizeof(int);
119+
std::memcpy(msg.inlineMsg, nums.data(), nums.size() * sizeof(int));
120+
}
121+
122+
SECTION("Non-empty (large) message")
123+
{
124+
// Make sure we send more ints than the maximum inline
125+
int32_t maxNumInts = MPI_MAX_INLINE_SEND / sizeof(int32_t);
126+
std::vector<int32_t> nums(maxNumInts + 3, 3);
127+
msg.count = nums.size();
128+
msg.typeSize = sizeof(int);
129+
REQUIRE(payloadSize(msg) > MPI_MAX_INLINE_SEND);
103130
msg.buffer = faabric::util::malloc(msg.count * msg.typeSize);
104131
std::memcpy(msg.buffer, nums.data(), nums.size() * sizeof(int));
105132
}
@@ -113,11 +140,13 @@ TEST_CASE("Test (de)serialising an MPI message", "[mpi]")
113140

114141
REQUIRE(areMpiMsgEqual(msg, parsedMsg));
115142

116-
if (msg.buffer != nullptr) {
117-
faabric::util::free(msg.buffer);
118-
}
119-
if (parsedMsg.buffer != nullptr) {
120-
faabric::util::free(parsedMsg.buffer);
143+
if (msg.count * msg.typeSize > MPI_MAX_INLINE_SEND) {
144+
if (msg.buffer != nullptr) {
145+
faabric::util::free(msg.buffer);
146+
}
147+
if (parsedMsg.buffer != nullptr) {
148+
faabric::util::free(parsedMsg.buffer);
149+
}
121150
}
122151
}
123152
}

tests/test/mpi/test_mpi_world.cpp

+28-3
Original file line numberDiff line numberDiff line change
@@ -239,11 +239,17 @@ TEST_CASE_METHOD(MpiTestFixture, "Test send and recv on same host", "[mpi]")
239239
int rankA2 = 1;
240240
std::vector<int> messageData;
241241

242-
SECTION("Non-empty message")
242+
SECTION("Non-empty (small) message")
243243
{
244244
messageData = { 0, 1, 2 };
245245
}
246246

247+
SECTION("Non-empty (large) message")
248+
{
249+
int32_t maxNumInts = MPI_MAX_INLINE_SEND / sizeof(int32_t);
250+
messageData = std::vector<int>(maxNumInts + 3, 3);
251+
}
252+
247253
SECTION("Empty message")
248254
{
249255
messageData = {};
@@ -273,8 +279,27 @@ TEST_CASE_METHOD(MpiTestFixture, "Test sendrecv", "[mpi]")
273279
int rankA = 1;
274280
int rankB = 2;
275281
MPI_Status status{};
276-
std::vector<int> messageDataAB = { 0, 1, 2 };
277-
std::vector<int> messageDataBA = { 3, 2, 1, 0 };
282+
std::vector<int> messageDataAB;
283+
std::vector<int> messageDataBA;
284+
285+
SECTION("Empty messages")
286+
{
287+
messageDataAB = {};
288+
messageDataBA = {};
289+
}
290+
291+
SECTION("Small messages")
292+
{
293+
messageDataAB = { 0, 1, 2 };
294+
messageDataBA = { 3, 2, 1, 0 };
295+
}
296+
297+
SECTION("Large messages")
298+
{
299+
int32_t maxNumInts = MPI_MAX_INLINE_SEND / sizeof(int32_t);
300+
messageDataAB = std::vector<int>(maxNumInts + 3, 3);
301+
messageDataBA = std::vector<int>(maxNumInts + 4, 4);
302+
}
278303

279304
// Results
280305
std::vector<int> recvBufferA(messageDataBA.size(), 0);

0 commit comments

Comments
 (0)