Skip to content

Commit

Permalink
dynamically pull out S matrix names for trend predictor models
Browse files Browse the repository at this point in the history
  • Loading branch information
nicholasjclark committed Nov 20, 2023
1 parent edc943b commit a9eb613
Show file tree
Hide file tree
Showing 9 changed files with 53 additions and 7 deletions.
12 changes: 12 additions & 0 deletions R/get_mvgam_priors.R
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,18 @@
#' run_model = FALSE)
#'code(mod2)
#'
#'# The "class = 'b'" shortcut can be used to put the same prior on all
#'# 'fixed' effect coefficients (apart from any intercepts)
#'set.seed(0)
#'dat <- mgcv::gamSim(1, n = 200, scale = 2)
#'dat$time <- 1:NROW(dat)
#'mod <- mvgam(y ~ x0 + x1 + s(x2) + s(x3),
#' priors = prior(normal(0, 0.75), class = 'b'),
#' data = dat,
#' family = gaussian(),
#' run_model = FALSE)
#'code(mod)
#'
#'@export
get_mvgam_priors = function(formula,
trend_formula,
Expand Down
3 changes: 2 additions & 1 deletion R/mvgam.R
Original file line number Diff line number Diff line change
Expand Up @@ -589,7 +589,8 @@ mvgam = function(formula,
n_lv = n_lv,
trend_model = trend_model,
trend_map = trend_map,
drift = drift)
drift = drift,
warnings = TRUE)
}
}

Expand Down
4 changes: 3 additions & 1 deletion R/stan_utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -2857,7 +2857,9 @@ add_trend_predictors = function(trend_formula,
paste0("int<lower=0> n_nonmissing; // number of nonmissing observations\n",
paste(S_lines, collapse = '\n'))

S_mats <- trend_mvgam$model_data[paste0('S', 1:length(S_lines))]
# Pull out S matrices (don't always start at 1!)
S_mats <- trend_mvgam$model_data[grepl("S[0-9]",
names(trend_mvgam$model_data))]
names(S_mats) <- gsub('S', 'S_trend', names(S_mats))
model_data <- append(model_data, S_mats)
}
Expand Down
29 changes: 24 additions & 5 deletions R/update_priors.R
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,8 @@ adapt_brms_priors = function(priors,
n_lv,
trend_model = 'None',
trend_map,
drift = FALSE){
drift = FALSE,
warnings = FALSE){

# Replace any call to 'Intercept' with '(Intercept)' to match mgcv style
priors[] <- lapply(priors, function(x)
Expand Down Expand Up @@ -318,7 +319,8 @@ adapt_brms_priors = function(priors,
priors_df$prior, fixed = TRUE)] <-
priors$ub[i]

} else if(any(grepl(paste0(priors$coef[i], ' ~ '),
} else if(priors$coef[i] != '' &
any(grepl(paste0(priors$coef[i], ' ~ '),
priors_df$prior, fixed = TRUE))){

# Update the prior distribution
Expand All @@ -335,10 +337,27 @@ adapt_brms_priors = function(priors,
priors_df$prior, fixed = TRUE)] <-
priors$ub[i]

} else if(priors$class[i] == 'b'){
# Update all fixed effect priors
if(any(grepl('fixed effect', priors_df$param_info))){

for(j in 1:NROW(priors_df)){
if(grepl('fixed effect', priors_df$param_info[j])){
priors_df$prior[j] <-
paste0(paste(trimws(
strsplit(priors_df$prior[j],
"[~]")[[1]][1]), '~ '),
priors$prior[i], ';')
}
}
}

} else {
warning(paste0('no match found in model_file for parameter: ',
paste0(priors$class[i], ', ', priors$coef[i])),
call. = FALSE)
if(warnings){
warning(paste0('no match found in model_file for parameter: ',
paste0(priors$class[i], ' ', priors$coef[i])),
call. = FALSE)
}
}
}

Expand Down
12 changes: 12 additions & 0 deletions man/get_mvgam_priors.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Binary file modified src/RcppExports.o
Binary file not shown.
Binary file modified src/mvgam.dll
Binary file not shown.
Binary file modified src/trend_funs.o
Binary file not shown.
Binary file modified tests/testthat/Rplots.pdf
Binary file not shown.

0 comments on commit a9eb613

Please sign in to comment.