Skip to content

Commit

Permalink
stan support for dynamic factor models
Browse files Browse the repository at this point in the history
  • Loading branch information
nicholasjclark committed Jul 15, 2022
1 parent 8501763 commit 20421f3
Show file tree
Hide file tree
Showing 16 changed files with 545 additions and 207 deletions.
Binary file not shown.
Binary file not shown.
Binary file not shown.
60 changes: 25 additions & 35 deletions NEON_manuscript/next_todo.R
Original file line number Diff line number Diff line change
@@ -1,52 +1,42 @@
library(mvgam)
dat <- sim_mvgam(T = 100, n_series=4, n_lv = 1)
mod1 <- mvgam(formula = y ~ s(season) + s(series, bs = 're'),
dat$true_corrs


mod1 <- mvgam(formula = y ~ s(season, bs = 'cc') +
s(series, bs = 're'),
data_train = dat$data_train,
trend_model = 'GP',
family = 'nb',
trend_model = 'AR3',
family = 'poisson',
use_lv = TRUE,
n_lv = 2,
use_stan = TRUE,
run_model = FALSE)
mod1$model_file
run_model = T,
burnin = 10)
summary(mod1)

mod2 <- mvgam(formula = y~year,
# Good for testing model files without compiling
stanc(model_code = mod1$model_file)$model_name
model_file <- mod1$model_file

mod2 <- mvgam(formula = y ~ s(season, bs = 'cc') +
s(series, bs = 're'),
data_train = dat$data_train,
trend_model = 'RW',
family = 'poisson',
use_lv = TRUE,
n_lv = 2,
run_model = TRUE,
use_stan = TRUE)
mod2$model_file

plot(mod2, 'smooths', residuals = TRUE, derivatives = TRUE)
compare_mvgams(model1 = mod1, model2 = mod2, fc_horizon = 6,
n_evaluations = 30, n_cores = 3)
eval_mvgam(object = mod2, n_cores = 1)
plot(mod1, type = 'forecast', realisations = TRUE)
plot(mod1, type = 'trend', realisations = TRUE)

plot_mvgam_smooth(mod1, 1, 'season', realisations = TRUE, n_realisations = 10)
plot_mvgam_fc(object = mod1, series = 1,
realisations = TRUE, n_realisations = 15)
fake <- dat$data_test
fake$y <- NULL
plot_mvgam_fc(object = mod1, series = 1, data_test = fake)
obj <- forecast(mod1, data_test = fake)
dim(obj)


plot_mvgam_trend(object = mod1, series = 1, data_test = fake,
realisations = TRUE)



burnin = 10)


plot(mod1, series = 3, 'forecast', data_test = dat$data_test)
plot(mod2, series = 3, 'forecast', data_test = dat$data_test)

pfilter_mvgam_init(object = mod1, n_particles = 2000,
n_cores = 3, data_assim = model_dat[28,])
plot(mod1, series = 4, 'trend', data_test = dat$data_test)
plot(mod2, series = 4, 'trend', data_test = dat$data_test)


