|
1 | 1 | # Initialize model environments |
2 | 2 |
|
| 3 | +all_modes <- c("classification", "regression", "censored regression") |
| 4 | + |
3 | 5 | # ------------------------------------------------------------------------------ |
4 | 6 |
|
5 | 7 | ## Rules about model-related information |
|
23 | 25 |
|
24 | 26 | # ------------------------------------------------------------------------------ |
25 | 27 |
|
26 | | - |
27 | 28 | parsnip <- rlang::new_environment() |
28 | 29 | parsnip$models <- NULL |
29 | | -parsnip$modes <- c("regression", "classification", "unknown") |
| 30 | +parsnip$modes <- c(all_modes, "unknown") |
30 | 31 |
|
31 | 32 | # ------------------------------------------------------------------------------ |
32 | 33 |
|
@@ -134,25 +135,119 @@ check_mode_val <- function(mode) { |
134 | 135 | } |
135 | 136 |
|
136 | 137 |
|
137 | | -stop_incompatible_mode <- function(spec_modes) { |
| 138 | +stop_incompatible_mode <- function(spec_modes, eng = NULL, cls = NULL) { |
| 139 | + if (is.null(eng) & is.null(cls)) { |
| 140 | + msg <- "Available modes are: " |
| 141 | + } |
| 142 | + if (!is.null(eng) & is.null(cls)) { |
| 143 | + msg <- glue::glue("Available modes for engine {eng} are: ") |
| 144 | + } |
| 145 | + if (is.null(eng) & !is.null(cls)) { |
| 146 | + msg <- glue::glue("Available modes for model type {cls} are: ") |
| 147 | + } |
| 148 | + if (!is.null(eng) & !is.null(cls)) { |
| 149 | + msg <- glue::glue("Available modes for model type {cls} with engine {eng} are: ") |
| 150 | + } |
| 151 | + |
138 | 152 | msg <- glue::glue( |
139 | | - "Available modes are: ", |
| 153 | + msg, |
140 | 154 | glue::glue_collapse(glue::glue("'{spec_modes}'"), sep = ", ") |
141 | 155 | ) |
142 | 156 | rlang::abort(msg) |
143 | 157 | } |
144 | 158 |
|
145 | | -# check if class and mode are compatible |
146 | | -check_spec_mode_val <- function(cls, mode) { |
147 | | - spec_modes <- rlang::env_get(get_model_env(), paste0(cls, "_modes")) |
| 159 | +stop_incompatible_engine <- function(spec_engs, mode) { |
| 160 | + msg <- glue::glue( |
| 161 | + "Available engines for mode {mode} are: ", |
| 162 | + glue::glue_collapse(glue::glue("'{spec_engs}'"), sep = ", ") |
| 163 | + ) |
| 164 | + rlang::abort(msg) |
| 165 | +} |
| 166 | + |
| 167 | +stop_missing_engine <- function(cls) { |
| 168 | + info <- |
| 169 | + get_from_env(cls) %>% |
| 170 | + dplyr::group_by(mode) %>% |
| 171 | + dplyr::summarize(msg = paste0(unique(mode), " {", |
| 172 | + paste0(unique(engine), collapse = ", "), |
| 173 | + "}"), |
| 174 | + .groups = "drop") |
| 175 | + if (nrow(info) == 0) { |
| 176 | + rlang::abort(paste0("No known engines for `", cls, "()`.")) |
| 177 | + } |
| 178 | + msg <- paste0(info$msg, collapse = ", ") |
| 179 | + msg <- paste("Missing engine. Possible mode/engine combinations are:", msg) |
| 180 | + rlang::abort(msg) |
| 181 | +} |
| 182 | + |
| 183 | + |
| 184 | +# check if class and mode and engine are compatible |
| 185 | +check_spec_mode_engine_val <- function(cls, eng, mode) { |
| 186 | + all_modes <- c("unknown", all_modes) |
| 187 | + if (!(mode %in% all_modes)) { |
| 188 | + rlang::abort(paste0("'", mode, "' is not a known mode.")) |
| 189 | + } |
| 190 | + |
| 191 | + model_info <- rlang::env_get(get_model_env(), cls) |
| 192 | + |
| 193 | + # Cases where the model definition is in parsnip but all of the engines |
| 194 | + # are contained in a different package |
| 195 | + if (nrow(model_info) == 0) { |
| 196 | + check_mode_with_no_engine(cls, mode) |
| 197 | + return(invisible(NULL)) |
| 198 | + } |
| 199 | + |
| 200 | + # ------------------------------------------------------------------------------ |
| 201 | + # First check engine against any mode for the given model class |
| 202 | + |
| 203 | + spec_engs <- model_info$engine |
| 204 | + # engine is allowed to be NULL |
| 205 | + if (!is.null(eng) && !(eng %in% spec_engs)) { |
| 206 | + rlang::abort( |
| 207 | + paste0( |
| 208 | + "Engine '", eng, "' is not supported for `", cls, "()`. See ", |
| 209 | + "`show_engines('", cls, "')`." |
| 210 | + ) |
| 211 | + ) |
| 212 | + } |
| 213 | + |
| 214 | + # ---------------------------------------------------------------------------- |
| 215 | + # Check modes based on model and engine |
| 216 | + |
| 217 | + spec_modes <- model_info$mode |
| 218 | + if (!is.null(eng)) { |
| 219 | + spec_modes <- spec_modes[model_info$engine == eng] |
| 220 | + } |
| 221 | + spec_modes <- unique(c("unknown", spec_modes)) |
| 222 | + |
148 | 223 | if (is.null(mode) || length(mode) > 1) { |
149 | | - stop_incompatible_mode(spec_modes) |
| 224 | + stop_incompatible_mode(spec_modes, eng) |
150 | 225 | } else if (!(mode %in% spec_modes)) { |
151 | | - stop_incompatible_mode(spec_modes) |
| 226 | + stop_incompatible_mode(spec_modes, eng) |
152 | 227 | } |
| 228 | + |
| 229 | + # ---------------------------------------------------------------------------- |
| 230 | + # Check engine based on model and model |
| 231 | + |
| 232 | + # How check for compatibility with the chosen mode (if any) |
| 233 | + if (!is.null(mode) && mode != "unknown") { |
| 234 | + spec_engs <- spec_engs[model_info$mode == mode] |
| 235 | + } |
| 236 | + spec_engs <- unique(spec_engs) |
| 237 | + if (!is.null(eng) && !(eng %in% spec_engs)) { |
| 238 | + stop_incompatible_engine(spec_engs, mode) |
| 239 | + } |
| 240 | + |
153 | 241 | invisible(NULL) |
154 | 242 | } |
155 | 243 |
|
| 244 | +check_mode_with_no_engine <- function(cls, mode) { |
| 245 | + spec_modes <- get_from_env(paste0(cls, "_modes")) |
| 246 | + if (!(mode %in% spec_modes)) { |
| 247 | + stop_incompatible_mode(spec_modes, cls = cls) |
| 248 | + } |
| 249 | +} |
| 250 | + |
156 | 251 | check_engine_val <- function(eng) { |
157 | 252 | if (rlang::is_missing(eng) || length(eng) != 1 || !is.character(eng)) |
158 | 253 | rlang::abort("Please supply a character string for an engine (e.g. `'lm'`).") |
@@ -625,8 +720,7 @@ get_dependency <- function(model) { |
625 | 720 | set_fit <- function(model, mode, eng, value) { |
626 | 721 | check_model_exists(model) |
627 | 722 | check_eng_val(eng) |
628 | | - check_mode_val(mode) |
629 | | - check_engine_val(eng) |
| 723 | + check_spec_mode_engine_val(model, eng, mode) |
630 | 724 | check_fit_info(value) |
631 | 725 |
|
632 | 726 | current <- get_model_env() |
@@ -692,8 +786,7 @@ get_fit <- function(model) { |
692 | 786 | set_pred <- function(model, mode, eng, type, value) { |
693 | 787 | check_model_exists(model) |
694 | 788 | check_eng_val(eng) |
695 | | - check_mode_val(mode) |
696 | | - check_engine_val(eng) |
| 789 | + check_spec_mode_engine_val(model, eng, mode) |
697 | 790 | check_pred_info(value, type) |
698 | 791 |
|
699 | 792 | current <- get_model_env() |
|
0 commit comments