Skip to content

Commit 5963321

Browse files
author
Nicholas Clark
committed
test VARs more broadly
1 parent c06a9cc commit 5963321

File tree

6 files changed

+72
-102
lines changed

6 files changed

+72
-102
lines changed

R/stan_utils.R

+7-101
Original file line numberDiff line numberDiff line change
@@ -1133,21 +1133,6 @@ mcmc_chains = function(object,
11331133
#for rstanarm/brms objects - set to NULL by default
11341134
sp_names <- NULL
11351135

1136-
#if from R2jags::jags.parallel
1137-
if (methods::is(object, 'rjags.parallel'))
1138-
{
1139-
x <- object$BUGSoutput
1140-
mclist <- vector('list', x$n.chains)
1141-
mclis <- vector('list', x$n.chains)
1142-
ord <- dimnames(x$sims.array)[[3]]
1143-
for (i in 1:x$n.chains)
1144-
{
1145-
tmp1 <- x$sims.array[, i, ord]
1146-
mclis[[i]] <- coda::mcmc(tmp1, thin = x$n.thin)
1147-
}
1148-
object <- coda::as.mcmc.list(mclis)
1149-
}
1150-
11511136
#if mcmc object (from nimble) - convert to mcmc.list
11521137
if (methods::is(object, 'mcmc'))
11531138
{
@@ -1160,13 +1145,6 @@ mcmc_chains = function(object,
11601145
object <- coda::mcmc.list(lapply(object, function(x) coda::mcmc(x)))
11611146
}
11621147

1163-
#if from rstanarm::stan_glm
1164-
if (methods::is(object, 'stanreg'))
1165-
{
1166-
object <- object$stanfit
1167-
sp_names <- object@sim$fnames_oi
1168-
}
1169-
11701148
if (coda::is.mcmc.list(object) != TRUE &
11711149
!methods::is(object, 'matrix') &
11721150
!methods::is(object, 'mcmc') &
@@ -1180,25 +1158,6 @@ mcmc_chains = function(object,
11801158
stop('Invalid object type. Input must be stanfit object (rstan), CmdStanMCMC object (cmdstanr), stanreg object (rstanarm), brmsfit object (brms), mcmc.list object (coda/rjags), mcmc object (coda/nimble), list object (nimble), rjags object (R2jags), jagsUI object (jagsUI), or matrix with MCMC chains.')
11811159
}
11821160

1183-
#if from brms::brm
1184-
if (methods::is(object, 'brmsfit'))
1185-
{
1186-
#extract stanfit portion of object
1187-
object <- object$fit
1188-
#Stan names
1189-
sp_names_p <- names(object@sim$samples[[1]])
1190-
#remove b_ and r_
1191-
st_nm <- substr(sp_names_p, start = 1, stop = 2)
1192-
sp_names <- rep(NA, length(sp_names_p))
1193-
b_idx <- which(st_nm == 'b_')
1194-
r_idx <- which(st_nm == 'r_')
1195-
ot_idx <- which(st_nm != 'b_' & st_nm != 'r_')
1196-
#fill names vec with b_ and r_ removed
1197-
sp_names[b_idx] <- gsub('b_', '', sp_names_p[b_idx])
1198-
sp_names[r_idx] <- gsub('r_', '', sp_names_p[r_idx])
1199-
sp_names[ot_idx] <- sp_names_p[ot_idx]
1200-
}
1201-
12021161
#NAME SORTING BLOCK
12031162
if (methods::is(object, 'stanfit'))
12041163
{
@@ -4271,66 +4230,13 @@ check_rhat <- function(fit, quiet=FALSE, fit_summary) {
42714230
#' @param quiet Logical (verbose or not?)
42724231
#' @details Utility function written by Michael Betancourt (https://betanalpha.github.io/)
42734232
#' @noRd
4274-
check_all_diagnostics <- function(fit, quiet=FALSE, max_treedepth = 10) {
4233+
check_all_diagnostics <- function(fit, max_treedepth = 10) {
42754234
sampler_params <- rstan::get_sampler_params(fit, inc_warmup=FALSE)
42764235
fit_summary <- rstan::summary(fit, probs = c(0.5))$summary
4277-
if (!quiet) {
4278-
check_n_eff(fit, fit_summary = fit_summary)
4279-
check_rhat(fit, fit_summary = fit_summary)
4280-
check_div(fit, sampler_params = sampler_params)
4281-
check_treedepth(fit, max_depth = max_treedepth,
4282-
sampler_params = sampler_params)
4283-
check_energy(fit, sampler_params = sampler_params)
4284-
} else {
4285-
warning_code <- 0
4286-
4287-
if (!check_n_eff(fit, quiet=TRUE, fit_summary = fit_summary))
4288-
warning_code <- bitwOr(warning_code, bitwShiftL(1, 0))
4289-
if (!check_rhat(fit, quiet=TRUE, fit_summary = fit_summary))
4290-
warning_code <- bitwOr(warning_code, bitwShiftL(1, 1))
4291-
if (!check_div(fit, quiet=TRUE, sampler_params = sampler_params))
4292-
warning_code <- bitwOr(warning_code, bitwShiftL(1, 2))
4293-
if (!check_treedepth(fit, quiet=TRUE, sampler_params = sampler_params))
4294-
warning_code <- bitwOr(warning_code, bitwShiftL(1, 3))
4295-
if (!check_energy(fit, quiet=TRUE, sampler_params = sampler_params))
4296-
warning_code <- bitwOr(warning_code, bitwShiftL(1, 4))
4297-
4298-
return(warning_code)
4299-
}
4300-
}
4301-
4302-
#' Parse warnings
4303-
#' @param warning_code Type of warning code to generate
4304-
#' @details Utility function written by Michael Betancourt (https://betanalpha.github.io/)
4305-
#' @noRd
4306-
parse_warning_code <- function(warning_code) {
4307-
if (bitwAnd(warning_code, bitwShiftL(1, 0)))
4308-
cat("n_eff / iteration warning")
4309-
if (bitwAnd(warning_code, bitwShiftL(1, 1)))
4310-
cat("rhat warning")
4311-
if (bitwAnd(warning_code, bitwShiftL(1, 2)))
4312-
cat("divergence warning")
4313-
if (bitwAnd(warning_code, bitwShiftL(1, 3)))
4314-
cat("treedepth warning")
4315-
if (bitwAnd(warning_code, bitwShiftL(1, 4)))
4316-
cat("energy warning")
4317-
}
4318-
4319-
#' Return parameter arrays separated into divergent and non-divergent transitions
4320-
#' @param fit A stanfit object
4321-
#' @details Utility function written by Michael Betancourt (https://betanalpha.github.io/)
4322-
#' @noRd
4323-
partition_div <- function(fit) {
4324-
nom_params <- rstan::extract(fit, permuted=FALSE)
4325-
n_chains <- dim(nom_params)[2]
4326-
params <- as.data.frame(do.call(rbind, lapply(1:n_chains, function(n) nom_params[,n,])))
4327-
4328-
sampler_params <- rstan::get_sampler_params(fit, inc_warmup=FALSE)
4329-
divergent <- do.call(rbind, sampler_params)[,'divergent__']
4330-
params$divergent <- divergent
4331-
4332-
div_params <- params[params$divergent == 1,]
4333-
nondiv_params <- params[params$divergent == 0,]
4334-
4335-
return(list(div_params, nondiv_params))
4236+
check_n_eff(fit, fit_summary = fit_summary)
4237+
check_rhat(fit, fit_summary = fit_summary)
4238+
check_div(fit, sampler_params = sampler_params)
4239+
check_treedepth(fit, max_depth = max_treedepth,
4240+
sampler_params = sampler_params)
4241+
check_energy(fit, sampler_params = sampler_params)
43364242
}

tests/testthat/Rplots.pdf

205 KB
Binary file not shown.

tests/testthat/test-RW.R

+17-1
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,22 @@ test_that("ma and cor options should work for trends other than VAR", {
6565
})
6666

6767
test_that("VARMAs are set up correctly", {
68+
var <- mvgam(y ~ s(series, bs = 're') +
69+
s(season, bs = 'cc') - 1,
70+
trend_model = VAR(),
71+
data = gaus_data$data_train,
72+
family = gaussian(),
73+
run_model = FALSE)
74+
expect_true(inherits(var, 'mvgam_prefit'))
75+
76+
var <- mvgam(y ~ s(series, bs = 're') +
77+
gp(time, c = 5/4, k = 20) - 1,
78+
trend_model = VAR(),
79+
data = gaus_data$data_train,
80+
family = gaussian(),
81+
run_model = FALSE)
82+
expect_true(inherits(var, 'mvgam_prefit'))
83+
6884
varma <- mvgam(y ~ s(series, bs = 're') +
6985
s(season, bs = 'cc') - 1,
7086
trend_model = 'VARMA',
@@ -76,7 +92,7 @@ test_that("VARMAs are set up correctly", {
7692
varma$model_file, fixed = TRUE)))
7793

7894
varma <- mvgam(y ~ s(series, bs = 're'),
79-
trend_formula = ~ s(season, bs = 'cc'),
95+
trend_formula = ~ gp(time, by = trend, c = 5/4),
8096
trend_model = VAR(ma = TRUE),
8197
data = gaus_data$data_train,
8298
family = gaussian(),

tests/testthat/test-example_processing.R

+19
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,25 @@ test_that("fitted() gives correct dimensions", {
1414
NROW(fitted(mvgam:::mvgam_example4)))
1515
})
1616

17+
test_that("residuals() gives correct dimensions", {
18+
expect_equal(NROW(mvgam:::mvgam_examp_dat$data_train),
19+
NROW(residuals(mvgam:::mvgam_example1)))
20+
21+
expect_equal(NROW(mvgam:::mvgam_examp_dat$data_train),
22+
NROW(residuals(mvgam:::mvgam_example2)))
23+
24+
expect_equal(NROW(mvgam:::mvgam_examp_dat$data_train),
25+
NROW(residuals(mvgam:::mvgam_example3)))
26+
27+
expect_equal(NROW(mvgam:::mvgam_examp_dat$data_train),
28+
NROW(residuals(mvgam:::mvgam_example4,
29+
robust = TRUE)))
30+
31+
expect_equal(NROW(mvgam:::mvgam_examp_dat$data_train),
32+
NROW(residuals(mvgam:::mvgam_example5,
33+
summary = FALSE)))
34+
})
35+
1736
test_that("variable extraction works correctly", {
1837
expect_true(inherits(as.matrix(mvgam:::mvgam_example4,
1938
'A', regex = TRUE),

tests/testthat/test-mvgam-methods.R

+22
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,28 @@ test_that("inverse links working", {
66
expect_true(is(mvgam:::family_invlinks('beta_binomial'), 'function'))
77
})
88

9+
test_that("series_to_mvgam working", {
10+
skip_on_cran()
11+
data("sunspots")
12+
series <- cbind(sunspots, sunspots)
13+
colnames(series) <- c('blood', 'bone')
14+
expect_true(inherits(series_to_mvgam(series,
15+
frequency(series),
16+
0.85),
17+
'list'))
18+
19+
# An xts object example
20+
dates <- seq(as.Date("2001-05-01"), length=30, by="quarter")
21+
data <- cbind(c(gas = rpois(30, cumprod(1+rnorm(30, mean = 0.01, sd = 0.001)))),
22+
c(oil = rpois(30, cumprod(1+rnorm(30, mean = 0.01, sd = 0.001)))))
23+
series <- xts::xts(x = data, order.by = dates)
24+
colnames(series) <- c('gas', 'oil')
25+
expect_true(inherits(series_to_mvgam(series,
26+
freq = 4,
27+
train_prop = 0.85),
28+
'list'))
29+
})
30+
931
test_that("add_residuals working properly", {
1032
mod <- mvgam:::mvgam_example1
1133
oldresids <- mod$resids

tests/testthat/test-mvgam.R

+7
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,13 @@ test_that("JAGS setups work", {
6363
run_model = FALSE)
6464
expect_true(inherits(mod, 'mvgam_prefit'))
6565

66+
expect_true(inherits(get_mvgam_priors(y ~ s(season),
67+
trend_model = 'RW',
68+
drift = TRUE,
69+
data = simdat$data_train,
70+
family = tweedie(),
71+
use_stan = FALSE),
72+
'data.frame'))
6673
mod <- mvgam(y ~ s(season),
6774
trend_model = 'RW',
6875
drift = TRUE,

0 commit comments

Comments
 (0)