Skip to content

Commit 1412531

Browse files
authored
simplify concept of outcome type in the package (#14)
* remove `container(mode)` * rename `adjust_*_calibration(type)` to `adjust_*_calibration(method)`
1 parent 071749e commit 1412531

35 files changed

+248
-177
lines changed

DESCRIPTION

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
Package: container
22
Title: Sandbox for a postprocessor object
3-
Version: 0.0.0.9000
3+
Version: 0.0.0.9001
44
Authors@R: c(
55
person("Simon", "Couch", , "[email protected]", role = "aut"),
66
person("Hannah", "Frick", , "[email protected]", role = "aut"),

R/adjust-equivocal-zone.R

+1-2
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
#' library(modeldata)
1010
#'
1111
#' post_obj <-
12-
#' container(mode = "classification") %>%
12+
#' container() %>%
1313
#' adjust_equivocal_zone(value = 1 / 4)
1414
#'
1515
#'
@@ -43,7 +43,6 @@ adjust_equivocal_zone <- function(x, value = 0.1, threshold = 1 / 2) {
4343
)
4444

4545
new_container(
46-
mode = x$mode,
4746
type = x$type,
4847
operations = c(x$operations, list(op)),
4948
columns = x$dat,

R/adjust-numeric-calibration.R

+10-11
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#' Re-calibrate numeric predictions
22
#'
33
#' @param x A [container()].
4-
#' @param type Character. One of `"linear"`, `"isotonic"`, or
4+
#' @param method Character. One of `"linear"`, `"isotonic"`, or
55
#' `"isotonic_boot"`, corresponding to the function from the \pkg{probably}
66
#' package [probably::cal_estimate_linear()],
77
#' [probably::cal_estimate_isotonic()], or
@@ -19,21 +19,21 @@
1919
#'
2020
#' # specify calibration
2121
#' reg_ctr <-
22-
#' container(mode = "regression") %>%
23-
#' adjust_numeric_calibration(type = "linear")
22+
#' container() %>%
23+
#' adjust_numeric_calibration(method = "linear")
2424
#'
2525
#' # train container
2626
#' reg_ctr_trained <- fit(reg_ctr, dat, outcome = y, estimate = y_pred)
2727
#'
2828
#' predict(reg_ctr_trained, dat)
2929
#' @export
30-
adjust_numeric_calibration <- function(x, type = NULL) {
30+
adjust_numeric_calibration <- function(x, method = NULL) {
3131
# to-do: add argument specifying `prop` in initial_split
3232
check_container(x, calibration_type = "numeric")
33-
# wait to `check_type()` until `fit()` time
34-
if (!is.null(type)) {
33+
# wait to `check_method()` until `fit()` time
34+
if (!is.null(method)) {
3535
arg_match0(
36-
type,
36+
method,
3737
c("linear", "isotonic", "isotonic_boot")
3838
)
3939
}
@@ -43,13 +43,12 @@ adjust_numeric_calibration <- function(x, type = NULL) {
4343
"numeric_calibration",
4444
inputs = "numeric",
4545
outputs = "numeric",
46-
arguments = list(type = type),
46+
arguments = list(method = method),
4747
results = list(),
4848
trained = FALSE
4949
)
5050

5151
new_container(
52-
mode = x$mode,
5352
type = x$type,
5453
operations = c(x$operations, list(op)),
5554
columns = x$dat,
@@ -67,13 +66,13 @@ print.numeric_calibration <- function(x, ...) {
6766

6867
#' @export
6968
fit.numeric_calibration <- function(object, data, container = NULL, ...) {
70-
type <- check_type(object$type, container$type)
69+
method <- check_method(object$method, container$type)
7170
# todo: adjust_numeric_calibration() should take arguments to pass to
7271
# cal_estimate_* via dots
7372
fit <-
7473
eval_bare(
7574
call2(
76-
paste0("cal_estimate_", type),
75+
paste0("cal_estimate_", method),
7776
.data = data,
7877
truth = container$columns$outcome,
7978
estimate = container$columns$estimate,

R/adjust-numeric-range.R

-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ adjust_numeric_range <- function(x, lower_limit = -Inf, upper_limit = Inf) {
1919
)
2020

2121
new_container(
22-
mode = x$mode,
2322
type = x$type,
2423
operations = c(x$operations, list(op)),
2524
columns = x$dat,

R/adjust-predictions-custom.R

+1-2
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
#' library(modeldata)
1010
#'
1111
#' post_obj <-
12-
#' container(mode = "classification") %>%
12+
#' container() %>%
1313
#' adjust_equivocal_zone() %>%
1414
#' adjust_predictions_custom(linear_predictor = binomial()$linkfun(Class2))
1515
#'
@@ -39,7 +39,6 @@ adjust_predictions_custom <- function(x, ..., .pkgs = character(0)) {
3939
)
4040

4141
new_container(
42-
mode = x$mode,
4342
type = x$type,
4443
operations = c(x$operations, list(op)),
4544
columns = x$dat,

R/adjust-probability-calibration.R

+8-9
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,18 @@
11
#' Re-calibrate classification probability predictions
22
#'
33
#' @param x A [container()].
4-
#' @param type Character. One of `"logistic"`, `"multinomial"`,
4+
#' @param method Character. One of `"logistic"`, `"multinomial"`,
55
#' `"beta"`, `"isotonic"`, or `"isotonic_boot"`, corresponding to the
66
#' function from the \pkg{probably} package [probably::cal_estimate_logistic()],
77
#' [probably::cal_estimate_multinomial()], etc., respectively.
88
#' @export
9-
adjust_probability_calibration <- function(x, type = NULL) {
9+
adjust_probability_calibration <- function(x, method = NULL) {
1010
# to-do: add argument specifying `prop` in initial_split
1111
check_container(x, calibration_type = "probability")
12-
# wait to `check_type()` until `fit()` time
13-
if (!is.null(type)) {
12+
# wait to `check_method()` until `fit()` time
13+
if (!is.null(method)) {
1414
arg_match(
15-
type,
15+
method,
1616
c("logistic", "multinomial", "beta", "isotonic", "isotonic_boot")
1717
)
1818
}
@@ -22,13 +22,12 @@ adjust_probability_calibration <- function(x, type = NULL) {
2222
"probability_calibration",
2323
inputs = "probability",
2424
outputs = "probability_class",
25-
arguments = list(type = type),
25+
arguments = list(method = method),
2626
results = list(),
2727
trained = FALSE
2828
)
2929

3030
new_container(
31-
mode = x$mode,
3231
type = x$type,
3332
operations = c(x$operations, list(op)),
3433
columns = x$dat,
@@ -46,14 +45,14 @@ print.probability_calibration <- function(x, ...) {
4645

4746
#' @export
4847
fit.probability_calibration <- function(object, data, container = NULL, ...) {
49-
type <- check_type(object$type, container$type)
48+
method <- check_method(object$method, container$type)
5049
# todo: adjust_probability_calibration() should take arguments to pass to
5150
# cal_estimate_* via dots
5251
# to-do: add argument specifying `prop` in initial_split
5352
fit <-
5453
eval_bare(
5554
call2(
56-
paste0("cal_estimate_", type),
55+
paste0("cal_estimate_", method),
5756
.data = data,
5857
# todo: make getters for the entries in `columns`
5958
truth = container$columns$outcome,

R/adjust-probability-threshold.R

+1-2
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
#' library(modeldata)
88
#'
99
#' post_obj <-
10-
#' container(mode = "classification") %>%
10+
#' container() %>%
1111
#' adjust_probability_threshold(threshold = .1)
1212
#'
1313
#' two_class_example %>% count(predicted)
@@ -39,7 +39,6 @@ adjust_probability_threshold <- function(x, threshold = 0.5) {
3939
)
4040

4141
new_container(
42-
mode = x$mode,
4342
type = x$type,
4443
operations = c(x$operations, list(op)),
4544
columns = x$dat,

R/container.R

+5-15
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
#' Declare post-processing for model predictions
22
#'
3-
#' @param mode The model's mode, one of `"classification"`, or `"regression"`.
4-
#' Modes of `"censored regression"` are not currently supported.
53
#' @param type The model sub-type. Possible values are `"unknown"`, `"regression"`,
64
#' `"binary"`, or `"multiclass"`.
75
#' @param outcome The name of the outcome variable.
@@ -14,9 +12,9 @@
1412
#' @param time The name of the predicted event time. (not yet supported)
1513
#' @examples
1614
#'
17-
#' container(mode = "regression")
15+
#' container()
1816
#' @export
19-
container <- function(mode, type = "unknown", outcome = NULL, estimate = NULL,
17+
container <- function(type = "unknown", outcome = NULL, estimate = NULL,
2018
probabilities = NULL, time = NULL) {
2119
columns <-
2220
list(
@@ -28,7 +26,6 @@ container <- function(mode, type = "unknown", outcome = NULL, estimate = NULL,
2826
)
2927

3028
new_container(
31-
mode,
3229
type,
3330
operations = list(),
3431
columns = columns,
@@ -37,13 +34,7 @@ container <- function(mode, type = "unknown", outcome = NULL, estimate = NULL,
3734
)
3835
}
3936

40-
new_container <- function(mode, type, operations, columns, ptype, call) {
41-
mode <- arg_match0(mode, c("regression", "classification"))
42-
43-
if (mode == "regression") {
44-
type <- "regression"
45-
}
46-
37+
new_container <- function(type, operations, columns, ptype, call) {
4738
type <- arg_match0(type, c("unknown", "regression", "binary", "multiclass"))
4839

4940
if (!is.list(operations)) {
@@ -58,11 +49,11 @@ new_container <- function(mode, type, operations, columns, ptype, call) {
5849
}
5950

6051
# validate operation order and check duplicates
61-
validate_order(operations, mode, call)
52+
validate_order(operations, type, call)
6253

6354
# check columns
6455
res <- list(
65-
mode = mode, type = type, operations = operations,
56+
type = type, operations = operations,
6657
columns = columns, ptype = ptype
6758
)
6859
class(res) <- "container"
@@ -120,7 +111,6 @@ fit.container <- function(object, .data, outcome, estimate, probabilities = c(),
120111
object <- set_container_type(object, .data[[columns$outcome]])
121112

122113
object <- new_container(
123-
object$mode,
124114
object$type,
125115
operations = object$operations,
126116
columns = columns,

R/utils.R

+18-19
Original file line numberDiff line numberDiff line change
@@ -61,13 +61,13 @@ check_container <- function(x, calibration_type = NULL, call = caller_env(), arg
6161
# check that the type of calibration ("numeric" or "probability") is
6262
# compatible with the container type
6363
if (!is.null(calibration_type)) {
64-
container_type <- x$type
64+
type <- x$type
6565
switch(
66-
container_type,
66+
type,
6767
regression =
68-
check_calibration_type(calibration_type, "numeric", container_type, call = call),
69-
binary = , multinomial =
70-
check_calibration_type(calibration_type, "probability", container_type, call = call)
68+
check_calibration_type(calibration_type, "numeric", type, call = call),
69+
binary = , multiclass =
70+
check_calibration_type(calibration_type, "probability", type, call = call)
7171
)
7272
}
7373

@@ -90,54 +90,53 @@ types_binary <- c("logistic", "beta", "isotonic", "isotonic_boot")
9090
types_multiclass <- c("multinomial", "beta", "isotonic", "isotonic_boot")
9191
# a check function to be called when a container is being `fit()`ted.
9292
# by the time a container is fitted, we have:
93-
# * `adjust_type`, the `type` argument passed to an `adjust_*` function
93+
# * `method`, the `method` argument passed to an `adjust_*` function
9494
# * this argument has already been checked to agree with the kind of
9595
# `adjust_*()` function via `arg_match0()`.
9696
# * `container_type`, the `type` argument either specified in `container()`
9797
# or inferred in `fit.container()`.
98-
check_type <- function(adjust_type,
99-
container_type,
100-
arg = caller_arg(adjust_type),
98+
check_method <- function(method,
99+
type,
100+
arg = caller_arg(method),
101101
call = caller_env()) {
102-
# if no `adjust_type` was supplied, infer a reasonable one based on the
103-
# `container_type`
104-
if (is.null(adjust_type)) {
102+
# if no `method` was supplied, infer a reasonable one based on the `type`
103+
if (is.null(method)) {
105104
switch(
106-
container_type,
105+
type,
107106
regression = return("linear"),
108107
binary = return("logistic"),
109108
multiclass = return("multinomial")
110109
)
111110
}
112111

113112
switch(
114-
container_type,
113+
type,
115114
regression = arg_match0(
116-
adjust_type,
115+
method,
117116
types_regression,
118117
arg_nm = arg,
119118
error_call = call
120119
),
121120
binary = arg_match0(
122-
adjust_type,
121+
method,
123122
types_binary,
124123
arg_nm = arg,
125124
error_call = call
126125
),
127126
multiclass = arg_match0(
128-
adjust_type,
127+
method,
129128
types_multiclass,
130129
arg_nm = arg,
131130
error_call = call
132131
),
133132
arg_match0(
134-
adjust_type,
133+
method,
135134
unique(c(types_regression, types_binary, types_multiclass)),
136135
arg_nm = arg,
137136
error_call = call
138137
)
139138
)
140139

141-
adjust_type
140+
method
142141
}
143142

R/validation-rules.R

+26-5
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
validate_order <- function(ops, mode, call) {
1+
validate_order <- function(ops, type, call = caller_env()) {
22
orderings <-
33
tibble::new_tibble(list(
44
name = purrr::map_chr(ops, ~ class(.x)[1]),
@@ -13,12 +13,17 @@ validate_order <- function(ops, mode, call) {
1313
return(invisible(orderings))
1414
}
1515

16-
if (mode == "classification") {
17-
check_classification_order(orderings, call)
18-
} else {
19-
check_regression_order(orderings, call)
16+
if (type == "unknown") {
17+
type <- infer_type(orderings)
2018
}
2119

20+
switch(
21+
type,
22+
regression = check_regression_order(orderings, call),
23+
binary = , multiclass = check_classification_order(orderings, call),
24+
invisible()
25+
)
26+
2227
invisible(orderings)
2328
}
2429

@@ -83,3 +88,19 @@ check_duplicates <- function(x, call) {
8388
}
8489
invisible(x)
8590
}
91+
92+
infer_type <- function(orderings) {
93+
if (all(orderings$output_all)) {
94+
return("unknown")
95+
}
96+
97+
if (all(orderings$output_numeric | orderings$output_all)) {
98+
return("regression")
99+
}
100+
101+
if (all(orderings$output_prob | orderings$output_class | orderings$output_all)) {
102+
return("binary")
103+
}
104+
105+
"unknown"
106+
}

inst/examples/container_regression_example.qmd

+1-1
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ We could manually use `cal_apply()` to adjust predictions, but instead, we'll ad
101101
#| label: post-1
102102
103103
post_obj <-
104-
container(mode = "regression") %>%
104+
container() %>%
105105
adjust_numeric_calibration(bst_cal)
106106
post_obj
107107
```

0 commit comments

Comments
 (0)