Skip to content

Commit

Permalink
Better initialization for implicit biases (#66)
Browse files Browse the repository at this point in the history
* better initialization for implicit biases

* remove redundant calculations

* avoid unneeded initialization

* fix failing tests

* fix bias init for rows/columns with all-missing values

* remove unneeded header

* correct formula for bias initialization

* spacing

* another correction
  • Loading branch information
david-cortes authored May 25, 2021
1 parent 84690b8 commit e9e5da0
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 54 deletions.
8 changes: 4 additions & 4 deletions R/RcppExports.R
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,11 @@ als_implicit_float <- function(m_csc_r, X_, Y_, XtX_, lambda, n_threads, solver,
.Call(`_rsparse_als_implicit_float`, m_csc_r, X_, Y_, XtX_, lambda, n_threads, solver, cg_steps, with_biases, is_x_bias_last_row, global_bias, global_bias_base_, initialize_bias_base)
}

initialize_biases_double <- function(m_csc_r, m_csr_r, user_bias, item_bias, lambda, dynamic_lambda, non_negative, calculate_global_bias = FALSE, is_explicit_feedback = FALSE, initialize_item_biases = FALSE) {
.Call(`_rsparse_initialize_biases_double`, m_csc_r, m_csr_r, user_bias, item_bias, lambda, dynamic_lambda, non_negative, calculate_global_bias, is_explicit_feedback, initialize_item_biases)
initialize_biases_double <- function(m_csc_r, m_csr_r, user_bias, item_bias, lambda, dynamic_lambda, non_negative, calculate_global_bias = FALSE, is_explicit_feedback = FALSE) {
.Call(`_rsparse_initialize_biases_double`, m_csc_r, m_csr_r, user_bias, item_bias, lambda, dynamic_lambda, non_negative, calculate_global_bias, is_explicit_feedback)
}

initialize_biases_float <- function(m_csc_r, m_csr_r, user_bias, item_bias, lambda, dynamic_lambda, non_negative, calculate_global_bias = FALSE, is_explicit_feedback = FALSE, initialize_item_biases = FALSE) {
.Call(`_rsparse_initialize_biases_float`, m_csc_r, m_csr_r, user_bias, item_bias, lambda, dynamic_lambda, non_negative, calculate_global_bias, is_explicit_feedback, initialize_item_biases)
initialize_biases_float <- function(m_csc_r, m_csr_r, user_bias, item_bias, lambda, dynamic_lambda, non_negative, calculate_global_bias = FALSE, is_explicit_feedback = FALSE) {
.Call(`_rsparse_initialize_biases_float`, m_csc_r, m_csr_r, user_bias, item_bias, lambda, dynamic_lambda, non_negative, calculate_global_bias, is_explicit_feedback)
}

3 changes: 1 addition & 2 deletions R/model_WRMF.R
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,7 @@ WRMF = R6::R6Class(
initialize_biases_double,
initialize_biases_float)
FUN(c_ui, c_iu, user_bias, item_bias, private$lambda, private$dynamic_lambda,
private$non_negative, private$with_global_bias, feedback == "explicit",
private$solver_code != 1)
private$non_negative, private$with_global_bias, feedback == "explicit")
}

self$components = init
Expand Down
99 changes: 68 additions & 31 deletions inst/include/wrmf_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ double initialize_biases_explicit(dMappedCSC& ConfCSC, // modified in place
}
item_bias[col] /=
lambda_use + static_cast<T>(ConfCSC.col_ptrs[col + 1] - ConfCSC.col_ptrs[col]);
if (non_negative) item_bias[col] = std::fmax(0., item_bias[col]);
if (non_negative) item_bias[col] = std::fmax((T)0, item_bias[col]);
}

