diff --git a/DESCRIPTION b/DESCRIPTION index 4da4cdbd..ba74167d 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: workflows Title: Modeling Workflows -Version: 1.1.3.9000 +Version: 1.1.3.9001 Authors@R: c( person("Davis", "Vaughan", , "davis@posit.co", role = "aut"), person("Simon", "Couch", , "simon.couch@posit.co", role = c("aut", "cre"), @@ -24,7 +24,7 @@ Imports: hardhat (>= 1.2.0), lifecycle (>= 1.0.3), modelenv (>= 0.1.0), - parsnip (>= 1.0.3), + parsnip (>= 1.1.0.9001), rlang (>= 1.0.3), tidyselect (>= 1.2.0), vctrs (>= 0.4.1) @@ -38,6 +38,8 @@ Suggests: recipes (>= 1.0.0), rmarkdown, testthat (>= 3.0.0) +Remotes: + tidymodels/parsnip#955 VignetteBuilder: knitr Config/Needs/website: diff --git a/NAMESPACE b/NAMESPACE index f1bac48f..5b8d8d5f 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -17,6 +17,7 @@ S3method(print,workflow) S3method(tidy,workflow) S3method(tunable,workflow) S3method(tune_args,workflow) +S3method(weight_propensity,workflow) export(.fit_finalize) export(.fit_model) export(.fit_pre) @@ -69,4 +70,5 @@ importFrom(hardhat,extract_recipe) importFrom(hardhat,extract_spec_parsnip) importFrom(lifecycle,deprecated) importFrom(parsnip,fit_xy) +importFrom(parsnip,weight_propensity) importFrom(stats,predict) diff --git a/R/weight_propensity.R b/R/weight_propensity.R new file mode 100644 index 00000000..8e3d11d0 --- /dev/null +++ b/R/weight_propensity.R @@ -0,0 +1,43 @@ +#' Helper for bridging two-stage causal fits +#' +#' @inherit parsnip::weight_propensity.model_fit description +#' +#' @inheritParams parsnip::weight_propensity.model_fit +#' +#' @inherit parsnip::weight_propensity.model_fit return +#' +#' @inherit parsnip::weight_propensity.model_fit references +#' +#' @importFrom parsnip weight_propensity +#' @method weight_propensity workflow +#' @export +weight_propensity.workflow <- function(object, + wt_fn, + .treated = extract_fit_parsnip(object)$lvl[2], + ..., + data) { + if (rlang::is_missing(wt_fn) || !is.function(wt_fn)) { + abort("`wt_fn` must be a function.") + } + + if (rlang::is_missing(data) || !is.data.frame(data)) { + abort("`data` must be the data supplied as the data argument to `fit()`.") + } + + if (!is_trained_workflow(object)) { + abort("`weight_propensity()` is not well-defined for an unfitted workflow.") + } + + outcome_name <- names(object$pre$mold$outcomes) + + preds <- predict(object, data, type = "prob") + preds <- preds[[paste0(".pred_", .treated)]] + + data$.wts <- + hardhat::importance_weights( + wt_fn(preds, data[[outcome_name]], .treated = .treated, ...) + ) + + data +} + diff --git a/man/weight_propensity.workflow.Rd b/man/weight_propensity.workflow.Rd new file mode 100644 index 00000000..f4c0db27 --- /dev/null +++ b/man/weight_propensity.workflow.Rd @@ -0,0 +1,53 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/weight_propensity.R +\name{weight_propensity.workflow} +\alias{weight_propensity.workflow} +\title{Helper for bridging two-stage causal fits} +\usage{ +\method{weight_propensity}{workflow}( + object, + wt_fn, + .treated = extract_fit_parsnip(object)$lvl[2], + ..., + data +) +} +\arguments{ +\item{object}{The object containing the model fit(s) that will generate +predictions used to calculate propensity weights. Currently, either a +\link[parsnip:fit.model_spec]{parsnip model fit}, fitted +\link[workflows:workflow]{workflow}, or +tuning results (\code{?tune::fit_resamples}) object. If a tuning result, the +object must have been generated with the control argument +(\code{?tune::control_resamples}) \code{extract = identity}.} + +\item{wt_fn}{A function used to calculate the propensity weights. The first +argument gives the predicted probability of exposure, the true value for +which is provided in the second argument. See \code{?propensity::wt_ate()} for +an example.} + +\item{.treated}{The level of the exposure corresponding to the treatment, as +a string. Additionally passed as \code{.treated} to \code{wt_fn}.} + +\item{...}{Additional arguments passed to \code{wt_fn}.} + +\item{data}{The data supplied as the \code{data} argument to \code{fit()} the \code{object}. +This argument is only required for the \code{model_fit} and \code{workflow} methods---the +needed data for the \code{tune_results} method lives inside of \code{object}.} +} +\value{ +For \code{model_fit} and fitted \code{workflow} input, a modified version of the data +set supplied in \code{data} that contains a \code{.wts} column with class +\code{importance_weights}. For \code{tune_results} input, a modified version of the +resampling object underlying the tuning results containing a new \code{.wts} column +with propensity values corresponding to each element of the analysis set. +} +\description{ +\code{weight_propensity()} is a helper function to more easily link the +propensity and outcome models in causal workflows. \strong{The main documentation +for this function lives in the tune package at} \code{?tune::weight_propensity}. +} +\references{ +Barrett M & D'Agostino McGowan L (forthcoming). +\emph{Causal Inference in R}. \url{https://www.r-causal.org/} +} diff --git a/tests/testthat/_snaps/weight_propensity.md b/tests/testthat/_snaps/weight_propensity.md new file mode 100644 index 00000000..9f18d07f --- /dev/null +++ b/tests/testthat/_snaps/weight_propensity.md @@ -0,0 +1,50 @@ +# errors informatively with bad input + + Code + weight_propensity(wf, silly_wt_fn, data = two_class_dat) + Condition + Error in `weight_propensity()`: + ! `weight_propensity()` is not well-defined for an unfitted workflow. + +--- + + Code + weight_propensity(wf_fit, data = two_class_dat) + Condition + Error in `weight_propensity()`: + ! `wt_fn` must be a function. + +--- + + Code + weight_propensity(wf_fit, "boop", data = two_class_dat) + Condition + Error in `weight_propensity()`: + ! `wt_fn` must be a function. + +--- + + Code + weight_propensity(wf_fit, function(...) { + -1L + }, data = two_class_dat) + Condition + Error in `hardhat::importance_weights()`: + ! `x` can't contain negative weights. + +--- + + Code + weight_propensity(wf_fit, silly_wt_fn) + Condition + Error in `weight_propensity()`: + ! `data` must be the data supplied as the data argument to `fit()`. + +--- + + Code + weight_propensity(wf_fit, silly_wt_fn, data = "boop") + Condition + Error in `weight_propensity()`: + ! `data` must be the data supplied as the data argument to `fit()`. + diff --git a/tests/testthat/test-weight_propensity.R b/tests/testthat/test-weight_propensity.R new file mode 100644 index 00000000..0f3cedfa --- /dev/null +++ b/tests/testthat/test-weight_propensity.R @@ -0,0 +1,63 @@ +test_that("basic functionality", { + skip_if_not_installed("modeldata") + library(modeldata) + library(parsnip) + + silly_wt_fn <- function(.propensity, .exposure = NULL, ...) { + seq(1, 2, length.out = length(.propensity)) + } + + lr_fit <- fit(workflow(Class ~ A + B, logistic_reg()), two_class_dat) + + lr_res1 <- weight_propensity(lr_fit, silly_wt_fn, data = two_class_dat) + expect_s3_class(lr_res1, "tbl_df") + expect_true(all(names(lr_res1) %in% c(names(two_class_dat), ".wts"))) + expect_equal(lr_res1$.wts, importance_weights(seq(1, 2, length.out = nrow(two_class_dat)))) +}) + +test_that("errors informatively with bad input", { + skip_if_not_installed("modeldata") + library(modeldata) + library(parsnip) + + silly_wt_fn <- function(.propensity, .exposure = NULL, ...) { + seq(1, 2, length.out = length(.propensity)) + } + + # untrained workflow + wf <- workflow(Class ~ A + B, logistic_reg()) + + expect_snapshot( + error = TRUE, + weight_propensity(wf, silly_wt_fn, data = two_class_dat) + ) + + # bad `wt_fn` + wf_fit <- fit(wf, two_class_dat) + + expect_snapshot( + error = TRUE, + weight_propensity(wf_fit, data = two_class_dat) + ) + + expect_snapshot( + error = TRUE, + weight_propensity(wf_fit, "boop", data = two_class_dat) + ) + + expect_snapshot( + error = TRUE, + weight_propensity(wf_fit, function(...) {-1L}, data = two_class_dat) + ) + + # bad `data` + expect_snapshot( + error = TRUE, + weight_propensity(wf_fit, silly_wt_fn) + ) + + expect_snapshot( + error = TRUE, + weight_propensity(wf_fit, silly_wt_fn, data = "boop") + ) +})