-
Notifications
You must be signed in to change notification settings - Fork 42
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor of get_tune_schedule()
#978
base: tune-schedule
Are you sure you want to change the base?
Changes from all commits
2fbcbfb
108d94f
f39824b
88279fe
64aadb3
450a182
ee76649
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -21,132 +21,109 @@ get_tune_schedule <- function(wflow, param, grid) { | |||||
cli::cli_abort("Argument {.arg grid} must be a tibble.") | ||||||
} | ||||||
|
||||||
# ---------------------------------------------------------------------------- | ||||||
# Get information on the parameters associated with the supervised model | ||||||
# Which parameter belongs to which stage and which is a submodel parameter? | ||||||
param_info <- get_param_info(wflow) | ||||||
|
||||||
model_spec <- extract_spec_parsnip(wflow) | ||||||
model_type <- class(model_spec)[1] | ||||||
model_eng <- model_spec$engine | ||||||
|
||||||
# Which, if any, is a submodel | ||||||
model_param <- parsnip::get_from_env(paste0(model_type, "_args")) %>% | ||||||
dplyr::filter(engine == model_spec$engine) %>% | ||||||
dplyr::select(name = parsnip, has_submodel) | ||||||
|
||||||
# Merge the info in with the other parameters | ||||||
param <- dplyr::left_join(param, model_param, by = "name") %>% | ||||||
dplyr::mutate( | ||||||
has_submodel = dplyr::if_else(is.na(has_submodel), FALSE, has_submodel) | ||||||
) | ||||||
|
||||||
# ------------------------------------------------------------------------------ | ||||||
# Get tuning parameter IDs for each stage of the workflow | ||||||
|
||||||
if (any(param$source == "recipe")) { | ||||||
pre_id <- param$id[param$source == "recipe"] | ||||||
} else { | ||||||
pre_id <- character(0) | ||||||
} | ||||||
schedule <- schedule_stages(grid, param_info, wflow) | ||||||
|
||||||
if (any(param$source == "model_spec")) { | ||||||
model_id <- param$id[param$source == "model_spec"] | ||||||
sub_id <- param$id[param$source == "model_spec" & param$has_submodel] | ||||||
non_sub_id <- param$id[param$source == "model_spec" & !param$has_submodel] | ||||||
} else { | ||||||
model_id <- sub_id <- non_sub_id <- character(0) | ||||||
} | ||||||
|
||||||
if (any(param$source == "tailor")) { | ||||||
post_id <- param$id[param$source == "tailor"] | ||||||
og_cls <- class(schedule) | ||||||
if (nrow(param) == 0) { | ||||||
cls <- "resample_schedule" | ||||||
} else { | ||||||
post_id <- character(0) | ||||||
cls <- "grid_schedule" | ||||||
} | ||||||
|
||||||
ids <- list( | ||||||
all = param$id, | ||||||
pre = pre_id, | ||||||
# All model param | ||||||
model = model_id, | ||||||
fits = c(pre_id, non_sub_id), | ||||||
sub = sub_id, | ||||||
non_sub = non_sub_id, | ||||||
post = post_id | ||||||
) | ||||||
# convert to symbols | ||||||
symbs <- purrr::map(ids, syms) | ||||||
|
||||||
has_submodels <- length(ids$sub) > 0 | ||||||
|
||||||
# ------------------------------------------------------------------------------ | ||||||
# First collapse the submodel parameters (if any) and postprocessors | ||||||
# TODO update this will submodels and postproc | ||||||
if (has_submodels) { | ||||||
sched <- grid %>% | ||||||
dplyr::group_nest(!!!symbs$fits, .key = "predict_stage") | ||||||
# Note 1: multi_predict() should only be triggered for a submodel parameter if | ||||||
# there are multiple rows in the `predict_stage` list column. i.e. the submodel | ||||||
# column will always be there but we only multipredict when there are 2+ | ||||||
# values to predict. | ||||||
|
||||||
# Note 2: The purpose of min_grid() is to determine the minimum grid for | ||||||
# preprocessing and model parameters to fit. We compute it here and ignore | ||||||
# any postprocessing tuning parmeters (if any). The postprocessing parameters | ||||||
# will still be in the schedule since we schedule those before the results | ||||||
# that use min_grid() are merged in. See issue #975 for an example and | ||||||
# discussion. | ||||||
first_loop_info <- | ||||||
min_grid(model_spec, | ||||||
grid %>% | ||||||
dplyr::select(-dplyr::any_of(post_id)) %>% | ||||||
dplyr::distinct()) | ||||||
} else { | ||||||
sched <- grid %>% | ||||||
dplyr::group_nest(!!!symbs$fits, .key = "predict_stage") | ||||||
first_loop_info <- grid %>% dplyr::select(!!!symbs$fits) | ||||||
if (nrow(grid) == 1) { | ||||||
cls <- c("single_schedule", cls) | ||||||
} | ||||||
|
||||||
first_loop_info <- first_loop_info %>% | ||||||
dplyr::select(!!!c(symbs$pre, symbs$model)) %>% | ||||||
dplyr::distinct() | ||||||
class(schedule) <- c(cls, "schedule", og_cls) | ||||||
|
||||||
# ------------------------------------------------------------------------------ | ||||||
# Add info an any postprocessing parameters | ||||||
schedule | ||||||
} | ||||||
|
||||||
sched <- sched %>% | ||||||
dplyr::mutate( | ||||||
predict_stage = purrr::map( | ||||||
predict_stage, | ||||||
~.x %>% dplyr::group_nest(!!!symbs$sub, .key = "post_stage") | ||||||
) | ||||||
) | ||||||
schedule_stages <- function(grid, param_info, wflow) { | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need |
||||||
# schedule preprocessing stage and push the rest into a nested tibble | ||||||
param_pre_stage <- param_info %>% | ||||||
filter(source == "recipe") %>% | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
pull(id) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
schedule <- grid %>% | ||||||
tidyr::nest(.by = all_of(param_pre_stage), .key = "model_stage") | ||||||
|
||||||
# schedule next stages recursively | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Apologies if this feels like a nit, but I struggled to wrap my head around this code a bit longer than I might've otherwise trying to work this comment into my mental model—is there actually any recursion happening in this code? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A) > hit me with your nits B) Not a nit, but a valuable comment! You're right, I supposed it's not quite the right word. What would you call it? Something with "nested"? Or just "schedule next stages within `schedule_model_stage_i()"? I just want to give the reader a heads-up that all stages will be taken care of, even though you can only "see" scheduling the immediate next stage from that point in the code. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, I see what you're trying to call out! Maybe "nested iteration"? Or possibly just point out "each model stage will also iterate across predict and post stages" |
||||||
schedule %>% | ||||||
mutate( | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
model_stage = | ||||||
purrr::map( | ||||||
model_stage, | ||||||
schedule_model_stage_i, | ||||||
param_info = param_info, | ||||||
wflow = wflow | ||||||
) | ||||||
) | ||||||
} | ||||||
|
||||||
# ------------------------------------------------------------------------------ | ||||||
# Merge in submodel fit value (if any) | ||||||
schedule_model_stage_i <- function(model_stage, param_info, wflow){ | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is so so smart and i love it There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Max worked it out and now I got to let it shine 😄 |
||||||
model_param <- param_info %>% | ||||||
filter(source == "model_spec") %>% | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
pull(id) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
non_submodel_param <- param_info %>% | ||||||
filter(source == "model_spec" & !has_submodel) %>% | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
pull(id) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
|
||||||
# schedule model parameters | ||||||
schedule <- min_model_grid(model_stage, model_param, wflow) | ||||||
|
||||||
# push remaining paramters into the next stage | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
next_stage <- model_stage %>% | ||||||
tidyr::nest(.by = all_of(non_submodel_param), .key = "predict_stage") | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok, I think that you get it at this point. This is all just protection for being invoked inside of worker processes. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for all the code suggestions! I find it totally okay for the reviewer to do one example and then hand it back to whoever opened the PR. About the worker process: Do we expect to execute |
||||||
|
||||||
schedule <- schedule %>% | ||||||
dplyr::left_join(next_stage, by = all_of(non_submodel_param)) | ||||||
|
||||||
# schedule next stages recursively | ||||||
schedule %>% | ||||||
mutate( | ||||||
predict_stage = | ||||||
purrr::map(predict_stage, schedule_predict_stage_i, param_info = param_info) | ||||||
) | ||||||
} | ||||||
|
||||||
loop_names <- names(sched)[names(sched) != "predict_stage"] | ||||||
if (length(loop_names) > 0) { | ||||||
# Using `by = character()` to perform a cross join was deprecated | ||||||
sched <- dplyr::full_join(sched, first_loop_info, by = loop_names) | ||||||
} | ||||||
min_model_grid <- function(grid, model_param, wflow){ | ||||||
# work on only the model parameters | ||||||
model_grid <- grid %>% | ||||||
select(all_of(model_param)) %>% | ||||||
dplyr::distinct() | ||||||
|
||||||
min_grid( | ||||||
extract_spec_parsnip(wflow), | ||||||
model_grid | ||||||
) %>% | ||||||
select(all_of(model_param)) | ||||||
} | ||||||
|
||||||
# ------------------------------------------------------------------------------ | ||||||
# Now collapse over the preprocessor for conditional execution | ||||||
schedule_predict_stage_i <- function(predict_stage, param_info) { | ||||||
submodel_param <- param_info %>% | ||||||
filter(source == "model_spec" & has_submodel) %>% | ||||||
pull(id) | ||||||
|
||||||
sched <- sched %>% dplyr::group_nest(!!!symbs$pre, .key = "model_stage") | ||||||
predict_stage %>% | ||||||
tidyr::nest(.by = all_of(submodel_param), .key = "post_stage") | ||||||
} | ||||||
|
||||||
# ------------------------------------------------------------------------------ | ||||||
get_param_info <- function(wflow) { | ||||||
param_info <- tune_args(wflow) %>% | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Using |
||||||
select(name, id, source) | ||||||
|
||||||
og_cls <- class(sched) | ||||||
if (nrow(param) == 0) { | ||||||
cls <- "resample_schedule" | ||||||
} else { | ||||||
cls <- "grid_schedule" | ||||||
} | ||||||
model_spec <- extract_spec_parsnip(wflow) | ||||||
model_type <- class(model_spec)[1] | ||||||
model_eng <- model_spec$engine | ||||||
|
||||||
if (nrow(grid) == 1) { | ||||||
cls <- c("single_schedule", cls) | ||||||
} | ||||||
model_param <- parsnip::get_from_env(paste0(model_type, "_args")) %>% | ||||||
dplyr::filter(engine == model_spec$engine) %>% | ||||||
dplyr::select(name = parsnip, has_submodel) | ||||||
|
||||||
class(sched) <- c(cls, "schedule", og_cls) | ||||||
sched | ||||||
param_info <- dplyr::left_join(param_info, model_param, by = "name") | ||||||
|
||||||
param_info | ||||||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,13 @@ | ||
suppressPackageStartupMessages(library(workflows)) | ||
suppressPackageStartupMessages(library(parsnip)) | ||
suppressPackageStartupMessages(library(recipes)) | ||
suppressPackageStartupMessages(library(dials)) | ||
suppressPackageStartupMessages(library(tailor)) | ||
suppressPackageStartupMessages(library(purrr)) | ||
# suppressPackageStartupMessages(library(workflows)) | ||
# suppressPackageStartupMessages(library(parsnip)) | ||
# suppressPackageStartupMessages(library(recipes)) | ||
# suppressPackageStartupMessages(library(dials)) | ||
# suppressPackageStartupMessages(library(tailor)) | ||
# suppressPackageStartupMessages(library(purrr)) | ||
|
||
|
||
# NOTE namsespacing is required to make this file load properly in the testthat machinery | ||
|
||
|
||
new_rng_snapshots <- utils::compareVersion("3.6.0", as.character(getRversion())) > 0 | ||
|
||
|
@@ -13,18 +17,18 @@ rankdeficient_version <- any(names(formals("predict.lm")) == "rankdeficient") | |
|
||
helper_objects_tune <- function() { | ||
rec_tune_1 <- | ||
recipe(mpg ~ ., data = mtcars) %>% | ||
step_normalize(all_predictors()) %>% | ||
step_pca(all_predictors(), num_comp = tune()) | ||
recipes::recipe(mpg ~ ., data = mtcars) %>% | ||
recipes::step_normalize(all_predictors()) %>% | ||
recipes::step_pca(all_predictors(), num_comp = tune()) | ||
|
||
rec_no_tune_1 <- | ||
recipe(mpg ~ ., data = mtcars) %>% | ||
step_normalize(all_predictors()) | ||
recipes::recipe(mpg ~ ., data = mtcars) %>% | ||
recipes::step_normalize(all_predictors()) | ||
|
||
lm_mod <- linear_reg() %>% set_engine("lm") | ||
lm_mod <- parsnip::linear_reg() %>% parsnip::set_engine("lm") | ||
|
||
svm_mod <- svm_rbf(mode = "regression", cost = tune()) %>% | ||
set_engine("kernlab") | ||
svm_mod <- parsnip::svm_rbf(mode = "regression", cost = tune()) %>% | ||
parsnip::set_engine("kernlab") | ||
|
||
list( | ||
rec_tune_1 = rec_tune_1, | ||
|
@@ -83,33 +87,33 @@ redefer_initialize_catalog <- function(test_env) { | |
|
||
if (rlang::is_installed("splines2")) { | ||
rec_df <- | ||
recipe(mpg ~ ., data = mtcars) %>% | ||
step_corr(all_predictors(), threshold = .1) %>% | ||
step_spline_natural(disp, deg_free = 5) | ||
recipes::recipe(mpg ~ ., data = mtcars) %>% | ||
recipes::step_corr(all_predictors(), threshold = .1) %>% | ||
recipes::step_spline_natural(disp, deg_free = 5) | ||
|
||
rec_tune_thrsh_df <- | ||
recipe(mpg ~ ., data = mtcars) %>% | ||
step_corr(all_predictors(), threshold = tune()) %>% | ||
step_spline_natural(disp, deg_free = tune("disp_df")) | ||
recipes::recipe(mpg ~ ., data = mtcars) %>% | ||
recipes::step_corr(all_predictors(), threshold = tune()) %>% | ||
recipes::step_spline_natural(disp, deg_free = tune("disp_df")) | ||
} | ||
|
||
|
||
|
||
mod_tune_bst <- boost_tree(trees = tune(), min_n = tune(), mode = "regression") | ||
mod_tune_rf <- rand_forest(min_n = tune(), mode = "regression") | ||
mod_tune_bst <- parsnip::boost_tree(trees = tune(), min_n = tune(), mode = "regression") | ||
mod_tune_rf <- parsnip::rand_forest(min_n = tune(), mode = "regression") | ||
|
||
if (rlang::is_installed("probably")) { | ||
|
||
adjust_tune_min <- | ||
tailor() %>% | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Given that we usually use |
||
adjust_numeric_range(lower_limit = tune()) | ||
tailor::tailor() %>% | ||
tailor::adjust_numeric_range(lower_limit = tune()) | ||
|
||
adjust_cal_tune_min <- | ||
tailor() %>% | ||
adjust_numeric_calibration(method = "linear") %>% | ||
adjust_numeric_range(lower_limit = tune()) | ||
tailor::tailor() %>% | ||
tailor::adjust_numeric_calibration(method = "linear") %>% | ||
tailor::adjust_numeric_range(lower_limit = tune()) | ||
|
||
adjust_min <- | ||
tailor() %>% | ||
adjust_numeric_range(lower_limit = 0) | ||
tailor::tailor() %>% | ||
tailor::adjust_numeric_range(lower_limit = 0) | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a little gross to code here but this class structure feels like a good solution (until we know that it is not)