@@ -515,25 +515,12 @@ namespace threading {
515
515
516
516
#ifdef POCKETFFT_NO_MULTITHREADING
517
517
518
- constexpr inline size_t thread_id () { return 0 ; }
519
- constexpr inline size_t num_threads () { return 1 ; }
520
-
521
518
template <typename Func>
522
519
void thread_map (size_t /* nthreads */ , Func f)
523
520
{ f (); }
524
521
525
522
#else
526
523
527
- inline size_t &thread_id ()
528
- {
529
- static thread_local size_t thread_id_=0 ;
530
- return thread_id_;
531
- }
532
- inline size_t &num_threads ()
533
- {
534
- static thread_local size_t num_threads_=1 ;
535
- return num_threads_;
536
- }
537
524
static const size_t max_threads = std::max(1u , std::thread::hardware_concurrency());
538
525
539
526
class latch
@@ -786,7 +773,7 @@ void thread_map(size_t nthreads, Func f)
786
773
nthreads = max_threads;
787
774
788
775
if (nthreads == 1 )
789
- { f (); return ; }
776
+ { f (0 , 1 ); return ; }
790
777
791
778
auto & pool = get_pool ();
792
779
latch counter (nthreads);
@@ -796,9 +783,7 @@ void thread_map(size_t nthreads, Func f)
796
783
{
797
784
pool.submit (
798
785
[&f, &counter, &ex, &ex_mut, i, nthreads] {
799
- thread_id () = i;
800
- num_threads () = nthreads;
801
- try { f (); }
786
+ try { f (i, nthreads); }
802
787
catch (...)
803
788
{
804
789
std::lock_guard<std::mutex> lock (ex_mut);
@@ -2881,15 +2866,14 @@ template<size_t N> class multi_iter
2881
2866
}
2882
2867
2883
2868
public:
2884
- multi_iter (const arr_info &iarr_, const arr_info &oarr_, size_t idim_)
2869
+ multi_iter (const arr_info &iarr_, const arr_info &oarr_, size_t idim_,
2870
+ size_t nshares, size_t myshare)
2885
2871
: pos(iarr_.ndim(), 0 ), iarr(iarr_), oarr(oarr_), p_ii(0 ),
2886
2872
str_i (iarr.stride(idim_)), p_oi(0 ), str_o(oarr.stride(idim_)),
2887
2873
idim(idim_), rem(iarr.size()/iarr.shape(idim))
2888
2874
{
2889
- auto nshares = threading::num_threads ();
2890
2875
if (nshares==1 ) return ;
2891
2876
if (nshares==0 ) throw std::runtime_error (" can't run with zero threads" );
2892
- auto myshare = threading::thread_id ();
2893
2877
if (myshare>=nshares) throw std::runtime_error (" impossible share requested" );
2894
2878
size_t nbase = rem/nshares;
2895
2879
size_t additional = rem%nshares;
@@ -3134,11 +3118,11 @@ POCKETFFT_NOINLINE void general_nd(const cndarr<T> &in, ndarr<T> &out,
3134
3118
3135
3119
threading::thread_map (
3136
3120
util::thread_count (nthreads, in.shape (), axes[iax], VLEN<T>::val),
3137
- [&] {
3121
+ [&] ( size_t tid, size_t nthreads) {
3138
3122
constexpr auto vlen = VLEN<T0>::val;
3139
3123
auto storage = alloc_tmp<T0>(in.shape (), len, sizeof (T));
3140
3124
const auto &tin (iax==0 ? in : out);
3141
- multi_iter<vlen> it (tin, out, axes[iax]);
3125
+ multi_iter<vlen> it (tin, out, axes[iax], nthreads, tid );
3142
3126
#ifndef POCKETFFT_NO_VECTORS
3143
3127
if (vlen>1 )
3144
3128
while (it.remaining ()>=vlen)
@@ -3241,10 +3225,10 @@ template<typename T> POCKETFFT_NOINLINE void general_r2c(
3241
3225
size_t len=in.shape (axis);
3242
3226
threading::thread_map (
3243
3227
util::thread_count (nthreads, in.shape (), axis, VLEN<T>::val),
3244
- [&] {
3228
+ [&] ( size_t tid, size_t nthreads) {
3245
3229
constexpr auto vlen = VLEN<T>::val;
3246
3230
auto storage = alloc_tmp<T>(in.shape (), len, sizeof (T));
3247
- multi_iter<vlen> it (in, out, axis);
3231
+ multi_iter<vlen> it (in, out, axis, nthreads, tid );
3248
3232
#ifndef POCKETFFT_NO_VECTORS
3249
3233
if (vlen>1 )
3250
3234
while (it.remaining ()>=vlen)
@@ -3296,10 +3280,10 @@ template<typename T> POCKETFFT_NOINLINE void general_c2r(
3296
3280
size_t len=out.shape (axis);
3297
3281
threading::thread_map (
3298
3282
util::thread_count (nthreads, in.shape (), axis, VLEN<T>::val),
3299
- [&] {
3283
+ [&] ( size_t tid, size_t nthreads) {
3300
3284
constexpr auto vlen = VLEN<T>::val;
3301
3285
auto storage = alloc_tmp<T>(out.shape (), len, sizeof (T));
3302
- multi_iter<vlen> it (in, out, axis);
3286
+ multi_iter<vlen> it (in, out, axis, nthreads, tid );
3303
3287
#ifndef POCKETFFT_NO_VECTORS
3304
3288
if (vlen>1 )
3305
3289
while (it.remaining ()>=vlen)
0 commit comments