Skip to content

Commit 299184c

Browse files
add ensemble and standata methods; more tests
1 parent 57ce61f commit 299184c

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

54 files changed

+1362
-144
lines changed

NAMESPACE

+4
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ S3method(as_draws_matrix,mvgam)
1414
S3method(as_draws_rvars,mvgam)
1515
S3method(coef,mvgam)
1616
S3method(conditional_effects,mvgam)
17+
S3method(ensemble,mvgam_forecast)
1718
S3method(find_predictors,mvgam)
1819
S3method(find_predictors,mvgam_prefit)
1920
S3method(fitted,mvgam)
@@ -59,6 +60,7 @@ S3method(smooth.construct,mod.smooth.spec)
5960
S3method(smooth.construct,moi.smooth.spec)
6061
S3method(stancode,mvgam)
6162
S3method(stancode,mvgam_prefit)
63+
S3method(standata,mvgam_prefit)
6264
S3method(summary,mvgam)
6365
S3method(summary,mvgam_prefit)
6466
S3method(update,mvgam)
@@ -76,6 +78,7 @@ export(code)
7678
export(compare_mvgams)
7779
export(drawDotmvgam)
7880
export(dynamic)
81+
export(ensemble)
7982
export(eval_mvgam)
8083
export(eval_smoothDothilbertDotsmooth)
8184
export(eval_smoothDotmodDotsmooth)
@@ -135,6 +138,7 @@ importFrom(brms,qstudent_t)
135138
importFrom(brms,rbeta_binomial)
136139
importFrom(brms,rstudent_t)
137140
importFrom(brms,stancode)
141+
importFrom(brms,standata)
138142
importFrom(brms,student)
139143
importFrom(ggplot2,scale_colour_discrete)
140144
importFrom(ggplot2,scale_fill_discrete)

NEWS.md

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
# mvgam 1.1.3 (development version; not yet on CRAN)
22
## New functionalities
3-
* Allow intercepts to be included in process models when `trend_formula` is supplied. This breaks the assumption that the process has to be zero-centred, adding flexibility but also potentially inducing nonidentifiabilities with respect to any observation model intercepts. Thoughtful priors are a must for these models
4-
* Added `stancode.mvgam` and `stancode.mvgam_prefit` methods
3+
* Allow intercepts to be included in process models when `trend_formula` is supplied. This breaks the assumption that the process has to be zero-centred, adding more modelling flexibility but also potentially inducing nonidentifiabilities with respect to any observation model intercepts. Thoughtful priors are a must for these models
4+
* Added `standata.mvgam_prefit`, `stancode.mvgam` and `stancode.mvgam_prefit` methods for better alignment with 'brms' workflows
55
* Added 'gratia' to *Enhancements* to allow popular methods such as `draw()` to be used for 'mvgam' models if 'gratia' is already installed
6+
* Added an `ensemble.mvgam_forecast` method to generate evenly weighted combinations of probabilistic forecast distributions
67

78
## Deprecations
89
* The `drift` argument has been deprecated. It is now recommended for users to include parametric fixed effects of "time" in their respective GAM formulae to capture any expected drift effects

R/ensemble.R

