Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] fix(ipc): reimplement I/O loop #17

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
195 changes: 67 additions & 128 deletions IPCConnection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,164 +18,103 @@ extern "C" {

#include "IPCConnection.hpp"

IPCConnection::IPCConnection(std::shared_ptr<UnixSocket> socket):
_socket(std::move(socket)),
_isWorkerTerminated(false),
namespace ulalaca::ipc {

_workerThread(),
_messageId(0),
_ackId(0),
_isGood(true)
{
IPCConnection::IPCConnection(std::shared_ptr<UnixSocket> socket) :
_socket(std::move(socket)),
_isWorkerTerminated(false),

}

IPCConnection::IPCConnection(const std::string &socketPath):
IPCConnection(std::make_shared<UnixSocket>(socketPath))
{

}

FD IPCConnection::descriptor() {
return _socket->descriptor();
}

void IPCConnection::connect() {
_socket->connect();
_workerThread(),
_messageId(0),
_ackId(0),
_isGood(true) {

// enable non-blocking io
auto flags = _socket->fcntl(F_GETFL, 0);
if (_socket->fcntl(F_SETFL, flags | O_NONBLOCK)) {
throw SystemCallException(errno, "fcntl");
}

_workerThread = std::thread(&IPCConnection::workerLoop, this);
}
IPCConnection::IPCConnection(const std::string &socketPath) :
IPCConnection(std::make_shared<UnixSocket>(socketPath)) {

void IPCConnection::disconnect() {
_isWorkerTerminated = true;

if (_workerThread.joinable()) {
_workerThread.join();
}
_socket->close();
}

bool IPCConnection::isGood() const {
return _isGood;
}

std::unique_ptr<ULIPCHeader, IPCConnection::MallocFreeDeleter> IPCConnection::nextHeader() {
return std::move(read<ULIPCHeader>(sizeof(ULIPCHeader)));
}

void IPCConnection::write(const void *pointer, size_t size) {
assert(pointer != nullptr);
assert(size > 0);

std::unique_ptr<uint8_t, MallocFreeDeleter> data(
(uint8_t *) malloc(size),
free
);

std::memcpy(data.get(), pointer, size);

{
std::scoped_lock<std::mutex> scopedWriteTasksLock(_writeTasksLock);
_writeTasks.emplace(size, std::move(data));
FD IPCConnection::descriptor() {
return _socket->descriptor();
}
}

void IPCConnection::workerLoop() {
const size_t MAX_READ_SIZE = 8192;
void IPCConnection::connect() {
_socket->connect();

size_t readPos = 0;
std::unique_ptr<uint8_t> readBuffer;

_isGood = true;

while (!_isWorkerTerminated) {
auto pollFd = _socket->poll(POLLIN | POLLOUT, -1);
// enable non-blocking io
auto flags = _socket->fcntl(F_GETFL, 0);
if (_socket->fcntl(F_SETFL, flags | O_NONBLOCK)) {
throw SystemCallException(errno, "fcntl");
}

bool canRead = (pollFd.revents & POLLIN) > 0;
bool canWrite = (pollFd.revents & POLLOUT) > 0;
_workerThread = std::thread(&IPCConnection::workerLoop, this);
}

if (canWrite && !_writeTasks.empty()) {
_writeTasksLock.lock();
auto writeTask = std::move(_writeTasks.front());
_writeTasks.pop();
_writeTasksLock.unlock();
void IPCConnection::disconnect() {
_isWorkerTerminated = true;

if (_workerThread.joinable()) {
_workerThread.join();
}
_socket->close();
}

if (_socket->write(writeTask.second.get(), writeTask.first) < 0) {
if (errno == EAGAIN) {
continue;
}
bool IPCConnection::isGood() const {
return _isGood;
}

LOG(LOG_LEVEL_ERROR, "write() failed (errno=%d)", errno);
continue;
}
IPCDataBlockPtr IPCConnection::readBlock(size_t size) {
if (size < 0 || !_isGood) {
return nullptr;
}

if (canRead && !_readTasks.empty()) {
auto &readTask = _readTasks.front();
auto readTask = std::make_shared<IPCReadTask>(IPCReadTask {
size,
std::make_shared<IPCReadPromise>(),

auto &contentLength = readTask.first;
auto &promise = readTask.second;
0,
nullptr
});

if (readBuffer == nullptr) {
readPos = 0;
readBuffer = std::unique_ptr<uint8_t>(new uint8_t[contentLength]);
}
{
std::unique_lock<std::shared_mutex> _lock(_readTasksMutex);
_readTasks.emplace(readTask);
}

int readForBytes = std::min(
(size_t) MAX_READ_SIZE,
contentLength - readPos
);
auto future = readTask->promise->get_future();
auto retval = future.get();

size_t retval = _socket->read(readBuffer.get() + readPos, readForBytes);
return std::move(retval);
}

if (retval < 0) {
if (errno == EAGAIN) {
continue;
} else {
throw SystemCallException(errno, "read");
}
}
void IPCConnection::writeBlock(const void *pointer, size_t size) {
assert(pointer != nullptr);
assert(size > 0);

if (_isGood && retval <= 0) {
break;
}
if (!_isGood) {
return;
}

readPos += retval;
auto writeTask = std::make_shared<IPCWriteTask>(IPCWriteTask {
size,
std::shared_ptr<uint8_t>((uint8_t *) malloc(size), free)
});

if (readPos >= contentLength) {
promise->set_value(std::move(readBuffer));
{
std::scoped_lock<std::mutex> scopedReadTasksLock(_readTasksLock);
_readTasks.pop();
}
memcpy(writeTask->data.get(), pointer, size);

readBuffer = nullptr;
readPos = 0;
}
{
_writeTasks.emplace(std::move(writeTask));
}
}

if (pollFd.revents & POLLHUP) {
LOG(LOG_LEVEL_WARNING, "POLLHUP bit set");
_isGood = false;
std::shared_ptr<ULIPCHeader> IPCConnection::nextHeader() {
auto header = std::move(this->readBlock<ULIPCHeader>(sizeof(ULIPCHeader)));

if (_readTasks.empty()) {
LOG(LOG_LEVEL_WARNING, "POLLHUP bit set; closing connection");
break;
}
}
_ackId = header->id;
// TODO: check timestamp or std::max(_ackId, header->id)

if (pollFd.revents & POLLERR) {
LOG(LOG_LEVEL_ERROR, "POLLERR bit set; closing connection");
break;
}
return std::move(header);
}

_isGood = false;
}
115 changes: 58 additions & 57 deletions IPCConnection.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,93 +5,94 @@
#ifndef ULALACA_IPCCONNECTION_HPP
#define ULALACA_IPCCONNECTION_HPP

#include <cassert>

#include <memory>
#include <thread>
#include <queue>
#include <future>
#include <cassert>
#include <mutex>
#include <shared_mutex>

#include "UnixSocket.hpp"

#include "messages/projector.h"

#include "ulalaca.hpp"

namespace ulalaca::ipc {

using IPCDataBlockPtr = std::shared_ptr<uint8_t>;
using IPCReadPromise = std::promise<IPCDataBlockPtr>;

struct IPCWriteTask {
size_t size;
IPCDataBlockPtr data;
};

struct IPCReadTask {
size_t size;
std::shared_ptr<IPCReadPromise> promise;

class IPCConnection {
public:
using MallocFreeDeleter = std::function<void(void *)>;
size_t read;
IPCDataBlockPtr buffer;
};

explicit IPCConnection(std::shared_ptr<UnixSocket> socket);
explicit IPCConnection(const std::string &socketPath);
IPCConnection(IPCConnection &) = delete;

FD descriptor();
class IPCConnection {
public:
explicit IPCConnection(std::shared_ptr<UnixSocket> socket);
explicit IPCConnection(const std::string &socketPath);

/**
* @throws SystemCallException
*/
void connect();
void disconnect();
IPCConnection(IPCConnection &) = delete;

bool isGood() const;
FD descriptor();

std::unique_ptr<ULIPCHeader, MallocFreeDeleter> nextHeader();
/** @throws SystemCallException */
void connect();
/** @throws SystemCallException */
void disconnect();

template <typename T>
void writeMessage(uint16_t messageType, T message) {
auto header = ULIPCHeader {
(uint16_t) messageType,
_messageId,
0, // FIXME
0, // FIXME
sizeof(T)
};
bool isGood() const;

write(&header, sizeof(header));
write(&message, sizeof(T));
}
std::shared_ptr<ULIPCHeader> nextHeader();

template<typename T>
std::unique_ptr<T, MallocFreeDeleter> read(size_t size) {
assert(size != 0);
IPCDataBlockPtr readBlock(size_t size);
void writeBlock(const void *pointer, size_t size);

auto promise = std::make_shared<std::promise<std::unique_ptr<uint8_t>>>();
{
std::scoped_lock<std::mutex> scopedReadTasksLock(_readTasksLock);
_readTasks.emplace(size, promise);
}
auto source = promise->get_future().get();
auto destination = std::unique_ptr<T, MallocFreeDeleter>(
(T *) malloc(size),
free
);
template <typename T>
std::shared_ptr<T> readBlock(size_t size);

std::memmove(destination.get(), source.get(), size);
template <typename T>
void writeMessage(const ULIPCHeader &header, const T &message);

return std::move(destination);
}
/** @deprecated use writeMessage(const ULIPCHeader &header, const T &message) instead */
template <typename T>
void writeMessage(uint16_t messageType, const T &message);

void write(const void *pointer, size_t size);
private:
void workerLoop();

private:
void workerLoop();
std::atomic_uint64_t _messageId;
std::atomic_uint64_t _ackId;

std::atomic_uint64_t _messageId;
std::atomic_uint64_t _ackId;
std::shared_ptr<UnixSocket> _socket;
std::thread _workerThread;
bool _isWorkerTerminated;

std::shared_ptr<UnixSocket> _socket;
std::thread _workerThread;
bool _isWorkerTerminated;
bool _isGood;

bool _isGood;
std::shared_mutex _writeTasksMutex;
std::queue<std::shared_ptr<IPCWriteTask>> _writeTasks;

std::mutex _writeTasksLock;
std::mutex _readTasksLock;
std::shared_mutex _readTasksMutex;
std::queue<std::shared_ptr<IPCReadTask>> _readTasks;
};
}

std::queue<std::pair<size_t, std::unique_ptr<uint8_t, MallocFreeDeleter>>> _writeTasks;
std::queue<std::pair<size_t, std::shared_ptr<std::promise<std::unique_ptr<uint8_t>>> >> _readTasks;
};
/** @deprecated use ulalaca::ipc::IPCConnection instead */
using IPCConnection = ulalaca::ipc::IPCConnection;

#include "IPCConnection.template.cpp"

#endif //XRDP_IPCCONNECTION_HPP
Loading