Skip to content

Commit

Permalink
2 chains
Browse files Browse the repository at this point in the history
  • Loading branch information
Nicholas Clark committed Apr 30, 2024
1 parent 47bde90 commit 226aadd
Show file tree
Hide file tree
Showing 298 changed files with 6,758 additions and 1,018 deletions.
4 changes: 3 additions & 1 deletion R/RW.R
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,9 @@
#' trend_model = CAR(),
#' data = dat,
#' family = gaussian(),
#' run_model = TRUE)
#' burnin = 300,
#' samples = 300,
#' chains = 2)
#'
#'# View usual summaries and plots
#'summary(mod)
Expand Down
4 changes: 3 additions & 1 deletion R/as.data.frame.mvgam.R
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@
#'mod1 <- mvgam(y ~ s(season, bs = 'cc'),
#' trend_model = 'AR1',
#' data = sim$data_train,
#' family = Gamma())
#' family = Gamma(),
#' chains = 2,
#' samples = 300)
#'beta_draws_df <- as.data.frame(mod1, variable = 'betas')
#'head(beta_draws_df)
#'str(beta_draws_df)
Expand Down
93 changes: 93 additions & 0 deletions R/data_grids.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
#' Get data objects into correct order in case it is not already
#'@noRd
sort_data = function(data, series_time = FALSE){
if(inherits(data, 'list')){
data_arranged <- data
if(series_time){
temp_dat = data.frame(time = data$index..time..index,
series = data$series) %>%
dplyr::mutate(index = dplyr::row_number()) %>%
dplyr::arrange(series, time)
} else {
temp_dat = data.frame(time = data$index..time..index,
series = data$series) %>%
dplyr::mutate(index = dplyr::row_number()) %>%
dplyr::arrange(time, series)
}

data_arranged <- lapply(data, function(x){
if(is.matrix(x)){
matrix(x[temp_dat$index,], ncol = NCOL(x))
} else {
x[temp_dat$index]
}
})
names(data_arranged) <- names(data)
} else {
if(series_time){
data_arranged <- data %>%
dplyr::arrange(series, index..time..index)
} else {
data_arranged <- data %>%
dplyr::arrange(index..time..index, series)
}
}

return(data_arranged)
}

