Skip to content

Commit

Permalink
a bunch of new stuff for working with continuous treatments
Browse files Browse the repository at this point in the history
  • Loading branch information
bcallaway11 committed Jan 30, 2025
1 parent 0bd0233 commit 213c648
Show file tree
Hide file tree
Showing 17 changed files with 677 additions and 97 deletions.
5 changes: 5 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,21 @@
S3method(print,group_time_att)
S3method(print,pte_results)
S3method(print,summary.pte_results)
S3method(summary,aggte_obj)
S3method(summary,group_time_att)
S3method(summary,pte_emp_boot)
S3method(summary,pte_results)
export(aggte_obj)
export(attgt_if)
export(attgt_noif)
export(attgt_pte_aggregations)
export(compute.pte)
export(compute.pte2)
export(crit_val_checks)
export(did_attgt)
export(dose_obj)
export(ggpte)
export(ggpte_cont)
export(group_time_att)
export(gt_data_frame)
export(keep_all_pretreatment_subset)
Expand Down
59 changes: 52 additions & 7 deletions R/ggpte.R
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,57 @@
ggpte <- function(pte_results) {
plot_df <- summary(pte_results)$event_study
colnames(plot_df) <- c("e", "att", "se", "cil", "ciu")
plot_df$post <- as.factor(1*(plot_df$e >= 0))
ggplot(plot_df, aes(x=e, y=att)) +
geom_line(aes(color=post)) +
geom_point(aes(color=post)) +
geom_line(aes(y=ciu), linetype="dashed", alpha=0.5) +
geom_line(aes(y=cil), linetype="dashed", alpha=0.5) +
plot_df$post <- as.factor(1 * (plot_df$e >= 0))
ggplot(plot_df, aes(x = e, y = att)) +
geom_line(aes(color = post)) +
geom_point(aes(color = post)) +
geom_line(aes(y = ciu), linetype = "dashed", alpha = 0.5) +
geom_line(aes(y = cil), linetype = "dashed", alpha = 0.5) +
theme_bw() +
theme(legend.position="bottom")
theme(legend.position = "bottom")
}


