Skip to content

Commit

Permalink
tidy.mvgam() implemented
Browse files Browse the repository at this point in the history
`tidy.mvgam()` snapshot value tests

The snapshots (note: not snapshot values) record what's printed, and the `check()` call led to a different truncation compared to `test_active_file()`, so I'll just use `expect_snapshot_value()` instead, even though it's harder to read.

`tidy.mvgam()` snapshots tests

covers the majority of main use cases, but is missing a random effects test and a heirarchical correlation test.

`tidy.mvgam()` docs; term name

Wrote documentation for method. Also settled on names for the "term" column contents.

`tidy.mvgam()` trend formula w/o trend model

These models have error terms (sigmas) which were not being included.

`tidy.mvgam()` random effect specific betas

Decided it would be better to include these than not; can be easily filtered out if undesired.
  • Loading branch information
swpease committed Feb 1, 2025
1 parent bd20133 commit 1250b80
Show file tree
Hide file tree
Showing 7 changed files with 708 additions and 1 deletion.
3 changes: 3 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ S3method(stancode,mvgam_prefit)
S3method(standata,mvgam_prefit)
S3method(summary,mvgam)
S3method(summary,mvgam_prefit)
S3method(tidy,mvgam)
S3method(update,jsdgam)
S3method(update,mvgam)
S3method(variables,mvgam)
Expand Down Expand Up @@ -166,6 +167,7 @@ export(student_t)
export(t2)
export(te)
export(ti)
export(tidy)
export(tweedie)
export(variables)
importFrom(Rcpp,evalCpp)
Expand Down Expand Up @@ -203,6 +205,7 @@ importFrom(brms,stancode)
importFrom(brms,standata)
importFrom(brms,student)
importFrom(generics,augment)
importFrom(generics,tidy)
importFrom(ggplot2,aes)
importFrom(ggplot2,facet_wrap)
importFrom(ggplot2,geom_bar)
Expand Down
281 changes: 281 additions & 0 deletions R/tidier_methods.R
Original file line number Diff line number Diff line change
@@ -1,7 +1,286 @@
#' @importFrom generics tidy
#' @export
generics::tidy

#' @importFrom generics augment
#' @export
generics::augment


