Skip to content

Commit

Permalink
Merge branch 'master' of https://github.com/nicholasjclark/mvgam
Browse files Browse the repository at this point in the history
  • Loading branch information
Nicholas Clark committed Dec 10, 2024
2 parents 5bd546b + 841ed79 commit 415b79c
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 9 deletions.
29 changes: 24 additions & 5 deletions R/gp.R
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,8 @@ make_gp_additions = function(gp_details,
k = k_gps[[x]],
def_rho = gp_details$def_rho[x],
def_rho_2 = gp_details$def_rho_2[x],
def_rho_3 = gp_details$def_rho_3[x],
def_rho_4 = gp_details$def_rho_4[x],
def_alpha = gp_details$def_alpha[x],
eigenvalues = eigenvals[[x]])

Expand Down Expand Up @@ -525,14 +527,20 @@ get_gp_attributes = function(formula, data, family = gaussian()){
def_alpha <- 'student_t(3, 0, 2.5);'
}
if(length(def_rho) > 1L){
def_rho_1 <- def_rho[1]
def_rho_2 <- def_rho[2]
def_rho <- def_rho[1]
out <- data.frame(def_rho = def_rho,
out <- data.frame(def_rho = def_rho_1,
def_rho_2 = def_rho_2,
def_alpha = def_alpha)
def_rho_3 = NA,
def_rho_4 = NA,
def_alpha = def_alpha)
if(length(def_rho) > 2L) out$def_rho_3 <- def_rho[3]
if(length(def_rho) > 3L) out$def_rho_4 <- def_rho[4]
} else {
out <- data.frame(def_rho = def_rho,
def_rho_2 = NA,
def_rho_3 = NA,
def_rho_4 = NA,
def_alpha = def_alpha)
}
out
Expand Down Expand Up @@ -566,7 +574,9 @@ get_gp_attributes = function(formula, data, family = gaussian()){
level = NA,
def_alpha = gp_def_priors$def_alpha,
def_rho = gp_def_priors$def_rho,
def_rho_2 = gp_def_priors$def_rho_2)
def_rho_2 = gp_def_priors$def_rho_2,
def_rho_3 = gp_def_priors$def_rho_3,
def_rho_4 = gp_def_priors$def_rho_4)
attr(ret_dat, 'gp_formula') <- gp_formula

# Return as a data.frame
Expand Down Expand Up @@ -605,6 +615,10 @@ add_gp_model_file = function(model_file, model_data,
use.names = FALSE)
rho_2_priors <- unlist(purrr::map(gp_additions$gp_att_table, 'def_rho_2'),
use.names = FALSE)
rho_3_priors <- unlist(purrr::map(gp_additions$gp_att_table, 'def_rho_3'),
use.names = FALSE)
rho_4_priors <- unlist(purrr::map(gp_additions$gp_att_table, 'def_rho_4'),
use.names = FALSE)
alpha_priors <- unlist(purrr::map(gp_additions$gp_att_table, 'def_alpha'),
use.names = FALSE)

Expand All @@ -631,6 +645,11 @@ add_gp_model_file = function(model_file, model_data,
gp_names_clean <- clean_gpnames(gp_names)
s_to_remove <- list()
for(i in seq_along(gp_names)){
i_rho_priors <- c(rho_priors[i],
rho_2_priors[i],
rho_3_priors[i],
rho_4_priors[i])
i_rho_priors <- i_rho_priors[!is.na(i_rho_priors)]
s_name <- gsub(' ', '', orig_names[i])
to_replace <- grep(paste0('// prior for ', s_name, '...'),
model_file, fixed = TRUE) + 1
Expand Down Expand Up @@ -664,7 +683,7 @@ add_gp_model_file = function(model_file, model_data,
if(gp_isos[i]){
rho_priors[i]
} else {
c(rho_priors[i], rho_2_priors[i])
i_rho_priors
},
';\n'),
collapse = '\n'
Expand Down
21 changes: 19 additions & 2 deletions R/stan_utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -2780,12 +2780,29 @@ add_trend_predictors = function(trend_formula,
}
}

all_gp_prior_lines = function(model_file,
prior_line,
max_break = 10){
last <- prior_line + max_break
for(i in prior_line:(prior_line + max_break)){
if(!grepl('b_raw[', model_file[i],
fixed = TRUE)){
} else {
last <- i
break
}
}
(prior_line + 1) : last
}

if(any(grepl('// prior for gp', trend_model_file))){
starts <- grep('// prior for gp', trend_model_file, fixed = TRUE) + 1
starts <- grep('// prior for gp', trend_model_file, fixed = TRUE)
ends <- grep('// prior for gp', trend_model_file, fixed = TRUE) + 4
for(i in seq_along(starts)){
spline_coef_lines <- c(spline_coef_lines,
paste(trend_model_file[starts[i]:ends[i]],
paste(trend_model_file[all_gp_prior_lines(trend_model_file,
starts[i],
max_break = 10)],
collapse = '\n'))
}
}
Expand Down
4 changes: 3 additions & 1 deletion R/summary.mvgam.R
Original file line number Diff line number Diff line change
Expand Up @@ -916,6 +916,8 @@ gp_param_summary = function(object,
use.names = FALSE)
gp_isos <- unlist(purrr::map(attr(mgcv_model, 'gp_att_table'), 'iso'),
use.names = FALSE)
gp_dims <- unlist(purrr::map(attr(mgcv_model, 'gp_att_table'), 'dim'),
use.names = FALSE)

# Create full list of rho parameter names
full_names <- vector(mode = 'list', length = length(gp_names))
Expand All @@ -925,7 +927,7 @@ gp_param_summary = function(object,
} else {
full_names[[i]] <- paste0(gp_names[i],
'[',
1:2,
1:gp_dims[i],
']')
}
}
Expand Down
Binary file modified tests/testthat/Rplots.pdf
Binary file not shown.
11 changes: 10 additions & 1 deletion tests/testthat/test-gp.R
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ test_that("unidimensional gp for process models working properly", {
test_that("multidimensional gp for process models working properly", {
mod <- mvgam(y ~ s(series, bs = 're'),
trend_formula = ~
gp(time, season, k = 10),
gp(time, season, k = 10, iso = FALSE),
data = beta_data$data_train,
family = betar(),
trend_model = AR(),
Expand All @@ -292,6 +292,15 @@ test_that("multidimensional gp for process models working properly", {
expect_true(any(grepl("b_trend[b_trend_idx_gp_timeby_season_] = sqrt(spd_gp_exp_quad(",
mod$model_file, fixed = TRUE)))

expect_true(any(grepl("array[1] vector<lower=0>[2] rho_gp_trend_timeby_season_;",
mod$model_file, fixed = TRUE)))

expect_true(any(grepl("rho_gp_trend_timeby_season_[1][1] ~ inv_gamma",
mod$model_file, fixed = TRUE)))

expect_true(any(grepl("rho_gp_trend_timeby_season_[1][2] ~ inv_gamma",
mod$model_file, fixed = TRUE)))

# Gp data structures should be in the model_data
expect_true("l_gp_trend_timeby_season_" %in% names(mod$model_data))
expect_true("b_trend_idx_gp_timeby_season_" %in% names(mod$model_data))
Expand Down

0 comments on commit 415b79c

Please sign in to comment.