Skip to content
Merged
121 changes: 121 additions & 0 deletions src/se/impl/common.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
/**
* Copyright Quadrivium LLC
* All Rights Reserved
* SPDX-License-Identifier: Apache-2.0
*/

#pragma once

#include <chrono>
#include <condition_variable>
#include <mutex>
#include <shared_mutex>

namespace jam::se::utils {

template <typename To, typename From>
inline std::shared_ptr<To> reinterpret_pointer_cast(
const std::shared_ptr<From> &ptr) noexcept {
return std::shared_ptr<To>(ptr, reinterpret_cast<To *>(ptr.get()));
}

template <typename T>
inline std::weak_ptr<T> make_weak(const std::shared_ptr<T> &ptr) noexcept {
return ptr;
}

struct NoCopy {
NoCopy(const NoCopy &) = delete;
NoCopy &operator=(const NoCopy &) = delete;
NoCopy() = default;
};

struct NoMove {
NoMove(NoMove &&) = delete;
NoMove &operator=(NoMove &&) = delete;
NoMove() = default;
};

template <typename T, typename M = std::shared_mutex>
struct SafeObject {
using Type = T;

template <typename... Args>
SafeObject(Args &&...args) : t_(std::forward<Args>(args)...) {}

template <typename F>
inline auto exclusiveAccess(F &&f) {
std::unique_lock lock(cs_);
return std::forward<F>(f)(t_);
}

template <typename F>
inline auto try_exclusiveAccess(F &&f) {
std::unique_lock lock(cs_, std::try_to_lock);
using ResultType = decltype(std::forward<F>(f)(t_));
constexpr bool is_void = std::is_void_v<ResultType>;
using OptionalType = std::conditional_t<is_void,
std::optional<std::monostate>,
std::optional<ResultType>>;

if (lock.owns_lock()) {
if constexpr (is_void) {
std::forward<F>(f)(t_);
return OptionalType(std::in_place);
} else {
return OptionalType(std::forward<F>(f)(t_));
}
} else {
return OptionalType();
}
}

template <typename F>
inline auto sharedAccess(F &&f) const {
std::shared_lock lock(cs_);
return std::forward<F>(f)(t_);
}

private:
T t_;
mutable M cs_;
};

template <typename T, typename M = std::shared_mutex>
using ReadWriteObject = SafeObject<T, M>;

class WaitForSingleObject final : NoMove, NoCopy {
std::condition_variable wait_cv_;
std::mutex wait_m_;
bool flag_;

public:
WaitForSingleObject() : flag_{true} {}

bool wait(std::chrono::microseconds wait_timeout) {
std::unique_lock<std::mutex> _lock(wait_m_);
return wait_cv_.wait_for(_lock, wait_timeout, [&]() {
auto prev = !flag_;
flag_ = true;
return prev;
});
}

void wait() {
std::unique_lock<std::mutex> _lock(wait_m_);
wait_cv_.wait(_lock, [&]() {
auto prev = !flag_;
flag_ = true;
return prev;
});
}

void set() {
{
std::unique_lock<std::mutex> _lock(wait_m_);
flag_ = false;
}
wait_cv_.notify_one();
}
};
} // namespace jam::se::utils
149 changes: 149 additions & 0 deletions src/utils/channel.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
/**
* Copyright Quadrivium LLC
* All Rights Reserved
* SPDX-License-Identifier: Apache-2.0
*/

#pragma once

#include <type_traits>

