Skip to content
Merged
24 changes: 5 additions & 19 deletions src/se/impl/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,29 +13,11 @@

namespace jam::se::utils {

template <typename To, typename From>
inline std::shared_ptr<To> reinterpret_pointer_cast(
const std::shared_ptr<From> &ptr) noexcept {
return std::shared_ptr<To>(ptr, reinterpret_cast<To *>(ptr.get()));
}

template <typename T>
inline std::weak_ptr<T> make_weak(const std::shared_ptr<T> &ptr) noexcept {
return ptr;
}

struct NoCopy {
NoCopy(const NoCopy &) = delete;
NoCopy &operator=(const NoCopy &) = delete;
NoCopy() = default;
};

struct NoMove {
NoMove(NoMove &&) = delete;
NoMove &operator=(NoMove &&) = delete;
NoMove() = default;
};

template <typename T, typename M = std::shared_mutex>
struct SafeObject {
using Type = T;
Expand Down Expand Up @@ -84,13 +66,17 @@ namespace jam::se::utils {
template <typename T, typename M = std::shared_mutex>
using ReadWriteObject = SafeObject<T, M>;

class WaitForSingleObject final : NoMove, NoCopy {
class WaitForSingleObject final {
std::condition_variable wait_cv_;
std::mutex wait_m_;
bool flag_;

public:
WaitForSingleObject() : flag_{true} {}
WaitForSingleObject(WaitForSingleObject &&) = delete;
WaitForSingleObject(const WaitForSingleObject &) = delete;
WaitForSingleObject &operator=(WaitForSingleObject &&) = delete;
WaitForSingleObject &operator=(const WaitForSingleObject &) = delete;

bool wait(std::chrono::microseconds wait_timeout) {
std::unique_lock<std::mutex> _lock(wait_m_);
Expand Down
113 changes: 63 additions & 50 deletions src/utils/channel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

#include <type_traits>

#include "utils/ctor_limiters.hpp"

namespace jam {

template <typename T>
Expand All @@ -23,94 +25,105 @@ namespace jam {
};

template <typename Opp>
struct Endpoint : se::utils::NoCopy {
struct Endpoint : NonCopyable {
static_assert(std::is_same_v<Opp, _Receiver>
|| std::is_same_v<Opp, _Sender>,
"Incorrect type");
static constexpr bool IsReceiver = std::is_same_v<Opp, _Receiver>;
static constexpr bool IsSender = std::is_same_v<Opp, _Sender>;

Endpoint(Endpoint &&other) requires(IsReceiver) {
Endpoint(Endpoint &&other)
requires(IsReceiver)
{
context_.exclusiveAccess([&](auto &my_context) {
Endpoint<typename Opp::Other> *opp = nullptr;
while (other.context_.exclusiveAccess([&](auto &other_context) {
if (other_context.opp_) {
if (!other_context.opp_->register_opp(*this)) {
return true;
}
opp = other_context.opp_;
other_context.opp_ = nullptr;
}
return false;
}));
my_context.opp_ = opp;
Endpoint<typename Opp::Other> *opp = nullptr;
while (other.context_.exclusiveAccess([&](auto &other_context) {
if (other_context.opp_) {
if (!other_context.opp_->register_opp(*this)) {
return true;
}
opp = other_context.opp_;
other_context.opp_ = nullptr;
}
return false;
}));
my_context.opp_ = opp;
});
}

Endpoint(Endpoint &&other) requires(IsSender) {
Endpoint(Endpoint &&other)
requires(IsSender)
{
context_.exclusiveAccess([&](auto &my_context) {
my_context.opp_ = other.context_.exclusiveAccess([&](auto &other_context) {
my_context.opp_ =
other.context_.exclusiveAccess([&](auto &other_context) {
Endpoint<typename Opp::Other> *opp = nullptr;
if (other_context.opp_) {
other_context.opp_->register_opp(*this);
opp = other_context.opp_;
other_context.opp_ = nullptr;
other_context.opp_->register_opp(*this);
opp = other_context.opp_;
other_context.opp_ = nullptr;
}
return opp;
});
});
});
}

Endpoint &operator=(Endpoint &&other) requires(IsReceiver) {
Endpoint &operator=(Endpoint &&other)
requires(IsReceiver)
{
if (this != &other) {
context_.exclusiveAccess([&](auto &my_context) {
Endpoint<typename Opp::Other> *opp = nullptr;
while (other.context_.exclusiveAccess([&](auto &other_context) {
if (other_context.opp_) {
if (!other_context.opp_->register_opp(*this)) {
return true;
}
opp = other_context.opp_;
other_context.opp_ = nullptr;
}
return false;
}));
my_context.opp_ = opp;
Endpoint<typename Opp::Other> *opp = nullptr;
while (other.context_.exclusiveAccess([&](auto &other_context) {
if (other_context.opp_) {
if (!other_context.opp_->register_opp(*this)) {
return true;
}
opp = other_context.opp_;
other_context.opp_ = nullptr;
}
return false;
}));
my_context.opp_ = opp;
});
}
return *this;
}

Endpoint &operator=(Endpoint &&other) requires(IsSender) {
Endpoint &operator=(Endpoint &&other)
requires(IsSender)
{
if (this != &other) {
context_.exclusiveAccess([&](auto &my_context) {
my_context.opp_ = other.context_.exclusiveAccess([&](auto &other_context) {
my_context.opp_ =
other.context_.exclusiveAccess([&](auto &other_context) {
Endpoint<typename Opp::Other> *opp = nullptr;
if (other_context.opp_) {
other_context.opp_->register_opp(*this);
opp = other_context.opp_;
other_context.opp_ = nullptr;
other_context.opp_->register_opp(*this);
opp = other_context.opp_;
other_context.opp_ = nullptr;
}
return opp;
});
});
});
}
return *this;
}

bool register_opp(Endpoint<typename Opp::Other> &opp) requires(IsReceiver) {
return context_.exclusiveAccess([&](auto &context) {
context.opp_ = &opp;
return true;
});

bool register_opp(Endpoint<typename Opp::Other> &opp)
requires(IsReceiver)
{
return context_.exclusiveAccess([&](auto &context) {
context.opp_ = &opp;
return true;
});
}

bool register_opp(Endpoint<typename Opp::Other> &opp) requires(IsSender) {
bool register_opp(Endpoint<typename Opp::Other> &opp)
requires(IsSender)
{
return context_
.try_exclusiveAccess([&](auto &context) {
context.opp_ = &opp;
})
.try_exclusiveAccess([&](auto &context) { context.opp_ = &opp; })
.has_value();
}

Expand All @@ -120,7 +133,7 @@ namespace jam {
return context_.exclusiveAccess([&](auto &context) {
assert(context.opp_ == &opp);
context.opp_ = nullptr;
return true;
return true;
});
}

Expand Down
2 changes: 1 addition & 1 deletion tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@ include_directories(
${PROJECT_SOURCE_DIR}/src
)

# add_subdirectory(utils)
# add_subdirectory(utils)
92 changes: 70 additions & 22 deletions tests/utils/channel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,27 +9,57 @@
using namespace std::chrono_literals;
using namespace jam;

/**
* @file ChannelTest.cpp
* @brief Unit tests for Channel class covering send/receive behavior.
*/

#include <gtest/gtest.h>

#include <optional>
#include <thread>

#include "Channel.h"

/**
* @brief Tests sending and receiving a single integer value through the
* channel.
* @details Creates a channel, sends the value 42, and verifies that the
* receiver obtains it.
*/
TEST(ChannelTest, SendAndReceiveValue) {
auto [recv, send] = Channel<int>::create_channel();
auto [recv, send] =
Channel<int>::create_channel(); ///< Create a channel for int values.

send.set(42);
auto value = recv.wait();
send.set(42); ///< Send the integer value 42.
auto value = recv.wait(); ///< Wait for and retrieve the sent value.

ASSERT_TRUE(value.has_value());
EXPECT_EQ(value.value(), 42);
ASSERT_TRUE(value.has_value()); ///< Verify that a value was received.
EXPECT_EQ(value.value(),
42); ///< Check that the received value is equal to 42.
}

/**
* @brief Tests sending an lvalue through the channel.
* @details Sends a copy of the variable 'x' and ensures the receiver obtains
* the correct value.
*/
TEST(ChannelTest, SendLValue) {
auto [recv, send] = Channel<int>::create_channel();

int x = 123;
send.set(x);
auto value = recv.wait();
int x = 123; ///< Define an integer variable.
send.set(x); ///< Send a copy of x.
auto value = recv.wait(); ///< Receive the value.

ASSERT_TRUE(value.has_value());
EXPECT_EQ(value.value(), 123);
ASSERT_TRUE(value.has_value()); ///< Ensure a value was received.
EXPECT_EQ(value.value(), 123); ///< Validate that the value matches x.
}

/**
* @brief Tests that destroying the sender notifies the receiver.
* @details Starts a waiting thread on the receiver, destroys the sender, and
* expects the receiver to unblock with no value.
*/
TEST(ChannelTest, SenderDestructionNotifiesReceiver) {
std::optional<Channel<int>::Receiver> recv;
std::optional<Channel<int>::Sender> send;
Expand All @@ -38,34 +68,52 @@ TEST(ChannelTest, SenderDestructionNotifiesReceiver) {

std::optional<int> result;

std::thread t([&]() { result = recv->wait(); });
std::thread t([&]() {
result = recv->wait();
}); ///< Thread blocks waiting for a value.

std::this_thread::sleep_for(50ms);
send.reset();
std::this_thread::sleep_for(
std::chrono::milliseconds(50)); ///< Ensure the thread is waiting.
send.reset(); ///< Destroy the sender to signal end-of-transmission.

t.join();
t.join(); ///< Wait for the thread to finish.

EXPECT_FALSE(result.has_value());
EXPECT_FALSE(result.has_value()); ///< The result should be empty since the
///< sender no longer exists.
}

/**
* @brief Tests that multiple sends only allow one value to be received.
* @details Sends two values consecutively; the receiver should get exactly one
* of them.
*/
TEST(ChannelTest, MultipleSendKeepsOneValue) {
auto [recv, send] = Channel<int>::create_channel();

send.set(1);
send.set(2);
send.set(1); ///< First send.
send.set(2); ///< Second send overrides or is ignored; only one value should
///< be kept.

auto value = recv.wait();
ASSERT_TRUE(value.has_value());
EXPECT_TRUE(value.value() == 1 || value.value() == 2);
auto value = recv.wait(); ///< Receive the value.
ASSERT_TRUE(value.has_value()); ///< Confirm a value was received.
EXPECT_TRUE(value.value() == 1
|| value.value() == 2); ///< Value should be either 1 or 2.
}

/**
* @brief Tests that destroying the receiver unregisters it without throwing in
* the sender.
* @details Receiver is destroyed before sending; calling send.set() should not
* throw an exception.
*/
TEST(ChannelTest, ReceiverDestructionUnregistersSender) {
std::optional<Channel<int>::Receiver> recv;
std::optional<Channel<int>::Sender> send;

std::tie(recv, send) = Channel<int>::create_channel();

recv.reset();
recv.reset(); ///< Destroy the receiver prior to sending.

EXPECT_NO_THROW(send->set(999));
EXPECT_NO_THROW(send->set(
999)); ///< Sending after receiver destruction should not throw.
}
Loading