Using `"derived"` will compute them for any derived +#' features/predictors, such as dummy indicator columns, etc. +#' @param size How many data points to predictor for each iteration. +#' @param times How many iterations to repeat the calculations. +#' @param event_level A single string. Either `"first"` or `"second"` to specify +#' which level of `truth` to consider as the "event". This argument is only +#' applicable when `estimator = "binary"`. +#' @details +#' The function can compute importance at two different levels. +#' +#' - The "original" predictors are the unaltered columns in the source data set. +#' For example, for a categorical predictor used with linear regression, the +#' original predictor is the factor column. +#' - "Derived" predictors are the final versions given to the model. For the +#' categorical predictor example, the derived versions are the binary +#' indicator variables produced from the factor version. +#' +#' This can make a difference when pre-processing/feature engineering is used. +#' This can help us understand _how_ a predictor can be important +#' +#' Importance scores are computed for each predictor (at the specified level) +#' and each performance metric. If no metric is specified, defaults are used: +#' +#' - Classification: [yardstick::brier_class()], [yardstick::roc_auc()], and +#' [yardstick::accuracy()]. +#' - Regression: [yardstick::rmse()] and [yardstick::rsq()]. +#' - Censored regression: [yardstick::brier_survival()] +#' +#' For censored data, importance is computed for each evaluation time (when a +#' dynamic metric is specified). +#' @return A tibble with extra classes `"importance_perm"` and either +#' "`original_importance_perm"` or "`derived_importance_perm"`. The columns are: +#' - `.metric` the name of the performance metric: +#' - `predictor`: the predictor +#' - `n`: the number of usable results (should be the same as `times`) +#' - `mean`: the average of the differences in performance. For each metric, +#' larger values indicate worse performance (i.e., higher importance). +#' - `std_err`: the standard error of the differences. +#' - `importance`: the mean divided by the standard error. +#' For censored regression models, an additional `.eval_time` column may also +#' be included (depending on the metric requested). +#' @examplesIf !is_cran_check() +#' if (!rlang::is_installed(c("modeldata", "recipes", "workflows"))) { +#' library(modeldata) +#' library(recipes) +#' library(workflows) +#' library(dplyr) +#' +#' set.seed(12) +#' dat_tr <- +#' sim_logistic(250, ~ .1 + 2 * A - 3 * B + 1 * A *B, corr = .7) |> +#' dplyr::bind_cols(sim_noise(250, num_vars = 10)) +#' +#' rec <- +#' recipe(class ~ ., data = dat_tr) |> +#' step_interact(~ A:B) |> +#' step_normalize(all_numeric_predictors()) |> +#' step_pca(contains("noise"), num_comp = 5) +#' +#' lr_wflow <- workflow(rec, logistic_reg()) +#' lr_fit <- fit(lr_wflow, dat_tr) +#' +#' set.seed(39) +#' orig_res <- importance_perm(lr_fit, data = dat_tr, type = "original", +#' size = 100, times = 25) +#' orig_res +#' +#' set.seed(39) +#' deriv_res <- importance_perm(lr_fit, data = dat_tr, type = "derived", +#' size = 100, times = 25) +#' deriv_res +#' } +#' @export +importance_perm <- function(wflow, data, metrics = NULL, type = "original", size = 500, + times = 10, eval_time = NULL, event_level = "first") { + if (!workflows::is_trained_workflow(wflow)) { + cli::cli_abort("The workflow in {.arg wflow} should be trained.") + } + type <- rlang::arg_match(type, c("original", "derived")) + metrics <- tune::check_metrics_arg(metrics, wflow) + pkgs <- required_pkgs(wflow) + rlang::check_installed(pkgs) + + # ------------------------------------------------------------------------------ + # Pull appropriate source data + # TODO extract and use case weights + + if (type == "original") { + extracted_data <- extract_data_original(wflow, data) + } else { + extracted_data <- extract_data_derived(wflow, data) + } + extracted_data_nms <- colnames(extracted_data) + outcome_nm <- tune::outcome_names(wflow) + extracted_data_nms <- extracted_data_nms[extracted_data_nms != outcome_nm] + n <- nrow(extracted_data) + size <- min(floor(n * 0.8) , size) + + # ------------------------------------------------------------------------------ + # Prepare for permutations. A large `combos` data frame is created to optimize + # how well parallel processing speeds-up computations + + info <- tune::metrics_info(metrics) + seed_vals <- sample.int(1e6, times) + combos <- tidyr::crossing(seed = seed_vals, colunm = extracted_data_nms) + + # ------------------------------------------------------------------------------ + # Generate all permutations + + rlang::local_options(doFuture.rng.onMisuse = "ignore") + res_perms <- purrr::map2( + combos$colunm, + combos$seed, + ~ metric_iter( + column = .x, + .y, + type = type, + fitted = wflow, + dat = extracted_data, + metrics = metrics, + size = size, + outcome = outcome_nm, + eval_time = eval_time, + event_level = event_level + ) + ) |> + purrr::list_rbind() + + # ------------------------------------------------------------------------------ + # Get un-permuted performance statistics (per seed value) + + res_bl <- purrr::map( + seed_vals, + ~ metric_iter( + column = NULL, + .x, + type = type, + fitted = wflow, + dat = extracted_data, + metrics = metrics, + size = size, + outcome = outcome_nm, + eval_time = eval_time, + event_level = event_level + ) + ) |> + purrr::list_rbind() |> + dplyr::rename(baseline = .estimate) |> + dplyr::select(-predictor) + + # ------------------------------------------------------------------------------ + # Combine and summarize results + + has_eval_time <- any(names(res_perms) == ".eval_time") + + join_groups <- c(".metric", ".estimator") + if (has_eval_time) { + join_groups <- c(join_groups, ".eval_time") + } + + res <- + dplyr::full_join(res_perms, res_bl, by = c(join_groups, "seed")) |> + dplyr::full_join(info, by = ".metric") |> + dplyr::mutate( + # TODO add (log) ratio? + importance = dplyr::if_else( + direction == "minimize", + .estimate - baseline, + baseline - .estimate)) + + summarize_groups <- c(".metric", "predictor") + if (has_eval_time) { + summarize_groups <- c(summarize_groups, ".eval_time") + } + + res <- + res |> + dplyr::summarize( + permuted = mean(.estimate, na.rm = TRUE), + n = sum(!is.na(importance)), + mean = mean(importance, na.rm = TRUE), + sd = sd(importance, na.rm = TRUE), + std_err = sd / sqrt(n), + importance = dplyr::if_else(sd == 0, 0, mean / std_err), + .by = c(dplyr::all_of(summarize_groups)) + ) |> + dplyr::select(-sd, -permuted) |> + dplyr::arrange(dplyr::desc(importance)) + class(res) <- c("importance_perm", paste0(type, "_importance_perm"), class(res)) + res +} + +metric_iter <- function(column = NULL, seed, type, fitted, dat, metrics, size, + outcome, eval_time, event_level) { + info <- tune::metrics_info(metrics) + set.seed(seed) + n <- nrow(dat) + if (!is.null(column)) { + dat[[column]] <- sample(dat[[column]]) + } + if (!is.null(size)) { + ind <- sample.int(n, size) + dat <- dat[ind,] + } + + # ------------------------------------------------------------------------------ + # Predictions. Use a wrapper because a simple `augment()` works for original + # predictors but not for derived. + preds <- predictions(fitted, dat, type, eval_time = eval_time) + + # ------------------------------------------------------------------------------ + # Compute metrics + + res <- + tune::.estimate_metrics( + preds, + metric = metrics, + param_names = NULL, + outcome_name = outcome, + event_level = event_level, + metrics_info = info) + + if (is.null(column)) { + column <- ".baseline" + } + res$predictor <- column + res$seed <- seed + res +} + +# TODO silently bad results when an in-line transformation is used with +# add_model(x formula = log(y) ~ x) _or_ fails due to not findnig the outcome +# column when add_formula(log(y) ~ .) is used diff --git a/R/important-package.R b/R/important-package.R new file mode 100644 index 0000000..027e695 --- /dev/null +++ b/R/important-package.R @@ -0,0 +1,31 @@ +#' @keywords internal +"_PACKAGE" + +## usethis namespace: start +#' @importFrom stats sd predict +#' @importFrom hardhat extract_fit_parsnip extract_postprocessor + +#' @importFrom ggplot2 autoplot +#' @export +ggplot2::autoplot + +#' @importFrom generics required_pkgs +#' @export +generics::required_pkgs + +#' @importFrom generics augment +#' @export +generics::augment + +utils::globalVariables( + c(".estimate", ".metric", "baseline", "direction", "importance", "permuted", + "predictor", "ranking", "std_err") +) +## usethis namespace: end +NULL + +## From workflows +# nocov start +has_postprocessor <- function (x) has_postprocessor_tailor(x) +has_postprocessor_tailor <- function(x) "tailor" %in% names(x$post$actions) +# nocov end diff --git a/R/plots.R b/R/plots.R new file mode 100644 index 0000000..2f27e33 --- /dev/null +++ b/R/plots.R @@ -0,0 +1,77 @@ +#' Visualize importance scores +#' @param object A tibble of results from [importance_perm()]. +#' @param metric A character vector or `NULL` for which metric to plot. By +#' default, all metrics will be shown via facets. Possible options are +#' the entries in `.metric` column of the object. +#' @param eval_time For censored regression models, a vector of time points at +#' which the survival probability is estimated. +#' @param top An integer for how many terms to show. The rankings of predictors +#' are computed across metrics. +#' @param type A character value. The default is `"importance"` which shows the +#' overall signal-to-noise ration (i.e., mean divided by standard error). +#' Alternatively, `"direction"` shows the mean difference value with standard +#' error bounds. +#' @param ... ## Code of Conduct

