Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enable tuning postprocessors #966

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 56 additions & 25 deletions R/grid_code_paths.R
Original file line number Diff line number Diff line change
Expand Up @@ -375,10 +375,12 @@ tune_grid_loop_iter <- function(split,

model_params <- vctrs::vec_slice(params, params$source == "model_spec")
preprocessor_params <- vctrs::vec_slice(params, params$source == "recipe")
postprocessor_params <- vctrs::vec_slice(params, params$source == "tailor")

param_names <- params$id
model_param_names <- model_params$id
preprocessor_param_names <- preprocessor_params$id
postprocessor_param_names <- postprocessor_params$id

# inline rsample::assessment so that we can pass indices to `predict_model()`
assessment_rows <- as.integer(split, data = "assessment")
Expand Down Expand Up @@ -542,34 +544,62 @@ tune_grid_loop_iter <- function(split,
# if the postprocessor does not require training, then `calibration` will
# be NULL and nothing other than the column names is learned from
# `assessment`.
workflow_with_post <- .fit_post(workflow, calibration %||% assessment)

workflow_with_post <- .fit_finalize(workflow_with_post)
# --------------------------------------------------------------------------
# Postprocessor loop
iter_postprocessors <- iter_grid_info_model[[".iter_postprocessor"]]

# run extract function on workflow with trained postprocessor
elt_extract <- .catch_and_log(
extract_details(workflow_with_post, control$extract),
control,
split_labels,
paste(iter_msg_model, "(extracts)"),
bad_only = TRUE,
notes = out_notes
)
elt_extract <- make_extracts(elt_extract, iter_grid, split_labels, .config = iter_config)
out_extracts <- append_extracts(out_extracts, elt_extract)
workflow_pre_and_model <- workflow

# generate predictions on the assessment set from the model and apply the
# post-processor to those predictions to generate updated predictions
iter_predictions <- .catch_and_log(
predict_model(assessment, assessment_rows, workflow_with_post, iter_grid,
metrics, iter_submodels, metrics_info = metrics_info,
eval_time = eval_time),
control,
split_labels,
paste(iter_msg_model, "(predictions with post-processor)"),
bad_only = TRUE,
notes = out_notes
)
for (iter_postprocessor in iter_postprocessors) {
workflow <- workflow_pre_and_model

iter_grid_info_postprocessor <- vctrs::vec_slice(
iter_grid_info_model,
iter_grid_info_model$.iter_postprocessor == iter_postprocessor
)

iter_grid_postprocessor <- iter_grid_info_postprocessor[, postprocessor_param_names]

iter_msg_postprocessor <- iter_grid_postprocessor[[".msg_postprocessor"]]
iter_config <- iter_grid_info_postprocessor[[".iter_config_post"]][[1L]]

workflow <- finalize_workflow_postprocessor(workflow, iter_grid_postprocessor)

workflow_with_post <- .fit_post(workflow, calibration %||% assessment)

workflow_with_post <- .fit_finalize(workflow_with_post)

iter_grid <- dplyr::bind_cols(
iter_grid_preprocessor,
iter_grid_model,
iter_grid_postprocessor
)

# run extract function on workflow with trained postprocessor
elt_extract <- .catch_and_log(
extract_details(workflow_with_post, control$extract),
control,
split_labels,
paste(iter_msg_model, "(extracts)"),
bad_only = TRUE,
notes = out_notes
)
elt_extract <- make_extracts(elt_extract, iter_grid, split_labels, .config = iter_config)
out_extracts <- append_extracts(out_extracts, elt_extract)

# generate predictions on the assessment set from the model and apply the
# post-processor to those predictions to generate updated predictions
iter_predictions <- .catch_and_log(
predict_model(assessment, assessment_rows, workflow_with_post, iter_grid,
metrics, iter_submodels, metrics_info = metrics_info,
eval_time = eval_time),
control,
split_labels,
paste(iter_msg_postprocessor, "(predictions with post-processor)"),
bad_only = TRUE,
notes = out_notes
)

# now, assess those predictions with performance metrics
}
Expand All @@ -595,6 +625,7 @@ tune_grid_loop_iter <- function(split,
control = control,
.config = iter_config_metrics
)
} # postprocessor loop
} # model loop
} # preprocessor loop

Expand Down
97 changes: 90 additions & 7 deletions R/grid_helpers.R
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@

