Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Process linpreds #17

Merged
merged 9 commits into from
Jun 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 23 additions & 5 deletions R/dynamic.R
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,32 @@ interpret_mvgam = function(formula, N){
newfacs <- facs[!grepl('bs = \"re\"', facs, fixed = TRUE)]
refacs <- facs[grepl('bs = \"re\"', facs, fixed = TRUE)]
int <- attr(terms.formula(formula), 'intercept')
newformula <- as.formula(paste(terms.formula(formula)[[2]], '~',paste(paste(newfacs, collapse = '+'),
'+',
paste(refacs, collapse = '+'),
collapse = '+'),
ifelse(int == 0, ' - 1', '')))

# Preserve offset if included
if(!is.null(attr(terms(formula(formula)), 'offset'))){
newformula <- as.formula(paste(terms.formula(formula)[[2]], '~',
grep('offset', rownames(attr(terms.formula(formula), 'factors')),
value = TRUE),
'+',
paste(paste(newfacs, collapse = '+'),
'+',
paste(refacs, collapse = '+'),
collapse = '+'),
ifelse(int == 0, ' - 1', '')))

} else {
newformula <- as.formula(paste(terms.formula(formula)[[2]], '~',
paste(paste(newfacs, collapse = '+'),
'+',
paste(refacs, collapse = '+'),
collapse = '+'),
ifelse(int == 0, ' - 1', '')))
}

} else {
newformula <- formula
}

attr(newformula, '.Environment') <- attr(formula, '.Environment')

