Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid using thread_local to work-around mingw bug #3

Open
wants to merge 1 commit into
base: cpp
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 11 additions & 27 deletions pocketfft_hdronly.h
Original file line number Diff line number Diff line change
Expand Up @@ -515,25 +515,12 @@ namespace threading {

#ifdef POCKETFFT_NO_MULTITHREADING

constexpr inline size_t thread_id() { return 0; }
constexpr inline size_t num_threads() { return 1; }

template <typename Func>
void thread_map(size_t /* nthreads */, Func f)
{ f(); }
{ f(0, 1); }

#else

inline size_t &thread_id()
{
static thread_local size_t thread_id_=0;
return thread_id_;
}
inline size_t &num_threads()
{
static thread_local size_t num_threads_=1;
return num_threads_;
}
static const size_t max_threads = std::max(1u, std::thread::hardware_concurrency());

class latch
Expand Down Expand Up @@ -786,7 +773,7 @@ void thread_map(size_t nthreads, Func f)
nthreads = max_threads;

if (nthreads == 1)
{ f(); return; }
{ f(0, 1); return; }

auto & pool = get_pool();
latch counter(nthreads);
Expand All @@ -796,9 +783,7 @@ void thread_map(size_t nthreads, Func f)
{
pool.submit(
[&f, &counter, &ex, &ex_mut, i, nthreads] {
thread_id() = i;
num_threads() = nthreads;
try { f(); }
try { f(i, nthreads); }
catch (...)
{
std::lock_guard<std::mutex> lock(ex_mut);
Expand Down Expand Up @@ -2881,15 +2866,14 @@ template<size_t N> class multi_iter
}

public:
multi_iter(const arr_info &iarr_, const arr_info &oarr_, size_t idim_)
multi_iter(const arr_info &iarr_, const arr_info &oarr_, size_t idim_,
size_t nshares, size_t myshare)
: pos(iarr_.ndim(), 0), iarr(iarr_), oarr(oarr_), p_ii(0),
str_i(iarr.stride(idim_)), p_oi(0), str_o(oarr.stride(idim_)),
idim(idim_), rem(iarr.size()/iarr.shape(idim))
{
auto nshares = threading::num_threads();
if (nshares==1) return;
if (nshares==0) throw std::runtime_error("can't run with zero threads");
auto myshare = threading::thread_id();
if (myshare>=nshares) throw std::runtime_error("impossible share requested");
size_t nbase = rem/nshares;
size_t additional = rem%nshares;
Expand Down Expand Up @@ -3134,11 +3118,11 @@ POCKETFFT_NOINLINE void general_nd(const cndarr<T> &in, ndarr<T> &out,

threading::thread_map(
util::thread_count(nthreads, in.shape(), axes[iax], VLEN<T>::val),
[&] {
[&] (size_t tid, size_t nthreads) {
constexpr auto vlen = VLEN<T0>::val;
auto storage = alloc_tmp<T0>(in.shape(), len, sizeof(T));
const auto &tin(iax==0? in : out);
multi_iter<vlen> it(tin, out, axes[iax]);
multi_iter<vlen> it(tin, out, axes[iax], nthreads, tid);
#ifndef POCKETFFT_NO_VECTORS
if (vlen>1)
while (it.remaining()>=vlen)
Expand Down Expand Up @@ -3241,10 +3225,10 @@ template<typename T> POCKETFFT_NOINLINE void general_r2c(
size_t len=in.shape(axis);
threading::thread_map(
util::thread_count(nthreads, in.shape(), axis, VLEN<T>::val),
[&] {
[&] (size_t tid, size_t nthreads) {
constexpr auto vlen = VLEN<T>::val;
auto storage = alloc_tmp<T>(in.shape(), len, sizeof(T));
multi_iter<vlen> it(in, out, axis);
multi_iter<vlen> it(in, out, axis, nthreads, tid);
#ifndef POCKETFFT_NO_VECTORS
if (vlen>1)
while (it.remaining()>=vlen)
Expand Down Expand Up @@ -3296,10 +3280,10 @@ template<typename T> POCKETFFT_NOINLINE void general_c2r(
size_t len=out.shape(axis);
threading::thread_map(
util::thread_count(nthreads, in.shape(), axis, VLEN<T>::val),
[&] {
[&] (size_t tid, size_t nthreads) {
constexpr auto vlen = VLEN<T>::val;
auto storage = alloc_tmp<T>(out.shape(), len, sizeof(T));
multi_iter<vlen> it(in, out, axis);
multi_iter<vlen> it(in, out, axis, nthreads, tid);
#ifndef POCKETFFT_NO_VECTORS
if (vlen>1)
while (it.remaining()>=vlen)
Expand Down