-
Notifications
You must be signed in to change notification settings - Fork 42
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
enable tuning postprocessors #966
base: main
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,6 @@ | ||
|
||
predict_model <- function(new_data, orig_rows, workflow, grid, metrics, | ||
submodels = NULL, metrics_info, eval_time = NULL) { | ||
|
||
model <- extract_fit_parsnip(workflow) | ||
|
||
forged <- forge_from_workflow(new_data, workflow) | ||
|
@@ -260,6 +259,22 @@ finalize_workflow_preprocessor <- function(workflow, grid_preprocessor) { | |
workflow | ||
} | ||
|
||
#' @export | ||
#' @rdname tune-internal-functions | ||
finalize_workflow_postprocessor <- function(workflow, grid_postprocessor) { | ||
# Already finalized, nothing to tune | ||
if (ncol(grid_postprocessor) == 0L) { | ||
return(workflow) | ||
} | ||
|
||
postprocessor <- workflows::extract_postprocessor(workflow) | ||
postprocessor <- merge(postprocessor, grid_postprocessor)$x[[1]] | ||
|
||
workflow <- set_workflow_tailor(workflow, postprocessor) | ||
|
||
workflow | ||
} | ||
|
||
# ------------------------------------------------------------------------------ | ||
|
||
# For any type of tuning, and for fit-resamples, we generate a unified | ||
|
@@ -310,16 +325,20 @@ compute_grid_info <- function(workflow, grid) { | |
grid <- tibble::as_tibble(grid) | ||
|
||
parameters <- hardhat::extract_parameter_set_dials(workflow) | ||
parameters_model <- dplyr::filter(parameters, source == "model_spec") | ||
|
||
parameters_preprocessor <- dplyr::filter(parameters, source == "recipe") | ||
parameters_model <- dplyr::filter(parameters, source == "model_spec") | ||
parameters_postprocessor <- dplyr::filter(parameters, source == "tailor") | ||
|
||
any_parameters_model <- nrow(parameters_model) > 0 | ||
any_parameters_preprocessor <- nrow(parameters_preprocessor) > 0 | ||
|
||
res <- min_grid(extract_spec_parsnip(workflow), grid) | ||
any_parameters_model <- nrow(parameters_model) > 0 | ||
any_parameters_postprocessor <- nrow(parameters_postprocessor) > 0 | ||
|
||
syms_pre <- rlang::syms(parameters_preprocessor$id) | ||
syms_mod <- rlang::syms(parameters_model$id) | ||
syms_post <- rlang::syms(parameters_postprocessor$id) | ||
|
||
res <- min_grid(extract_spec_parsnip(workflow), grid) | ||
|
||
# ---------------------------------------------------------------------------- | ||
# Create an order of execution to train the preprocessor (if any). This will | ||
|
@@ -340,7 +359,7 @@ compute_grid_info <- function(workflow, grid) { | |
res$.lab_pre <- "Preprocessor1" | ||
} | ||
|
||
# Make the label shown in the grid and in loggining | ||
# Make the label shown in the grid and in logging | ||
res$.msg_preprocessor <- | ||
new_msgs_preprocessor( | ||
res$.iter_preprocessor, | ||
|
@@ -351,6 +370,17 @@ compute_grid_info <- function(workflow, grid) { | |
# Now make a similar iterator across models. Conditioning on each unique | ||
# preprocessing candidate set, make an iterator for the model candidate sets | ||
# (if any) | ||
if (any_parameters_postprocessor) { | ||
# Ensure that the submodel trick kicks in by temporarily nesting the | ||
# postprocessor parameters while iterating in the model grid | ||
# TODO: will this introduce issues when there are matching postprocessor | ||
# values across models? | ||
# ... i think we actually want to (temporarily?) situate these as submodels | ||
res <- tidyr::nest( | ||
res, | ||
.data_post = all_of(parameters_postprocessor$id) | ||
) | ||
} | ||
|
||
res <- | ||
res %>% | ||
|
@@ -370,9 +400,28 @@ compute_grid_info <- function(workflow, grid) { | |
n = res$.num_models, | ||
res$.msg_preprocessor) | ||
|
||
res %>% | ||
res <- res %>% | ||
dplyr::select(-.num_models) %>% | ||
dplyr::relocate(dplyr::starts_with(".msg")) | ||
|
||
# ---------------------------------------------------------------------------- | ||
# Finally, iterate across postprocessors. Conditioning on an .iter_config, | ||
# make an iterator for each postprocessing candidate set (if any). | ||
if (!any_parameters_postprocessor) { | ||
return(res) | ||
} | ||
|
||
res <- | ||
res %>% | ||
dplyr::group_nest(.iter_config, keep = TRUE) %>% | ||
dplyr::mutate( | ||
data = purrr::map(data, make_iter_postprocessor) | ||
) %>% | ||
tidyr::unnest(cols = data) %>% | ||
dplyr::relocate(dplyr::starts_with(".iter"), dplyr::starts_with(".msg")) %>% | ||
tidyr::unnest(.data_post) | ||
|
||
res | ||
} | ||
|
||
make_iter_config <- function(dat) { | ||
|
@@ -385,6 +434,32 @@ make_iter_config <- function(dat) { | |
tibble::tibble(.iter_config = .iter_config) | ||
} | ||
|
||
make_iter_postprocessor <- function(data) { | ||
data %>% | ||
mutate( | ||
.iter_postprocessor = seq_len(nrow(data)), | ||
.msg_postprocessor = new_msgs_postprocessor( | ||
i = .iter_postprocessor, | ||
n = max(.iter_postprocessor), | ||
msgs_model = .msg_model | ||
), | ||
.iter_config_post = purrr::map2( | ||
.iter_config, | ||
.iter_postprocessor, | ||
make_iter_config_post | ||
) | ||
) %>% | ||
select(-.iter_config) | ||
} | ||
|
||
make_iter_config_post <- function(iter_config, iter_postprocessor) { | ||
paste0( | ||
iter_config, | ||
"_Postprocessor", | ||
iter_postprocessor | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Still needs a |
||
) | ||
} | ||
|
||
# This generates a "dummy" grid_info object that has the same | ||
# structure as a grid-info object with no tunable recipe parameters | ||
# and no tunable model parameters. | ||
|
@@ -420,6 +495,9 @@ new_msgs_preprocessor <- function(i, n) { | |
new_msgs_model <- function(i, n, msgs_preprocessor) { | ||
paste0(msgs_preprocessor, ", model ", i, "/", n) | ||
} | ||
new_msgs_postprocessor <- function(i, n, msgs_model) { | ||
paste0(msgs_model, ", postprocessor ", i, "/", n) | ||
} | ||
|
||
# c(1, 10) -> c("01", "10") | ||
format_with_padding <- function(x) { | ||
|
@@ -467,3 +545,8 @@ set_workflow_recipe <- function(workflow, recipe) { | |
workflow$pre$actions$recipe$recipe <- recipe | ||
workflow | ||
} | ||
|
||
set_workflow_tailor <- function(workflow, tailor) { | ||
workflow$post$actions$tailor$tailor <- tailor | ||
workflow | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -326,3 +326,13 @@ min_grid.pls <- fit_max_value | |
#' @export min_grid.poisson_reg | ||
#' @rdname min_grid | ||
min_grid.poisson_reg <- fit_max_value | ||
|
||
|
||
# When `min_grid()` is applied to grids with additional columns for | ||
# postprocessors, we need to nest the postprocessor columns into | ||
# .submodels to effectively enable the submodel trick. | ||
# See: https://gist.github.com/simonpcouch/28d984cdcc3fc6d22ff776ed8740004e | ||
nest_min_grid <- function(min_grid, post_params) { | ||
# TODO | ||
min_grid | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I thiiink this will be our entry point; patch There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The postprocessing object will only be available from within a workflow. We could make a workflow method for It might be possible to execute There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wonder if a join, like you suggest, would be the more effective way to make this happen. Right now, my theory is that we can take the output from This is gnaaaarly, though. |
||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My first pass at enabling the submodel trick—nest by the postprocessor and unnest later—but that doesn't quite do the job.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you expand on this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See the linked notes above!