@@ -62,7 +62,7 @@ double initialize_biases_explicit(dMappedCSC& ConfCSC, // modified in place
6262 }
6363 item_bias[col] /=
6464 lambda_use + static_cast <T>(ConfCSC.col_ptrs [col + 1 ] - ConfCSC.col_ptrs [col]);
65- if (non_negative) item_bias[col] = std::fmax (0 . , item_bias[col]);
65+ if (non_negative) item_bias[col] = std::fmax ((T) 0 , item_bias[col]);
6666 }
6767
6868 user_bias.zeros ();
@@ -75,7 +75,7 @@ double initialize_biases_explicit(dMappedCSC& ConfCSC, // modified in place
7575 }
7676 user_bias[row] /=
7777 lambda_use + static_cast <T>(ConfCSR.col_ptrs [row + 1 ] - ConfCSR.col_ptrs [row]);
78- if (non_negative) user_bias[row] = std::fmax (0 . , user_bias[row]);
78+ if (non_negative) user_bias[row] = std::fmax ((T) 0 , user_bias[row]);
7979 }
8080 }
8181 return global_bias;
@@ -84,8 +84,7 @@ double initialize_biases_explicit(dMappedCSC& ConfCSC, // modified in place
8484template <class T >
8585double initialize_biases_implicit (dMappedCSC& ConfCSC, dMappedCSC& ConfCSR,
8686 arma::Col<T>& user_bias, arma::Col<T>& item_bias,
87- T lambda, bool calculate_global_bias, bool non_negative,
88- const bool initialize_item_biases)
87+ T lambda, bool calculate_global_bias, bool non_negative)
8988{
9089 double global_bias = 0 ;
9190 if (calculate_global_bias) {
@@ -94,35 +93,74 @@ double initialize_biases_implicit(dMappedCSC& ConfCSC, dMappedCSC& ConfCSR,
9493 }
9594 if (non_negative) global_bias = std::fmax (0 ., global_bias); /* <- should not happen, but just in case */
9695
97- user_bias.zeros ();
98- item_bias.zeros ();
96+ const int n_users = ConfCSR.n_cols ;
97+ const int n_items = ConfCSR.n_rows ;
98+ std::vector<double > user_means (n_users);
99+ std::vector<double > item_means (n_items);
100+ std::vector<double > user_adjustment (n_users);
101+ std::vector<double > item_adjustment (n_items);
102+ for (int row = 0 ; row < n_users; row++) {
103+ if (ConfCSR.col_ptrs [row + 1 ] > ConfCSR.col_ptrs [row]) {
104+ for (int ix = ConfCSR.col_ptrs [row]; ix < ConfCSR.col_ptrs [row + 1 ]; ix++)
105+ user_adjustment[row] += ConfCSR.values [ix];
106+ user_means[row] = user_adjustment[row] / (user_adjustment[row] + (double )(n_items - (ConfCSR.col_ptrs [row + 1 ] - ConfCSR.col_ptrs [row])));
107+ user_adjustment[row] += (double )(n_items - (ConfCSR.col_ptrs [row + 1 ] - ConfCSR.col_ptrs [row]));
108+ user_adjustment[row] /= user_adjustment[row] + lambda;
109+ } else {
110+ user_means[row] = 0 ;
111+ user_adjustment[row] = (double )n_items / ((double )n_items + lambda);
112+ }
113+ }
114+ for (int col = 0 ; col < n_items; col++) {
115+ if (ConfCSC.col_ptrs [col + 1 ] > ConfCSC.col_ptrs [col]) {
116+ for (int ix = ConfCSC.col_ptrs [col]; ix < ConfCSC.col_ptrs [col + 1 ]; ix++)
117+ item_adjustment[col] += ConfCSC.values [ix];
118+ item_means[col] = item_adjustment[col] / (item_adjustment[col] + (double )(n_users - (ConfCSC.col_ptrs [col + 1 ] - ConfCSC.col_ptrs [col])));
119+ item_adjustment[col] += (double )(n_users - (ConfCSC.col_ptrs [col + 1 ] - ConfCSC.col_ptrs [col]));
120+ item_adjustment[col] /= item_adjustment[col] + lambda;
121+ } else {
122+ item_means[col] = 0 ;
123+ item_adjustment[col] = (double )n_users / ((double )n_users + lambda);
124+ }
125+ }
99126
100- double sweight;
101- const double n_items = ConfCSR.n_rows ;
102127
103- for (int row = 0 ; row < ConfCSR.n_cols ; row++) {
104- sweight = 0 ;
105- for (int ix = ConfCSR.col_ptrs [row]; ix < ConfCSR.col_ptrs [row + 1 ]; ix++) {
106- user_bias[row] += ConfCSR.values [ix] + global_bias * (1 . - ConfCSR.values [ix]);
107- sweight += ConfCSR.values [ix] - 1 .;
128+ double bias_mean;
129+ double bias_this;
130+ double wsum;
131+ for (int iter = 0 ; iter < 5 ; iter++) {
132+ /* item biases */
133+ bias_mean = 0 ;
134+ if (iter > 0 ) {
135+ for (int row = 0 ; row < n_users; row++)
136+ bias_mean += (user_bias[row] - bias_mean) / (T)(row + 1 );
108137 }
109- user_bias[row] -= global_bias * n_items;
110- user_bias[row] /= sweight + n_items + lambda;
111- user_bias[row] /= 3 ; /* <- item biases are unaccounted for, don't want to assign everything to the user */
112- if (non_negative) user_bias[row] = std::fmax (0 ., user_bias[row]);
113- }
138+ for (int col = 0 ; col < n_items; col++) {
139+ wsum = n_users;
140+ bias_this = bias_mean;
141+ for (int ix = ConfCSC.col_ptrs [col]; ix < ConfCSC.col_ptrs [col + 1 ]; ix++)
142+ bias_this += ((ConfCSC.values [ix] - 1 ) * (user_bias[ConfCSC.row_indices [ix]] - bias_this)) / (wsum += (ConfCSC.values [ix] - 1 ));
143+ item_bias[col] = (item_means[col] - bias_this - global_bias) * item_adjustment[col];
144+ }
145+
146+ if (non_negative)
147+ for (int col = 0 ; col < n_items; col++) item_bias[col] = std::fmax ((T)0 , item_bias[col]);
114148
115- const double n_users = ConfCSC.n_rows ;
116- for (int col = 0 ; col < ConfCSC.n_cols ; col++) {
117- sweight = 0 ;
118- for (int ix = ConfCSC.col_ptrs [col]; ix < ConfCSC.col_ptrs [col + 1 ]; ix++) {
119- item_bias[col] += ConfCSC.values [ix] + global_bias * (1 . - ConfCSC.values [ix]);
120- sweight += ConfCSC.values [ix] - 1 .;
149+ /* user biases */
150+ bias_mean = 0 ;
151+ for (int col = 0 ; col < n_items; col++)
152+ bias_mean += (item_bias[col] - bias_mean) / (T)(col + 1 );
153+
154+ for (int row = 0 ; row < n_users; row++) {
155+ wsum = n_items;
156+ bias_this = bias_mean;
157+ for (int ix = ConfCSR.col_ptrs [row]; ix < ConfCSR.col_ptrs [row + 1 ]; ix++)
158+ bias_this += ((ConfCSR.values [ix] - 1 ) * (item_bias[ConfCSR.row_indices [ix]] - bias_this)) / (wsum += (ConfCSR.values [ix] - 1 ));
159+ user_bias[row] = (user_means[row] - bias_this - global_bias) * user_adjustment[row];
121160 }
122- item_bias[col] -= global_bias * n_users;
123- item_bias[col] /= sweight + n_users + lambda;
124- item_bias[col] /= 3 ; /* <- user biases are unaccounted for */
125- if (non_negative) item_bias[col] = std::fmax (0 ., item_bias[col]);
161+
162+ if (non_negative)
163+ for (int row = 0 ; row < n_users; row++) user_bias[row] = std::fmax ((T)0 , user_bias[row]);
126164 }
127165
128166 return global_bias;
@@ -134,13 +172,12 @@ double initialize_biases(dMappedCSC& ConfCSC, // modified in place
134172 dMappedCSC& ConfCSR, // modified in place
135173 arma::Col<T>& user_bias, arma::Col<T>& item_bias, T lambda,
136174 bool dynamic_lambda, bool non_negative,
137- bool calculate_global_bias, bool is_explicit_feedback,
138- const bool initialize_item_biases) {
175+ bool calculate_global_bias, bool is_explicit_feedback) {
139176 if (is_explicit_feedback)
140177 return initialize_biases_explicit (ConfCSC, ConfCSR, user_bias, item_bias,
141178 lambda, dynamic_lambda, non_negative,
142179 calculate_global_bias);
143180 else
144181 return initialize_biases_implicit (ConfCSC, ConfCSR, user_bias, item_bias, lambda,
145- calculate_global_bias,non_negative, initialize_item_biases );
182+ calculate_global_bias,non_negative);
146183}
0 commit comments