Skip to content

Commit

Permalink
improve efficiency of compare_mvgams
Browse files Browse the repository at this point in the history
  • Loading branch information
nicholasjclark committed Dec 22, 2021
1 parent d6ed714 commit 0f2e700
Show file tree
Hide file tree
Showing 6 changed files with 121 additions and 14 deletions.
Binary file removed .DS_Store
Binary file not shown.
Binary file removed NEON_manuscript/.DS_Store
Binary file not shown.
Binary file removed NEON_manuscript/Figures/.DS_Store
Binary file not shown.
98 changes: 98 additions & 0 deletions R/eval_mvgam.R
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,8 @@ eval_mvgam = function(object,
}

# Run particles forward in time to generate their forecasts
if(n_cores > 1){

cl <- parallel::makePSOCKcluster(n_cores)
setDefaultCluster(cl)
clusterExport(NULL, c('use_lv',
Expand Down Expand Up @@ -259,6 +261,102 @@ eval_mvgam = function(object,
}, cl = cl)
stopCluster(cl)

} else {
particle_fcs <- lapply(sample_seq, function(x){

if(use_lv){
# Sample a last state estimate for the latent variables
samp_index <- x
last_lvs <- lapply(seq_along(lvs), function(lv){
lvs[[lv]][samp_index, ]
})

# Sample drift and AR parameters
phi <- phis[samp_index, ]
ar1 <- ar1s[samp_index, ]
ar2 <- ar2s[samp_index, ]
ar3 <- ar3s[samp_index, ]

# Sample lv precision
tau <- taus[samp_index,]

# Sample lv loadings
lv_coefs <- do.call(rbind, lapply(seq_len(n_series), function(series){
lv_coefs[[series]][samp_index,]
}))

# Sample beta coefs
betas <- betas[samp_index, ]

# Sample a negative binomial size parameter
size <- sizes[samp_index, ]

# Run the latent variables forward fc_horizon timesteps
lv_preds <- do.call(rbind, lapply(seq_len(n_lv), function(lv){
sim_ar3(phi = phi[lv],
ar1 = ar1[lv],
ar2 = ar2[lv],
ar3 = ar3[lv],
tau = tau,
state = last_lvs[[lv]],
h = fc_horizon)
}))

series_fcs <- lapply(seq_len(n_series), function(series){
trend_preds <- as.numeric(t(lv_preds) %*% lv_coefs[series,])
trunc_preds <- rnbinom(fc_horizon,
mu = exp(as.vector((Xp[which(as.numeric(data_assim$series) == series),] %*%
betas)) + (trend_preds)),
size = size)
trunc_preds
})

} else {
# Run the trends forward fc_horizon timesteps
# Sample index for the particle
samp_index <- x

# Sample beta coefs
betas <- betas[samp_index, ]

# Sample last state estimates for the trends
last_trends <- lapply(seq_along(trends), function(trend){
trends[[trend]][samp_index, ]
})

# Sample AR parameters
phi <- phis[samp_index, ]
ar1 <- ar1s[samp_index, ]
ar2 <- ar2s[samp_index, ]
ar3 <- ar3s[samp_index, ]

# Sample trend precisions
tau <- taus[samp_index,]

# Sample a negative binomial size parameter
size <- sizes[samp_index, ]

series_fcs <- lapply(seq_len(n_series), function(series){
trend_preds <- sim_ar3(phi = phi[series],
ar1 = ar1[series],
ar2 = ar2[series],
ar3 = ar3[series],
tau = tau[series],
state = last_trends[[series]],
h = fc_horizon)
fc <- rnbinom(fc_horizon,
mu = exp(as.vector((Xp[which(as.numeric(data_assim$series) == series),] %*%
betas)) + (trend_preds)),
size = size)
fc
})
}

series_fcs
})

}

# Final forecast distribution
series_fcs <- lapply(seq_len(n_series), function(series){
indexed_forecasts <- do.call(rbind, lapply(seq_along(particle_fcs), function(x){
Expand Down
22 changes: 19 additions & 3 deletions R/roll_eval_mvgam.R
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,28 @@ roll_eval_mvgam = function(object,
}

# Loop across evaluation sequence and calculate evaluation metrics
evals <- lapply(evaluation_seq, function(timepoint){
cl <- parallel::makePSOCKcluster(n_cores)
setDefaultCluster(cl)
clusterExport(NULL, c('all_timepoints',
'evaluation_seq',
'object',
'n_samples',
'fc_horizon',
'eval_mvgam'),
envir = environment())
parallel::clusterEvalQ(cl, library(mgcv))
parallel::clusterEvalQ(cl, library(coda))

pbapply::pboptions(type = "none")
evals <- pbapply::pblapply(evaluation_seq, function(timepoint){
eval_mvgam(object = object,
n_samples = n_samples,
n_cores = n_cores,
n_cores = 1,
eval_timepoint = timepoint,
fc_horizon = fc_horizon)
})
},
cl = cl)
stopCluster(cl)

# Take sum of DRPS at each evaluation point for multivariate models
sum_or_na = function(x){
Expand All @@ -66,6 +81,7 @@ roll_eval_mvgam = function(object,
sum(x, na.rm = T)
}
}

evals_df <- do.call(rbind, do.call(rbind, evals)) %>%
dplyr::group_by(eval_season, eval_year, eval_horizon) %>%
dplyr::summarise(drps = sum_or_na(drps),
Expand Down
15 changes: 4 additions & 11 deletions test_mvjagam.R
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ mod <- mvjagam(data_train = fake_data$data_train,
family = 'nb',
use_lv = F,
trend_model = 'AR3',
n.burnin = 5000,
n.burnin = 10000,
auto_update = F)

# Fit a mis-specified model for testing the model comparison functions by
Expand All @@ -30,7 +30,7 @@ fake_data$data_test$fake_cov <- rnorm(NROW(fake_data$data_test))
mod2 <- mvjagam(data_train = fake_data$data_train,
data_test = fake_data$data_test,
formula = y ~ s(fake_cov, k = 3),
family = 'nb',
family = 'poisson',
use_lv = F,
trend_model = 'RW',
n.burnin = 10,
Expand All @@ -39,8 +39,8 @@ mod2 <- mvjagam(data_train = fake_data$data_train,
auto_update = F)

# Compare the models using rolling forecast DRPS evaluation
compare_mvgams(mod, mod2, fc_horizon = 12,
n_evaluations = 15)
compare_mvgams(mod, mod2, fc_horizon = 6,
n_evaluations = 25, n_cores = 4)

# Summary plots and diagnostics for the preferred model (Model 1)
# Check Dunn-Smyth residuals for autocorrelation
Expand All @@ -49,13 +49,6 @@ lines(mod$resids$Air)
acf(mod$resids$Air)
pacf(mod$resids$Air)

test <- gam(y ~ s(season, bs = c('cc')),
data = fake_data$data_train,
family = nb())
autoplot(forecast(ets(residuals(test), model = 'AAN', damped = T)))
autoplot(forecast(Arima(residuals(test), order = c(3,0,0),
include.drift = T)))

# Plot the estimated seasonality smooth function
plot_mvgam_smooth(mod, smooth = 'season')

Expand Down

0 comments on commit 0f2e700

Please sign in to comment.