Skip to content

Commit

Permalink
optimize R code, avoid double work in transform
Browse files Browse the repository at this point in the history
  • Loading branch information
dselivanov committed Nov 21, 2022
1 parent 780fea2 commit c91b676
Showing 1 changed file with 51 additions and 38 deletions.
89 changes: 51 additions & 38 deletions R/model_WRMF.R
Original file line number Diff line number Diff line change
Expand Up @@ -180,10 +180,15 @@ WRMF = R6::R6Class(
RhpcBLASctl::blas_set_num_threads(blas_threads_keep)
})
}

logger$debug("converting input user-item matrix")
c_ui = MatrixExtra::as.csc.matrix(x)
# c_ui = as(x, "CsparseMatrix")
logger$debug("pre-processing input")
c_ui = private$preprocess(c_ui)
c_iu = MatrixExtra::t_shallow(MatrixExtra::as.csr.matrix(x))
logger$debug("creating item-user matrix")
c_iu = MatrixExtra::t_shallow(MatrixExtra::as.csr.matrix(c_ui))
# c_iu = t(c_ui)
logger$debug("created item-user matrix")
# store item_ids in order to use them in predict method
private$item_ids = colnames(c_ui)

Expand All @@ -195,7 +200,7 @@ WRMF = R6::R6Class(
n_user = nrow(c_ui)
n_item = ncol(c_ui)

logger$trace("initializing U")
logger$debug("initializing U")
if (private$precision == "double") {
private$U = large_rand_matrix(private$rank, n_user)
# for item biases
Expand All @@ -210,7 +215,7 @@ WRMF = R6::R6Class(
}

if (is.null(self$components)) {

logger$debug("initializing components")
if (private$solver_code == 1L) { ### <- cholesky
if (private$precision == "double") {
self$components = matrix(0, private$rank, n_item)
Expand Down Expand Up @@ -331,6 +336,8 @@ WRMF = R6::R6Class(

loss_prev_iter = loss
}
logger$debug("solver finished")

if (private$precision == "double")
data.table::setattr(self$components, "dimnames", list(NULL, colnames(x)))
else
Expand All @@ -348,7 +355,8 @@ WRMF = R6::R6Class(
# call extra transform to ensure results from transform() and fit_transform()
# are the same (due to avoid_cg, etc)
# this adds some extra computation, but not a big deal though
self$transform(x)
# self$transform(x)
private$transform_(c_iu, ...)
},
# project new users into latent user space - just make ALS step given fixed items matrix
#' @description create user embeddings for new input
Expand All @@ -368,6 +376,41 @@ WRMF = R6::R6Class(
x = MatrixExtra::t_shallow(x)
}

x = private$preprocess(x)

if (self$global_bias != 0. && private$feedback == "explicit")
x@x = x@x - self$global_bias

private$transform_(x, ...)
}
),
#### private -----
private = list(
solver_code = NULL,
cg_steps = NULL,
scorers = NULL,
lambda = NULL,
dynamic_lambda = FALSE,
rank = NULL,
non_negative = NULL,
cnt_u = NULL,
# user factor matrix = rank * n_users
U = NULL,
# item factor matrix = rank * n_items
I = NULL,
# preprocess - transformation of input matrix before passing it to ALS
# for example we can scale each row or apply log() to values
# this is essentially "confidence" transformation from WRMF article
preprocess = NULL,
feedback = NULL,
precision = NULL,
XtX = NULL,
solver = NULL,
with_user_item_bias = NULL,
with_global_bias = NULL,
init_user_item_bias = NULL,
transform_ = function(x, ...) {
logger$debug('starting transform')
if (private$feedback == "implicit" ) {
logger$trace("WRMF$transform(): calling `RhpcBLASctl::blas_set_num_threads(1)` (to avoid thread contention)")
blas_threads_keep = RhpcBLASctl::blas_get_num_procs()
Expand All @@ -377,11 +420,6 @@ WRMF = R6::R6Class(
RhpcBLASctl::blas_set_num_threads(blas_threads_keep)
})
}

x = private$preprocess(x)
if (self$global_bias != 0. && private$feedback == "explicit")
x@x = x@x - self$global_bias

if (private$precision == "double") {
res = matrix(0, nrow = private$rank, ncol = ncol(x))
} else {
Expand All @@ -391,7 +429,7 @@ WRMF = R6::R6Class(
if (private$with_user_item_bias) {
res[1, ] = if(private$precision == "double") 1.0 else float::fl(1.0)
}

logger$debug('starting transform solver')
loss = private$solver(
x,
self$components,
Expand All @@ -401,42 +439,17 @@ WRMF = R6::R6Class(
cnt_X = private$cnt_u,
avoid_cg = TRUE
)
logger$debug('finished transform solver')

res = t(res)

if (private$precision == "double")
setattr(res, "dimnames", list(colnames(x), NULL))
else
setattr(res@Data, "dimnames", list(colnames(x), NULL))

logger$debug('finished transform')
res
}
),
#### private -----
private = list(
solver_code = NULL,
cg_steps = NULL,
scorers = NULL,
lambda = NULL,
dynamic_lambda = FALSE,
rank = NULL,
non_negative = NULL,
cnt_u = NULL,
# user factor matrix = rank * n_users
U = NULL,
# item factor matrix = rank * n_items
I = NULL,
# preprocess - transformation of input matrix before passing it to ALS
# for example we can scale each row or apply log() to values
# this is essentially "confidence" transformation from WRMF article
preprocess = NULL,
feedback = NULL,
precision = NULL,
XtX = NULL,
solver = NULL,
with_user_item_bias = NULL,
with_global_bias = NULL,
init_user_item_bias = NULL
)
)

Expand Down

0 comments on commit c91b676

Please sign in to comment.