Skip to content

Commit 3aae9ea

Browse files
author
Nicholas Clark
committed
add check to ensure test data are in correct order when forecasting
1 parent 791f5a0 commit 3aae9ea

File tree

6 files changed

+93
-43
lines changed

6 files changed

+93
-43
lines changed

R/add_nmixture.R

-24
Original file line numberDiff line numberDiff line change
@@ -719,27 +719,3 @@ add_nmix_posterior = function(model_output,
719719

720720
return(model_output)
721721
}
722-
723-
# Get list object into correct order in case it is not already
724-
sort_data = function(obs_data){
725-
if(inherits(obs_data, 'list')){
726-
obs_data_arranged <- obs_data
727-
temp_dat = data.frame(time = obs_data$time,
728-
series = obs_data$series) %>%
729-
dplyr::mutate(index = dplyr::row_number()) %>%
730-
dplyr::arrange(series, time)
731-
obs_data_arranged <- lapply(obs_data, function(x){
732-
if(is.matrix(x)){
733-
matrix(x[temp_dat$index,], ncol = NCOL(x))
734-
} else {
735-
x[temp_dat$index]
736-
}
737-
})
738-
names(obs_data_arranged) <- names(obs_data)
739-
} else {
740-
obs_data_arranged <- obs_data %>%
741-
dplyr::arrange(series, time)
742-
}
743-
744-
return(obs_data_arranged)
745-
}

R/forecast.mvgam.R

+44-14
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,8 @@ forecast.mvgam = function(object, newdata, data_test,
144144
type = type,
145145
series = series,
146146
data_test = data_test,
147-
n_cores = n_cores)
147+
n_cores = n_cores,
148+
...)
148149

