1- use ndarray:: { Array2 , ArrayView1 , ArrayView2 , Zip } ;
1+ use ndarray:: { Array2 , ArrayView1 , ArrayView2 , ArrayViewMut1 } ;
22use numpy:: { PyArray2 , PyReadonlyArray1 , PyReadonlyArray2 } ;
33use pyo3:: prelude:: * ;
44use rayon:: prelude:: * ;
5- use std:: sync:: atomic:: { AtomicUsize , Ordering } ;
5+ use std:: sync:: atomic:: { AtomicBool , Ordering } ;
66use std:: sync:: Arc ;
77
8- mod internal {
9- pub ( super ) fn sad_converged ( a : & [ f64 ] , b : & [ f64 ] , tol : f64 ) -> bool {
10- a. iter ( ) . zip ( b) . all ( |( & x, & y) | ( x - y) . abs ( ) < tol)
8+
9+ #[ allow( non_snake_case) ]
10+ pub fn demean_impl ( X : & mut Array2 < f64 > , D : ArrayView2 < usize > , weights : ArrayView1 < f64 > , tol : f64 , iterations : usize ) -> bool {
11+ let nsamples = X . nrows ( ) ;
12+ let nfactors = D . ncols ( ) ;
13+ let success = Arc :: new ( AtomicBool :: new ( true ) ) ;
14+ let group_weights = FactorGroupWeights :: new ( & D , & weights) ;
15+
16+ X . axis_iter_mut ( ndarray:: Axis ( 1 ) )
17+ . into_par_iter ( )
18+ . for_each ( |mut column| {
19+ let mut demeaner = ColumnDemeaner :: new ( nsamples, group_weights. width ) ;
20+
21+ for _ in 0 ..iterations {
22+ for i in 0 ..nfactors {
23+ demeaner. demean_column (
24+ & mut column,
25+ & weights,
26+ & D . column ( i) ,
27+ group_weights. factor_weight_slice ( i)
28+ ) ;
29+ }
30+
31+ demeaner. check_convergence ( & column. view ( ) , tol) ;
32+ if demeaner. converged {
33+ break ;
34+ }
35+ }
36+
37+ if !demeaner. converged {
38+ // We can use a relaxed ordering since we only ever go from true to false
39+ // and it doesn't matter how many times we do this.
40+ success. store ( false , Ordering :: Relaxed ) ;
41+ }
42+ } ) ;
43+
44+ success. load ( Ordering :: Relaxed )
45+ }
46+
47+ // The column demeaner is in charge of subtracting group means until convergence.
48+ struct ColumnDemeaner {
49+ converged : bool ,
50+ checkpoint : Vec < f64 > ,
51+ group_sums : Vec < f64 > ,
52+ }
53+
54+ impl ColumnDemeaner {
55+ fn new ( n : usize , k : usize ) -> Self {
56+ Self {
57+ converged : false ,
58+ checkpoint : vec ! [ 0.0 ; n] ,
59+ group_sums : vec ! [ 0.0 ; k] ,
60+ }
1161 }
1262
13- pub ( super ) fn subtract_weighted_group_mean (
14- x : & mut [ f64 ] ,
15- sample_weights : & [ f64 ] ,
16- group_ids : & [ usize ] ,
63+ fn demean_column (
64+ & mut self ,
65+ x : & mut ArrayViewMut1 < f64 > ,
66+ weights : & ArrayView1 < f64 > ,
67+ groups : & ArrayView1 < usize > ,
1768 group_weights : & [ f64 ] ,
18- group_weighted_sums : & mut [ f64 ] ,
1969 ) {
20- group_weighted_sums . fill ( 0.0 ) ;
70+ self . group_sums . fill ( 0.0 ) ;
2171
22- // Accumulate weighted sums per group
23- x. iter ( )
24- . zip ( sample_weights)
25- . zip ( group_ids)
26- . for_each ( |( ( & xi, & wi) , & gid) | {
27- group_weighted_sums[ gid] += wi * xi;
28- } ) ;
72+ // Compute group sums
73+ for ( ( & xi, & wi) , & gid) in x. iter ( ) . zip ( weights) . zip ( groups) {
74+ self . group_sums [ gid] += wi * xi;
75+ }
2976
30- // Compute group means
31- let group_means: Vec < f64 > = group_weighted_sums
32- . iter ( )
33- . zip ( group_weights)
34- . map ( |( & sum, & weight) | sum / weight)
35- . collect ( ) ;
77+ // Convert sums to means
78+ self . group_sums
79+ . iter_mut ( )
80+ . zip ( group_weights. iter ( ) )
81+ . for_each ( |( sum, & weight) | {
82+ * sum /= weight
83+ } ) ;
3684
37- // Subtract means from each sample
38- x. iter_mut ( ) . zip ( group_ids ) . for_each ( | ( xi , & gid ) | {
39- * xi -= group_means [ gid] ;
40- } ) ;
85+ // Subtract group means
86+ for ( xi , & gid ) in x. iter_mut ( ) . zip ( groups ) {
87+ * xi -= self . group_sums [ gid] // Really these are means now
88+ }
4189 }
4290
43- pub ( super ) fn calc_group_weights (
44- sample_weights : & [ f64 ] ,
45- group_ids : & [ usize ] ,
46- n_samples : usize ,
47- n_factors : usize ,
48- n_groups : usize ,
49- ) -> Vec < f64 > {
50- let mut group_weights = vec ! [ 0.0 ; n_factors * n_groups ] ;
51- for i in 0 ..n_samples {
52- let weight = sample_weights [ i ] ;
53- for j in 0 ..n_factors {
54- let id = group_ids [ i * n_factors + j ] ;
55- group_weights [ j * n_groups + id ] += weight ;
56- }
57- }
58- group_weights
91+
92+ // Check elementwise convergence and update checkpoint
93+ fn check_convergence (
94+ & mut self ,
95+ x : & ArrayView1 < f64 > ,
96+ tol : f64 ,
97+ ) {
98+ self . converged = true ; // Innocent until proven guilty
99+ x . iter ( )
100+ . zip ( self . checkpoint . iter_mut ( ) )
101+ . for_each ( | ( & xi , cp ) | {
102+ if ( xi - * cp ) . abs ( ) > tol {
103+ self . converged = false ; // Guilty!
104+ }
105+ * cp = xi ; // Update checkpoint
106+ } ) ;
59107 }
60108}
61109
62- fn demean_impl (
63- x : & ArrayView2 < f64 > ,
64- flist : & ArrayView2 < usize > ,
65- weights : & ArrayView1 < f64 > ,
66- tol : f64 ,
67- maxiter : usize ,
68- ) -> ( Array2 < f64 > , bool ) {
69- let ( n_samples, n_features) = x. dim ( ) ;
70- let n_factors = flist. ncols ( ) ;
71- let n_groups = flist. iter ( ) . cloned ( ) . max ( ) . unwrap ( ) + 1 ;
72-
73- let sample_weights: Vec < f64 > = weights. iter ( ) . cloned ( ) . collect ( ) ;
74- let group_ids: Vec < usize > = flist. iter ( ) . cloned ( ) . collect ( ) ;
75- let group_weights =
76- internal:: calc_group_weights ( & sample_weights, & group_ids, n_samples, n_factors, n_groups) ;
77-
78- let not_converged = Arc :: new ( AtomicUsize :: new ( 0 ) ) ;
79-
80- // Precompute slices of group_ids for each factor
81- let group_ids_by_factor: Vec < Vec < usize > > = ( 0 ..n_factors)
82- . map ( |j| {
83- ( 0 ..n_samples)
84- . map ( |i| group_ids[ i * n_factors + j] )
85- . collect ( )
86- } )
87- . collect ( ) ;
88-
89- // Precompute group weight slices
90- let group_weight_slices: Vec < & [ f64 ] > = ( 0 ..n_factors)
91- . map ( |j| & group_weights[ j * n_groups..( j + 1 ) * n_groups] )
92- . collect ( ) ;
93-
94- let process_column = |( k, mut col) : ( usize , ndarray:: ArrayViewMut1 < f64 > ) | {
95- let mut xk_curr: Vec < f64 > = ( 0 ..n_samples) . map ( |i| x[ [ i, k] ] ) . collect ( ) ;
96- let mut xk_prev: Vec < f64 > = xk_curr. iter ( ) . map ( |& v| v - 1.0 ) . collect ( ) ;
97- let mut gw_sums = vec ! [ 0.0 ; n_groups] ;
98-
99- let mut converged = false ;
100- for _ in 0 ..maxiter {
101- for j in 0 ..n_factors {
102- internal:: subtract_weighted_group_mean (
103- & mut xk_curr,
104- & sample_weights,
105- & group_ids_by_factor[ j] ,
106- group_weight_slices[ j] ,
107- & mut gw_sums,
108- ) ;
109- }
110+ // Instead of recomputing the denominators for the weighted group averages every time,
111+ // we'll precompute them and store them in a grid-like structure. The grid will have
112+ // dimensions (m, k) where m is the number of factors and k is the maximum group ID.
113+ struct FactorGroupWeights {
114+ values : Vec < f64 > ,
115+ width : usize ,
116+ }
117+
118+ impl FactorGroupWeights {
119+ fn new ( flist : & ArrayView2 < usize > , weights : & ArrayView1 < f64 > ) -> Self {
120+ let n_samples = flist. nrows ( ) ;
121+ let n_factors = flist. ncols ( ) ;
122+ let width = flist. iter ( ) . max ( ) . unwrap ( ) + 1 ;
110123
111- if internal:: sad_converged ( & xk_curr, & xk_prev, tol) {
112- converged = true ;
113- break ;
124+ let mut values = vec ! [ 0.0 ; n_factors * width] ;
125+ for i in 0 ..n_samples {
126+ let weight = weights[ i] ;
127+ for j in 0 ..n_factors {
128+ let id = flist[ [ i, j] ] ;
129+ values[ j * width + id] += weight;
114130 }
115- xk_prev. copy_from_slice ( & xk_curr) ;
116131 }
117132
118- if !converged {
119- not_converged. fetch_add ( 1 , Ordering :: SeqCst ) ;
133+ Self {
134+ values,
135+ width,
120136 }
121- Zip :: from ( & mut col) . and ( & xk_curr) . for_each ( |col_elm, & val| {
122- * col_elm = val;
123- } ) ;
124- } ;
125-
126- let mut res = Array2 :: < f64 > :: zeros ( ( n_samples, n_features) ) ;
127-
128- res. axis_iter_mut ( ndarray:: Axis ( 1 ) )
129- . into_par_iter ( )
130- . enumerate ( )
131- . for_each ( process_column) ;
137+ }
132138
133- let success = not_converged. load ( Ordering :: SeqCst ) == 0 ;
134- ( res, success)
139+ fn factor_weight_slice ( & self , factor_index : usize ) -> & [ f64 ] {
140+ & self . values [ factor_index * self . width ..( factor_index + 1 ) * self . width ]
141+ }
135142}
136143
137144
@@ -196,7 +203,6 @@ fn demean_impl(
196203/// print(x_demeaned)
197204/// print("Converged:", converged)
198205/// ```
199-
200206#[ pyfunction]
201207#[ pyo3( signature = ( x, flist, weights, tol=1e-8 , maxiter=100_000 ) ) ]
202208pub fn _demean_rs (
@@ -207,13 +213,18 @@ pub fn _demean_rs(
207213 tol : f64 ,
208214 maxiter : usize ,
209215) -> PyResult < ( Py < PyArray2 < f64 > > , bool ) > {
210- let x_arr = x. as_array ( ) ;
211- let flist_arr = flist. as_array ( ) ;
212- let weights_arr = weights. as_array ( ) ;
213-
214- let ( out, success) =
215- py. allow_threads ( || demean_impl ( & x_arr, & flist_arr, & weights_arr, tol, maxiter) ) ;
216-
217- let pyarray = PyArray2 :: from_owned_array ( py, out) ;
218- Ok ( ( pyarray. into_py ( py) , success) )
216+ let mut x_array = x. as_array ( ) . to_owned ( ) ;
217+ let flist_array = flist. as_array ( ) ;
218+ let weights_array = weights. as_array ( ) ;
219+
220+ let converged = demean_impl (
221+ & mut x_array,
222+ flist_array,
223+ weights_array,
224+ tol,
225+ maxiter,
226+ ) ;
227+
228+ let pyarray = PyArray2 :: from_owned_array ( py, x_array) ;
229+ Ok ( ( pyarray. into_py ( py) , converged) )
219230}
0 commit comments