diff --git a/.Rbuildignore b/.Rbuildignore new file mode 100644 index 0000000..eefec3f --- /dev/null +++ b/.Rbuildignore @@ -0,0 +1,5 @@ +^LICENSE\.md$ +^CODE_OF_CONDUCT\.md$ +^README\.Rmd$ +^.*\.Rproj$ +^\.Rproj\.user$ diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md new file mode 100644 index 0000000..08c7087 --- /dev/null +++ b/CODE_OF_CONDUCT.md @@ -0,0 +1,126 @@ +# Contributor Covenant Code of Conduct + +## Our Pledge + +We as members, contributors, and leaders pledge to make participation in our +community a harassment-free experience for everyone, regardless of age, body +size, visible or invisible disability, ethnicity, sex characteristics, gender +identity and expression, level of experience, education, socio-economic status, +nationality, personal appearance, race, caste, color, religion, or sexual +identity and orientation. + +We pledge to act and interact in ways that contribute to an open, welcoming, +diverse, inclusive, and healthy community. + +## Our Standards + +Examples of behavior that contributes to a positive environment for our +community include: + +* Demonstrating empathy and kindness toward other people +* Being respectful of differing opinions, viewpoints, and experiences +* Giving and gracefully accepting constructive feedback +* Accepting responsibility and apologizing to those affected by our mistakes, + and learning from the experience +* Focusing on what is best not just for us as individuals, but for the overall + community + +Examples of unacceptable behavior include: + +* The use of sexualized language or imagery, and sexual attention or advances of + any kind +* Trolling, insulting or derogatory comments, and personal or political attacks +* Public or private harassment +* Publishing others' private information, such as a physical or email address, + without their explicit permission +* Other conduct which could reasonably be considered inappropriate in a + professional setting + +## Enforcement Responsibilities + +Community leaders are responsible for clarifying and enforcing our standards of +acceptable behavior and will take appropriate and fair corrective action in +response to any behavior that they deem inappropriate, threatening, offensive, +or harmful. + +Community leaders have the right and responsibility to remove, edit, or reject +comments, commits, code, wiki edits, issues, and other contributions that are +not aligned to this Code of Conduct, and will communicate reasons for moderation +decisions when appropriate. + +## Scope + +This Code of Conduct applies within all community spaces, and also applies when +an individual is officially representing the community in public spaces. +Examples of representing our community include using an official e-mail address, +posting via an official social media account, or acting as an appointed +representative at an online or offline event. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be +reported to the community leaders responsible for enforcement at max@posit.co. +All complaints will be reviewed and investigated promptly and fairly. + +All community leaders are obligated to respect the privacy and security of the +reporter of any incident. + +## Enforcement Guidelines + +Community leaders will follow these Community Impact Guidelines in determining +the consequences for any action they deem in violation of this Code of Conduct: + +### 1. Correction + +**Community Impact**: Use of inappropriate language or other behavior deemed +unprofessional or unwelcome in the community. + +**Consequence**: A private, written warning from community leaders, providing +clarity around the nature of the violation and an explanation of why the +behavior was inappropriate. A public apology may be requested. + +### 2. Warning + +**Community Impact**: A violation through a single incident or series of +actions. + +**Consequence**: A warning with consequences for continued behavior. No +interaction with the people involved, including unsolicited interaction with +those enforcing the Code of Conduct, for a specified period of time. This +includes avoiding interactions in community spaces as well as external channels +like social media. Violating these terms may lead to a temporary or permanent +ban. + +### 3. Temporary Ban + +**Community Impact**: A serious violation of community standards, including +sustained inappropriate behavior. + +**Consequence**: A temporary ban from any sort of interaction or public +communication with the community for a specified period of time. No public or +private interaction with the people involved, including unsolicited interaction +with those enforcing the Code of Conduct, is allowed during this period. +Violating these terms may lead to a permanent ban. + +### 4. Permanent Ban + +**Community Impact**: Demonstrating a pattern of violation of community +standards, including sustained inappropriate behavior, harassment of an +individual, or aggression toward or disparagement of classes of individuals. + +**Consequence**: A permanent ban from any sort of public interaction within the +community. + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant][homepage], +version 2.1, available at +. + +Community Impact Guidelines were inspired by +[Mozilla's code of conduct enforcement ladder][https://github.com/mozilla/inclusion]. + +For answers to common questions about this code of conduct, see the FAQ at +. Translations are available at . + +[homepage]: https://www.contributor-covenant.org diff --git a/DESCRIPTION b/DESCRIPTION new file mode 100644 index 0000000..51927cd --- /dev/null +++ b/DESCRIPTION @@ -0,0 +1,34 @@ +Package: important +Title: Tools for Supervised Feature Selection +Version: 0.0.0.9000 +Authors@R: c( + person("Max", "Kuhn", , "max@posit.co", role = c("aut", "cre"), + comment = c(ORCID = "0000-0003-2402-136X")), + person("Posit Software PBC", role = "cph") + ) +Description: Low-level data and interfaces to choose important predictors + for predicting an outcome. +License: MIT + file LICENSE +Encoding: UTF-8 +Roxygen: list(markdown = TRUE) +RoxygenNote: 7.3.2 +Suggests: + modeldata, + recipes, + spelling, + testthat (>= 3.0.0) +Language: en-US +Imports: + cli, + dplyr, + generics, + ggplot2, + hardhat, + purrr, + rlang, + tibble, + tidyr, + tune, + workflows, + vctrs +Config/testthat/edition: 3 diff --git a/LICENSE b/LICENSE index 888e3e1..6359c28 100644 --- a/LICENSE +++ b/LICENSE @@ -1,21 +1,2 @@ -MIT License - -Copyright (c) 2024 Max Kuhn - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. +YEAR: 2024 +COPYRIGHT HOLDER: Posit Software PBC diff --git a/LICENSE.md b/LICENSE.md new file mode 100644 index 0000000..d6cf5fa --- /dev/null +++ b/LICENSE.md @@ -0,0 +1,21 @@ +# MIT License + +Copyright (c) 2024 Posit Software PBC + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/NAMESPACE b/NAMESPACE new file mode 100644 index 0000000..78e875e --- /dev/null +++ b/NAMESPACE @@ -0,0 +1,15 @@ +# Generated by roxygen2: do not edit by hand + +S3method(autoplot,importance_perm) +export(augment) +export(autoplot) +export(importance_perm) +export(is_cran_check) +export(required_pkgs) +importFrom(generics,augment) +importFrom(generics,required_pkgs) +importFrom(ggplot2,autoplot) +importFrom(hardhat,extract_fit_parsnip) +importFrom(hardhat,extract_postprocessor) +importFrom(stats,predict) +importFrom(stats,sd) diff --git a/R/aaa.R b/R/aaa.R new file mode 100644 index 0000000..6090507 --- /dev/null +++ b/R/aaa.R @@ -0,0 +1,13 @@ +# nocov start +# from tune +#' Internal functions +#' @keywords internal +#' @export +is_cran_check <- function() { + if (identical(Sys.getenv("NOT_CRAN"), "true")) { + FALSE + } else { + Sys.getenv("_R_CHECK_PACKAGE_NAME_", "") != "" + } +} +# nocov end diff --git a/R/extract_data.R b/R/extract_data.R new file mode 100644 index 0000000..004227d --- /dev/null +++ b/R/extract_data.R @@ -0,0 +1,47 @@ + +forge_predictors <- function (new_data, workflow) { + mold <- hardhat::extract_mold(workflow) + forged <- hardhat::forge(new_data, blueprint = mold$blueprint) + forged$predictors +} + + +# TODO case weights +# TODO use original data if not available in workflow +extract_data_original <- function(wflow, data, ...) { + if (!tibble::is_tibble(data)) { + data <- tibble::as_tibble(data) + } + # TODO should we get price or log(price) when log(price) ~ blah is used? + + ptypes <- wflow$pre$mold$blueprint$ptypes + extras <- wflow$pre$mold$blueprint$extra_role_ptypes + for (i in seq_along(extras)) { + ptypes[[2 + i]] <- extras[[i]] + } + ptypes <- purrr::list_cbind(unname(ptypes)) + col_names <- colnames(ptypes) + data <- data[, col_names] + hardhat::scream(data, ptypes) +} + + +extract_data_derived <- function(wflow, data, type = c("predictors", "outcomes"), bind = TRUE) { + type <- rlang::arg_match(type, c("predictors", "outcomes"), multiple = TRUE) + res <- list() + if (any(type == "predictors")) { + res$predictors <- forge_predictors(data, wflow) + } + if (any(type == "outcomes")) { + bp <- wflow |> hardhat::extract_mold() |> purrr::pluck("blueprint") + res$outcomes <- hardhat::forge(data, bp, outcomes = TRUE)$outcomes + } + if (bind) { + # TOD reorder so outcome is first + res <- purrr::list_cbind(unname(res)) + } + if (!tibble::is_tibble(res)) { + res <- tibble::as_tibble(res) + } + res +} diff --git a/R/importance_perm.R b/R/importance_perm.R new file mode 100644 index 0000000..cbf1c92 --- /dev/null +++ b/R/importance_perm.R @@ -0,0 +1,247 @@ +#' Compute permutation-based predictor importance +#' +#' [importance_perm()] computes model-agnostic variable importance scores by +#' permuting individual predictors (one at a time) and measuring how worse +#' model performance becomes. +#' +#' @param wflow A fitted [workflows::workflow()]. +#' @param data A data frame of the data passed to [workflows::fit.workflow()], +#' including the outcome and case weights. +#' @param metrics A [yardstick::metric_set()] or `NULL`. +#' @param eval_time For censored regression models, a vector of time points at +#' which the survival probability is estimated. +#' @param type A character string for which _level_ of predictors to compute. +#' A value of `"original"` (default) will return values in the same +#' representation of `data`. Using `"derived"` will compute them for any derived +#' features/predictors, such as dummy indicator columns, etc. +#' @param size How many data points to predictor for each iteration. +#' @param times How many iterations to repeat the calculations. +#' @param event_level A single string. Either `"first"` or `"second"` to specify +#' which level of `truth` to consider as the "event". This argument is only +#' applicable when `estimator = "binary"`. +#' @details +#' The function can compute importance at two different levels. +#' +#' - The "original" predictors are the unaltered columns in the source data set. +#' For example, for a categorical predictor used with linear regression, the +#' original predictor is the factor column. +#' - "Derived" predictors are the final versions given to the model. For the +#' categorical predictor example, the derived versions are the binary +#' indicator variables produced from the factor version. +#' +#' This can make a difference when pre-processing/feature engineering is used. +#' This can help us understand _how_ a predictor can be important +#' +#' Importance scores are computed for each predictor (at the specified level) +#' and each performance metric. If no metric is specified, defaults are used: +#' +#' - Classification: [yardstick::brier_class()], [yardstick::roc_auc()], and +#' [yardstick::accuracy()]. +#' - Regression: [yardstick::rmse()] and [yardstick::rsq()]. +#' - Censored regression: [yardstick::brier_survival()] +#' +#' For censored data, importance is computed for each evaluation time (when a +#' dynamic metric is specified). +#' @return A tibble with extra classes `"importance_perm"` and either +#' "`original_importance_perm"` or "`derived_importance_perm"`. The columns are: +#' - `.metric` the name of the performance metric: +#' - `predictor`: the predictor +#' - `n`: the number of usable results (should be the same as `times`) +#' - `mean`: the average of the differences in performance. For each metric, +#' larger values indicate worse performance (i.e., higher importance). +#' - `std_err`: the standard error of the differences. +#' - `importance`: the mean divided by the standard error. +#' For censored regression models, an additional `.eval_time` column may also +#' be included (depending on the metric requested). +#' @examplesIf !is_cran_check() +#' if (!rlang::is_installed(c("modeldata", "recipes", "workflows"))) { +#' library(modeldata) +#' library(recipes) +#' library(workflows) +#' library(dplyr) +#' +#' set.seed(12) +#' dat_tr <- +#' sim_logistic(250, ~ .1 + 2 * A - 3 * B + 1 * A *B, corr = .7) |> +#' dplyr::bind_cols(sim_noise(250, num_vars = 10)) +#' +#' rec <- +#' recipe(class ~ ., data = dat_tr) |> +#' step_interact(~ A:B) |> +#' step_normalize(all_numeric_predictors()) |> +#' step_pca(contains("noise"), num_comp = 5) +#' +#' lr_wflow <- workflow(rec, logistic_reg()) +#' lr_fit <- fit(lr_wflow, dat_tr) +#' +#' set.seed(39) +#' orig_res <- importance_perm(lr_fit, data = dat_tr, type = "original", +#' size = 100, times = 25) +#' orig_res +#' +#' set.seed(39) +#' deriv_res <- importance_perm(lr_fit, data = dat_tr, type = "derived", +#' size = 100, times = 25) +#' deriv_res +#' } +#' @export +importance_perm <- function(wflow, data, metrics = NULL, type = "original", size = 500, + times = 10, eval_time = NULL, event_level = "first") { + if (!workflows::is_trained_workflow(wflow)) { + cli::cli_abort("The workflow in {.arg wflow} should be trained.") + } + type <- rlang::arg_match(type, c("original", "derived")) + metrics <- tune::check_metrics_arg(metrics, wflow) + pkgs <- required_pkgs(wflow) + rlang::check_installed(pkgs) + + # ------------------------------------------------------------------------------ + # Pull appropriate source data + # TODO extract and use case weights + + if (type == "original") { + extracted_data <- extract_data_original(wflow, data) + } else { + extracted_data <- extract_data_derived(wflow, data) + } + extracted_data_nms <- colnames(extracted_data) + outcome_nm <- tune::outcome_names(wflow) + extracted_data_nms <- extracted_data_nms[extracted_data_nms != outcome_nm] + n <- nrow(extracted_data) + size <- min(floor(n * 0.8) , size) + + # ------------------------------------------------------------------------------ + # Prepare for permutations. A large `combos` data frame is created to optimize + # how well parallel processing speeds-up computations + + info <- tune::metrics_info(metrics) + seed_vals <- sample.int(1e6, times) + combos <- tidyr::crossing(seed = seed_vals, colunm = extracted_data_nms) + + # ------------------------------------------------------------------------------ + # Generate all permutations + + rlang::local_options(doFuture.rng.onMisuse = "ignore") + res_perms <- purrr::map2( + combos$colunm, + combos$seed, + ~ metric_iter( + column = .x, + .y, + type = type, + fitted = wflow, + dat = extracted_data, + metrics = metrics, + size = size, + outcome = outcome_nm, + eval_time = eval_time, + event_level = event_level + ) + ) |> + purrr::list_rbind() + + # ------------------------------------------------------------------------------ + # Get un-permuted performance statistics (per seed value) + + res_bl <- purrr::map( + seed_vals, + ~ metric_iter( + column = NULL, + .x, + type = type, + fitted = wflow, + dat = extracted_data, + metrics = metrics, + size = size, + outcome = outcome_nm, + eval_time = eval_time, + event_level = event_level + ) + ) |> + purrr::list_rbind() |> + dplyr::rename(baseline = .estimate) |> + dplyr::select(-predictor) + + # ------------------------------------------------------------------------------ + # Combine and summarize results + + has_eval_time <- any(names(res_perms) == ".eval_time") + + join_groups <- c(".metric", ".estimator") + if (has_eval_time) { + join_groups <- c(join_groups, ".eval_time") + } + + res <- + dplyr::full_join(res_perms, res_bl, by = c(join_groups, "seed")) |> + dplyr::full_join(info, by = ".metric") |> + dplyr::mutate( + # TODO add (log) ratio? + importance = dplyr::if_else( + direction == "minimize", + .estimate - baseline, + baseline - .estimate)) + + summarize_groups <- c(".metric", "predictor") + if (has_eval_time) { + summarize_groups <- c(summarize_groups, ".eval_time") + } + + res <- + res |> + dplyr::summarize( + permuted = mean(.estimate, na.rm = TRUE), + n = sum(!is.na(importance)), + mean = mean(importance, na.rm = TRUE), + sd = sd(importance, na.rm = TRUE), + std_err = sd / sqrt(n), + importance = dplyr::if_else(sd == 0, 0, mean / std_err), + .by = c(dplyr::all_of(summarize_groups)) + ) |> + dplyr::select(-sd, -permuted) |> + dplyr::arrange(dplyr::desc(importance)) + class(res) <- c("importance_perm", paste0(type, "_importance_perm"), class(res)) + res +} + +metric_iter <- function(column = NULL, seed, type, fitted, dat, metrics, size, + outcome, eval_time, event_level) { + info <- tune::metrics_info(metrics) + set.seed(seed) + n <- nrow(dat) + if (!is.null(column)) { + dat[[column]] <- sample(dat[[column]]) + } + if (!is.null(size)) { + ind <- sample.int(n, size) + dat <- dat[ind,] + } + + # ------------------------------------------------------------------------------ + # Predictions. Use a wrapper because a simple `augment()` works for original + # predictors but not for derived. + preds <- predictions(fitted, dat, type, eval_time = eval_time) + + # ------------------------------------------------------------------------------ + # Compute metrics + + res <- + tune::.estimate_metrics( + preds, + metric = metrics, + param_names = NULL, + outcome_name = outcome, + event_level = event_level, + metrics_info = info) + + if (is.null(column)) { + column <- ".baseline" + } + res$predictor <- column + res$seed <- seed + res +} + +# TODO silently bad results when an in-line transformation is used with +# add_model(x formula = log(y) ~ x) _or_ fails due to not findnig the outcome +# column when add_formula(log(y) ~ .) is used diff --git a/R/important-package.R b/R/important-package.R new file mode 100644 index 0000000..027e695 --- /dev/null +++ b/R/important-package.R @@ -0,0 +1,31 @@ +#' @keywords internal +"_PACKAGE" + +## usethis namespace: start +#' @importFrom stats sd predict +#' @importFrom hardhat extract_fit_parsnip extract_postprocessor + +#' @importFrom ggplot2 autoplot +#' @export +ggplot2::autoplot + +#' @importFrom generics required_pkgs +#' @export +generics::required_pkgs + +#' @importFrom generics augment +#' @export +generics::augment + +utils::globalVariables( + c(".estimate", ".metric", "baseline", "direction", "importance", "permuted", + "predictor", "ranking", "std_err") +) +## usethis namespace: end +NULL + +## From workflows +# nocov start +has_postprocessor <- function (x) has_postprocessor_tailor(x) +has_postprocessor_tailor <- function(x) "tailor" %in% names(x$post$actions) +# nocov end diff --git a/R/plots.R b/R/plots.R new file mode 100644 index 0000000..2f27e33 --- /dev/null +++ b/R/plots.R @@ -0,0 +1,77 @@ +#' Visualize importance scores +#' @param object A tibble of results from [importance_perm()]. +#' @param metric A character vector or `NULL` for which metric to plot. By +#' default, all metrics will be shown via facets. Possible options are +#' the entries in `.metric` column of the object. +#' @param eval_time For censored regression models, a vector of time points at +#' which the survival probability is estimated. +#' @param top An integer for how many terms to show. The rankings of predictors +#' are computed across metrics. +#' @param type A character value. The default is `"importance"` which shows the +#' overall signal-to-noise ration (i.e., mean divided by standard error). +#' Alternatively, `"direction"` shows the mean difference value with standard +#' error bounds. +#' @param ... Not used. +#' @return A `ggplot2` object. +#' @export +autoplot.importance_perm <- function(object, top = Inf, metric = NULL, + eval_time = NULL, type = "importance", ...) { + type <- rlang::arg_match(type, values = c("importance", "difference")) + if (!is.null(metric)) { + object <- object[object$.metric %in% metric,] + if (nrow(object) == 0) { + cli::cli_abort("No data left when filtering over {.val {metric}}.") + } + } + overall_rank <- + object |> + dplyr::mutate(ranking = rank(-importance)) |> + dplyr::summarize( + ranking = mean(ranking), + .by = c(predictor) + ) |> + dplyr::arrange(ranking) + + num_pred <- vctrs::vec_unique_count(object$predictor) + if (top < num_pred) { + overall_rank <- overall_rank[overall_rank$ranking <= top, ] + object <- object[object$predictor %in% unique(overall_rank$predictor), ] + } + object$predictor <- factor(object$predictor, levels = rev(overall_rank$predictor)) + + p <- + ggplot2::ggplot(object, ggplot2::aes(y = predictor)) + + ggplot2::geom_vline(xintercept = 0, col = "red", lty = 2) + + ggplot2::labs(y = NULL, x = "Permutation Importance Score") + if (length(unique(object$.metric)) > 1) { + p <- p + ggplot2::facet_wrap(~ .metric) + } + if (type == "importance") { + p <- p + ggplot2::geom_point(ggplot2::aes(x = importance)) + } else if (type == "difference") { + # TODO add alpha level + p <- p + + ggplot2::geom_point(ggplot2::aes(x = mean)) + + ggplot2::geom_errorbar( + ggplot2::aes(xmin = mean - 1.96 * std_err, xmax = mean + 1.96 * std_err), + width = 0) + } + p +} + +predictions <- function(wflow, new_data, type, eval_time) { + if (type == "original") { + preds <- augment(wflow, new_data = new_data, eval_time = eval_time) + } else { + preds <- + wflow |> + extract_fit_parsnip() |> + augment(new_data = new_data, eval_time = eval_time) + use_post <- has_postprocessor(wflow) + if (use_post) { + post_proc <- extract_postprocessor(wflow) + preds <- predict(post_proc, preds) + } + } + preds +} diff --git a/README.Rmd b/README.Rmd new file mode 100644 index 0000000..dfaef34 --- /dev/null +++ b/README.Rmd @@ -0,0 +1,116 @@ +--- +output: github_document +--- + + + +```{r, include = FALSE} +knitr::opts_chunk$set( + collapse = TRUE, + comment = "#>", + fig.path = "man/figures/README-", + out.width = "100%" +) +``` + +# important + + + + +The goal of important is to ... + +## Installation + +You can install the development version of important from [GitHub](https://github.com/) with: + +``` r +# install.packages("devtools") +devtools::install_github("topepo/important") +``` + +## Example + +```{r} +#| label: startup-sshh +#| include: false +library(tidymodels) +library(important) +theme_set(theme_bw()) +``` +```{r} +#| label: startup +#| include: false +library(tidymodels) +library(important) +``` + + +```{r} +#| label: chi-data +data(deliveries, package = "modeldata") + +set.seed(991) +delivery_split <- initial_validation_split(deliveries, prop = c(0.6, 0.2), strata = time_to_delivery) +delivery_train <- training(delivery_split) +``` + + +```{r} +#| label: model +delivery_rec <- + recipe(time_to_delivery ~ ., data = delivery_train) %>% + step_dummy(all_factor_predictors()) %>% + step_zv(all_predictors()) %>% + step_spline_natural(hour, distance, deg_free = 10) %>% + step_interact(~ starts_with("hour_"):starts_with("day_")) + +lm_wflow <- workflow(delivery_rec, linear_reg()) +lm_fit <- fit(lm_wflow, delivery_train) +``` + + +```{r} +#| label: derived-importance +set.seed(382) +lm_deriv_imp <- + importance_perm( + lm_fit, + data = delivery_train, + metrics = metric_set(mae, rsq), + times = 50, + type = "derived" + ) +lm_deriv_imp +``` + +```{r} +#| label: derived-plot +#| fig.height: 8 + +autoplot(lm_deriv_imp, top = 100) +``` + +```{r} +#| label: original-importance +set.seed(382) +lm_orig_imp <- + importance_perm( + lm_fit, + data = delivery_train, + metrics = metric_set(mae, rsq), + times = 50 + ) +lm_orig_imp +``` + +```{r} +#| label: original-plot + +autoplot(lm_orig_imp) +``` + + +## Code of Conduct + +Please note that the important project is released with a [Contributor Code of Conduct](https://contributor-covenant.org/version/2/1/CODE_OF_CONDUCT.html). By contributing to this project, you agree to abide by its terms. diff --git a/README.md b/README.md index ed56f59..aa25a47 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,113 @@ + + + # important -Tools for Supervised Feature Selection + + + + +The goal of important is to … + +## Installation + +You can install the development version of important from +[GitHub](https://github.com/) with: + +``` r +# install.packages("devtools") +devtools::install_github("topepo/important") +``` + +## Example + +``` r +data(deliveries, package = "modeldata") + +set.seed(991) +delivery_split <- initial_validation_split(deliveries, prop = c(0.6, 0.2), strata = time_to_delivery) +delivery_train <- training(delivery_split) +``` + +``` r +delivery_rec <- + recipe(time_to_delivery ~ ., data = delivery_train) %>% + step_dummy(all_factor_predictors()) %>% + step_zv(all_predictors()) %>% + step_spline_natural(hour, distance, deg_free = 10) %>% + step_interact(~ starts_with("hour_"):starts_with("day_")) + +lm_wflow <- workflow(delivery_rec, linear_reg()) +lm_fit <- fit(lm_wflow, delivery_train) +``` + +``` r +set.seed(382) +lm_deriv_imp <- + importance_perm( + lm_fit, + data = delivery_train, + metrics = metric_set(mae, rsq), + times = 50, + type = "derived" + ) +lm_deriv_imp +#> # A tibble: 226 × 6 +#> .metric predictor n mean std_err importance +#> +#> 1 rsq distance_10 50 0.531 0.00642 82.7 +#> 2 mae distance_10 50 2.24 0.0308 72.8 +#> 3 mae day_Sat 50 1.09 0.0194 56.3 +#> 4 mae day_Fri 50 0.904 0.0171 53.0 +#> 5 rsq day_Sat 50 0.120 0.00274 43.8 +#> 6 mae distance_09 50 0.783 0.0191 41.0 +#> 7 mae day_Thu 50 0.633 0.0165 38.3 +#> 8 rsq day_Fri 50 0.101 0.00265 37.9 +#> 9 rsq hour_07_x_day_Sat 50 0.140 0.00380 36.8 +#> 10 rsq hour_06_x_day_Sat 50 0.143 0.00403 35.5 +#> # ℹ 216 more rows +``` + +``` r +autoplot(lm_deriv_imp, top = 100) +``` + + + +``` r +set.seed(382) +lm_orig_imp <- + importance_perm( + lm_fit, + data = delivery_train, + metrics = metric_set(mae, rsq), + times = 50 + ) +lm_orig_imp +#> # A tibble: 60 × 6 +#> .metric predictor n mean std_err importance +#> +#> 1 rsq hour 50 0.780 0.00423 184. +#> 2 mae hour 50 4.07 0.0332 123. +#> 3 mae day 50 1.91 0.0250 76.4 +#> 4 mae distance 50 1.49 0.0209 71.2 +#> 5 rsq distance 50 0.289 0.00450 64.3 +#> 6 rsq day 50 0.325 0.00516 63.0 +#> 7 mae item_24 50 0.0587 0.0149 3.93 +#> 8 mae item_03 50 0.0446 0.0146 3.06 +#> 9 mae item_10 50 0.0457 0.0152 3.00 +#> 10 mae item_02 50 0.0398 0.0146 2.72 +#> # ℹ 50 more rows +``` + +``` r +autoplot(lm_orig_imp) +``` + + + +## Code of Conduct + +Please note that the important project is released with a [Contributor +Code of +Conduct](https://contributor-covenant.org/version/2/1/CODE_OF_CONDUCT.html). +By contributing to this project, you agree to abide by its terms. diff --git a/important.Rproj b/important.Rproj new file mode 100644 index 0000000..497f8bf --- /dev/null +++ b/important.Rproj @@ -0,0 +1,20 @@ +Version: 1.0 + +RestoreWorkspace: Default +SaveWorkspace: Default +AlwaysSaveHistory: Default + +EnableCodeIndexing: Yes +UseSpacesForTab: Yes +NumSpacesForTab: 2 +Encoding: UTF-8 + +RnwWeave: Sweave +LaTeX: pdfLaTeX + +AutoAppendNewline: Yes +StripTrailingWhitespace: Yes + +BuildType: Package +PackageUseDevtools: Yes +PackageInstallArgs: --no-multiarch --with-keep.source diff --git a/inst/WORDLIST b/inst/WORDLIST new file mode 100644 index 0000000..50be5b1 --- /dev/null +++ b/inst/WORDLIST @@ -0,0 +1,4 @@ +ORCID +PBC +pre +tibble diff --git a/man/autoplot.importance_perm.Rd b/man/autoplot.importance_perm.Rd new file mode 100644 index 0000000..e8c34a5 --- /dev/null +++ b/man/autoplot.importance_perm.Rd @@ -0,0 +1,41 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/plots.R +\name{autoplot.importance_perm} +\alias{autoplot.importance_perm} +\title{Visualize importance scores} +\usage{ +\method{autoplot}{importance_perm}( + object, + top = Inf, + metric = NULL, + eval_time = NULL, + type = "importance", + ... +) +} +\arguments{ +\item{object}{A tibble of results from \code{\link[=importance_perm]{importance_perm()}}.} + +\item{top}{An integer for how many terms to show. The rankings of predictors +are computed across metrics.} + +\item{metric}{A character vector or \code{NULL} for which metric to plot. By +default, all metrics will be shown via facets. Possible options are +the entries in \code{.metric} column of the object.} + +\item{eval_time}{For censored regression models, a vector of time points at +which the survival probability is estimated.} + +\item{type}{A character value. The default is \code{"importance"} which shows the +overall signal-to-noise ration (i.e., mean divided by standard error). +Alternatively, \code{"direction"} shows the mean difference value with standard +error bounds.} + +\item{...}{Not used.} +} +\value{ +A \code{ggplot2} object. +} +\description{ +Visualize importance scores +} diff --git a/man/figures/README-derived-plot-1.png b/man/figures/README-derived-plot-1.png new file mode 100644 index 0000000..318348d Binary files /dev/null and b/man/figures/README-derived-plot-1.png differ diff --git a/man/figures/README-original-plot-1.png b/man/figures/README-original-plot-1.png new file mode 100644 index 0000000..d030107 Binary files /dev/null and b/man/figures/README-original-plot-1.png differ diff --git a/man/importance_perm.Rd b/man/importance_perm.Rd new file mode 100644 index 0000000..287a219 --- /dev/null +++ b/man/importance_perm.Rd @@ -0,0 +1,121 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/importance_perm.R +\name{importance_perm} +\alias{importance_perm} +\title{Compute permutation-based predictor importance} +\usage{ +importance_perm( + wflow, + data, + metrics = NULL, + type = "original", + size = 500, + times = 10, + eval_time = NULL, + event_level = "first" +) +} +\arguments{ +\item{wflow}{A fitted \code{\link[workflows:workflow]{workflows::workflow()}}.} + +\item{data}{A data frame of the data passed to \code{\link[workflows:fit-workflow]{workflows::fit.workflow()}}, +including the outcome and case weights.} + +\item{metrics}{A \code{\link[yardstick:metric_set]{yardstick::metric_set()}} or \code{NULL}.} + +\item{type}{A character string for which \emph{level} of predictors to compute. +A value of \code{"original"} (default) will return values in the same +representation of \code{data}. Using \code{"derived"} will compute them for any derived +features/predictors, such as dummy indicator columns, etc.} + +\item{size}{How many data points to predictor for each iteration.} + +\item{times}{How many iterations to repeat the calculations.} + +\item{eval_time}{For censored regression models, a vector of time points at +which the survival probability is estimated.} + +\item{event_level}{A single string. Either \code{"first"} or \code{"second"} to specify +which level of \code{truth} to consider as the "event". This argument is only +applicable when \code{estimator = "binary"}.} +} +\value{ +A tibble with extra classes \code{"importance_perm"} and either +"\verb{original_importance_perm"} or "\verb{derived_importance_perm"}. The columns are: +\itemize{ +\item \code{.metric} the name of the performance metric: +\item \code{predictor}: the predictor +\item \code{n}: the number of usable results (should be the same as \code{times}) +\item \code{mean}: the average of the differences in performance. For each metric, +larger values indicate worse performance (i.e., higher importance). +\item \code{std_err}: the standard error of the differences. +\item \code{importance}: the mean divided by the standard error. +For censored regression models, an additional \code{.eval_time} column may also +be included (depending on the metric requested). +} +} +\description{ +\code{\link[=importance_perm]{importance_perm()}} computes model-agnostic variable importance scores by +permuting individual predictors (one at a time) and measuring how worse +model performance becomes. +} +\details{ +The function can compute importance at two different levels. +\itemize{ +\item The "original" predictors are the unaltered columns in the source data set. +For example, for a categorical predictor used with linear regression, the +original predictor is the factor column. +\item "Derived" predictors are the final versions given to the model. For the +categorical predictor example, the derived versions are the binary +indicator variables produced from the factor version. +} + +This can make a difference when pre-processing/feature engineering is used. +This can help us understand \emph{how} a predictor can be important + +Importance scores are computed for each predictor (at the specified level) +and each performance metric. If no metric is specified, defaults are used: +\itemize{ +\item Classification: \code{\link[yardstick:brier_class]{yardstick::brier_class()}}, \code{\link[yardstick:roc_auc]{yardstick::roc_auc()}}, and +\code{\link[yardstick:accuracy]{yardstick::accuracy()}}. +\item Regression: \code{\link[yardstick:rmse]{yardstick::rmse()}} and \code{\link[yardstick:rsq]{yardstick::rsq()}}. +\item Censored regression: \code{\link[yardstick:brier_survival]{yardstick::brier_survival()}} +} + +For censored data, importance is computed for each evaluation time (when a +dynamic metric is specified). +} +\examples{ +\dontshow{if (!is_cran_check()) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +if (!rlang::is_installed(c("modeldata", "recipes", "workflows"))) { + library(modeldata) + library(recipes) + library(workflows) + library(dplyr) + + set.seed(12) + dat_tr <- + sim_logistic(250, ~ .1 + 2 * A - 3 * B + 1 * A *B, corr = .7) |> + dplyr::bind_cols(sim_noise(250, num_vars = 10)) + + rec <- + recipe(class ~ ., data = dat_tr) |> + step_interact(~ A:B) |> + step_normalize(all_numeric_predictors()) |> + step_pca(contains("noise"), num_comp = 5) + + lr_wflow <- workflow(rec, logistic_reg()) + lr_fit <- fit(lr_wflow, dat_tr) + + set.seed(39) + orig_res <- importance_perm(lr_fit, data = dat_tr, type = "original", + size = 100, times = 25) + orig_res + + set.seed(39) + deriv_res <- importance_perm(lr_fit, data = dat_tr, type = "derived", + size = 100, times = 25) + deriv_res +} +\dontshow{\}) # examplesIf} +} diff --git a/man/important-package.Rd b/man/important-package.Rd new file mode 100644 index 0000000..9517aee --- /dev/null +++ b/man/important-package.Rd @@ -0,0 +1,20 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/important-package.R +\docType{package} +\name{important-package} +\alias{important} +\alias{important-package} +\title{important: Tools for Supervised Feature Selection} +\description{ +Low-level data and interfaces to choose important predictors for predicting an outcome. +} +\author{ +\strong{Maintainer}: Max Kuhn \email{max@posit.co} (\href{https://orcid.org/0000-0003-2402-136X}{ORCID}) + +Other contributors: +\itemize{ + \item Posit Software PBC [copyright holder] +} + +} +\keyword{internal} diff --git a/man/is_cran_check.Rd b/man/is_cran_check.Rd new file mode 100644 index 0000000..dfe0f05 --- /dev/null +++ b/man/is_cran_check.Rd @@ -0,0 +1,12 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/aaa.R +\name{is_cran_check} +\alias{is_cran_check} +\title{Internal functions} +\usage{ +is_cran_check() +} +\description{ +Internal functions +} +\keyword{internal} diff --git a/man/reexports.Rd b/man/reexports.Rd new file mode 100644 index 0000000..dede400 --- /dev/null +++ b/man/reexports.Rd @@ -0,0 +1,20 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/important-package.R +\docType{import} +\name{reexports} +\alias{reexports} +\alias{autoplot} +\alias{required_pkgs} +\alias{augment} +\title{Objects exported from other packages} +\keyword{internal} +\description{ +These objects are imported from other packages. Follow the links +below to see their documentation. + +\describe{ + \item{generics}{\code{\link[generics]{augment}}, \code{\link[generics]{required_pkgs}}} + + \item{ggplot2}{\code{\link[ggplot2]{autoplot}}} +}} + diff --git a/tests/spelling.R b/tests/spelling.R new file mode 100644 index 0000000..6713838 --- /dev/null +++ b/tests/spelling.R @@ -0,0 +1,3 @@ +if(requireNamespace('spelling', quietly = TRUE)) + spelling::spell_check_test(vignettes = TRUE, error = FALSE, + skip_on_cran = TRUE) diff --git a/tests/testthat.R b/tests/testthat.R new file mode 100644 index 0000000..6c72e1a --- /dev/null +++ b/tests/testthat.R @@ -0,0 +1,12 @@ +# This file is part of the standard setup for testthat. +# It is recommended that you do not modify it. +# +# Where should you do additional test configuration? +# Learn more about the roles of various files in: +# * https://r-pkgs.org/testing-design.html#sec-tests-files-overview +# * https://testthat.r-lib.org/articles/special-files.html + +library(testthat) +library(important) + +test_check("important")