diff --git a/inst/include/wrmf_implicit.hpp b/inst/include/wrmf_implicit.hpp index 693da67..35b5680 100644 --- a/inst/include/wrmf_implicit.hpp +++ b/inst/include/wrmf_implicit.hpp @@ -248,8 +248,18 @@ T als_implicit(const dMappedCSC& Conf, arma::Mat& X, arma::Mat& Y, Y.unsafe_col(i) = Y_new; } - loss += dot(square(1 - (Y_new.t() * X_nnz)), confidence) + - lambda * arma::dot(Y_new, Y_new); + if (!global_bias && !with_biases) + loss += dot(square(1 - (Y_new.t() * X_nnz)), confidence) + + lambda * arma::dot(Y_new, Y_new); + else if (global_bias && !with_biases) + loss += dot(square((1 - global_bias) - (Y_new.t() * X_nnz)), confidence) + + lambda * arma::dot(Y_new, Y_new); + else if (!global_bias && with_biases) + loss += dot(square(1 - (Y_new.t() * X_nnz) - x_biases(idx).t()), confidence) + + lambda * arma::dot(Y_new, Y_new); + else + loss += dot(square((1 - global_bias) - (Y_new.t() * X_nnz) - x_biases(idx).t()), confidence) + + lambda * arma::dot(Y_new, Y_new); } else { if (with_biases) {