Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

augment.mvgam tidying method added #88

Merged
merged 3 commits into from
Nov 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ Imports:
purrr,
dplyr,
magrittr,
rlang
rlang,
generics,
tibble (>= 3.0.0)
Encoding: UTF-8
LazyData: true
Roxygen: list(markdown = TRUE)
Expand All @@ -51,9 +53,8 @@ Suggests:
usethis,
testthat
Enhances:
gratia (>= 0.9.0),
tibble (>= 3.0.0),
tidyr
gratia (>= 0.9.0),
tidyr
Additional_repositories: https://mc-stan.org/r-packages/
LinkingTo: Rcpp, RcppArmadillo
VignetteBuilder: knitr
4 changes: 4 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ S3method(as_draws_df,mvgam)
S3method(as_draws_list,mvgam)
S3method(as_draws_matrix,mvgam)
S3method(as_draws_rvars,mvgam)
S3method(augment,mvgam)
S3method(coef,mvgam)
S3method(conditional_effects,mvgam)
S3method(ensemble,mvgam_forecast)
Expand Down Expand Up @@ -87,6 +88,7 @@ export(as_draws_df)
export(as_draws_list)
export(as_draws_matrix)
export(as_draws_rvars)
export(augment)
export(bernoulli)
export(beta_binomial)
export(betar)
Expand Down Expand Up @@ -197,6 +199,7 @@ importFrom(brms,set_prior)
importFrom(brms,stancode)
importFrom(brms,standata)
importFrom(brms,student)
importFrom(generics,augment)
importFrom(ggplot2,aes)
importFrom(ggplot2,facet_wrap)
importFrom(ggplot2,geom_bar)
Expand Down Expand Up @@ -334,6 +337,7 @@ importFrom(stats,quantile)
importFrom(stats,rbeta)
importFrom(stats,rbinom)
importFrom(stats,reformulate)
importFrom(stats,residuals)
importFrom(stats,rgamma)
importFrom(stats,rlnorm)
importFrom(stats,rnbinom)
Expand Down
71 changes: 71 additions & 0 deletions R/tidier_methods.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
#' @importFrom generics augment
#' @export
generics::augment

#' Augment an mvgam object's data
#'
#' Add fits and residuals to the data, implementing the generic `augment` from
#' the package {broom}.
#'
#' A `list` is returned if `class(x$obs_data) == 'list'`, otherwise a `tibble` is
#' returned, but the contents of either object is the same.
#'
#' The arguments `robust` and `probs` are applied to both the fit and
#' residuals calls (see [fitted.mvgam()] and [residuals.mvgam()] for details).
#'
#' @param x An object of class `mvgam`.
#' @param robust If `FALSE` (the default) the mean is used as the measure of
#' central tendency and the standard deviation as the measure of variability.
#' If `TRUE`, the median and the median absolute deviation (MAD)
#' are applied instead.
#' @param probs The percentiles to be computed by the quantile function.
#' @param ... Unused, included for generic consistency only.
#' @returns A `list` or `tibble` (see details) combining:
#' * The data supplied to `mvgam()`.
#' * The outcome variable, named as `.observed`.
#' * The fitted backcasts, along with their variability and credible bounds.
#' * The residuals, along with their variability and credible bounds.
#'
#' @examples
#' \dontrun{
#' set.seed(0)
#' dat <- sim_mvgam(T = 80,
#' n_series = 3,
#' mu = 2,
#' trend_model = AR(p = 1),
#' prop_missing = 0.1,
#' prop_trend = 0.6)
#' mod1 <- mvgam(formula = y ~ s(season, bs = 'cc', k = 6),
#' data = dat$data_train,
#' trend_model = AR(),
#' family = poisson(),
#' noncentred = TRUE)
#' augment(mod1, robust = TRUE, probs = c(0.25, 0.75))
#' }
#'
#' @importFrom stats residuals
#' @export
augment.mvgam <- function(x,
robust = FALSE,
probs = c(0.025, 0.975),
...) {
obs_data <- x$obs_data
obs_data$.observed = obs_data$y
obs_data <- purrr::discard_at(obs_data, c("index..orig..order", "index..time..index"))

resids <- residuals(x, robust = robust, probs = probs) %>%
tibble::as_tibble()
fits <- fitted(x, robust = robust, probs = probs) %>%
tibble::as_tibble()
hc_fits <- fits %>%
dplyr::slice_head(n = NROW(resids)) # fits can include fcs
colnames(resids) <- c(".resid", ".resid.variability", ".resid.cred.low", ".resid.cred.high")
colnames(hc_fits) <- c(".fitted", ".fit.variability", ".fit.cred.low", ".fit.cred.high")

augmented <- c(obs_data, hc_fits, resids) # coerces to list
if (!identical(class(x$obs_data), "list")) { # data.frame
augmented <- tibble::as_tibble(augmented)
}

augmented
}
58 changes: 58 additions & 0 deletions man/augment.mvgam.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 4 additions & 1 deletion man/reexports.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 11 additions & 0 deletions tests/testthat/test-tidier_methods.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
test_that("augment doesn't error", {
expect_no_error(augment(mvgam_example1))
expect_no_error(augment(mvgam_example5))
})

test_that("augment return types", {
out1 = augment(mvgam_example1)
out5 = augment(mvgam_example5)
expect_equal(class(out1)[[1]], "tbl_df")
expect_equal(class(out5), "list")
})
Loading