From c91b67633a585b31b5711bb02247c0ab60378ad0 Mon Sep 17 00:00:00 2001 From: Dmitriy Selivanov Date: Mon, 21 Nov 2022 10:36:42 +0800 Subject: [PATCH] optimize R code, avoid double work in transform --- R/model_WRMF.R | 89 +++++++++++++++++++++++++++++--------------------- 1 file changed, 51 insertions(+), 38 deletions(-) diff --git a/R/model_WRMF.R b/R/model_WRMF.R index 5134c90..1ead70c 100644 --- a/R/model_WRMF.R +++ b/R/model_WRMF.R @@ -180,10 +180,15 @@ WRMF = R6::R6Class( RhpcBLASctl::blas_set_num_threads(blas_threads_keep) }) } - + logger$debug("converting input user-item matrix") c_ui = MatrixExtra::as.csc.matrix(x) + # c_ui = as(x, "CsparseMatrix") + logger$debug("pre-processing input") c_ui = private$preprocess(c_ui) - c_iu = MatrixExtra::t_shallow(MatrixExtra::as.csr.matrix(x)) + logger$debug("creating item-user matrix") + c_iu = MatrixExtra::t_shallow(MatrixExtra::as.csr.matrix(c_ui)) + # c_iu = t(c_ui) + logger$debug("created item-user matrix") # store item_ids in order to use them in predict method private$item_ids = colnames(c_ui) @@ -195,7 +200,7 @@ WRMF = R6::R6Class( n_user = nrow(c_ui) n_item = ncol(c_ui) - logger$trace("initializing U") + logger$debug("initializing U") if (private$precision == "double") { private$U = large_rand_matrix(private$rank, n_user) # for item biases @@ -210,7 +215,7 @@ WRMF = R6::R6Class( } if (is.null(self$components)) { - + logger$debug("initializing components") if (private$solver_code == 1L) { ### <- cholesky if (private$precision == "double") { self$components = matrix(0, private$rank, n_item) @@ -331,6 +336,8 @@ WRMF = R6::R6Class( loss_prev_iter = loss } + logger$debug("solver finished") + if (private$precision == "double") data.table::setattr(self$components, "dimnames", list(NULL, colnames(x))) else @@ -348,7 +355,8 @@ WRMF = R6::R6Class( # call extra transform to ensure results from transform() and fit_transform() # are the same (due to avoid_cg, etc) # this adds some extra computation, but not a big deal though - self$transform(x) + # self$transform(x) + private$transform_(c_iu, ...) }, # project new users into latent user space - just make ALS step given fixed items matrix #' @description create user embeddings for new input @@ -368,6 +376,41 @@ WRMF = R6::R6Class( x = MatrixExtra::t_shallow(x) } + x = private$preprocess(x) + + if (self$global_bias != 0. && private$feedback == "explicit") + x@x = x@x - self$global_bias + + private$transform_(x, ...) + } + ), + #### private ----- + private = list( + solver_code = NULL, + cg_steps = NULL, + scorers = NULL, + lambda = NULL, + dynamic_lambda = FALSE, + rank = NULL, + non_negative = NULL, + cnt_u = NULL, + # user factor matrix = rank * n_users + U = NULL, + # item factor matrix = rank * n_items + I = NULL, + # preprocess - transformation of input matrix before passing it to ALS + # for example we can scale each row or apply log() to values + # this is essentially "confidence" transformation from WRMF article + preprocess = NULL, + feedback = NULL, + precision = NULL, + XtX = NULL, + solver = NULL, + with_user_item_bias = NULL, + with_global_bias = NULL, + init_user_item_bias = NULL, + transform_ = function(x, ...) { + logger$debug('starting transform') if (private$feedback == "implicit" ) { logger$trace("WRMF$transform(): calling `RhpcBLASctl::blas_set_num_threads(1)` (to avoid thread contention)") blas_threads_keep = RhpcBLASctl::blas_get_num_procs() @@ -377,11 +420,6 @@ WRMF = R6::R6Class( RhpcBLASctl::blas_set_num_threads(blas_threads_keep) }) } - - x = private$preprocess(x) - if (self$global_bias != 0. && private$feedback == "explicit") - x@x = x@x - self$global_bias - if (private$precision == "double") { res = matrix(0, nrow = private$rank, ncol = ncol(x)) } else { @@ -391,7 +429,7 @@ WRMF = R6::R6Class( if (private$with_user_item_bias) { res[1, ] = if(private$precision == "double") 1.0 else float::fl(1.0) } - + logger$debug('starting transform solver') loss = private$solver( x, self$components, @@ -401,6 +439,7 @@ WRMF = R6::R6Class( cnt_X = private$cnt_u, avoid_cg = TRUE ) + logger$debug('finished transform solver') res = t(res) @@ -408,35 +447,9 @@ WRMF = R6::R6Class( setattr(res, "dimnames", list(colnames(x), NULL)) else setattr(res@Data, "dimnames", list(colnames(x), NULL)) - + logger$debug('finished transform') res } - ), - #### private ----- - private = list( - solver_code = NULL, - cg_steps = NULL, - scorers = NULL, - lambda = NULL, - dynamic_lambda = FALSE, - rank = NULL, - non_negative = NULL, - cnt_u = NULL, - # user factor matrix = rank * n_users - U = NULL, - # item factor matrix = rank * n_items - I = NULL, - # preprocess - transformation of input matrix before passing it to ALS - # for example we can scale each row or apply log() to values - # this is essentially "confidence" transformation from WRMF article - preprocess = NULL, - feedback = NULL, - precision = NULL, - XtX = NULL, - solver = NULL, - with_user_item_bias = NULL, - with_global_bias = NULL, - init_user_item_bias = NULL ) )