149150
# Extract hindcasts and forecasts into the correct format
150151
if(series == 'all'){
@@ -757,7 +758,10 @@ forecast_draws = function(object,
757758
data_test,
758759
n_cores = 1,
759760
n_samples,
760-
ending_time){
761+
ending_time,
762+
b_uncertainty = TRUE,
763+
trend_uncertainty = TRUE,
764+
obs_uncertainty = TRUE){
761765

762766
# Check arguments
763767
validate_pos_integer(n_cores)
@@ -769,9 +773,10 @@ forecast_draws = function(object,
769773
s_name <- levels(data_test$series)[series]
770774
}
771775

772-
# Generate the observation model linear predictor matrix
776+
# Generate the observation model linear predictor matrix,
777+
# ensuring the test data is sorted correctly (by time and then series)
773778
if(inherits(data_test, 'list')){
774-
Xp <- obs_Xp_matrix(newdata = data_test,
779+
Xp <- obs_Xp_matrix(newdata = sort_data(data_test),
775780
mgcv_model = object$mgcv_model)
776781

777782
if(series != 'all'){
@@ -800,15 +805,16 @@ forecast_draws = function(object,
800805
Xp <- obs_Xp_matrix(newdata = series_test,
801806
mgcv_model = object$mgcv_model)
802807
} else {
803-
Xp <- obs_Xp_matrix(newdata = data_test,
808+
Xp <- obs_Xp_matrix(newdata = sort_data(data_test),
804809
mgcv_model = object$mgcv_model)
805810
series_test <- NULL
806811
}
807812
}
808813

809-
# Generate linear predictor matrix from trend mgcv model
814+
# Generate linear predictor matrix from trend mgcv model, ensuring
815+
# the test data is sorted correctly (by time and then series)
810816
if(!is.null(object$trend_call)){
811-
Xp_trend <- trend_Xp_matrix(newdata = data_test,
817+
Xp_trend <- trend_Xp_matrix(newdata = sort_data(data_test),
812818
trend_map = object$trend_map,
813819
series = series,
814820
mgcv_model = object$trend_mgcv_model)
@@ -984,7 +990,10 @@ forecast_draws = function(object,
984990
'series_test',
985991
'Xp',
986992
'Xp_trend',
987-
'fc_horizon'),
993+
'fc_horizon',
994+
'b_uncertainty',
995+
'trend_uncertainty',
996+
'obs_uncertainty'),
988997
envir = environment())
989998
parallel::clusterExport(cl = cl,
990999
unclass(lsf.str(envir = asNamespace("mvgam"),
@@ -998,18 +1007,31 @@ forecast_draws = function(object,
9981007
samp_index <- i
9991008

10001009
# Sample beta coefs
1001-
betas <- betas[samp_index, ]
1010+
if(b_uncertainty){
1011+
betas <- betas[samp_index, ]
1012+
} else {
1013+
betas <- betas[1, ]
1014+
}
10021015

10031016
if(!is.null(betas_trend)){
1004-
betas_trend <- betas_trend[samp_index, ]
1017+
if(b_uncertainty){
1018+
betas_trend <- betas_trend[samp_index, ]
1019+
} else {
1020+
betas_trend <- betas_trend[1, ]
1021+
}
10051022
}
10061023

10071024
# Return predictions
10081025
if(series == 'all'){
10091026

10101027
# Sample general trend-specific parameters
1011-
general_trend_pars <- extract_general_trend_pars(trend_pars = trend_pars,
1012-
samp_index = samp_index)
1028+
if(trend_uncertainty){
1029+
general_trend_pars <- extract_general_trend_pars(trend_pars = trend_pars,
1030+
samp_index = samp_index)
1031+
} else {
1032+
general_trend_pars <- extract_general_trend_pars(trend_pars = trend_pars,
1033+
samp_index = 1)
1034+
}
10131035

10141036
if(use_lv || trend_model %in% c('VAR1', 'PWlinear', 'PWlogistic')){
10151037
if(trend_model == 'PWlogistic'){
@@ -1107,9 +1129,17 @@ forecast_draws = function(object,
11071129
# Family-specific parameters
11081130
family_extracts <- lapply(seq_along(family_pars), function(x){
11091131
if(is.matrix(family_pars[[x]])){
1110-
family_pars[[x]][samp_index, series]
1132+
if(obs_uncertainty){
1133+
family_pars[[x]][samp_index, series]
1134+
} else {
1135+
family_pars[[x]][1, series]
1136+
}
11111137
} else {
1112-
family_pars[[x]][samp_index]
1138+
if(obs_uncertainty){
1139+
family_pars[[x]][samp_index]
1140+
} else {
1141+
family_pars[[x]][1]
1142+
}
11131143
}
11141144
})
11151145
names(family_extracts) <- names(family_pars)

R/sim_mvgam.R

+11-5
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,11 @@ sim_mvgam = function(T = 100,
106106
use_lv <- FALSE
107107
}
108108

109+
if(trend_model %in% c('RWcor', 'AR1cor', 'AR2cor', 'AR3cor')){
110+
warning(paste0('Simulation of correlated AR or RW trends not yet supported.\n',
111+
'Reverting to uncorrelated trends'))
112+
}
113+
109114
if(missing(trend_rel)){
110115
trend_rel <- prop_trend
111116
}
@@ -219,31 +224,32 @@ sim_mvgam = function(T = 100,
219224
}
220225

221226
# Set trend parameters
222-
if(trend_model == 'RW'){
227+
if(trend_model %in% c('RW', 'RWcor')){
223228
ar1s <- rep(1, n_lv)
224229
ar2s <- rep(0, n_lv)
225230
ar3s <- rep(0, n_lv)
226231
}
227232

228-
if(trend_model == 'AR1'){
233+
if(trend_model %in% c('AR1', 'AR1cor')){
229234
ar1s <- rnorm(n_lv, sd = 0.5)
230235
ar2s <- rep(0, n_lv)
231236
ar3s <- rep(0, n_lv)
232237
}
233238

234-
if(trend_model == 'AR2'){
239+
if(trend_model %in% c('AR2', 'AR2cor')){
235240
ar1s <- rnorm(n_lv, sd = 0.5)
236241
ar2s <- rnorm(n_lv, sd = 0.5)
237242
ar3s <- rep(0, n_lv)
238243
}
239244

240-
if(trend_model == 'AR3'){
245+
if(trend_model %in% c('AR3', 'AR3cor')){
241246
ar1s <- rnorm(n_lv, sd = 0.5)
242247
ar2s <- rnorm(n_lv, sd = 0.5)
243248
ar3s <- rnorm(n_lv, sd = 0.5)
244249
}
245250

246-
if(trend_model %in% c('RW', 'AR1', 'AR2', 'AR3', 'VAR1', 'VAR1cor')){
251+
if(trend_model %in% c('RW', 'AR1', 'AR2', 'AR3',
252+
'VAR1', 'VAR1cor')){
247253
# Sample trend drift terms so they are (hopefully) not too correlated
248254
if(drift){
249255
trend_alphas <- rnorm(n_lv, sd = 0.15)

R/validations.R

+38
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,44 @@ validate_series_time = function(data, name = 'data'){
6464
return(data)
6565
}
6666

67+
#'@noRd
68+
#' Get data objects into correct order in case it is not already
69+
sort_data = function(data, series_time = FALSE){
70+
if(inherits(data, 'list')){
71+
data_arranged <- data
72+
if(series_time){
73+
temp_dat = data.frame(time = data$time,
74+
series = data$series) %>%
75+
dplyr::mutate(index = dplyr::row_number()) %>%
76+
dplyr::arrange(series, time)
77+
} else {
78+
temp_dat = data.frame(time = data$time,
79+
series = data$series) %>%
80+
dplyr::mutate(index = dplyr::row_number()) %>%
81+
dplyr::arrange(time, series)
82+
}
83+
84+
data_arranged <- lapply(data, function(x){
85+
if(is.matrix(x)){
86+
matrix(x[temp_dat$index,], ncol = NCOL(x))
87+
} else {
88+
x[temp_dat$index]
89+
}
90+
})
91+
names(data_arranged) <- names(data)
92+
} else {
93+
if(series_time){
94+
data_arranged <- data %>%
95+
dplyr::arrange(series, time)
96+
} else {
97+
data_arranged <- data %>%
98+
dplyr::arrange(time, series)
99+
}
100+
}
101+
102+
return(data_arranged)
103+
}
104+
67105
#'@importFrom rlang warn
68106
#'@noRd
69107
validate_family = function(family, use_stan = TRUE){

src/mvgam.dll

0 Bytes
Binary file not shown.

tests/testthat/Rplots.pdf

-5 Bytes
Binary file not shown.

0 commit comments

Comments
 (0)