Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Round Robin Queue Servicing #24

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions include/thread_pool/thread_pool.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ using ThreadPool = ThreadPoolImpl<FixedFunction<void(), 128>,
*/
template <typename Task, template<typename> class Queue>
class ThreadPoolImpl {

using WorkerVector = std::vector<std::unique_ptr<Worker<Task, Queue>>>;

public:
/**
* @brief ThreadPool Construct and start new thread pool.
Expand Down Expand Up @@ -74,7 +77,7 @@ class ThreadPoolImpl {
private:
Worker<Task, Queue>& getWorker();

std::vector<std::unique_ptr<Worker<Task, Queue>>> m_workers;
WorkerVector m_workers;
std::atomic<size_t> m_next_worker;
};

Expand All @@ -94,9 +97,7 @@ inline ThreadPoolImpl<Task, Queue>::ThreadPoolImpl(

for(size_t i = 0; i < m_workers.size(); ++i)
{
Worker<Task, Queue>* steal_donor =
m_workers[(i + 1) % m_workers.size()].get();
m_workers[i]->start(i, steal_donor);
m_workers[i]->start(i, &m_workers);
}
}

Expand Down Expand Up @@ -131,7 +132,7 @@ template <typename Task, template<typename> class Queue>
template <typename Handler>
inline bool ThreadPoolImpl<Task, Queue>::tryPost(Handler&& handler)
{
return getWorker().post(std::forward<Handler>(handler));
return getWorker().tryPost(std::forward<Handler>(handler));
}

template <typename Task, template<typename> class Queue>
Expand Down
10 changes: 5 additions & 5 deletions include/thread_pool/thread_pool_options.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,28 +47,28 @@ class ThreadPoolOptions

/// Implementation

ThreadPoolOptions::ThreadPoolOptions()
inline ThreadPoolOptions::ThreadPoolOptions()
: m_thread_count(std::max<size_t>(1u, std::thread::hardware_concurrency()))
, m_queue_size(1024u)
{
}

void ThreadPoolOptions::setThreadCount(size_t count)
inline void ThreadPoolOptions::setThreadCount(size_t count)
{
m_thread_count = std::max<size_t>(1u, count);
}

void ThreadPoolOptions::setQueueSize(size_t size)
inline void ThreadPoolOptions::setQueueSize(size_t size)
{
m_queue_size = std::max<size_t>(1u, size);
}

size_t ThreadPoolOptions::threadCount() const
inline size_t ThreadPoolOptions::threadCount() const
{
return m_thread_count;
}

size_t ThreadPoolOptions::queueSize() const
inline size_t ThreadPoolOptions::queueSize() const
{
return m_queue_size;
}
Expand Down
71 changes: 54 additions & 17 deletions include/thread_pool/worker.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include <atomic>
#include <thread>
#include <limits>

namespace tp
{
Expand All @@ -15,6 +16,8 @@ namespace tp
template <typename Task, template<typename> class Queue>
class Worker
{
using WorkerVector = std::vector<std::unique_ptr<Worker<Task, Queue>>>;

public:
/**
* @brief Worker Constructor.
Expand All @@ -35,9 +38,9 @@ class Worker
/**
* @brief start Create the executing thread and start tasks execution.
* @param id Worker ID.
* @param steal_donor Sibling worker to steal task from it.
* @param workers Sibling workers for performing round robin work stealing.
*/
void start(size_t id, Worker* steal_donor);
void start(size_t id, WorkerVector* workers);

/**
* @brief stop Stop all worker's thread and stealing activity.
Expand All @@ -46,19 +49,19 @@ class Worker
void stop();

/**
* @brief post Post task to queue.
* @brief tryPost Post task to queue.
* @param handler Handler to be executed in executing thread.
* @return true on success.
*/
template <typename Handler>
bool post(Handler&& handler);
bool tryPost(Handler&& handler);

/**
* @brief steal Steal one task from this worker queue.
* @param task Place for stealed task to be stored.
* @brief tryGetLocalTask Get one task from this worker queue.
* @param task Place for the obtained task to be stored.
* @return true on success.
*/
bool steal(Task& task);
bool tryGetLocalTask(Task& task);

/**
* @brief getWorkerIdForCurrentThread Return worker ID associated with
Expand All @@ -68,16 +71,24 @@ class Worker
static size_t getWorkerIdForCurrentThread();

private:
/**
* @brief tryRoundRobinSteal Try stealing a thread from sibling workers in a round-robin fashion.
* @param task Place for the obtained task to be stored.
* @param workers Sibling workers for performing round robin work stealing.
*/
bool tryRoundRobinSteal(Task& task, WorkerVector* workers);

/**
* @brief threadFunc Executing thread function.
* @param id Worker ID to be associated with this thread.
* @param steal_donor Sibling worker to steal task from it.
* @param workers Sibling workers for performing round robin work stealing.
*/
void threadFunc(size_t id, Worker* steal_donor);
void threadFunc(size_t id, WorkerVector* workers);

Queue<Task> m_queue;
std::atomic<bool> m_running_flag;
std::thread m_thread;
size_t m_next_donor;
};


