Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor of get_tune_schedule() #978

Open
wants to merge 7 commits into
base: tune-schedule
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion R/0_imports.R
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ utils::globalVariables(
"rowwise", ".best", "location", "msg", "..object", ".eval_time",
".pred_survival", ".pred_time", ".weight_censored", "nice_time",
"time_metric", ".lower", ".upper", "i", "results", "term", ".alpha",
".method", "old_term", ".lab_pre", ".model", ".num_models", "predict_stage"
".method", "old_term", ".lab_pre", ".model", ".num_models", "model_stage",
"predict_stage"
)
)

Expand Down
199 changes: 88 additions & 111 deletions R/schedule.R
Original file line number Diff line number Diff line change
Expand Up @@ -21,132 +21,109 @@ get_tune_schedule <- function(wflow, param, grid) {
cli::cli_abort("Argument {.arg grid} must be a tibble.")
}

# ----------------------------------------------------------------------------
# Get information on the parameters associated with the supervised model
# Which parameter belongs to which stage and which is a submodel parameter?
param_info <- get_param_info(wflow)

model_spec <- extract_spec_parsnip(wflow)
model_type <- class(model_spec)[1]
model_eng <- model_spec$engine

# Which, if any, is a submodel
model_param <- parsnip::get_from_env(paste0(model_type, "_args")) %>%
dplyr::filter(engine == model_spec$engine) %>%
dplyr::select(name = parsnip, has_submodel)

# Merge the info in with the other parameters
param <- dplyr::left_join(param, model_param, by = "name") %>%
dplyr::mutate(
has_submodel = dplyr::if_else(is.na(has_submodel), FALSE, has_submodel)
)

# ------------------------------------------------------------------------------
# Get tuning parameter IDs for each stage of the workflow

if (any(param$source == "recipe")) {
pre_id <- param$id[param$source == "recipe"]
} else {
pre_id <- character(0)
}
schedule <- schedule_stages(grid, param_info, wflow)

if (any(param$source == "model_spec")) {
model_id <- param$id[param$source == "model_spec"]
sub_id <- param$id[param$source == "model_spec" & param$has_submodel]
non_sub_id <- param$id[param$source == "model_spec" & !param$has_submodel]
} else {
model_id <- sub_id <- non_sub_id <- character(0)
}

if (any(param$source == "tailor")) {
post_id <- param$id[param$source == "tailor"]
og_cls <- class(schedule)
if (nrow(param) == 0) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a little gross to code here but this class structure feels like a good solution (until we know that it is not)

cls <- "resample_schedule"
} else {
post_id <- character(0)
cls <- "grid_schedule"
}

ids <- list(
all = param$id,
pre = pre_id,
# All model param
model = model_id,
fits = c(pre_id, non_sub_id),
sub = sub_id,
non_sub = non_sub_id,
post = post_id
)
# convert to symbols
symbs <- purrr::map(ids, syms)

has_submodels <- length(ids$sub) > 0

# ------------------------------------------------------------------------------
# First collapse the submodel parameters (if any) and postprocessors
# TODO update this will submodels and postproc
if (has_submodels) {
sched <- grid %>%
dplyr::group_nest(!!!symbs$fits, .key = "predict_stage")
# Note 1: multi_predict() should only be triggered for a submodel parameter if
# there are multiple rows in the `predict_stage` list column. i.e. the submodel
# column will always be there but we only multipredict when there are 2+
# values to predict.

# Note 2: The purpose of min_grid() is to determine the minimum grid for
# preprocessing and model parameters to fit. We compute it here and ignore
# any postprocessing tuning parmeters (if any). The postprocessing parameters
# will still be in the schedule since we schedule those before the results
# that use min_grid() are merged in. See issue #975 for an example and
# discussion.
first_loop_info <-
min_grid(model_spec,
grid %>%
dplyr::select(-dplyr::any_of(post_id)) %>%
dplyr::distinct())
} else {
sched <- grid %>%
dplyr::group_nest(!!!symbs$fits, .key = "predict_stage")
first_loop_info <- grid %>% dplyr::select(!!!symbs$fits)
if (nrow(grid) == 1) {
cls <- c("single_schedule", cls)
}

first_loop_info <- first_loop_info %>%
dplyr::select(!!!c(symbs$pre, symbs$model)) %>%
dplyr::distinct()
class(schedule) <- c(cls, "schedule", og_cls)

# ------------------------------------------------------------------------------
# Add info an any postprocessing parameters
schedule
}

sched <- sched %>%
dplyr::mutate(
predict_stage = purrr::map(
predict_stage,
~.x %>% dplyr::group_nest(!!!symbs$sub, .key = "post_stage")
)
)
schedule_stages <- function(grid, param_info, wflow) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need param_info as an argument? Since it is created by get_param_info(), we could call that immediately with wflow.

