Skip to content

Commit

Permalink
document should_use_sparsity()
Browse files Browse the repository at this point in the history
  • Loading branch information
EmilHvitfeldt committed Jan 17, 2025
1 parent 96e1349 commit 3836124
Showing 1 changed file with 23 additions and 4 deletions.
27 changes: 23 additions & 4 deletions R/sparsevctrs.R
Original file line number Diff line number Diff line change
@@ -34,8 +34,27 @@ allow_sparse <- function(x) {
all(res$allow_sparse_x[res$engine == x$engine])
}

should_use_sparsity <- function(sparsity, model, n_rows) {
if (is.null(model) || model == "ranger") {
# This function was created using from the output of a mars model fit on the
# simulation data generated in `analysis/time_analysis.R`
# https://github.com/tidymodels/benchmark-sparsity-threshold
#
# The model was extracted using {tidypredict} and hand-tuned for speed.
#
# The model was fit on `sparsity`, `engine` and `n_rows` and the outcome was
# `log_fold` which is defined as
# `log(time to fit with dense data / time to fit with sparse data)`.
# Meaning that values above above 0 would reflects longer fit times for dense,
# Hence we want to use sparse data.
#
# At this time the only engines that support sparse data are glmnet, LiblineaR,
# ranger, and xgboost. Which is why they are the only ones listed here.
# This is fine as this code will only run if `allow_sparse()` returns `TRUE`
# Which only happens for these engines.
#
# Ranger is hard-coded to always fail since they appear to use the same
# algorithm for sparse and dense data, resulting in identical times.
should_use_sparsity <- function(sparsity, engine, n_rows) {
if (is.null(engine) || engine == "ranger") {
return("no")

Check warning on line 58 in R/sparsevctrs.R

Codecov / codecov/patch

R/sparsevctrs.R#L58

Added line #L58 was not covered by tests
}

@@ -53,7 +72,7 @@ should_use_sparsity <- function(sparsity, model, n_rows) {
ifelse(n_rows < 8000, 8000 - n_rows, 0) *
-0.000798307404212627

if (model == "xgboost") {
if (engine == "xgboost") {
log_fold <- log_fold +
ifelse(sparsity < 0.984615384615385, 0.984615384615385 - sparsity, 0) *
0.113098025073806 +
@@ -64,7 +83,7 @@ should_use_sparsity <- function(sparsity, model, n_rows) {
0.913457808326756

Check warning on line 83 in R/sparsevctrs.R

Codecov / codecov/patch

R/sparsevctrs.R#L76-L83

Added lines #L76 - L83 were not covered by tests
}

if (model == "LiblineaR") {
if (engine == "LiblineaR") {
log_fold <- log_fold +
ifelse(sparsity > 0.836601307189543, sparsity - 0.836601307189543, 0) *
-5.39592564852111

Check warning on line 89 in R/sparsevctrs.R

Codecov / codecov/patch

R/sparsevctrs.R#L87-L89

Added lines #L87 - L89 were not covered by tests

0 comments on commit 3836124

Please sign in to comment.