Skip to content

Commit

Permalink
make correct_for_ps() code more readable
Browse files Browse the repository at this point in the history
  • Loading branch information
malcolmbarrett committed Feb 4, 2025
1 parent 8eef786 commit 6151303
Showing 1 changed file with 19 additions and 5 deletions.
24 changes: 19 additions & 5 deletions R/ipw.R
Original file line number Diff line number Diff line change
Expand Up @@ -415,11 +415,25 @@ t_tcrossprod_over_rows <- function(mat) {
}

correct_for_ps <- function(exposure, exposure_actual = exposure, outcome, ps, mu, n_group, weight_matrix, weight_derivatives, correction_mat, n) {
drop(
n / n_group *
rbind(colSums(weight_derivatives * exposure_actual * (outcome - mu)) / n) %*%
(correction_mat %*% t((exposure - ps) * weight_matrix))
) |> unname()
# first, compute partial-derivative sums over subjects (averaged by n)
partial_derivative_sums <- colSums(
weight_derivatives * exposure_actual * (outcome - mu)
) / n

# then build the transformation matrix for correction
transformation_mat <- correction_mat %*% t((exposure - ps) * weight_matrix)

# and then apply the partial-derivative sums to that transformation
correction_contrib <- rbind(partial_derivative_sums) %*% transformation_mat

# rescale by (n / n_group)
scaling_factor <- n / n_group
correction_contrib <- correction_contrib * scaling_factor

# and reduce to vector and unname
correction_contrib |>
drop() |>
unname()
}

estimate_marginal_means <- function(outcome_mod, wts, exposure, exposure_name, .df = NULL) {
Expand Down

0 comments on commit 6151303

Please sign in to comment.