Skip to content

Commit

Permalink
final updates for expanded gp effects
Browse files Browse the repository at this point in the history
  • Loading branch information
nicholasjclark committed Nov 11, 2024
1 parent 0dc837b commit c23c75e
Show file tree
Hide file tree
Showing 58 changed files with 992 additions and 732 deletions.
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# mvgam 1.1.4 (development version; not yet on CRAN)
## New functionalities
* Added support for approximate `gp()` effects with more than one covariate and with different kernel functions (#79)
* Added function `jsdgam()` to estimate Joint Species Distribution Models in which both the latent factors and the observation model components can include any of mvgam's complex linear predictor effects. Also added a function `residual_cor()` to compute residual correlation, covariance and precision matrices from `jsdgam` models. See `?mvgam::jsdgam` and `?mvgam::residual_cor` for details
* Added a `stability.mvgam()` method to compute stability metrics from models fit with Vector Autoregressive dynamics (#21 and #76)
* Added functionality to estimate hierarchical error correlations when using multivariate latent process models and when the data are nested among levels of a relevant grouping factor (#75); see `?mvgam::AR` for an example
Expand Down
6 changes: 3 additions & 3 deletions R/get_linear_predictors.R
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ trend_Xp_matrix = function(newdata, trend_map, series = 'all',
trend_map,
forecast = forecast)

suppressWarnings(Xp_trend <- try(predict(mgcv_model,
suppressWarnings(Xp_trend <- try(predict(mgcv_model,
newdata = trend_test,
type = 'lpmatrix'),
silent = TRUE))
Expand Down Expand Up @@ -164,11 +164,11 @@ trend_Xp_matrix = function(newdata, trend_map, series = 'all',
# Requires a dataframe of all relevant variables for the gp effects
mock_terms <- brms::brmsterms(attr(mgcv_model, 'brms_mock')$formula)
terms_needed <- unique(all.vars(mock_terms$formula)[-1])
newdata_mock <- data.frame(newdata[[terms_needed[1]]])
newdata_mock <- data.frame(trend_test[[terms_needed[1]]])
if(length(terms_needed) > 1L){
for(i in 2:length(terms_needed)){
newdata_mock <- cbind(newdata_mock,
data.frame(newdata[[terms_needed[i]]]))
data.frame(trend_test[[terms_needed[i]]]))
}
}
colnames(newdata_mock) <- terms_needed
Expand Down
77 changes: 65 additions & 12 deletions R/get_mvgam_priors.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#'changed for a given `mvgam` model, as well listing their default distributions
#'
#'@inheritParams mvgam
#'@inheritParams jsdgam
#'@param factor_formula Can be supplied instead `trend_formula` to match syntax from
#'[jsdgam]
#'@details Users can supply a model formula, prior to fitting the model, so that default priors can be inspected and
Expand Down Expand Up @@ -156,7 +157,9 @@ get_mvgam_priors = function(formula,
factor_formula,
data,
data_train,
family = 'poisson',
family = poisson(),
unit = time,
species = series,
knots,
trend_knots,
use_lv = FALSE,
Expand All @@ -177,6 +180,19 @@ get_mvgam_priors = function(formula,

# Set trend_formula
if(!missing(factor_formula)){
if(missing(n_lv)){
n_lv <- 2
}
validate_pos_integer(n_lv)
unit <- deparse0(substitute(unit))
subgr <- deparse0(substitute(species))
prepped_trend <- prep_jsdgam_trend(unit = unit,
subgr = subgr,
data = data)
trend_model <- 'None'
data_train <- validate_series_time(data = data,
trend_model = prepped_trend)
trend_map <- prep_jsdgam_trendmap(data_train, n_lv)
if(!missing(trend_formula)){
warning('Both "trend_formula" and "factor_formula" supplied\nUsing "factor_formula" as default')
}
Expand Down Expand Up @@ -247,8 +263,8 @@ get_mvgam_priors = function(formula,
if(trend_model == 'None') trend_model <- 'RW'
validate_trend_formula(trend_formula)
prior_df <- get_mvgam_priors(formula = orig_formula,
data = data,
data_train = data_train,
data = data_train,
#data_train = data_train,
family = family,
use_lv = FALSE,
use_stan = TRUE,
Expand Down Expand Up @@ -541,11 +557,48 @@ get_mvgam_priors = function(formula,
mgcv_model = ss_gam,
gp_terms = gp_terms,
family = family)
gp_names <- unlist(purrr::map(gp_additions$gp_att_table, 'name'))
gp_names <- unlist(purrr::map(gp_additions$gp_att_table, 'name'),
use.names = FALSE)
gp_isos <- unlist(purrr::map(gp_additions$gp_att_table, 'iso'),
use.names = FALSE)
abbv_names <- vector(mode = 'list', length = length(gp_names))
full_names <- vector(mode = 'list', length = length(gp_names))
for(i in seq_len(length(gp_names))){
if(gp_isos[i]){
abbv_names[[i]] <- gp_names[i]
full_names[[i]] <- paste0(gp_names[i],
'[1]')
} else {
abbv_names[[i]] <- paste0(gp_names[i],
'[1][',
1:2,
']')
full_names[[i]] <- paste0(gp_names[i],
'[1][',
1:2,
']')
}
}
full_names <- unlist(full_names, use.names = FALSE)
abbv_names <- unlist(abbv_names, use.names = FALSE)
alpha_priors <- unlist(purrr::map(gp_additions$gp_att_table,
'def_alpha'))
'def_alpha'),
use.names = FALSE)
rho_priors <- unlist(purrr::map(gp_additions$gp_att_table,
'def_rho'))
'def_rho'),
use.names = FALSE)
rho_2_priors <- unlist(purrr::map(gp_additions$gp_att_table,
'def_rho_2'),
use.names = FALSE)
full_priors <- vector(mode = 'list', length = length(gp_names))
for(i in seq_len(length(gp_names))){
if(gp_isos[i]){
full_priors[[i]] <- rho_priors[i]
} else {
full_priors[[i]] <- c(rho_priors[i], rho_2_priors[i])
}
}
full_priors <- unlist(full_priors, use.names = FALSE)
smooth_labs <- smooth_labs %>%
dplyr::filter(!label %in%
gsub('gp(', 's(', gp_names, fixed = TRUE))
Expand All @@ -561,14 +614,14 @@ get_mvgam_priors = function(formula,
round(runif(length(gp_names), 0.5, 1), 2),
');'))
rho_df <- data.frame(param_name = paste0('real<lower=0> rho_',
gp_names, ';'),
param_length = 1,
param_info = paste(gp_names,
abbv_names, ';'),
param_length = 1,
param_info = paste(abbv_names,
'length scale'),
prior = paste0('rho_', gp_names, ' ~ ', rho_priors, ';'),
example_change = paste0('rho_', gp_names, ' ~ ',
prior = paste0('rho_', full_names, ' ~ ', full_priors, ';'),
example_change = paste0('rho_', full_names, ' ~ ',
'normal(0, ',
round(runif(length(gp_names), 1, 10), 2),
round(runif(length(full_names), 0.5, 1), 2),
');'))
gp_df <- rbind(alpha_df, rho_df)
} else {
Expand Down
31 changes: 23 additions & 8 deletions R/gp.R
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ make_gp_additions = function(gp_details,
level = gp_details$level[x],
k = k_gps[[x]],
def_rho = gp_details$def_rho[x],
def_rho_2 = gp_details$def_rho_2[x],
def_alpha = gp_details$def_alpha[x],
eigenvalues = eigenvals[[x]])

Expand Down Expand Up @@ -516,16 +517,23 @@ get_gp_attributes = function(formula, data, family = gaussian()){
family = family_to_brmsfam(family),
data = data))
def_gp_prior <- def_gp_prior[def_gp_prior$prior != '',]
def_rho <- def_gp_prior$prior[min(which(def_gp_prior$class == 'lscale'))]
if(def_rho == ''){
def_rho <- 'inv_gamma(1.5, 5);'
}
def_rho <- def_gp_prior$prior[which(def_gp_prior$class == 'lscale')]
def_alpha <- def_gp_prior$prior[min(which(def_gp_prior$class == 'sdgp'))]
if(def_alpha == ''){
def_alpha <- 'student_t(3, 0, 2.5);'
}
data.frame(def_rho = def_rho,
def_alpha = def_alpha)
if(length(def_rho) > 1L){
def_rho_2 <- def_rho[2]
def_rho <- def_rho[1]
out <- data.frame(def_rho = def_rho,
def_rho_2 = def_rho_2,
def_alpha = def_alpha)
} else {
out <- data.frame(def_rho = def_rho,
def_rho_2 = NA,
def_alpha = def_alpha)
}
out
}))

# Extract information necessary to construct the GP terms
Expand Down Expand Up @@ -555,7 +563,8 @@ get_gp_attributes = function(formula, data, family = gaussian()){
by,
level = NA,
def_alpha = gp_def_priors$def_alpha,
def_rho = gp_def_priors$def_rho)
def_rho = gp_def_priors$def_rho,
def_rho_2 = gp_def_priors$def_rho_2)
attr(ret_dat, 'gp_formula') <- gp_formula

# Return as a data.frame
Expand Down Expand Up @@ -592,6 +601,8 @@ add_gp_model_file = function(model_file, model_data,

rho_priors <- unlist(purrr::map(gp_additions$gp_att_table, 'def_rho'),
use.names = FALSE)
rho_2_priors <- unlist(purrr::map(gp_additions$gp_att_table, 'def_rho_2'),
use.names = FALSE)
alpha_priors <- unlist(purrr::map(gp_additions$gp_att_table, 'def_alpha'),
use.names = FALSE)

Expand Down Expand Up @@ -648,7 +659,11 @@ add_gp_model_file = function(model_file, model_data,
},
']',
' ~ ',
rho_priors[i],
if(gp_isos[i]){
rho_priors[i]
} else {
c(rho_priors[i], rho_2_priors[i])
},
';\n'),
collapse = '\n'
)
Expand Down
17 changes: 11 additions & 6 deletions R/jsdgam.R
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
#' Defaults to `series` to be consistent with other `mvgam` models
#'@param n_lv \code{integer} the number of latent factors to use for modelling
#'residual associations.
#'Cannot be `< n_species`. Defaults arbitrarily to `1`
#'Cannot be `> n_species`. Defaults arbitrarily to `2`
#'@param threads \code{integer} Experimental option to use multithreading for within-chain
#'parallelisation in \code{Stan}. We recommend its use only if you are experienced with
#'\code{Stan}'s `reduce_sum` function and have a slow running model that cannot be sped
Expand Down Expand Up @@ -198,6 +198,7 @@
#' ggplot(dat, aes(x = lat, y = lon, col = log(count + 1))) +
#' geom_point(size = 2.25) +
#' facet_wrap(~ species, scales = 'free') +
#' scale_color_viridis_c() +
#' theme_classic()
#'
#' # Inspect default priors for a joint species model with spatial factors
Expand All @@ -209,10 +210,13 @@
#'
#' # Each factor estimates a different nonlinear spatial process, using
#' # 'by = trend' as in other mvgam State-Space models
#' factor_formula = ~ te(lat, lon, k = 5, by = trend) - 1,
#' factor_formula = ~ gp(lat, lon, k = 6, by = trend) - 1,
#' n_lv = 4,
#'
#' # The data
#' # The data and grouping variables
#' data = dat,
#' unit = site,
#' species = species,
#'
#' # Poisson observations
#' family = poisson())
Expand All @@ -228,7 +232,7 @@
#'
#' # Each factor estimates a different nonlinear spatial process, using
#' # 'by = trend' as in other mvgam State-Space models
#' factor_formula = ~ te(lat, lon, k = 5, by = trend) - 1,
#' factor_formula = ~ gp(lat, lon, k = 6, by = trend) - 1,
#' n_lv = 4,
#'
#' # Change default priors for fixed effect betas to standard normal
Expand Down Expand Up @@ -274,7 +278,7 @@
#' image(post_cors$cor)
#'
#' # Posterior predictive checks and ELPD-LOO can ascertain model fit
#' pp_check(mod, type = "ecdf_overlay_grouped",
#' pp_check(mod, type = "pit_ecdf_grouped",
#' group = "species", ndraws = 100)
#' loo(mod)
#'
Expand All @@ -296,6 +300,7 @@
#' ggplot(newdata, aes(x = lat, y = lon, col = log_count)) +
#' geom_point(size = 1.5) +
#' facet_wrap(~ species, scales = 'free') +
#' scale_color_viridis_c() +
#' theme_classic()
#'}
#'@export
Expand All @@ -310,7 +315,7 @@ jsdgam = function(formula,
species = series,
share_obs_params = FALSE,
priors,
n_lv = 1,
n_lv = 2,
chains = 4,
burnin = 500,
samples = 500,
Expand Down
12 changes: 6 additions & 6 deletions R/mvgam_formulae.R
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,16 @@
#' \cr
#' The formulae supplied to \code{\link{mvgam}} are exactly like those supplied to
#' \code{\link{glm}} except that smooth terms,
#' \code{\link[mgcv]{s}},
#' \code{\link[mgcv]{te}},
#' \code{\link[mgcv]{ti}} and
#' \code{\link[mgcv]{t2}},
#' time-varying effects using \code{\link{dynamic}},
#' \code{\link[mgcv]{s()}},
#' \code{\link[mgcv]{te()}},
#' \code{\link[mgcv]{ti()}} and
#' \code{\link[mgcv]{t2()}},
#' time-varying effects using \code{\link{dynamic()}},
#' monotonically increasing (using `s(x, bs = 'moi')`)
#' or decreasing splines (using `s(x, bs = 'mod')`;
#' see \code{\link{smooth.construct.moi.smooth.spec}} for
#' details), as well as
#' Gaussian Process functions using \code{\link[brms]{gp}},
#' Gaussian Process functions using \code{\link[brms]{gp()}},
#' can be added to the right hand side (and \code{.} is not supported in \code{mvgam} formulae).
#' \cr
#' \cr
Expand Down
2 changes: 1 addition & 1 deletion R/plot_mvgam_smooth.R
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ plot_mvgam_smooth = function(object,

if(gp_term){
object2$mgcv_model$smooth[[smooth_int]]$label <-
gsub('s\\(|te\\(', 'gp(',
gsub('s\\(|ti\\(', 'gp(',
object2$mgcv_model$smooth[[smooth_int]]$label)
# Check if this is a factor by variable
is_fac <- is.factor(object2$obs_data[[object2$mgcv_model$smooth[[smooth_int]]$by]])
Expand Down
Loading

0 comments on commit c23c75e

Please sign in to comment.