Skip to content

Commit

Permalink
Merge pull request #88 from swpease/broomy
Browse files Browse the repository at this point in the history
`augment.mvgam` tidying method added
  • Loading branch information
nicholasjclark authored Nov 21, 2024
2 parents 9868aed + 1b102e3 commit 55a5505
Show file tree
Hide file tree
Showing 6 changed files with 153 additions and 5 deletions.
9 changes: 5 additions & 4 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ Imports:
purrr,
dplyr,
magrittr,
rlang
rlang,
generics,
tibble (>= 3.0.0)
Encoding: UTF-8
LazyData: true
Roxygen: list(markdown = TRUE)
Expand All @@ -51,9 +53,8 @@ Suggests:
usethis,
testthat
Enhances:
gratia (>= 0.9.0),
tibble (>= 3.0.0),
tidyr
gratia (>= 0.9.0),
tidyr
Additional_repositories: https://mc-stan.org/r-packages/
LinkingTo: Rcpp, RcppArmadillo
VignetteBuilder: knitr
4 changes: 4 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ S3method(as_draws_df,mvgam)
S3method(as_draws_list,mvgam)
S3method(as_draws_matrix,mvgam)
S3method(as_draws_rvars,mvgam)
S3method(augment,mvgam)
S3method(coef,mvgam)
S3method(conditional_effects,mvgam)
S3method(ensemble,mvgam_forecast)
Expand Down Expand Up @@ -87,6 +88,7 @@ export(as_draws_df)
export(as_draws_list)
export(as_draws_matrix)
export(as_draws_rvars)
export(augment)
export(bernoulli)
export(beta_binomial)
export(betar)
Expand Down Expand Up @@ -197,6 +199,7 @@ importFrom(brms,set_prior)
importFrom(brms,stancode)
importFrom(brms,standata)
importFrom(brms,student)
importFrom(generics,augment)
importFrom(ggplot2,aes)
importFrom(ggplot2,facet_wrap)
importFrom(ggplot2,geom_bar)
Expand Down Expand Up @@ -335,6 +338,7 @@ importFrom(stats,quantile)
importFrom(stats,rbeta)
importFrom(stats,rbinom)
importFrom(stats,reformulate)
importFrom(stats,residuals)
importFrom(stats,rgamma)
importFrom(stats,rlnorm)
importFrom(stats,rnbinom)
Expand Down
71 changes: 71 additions & 0 deletions R/tidier_methods.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
#' @importFrom generics augment
#' @export
generics::augment

#' Augment an mvgam object's data
#'
#' Add fits and residuals to the data, implementing the generic `augment` from
#' the package {broom}.
#'
#' A `list` is returned if `class(x$obs_data) == 'list'`, otherwise a `tibble` is
#' returned, but the contents of either object is the same.
#'
#' The arguments `robust` and `probs` are applied to both the fit and
#' residuals calls (see [fitted.mvgam()] and [residuals.mvgam()] for details).
#'
#' @param x An object of class `mvgam`.
#' @param robust If `FALSE` (the default) the mean is used as the measure of
#' central tendency and the standard deviation as the measure of variability.
#' If `TRUE`, the median and the median absolute deviation (MAD)
#' are applied instead.
#' @param probs The percentiles to be computed by the quantile function.
#' @param ... Unused, included for generic consistency only.
#' @returns A `list` or `tibble` (see details) combining:
#' * The data supplied to `mvgam()`.
#' * The outcome variable, named as `.observed`.
#' * The fitted backcasts, along with their variability and credible bounds.
#' * The residuals, along with their variability and credible bounds.
#'
#' @examples
#' \dontrun{
#' set.seed(0)
#' dat <- sim_mvgam(T = 80,
#' n_series = 3,
#' mu = 2,
#' trend_model = AR(p = 1),
#' prop_missing = 0.1,
#' prop_trend = 0.6)
#' mod1 <- mvgam(formula = y ~ s(season, bs = 'cc', k = 6),
#' data = dat$data_train,
#' trend_model = AR(),
#' family = poisson(),
#' noncentred = TRUE)
#' augment(mod1, robust = TRUE, probs = c(0.25, 0.75))
#' }
#'
#' @importFrom stats residuals
#' @export
augment.mvgam <- function(x,
robust = FALSE,
probs = c(0.025, 0.975),
...) {
obs_data <- x$obs_data
obs_data$.observed = obs_data$y
obs_data <- purrr::discard_at(obs_data, c("index..orig..order", "index..time..index"))

resids <- residuals(x, robust = robust, probs = probs) %>%
tibble::as_tibble()
fits <- fitted(x, robust = robust, probs = probs) %>%
tibble::as_tibble()
hc_fits <- fits %>%
dplyr::slice_head(n = NROW(resids)) # fits can include fcs
colnames(resids) <- c(".resid", ".resid.variability", ".resid.cred.low", ".resid.cred.high")
colnames(hc_fits) <- c(".fitted", ".fit.variability", ".fit.cred.low", ".fit.cred.high")

augmented <- c(obs_data, hc_fits, resids) # coerces to list
if (!identical(class(x$obs_data), "list")) { # data.frame
augmented <- tibble::as_tibble(augmented)
}

augmented
}
58 changes: 58 additions & 0 deletions man/augment.mvgam.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 4 additions & 1 deletion man/reexports.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 11 additions & 0 deletions tests/testthat/test-tidier_methods.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
test_that("augment doesn't error", {
expect_no_error(augment(mvgam_example1))
expect_no_error(augment(mvgam_example5))
})

test_that("augment return types", {
out1 = augment(mvgam_example1)
out5 = augment(mvgam_example5)
expect_equal(class(out1)[[1]], "tbl_df")
expect_equal(class(out5), "list")
})

0 comments on commit 55a5505

Please sign in to comment.