Skip to content

Commit

Permalink
Merge pull request #17 from nicholasjclark/process_linpreds
Browse files Browse the repository at this point in the history
Process linpreds
  • Loading branch information
nicholasjclark authored Jun 21, 2023
2 parents dd94a5a + cc6fde7 commit 15b4251
Show file tree
Hide file tree
Showing 23 changed files with 1,133 additions and 197 deletions.
28 changes: 23 additions & 5 deletions R/dynamic.R
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,32 @@ interpret_mvgam = function(formula, N){
newfacs <- facs[!grepl('bs = \"re\"', facs, fixed = TRUE)]
refacs <- facs[grepl('bs = \"re\"', facs, fixed = TRUE)]
int <- attr(terms.formula(formula), 'intercept')
newformula <- as.formula(paste(terms.formula(formula)[[2]], '~',paste(paste(newfacs, collapse = '+'),
'+',
paste(refacs, collapse = '+'),
collapse = '+'),
ifelse(int == 0, ' - 1', '')))

# Preserve offset if included
if(!is.null(attr(terms(formula(formula)), 'offset'))){
newformula <- as.formula(paste(terms.formula(formula)[[2]], '~',
grep('offset', rownames(attr(terms.formula(formula), 'factors')),
value = TRUE),
'+',
paste(paste(newfacs, collapse = '+'),
'+',
paste(refacs, collapse = '+'),
collapse = '+'),
ifelse(int == 0, ' - 1', '')))

} else {
newformula <- as.formula(paste(terms.formula(formula)[[2]], '~',
paste(paste(newfacs, collapse = '+'),
'+',
paste(refacs, collapse = '+'),
collapse = '+'),
ifelse(int == 0, ' - 1', '')))
}

} else {
newformula <- formula
}

attr(newformula, '.Environment') <- attr(formula, '.Environment')

