Skip to content

Commit

Permalink
add temp loop code
Browse files Browse the repository at this point in the history
  • Loading branch information
topepo committed Dec 5, 2024
1 parent 5c90128 commit 8e51568
Show file tree
Hide file tree
Showing 4 changed files with 371 additions and 108 deletions.
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -204,13 +204,15 @@ export(fit_resamples)
export(forge_from_workflow)
export(get_metric_time)
export(get_tune_colors)
export(get_tune_schedule)
export(initialize_catalog)
export(int_pctl)
export(is_preprocessor)
export(is_recipe)
export(is_workflow)
export(last_fit)
export(load_pkgs)
export(loopy)
export(maybe_choose_eval_time)
export(message_wrap)
export(metrics_info)
Expand Down
108 changes: 0 additions & 108 deletions R/grid_helpers.R
Original file line number Diff line number Diff line change
Expand Up @@ -425,114 +425,6 @@ format_with_padding <- function(x) {

# ------------------------------------------------------------------------------

get_tune_schedule <- function(wflow, param, grid) {

# ----------------------------------------------------------------------------
# Get information on the parameters associated with the supervised model

model_spec <- extract_spec_parsnip(wflow)
model_type <- class(model_spec)[1]
model_eng <- model_spec$engine

# Which, if any, is a submodel
model_param <-
parsnip::get_from_env(paste0(model_type, "_args")) %>%
dplyr::filter(engine == model_spec$engine) %>%
dplyr::select(name = parsnip, has_submodel)

# Merge the info in with the other parameters
param <-
dplyr::left_join(param, model_param, by = "name") %>%
dplyr::mutate(has_submodel = if_else(is.na(has_submodel), FALSE, has_submodel))

# ------------------------------------------------------------------------------
# Get tuning parameter IDs for each stage of the workflow

if (any(param$source == "recipe")) {
pre_id <- param$id[param$source == "recipe"]
} else {
pre_id <- character(0)
}

if (any(param$source == "model_spec")) {
model_id <- param$id[param$source == "model_spec"]
sub_id <- param$id[param$source == "model_spec" & param$has_submodel]
non_sub_id <- param$id[param$source == "model_spec" & !param$has_submodel]
} else {
model_id <- sub_id <- non_sub_id <- character(0)
}

if (any(param$source == "tailor")) {
post_id <- param$id[param$source == "tailor"]
} else {
post_id <- character(0)
}

ids <- list(
all = param$id,
pre = pre_id,
# All model param
model = model_id,
fits = c(pre_id, non_sub_id),
sub = sub_id,
non_sub = non_sub_id,
post = post_id
)
# convert to symbols
symbs <- purrr::map(ids, syms)

has_submodels <- length(ids$sub) > 0

# ------------------------------------------------------------------------------
# First collapse the submodel parameters (if any)

if (has_submodels) {
sched <-
grid %>%
dplyr::group_nest(!!!symbs$fits, .key = "predict_stage")
# Note: multi_predict() should only be triggered for a submodel parameter if
# there are multiple rows in the `predict_stage` list column.
first_loop_info <- min_grid(model_spec, grid)
} else {
sched <-
grid %>%
dplyr::group_nest(!!!symbs$all, .key = "predict_stage")
first_loop_info <- grid
}

first_loop_info <-
first_loop_info %>%
dplyr::select(!!!c(symbs$pre, symbs$model))

# ------------------------------------------------------------------------------
# Add info an any postprocessing parameters

sched <-
sched %>%
dplyr::mutate(
predict_stage =
purrr::map(
predict_stage,
~ .x %>% dplyr::group_nest(!!!symbs$sub, .key = "post_stage")))

# ------------------------------------------------------------------------------
# Merge in submodel fit value (if any)

loop_names <- names(sched)[names(sched) != "predict_stage"]
sched <- dplyr::full_join(sched, first_loop_info, by = loop_names)

# ------------------------------------------------------------------------------
# Now collapse over the preprocessor for conditional execution

sched <- sched %>% dplyr::group_nest(!!!symbs$pre, .key = "second_stage")

# ------------------------------------------------------------------------------

sched
}

# ------------------------------------------------------------------------------

has_preprocessor <- function(workflow) {
has_preprocessor_recipe(workflow) ||
has_preprocessor_formula(workflow) ||
Expand Down
261 changes: 261 additions & 0 deletions R/loopy.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,261 @@
# This code is a working version to fingure some things out; it probably
# won't make it into main as-is

no_stage <- function(x) {
stages <- c("model_stage", "predict_stage", "post_stage")
x[, !(names(x) %in% stages)]
}

text_param <- function(x) {
x <- no_stage(x)
x <- as.list(x)
x <- purrr::map_chr(x, ~ format(.x, digits = 3))
x <- paste0(names(x), ": ", x)
cli::format_inline("{x}")
}

has_pre_param <- function(x) {
any(names(x) != "model_stage")
}

has_mod_param <- function(x) {
any(names(x) != "predict_stage")
}

has_sub_param <- function(x) {
not_post_list <- names(x) != "post_stage"
has_param_col <- any(not_post_list)
if (!has_param_col) {
return(FALSE)
}
param_col_nm <- names(x)[not_post_list]
param_col <- x[[param_col_nm]]
two_plus_vals <- length(param_col) > 1
two_plus_vals
}

# from workflows
has_tailor <- function(x) {
"tailor" %in% names(x$post$actions)
}
#
has_tailor_tuned <- function(x) {
if (!has_tailor(x)) {
res <- FALSE
} else {
res <- any(tune_args(x)$source == "tailor")
}
res
}
has_tailor_estimated <- function(x) {
if (!has_tailor(x)) {
res <- FALSE
} else {
post <- extract_postprocessor(x)
res <- tailor::tailor_requires_fit(post)
}
res
}

# ------------------------------------------------------------------------------

pred_post_strategy <- function(x) {
if (has_tailor(x)) {
if (has_tailor_tuned(x) | has_tailor_estimated(x)) {
# There is no way to get around having to estimate/fit the tailor object
# for each tuning combination
res <- "loop over pred and post"
} else {
# For a set of predictions, or a set of submodel predictions, we can
# just apply the tailor object (i.e. predict) to the set(s)
res <- "predict and post at same time"
}
} else {
# Stop at prediction, submodels or not
res <- "just predict"
}
res
}

predict_only <- function(wflow, sched, data_pred, grid) {
outputs <- get_output_columns(wflow, syms = TRUE)

if (has_sub_param(sched$predict_stage[[1]])) {
cli::cli_inform("multipredict only")

# get submodel name and vector; remove col from grid
# loop over types
# move to predict_wrapper
processed_data_pred <- forge_from_workflow(data_pred, wflow)
pred <-
wflow %>%
extract_fit_parsnip() %>%
multi_predict(processed_data_pred$predictors, trees = 1:10) %>%
parsnip::add_rowindex() %>%
tidyr::unnest(.pred) %>%
dplyr::full_join(data_pred %>% add_rowindex(), by = ".row") %>%
dplyr::select(!!!outputs$outcome, !!!outputs$estimate, .row, trees) %>%
cbind(grid %>% dplyr::select(-trees)) %>%
dplyr::arrange(.row) %>%
dplyr::select(-.row)

# multi_predict
} else {
cli::cli_inform("predict only")
pred <-
augment(wflow, data_pred) %>%
dplyr::select(!!!unlist(outputs)) %>%
cbind(grid)
}
pred
}

predict_post_one_shot <- function(wflow, sched, data_pred, grid) {
outputs <- get_output_columns(wflow, syms = TRUE)
cli::cli_inform("predict/post once (not working)")
# mimic what .fit_post does but directly use tailor
# fit just to update the columns names
#
# multi_predict?
# add row numbers
# group nest
# apply tailor
# unnest
# predict
# apply tailor
}

predict_post_loop <- function(wflow, sched, data_pred, grid) {
cli::cli_inform("predict/post looping")
outputs <- get_output_columns(wflow, syms = TRUE)
num_pred_iter <- nrow(sched$predict_stage[[1]])
# TODO pre-allocate space and fill in
for(prd in seq_len(num_pred_iter)) {
current_pred <- sched$predict_stage[[1]][prd,]

num_post_iter <- nrow(current_pred$post_stage[[1]])

for(post in seq_len(num_post_iter)) {
current_post <- current_pred$post_stage[[1]][post,]

current_grid <- dplyr::bind_cols(current_grid, no_stage(current_post))
wflow <- post_update_fit(wflow, current_post, data_fit) # other data needed

predicted <-
augment(wflow, data_pred) %>%
dplyr::select(!!!unlist(outputs)) %>%
cbind(grid)
# bind cols;
}
}
predicted
}


predictions <- function(wflow, sched, data_pred, grid) {
strategy <- pred_post_strategy(wflow)
if (strategy == "just predict") {
pred <- predict_only(wflow, sched, data_pred, grid)
} else if (strategy == "predict and post at same time") {
# not yet implemented
pred <- predict_post_one_shot(wflow, sched, data_pred, grid)
} else {
pred <- predict_post_loop(wflow, sched, data_pred, grid)
}
pred
}

# ------------------------------------------------------------------------------

pre_update_fit <- function(wflow, grid, fit_data) {
pre_proc <- extract_preprocessor(wflow)

if (inherits(pre_proc, "recipe")) {
grid <- no_stage(grid)
pre_proc_param <- extract_parameter_set_dials(pre_proc)
pre_proc_id <- pre_proc_param$id

if (length(pre_proc_id) > 0) {
grid <- grid[, pre_proc_id]
pre_proc <- finalize_recipe(pre_proc, grid)
wflow <- set_workflow_recipe(wflow, pre_proc)
}
}
.fit_pre(wflow, fit_data)
}

model_update_fit <- function(wflow, grid) {
mod_spec <- extract_spec_parsnip(wflow)

grid <- no_stage(grid)
pre_proc_param <- extract_parameter_set_dials(mod_spec)
pre_proc_id <- pre_proc_param$id

if (length(pre_proc_id) > 0) {
grid <- grid[, pre_proc_id]
mod_spec <- finalize_model(mod_spec, grid)
wflow <- set_workflow_spec(wflow, mod_spec)
}

.fit_model(wflow, control_workflow())
}


post_update_fit <- function(wflow, grid, post_data) {
mod_spec <- extract_postprocessor(wflow)

grid <- no_stage(grid)
post_proc_param <- extract_parameter_set_dials(mod_spec)
post_proc_id <- post_proc_param$id

if (length(post_proc_id) > 0) {
grid <- grid[, post_proc_id]
mod_spec <- finalize_tailor(mod_spec, grid)
wflow <- set_workflow_spec(wflow, mod_spec)
}

res <- .fit_post(wflow, post_data)
.fit_finalize(res)
}


rebind_grid <- function(...) {
list(...) %>% purrr::map(no_stage) %>% purrr::list_cbind()
}

get_output_columns <- function(x, syms = FALSE) {
pred_cols <- .get_prediction_column_names(x, syms = TRUE)
res <- c(list(outcome = rlang::syms(outcome_names(x))), res)
res
}

# ------------------------------------------------------------------------------

#' @export
loopy <- function(sched, wflow, data_fit, data_pred) {
num_pre_iter <- nrow(sched)


for (pre in seq_len(num_pre_iter)) {
current_pre <- sched[pre, ]
cli::cli_inform("{pre}/{num_pre_iter} preprocessing: {text_param(current_pre)}")

current_wflow <- pre_update_fit(wflow, current_pre, data_fit)
num_mod_iter <- nrow(current_pre$model_stage[[1]])

# --------------------------------------------------------------------------
for (mod in seq_len(num_mod_iter)) {
current_model <- current_pre$model_stage[[1]][mod,]
cli::cli_inform("├── {mod}/{num_mod_iter} model: {text_param(current_model)}")

current_wflow <- model_update_fit(current_wflow, current_model)

num_pred_iter <- nrow(current_model$predict_stage[[1]])
current_grid <- rebind_grid(current_pre, current_model)

pred <- predictions(current_wflow, current_model, data_pred, current_grid)

# bind rows and/or pre-allocate
} # model loop
} # pre loop
}
Loading

0 comments on commit 8e51568

Please sign in to comment.