From 9c2ccc0c21f75ff1496c65bd28343638ceaa7c2d Mon Sep 17 00:00:00 2001 From: Raven Szewczyk Date: Thu, 1 Dec 2022 11:18:52 +0000 Subject: [PATCH 1/2] Address review comments --- include/faabric/transport/Message.h | 4 +-- include/faabric/transport/MessageEndpoint.h | 9 ++--- include/faabric/util/concurrent_map.h | 33 +++++++++++-------- src/scheduler/FunctionCallServer.cpp | 4 +-- src/snapshot/SnapshotServer.cpp | 4 +-- src/state/StateServer.cpp | 2 +- src/transport/MessageEndpoint.cpp | 8 +++-- src/transport/MessageEndpointClient.cpp | 2 +- src/transport/PointToPointBroker.cpp | 8 +++-- src/transport/PointToPointServer.cpp | 4 +-- .../test_message_endpoint_client.cpp | 8 ++--- 11 files changed, 49 insertions(+), 37 deletions(-) diff --git a/include/faabric/transport/Message.h b/include/faabric/transport/Message.h index 65ea26fed..fe0d5a22f 100644 --- a/include/faabric/transport/Message.h +++ b/include/faabric/transport/Message.h @@ -1,11 +1,11 @@ #pragma once +#include #include #include #include #include -#include // The header structure is: // 1 byte - Message code (uint8_t) @@ -106,7 +106,7 @@ class Message final std::vector dataCopy() const; - uint8_t getHeader() const + uint8_t getMessageCode() const { return nngMsg == nullptr ? 0 : allData().data()[0]; } diff --git a/include/faabric/transport/MessageEndpoint.h b/include/faabric/transport/MessageEndpoint.h index d810866c0..2a486be20 100644 --- a/include/faabric/transport/MessageEndpoint.h +++ b/include/faabric/transport/MessageEndpoint.h @@ -70,9 +70,10 @@ class MessageContext final nng_ctx context = NNG_CTX_INITIALIZER; }; -// Note: sockets must be open-ed and close-ed from the _same_ thread. In a given -// communication group, one socket may bind, and all the rest must connect. -// Order does not matter. +// Note: In a given communication group, one socket may bind, and all the rest +// must connect. The bound socket should be created before the connecting +// sockets, otherwise the first sendMessage call will block, waiting for the +// socket to connect. class MessageEndpoint { public: @@ -81,7 +82,7 @@ class MessageEndpoint MessageEndpoint(const std::string& addressIn, int timeoutMsIn); // Delete assignment and copy-constructor as we need to be very careful with - // scoping and same-thread instantiation + // scoping. MessageEndpoint& operator=(const MessageEndpoint&) = delete; MessageEndpoint(const MessageEndpoint& ctx) = delete; diff --git a/include/faabric/util/concurrent_map.h b/include/faabric/util/concurrent_map.h index f113b6456..5c47895d8 100644 --- a/include/faabric/util/concurrent_map.h +++ b/include/faabric/util/concurrent_map.h @@ -1,20 +1,19 @@ #pragma once +#include +#include #include #include #include #include +#include #include #include #include +#include #include -#include -#include -#include -#include - namespace faabric::util { namespace detail { @@ -29,15 +28,17 @@ struct is_shared_ptr> : std::true_type } -// A thread-safe wrapper around a hashmap -// -// Supports heterogeneous lookup, e.g. Key==std::string, lookup with -// std::string_view -// -// Most such maps in faasm/faabric don't need to scale to many writers, so for -// simplicity a shared_mutex is simply used instead of a more sophisticated -// lock-free structure, but the underlying map could be swapped in the future if -// needed. +/* + * A thread-safe wrapper around a hashmap + * + * Supports heterogeneous lookup, e.g. Key==std::string, lookup with + * std::string_view + * + * Most such maps in faasm/faabric don't need to scale to many writers, so for + * simplicity a shared_mutex is simply used instead of a more sophisticated + * lock-free structure, but the underlying map could be swapped in the future if + * needed. + */ template class ConcurrentMap final { @@ -82,21 +83,25 @@ class ConcurrentMap final SharedLock lock{ mutex }; return map.empty(); } + size_t size() const { SharedLock lock{ mutex }; return map.size(); } + size_t capacity() const { SharedLock lock{ mutex }; return map.capacity(); } + void reserve(size_t count) { FullLock lock{ mutex }; map.reserve(count); } + // Rehashes the `flat_hash_map`, setting the number of slots to be at least // the passed value. Pass 0 to force a simple rehash. void rehash(size_t count) diff --git a/src/scheduler/FunctionCallServer.cpp b/src/scheduler/FunctionCallServer.cpp index 787610f6d..2d2a0f45c 100644 --- a/src/scheduler/FunctionCallServer.cpp +++ b/src/scheduler/FunctionCallServer.cpp @@ -19,7 +19,7 @@ FunctionCallServer::FunctionCallServer() void FunctionCallServer::doAsyncRecv(transport::Message& message) { - uint8_t header = message.getHeader(); + uint8_t header = message.getMessageCode(); switch (header) { case faabric::scheduler::FunctionCalls::ExecuteFunctions: { recvExecuteFunctions(message.udata()); @@ -39,7 +39,7 @@ void FunctionCallServer::doAsyncRecv(transport::Message& message) std::unique_ptr FunctionCallServer::doSyncRecv( transport::Message& message) { - uint8_t header = message.getHeader(); + uint8_t header = message.getMessageCode(); switch (header) { case faabric::scheduler::FunctionCalls::Flush: { return recvFlush(message.udata()); diff --git a/src/snapshot/SnapshotServer.cpp b/src/snapshot/SnapshotServer.cpp index 0b14f3e44..6e6a4fdc7 100644 --- a/src/snapshot/SnapshotServer.cpp +++ b/src/snapshot/SnapshotServer.cpp @@ -27,7 +27,7 @@ SnapshotServer::SnapshotServer() void SnapshotServer::doAsyncRecv(transport::Message& message) { - uint8_t header = message.getHeader(); + uint8_t header = message.getMessageCode(); switch (header) { case faabric::snapshot::SnapshotCalls::DeleteSnapshot: { recvDeleteSnapshot(message.udata()); @@ -43,7 +43,7 @@ void SnapshotServer::doAsyncRecv(transport::Message& message) std::unique_ptr SnapshotServer::doSyncRecv( transport::Message& message) { - uint8_t header = message.getHeader(); + uint8_t header = message.getMessageCode(); switch (header) { case faabric::snapshot::SnapshotCalls::PushSnapshot: { return recvPushSnapshot(message.udata()); diff --git a/src/state/StateServer.cpp b/src/state/StateServer.cpp index 1b50664a2..be44388eb 100644 --- a/src/state/StateServer.cpp +++ b/src/state/StateServer.cpp @@ -29,7 +29,7 @@ void StateServer::doAsyncRecv(transport::Message& message) std::unique_ptr StateServer::doSyncRecv( transport::Message& message) { - uint8_t header = message.getHeader(); + uint8_t header = message.getMessageCode(); switch (header) { case faabric::state::StateCalls::Pull: { return recvPull(message.udata()); diff --git a/src/transport/MessageEndpoint.cpp b/src/transport/MessageEndpoint.cpp index 7406cff36..6576b12ae 100644 --- a/src/transport/MessageEndpoint.cpp +++ b/src/transport/MessageEndpoint.cpp @@ -1,4 +1,3 @@ -#include #include #include #include @@ -8,6 +7,7 @@ #include #include +#include #include #include #include @@ -343,11 +343,11 @@ Message MessageEndpoint::recvMessage(bool async, std::optional context) } SPDLOG_TRACE("Received message with header {} size {} on {}", - msg.getHeader(), + msg.getMessageCode(), msg.udata().size(), getAddress()); - if (msg.getHeader() == SHUTDOWN_HEADER) { + if (msg.getMessageCode() == SHUTDOWN_HEADER) { if (std::equal(dataBytes.begin(), dataBytes.end(), shutdownPayload.cbegin(), @@ -378,6 +378,8 @@ MessageContext MessageEndpoint::createContext() void MessageEndpoint::close() { + // Switch depending on the type in the variant, closing the dialer or + // listener depending on which one this socket was opened with. std::visit( [](auto&& value) { // see example at diff --git a/src/transport/MessageEndpointClient.cpp b/src/transport/MessageEndpointClient.cpp index ec8d736f8..7984c951b 100644 --- a/src/transport/MessageEndpointClient.cpp +++ b/src/transport/MessageEndpointClient.cpp @@ -1,5 +1,5 @@ -#include "faabric/util/testing.h" #include +#include #include namespace faabric::transport { diff --git a/src/transport/PointToPointBroker.cpp b/src/transport/PointToPointBroker.cpp index 525fa3efb..31b598873 100644 --- a/src/transport/PointToPointBroker.cpp +++ b/src/transport/PointToPointBroker.cpp @@ -1,7 +1,7 @@ -#include "faabric/transport/Message.h" -#include "faabric/transport/MessageEndpoint.h" #include #include +#include +#include #include #include #include @@ -584,6 +584,10 @@ void PointToPointBroker::sendMessage(int groupId, hostHint); } +// Gets or creates a pair of inproc endpoints (recv&send) in the endpoints map. +// Ensures the receiving endpoint gets created first. A reference counter is +// also allocated with the pair to keep track of how many threads are using the +// endpoint pair for cleanup later. auto getEndpointPtrs(const std::string& label) { auto maybeEndpoint = endpoints.get(label); diff --git a/src/transport/PointToPointServer.cpp b/src/transport/PointToPointServer.cpp index 6a3e4c0c2..a62a021ec 100644 --- a/src/transport/PointToPointServer.cpp +++ b/src/transport/PointToPointServer.cpp @@ -21,7 +21,7 @@ PointToPointServer::PointToPointServer() void PointToPointServer::doAsyncRecv(transport::Message& message) { - uint8_t header = message.getHeader(); + uint8_t header = message.getMessageCode(); int sequenceNum = message.getSequenceNum(); switch (header) { case (faabric::transport::PointToPointCall::MESSAGE): { @@ -70,7 +70,7 @@ void PointToPointServer::doAsyncRecv(transport::Message& message) std::unique_ptr PointToPointServer::doSyncRecv( transport::Message& message) { - uint8_t header = message.getHeader(); + uint8_t header = message.getMessageCode(); switch (header) { case (faabric::transport::PointToPointCall::MAPPING): { return doRecvMappings(message.udata()); diff --git a/tests/test/transport/test_message_endpoint_client.cpp b/tests/test/transport/test_message_endpoint_client.cpp index 02bcb3b27..2172e34bc 100644 --- a/tests/test/transport/test_message_endpoint_client.cpp +++ b/tests/test/transport/test_message_endpoint_client.cpp @@ -34,7 +34,7 @@ TEST_CASE_METHOD(SchedulerTestFixture, // Receive message faabric::transport::Message recvMsg = dst.recv(); - REQUIRE(recvMsg.getHeader() == dummyHeader); + REQUIRE(recvMsg.getMessageCode() == dummyHeader); REQUIRE(recvMsg.data().size() == expectedMsg.size()); std::string actualMsg(recvMsg.data().begin(), recvMsg.data().end()); REQUIRE(actualMsg == expectedMsg); @@ -145,7 +145,7 @@ TEST_CASE_METHOD(SchedulerTestFixture, if ((i % (numMessages / 10)) == 0) { std::string expectedMsg = baseMsg + std::to_string(i); REQUIRE(recvMsg.data().size() == expectedMsg.size()); - REQUIRE(recvMsg.getHeader() == dummyHeader); + REQUIRE(recvMsg.getMessageCode() == dummyHeader); std::string actualMsg(recvMsg.data().begin(), recvMsg.data().end()); REQUIRE(actualMsg == expectedMsg); } @@ -187,7 +187,7 @@ TEST_CASE_METHOD(SchedulerTestFixture, // Check just a subset of the messages if ((i % (numMessages / 10)) == 0) { REQUIRE(recvMsg.data().size() == expectedMsg.size()); - REQUIRE(recvMsg.getHeader() == dummyHeader); + REQUIRE(recvMsg.getMessageCode() == dummyHeader); std::string actualMsg(recvMsg.data().begin(), recvMsg.data().end()); REQUIRE(actualMsg == expectedMsg); } @@ -258,7 +258,7 @@ TEST_CASE_METHOD(SchedulerTestFixture, "Test direct messaging", "[transport]") std::string actual; faabric::transport::Message recvMsg = receiver.recv(); - REQUIRE(recvMsg.getHeader() == dummyHeader); + REQUIRE(recvMsg.getMessageCode() == dummyHeader); actual = std::string(recvMsg.data().begin(), recvMsg.data().end()); REQUIRE(actual == expected); From 5f8bea122a4e840a748001822fee3bf03eccb532 Mon Sep 17 00:00:00 2001 From: Raven Szewczyk Date: Thu, 1 Dec 2022 11:30:08 +0000 Subject: [PATCH 2/2] Add tests for more map operations --- tests/test/util/test_concurrent_map.cpp | 35 +++++++++++++++++++++++-- 1 file changed, 33 insertions(+), 2 deletions(-) diff --git a/tests/test/util/test_concurrent_map.cpp b/tests/test/util/test_concurrent_map.cpp index 5bed013a9..dc0aea5e2 100644 --- a/tests/test/util/test_concurrent_map.cpp +++ b/tests/test/util/test_concurrent_map.cpp @@ -29,10 +29,21 @@ TEST_CASE("Test basic map operations in a single thread", REQUIRE(map.capacity() >= initialCapacity); REQUIRE(map.sortedKvPairs().empty()); + map.reserve(2 * initialCapacity); + REQUIRE(map.capacity() >= 2 * initialCapacity); + REQUIRE(map.insert(std::make_pair(1, 10))); REQUIRE(!map.insert(std::make_pair(1, 20))); // no-op REQUIRE(map.insert(std::make_pair(3, 30))); + REQUIRE(!map.isEmpty()); + REQUIRE(map.size() == 2); + REQUIRE_THAT( + map.sortedKvPairs(), + Equals(std::vector>{ { 1, 10 }, { 3, 30 } })); + + map.rehash(0); + REQUIRE(!map.isEmpty()); REQUIRE(map.size() == 2); REQUIRE_THAT( @@ -58,13 +69,21 @@ TEST_CASE("Test basic map operations in a single thread", Equals(std::vector>{ { 1, 20 }, { 2, 20 }, { 3, 30 }, { 4, 40 }, { 5, 50 } })); + map.erase(1); + + REQUIRE(!map.isEmpty()); + REQUIRE(map.size() == 4); + REQUIRE_THAT(map.sortedKvPairs(), + Equals(std::vector>{ + { 2, 20 }, { 3, 30 }, { 4, 40 }, { 5, 50 } })); + int called = 0; map.tryEmplaceThenMutate( 1, [&](bool placed, int& val) { called++; - REQUIRE(!placed); - REQUIRE(val == 20); + REQUIRE(placed); + REQUIRE(val == 0); val = 10; }, 0); @@ -133,9 +152,21 @@ TEST_CASE("Test basic map operations in a single thread", value /= 10; }); REQUIRE(called == 5); + REQUIRE(map.size() == 5); REQUIRE_THAT(map.sortedKvPairs(), Equals(std::vector>{ { 1, 1 }, { 2, 2 }, { 3, 3 }, { 4, 4 }, { 5, 5 } })); + + map.eraseIf([](const int& key, const int& value) { return key <= 3; }); + REQUIRE(map.size() == 2); + REQUIRE_THAT( + map.sortedKvPairs(), + Equals(std::vector>{ { 4, 4 }, { 5, 5 } })); + + map.clear(); + REQUIRE(map.isEmpty()); + REQUIRE(map.size() == 0); + REQUIRE(map.sortedKvPairs().empty()); } TEST_CASE("Test insertion from many threads", "[util][concurrent_map]")