user_bias.zeros();
Expand All @@ -75,7 +75,7 @@ double initialize_biases_explicit(dMappedCSC& ConfCSC, // modified in place
}
user_bias[row] /=
lambda_use + static_cast<T>(ConfCSR.col_ptrs[row + 1] - ConfCSR.col_ptrs[row]);
if (non_negative) user_bias[row] = std::fmax(0., user_bias[row]);
if (non_negative) user_bias[row] = std::fmax((T)0, user_bias[row]);
}
}
return global_bias;
Expand All @@ -84,8 +84,7 @@ double initialize_biases_explicit(dMappedCSC& ConfCSC, // modified in place
template <class T>
double initialize_biases_implicit(dMappedCSC& ConfCSC, dMappedCSC& ConfCSR,
arma::Col<T>& user_bias, arma::Col<T>& item_bias,
T lambda, bool calculate_global_bias, bool non_negative,
const bool initialize_item_biases)
T lambda, bool calculate_global_bias, bool non_negative)
{
double global_bias = 0;
if (calculate_global_bias) {
Expand All @@ -94,35 +93,74 @@ double initialize_biases_implicit(dMappedCSC& ConfCSC, dMappedCSC& ConfCSR,
}
if (non_negative) global_bias = std::fmax(0., global_bias); /* <- should not happen, but just in case */

user_bias.zeros();
item_bias.zeros();
const int n_users = ConfCSR.n_cols;
const int n_items = ConfCSR.n_rows;
std::vector<double> user_means(n_users);
std::vector<double> item_means(n_items);
std::vector<double> user_adjustment(n_users);
std::vector<double> item_adjustment(n_items);
for (int row = 0; row < n_users; row++) {
if (ConfCSR.col_ptrs[row + 1] > ConfCSR.col_ptrs[row]) {
for (int ix = ConfCSR.col_ptrs[row]; ix < ConfCSR.col_ptrs[row + 1]; ix++)
user_adjustment[row] += ConfCSR.values[ix];
user_means[row] = user_adjustment[row] / (user_adjustment[row] + (double)(n_items - (ConfCSR.col_ptrs[row + 1] - ConfCSR.col_ptrs[row])));
user_adjustment[row] += (double)(n_items - (ConfCSR.col_ptrs[row + 1] - ConfCSR.col_ptrs[row]));
user_adjustment[row] /= user_adjustment[row] + lambda;
} else {
user_means[row] = 0;
user_adjustment[row] = (double)n_items / ((double)n_items + lambda);
}
}
for (int col = 0; col < n_items; col++) {
if (ConfCSC.col_ptrs[col + 1] > ConfCSC.col_ptrs[col]) {
for (int ix = ConfCSC.col_ptrs[col]; ix < ConfCSC.col_ptrs[col + 1]; ix++)
item_adjustment[col] += ConfCSC.values[ix];
item_means[col] = item_adjustment[col] / (item_adjustment[col] + (double)(n_users - (ConfCSC.col_ptrs[col + 1] - ConfCSC.col_ptrs[col])));
item_adjustment[col] += (double)(n_users - (ConfCSC.col_ptrs[col + 1] - ConfCSC.col_ptrs[col]));
item_adjustment[col] /= item_adjustment[col] + lambda;
} else {
item_means[col] = 0;
item_adjustment[col] = (double)n_users / ((double)n_users + lambda);
}
}

double sweight;
const double n_items = ConfCSR.n_rows;

for (int row = 0; row < ConfCSR.n_cols; row++) {
sweight = 0;
for (int ix = ConfCSR.col_ptrs[row]; ix < ConfCSR.col_ptrs[row + 1]; ix++) {
user_bias[row] += ConfCSR.values[ix] + global_bias * (1. - ConfCSR.values[ix]);
sweight += ConfCSR.values[ix] - 1.;
double bias_mean;
double bias_this;
double wsum;
for (int iter = 0; iter < 5; iter++) {
/* item biases */
bias_mean = 0;
if (iter > 0) {
for (int row = 0; row < n_users; row++)
bias_mean += (user_bias[row] - bias_mean) / (T)(row + 1);
}
user_bias[row] -= global_bias * n_items;
user_bias[row] /= sweight + n_items + lambda;
user_bias[row] /= 3; /* <- item biases are unaccounted for, don't want to assign everything to the user */
if (non_negative) user_bias[row] = std::fmax(0., user_bias[row]);
}
for (int col = 0; col < n_items; col++) {
wsum = n_users;
bias_this = bias_mean;
for (int ix = ConfCSC.col_ptrs[col]; ix < ConfCSC.col_ptrs[col + 1]; ix++)
bias_this += ((ConfCSC.values[ix] - 1) * (user_bias[ConfCSC.row_indices[ix]] - bias_this)) / (wsum += (ConfCSC.values[ix] - 1));
item_bias[col] = (item_means[col] - bias_this - global_bias) * item_adjustment[col];
}

if (non_negative)
for (int col = 0; col < n_items; col++) item_bias[col] = std::fmax((T)0, item_bias[col]);

const double n_users = ConfCSC.n_rows;
for (int col = 0; col < ConfCSC.n_cols; col++) {
sweight = 0;
for (int ix = ConfCSC.col_ptrs[col]; ix < ConfCSC.col_ptrs[col + 1]; ix++) {
item_bias[col] += ConfCSC.values[ix] + global_bias * (1. - ConfCSC.values[ix]);
sweight += ConfCSC.values[ix] - 1.;
/* user biases */
bias_mean = 0;
for (int col = 0; col < n_items; col++)
bias_mean += (item_bias[col] - bias_mean) / (T)(col + 1);

for (int row = 0; row < n_users; row++) {
wsum = n_items;
bias_this = bias_mean;
for (int ix = ConfCSR.col_ptrs[row]; ix < ConfCSR.col_ptrs[row + 1]; ix++)
bias_this += ((ConfCSR.values[ix] - 1) * (item_bias[ConfCSR.row_indices[ix]] - bias_this)) / (wsum += (ConfCSR.values[ix] - 1));
user_bias[row] = (user_means[row] - bias_this - global_bias) * user_adjustment[row];
}
item_bias[col] -= global_bias * n_users;
item_bias[col] /= sweight + n_users + lambda;
item_bias[col] /= 3; /* <- user biases are unaccounted for */
if (non_negative) item_bias[col] = std::fmax(0., item_bias[col]);

if (non_negative)
for (int row = 0; row < n_users; row++) user_bias[row] = std::fmax((T)0, user_bias[row]);
}

return global_bias;
Expand All @@ -134,13 +172,12 @@ double initialize_biases(dMappedCSC& ConfCSC, // modified in place
dMappedCSC& ConfCSR, // modified in place
arma::Col<T>& user_bias, arma::Col<T>& item_bias, T lambda,
bool dynamic_lambda, bool non_negative,
bool calculate_global_bias, bool is_explicit_feedback,
const bool initialize_item_biases) {
bool calculate_global_bias, bool is_explicit_feedback) {
if (is_explicit_feedback)
return initialize_biases_explicit(ConfCSC, ConfCSR, user_bias, item_bias,
lambda, dynamic_lambda, non_negative,
calculate_global_bias);
else
return initialize_biases_implicit(ConfCSC, ConfCSR, user_bias, item_bias, lambda,
calculate_global_bias,non_negative, initialize_item_biases);
calculate_global_bias,non_negative);
}
18 changes: 8 additions & 10 deletions src/RcppExports.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -410,8 +410,8 @@ BEGIN_RCPP
END_RCPP
}
// initialize_biases_double
double initialize_biases_double(const Rcpp::S4& m_csc_r, const Rcpp::S4& m_csr_r, arma::Col<double>& user_bias, arma::Col<double>& item_bias, double lambda, bool dynamic_lambda, bool non_negative, bool calculate_global_bias, bool is_explicit_feedback, const bool initialize_item_biases);
RcppExport SEXP _rsparse_initialize_biases_double(SEXP m_csc_rSEXP, SEXP m_csr_rSEXP, SEXP user_biasSEXP, SEXP item_biasSEXP, SEXP lambdaSEXP, SEXP dynamic_lambdaSEXP, SEXP non_negativeSEXP, SEXP calculate_global_biasSEXP, SEXP is_explicit_feedbackSEXP, SEXP initialize_item_biasesSEXP) {
double initialize_biases_double(const Rcpp::S4& m_csc_r, const Rcpp::S4& m_csr_r, arma::Col<double>& user_bias, arma::Col<double>& item_bias, double lambda, bool dynamic_lambda, bool non_negative, bool calculate_global_bias, bool is_explicit_feedback);
RcppExport SEXP _rsparse_initialize_biases_double(SEXP m_csc_rSEXP, SEXP m_csr_rSEXP, SEXP user_biasSEXP, SEXP item_biasSEXP, SEXP lambdaSEXP, SEXP dynamic_lambdaSEXP, SEXP non_negativeSEXP, SEXP calculate_global_biasSEXP, SEXP is_explicit_feedbackSEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
Expand All @@ -424,14 +424,13 @@ BEGIN_RCPP
Rcpp::traits::input_parameter< bool >::type non_negative(non_negativeSEXP);
Rcpp::traits::input_parameter< bool >::type calculate_global_bias(calculate_global_biasSEXP);
Rcpp::traits::input_parameter< bool >::type is_explicit_feedback(is_explicit_feedbackSEXP);
Rcpp::traits::input_parameter< const bool >::type initialize_item_biases(initialize_item_biasesSEXP);
rcpp_result_gen = Rcpp::wrap(initialize_biases_double(m_csc_r, m_csr_r, user_bias, item_bias, lambda, dynamic_lambda, non_negative, calculate_global_bias, is_explicit_feedback, initialize_item_biases));
rcpp_result_gen = Rcpp::wrap(initialize_biases_double(m_csc_r, m_csr_r, user_bias, item_bias, lambda, dynamic_lambda, non_negative, calculate_global_bias, is_explicit_feedback));
return rcpp_result_gen;
END_RCPP
}
// initialize_biases_float
double initialize_biases_float(const Rcpp::S4& m_csc_r, const Rcpp::S4& m_csr_r, Rcpp::S4& user_bias, Rcpp::S4& item_bias, double lambda, bool dynamic_lambda, bool non_negative, bool calculate_global_bias, bool is_explicit_feedback, const bool initialize_item_biases);
RcppExport SEXP _rsparse_initialize_biases_float(SEXP m_csc_rSEXP, SEXP m_csr_rSEXP, SEXP user_biasSEXP, SEXP item_biasSEXP, SEXP lambdaSEXP, SEXP dynamic_lambdaSEXP, SEXP non_negativeSEXP, SEXP calculate_global_biasSEXP, SEXP is_explicit_feedbackSEXP, SEXP initialize_item_biasesSEXP) {
double initialize_biases_float(const Rcpp::S4& m_csc_r, const Rcpp::S4& m_csr_r, Rcpp::S4& user_bias, Rcpp::S4& item_bias, double lambda, bool dynamic_lambda, bool non_negative, bool calculate_global_bias, bool is_explicit_feedback);
RcppExport SEXP _rsparse_initialize_biases_float(SEXP m_csc_rSEXP, SEXP m_csr_rSEXP, SEXP user_biasSEXP, SEXP item_biasSEXP, SEXP lambdaSEXP, SEXP dynamic_lambdaSEXP, SEXP non_negativeSEXP, SEXP calculate_global_biasSEXP, SEXP is_explicit_feedbackSEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
Expand All @@ -444,8 +443,7 @@ BEGIN_RCPP
Rcpp::traits::input_parameter< bool >::type non_negative(non_negativeSEXP);
Rcpp::traits::input_parameter< bool >::type calculate_global_bias(calculate_global_biasSEXP);
Rcpp::traits::input_parameter< bool >::type is_explicit_feedback(is_explicit_feedbackSEXP);
Rcpp::traits::input_parameter< const bool >::type initialize_item_biases(initialize_item_biasesSEXP);
rcpp_result_gen = Rcpp::wrap(initialize_biases_float(m_csc_r, m_csr_r, user_bias, item_bias, lambda, dynamic_lambda, non_negative, calculate_global_bias, is_explicit_feedback, initialize_item_biases));
rcpp_result_gen = Rcpp::wrap(initialize_biases_float(m_csc_r, m_csr_r, user_bias, item_bias, lambda, dynamic_lambda, non_negative, calculate_global_bias, is_explicit_feedback));
return rcpp_result_gen;
END_RCPP
}
Expand Down Expand Up @@ -476,8 +474,8 @@ static const R_CallMethodDef CallEntries[] = {
{"_rsparse_als_explicit_float", (DL_FUNC) &_rsparse_als_explicit_float, 11},
{"_rsparse_als_implicit_double", (DL_FUNC) &_rsparse_als_implicit_double, 13},
{"_rsparse_als_implicit_float", (DL_FUNC) &_rsparse_als_implicit_float, 13},
{"_rsparse_initialize_biases_double", (DL_FUNC) &_rsparse_initialize_biases_double, 10},
{"_rsparse_initialize_biases_float", (DL_FUNC) &_rsparse_initialize_biases_float, 10},
{"_rsparse_initialize_biases_double", (DL_FUNC) &_rsparse_initialize_biases_double, 9},
{"_rsparse_initialize_biases_float", (DL_FUNC) &_rsparse_initialize_biases_float, 9},
{NULL, NULL, 0}
};

Expand Down
11 changes: 4 additions & 7 deletions src/wrmf_init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,20 @@ double initialize_biases_double(const Rcpp::S4& m_csc_r, const Rcpp::S4& m_csr_r
arma::Col<double>& item_bias, double lambda,
bool dynamic_lambda, bool non_negative,
bool calculate_global_bias = false,
bool is_explicit_feedback = false,
const bool initialize_item_biases = false) {
bool is_explicit_feedback = false) {
dMappedCSC ConfCSC = extract_mapped_csc(m_csc_r);
dMappedCSC ConfCSR = extract_mapped_csc(m_csr_r);
return initialize_biases<double>(ConfCSC, ConfCSR, user_bias, item_bias, lambda,
dynamic_lambda, non_negative, calculate_global_bias,
is_explicit_feedback, initialize_item_biases);
is_explicit_feedback);
}

// [[Rcpp::export]]
double initialize_biases_float(const Rcpp::S4& m_csc_r, const Rcpp::S4& m_csr_r,
Rcpp::S4& user_bias, Rcpp::S4& item_bias, double lambda,
bool dynamic_lambda, bool non_negative,
bool calculate_global_bias = false,
bool is_explicit_feedback = false,
const bool initialize_item_biases = false) {
bool is_explicit_feedback = false) {
dMappedCSC ConfCSC = extract_mapped_csc(m_csc_r);
dMappedCSC ConfCSR = extract_mapped_csc(m_csr_r);

Expand All @@ -32,6 +30,5 @@ double initialize_biases_float(const Rcpp::S4& m_csc_r, const Rcpp::S4& m_csr_r,

return initialize_biases<float>(ConfCSC, ConfCSR, user_bias_arma, item_bias_arma,
lambda, dynamic_lambda, non_negative,
calculate_global_bias, is_explicit_feedback,
initialize_item_biases);
calculate_global_bias, is_explicit_feedback);
}

0 comments on commit e9e5da0

Please sign in to comment.