Skip to content

Commit 780fea2

Browse files
committed
re-arrange omp parallel region to make more efficient memory allocations. Related to #72
1 parent aa9cf58 commit 780fea2

File tree

2 files changed

+29
-19
lines changed

2 files changed

+29
-19
lines changed

R/model_WRMF.R

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,6 @@ WRMF = R6::R6Class(
331331

332332
loss_prev_iter = loss
333333
}
334-
335334
if (private$precision == "double")
336335
data.table::setattr(self$components, "dimnames", list(NULL, colnames(x)))
337336
else
@@ -341,7 +340,10 @@ WRMF = R6::R6Class(
341340
rank_ = ifelse(private$with_user_item_bias, private$rank - 1L, private$rank)
342341
ridge = fl(diag(x = private$lambda, nrow = rank_, ncol = rank_))
343342
XX = if (private$with_user_item_bias) self$components[-1L, , drop = FALSE] else self$components
343+
344+
RhpcBLASctl::blas_set_num_threads(RhpcBLASctl::get_num_cores())
344345
private$XtX = tcrossprod(XX) + ridge
346+
RhpcBLASctl::blas_set_num_threads(1)
345347

346348
# call extra transform to ensure results from transform() and fit_transform()
347349
# are the same (due to avoid_cg, etc)
@@ -465,7 +467,9 @@ als_implicit = function(
465467
} else {
466468
XX = X
467469
}
470+
RhpcBLASctl::blas_set_num_threads(RhpcBLASctl::get_num_cores())
468471
XtX = tcrossprod(XX) + ridge
472+
RhpcBLASctl::blas_set_num_threads(1)
469473
}
470474
if (is.null(global_bias_base)) {
471475
global_bias_base = numeric()

inst/include/wrmf_implicit.hpp

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -149,22 +149,28 @@ T als_implicit(const dMappedCSC& Conf, arma::Mat<T>& X, arma::Mat<T>& Y,
149149
// C = 1 (so we omit multiplication on eye matrix)
150150
// rhs = X * eye * (0 - x_biases) = -X * x_biases
151151
rhs_init *= -x_biases;
152-
}
153-
154-
else {
152+
} else {
155153
rhs_init = - (drop_row<T>(X, is_x_bias_last_row) * (x_biases + global_bias));
156154
}
157-
}
158-
159-
else if (global_bias) {
155+
} else if (global_bias) {
160156
rhs_init = arma::Mat<T>(&global_bias_base[0], rank - (int)with_biases, 1, false, true);
161157
}
162158

163159

164160
double loss = 0;
165161
size_t nc = Conf.n_cols;
166162
#ifdef _OPENMP
167-
#pragma omp parallel for num_threads(n_threads) schedule(dynamic, GRAIN_SIZE) reduction(+:loss)
163+
#pragma omp parallel num_threads(n_threads)
164+
#endif
165+
{
166+
arma::Mat<T> X_nnz;
167+
arma::Mat<T> X_nnz_t;
168+
arma::Col<T> init;
169+
arma::Col<T> Y_new;
170+
arma::Mat<T> rhs;
171+
172+
#ifdef _OPENMP
173+
#pragma omp for schedule(dynamic) reduction(+:loss)
168174
#endif
169175
for (size_t i = 0; i < nc; i++) {
170176
arma::uword p1 = Conf.col_ptrs[i];
@@ -175,32 +181,32 @@ T als_implicit(const dMappedCSC& Conf, arma::Mat<T>& X, arma::Mat<T>& Y,
175181
const arma::uvec idx = arma::uvec(&Conf.row_indices[p1], p2 - p1, false, true);
176182
arma::Col<T> confidence =
177183
arma::conv_to<arma::Col<T> >::from(arma::vec(&Conf.values[p1], p2 - p1));
178-
arma::Mat<T> X_nnz = X.cols(idx);
179-
arma::Col<T> init = Y.col(i);
184+
X_nnz = X.cols(idx);
180185
// if is_x_bias_last_row == true
181186
// X_nnz = [1, ...]
182187
// if is_x_bias_last_row == false
183188
// X_nnz = [..., 1]
184189
if (with_biases) {
185190
X_nnz = drop_row<T>(X_nnz, is_x_bias_last_row);
186-
init = drop_row<T>(init, !is_x_bias_last_row);
191+
// init = drop_row<T>(init, !is_x_bias_last_row);
187192
}
188-
arma::Col<T> Y_new;
189193

190194
if (solver == CONJUGATE_GRADIENT) {
195+
init = Y.col(i);
191196
if (!with_biases && !global_bias)
192197
Y_new = cg_solver_implicit<T>(X_nnz, confidence, init, cg_steps, XtX);
193-
else if (with_biases)
198+
else if (with_biases) {
199+
init = drop_row<T>(init, !is_x_bias_last_row);
194200
Y_new = cg_solver_implicit_user_item_bias<T>(X_nnz, confidence, init, cg_steps, XtX,
195201
rhs_init, x_biases(idx), global_bias);
196-
else
202+
} else {
197203
Y_new = cg_solver_implicit_global_bias<T>(X_nnz, confidence, init, cg_steps, XtX,
198204
rhs_init, global_bias);
199-
205+
}
200206
} else {
201207
const arma::Mat<T> lhs =
202-
XtX + X_nnz.each_row() % (confidence.t() - 1) * X_nnz.t();
203-
arma::Mat<T> rhs;
208+
XtX + X_nnz.each_row() % (confidence.t() - 1) * X_nnz.t();
209+
204210
if (with_biases) {
205211
// now we need to update rhs with rhs_init and take into account
206212
// items with interactions (p=1)
@@ -227,7 +233,7 @@ T als_implicit(const dMappedCSC& Conf, arma::Mat<T>& X, arma::Mat<T>& Y,
227233
if (solver == SEQ_COORDINATE_WISE_NNLS) {
228234
Y_new = c_nnls<T>(lhs, rhs, init, SCD_MAX_ITER, SCD_TOL);
229235
} else { // CHOLESKY
230-
Y_new = solve(lhs, rhs, arma::solve_opts::fast);
236+
Y_new = solve(lhs, rhs, arma::solve_opts::fast + arma::solve_opts::likely_sympd);
231237
}
232238
}
233239

@@ -276,7 +282,7 @@ T als_implicit(const dMappedCSC& Conf, arma::Mat<T>& X, arma::Mat<T>& Y,
276282
}
277283
}
278284
}
279-
285+
}
280286
if (lambda > 0) {
281287
if (with_biases) {
282288
// lambda applied to all learned parameters:

0 commit comments

Comments
 (0)