# Check if any terms use the dynamic wrapper
Expand Down
6 changes: 6 additions & 0 deletions R/get_mvgam_priors.R
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,12 @@ get_mvgam_priors = function(formula,
trend_map,
drift = FALSE){

# Check formula
if(attr(terms(formula), "response") == 0L){
stop('response variable is missing from formula',
call. = FALSE)
}

# Validate the family argument
family <- evaluate_family(family)
family_char <- match.arg(arg = family$family,
Expand Down
6 changes: 5 additions & 1 deletion R/mvgam-class.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
#' `print`, `score`, `summary` and `update` exist for this class.
#' @details A `mvgam` object contains the following elements:
#'\itemize{
#' \item `call` the original model formula
#' \item `call` the original observation model formula
#' \item `trend_call` If a `trend_formula was supplied`, the original trend model formula is
#' returned. Otherwise `NULL`
#' \item `family` \code{character} description of the observation distribution
#' \item `trend_model` \code{character} description of the latent trend model
#' \item `drift` Logical specifying whether a drift term was used in the trend model
Expand All @@ -32,6 +34,8 @@
#' but these are only used if generating plots of smooth functions that `mvgam` currently cannot handle
#' (such as plots for three-dimensional smooths). This model therefore should not be used for inference.
#' See \code{\link[mgcv]{gamObject}} for details.
#' \item `trend_mgcv_model` If a `trend_formula was supplied`, an object of class `gam` containing
#' the `mgcv` version of the trend model. Otherwise `NULL`
#' \item `ytimes` The `matrix` object used in model fitting for indexing which series and timepoints
#' were observed in each row of the supplied data. Used internally by some downstream plotting
#' and prediction functions
Expand Down
134 changes: 117 additions & 17 deletions R/mvgam.R
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,13 @@
#'
#'@importFrom parallel clusterExport stopCluster setDefaultCluster
#'@importFrom stats formula terms rnorm update.formula predict
#'@param formula A \code{character} string specifying the GAM formula. These are exactly like the formula
#'@param formula A \code{character} string specifying the GAM observation model formula. These are exactly like the formula
#'for a GLM except that smooth terms, s, te, ti and t2, can be added to the right hand side
#'to specify that the linear predictor depends on smooth functions of predictors (or linear functionals of these).
#'@param trend_formula An optional \code{character} string specifying the GAM process model formula. If
#'supplied, a linear predictor will be modelled for the latent trends to capture process model evolution
#'separately from the observation model. Should not have a response variable specified on the left-hand side
#'of the formula (i.e. a valid option would be `~ season + s(year)`)
#'@param knots An optional \code{list} containing user specified knot values to be used for basis construction.
#'For most bases the user simply supplies the knots to be used, which must match up with the k value supplied
#'(note that the number of knots is not always just k). Different terms can use different numbers of knots,
Expand Down Expand Up @@ -337,7 +341,7 @@
#' # Example showing how to incorporate an offset; simulate some count data
#' # with different means per series
#' set.seed(100)
#' dat <- sim_mvgam(trend_rel = 0, mu_obs = c(4, 8, 8), seasonality = 'hierarchical')
#' dat <- sim_mvgam(trend_rel = 0, mu = c(0, 2, 2), seasonality = 'hierarchical')
#'
#' # Add offset terms to the training and testing data
#' dat$data_train$offset <- 0.5 * as.numeric(dat$data_train$series)
Expand All @@ -346,7 +350,8 @@
#' # Fit a model that includes the offset in the linear predictor as well as
#' # hierarchical seasonal smooths
#' mod1 <- mvgam(formula = y ~ offset(offset) +
#' s(season, bs = 'cc') +
#' s(series, bs = 're') +
#' s(season, bs = 'cc') +
#' s(season, by = series, m = 1, k = 5),
#' data = dat$data_train,
#' trend_model = 'None',
Expand All @@ -359,9 +364,9 @@
#' # Forecasts for the first two series will differ in magnitude
#' layout(matrix(1:2, ncol = 2))
#' plot(mod1, type = 'forecast', series = 1, newdata = dat$data_test,
#' ylim = c(0, 70))
#' ylim = c(0, 75))
#' plot(mod1, type = 'forecast', series = 2, newdata = dat$data_test,
#' ylim = c(0, 70))
#' ylim = c(0, 75))
#' layout(1)
#'
#' # Changing the offset for the testing data should lead to changes in
Expand Down Expand Up @@ -396,6 +401,7 @@
#'@export

mvgam = function(formula,
trend_formula,
knots,
data,
data_train,
Expand Down Expand Up @@ -433,6 +439,11 @@ mvgam = function(formula,
data_train <- data
}

if(attr(terms(formula), "response") == 0L){
stop('response variable is missing from formula',
call. = FALSE)
}

if(!as.character(terms(formula(formula))[[2]]) %in% names(data_train)){
stop(paste0('variable ', terms(formula(formula))[[2]], ' not found in data'),
call. = FALSE)
Expand Down Expand Up @@ -467,6 +478,18 @@ mvgam = function(formula,
# Validate the trend arguments
trend_model <- evaluate_trend_model(trend_model)

if(!missing(trend_formula)){
if(missing(trend_map)){
trend_map <- data.frame(series = unique(data_train$series),
trend = 1:length(unique(data_train$series)))
}

if(trend_model != 'RW'){
stop('only random walk trends currently supported for trend predictor models',
call. = FALSE)
}
}

# Check trend_map is correctly specified
if(!missing(trend_map)){

Expand Down Expand Up @@ -676,6 +699,13 @@ mvgam = function(formula,
warning('No point in latent variables if trend model is None; changing use_lv to FALSE')
}

# Check if there is an offset variable included
if(is.null(attr(terms(formula(formula)), 'offset'))){
offset <- FALSE
} else {
offset <- TRUE
}

# Ensure outcome is labelled 'y' when feeding data to the model for simplicity
orig_formula <- formula
formula <- interpret_mvgam(formula, N = max(data_train$time))
Expand All @@ -695,24 +725,17 @@ mvgam = function(formula,
}
}

# Check if there is an offset variable included
if(is.null(attr(terms(formula(formula)), 'offset'))){
offset <- FALSE
} else {
offset <- TRUE
}

# If there are missing values in y, use predictions from an initial mgcv model to fill
# these in so that initial values to maintain the true size of the training dataset
orig_y <- data_train$y

# Initiate the GAM model using mgcv so that the linear predictor matrix can be easily calculated
# when simulating from the Bayesian model later on;
ss_gam <- mvgam_setup(formula = formula,
family = family_to_mgcvfam(family),
data = data_train,
drop.unused.levels = FALSE,
maxit = 30)
family = family_to_mgcvfam(family),
data = data_train,
drop.unused.levels = FALSE,
maxit = 30)

# Fill in missing observations in data_train so the size of the dataset is correct when
# building the initial JAGS model.
Expand Down Expand Up @@ -966,7 +989,19 @@ mvgam = function(formula,
model_file[grep('eta <- X %*% b', model_file, fixed = TRUE)] <-
"eta <- X %*% b + offset"
if(!missing(data_test) & !prior_simulation){
ss_jagam$jags.data$offset <- c(ss_jagam$jags.data$offset, data_test$offset)

get_offset <- function(model) {
nm1 <- names(attributes(model$terms)$dataClasses)
if('(offset)' %in% nm1) {
deparse(as.list(model$call)$offset)
} else {

sub("offset\\((.*)\\)$", "\\1", grep('offset', nm1, value = TRUE))
}
}

ss_jagam$jags.data$offset <- c(ss_jagam$jags.data$offset,
data_test[[get_offset(ss_gam)]])
}
}

Expand Down Expand Up @@ -1351,6 +1386,42 @@ mvgam = function(formula,
trend_model = trend_model)
vectorised$model_file <- trend_map_setup$model_file
vectorised$model_data <- trend_map_setup$model_data

if(trend_model %in% c('RW', 'AR1', 'AR2', 'AR3')){
param <- c(param, 'sigma')
}

# If trend formula specified, add the predictors for the trend models
if(!missing(trend_formula)){
trend_pred_setup <- add_trend_predictors(
trend_formula = trend_formula,
trend_map = trend_map,
trend_model = trend_model,
data_train = data_train,
data_test = if(missing(data_test)){
NULL
} else {
data_test
},
model_file = vectorised$model_file,
model_data = vectorised$model_data,
drift = drift)

vectorised$model_file <- trend_pred_setup$model_file
vectorised$model_data <- trend_pred_setup$model_data
trend_mgcv_model <- trend_pred_setup$trend_mgcv_model

param <- c(param, 'b_trend')

if(trend_pred_setup$trend_smooths_included){
param <- c(param, 'rho_trend')
}

if(trend_pred_setup$trend_random_included){
param <- c(param, 'mu_raw_trend', 'sigma_raw_trend')
}

}
}

} else {
Expand Down Expand Up @@ -1388,6 +1459,11 @@ mvgam = function(formula,
if(!run_model){
unlink('base_gam.txt')
output <- structure(list(call = orig_formula,
trend_call = if(!missing(trend_formula)){
trend_formula
} else {
NULL
},
family = family_char,
trend_model = trend_model,
drift = drift,
Expand All @@ -1405,6 +1481,11 @@ mvgam = function(formula,
inits = inits,
monitor_pars = param,
mgcv_model = ss_gam,
trend_mgcv_model = if(!missing(trend_formula)){
trend_mgcv_model
} else {
NULL
},
sp_names = rho_names,
ytimes = ytimes,
use_lv = use_lv,
Expand Down Expand Up @@ -1666,10 +1747,24 @@ mvgam = function(formula,
p <- mcmc_summary(out_gam_mod, 'b')[,c(4)]
names(p) <- names(ss_gam$coefficients)
ss_gam$coefficients <- p

if(!missing(trend_formula)){
V <- cov(mcmc_chains(out_gam_mod, 'b_trend'))
trend_mgcv_model$Ve <- trend_mgcv_model$Vp <- trend_mgcv_model$Vc <- V

p <- mcmc_summary(out_gam_mod, 'b_trend')[,c(4)]
names(p) <- names(trend_mgcv_model$coefficients)
trend_mgcv_model$coefficients <- p
}
}

#### Return the output as class mvgam ####
output <- structure(list(call = orig_formula,
trend_call = if(!missing(trend_formula)){
trend_formula
} else {
NULL
},
family = family_char,
trend_model = trend_model,
drift = drift,
Expand Down Expand Up @@ -1697,6 +1792,11 @@ mvgam = function(formula,
},
sp_names = rho_names,
mgcv_model = ss_gam,
trend_mgcv_model = if(!missing(trend_formula)){
trend_mgcv_model
} else {
NULL
},
ytimes = ytimes,
resids = series_resids,
use_lv = use_lv,
Expand Down
Loading