diff --git a/DESCRIPTION b/DESCRIPTION index 40b3923b..343968ef 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -9,7 +9,7 @@ BugReports: https://github.com/nicholasjclark/mvgam/issues License: MIT + file LICENSE Depends: R (>= 3.6.0), - brms (>= 2.17) + brms (>= 2.21.0) Imports: methods, mgcv (>= 1.8-13), diff --git a/NAMESPACE b/NAMESPACE index 54cba82d..58ed9bd6 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -136,6 +136,7 @@ importFrom(brms,prior_string) importFrom(brms,pstudent_t) importFrom(brms,qstudent_t) importFrom(brms,rbeta_binomial) +importFrom(brms,read_csv_as_stanfit) importFrom(brms,rstudent_t) importFrom(brms,stancode) importFrom(brms,standata) @@ -176,7 +177,6 @@ importFrom(marginaleffects,get_vcov) importFrom(marginaleffects,plot_predictions) importFrom(marginaleffects,set_coef) importFrom(methods,cbind2) -importFrom(methods,new) importFrom(mgcv,Predict.matrix) importFrom(mgcv,Rrank) importFrom(mgcv,bam) diff --git a/NEWS.md b/NEWS.md index 8d338d38..cb8e9e17 100644 --- a/NEWS.md +++ b/NEWS.md @@ -10,6 +10,7 @@ ## Bug fixes * Added a new check to ensure that exception messages are only suppressed by the `silent` argument if the user's version of 'cmdstanr' is adequate +* Updated dependency for 'brms' to version >= '2.21.0' so that `read_csv_as_stanfit` can be imported, which should future-proof the conversion of 'cmdstanr' models to `stanfit` objects (#70) # mvgam 1.1.2 ## New functionalities diff --git a/R/backends.R b/R/backends.R index afa082f7..c1637a9e 100644 --- a/R/backends.R +++ b/R/backends.R @@ -60,6 +60,7 @@ } #' fit Stan model with cmdstanr using HMC sampling or variational inference +#' @importFrom brms read_csv_as_stanfit #' @param model a compiled Stan model #' @param data named list to be passed to Stan as data #' @return a fitted Stan model @@ -169,11 +170,35 @@ } # Convert model files to stan_fit class for consistency + repair_names <- function(x) { + x <- sub("\\.", "[", x) + x <- gsub("\\.", ",", x) + x[grep("\\[", x)] <- paste0(x[grep("\\[", x)], "]") + x + } + if(save_all_pars){ - out_gam_mod <- read_csv_as_stanfit(out$output_files()) + out_gam_mod <- brms::read_csv_as_stanfit(out$output_files(), + algorithm = algorithm) } else { - out_gam_mod <- read_csv_as_stanfit(out$output_files(), - variables = param) + # Exclude certain pars and transformed_pars that are never needed + # for mvgam post-processing + metadata <- cmdstanr::read_cmdstan_csv(files = out$output_files(), + variables = "", + sampler_diagnostics = "") + all_vars <- metadata$metadata$variables + out_gam_mod <- brms::read_csv_as_stanfit(out$output_files(), + variables = all_vars, + exclude = c('trend_raw', + 'b_raw', + 'eta', + 'phi_vec', + 'nu_vec', + 'sigma_obs_vec', + 'shape_vec', + 'phi_inv', + 'lv_coefs_raw'), + algorithm = algorithm) } out_gam_mod <- repair_stanfit(out_gam_mod) @@ -303,335 +328,6 @@ return(out) } -#' @importFrom methods new -#' @noRd -read_csv_as_stanfit <- function(files, variables = NULL, - sampler_diagnostics = NULL) { - - # Code borrowed from brms: https://github.com/paul-buerkner/brms/R/backends.R#L603 - repair_names <- function(x) { - x <- sub("\\.", "[", x) - x <- gsub("\\.", ",", x) - x[grep("\\[", x)] <- paste0(x[grep("\\[", x)], "]") - x - } - - if(!is.null(variables)){ - # ensure that only relevant variables are read from CSV - metadata <- cmdstanr::read_cmdstan_csv( - files = files, variables = "", sampler_diagnostics = "") - - all_vars <- repair_names(metadata$metadata$variables) - all_vars <- unique(sub("\\[.+", "", all_vars)) - variables <- variables[variables %in% all_vars] - } - - csfit <- cmdstanr::read_cmdstan_csv( - files = files, variables = variables, - sampler_diagnostics = sampler_diagnostics, - format = NULL - ) - - # @model_name - model_name = gsub(".csv", "", basename(files[[1]])) - - # @model_pars - svars <- csfit$metadata$stan_variables - if (!is.null(variables)) { - variables_main <- unique(gsub("\\[.*\\]", "", variables)) - svars <- intersect(variables_main, svars) - } - if ("lp__" %in% svars) { - svars <- c(setdiff(svars, "lp__"), "lp__") - } - pars_oi <- svars - par_names <- csfit$metadata$model_params - - # @par_dims - par_dims <- vector("list", length(svars)) - - names(par_dims) <- svars - par_dims <- lapply(par_dims, function(x) x <- integer(0)) - - pdims_num <- ulapply( - svars, function(x) sum(grepl(paste0("^", x, "\\[.*\\]$"), par_names)) - ) - par_dims[pdims_num != 0] <- - csfit$metadata$stan_variable_sizes[svars][pdims_num != 0] - - # @mode - mode <- 0L - - # @sim - rstan_diagn_order <- c("accept_stat__", "treedepth__", "stepsize__", - "divergent__", "n_leapfrog__", "energy__") - - if (!is.null(sampler_diagnostics)) { - rstan_diagn_order <- rstan_diagn_order[rstan_diagn_order %in% sampler_diagnostics] - } - - res_vars <- c(".chain", ".iteration", ".draw") - if ("post_warmup_draws" %in% names(csfit)) { - # for MCMC samplers - n_chains <- max( - posterior::nchains(csfit$warmup_draws), - posterior::nchains(csfit$post_warmup_draws) - ) - n_iter_warmup <- posterior::niterations(csfit$warmup_draws) - n_iter_sample <- posterior::niterations(csfit$post_warmup_draws) - if (n_iter_warmup > 0) { - csfit$warmup_draws <- posterior::as_draws_df(csfit$warmup_draws) - csfit$warmup_sampler_diagnostics <- - posterior::as_draws_df(csfit$warmup_sampler_diagnostics) - } - if (n_iter_sample > 0) { - csfit$post_warmup_draws <- posterior::as_draws_df(csfit$post_warmup_draws) - csfit$post_warmup_sampler_diagnostics <- - posterior::as_draws_df(csfit$post_warmup_sampler_diagnostics) - } - - # called 'samples' for consistency with rstan - samples <- rbind(csfit$warmup_draws, csfit$post_warmup_draws) - # manage memory - csfit$warmup_draws <- NULL - csfit$post_warmup_draws <- NULL - - # prepare sampler diagnostics - diagnostics <- rbind(csfit$warmup_sampler_diagnostics, - csfit$post_warmup_sampler_diagnostics) - # manage memory - csfit$warmup_sampler_diagnostics <- NULL - csfit$post_warmup_sampler_diagnostics <- NULL - # convert to regular data.frame - diagnostics <- as.data.frame(diagnostics) - diag_chain_ids <- diagnostics$.chain - diagnostics[res_vars] <- NULL - - } else if ("draws" %in% names(csfit)) { - # for variational inference "samplers" - n_chains <- 1 - n_iter_warmup <- 0 - n_iter_sample <- posterior::niterations(csfit$draws) - if (n_iter_sample > 0) { - csfit$draws <- posterior::as_draws_df(csfit$draws) - } - - # called 'samples' for consistency with rstan - samples <- csfit$draws - # manage memory - csfit$draws <- NULL - - # VI has no sampler diagnostics - diag_chain_ids <- rep(1L, nrow(samples)) - diagnostics <- as.data.frame(matrix(nrow = nrow(samples), ncol = 0)) - } - - # convert to regular data.frame - samples <- as.data.frame(samples) - chain_ids <- samples$.chain - samples[res_vars] <- NULL - - move2end <- function(x, last) { - x[c(setdiff(names(x), last), last)] - } - - if ("lp__" %in% colnames(samples)) { - samples <- move2end(samples, "lp__") - } - - fnames_oi <- colnames(samples) - - colnames(samples) <- gsub("\\[", ".", colnames(samples)) - colnames(samples) <- gsub("\\]", "", colnames(samples)) - colnames(samples) <- gsub("\\,", ".", colnames(samples)) - - # split samples into chains - samples <- split(samples, chain_ids) - names(samples) <- NULL - - # split diagnostics into chains - diagnostics <- split(diagnostics, diag_chain_ids) - names(diagnostics) <- NULL - - # @sim$sample: largely 113-130 from rstan::read_stan_csv - values <- list() - values$algorithm <- csfit$metadata$algorithm - values$engine <- csfit$metadata$engine - values$metric <- csfit$metadata$metric - - sampler_t <- NULL - if (!is.null(values$algorithm)) { - if (values$algorithm == "rwm" || values$algorithm == "Metropolis") { - sampler_t <- "Metropolis" - } else if (values$algorithm == "hmc") { - if (values$engine == "static") { - sampler_t <- "HMC" - } else { - if (values$metric == "unit_e") { - sampler_t <- "NUTS(unit_e)" - } else if (values$metric == "diag_e") { - sampler_t <- "NUTS(diag_e)" - } else if (values$metric == "dense_e") { - sampler_t <- "NUTS(dense_e)" - } - } - } - } - - adapt_info <- vector("list", 4) - idx_samples <- (n_iter_warmup + 1):(n_iter_warmup + n_iter_sample) - - for (i in seq_along(samples)) { - m <- colMeans(samples[[i]][idx_samples, , drop=FALSE]) - rownames(samples[[i]]) <- seq_rows(samples[[i]]) - attr(samples[[i]], "sampler_params") <- diagnostics[[i]][rstan_diagn_order] - rownames(attr(samples[[i]], "sampler_params")) <- seq_rows(diagnostics[[i]]) - - # reformat back to text - if (is_equal(sampler_t, "NUTS(dense_e)")) { - mmatrix_txt <- "\n# Elements of inverse mass matrix:\n# " - mmat <- paste0(apply(csfit$inv_metric[[i]], 1, paste0, collapse=", "), - collapse="\n# ") - } else { - mmatrix_txt <- "\n# Diagonal elements of inverse mass matrix:\n# " - mmat <- paste0(csfit$inv_metric[[i]], collapse = ", ") - } - - adapt_info[[i]] <- paste0("# Step size = ", - csfit$step_size[[i]], - mmatrix_txt, - mmat, "\n# ") - - attr(samples[[i]], "adaptation_info") <- adapt_info[[i]] - - attr(samples[[i]], "args") <- list(sampler_t = sampler_t, chain_id = i) - - if (NROW(csfit$metadata$time)) { - time_i <- as.double(csfit$metadata$time[i, c("warmup", "sampling")]) - names(time_i) <- c("warmup", "sample") - attr(samples[[i]], "elapsed_time") <- time_i - } - - attr(samples[[i]], "mean_pars") <- m[-length(m)] - attr(samples[[i]], "mean_lp__") <- m["lp__"] - } - - perm_lst <- lapply(seq_len(n_chains), function(id) sample.int(n_iter_sample)) - - # @sim - sim <- list( - samples = samples, - iter = csfit$metadata$iter_sampling + csfit$metadata$iter_warmup, - thin = csfit$metadata$thin, - warmup = csfit$metadata$iter_warmup, - chains = n_chains, - n_save = rep(n_iter_sample + n_iter_warmup, n_chains), - warmup2 = rep(n_iter_warmup, n_chains), - permutation = perm_lst, - pars_oi = pars_oi, - dims_oi = par_dims, - fnames_oi = fnames_oi, - n_flatnames = length(fnames_oi) - ) - - # @stan_args - sargs <- list( - stan_version_major = as.character(csfit$metadata$stan_version_major), - stan_version_minor = as.character(csfit$metadata$stan_version_minor), - stan_version_patch = as.character(csfit$metadata$stan_version_patch), - model = csfit$metadata$model_name, - start_datetime = gsub(" ", "", csfit$metadata$start_datetime), - method = csfit$metadata$method, - iter = csfit$metadata$iter_sampling + csfit$metadata$iter_warmup, - warmup = csfit$metadata$iter_warmup, - save_warmup = csfit$metadata$save_warmup, - thin = csfit$metadata$thin, - engaged = as.character(csfit$metadata$adapt_engaged), - gamma = csfit$metadata$gamma, - delta = csfit$metadata$adapt_delta, - kappa = csfit$metadata$kappa, - t0 = csfit$metadata$t0, - init_buffer = as.character(csfit$metadata$init_buffer), - term_buffer = as.character(csfit$metadata$term_buffer), - window = as.character(csfit$metadata$window), - algorithm = csfit$metadata$algorithm, - engine = csfit$metadata$engine, - max_depth = csfit$metadata$max_treedepth, - metric = csfit$metadata$metric, - metric_file = character(0), # not stored in metadata - stepsize = NA, # add in loop - stepsize_jitter = csfit$metadata$stepsize_jitter, - num_chains = as.character(csfit$metadata$num_chains), - chain_id = NA, # add in loop - file = character(0), # not stored in metadata - init = NA, # add in loop - seed = as.character(csfit$metadata$seed), - file = NA, # add in loop - diagnostic_file = character(0), # not stored in metadata - refresh = as.character(csfit$metadata$refresh), - sig_figs = as.character(csfit$metadata$sig_figs), - profile_file = csfit$metadata$profile_file, - num_threads = as.character(csfit$metadata$threads_per_chain), - stanc_version = gsub(" ", "", csfit$metadata$stanc_version), - stancflags = character(0), # not stored in metadata - adaptation_info = NA, # add in loop - has_time = is.numeric(csfit$metadata$time$total), - time_info = NA, # add in loop - sampler_t = sampler_t - ) - - sargs_rep <- replicate(n_chains, sargs, simplify = FALSE) - - for (i in seq_along(sargs_rep)) { - sargs_rep[[i]]$chain_id <- i - sargs_rep[[i]]$stepsize <- csfit$metadata$step_size[i] - sargs_rep[[i]]$init <- as.character(csfit$metadata$init[i]) - # two 'file' elements: select the second - file_idx <- which(names(sargs_rep[[i]]) == "file") - sargs_rep[[i]][[file_idx[2]]] <- files[[i]] - - sargs_rep[[i]]$adaptation_info <- adapt_info[[i]] - - if (NROW(csfit$metadata$time)) { - sargs_rep[[i]]$time_info <- paste0( - c("# Elapsed Time: ", "# ", "# ", "# "), - c(csfit$metadata$time[i, c("warmup", "sampling", "total")], ""), - c(" seconds (Warm-up)", " seconds (Sampling)", " seconds (Total)", "") - ) - } - } - - # @stanmodel - null_dso <- new( - "cxxdso", sig = list(character(0)), dso_saved = FALSE, - dso_filename = character(0), modulename = character(0), - system = R.version$system, cxxflags = character(0), - .CXXDSOMISC = new.env(parent = emptyenv()) - ) - null_sm <- new( - "stanmodel", model_name = model_name, model_code = character(0), - model_cpp = list(), dso = null_dso - ) - - # @date - sdate <- do.call(max, lapply(files, function(csv) file.info(csv)$mtime)) - sdate <- format(sdate, "%a %b %d %X %Y") - - new( - "stanfit", - model_name = model_name, - model_pars = svars, - par_dims = par_dims, - mode = mode, - sim = sim, - inits = list(), - stan_args = sargs_rep, - stanmodel = null_sm, - date = sdate, - .MISC = new.env(parent = emptyenv()) - ) -} - #' @noRd .autoformat <- function(stan_file, overwrite_file = TRUE, backend = 'cmdstanr', silent = TRUE){ @@ -1067,7 +763,6 @@ read_csv_as_stanfit <- function(files, variables = NULL, return(out) } - #' @noRd repair_stanfit <- function(x) { if (!length(x@sim$fnames_oi)) { diff --git a/docs/news/index.html b/docs/news/index.html index 2cebee5e..883d5191 100644 --- a/docs/news/index.html +++ b/docs/news/index.html @@ -60,9 +60,10 @@

