Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Brms gps #89

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,8 @@ importFrom(bayesplot,nuts_params)
importFrom(bayesplot,pp_check)
importFrom(brms,bernoulli)
importFrom(brms,beta_binomial)
importFrom(brms,brm)
importFrom(brms,brmsterms)
importFrom(brms,conditional_effects)
importFrom(brms,dbeta_binomial)
importFrom(brms,do_call)
Expand Down
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# mvgam 1.1.4 (development version; not yet on CRAN)
## New functionalities
* Added support for approximate `gp()` effects with more than one covariate and with different kernel functions (#79)
* Added function `jsdgam()` to estimate Joint Species Distribution Models in which both the latent factors and the observation model components can include any of mvgam's complex linear predictor effects. Also added a function `residual_cor()` to compute residual correlation, covariance and precision matrices from `jsdgam` models. See `?mvgam::jsdgam` and `?mvgam::residual_cor` for details
* Added a `stability.mvgam()` method to compute stability metrics from models fit with Vector Autoregressive dynamics (#21 and #76)
* Added functionality to estimate hierarchical error correlations when using multivariate latent process models and when the data are nested among levels of a relevant grouping factor (#75); see `?mvgam::AR` for an example
Expand Down
4 changes: 2 additions & 2 deletions R/conditional_effects.R
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,8 @@
conditional_effects.mvgam = function(x,
effects = NULL,
type = 'response',
points = TRUE,
rug = TRUE,
points = FALSE,
rug = FALSE,
...){

use_def_effects <- is.null(effects)
Expand Down
107 changes: 54 additions & 53 deletions R/get_linear_predictors.R
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#' Function to prepare observation model linear predictor matrix
#' @importFrom brms brmsterms
#' @noRd
obs_Xp_matrix = function(newdata, mgcv_model){
suppressWarnings(Xp <- try(predict(mgcv_model,
Expand Down Expand Up @@ -32,41 +33,41 @@ obs_Xp_matrix = function(newdata, mgcv_model){
# Check for any gp() terms and update the design matrix
# accordingly
if(!is.null(attr(mgcv_model, 'gp_att_table'))){
# Compute the eigenfunctions from the supplied attribute table,
# and add them to the Xp matrix

# Compute the gp() eigenfunctions for newdata using the supplied brms_mock object
# Requires a dataframe of all relevant variables for the gp effects
mock_terms <- brms::brmsterms(attr(mgcv_model, 'brms_mock')$formula)
terms_needed <- unique(all.vars(mock_terms$formula)[-1])
newdata_mock <- data.frame(newdata[[terms_needed[1]]])
if(length(terms_needed) > 1L){
for(i in 2:length(terms_needed)){
newdata_mock <- cbind(newdata_mock,
data.frame(newdata[[terms_needed[i]]]))
}
}
colnames(newdata_mock) <- terms_needed
newdata_mock$.fake_gp_y <- rnorm(NROW(newdata_mock))
brms_mock_data <- brms::standata(attr(mgcv_model, 'brms_mock'),
newdata = newdata_mock,
internal = TRUE)

# Extract GP attributes
gp_att_table <- attr(mgcv_model, 'gp_att_table')
gp_covariates <- unlist(purrr::map(gp_att_table, 'covariate'))
by <- unlist(purrr::map(gp_att_table, 'by'))
level <- unlist(purrr::map(gp_att_table, 'level'))
k <- unlist(purrr::map(gp_att_table, 'k'))
scale <- unlist(purrr::map(gp_att_table, 'scale'))
mean <- unlist(purrr::map(gp_att_table, 'mean'))
max_dist <- unlist(purrr::map(gp_att_table, 'max_dist'))
boundary <- unlist(purrr::map(gp_att_table, 'boundary'))
L <- unlist(purrr::map(gp_att_table, 'L'))

# Compute eigenfunctions
test_eigenfunctions <- lapply(seq_along(gp_covariates), function(x){
prep_eigenfunctions(data = newdata,
covariate = gp_covariates[x],
by = by[x],
level = level[x],
k = k[x],
boundary = boundary[x],
L = L[x],
mean = mean[x],
scale = scale[x],
max_dist = max_dist[x])
})
bys <- unlist(purrr::map(gp_att_table, 'by'), use.names = FALSE)
lvls <- unlist(purrr::map(gp_att_table, 'level'), use.names = FALSE)

# Extract eigenfunctions for each gp effect
eigenfuncs <- eigenfunc_list(stan_data = brms_mock_data,
mock_df = newdata_mock,
by = bys,
level = lvls)

# Find indices to replace in the design matrix and replace with
# the computed eigenfunctions
starts <- purrr::map(gp_att_table, 'first_coef')
ends <- purrr::map(gp_att_table, 'last_coef')
for(i in seq_along(starts)){
Xp[,c(starts[[i]]:ends[[i]])] <- test_eigenfunctions[[i]]
Xp[,c(starts[[i]]:ends[[i]])] <- eigenfuncs[[i]]
}
}

Expand Down Expand Up @@ -127,7 +128,7 @@ trend_Xp_matrix = function(newdata, trend_map, series = 'all',
trend_map,
forecast = forecast)

suppressWarnings(Xp_trend <- try(predict(mgcv_model,
suppressWarnings(Xp_trend <- try(predict(mgcv_model,
newdata = trend_test,
type = 'lpmatrix'),
silent = TRUE))
Expand Down Expand Up @@ -158,41 +159,41 @@ trend_Xp_matrix = function(newdata, trend_map, series = 'all',
# Check for any gp() terms and update the design matrix
# accordingly
if(!is.null(attr(mgcv_model, 'gp_att_table'))){
# Compute the eigenfunctions from the supplied attribute table,
# and add them to the Xp matrix

# Compute the gp() eigenfunctions for newdata using the supplied brms_mock object
# Requires a dataframe of all relevant variables for the gp effects
mock_terms <- brms::brmsterms(attr(mgcv_model, 'brms_mock')$formula)
terms_needed <- unique(all.vars(mock_terms$formula)[-1])
newdata_mock <- data.frame(trend_test[[terms_needed[1]]])
if(length(terms_needed) > 1L){
for(i in 2:length(terms_needed)){
newdata_mock <- cbind(newdata_mock,
data.frame(trend_test[[terms_needed[i]]]))
}
}
colnames(newdata_mock) <- terms_needed
newdata_mock$.fake_gp_y <- rnorm(NROW(newdata_mock))
brms_mock_data <- brms::standata(attr(mgcv_model, 'brms_mock'),
newdata = newdata_mock,
internal = TRUE)

# Extract GP attributes
gp_att_table <- attr(mgcv_model, 'gp_att_table')
gp_covariates <- unlist(purrr::map(gp_att_table, 'covariate'))
by <- unlist(purrr::map(gp_att_table, 'by'))
level <- unlist(purrr::map(gp_att_table, 'level'))
k <- unlist(purrr::map(gp_att_table, 'k'))
scale <- unlist(purrr::map(gp_att_table, 'scale'))
mean <- unlist(purrr::map(gp_att_table, 'mean'))
max_dist <- unlist(purrr::map(gp_att_table, 'max_dist'))
boundary <- unlist(purrr::map(gp_att_table, 'boundary'))
L <- unlist(purrr::map(gp_att_table, 'L'))

# Compute eigenfunctions
test_eigenfunctions <- lapply(seq_along(gp_covariates), function(x){
prep_eigenfunctions(data = trend_test,
covariate = gp_covariates[x],
by = by[x],
level = level[x],
k = k[x],
boundary = boundary[x],
L = L[x],
mean = mean[x],
scale = scale[x],
max_dist = max_dist[x])
})
bys <- unlist(purrr::map(gp_att_table, 'by'), use.names = FALSE)
lvls <- unlist(purrr::map(gp_att_table, 'level'), use.names = FALSE)

# Extract eigenfunctions for each gp effect
eigenfuncs <- eigenfunc_list(stan_data = brms_mock_data,
mock_df = newdata_mock,
by = bys,
level = lvls)

# Find indices to replace in the design matrix and replace with
# the computed eigenfunctions
starts <- purrr::map(gp_att_table, 'first_coef')
ends <- purrr::map(gp_att_table, 'last_coef')
for(i in seq_along(starts)){
Xp_trend[,c(starts[[i]]:ends[[i]])] <- test_eigenfunctions[[i]]
Xp_trend[,c(starts[[i]]:ends[[i]])] <- eigenfuncs[[i]]
}
}

Expand Down
78 changes: 66 additions & 12 deletions R/get_mvgam_priors.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#'changed for a given `mvgam` model, as well listing their default distributions
#'
#'@inheritParams mvgam
#'@inheritParams jsdgam
#'@param factor_formula Can be supplied instead `trend_formula` to match syntax from
#'[jsdgam]
#'@details Users can supply a model formula, prior to fitting the model, so that default priors can be inspected and
Expand Down Expand Up @@ -156,7 +157,9 @@ get_mvgam_priors = function(formula,
factor_formula,
data,
data_train,
family = 'poisson',
family = poisson(),
unit = time,
species = series,
knots,
trend_knots,
use_lv = FALSE,
Expand All @@ -177,6 +180,19 @@ get_mvgam_priors = function(formula,

# Set trend_formula
if(!missing(factor_formula)){
if(missing(n_lv)){
n_lv <- 2
}
validate_pos_integer(n_lv)
unit <- deparse0(substitute(unit))
subgr <- deparse0(substitute(species))
prepped_trend <- prep_jsdgam_trend(unit = unit,
subgr = subgr,
data = data)
trend_model <- 'None'
data_train <- validate_series_time(data = data,
trend_model = prepped_trend)
trend_map <- prep_jsdgam_trendmap(data_train, n_lv)
if(!missing(trend_formula)){
warning('Both "trend_formula" and "factor_formula" supplied\nUsing "factor_formula" as default')
}
Expand Down Expand Up @@ -247,8 +263,8 @@ get_mvgam_priors = function(formula,
if(trend_model == 'None') trend_model <- 'RW'
validate_trend_formula(trend_formula)
prior_df <- get_mvgam_priors(formula = orig_formula,
data = data,
data_train = data_train,
data = data_train,
#data_train = data_train,
family = family,
use_lv = FALSE,
use_stan = TRUE,
Expand Down Expand Up @@ -534,17 +550,55 @@ get_mvgam_priors = function(formula,
# Check for gp() terms
if(!is.null(gp_terms)){
gp_additions <- make_gp_additions(gp_details = gp_details,
orig_formula = orig_formula,
data = data_train,
newdata = NULL,
model_data = list(X = t(predict(ss_gam, type = 'lpmatrix'))),
mgcv_model = ss_gam,
gp_terms = gp_terms,
family = family)
gp_names <- unlist(purrr::map(gp_additions$gp_att_table, 'name'))
gp_names <- unlist(purrr::map(gp_additions$gp_att_table, 'name'),
use.names = FALSE)
gp_isos <- unlist(purrr::map(gp_additions$gp_att_table, 'iso'),
use.names = FALSE)
abbv_names <- vector(mode = 'list', length = length(gp_names))
full_names <- vector(mode = 'list', length = length(gp_names))
for(i in seq_len(length(gp_names))){
if(gp_isos[i]){
abbv_names[[i]] <- gp_names[i]
full_names[[i]] <- paste0(gp_names[i],
'[1]')
} else {
abbv_names[[i]] <- paste0(gp_names[i],
'[1][',
1:2,
']')
full_names[[i]] <- paste0(gp_names[i],
'[1][',
1:2,
']')
}
}
full_names <- unlist(full_names, use.names = FALSE)
abbv_names <- unlist(abbv_names, use.names = FALSE)
alpha_priors <- unlist(purrr::map(gp_additions$gp_att_table,
'def_alpha'))
'def_alpha'),
use.names = FALSE)
rho_priors <- unlist(purrr::map(gp_additions$gp_att_table,
'def_rho'))
'def_rho'),
use.names = FALSE)
rho_2_priors <- unlist(purrr::map(gp_additions$gp_att_table,
'def_rho_2'),
use.names = FALSE)
full_priors <- vector(mode = 'list', length = length(gp_names))
for(i in seq_len(length(gp_names))){
if(gp_isos[i]){
full_priors[[i]] <- rho_priors[i]
} else {
full_priors[[i]] <- c(rho_priors[i], rho_2_priors[i])
}
}
full_priors <- unlist(full_priors, use.names = FALSE)
smooth_labs <- smooth_labs %>%
dplyr::filter(!label %in%
gsub('gp(', 's(', gp_names, fixed = TRUE))
Expand All @@ -560,14 +614,14 @@ get_mvgam_priors = function(formula,
round(runif(length(gp_names), 0.5, 1), 2),
');'))
rho_df <- data.frame(param_name = paste0('real<lower=0> rho_',
gp_names, ';'),
param_length = 1,
param_info = paste(gp_names,
abbv_names, ';'),
param_length = 1,
param_info = paste(abbv_names,
'length scale'),
prior = paste0('rho_', gp_names, ' ~ ', rho_priors, ';'),
example_change = paste0('rho_', gp_names, ' ~ ',
prior = paste0('rho_', full_names, ' ~ ', full_priors, ';'),
example_change = paste0('rho_', full_names, ' ~ ',
'normal(0, ',
round(runif(length(gp_names), 1, 10), 2),
round(runif(length(full_names), 0.5, 1), 2),
');'))
gp_df <- rbind(alpha_df, rho_df)
} else {
Expand Down
3 changes: 2 additions & 1 deletion R/globals.R
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,5 @@ utils::globalVariables(c("y", "year", "smooth_vals", "smooth_num",
"matches", "time.", "file_name", ".data",
"horizon", "target", "Series", "evd", "mean_evd",
"total_evd", "smooth_label", "by_variable",
"gr", "tot_subgrs", "subgr"))
"gr", "tot_subgrs", "subgr", "lambda",
"level", "sim_hilbert_gp", "trend_model"))
Loading
Loading