# schedule preprocessing stage and push the rest into a nested tibble
param_pre_stage <- param_info %>%
filter(source == "recipe") %>%
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
filter(source == "recipe") %>%
dplyr::filter(source == "recipe") %>%

pull(id)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
pull(id)
dplyr::pull(id)

schedule <- grid %>%
tidyr::nest(.by = all_of(param_pre_stage), .key = "model_stage")

# schedule next stages recursively
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apologies if this feels like a nit, but I struggled to wrap my head around this code a bit longer than I might've otherwise trying to work this comment into my mental model—is there actually any recursion happening in this code?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A) > hit me with your nits

B) Not a nit, but a valuable comment! You're right, I supposed it's not quite the right word. What would you call it? Something with "nested"? Or just "schedule next stages within `schedule_model_stage_i()"? I just want to give the reader a heads-up that all stages will be taken care of, even though you can only "see" scheduling the immediate next stage from that point in the code.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I see what you're trying to call out! Maybe "nested iteration"? Or possibly just point out "each model stage will also iterate across predict and post stages"

schedule %>%
mutate(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
mutate(
dplyr::mutate(

model_stage =
purrr::map(
model_stage,
schedule_model_stage_i,
param_info = param_info,
wflow = wflow
)
)
}

# ------------------------------------------------------------------------------
# Merge in submodel fit value (if any)
schedule_model_stage_i <- function(model_stage, param_info, wflow){
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is so so smart and i love it

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Max worked it out and now I got to let it shine 😄

model_param <- param_info %>%
filter(source == "model_spec") %>%
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
filter(source == "model_spec") %>%
dplyr::filter(source == "model_spec") %>%

pull(id)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
pull(id)
dplyr::pull(id)

non_submodel_param <- param_info %>%
filter(source == "model_spec" & !has_submodel) %>%
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
filter(source == "model_spec" & !has_submodel) %>%
dplyr::filter(source == "model_spec" & !has_submodel) %>%

pull(id)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
pull(id)
dplyr::pull(id)


# schedule model parameters
schedule <- min_model_grid(model_stage, model_param, wflow)

# push remaining paramters into the next stage
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# push remaining paramters into the next stage
# push remaining parameters into the next stage

next_stage <- model_stage %>%
tidyr::nest(.by = all_of(non_submodel_param), .key = "predict_stage")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
tidyr::nest(.by = all_of(non_submodel_param), .key = "predict_stage")
tidyr::nest(.by = dplyr::all_of(non_submodel_param), .key = "predict_stage")

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, I think that you get it at this point. This is all just protection for being invoked inside of worker processes.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for all the code suggestions! I find it totally okay for the reviewer to do one example and then hand it back to whoever opened the PR.

About the worker process: Do we expect to execute get_tune_schedule() in parallel? My understanding is that we first schedule and then send todos off to workers, i.e. we would not expect to call get_tune_schedule() in parallel, no?


schedule <- schedule %>%
dplyr::left_join(next_stage, by = all_of(non_submodel_param))

# schedule next stages recursively
schedule %>%
mutate(
predict_stage =
purrr::map(predict_stage, schedule_predict_stage_i, param_info = param_info)
)
}

loop_names <- names(sched)[names(sched) != "predict_stage"]
if (length(loop_names) > 0) {
# Using `by = character()` to perform a cross join was deprecated
sched <- dplyr::full_join(sched, first_loop_info, by = loop_names)
}
min_model_grid <- function(grid, model_param, wflow){
# work on only the model parameters
model_grid <- grid %>%
select(all_of(model_param)) %>%
dplyr::distinct()

min_grid(
extract_spec_parsnip(wflow),
model_grid
) %>%
select(all_of(model_param))
}

# ------------------------------------------------------------------------------
# Now collapse over the preprocessor for conditional execution
schedule_predict_stage_i <- function(predict_stage, param_info) {
submodel_param <- param_info %>%
filter(source == "model_spec" & has_submodel) %>%
pull(id)

sched <- sched %>% dplyr::group_nest(!!!symbs$pre, .key = "model_stage")
predict_stage %>%
tidyr::nest(.by = all_of(submodel_param), .key = "post_stage")
}

# ------------------------------------------------------------------------------
get_param_info <- function(wflow) {
param_info <- tune_args(wflow) %>%
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using tune_args() here instead of a parameter set object, due to considerations I've put in #974 (comment)

select(name, id, source)

og_cls <- class(sched)
if (nrow(param) == 0) {
cls <- "resample_schedule"
} else {
cls <- "grid_schedule"
}
model_spec <- extract_spec_parsnip(wflow)
model_type <- class(model_spec)[1]
model_eng <- model_spec$engine

if (nrow(grid) == 1) {
cls <- c("single_schedule", cls)
}
model_param <- parsnip::get_from_env(paste0(model_type, "_args")) %>%
dplyr::filter(engine == model_spec$engine) %>%
dplyr::select(name = parsnip, has_submodel)

class(sched) <- c(cls, "schedule", og_cls)
sched
param_info <- dplyr::left_join(param_info, model_param, by = "name")

param_info
}
62 changes: 33 additions & 29 deletions tests/testthat/helper-tune-package.R
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
suppressPackageStartupMessages(library(workflows))
suppressPackageStartupMessages(library(parsnip))
suppressPackageStartupMessages(library(recipes))
suppressPackageStartupMessages(library(dials))
suppressPackageStartupMessages(library(tailor))
suppressPackageStartupMessages(library(purrr))
# suppressPackageStartupMessages(library(workflows))
# suppressPackageStartupMessages(library(parsnip))
# suppressPackageStartupMessages(library(recipes))
# suppressPackageStartupMessages(library(dials))
# suppressPackageStartupMessages(library(tailor))
# suppressPackageStartupMessages(library(purrr))


# NOTE namsespacing is required to make this file load properly in the testthat machinery


new_rng_snapshots <- utils::compareVersion("3.6.0", as.character(getRversion())) > 0

Expand All @@ -13,18 +17,18 @@ rankdeficient_version <- any(names(formals("predict.lm")) == "rankdeficient")

helper_objects_tune <- function() {
rec_tune_1 <-
recipe(mpg ~ ., data = mtcars) %>%
step_normalize(all_predictors()) %>%
step_pca(all_predictors(), num_comp = tune())
recipes::recipe(mpg ~ ., data = mtcars) %>%
recipes::step_normalize(all_predictors()) %>%
recipes::step_pca(all_predictors(), num_comp = tune())

rec_no_tune_1 <-
recipe(mpg ~ ., data = mtcars) %>%
step_normalize(all_predictors())
recipes::recipe(mpg ~ ., data = mtcars) %>%
recipes::step_normalize(all_predictors())

lm_mod <- linear_reg() %>% set_engine("lm")
lm_mod <- parsnip::linear_reg() %>% parsnip::set_engine("lm")

svm_mod <- svm_rbf(mode = "regression", cost = tune()) %>%
set_engine("kernlab")
svm_mod <- parsnip::svm_rbf(mode = "regression", cost = tune()) %>%
parsnip::set_engine("kernlab")

list(
rec_tune_1 = rec_tune_1,
Expand Down Expand Up @@ -83,33 +87,33 @@ redefer_initialize_catalog <- function(test_env) {

if (rlang::is_installed("splines2")) {
rec_df <-
recipe(mpg ~ ., data = mtcars) %>%
step_corr(all_predictors(), threshold = .1) %>%
step_spline_natural(disp, deg_free = 5)
recipes::recipe(mpg ~ ., data = mtcars) %>%
recipes::step_corr(all_predictors(), threshold = .1) %>%
recipes::step_spline_natural(disp, deg_free = 5)

rec_tune_thrsh_df <-
recipe(mpg ~ ., data = mtcars) %>%
step_corr(all_predictors(), threshold = tune()) %>%
step_spline_natural(disp, deg_free = tune("disp_df"))
recipes::recipe(mpg ~ ., data = mtcars) %>%
recipes::step_corr(all_predictors(), threshold = tune()) %>%
recipes::step_spline_natural(disp, deg_free = tune("disp_df"))
}



mod_tune_bst <- boost_tree(trees = tune(), min_n = tune(), mode = "regression")
mod_tune_rf <- rand_forest(min_n = tune(), mode = "regression")
mod_tune_bst <- parsnip::boost_tree(trees = tune(), min_n = tune(), mode = "regression")
mod_tune_rf <- parsnip::rand_forest(min_n = tune(), mode = "regression")

if (rlang::is_installed("probably")) {

adjust_tune_min <-
tailor() %>%
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that we usually use rec in the name of recipes objects, I would like to advocate for calling tailor objects something with tailor rather than adjust_.

adjust_numeric_range(lower_limit = tune())
tailor::tailor() %>%
tailor::adjust_numeric_range(lower_limit = tune())

adjust_cal_tune_min <-
tailor() %>%
adjust_numeric_calibration(method = "linear") %>%
adjust_numeric_range(lower_limit = tune())
tailor::tailor() %>%
tailor::adjust_numeric_calibration(method = "linear") %>%
tailor::adjust_numeric_range(lower_limit = tune())

adjust_min <-
tailor() %>%
adjust_numeric_range(lower_limit = 0)
tailor::tailor() %>%
tailor::adjust_numeric_range(lower_limit = 0)
}
Loading
Loading