From e2d17c010b3cafd42a42e0b99da3de07880d6e05 Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Tue, 21 Jan 2025 14:36:31 -0800 Subject: [PATCH] pass call to check_for_disaster() --- R/C5.0.R | 4 ++-- R/bridge.R | 8 ++++---- R/cart.R | 4 ++-- R/cost_models.R | 4 ++-- R/mars.R | 4 ++-- R/nnet.R | 4 ++-- R/validate.R | 7 +++++-- tests/testthat/_snaps/validation.md | 2 +- 8 files changed, 20 insertions(+), 17 deletions(-) diff --git a/R/C5.0.R b/R/C5.0.R index 9ded28e..c72ca31 100644 --- a/R/C5.0.R +++ b/R/C5.0.R @@ -1,5 +1,5 @@ -c5_bagger <- function(rs, control, ...) { +c5_bagger <- function(rs, control, ..., call) { opt <- rlang::dots_list(...) mod_spec <- make_c5_spec(opt) @@ -17,7 +17,7 @@ c5_bagger <- function(rs, control, ...) { control = control )) - rs <- check_for_disaster(rs) + rs <- check_for_disaster(rs, call = call) rs <- filter_rs(rs) diff --git a/R/bridge.R b/R/bridge.R index 9283336..b4f0e88 100644 --- a/R/bridge.R +++ b/R/bridge.R @@ -20,10 +20,10 @@ bagger_bridge <- function(processed, weights, base_model, seed, times, control, if (is.null(cost)) { res <- switch( base_model, - CART = cart_bagger(rs, control, ...), - C5.0 = c5_bagger(rs, control, ...), - MARS = mars_bagger(rs, control, ...), - nnet = nnet_bagger(rs, control, ...) + CART = cart_bagger(rs, control, ..., call = call), + C5.0 = c5_bagger(rs, control, ..., call = call), + MARS = mars_bagger(rs, control, ..., call = call), + nnet = nnet_bagger(rs, control, ..., call = call) ) } else { res <- switch( diff --git a/R/cart.R b/R/cart.R index 58a0e4a..e46929b 100644 --- a/R/cart.R +++ b/R/cart.R @@ -1,5 +1,5 @@ -cart_bagger <- function(rs, control, ...) { +cart_bagger <- function(rs, control, ..., call) { opt <- rlang::dots_list(...) is_classif <- is.factor(rs$splits[[1]]$data$.outcome) mod_spec <- make_cart_spec(is_classif, opt) @@ -19,7 +19,7 @@ cart_bagger <- function(rs, control, ...) { ) ) - rs <- check_for_disaster(rs) + rs <- check_for_disaster(rs, call = call) rs <- filter_rs(rs) diff --git a/R/cost_models.R b/R/cost_models.R index eabb19c..7eb53bf 100644 --- a/R/cost_models.R +++ b/R/cost_models.R @@ -37,7 +37,7 @@ cost_sens_cart_bagger <- function(rs, control, cost, ..., call = rlang::caller_e opt$parms <- list(loss = cost) } - cart_bagger(rs = rs, control = control, !!!opt) + cart_bagger(rs = rs, control = control, call = call, !!!opt) } @@ -54,5 +54,5 @@ cost_sens_c5_bagger <- function(rs, control, cost, ..., call = rlang::caller_env # Attach cost matrix to options opt$costs <- cost - c5_bagger(rs = rs, control = control, !!!opt) + c5_bagger(rs = rs, control = control, call = call, !!!opt) } diff --git a/R/mars.R b/R/mars.R index 27afc73..4f5c480 100644 --- a/R/mars.R +++ b/R/mars.R @@ -1,5 +1,5 @@ -mars_bagger <- function(rs, control, ...) { +mars_bagger <- function(rs, control, ..., call) { opt <- rlang::dots_list(...) is_classif <- is.factor(rs$splits[[1]]$data$.outcome) @@ -18,7 +18,7 @@ mars_bagger <- function(rs, control, ...) { control = control )) - rs <- check_for_disaster(rs) + rs <- check_for_disaster(rs, call = call) rs <- filter_rs(rs) diff --git a/R/nnet.R b/R/nnet.R index c1531ac..3a820f5 100644 --- a/R/nnet.R +++ b/R/nnet.R @@ -1,5 +1,5 @@ -nnet_bagger <- function(rs, control, ...) { +nnet_bagger <- function(rs, control, ..., call) { opt <- rlang::dots_list(...) is_classif <- is.factor(rs$splits[[1]]$data$.outcome) mod_spec <- make_nnet_spec(is_classif, opt) @@ -19,7 +19,7 @@ nnet_bagger <- function(rs, control, ...) { ) ) - rs <- check_for_disaster(rs) + rs <- check_for_disaster(rs, call = call) rs <- filter_rs(rs) diff --git a/R/validate.R b/R/validate.R index f842324..616152c 100644 --- a/R/validate.R +++ b/R/validate.R @@ -69,9 +69,12 @@ check_for_disaster <- function(x, call = rlang::caller_env()) { if (!is.na(msg)) { # escape any brackets in the error message msg <- cli::format_error("{msg}") - cli::cli_abort(c("All of the models failed. Example:", "x" = "{msg}")) + cli::cli_abort( + c("All of the models failed. Example:", "x" = "{msg}"), + call = call + ) } else { - cli::cli_abort("All of the models failed.") + cli::cli_abort("All of the models failed.", call = call) } } x diff --git a/tests/testthat/_snaps/validation.md b/tests/testthat/_snaps/validation.md index 87b8374..5f05a69 100644 --- a/tests/testthat/_snaps/validation.md +++ b/tests/testthat/_snaps/validation.md @@ -160,7 +160,7 @@ set.seed(459394) bagger(a ~ ., data = bad_iris, base_model = "CART", times = 3) Condition - Error in `check_for_disaster()`: + Error in `bagger()`: ! All of the models failed. Example: x Error in cbind(yval2, yprob, nodeprob) : number of rows of matrices must match (see arg 2)