Skip to content

Commit 678d003

Browse files
author
Nicholas Clark
committed
array updates for newest Cmdstan; more piecewise tests
1 parent d4e89f2 commit 678d003

File tree

144 files changed

+931
-555
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

144 files changed

+931
-555
lines changed

R/add_nmixture.R

+27-4
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,36 @@ add_nmixture = function(model_file,
1515
stop('Max abundances must be supplied as a variable named "cap" for N-mixture models',
1616
call. = FALSE)
1717
}
18-
cap <- data_train$cap
18+
19+
if(inherits(data_train, 'data.frame')){
20+
cap = data_train %>%
21+
dplyr::arrange(series, time) %>%
22+
dplyr::pull(cap)
23+
} else {
24+
cap = data.frame(series = data_train$series,
25+
cap = data_train$cap,
26+
time = data_train$time)%>%
27+
dplyr::arrange(series, time) %>%
28+
dplyr::pull(cap)
29+
}
30+
1931
if(!is.null(data_test)){
2032
if(!(exists('cap', where = data_test))) {
2133
stop('Max abundances must be supplied in test data as a variable named "cap" for N-mixture models',
2234
call. = FALSE)
2335
}
24-
cap <- c(cap, data_test$cap)
36+
if(inherits(data_test, 'data.frame')){
37+
captest = data_test %>%
38+
dplyr::arrange(series, time) %>%
39+
dplyr::pull(cap)
40+
} else {
41+
captest = data.frame(series = data_test$series,
42+
cap = data_test$cap,
43+
time = data_test$time)%>%
44+
dplyr::arrange(series, time) %>%
45+
dplyr::pull(cap)
46+
}
47+
cap <- c(cap, captest)
2548
}
2649

2750
validate_pos_integers(cap)
@@ -33,7 +56,7 @@ add_nmixture = function(model_file,
3356

3457
model_data$cap <- as.vector(cap)
3558

36-
if(any(model_data$cap < model_data$y)){
59+
if(any(model_data$cap[model_data$obs_ind] < model_data$flat_ys)){
3760
stop(paste0('Some "cap" terms are < the observed counts. This is not allowed'),
3861
call. = FALSE)
3962
}
@@ -243,7 +266,7 @@ add_nmixture = function(model_file,
243266
'array[n, n_series] int latent_ypred;\n',
244267
'array[total_obs] int latent_truncpred;\n',
245268
'vector[n_nonmissing] flat_ps;\n',
246-
'int flat_caps[n_nonmissing];',
269+
'int flat_caps[n_nonmissing];\n',
247270
'vector[total_obs] flat_trends;\n',
248271
'vector[n_nonmissing] flat_trends_nonmis;\n',
249272
'vector[total_obs] detprob;\n',

R/get_mvgam_priors.R

+1-1
Original file line numberDiff line numberDiff line change
@@ -462,7 +462,7 @@ get_mvgam_priors = function(formula,
462462
}
463463

464464
# Remove sigma prior if this is an N-mixture with no dynamics
465-
if(add_nmix){
465+
if(add_nmix & trend_model == 'None'){
466466
out <- out[-grep('vector<lower=0>[n_lv] sigma;',
467467
out$param_name,
468468
fixed = TRUE),]

R/gp.R

+9-4
Original file line numberDiff line numberDiff line change
@@ -158,10 +158,15 @@ make_gp_additions = function(gp_details, data,
158158

159159
# Add coefficient indices to attribute table and to Stan data
160160
for(covariate in seq_along(gp_att_table)){
161-
# coef_indices <- grep(gp_att_table[[covariate]]$name,
162-
# names(coef(mgcv_model)), fixed = TRUE)
163-
coef_indices <- which(grepl(gp_att_table[[covariate]]$name,
164-
names(coef(mgcv_model)), fixed = TRUE) &
161+
# coef_indices <- which(grepl(gp_att_table[[covariate]]$name,
162+
# names(coef(mgcv_model)), fixed = TRUE) &
163+
# !grepl(paste0(gp_att_table[[covariate]]$name,':'),
164+
# names(coef(mgcv_model)), fixed = TRUE) == TRUE)
165+
166+
coef_indices <- which(grepl(paste0(gsub("([()])","\\\\\\1",
167+
gp_att_table[[covariate]]$name),
168+
'\\.+[0-9]'),
169+
names(coef(mgcv_model)), fixed = FALSE) &
165170
!grepl(paste0(gp_att_table[[covariate]]$name,':'),
166171
names(coef(mgcv_model)), fixed = TRUE) == TRUE)
167172

R/mvgam.R

+2-3
Original file line numberDiff line numberDiff line change
@@ -1718,8 +1718,7 @@ mvgam = function(formula,
17181718
# Auto-format the model file
17191719
if(autoformat){
17201720
if(requireNamespace('cmdstanr') & cmdstanr::cmdstan_version() >= "2.29.0") {
1721-
tmp_file <- cmdstanr::write_stan_file(vectorised$model_file)
1722-
vectorised$model_file <- .autoformat(tmp_file,
1721+
vectorised$model_file <- .autoformat(vectorised$model_file,
17231722
overwrite_file = FALSE)
17241723
}
17251724
vectorised$model_file <- readLines(textConnection(vectorised$model_file),
@@ -1908,7 +1907,7 @@ mvgam = function(formula,
19081907
cpp_options = list(stan_threads = TRUE))
19091908
} else {
19101909
cmd_mod <- cmdstanr::cmdstan_model(cmdstanr::write_stan_file(vectorised$model_file),
1911-
stanc_options = list('O1'))
1910+
stanc_options = list('O1'),)
19121911
}
19131912

19141913
} else {

0 commit comments

Comments
 (0)