predict_model <- function(new_data, orig_rows, workflow, grid, metrics,
submodels = NULL, metrics_info, eval_time = NULL) {

model <- extract_fit_parsnip(workflow)

forged <- forge_from_workflow(new_data, workflow)
Expand Down Expand Up @@ -260,6 +259,22 @@ finalize_workflow_preprocessor <- function(workflow, grid_preprocessor) {
workflow
}

#' @export
#' @rdname tune-internal-functions
finalize_workflow_postprocessor <- function(workflow, grid_postprocessor) {
# Already finalized, nothing to tune
if (ncol(grid_postprocessor) == 0L) {
return(workflow)
}

postprocessor <- workflows::extract_postprocessor(workflow)
postprocessor <- merge(postprocessor, grid_postprocessor)$x[[1]]

workflow <- set_workflow_tailor(workflow, postprocessor)

workflow
}

# ------------------------------------------------------------------------------

# For any type of tuning, and for fit-resamples, we generate a unified
Expand Down Expand Up @@ -310,16 +325,20 @@ compute_grid_info <- function(workflow, grid) {
grid <- tibble::as_tibble(grid)

parameters <- hardhat::extract_parameter_set_dials(workflow)
parameters_model <- dplyr::filter(parameters, source == "model_spec")

parameters_preprocessor <- dplyr::filter(parameters, source == "recipe")
parameters_model <- dplyr::filter(parameters, source == "model_spec")
parameters_postprocessor <- dplyr::filter(parameters, source == "tailor")

any_parameters_model <- nrow(parameters_model) > 0
any_parameters_preprocessor <- nrow(parameters_preprocessor) > 0

res <- min_grid(extract_spec_parsnip(workflow), grid)
any_parameters_model <- nrow(parameters_model) > 0
any_parameters_postprocessor <- nrow(parameters_postprocessor) > 0

syms_pre <- rlang::syms(parameters_preprocessor$id)
syms_mod <- rlang::syms(parameters_model$id)
syms_post <- rlang::syms(parameters_postprocessor$id)

res <- min_grid(extract_spec_parsnip(workflow), grid)

# ----------------------------------------------------------------------------
# Create an order of execution to train the preprocessor (if any). This will
Expand All @@ -340,7 +359,7 @@ compute_grid_info <- function(workflow, grid) {
res$.lab_pre <- "Preprocessor1"
}

# Make the label shown in the grid and in loggining
# Make the label shown in the grid and in logging
res$.msg_preprocessor <-
new_msgs_preprocessor(
res$.iter_preprocessor,
Expand All @@ -351,6 +370,17 @@ compute_grid_info <- function(workflow, grid) {
# Now make a similar iterator across models. Conditioning on each unique
# preprocessing candidate set, make an iterator for the model candidate sets
# (if any)
if (any_parameters_postprocessor) {
# Ensure that the submodel trick kicks in by temporarily nesting the
# postprocessor parameters while iterating in the model grid
# TODO: will this introduce issues when there are matching postprocessor
# values across models?
# ... i think we actually want to (temporarily?) situate these as submodels
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My first pass at enabling the submodel trick—nest by the postprocessor and unnest later—but that doesn't quite do the job.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that doesn't quite do the job.

Can you expand on this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See the linked notes above!

res <- tidyr::nest(
res,
.data_post = all_of(parameters_postprocessor$id)
)
}

res <-
res %>%
Expand All @@ -370,9 +400,28 @@ compute_grid_info <- function(workflow, grid) {
n = res$.num_models,
res$.msg_preprocessor)

res %>%
res <- res %>%
dplyr::select(-.num_models) %>%
dplyr::relocate(dplyr::starts_with(".msg"))

# ----------------------------------------------------------------------------
# Finally, iterate across postprocessors. Conditioning on an .iter_config,
# make an iterator for each postprocessing candidate set (if any).
if (!any_parameters_postprocessor) {
return(res)
}

res <-
res %>%
dplyr::group_nest(.iter_config, keep = TRUE) %>%
dplyr::mutate(
data = purrr::map(data, make_iter_postprocessor)
) %>%
tidyr::unnest(cols = data) %>%
dplyr::relocate(dplyr::starts_with(".iter"), dplyr::starts_with(".msg")) %>%
tidyr::unnest(.data_post)

res
}

make_iter_config <- function(dat) {
Expand All @@ -385,6 +434,32 @@ make_iter_config <- function(dat) {
tibble::tibble(.iter_config = .iter_config)
}

make_iter_postprocessor <- function(data) {
data %>%
mutate(
.iter_postprocessor = seq_len(nrow(data)),
.msg_postprocessor = new_msgs_postprocessor(
i = .iter_postprocessor,
n = max(.iter_postprocessor),
msgs_model = .msg_model
),
.iter_config_post = purrr::map2(
.iter_config,
.iter_postprocessor,
make_iter_config_post
)
) %>%
select(-.iter_config)
}

make_iter_config_post <- function(iter_config, iter_postprocessor) {
paste0(
iter_config,
"_Postprocessor",
iter_postprocessor
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still needs a format_with_padding().

)
}

# This generates a "dummy" grid_info object that has the same
# structure as a grid-info object with no tunable recipe parameters
# and no tunable model parameters.
Expand Down Expand Up @@ -420,6 +495,9 @@ new_msgs_preprocessor <- function(i, n) {
new_msgs_model <- function(i, n, msgs_preprocessor) {
paste0(msgs_preprocessor, ", model ", i, "/", n)
}
new_msgs_postprocessor <- function(i, n, msgs_model) {
paste0(msgs_model, ", postprocessor ", i, "/", n)
}

# c(1, 10) -> c("01", "10")
format_with_padding <- function(x) {
Expand Down Expand Up @@ -467,3 +545,8 @@ set_workflow_recipe <- function(workflow, recipe) {
workflow$pre$actions$recipe$recipe <- recipe
workflow
}

set_workflow_tailor <- function(workflow, tailor) {
workflow$post$actions$tailor$tailor <- tailor
workflow
}
21 changes: 20 additions & 1 deletion R/merge.R
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,12 @@ merge.model_spec <- function(x, y, ...) {
merger(x, y, ...)
}

#' @export
#' @rdname merge.recipe
merge.tailor <- function(x, y, ...) {
merger(x, y, ...)
}

update_model <- function(grid, object, pset, step_id, nms, ...) {
for (i in nms) {
param_info <- pset %>% dplyr::filter(id == i & source == "model_spec")
Expand Down Expand Up @@ -108,6 +114,16 @@ update_recipe <- function(grid, object, pset, step_id, nms, ...) {
object
}

update_tailor <- function(grid, object, pset, adjustment_id, nms, ...) {
for (i in nms) {
param_info <- pset %>% dplyr::filter(id == i & source == "tailor")
if (nrow(param_info) == 1) {
idx <- which(adjustment_id == param_info$component_id)
object$adjustments[[idx]][["arguments"]][[param_info$name]] <- grid[[i]]
}
}
object
}

merger <- function(x, y, ...) {
if (!is.data.frame(y)) {
Expand All @@ -127,9 +143,12 @@ merger <- function(x, y, ...) {
if (inherits(x, "recipe")) {
updater <- update_recipe
step_ids <- purrr::map_chr(x$steps, ~ .x$id)
} else {
} else if (inherits(x, "model_spec")) {
updater <- update_model
step_ids <- NULL
} else {
updater <- update_tailor
step_ids <- purrr::map_chr(x$adjustments, ~class(.x)[1])
}

if (!any(grid_name %in% pset$id)) {
Expand Down
10 changes: 10 additions & 0 deletions R/min_grid.R
Original file line number Diff line number Diff line change
Expand Up @@ -326,3 +326,13 @@ min_grid.pls <- fit_max_value
#' @export min_grid.poisson_reg
#' @rdname min_grid
min_grid.poisson_reg <- fit_max_value


# When `min_grid()` is applied to grids with additional columns for
# postprocessors, we need to nest the postprocessor columns into
# .submodels to effectively enable the submodel trick.
# See: https://gist.github.com/simonpcouch/28d984cdcc3fc6d22ff776ed8740004e
nest_min_grid <- function(min_grid, post_params) {
# TODO
min_grid
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thiiink this will be our entry point; patch min_grid() output by dropping postprocessor values into the nested .submodel structure.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The postprocessing object will only be available from within a workflow.

We could make a workflow method for min_grid() so that we have all of the information coming from a single object.

It might be possible to execute min_grid() on the model spec then generate a separate (but similar) data structure from the postprocessor, then join the two.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if a join, like you suggest, would be the more effective way to make this happen. Right now, my theory is that we can take the output from model_spec min_grid() methods and do some dplyr/tidyr-fu based on the postprocessor parameter names.

This is gnaaaarly, though.

}
Loading
Loading