Expand All @@ -87,7 +98,7 @@ namespace detail
{
inline size_t* thread_id()
{
static thread_local size_t tss_id = -1u;
static thread_local size_t tss_id = std::numeric_limits<size_t>::max();
return &tss_id;
}
}
Expand All @@ -96,6 +107,7 @@ template <typename Task, template<typename> class Queue>
inline Worker<Task, Queue>::Worker(size_t queue_size)
: m_queue(queue_size)
, m_running_flag(true)
, m_next_donor(0) // Initialized in threadFunc.
{
}

Expand Down Expand Up @@ -125,9 +137,9 @@ inline void Worker<Task, Queue>::stop()
}

template <typename Task, template<typename> class Queue>
inline void Worker<Task, Queue>::start(size_t id, Worker* steal_donor)
inline void Worker<Task, Queue>::start(size_t id, WorkerVector* workers)
{
m_thread = std::thread(&Worker<Task, Queue>::threadFunc, this, id, steal_donor);
m_thread = std::thread(&Worker<Task, Queue>::threadFunc, this, id, workers);
}

template <typename Task, template<typename> class Queue>
Expand All @@ -138,35 +150,60 @@ inline size_t Worker<Task, Queue>::getWorkerIdForCurrentThread()

template <typename Task, template<typename> class Queue>
template <typename Handler>
inline bool Worker<Task, Queue>::post(Handler&& handler)
inline bool Worker<Task, Queue>::tryPost(Handler&& handler)
{
return m_queue.push(std::forward<Handler>(handler));
}

template <typename Task, template<typename> class Queue>
inline bool Worker<Task, Queue>::steal(Task& task)
inline bool Worker<Task, Queue>::tryGetLocalTask(Task& task)
{
return m_queue.pop(task);
}

template <typename Task, template<typename> class Queue>
inline void Worker<Task, Queue>::threadFunc(size_t id, Worker* steal_donor)
inline bool Worker<Task, Queue>::tryRoundRobinSteal(Task& task, WorkerVector* workers)
{
auto starting_index = m_next_donor;

// Iterate once through the worker ring, checking for queued work items on each thread.
do
{
// Don't steal from local queue.
if (m_next_donor != *detail::thread_id() && workers->at(m_next_donor)->tryGetLocalTask(task))
{
// Increment before returning so that m_next_donor always points to the worker that has gone the longest
// without a steal attempt. This helps enforce fairness in the stealing.
++m_next_donor %= workers->size();
return true;
}

++m_next_donor %= workers->size();
} while (m_next_donor != starting_index);

return false;
}

template <typename Task, template<typename> class Queue>
inline void Worker<Task, Queue>::threadFunc(size_t id, WorkerVector* workers)
{
*detail::thread_id() = id;
m_next_donor = ++id % workers->size();

Task handler;

while (m_running_flag.load(std::memory_order_relaxed))
{
if (m_queue.pop(handler) || steal_donor->steal(handler))
// Prioritize local queue, then try stealing from sibling workers.
if (tryGetLocalTask(handler) || tryRoundRobinSteal(handler, workers))
{
try
{
handler();
}
catch(...)
{
// suppress all exceptions
// Suppress all exceptions.
}
}
else
Expand Down