diff --git a/R/backends.R b/R/backends.R new file mode 100644 index 00000000..23a9c65d --- /dev/null +++ b/R/backends.R @@ -0,0 +1,1107 @@ +#### Helper functions for preparing and manipulating Stan models #### +# All functions were modified from `brms` source code and so all credit must +# go to the `brms` development team + +#' parse Stan model code with cmdstanr +#' @param model Stan model code +#' @return validated Stan model code +#' @noRd +.model_cmdstanr <- function(model_file, + threads = 1, + silent = 1, + ...) { + + if(silent < 2){ + message('Compiling Stan program using cmdstanr') + message() + } + + if(cmdstanr::cmdstan_version() < "2.26.0"){ + warning('Your version of Cmdstan is < 2.26.0; some mvgam models may not work properly!') + } + + temp_file <- cmdstanr::write_stan_file(model_file) + + if(cmdstanr::cmdstan_version() >= "2.29.0"){ + if(threads > 1){ + out <- eval_silent( + cmdstanr::cmdstan_model(temp_file, + stanc_options = list('O1'), + cpp_options = list(stan_threads = TRUE), + ...), + type = "message", try = TRUE, silent = silent > 0L + ) + } else { + out <- eval_silent( + cmdstanr::cmdstan_model(temp_file, + stanc_options = list('O1'), + ...), + type = "message", try = TRUE, silent = silent > 0L + ) + } + } else { + if(threads > 1){ + out <- eval_silent( + cmdstanr::cmdstan_model(temp_file, + cpp_options = list(stan_threads = TRUE), + ...), + type = "message", try = TRUE, silent = silent + ) + } else { + out <- eval_silent( + cmdstanr::cmdstan_model(temp_file, + ...), + type = "message", try = TRUE, silent = silent + ) + } + } + + return(out) +} + +#' fit Stan model with cmdstanr using HMC sampling or variational inference +#' @param model a compiled Stan model +#' @param data named list to be passed to Stan as data +#' @return a fitted Stan model +#' @noRd +.sample_model_cmdstanr <- function(model, + algorithm = 'sampling', + prior_simulation = FALSE, + data, + inits, + chains = 4, + parallel = TRUE, + silent = 1L, + max_treedepth, + adapt_delta, + threads = 1, + burnin, + samples, + param = param, + save_all_pars = FALSE, + ...) { + + if(algorithm == 'pathfinder'){ + if(cmdstanr::cmdstan_version() < "2.33"){ + stop('Your version of Cmdstan is < 2.33; the "pathfinder" algorithm is not available', + call. = FALSE) + } + + if(utils::packageVersion('cmdstanr') < '0.6.1.9000'){ + stop('Your version of cmdstanr is < 0.6.1.9000; the "pathfinder" algorithm is not available', + call. = FALSE) + } + } + + warn_inits_def <- getOption('cmdstanr_warn_inits') + options(cmdstanr_warn_inits = FALSE) + on.exit(options(cmdstanr_warn_inits = warn_inits_def)) + + # Construct cmdstanr sampling arguments + args <- nlist(data = data) + dots <- list(...) + args[names(dots)] <- dots + + if(prior_simulation){ + burnin <- 200 + } + + # do the actual sampling + if (silent < 2) { + message("Start sampling") + } + + if(algorithm == 'sampling'){ + c(args) <- nlist( + chains = chains, + refresh = 100, + init = inits, + max_treedepth, + adapt_delta, + diagnostics = NULL, + iter_sampling = samples, + iter_warmup = burnin, + show_messages = silent < 2, + show_exceptions = silent == 0) + + if(parallel){ + c(args) <- nlist(parallel_chains = min(c(chains, parallel::detectCores() - 1))) + } + + if(threads > 1){ + c(args) <- nlist(threads_per_chain = threads) + } + + out <- do_call(model$sample, args) + + } else if (algorithm %in% c("fullrank", "meanfield")) { + c(args) <- nlist(algorithm = algorithm, + refresh = 500, + output_samples = samples) + if(threads > 1){ + c(args) <- nlist(threads = threads) + } + + out <- do_call(model$variational, args) + + } else if (algorithm %in% c("laplace")) { + c(args) <- nlist(refresh = 500, + draws = samples) + if(threads > 1){ + c(args) <- nlist(threads = threads) + } + + out <- do_call(model$laplace, args) + + } else if (algorithm %in% c("pathfinder")) { + c(args) <- nlist(refresh = 500, + draws = samples) + if(threads > 1){ + c(args) <- nlist(num_threads = threads) + } + + out <- do_call(model$pathfinder, args) + } else { + stop("Algorithm '", algorithm, "' is not supported.", + call. = FALSE) + } + + if(algorithm %in% c('meanfield', 'fullrank', + 'laplace', 'pathfinder')){ + param <- param[!param %in% 'lp__'] + } + + # Convert model files to stan_fit class for consistency + if(save_all_pars){ + out_gam_mod <- read_csv_as_stanfit(out$output_files()) + } else { + out_gam_mod <- read_csv_as_stanfit(out$output_files(), + variables = param) + } + + out_gam_mod <- repair_stanfit(out_gam_mod) + + if(algorithm %in% c('meanfield', 'fullrank', + 'pathfinder', 'laplace')){ + out_gam_mod@sim$iter <- samples + out_gam_mod@sim$thin <- 1 + out_gam_mod@stan_args[[1]]$method <- 'sampling' + } + + return(out_gam_mod) +} + +#' fit Stan model with rstan +#' @param model a compiled Stan model +#' @param sdata named list to be passed to Stan as data +#' @return a fitted Stan model +#' @noRd +.sample_model_rstan <- function(model, + algorithm = 'sampling', + prior_simulation = FALSE, + data, + inits, + chains = 4, + parallel = TRUE, + silent = 1L, + max_treedepth, + adapt_delta, + threads = 1, + burnin, + samples, + thin, + ...) { + + if(rstan::stan_version() < "2.26.0"){ + warning('Your version of Stan is < 2.26.0; some mvgam models may not work properly!') + } + + if(algorithm == 'pathfinder'){ + stop('The "pathfinder" algorithm is not yet available in rstan', + call. = FALSE) + } + + if(algorithm == 'laplace'){ + stop('The "laplace" algorithm is not yet available in rstan', + call. = FALSE) + } + + # Set up parallel cores + mc_cores_def <- getOption('mc.cores') + options(mc.cores = parallel::detectCores()) + on.exit(options(mc.cores = mc_cores_def)) + + # Fit the model in rstan using custom control parameters + if(threads > 1){ + if(utils::packageVersion("rstan") >= "2.26") { + threads_per_chain_def <- rstan::rstan_options("threads_per_chain") + on.exit(rstan::rstan_options(threads_per_chain = threads_per_chain_def)) + rstan::rstan_options(threads_per_chain = threads) + } else { + stop("Threading is not supported by backend 'rstan' version ", + utils::packageVersion("rstan"), ".", + call. = FALSE) + } + } + + # Compile the model + if(silent < 2L){ + message('Compiling Stan program using rstan') + message() + } + + stan_mod <- eval_silent( + rstan::stan_model(model_code = model, verbose = silent < 1L), + type = "message", try = TRUE, silent = silent >= 1L) + + # Construct rstan sampling arguments + args <- nlist(object = stan_mod, + data = data) + dots <- list(...) + args[names(dots)] <- dots + + if(samples <= burnin){ + samples <- burnin + samples + } + + # do the actual sampling + if (silent < 2) { + message("Start sampling") + } + + if(algorithm %in% c("sampling", "fixed_param")) { + stan_control <- list(max_treedepth = max_treedepth, + adapt_delta = adapt_delta) + if(prior_simulation){ + burnin = 200; samples = 700 + } + if(parallel){ + c(args) <- nlist(cores = min(c(chains, parallel::detectCores() - 1))) + } + + c(args) <- nlist(warmup = burnin, + iter = samples, + chains = chains, + control = stan_control, + show_messages = silent < 1L, + init = inits, + verbose = FALSE, + thin = thin, + pars = NA, + refresh = 100, + save_warmup = FALSE) + + out <- do_call(rstan::sampling, args) + + } else if (algorithm %in% c("fullrank", "meanfield")) { + c(args) <- nlist(algorithm, + output_samples = samples, + pars = NA) + out <- do_call(rstan::vb, args) + + } else { + stop("Algorithm '", algorithm, "' is not supported.", + call. = FALSE) + } + + out <- repair_stanfit(out) + return(out) +} + +#' @importFrom methods new +#' @noRd +read_csv_as_stanfit <- function(files, variables = NULL, + sampler_diagnostics = NULL) { + + # Code borrowed from brms: https://github.com/paul-buerkner/brms/R/backends.R#L603 + repair_names <- function(x) { + x <- sub("\\.", "[", x) + x <- gsub("\\.", ",", x) + x[grep("\\[", x)] <- paste0(x[grep("\\[", x)], "]") + x + } + + if(!is.null(variables)){ + # ensure that only relevant variables are read from CSV + metadata <- cmdstanr::read_cmdstan_csv( + files = files, variables = "", sampler_diagnostics = "") + + all_vars <- repair_names(metadata$metadata$variables) + all_vars <- unique(sub("\\[.+", "", all_vars)) + variables <- variables[variables %in% all_vars] + } + + csfit <- cmdstanr::read_cmdstan_csv( + files = files, variables = variables, + sampler_diagnostics = sampler_diagnostics, + format = NULL + ) + + # @model_name + model_name = gsub(".csv", "", basename(files[[1]])) + + # @model_pars + svars <- csfit$metadata$stan_variables + if (!is.null(variables)) { + variables_main <- unique(gsub("\\[.*\\]", "", variables)) + svars <- intersect(variables_main, svars) + } + if ("lp__" %in% svars) { + svars <- c(setdiff(svars, "lp__"), "lp__") + } + pars_oi <- svars + par_names <- csfit$metadata$model_params + + # @par_dims + par_dims <- vector("list", length(svars)) + + names(par_dims) <- svars + par_dims <- lapply(par_dims, function(x) x <- integer(0)) + + pdims_num <- ulapply( + svars, function(x) sum(grepl(paste0("^", x, "\\[.*\\]$"), par_names)) + ) + par_dims[pdims_num != 0] <- + csfit$metadata$stan_variable_sizes[svars][pdims_num != 0] + + # @mode + mode <- 0L + + # @sim + rstan_diagn_order <- c("accept_stat__", "treedepth__", "stepsize__", + "divergent__", "n_leapfrog__", "energy__") + + if (!is.null(sampler_diagnostics)) { + rstan_diagn_order <- rstan_diagn_order[rstan_diagn_order %in% sampler_diagnostics] + } + + res_vars <- c(".chain", ".iteration", ".draw") + if ("post_warmup_draws" %in% names(csfit)) { + # for MCMC samplers + n_chains <- max( + posterior::nchains(csfit$warmup_draws), + posterior::nchains(csfit$post_warmup_draws) + ) + n_iter_warmup <- posterior::niterations(csfit$warmup_draws) + n_iter_sample <- posterior::niterations(csfit$post_warmup_draws) + if (n_iter_warmup > 0) { + csfit$warmup_draws <- posterior::as_draws_df(csfit$warmup_draws) + csfit$warmup_sampler_diagnostics <- + posterior::as_draws_df(csfit$warmup_sampler_diagnostics) + } + if (n_iter_sample > 0) { + csfit$post_warmup_draws <- posterior::as_draws_df(csfit$post_warmup_draws) + csfit$post_warmup_sampler_diagnostics <- + posterior::as_draws_df(csfit$post_warmup_sampler_diagnostics) + } + + # called 'samples' for consistency with rstan + samples <- rbind(csfit$warmup_draws, csfit$post_warmup_draws) + # manage memory + csfit$warmup_draws <- NULL + csfit$post_warmup_draws <- NULL + + # prepare sampler diagnostics + diagnostics <- rbind(csfit$warmup_sampler_diagnostics, + csfit$post_warmup_sampler_diagnostics) + # manage memory + csfit$warmup_sampler_diagnostics <- NULL + csfit$post_warmup_sampler_diagnostics <- NULL + # convert to regular data.frame + diagnostics <- as.data.frame(diagnostics) + diag_chain_ids <- diagnostics$.chain + diagnostics[res_vars] <- NULL + + } else if ("draws" %in% names(csfit)) { + # for variational inference "samplers" + n_chains <- 1 + n_iter_warmup <- 0 + n_iter_sample <- posterior::niterations(csfit$draws) + if (n_iter_sample > 0) { + csfit$draws <- posterior::as_draws_df(csfit$draws) + } + + # called 'samples' for consistency with rstan + samples <- csfit$draws + # manage memory + csfit$draws <- NULL + + # VI has no sampler diagnostics + diag_chain_ids <- rep(1L, nrow(samples)) + diagnostics <- as.data.frame(matrix(nrow = nrow(samples), ncol = 0)) + } + + # convert to regular data.frame + samples <- as.data.frame(samples) + chain_ids <- samples$.chain + samples[res_vars] <- NULL + + move2end <- function(x, last) { + x[c(setdiff(names(x), last), last)] + } + + if ("lp__" %in% colnames(samples)) { + samples <- move2end(samples, "lp__") + } + + fnames_oi <- colnames(samples) + + colnames(samples) <- gsub("\\[", ".", colnames(samples)) + colnames(samples) <- gsub("\\]", "", colnames(samples)) + colnames(samples) <- gsub("\\,", ".", colnames(samples)) + + # split samples into chains + samples <- split(samples, chain_ids) + names(samples) <- NULL + + # split diagnostics into chains + diagnostics <- split(diagnostics, diag_chain_ids) + names(diagnostics) <- NULL + + # @sim$sample: largely 113-130 from rstan::read_stan_csv + values <- list() + values$algorithm <- csfit$metadata$algorithm + values$engine <- csfit$metadata$engine + values$metric <- csfit$metadata$metric + + sampler_t <- NULL + if (!is.null(values$algorithm)) { + if (values$algorithm == "rwm" || values$algorithm == "Metropolis") { + sampler_t <- "Metropolis" + } else if (values$algorithm == "hmc") { + if (values$engine == "static") { + sampler_t <- "HMC" + } else { + if (values$metric == "unit_e") { + sampler_t <- "NUTS(unit_e)" + } else if (values$metric == "diag_e") { + sampler_t <- "NUTS(diag_e)" + } else if (values$metric == "dense_e") { + sampler_t <- "NUTS(dense_e)" + } + } + } + } + + adapt_info <- vector("list", 4) + idx_samples <- (n_iter_warmup + 1):(n_iter_warmup + n_iter_sample) + + for (i in seq_along(samples)) { + m <- colMeans(samples[[i]][idx_samples, , drop=FALSE]) + rownames(samples[[i]]) <- seq_rows(samples[[i]]) + attr(samples[[i]], "sampler_params") <- diagnostics[[i]][rstan_diagn_order] + rownames(attr(samples[[i]], "sampler_params")) <- seq_rows(diagnostics[[i]]) + + # reformat back to text + if (is_equal(sampler_t, "NUTS(dense_e)")) { + mmatrix_txt <- "\n# Elements of inverse mass matrix:\n# " + mmat <- paste0(apply(csfit$inv_metric[[i]], 1, paste0, collapse=", "), + collapse="\n# ") + } else { + mmatrix_txt <- "\n# Diagonal elements of inverse mass matrix:\n# " + mmat <- paste0(csfit$inv_metric[[i]], collapse = ", ") + } + + adapt_info[[i]] <- paste0("# Step size = ", + csfit$step_size[[i]], + mmatrix_txt, + mmat, "\n# ") + + attr(samples[[i]], "adaptation_info") <- adapt_info[[i]] + + attr(samples[[i]], "args") <- list(sampler_t = sampler_t, chain_id = i) + + if (NROW(csfit$metadata$time)) { + time_i <- as.double(csfit$metadata$time[i, c("warmup", "sampling")]) + names(time_i) <- c("warmup", "sample") + attr(samples[[i]], "elapsed_time") <- time_i + } + + attr(samples[[i]], "mean_pars") <- m[-length(m)] + attr(samples[[i]], "mean_lp__") <- m["lp__"] + } + + perm_lst <- lapply(seq_len(n_chains), function(id) sample.int(n_iter_sample)) + + # @sim + sim <- list( + samples = samples, + iter = csfit$metadata$iter_sampling + csfit$metadata$iter_warmup, + thin = csfit$metadata$thin, + warmup = csfit$metadata$iter_warmup, + chains = n_chains, + n_save = rep(n_iter_sample + n_iter_warmup, n_chains), + warmup2 = rep(n_iter_warmup, n_chains), + permutation = perm_lst, + pars_oi = pars_oi, + dims_oi = par_dims, + fnames_oi = fnames_oi, + n_flatnames = length(fnames_oi) + ) + + # @stan_args + sargs <- list( + stan_version_major = as.character(csfit$metadata$stan_version_major), + stan_version_minor = as.character(csfit$metadata$stan_version_minor), + stan_version_patch = as.character(csfit$metadata$stan_version_patch), + model = csfit$metadata$model_name, + start_datetime = gsub(" ", "", csfit$metadata$start_datetime), + method = csfit$metadata$method, + iter = csfit$metadata$iter_sampling + csfit$metadata$iter_warmup, + warmup = csfit$metadata$iter_warmup, + save_warmup = csfit$metadata$save_warmup, + thin = csfit$metadata$thin, + engaged = as.character(csfit$metadata$adapt_engaged), + gamma = csfit$metadata$gamma, + delta = csfit$metadata$adapt_delta, + kappa = csfit$metadata$kappa, + t0 = csfit$metadata$t0, + init_buffer = as.character(csfit$metadata$init_buffer), + term_buffer = as.character(csfit$metadata$term_buffer), + window = as.character(csfit$metadata$window), + algorithm = csfit$metadata$algorithm, + engine = csfit$metadata$engine, + max_depth = csfit$metadata$max_treedepth, + metric = csfit$metadata$metric, + metric_file = character(0), # not stored in metadata + stepsize = NA, # add in loop + stepsize_jitter = csfit$metadata$stepsize_jitter, + num_chains = as.character(csfit$metadata$num_chains), + chain_id = NA, # add in loop + file = character(0), # not stored in metadata + init = NA, # add in loop + seed = as.character(csfit$metadata$seed), + file = NA, # add in loop + diagnostic_file = character(0), # not stored in metadata + refresh = as.character(csfit$metadata$refresh), + sig_figs = as.character(csfit$metadata$sig_figs), + profile_file = csfit$metadata$profile_file, + num_threads = as.character(csfit$metadata$threads_per_chain), + stanc_version = gsub(" ", "", csfit$metadata$stanc_version), + stancflags = character(0), # not stored in metadata + adaptation_info = NA, # add in loop + has_time = is.numeric(csfit$metadata$time$total), + time_info = NA, # add in loop + sampler_t = sampler_t + ) + + sargs_rep <- replicate(n_chains, sargs, simplify = FALSE) + + for (i in seq_along(sargs_rep)) { + sargs_rep[[i]]$chain_id <- i + sargs_rep[[i]]$stepsize <- csfit$metadata$step_size[i] + sargs_rep[[i]]$init <- as.character(csfit$metadata$init[i]) + # two 'file' elements: select the second + file_idx <- which(names(sargs_rep[[i]]) == "file") + sargs_rep[[i]][[file_idx[2]]] <- files[[i]] + + sargs_rep[[i]]$adaptation_info <- adapt_info[[i]] + + if (NROW(csfit$metadata$time)) { + sargs_rep[[i]]$time_info <- paste0( + c("# Elapsed Time: ", "# ", "# ", "# "), + c(csfit$metadata$time[i, c("warmup", "sampling", "total")], ""), + c(" seconds (Warm-up)", " seconds (Sampling)", " seconds (Total)", "") + ) + } + } + + # @stanmodel + null_dso <- new( + "cxxdso", sig = list(character(0)), dso_saved = FALSE, + dso_filename = character(0), modulename = character(0), + system = R.version$system, cxxflags = character(0), + .CXXDSOMISC = new.env(parent = emptyenv()) + ) + null_sm <- new( + "stanmodel", model_name = model_name, model_code = character(0), + model_cpp = list(), dso = null_dso + ) + + # @date + sdate <- do.call(max, lapply(files, function(csv) file.info(csv)$mtime)) + sdate <- format(sdate, "%a %b %d %X %Y") + + new( + "stanfit", + model_name = model_name, + model_pars = svars, + par_dims = par_dims, + mode = mode, + sim = sim, + inits = list(), + stan_args = sargs_rep, + stanmodel = null_sm, + date = sdate, + .MISC = new.env(parent = emptyenv()) + ) +} + +#' @noRd +.autoformat <- function(stan_file, overwrite_file = TRUE, + backend = 'cmdstanr', silent = TRUE){ + + # No need to fill lv_coefs in each iteration if this is a + # trend_formula model + if(any(grepl('lv_coefs = Z;', + stan_file, fixed = TRUE)) & + !any(grepl('vector[n_lv] LV[n];', + stan_file, fixed = TRUE))){ + stan_file <- stan_file[-grep('lv_coefs = Z;', + stan_file, fixed = TRUE)] + stan_file <- stan_file[-grep('matrix[n_series, n_lv] lv_coefs;', + stan_file, fixed = TRUE)] + stan_file[grep('trend[i, s] = dot_product(lv_coefs[s,], LV[i,]);', + stan_file, fixed = TRUE)] <- + 'trend[i, s] = dot_product(Z[s,], LV[i,]);' + + stan_file[grep('// posterior predictions', + stan_file, fixed = TRUE)-1] <- + paste0(stan_file[grep('// posterior predictions', + stan_file, fixed = TRUE)-1], + '\n', + 'matrix[n_series, n_lv] lv_coefs = Z;') + stan_file <- readLines(textConnection(stan_file), n = -1) + } + + if(backend == 'rstan' & rstan::stan_version() < '2.29.0'){ + # normal_id_glm became available in 2.29.0; this needs to be replaced + # with the older non-glm version + if(any(grepl('normal_id_glm', + stan_file, fixed = TRUE))){ + if(any(grepl("flat_ys ~ normal_id_glm(flat_xs,", + stan_file, fixed = TRUE))){ + start <- grep("flat_ys ~ normal_id_glm(flat_xs,", + stan_file, fixed = TRUE) + end <- start + 2 + stan_file <- stan_file[-c((start + 1):(start + 2))] + stan_file[start] <- 'flat_ys ~ normal(flat_xs * b, flat_sigma_obs);' + } + } + } + + # Old ways of specifying arrays have been converted to errors in + # the latest version of Cmdstan (2.32.0); this coincides with + # a decision to stop automatically replacing these deprecations with + # the canonicalizer, so we have no choice but to replace the old + # syntax with this ugly bit of code + + # rstan dependency in Description should mean that updates should + # always happen (mvgam depends on rstan >= 2.29.0) + update_code <- TRUE + + # Tougher if using cmdstanr + if(backend == 'cmdstanr'){ + if(cmdstanr::cmdstan_version() < "2.32.0"){ + # If the autoformat options from cmdstanr are available, + # make use of them to update any deprecated array syntax + update_code <- FALSE + } + } + + if(update_code){ + # Data modifications + stan_file[grep("int ytimes[n, n_series]; // time-ordered matrix (which col in X belongs to each [time, series] observation?)", + stan_file, fixed = TRUE)] <- + 'array[n, n_series] int ytimes; // time-ordered matrix (which col in X belongs to each [time, series] observation?)' + + stan_file[grep("int flat_ys[n_nonmissing]; // flattened nonmissing observations", + stan_file, fixed = TRUE)] <- + 'array[n_nonmissing] int flat_ys; // flattened nonmissing observations' + + stan_file[grep("int obs_ind[n_nonmissing]; // indices of nonmissing observations", + stan_file, fixed = TRUE)] <- + "array[n_nonmissing] int obs_ind; // indices of nonmissing observations" + + if(any(grepl('int ytimes_trend[n, n_lv]; // time-ordered matrix for latent states', + stan_file, fixed = TRUE))){ + stan_file[grep("int ytimes_trend[n, n_lv]; // time-ordered matrix for latent states", + stan_file, fixed = TRUE)] <- + "array[n, n_lv] int ytimes_trend;" + } + + if(any(grepl('int idx', stan_file) & + grepl('// discontiguous index values', + stan_file, fixed = TRUE))){ + lines_replace <- which(grepl('int idx', stan_file) & + grepl('// discontiguous index values', + stan_file, fixed = TRUE)) + for(i in lines_replace){ + split_line <- strsplit(stan_file[i], ' ')[[1]] + + idxnum <- gsub(';', '', + gsub("\\s*\\[[^\\]+\\]", "", + as.character(split_line[2]))) + idx_length <- gsub("\\]", "", gsub("\\[", "", + regmatches(split_line[2], + gregexpr("\\[.*?\\]", split_line[2]))[[1]])) + + stan_file[i] <- + paste0('array[', + idx_length, + '] int ', + idxnum, + '; // discontiguous index values') + } + } + + if(any(grepl('int cap[total_obs]; // upper limits of latent abundances', + stan_file, fixed = TRUE))){ + stan_file[grep('int cap[total_obs]; // upper limits of latent abundances', + stan_file, fixed = TRUE)] <- + 'array[total_obs] int cap; // upper limits of latent abundances' + + stan_file[grep('int flat_caps[n_nonmissing];', + stan_file, fixed = TRUE)] <- + 'array[n_nonmissing] int flat_caps;' + } + + # Model modifications + if(any(grepl('real flat_phis[n_nonmissing];', + stan_file, fixed = TRUE))){ + stan_file[grep("real flat_phis[n_nonmissing];", + stan_file, fixed = TRUE)] <- + "array[n_nonmissing] real flat_phis;" + } + + # n-mixture modifications + if(any(grepl('real p_ub = poisson_cdf(max_k, lambda);', + stan_file, fixed = TRUE))){ + stan_file[grep('real p_ub = poisson_cdf(max_k, lambda);', + stan_file, fixed = TRUE)] <- + 'real p_ub = poisson_cdf(max_k | lambda);' + } + + # trend_formula modifications + if(any(grepl('int trend_rand_idx', stan_file) & + grepl('// trend random effect indices', + stan_file, fixed = TRUE))){ + lines_replace <- which(grepl('int trend_rand_idx', stan_file) & + grepl('// trend random effect indices', + stan_file, fixed = TRUE)) + for(i in lines_replace){ + split_line <- strsplit(stan_file[i], ' ')[[1]] + + trend_idxnum <- gsub(';', '', + gsub("\\s*\\[[^\\]+\\]", "", + as.character(split_line[2]))) + idx_length <- gsub("\\]", "", gsub("\\[", "", + regmatches(split_line[2], + gregexpr("\\[.*?\\]", split_line[2]))[[1]])) + + stan_file[i] <- + paste0('array[', + idx_length, + '] int ', + trend_idxnum, + '; // trend random effect indices') + } + } + + if(any(grepl('int trend_idx', stan_file) & + grepl('// discontiguous index values', + stan_file, fixed = TRUE))){ + lines_replace <- which(grepl('int trend_idx', stan_file) & + grepl('// discontiguous index values', + stan_file, fixed = TRUE)) + for(i in lines_replace){ + split_line <- strsplit(stan_file[i], ' ')[[1]] + + trend_idxnum <- gsub(';', '', + gsub("\\s*\\[[^\\]+\\]", "", + as.character(split_line[2]))) + idx_length <- gsub("\\]", "", gsub("\\[", "", + regmatches(split_line[2], + gregexpr("\\[.*?\\]", split_line[2]))[[1]])) + + stan_file[i] <- + paste0('array[', + idx_length, + '] int ', + trend_idxnum, + '; // discontiguous index values') + } + } + + if(any(grepl('vector[n_series] trend_raw[n];', + stan_file, fixed = TRUE))){ + stan_file[grep("vector[n_series] trend_raw[n];", + stan_file, fixed = TRUE)] <- + "array[n] vector[n_series] trend_raw;" + } + + if(any(grepl('vector[n_lv] error[n];', + stan_file, fixed = TRUE))){ + stan_file[grep("vector[n_lv] error[n];", + stan_file, fixed = TRUE)] <- + "array[n] vector[n_lv] error;" + } + + if(any(grepl('vector[n_series] error[n];', + stan_file, fixed = TRUE))){ + stan_file[grep("vector[n_series] error[n];", + stan_file, fixed = TRUE)] <- + "array[n] vector[n_series] error;" + } + + if(any(grepl('vector[n_lv] LV[n];', + stan_file, fixed = TRUE))){ + stan_file[grep("vector[n_lv] LV[n];", + stan_file, fixed = TRUE)] <- + "array[n] vector[n_lv] LV;" + } + + if(any(grepl('vector[n_series] mu[n - 1];', + stan_file, fixed = TRUE))){ + stan_file[grep("vector[n_series] mu[n - 1];", + stan_file, fixed = TRUE)] <- + "array[n - 1] vector[n_series] mu;" + } + + if(any(grepl('vector[n_lv] mu[n - 1];', + stan_file, fixed = TRUE))){ + stan_file[grep("vector[n_lv] mu[n - 1];", + stan_file, fixed = TRUE)] <- + "array[n - 1] vector[n_lv] mu;" + } + + if(any(grepl('vector[n_series] mu[n];', + stan_file, fixed = TRUE))){ + stan_file[grep("vector[n_series] mu[n];", + stan_file, fixed = TRUE)] <- + "array[n] vector[n_series] mu;" + } + + if(any(grepl('vector[n_lv] mu[n];', + stan_file, fixed = TRUE))){ + stan_file[grep("vector[n_lv] mu[n];", + stan_file, fixed = TRUE)] <- + "array[n] vector[n_lv] mu;" + } + # Generated quantity modifications + if(any(grepl('real ypred[n, n_series];', + stan_file, fixed = TRUE))){ + stan_file[grep("real ypred[n, n_series];", + stan_file, fixed = TRUE)] <- + "array[n, n_series] real ypred;" + } + + if(any(grepl('real ypred[n, n_series];', + stan_file, fixed = TRUE))){ + stan_file[grep("real ypred[n, n_series];", + stan_file, fixed = TRUE)] <- + "array[n, n_series] real ypred;" + } + + # ARMA model modifications + if(any(grepl('vector[n_series] epsilon[n];', + stan_file, fixed = TRUE))){ + stan_file[grep("vector[n_series] epsilon[n];", + stan_file, fixed = TRUE)] <- + "array[n] vector[n_series] epsilon;" + } + + if(any(grepl('vector[n_lv] epsilon[n];', + stan_file, fixed = TRUE))){ + stan_file[grep("vector[n_lv] epsilon[n];", + stan_file, fixed = TRUE)] <- + "array[n] vector[n_lv] epsilon;" + } + + # VARMA model modifications + if(any(grepl('matrix[n_series, n_series] P[1];', + stan_file, fixed = TRUE))){ + stan_file[grep("matrix[n_series, n_series] P[1];", + stan_file, fixed = TRUE)] <- + "array[1] matrix[n_series, n_series] P;" + + stan_file[grep("matrix[n_series, n_series] phiGamma[2, 1];", + stan_file, fixed = TRUE)] <- + "array[2, 1] matrix[n_series, n_series] phiGamma;" + } + + if(any(grepl('matrix initial_joint_var(matrix Sigma, matrix[] phi, matrix[] theta) {', + stan_file, fixed = TRUE))){ + stan_file[grep("matrix initial_joint_var(matrix Sigma, matrix[] phi, matrix[] theta) {", + stan_file, fixed = TRUE)] <- + "matrix initial_joint_var(matrix Sigma, array[] matrix phi, array[] matrix theta) {" + } + + if(any(grepl('matrix[n_lv, n_lv] P[1];', + stan_file, fixed = TRUE))){ + stan_file[grep("matrix[n_lv, n_lv] P[1];", + stan_file, fixed = TRUE)] <- + "array[1] matrix[n_lv, n_lv] P;" + + stan_file[grep("matrix[n_lv, n_lv] R[1];", + stan_file, fixed = TRUE)] <- + "array[1] matrix[n_lv, n_lv] R;" + + stan_file[grep("matrix[n_lv, n_lv] A_init[1];", + stan_file, fixed = TRUE)] <- + "array[1] matrix[n_lv, n_lv] A_init;" + + stan_file[grep("matrix[n_lv, n_lv] theta_init[1];", + stan_file, fixed = TRUE)] <- + "array[1] matrix[n_lv, n_lv] theta_init;" + } + + if(any(grepl('matrix[n_series, n_series] R[1];', + stan_file, fixed = TRUE))){ + + stan_file[grep("matrix[n_series, n_series] R[1];", + stan_file, fixed = TRUE)] <- + "array[1] matrix[n_series, n_series] R;" + + stan_file[grep("matrix[n_series, n_series] A_init[1];", + stan_file, fixed = TRUE)] <- + "array[1] matrix[n_series, n_series] A_init;" + + stan_file[grep("matrix[n_series, n_series] theta_init[1];", + stan_file, fixed = TRUE)] <- + "array[1] matrix[n_series, n_series] theta_init;" + } + + if(any(grepl('matrix[] rev_mapping(matrix[] P, matrix Sigma) {', + stan_file, fixed = TRUE))){ + stan_file[grep("matrix[] rev_mapping(matrix[] P, matrix Sigma) {", + stan_file, fixed = TRUE)] <- + "array[] matrix rev_mapping(array[] matrix P, matrix Sigma) {" + + stan_file[grep("matrix[m, m] phi_for[p, p]; matrix[m, m] phi_rev[p, p];", + stan_file, fixed = TRUE)] <- + 'array[p, p] matrix[m, m] phi_for; array[p, p] matrix[m, m] phi_rev;' + + stan_file[grep("matrix[m, m] Sigma_for[p+1]; matrix[m, m] Sigma_rev[p+1];", + stan_file, fixed = TRUE)] <- + 'array[p+1] matrix[m, m] Sigma_for; array[p+1] matrix[m, m] Sigma_rev;' + + stan_file[grep("matrix[m, m] S_for_list[p+1];", + stan_file, fixed = TRUE)] <- + 'array[p+1] matrix[m, m] S_for_list;' + } + + # VAR model modifications + if(any(grepl('matrix[n_lv, n_lv] phiGamma[2, 1];', + stan_file, fixed = TRUE))){ + stan_file[grep('matrix[n_lv, n_lv] phiGamma[2, 1];', + stan_file, fixed = TRUE)] <- + 'array[2, 1] matrix[n_lv, n_lv] phiGamma;' + } + + if(any(grepl('matrix[,] rev_mapping(matrix[] P, matrix Sigma) {', + stan_file, fixed = TRUE))){ + stan_file[grep("matrix[,] rev_mapping(matrix[] P, matrix Sigma) {", + stan_file, fixed = TRUE)] <- + "array[,] matrix rev_mapping(array[] matrix P, matrix Sigma) {" + + stan_file[grep("matrix[m, m] phi_for[p, p]; matrix[m, m] phi_rev[p, p];", + stan_file, fixed = TRUE)] <- + 'array[p, p] matrix[m, m] phi_for; array[p, p] matrix[m, m] phi_rev;' + + stan_file[grep("matrix[m, m] Sigma_for[p+1]; matrix[m, m] Sigma_rev[p+1];", + stan_file, fixed = TRUE)] <- + 'array[p+1] matrix[m, m] Sigma_for; array[p+1] matrix[m, m] Sigma_rev;' + + stan_file[grep("matrix[m, m] S_for_list[p+1];", + stan_file, fixed = TRUE)] <- + 'array[p+1] matrix[m, m] S_for_list;' + + stan_file[grep("matrix[m, m] Gamma_trans[p+1];", + stan_file, fixed = TRUE)] <- + 'array[p+1] matrix[m, m] Gamma_trans;' + + stan_file[grep("matrix[m, m] phiGamma[2, p];", + stan_file, fixed = TRUE)] <- + 'array[2, p] matrix[m, m] phiGamma;' + } + + if(any(grepl("real partial_log_lik(int[] seq, int start, int end,", + stan_file, fixed = TRUE))){ + stan_file[grepl("real partial_log_lik(int[] seq, int start, int end,", + stan_file, fixed = TRUE)] <- + "real partial_log_lik(array[] int seq, int start, int end," + } + + if(any(grepl("data vector Y, vector mu, real[] shape) {", + stan_file, fixed = TRUE))){ + stan_file[grepl("data vector Y, vector mu, real[] shape) {" , + stan_file, fixed = TRUE)] <- + "data vector Y, vector mu, array[] real shape) {" + } + + if(any(grepl("int seq[n_nonmissing]; // an integer sequence for reduce_sum slicing", + stan_file, fixed = TRUE))){ + stan_file[grepl("int seq[n_nonmissing]; // an integer sequence for reduce_sum slicing", + stan_file, fixed = TRUE)] <- + "array[n_nonmissing] int seq; // an integer sequence for reduce_sum slicing" + } + } + + if(backend == 'rstan'){ + options(stanc.allow_optimizations = TRUE, + stanc.auto_format = TRUE) + + out <- eval_silent( + rstan::stanc(model_code = stan_file), + type = "message", try = TRUE, silent = silent) + out <- out$model_code + + } else { + stan_file <- cmdstanr::write_stan_file(stan_file) + + cmdstan_mod <- eval_silent( + cmdstanr::cmdstan_model(stan_file, compile = FALSE), + type = "message", try = TRUE, silent = silent) + out <- utils::capture.output( + cmdstan_mod$format( + max_line_length = 80, + canonicalize = TRUE, + overwrite_file = overwrite_file, backup = FALSE)) + out <- paste0(out, collapse = "\n") + } + return(out) +} + + +#' @noRd +repair_stanfit <- function(x) { + if (!length(x@sim$fnames_oi)) { + # nothing to rename + return(x) + } + # the posterior package cannot deal with non-unique parameter names + # this case happens rarely but might happen when sample_prior = "yes" + x@sim$fnames_oi <- make.unique(as.character(x@sim$fnames_oi), "__") + for (i in seq_along(x@sim$samples)) { + # stanfit may have renamed dimension suffixes (#1218) + if (length(x@sim$samples[[i]]) == length(x@sim$fnames_oi)) { + names(x@sim$samples[[i]]) <- x@sim$fnames_oi + } + } + x +} + +#' @noRd +repair_variable_names <- function(x) { + x <- sub("\\.", "[", x) + x <- gsub("\\.", ",", x) + x[grep("\\[", x)] <- paste0(x[grep("\\[", x)], "]") + x +} + +#' @noRd +seq_rows = function (x){ + seq_len(NROW(x)) +} + +#' @noRd +is_equal <- function(x, y, check.attributes = FALSE, ...) { + isTRUE(all.equal(x, y, check.attributes = check.attributes, ...)) +} + + +#' @noRd +ulapply <- function(X, FUN, ..., recursive = TRUE, use.names = TRUE) { + unlist(lapply(X, FUN, ...), recursive, use.names) +} diff --git a/R/mvgam.R b/R/mvgam.R index 57a29009..3b8fa88b 100644 --- a/R/mvgam.R +++ b/R/mvgam.R @@ -198,6 +198,11 @@ #'The step size used by the numerical integrator is a function of `adapt_delta` in that increasing #'`adapt_delta` will result in a smaller step size and fewer divergences. Increasing `adapt_delta` will #'typically result in a slower sampler, but it will always lead to a more robust sampler +#'@param silent Verbosity level between `0` and `2`. If `1` (the default), most of the informational +#'messages of compiler and sampler are suppressed. If `2`, even more messages are suppressed. The +#'actual sampling progress is still printed. Set `refresh = 0` to turn this off as well. If using +#'`backend = "rstan"` you can also set open_progress = FALSE to prevent opening additional +#'progress bars. #'@param jags_path Optional character vector specifying the path to the location of the `JAGS` executable (.exe) to use #'for modelling if `use_stan == FALSE`. If missing, the path will be recovered from a call to \code{\link[runjags]{findjags}} #'@param ... Further arguments passed to Stan. @@ -586,8 +591,9 @@ mvgam = function(formula, algorithm = getOption("brms.algorithm", "sampling"), autoformat = TRUE, save_all_pars = FALSE, - max_treedepth, - adapt_delta, + max_treedepth = 12, + adapt_delta = 0.85, + silent = 1, jags_path, ...){ @@ -618,6 +624,7 @@ mvgam = function(formula, validate_pos_integer(burnin) validate_pos_integer(samples) validate_pos_integer(thin) + validate_silent(silent) # Upper bounds no longer supported as they are fairly useless upper_bounds <- rlang::missing_arg() @@ -1632,7 +1639,8 @@ mvgam = function(formula, cmdstanr::cmdstan_version() >= "2.29.0") { vectorised$model_file <- .autoformat(vectorised$model_file, overwrite_file = FALSE, - backend = 'cmdstanr') + backend = 'cmdstanr', + silent = silent >= 1L) } vectorised$model_file <- readLines(textConnection(vectorised$model_file), n = -1) @@ -1642,10 +1650,12 @@ mvgam = function(formula, if(autoformat){ vectorised$model_file <- .autoformat(vectorised$model_file, overwrite_file = FALSE, - backend = 'rstan') + backend = 'rstan', + silent = silent >= 1L) vectorised$model_file <- readLines(textConnection(vectorised$model_file), n = -1) } + # Replace new syntax if this is an older version of Stan if(rstan::stan_version() < "2.26"){ warning('Your version of rstan is out of date. Some features of mvgam may not work') @@ -1802,261 +1812,47 @@ mvgam = function(formula, } if(use_cmdstan){ - message('Using cmdstanr as the backend') - message() - if(cmdstanr::cmdstan_version() < "2.26.0"){ - warning('Your version of Cmdstan is < 2.26.0; some mvgam models may not work properly!') - } - - if(algorithm == 'pathfinder'){ - if(cmdstanr::cmdstan_version() < "2.33"){ - stop('Your version of Cmdstan is < 2.33; the "pathfinder" algorithm is not available', - call. = FALSE) - } - - if(utils::packageVersion('cmdstanr') < '0.6.1.9000'){ - stop('Your version of cmdstanr is < 0.6.1.9000; the "pathfinder" algorithm is not available', - call. = FALSE) - } - } - - # Prepare threading - if(cmdstanr::cmdstan_version() >= "2.29.0"){ - if(threads > 1){ - cmd_mod <- cmdstanr::cmdstan_model(cmdstanr::write_stan_file(vectorised$model_file), - stanc_options = list('O1'), - cpp_options = list(stan_threads = TRUE)) - } else { - cmd_mod <- cmdstanr::cmdstan_model(cmdstanr::write_stan_file(vectorised$model_file), - stanc_options = list('O1')) - } - - } else { - if(threads > 1){ - cmd_mod <- cmdstanr::cmdstan_model(cmdstanr::write_stan_file(vectorised$model_file), - cpp_options = list(stan_threads = TRUE)) - } else { - cmd_mod <- cmdstanr::cmdstan_model(cmdstanr::write_stan_file(vectorised$model_file)) - } - } - - if(missing(max_treedepth)){ - max_treedepth <- 12 - } - if(missing(adapt_delta)){ - adapt_delta <- 0.85 - } + # Prepare threading and generate the model + cmd_mod <- .model_cmdstanr(vectorised$model_file, + threads = threads, + silent = silent) # Condition the model using Cmdstan - if(algorithm == 'sampling'){ - if(prior_simulation){ - if(parallel){ - fit1 <- cmd_mod$sample(data = model_data, - chains = chains, - parallel_chains = min(c(chains, parallel::detectCores() - 1)), - threads_per_chain = if(threads > 1){ threads } else { NULL }, - refresh = 100, - init = inits, - max_treedepth = 12, - adapt_delta = 0.8, - iter_sampling = samples, - iter_warmup = 200, - show_messages = FALSE, - diagnostics = NULL, - ...) - } else { - fit1 <- cmd_mod$sample(data = model_data, - chains = chains, - threads_per_chain = if(threads > 1){ threads } else { NULL }, - refresh = 100, - init = inits, - max_treedepth = 12, - adapt_delta = 0.8, - iter_sampling = samples, - iter_warmup = 200, - show_messages = FALSE, - diagnostics = NULL, - ...) - } - - } else { - if(parallel){ - fit1 <- cmd_mod$sample(data = model_data, - chains = chains, - parallel_chains = min(c(chains, parallel::detectCores() - 1)), - threads_per_chain = if(threads > 1){ threads } else { NULL }, - refresh = 100, - init = inits, - max_treedepth = max_treedepth, - adapt_delta = adapt_delta, - iter_sampling = samples, - iter_warmup = burnin, - ...) - } else { - fit1 <- cmd_mod$sample(data = model_data, - chains = chains, - threads_per_chain = if(threads > 1){ threads } else { NULL }, - refresh = 100, - init = inits, - max_treedepth = max_treedepth, - adapt_delta = adapt_delta, - iter_sampling = samples, - iter_warmup = burnin, - ...) - } - } - } - - if(algorithm %in% c('meanfield', 'fullrank')){ - param <- param[!param %in% 'lp__'] - fit1 <- cmd_mod$variational(data = model_data, - threads = if(threads > 1){ threads } else { NULL }, - refresh = 500, - output_samples = samples, - algorithm = algorithm, - ...) - } - - if(algorithm %in% c('laplace')){ - param <- param[!param %in% 'lp__'] - fit1 <- cmd_mod$laplace(data = model_data, - threads = if(threads > 1){ threads } else { NULL }, - refresh = 500, - draws = samples, - ...) - } - - if(algorithm %in% c('pathfinder')){ - param <- param[!param %in% 'lp__'] - fit1 <- cmd_mod$pathfinder(data = model_data, - num_threads = if(threads > 1){ threads } else { NULL }, - refresh = 500, - draws = samples, - ...) - } - - # Convert model files to stan_fit class for consistency - if(save_all_pars){ - out_gam_mod <- read_csv_as_stanfit(fit1$output_files()) - } else { - out_gam_mod <- read_csv_as_stanfit(fit1$output_files(), - variables = param) - } - - out_gam_mod <- repair_stanfit(out_gam_mod) - - if(algorithm %in% c('meanfield', 'fullrank', - 'pathfinder', 'laplace')){ - out_gam_mod@sim$iter <- samples - out_gam_mod@sim$thin <- 1 - out_gam_mod@stan_args[[1]]$method <- 'sampling' - } + out_gam_mod <- .sample_model_cmdstanr(model = cmd_mod, + algorithm = algorithm, + prior_simulation = prior_simulation, + data = model_data, + inits = inits, + chains = chains, + parallel = parallel, + silent = silent, + max_treedepth = max_treedepth, + adapt_delta = adapt_delta, + threads = threads, + burnin = burnin, + samples = samples, + param = param, + save_all_pars = save_all_pars, + ...) } else { + # Condition the model using rstan requireNamespace('rstan', quietly = TRUE) - message('Using rstan as the backend') - message() - - if(rstan::stan_version() < "2.26.0"){ - warning('Your version of Stan is < 2.26.0; some mvgam models may not work properly!') - } - - if(algorithm == 'pathfinder'){ - stop('The "pathfinder" algorithm is not yet available in rstan', - call. = FALSE) - } - - if(algorithm == 'laplace'){ - stop('The "laplace" algorithm is not yet available in rstan', - call. = FALSE) - } - options(mc.cores = parallel::detectCores()) - - # Fit the model in rstan using custom control parameters - if(missing(max_treedepth)){ - max_treedepth <- 12 - } - - if(missing(adapt_delta)){ - adapt_delta <- 0.85 - } - - if(threads > 1){ - if(utils::packageVersion("rstan") >= "2.26") { - threads_per_chain_def <- rstan::rstan_options("threads_per_chain") - on.exit(rstan::rstan_options(threads_per_chain = threads_per_chain_def)) - rstan::rstan_options(threads_per_chain = threads) - } else { - stop("Threading is not supported by backend 'rstan' version ", - utils::packageVersion("rstan"), ".", - call. = FALSE) - } - } - - message("Compiling the Stan program...") - message() - stan_mod <- rstan::stan_model(model_code = vectorised$model_file, - verbose = TRUE) - if(samples <= burnin){ - samples <- burnin + samples - } - - if(prior_simulation){ - burnin <- 200 - samples <- 600 - adapt_delta <- 0.8 - max_treedepth <- 12 - } - - stan_control <- list(max_treedepth = max_treedepth, - adapt_delta = adapt_delta) - - if(algorithm == 'sampling'){ - if(parallel){ - fit1 <- rstan::sampling(stan_mod, - iter = samples, - warmup = burnin, - chains = chains, - data = model_data, - cores = min(c(chains, parallel::detectCores() - 1)), - init = inits, - verbose = FALSE, - thin = thin, - control = stan_control, - pars = NA, - refresh = 100, - save_warmup = FALSE, - ...) - } else { - fit1 <- rstan::sampling(stan_mod, - iter = samples, - warmup = burnin, - chains = chains, - data = model_data, - cores = 1, - init = inits, - verbose = FALSE, - thin = thin, - control = stan_control, - pars = NA, - refresh = 100, - save_warmup = FALSE, - ...) - } - } - - if(algorithm %in% c('fullrank', 'meanfield')){ - param <- param[!param %in% 'lp__'] - fit1 <- rstan::vb(stan_mod, - output_samples = samples, - data = model_data, - algorithm = algorithm, - pars = NA, - ...) - } - - out_gam_mod <- fit1 - out_gam_mod <- repair_stanfit(out_gam_mod) + out_gam_mod <- .sample_model_rstan(model = vectorised$model_file, + algorithm = algorithm, + prior_simulation = prior_simulation, + data = model_data, + inits = inits, + chains = chains, + parallel = parallel, + silent = silent, + max_treedepth = max_treedepth, + adapt_delta = adapt_delta, + threads = threads, + burnin = burnin, + samples = samples, + thin = thin, + ...) } } diff --git a/R/stan_utils.R b/R/stan_utils.R index d037f27d..befb9208 100644 --- a/R/stan_utils.R +++ b/R/stan_utils.R @@ -35,423 +35,6 @@ remove_likelihood = function(model_file){ model_file[-(start_remove:end_remove)] } -#' @noRd -.autoformat <- function(stan_file, overwrite_file = TRUE, - backend = 'cmdstanr'){ - - # No need to fill lv_coefs in each iteration if this is a - # trend_formula model - if(any(grepl('lv_coefs = Z;', - stan_file, fixed = TRUE)) & - !any(grepl('vector[n_lv] LV[n];', - stan_file, fixed = TRUE))){ - stan_file <- stan_file[-grep('lv_coefs = Z;', - stan_file, fixed = TRUE)] - stan_file <- stan_file[-grep('matrix[n_series, n_lv] lv_coefs;', - stan_file, fixed = TRUE)] - stan_file[grep('trend[i, s] = dot_product(lv_coefs[s,], LV[i,]);', - stan_file, fixed = TRUE)] <- - 'trend[i, s] = dot_product(Z[s,], LV[i,]);' - - stan_file[grep('// posterior predictions', - stan_file, fixed = TRUE)-1] <- - paste0(stan_file[grep('// posterior predictions', - stan_file, fixed = TRUE)-1], - '\n', - 'matrix[n_series, n_lv] lv_coefs = Z;') - stan_file <- readLines(textConnection(stan_file), n = -1) - } - - if(backend == 'rstan' & rstan::stan_version() < '2.29.0'){ - # normal_id_glm became available in 2.29.0; this needs to be replaced - # with the older non-glm version - if(any(grepl('normal_id_glm', - stan_file, fixed = TRUE))){ - if(any(grepl("flat_ys ~ normal_id_glm(flat_xs,", - stan_file, fixed = TRUE))){ - start <- grep("flat_ys ~ normal_id_glm(flat_xs,", - stan_file, fixed = TRUE) - end <- start + 2 - stan_file <- stan_file[-c((start + 1):(start + 2))] - stan_file[start] <- 'flat_ys ~ normal(flat_xs * b, flat_sigma_obs);' - } - } - } - - # Old ways of specifying arrays have been converted to errors in - # the latest version of Cmdstan (2.32.0); this coincides with - # a decision to stop automatically replacing these deprecations with - # the canonicalizer, so we have no choice but to replace the old - # syntax with this ugly bit of code - - # rstan dependency in Description should mean that updates should - # always happen (mvgam depends on rstan >= 2.29.0) - update_code <- TRUE - - # Tougher if using cmdstanr - if(backend == 'cmdstanr'){ - if(cmdstanr::cmdstan_version() < "2.32.0"){ - # If the autoformat options from cmdstanr are available, - # make use of them to update any deprecated array syntax - update_code <- FALSE - } - } - - if(update_code){ - # Data modifications - stan_file[grep("int ytimes[n, n_series]; // time-ordered matrix (which col in X belongs to each [time, series] observation?)", - stan_file, fixed = TRUE)] <- - 'array[n, n_series] int ytimes; // time-ordered matrix (which col in X belongs to each [time, series] observation?)' - - stan_file[grep("int flat_ys[n_nonmissing]; // flattened nonmissing observations", - stan_file, fixed = TRUE)] <- - 'array[n_nonmissing] int flat_ys; // flattened nonmissing observations' - - stan_file[grep("int obs_ind[n_nonmissing]; // indices of nonmissing observations", - stan_file, fixed = TRUE)] <- - "array[n_nonmissing] int obs_ind; // indices of nonmissing observations" - - if(any(grepl('int ytimes_trend[n, n_lv]; // time-ordered matrix for latent states', - stan_file, fixed = TRUE))){ - stan_file[grep("int ytimes_trend[n, n_lv]; // time-ordered matrix for latent states", - stan_file, fixed = TRUE)] <- - "array[n, n_lv] int ytimes_trend;" - } - - if(any(grepl('int idx', stan_file) & - grepl('// discontiguous index values', - stan_file, fixed = TRUE))){ - lines_replace <- which(grepl('int idx', stan_file) & - grepl('// discontiguous index values', - stan_file, fixed = TRUE)) - for(i in lines_replace){ - split_line <- strsplit(stan_file[i], ' ')[[1]] - - idxnum <- gsub(';', '', - gsub("\\s*\\[[^\\]+\\]", "", - as.character(split_line[2]))) - idx_length <- gsub("\\]", "", gsub("\\[", "", - regmatches(split_line[2], - gregexpr("\\[.*?\\]", split_line[2]))[[1]])) - - stan_file[i] <- - paste0('array[', - idx_length, - '] int ', - idxnum, - '; // discontiguous index values') - } - } - - if(any(grepl('int cap[total_obs]; // upper limits of latent abundances', - stan_file, fixed = TRUE))){ - stan_file[grep('int cap[total_obs]; // upper limits of latent abundances', - stan_file, fixed = TRUE)] <- - 'array[total_obs] int cap; // upper limits of latent abundances' - - stan_file[grep('int flat_caps[n_nonmissing];', - stan_file, fixed = TRUE)] <- - 'array[n_nonmissing] int flat_caps;' - } - - # Model modifications - if(any(grepl('real flat_phis[n_nonmissing];', - stan_file, fixed = TRUE))){ - stan_file[grep("real flat_phis[n_nonmissing];", - stan_file, fixed = TRUE)] <- - "array[n_nonmissing] real flat_phis;" - } - - # n-mixture modifications - if(any(grepl('real p_ub = poisson_cdf(max_k, lambda);', - stan_file, fixed = TRUE))){ - stan_file[grep('real p_ub = poisson_cdf(max_k, lambda);', - stan_file, fixed = TRUE)] <- - 'real p_ub = poisson_cdf(max_k | lambda);' - } - - # trend_formula modifications - if(any(grepl('int trend_rand_idx', stan_file) & - grepl('// trend random effect indices', - stan_file, fixed = TRUE))){ - lines_replace <- which(grepl('int trend_rand_idx', stan_file) & - grepl('// trend random effect indices', - stan_file, fixed = TRUE)) - for(i in lines_replace){ - split_line <- strsplit(stan_file[i], ' ')[[1]] - - trend_idxnum <- gsub(';', '', - gsub("\\s*\\[[^\\]+\\]", "", - as.character(split_line[2]))) - idx_length <- gsub("\\]", "", gsub("\\[", "", - regmatches(split_line[2], - gregexpr("\\[.*?\\]", split_line[2]))[[1]])) - - stan_file[i] <- - paste0('array[', - idx_length, - '] int ', - trend_idxnum, - '; // trend random effect indices') - } - } - - if(any(grepl('int trend_idx', stan_file) & - grepl('// discontiguous index values', - stan_file, fixed = TRUE))){ - lines_replace <- which(grepl('int trend_idx', stan_file) & - grepl('// discontiguous index values', - stan_file, fixed = TRUE)) - for(i in lines_replace){ - split_line <- strsplit(stan_file[i], ' ')[[1]] - - trend_idxnum <- gsub(';', '', - gsub("\\s*\\[[^\\]+\\]", "", - as.character(split_line[2]))) - idx_length <- gsub("\\]", "", gsub("\\[", "", - regmatches(split_line[2], - gregexpr("\\[.*?\\]", split_line[2]))[[1]])) - - stan_file[i] <- - paste0('array[', - idx_length, - '] int ', - trend_idxnum, - '; // discontiguous index values') - } - } - - if(any(grepl('vector[n_series] trend_raw[n];', - stan_file, fixed = TRUE))){ - stan_file[grep("vector[n_series] trend_raw[n];", - stan_file, fixed = TRUE)] <- - "array[n] vector[n_series] trend_raw;" - } - - if(any(grepl('vector[n_lv] error[n];', - stan_file, fixed = TRUE))){ - stan_file[grep("vector[n_lv] error[n];", - stan_file, fixed = TRUE)] <- - "array[n] vector[n_lv] error;" - } - - if(any(grepl('vector[n_series] error[n];', - stan_file, fixed = TRUE))){ - stan_file[grep("vector[n_series] error[n];", - stan_file, fixed = TRUE)] <- - "array[n] vector[n_series] error;" - } - - if(any(grepl('vector[n_lv] LV[n];', - stan_file, fixed = TRUE))){ - stan_file[grep("vector[n_lv] LV[n];", - stan_file, fixed = TRUE)] <- - "array[n] vector[n_lv] LV;" - } - - if(any(grepl('vector[n_series] mu[n - 1];', - stan_file, fixed = TRUE))){ - stan_file[grep("vector[n_series] mu[n - 1];", - stan_file, fixed = TRUE)] <- - "array[n - 1] vector[n_series] mu;" - } - - if(any(grepl('vector[n_lv] mu[n - 1];', - stan_file, fixed = TRUE))){ - stan_file[grep("vector[n_lv] mu[n - 1];", - stan_file, fixed = TRUE)] <- - "array[n - 1] vector[n_lv] mu;" - } - - if(any(grepl('vector[n_series] mu[n];', - stan_file, fixed = TRUE))){ - stan_file[grep("vector[n_series] mu[n];", - stan_file, fixed = TRUE)] <- - "array[n] vector[n_series] mu;" - } - - if(any(grepl('vector[n_lv] mu[n];', - stan_file, fixed = TRUE))){ - stan_file[grep("vector[n_lv] mu[n];", - stan_file, fixed = TRUE)] <- - "array[n] vector[n_lv] mu;" - } - # Generated quantity modifications - if(any(grepl('real ypred[n, n_series];', - stan_file, fixed = TRUE))){ - stan_file[grep("real ypred[n, n_series];", - stan_file, fixed = TRUE)] <- - "array[n, n_series] real ypred;" - } - - if(any(grepl('real ypred[n, n_series];', - stan_file, fixed = TRUE))){ - stan_file[grep("real ypred[n, n_series];", - stan_file, fixed = TRUE)] <- - "array[n, n_series] real ypred;" - } - - # ARMA model modifications - if(any(grepl('vector[n_series] epsilon[n];', - stan_file, fixed = TRUE))){ - stan_file[grep("vector[n_series] epsilon[n];", - stan_file, fixed = TRUE)] <- - "array[n] vector[n_series] epsilon;" - } - - if(any(grepl('vector[n_lv] epsilon[n];', - stan_file, fixed = TRUE))){ - stan_file[grep("vector[n_lv] epsilon[n];", - stan_file, fixed = TRUE)] <- - "array[n] vector[n_lv] epsilon;" - } - - # VARMA model modifications - if(any(grepl('matrix[n_series, n_series] P[1];', - stan_file, fixed = TRUE))){ - stan_file[grep("matrix[n_series, n_series] P[1];", - stan_file, fixed = TRUE)] <- - "array[1] matrix[n_series, n_series] P;" - - stan_file[grep("matrix[n_series, n_series] phiGamma[2, 1];", - stan_file, fixed = TRUE)] <- - "array[2, 1] matrix[n_series, n_series] phiGamma;" - } - - if(any(grepl('matrix initial_joint_var(matrix Sigma, matrix[] phi, matrix[] theta) {', - stan_file, fixed = TRUE))){ - stan_file[grep("matrix initial_joint_var(matrix Sigma, matrix[] phi, matrix[] theta) {", - stan_file, fixed = TRUE)] <- - "matrix initial_joint_var(matrix Sigma, array[] matrix phi, array[] matrix theta) {" - } - - if(any(grepl('matrix[n_lv, n_lv] P[1];', - stan_file, fixed = TRUE))){ - stan_file[grep("matrix[n_lv, n_lv] P[1];", - stan_file, fixed = TRUE)] <- - "array[1] matrix[n_lv, n_lv] P;" - - stan_file[grep("matrix[n_lv, n_lv] R[1];", - stan_file, fixed = TRUE)] <- - "array[1] matrix[n_lv, n_lv] R;" - - stan_file[grep("matrix[n_lv, n_lv] A_init[1];", - stan_file, fixed = TRUE)] <- - "array[1] matrix[n_lv, n_lv] A_init;" - - stan_file[grep("matrix[n_lv, n_lv] theta_init[1];", - stan_file, fixed = TRUE)] <- - "array[1] matrix[n_lv, n_lv] theta_init;" - } - - if(any(grepl('matrix[n_series, n_series] R[1];', - stan_file, fixed = TRUE))){ - - stan_file[grep("matrix[n_series, n_series] R[1];", - stan_file, fixed = TRUE)] <- - "array[1] matrix[n_series, n_series] R;" - - stan_file[grep("matrix[n_series, n_series] A_init[1];", - stan_file, fixed = TRUE)] <- - "array[1] matrix[n_series, n_series] A_init;" - - stan_file[grep("matrix[n_series, n_series] theta_init[1];", - stan_file, fixed = TRUE)] <- - "array[1] matrix[n_series, n_series] theta_init;" - } - - if(any(grepl('matrix[] rev_mapping(matrix[] P, matrix Sigma) {', - stan_file, fixed = TRUE))){ - stan_file[grep("matrix[] rev_mapping(matrix[] P, matrix Sigma) {", - stan_file, fixed = TRUE)] <- - "array[] matrix rev_mapping(array[] matrix P, matrix Sigma) {" - - stan_file[grep("matrix[m, m] phi_for[p, p]; matrix[m, m] phi_rev[p, p];", - stan_file, fixed = TRUE)] <- - 'array[p, p] matrix[m, m] phi_for; array[p, p] matrix[m, m] phi_rev;' - - stan_file[grep("matrix[m, m] Sigma_for[p+1]; matrix[m, m] Sigma_rev[p+1];", - stan_file, fixed = TRUE)] <- - 'array[p+1] matrix[m, m] Sigma_for; array[p+1] matrix[m, m] Sigma_rev;' - - stan_file[grep("matrix[m, m] S_for_list[p+1];", - stan_file, fixed = TRUE)] <- - 'array[p+1] matrix[m, m] S_for_list;' - } - - # VAR model modifications - if(any(grepl('matrix[n_lv, n_lv] phiGamma[2, 1];', - stan_file, fixed = TRUE))){ - stan_file[grep('matrix[n_lv, n_lv] phiGamma[2, 1];', - stan_file, fixed = TRUE)] <- - 'array[2, 1] matrix[n_lv, n_lv] phiGamma;' - } - - if(any(grepl('matrix[,] rev_mapping(matrix[] P, matrix Sigma) {', - stan_file, fixed = TRUE))){ - stan_file[grep("matrix[,] rev_mapping(matrix[] P, matrix Sigma) {", - stan_file, fixed = TRUE)] <- - "array[,] matrix rev_mapping(array[] matrix P, matrix Sigma) {" - - stan_file[grep("matrix[m, m] phi_for[p, p]; matrix[m, m] phi_rev[p, p];", - stan_file, fixed = TRUE)] <- - 'array[p, p] matrix[m, m] phi_for; array[p, p] matrix[m, m] phi_rev;' - - stan_file[grep("matrix[m, m] Sigma_for[p+1]; matrix[m, m] Sigma_rev[p+1];", - stan_file, fixed = TRUE)] <- - 'array[p+1] matrix[m, m] Sigma_for; array[p+1] matrix[m, m] Sigma_rev;' - - stan_file[grep("matrix[m, m] S_for_list[p+1];", - stan_file, fixed = TRUE)] <- - 'array[p+1] matrix[m, m] S_for_list;' - - stan_file[grep("matrix[m, m] Gamma_trans[p+1];", - stan_file, fixed = TRUE)] <- - 'array[p+1] matrix[m, m] Gamma_trans;' - - stan_file[grep("matrix[m, m] phiGamma[2, p];", - stan_file, fixed = TRUE)] <- - 'array[2, p] matrix[m, m] phiGamma;' - } - - if(any(grepl("real partial_log_lik(int[] seq, int start, int end,", - stan_file, fixed = TRUE))){ - stan_file[grepl("real partial_log_lik(int[] seq, int start, int end,", - stan_file, fixed = TRUE)] <- - "real partial_log_lik(array[] int seq, int start, int end," - } - - if(any(grepl("data vector Y, vector mu, real[] shape) {", - stan_file, fixed = TRUE))){ - stan_file[grepl("data vector Y, vector mu, real[] shape) {" , - stan_file, fixed = TRUE)] <- - "data vector Y, vector mu, array[] real shape) {" - } - - if(any(grepl("int seq[n_nonmissing]; // an integer sequence for reduce_sum slicing", - stan_file, fixed = TRUE))){ - stan_file[grepl("int seq[n_nonmissing]; // an integer sequence for reduce_sum slicing", - stan_file, fixed = TRUE)] <- - "array[n_nonmissing] int seq; // an integer sequence for reduce_sum slicing" - } - } - - if(backend == 'rstan'){ - options(stanc.allow_optimizations = TRUE, - stanc.auto_format = TRUE) - out <- rstan::stanc(model_code = stan_file)$model_code - } else { - stan_file <- cmdstanr::write_stan_file(stan_file) - cmdstan_mod <- cmdstanr::cmdstan_model(stan_file, compile = FALSE) - out <- utils::capture.output( - cmdstan_mod$format( - max_line_length = 80, - canonicalize = TRUE, - overwrite_file = overwrite_file, backup = FALSE)) - out <- paste0(out, collapse = "\n") - } - return(out) -} #### Replacement for MCMCvis functions to remove dependence on rstan for working # with stanfit objects #### @@ -3703,381 +3286,6 @@ add_trend_predictors = function(trend_formula, trend_random_included = trend_random_included)) } -#### Helper functions for extracting parameter estimates from cmdstan objects #### -#' All functions were directly copied from `brms` and so all credit must -#' go to the `brms` development team -#' @noRd -repair_variable_names <- function(x) { - x <- sub("\\.", "[", x) - x <- gsub("\\.", ",", x) - x[grep("\\[", x)] <- paste0(x[grep("\\[", x)], "]") - x -} - -#' @noRd -seq_rows = function (x) -{ - seq_len(NROW(x)) -} - -#' @noRd -is_equal <- function(x, y, check.attributes = FALSE, ...) { - isTRUE(all.equal(x, y, check.attributes = check.attributes, ...)) -} - -#' @noRd -repair_stanfit <- function(x) { - if (!length(x@sim$fnames_oi)) { - # nothing to rename - return(x) - } - # the posterior package cannot deal with non-unique parameter names - # this case happens rarely but might happen when sample_prior = "yes" - x@sim$fnames_oi <- make.unique(as.character(x@sim$fnames_oi), "__") - for (i in seq_along(x@sim$samples)) { - # stanfit may have renamed dimension suffixes (#1218) - if (length(x@sim$samples[[i]]) == length(x@sim$fnames_oi)) { - names(x@sim$samples[[i]]) <- x@sim$fnames_oi - } - } - x -} - -#' @importFrom methods new -#' @noRd -read_csv_as_stanfit <- function(files, variables = NULL, - sampler_diagnostics = NULL) { - - # Code borrowed from brms: https://github.com/paul-buerkner/brms/R/backends.R#L603 - repair_names <- function(x) { - x <- sub("\\.", "[", x) - x <- gsub("\\.", ",", x) - x[grep("\\[", x)] <- paste0(x[grep("\\[", x)], "]") - x - } - - if(!is.null(variables)){ - # ensure that only relevant variables are read from CSV - metadata <- cmdstanr::read_cmdstan_csv( - files = files, variables = "", sampler_diagnostics = "") - - all_vars <- repair_names(metadata$metadata$variables) - all_vars <- unique(sub("\\[.+", "", all_vars)) - variables <- variables[variables %in% all_vars] - } - - csfit <- cmdstanr::read_cmdstan_csv( - files = files, variables = variables, - sampler_diagnostics = sampler_diagnostics, - format = NULL - ) - - # @model_name - model_name = gsub(".csv", "", basename(files[[1]])) - - # @model_pars - svars <- csfit$metadata$stan_variables - if (!is.null(variables)) { - variables_main <- unique(gsub("\\[.*\\]", "", variables)) - svars <- intersect(variables_main, svars) - } - if ("lp__" %in% svars) { - svars <- c(setdiff(svars, "lp__"), "lp__") - } - pars_oi <- svars - par_names <- csfit$metadata$model_params - - # @par_dims - par_dims <- vector("list", length(svars)) - - names(par_dims) <- svars - par_dims <- lapply(par_dims, function(x) x <- integer(0)) - - pdims_num <- ulapply( - svars, function(x) sum(grepl(paste0("^", x, "\\[.*\\]$"), par_names)) - ) - par_dims[pdims_num != 0] <- - csfit$metadata$stan_variable_sizes[svars][pdims_num != 0] - - # @mode - mode <- 0L - - # @sim - rstan_diagn_order <- c("accept_stat__", "treedepth__", "stepsize__", - "divergent__", "n_leapfrog__", "energy__") - - if (!is.null(sampler_diagnostics)) { - rstan_diagn_order <- rstan_diagn_order[rstan_diagn_order %in% sampler_diagnostics] - } - - res_vars <- c(".chain", ".iteration", ".draw") - if ("post_warmup_draws" %in% names(csfit)) { - # for MCMC samplers - n_chains <- max( - posterior::nchains(csfit$warmup_draws), - posterior::nchains(csfit$post_warmup_draws) - ) - n_iter_warmup <- posterior::niterations(csfit$warmup_draws) - n_iter_sample <- posterior::niterations(csfit$post_warmup_draws) - if (n_iter_warmup > 0) { - csfit$warmup_draws <- posterior::as_draws_df(csfit$warmup_draws) - csfit$warmup_sampler_diagnostics <- - posterior::as_draws_df(csfit$warmup_sampler_diagnostics) - } - if (n_iter_sample > 0) { - csfit$post_warmup_draws <- posterior::as_draws_df(csfit$post_warmup_draws) - csfit$post_warmup_sampler_diagnostics <- - posterior::as_draws_df(csfit$post_warmup_sampler_diagnostics) - } - - # called 'samples' for consistency with rstan - samples <- rbind(csfit$warmup_draws, csfit$post_warmup_draws) - # manage memory - csfit$warmup_draws <- NULL - csfit$post_warmup_draws <- NULL - - # prepare sampler diagnostics - diagnostics <- rbind(csfit$warmup_sampler_diagnostics, - csfit$post_warmup_sampler_diagnostics) - # manage memory - csfit$warmup_sampler_diagnostics <- NULL - csfit$post_warmup_sampler_diagnostics <- NULL - # convert to regular data.frame - diagnostics <- as.data.frame(diagnostics) - diag_chain_ids <- diagnostics$.chain - diagnostics[res_vars] <- NULL - - } else if ("draws" %in% names(csfit)) { - # for variational inference "samplers" - n_chains <- 1 - n_iter_warmup <- 0 - n_iter_sample <- posterior::niterations(csfit$draws) - if (n_iter_sample > 0) { - csfit$draws <- posterior::as_draws_df(csfit$draws) - } - - # called 'samples' for consistency with rstan - samples <- csfit$draws - # manage memory - csfit$draws <- NULL - - # VI has no sampler diagnostics - diag_chain_ids <- rep(1L, nrow(samples)) - diagnostics <- as.data.frame(matrix(nrow = nrow(samples), ncol = 0)) - } - - # convert to regular data.frame - samples <- as.data.frame(samples) - chain_ids <- samples$.chain - samples[res_vars] <- NULL - - move2end <- function(x, last) { - x[c(setdiff(names(x), last), last)] - } - - if ("lp__" %in% colnames(samples)) { - samples <- move2end(samples, "lp__") - } - - fnames_oi <- colnames(samples) - - colnames(samples) <- gsub("\\[", ".", colnames(samples)) - colnames(samples) <- gsub("\\]", "", colnames(samples)) - colnames(samples) <- gsub("\\,", ".", colnames(samples)) - - # split samples into chains - samples <- split(samples, chain_ids) - names(samples) <- NULL - - # split diagnostics into chains - diagnostics <- split(diagnostics, diag_chain_ids) - names(diagnostics) <- NULL - - # @sim$sample: largely 113-130 from rstan::read_stan_csv - values <- list() - values$algorithm <- csfit$metadata$algorithm - values$engine <- csfit$metadata$engine - values$metric <- csfit$metadata$metric - - sampler_t <- NULL - if (!is.null(values$algorithm)) { - if (values$algorithm == "rwm" || values$algorithm == "Metropolis") { - sampler_t <- "Metropolis" - } else if (values$algorithm == "hmc") { - if (values$engine == "static") { - sampler_t <- "HMC" - } else { - if (values$metric == "unit_e") { - sampler_t <- "NUTS(unit_e)" - } else if (values$metric == "diag_e") { - sampler_t <- "NUTS(diag_e)" - } else if (values$metric == "dense_e") { - sampler_t <- "NUTS(dense_e)" - } - } - } - } - - adapt_info <- vector("list", 4) - idx_samples <- (n_iter_warmup + 1):(n_iter_warmup + n_iter_sample) - - for (i in seq_along(samples)) { - m <- colMeans(samples[[i]][idx_samples, , drop=FALSE]) - rownames(samples[[i]]) <- seq_rows(samples[[i]]) - attr(samples[[i]], "sampler_params") <- diagnostics[[i]][rstan_diagn_order] - rownames(attr(samples[[i]], "sampler_params")) <- seq_rows(diagnostics[[i]]) - - # reformat back to text - if (is_equal(sampler_t, "NUTS(dense_e)")) { - mmatrix_txt <- "\n# Elements of inverse mass matrix:\n# " - mmat <- paste0(apply(csfit$inv_metric[[i]], 1, paste0, collapse=", "), - collapse="\n# ") - } else { - mmatrix_txt <- "\n# Diagonal elements of inverse mass matrix:\n# " - mmat <- paste0(csfit$inv_metric[[i]], collapse = ", ") - } - - adapt_info[[i]] <- paste0("# Step size = ", - csfit$step_size[[i]], - mmatrix_txt, - mmat, "\n# ") - - attr(samples[[i]], "adaptation_info") <- adapt_info[[i]] - - attr(samples[[i]], "args") <- list(sampler_t = sampler_t, chain_id = i) - - if (NROW(csfit$metadata$time)) { - time_i <- as.double(csfit$metadata$time[i, c("warmup", "sampling")]) - names(time_i) <- c("warmup", "sample") - attr(samples[[i]], "elapsed_time") <- time_i - } - - attr(samples[[i]], "mean_pars") <- m[-length(m)] - attr(samples[[i]], "mean_lp__") <- m["lp__"] - } - - perm_lst <- lapply(seq_len(n_chains), function(id) sample.int(n_iter_sample)) - - # @sim - sim <- list( - samples = samples, - iter = csfit$metadata$iter_sampling + csfit$metadata$iter_warmup, - thin = csfit$metadata$thin, - warmup = csfit$metadata$iter_warmup, - chains = n_chains, - n_save = rep(n_iter_sample + n_iter_warmup, n_chains), - warmup2 = rep(n_iter_warmup, n_chains), - permutation = perm_lst, - pars_oi = pars_oi, - dims_oi = par_dims, - fnames_oi = fnames_oi, - n_flatnames = length(fnames_oi) - ) - - # @stan_args - sargs <- list( - stan_version_major = as.character(csfit$metadata$stan_version_major), - stan_version_minor = as.character(csfit$metadata$stan_version_minor), - stan_version_patch = as.character(csfit$metadata$stan_version_patch), - model = csfit$metadata$model_name, - start_datetime = gsub(" ", "", csfit$metadata$start_datetime), - method = csfit$metadata$method, - iter = csfit$metadata$iter_sampling + csfit$metadata$iter_warmup, - warmup = csfit$metadata$iter_warmup, - save_warmup = csfit$metadata$save_warmup, - thin = csfit$metadata$thin, - engaged = as.character(csfit$metadata$adapt_engaged), - gamma = csfit$metadata$gamma, - delta = csfit$metadata$adapt_delta, - kappa = csfit$metadata$kappa, - t0 = csfit$metadata$t0, - init_buffer = as.character(csfit$metadata$init_buffer), - term_buffer = as.character(csfit$metadata$term_buffer), - window = as.character(csfit$metadata$window), - algorithm = csfit$metadata$algorithm, - engine = csfit$metadata$engine, - max_depth = csfit$metadata$max_treedepth, - metric = csfit$metadata$metric, - metric_file = character(0), # not stored in metadata - stepsize = NA, # add in loop - stepsize_jitter = csfit$metadata$stepsize_jitter, - num_chains = as.character(csfit$metadata$num_chains), - chain_id = NA, # add in loop - file = character(0), # not stored in metadata - init = NA, # add in loop - seed = as.character(csfit$metadata$seed), - file = NA, # add in loop - diagnostic_file = character(0), # not stored in metadata - refresh = as.character(csfit$metadata$refresh), - sig_figs = as.character(csfit$metadata$sig_figs), - profile_file = csfit$metadata$profile_file, - num_threads = as.character(csfit$metadata$threads_per_chain), - stanc_version = gsub(" ", "", csfit$metadata$stanc_version), - stancflags = character(0), # not stored in metadata - adaptation_info = NA, # add in loop - has_time = is.numeric(csfit$metadata$time$total), - time_info = NA, # add in loop - sampler_t = sampler_t - ) - - sargs_rep <- replicate(n_chains, sargs, simplify = FALSE) - - for (i in seq_along(sargs_rep)) { - sargs_rep[[i]]$chain_id <- i - sargs_rep[[i]]$stepsize <- csfit$metadata$step_size[i] - sargs_rep[[i]]$init <- as.character(csfit$metadata$init[i]) - # two 'file' elements: select the second - file_idx <- which(names(sargs_rep[[i]]) == "file") - sargs_rep[[i]][[file_idx[2]]] <- files[[i]] - - sargs_rep[[i]]$adaptation_info <- adapt_info[[i]] - - if (NROW(csfit$metadata$time)) { - sargs_rep[[i]]$time_info <- paste0( - c("# Elapsed Time: ", "# ", "# ", "# "), - c(csfit$metadata$time[i, c("warmup", "sampling", "total")], ""), - c(" seconds (Warm-up)", " seconds (Sampling)", " seconds (Total)", "") - ) - } - } - - # @stanmodel - null_dso <- new( - "cxxdso", sig = list(character(0)), dso_saved = FALSE, - dso_filename = character(0), modulename = character(0), - system = R.version$system, cxxflags = character(0), - .CXXDSOMISC = new.env(parent = emptyenv()) - ) - null_sm <- new( - "stanmodel", model_name = model_name, model_code = character(0), - model_cpp = list(), dso = null_dso - ) - - # @date - sdate <- do.call(max, lapply(files, function(csv) file.info(csv)$mtime)) - sdate <- format(sdate, "%a %b %d %X %Y") - - new( - "stanfit", - model_name = model_name, - model_pars = svars, - par_dims = par_dims, - mode = mode, - sim = sim, - inits = list(), - stan_args = sargs_rep, - stanmodel = null_sm, - date = sdate, # not the time of sampling - .MISC = new.env(parent = emptyenv()) - ) -} - -#' @noRd -ulapply <- function(X, FUN, ..., recursive = TRUE, use.names = TRUE) { - unlist(lapply(X, FUN, ...), recursive, use.names) -} - - #### Stan diagnostic checks #### #' Check transitions that ended with a divergence #' @param fit A stanfit object @@ -4240,3 +3448,66 @@ check_all_diagnostics <- function(fit, max_treedepth = 10) { sampler_params = sampler_params) check_energy(fit, sampler_params = sampler_params) } + +#' @noRd +is_try_error = function (x) { + inherits(x, "try-error") +} + +#' evaluate an expression without printing output or messages +#' @param expr expression to be evaluated +#' @param type type of output to be suppressed (see ?sink) +#' @param try wrap evaluation of expr in 'try' and +#' not suppress outputs if evaluation fails? +#' @param silent actually evaluate silently? +#' @noRd +eval_silent <- function(expr, type = "output", try = FALSE, + silent = TRUE, ...) { + try <- as_one_logical(try) + silent <- as_one_logical(silent) + type <- match.arg(type, c("output", "message")) + expr <- substitute(expr) + envir <- parent.frame() + if (silent) { + if (try && type == "message") { + try_out <- try(utils::capture.output( + out <- eval(expr, envir), type = type, ... + )) + if (is_try_error(try_out)) { + # try again without suppressing error messages + out <- eval(expr, envir) + } + } else { + utils::capture.output(out <- eval(expr, envir), type = type, ...) + } + } else { + out <- eval(expr, envir) + } + out +} + + +#' @noRd +nlist = function (...) { + m <- match.call() + dots <- list(...) + no_names <- is.null(names(dots)) + has_name <- if (no_names) + FALSE + else nzchar(names(dots)) + if (all(has_name)) + return(dots) + nms <- as.character(m)[-1] + if (no_names) { + names(dots) <- nms + } + else { + names(dots)[!has_name] <- nms[!has_name] + } + dots +} + +#' @noRd +`c<-` = function (x, value) { + c(x, value) +} diff --git a/R/validations.R b/R/validations.R index b78acd5a..8096a341 100644 --- a/R/validations.R +++ b/R/validations.R @@ -113,6 +113,48 @@ validate_series_time = function(data, name = 'data', return(data) } +#'@noRd +as_one_logical = function (x, allow_na = FALSE) { + s <- substitute(x) + x <- as.logical(x) + if (length(x) != 1L || anyNA(x) && !allow_na) { + s <- deparse0(s, max_char = 100L) + stop("Cannot coerce '", s, "' to a single logical value.", + call. = FALSE) + } + x +} + +#'@noRd +as_one_integer <- function(x, allow_na = FALSE) { + s <- substitute(x) + x <- suppressWarnings(as.integer(x)) + if (length(x) != 1L || anyNA(x) && !allow_na) { + s <- deparse0(s, max_char = 100L) + stop("Cannot coerce '", s, "' to a single integer value.", + call. = FALSE) + } + x +} + +#'@noRd +deparse0 = function (x, max_char = NULL, ...) { + out <- collapse(deparse(x, ...)) + if (isTRUE(max_char > 0)) { + out <- substr(out, 1L, max_char) + } + out +} + +#'@noRd +validate_silent <- function(silent) { + silent <- as_one_integer(silent) + if (silent < 0 || silent > 2) { + stop2("'silent' must be between 0 and 2.") + } + silent +} + #'@importFrom rlang warn #'@noRd validate_family = function(family, use_stan = TRUE){ diff --git a/man/mvgam.Rd b/man/mvgam.Rd index 7ff3a9a8..7090a48f 100644 --- a/man/mvgam.Rd +++ b/man/mvgam.Rd @@ -38,8 +38,9 @@ mvgam( algorithm = getOption("brms.algorithm", "sampling"), autoformat = TRUE, save_all_pars = FALSE, - max_treedepth, - adapt_delta, + max_treedepth = 12, + adapt_delta = 0.85, + silent = 1, jags_path, ... ) @@ -267,6 +268,12 @@ The step size used by the numerical integrator is a function of \code{adapt_delt \code{adapt_delta} will result in a smaller step size and fewer divergences. Increasing \code{adapt_delta} will typically result in a slower sampler, but it will always lead to a more robust sampler} +\item{silent}{Verbosity level between \code{0} and \code{2}. If \code{1} (the default), most of the informational +messages of compiler and sampler are suppressed. If \code{2}, even more messages are suppressed. The +actual sampling progress is still printed. Set \code{refresh = 0} to turn this off as well. If using +\code{backend = "rstan"} you can also set open_progress = FALSE to prevent opening additional +progress bars.} + \item{jags_path}{Optional character vector specifying the path to the location of the \code{JAGS} executable (.exe) to use for modelling if \code{use_stan == FALSE}. If missing, the path will be recovered from a call to \code{\link[runjags]{findjags}}} diff --git a/man/mvgam_marginaleffects.Rd b/man/mvgam_marginaleffects.Rd index 5564cfdb..16e6648e 100644 --- a/man/mvgam_marginaleffects.Rd +++ b/man/mvgam_marginaleffects.Rd @@ -86,6 +86,8 @@ arguments.} \item \code{newdata = datagrid(cyl = c(4, 6))}: \code{cyl} variable equal to 4 and 6 and other regressors fixed at their means or modes. \item See the Examples section and the \code{\link[marginaleffects:datagrid]{datagrid()}} documentation. } +\item \code{\link[=subset]{subset()}} call with a single argument to select a subset of the dataset used to fit the model, ex: \code{newdata = subset(treatment == 1)} +\item \code{\link[dplyr:filter]{dplyr::filter()}} call with a single argument to select a subset of the dataset used to fit the model, ex: \code{newdata = filter(treatment == 1)} \item string: \itemize{ \item "mean": Marginal Effects at the Mean. Slopes when each predictor is held at its mean or mode. diff --git a/src/RcppExports.o b/src/RcppExports.o index 03b01119..69942520 100644 Binary files a/src/RcppExports.o and b/src/RcppExports.o differ diff --git a/src/mvgam.dll b/src/mvgam.dll index 0e5aad8f..85ad6dee 100644 Binary files a/src/mvgam.dll and b/src/mvgam.dll differ diff --git a/src/trend_funs.o b/src/trend_funs.o index 1c0d5a8f..5cb98e5c 100644 Binary files a/src/trend_funs.o and b/src/trend_funs.o differ diff --git a/tests/testthat/Rplots.pdf b/tests/testthat/Rplots.pdf index 318ee0c1..5d737044 100644 Binary files a/tests/testthat/Rplots.pdf and b/tests/testthat/Rplots.pdf differ diff --git a/tests/testthat/test-binomial.R b/tests/testthat/test-binomial.R index 5fd1b3e8..0183a0f1 100644 --- a/tests/testthat/test-binomial.R +++ b/tests/testthat/test-binomial.R @@ -87,7 +87,8 @@ test_that("binomial() post-processing works", { data = dat_train, burnin = 500, samples = 200, - chains = 2)) + chains = 2, + silent = 2)) expect_no_error(capture_output(summary(mod))) expect_no_error(capture_output(code(mod))) expect_no_error(capture_output(print(mod))) @@ -158,7 +159,8 @@ test_that("binomial() post-processing works", { newdata = dat_test, burnin = 200, samples = 200, - chains = 2)) + chains = 2, + silent = 2)) fc <- forecast(mod) expect_true(inherits(fc, 'mvgam_forecast')) expect_no_error(plot_mvgam_uncertainty(mod)) @@ -281,7 +283,8 @@ test_that("bernoulli() post-processing works", { data = dat_train, burnin = 200, samples = 200, - chains = 2)) + chains = 2, + silent = 2)) expect_no_error(capture_output(summary(mod))) expect_no_error(capture_output(print(mod))) diff --git a/tests/testthat/test-nmixture.R b/tests/testthat/test-nmixture.R index dc76e4b5..fc80a80f 100644 --- a/tests/testthat/test-nmixture.R +++ b/tests/testthat/test-nmixture.R @@ -190,7 +190,8 @@ test_that("nmix() post-processing works", { prior(normal(1, 1.5), class = Intercept_trend)), samples = 300, residuals = FALSE, - chains = 2)) + chains = 2, + silent = 2)) expect_no_error(capture_output(summary(mod))) expect_no_error(capture_output(print(mod)))