Skip to content

Commit

Permalink
augment.mvgam tidying method added
Browse files Browse the repository at this point in the history
Implemented an `augment` method for mvgam objects, one of the "tidying" methods. Largely diverged from the set of recommendations in [generics](https://web.archive.org/web/20240520145739/https://www.tidymodels.org/learn/develop/broom/#implementing-the-augment-method).
  • Loading branch information
swpease committed Nov 8, 2024
1 parent 322dde6 commit 95a9825
Show file tree
Hide file tree
Showing 6 changed files with 127 additions and 5 deletions.
9 changes: 5 additions & 4 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ Imports:
dplyr,
magrittr,
Matrix,
rlang
rlang,
generics,
tibble (>= 3.0.0)
Encoding: UTF-8
LazyData: true
Roxygen: list(markdown = TRUE)
Expand All @@ -56,9 +58,8 @@ Suggests:
usethis,
testthat
Enhances:
gratia (>= 0.9.0),
tibble (>= 3.0.0),
tidyr
gratia (>= 0.9.0),
tidyr
Additional_repositories: https://mc-stan.org/r-packages/
LinkingTo: Rcpp, RcppArmadillo
VignetteBuilder: knitr
4 changes: 4 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ S3method(as_draws_df,mvgam)
S3method(as_draws_list,mvgam)
S3method(as_draws_matrix,mvgam)
S3method(as_draws_rvars,mvgam)
S3method(augment,mvgam)
S3method(coef,mvgam)
S3method(conditional_effects,mvgam)
S3method(ensemble,mvgam_forecast)
Expand Down Expand Up @@ -87,6 +88,7 @@ export(as_draws_df)
export(as_draws_list)
export(as_draws_matrix)
export(as_draws_rvars)
export(augment)
export(bernoulli)
export(beta_binomial)
export(betar)
Expand Down Expand Up @@ -196,6 +198,7 @@ importFrom(brms,set_prior)
importFrom(brms,stancode)
importFrom(brms,standata)
importFrom(brms,student)
importFrom(generics,augment)
importFrom(ggplot2,aes)
importFrom(ggplot2,facet_wrap)
importFrom(ggplot2,geom_bar)
Expand Down Expand Up @@ -334,6 +337,7 @@ importFrom(stats,quantile)
importFrom(stats,rbeta)
importFrom(stats,rbinom)
importFrom(stats,reformulate)
importFrom(stats,residuals)
importFrom(stats,rgamma)
importFrom(stats,rlnorm)
importFrom(stats,rnbinom)
Expand Down
61 changes: 61 additions & 0 deletions R/tidier_methods.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
#' @importFrom generics augment
#' @export
generics::augment

#' Augment an mvgam object's data
#'
#' Implements the generic `augment` from the package {broom}.
#' Add fits and residuals to the data, returning a `tibble`.
#'
#' @param x An object of class `mvgam`.
#' @param robust If `FALSE` (the default) the mean is used as the measure of
#' central tendency and the standard deviation as the measure of variability.
#' If `TRUE`, the median and the median absolute deviation (MAD)
#' are applied instead.
#' @param probs The percentiles to be computed by the quantile function.
#' @param ... Unused, included for generic consistency only.
#' @returns A `tibble` combining:
#' * The data supplied to `mvgam()`.
#' * The fitted backcasts, along with their variability and credible bounds.
#' * The residuals, along with their variability and credible bounds.
#'
#' @examples
#' \dontrun{
#' set.seed(0)
#' dat <- sim_mvgam(T = 80,
#' n_series = 3,
#' mu = 2,
#' trend_model = AR(p = 1),
#' prop_missing = 0.1,
#' prop_trend = 0.6)
#' mod1 <- mvgam(formula = y ~ s(season, bs = 'cc', k = 6),
#' data = dat$data_train,
#' trend_model = AR(),
#' family = poisson(),
#' noncentred = TRUE)
#' augment(mod1, robust = TRUE, probs = c(0.25, 0.75))
#' }
#'
#' @importFrom stats residuals
#' @export
augment.mvgam <- function(x,
robust = FALSE,
probs = c(0.025, 0.975),
...) {
obs_data = tibble::as_tibble(x$obs_data) %>%
dplyr::mutate(.observed = y) %>%
dplyr::select(!dplyr::any_of(c("index..orig..order", "index..time..index")))
resids = residuals(x, robust = robust, probs = probs) %>%
tibble::as_tibble()
fits = fitted(x, robust = robust, probs = probs) %>%
tibble::as_tibble()
hc_fits = fits %>%
dplyr::slice_head(n = NROW(resids)) # fits can include fcs

colnames(resids) <- c(".resid", ".resid.variability", ".resid.cred.low", ".resid.cred.high")
colnames(hc_fits) <- c(".fitted", ".fit.variability", ".fit.cred.low", ".fit.cred.high")

augmented = dplyr::bind_cols(obs_data, hc_fits, resids)

augmented
}
50 changes: 50 additions & 0 deletions man/augment.mvgam.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 4 additions & 1 deletion man/reexports.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions tests/testthat/test-tidier_methods.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
test_that("augment doesn't error", {
expect_no_error(augment(mvgam_example1))
})

0 comments on commit 95a9825

Please sign in to comment.