Please note that the important project is released with a [Contributor Code of Conduct](https://contributor-covenant.org/version/2/1/CODE_OF_CONDUCT.html). By contributing to this project, you agree to abide by its terms. The rankings of predictors +are computed across metrics.} + +\item{metric}{A character vector or \code{NULL} for which metric to plot. By +default, all metrics will be shown via facets. Possible options are +the entries in \code{.metric} column of the object.} + +\item{eval_time}{For censored regression models, a vector of time points at +which the survival probability is estimated.} + +\item{type}{A character value. The default is \code{"importance"} which shows the +overall signal-to-noise ration (i.e., mean divided by standard error). +Alternatively, \code{"direction"} shows the mean difference value with standard +error bounds.} + +\item{...}{Not used.} +} +\value{ +A \code{ggplot2} object. +} +\description{ +Visualize importance scores +} diff --git a/man/figures/README-derived-plot-1.png b/man/figures/README-derived-plot-1.png new file mode 100644 index 0000000..318348d Binary files /dev/null and b/man/figures/README-derived-plot-1.png differ diff --git a/man/figures/README-original-plot-1.png b/man/figures/README-original-plot-1.png new file mode 100644 index 0000000..d030107 Binary files /dev/null and b/man/figures/README-original-plot-1.png differ diff --git a/man/importance_perm.Rd b/man/importance_perm.Rd new file mode 100644 index 0000000..287a219 --- /dev/null +++ b/man/importance_perm.Rd @@ -0,0 +1,121 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/importance_perm.R +\name{importance_perm} +\alias{importance_perm} +\title{Compute permutation-based predictor importance} +\usage{ +importance_perm( + wflow, + data, + metrics = NULL, + type = "original", + size = 500, + times = 10, + eval_time = NULL, + event_level = "first" +) +} +\arguments{ +\item{wflow}{A fitted \code{\link[workflows:workflow]{workflows::workflow()}}.} + +\item{data}{A data frame of the data passed to \code{\link[workflows:fit-workflow]{workflows::fit.workflow()}}, +including the outcome and case weights.} + +\item{metrics}{A \code{\link[yardstick:metric_set]{yardstick::metric_set()}} or \code{NULL}.} + +\item{type}{A character string for which \emph{level} of predictors to compute. +A value of \code{"original"} (default) will return values in the same +representation of \code{data}. Using \code{"derived"} will compute them for any derived +features/predictors, such as dummy indicator columns, etc.} + +\item{size}{How many data points to predictor for each iteration.} + +\item{times}{How many iterations to repeat the calculations.} + +\item{eval_time}{For censored regression models, a vector of time points at +which the survival probability is estimated.} + +\item{event_level}{A single string. Either \code{"first"} or \code{"second"} to specify +which level of \code{truth} to consider as the "event". This argument is only +applicable when \code{estimator = "binary"}.} +} +\value{ +A tibble with extra classes \code{"importance_perm"} and either +"\verb{original_importance_perm"} or "\verb{derived_importance_perm"}. The columns are: +\itemize{ +\item \code{.metric} the name of the performance metric: +\item \code{predictor}: the predictor +\item \code{n}: the number of usable results (should be the same as \code{times}) +\item \code{mean}: the average of the differences in performance. For each metric, +larger values indicate worse performance (i.e., higher importance). +\item \code{std_err}: the standard error of the differences. +\item \code{importance}: the mean divided by the standard error. +For censored regression models, an additional \code{.eval_time} column may also +be included (depending on the metric requested). +} +} +\description{ +\code{\link[=importance_perm]{importance_perm()}} computes model-agnostic variable importance scores by +permuting individual predictors (one at a time) and measuring how worse +model performance becomes. +} +\details{ +The function can compute importance at two different levels. +\itemize{ +\item The "original" predictors are the unaltered columns in the source data set. +For example, for a categorical predictor used with linear regression, the +original predictor is the factor column. +\item "Derived" predictors are the final versions given to the model. For the +categorical predictor example, the derived versions are the binary +indicator variables produced from the factor version. +} + +This can make a difference when pre-processing/feature engineering is used. +This can help us understand \emph{how} a predictor can be important + +Importance scores are computed for each predictor (at the specified level) +and each performance metric. If no metric is specified, defaults are used: +\itemize{ +\item Classification: \code{\link[yardstick:brier_class]{yardstick::brier_class()}}, \code{\link[yardstick:roc_auc]{yardstick::roc_auc()}}, and +\code{\link[yardstick:accuracy]{yardstick::accuracy()}}. +\item Regression: \code{\link[yardstick:rmse]{yardstick::rmse()}} and \code{\link[yardstick:rsq]{yardstick::rsq()}}. +\item Censored regression: \code{\link[yardstick:brier_survival]{yardstick::brier_survival()}} +} + +For censored data, importance is computed for each evaluation time (when a +dynamic metric is specified). +} +\examples{ +\dontshow{if (!is_cran_check()) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +if (!rlang::is_installed(c("modeldata", "recipes", "workflows"))) { + library(modeldata) + library(recipes) + library(workflows) + library(dplyr) + + set.seed(12) + dat_tr <- + sim_logistic(250, ~ .1 + 2 * A - 3 * B + 1 * A *B, corr = .7) |> + dplyr::bind_cols(sim_noise(250, num_vars = 10)) + + rec <- + recipe(class ~ ., data = dat_tr) |> + step_interact(~ A:B) |> + step_normalize(all_numeric_predictors()) |> + step_pca(contains("noise"), num_comp = 5) + + lr_wflow <- workflow(rec, logistic_reg()) + lr_fit <- fit(lr_wflow, dat_tr) + + set.seed(39) + orig_res <- importance_perm(lr_fit, data = dat_tr, type = "original", + size = 100, times = 25) + orig_res + + set.seed(39) + deriv_res <- importance_perm(lr_fit, data = dat_tr, type = "derived", + size = 100, times = 25) + deriv_res +} +\dontshow{\}) # examplesIf} +} diff --git a/man/important-package.Rd b/man/important-package.Rd new file mode 100644 index 0000000..9517aee --- /dev/null +++ b/man/important-package.Rd @@ -0,0 +1,20 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/important-package.R +\docType{package} +\name{important-package} +\alias{important} +\alias{important-package} +\title{important: Tools for Supervised Feature Selection} +\description{ +Low-level data and interfaces to choose important predictors for predicting an outcome. +} +\author{ +\strong{Maintainer}: Max Kuhn \email{max@posit.co} (\href{https://orcid.org/0000-0003-2402-136X}{ORCID}) + +Other contributors: +\itemize{ + \item Posit Software PBC [copyright holder] +} + +} +\keyword{internal} diff --git a/man/is_cran_check.Rd b/man/is_cran_check.Rd new file mode 100644 index 0000000..dfe0f05 --- /dev/null +++ b/man/is_cran_check.Rd @@ -0,0 +1,12 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/aaa.R +\name{is_cran_check} +\alias{is_cran_check} +\title{Internal functions} +\usage{ +is_cran_check() +} +\description{ +Internal functions +} +\keyword{internal} diff --git a/man/reexports.Rd b/man/reexports.Rd new file mode 100644 index 0000000..dede400 --- /dev/null +++ b/man/reexports.Rd @@ -0,0 +1,20 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/important-package.R +\docType{import} +\name{reexports} +\alias{reexports} +\alias{autoplot} +\alias{required_pkgs} +\alias{augment} +\title{Objects exported from other packages} +\keyword{internal} +\description{ +These objects are imported from other packages. Follow the links +below to see their documentation. + +\describe{ + \item{generics}{\code{\link[generics]{augment}}, \code{\link[generics]{required_pkgs}}} + + \item{ggplot2}{\code{\link[ggplot2]{autoplot}}} +}} + diff --git a/tests/spelling.R b/tests/spelling.R new file mode 100644 index 0000000..6713838 --- /dev/null +++ b/tests/spelling.R @@ -0,0 +1,3 @@ +if(requireNamespace('spelling', quietly = TRUE)) + spelling::spell_check_test(vignettes = TRUE, error = FALSE, + skip_on_cran = TRUE) diff --git a/tests/testthat.R b/tests/testthat.R new file mode 100644 index 0000000..6c72e1a --- /dev/null +++ b/tests/testthat.R @@ -0,0 +1,12 @@ +# This file is part of the standard setup for testthat. +# It is recommended that you do not modify it. +# +# Where should you do additional test configuration? +# Learn more about the roles of various files in: +# * https://r-pkgs.org/testing-design.html#sec-tests-files-overview +# * https://testthat.r-lib.org/articles/special-files.html + +library(testthat) +library(important) + +test_check("important")