Skip to content

Commit 5bf88c6

Browse files
author
Nicholas Clark
committed
catch normal glm for older versions of Stan
1 parent e546475 commit 5bf88c6

File tree

6 files changed

+67
-42
lines changed

6 files changed

+67
-42
lines changed

DESCRIPTION

+2-2
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ Depends:
2323
insight (>= 0.19.1),
2424
methods
2525
Imports:
26-
rstan (>= 2.19.2),
26+
rstan (>= 2.29.0),
2727
posterior (>= 1.0.0),
2828
loo (>= 2.3.1),
2929
rstantools (>= 2.1.1),
@@ -46,7 +46,7 @@ LazyData: true
4646
Roxygen: list(markdown = TRUE)
4747
RoxygenNote: 7.2.3
4848
Suggests:
49-
cmdstanr (>= 0.4.0),
49+
cmdstanr (>= 0.5.0),
5050
tweedie,
5151
splines2,
5252
extraDistr,

R/mvgam.R

+15-7
Original file line numberDiff line numberDiff line change
@@ -1674,16 +1674,24 @@ mvgam = function(formula,
16741674

16751675
# Auto-format the model file
16761676
if(autoformat){
1677-
if(requireNamespace('cmdstanr') & cmdstanr::cmdstan_version() >= "2.29.0") {
1677+
if(requireNamespace('cmdstanr') &
1678+
cmdstanr::cmdstan_version() >= "2.29.0") {
16781679
vectorised$model_file <- .autoformat(vectorised$model_file,
1679-
overwrite_file = FALSE)
1680+
overwrite_file = FALSE,
1681+
backend = 'cmdstanr')
16801682
}
16811683
vectorised$model_file <- readLines(textConnection(vectorised$model_file),
16821684
n = -1)
16831685
}
16841686

16851687
} else {
1686-
1688+
if(autoformat){
1689+
vectorised$model_file <- .autoformat(vectorised$model_file,
1690+
overwrite_file = FALSE,
1691+
backend = 'rstan')
1692+
vectorised$model_file <- readLines(textConnection(vectorised$model_file),
1693+
n = -1)
1694+
}
16871695
# Replace new syntax if this is an older version of Stan
16881696
if(rstan::stan_version() < "2.26"){
16891697
warning('Your version of rstan is out of date. Some features of mvgam may not work')
@@ -1842,8 +1850,8 @@ mvgam = function(formula,
18421850
if(use_cmdstan){
18431851
message('Using cmdstanr as the backend')
18441852
message()
1845-
if(cmdstanr::cmdstan_version() < "2.24.0"){
1846-
warning('Your version of Cmdstan is < 2.24.0; some mvgam models may not work properly!')
1853+
if(cmdstanr::cmdstan_version() < "2.26.0"){
1854+
warning('Your version of Cmdstan is < 2.26.0; some mvgam models may not work properly!')
18471855
}
18481856

18491857
if(algorithm == 'pathfinder'){
@@ -1995,8 +2003,8 @@ mvgam = function(formula,
19952003
message('Using rstan as the backend')
19962004
message()
19972005

1998-
if(rstan::stan_version() < "2.24.0"){
1999-
warning('Your version of Stan is < 2.24.0; some mvgam models may not work properly!')
2006+
if(rstan::stan_version() < "2.26.0"){
2007+
warning('Your version of Stan is < 2.26.0; some mvgam models may not work properly!')
20002008
}
20012009

20022010
if(algorithm == 'pathfinder'){

R/stan_utils.R

+50-33
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@ remove_likelihood = function(model_file){
3636
}
3737

3838
#' @noRd
39-
.autoformat <- function(stan_file, overwrite_file = TRUE){
39+
.autoformat <- function(stan_file, overwrite_file = TRUE,
40+
backend = 'cmdstanr'){
4041

4142
# No need to fill lv_coefs in each iteration if this is a
4243
# trend_formula model
@@ -52,27 +53,6 @@ remove_likelihood = function(model_file){
5253
stan_file, fixed = TRUE)] <-
5354
'trend[i, s] = dot_product(Z[s,], LV[i,]);'
5455

55-
# if(any(grepl('// derived latent states',
56-
# stan_file, fixed = TRUE))){
57-
# starts <- grep('// derived latent states',
58-
# stan_file, fixed = TRUE) + 1
59-
# ends <- starts + 4
60-
# stan_file <- stan_file[-c(starts:ends)]
61-
# stan_file[grep('// derived latent states',
62-
# stan_file, fixed = TRUE)] <-
63-
# paste0('// derived latent states\n',
64-
# "trend = LV * Z';")
65-
# } else {
66-
# starts <- grep('// derived latent trends',
67-
# stan_file, fixed = TRUE) + 1
68-
# ends <- starts + 4
69-
# stan_file <- stan_file[-c(starts:ends)]
70-
# stan_file[grep('// derived latent trends',
71-
# stan_file, fixed = TRUE)] <-
72-
# paste0('// derived latent trends\n',
73-
# "trend = LV * Z';")
74-
# }
75-
7656
stan_file[grep('// posterior predictions',
7757
stan_file, fixed = TRUE)-1] <-
7858
paste0(stan_file[grep('// posterior predictions',
@@ -81,13 +61,43 @@ remove_likelihood = function(model_file){
8161
'matrix[n_series, n_lv] lv_coefs = Z;')
8262
stan_file <- readLines(textConnection(stan_file), n = -1)
8363
}
64+
65+
if(backend == 'rstan' & rstan::stan_version() < '2.29.0'){
66+
# normal_id_glm became available in 2.29.0; this needs to be replaced
67+
# with the older non-glm version
68+
if(any(grepl('normal_id_glm',
69+
stan_file, fixed = TRUE))){
70+
if(any(grepl("flat_ys ~ normal_id_glm(flat_xs,",
71+
stan_file, fixed = TRUE))){
72+
start <- grep("flat_ys ~ normal_id_glm(flat_xs,",
73+
stan_file, fixed = TRUE)
74+
end <- start + 2
75+
stan_file <- stan_file[-c((start + 1):(start + 2))]
76+
stan_file[start] <- 'flat_ys ~ normal(flat_xs * b, flat_sigma_obs);'
77+
}
78+
}
79+
}
80+
8481
# Old ways of specifying arrays have been converted to errors in
85-
# the latest version of Cmdstan (2.34.0); this coincides with
82+
# the latest version of Cmdstan (2.32.0); this coincides with
8683
# a decision to stop automatically replacing these deprecations with
8784
# the canonicalizer, so we have no choice but to replace the old
88-
# syntax with this ugly bit of code:
89-
if(requireNamespace('cmdstanr') & cmdstanr::cmdstan_version() >= "2.33.0"){
85+
# syntax with this ugly bit of code
9086

87+
# rstan dependency in Description should mean that updates should
88+
# always happen (mvgam depends on rstan >= 2.29.0)
89+
update_code <- TRUE
90+
91+
# Tougher if using cmdstanr
92+
if(backend == 'cmdstanr'){
93+
if(cmdstanr::cmdstan_version() < "2.32.0"){
94+
# If the autoformat options from cmdstanr are available,
95+
# make use of them to update any deprecated array syntax
96+
update_code <- FALSE
97+
}
98+
}
99+
100+
if(update_code){
91101
# Data modifications
92102
stan_file[grep("int<lower=0> ytimes[n, n_series]; // time-ordered matrix (which col in X belongs to each [time, series] observation?)",
93103
stan_file, fixed = TRUE)] <-
@@ -426,14 +436,21 @@ remove_likelihood = function(model_file){
426436
}
427437
}
428438

429-
stan_file <- cmdstanr::write_stan_file(stan_file)
430-
cmdstan_mod <- cmdstanr::cmdstan_model(stan_file, compile = FALSE)
431-
out <- utils::capture.output(
432-
cmdstan_mod$format(
433-
max_line_length = 80,
434-
canonicalize = TRUE,
435-
overwrite_file = overwrite_file, backup = FALSE))
436-
paste0(out, collapse = "\n")
439+
if(backend == 'rstan'){
440+
options(stanc.allow_optimizations = TRUE,
441+
stanc.auto_format = TRUE)
442+
out <- rstan::stanc(model_code = stan_file)$model_code
443+
} else {
444+
stan_file <- cmdstanr::write_stan_file(stan_file)
445+
cmdstan_mod <- cmdstanr::cmdstan_model(stan_file, compile = FALSE)
446+
out <- utils::capture.output(
447+
cmdstan_mod$format(
448+
max_line_length = 80,
449+
canonicalize = TRUE,
450+
overwrite_file = overwrite_file, backup = FALSE))
451+
out <- paste0(out, collapse = "\n")
452+
}
453+
return(out)
437454
}
438455

439456
#### Replacement for MCMCvis functions to remove dependence on rstan for working

src/RcppExports.o

7.46 KB
Binary file not shown.

src/mvgam.dll

1 KB
Binary file not shown.

src/trend_funs.o

1.62 KB
Binary file not shown.

0 commit comments

Comments
 (0)