Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/nng'
Browse files Browse the repository at this point in the history
  • Loading branch information
eigenraven committed Dec 1, 2022
2 parents 93bb3b8 + 5f8bea1 commit 8e43276
Show file tree
Hide file tree
Showing 12 changed files with 82 additions and 39 deletions.
4 changes: 2 additions & 2 deletions include/faabric/transport/Message.h
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
#pragma once

#include <nng/nng.h>
#include <span>
#include <string>
#include <vector>

#include <faabric/util/bytes.h>
#include <nng/nng.h>

// The header structure is:
// 1 byte - Message code (uint8_t)
Expand Down Expand Up @@ -106,7 +106,7 @@ class Message final

std::vector<uint8_t> dataCopy() const;

uint8_t getHeader() const
uint8_t getMessageCode() const
{
return nngMsg == nullptr ? 0 : allData().data()[0];
}
Expand Down
9 changes: 5 additions & 4 deletions include/faabric/transport/MessageEndpoint.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,10 @@ class MessageContext final
nng_ctx context = NNG_CTX_INITIALIZER;
};

// Note: sockets must be open-ed and close-ed from the _same_ thread. In a given
// communication group, one socket may bind, and all the rest must connect.
// Order does not matter.
// Note: In a given communication group, one socket may bind, and all the rest
// must connect. The bound socket should be created before the connecting
// sockets, otherwise the first sendMessage call will block, waiting for the
// socket to connect.
class MessageEndpoint
{
public:
Expand All @@ -81,7 +82,7 @@ class MessageEndpoint
MessageEndpoint(const std::string& addressIn, int timeoutMsIn);

// Delete assignment and copy-constructor as we need to be very careful with
// scoping and same-thread instantiation
// scoping.
MessageEndpoint& operator=(const MessageEndpoint&) = delete;

MessageEndpoint(const MessageEndpoint& ctx) = delete;
Expand Down
33 changes: 19 additions & 14 deletions include/faabric/util/concurrent_map.h
Original file line number Diff line number Diff line change
@@ -1,20 +1,19 @@
#pragma once

#include <absl/container/flat_hash_map.h>
#include <boost/type_traits.hpp>
#include <concepts>
#include <cstdint>
#include <functional>
#include <optional>
#include <shared_mutex>
#include <string>
#include <tuple>
#include <type_traits>
#include <utility>

#include <faabric/util/locks.h>

#include <absl/container/flat_hash_map.h>
#include <boost/type_traits.hpp>
#include <shared_mutex>
#include <utility>

namespace faabric::util {

namespace detail {
Expand All @@ -29,15 +28,17 @@ struct is_shared_ptr<std::shared_ptr<Pointee>> : std::true_type

}

// A thread-safe wrapper around a hashmap
//
// Supports heterogeneous lookup, e.g. Key==std::string, lookup with
// std::string_view
//
// Most such maps in faasm/faabric don't need to scale to many writers, so for
// simplicity a shared_mutex is simply used instead of a more sophisticated
// lock-free structure, but the underlying map could be swapped in the future if
// needed.
/*
* A thread-safe wrapper around a hashmap
*
* Supports heterogeneous lookup, e.g. Key==std::string, lookup with
* std::string_view
*
* Most such maps in faasm/faabric don't need to scale to many writers, so for
* simplicity a shared_mutex is simply used instead of a more sophisticated
* lock-free structure, but the underlying map could be swapped in the future if
* needed.
*/
template<class Key, class Value>
class ConcurrentMap final
{
Expand Down Expand Up @@ -82,21 +83,25 @@ class ConcurrentMap final
SharedLock lock{ mutex };
return map.empty();
}

size_t size() const
{
SharedLock lock{ mutex };
return map.size();
}

size_t capacity() const
{
SharedLock lock{ mutex };
return map.capacity();
}

void reserve(size_t count)
{
FullLock lock{ mutex };
map.reserve(count);
}

// Rehashes the `flat_hash_map`, setting the number of slots to be at least
// the passed value. Pass 0 to force a simple rehash.
void rehash(size_t count)
Expand Down
4 changes: 2 additions & 2 deletions src/scheduler/FunctionCallServer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ FunctionCallServer::FunctionCallServer()

void FunctionCallServer::doAsyncRecv(transport::Message& message)
{
uint8_t header = message.getHeader();
uint8_t header = message.getMessageCode();
switch (header) {
case faabric::scheduler::FunctionCalls::ExecuteFunctions: {
recvExecuteFunctions(message.udata());
Expand All @@ -44,7 +44,7 @@ void FunctionCallServer::doAsyncRecv(transport::Message& message)
std::unique_ptr<google::protobuf::Message> FunctionCallServer::doSyncRecv(
transport::Message& message)
{
uint8_t header = message.getHeader();
uint8_t header = message.getMessageCode();
switch (header) {
case faabric::scheduler::FunctionCalls::Flush: {
return recvFlush(message.udata());
Expand Down
4 changes: 2 additions & 2 deletions src/snapshot/SnapshotServer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ SnapshotServer::SnapshotServer()

void SnapshotServer::doAsyncRecv(transport::Message& message)
{
uint8_t header = message.getHeader();
uint8_t header = message.getMessageCode();
switch (header) {
case faabric::snapshot::SnapshotCalls::DeleteSnapshot: {
recvDeleteSnapshot(message.udata());
Expand All @@ -43,7 +43,7 @@ void SnapshotServer::doAsyncRecv(transport::Message& message)
std::unique_ptr<google::protobuf::Message> SnapshotServer::doSyncRecv(
transport::Message& message)
{
uint8_t header = message.getHeader();
uint8_t header = message.getMessageCode();
switch (header) {
case faabric::snapshot::SnapshotCalls::PushSnapshot: {
return recvPushSnapshot(message.udata());
Expand Down
2 changes: 1 addition & 1 deletion src/state/StateServer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ void StateServer::doAsyncRecv(transport::Message& message)
std::unique_ptr<google::protobuf::Message> StateServer::doSyncRecv(
transport::Message& message)
{
uint8_t header = message.getHeader();
uint8_t header = message.getMessageCode();
switch (header) {
case faabric::state::StateCalls::Pull: {
return recvPull(message.udata());
Expand Down
8 changes: 5 additions & 3 deletions src/transport/MessageEndpoint.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
#include <chrono>
#include <faabric/transport/Message.h>
#include <faabric/transport/MessageEndpoint.h>
#include <faabric/transport/common.h>
Expand All @@ -8,6 +7,7 @@
#include <faabric/util/macros.h>

#include <array>
#include <chrono>
#include <concepts>
#include <optional>
#include <stdexcept>
Expand Down Expand Up @@ -343,11 +343,11 @@ Message MessageEndpoint::recvMessage(bool async, std::optional<nng_ctx> context)
}

SPDLOG_TRACE("Received message with header {} size {} on {}",
msg.getHeader(),
msg.getMessageCode(),
msg.udata().size(),
getAddress());

if (msg.getHeader() == SHUTDOWN_HEADER) {
if (msg.getMessageCode() == SHUTDOWN_HEADER) {
if (std::equal(dataBytes.begin(),
dataBytes.end(),
shutdownPayload.cbegin(),
Expand Down Expand Up @@ -378,6 +378,8 @@ MessageContext MessageEndpoint::createContext()

void MessageEndpoint::close()
{
// Switch depending on the type in the variant, closing the dialer or
// listener depending on which one this socket was opened with.
std::visit(
[](auto&& value) {
// see example at
Expand Down
2 changes: 1 addition & 1 deletion src/transport/MessageEndpointClient.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#include "faabric/util/testing.h"
#include <faabric/transport/MessageEndpointClient.h>
#include <faabric/util/testing.h>
#include <optional>

namespace faabric::transport {
Expand Down
8 changes: 6 additions & 2 deletions src/transport/PointToPointBroker.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#include "faabric/transport/Message.h"
#include "faabric/transport/MessageEndpoint.h"
#include <faabric/proto/faabric.pb.h>
#include <faabric/scheduler/Scheduler.h>
#include <faabric/transport/Message.h>
#include <faabric/transport/MessageEndpoint.h>
#include <faabric/transport/PointToPointBroker.h>
#include <faabric/transport/PointToPointClient.h>
#include <faabric/util/bytes.h>
Expand Down Expand Up @@ -584,6 +584,10 @@ void PointToPointBroker::sendMessage(int groupId,
hostHint);
}

// Gets or creates a pair of inproc endpoints (recv&send) in the endpoints map.
// Ensures the receiving endpoint gets created first. A reference counter is
// also allocated with the pair to keep track of how many threads are using the
// endpoint pair for cleanup later.
auto getEndpointPtrs(const std::string& label)
{
auto maybeEndpoint = endpoints.get(label);
Expand Down
4 changes: 2 additions & 2 deletions src/transport/PointToPointServer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ PointToPointServer::PointToPointServer()

void PointToPointServer::doAsyncRecv(transport::Message& message)
{
uint8_t header = message.getHeader();
uint8_t header = message.getMessageCode();
int sequenceNum = message.getSequenceNum();
switch (header) {
case (faabric::transport::PointToPointCall::MESSAGE): {
Expand Down Expand Up @@ -70,7 +70,7 @@ void PointToPointServer::doAsyncRecv(transport::Message& message)
std::unique_ptr<google::protobuf::Message> PointToPointServer::doSyncRecv(
transport::Message& message)
{
uint8_t header = message.getHeader();
uint8_t header = message.getMessageCode();
switch (header) {
case (faabric::transport::PointToPointCall::MAPPING): {
return doRecvMappings(message.udata());
Expand Down
8 changes: 4 additions & 4 deletions tests/test/transport/test_message_endpoint_client.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ TEST_CASE_METHOD(SchedulerTestFixture,

// Receive message
faabric::transport::Message recvMsg = dst.recv();
REQUIRE(recvMsg.getHeader() == dummyHeader);
REQUIRE(recvMsg.getMessageCode() == dummyHeader);
REQUIRE(recvMsg.data().size() == expectedMsg.size());
std::string actualMsg(recvMsg.data().begin(), recvMsg.data().end());
REQUIRE(actualMsg == expectedMsg);
Expand Down Expand Up @@ -145,7 +145,7 @@ TEST_CASE_METHOD(SchedulerTestFixture,
if ((i % (numMessages / 10)) == 0) {
std::string expectedMsg = baseMsg + std::to_string(i);
REQUIRE(recvMsg.data().size() == expectedMsg.size());
REQUIRE(recvMsg.getHeader() == dummyHeader);
REQUIRE(recvMsg.getMessageCode() == dummyHeader);
std::string actualMsg(recvMsg.data().begin(), recvMsg.data().end());
REQUIRE(actualMsg == expectedMsg);
}
Expand Down Expand Up @@ -187,7 +187,7 @@ TEST_CASE_METHOD(SchedulerTestFixture,
// Check just a subset of the messages
if ((i % (numMessages / 10)) == 0) {
REQUIRE(recvMsg.data().size() == expectedMsg.size());
REQUIRE(recvMsg.getHeader() == dummyHeader);
REQUIRE(recvMsg.getMessageCode() == dummyHeader);
std::string actualMsg(recvMsg.data().begin(), recvMsg.data().end());
REQUIRE(actualMsg == expectedMsg);
}
Expand Down Expand Up @@ -258,7 +258,7 @@ TEST_CASE_METHOD(SchedulerTestFixture, "Test direct messaging", "[transport]")
std::string actual;
faabric::transport::Message recvMsg = receiver.recv();

REQUIRE(recvMsg.getHeader() == dummyHeader);
REQUIRE(recvMsg.getMessageCode() == dummyHeader);

actual = std::string(recvMsg.data().begin(), recvMsg.data().end());
REQUIRE(actual == expected);
Expand Down
35 changes: 33 additions & 2 deletions tests/test/util/test_concurrent_map.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,21 @@ TEST_CASE("Test basic map operations in a single thread",
REQUIRE(map.capacity() >= initialCapacity);
REQUIRE(map.sortedKvPairs().empty());

map.reserve(2 * initialCapacity);
REQUIRE(map.capacity() >= 2 * initialCapacity);

REQUIRE(map.insert(std::make_pair(1, 10)));
REQUIRE(!map.insert(std::make_pair(1, 20))); // no-op
REQUIRE(map.insert(std::make_pair(3, 30)));

REQUIRE(!map.isEmpty());
REQUIRE(map.size() == 2);
REQUIRE_THAT(
map.sortedKvPairs(),
Equals(std::vector<std::pair<int, int>>{ { 1, 10 }, { 3, 30 } }));

map.rehash(0);

REQUIRE(!map.isEmpty());
REQUIRE(map.size() == 2);
REQUIRE_THAT(
Expand All @@ -58,13 +69,21 @@ TEST_CASE("Test basic map operations in a single thread",
Equals(std::vector<std::pair<int, int>>{
{ 1, 20 }, { 2, 20 }, { 3, 30 }, { 4, 40 }, { 5, 50 } }));

map.erase(1);

REQUIRE(!map.isEmpty());
REQUIRE(map.size() == 4);
REQUIRE_THAT(map.sortedKvPairs(),
Equals(std::vector<std::pair<int, int>>{
{ 2, 20 }, { 3, 30 }, { 4, 40 }, { 5, 50 } }));

int called = 0;
map.tryEmplaceThenMutate(
1,
[&](bool placed, int& val) {
called++;
REQUIRE(!placed);
REQUIRE(val == 20);
REQUIRE(placed);
REQUIRE(val == 0);
val = 10;
},
0);
Expand Down Expand Up @@ -133,9 +152,21 @@ TEST_CASE("Test basic map operations in a single thread",
value /= 10;
});
REQUIRE(called == 5);
REQUIRE(map.size() == 5);
REQUIRE_THAT(map.sortedKvPairs(),
Equals(std::vector<std::pair<int, int>>{
{ 1, 1 }, { 2, 2 }, { 3, 3 }, { 4, 4 }, { 5, 5 } }));

map.eraseIf([](const int& key, const int& value) { return key <= 3; });
REQUIRE(map.size() == 2);
REQUIRE_THAT(
map.sortedKvPairs(),
Equals(std::vector<std::pair<int, int>>{ { 4, 4 }, { 5, 5 } }));

map.clear();
REQUIRE(map.isEmpty());
REQUIRE(map.size() == 0);
REQUIRE(map.sortedKvPairs().empty());
}

TEST_CASE("Test insertion from many threads", "[util][concurrent_map]")
Expand Down

0 comments on commit 8e43276

Please sign in to comment.