# models with no smooths
# predictions with new data by extending the temporal process forward
# respect upper bounds for forecasts, prediction, particle filtering
trunc_poiss = function(lambda, bound){
out <- vector(length = length(lambda))
Expand Down
108 changes: 107 additions & 1 deletion R/add_base_dgam_lines.R
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,112 @@
add_base_dgam_lines = function(use_lv, stan = FALSE){

if(stan){
add <- "
if(use_lv){
add <- "
##insert data
transformed data{
// Number of non-zero lower triangular factor loadings
// Ensures identifiability of the model - no rotation of factors
int<lower=1> M;
M = n_lv * (n_series - n_lv) + n_lv * (n_lv - 1) / 2 + n_lv;
}
parameters {
// raw basis coefficients
row_vector<lower=-30,upper=30>[num_basis] b_raw;
// dynamic factors
matrix[n, n_lv] LV;
// dynamic factor lower triangle loading coefficients
vector[M] L;
// smoothing parameters
vector<lower=0.0005>[n_sp] lambda;
}
transformed parameters {
// basis coefficients
row_vector[num_basis] b;
// dynamic factor loading matrix
matrix[n_series, n_lv] lv_coefs;
// constraints allow identifiability of loadings
for (i in 1:(n_lv - 1)) {
for (j in (i + 1):(n_lv)){
lv_coefs[i, j] = 0;
}
}
{
int index;
index = 0;
for (j in 1:n_lv) {
for (i in j:n_series) {
index = index + 1;
lv_coefs[i, j] = L[index];
}
}
}
// derived latent trends
matrix[n, n_series] trend;
for (i in 1:n){;
for (s in 1:n_series){
trend[i, s] = dot_product(lv_coefs[s,], LV[i,]);
}
}
// GAM contribution to expectations (log scale)
vector[total_obs] eta;
eta = to_vector(b * X);
}
model {
##insert smooths
// priors for smoothing parameters
lambda ~ exponential(0.05);
// priors for dynamic factor loading coefficients
L ~ double_exponential(0, 1);
// dynamic factor estimates
for (j in 1:n_lv) {
LV[1, j] ~ normal(0, 1);
}
for (j in 1:n_lv) {
LV[2:n, j] ~ normal(LV[1:(n - 1), j], 1);
}
// likelihood functions
for (i in 1:n) {
for (s in 1:n_series) {
if (y_observed[i, s])
y[i, s] ~ poisson_log(eta[ytimes[i, s]] + trend[i, s]);
}
}
}
generated quantities {
vector[n_sp] rho;
rho = log(lambda);
vector[n_lv] penalty;
penalty = rep_vector(1.0, n_lv);
// posterior predictions
matrix[n, n_series] ypred;
for(i in 1:n){
for(s in 1:n_series){
ypred[i, s] = poisson_log_rng(eta[ytimes[i, s]] + trend[i, s]);
}
}
}
"

} else {
add <- "
##insert data
parameters {
// raw basis coefficients
Expand Down Expand Up @@ -75,6 +180,7 @@ add_base_dgam_lines = function(use_lv, stan = FALSE){
}
}
"
}

} else {
if(use_lv){
Expand Down
14 changes: 13 additions & 1 deletion R/add_stan_data.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
#' @export
#' @param jags_file Prepared JAGS mvgam model file
#' @param stan_file Incomplete Stan model file to be edited
#' @param use_lv logical
#' @param n_lv \code{integer} number of latent dynamic factors (if \code{use_lv = TRUE})
#' @param jags_data Prepared mvgam data for JAGS modelling
#' @param r_prior \code{character} specifying (in Stan syntax) the prior distribution for the Negative Binomial
#'overdispersion parameters. Note that this prior acts on the inverse of \code{r}, which is convenient
Expand All @@ -15,7 +17,8 @@
#' is computationally expensive in \code{JAGS} but can lead to better estimates when true bounds exist. Default is to remove
#' truncation entirely (i.e. there is no upper bound for each series)
#' @return A `list` containing the updated Stan model and model data
add_stan_data = function(jags_file, stan_file,
add_stan_data = function(jags_file, stan_file, use_lv = FALSE,
n_lv,
r_prior,
jags_data, family = 'poisson',
upper_bounds){
Expand Down Expand Up @@ -141,6 +144,13 @@ add_stan_data = function(jags_file, stan_file,
n_sp_data <- NULL
}

# latent variable lines
if(use_lv){
lv_data <- paste0('int<lower=0> n_lv; // number of dynamic factors\n')
} else {
lv_data <- NULL
}

# Search for any non-contiguous indices that sometimes are used by mgcv
if(any(grep('in c\\(', jags_file))){
add_idxs <- TRUE
Expand Down Expand Up @@ -174,6 +184,7 @@ add_stan_data = function(jags_file, stan_file,
paste0(idx_data, collapse = '\n'), '\n',
'int<lower=0> total_obs; // total number of observations\n',
'int<lower=0> n; // number of timepoints per series\n',
lv_data,
n_sp_data,
'int<lower=0> n_series; // number of series\n',
'int<lower=0> num_basis; // total number of basis coefficients\n',
Expand All @@ -194,6 +205,7 @@ add_stan_data = function(jags_file, stan_file,
bounds,
'int<lower=0> total_obs; // total number of observations\n',
'int<lower=0> n; // number of timepoints per series\n',
lv_data,
n_sp_data,
'int<lower=0> n_series; // number of series\n',
'int<lower=0> num_basis; // total number of basis coefficients\n',
Expand Down
Loading

0 comments on commit 20421f3

Please sign in to comment.