namespace jam {

template <typename T>
struct Channel {
struct _Receiver;
struct _Sender;

struct _Receiver {
using Other = _Sender;
};
struct _Sender {
using Other = _Receiver;
};

template <typename Opp>
struct Endpoint {
static_assert(std::is_same_v<Opp, _Receiver>
|| std::is_same_v<Opp, _Sender>,
"Incorrect type");
static constexpr bool IsReceiver = std::is_same_v<Opp, _Receiver>;
static constexpr bool IsSender = std::is_same_v<Opp, _Sender>;

void register_opp(Endpoint<typename Opp::Other> &opp) {
context_.exclusiveAccess([&](auto &context) { context.opp_ = &opp; });
}

bool unregister_opp(Endpoint<typename Opp::Other> &opp)
requires(IsReceiver)
{
context_.exclusiveAccess([&](auto &context) {
assert(context.opp_ == &opp);
context.opp_ = nullptr;
});
return true;
}

bool unregister_opp(Endpoint<typename Opp::Other> &opp)
requires(IsSender)
{
return context_
.try_exclusiveAccess([&](auto &context) {
assert(context.opp_ == &opp);
context.opp_ = nullptr;
})
.has_value();
}

~Endpoint()
requires(IsSender)
{
context_.exclusiveAccess([&](auto &context) {
if (context.opp_) {
context.opp_->unregister_opp(*this);
context.opp_->event_.set();
context.opp_ = nullptr;
}
});
}


~Endpoint()
requires(IsReceiver)
{
while (context_.exclusiveAccess([&](auto &context) {
if (context.opp_) {
if (!context.opp_->unregister_opp(*this)) {
return true;
}
context.opp_ = nullptr;
}
return false;
}));
}

void set(T &&t)
requires(IsSender)
{
context_.exclusiveAccess([&](auto &context) {
if (context.opp_) {
context.opp_->context_.exclusiveAccess(
[&](auto &c) { c.data_ = std::move(t); });
context.opp_->event_.set();
}
});
}

void set(T &t)
requires(IsSender)
{
context_.exclusiveAccess([&](auto &context) {
if (context.opp_) {
context.opp_->context_.exclusiveAccess(
[&](auto &c) { c.data_ = t; });
context.opp_->event_.set();
}
});
}

std::optional<T> wait()
requires(IsReceiver)
{
event_.wait();
return context_.exclusiveAccess(
[&](auto &context) { return std::move(context.data_); });
}

private:
friend struct Endpoint<typename Opp::Other>;
struct SafeContext {
std::conditional_t<std::is_same_v<Opp, _Receiver>,
std::optional<T>,
std::monostate>
data_;
Endpoint<typename Opp::Other> *opp_ = nullptr;
};

std::conditional_t<std::is_same_v<Opp, _Receiver>,
jam::se::utils::WaitForSingleObject,
std::monostate>
event_;
jam::se::utils::SafeObject<SafeContext, std::mutex> context_;
};

using Receiver = Endpoint<_Receiver>;
using Sender = Endpoint<_Sender>;

template <typename T>
inline std::pair<Receiver, Sender> create_channel() {
using C = Channel<T>;
C::Receiver r;
C::Sender s;

r.register_opp(s);
s.register_opp(r);
return std::make_pair(std::move(r), std::move(s));
}
};

} // namespace jam
7 changes: 6 additions & 1 deletion tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,9 @@
# SPDX-License-Identifier: Apache-2.0
#

message(STATUS "There are no tests yet")
include_directories(
${CMAKE_CURRENT_SOURCE_DIR}
${PROJECT_SOURCE_DIR}/src
)

# add_subdirectory(utils)
13 changes: 13 additions & 0 deletions tests/utils/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#
# Copyright Quadrivium LLC
# All Rights Reserved
# SPDX-License-Identifier: Apache-2.0
#

# addtest(utils_test
# channel.cpp
# )

# target_link_libraries(utils_test

# )
70 changes: 70 additions & 0 deletions tests/utils/channel.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
#include "utils/channel.hpp"

#include <gtest/gtest.h>

#include <chrono>
#include <optional>
#include <thread>

using namespace std::chrono_literals;

TEST(ChannelTest, SendAndReceiveValue) {
auto [recv, send] = Channel<int>::create_channel<int>();

send.set(42);
auto value = recv.wait();

ASSERT_TRUE(value.has_value());
EXPECT_EQ(value.value(), 42);
}

TEST(ChannelTest, SendLValue) {
auto [recv, send] = Channel<int>::create_channel<int>();

int x = 123;
send.set(x);
auto value = recv.wait();

ASSERT_TRUE(value.has_value());
EXPECT_EQ(value.value(), 123);
}

TEST(ChannelTest, SenderDestructionNotifiesReceiver) {
std::optional<Channel<int>::Receiver> recv;
std::optional<Channel<int>::Sender> send;

std::tie(recv, send) = Channel<int>::create_channel<int>();

std::optional<int> result;

std::thread t([&]() { result = recv->wait(); });

std::this_thread::sleep_for(50ms);
send.reset();

t.join();

EXPECT_FALSE(result.has_value());
}

TEST(ChannelTest, MultipleSendKeepsOneValue) {
auto [recv, send] = Channel<int>::create_channel<int>();

send.set(1);
send.set(2);

auto value = recv.wait();
ASSERT_TRUE(value.has_value());
EXPECT_TRUE(value.value() == 1 || value.value() == 2);
}

TEST(ChannelTest, ReceiverDestructionUnregistersSender) {
std::optional<Channel<int>::Receiver> recv;
std::optional<Channel<int>::Sender> send;

std::tie(recv, send) = Channel<int>::create_channel<int>();

recv.reset();

EXPECT_NO_THROW(send->set(999));
}
Loading