# Check if any terms use the dynamic wrapper
Expand Down
6 changes: 6 additions & 0 deletions R/get_mvgam_priors.R
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,12 @@ get_mvgam_priors = function(formula,
trend_map,
drift = FALSE){

# Check formula
if(attr(terms(formula), "response") == 0L){
stop('response variable is missing from formula',
call. = FALSE)
}

# Validate the family argument
family <- evaluate_family(family)
family_char <- match.arg(arg = family$family,
Expand Down
6 changes: 5 additions & 1 deletion R/mvgam-class.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
#' `print`, `score`, `summary` and `update` exist for this class.
#' @details A `mvgam` object contains the following elements:
#'\itemize{
#' \item `call` the original model formula
#' \item `call` the original observation model formula
#' \item `trend_call` If a `trend_formula was supplied`, the original trend model formula is
#' returned. Otherwise `NULL`
#' \item `family` \code{character} description of the observation distribution
#' \item `trend_model` \code{character} description of the latent trend model
#' \item `drift` Logical specifying whether a drift term was used in the trend model
Expand All @@ -32,6 +34,8 @@
#' but these are only used if generating plots of smooth functions that `mvgam` currently cannot handle
#' (such as plots for three-dimensional smooths). This model therefore should not be used for inference.
#' See \code{\link[mgcv]{gamObject}} for details.
#' \item `trend_mgcv_model` If a `trend_formula was supplied`, an object of class `gam` containing
#' the `mgcv` version of the trend model. Otherwise `NULL`
#' \item `ytimes` The `matrix` object used in model fitting for indexing which series and timepoints
#' were observed in each row of the supplied data. Used internally by some downstream plotting
#' and prediction functions
Expand Down
134 changes: 117 additions & 17 deletions R/mvgam.R
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,13 @@
#'
#'@importFrom parallel clusterExport stopCluster setDefaultCluster
#'@importFrom stats formula terms rnorm update.formula predict
#'@param formula A \code{character} string specifying the GAM formula. These are exactly like the formula
#'@param formula A \code{character} string specifying the GAM observation model formula. These are exactly like the formula
#'for a GLM except that smooth terms, s, te, ti and t2, can be added to the right hand side
#'to specify that the linear predictor depends on smooth functions of predictors (or linear functionals of these).
#'@param trend_formula An optional \code{character} string specifying the GAM process model formula. If
#'supplied, a linear predictor will be modelled for the latent trends to capture process model evolution
#'separately from the observation model. Should not have a response variable specified on the left-hand side
#'of the formula (i.e. a valid option would be `~ season + s(year)`)
#'@param knots An optional \code{list} containing user specified knot values to be used for basis construction.
#'For most bases the user simply supplies the knots to be used, which must match up with the k value supplied
#'(note that the number of knots is not always just k). Different terms can use different numbers of knots,
Expand Down Expand Up @@ -337,7 +341,7 @@
#' # Example showing how to incorporate an offset; simulate some count data
#' # with different means per series
#' set.seed(100)
#' dat <- sim_mvgam(trend_rel = 0, mu_obs = c(4, 8, 8), seasonality = 'hierarchical')
#' dat <- sim_mvgam(trend_rel = 0, mu = c(0, 2, 2), seasonality = 'hierarchical')
#'
#' # Add offset terms to the training and testing data
#' dat$data_train$offset <- 0.5 * as.numeric(dat$data_train$series)
Expand All @@ -346,7 +350,8 @@
#' # Fit a model that includes the offset in the linear predictor as well as
#' # hierarchical seasonal smooths
#' mod1 <- mvgam(formula = y ~ offset(offset) +
#' s(season, bs = 'cc') +
#' s(series, bs = 're') +
#' s(season, bs = 'cc') +
#' s(season, by = series, m = 1, k = 5),
#' data = dat$data_train,
#' trend_model = 'None',
Expand All @@ -359,9 +364,9 @@
#' # Forecasts for the first two series will differ in magnitude
#' layout(matrix(1:2, ncol = 2))
#' plot(mod1, type = 'forecast', series = 1, newdata = dat$data_test,
#' ylim = c(0, 70))
#' ylim = c(0, 75))
#' plot(mod1, type = 'forecast', series = 2, newdata = dat$data_test,
#' ylim = c(0, 70))
#' ylim = c(0, 75))
#' layout(1)
#'
#' # Changing the offset for the testing data should lead to changes in
Expand Down Expand Up @@ -396,6 +401,7 @@
#'@export

mvgam = function(formula,
trend_formula,
knots,
data,
data_train,
Expand Down Expand Up @@ -433,6 +439,11 @@ mvgam = function(formula,
data_train <- data
}

if(attr(terms(formula), "response") == 0L){
stop('response variable is missing from formula',
call. = FALSE)
}

if(!as.character(terms(formula(formula))[[2]]) %in% names(data_train)){
stop(paste0('variable ', terms(formula(formula))[[2]], ' not found in data'),
call. = FALSE)
Expand Down Expand Up @@ -467,6 +478,18 @@ mvgam = function(formula,
# Validate the trend arguments
trend_model <- evaluate_trend_model(trend_model)

if(!missing(trend_formula)){
if(missing(trend_map)){
trend_map <- data.frame(series = unique(data_train$series),
trend = 1:length(unique(data_train$series)))
}

if(trend_model != 'RW'){
stop('only random walk trends currently supported for trend predictor models',
call. = FALSE)
}
}

# Check trend_map is correctly specified
if(!missing(trend_map)){

Expand Down Expand Up @@ -676,6 +699,13 @@ mvgam = function(formula,
warning('No point in latent variables if trend model is None; changing use_lv to FALSE')
}

# Check if there is an offset variable included
if(is.null(attr(terms(formula(formula)), 'offset'))){
offset <- FALSE
} else {
offset <- TRUE
}

# Ensure outcome is labelled 'y' when feeding data to the model for simplicity
orig_formula <- formula
formula <- interpret_mvgam(formula, N = max(data_train$time))
Expand All @@ -695,24 +725,17 @@ mvgam = function(formula,
}
}

# Check if there is an offset variable included
if(is.null(attr(terms(formula(formula)), 'offset'))){
offset <- FALSE
} else {
offset <- TRUE
}

# If there are missing values in y, use predictions from an initial mgcv model to fill
# these in so that initial values to maintain the true size of the training dataset
orig_y <- data_train$y

# Initiate the GAM model using mgcv so that the linear predictor matrix can be easily calculated
# when simulating from the Bayesian model later on;
ss_gam <- mvgam_setup(formula = formula,
family = family_to_mgcvfam(family),
data = data_train,
drop.unused.levels = FALSE,
maxit = 30)
family = family_to_mgcvfam(family),
data = data_train,
drop.unused.levels = FALSE,
maxit = 30)

# Fill in missing observations in data_train so the size of the dataset is correct when
# building the initial JAGS model.
Expand Down Expand Up @@ -966,7 +989,19 @@ mvgam = function(formula,
model_file[grep('eta <- X %*% b', model_file, fixed = TRUE)] <-
"eta <- X %*% b + offset"
if(!missing(data_test) & !prior_simulation){
ss_jagam$jags.data$offset <- c(ss_jagam$jags.data$offset, data_test$offset)

get_offset <- function(model) {
nm1 <- names(attributes(model$terms)$dataClasses)
if('(offset)' %in% nm1) {
deparse(as.list(model$call)$offset)
} else {

sub("offset\\((.*)\\)$", "\\1", grep('offset', nm1, value = TRUE))
}
}

ss_jagam$jags.data$offset <- c(ss_jagam$jags.data$offset,
data_test[[get_offset(ss_gam)]])
}
}

Expand Down Expand Up @@ -1351,6 +1386,42 @@ mvgam = function(formula,
trend_model = trend_model)
vectorised$model_file <- trend_map_setup$model_file
vectorised$model_data <- trend_map_setup$model_data

if(trend_model %in% c('RW', 'AR1', 'AR2', 'AR3')){
param <- c(param, 'sigma')
}

# If trend formula specified, add the predictors for the trend models
if(!missing(trend_formula)){
trend_pred_setup <- add_trend_predictors(
trend_formula = trend_formula,
trend_map = trend_map,
trend_model = trend_model,
data_train = data_train,
data_test = if(missing(data_test)){
NULL
} else {
data_test
},
model_file = vectorised$model_file,
model_data = vectorised$model_data,
drift = drift)

vectorised$model_file <- trend_pred_setup$model_file
vectorised$model_data <- trend_pred_setup$model_data
trend_mgcv_model <- trend_pred_setup$trend_mgcv_model

param <- c(param, 'b_trend')

if(trend_pred_setup$trend_smooths_included){
param <- c(param, 'rho_trend')
}

if(trend_pred_setup$trend_random_included){
param <- c(param, 'mu_raw_trend', 'sigma_raw_trend')
}

}
}

} else {
Expand Down Expand Up @@ -1388,6 +1459,11 @@ mvgam = function(formula,
if(!run_model){
unlink('base_gam.txt')
output <- structure(list(call = orig_formula,
trend_call = if(!missing(trend_formula)){
trend_formula
} else {
NULL
},
family = family_char,
trend_model = trend_model,
drift = drift,
Expand All @@ -1405,6 +1481,11 @@ mvgam = function(formula,
inits = inits,
monitor_pars = param,
mgcv_model = ss_gam,
trend_mgcv_model = if(!missing(trend_formula)){
trend_mgcv_model
} else {
NULL
},
sp_names = rho_names,
ytimes = ytimes,
use_lv = use_lv,
Expand Down Expand Up @@ -1666,10 +1747,24 @@ mvgam = function(formula,
p <- mcmc_summary(out_gam_mod, 'b')[,c(4)]
names(p) <- names(ss_gam$coefficients)
ss_gam$coefficients <- p

if(!missing(trend_formula)){
V <- cov(mcmc_chains(out_gam_mod, 'b_trend'))
trend_mgcv_model$Ve <- trend_mgcv_model$Vp <- trend_mgcv_model$Vc <- V

p <- mcmc_summary(out_gam_mod, 'b_trend')[,c(4)]
names(p) <- names(trend_mgcv_model$coefficients)
trend_mgcv_model$coefficients <- p
}
}

#### Return the output as class mvgam ####
output <- structure(list(call = orig_formula,
trend_call = if(!missing(trend_formula)){
trend_formula
} else {
NULL
},
family = family_char,
trend_model = trend_model,
drift = drift,
Expand Down Expand Up @@ -1697,6 +1792,11 @@ mvgam = function(formula,
},
sp_names = rho_names,
mgcv_model = ss_gam,
trend_mgcv_model = if(!missing(trend_formula)){
trend_mgcv_model
} else {
NULL
},
ytimes = ytimes,
resids = series_resids,
use_lv = use_lv,
Expand Down
Loading

0 comments on commit 15b4251

Please sign in to comment.