From 2fbcbfb64007b701fdee9cd947777cc4dc496452 Mon Sep 17 00:00:00 2001 From: Hannah Frick Date: Wed, 15 Jan 2025 11:15:37 +0000 Subject: [PATCH 1/7] Refactor `get_tune_schedule()` --- R/schedule.R | 201 +++++++++++++++++++++++---------------------------- 1 file changed, 90 insertions(+), 111 deletions(-) diff --git a/R/schedule.R b/R/schedule.R index 24ff5330..5ecac0ca 100644 --- a/R/schedule.R +++ b/R/schedule.R @@ -21,132 +21,111 @@ get_tune_schedule <- function(wflow, param, grid) { cli::cli_abort("Argument {.arg grid} must be a tibble.") } - # ---------------------------------------------------------------------------- - # Get information on the parameters associated with the supervised model + # Which parameter belongs to which stage and which is a submodel parameter? + param_info <- get_param_info(wflow) - model_spec <- extract_spec_parsnip(wflow) - model_type <- class(model_spec)[1] - model_eng <- model_spec$engine - - # Which, if any, is a submodel - model_param <- parsnip::get_from_env(paste0(model_type, "_args")) %>% - dplyr::filter(engine == model_spec$engine) %>% - dplyr::select(name = parsnip, has_submodel) - - # Merge the info in with the other parameters - param <- dplyr::left_join(param, model_param, by = "name") %>% - dplyr::mutate( - has_submodel = dplyr::if_else(is.na(has_submodel), FALSE, has_submodel) - ) - - # ------------------------------------------------------------------------------ - # Get tuning parameter IDs for each stage of the workflow - - if (any(param$source == "recipe")) { - pre_id <- param$id[param$source == "recipe"] - } else { - pre_id <- character(0) - } + schedule <- schedule_stages(grid, param_info, wflow) - if (any(param$source == "model_spec")) { - model_id <- param$id[param$source == "model_spec"] - sub_id <- param$id[param$source == "model_spec" & param$has_submodel] - non_sub_id <- param$id[param$source == "model_spec" & !param$has_submodel] - } else { - model_id <- sub_id <- non_sub_id <- character(0) - } - - if (any(param$source == "tailor")) { - post_id <- param$id[param$source == "tailor"] + # TODO rework class(es)? + og_cls <- class(schedule) + if (nrow(param) == 0) { + cls <- "resample_schedule" } else { - post_id <- character(0) + cls <- "grid_schedule" } - ids <- list( - all = param$id, - pre = pre_id, - # All model param - model = model_id, - fits = c(pre_id, non_sub_id), - sub = sub_id, - non_sub = non_sub_id, - post = post_id - ) - # convert to symbols - symbs <- purrr::map(ids, syms) - - has_submodels <- length(ids$sub) > 0 - - # ------------------------------------------------------------------------------ - # First collapse the submodel parameters (if any) and postprocessors - # TODO update this will submodels and postproc - if (has_submodels) { - sched <- grid %>% - dplyr::group_nest(!!!symbs$fits, .key = "predict_stage") - # Note 1: multi_predict() should only be triggered for a submodel parameter if - # there are multiple rows in the `predict_stage` list column. i.e. the submodel - # column will always be there but we only multipredict when there are 2+ - # values to predict. - - # Note 2: The purpose of min_grid() is to determine the minimum grid for - # preprocessing and model parameters to fit. We compute it here and ignore - # any postprocessing tuning parmeters (if any). The postprocessing parameters - # will still be in the schedule since we schedule those before the results - # that use min_grid() are merged in. See issue #975 for an example and - # discussion. - first_loop_info <- - min_grid(model_spec, - grid %>% - dplyr::select(-dplyr::any_of(post_id)) %>% - dplyr::distinct()) - } else { - sched <- grid %>% - dplyr::group_nest(!!!symbs$fits, .key = "predict_stage") - first_loop_info <- grid %>% dplyr::select(!!!symbs$fits) + if (nrow(grid) == 1) { + cls <- c("single_schedule", cls) } - first_loop_info <- first_loop_info %>% - dplyr::select(!!!c(symbs$pre, symbs$model)) %>% - dplyr::distinct() + class(schedule) <- c(cls, "schedule", og_cls) - # ------------------------------------------------------------------------------ - # Add info an any postprocessing parameters + schedule +} - sched <- sched %>% - dplyr::mutate( - predict_stage = purrr::map( - predict_stage, - ~.x %>% dplyr::group_nest(!!!symbs$sub, .key = "post_stage") - ) - ) +schedule_stages <- function(grid, param_info, wflow) { + # schedule preprocessing stage and push the rest into a nested tibble + param_pre_stage <- param_info %>% + filter(source == "recipe") %>% + pull(id) + schedule <- grid %>% + tidyr::nest(.by = all_of(param_pre_stage), .key = "model_stage") + + # schedule next stages recursively + schedule %>% + mutate( + model_stage = + purrr::map( + model_stage, + schedule_model_stage_i, + param_info = param_info, + wflow = wflow + ) + ) +} - # ------------------------------------------------------------------------------ - # Merge in submodel fit value (if any) +schedule_model_stage_i <- function(model_stage, param_info, wflow){ + model_param <- param_info %>% + filter(source == "model_spec") %>% + pull(id) + non_submodel_param <- param_info %>% + filter(source == "model_spec" & !has_submodel) %>% + pull(id) + + # schedule model parameters + schedule <- min_model_grid(model_stage, model_param, wflow) + + # push remaining paramters into the next stage + next_stage <- model_stage %>% + tidyr::nest(.by = all_of(non_submodel_param), .key = "predict_stage") + + schedule <- schedule %>% + dplyr::left_join(next_stage, by = all_of(non_submodel_param)) + + # schedule next stages recursively + schedule %>% + mutate( + predict_stage = + purrr::map(predict_stage, schedule_predict_stage_i, param_info = param_info) + ) +} - loop_names <- names(sched)[names(sched) != "predict_stage"] - if (length(loop_names) > 0) { - # Using `by = character()` to perform a cross join was deprecated - sched <- dplyr::full_join(sched, first_loop_info, by = loop_names) - } +min_model_grid <- function(grid, model_param, wflow){ + # work on only the model parameters + model_grid <- grid %>% + select(all_of(model_param)) %>% + dplyr::distinct() + + min_grid( + extract_spec_parsnip(wflow), + model_grid + ) %>% + select(all_of(model_param)) +} - # ------------------------------------------------------------------------------ - # Now collapse over the preprocessor for conditional execution +schedule_predict_stage_i <- function(predict_stage, param_info) { + submodel_param <- param_info %>% + filter(source == "model_spec" & has_submodel) %>% + pull(id) - sched <- sched %>% dplyr::group_nest(!!!symbs$pre, .key = "model_stage") + predict_stage %>% + tidyr::nest(.by = all_of(submodel_param), .key = "post_stage") +} - # ------------------------------------------------------------------------------ +# TODO check if existing tune functionality already covers this +get_param_info <- function(wflow) { + param_info <- tune_args(wflow) %>% + select(name, id, source) - og_cls <- class(sched) - if (nrow(param) == 0) { - cls <- "resample_schedule" - } else { - cls <- "grid_schedule" - } + model_spec <- extract_spec_parsnip(wflow) + model_type <- class(model_spec)[1] + model_eng <- model_spec$engine - if (nrow(grid) == 1) { - cls <- c("single_schedule", cls) - } + model_param <- parsnip::get_from_env(paste0(model_type, "_args")) %>% + dplyr::filter(engine == model_spec$engine) %>% + dplyr::select(name = parsnip, has_submodel) - class(sched) <- c(cls, "schedule", og_cls) - sched + param_info <- dplyr::left_join(param_info, model_param, by = "name") + + param_info } From 108d94f913e429b3f270de358a37df24a0933292 Mon Sep 17 00:00:00 2001 From: Hannah Frick Date: Wed, 15 Jan 2025 11:17:28 +0000 Subject: [PATCH 2/7] Namespace things (note that otherwise the testing pane in Positron doesn't work) --- tests/testthat/helper-tune-package.R | 62 +++++++++++++++------------- 1 file changed, 33 insertions(+), 29 deletions(-) diff --git a/tests/testthat/helper-tune-package.R b/tests/testthat/helper-tune-package.R index d8f70460..ae0e5312 100644 --- a/tests/testthat/helper-tune-package.R +++ b/tests/testthat/helper-tune-package.R @@ -1,9 +1,13 @@ -suppressPackageStartupMessages(library(workflows)) -suppressPackageStartupMessages(library(parsnip)) -suppressPackageStartupMessages(library(recipes)) -suppressPackageStartupMessages(library(dials)) -suppressPackageStartupMessages(library(tailor)) -suppressPackageStartupMessages(library(purrr)) +# suppressPackageStartupMessages(library(workflows)) +# suppressPackageStartupMessages(library(parsnip)) +# suppressPackageStartupMessages(library(recipes)) +# suppressPackageStartupMessages(library(dials)) +# suppressPackageStartupMessages(library(tailor)) +# suppressPackageStartupMessages(library(purrr)) + + +# NOTE namsespacing is required to make this file load properly in the testthat machinery + new_rng_snapshots <- utils::compareVersion("3.6.0", as.character(getRversion())) > 0 @@ -13,18 +17,18 @@ rankdeficient_version <- any(names(formals("predict.lm")) == "rankdeficient") helper_objects_tune <- function() { rec_tune_1 <- - recipe(mpg ~ ., data = mtcars) %>% - step_normalize(all_predictors()) %>% - step_pca(all_predictors(), num_comp = tune()) + recipes::recipe(mpg ~ ., data = mtcars) %>% + recipes::step_normalize(all_predictors()) %>% + recipes::step_pca(all_predictors(), num_comp = tune()) rec_no_tune_1 <- - recipe(mpg ~ ., data = mtcars) %>% - step_normalize(all_predictors()) + recipes::recipe(mpg ~ ., data = mtcars) %>% + recipes::step_normalize(all_predictors()) - lm_mod <- linear_reg() %>% set_engine("lm") + lm_mod <- parsnip::linear_reg() %>% parsnip::set_engine("lm") - svm_mod <- svm_rbf(mode = "regression", cost = tune()) %>% - set_engine("kernlab") + svm_mod <- parsnip::svm_rbf(mode = "regression", cost = tune()) %>% + parsnip::set_engine("kernlab") list( rec_tune_1 = rec_tune_1, @@ -83,33 +87,33 @@ redefer_initialize_catalog <- function(test_env) { if (rlang::is_installed("splines2")) { rec_df <- - recipe(mpg ~ ., data = mtcars) %>% - step_corr(all_predictors(), threshold = .1) %>% - step_spline_natural(disp, deg_free = 5) + recipes::recipe(mpg ~ ., data = mtcars) %>% + recipes::step_corr(all_predictors(), threshold = .1) %>% + recipes::step_spline_natural(disp, deg_free = 5) rec_tune_thrsh_df <- - recipe(mpg ~ ., data = mtcars) %>% - step_corr(all_predictors(), threshold = tune()) %>% - step_spline_natural(disp, deg_free = tune("disp_df")) + recipes::recipe(mpg ~ ., data = mtcars) %>% + recipes::step_corr(all_predictors(), threshold = tune()) %>% + recipes::step_spline_natural(disp, deg_free = tune("disp_df")) } -mod_tune_bst <- boost_tree(trees = tune(), min_n = tune(), mode = "regression") -mod_tune_rf <- rand_forest(min_n = tune(), mode = "regression") +mod_tune_bst <- parsnip::boost_tree(trees = tune(), min_n = tune(), mode = "regression") +mod_tune_rf <- parsnip::rand_forest(min_n = tune(), mode = "regression") if (rlang::is_installed("probably")) { adjust_tune_min <- - tailor() %>% - adjust_numeric_range(lower_limit = tune()) + tailor::tailor() %>% + tailor::adjust_numeric_range(lower_limit = tune()) adjust_cal_tune_min <- - tailor() %>% - adjust_numeric_calibration(method = "linear") %>% - adjust_numeric_range(lower_limit = tune()) + tailor::tailor() %>% + tailor::adjust_numeric_calibration(method = "linear") %>% + tailor::adjust_numeric_range(lower_limit = tune()) adjust_min <- - tailor() %>% - adjust_numeric_range(lower_limit = 0) + tailor::tailor() %>% + tailor::adjust_numeric_range(lower_limit = 0) } From f39824b404066ef56a1b5321a4a42c33e8a2ddfd Mon Sep 17 00:00:00 2001 From: Hannah Frick Date: Wed, 15 Jan 2025 11:20:01 +0000 Subject: [PATCH 3/7] give in temporarily --- tests/testthat/test-schedule.R | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/tests/testthat/test-schedule.R b/tests/testthat/test-schedule.R index 1ee99599..0c81d406 100644 --- a/tests/testthat/test-schedule.R +++ b/tests/testthat/test-schedule.R @@ -3,6 +3,14 @@ # Objects in helper-tune-package.R +suppressPackageStartupMessages(library(workflows)) +suppressPackageStartupMessages(library(parsnip)) +suppressPackageStartupMessages(library(recipes)) +suppressPackageStartupMessages(library(dials)) +suppressPackageStartupMessages(library(tailor)) +suppressPackageStartupMessages(library(purrr)) +suppressPackageStartupMessages(library(dplyr)) + # ------------------------------------------------------------------------------ # No tuning or postprocesing estimation @@ -564,10 +572,10 @@ test_that("grid processing schedule - recipe + model, submodels, irregular grid" grid_model <- grid_pre_model %>% - group_nest(threshold, disp_df) %>% + dplyr::group_nest(threshold, disp_df) %>% mutate( - data = map(data, ~ .x %>% summarize(trees = max(trees), .by = c(min_n))), - data = map(data, ~ .x %>% arrange(min_n)) + data = purrr::map(data, ~ .x %>% dplyr::summarize(trees = max(trees), .by = c(min_n))), + data = purrr::map(data, ~ .x %>% arrange(min_n)) ) # ------------------------------------------------------------------------------ @@ -576,8 +584,8 @@ test_that("grid processing schedule - recipe + model, submodels, irregular grid" expect_named(sched_pre_model, c("threshold", "disp_df", "model_stage")) expect_equal( - sched_pre_model %>% select(-model_stage) %>% as_tibble(), grid_pre %>% arrange(threshold, disp_df) + sched_pre_model %>% select(-model_stage) %>% tibble::as_tibble(), ) for (i in seq_along(sched_pre_model$model_stage)) { From 88279fee52eb71c2fa000c41d52b896ec0cfec6f Mon Sep 17 00:00:00 2001 From: Hannah Frick Date: Wed, 15 Jan 2025 11:26:26 +0000 Subject: [PATCH 4/7] submodel parameters don't get move to the end anymore --- tests/testthat/test-schedule.R | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/testthat/test-schedule.R b/tests/testthat/test-schedule.R index 0c81d406..eaf3539f 100644 --- a/tests/testthat/test-schedule.R +++ b/tests/testthat/test-schedule.R @@ -175,7 +175,7 @@ test_that("grid processing schedule - model only, submodels, regular grid", { reg_n <- length(sched_bst$model_stage) for (i in 1:reg_n) { - expect_named(sched_bst$model_stage[[i]], c("min_n", "predict_stage", "trees")) + expect_named(sched_bst$model_stage[[i]], c("trees", "min_n", "predict_stage")) expect_equal( sched_bst$model_stage[[i]] %>% @@ -223,7 +223,7 @@ test_that("grid processing schedule - model only, submodels, SFD grid", { irreg_n <- length(sched_sfd_bst$model_stage) expect_equal(irreg_n, 1L) - expect_named(sched_sfd_bst$model_stage[[1]], c("min_n", "predict_stage", "trees")) + expect_named(sched_sfd_bst$model_stage[[1]], c("trees", "min_n", "predict_stage")) expect_equal( sched_sfd_bst$model_stage[[1]] %>% dplyr::select(-predict_stage) %>% @@ -266,7 +266,7 @@ test_that("grid processing schedule - model only, submodels, irregular design", odd_n <- length(sched_odd_bst$model_stage) expect_equal(odd_n, 1L) - expect_named(sched_odd_bst$model_stage[[1]], c("min_n", "predict_stage", "trees")) + expect_named(sched_odd_bst$model_stage[[1]], c("trees", "min_n", "predict_stage")) expect_equal( sched_odd_bst$model_stage[[1]] %>% dplyr::select(-predict_stage) %>% @@ -307,7 +307,7 @@ test_that("grid processing schedule - model only, submodels, 1 point design", { expect_equal(length(sched_1_pt$model_stage), 1L) expect_named( sched_1_pt$model_stage[[1]], - c("min_n", "predict_stage", "trees") + c("trees", "min_n", "predict_stage") ) expect_equal( @@ -590,7 +590,7 @@ test_that("grid processing schedule - recipe + model, submodels, irregular grid" for (i in seq_along(sched_pre_model$model_stage)) { model_i <- sched_pre_model$model_stage[[i]] - expect_named(model_i, c("min_n", "predict_stage", "trees")) + expect_named(model_i, c("trees", "min_n", "predict_stage")) expect_equal( model_i %>% select(min_n, trees) %>% arrange(min_n), grid_model$data[[i]] @@ -705,7 +705,7 @@ test_that("grid processing schedule - recipe + model + tailor, submodels, irregu for (j in seq_along(sched_pre_model_post$model_stage[[i]]$predict_stage)) { model_ij <- model_i[j,] - expect_named(model_ij, c("min_n", "predict_stage", "trees")) + expect_named(model_ij, c("trees", "min_n", "predict_stage")) predict_j <- model_ij$predict_stage[[1]] expect_named(predict_j, c("trees", "post_stage")) From 64aadb3de62370bb8c2a67f0ea368b689e3c6f7b Mon Sep 17 00:00:00 2001 From: Hannah Frick Date: Wed, 15 Jan 2025 11:27:36 +0000 Subject: [PATCH 5/7] schedule now keeps grid ordering --- tests/testthat/test-schedule.R | 25 ++++++++++--------------- 1 file changed, 10 insertions(+), 15 deletions(-) diff --git a/tests/testthat/test-schedule.R b/tests/testthat/test-schedule.R index eaf3539f..4befd07d 100644 --- a/tests/testthat/test-schedule.R +++ b/tests/testthat/test-schedule.R @@ -385,8 +385,7 @@ test_that("grid processing schedule - recipe + postprocessing, regular grid", { grid_pre <- grid_pre_post %>% - distinct(threshold, disp_df) %>% - arrange(threshold, disp_df) + distinct(threshold, disp_df) grid_post <- grid_pre_post %>% distinct(lower_limit) %>% @@ -399,7 +398,7 @@ test_that("grid processing schedule - recipe + postprocessing, regular grid", { expect_named(sched_pre_post, c("threshold", "disp_df", "model_stage")) expect_equal( sched_pre_post %>% select(-model_stage) %>% as_tibble(), - grid_pre %>% arrange(threshold, disp_df) + grid_pre ) for (i in seq_along(sched_pre_post$model_stage)) { @@ -439,8 +438,7 @@ test_that("grid processing schedule - recipe + postprocessing, irregular grid", grid_pre <- grid_pre_post %>% - distinct(threshold, disp_df) %>% - arrange(threshold, disp_df) + distinct(threshold, disp_df) grids_post <- grid_pre_post %>% @@ -454,7 +452,7 @@ test_that("grid processing schedule - recipe + postprocessing, irregular grid", expect_named(sched_pre_post, c("threshold", "disp_df", "model_stage")) expect_equal( sched_pre_post %>% select(-model_stage) %>% as_tibble(), - grid_pre %>% arrange(threshold, disp_df) + grid_pre ) for (i in seq_along(sched_pre_post$model_stage)) { @@ -504,8 +502,7 @@ test_that("grid processing schedule - recipe + model, no submodels, regular grid grid_pre <- grid_pre_model %>% - distinct(threshold, disp_df) %>% - arrange(threshold, disp_df) + distinct(threshold, disp_df) grid_model <- grid_pre_model %>% @@ -519,7 +516,7 @@ test_that("grid processing schedule - recipe + model, no submodels, regular grid expect_named(sched_pre_model, c("threshold", "disp_df", "model_stage")) expect_equal( sched_pre_model %>% select(-model_stage) %>% as_tibble(), - grid_pre %>% arrange(threshold, disp_df) + grid_pre ) for (i in seq_along(sched_pre_model$model_stage)) { @@ -567,8 +564,7 @@ test_that("grid processing schedule - recipe + model, submodels, irregular grid" grid_pre <- grid_pre_model %>% - distinct(threshold, disp_df) %>% - arrange(threshold, disp_df) + distinct(threshold, disp_df) grid_model <- grid_pre_model %>% @@ -584,8 +580,8 @@ test_that("grid processing schedule - recipe + model, submodels, irregular grid" expect_named(sched_pre_model, c("threshold", "disp_df", "model_stage")) expect_equal( - grid_pre %>% arrange(threshold, disp_df) sched_pre_model %>% select(-model_stage) %>% tibble::as_tibble(), + grid_pre ) for (i in seq_along(sched_pre_model$model_stage)) { @@ -652,8 +648,7 @@ test_that("grid processing schedule - recipe + model + tailor, submodels, irregu grid_pre <- grid_pre_model_post %>% - distinct(threshold, disp_df) %>% - arrange(threshold, disp_df) + distinct(threshold, disp_df) grid_model <- grid_pre_model_post %>% @@ -674,7 +669,7 @@ test_that("grid processing schedule - recipe + model + tailor, submodels, irregu expect_named(sched_pre_model_post, c("threshold", "disp_df", "model_stage")) expect_equal( sched_pre_model_post %>% select(-model_stage) %>% as_tibble(), - grid_pre %>% arrange(threshold, disp_df) + grid_pre ) for (i in seq_along(sched_pre_model_post$model_stage)) { From 450a1820a80654b02f816801966c3562fa4c9689 Mon Sep 17 00:00:00 2001 From: Hannah Frick Date: Fri, 17 Jan 2025 13:59:45 +0000 Subject: [PATCH 6/7] allow 0-row tibbles for "no tuning" --- R/schedule.R | 1 - tests/testthat/test-schedule.R | 33 +++------------------------------ 2 files changed, 3 insertions(+), 31 deletions(-) diff --git a/R/schedule.R b/R/schedule.R index 5ecac0ca..878a128e 100644 --- a/R/schedule.R +++ b/R/schedule.R @@ -26,7 +26,6 @@ get_tune_schedule <- function(wflow, param, grid) { schedule <- schedule_stages(grid, param_info, wflow) - # TODO rework class(es)? og_cls <- class(schedule) if (nrow(param) == 0) { cls <- "resample_schedule" diff --git a/tests/testthat/test-schedule.R b/tests/testthat/test-schedule.R index 4befd07d..d1d35761 100644 --- a/tests/testthat/test-schedule.R +++ b/tests/testthat/test-schedule.R @@ -22,16 +22,7 @@ test_that("grid processing schedule - no parameters", { sched_nada <- get_tune_schedule(wflow_nada, prm_used_nada, grid_nada) expect_named(sched_nada, "model_stage") - expect_equal(nrow(sched_nada), 1) - - # All of the other nested tibbles should be empty - expect_equal( - sched_nada %>% - tidyr::unnest(model_stage) %>% - tidyr::unnest(predict_stage) %>% - tidyr::unnest(post_stage), - grid_nada - ) + expect_equal(nrow(sched_nada), 0) expect_s3_class( sched_nada, @@ -49,16 +40,7 @@ test_that("grid processing schedule - recipe and model", { sched_pre_only <- get_tune_schedule(wflow_pre_only, prm_used_pre_only, grid_pre_only) expect_named(sched_pre_only, c("model_stage")) - expect_equal(nrow(sched_pre_only), max(nrow(grid_pre_only), 1)) - - # All of the other nested tibbles should be empty - expect_equal( - sched_pre_only %>% - tidyr::unnest(model_stage) %>% - tidyr::unnest(predict_stage) %>% - tidyr::unnest(post_stage), - grid_pre_only - ) + expect_equal(nrow(sched_pre_only), 0) expect_s3_class( sched_pre_only, @@ -77,16 +59,7 @@ test_that("grid processing schedule - recipe, model, and post", { sched_three <- get_tune_schedule(wflow_three, prm_used_three, grid_three) expect_named(sched_three, c("model_stage")) - expect_equal(nrow(sched_three), max(nrow(grid_three), 1)) - - # All of the other nested tibbles should be empty - expect_equal( - sched_three %>% - tidyr::unnest(model_stage) %>% - tidyr::unnest(predict_stage) %>% - tidyr::unnest(post_stage), - grid_three - ) + expect_equal(nrow(sched_three), 0) expect_s3_class( sched_three, From ee7664923c370161dc7223b4af0e6241a2d475f2 Mon Sep 17 00:00:00 2001 From: Hannah Frick Date: Fri, 17 Jan 2025 14:13:51 +0000 Subject: [PATCH 7/7] clean up --- R/0_imports.R | 3 ++- R/schedule.R | 1 - 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/R/0_imports.R b/R/0_imports.R index 7739c8b9..21a2c136 100644 --- a/R/0_imports.R +++ b/R/0_imports.R @@ -47,7 +47,8 @@ utils::globalVariables( "rowwise", ".best", "location", "msg", "..object", ".eval_time", ".pred_survival", ".pred_time", ".weight_censored", "nice_time", "time_metric", ".lower", ".upper", "i", "results", "term", ".alpha", - ".method", "old_term", ".lab_pre", ".model", ".num_models", "predict_stage" + ".method", "old_term", ".lab_pre", ".model", ".num_models", "model_stage", + "predict_stage" ) ) diff --git a/R/schedule.R b/R/schedule.R index 878a128e..6e6fbe41 100644 --- a/R/schedule.R +++ b/R/schedule.R @@ -111,7 +111,6 @@ schedule_predict_stage_i <- function(predict_stage, param_info) { tidyr::nest(.by = all_of(submodel_param), .key = "post_stage") } -# TODO check if existing tune functionality already covers this get_param_info <- function(wflow) { param_info <- tune_args(wflow) %>% select(name, id, source)