Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tidy.mvgam() WIP #100

Draft
wants to merge 10 commits into
base: master
Choose a base branch
from
203 changes: 203 additions & 0 deletions R/tidier_methods.R
Original file line number Diff line number Diff line change
@@ -1,7 +1,210 @@
#' @importFrom generics tidy
#' @export
generics::tidy

#' @importFrom generics augment
#' @export
generics::augment


#' List of observation families and any extra parameters.
#' Wrapped in a function for testing purposes.
#' @noRd
# extra_params <- function() {
# list(
# 'negative binomial' = c('phi'),
# 'beta_binomial' = c('phi'),
# 'beta' = c('phi'),
# 'tweedie' = c('phi'),
# 'gaussian' = c('sigma_obs'),
# 'student' = c('sigma_obs', 'nu'),
# 'lognormal' = c('sigma_obs'),
# 'Gamma' = c('shape'),
# 'poisson' = c(),
# 'binomial' = c(),
# 'bernoulli' = c(),
# 'nmix' = c()
# )
# }

# TODO: name cols based on `get_mvgam_priors()$param_info`?
#' @export
tidy.mvgam <- function(x, probs = c(0.025, 0.5, 0.975), ...) {
object <- x # For consistency with `summary.mvgam()`
obj_vars <- variables(object)
digits <- 2 # TODO: Let user change?
partialized_mcmc_summary <- purrr::partial(mcmc_summary,
object$model_output,
... =,
ISB = FALSE, # Matches `x[i]`'s rather than `x`.
probs = probs,
digits = digits,
Rhat = FALSE,
n.eff = FALSE)
out <- tibble::tibble()

# Observation family extra parameters

# extra_params <- extra_params()
# for (xp in extra_params[[object$family]]) {
# extra_params_out <- mcmc_summary(object$model_output,
# params = xp,
# digits = digits,
# variational = variational)
# out <- dplyr::bind_rows(out, extra_params_out)
# }

# Alt implementation of Observation family extra parameters
xp_names_all <- obj_vars$observation_pars$orig_name
xp_names <- grep("vec", xp_names_all, value = TRUE, invert = TRUE)
if (!is.null(xp_names)) {
extra_params_out <- partialized_mcmc_summary(params = xp_names)
extra_params_out <- tibble::add_column(extra_params_out,
param_type = "obs_fam_extra_param",
.before = 1)
out <- dplyr::bind_rows(out, extra_params_out)
}
# END Alt implementation
# END Observation family extra parameters


# obs non-smoother betas
if (object$mgcv_model$nsdf > 0) {
obs_beta_name_map <- dplyr::slice_head(obj_vars$observation_betas, n = object$mgcv_model$nsdf) # df("orig_name", "alias")
obs_betas_out <- partialized_mcmc_summary(params = obs_beta_name_map$orig_name)
row.names(obs_betas_out) <- obs_beta_name_map$alias
obs_betas_out <- tibble::add_column(obs_betas_out,
param_type = "observation_beta",
.before = 1)
out <- dplyr::bind_rows(out, obs_betas_out)
}
# END obs non-smoother betas


# random effects
# TODO: include specific s(re).[n] intercepts?
# TODO: random slopes' names? obj$mgcv_model$smooth$label?
re_param_name_map <- obj_vars$observation_re_params
if (!is.null(re_param_name_map)) {
re_params_out <- partialized_mcmc_summary(params = re_param_name_map$orig_name)
row.names(re_params_out) <- re_param_name_map$alias
re_params_out <- tibble::add_column(re_params_out,
param_type = "random effect (group-level)",
.before = 1)
out <- dplyr::bind_rows(out, re_params_out)
}
# END random effects -----------

# GPs
if (!is.null(obj_vars$trend_pars)) {
tm_param_names_all <- obj_vars$trend_pars$orig_name
gp_param_names <- grep("^alpha_gp|^rho_gp", tm_param_names_all, value = TRUE)
if (length(gp_param_names) > 0) {
gp_params_out <- partialized_mcmc_summary(params = gp_param_names)
# where is GP? can be in formula, trend_formula, or trend_model
if (grepl("^(alpha|rho)_gp_trend", gp_param_names[[1]])) {
param_type = "trend_formula_param"
} else if (grepl("^(alpha|rho)_gp_", gp_param_names[[1]])) { # hmph.
param_type = "observation_param"
} else {
param_type = "trend_model_param"
}
gp_params_out <- tibble::add_column(gp_params_out,
param_type = param_type,
.before = 1)
out <- dplyr::bind_rows(out, gp_params_out)
}
}
# END GPs --------------

# RW, AR, CAR, VAR
# TODO: split out Sigma for heircor?
# str vs called obj as arg to mvgam
# TODO: move trend_model_name up?
trend_model_name <- ifelse(inherits(object$trend_model, "mvgam_trend"),
object$trend_model$trend_model,
object$trend_model)
if (grepl("^VAR|^CAR|^AR|^RW|^ZMVN", trend_model_name)) {
# theta = MA terms
# alpha_cor = heirarchical corr term
# A = VAR auto-regressive matrix
# Sigma = correlated errors matrix
# sigma = errors

# setting up the params to extract
if (trend_model_name == "VAR") {
trend_model_params <- c("^A\\[", "^alpha_cor", "^theta", "^Sigma")
} else if (grepl("^CAR|^AR|^RW", trend_model_name)) {
cor = inherits(object$trend_model, "mvgam_trend") && object$trend_model$cor
sigma_name <- ifelse(cor, "^Sigma", "^sigma")
trend_model_params <- c("^ar", "^alpha_cor", "^theta", sigma_name)
} else if (grepl("^ZMVN", trend_model_name)) {
trend_model_params <- c("^alpha_cor", "^Sigma")
}

# extracting the params
trend_model_params <- paste(trend_model_params, collapse = "|")
tm_param_names_all <- obj_vars$trend_pars$orig_name
tm_param_names <- grep(trend_model_params, tm_param_names_all, value = TRUE)
tm_params_out <- partialized_mcmc_summary(params = tm_param_names)
tm_params_out <- tibble::add_column(tm_params_out,
param_type = "trend_model_param",
.before = 1)
out <- dplyr::bind_rows(out, tm_params_out)
}
# END RW, AR, CAR, VAR-----------

# Piecewise
# TODO: potentially lump into AR section, above; how to handle change points?
# to lump in, just add an
# `else if (grepl("^PW", trend_model_name)`, then
# `trend_model_params <- c("^k_trend", "^m_trend", "^delta_trend")`
# and change initial grep(ar car var) call
if (grepl("^PW", trend_model_name)) {
trend_model_params <- "^k_trend|^m_trend|^delta_trend"
tm_param_names_all <- obj_vars$trend_pars$orig_name
tm_param_names <- grep(trend_model_params, tm_param_names_all, value = TRUE)
tm_params_out <- partialized_mcmc_summary(params = tm_param_names)
tm_params_out <- tibble::add_column(tm_params_out,
param_type = "trend_model_param",
.before = 1)
out <- dplyr::bind_rows(out, tm_params_out)
}
# END Piecewise ------------

# Trend formula betas
if (!is.null(object$trend_call) && object$trend_mgcv_model$nsdf > 0) {
trend_beta_name_map <- dplyr::slice_head(obj_vars$trend_betas,
n = object$trend_mgcv_model$nsdf) # df("orig_name", "alias")
trend_betas_out <- partialized_mcmc_summary(params = trend_beta_name_map$orig_name)
row.names(trend_betas_out) <- trend_beta_name_map$alias
trend_betas_out <- tibble::add_column(trend_betas_out,
param_type = "trend_beta",
.before = 1)
out <- dplyr::bind_rows(out, trend_betas_out)
}
# END Trend formula betas ----------

# trend random effects
# TODO: include specific s(re).[n] intercepts?
trend_re_param_name_map <- obj_vars$trend_re_params
if (!is.null(trend_re_param_name_map)) {
trend_re_params_out <- partialized_mcmc_summary(params = trend_re_param_name_map$orig_name)
row.names(trend_re_params_out) <- trend_re_param_name_map$alias
trend_re_params_out <- tibble::add_column(trend_re_params_out,
param_type = "trend random effect (group-level)",
.before = 1)
out <- dplyr::bind_rows(out, trend_re_params_out)
}
# END tremd random effects -----------

# OUTPUT
# TODO: might need to put this prior to every bind_rows to avoid rowname dups.
out <- tibble::rownames_to_column(out, "parameter")
out
}


#' Augment an mvgam object's data
#'
#' Add fits and residuals to the data, implementing the generic `augment` from
Expand Down
Loading