Skip to content

Commit

Permalink
Merge pull request #1093 from tidymodels/cli-check_args
Browse files Browse the repository at this point in the history
switch to {cli} in check_args() functions
  • Loading branch information
EmilHvitfeldt authored Apr 10, 2024
2 parents 7b7e118 + 21c0e91 commit 2f386d2
Show file tree
Hide file tree
Showing 48 changed files with 578 additions and 201 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: parsnip
Title: A Common API to Modeling and Analysis Functions
Version: 1.2.1.9000
Version: 1.2.1.9001
Authors@R: c(
person("Max", "Kuhn", , "[email protected]", role = c("aut", "cre")),
person("Davis", "Vaughan", , "[email protected]", role = "aut"),
Expand Down
4 changes: 1 addition & 3 deletions R/bag_tree.R
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,7 @@ update.bag_tree <-
# ------------------------------------------------------------------------------

#' @export
check_args.bag_tree <- function(object) {
if (object$engine == "C5.0" && object$mode == "regression")
stop("C5.0 is classification only.", call. = FALSE)
check_args.bag_tree <- function(object, call = rlang::caller_env()) {
invisible(object)
}

Expand Down
20 changes: 6 additions & 14 deletions R/boost_tree.R
Original file line number Diff line number Diff line change
Expand Up @@ -164,23 +164,15 @@ translate.boost_tree <- function(x, engine = x$engine, ...) {
# ------------------------------------------------------------------------------

#' @export
check_args.boost_tree <- function(object) {
check_args.boost_tree <- function(object, call = rlang::caller_env()) {

args <- lapply(object$args, rlang::eval_tidy)

if (is.numeric(args$trees) && args$trees < 0) {
rlang::abort("`trees` should be >= 1.")
}
if (is.numeric(args$sample_size) && (args$sample_size < 0 | args$sample_size > 1)) {
rlang::abort("`sample_size` should be within [0,1].")
}
if (is.numeric(args$tree_depth) && args$tree_depth < 0) {
rlang::abort("`tree_depth` should be >= 1.")
}
if (is.numeric(args$min_n) && args$min_n < 0) {
rlang::abort("`min_n` should be >= 1.")
}

check_number_whole(args$trees, min = 0, allow_null = TRUE, call = call, arg = "trees")
check_number_decimal(args$sample_size, min = 0, max = 1, allow_null = TRUE, call = call, arg = "sample_size")
check_number_whole(args$tree_depth, min = 0, allow_null = TRUE, call = call, arg = "tree_depth")
check_number_whole(args$min_n, min = 0, allow_null = TRUE, call = call, arg = "min_n")

invisible(object)
}

Expand Down
31 changes: 11 additions & 20 deletions R/c5_rules.R
Original file line number Diff line number Diff line change
Expand Up @@ -111,32 +111,23 @@ update.C5_rules <-
# make work in different places

#' @export
check_args.C5_rules <- function(object) {
check_args.C5_rules <- function(object, call = rlang::caller_env()) {

args <- lapply(object$args, rlang::eval_tidy)

if (is.numeric(args$trees)) {
if (length(args$trees) > 1) {
rlang::abort("Only a single value of `trees` is used.")
}
msg <- "The number of trees should be >= 1 and <= 100. Truncating the value."
if (args$trees > 100) {
object$args$trees <-
rlang::new_quosure(100L, env = rlang::empty_env())
rlang::warn(msg)
}
if (args$trees < 1) {
object$args$trees <-
rlang::new_quosure(1L, env = rlang::empty_env())
rlang::warn(msg)
}
check_number_whole(args$min_n, allow_null = TRUE, call = call, arg = "min_n")
check_number_whole(args$tree, allow_null = TRUE, call = call, arg = "tree")

msg <- "The number of trees should be {.code >= 1} and {.code <= 100}"
if (!(is.null(args$trees)) && args$trees > 100) {
object$args$trees <- rlang::new_quosure(100L, env = rlang::empty_env())
cli::cli_warn(c(msg, "Truncating to 100."))
}
if (is.numeric(args$min_n)) {
if (length(args$min_n) > 1) {
rlang::abort("Only a single `min_n`` value is used.")
}
if (!(is.null(args$trees)) && args$trees < 1) {
object$args$trees <- rlang::new_quosure(1L, env = rlang::empty_env())
cli::cli_warn(c(msg, "Truncating to 1."))
}

invisible(object)
}

Expand Down
54 changes: 23 additions & 31 deletions R/cubist_rules.R
Original file line number Diff line number Diff line change
Expand Up @@ -135,44 +135,36 @@ update.cubist_rules <-
# make work in different places

#' @export
check_args.cubist_rules <- function(object) {
check_args.cubist_rules <- function(object, call = rlang::caller_env()) {

args <- lapply(object$args, rlang::eval_tidy)

if (is.numeric(args$committees)) {
if (length(args$committees) > 1) {
rlang::abort("Only a single committee member is used.")
}
msg <- "The number of committees should be >= 1 and <= 100. Truncating the value."
if (args$committees > 100) {
object$args$committees <-
rlang::new_quosure(100L, env = rlang::empty_env())
rlang::warn(msg)
}
if (args$committees < 1) {
object$args$committees <-
rlang::new_quosure(1L, env = rlang::empty_env())
rlang::warn(msg)
}
check_number_whole(args$committees, allow_null = TRUE, call = call, arg = "committees")

}
if (is.numeric(args$neighbors)) {
if (length(args$neighbors) > 1) {
rlang::abort("Only a single neighbors value is used.")
}
msg <- "The number of neighbors should be >= 0 and <= 9. Truncating the value."
if (args$neighbors > 9) {
object$args$neighbors <-
rlang::new_quosure(9L, env = rlang::empty_env())
rlang::warn(msg)
}
if (args$neighbors < 0) {
object$args$neighbors <-
rlang::new_quosure(0L, env = rlang::empty_env())
rlang::warn(msg)
msg <- "The number of committees should be {.code >= 1} and {.code <= 100}."
if (!(is.null(args$committees)) && args$committees > 100) {
object$args$committees <-
rlang::new_quosure(100L, env = rlang::empty_env())
cli::cli_warn(c(msg, "Truncating to 100."))
}
if (!(is.null(args$committees)) && args$committees < 1) {
object$args$committees <-
rlang::new_quosure(1L, env = rlang::empty_env())
cli::cli_warn(c(msg, "Truncating to 1."))
}

check_number_whole(args$neighbors, allow_null = TRUE, call = call, arg = "neighbors")

msg <- "The number of neighbors should be {.code >= 0} and {.code <= 9}."
if (!(is.null(args$neighbors)) && args$neighbors > 9) {
object$args$neighbors <- rlang::new_quosure(9L, env = rlang::empty_env())
cli::cli_warn(c(msg, "Truncating to 9."))
}
if (!(is.null(args$neighbors)) && args$neighbors < 0) {
object$args$neighbors <- rlang::new_quosure(0L, env = rlang::empty_env())
cli::cli_warn(c(msg, "Truncating to 0."))
}

invisible(object)
}

Expand Down
4 changes: 1 addition & 3 deletions R/decision_tree.R
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,7 @@ translate.decision_tree <- function(x, engine = x$engine, ...) {
# ------------------------------------------------------------------------------

#' @export
check_args.decision_tree <- function(object) {
if (object$engine == "C5.0" && object$mode == "regression")
rlang::abort("C5.0 is classification only.")
check_args.decision_tree <- function(object, call = rlang::caller_env()) {
invisible(object)
}

Expand Down
17 changes: 5 additions & 12 deletions R/discrim_flexible.R
Original file line number Diff line number Diff line change
Expand Up @@ -85,21 +85,14 @@ update.discrim_flexible <-
# ------------------------------------------------------------------------------

#' @export
check_args.discrim_flexible <- function(object) {
check_args.discrim_flexible <- function(object, call = rlang::caller_env()) {

args <- lapply(object$args, rlang::eval_tidy)

if (is.numeric(args$prod_degree) && args$prod_degree < 0)
stop("`prod_degree` should be >= 1", call. = FALSE)

if (is.numeric(args$num_terms) && args$num_terms < 0)
stop("`num_terms` should be >= 1", call. = FALSE)

if (!is.character(args$prune_method) &&
!is.null(args$prune_method) &&
!is.character(args$prune_method))
stop("`prune_method` should be a single string value", call. = FALSE)

check_number_whole(args$prod_degree, min = 1, allow_null = TRUE, call = call, arg = "prod_degree")
check_number_whole(args$num_terms, min = 1, allow_null = TRUE, call = call, arg = "num_terms")
check_string(args$prune_method, allow_empty = FALSE, allow_null = TRUE, call = call, arg = "prune_method")

invisible(object)
}

Expand Down
6 changes: 2 additions & 4 deletions R/discrim_linear.R
Original file line number Diff line number Diff line change
Expand Up @@ -80,13 +80,11 @@ update.discrim_linear <-
# ------------------------------------------------------------------------------

#' @export
check_args.discrim_linear <- function(object) {
check_args.discrim_linear <- function(object, call = rlang::caller_env()) {

args <- lapply(object$args, rlang::eval_tidy)

if (all(is.numeric(args$penalty)) && any(args$penalty < 0)) {
stop("The amount of regularization should be >= 0", call. = FALSE)
}
check_number_decimal(args$penalty, min = 0, allow_null = TRUE, call = call, arg = "penalty")

invisible(object)
}
Expand Down
13 changes: 4 additions & 9 deletions R/discrim_regularized.R
Original file line number Diff line number Diff line change
Expand Up @@ -95,18 +95,13 @@ update.discrim_regularized <-
# ------------------------------------------------------------------------------

#' @export
check_args.discrim_regularized <- function(object) {
check_args.discrim_regularized <- function(object, call = rlang::caller_env()) {

args <- lapply(object$args, rlang::eval_tidy)

if (is.numeric(args$frac_common_cov) &&
(args$frac_common_cov < 0 | args$frac_common_cov > 1)) {
stop("The common covariance fraction should be between zero and one", call. = FALSE)
}
if (is.numeric(args$frac_identity) &&
(args$frac_identity < 0 | args$frac_identity > 1)) {
stop("The identity matrix fraction should be between zero and one", call. = FALSE)
}
check_number_decimal(args$frac_common_cov, min = 0, max = 1, allow_null = TRUE, call = call, arg = "frac_common_cov")
check_number_decimal(args$frac_identity, min = 0, max = 1, allow_null = TRUE, call = call, arg = "frac_identity")

invisible(object)
}

Expand Down
18 changes: 12 additions & 6 deletions R/fit_helpers.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# data to formula/data objects and so on.

form_form <-
function(object, control, env, ...) {
function(object, control, env, ..., call = rlang::caller_env()) {

if (inherits(env$data, "data.frame")) {
check_outcome(eval_tidy(rlang::f_lhs(env$formula), env$data), object)
Expand Down Expand Up @@ -32,7 +32,7 @@ form_form <-
}

# evaluate quoted args once here to check them
object <- check_args(object)
object <- check_args(object, call = call)

# sub in arguments to actual syntax for corresponding engine
object <- translate(object, engine = object$engine)
Expand Down Expand Up @@ -60,7 +60,12 @@ form_form <-
res
}

xy_xy <- function(object, env, control, target = "none", ...) {
xy_xy <- function(object,
env,
control,
target = "none",
...,
call = rlang::caller_env()) {

if (inherits(env$x, "tbl_spark") | inherits(env$y, "tbl_spark"))
rlang::abort("spark objects can only be used with the formula interface to `fit()`")
Expand All @@ -83,7 +88,7 @@ xy_xy <- function(object, env, control, target = "none", ...) {
}

# evaluate quoted args once here to check them
object <- check_args(object)
object <- check_args(object, call = call)

# sub in arguments to actual syntax for corresponding engine
object <- translate(object, engine = object$engine)
Expand Down Expand Up @@ -114,7 +119,7 @@ xy_xy <- function(object, env, control, target = "none", ...) {
}

form_xy <- function(object, control, env,
target = "none", ...) {
target = "none", ..., call = rlang::caller_env()) {

encoding_info <-
get_encoding(class(object)[1]) %>%
Expand All @@ -138,7 +143,8 @@ form_xy <- function(object, control, env,
object = object,
env = env, #weights!
control = control,
target = target
target = target,
call = call
)
data_obj$y_var <- all.vars(rlang::f_lhs(env$formula))
data_obj$x <- NULL
Expand Down
10 changes: 3 additions & 7 deletions R/linear_reg.R
Original file line number Diff line number Diff line change
Expand Up @@ -106,16 +106,12 @@ update.linear_reg <-
# ------------------------------------------------------------------------------

#' @export
check_args.linear_reg <- function(object) {
check_args.linear_reg <- function(object, call = rlang::caller_env()) {

args <- lapply(object$args, rlang::eval_tidy)

if (all(is.numeric(args$penalty)) && any(args$penalty < 0))
rlang::abort("The amount of regularization should be >= 0.")
if (is.numeric(args$mixture) && (args$mixture < 0 | args$mixture > 1))
rlang::abort("The mixture proportion should be within [0,1].")
if (is.numeric(args$mixture) && length(args$mixture) > 1)
rlang::abort("Only one value of `mixture` is allowed.")
check_number_decimal(args$mixture, min = 0, max = 1, allow_null = TRUE, call = call, arg = "mixture")
check_number_decimal(args$penalty, min = 0, allow_null = TRUE, call = call, arg = "penalty")

invisible(object)
}
35 changes: 21 additions & 14 deletions R/logistic_reg.R
Original file line number Diff line number Diff line change
Expand Up @@ -135,25 +135,32 @@ update.logistic_reg <-
# ------------------------------------------------------------------------------

#' @export
check_args.logistic_reg <- function(object) {
check_args.logistic_reg <- function(object, call = rlang::caller_env()) {

args <- lapply(object$args, rlang::eval_tidy)

if (all(is.numeric(args$penalty)) && any(args$penalty < 0))
rlang::abort("The amount of regularization should be >= 0.")
if (is.numeric(args$mixture) && (args$mixture < 0 | args$mixture > 1))
rlang::abort("The mixture proportion should be within [0,1].")
if (is.numeric(args$mixture) && length(args$mixture) > 1)
rlang::abort("Only one value of `mixture` is allowed.")
check_number_decimal(args$mixture, min = 0, max = 1, allow_null = TRUE, call = call, arg = "mixture")
check_number_decimal(args$penalty, min = 0, allow_null = TRUE, call = call, arg = "penalty")

if (object$engine == "LiblineaR") {
if(is.numeric(args$mixture) && !args$mixture %in% 0:1)
rlang::abort(c("For the LiblineaR engine, mixture must be 0 or 1.",
"Choose a pure ridge model with `mixture = 0`.",
"Choose a pure lasso model with `mixture = 1`.",
"The Liblinear engine does not support other values."))
if(all(is.numeric(args$penalty)) && !all(args$penalty > 0))
rlang::abort("For the LiblineaR engine, penalty must be > 0.")
if (is.numeric(args$mixture) && !args$mixture %in% 0:1) {
cli::cli_abort(
c("x" = "For the {.pkg LiblineaR} engine, mixture must be 0 or 1, \\
not {args$mixture}.",
"i" = "Choose a pure ridge model with {.code mixture = 0} or \\
a pure lasso model with {.code mixture = 1}.",
"!" = "The {.pkg Liblinear} engine does not support other values."),
call = call
)
}

if ((!is.null(args$penalty)) && args$penalty == 0) {
cli::cli_abort(
"For the {.pkg LiblineaR} engine, {.arg penalty} must be {.code > 0}, \\
not 0.",
call = call
)
}
}

invisible(object)
Expand Down
15 changes: 4 additions & 11 deletions R/mars.R
Original file line number Diff line number Diff line change
Expand Up @@ -105,20 +105,13 @@ translate.mars <- function(x, engine = x$engine, ...) {
# ------------------------------------------------------------------------------

#' @export
check_args.mars <- function(object) {
check_args.mars <- function(object, call = rlang::caller_env()) {

args <- lapply(object$args, rlang::eval_tidy)

if (is.numeric(args$prod_degree) && args$prod_degree < 0)
rlang::abort("`prod_degree` should be >= 1.")

if (is.numeric(args$num_terms) && args$num_terms < 0)
rlang::abort("`num_terms` should be >= 1.")

if (!is_varying(args$prune_method) &&
!is.null(args$prune_method) &&
!is.character(args$prune_method))
rlang::abort("`prune_method` should be a single string value.")
check_number_whole(args$prod_degree, min = 1, allow_null = TRUE, call = call, arg = "prod_degree")
check_number_whole(args$num_terms, min = 1, allow_null = TRUE, call = call, arg = "num_terms")
check_string(args$prune_method, allow_empty = FALSE, allow_null = TRUE, call = call, arg = "prune_method")

invisible(object)
}
Expand Down
Loading

0 comments on commit 2f386d2

Please sign in to comment.