Skip to content

Commit 9532b78

Browse files
update ppc docs
1 parent 3c4ec4a commit 9532b78

11 files changed

+89
-40
lines changed

R/forecast.mvgam.R

+18-15
Original file line numberDiff line numberDiff line change
@@ -574,7 +574,7 @@ forecast.mvgam = function(object,
574574
use_lv = object$use_lv,
575575
fit_engine = object$fit_engine,
576576
type = type,
577-
series_names = factor(unique(data_train$series),
577+
series_names = factor(levels(data_train$series),
578578
levels = levels(data_train$series)),
579579
train_observations = series_obs,
580580
train_times = unique(data_train$index..time..index),
@@ -654,32 +654,35 @@ forecast_draws = function(object,
654654
Xp_trend <- trend_Xp_matrix(newdata = sort_data(data_test),
655655
trend_map = object$trend_map,
656656
series = series,
657-
mgcv_model = object$trend_mgcv_model)
657+
mgcv_model = object$trend_mgcv_model,
658+
forecast = TRUE)
658659

659660
# For trend_formula models with autoregressive processes,
660661
# the process model operates as: AR * (process[t - 1] - mu[t-1]])
661662
# We therefore need the values of mu at the end of the training set
662663
# to correctly propagate the process model forward
663664
if(use_lv & attr(object$model_data, 'trend_model') != 'GP'){
664665
# Get the observed trend predictor matrix
665-
Xp_trend_last <- trend_Xp_matrix(newdata = object$obs_data,
666-
trend_map = object$trend_map,
667-
series = series,
668-
mgcv_model = object$trend_mgcv_model)
666+
newdata <- trend_map_data_prep(object$obs_data,
667+
object$trend_map,
668+
forecast = TRUE)
669+
Xp_trend_last <- predict(object$trend_mgcv_model,
670+
newdata = newdata,
671+
type = 'lpmatrix')
669672

670673
# Ensure the last three values are used, in case the obs_data
671674
# was not supplied in order
672-
data.frame(time = object$obs_data$index..time..index,
673-
series = object$obs_data$series,
674-
row_id = 1:length(object$obs_data$index..time..index)) %>%
675+
data.frame(time = newdata$index..time..index,
676+
series = newdata$series,
677+
row_id = 1:length(newdata$index..time..index)) %>%
675678
dplyr::arrange(time, series) %>%
676679
dplyr::pull(row_id) -> sorted_inds
677-
678-
linpred_order <- vector(length = 3 * n_series)
679-
last_rows <- tail(sort(sorted_inds), 3 * n_series)
680-
for(i in seq_along(last_rows)){
681-
linpred_order[i] <- which(sorted_inds == last_rows[i])
682-
}
680+
n_processes <- length(unique(object$trend_map$trend))
681+
linpred_order <- tail(sorted_inds, 3 * n_processes)
682+
# last_rows <- tail(sorted_inds, 3 * n_processes)
683+
# for(i in seq_along(last_rows)){
684+
# linpred_order[i] <- which(sorted_inds == last_rows[i])
685+
# }
683686

684687
# Deal with any offsets
685688
if(!all(attr(Xp_trend_last, 'model.offset') == 0)){

R/get_linear_predictors.R

+41-6
Original file line numberDiff line numberDiff line change
@@ -73,13 +73,10 @@ obs_Xp_matrix = function(newdata, mgcv_model){
7373
return(Xp)
7474
}
7575

76-
77-
#' Function to prepare trend linear predictor matrix, ensuring ordering and
78-
#' indexing is correct with respect to the model structure
76+
#' Function to prepare trend linear predictor matrix in the presence of a
77+
#' trend_map
7978
#' @noRd
80-
trend_Xp_matrix = function(newdata, trend_map, series = 'all',
81-
mgcv_model){
82-
79+
trend_map_data_prep = function(newdata, trend_map, forecast = FALSE){
8380
trend_test <- newdata
8481
trend_indicators <- vector(length = length(trend_test$series))
8582
for(i in 1:length(trend_test$series)){
@@ -92,6 +89,44 @@ trend_Xp_matrix = function(newdata, trend_map, series = 'all',
9289
trend_test$series <- trend_indicators
9390
trend_test$y <- NULL
9491

92+
# Only keep one time observation per trend, in case this is a reduced dimensionality
93+
# State-Space model (with a trend_map) and we are forecasting ahead
94+
if(forecast){
95+
data.frame(series = trend_test$series,
96+
time = trend_test$index..time..index,
97+
row_num = 1:length(trend_test$index..time..index)) %>%
98+
dplyr::group_by(series, time) %>%
99+
dplyr::slice_head(n = 1) %>%
100+
dplyr::pull(row_num) -> inds_keep
101+
inds_keep <- sort(inds_keep)
102+
103+
if(inherits(trend_test, 'list')){
104+
trend_test <- lapply(trend_test, function(x){
105+
if(is.matrix(x)){
106+
matrix(x[inds_keep,], ncol = NCOL(x))
107+
} else {
108+
x[inds_keep]
109+
}
110+
111+
})
112+
} else {
113+
trend_test <- trend_test[inds_keep, ]
114+
}
115+
}
116+
117+
return(trend_test)
118+
}
119+
120+
#' Function to prepare trend linear predictor matrix, ensuring ordering and
121+
#' indexing is correct with respect to the model structure
122+
#' @noRd
123+
trend_Xp_matrix = function(newdata, trend_map, series = 'all',
124+
mgcv_model, forecast = FALSE){
125+
126+
trend_test <- trend_map_data_prep(newdata,
127+
trend_map,
128+
forecast = forecast)
129+
95130
suppressWarnings(Xp_trend <- try(predict(mgcv_model,
96131
newdata = trend_test,
97132
type = 'lpmatrix'),

R/ppc.mvgam.R

+12-8
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
#'@title Plot mvgam posterior predictive checks for a specified series
1+
#'@title Plot mvgam conditional posterior predictive checks for a specified series
22
#'@importFrom stats quantile density ecdf formula terms
33
#'@importFrom graphics hist abline box rect lines polygon par
44
#'@importFrom grDevices rgb
@@ -26,16 +26,17 @@
2626
#'@param xlab label for x axis.
2727
#'@param ylab label for y axis.
2828
#'@param ... further \code{\link[graphics]{par}} graphical parameters.
29-
#'@details Posterior predictions are drawn from the fitted \code{mvgam} and compared against
29+
#'@details Conditional posterior predictions are drawn from the fitted \code{mvgam} and compared against
3030
#'the empirical distribution of the observed data for a specified series to help evaluate the model's
3131
#'ability to generate unbiased predictions. For all plots apart from `type = 'rootogram'`, posterior predictions
3232
#'can also be compared to out of sample observations as long as these observations were included as
3333
#''data_test' in the original model fit and supplied here. Rootograms are currently only plotted using the
3434
#''hanging' style.
3535
#'\cr
36-
#'Note that the predictions used for these plots are those that have been generated directly within
37-
#'the `mvgam()` model, so they can be misleading if the model included flexible dynamic trend components. For
38-
#'a broader range of posterior checks that are created using "new data" predictions, see
36+
#'Note that the predictions used for these plots are *conditional on the observed data*, i.e. they
37+
#'are those predictions that have been generated directly within
38+
#'the `mvgam()` model. They can be misleading if the model included flexible dynamic trend components. For
39+
#'a broader range of posterior checks that are created using *unconditional* "new data" predictions, see
3940
#'\code{\link{pp_check.mvgam}}
4041
#'@return A base \code{R} graphics plot showing either a posterior rootogram (for \code{type == 'rootogram'}),
4142
#'the predicted vs observed mean for the
@@ -759,7 +760,7 @@ ppc.mvgam = function(object, newdata, data_test, series = 1, type = 'hist',
759760

760761
#' Posterior Predictive Checks for \code{mvgam} Objects
761762
#'
762-
#' Perform posterior predictive checks with the help
763+
#' Perform unconditional posterior predictive checks with the help
763764
#' of the \pkg{bayesplot} package.
764765
#'
765766
#' @aliases pp_check
@@ -777,11 +778,14 @@ ppc.mvgam = function(object, newdata, data_test, series = 1, type = 'hist',
777778
#' @return A ggplot object that can be further
778779
#' customized using the \pkg{ggplot2} package.
779780
#'
780-
#' @details For a detailed explanation of each of the ppc functions,
781+
#' @details Unlike the conditional posterior checks provided by \code{\link{ppc}},
782+
#' This function computes *unconditional* posterior predictive checks (i.e. it generates
783+
#' predictions for fake data without considering the true observations associated with those
784+
#' fake data). For a detailed explanation of each of the ppc functions,
781785
#' see the \code{\link[bayesplot:PPC-overview]{PPC}}
782786
#' documentation of the \pkg{\link[bayesplot:bayesplot-package]{bayesplot}}
783787
#' package.
784-
#' @seealso \code{\link{ppc}} \code{\link{predict.mvgam}}
788+
#' @seealso \code{\link{ppc}}, \code{\link{predict.mvgam}}
785789
#' @examples
786790
#' \dontrun{
787791
#'simdat <- sim_mvgam(seasonality = 'hierarchical')

R/sysdata.rda

-161 KB
Binary file not shown.

man/pp_check.mvgam.Rd

+6-3
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/ppc.mvgam.Rd

+7-6
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/mvgam.dll

0 Bytes
Binary file not shown.

tests/mvgam_examples.R

+4-1
Original file line numberDiff line numberDiff line change
@@ -152,10 +152,13 @@ mvgam_example1 <- mvgam(y ~ s(season, k = 5),
152152
samples = 30,
153153
chains = 1)
154154

155-
# Univariate process with trend_formula and correlated process errors
155+
# Univariate process with trend_formula, trend_map and correlated process errors
156+
trend_map <- data.frame(series = unique(mvgam_examp_dat$data_train$series),
157+
trend = c(1,1,2))
156158
mvgam_example2 <- mvgam(y ~ 1,
157159
trend_formula = ~ s(season, k = 5),
158160
trend_model = RW(cor = TRUE),
161+
trend_map = trend_map,
159162
family = gaussian(),
160163
data = mvgam_examp_dat$data_train,
161164
burnin = 300,

tests/testthat/Rplots.pdf

1.79 KB
Binary file not shown.

tests/testthat/setup.R

+1
Original file line numberDiff line numberDiff line change
@@ -42,3 +42,4 @@ gaus_data <- sim_mvgam(family = gaussian(),
4242
mu = c(-1, 0, 1),
4343
trend_rel = 0.5,
4444
prop_missing = 0.2)
45+

tests/testthat/test-binomial.R

-1
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,6 @@ test_that("binomial() post-processing works", {
107107
expect_no_error(ppc(mod, type = 'pit'))
108108
expect_no_error(ppc(mod, type = 'cdf'))
109109
expect_no_error(ppc(mod, type = 'rootogram'))
110-
111110
expect_no_error(plot(mod, type = 'residuals'))
112111

113112
expect_no_error(plot_mvgam_series(object = mod))

0 commit comments

Comments
 (0)