@@ -2,14 +2,20 @@ use ndarray::ArrayView1;
2
2
3
3
use rayon:: iter:: IndexedParallelIterator ;
4
4
use rayon:: prelude:: * ;
5
+ use std:: thread:: available_parallelism;
5
6
6
7
use super :: types:: Num ;
7
8
use num_traits:: { AsPrimitive , FromPrimitive } ;
8
9
9
10
// ---------------------- Binary search ----------------------
10
11
11
12
// #[inline(always)]
12
- fn binary_search < T : PartialOrd > ( arr : ArrayView1 < T > , value : T , left : usize , right : usize ) -> usize {
13
+ fn binary_search < T : Copy + PartialOrd > (
14
+ arr : ArrayView1 < T > ,
15
+ value : T ,
16
+ left : usize ,
17
+ right : usize ,
18
+ ) -> usize {
13
19
let mut size: usize = right - left;
14
20
let mut left: usize = left;
15
21
let mut right: usize = right;
@@ -27,7 +33,7 @@ fn binary_search<T: PartialOrd>(arr: ArrayView1<T>, value: T, left: usize, right
27
33
}
28
34
29
35
// #[inline(always)]
30
- fn binary_search_with_mid < T : PartialOrd > (
36
+ fn binary_search_with_mid < T : Copy + PartialOrd > (
31
37
arr : ArrayView1 < T > ,
32
38
value : T ,
33
39
left : usize ,
@@ -69,17 +75,17 @@ where
69
75
( arr[ arr. len ( ) - 1 ] . as_ ( ) / nb_bins as f64 ) - ( arr[ 0 ] . as_ ( ) / nb_bins as f64 ) ;
70
76
let idx_step: usize = arr. len ( ) / nb_bins; // used to pre-guess the mid index
71
77
let mut value: f64 = arr[ 0 ] . as_ ( ) ; // Search value
72
- let mut idx = 0 ; // Index of the search value
78
+ let mut idx: usize = 0 ; // Index of the search value
73
79
( 0 ..nb_bins) . map ( move |_| {
74
- let start_idx = idx; // Start index of the bin (previous end index)
80
+ let start_idx: usize = idx; // Start index of the bin (previous end index)
75
81
value += val_step;
76
- let mid = idx + idx_step;
82
+ let mid: usize = idx + idx_step;
77
83
let mid = if mid < arr. len ( ) - 1 {
78
84
mid
79
85
} else {
80
86
arr. len ( ) - 2 // TODO: arr.len() - 1 gives error I thought...
81
87
} ;
82
- let search_value = T :: from_f64 ( value) . unwrap ( ) ;
88
+ let search_value: T = T :: from_f64 ( value) . unwrap ( ) ;
83
89
// Implementation WITHOUT pre-guessing mid is slower!!
84
90
// idx = binary_search(arr, search_value, idx, arr.len()-1);
85
91
idx = binary_search_with_mid ( arr, search_value, idx, arr. len ( ) - 1 , mid) ; // End index of the bin
@@ -102,7 +108,7 @@ fn sequential_add_mul(start_val: f64, add_val: f64, mul: usize) -> f64 {
102
108
pub ( crate ) fn get_equidistant_bin_idx_iterator_parallel < T > (
103
109
arr : ArrayView1 < T > ,
104
110
nb_bins : usize ,
105
- ) -> impl IndexedParallelIterator < Item = ( usize , usize ) > + ' _
111
+ ) -> impl IndexedParallelIterator < Item = impl Iterator < Item = ( usize , usize ) > + ' _ > + ' _
106
112
where
107
113
T : Num + FromPrimitive + AsPrimitive < f64 > + Sync + Send ,
108
114
{
@@ -111,14 +117,35 @@ where
111
117
let val_step: f64 =
112
118
( arr[ arr. len ( ) - 1 ] . as_ ( ) / nb_bins as f64 ) - ( arr[ 0 ] . as_ ( ) / nb_bins as f64 ) ;
113
119
let arr0: f64 = arr[ 0 ] . as_ ( ) ;
114
- ( 0 ..nb_bins) . into_par_iter ( ) . map ( move |i| {
115
- let start_value = sequential_add_mul ( arr0, val_step, i) ;
116
- let end_value = start_value + val_step;
117
- let start_value = T :: from_f64 ( start_value) . unwrap ( ) ;
118
- let end_value = T :: from_f64 ( end_value) . unwrap ( ) ;
119
- let start_idx = binary_search ( arr, start_value, 0 , arr. len ( ) - 1 ) ;
120
- let end_idx = binary_search ( arr, end_value, start_idx, arr. len ( ) - 1 ) ;
121
- ( start_idx, end_idx)
120
+ let nb_threads = available_parallelism ( ) . map ( |x| x. get ( ) ) . unwrap_or ( 1 ) ;
121
+ let nb_threads = if nb_threads > nb_bins {
122
+ nb_bins
123
+ } else {
124
+ nb_threads
125
+ } ;
126
+ let nb_bins_per_thread = nb_bins / nb_threads;
127
+ let nb_bins_last_thread = nb_bins - nb_bins_per_thread * ( nb_threads - 1 ) ;
128
+ // Iterate over the number of threads
129
+ // -> for each thread perform the binary search sorted with moving left and
130
+ // yield the indices (using the same idea as for the sequential version)
131
+ ( 0 ..nb_threads) . into_par_iter ( ) . map ( move |i| {
132
+ // Search the start of the fist bin o(f the thread)
133
+ let mut value: f64 = sequential_add_mul ( arr0, val_step, i * nb_bins_per_thread) ; // Search value
134
+ let start_value: T = T :: from_f64 ( value) . unwrap ( ) ;
135
+ let mut idx: usize = binary_search ( arr, start_value, 0 , arr. len ( ) - 1 ) ; // Index of the search value
136
+ let nb_bins_thread = if i == nb_threads - 1 {
137
+ nb_bins_last_thread
138
+ } else {
139
+ nb_bins_per_thread
140
+ } ;
141
+ // Perform sequential binary search for the end of the bins (of the thread)
142
+ ( 0 ..nb_bins_thread) . map ( move |_| {
143
+ let start_idx: usize = idx; // Start index of the bin (previous end index)
144
+ value += val_step;
145
+ let search_value: T = T :: from_f64 ( value) . unwrap ( ) ;
146
+ idx = binary_search ( arr, search_value, idx, arr. len ( ) - 1 ) ; // End index of the bin
147
+ ( start_idx, idx)
148
+ } )
122
149
} )
123
150
}
124
151
@@ -207,7 +234,10 @@ mod tests {
207
234
let bin_idxs = bin_idxs_iter. map ( |x| x. 0 ) . collect :: < Vec < usize > > ( ) ;
208
235
assert_eq ! ( bin_idxs, vec![ 0 , 3 , 6 ] ) ;
209
236
let bin_idxs_iter = get_equidistant_bin_idx_iterator_parallel ( arr. view ( ) , 3 ) ;
210
- let bin_idxs = bin_idxs_iter. map ( |x| x. 0 ) . collect :: < Vec < usize > > ( ) ;
237
+ let bin_idxs = bin_idxs_iter
238
+ . map ( |x| x. map ( |x| x. 0 ) . collect :: < Vec < usize > > ( ) )
239
+ . flatten ( )
240
+ . collect :: < Vec < usize > > ( ) ;
211
241
assert_eq ! ( bin_idxs, vec![ 0 , 3 , 6 ] ) ;
212
242
}
213
243
@@ -225,7 +255,10 @@ mod tests {
225
255
let bin_idxs_iter = get_equidistant_bin_idx_iterator ( arr. view ( ) , nb_bins) ;
226
256
let bin_idxs = bin_idxs_iter. map ( |x| x. 0 ) . collect :: < Vec < usize > > ( ) ;
227
257
let bin_idxs_iter = get_equidistant_bin_idx_iterator_parallel ( arr. view ( ) , nb_bins) ;
228
- let bin_idxs_parallel = bin_idxs_iter. map ( |x| x. 0 ) . collect :: < Vec < usize > > ( ) ;
258
+ let bin_idxs_parallel = bin_idxs_iter
259
+ . map ( |x| x. map ( |x| x. 0 ) . collect :: < Vec < usize > > ( ) )
260
+ . flatten ( )
261
+ . collect :: < Vec < usize > > ( ) ;
229
262
assert_eq ! ( bin_idxs, bin_idxs_parallel) ;
230
263
}
231
264
}
0 commit comments