diff --git a/DESCRIPTION b/DESCRIPTION index 368bd015..2e2148da 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -26,7 +26,9 @@ Imports: purrr, dplyr, magrittr, - rlang + rlang, + generics, + tibble (>= 3.0.0) Encoding: UTF-8 LazyData: true Roxygen: list(markdown = TRUE) @@ -51,9 +53,8 @@ Suggests: usethis, testthat Enhances: - gratia (>= 0.9.0), - tibble (>= 3.0.0), - tidyr + gratia (>= 0.9.0), + tidyr Additional_repositories: https://mc-stan.org/r-packages/ LinkingTo: Rcpp, RcppArmadillo VignetteBuilder: knitr diff --git a/NAMESPACE b/NAMESPACE index 60b9b692..b4343840 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -12,6 +12,7 @@ S3method(as_draws_df,mvgam) S3method(as_draws_list,mvgam) S3method(as_draws_matrix,mvgam) S3method(as_draws_rvars,mvgam) +S3method(augment,mvgam) S3method(coef,mvgam) S3method(conditional_effects,mvgam) S3method(ensemble,mvgam_forecast) @@ -87,6 +88,7 @@ export(as_draws_df) export(as_draws_list) export(as_draws_matrix) export(as_draws_rvars) +export(augment) export(bernoulli) export(beta_binomial) export(betar) @@ -197,6 +199,7 @@ importFrom(brms,set_prior) importFrom(brms,stancode) importFrom(brms,standata) importFrom(brms,student) +importFrom(generics,augment) importFrom(ggplot2,aes) importFrom(ggplot2,facet_wrap) importFrom(ggplot2,geom_bar) @@ -335,6 +338,7 @@ importFrom(stats,quantile) importFrom(stats,rbeta) importFrom(stats,rbinom) importFrom(stats,reformulate) +importFrom(stats,residuals) importFrom(stats,rgamma) importFrom(stats,rlnorm) importFrom(stats,rnbinom) diff --git a/R/tidier_methods.R b/R/tidier_methods.R new file mode 100644 index 00000000..0d778df5 --- /dev/null +++ b/R/tidier_methods.R @@ -0,0 +1,71 @@ +#' @importFrom generics augment +#' @export +generics::augment + +#' Augment an mvgam object's data +#' +#' Add fits and residuals to the data, implementing the generic `augment` from +#' the package {broom}. +#' +#' A `list` is returned if `class(x$obs_data) == 'list'`, otherwise a `tibble` is +#' returned, but the contents of either object is the same. +#' +#' The arguments `robust` and `probs` are applied to both the fit and +#' residuals calls (see [fitted.mvgam()] and [residuals.mvgam()] for details). +#' +#' @param x An object of class `mvgam`. +#' @param robust If `FALSE` (the default) the mean is used as the measure of +#' central tendency and the standard deviation as the measure of variability. +#' If `TRUE`, the median and the median absolute deviation (MAD) +#' are applied instead. +#' @param probs The percentiles to be computed by the quantile function. +#' @param ... Unused, included for generic consistency only. +#' @returns A `list` or `tibble` (see details) combining: +#' * The data supplied to `mvgam()`. +#' * The outcome variable, named as `.observed`. +#' * The fitted backcasts, along with their variability and credible bounds. +#' * The residuals, along with their variability and credible bounds. +#' +#' @examples +#' \dontrun{ +#' set.seed(0) +#' dat <- sim_mvgam(T = 80, +#' n_series = 3, +#' mu = 2, +#' trend_model = AR(p = 1), +#' prop_missing = 0.1, +#' prop_trend = 0.6) +#' mod1 <- mvgam(formula = y ~ s(season, bs = 'cc', k = 6), +#' data = dat$data_train, +#' trend_model = AR(), +#' family = poisson(), +#' noncentred = TRUE) +#' augment(mod1, robust = TRUE, probs = c(0.25, 0.75)) +#' } +#' +#' @importFrom stats residuals +#' @export +augment.mvgam <- function(x, + robust = FALSE, + probs = c(0.025, 0.975), + ...) { + obs_data <- x$obs_data + obs_data$.observed = obs_data$y + obs_data <- purrr::discard_at(obs_data, c("index..orig..order", "index..time..index")) + + resids <- residuals(x, robust = robust, probs = probs) %>% + tibble::as_tibble() + fits <- fitted(x, robust = robust, probs = probs) %>% + tibble::as_tibble() + hc_fits <- fits %>% + dplyr::slice_head(n = NROW(resids)) # fits can include fcs + colnames(resids) <- c(".resid", ".resid.variability", ".resid.cred.low", ".resid.cred.high") + colnames(hc_fits) <- c(".fitted", ".fit.variability", ".fit.cred.low", ".fit.cred.high") + + augmented <- c(obs_data, hc_fits, resids) # coerces to list + if (!identical(class(x$obs_data), "list")) { # data.frame + augmented <- tibble::as_tibble(augmented) + } + + augmented +} diff --git a/man/augment.mvgam.Rd b/man/augment.mvgam.Rd new file mode 100644 index 00000000..be268890 --- /dev/null +++ b/man/augment.mvgam.Rd @@ -0,0 +1,58 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/tidier_methods.R +\name{augment.mvgam} +\alias{augment.mvgam} +\title{Augment an mvgam object's data} +\usage{ +\method{augment}{mvgam}(x, robust = FALSE, probs = c(0.025, 0.975), ...) +} +\arguments{ +\item{x}{An object of class \code{mvgam}.} + +\item{robust}{If \code{FALSE} (the default) the mean is used as the measure of +central tendency and the standard deviation as the measure of variability. +If \code{TRUE}, the median and the median absolute deviation (MAD) +are applied instead.} + +\item{probs}{The percentiles to be computed by the quantile function.} + +\item{...}{Unused, included for generic consistency only.} +} +\value{ +A \code{list} or \code{tibble} (see details) combining: +\itemize{ +\item The data supplied to \code{mvgam()}. +\item The outcome variable, named as \code{.observed}. +\item The fitted backcasts, along with their variability and credible bounds. +\item The residuals, along with their variability and credible bounds. +} +} +\description{ +Add fits and residuals to the data, implementing the generic \code{augment} from +the package {broom}. +} +\details{ +A \code{list} is returned if \code{class(x$obs_data) == 'list'}, otherwise a \code{tibble} is +returned, but the contents of either object is the same. + +The arguments \code{robust} and \code{probs} are applied to both the fit and +residuals calls (see \code{\link[=fitted.mvgam]{fitted.mvgam()}} and \code{\link[=residuals.mvgam]{residuals.mvgam()}} for details). +} +\examples{ +\dontrun{ +set.seed(0) +dat <- sim_mvgam(T = 80, + n_series = 3, + mu = 2, + trend_model = AR(p = 1), + prop_missing = 0.1, + prop_trend = 0.6) +mod1 <- mvgam(formula = y ~ s(season, bs = 'cc', k = 6), + data = dat$data_train, + trend_model = AR(), + family = poisson(), + noncentred = TRUE) +augment(mod1, robust = TRUE, probs = c(0.25, 0.75)) +} + +} diff --git a/man/reexports.Rd b/man/reexports.Rd index a893e63c..da19e2d5 100644 --- a/man/reexports.Rd +++ b/man/reexports.Rd @@ -2,7 +2,7 @@ % Please edit documentation in R/as.data.frame.mvgam.R, R/conditional_effects.R, % R/get_mvgam_priors.R, R/loo.mvgam.R, R/marginaleffects.mvgam.R, % R/mcmc_plot.mvgam.R, R/mvgam_formulae.R, R/posterior_epred.mvgam.R, -% R/stan_utils.R +% R/stan_utils.R, R/tidier_methods.R \docType{import} \name{reexports} \alias{reexports} @@ -40,6 +40,7 @@ \alias{posterior_linpred} \alias{stancode} \alias{standata} +\alias{augment} \title{Objects exported from other packages} \keyword{internal} \description{ @@ -49,6 +50,8 @@ below to see their documentation. \describe{ \item{brms}{\code{\link[brms:conditional_effects.brmsfit]{conditional_effects}}, \code{\link[brms]{gp}}, \code{\link[brms:mcmc_plot.brmsfit]{mcmc_plot}}, \code{\link[brms:set_prior]{prior}}, \code{\link[brms:set_prior]{prior_}}, \code{\link[brms:set_prior]{prior_string}}, \code{\link[brms]{set_prior}}, \code{\link[brms]{stancode}}, \code{\link[brms]{standata}}} + \item{generics}{\code{\link[generics]{augment}}} + \item{insight}{\code{\link[insight]{get_data}}} \item{loo}{\code{\link[loo]{loo}}, \code{\link[loo]{loo_compare}}} diff --git a/tests/testthat/test-tidier_methods.R b/tests/testthat/test-tidier_methods.R new file mode 100644 index 00000000..33a3632a --- /dev/null +++ b/tests/testthat/test-tidier_methods.R @@ -0,0 +1,11 @@ +test_that("augment doesn't error", { + expect_no_error(augment(mvgam_example1)) + expect_no_error(augment(mvgam_example5)) +}) + +test_that("augment return types", { + out1 = augment(mvgam_example1) + out5 = augment(mvgam_example5) + expect_equal(class(out1)[[1]], "tbl_df") + expect_equal(class(out5), "list") +})