3131# ' page in the references below). This enables the user to compute performance
3232# ' metrics in the \pkg{yardstick} package.
3333# '
34+ # ' ## Quantile Regression
35+ # '
36+ # ' For quantile regression models, a `.pred_quantile` column is added that
37+ # ' contains the quantile predictions for each row. This column has a special
38+ # ' class `"quantile_pred"` and can be unnested using [tidyr::unnest()]
39+ # '
3440# ' @param new_data A data frame or matrix.
3541# ' @param ... Not currently used.
3642# ' @rdname augment
7884# ' augment(cls_xy, cls_tst)
7985# ' augment(cls_xy, cls_tst[, -3])
8086# '
87+ # ' # ------------------------------------------------------------------------------
88+ # '
89+ # ' # Quantile regression example
90+ # ' qr_form <-
91+ # ' linear_reg() |>
92+ # ' set_engine("quantreg") |>
93+ # ' set_mode("quantile regression", quantile_levels = c(0.25, 0.5, 0.75)) |>
94+ # ' fit(mpg ~ ., data = car_trn)
95+ # '
96+ # ' augment(qr_form, car_tst)
97+ # ' augment(qr_form, car_tst[, -1])
98+ # '
8199augment.model_fit <- function (x , new_data , eval_time = NULL , ... ) {
82100 new_data <- tibble :: new_tibble(new_data )
83101 res <-
84102 switch (
85103 x $ spec $ mode ,
86- " regression" = augment_regression(x , new_data ),
87- " classification" = augment_classification(x , new_data ),
88- " censored regression" = augment_censored(x , new_data , eval_time = eval_time ),
104+ " regression" = augment_regression(x , new_data ),
105+ " classification" = augment_classification(x , new_data ),
106+ " censored regression" = augment_censored(
107+ x ,
108+ new_data ,
109+ eval_time = eval_time
110+ ),
111+ " quantile regression" = augment_quantile_regression(x , new_data ),
89112 cli :: cli_abort(
90113 c(
91114 " Unknown mode {.val {x$spec$mode}}." ,
@@ -106,7 +129,11 @@ augment_regression <- function(x, new_data) {
106129 ret <- dplyr :: mutate(ret , .resid = !! rlang :: sym(y_nm ) - .pred )
107130 }
108131 }
109- dplyr :: relocate(ret , dplyr :: starts_with(" .pred" ), dplyr :: starts_with(" .resid" ))
132+ dplyr :: relocate(
133+ ret ,
134+ dplyr :: starts_with(" .pred" ),
135+ dplyr :: starts_with(" .resid" )
136+ )
110137}
111138
112139augment_classification <- function (x , new_data ) {
@@ -117,11 +144,15 @@ augment_classification <- function(x, new_data) {
117144 }
118145
119146 if (spec_has_pred_type(x , " class" )) {
120- ret <- dplyr :: bind_cols(predict(x , new_data = new_data , type = " class" ), ret )
147+ ret <- dplyr :: bind_cols(
148+ predict(x , new_data = new_data , type = " class" ),
149+ ret
150+ )
121151 }
122152 ret
123153}
124154
155+
125156# nocov start
126157# tested in tidymodels/extratests#
127158augment_censored <- function (x , new_data , eval_time = NULL ) {
@@ -145,7 +176,8 @@ augment_censored <- function(x, new_data, eval_time = NULL) {
145176 .filter_eval_time(eval_time )
146177 ret <- dplyr :: bind_cols(
147178 predict(x , new_data = new_data , type = " survival" , eval_time = eval_time ),
148- ret )
179+ ret
180+ )
149181 # Add inverse probability weights when the outcome is present in new_data
150182 y_col <- .find_surv_col(new_data , fail = FALSE )
151183 if (length(y_col ) != 0 ) {
@@ -155,3 +187,10 @@ augment_censored <- function(x, new_data, eval_time = NULL) {
155187 ret
156188}
157189# nocov end
190+
191+ augment_quantile_regression <- function (x , new_data ) {
192+ ret <- new_data
193+ check_spec_pred_type(x , " quantile" )
194+ ret <- dplyr :: bind_cols(predict(x , new_data = new_data ), ret )
195+ dplyr :: relocate(ret , dplyr :: starts_with(" .pred" ))
196+ }
0 commit comments