diff --git a/R/gp.R b/R/gp.R index 90479cce..e48c9332 100644 --- a/R/gp.R +++ b/R/gp.R @@ -147,6 +147,8 @@ make_gp_additions = function(gp_details, k = k_gps[[x]], def_rho = gp_details$def_rho[x], def_rho_2 = gp_details$def_rho_2[x], + def_rho_3 = gp_details$def_rho_3[x], + def_rho_4 = gp_details$def_rho_4[x], def_alpha = gp_details$def_alpha[x], eigenvalues = eigenvals[[x]]) @@ -525,14 +527,20 @@ get_gp_attributes = function(formula, data, family = gaussian()){ def_alpha <- 'student_t(3, 0, 2.5);' } if(length(def_rho) > 1L){ + def_rho_1 <- def_rho[1] def_rho_2 <- def_rho[2] - def_rho <- def_rho[1] - out <- data.frame(def_rho = def_rho, + out <- data.frame(def_rho = def_rho_1, def_rho_2 = def_rho_2, - def_alpha = def_alpha) + def_rho_3 = NA, + def_rho_4 = NA, + def_alpha = def_alpha) + if(length(def_rho) > 2L) out$def_rho_3 <- def_rho[3] + if(length(def_rho) > 3L) out$def_rho_4 <- def_rho[4] } else { out <- data.frame(def_rho = def_rho, def_rho_2 = NA, + def_rho_3 = NA, + def_rho_4 = NA, def_alpha = def_alpha) } out @@ -566,7 +574,9 @@ get_gp_attributes = function(formula, data, family = gaussian()){ level = NA, def_alpha = gp_def_priors$def_alpha, def_rho = gp_def_priors$def_rho, - def_rho_2 = gp_def_priors$def_rho_2) + def_rho_2 = gp_def_priors$def_rho_2, + def_rho_3 = gp_def_priors$def_rho_3, + def_rho_4 = gp_def_priors$def_rho_4) attr(ret_dat, 'gp_formula') <- gp_formula # Return as a data.frame @@ -605,6 +615,10 @@ add_gp_model_file = function(model_file, model_data, use.names = FALSE) rho_2_priors <- unlist(purrr::map(gp_additions$gp_att_table, 'def_rho_2'), use.names = FALSE) + rho_3_priors <- unlist(purrr::map(gp_additions$gp_att_table, 'def_rho_3'), + use.names = FALSE) + rho_4_priors <- unlist(purrr::map(gp_additions$gp_att_table, 'def_rho_4'), + use.names = FALSE) alpha_priors <- unlist(purrr::map(gp_additions$gp_att_table, 'def_alpha'), use.names = FALSE) @@ -631,6 +645,11 @@ add_gp_model_file = function(model_file, model_data, gp_names_clean <- clean_gpnames(gp_names) s_to_remove <- list() for(i in seq_along(gp_names)){ + i_rho_priors <- c(rho_priors[i], + rho_2_priors[i], + rho_3_priors[i], + rho_4_priors[i]) + i_rho_priors <- i_rho_priors[!is.na(i_rho_priors)] s_name <- gsub(' ', '', orig_names[i]) to_replace <- grep(paste0('// prior for ', s_name, '...'), model_file, fixed = TRUE) + 1 @@ -664,7 +683,7 @@ add_gp_model_file = function(model_file, model_data, if(gp_isos[i]){ rho_priors[i] } else { - c(rho_priors[i], rho_2_priors[i]) + i_rho_priors }, ';\n'), collapse = '\n' diff --git a/R/stan_utils.R b/R/stan_utils.R index 84970746..10b3f0fb 100644 --- a/R/stan_utils.R +++ b/R/stan_utils.R @@ -2780,12 +2780,29 @@ add_trend_predictors = function(trend_formula, } } + all_gp_prior_lines = function(model_file, + prior_line, + max_break = 10){ + last <- prior_line + max_break + for(i in prior_line:(prior_line + max_break)){ + if(!grepl('b_raw[', model_file[i], + fixed = TRUE)){ + } else { + last <- i + break + } + } + (prior_line + 1) : last + } + if(any(grepl('// prior for gp', trend_model_file))){ - starts <- grep('// prior for gp', trend_model_file, fixed = TRUE) + 1 + starts <- grep('// prior for gp', trend_model_file, fixed = TRUE) ends <- grep('// prior for gp', trend_model_file, fixed = TRUE) + 4 for(i in seq_along(starts)){ spline_coef_lines <- c(spline_coef_lines, - paste(trend_model_file[starts[i]:ends[i]], + paste(trend_model_file[all_gp_prior_lines(trend_model_file, + starts[i], + max_break = 10)], collapse = '\n')) } } diff --git a/R/summary.mvgam.R b/R/summary.mvgam.R index d8bbdba6..88c27218 100644 --- a/R/summary.mvgam.R +++ b/R/summary.mvgam.R @@ -916,6 +916,8 @@ gp_param_summary = function(object, use.names = FALSE) gp_isos <- unlist(purrr::map(attr(mgcv_model, 'gp_att_table'), 'iso'), use.names = FALSE) + gp_dims <- unlist(purrr::map(attr(mgcv_model, 'gp_att_table'), 'dim'), + use.names = FALSE) # Create full list of rho parameter names full_names <- vector(mode = 'list', length = length(gp_names)) @@ -925,7 +927,7 @@ gp_param_summary = function(object, } else { full_names[[i]] <- paste0(gp_names[i], '[', - 1:2, + 1:gp_dims[i], ']') } } diff --git a/tests/testthat/Rplots.pdf b/tests/testthat/Rplots.pdf index b334e65f..ebdffac3 100644 Binary files a/tests/testthat/Rplots.pdf and b/tests/testthat/Rplots.pdf differ diff --git a/tests/testthat/test-gp.R b/tests/testthat/test-gp.R index 0329cb68..1aea1fd5 100644 --- a/tests/testthat/test-gp.R +++ b/tests/testthat/test-gp.R @@ -279,7 +279,7 @@ test_that("unidimensional gp for process models working properly", { test_that("multidimensional gp for process models working properly", { mod <- mvgam(y ~ s(series, bs = 're'), trend_formula = ~ - gp(time, season, k = 10), + gp(time, season, k = 10, iso = FALSE), data = beta_data$data_train, family = betar(), trend_model = AR(), @@ -292,6 +292,15 @@ test_that("multidimensional gp for process models working properly", { expect_true(any(grepl("b_trend[b_trend_idx_gp_timeby_season_] = sqrt(spd_gp_exp_quad(", mod$model_file, fixed = TRUE))) + expect_true(any(grepl("array[1] vector[2] rho_gp_trend_timeby_season_;", + mod$model_file, fixed = TRUE))) + + expect_true(any(grepl("rho_gp_trend_timeby_season_[1][1] ~ inv_gamma", + mod$model_file, fixed = TRUE))) + + expect_true(any(grepl("rho_gp_trend_timeby_season_[1][2] ~ inv_gamma", + mod$model_file, fixed = TRUE))) + # Gp data structures should be in the model_data expect_true("l_gp_trend_timeby_season_" %in% names(mod$model_data)) expect_true("b_trend_idx_gp_timeby_season_" %in% names(mod$model_data))