diff --git a/R/add_stan_data.R b/R/add_stan_data.R index 38aacd8c..fdeee218 100644 --- a/R/add_stan_data.R +++ b/R/add_stan_data.R @@ -510,21 +510,30 @@ add_stan_data = function(jags_file, stan_file, if(any(grep('## parametric effect priors', jags_file))){ # Get indices of parametric effects - min_paras <- as.numeric(sub('.*(?=.$)', '', - sub("\\:.*", "", - jags_file[grep('## parametric effect', jags_file) + 1]), perl=T)) - max_paras <- as.numeric(substr(sub(".*\\:", "", - jags_file[grep('## parametric effect', jags_file) + 1]), - 1, 1)) - para_indices <- seq(min_paras, max_paras) - - # Get names of parametric terms - int_included <- attr(ss_gam$pterms, 'intercept') == 1L - other_pterms <- attr(ss_gam$pterms, 'term.labels') - all_paras <- other_pterms - if(int_included){ - all_paras <- c('(Intercept)', all_paras) - } + smooth_labs <- do.call(rbind, lapply(seq_along(ss_gam$smooth), function(x){ + data.frame(label = ss_gam$smooth[[x]]$label, + term = paste(ss_gam$smooth[[x]]$term, collapse = ','), + class = class(ss_gam$smooth[[x]])[1]) + })) + lpmat <- predict(ss_gam, type = 'lpmatrix', + exclude = smooth_labs$label) + para_indices <- which(apply(lpmat, 2, function(x) !all(x == 0)) == TRUE) + all_paras <- names(para_indices) + # min_paras <- as.numeric(sub('.*(?=.$)', '', + # sub("\\:.*", "", + # jags_file[grep('## parametric effect', jags_file) + 1]), perl=T)) + # max_paras <- as.numeric(substr(sub(".*\\:", "", + # jags_file[grep('## parametric effect', jags_file) + 1]), + # 1, 1)) + # para_indices <- seq(min_paras, max_paras) + # + # # Get names of parametric terms + # int_included <- attr(ss_gam$pterms, 'intercept') == 1L + # other_pterms <- attr(ss_gam$pterms, 'term.labels') + # all_paras <- other_pterms + # if(int_included){ + # all_paras <- c('(Intercept)', all_paras) + # } # Create prior lines for parametric terms para_lines <- vector() diff --git a/R/stan_utils.R b/R/stan_utils.R index 7b1427bf..91d78e00 100644 --- a/R/stan_utils.R +++ b/R/stan_utils.R @@ -2921,9 +2921,19 @@ add_trend_predictors = function(trend_formula, paste0('// dynamic process models\n', paste0(paste(plines, collapse = '\n'))) } else { - model_file[grep("// dynamic factor estimates", model_file, fixed = TRUE)] <- - paste0('// dynamic process models\n', - paste0(paste(plines, collapse = '\n'))) + if(any(grepl("// dynamic factor estimates", model_file, fixed = TRUE))){ + model_file[grep("// dynamic factor estimates", model_file, fixed = TRUE)] <- + paste0('// dynamic process models\n', + paste0(paste(plines, collapse = '\n'))) + } + + if(any(grepl("// trend means", model_file, fixed = TRUE))){ + model_file[grep("// trend means", model_file, fixed = TRUE)] <- + paste0('// dynamic process models\n', + paste0(paste(plines, collapse = '\n'), + '// trend means')) + } + } } diff --git a/src/RcppExports.o b/src/RcppExports.o index 3738bd53..849c29a4 100644 Binary files a/src/RcppExports.o and b/src/RcppExports.o differ diff --git a/src/mvgam.dll b/src/mvgam.dll index 18b55708..4ee6cc2c 100644 Binary files a/src/mvgam.dll and b/src/mvgam.dll differ diff --git a/src/trend_funs.o b/src/trend_funs.o index 371ea49f..faf098a5 100644 Binary files a/src/trend_funs.o and b/src/trend_funs.o differ diff --git a/tests/testthat/Rplots.pdf b/tests/testthat/Rplots.pdf index 2799ff09..cf2654cc 100644 Binary files a/tests/testthat/Rplots.pdf and b/tests/testthat/Rplots.pdf differ diff --git a/tests/testthat/test-mvgam.R b/tests/testthat/test-mvgam.R index 6dae714d..aff31255 100644 --- a/tests/testthat/test-mvgam.R +++ b/tests/testthat/test-mvgam.R @@ -225,5 +225,133 @@ test_that("trend_formula setup is working properly", { }) +# Check that parametric effect priors are properly incorporated in the +# model for a wide variety of model forms +test_that("parametric effect priors correctly incorporated in models", { + mod_data <- mvgam:::mvgam_examp_dat + mod_data$data_train$x1 <- + rnorm(NROW(mod_data$data_train)) + mod_data$data_train$x2 <- + rnorm(NROW(mod_data$data_train)) + mod_data$data_train$x3 <- + rnorm(NROW(mod_data$data_train)) + # Observation formula; no trend + mod <- mvgam(y ~ s(season) + series:x1 + + series:x2 + series:x3, + trend_model = 'None', + data = mod_data$data_train, + family = gaussian(), + run_model = FALSE) + + expect_true(any(grepl('// prior for seriesseries_3:x1...', + mod$model_file, fixed = TRUE))) + expect_true(any(grepl('// prior for (Intercept)...', + mod$model_file, fixed = TRUE))) + + para_names <- paste0(paste0('// prior for seriesseries_', 1:3, + paste0(':x', 1:3, '...'))) + for(i in seq_along(para_names)){ + expect_true(any(grepl(para_names[i], + mod$model_file, fixed = TRUE))) + } + + priors <- get_mvgam_priors(y ~ s(season) + series:x1 + + series:x2 + series:x3, + trend_model = 'None', + data = mod_data$data_train, + family = gaussian()) + expect_true(any(grepl('seriesseries_1:x2', + priors$param_name))) + expect_true(any(grepl('seriesseries_2:x3', + priors$param_name))) + + + # Observation formula; complex trend + mod <- mvgam(y ~ s(season) + series:x1 + series:x2 + series:x3, + trend_model = 'VARMA', + data = mod_data$data_train, + family = gaussian(), + run_model = FALSE) + + expect_true(any(grepl('// prior for seriesseries_3:x1...', + mod$model_file, fixed = TRUE))) + expect_true(any(grepl('// prior for (Intercept)...', + mod$model_file, fixed = TRUE))) + + para_names <- paste0(paste0('// prior for seriesseries_', 1:3, + paste0(':x', 1:3, '...'))) + for(i in seq_along(para_names)){ + expect_true(any(grepl(para_names[i], + mod$model_file, fixed = TRUE))) + } + + priors <- get_mvgam_priors(y ~ s(season) + series:x1 + + series:x2 + series:x3, + trend_model = 'VARMA', + data = mod_data$data_train, + family = gaussian()) + expect_true(any(grepl('seriesseries_1:x2', + priors$param_name))) + expect_true(any(grepl('seriesseries_2:x3', + priors$param_name))) + + # Trend formula; RW + mod <- mvgam(y ~ 1, + trend_formula = ~ s(season) + trend:x1 + + trend:x2 + trend:x3, + trend_model = 'RW', + data = mod_data$data_train, + family = gaussian(), + run_model = FALSE) + + expect_true(any(grepl('// prior for (Intercept)...', + mod$model_file, fixed = TRUE))) + + para_names <- paste0(paste0('// prior for trendtrend', 1:3, + paste0(':x', 1:3, '_trend...'))) + for(i in seq_along(para_names)){ + expect_true(any(grepl(para_names[i], + mod$model_file, fixed = TRUE))) + } + + priors <- get_mvgam_priors(y ~ 1, + trend_formula = ~ s(season) + trend:x1 + + trend:x2 + trend:x3, + trend_model = 'RW', + data = mod_data$data_train, + family = gaussian()) + expect_true(any(grepl('trendtrend1:x1_trend', + priors$param_name))) + expect_true(any(grepl('trendtrend2:x3_trend', + priors$param_name))) + + # Trend formula; VARMA + mod <- mvgam(y ~ 1, + trend_formula = ~ s(season) + trend:x1 + trend:x2 + trend:x3, + trend_model = 'VARMA', + data = mod_data$data_train, + family = gaussian(), + run_model = FALSE) + + expect_true(any(grepl('// prior for (Intercept)...', + mod$model_file, fixed = TRUE))) + + para_names <- paste0(paste0('// prior for trendtrend', 1:3, + paste0(':x', 1:3, '_trend...'))) + for(i in seq_along(para_names)){ + expect_true(any(grepl(para_names[i], + mod$model_file, fixed = TRUE))) + } + + priors <- get_mvgam_priors(y ~ 1, + trend_formula = ~ s(season) + trend:x1 + trend:x2 + trend:x3, + trend_model = 'RW', + data = mod_data$data_train, + family = gaussian()) + expect_true(any(grepl('trendtrend1:x1_trend', + priors$param_name))) + expect_true(any(grepl('trendtrend2:x3_trend', + priors$param_name))) +})