Skip to content

Commit

Permalink
limit the range of numeric predictions (#142)
Browse files Browse the repository at this point in the history
* add a function to constrain numeric predictions

* test cases

* argument name change

* test against tune() value

* Apply suggestions from code review

Co-authored-by: Simon P. Couch <[email protected]>

* cleanup

* rlang type checkers

* re-doc

* Apply suggestions from code review

Co-authored-by: Simon P. Couch <[email protected]>

* update tests

---------

Co-authored-by: ‘topepo’ <‘[email protected]’>
Co-authored-by: Simon P. Couch <[email protected]>
  • Loading branch information
3 people authored Apr 4, 2024
1 parent f36f4b2 commit 5aa7ecd
Show file tree
Hide file tree
Showing 11 changed files with 1,098 additions and 3 deletions.
4 changes: 2 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ Imports:
hardhat,
pillar,
purrr,
rlang (>= 1.0.4),
rlang (>= 1.1.0),
tidyr (>= 1.3.0),
tidyselect (>= 1.1.2),
tune (>= 1.1.2),
Expand Down Expand Up @@ -59,4 +59,4 @@ Config/testthat/edition: 3
Encoding: UTF-8
LazyData: true
Roxygen: list(markdown = TRUE)
RoxygenNote: 7.2.3
RoxygenNote: 7.3.1
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ export(as.factor)
export(as.ordered)
export(as_class_pred)
export(augment)
export(bound_prediction)
export(cal_apply)
export(cal_estimate_beta)
export(cal_estimate_isotonic)
Expand Down
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# probably (development version)

* A new function `bound_prediction()` is available to constrain the values of a numeric prediction (#142).

# probably 1.0.3

* Fixed a bug where the grouping for calibration methods was sensitive to the type of the grouping variables (#127).
Expand Down
42 changes: 42 additions & 0 deletions R/bound_prediction.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#' Truncate a numeric prediction column
#'
#' For user-defined `lower_limit` and/or `upper_limit` bound, ensure that the values in the
#' `.pred` column are coerced to these bounds.
#'
#' @param x A data frame that contains a numeric column named `.pred`.
#' @param lower_limit,upper_limit Single numerics (or `NA`) that define
#' constraints on `.pred`.
#' @param call The call to be displayed in warnings or errors.
#' @return `x` with potentially adjusted values.
#' @examples
#' data(solubility_test, package = "yardstick")
#'
#' names(solubility_test) <- c("solubility", ".pred")
#'
#' bound_prediction(solubility_test, lower_limit = -1)
#' @export
bound_prediction <- function(x, lower_limit = -Inf, upper_limit = Inf,
call = rlang::current_env()) {
check_data_frame(x, call = call)

if (!any(names(x) == ".pred")) {
cli::cli_abort("The argument {.arg x} should have a column named {.code .pred}.",
call = call)
}
if (!is.numeric(x$.pred)) {
cli::cli_abort("Column {.code .pred} should be numeric.", call = call)
}

check_number_decimal(lower_limit, allow_na = TRUE, call = call)
check_number_decimal(upper_limit, allow_na = TRUE, call = call)

if (!is.na(lower_limit)) {
x$.pred <- ifelse(x$.pred < lower_limit, lower_limit, x$.pred)
}

if (!is.na(upper_limit)) {
x$.pred <- ifelse(x$.pred > upper_limit, upper_limit, x$.pred)
}
x
}

Loading

0 comments on commit 5aa7ecd

Please sign in to comment.