mvgam 1.1.3 (development version; not yet on CRAN)

New functionalities

-

Deprecations

@@ -71,6 +72,7 @@

Deprecations

Bug fixes

diff --git a/docs/reference/RW.html b/docs/reference/RW.html index b8202d6a..451e9b24 100644 --- a/docs/reference/RW.html +++ b/docs/reference/RW.html @@ -16,7 +16,7 @@ mvgam - 1.1.2 + 1.1.3
diff --git a/docs/reference/Rplot001.png b/docs/reference/Rplot001.png index 17a35806..79bd0396 100644 Binary files a/docs/reference/Rplot001.png and b/docs/reference/Rplot001.png differ diff --git a/docs/reference/Rplot002.png b/docs/reference/Rplot002.png index f60122ed..16da55c0 100644 Binary files a/docs/reference/Rplot002.png and b/docs/reference/Rplot002.png differ diff --git a/docs/reference/Rplot003.png b/docs/reference/Rplot003.png index 84523f71..633fb5a2 100644 Binary files a/docs/reference/Rplot003.png and b/docs/reference/Rplot003.png differ diff --git a/docs/reference/Rplot004.png b/docs/reference/Rplot004.png index 74561f52..ba602c61 100644 Binary files a/docs/reference/Rplot004.png and b/docs/reference/Rplot004.png differ diff --git a/docs/reference/Rplot005.png b/docs/reference/Rplot005.png index 102c4abc..8b3cbead 100644 Binary files a/docs/reference/Rplot005.png and b/docs/reference/Rplot005.png differ diff --git a/docs/reference/Rplot006.png b/docs/reference/Rplot006.png index 8324296b..371d33a3 100644 Binary files a/docs/reference/Rplot006.png and b/docs/reference/Rplot006.png differ diff --git a/docs/reference/Rplot007.png b/docs/reference/Rplot007.png index 25a6bb1d..567205f0 100644 Binary files a/docs/reference/Rplot007.png and b/docs/reference/Rplot007.png differ diff --git a/docs/reference/Rplot008.png b/docs/reference/Rplot008.png index b9f04a08..f9eb49e4 100644 Binary files a/docs/reference/Rplot008.png and b/docs/reference/Rplot008.png differ diff --git a/docs/reference/Rplot009.png b/docs/reference/Rplot009.png index fe7c1b9c..d865cc3d 100644 Binary files a/docs/reference/Rplot009.png and b/docs/reference/Rplot009.png differ diff --git a/docs/reference/Rplot010.png b/docs/reference/Rplot010.png index bcbb7016..a17fc228 100644 Binary files a/docs/reference/Rplot010.png and b/docs/reference/Rplot010.png differ diff --git a/docs/reference/Rplot011.png b/docs/reference/Rplot011.png index daec4e3b..bcec02ea 100644 Binary files a/docs/reference/Rplot011.png and b/docs/reference/Rplot011.png differ diff --git a/docs/reference/Rplot012.png b/docs/reference/Rplot012.png index f3fc4b72..2f0af5c2 100644 Binary files a/docs/reference/Rplot012.png and b/docs/reference/Rplot012.png differ diff --git a/docs/reference/Rplot013.png b/docs/reference/Rplot013.png index 39a292ce..15c948f5 100644 Binary files a/docs/reference/Rplot013.png and b/docs/reference/Rplot013.png differ diff --git a/docs/reference/add_residuals.mvgam.html b/docs/reference/add_residuals.mvgam.html index f6ed8df9..6a069f42 100644 --- a/docs/reference/add_residuals.mvgam.html +++ b/docs/reference/add_residuals.mvgam.html @@ -10,7 +10,7 @@ mvgam - 1.1.0 + 1.1.3