#' Tidy an mvgam object's parameter posteriors
#'
#' Get parameters' posterior statistics, implementing the generic `tidy` from
#' the package \pkg{broom}.
#'
#' The parameters are categorized by the column "type". For instance, the
#' intercept of the observation model (i.e. the "formula" arg to `mvgam()`) has
#' the "type" "observation_beta". The possible "type"s are:
#' * observation_family_extra_param: any extra parameters for your observation
#' model, e.g. sigma for a gaussian observation model. These parameters are
#' not directly derived from the latent trend components (continuing the
#' gaussian example, contrast to mu).
#' * observation_beta: betas from your observation model, excluding any smooths.
#' If your formula was `y ~ x1 + s(x2, bs='cr')`, then your intercept and
#' `x1`'s beta would be categorized as this.
#' * random_effect_group_level: Group-level random effects parameters, i.e.
#' the mean and sd of the distribution from which the specific random
#' intercepts/slopes are considered to be drawn from.
#' * random_effect_beta: betas for the individual random intercepts/slopes.
#' * trend_model_param: parameters from your `trend_model`.
#' * trend_beta: analog of "observation_beta", but for any `trend_formula`.
#' * trend_random_effect_group_level: analog of "random_effect_group_level",
#' but for any `trend_formula`.
#' * trend_random_effect_beta: analog of "random_effect_beta",
#' but for any `trend_formula`.
#'
#' Additionally, GP terms can be incorporated in several ways, leading to
#' different "type"s (or absence!):
#' * `s(bs = "gp")`: No parameters returned.
#' * `gp()` in `formula`: "type" of "observation_param".
#' * `gp()` in `trend_formula`: "type" of "trend_formula_param".
#' * `GP()` in `trend_model`: "type" of "trend_model_param".
#'
#'
#' @param x An object of class `mvgam`.
#' @param probs The desired probability levels of the parameters' posteriors.
#' Defaults to `c(0.025, 0.5, 0.975)`, i.e. 2.5%, 50%, and 97.5%.
#' @param ... Unused, included for generic consistency only.
#' @returns A `tibble` containing:
#' * "parameter": The parameter in question.
#' * "type": The component of the model that the parameter belongs to (see details).
#' * "mean": The posterior mean.
#' * "sd": The posterior standard deviation.
#' * percentile(s): Any percentiles of interest from these posteriors.
#'
#' @family tidiers
#'
#' @examples
#' \dontrun{
#' set.seed(0)
#' simdat <- sim_mvgam(T = 100,
#' n_series = 3,
#' trend_model = AR(),
#' prop_trend = 0.75,
#' family = gaussian())
#' simdat$data_train$x = rnorm(nrow(simdat$data_train))
#' simdat$data_train$year_fac = factor(simdat$data_train$year)
#'
#' mod <- mvgam(y ~ - 1 + s(time, by = series, bs = 'cr', k = 20) + x,
#' trend_formula = ~ s(year_fac, bs = 're') - 1,
#' trend_model = AR(cor = TRUE),
#' family = gaussian(),
#' data = simdat$data_train,
#' silent = 2)
#'
#' tidy(mod, probs = c(0.2, 0.5, 0.8))
#' }
#'
#' @export
tidy.mvgam <- function(x, probs = c(0.025, 0.5, 0.975), ...) {
object <- x
obj_vars <- variables(object)
digits <- 2 # TODO: Let user change?
partialized_mcmc_summary <- purrr::partial(mcmc_summary,
object$model_output,
... =,
ISB = FALSE, # Matches `x[i]`'s rather than `x`.
probs = probs,
digits = digits,
Rhat = FALSE,
n.eff = FALSE)
out <- tibble::tibble()

# Observation family extra parameters --------
xp_names_all <- obj_vars$observation_pars$orig_name
# no matches -> length(xp_names) == 0, even if xp_names_all is NULL
xp_names <- grep("vec", xp_names_all, value = TRUE, invert = TRUE)
if (length(xp_names) > 0) {
extra_params_out <- partialized_mcmc_summary(params = xp_names)
extra_params_out <- tibble::add_column(extra_params_out,
type = "observation_family_extra_param",
.before = 1)
out <- dplyr::bind_rows(out, extra_params_out)
}
# END Observation family extra parameters

# obs non-smoother betas --------
if (object$mgcv_model$nsdf > 0) {
obs_beta_name_map <- dplyr::slice_head(obj_vars$observation_betas, n = object$mgcv_model$nsdf) # df("orig_name", "alias")
obs_betas_out <- partialized_mcmc_summary(params = obs_beta_name_map$orig_name)
row.names(obs_betas_out) <- obs_beta_name_map$alias
obs_betas_out <- tibble::add_column(obs_betas_out,
type = "observation_beta",
.before = 1)
out <- dplyr::bind_rows(out, obs_betas_out)
}
# END obs non-smoother betas

# random effects --------
# TODO: names for random slopes
re_param_name_map <- obj_vars$observation_re_params
if (!is.null(re_param_name_map)) {
re_params_out <- partialized_mcmc_summary(params = re_param_name_map$orig_name)
row.names(re_params_out) <- re_param_name_map$alias
re_params_out <- tibble::add_column(re_params_out,
type = "random_effect_group_level",
.before = 1)
out <- dplyr::bind_rows(out, re_params_out)

# specific betas
for (sp in object$mgcv_model$smooth) {
if (inherits(sp, "random.effect")) {
re_label <- sp$label
betas_all <- obj_vars$observation_betas
re_beta_idxs <- grep(re_label, betas_all$alias, fixed = TRUE)
re_beta_name_map <- dplyr::slice(betas_all, re_beta_idxs)
re_betas_out <- partialized_mcmc_summary(params = re_beta_name_map$orig_name)
row.names(re_betas_out) <- re_beta_name_map$alias
re_betas_out <- tibble::add_column(re_betas_out,
type = "random_effect_beta",
.before = 1)
out <- dplyr::bind_rows(out, re_betas_out)
}
}
}
# END random effects

# GPs --------
if (!is.null(obj_vars$trend_pars)) {
tm_param_names_all <- obj_vars$trend_pars$orig_name
gp_param_names <- grep("^alpha_gp|^rho_gp", tm_param_names_all, value = TRUE)
if (length(gp_param_names) > 0) {
gp_params_out <- partialized_mcmc_summary(params = gp_param_names)
# where is GP? can be in formula, trend_formula, or trend_model
if (grepl("^(alpha|rho)_gp_trend", gp_param_names[[1]])) {
param_type <- "trend_formula_param"
} else if (grepl("^(alpha|rho)_gp_", gp_param_names[[1]])) { # hmph.
param_type <- "observation_param"
} else {
param_type <- "trend_model_param"
}
gp_params_out <- tibble::add_column(gp_params_out,
type = param_type,
.before = 1)
out <- dplyr::bind_rows(out, gp_params_out)
}
}
# END GPs

# RW, AR, CAR, VAR, ZMVN --------
# TODO: split out Sigma for heircor?
trend_model_name <- ifelse(inherits(object$trend_model, "mvgam_trend"),
object$trend_model$trend_model,
object$trend_model) # str vs called obj as arg to mvgam
if (grepl("^VAR|^CAR|^AR|^RW|^ZMVN", trend_model_name)) {
# theta = MA terms
# alpha_cor = heirarchical corr term
# A = VAR auto-regressive matrix
# Sigma = correlated errors matrix
# sigma = errors

# setting up the params to extract
if (trend_model_name == "VAR") {
trend_model_params <- c("^A\\[", "^alpha_cor", "^theta", "^Sigma")
} else if (grepl("^CAR|^AR|^RW", trend_model_name)) {
cor <- inherits(object$trend_model, "mvgam_trend") && object$trend_model$cor
sigma_name <- ifelse(cor, "^Sigma", "^sigma")
trend_model_params <- c("^ar", "^alpha_cor", "^theta", sigma_name)
} else if (grepl("^ZMVN", trend_model_name)) {
trend_model_params <- c("^alpha_cor", "^Sigma")
}

# extracting the params
trend_model_params <- paste(trend_model_params, collapse = "|")
tm_param_names_all <- obj_vars$trend_pars$orig_name
tm_param_names <- grep(trend_model_params, tm_param_names_all, value = TRUE)
tm_params_out <- partialized_mcmc_summary(params = tm_param_names)
tm_params_out <- tibble::add_column(tm_params_out,
type = "trend_model_param",
.before = 1)
out <- dplyr::bind_rows(out, tm_params_out)
}
# END RW, AR, CAR, VAR

# 'None' trend_model with a trend_formula --------
if (trend_model_name == "None" && !is.null(object$trend_call)) {
trend_pars_names_all <- obj_vars$trend_pars$orig_name
trend_pars_names <- grep("sigma", trend_pars_names_all, value = TRUE)
if (length(trend_pars_names) > 0) {
trend_params_out <- partialized_mcmc_summary(params = trend_pars_names)
trend_params_out <- tibble::add_column(trend_params_out,
type = "trend_model_param",
.before = 1)
out <- dplyr::bind_rows(out, trend_params_out)
}
}
# END 'None' trend_model with a trend_formula

# Piecewise --------
# TODO: potentially lump into AR section, above; how to handle change points?
# to lump in, just add an
# `else if (grepl("^PW", trend_model_name)`, then
# `trend_model_params <- c("^k_trend", "^m_trend", "^delta_trend")`
# and change initial grep(ar car var) call
if (grepl("^PW", trend_model_name)) {
trend_model_params <- "^k_trend|^m_trend|^delta_trend"
tm_param_names_all <- obj_vars$trend_pars$orig_name
tm_param_names <- grep(trend_model_params, tm_param_names_all, value = TRUE)
tm_params_out <- partialized_mcmc_summary(params = tm_param_names)
tm_params_out <- tibble::add_column(tm_params_out,
type = "trend_model_param",
.before = 1)
out <- dplyr::bind_rows(out, tm_params_out)
}
# END Piecewise

# Trend formula betas --------
if (!is.null(object$trend_call) && object$trend_mgcv_model$nsdf > 0) {
trend_beta_name_map <- dplyr::slice_head(obj_vars$trend_betas,
n = object$trend_mgcv_model$nsdf) # df("orig_name", "alias")
trend_betas_out <- partialized_mcmc_summary(params = trend_beta_name_map$orig_name)
row.names(trend_betas_out) <- trend_beta_name_map$alias
trend_betas_out <- tibble::add_column(trend_betas_out,
type = "trend_beta",
.before = 1)
out <- dplyr::bind_rows(out, trend_betas_out)
}
# END Trend formula betas

# trend random effects --------
trend_re_param_name_map <- obj_vars$trend_re_params
if (!is.null(trend_re_param_name_map)) {
trend_re_params_out <- partialized_mcmc_summary(params = trend_re_param_name_map$orig_name)
row.names(trend_re_params_out) <- trend_re_param_name_map$alias
trend_re_params_out <- tibble::add_column(trend_re_params_out,
type = "trend_random_effect_group_level",
.before = 1)
out <- dplyr::bind_rows(out, trend_re_params_out)

# specific betas
for (sp in object$trend_mgcv_model$smooth) {
if (inherits(sp, "random.effect")) {
trend_re_label <- sp$label
trend_betas_all <- obj_vars$trend_betas
trend_re_beta_idxs <- grep(trend_re_label, trend_betas_all$alias, fixed = TRUE)
trend_re_beta_name_map <- dplyr::slice(trend_betas_all, trend_re_beta_idxs)
trend_re_betas_out <- partialized_mcmc_summary(params = trend_re_beta_name_map$orig_name)
row.names(trend_re_betas_out) <- trend_re_beta_name_map$alias
trend_re_betas_out <- tibble::add_column(trend_re_betas_out,
type = "trend_random_effect_beta",
.before = 1)
out <- dplyr::bind_rows(out, trend_re_betas_out)
}
}
}
# END trend random effects

# OUTPUT --------
# TODO: might need to put this prior to every bind_rows to avoid rowname dups.
out <- tibble::rownames_to_column(out, "parameter")
out
}


#' Augment an mvgam object's data
#'
#' Add fits and residuals to the data, implementing the generic `augment` from
Expand All @@ -27,6 +306,8 @@ generics::augment
#' * The residuals, along with their variability and credible bounds.
#'
#' @seealso \code{\link{residuals.mvgam}}, \code{\link{fitted.mvgam}}
#' @family tidiers
#'
#' @examples
#' \dontrun{
#' set.seed(0)
Expand Down
4 changes: 4 additions & 0 deletions man/augment.mvgam.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion man/reexports.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 1250b80

Please sign in to comment.