#' @title ggpte_cont
#'
#' @description a function for plotting results in applications with a continuous treatment
#'
#' @param dose_obj a `dose_obj` that holds results with a continuous treatment
#' @param type whether to plot ATT(d) or ACRT(d), defaults to `att` for
#' plotting ATT(d). For ACRT(d), use "acrt"
#'
#' @export
ggpte_cont <- function(dose_obj, type = "att") {
dose <- dose_obj$dose
if (type == "acrt") {
acrt.d <- dose_obj$acrt.d
acrt.d_se <- dose_obj$acrt.d_se
acrt.d_crit.val <- dose_obj$acrt.d_crit.val
plot_df <- cbind.data.frame(dose, acrt.d, acrt.d_se, acrt.d_crit.val)
ggplot(plot_df, aes(x = dose, y = acrt.d)) +
geom_line(size = 2) +
geom_ribbon(
aes(
ymin = acrt.d - acrt.d_crit.val * acrt.d_se,
ymax = acrt.d + acrt.d_crit.val * acrt.d_se
),
fill = "lightgray", alpha = 0.5
) +
theme_bw()
} else { # att(d) plot
att.d <- dose_obj$att.d
att.d_se <- dose_obj$att.d_se
att.d_crit.val <- dose_obj$att.d_crit.val
plot_df <- cbind.data.frame(dose, att.d, att.d_se, att.d_crit.val)
ggplot(plot_df, aes(x = dose, y = att.d)) +
geom_line(size = 2) +
geom_ribbon(
aes(
ymin = att.d - att.d_crit.val * att.d_se,
ymax = att.d + att.d_crit.val * att.d_se
),
fill = "lightgray", alpha = 0.5
) +
theme_bw()
}
}
162 changes: 159 additions & 3 deletions R/process_dose_gt.R
Original file line number Diff line number Diff line change
@@ -1,9 +1,165 @@
process_dose_gt <- function(gt_results, ptep, ...) {
browser()

# make the call to att, to get same format of results
att_gt <- process_att_gt(gt_results, ptep)
o_weights <- overall_weights(att_gt, ...)

1 + 1
# main dose-specific results are in extra_gt_returns
all_extra_gt_returns <- att_gt$extra_gt_returns
groups <- unlist(BMisc::getListElement(all_extra_gt_returns, "group"))
time.periods <- unlist(BMisc::getListElement(all_extra_gt_returns, "time.period"))

# check that order of groups and time periods matches
if (!all(cbind(groups, time.periods) == o_weights[, c("group", "time.period")])) {
stop("in processing dose results, mismatch between order of groups and time periods")
}

inner_extra_gt_returns <- BMisc::getListElement(all_extra_gt_returns, "extra_gt_returns")
att.d_gt <- BMisc::getListElement(inner_extra_gt_returns, "att.d")
acrt.d_gt <- BMisc::getListElement(inner_extra_gt_returns, "acrt.d")
att.overall_gt <- unlist(BMisc::getListElement(inner_extra_gt_returns, "att.overall"))
acrt.overall_gt <- unlist(BMisc::getListElement(inner_extra_gt_returns, "acrt.overall"))
bet_gt <- BMisc::getListElement(inner_extra_gt_returns, "bet")
bread_gt <- BMisc::getListElement(inner_extra_gt_returns, "bread")
Xe_gt <- BMisc::getListElement(inner_extra_gt_returns, "Xe")

# point estimates of ATT(d) and ACRT(d)
att.d <- weighted_combine_list(att.d_gt, o_weights$overall_weight)
acrt.d <- weighted_combine_list(acrt.d_gt, o_weights$overall_weight)

# values of the dose
dvals <- ptep$dvals
degree <- ptep$degree
knots <- ptep$knots
bs_grid <- splines2::bSpline(dvals, degree = degree, knots = knots)
bs_grid <- cbind(1, bs_grid) # add intercept
bs_deriv <- splines2::dbs(dvals, degree = degree, knots = knots)
bs_deriv <- cbind(0, bs_deriv) # intercept doesn't matter here, just placeholder to get dimensions right

# since we are picking dvals over a grid, the only randomness comes from
# estimating the \beta's
n1_vec <- sapply(Xe_gt, nrow)
acrt_gt_inffunc_mat <- gt_results$inffunc
n <- nrow(acrt_gt_inffunc_mat)
keep_mat <- acrt_gt_inffunc_mat != 0
if (!all(colSums(keep_mat) == n1_vec)) {
stop("something off with overall influence function")
}

att.d_gt_inffunc <- lapply(
1:length(Xe_gt),
function(i) {
out_inffunc <- matrix(data = 0, nrow = n, ncol = length(dvals))
this_inffunc <- (Xe_gt[[i]] %*% bread_gt[[i]] %*% t(bs_grid))
out_inffunc[keep_mat[, i], ] <- (n / n1_vec[i]) * this_inffunc
out_inffunc
}
)

att.d_inffunc <- weighted_combine_list(att.d_gt_inffunc, o_weights$overall_weight)
biters <- ptep$biters
alp <- ptep$alp
cband <- ptep$cband
boot_res <- mboot2(att.d_inffunc, biters = biters, alp = alp)
att.d_se <- boot_res$boot_se
if (cband) {
att.d_crit.val <- boot_res$crit_val
att.d_crit.val <- crit_val_checks(att.d_crit.val, alp)
} else {
att.d_crit.val <- qnorm(1 - alp / 2)
}

# influence function for acrt

# acrt influence function - same as for att.d except use derivative of basis functions
acrt.d_gt_inffunc <- lapply(
1:length(Xe_gt),
function(i) {
out_inffunc <- matrix(data = 0, nrow = n, ncol = length(dvals))
this_inffunc <- (Xe_gt[[i]] %*% bread_gt[[i]] %*% t(bs_deriv))
out_inffunc[keep_mat[, i], ] <- (n / n1_vec[i]) * this_inffunc
out_inffunc
}
)
acrt.d_inffunc <- weighted_combine_list(acrt.d_gt_inffunc, o_weights$overall_weight)
acrt_boot_res <- mboot2(acrt.d_inffunc, biters = biters, alp = alp)
acrt.d_se <- acrt_boot_res$boot_se
if (cband) {
acrt.d_crit.val <- acrt_boot_res$crit_val
acrt.d_crit.val <- crit_val_checks(acrt.d_crit.val, alp)
} else {
acrt.d_crit.val <- qnorm(1 - alp / 2)
}

# placeholder for tracking `call`
call <- NULL

dose_obj(
dose = dvals,
att.d = att.d,
att.d_se = att.d_se,
att.d_crit.val = att.d_crit.val,
att.d_inffunc = att.d_inffunc,
acrt.d = acrt.d,
acrt.d_se = acrt.d_se,
acrt.d_crit.val = acrt.d_crit.val,
acrt.d_inffunc = acrt.d_inffunc,
pte_params = ptep,
call = call
)
}

