Skip to content

Commit

Permalink
re-arrange omp parallel region to make more efficient memory allocati…
Browse files Browse the repository at this point in the history
…ons. Related to #72
  • Loading branch information
dselivanov committed Nov 19, 2022
1 parent aa9cf58 commit 780fea2
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 19 deletions.
6 changes: 5 additions & 1 deletion R/model_WRMF.R
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,6 @@ WRMF = R6::R6Class(

loss_prev_iter = loss
}

if (private$precision == "double")
data.table::setattr(self$components, "dimnames", list(NULL, colnames(x)))
else
Expand All @@ -341,7 +340,10 @@ WRMF = R6::R6Class(
rank_ = ifelse(private$with_user_item_bias, private$rank - 1L, private$rank)
ridge = fl(diag(x = private$lambda, nrow = rank_, ncol = rank_))
XX = if (private$with_user_item_bias) self$components[-1L, , drop = FALSE] else self$components

RhpcBLASctl::blas_set_num_threads(RhpcBLASctl::get_num_cores())
private$XtX = tcrossprod(XX) + ridge
RhpcBLASctl::blas_set_num_threads(1)

# call extra transform to ensure results from transform() and fit_transform()
# are the same (due to avoid_cg, etc)
Expand Down Expand Up @@ -465,7 +467,9 @@ als_implicit = function(
} else {
XX = X
}
RhpcBLASctl::blas_set_num_threads(RhpcBLASctl::get_num_cores())
XtX = tcrossprod(XX) + ridge
RhpcBLASctl::blas_set_num_threads(1)
}
if (is.null(global_bias_base)) {
global_bias_base = numeric()
Expand Down
42 changes: 24 additions & 18 deletions inst/include/wrmf_implicit.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,22 +149,28 @@ T als_implicit(const dMappedCSC& Conf, arma::Mat<T>& X, arma::Mat<T>& Y,
// C = 1 (so we omit multiplication on eye matrix)
// rhs = X * eye * (0 - x_biases) = -X * x_biases
rhs_init *= -x_biases;
}

else {
} else {
rhs_init = - (drop_row<T>(X, is_x_bias_last_row) * (x_biases + global_bias));
}
}

else if (global_bias) {
} else if (global_bias) {
rhs_init = arma::Mat<T>(&global_bias_base[0], rank - (int)with_biases, 1, false, true);
}


double loss = 0;
size_t nc = Conf.n_cols;
#ifdef _OPENMP
#pragma omp parallel for num_threads(n_threads) schedule(dynamic, GRAIN_SIZE) reduction(+:loss)
#pragma omp parallel num_threads(n_threads)
#endif
{
arma::Mat<T> X_nnz;
arma::Mat<T> X_nnz_t;
arma::Col<T> init;
arma::Col<T> Y_new;
arma::Mat<T> rhs;

#ifdef _OPENMP
#pragma omp for schedule(dynamic) reduction(+:loss)
#endif
for (size_t i = 0; i < nc; i++) {
arma::uword p1 = Conf.col_ptrs[i];
Expand All @@ -175,32 +181,32 @@ T als_implicit(const dMappedCSC& Conf, arma::Mat<T>& X, arma::Mat<T>& Y,
const arma::uvec idx = arma::uvec(&Conf.row_indices[p1], p2 - p1, false, true);
arma::Col<T> confidence =
arma::conv_to<arma::Col<T> >::from(arma::vec(&Conf.values[p1], p2 - p1));
arma::Mat<T> X_nnz = X.cols(idx);
arma::Col<T> init = Y.col(i);
X_nnz = X.cols(idx);
// if is_x_bias_last_row == true
// X_nnz = [1, ...]
// if is_x_bias_last_row == false
// X_nnz = [..., 1]
if (with_biases) {
X_nnz = drop_row<T>(X_nnz, is_x_bias_last_row);
init = drop_row<T>(init, !is_x_bias_last_row);
// init = drop_row<T>(init, !is_x_bias_last_row);
}
arma::Col<T> Y_new;

if (solver == CONJUGATE_GRADIENT) {
init = Y.col(i);
if (!with_biases && !global_bias)
Y_new = cg_solver_implicit<T>(X_nnz, confidence, init, cg_steps, XtX);
else if (with_biases)
else if (with_biases) {
init = drop_row<T>(init, !is_x_bias_last_row);
Y_new = cg_solver_implicit_user_item_bias<T>(X_nnz, confidence, init, cg_steps, XtX,
rhs_init, x_biases(idx), global_bias);
else
} else {
Y_new = cg_solver_implicit_global_bias<T>(X_nnz, confidence, init, cg_steps, XtX,
rhs_init, global_bias);

}
} else {
const arma::Mat<T> lhs =
XtX + X_nnz.each_row() % (confidence.t() - 1) * X_nnz.t();
arma::Mat<T> rhs;
XtX + X_nnz.each_row() % (confidence.t() - 1) * X_nnz.t();

if (with_biases) {
// now we need to update rhs with rhs_init and take into account
// items with interactions (p=1)
Expand All @@ -227,7 +233,7 @@ T als_implicit(const dMappedCSC& Conf, arma::Mat<T>& X, arma::Mat<T>& Y,
if (solver == SEQ_COORDINATE_WISE_NNLS) {
Y_new = c_nnls<T>(lhs, rhs, init, SCD_MAX_ITER, SCD_TOL);
} else { // CHOLESKY
Y_new = solve(lhs, rhs, arma::solve_opts::fast);
Y_new = solve(lhs, rhs, arma::solve_opts::fast + arma::solve_opts::likely_sympd);
}
}

Expand Down Expand Up @@ -276,7 +282,7 @@ T als_implicit(const dMappedCSC& Conf, arma::Mat<T>& X, arma::Mat<T>& Y,
}
}
}

}
if (lambda > 0) {
if (with_biases) {
// lambda applied to all learned parameters:
Expand Down

0 comments on commit 780fea2

Please sign in to comment.