+175
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
#' Combine mvgam forecasts into evenly weighted ensembles
2+
#'
3+
#' Generate evenly weighted ensemble forecast distributions from \code{mvgam_forecast} objects
4+
#'
5+
#'@name ensemble.mvgam_forecast
6+
#'@param object \code{list} object of class \code{mvgam_forecast}. See [forecast.mvgam()]
7+
#'@param ... More \code{mvgam_forecast} objects.
8+
#'@details It is widely recognised in the forecasting literature that combining forecasts
9+
#'from different models often results in improved forecast accuracy. The simplest way to create
10+
#'an ensemble is to use evenly weighted combinations of forecasts from the different models.
11+
#' This is straightforward to do in a Bayesian setting with `mvgam` as the posterior MCMC draws
12+
#' contained in each \code{mvgam_forecast} object will already implicitly capture correlations among
13+
#' the temporal posterior predictions.
14+
#'@return An object of class \code{mvgam_forecast} containing the ensemble predictions. This
15+
#'object can be readily used with the supplied S3 functions \code{plot} and \code{score}
16+
#'@author Nicholas J Clark
17+
#'@seealso \code{\link{plot.mvgam_forecast}}, \code{\link{score.mvgam_forecast}}
18+
#' @examples
19+
#' \donttest{
20+
#' # Simulate some series and fit a few competing dynamic models
21+
#' set.seed(1)
22+
#' simdat <- sim_mvgam(n_series = 1,
23+
#' prop_trend = 0.6,
24+
#' mu = 1)
25+
#'
26+
#' plot_mvgam_series(data = simdat$data_train,
27+
#' newdata = simdat$data_test)
28+
#'
29+
#' m1 <- mvgam(y ~ 1,
30+
#' trend_formula = ~ time +
31+
#' s(season, bs = 'cc', k = 9),
32+
#' trend_model = AR(p = 1),
33+
#' noncentred = TRUE,
34+
#' data = simdat$data_train,
35+
#' newdata = simdat$data_test)
36+
#'
37+
#' m2 <- mvgam(y ~ time,
38+
#' trend_model = RW(),
39+
#' noncentred = TRUE,
40+
#' data = simdat$data_train,
41+
#' newdata = simdat$data_test)
42+
#'
43+
#' # Calculate forecast distributions for each model
44+
#' fc1 <- forecast(m1)
45+
#' fc2 <- forecast(m2)
46+
#'
47+
#' # Generate the ensemble forecast
48+
#' ensemble_fc <- ensemble(fc1, fc2)
49+
#'
50+
#' # Plot forecasts
51+
#' plot(fc1)
52+
#' plot(fc2)
53+
#' plot(ensemble_fc)
54+
#'
55+
#' # Score forecasts
56+
#' score(fc1)
57+
#' score(fc2)
58+
#' score(ensemble_fc)
59+
#' }
60+
#'@export
61+
ensemble <- function(object, ...){
62+
UseMethod("ensemble", object)
63+
}
64+
65+
#'@rdname ensemble.mvgam_forecast
66+
#'@method ensemble mvgam_forecast
67+
#'@param ndraws Positive integer specifying the number of draws to use from each
68+
#'forecast distribution for creating the ensemble. If some of the ensemble members have
69+
#'fewer draws than `ndraws`, their forecast distributions will be resampled with replacement
70+
#'to achieve the correct number of draws
71+
#'@export
72+
ensemble.mvgam_forecast <- function(object, ..., ndraws = 5000){
73+
models <- split_fc_dots(object, ..., model_names = NULL)
74+
n_models <- length(models)
75+
76+
# Check that series names and key dimensions match for all forecasts
77+
allsame <- function(x) length(unique(x)) == 1
78+
if(!allsame(purrr::map(models, 'series_names'))){
79+
stop('Names of series must match for all forecast objects.',
80+
call. = FALSE)
81+
}
82+
83+
if(!allsame(lapply(models, function(x) length(x$forecasts)))){
84+
stop('The number of forecast distributions must match for all forecast objects.',
85+
call. = FALSE)
86+
}
87+
88+
if(!allsame(lapply(models, function(x) length(x$test_observations)))){
89+
stop('Validation data must match for all forecast objects.',
90+
call. = FALSE)
91+
}
92+
93+
if(!allsame(lapply(models, function(x) {
94+
unlist(lapply(x$forecasts, function(y) dim(y)[2]),
95+
use.names = FALSE) }))){
96+
stop('Forecast horizons must match for all forecast objects.',
97+
call. = FALSE)
98+
}
99+
100+
validate_pos_integer(ndraws)
101+
102+
# End of checks; now proceed with ensembling
103+
n_series <- length(models[[1]]$series_names)
104+
105+
# Function to random sample rows of a matrix with
106+
# replacement (in case some forecasts contain fewer draws than others)
107+
subsamp <- function(x, nsamps){
108+
if(NROW(x) < nsamps){
109+
sampinds <- sample(1:NROW(x), nsamps, replace = TRUE)
110+
} else {
111+
sampinds <- 1:nsamps
112+
}
113+
114+
x[sampinds, ]
115+
}
116+
117+
# Create evenly weighted ensemble forecasts
118+
ens_fcs <- lapply(seq_len(n_series), function(series){
119+
all_fcs <- do.call(rbind,
120+
lapply(models,
121+
function(x) x$forecasts[[series]]))
122+
subsamp(all_fcs, ndraws)
123+
})
124+
125+
# Initiate the ensemble forecast object
126+
ens_fc <- models[[1]]
127+
128+
# Add in forecasts
129+
ens_fc$forecasts <- ens_fcs
130+
names(ens_fc$forecasts) <- names(models[[1]]$forecasts)
131+
132+
# Ensure hindcasts have same number of samples
133+
ens_hcs <- lapply(seq_len(n_series), function(series){
134+
subsamp(ens_fc$hindcasts[[series]], ndraws)
135+
})
136+
ens_fc$hindcasts <- ens_hcs
137+
names(ens_fc$hindcasts) <- names(models[[1]]$hindcasts)
138+
139+
# Return
140+
return(ens_fc)
141+
}
142+
143+
144+
#'@noRd
145+
split_fc_dots = function (x, ..., model_names = NULL, other = TRUE) {
146+
147+
dots <- list(x, ...)
148+
names <- substitute(list(x, ...), env = parent.frame())[-1]
149+
names <- ulapply(names, deparse)
150+
151+
if(!is.null(model_names)){
152+
names <- model_names
153+
}
154+
155+
if (length(names)) {
156+
if (!length(names(dots))) {
157+
names(dots) <- names
158+
}
159+
else {
160+
has_no_name <- !nzchar(names(dots))
161+
names(dots)[has_no_name] <- names[has_no_name]
162+
}
163+
}
164+
is_mvgam_fc <- unlist(lapply(dots, function(y) inherits(y, 'mvgam_forecast')))
165+
models <- dots[is_mvgam_fc]
166+
out <- dots[!is_mvgam_fc]
167+
168+
if (length(out)) {
169+
stop("Only mvgam_forecast objects can be passed to '...' for this method.",
170+
call. = FALSE)
171+
}
172+
models
173+
}
174+
175+