#' Create prediction grids, mostly for simple plotting functions
#'@noRd
data_grid = function(..., newdata){

dots <- list(...)
vars <- names(dots)

# Validate that vars exist in supplied data
for(i in seq_along(vars)){
if(!exists(vars[i], newdata)){
stop(paste0('Variable ', vars[i], ' not found in newdata'),
call. = FALSE)
}
}

# Create sample dummy dataframe to get the prediction grid, ensuring
# factors are preserved
newdat_grid <- data.frame(do.call(cbind.data.frame,
lapply(vars, function(x){
newdata[[x]]
})))
colnames(newdat_grid) <- vars

# Use the supplied conditions for making the datagrid
newdat_grid <- marginaleffects::datagrid(..., newdata = newdat_grid)

# Now replicate the first observation for all other variables
if(inherits(newdata, 'list')){
newdat_full <- lapply(seq_along(newdata), function(x){
if(names(newdata)[x] %in% vars){
newdat_grid[[names(newdata)[x]]]
} else {
if(is.matrix(newdata[[x]])){
t(replicate(NROW(newdat_grid), newdata[[x]][1, ]))
} else {
if(is.factor(newdata[[x]])){
factor(rep(newdata[[x]][1], NROW(newdat_grid)),
levels = levels(newdata[[x]]))
} else {
rep(newdata[[x]][1], NROW(newdat_grid))
}
}
}
})
names(newdat_full) <- names(newdata)
} else {
newdat_full <-
dplyr::bind_cols(newdat_grid,
data.frame(newdata %>%
dplyr::select(!vars) %>%
dplyr::slice_head(n = 1)))
}

return(newdat_full)
}
11 changes: 6 additions & 5 deletions R/dynamic.R
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,12 @@
#'
#'# Fit a model using the dynamic function
#'mod <- mvgam(out ~
#' # mis-specify the length scale slightly as this
#' # won't be known in practice
#' dynamic(predictor, rho = 8, stationary = TRUE),
#' family = gaussian(),
#' data = data_train)
#' # mis-specify the length scale slightly as this
#' # won't be known in practice
#' dynamic(predictor, rho = 8, stationary = TRUE),
#' family = gaussian(),
#' data = data_train,
#' chains = 2)
#'
#'# Inspect the summary
#'summary(mod)
Expand Down
6 changes: 4 additions & 2 deletions R/evaluate_mvgams.R
Original file line number Diff line number Diff line change
Expand Up @@ -76,14 +76,16 @@
#' trend_model = AR(p = 2),
#' family = poisson(),
#' data = dat$data_train,
#' newdata = dat$data_test)
#' newdata = dat$data_test,
#' chains = 2)
#'
#'# Fit a less appropriate model
#'mod_rw <- mvgam(y ~ s(season, bs = 'cc'),
#' trend_model = RW(),
#' family = poisson(),
#' data = dat$data_train,
#' newdata = dat$data_test)
#' newdata = dat$data_test,
#' chains = 2)
#'
#'# Compare Discrete Ranked Probability Scores for the testing period
#'fc_ar2 <- forecast(mod_ar2)
Expand Down
3 changes: 2 additions & 1 deletion R/families.R
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,8 @@ student_t = function(link = 'identity'){
#' priors = c(prior(std_normal(), class = b),
#' prior(normal(1, 1.5), class = Intercept_trend)),
#' burnin = 300,
#' samples = 300)
#' samples = 300,
#' chains = 2)
#'
#' # The usual diagnostics
#' summary(mod)
Expand Down
3 changes: 2 additions & 1 deletion R/forecast.mvgam.R
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ forecast <- function(object, ...){
#' trend_model = AR(),
#' data = simdat$data_train,
#' burnin = 300,
#' samples = 300)
#' samples = 300,
#' chains = 2)
#'
#' # Hindcasts on response scale
#' hc <- hindcast(mod)
Expand Down
3 changes: 2 additions & 1 deletion R/hindcast.mvgam.R
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ hindcast <- function(object, ...){
#' trend_model = AR(),
#' data = simdat$data_train,
#' burnin = 300,
#' samples = 300)
#' samples = 300,
#' chains = 2)
#'
#' # Hindcasts on response scale
#' hc <- hindcast(mod)
Expand Down
3 changes: 2 additions & 1 deletion R/index-mvgam.R
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ NULL
#' trend_model = AR(),
#' data = simdat$data_train,
#' burnin = 300,
#' samples = 300)
#' samples = 300,
#' chains = 2)
#' variables(mod)
#' }
#' @export
Expand Down
6 changes: 4 additions & 2 deletions R/lfo_cv.mvgam.R
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@
#' data = dat$data_train,
#' newdata = dat$data_test,
#' burnin = 300,
#' samples = 300)
#' samples = 300,
#' chains = 2)
#'
#'# Fit a less appropriate model
#'mod_rw <- mvgam(y ~ s(season, bs = 'cc', k = 6),
Expand All @@ -69,7 +70,8 @@
#' data = dat$data_train,
#' newdata = dat$data_test,
#' burnin = 300,
#' samples = 300)
#' samples = 300,
#' chains = 2)
#'
#'# Compare Discrete Ranked Probability Scores for the testing period
#'fc_ar2 <- forecast(mod_ar2)
Expand Down
9 changes: 5 additions & 4 deletions R/logLik.mvgam.R
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,11 @@
#' # Simulate some data and fit a model
#' simdat <- sim_mvgam(n_series = 1, trend_model = 'AR1')
#' mod <- mvgam(y ~ s(season, bs = 'cc', k = 6),
#' trend_model = AR(),
#' data = simdat$data_train,
#' burnin = 300,
#' samples = 300)
#' trend_model = AR(),
#' data = simdat$data_train,
#' burnin = 300,
#' samples = 300,
#' chains = 2)
#'
#'# Extract logLikelihood values
#'lls <- logLik(mod)
Expand Down
9 changes: 6 additions & 3 deletions R/loo.mvgam.R
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
#' simdat$data_test),
#' family = gaussian(),
#' burnin = 300,
#' samples = 300)
#' samples = 300,
#' chains = 2)
#'plot(mod1, type = 'smooths')
#'loo(mod1)
#'
Expand All @@ -34,15 +35,17 @@
#' s(season, series, bs = 'fs',
#' xt = list(bs = 'cc'), k = 4),
#' burnin = 300,
#' samples = 300)
#' samples = 300,
#' chains = 2)
#'plot(mod2, type = 'smooths')
#'loo(mod2)
#'
#'# Now add AR1 dynamic errors to mod2
#'mod3 <- update(mod2,
#' trend_model = 'AR1',
#' burnin = 300,
#' samples = 300)
#' samples = 300,
#' chains = 2)
#'plot(mod3, type = 'smooths')
#'plot(mod3, type = 'trend')
#'loo(mod3)
Expand Down
3 changes: 2 additions & 1 deletion R/lv_correlations.R
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
#' n_lv = 2,
#' data = simdat$data_train,
#' burnin = 300,
#' samples = 300)
#' samples = 300,
#' chains = 2)
#'lvcors <- lv_correlations(mod)
#'names(lvcors)
#'lapply(lvcors, class)
Expand Down
3 changes: 2 additions & 1 deletion R/mcmc_plot.mvgam.R
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@
#' trend_model = AR(),
#' data = simdat$data_train,
#' burnin = 300,
#' samples = 300)
#' samples = 300,
#' chains = 2)
#' mcmc_plot(mod)
#' mcmc_plot(mod, type = 'neff_hist')
#' mcmc_plot(mod, variable = 'betas', type = 'areas')
Expand Down
6 changes: 4 additions & 2 deletions R/monotonic.R
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@
#' data = mod_data,
#' family = gaussian(),
#' burnin = 300,
#' samples = 300)
#' samples = 300,
#' chains = 2)
#'
#' plot_predictions(mod2,
#' by = 'x',
Expand Down Expand Up @@ -108,7 +109,8 @@
#' data = mod_data,
#' family = gaussian(),
#' burnin = 300,
#' samples = 300)
#' samples = 300,
#' chains = 2)
#'
#' # Visualise the different monotonic functions
#' plot_predictions(mod, condition = c('x', 'fac', 'fac'),
Expand Down
18 changes: 12 additions & 6 deletions R/mvgam.R
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,8 @@
#' trend_model = RW(),
#' family = poisson(),
#' burnin = 300,
#' samples = 300)
#' samples = 300,
#' chains = 2)
#'
#' # Extract the model summary
#' summary(mod1)
Expand Down Expand Up @@ -399,7 +400,8 @@
#' data = mod_data,
#' return_model_data = TRUE,
#' burnin = 300,
#' samples = 300)
#' samples = 300,
#' chains = 2)
#'
#' # The mapping matrix is now supplied as data to the model in the 'Z' element
#' mod1$model_data$Z
Expand Down Expand Up @@ -443,7 +445,8 @@
#' data = data_train,
#' newdata = data_test,
#' burnin = 300,
#' samples = 300)
#' samples = 300,
#' chains = 2)
#'
#' # Inspect the model summary, forecast and time-varying coefficient distribution
#' summary(mod)
Expand Down Expand Up @@ -475,7 +478,8 @@
#' data = dat$data_train,
#' trend_model = 'None',
#' burnin = 300,
#' samples = 300)
#' samples = 300,
#' chains = 2)
#'
#' # Inspect the model file to see the modification to the linear predictor
#' # (eta)
Expand Down Expand Up @@ -569,7 +573,8 @@
#' family = poisson(),
#' data = mod_data,
#' burnin = 300,
#' samples = 300)
#' samples = 300,
#' chains = 2)
#' summary(mod)
#'
#' # Plot the posterior hindcast
Expand Down Expand Up @@ -610,7 +615,8 @@
#' family = binomial(),
#' data = dat,
#' burnin = 300,
#' samples = 300)
#' samples = 300,
#' chains = 2)
#' summary(mod)
#' pp_check(mod, type = "bars_grouped",
#' group = "series", ndraws = 50)
Expand Down
3 changes: 2 additions & 1 deletion R/mvgam_diagnostics.R
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
#' trend_model = AR(),
#' data = simdat$data_train,
#' burnin = 300,
#' samples = 300)
#' samples = 300,
#' chains = 2)
#' np <- nuts_params(mod)
#' head(np)
#'
Expand Down
5 changes: 3 additions & 2 deletions R/pairs.mvgam.R
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@
#' \donttest{
#' simdat <- sim_mvgam(n_series = 1, trend_model = 'AR1')
#' mod <- mvgam(y ~ s(season, bs = 'cc'),
#' trend_model = AR(),
#' data = simdat$data_train)
#' trend_model = AR(),
#' data = simdat$data_train,
#' chains = 2)
#' pairs(mod)
#' pairs(mod, variable = c('ar1', 'sigma'), regex = TRUE)
#' }
Expand Down
Loading

0 comments on commit 226aadd

Please sign in to comment.