#' @title dose_obj
#'
#' @description Holds results from computing dose-specific treatment effects
#' with a continuous treatment
#'
#' @param dose vector containing the values of the dose used in estimation
#' @param att.d estimates of ATT(d) for each value of `dose`
#' @param att.d_se standard error of ATT(d) for each value of `dose`
#' @param att.d_crt.val critical value to produce pointwise or uniform confidence
#' interval for ATT(d)
#' @param att.d_inffunc matrix containing the influence function from estimating
#' ATT(d)
#' @param acrt.d estimates of ACRT(d) for each value of `dose`
#' @param acrt.d_se standard error of ACRT(d) for each value of `dose`
#' @param acrt.d_crt.val critical value to produce pointwise or uniform confidence
#' interval for ACRT(d)
#' @param acrt.d_inffunc matrix containing the influence function from estimating
#' ACRT(d)
#' @param pte_params a pte_params object containing other parameters passed to the function
#' @param call the original call to the function for computing causal effect parameters
#' with a continuous treatment
#'
#' @return dose_obj
#'
#' @export
dose_obj <- function(
dose,
att.d = NULL,
att.d_se = NULL,
att.d_crit.val = NULL,
att.d_inffunc = NULL,
acrt.d = NULL,
acrt.d_se = NULL,
acrt.d_crit.val = NULL,
acrt.d_inffunc = NULL,
pte_params = NULL,
call = NULL) {
out <- list(
dose = dose,
att.d = att.d,
att.d_se = att.d_se,
att.d_crit.val = att.d_crit.val,
att.d_inffunc = att.d_inffunc,
acrt.d = acrt.d,
acrt.d_se = acrt.d_se,
acrt.d_crit.val = acrt.d_crit.val,
acrt.d_inffunc = acrt.d_inffunc,
pte_params = pte_params,
call = call
)

class(out) <- "dose_obj"

out
}
10 changes: 9 additions & 1 deletion R/pte.R
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,15 @@ pte <- function(yname,
max_e <- ifelse(is.null(dots$max_e), Inf, dots$max_e)
balance_e <- dots$balance_e

event_study <- pte_aggte(att_gt, type = "dynamic", bstrap = TRUE, cband = cband, alp = ptep$alp, min_e = min_e, max_e = max_e, balance_e = balance_e)
event_study <- pte_aggte(att_gt,
type = "dynamic",
bstrap = TRUE,
cband = cband,
alp = ptep$alp,
min_e = min_e,
max_e = max_e,
balance_e = balance_e
)

# output
out <- pte_results(
Expand Down
Loading

0 comments on commit 213c648

Please sign in to comment.