R/families.R

+7-7
Original file line numberDiff line numberDiff line change
@@ -454,7 +454,7 @@ mvgam_predict = function(Xp,
454454
mu <- ((matrix(Xp, ncol = NCOL(Xp)) %*%
455455
betas)) + attr(Xp, 'model.offset')
456456
sd <- as.vector(family_pars$sigma_obs)
457-
out <- (exp((sd) ^ 2) - 1) * exp((2 * mu + sd ^ 2))
457+
out <- as.vector((exp((sd) ^ 2) - 1) * exp((2 * mu + sd ^ 2)))
458458

459459
} else {
460460
mu <- as.vector((matrix(Xp, ncol = NCOL(Xp)) %*%
@@ -610,10 +610,10 @@ mvgam_predict = function(Xp,
610610
out <- ((n * p) * (1 - p)) * ((alpha + beta + n) / (alpha + beta + 1))
611611

612612
} else {
613-
out <- plogis(((matrix(Xp, ncol = NCOL(Xp)) %*%
613+
out <- as.vector(plogis(((matrix(Xp, ncol = NCOL(Xp)) %*%
614614
betas)) +
615615
attr(Xp, 'model.offset')) *
616-
as.vector(family_pars$trials)
616+
as.vector(family_pars$trials))
617617
}
618618
}
619619

@@ -640,9 +640,9 @@ mvgam_predict = function(Xp,
640640
betas) + attr(Xp, 'model.offset')))
641641
out <- mu + mu^2 / as.vector(family_pars$phi)
642642
} else {
643-
out <- exp(((matrix(Xp, ncol = NCOL(Xp)) %*%
643+
out <- as.vector(exp(((matrix(Xp, ncol = NCOL(Xp)) %*%
644644
betas)) +
645-
attr(Xp, 'model.offset'))
645+
attr(Xp, 'model.offset')))
646646
}
647647
}
648648

@@ -735,9 +735,9 @@ mvgam_predict = function(Xp,
735735
attr(Xp, 'model.offset')) ^ 1.5) *
736736
as.vector(family_pars$phi)
737737
} else {
738-
out <- exp(((matrix(Xp, ncol = NCOL(Xp)) %*%
738+
out <- as.vector(exp(((matrix(Xp, ncol = NCOL(Xp)) %*%
739739
betas)) +
740-
attr(Xp, 'model.offset'))
740+
attr(Xp, 'model.offset')))
741741
}
742742
}
743743
return(out)

R/forecast.mvgam.R

+1-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
#'@details Posterior predictions are drawn from the fitted \code{mvgam} and used to simulate a forecast distribution
2020
#'@return An object of class \code{mvgam_forecast} containing hindcast and forecast distributions.
2121
#'See \code{\link{mvgam_forecast-class}} for details.
22-
#'@seealso \code{\link{hindcast}}, \code{\link{score}}
22+
#'@seealso \code{\link{hindcast}}, \code{\link{score}}, \code{\link{ensemble}}
2323
#'@export
2424
forecast <- function(object, ...){
2525
UseMethod("forecast", object)

R/mvgam.R

+8-3
Original file line numberDiff line numberDiff line change
@@ -366,7 +366,8 @@
366366
#'plot_mvgam_series(data = dat$data_train, series = 'all')
367367
#'
368368
#'# Formulate a model using Stan where series share a cyclic smooth for
369-
#'# seasonality and each series has an independent AR1 temporal process;
369+
#'# seasonality and each series has an independent AR1 temporal process.
370+
#'# Note that 'noncentred = TRUE' will likely give performance gains.
370371
#'# Set run_model = FALSE to inspect the returned objects
371372
#'mod1 <- mvgam(formula = y ~ s(season, bs = 'cc', k = 6),
372373
#' data = dat$data_train,
@@ -377,9 +378,13 @@
377378
#' run_model = FALSE)
378379
#'
379380
#' # View the model code in Stan language
380-
#' code(mod1)
381+
#' stancode(mod1)
381382
#'
382-
#' # Now fit the model, noting that 'noncentred = TRUE' will likely give performance gains
383+
#' # View the data objects needed to fit the model in Stan
384+
#' sdata1 <- standata(mod1)
385+
#' str(sdata1)
386+
#'
387+
#' # Now fit the model
383388
#' mod1 <- mvgam(formula = y ~ s(season, bs = 'cc', k = 6),
384389
#' data = dat$data_train,
385390
#' trend_model = AR(),

R/score.mvgam_forecast.R

+1-1
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
#'str(fc_scores)
5050
#'}
5151
#'@method score mvgam_forecast
52-
#'@seealso \code{\link{forecast.mvgam}}
52+
#'@seealso \code{\link{forecast.mvgam}}, \code{\link{ensemble}}
5353
#'@export
5454
score.mvgam_forecast = function(object, score = 'crps',
5555
log = FALSE, weights,

R/stan_utils.R

+20-5
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
1-
#' Stan code for mvgam models
1+
#' Stan code and data objects for mvgam models
22
#'
3-
#' Generate Stan code for \pkg{mvgam} models
3+
#' Generate Stan code and data objects for \pkg{mvgam} models
44
#'
55
#' @param object An object of class `mvgam` or `mvgam_prefit`,
66
#' returned from a call to \code{mvgam}
7-
#' @return A character string containing the fully commented \pkg{Stan} code
8-
#' to fit a \pkg{mvgam} model. It is of class \code{c("character", "brmsmodel")}
9-
#' to facilitate pretty printing.
7+
#' @return Either a character string containing the fully commented \pkg{Stan} code
8+
#' to fit a \pkg{mvgam} model or a named list containing the data objects needed
9+
#' to fit the model in Stan.
1010
#' @export
1111
#' @examples
1212
#' simdat <- sim_mvgam()
@@ -15,8 +15,14 @@
1515
#' family = poisson(),
1616
#' data = simdat$data_train,
1717
#' run_model = FALSE)
18+
#'
19+
#' # View Stan model code
1820
#' stancode(mod)
1921
#'
22+
#' # View Stan model data
23+
#' sdata <- standata(mod)
24+
#' str(sdata)
25+
#'
2026
code = function(object){
2127
if(!class(object) %in% c('mvgam', 'mvgam_prefit')){
2228
stop('argument "object" must be of class "mvgam" or "mvgam_prefit"')
@@ -49,6 +55,15 @@ stancode.mvgam = function(object, ...){
4955
code(object)
5056
}
5157

58+
#' @export
59+
#' @importFrom brms standata
60+
#' @param ... ignored
61+
#' @rdname code
62+
standata.mvgam_prefit = function(object, ...){
63+
64+
object$model_data
65+
}
66+
5267
#' @noRd
5368
remove_likelihood = function(model_file){
5469
like_line <- grep('// likelihood functions',

docs/reference/Rplot001.png

-48 KB
Loading

docs/reference/Rplot002.png

8.87 KB
Loading

docs/reference/Rplot003.png

13.5 KB
Loading

docs/reference/Rplot004.png

-22.5 KB
Loading

docs/reference/Rplot005.png

-9.04 KB
Loading

docs/reference/Rplot006.png

-37.4 KB
Loading

docs/reference/Rplot007.png

2.71 KB
Loading

docs/reference/Rplot008.png

23.8 KB
Loading

docs/reference/Rplot009.png

-7.02 KB
Loading

docs/reference/Rplot010.png

27.6 KB
Loading

docs/reference/Rplot011.png

18.7 KB
Loading

docs/reference/Rplot012.png

-6.31 KB
Loading

docs/reference/Rplot013.png

13.9 KB
Loading

0 commit comments

Comments
 (0)