From 6fde2fd0d6f1e35f574fdcd5b7be90abea6f0609 Mon Sep 17 00:00:00 2001 From: swpease Date: Fri, 10 Jan 2025 11:28:21 -0800 Subject: [PATCH] `tidy.mvgam()` WIP draft method --- R/tidier_methods.R | 203 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 203 insertions(+) diff --git a/R/tidier_methods.R b/R/tidier_methods.R index 23657a17..33277ad9 100644 --- a/R/tidier_methods.R +++ b/R/tidier_methods.R @@ -1,7 +1,210 @@ +#' @importFrom generics tidy +#' @export +generics::tidy + #' @importFrom generics augment #' @export generics::augment + +#' List of observation families and any extra parameters. +#' Wrapped in a function for testing purposes. +#' @noRd +# extra_params <- function() { +# list( +# 'negative binomial' = c('phi'), +# 'beta_binomial' = c('phi'), +# 'beta' = c('phi'), +# 'tweedie' = c('phi'), +# 'gaussian' = c('sigma_obs'), +# 'student' = c('sigma_obs', 'nu'), +# 'lognormal' = c('sigma_obs'), +# 'Gamma' = c('shape'), +# 'poisson' = c(), +# 'binomial' = c(), +# 'bernoulli' = c(), +# 'nmix' = c() +# ) +# } + +# TODO: name cols based on `get_mvgam_priors()$param_info`? +#' @export +tidy.mvgam <- function(x, probs = c(0.025, 0.5, 0.975), ...) { + object <- x # For consistency with `summary.mvgam()` + obj_vars <- variables(object) + digits <- 2 # TODO: Let user change? + partialized_mcmc_summary <- purrr::partial(mcmc_summary, + object$model_output, + ... =, + ISB = FALSE, # Matches `x[i]`'s rather than `x`. + probs = probs, + digits = digits, + Rhat = FALSE, + n.eff = FALSE) + out <- tibble::tibble() + + # Observation family extra parameters + + # extra_params <- extra_params() + # for (xp in extra_params[[object$family]]) { + # extra_params_out <- mcmc_summary(object$model_output, + # params = xp, + # digits = digits, + # variational = variational) + # out <- dplyr::bind_rows(out, extra_params_out) + # } + + # Alt implementation of Observation family extra parameters + xp_names_all <- obj_vars$observation_pars$orig_name + xp_names <- grep("vec", xp_names_all, value = TRUE, invert = TRUE) + if (!is.null(xp_names)) { + extra_params_out <- partialized_mcmc_summary(params = xp_names) + extra_params_out <- tibble::add_column(extra_params_out, + param_type = "obs_fam_extra_param", + .before = 1) + out <- dplyr::bind_rows(out, extra_params_out) + } + # END Alt implementation + # END Observation family extra parameters + + + # obs non-smoother betas + if (object$mgcv_model$nsdf > 0) { + obs_beta_name_map <- dplyr::slice_head(obj_vars$observation_betas, n = object$mgcv_model$nsdf) # df("orig_name", "alias") + obs_betas_out <- partialized_mcmc_summary(params = obs_beta_name_map$orig_name) + row.names(obs_betas_out) <- obs_beta_name_map$alias + obs_betas_out <- tibble::add_column(obs_betas_out, + param_type = "observation_beta", + .before = 1) + out <- dplyr::bind_rows(out, obs_betas_out) + } + # END obs non-smoother betas + + + # random effects + # TODO: include specific s(re).[n] intercepts? + # TODO: random slopes' names? obj$mgcv_model$smooth$label? + re_param_name_map <- obj_vars$observation_re_params + if (!is.null(re_param_name_map)) { + re_params_out <- partialized_mcmc_summary(params = re_param_name_map$orig_name) + row.names(re_params_out) <- re_param_name_map$alias + re_params_out <- tibble::add_column(re_params_out, + param_type = "random effect (group-level)", + .before = 1) + out <- dplyr::bind_rows(out, re_params_out) + } + # END random effects ----------- + + # GPs + if (!is.null(obj_vars$trend_pars)) { + tm_param_names_all <- obj_vars$trend_pars$orig_name + gp_param_names <- grep("^alpha_gp|^rho_gp", tm_param_names_all, value = TRUE) + if (length(gp_param_names) > 0) { + gp_params_out <- partialized_mcmc_summary(params = gp_param_names) + # where is GP? can be in formula, trend_formula, or trend_model + if (grepl("^(alpha|rho)_gp_trend", gp_param_names[[1]])) { + param_type = "trend_formula_param" + } else if (grepl("^(alpha|rho)_gp_", gp_param_names[[1]])) { # hmph. + param_type = "observation_param" + } else { + param_type = "trend_model_param" + } + gp_params_out <- tibble::add_column(gp_params_out, + param_type = param_type, + .before = 1) + out <- dplyr::bind_rows(out, gp_params_out) + } + } + # END GPs -------------- + + # RW, AR, CAR, VAR + # TODO: split out Sigma for heircor? + # str vs called obj as arg to mvgam + # TODO: move trend_model_name up? + trend_model_name <- ifelse(inherits(object$trend_model, "mvgam_trend"), + object$trend_model$trend_model, + object$trend_model) + if (grepl("^VAR|^CAR|^AR|^RW|^ZMVN", trend_model_name)) { + # theta = MA terms + # alpha_cor = heirarchical corr term + # A = VAR auto-regressive matrix + # Sigma = correlated errors matrix + # sigma = errors + + # setting up the params to extract + if (trend_model_name == "VAR") { + trend_model_params <- c("^A\\[", "^alpha_cor", "^theta", "^Sigma") + } else if (grepl("^CAR|^AR|^RW", trend_model_name)) { + cor = inherits(object$trend_model, "mvgam_trend") && object$trend_model$cor + sigma_name <- ifelse(cor, "^Sigma", "^sigma") + trend_model_params <- c("^ar", "^alpha_cor", "^theta", sigma_name) + } else if (grepl("^ZMVN", trend_model_name)) { + trend_model_params <- c("^alpha_cor", "^Sigma") + } + + # extracting the params + trend_model_params <- paste(trend_model_params, collapse = "|") + tm_param_names_all <- obj_vars$trend_pars$orig_name + tm_param_names <- grep(trend_model_params, tm_param_names_all, value = TRUE) + tm_params_out <- partialized_mcmc_summary(params = tm_param_names) + tm_params_out <- tibble::add_column(tm_params_out, + param_type = "trend_model_param", + .before = 1) + out <- dplyr::bind_rows(out, tm_params_out) + } + # END RW, AR, CAR, VAR----------- + + # Piecewise + # TODO: potentially lump into AR section, above; how to handle change points? + # to lump in, just add an + # `else if (grepl("^PW", trend_model_name)`, then + # `trend_model_params <- c("^k_trend", "^m_trend", "^delta_trend")` + # and change initial grep(ar car var) call + if (grepl("^PW", trend_model_name)) { + trend_model_params <- "^k_trend|^m_trend|^delta_trend" + tm_param_names_all <- obj_vars$trend_pars$orig_name + tm_param_names <- grep(trend_model_params, tm_param_names_all, value = TRUE) + tm_params_out <- partialized_mcmc_summary(params = tm_param_names) + tm_params_out <- tibble::add_column(tm_params_out, + param_type = "trend_model_param", + .before = 1) + out <- dplyr::bind_rows(out, tm_params_out) + } + # END Piecewise ------------ + + # Trend formula betas + if (!is.null(object$trend_call) && object$trend_mgcv_model$nsdf > 0) { + trend_beta_name_map <- dplyr::slice_head(obj_vars$trend_betas, + n = object$trend_mgcv_model$nsdf) # df("orig_name", "alias") + trend_betas_out <- partialized_mcmc_summary(params = trend_beta_name_map$orig_name) + row.names(trend_betas_out) <- trend_beta_name_map$alias + trend_betas_out <- tibble::add_column(trend_betas_out, + param_type = "trend_beta", + .before = 1) + out <- dplyr::bind_rows(out, trend_betas_out) + } + # END Trend formula betas ---------- + + # trend random effects + # TODO: include specific s(re).[n] intercepts? + trend_re_param_name_map <- obj_vars$trend_re_params + if (!is.null(trend_re_param_name_map)) { + trend_re_params_out <- partialized_mcmc_summary(params = trend_re_param_name_map$orig_name) + row.names(trend_re_params_out) <- trend_re_param_name_map$alias + trend_re_params_out <- tibble::add_column(trend_re_params_out, + param_type = "trend random effect (group-level)", + .before = 1) + out <- dplyr::bind_rows(out, trend_re_params_out) + } + # END tremd random effects ----------- + + # OUTPUT + # TODO: might need to put this prior to every bind_rows to avoid rowname dups. + out <- tibble::rownames_to_column(out, "parameter") + out +} + + #' Augment an mvgam object's data #' #' Add fits and residuals to the data, implementing the generic `augment` from