Skip to content

Commit 071749e

Browse files
authored
fit calibrators at fit.container() (#12)
1 parent 71ed887 commit 071749e

13 files changed

+222
-96
lines changed

R/adjust-numeric-calibration.R

+33-19
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
#' Re-calibrate numeric predictions
22
#'
33
#' @param x A [container()].
4-
#' @param calibrator A pre-trained calibration method from the \pkg{probably}
5-
#' package, such as [probably::cal_estimate_linear()].
4+
#' @param type Character. One of `"linear"`, `"isotonic"`, or
5+
#' `"isotonic_boot"`, corresponding to the function from the \pkg{probably}
6+
#' package [probably::cal_estimate_linear()],
7+
#' [probably::cal_estimate_isotonic()], or
8+
#' [probably::cal_estimate_isotonic_boot()], respectively.
69
#' @examples
710
#' library(modeldata)
811
#' library(probably)
@@ -14,27 +17,24 @@
1417
#'
1518
#' dat
1619
#'
17-
#' # calibrate numeric predictions
18-
#' reg_cal <- cal_estimate_linear(dat, truth = y, estimate = y_pred)
19-
#'
2020
#' # specify calibration
2121
#' reg_ctr <-
2222
#' container(mode = "regression") %>%
23-
#' adjust_numeric_calibration(reg_cal)
23+
#' adjust_numeric_calibration(type = "linear")
2424
#'
25-
#' # "train" container
25+
#' # train container
2626
#' reg_ctr_trained <- fit(reg_ctr, dat, outcome = y, estimate = y_pred)
2727
#'
28-
#' predict(reg_ctr, dat)
28+
#' predict(reg_ctr_trained, dat)
2929
#' @export
30-
adjust_numeric_calibration <- function(x, calibrator) {
31-
check_container(x)
32-
check_required(calibrator)
33-
if (!inherits(calibrator, "cal_regression")) {
34-
cli_abort(
35-
"{.arg calibrator} should be a \\
36-
{.help [<cal_regression> object](probably::cal_estimate_linear)}, \\
37-
not {.obj_type_friendly {calibrator}}."
30+
adjust_numeric_calibration <- function(x, type = NULL) {
31+
# to-do: add argument specifying `prop` in initial_split
32+
check_container(x, calibration_type = "numeric")
33+
# wait to `check_type()` until `fit()` time
34+
if (!is.null(type)) {
35+
arg_match0(
36+
type,
37+
c("linear", "isotonic", "isotonic_boot")
3838
)
3939
}
4040

@@ -43,7 +43,7 @@ adjust_numeric_calibration <- function(x, calibrator) {
4343
"numeric_calibration",
4444
inputs = "numeric",
4545
outputs = "numeric",
46-
arguments = list(calibrator = calibrator),
46+
arguments = list(type = type),
4747
results = list(),
4848
trained = FALSE
4949
)
@@ -67,19 +67,33 @@ print.numeric_calibration <- function(x, ...) {
6767

6868
#' @export
6969
fit.numeric_calibration <- function(object, data, container = NULL, ...) {
70+
type <- check_type(object$type, container$type)
71+
# todo: adjust_numeric_calibration() should take arguments to pass to
72+
# cal_estimate_* via dots
73+
fit <-
74+
eval_bare(
75+
call2(
76+
paste0("cal_estimate_", type),
77+
.data = data,
78+
truth = container$columns$outcome,
79+
estimate = container$columns$estimate,
80+
.ns = "probably"
81+
)
82+
)
83+
7084
new_operation(
7185
class(object),
7286
inputs = object$inputs,
7387
outputs = object$outputs,
7488
arguments = object$arguments,
75-
results = list(),
89+
results = list(fit = fit),
7690
trained = TRUE
7791
)
7892
}
7993

8094
#' @export
8195
predict.numeric_calibration <- function(object, new_data, container, ...) {
82-
probably::cal_apply(new_data, object$argument$calibrator)
96+
probably::cal_apply(new_data, object$results$fit)
8397
}
8498

8599
# todo probably needs required_pkgs methods for cal objects

R/adjust-probability-calibration.R

+31-14
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,19 @@
11
#' Re-calibrate classification probability predictions
22
#'
33
#' @param x A [container()].
4-
#' @param calibrator A pre-trained calibration method from the \pkg{probably}
5-
#' package, such as [probably::cal_estimate_logistic()].
4+
#' @param type Character. One of `"logistic"`, `"multinomial"`,
5+
#' `"beta"`, `"isotonic"`, or `"isotonic_boot"`, corresponding to the
6+
#' function from the \pkg{probably} package [probably::cal_estimate_logistic()],
7+
#' [probably::cal_estimate_multinomial()], etc., respectively.
68
#' @export
7-
adjust_probability_calibration <- function(x, calibrator) {
8-
check_container(x)
9-
cls <- c("cal_binary", "cal_multinomial")
10-
check_required(calibrator)
11-
if (!inherits_any(calibrator, cls)) {
12-
cli_abort(
13-
"{.arg calibrator} should be a \\
14-
{.help [<cal_binary> or <cal_multinomial> object](probably::cal_estimate_logistic)}, \\
15-
not {.obj_type_friendly {calibrator}}."
9+
adjust_probability_calibration <- function(x, type = NULL) {
10+
# to-do: add argument specifying `prop` in initial_split
11+
check_container(x, calibration_type = "probability")
12+
# wait to `check_type()` until `fit()` time
13+
if (!is.null(type)) {
14+
arg_match(
15+
type,
16+
c("logistic", "multinomial", "beta", "isotonic", "isotonic_boot")
1617
)
1718
}
1819

@@ -21,7 +22,7 @@ adjust_probability_calibration <- function(x, calibrator) {
2122
"probability_calibration",
2223
inputs = "probability",
2324
outputs = "probability_class",
24-
arguments = list(calibrator = calibrator),
25+
arguments = list(type = type),
2526
results = list(),
2627
trained = FALSE
2728
)
@@ -45,19 +46,35 @@ print.probability_calibration <- function(x, ...) {
4546

4647
#' @export
4748
fit.probability_calibration <- function(object, data, container = NULL, ...) {
49+
type <- check_type(object$type, container$type)
50+
# todo: adjust_probability_calibration() should take arguments to pass to
51+
# cal_estimate_* via dots
52+
# to-do: add argument specifying `prop` in initial_split
53+
fit <-
54+
eval_bare(
55+
call2(
56+
paste0("cal_estimate_", type),
57+
.data = data,
58+
# todo: make getters for the entries in `columns`
59+
truth = container$columns$outcome,
60+
estimate = container$columns$estimate,
61+
.ns = "probably"
62+
)
63+
)
64+
4865
new_operation(
4966
class(object),
5067
inputs = object$inputs,
5168
outputs = object$outputs,
5269
arguments = object$arguments,
53-
results = list(),
70+
results = list(fit = fit),
5471
trained = TRUE
5572
)
5673
}
5774

5875
#' @export
5976
predict.probability_calibration <- function(object, new_data, container, ...) {
60-
probably::cal_apply(new_data, object$argument$calibrator)
77+
probably::cal_apply(new_data, object$results$fit)
6178
}
6279

6380
# todo probably needs required_pkgs methods for cal objects

R/container.R

+1-1
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ fit.container <- function(object, .data, outcome, estimate, probabilities = c(),
130130

131131
num_oper <- length(object$operations)
132132
for (op in seq_len(num_oper)) {
133-
object$operations[[op]] <- fit(object$operations[[op]], data, object)
133+
object$operations[[op]] <- fit(object$operations[[op]], .data, object)
134134
.data <- predict(object$operations[[op]], .data, object)
135135
}
136136

R/utils.R

+83-2
Original file line numberDiff line numberDiff line change
@@ -49,14 +49,95 @@ is_container <- function(x) {
4949
}
5050

5151
# ad-hoc checking --------------------------------------------------------------
52-
check_container <- function(x, call = caller_env(), arg = caller_arg(x)) {
52+
check_container <- function(x, calibration_type = NULL, call = caller_env(), arg = caller_arg(x)) {
5353
if (!is_container(x)) {
54-
cli::cli_abort(
54+
cli_abort(
5555
"{.arg {arg}} should be a {.help [{.cls container}](container::container)}, \\
5656
not {.obj_type_friendly {x}}.",
5757
call = call
5858
)
5959
}
6060

61+
# check that the type of calibration ("numeric" or "probability") is
62+
# compatible with the container type
63+
if (!is.null(calibration_type)) {
64+
container_type <- x$type
65+
switch(
66+
container_type,
67+
regression =
68+
check_calibration_type(calibration_type, "numeric", container_type, call = call),
69+
binary = , multinomial =
70+
check_calibration_type(calibration_type, "probability", container_type, call = call)
71+
)
72+
}
73+
6174
invisible()
6275
}
76+
77+
check_calibration_type <- function(calibration_type, calibration_type_expected,
78+
container_type, call) {
79+
if (!identical(calibration_type, calibration_type_expected)) {
80+
cli_abort(
81+
"A {.field {container_type}} container is incompatible with the operation \\
82+
{.fun {paste0('adjust_', calibration_type, '_calibration')}}.",
83+
call = call
84+
)
85+
}
86+
}
87+
88+
types_regression <- c("linear", "isotonic", "isotonic_boot")
89+
types_binary <- c("logistic", "beta", "isotonic", "isotonic_boot")
90+
types_multiclass <- c("multinomial", "beta", "isotonic", "isotonic_boot")
91+
# a check function to be called when a container is being `fit()`ted.
92+
# by the time a container is fitted, we have:
93+
# * `adjust_type`, the `type` argument passed to an `adjust_*` function
94+
# * this argument has already been checked to agree with the kind of
95+
# `adjust_*()` function via `arg_match0()`.
96+
# * `container_type`, the `type` argument either specified in `container()`
97+
# or inferred in `fit.container()`.
98+
check_type <- function(adjust_type,
99+
container_type,
100+
arg = caller_arg(adjust_type),
101+
call = caller_env()) {
102+
# if no `adjust_type` was supplied, infer a reasonable one based on the
103+
# `container_type`
104+
if (is.null(adjust_type)) {
105+
switch(
106+
container_type,
107+
regression = return("linear"),
108+
binary = return("logistic"),
109+
multiclass = return("multinomial")
110+
)
111+
}
112+
113+
switch(
114+
container_type,
115+
regression = arg_match0(
116+
adjust_type,
117+
types_regression,
118+
arg_nm = arg,
119+
error_call = call
120+
),
121+
binary = arg_match0(
122+
adjust_type,
123+
types_binary,
124+
arg_nm = arg,
125+
error_call = call
126+
),
127+
multiclass = arg_match0(
128+
adjust_type,
129+
types_multiclass,
130+
arg_nm = arg,
131+
error_call = call
132+
),
133+
arg_match0(
134+
adjust_type,
135+
unique(c(types_regression, types_binary, types_multiclass)),
136+
arg_nm = arg,
137+
error_call = call
138+
)
139+
)
140+
141+
adjust_type
142+
}
143+

man/adjust_numeric_calibration.Rd

+9-9
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/adjust_probability_calibration.Rd

+5-3
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,36 @@
11
# adjustment printing
22

33
Code
4-
ctr_reg %>% adjust_numeric_calibration(dummy_reg_cal)
4+
ctr_reg %>% adjust_numeric_calibration()
55
Message
66
77
-- Container -------------------------------------------------------------------
8-
A postprocessor with 1 operation:
8+
A regression postprocessor with 1 operation:
99
1010
* Re-calibrate numeric predictions.
1111

1212
# errors informatively with bad input
1313

1414
Code
15-
adjust_numeric_calibration(ctr_reg)
15+
adjust_numeric_calibration(ctr_reg, "boop")
1616
Condition
1717
Error in `adjust_numeric_calibration()`:
18-
! `calibrator` is absent but must be supplied.
18+
! `type` must be one of "linear", "isotonic", or "isotonic_boot", not "boop".
1919

2020
---
2121

2222
Code
23-
adjust_numeric_calibration(ctr_reg, "boop")
23+
container("classification", "binary") %>% adjust_numeric_calibration("linear")
2424
Condition
2525
Error in `adjust_numeric_calibration()`:
26-
! `calibrator` should be a <cal_regression> object (`?probably::cal_estimate_linear()`), not a string.
26+
! A binary container is incompatible with the operation `adjust_numeric_calibration()`.
2727

2828
---
2929

3030
Code
31-
adjust_numeric_calibration(ctr_cls, dummy_cls_cal)
31+
container("regression", "regression") %>% adjust_numeric_calibration("binary")
3232
Condition
3333
Error in `adjust_numeric_calibration()`:
34-
! `calibrator` should be a <cal_regression> object (`?probably::cal_estimate_linear()`), not a <cal_binary> object.
34+
! `type` must be one of "linear", "isotonic", or "isotonic_boot", not "binary".
35+
i Did you mean "linear"?
3536

0 commit comments

Comments
 (0)