diff --git a/include/thread_pool/thread_pool.hpp b/include/thread_pool/thread_pool.hpp index 4c79fdbf..e7c7d408 100644 --- a/include/thread_pool/thread_pool.hpp +++ b/include/thread_pool/thread_pool.hpp @@ -28,6 +28,9 @@ using ThreadPool = ThreadPoolImpl, */ template class Queue> class ThreadPoolImpl { + + using WorkerVector = std::vector>>; + public: /** * @brief ThreadPool Construct and start new thread pool. @@ -74,7 +77,7 @@ class ThreadPoolImpl { private: Worker& getWorker(); - std::vector>> m_workers; + WorkerVector m_workers; std::atomic m_next_worker; }; @@ -94,9 +97,7 @@ inline ThreadPoolImpl::ThreadPoolImpl( for(size_t i = 0; i < m_workers.size(); ++i) { - Worker* steal_donor = - m_workers[(i + 1) % m_workers.size()].get(); - m_workers[i]->start(i, steal_donor); + m_workers[i]->start(i, &m_workers); } } @@ -131,7 +132,7 @@ template class Queue> template inline bool ThreadPoolImpl::tryPost(Handler&& handler) { - return getWorker().post(std::forward(handler)); + return getWorker().tryPost(std::forward(handler)); } template class Queue> diff --git a/include/thread_pool/thread_pool_options.hpp b/include/thread_pool/thread_pool_options.hpp index 88b07f63..c1cde526 100644 --- a/include/thread_pool/thread_pool_options.hpp +++ b/include/thread_pool/thread_pool_options.hpp @@ -47,28 +47,28 @@ class ThreadPoolOptions /// Implementation -ThreadPoolOptions::ThreadPoolOptions() +inline ThreadPoolOptions::ThreadPoolOptions() : m_thread_count(std::max(1u, std::thread::hardware_concurrency())) , m_queue_size(1024u) { } -void ThreadPoolOptions::setThreadCount(size_t count) +inline void ThreadPoolOptions::setThreadCount(size_t count) { m_thread_count = std::max(1u, count); } -void ThreadPoolOptions::setQueueSize(size_t size) +inline void ThreadPoolOptions::setQueueSize(size_t size) { m_queue_size = std::max(1u, size); } -size_t ThreadPoolOptions::threadCount() const +inline size_t ThreadPoolOptions::threadCount() const { return m_thread_count; } -size_t ThreadPoolOptions::queueSize() const +inline size_t ThreadPoolOptions::queueSize() const { return m_queue_size; } diff --git a/include/thread_pool/worker.hpp b/include/thread_pool/worker.hpp index 91e67a37..bca966bf 100644 --- a/include/thread_pool/worker.hpp +++ b/include/thread_pool/worker.hpp @@ -2,6 +2,7 @@ #include #include +#include namespace tp { @@ -15,6 +16,8 @@ namespace tp template class Queue> class Worker { + using WorkerVector = std::vector>>; + public: /** * @brief Worker Constructor. @@ -35,9 +38,9 @@ class Worker /** * @brief start Create the executing thread and start tasks execution. * @param id Worker ID. - * @param steal_donor Sibling worker to steal task from it. + * @param workers Sibling workers for performing round robin work stealing. */ - void start(size_t id, Worker* steal_donor); + void start(size_t id, WorkerVector* workers); /** * @brief stop Stop all worker's thread and stealing activity. @@ -46,19 +49,19 @@ class Worker void stop(); /** - * @brief post Post task to queue. + * @brief tryPost Post task to queue. * @param handler Handler to be executed in executing thread. * @return true on success. */ template - bool post(Handler&& handler); + bool tryPost(Handler&& handler); /** - * @brief steal Steal one task from this worker queue. - * @param task Place for stealed task to be stored. + * @brief tryGetLocalTask Get one task from this worker queue. + * @param task Place for the obtained task to be stored. * @return true on success. */ - bool steal(Task& task); + bool tryGetLocalTask(Task& task); /** * @brief getWorkerIdForCurrentThread Return worker ID associated with @@ -68,16 +71,24 @@ class Worker static size_t getWorkerIdForCurrentThread(); private: + /** + * @brief tryRoundRobinSteal Try stealing a thread from sibling workers in a round-robin fashion. + * @param task Place for the obtained task to be stored. + * @param workers Sibling workers for performing round robin work stealing. + */ + bool tryRoundRobinSteal(Task& task, WorkerVector* workers); + /** * @brief threadFunc Executing thread function. * @param id Worker ID to be associated with this thread. - * @param steal_donor Sibling worker to steal task from it. + * @param workers Sibling workers for performing round robin work stealing. */ - void threadFunc(size_t id, Worker* steal_donor); + void threadFunc(size_t id, WorkerVector* workers); Queue m_queue; std::atomic m_running_flag; std::thread m_thread; + size_t m_next_donor; }; @@ -87,7 +98,7 @@ namespace detail { inline size_t* thread_id() { - static thread_local size_t tss_id = -1u; + static thread_local size_t tss_id = std::numeric_limits::max(); return &tss_id; } } @@ -96,6 +107,7 @@ template class Queue> inline Worker::Worker(size_t queue_size) : m_queue(queue_size) , m_running_flag(true) + , m_next_donor(0) // Initialized in threadFunc. { } @@ -125,9 +137,9 @@ inline void Worker::stop() } template class Queue> -inline void Worker::start(size_t id, Worker* steal_donor) +inline void Worker::start(size_t id, WorkerVector* workers) { - m_thread = std::thread(&Worker::threadFunc, this, id, steal_donor); + m_thread = std::thread(&Worker::threadFunc, this, id, workers); } template class Queue> @@ -138,27 +150,52 @@ inline size_t Worker::getWorkerIdForCurrentThread() template class Queue> template -inline bool Worker::post(Handler&& handler) +inline bool Worker::tryPost(Handler&& handler) { return m_queue.push(std::forward(handler)); } template class Queue> -inline bool Worker::steal(Task& task) +inline bool Worker::tryGetLocalTask(Task& task) { return m_queue.pop(task); } template class Queue> -inline void Worker::threadFunc(size_t id, Worker* steal_donor) +inline bool Worker::tryRoundRobinSteal(Task& task, WorkerVector* workers) +{ + auto starting_index = m_next_donor; + + // Iterate once through the worker ring, checking for queued work items on each thread. + do + { + // Don't steal from local queue. + if (m_next_donor != *detail::thread_id() && workers->at(m_next_donor)->tryGetLocalTask(task)) + { + // Increment before returning so that m_next_donor always points to the worker that has gone the longest + // without a steal attempt. This helps enforce fairness in the stealing. + ++m_next_donor %= workers->size(); + return true; + } + + ++m_next_donor %= workers->size(); + } while (m_next_donor != starting_index); + + return false; +} + +template class Queue> +inline void Worker::threadFunc(size_t id, WorkerVector* workers) { *detail::thread_id() = id; + m_next_donor = ++id % workers->size(); Task handler; while (m_running_flag.load(std::memory_order_relaxed)) { - if (m_queue.pop(handler) || steal_donor->steal(handler)) + // Prioritize local queue, then try stealing from sibling workers. + if (tryGetLocalTask(handler) || tryRoundRobinSteal(handler, workers)) { try { @@ -166,7 +203,7 @@ inline void Worker::threadFunc(size_t id, Worker* steal_donor) } catch(...) { - // suppress all exceptions + // Suppress all exceptions. } } else