Skip to content

Commit

Permalink
updates for shared family-level params; allow multithreading with lat…
Browse files Browse the repository at this point in the history
…er cmdstan versions
  • Loading branch information
Nicholas Clark committed Mar 1, 2024
1 parent b014d32 commit ea715c9
Show file tree
Hide file tree
Showing 11 changed files with 248 additions and 19 deletions.
47 changes: 42 additions & 5 deletions R/families.R
Original file line number Diff line number Diff line change
Expand Up @@ -757,6 +757,9 @@ extract_family_pars = function(object, newdata = NULL){
for(i in 1:length(pars_to_extract)){
out[[i]] <- mcmc_chains(object$model_output,
params = pars_to_extract[i])
if(NCOL(out[[i]]) == 1){
out[[i]] <- as.vector(out[[i]])
}
}

} else {
Expand Down Expand Up @@ -1463,19 +1466,44 @@ dsresids_vec = function(object){
}

if(family == 'student'){
if(NCOL(mcmc_chains(object$model_output, 'sigma_obs')) == 1){
sigma_obs <- family_extracts$sigma_obs
sigma_mat <- matrix(rep(sigma_obs,
NCOL(truth_mat)),
ncol = NCOL(truth_mat))
sigma_obs <- as.vector(sigma_mat)

nu <- family_extracts$nu
nu_mat <- matrix(rep(nu,
NCOL(truth_mat)),
ncol = NCOL(truth_mat))
nu <- as.vector(nu_mat)
} else {
sigma_obs <- family_extracts$sigma_obs
nu <- family_extracts$nu
}
resids <- matrix(ds_resids_student(truth = as.vector(truth_mat),
fitted = as.vector(preds),
draw = 1,
sigma = family_extracts$sigma_obs,
nu = family_extracts$nu),
sigma = sigma_obs,
nu = nu),
nrow = NROW(preds))
}

if(family == 'lognormal'){
if(NCOL(mcmc_chains(object$model_output, 'sigma_obs')) == 1){
sigma_obs <- family_extracts$sigma_obs
sigma_mat <- matrix(rep(sigma_obs,
NCOL(truth_mat)),
ncol = NCOL(truth_mat))
sigma_obs <- as.vector(sigma_mat)
} else {
sigma_obs <- family_extracts$sigma_obs
}
resids <- matrix(ds_resids_lnorm(truth = as.vector(truth_mat),
fitted = as.vector(preds),
draw = 1,
sigma = family_extracts$sigma_obs),
sigma = sigma_obs),
nrow = NROW(preds))
}

Expand All @@ -1495,18 +1523,27 @@ dsresids_vec = function(object){
}

if(family == 'Gamma'){
if(NCOL(mcmc_chains(object$model_output, 'shape')) == 1){
shape <- family_extracts$shape
sigma_mat <- matrix(rep(shape,
NCOL(truth_mat)),
ncol = NCOL(truth_mat))
shape <- as.vector(sigma_mat)
} else {
shape <- family_extracts$shape
}
resids <- matrix(ds_resids_gamma(truth = as.vector(truth_mat),
fitted = as.vector(preds),
draw = 1,
shape = family_extracts$shape),
shape = shape),
nrow = NROW(preds))
}

if(family == 'negative binomial'){
resids <- matrix(ds_resids_nb(truth = as.vector(truth_mat),
fitted = as.vector(preds),
draw = 1,
size = family_extracts$size),
size = family_extracts$phi),
nrow = NROW(preds))
}

Expand Down
27 changes: 25 additions & 2 deletions R/mvgam.R
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,12 @@
#'Note that only `nb()` and `poisson()` are available if using `JAGS` as the backend.
#'Default is `poisson()`.
#'See \code{\link{mvgam_families}} for more details
#'@param share_obs_params \code{logical}. If \code{TRUE} and the \code{family}
#'has additional family-specific observation parameters (e.g. variance components in
#'`student_t()` or `gaussian()`, or dispersion parameters in `nb()` or `betar()`),
#'these parameters will be shared across all series. This is handy if you have multiple
#'time series that you believe share some properties, such as being from the same
#'species over different spatial units. Default is \code{FALSE}.
#'@param use_lv \code{logical}. If \code{TRUE}, use dynamic factors to estimate series'
#'latent trends in a reduced dimension format. Only available for
#'`RW()`, `AR()` and `GP()` trend models. Defaults to \code{FALSE}
Expand Down Expand Up @@ -124,7 +130,8 @@
#'@param threads \code{integer} Experimental option to use multithreading for within-chain
#'parallelisation in \code{Stan}. We recommend its use only if you are experienced with
#'\code{Stan}'s `reduce_sum` function and have a slow running model that cannot be sped
#'up by any other means. Only available when using \code{Cmdstan} as the backend
#'up by any other means. Only available for some families(`poisson()`, `nb()`, `gaussian()`) and
#'when using \code{Cmdstan} as the backend
#'@param priors An optional \code{data.frame} with prior
#'definitions (in JAGS or Stan syntax). if using Stan, this can also be an object of
#'class `brmsprior` (see. \code{\link[brms]{prior}} for details). See [get_mvgam_priors] and
Expand Down Expand Up @@ -242,7 +249,8 @@
#'*Observation level parameters*: When more than one series is included in \code{data} and an
#'observation family that contains more than one parameter is used, additional observation family parameters
#'(i.e. `phi` for `nb()` or `sigma` for `gaussian()`) are
#'estimated independently for each series.
#'by default estimated independently for each series. But if you wish for the series to share
#'the same observation parameters, set `share_obs_params = TRUE`
#'\cr
#'\cr
#'*Factor regularisation*: When using a dynamic factor model for the trends with `JAGS` factor precisions are given
Expand Down Expand Up @@ -574,6 +582,7 @@ mvgam = function(formula,
prior_simulation = FALSE,
return_model_data = FALSE,
family = 'poisson',
share_obs_params = FALSE,
use_lv = FALSE,
n_lv,
trend_map,
Expand Down Expand Up @@ -709,6 +718,10 @@ mvgam = function(formula,
# Validate the family argument
family <- validate_family(family, use_stan = use_stan)
family_char <- match.arg(arg = family$family, choices = family_char_choices())
if(threads > 1 & !family_char %in% c('poisson', 'negative binomial', 'gaussian')){
warning('multithreading not supported for this family; setting threads = 1')
threads <- 1
}

# Validate the trend arguments
orig_trend_model <- trend_model
Expand Down Expand Up @@ -1724,6 +1737,12 @@ mvgam = function(formula,
'lv_coefs', 'error')]
}

# Updates for sharing of observation params
if(share_obs_params){
vectorised$model_file <- shared_obs_params(vectorised$model_file,
family_char)
}

# Tidy the representation
vectorised$model_file <- sanitise_modelfile(vectorised$model_file)

Expand Down Expand Up @@ -1816,6 +1835,7 @@ mvgam = function(formula,
NULL
},
family = family_char,
share_obs_params = share_obs_params,
trend_model = orig_trend_model,
trend_map = if(!missing(trend_map)){
trend_map
Expand Down Expand Up @@ -2255,6 +2275,7 @@ mvgam = function(formula,
},
fit_engine = fit_engine,
family = family_char,
share_obs_params = share_obs_params,
obs_data = data_train,
test_data = data_test,
ytimes = ytimes)
Expand Down Expand Up @@ -2296,6 +2317,7 @@ mvgam = function(formula,
},
fit_engine = fit_engine,
family = family_char,
share_obs_params = share_obs_params,
obs_data = data_train,
test_data = data_test,
trend_model = trend_model,
Expand Down Expand Up @@ -2337,6 +2359,7 @@ mvgam = function(formula,
NULL
},
family = family_char,
share_obs_params = share_obs_params,
trend_model = orig_trend_model,
trend_map = if(!missing(trend_map)){
trend_map
Expand Down
120 changes: 120 additions & 0 deletions R/shared_obs_params.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
#' Updates for allowing shared observation params across series
#' @noRd
shared_obs_params = function(model_file, family){
if(family == 'poisson'){
message('Context share_obs_params: Poisson family has no additional observation params')
model_file <- model_file
}

if(family == 'nmix'){
message('Context share_obs_params: nmix family has no additional observation params')
model_file <- model_file
}

if(family %in% c('student', 'gaussian', 'lognormal')){
model_file[grep("vector<lower=0>[n_series] sigma_obs;",
model_file, fixed = TRUE)] <-
"real<lower=0> sigma_obs;"

model_file <- model_file[-grep("flat_sigma_obs = rep_each(sigma_obs, n)[obs_ind];" ,
model_file, fixed = TRUE)]
model_file <- model_file[-grep("vector[n_nonmissing] flat_sigma_obs;" ,
model_file, fixed = TRUE)]
model_file[grep("flat_sigma_obs);", model_file, fixed = TRUE)] <-
'sigma_obs);'

if(any(grepl("flat_sigma_obs,", model_file, fixed = TRUE))){
model_file[grep("flat_sigma_obs,", model_file, fixed = TRUE)] <-
"sigma_obs,"
model_file[grep("data vector Y, matrix X, vector b, vector sigma_obs, real alpha) {",
model_file, fixed = TRUE)] <-
"data vector Y, matrix X, vector b, real sigma_obs, real alpha) {"
model_file[grep("ptarget += normal_id_glm_lpdf(Y[start:end] | X[start:end], alpha, b, sigma_obs[start:end]);",
model_file, fixed = TRUE)] <-
"ptarget += normal_id_glm_lpdf(Y[start:end] | X[start:end], alpha, b, sigma_obs);"
}

model_file[grep("sigma_obs_vec[1:n,s] = rep_vector(sigma_obs[s], n);",
model_file, fixed = TRUE)] <-
"sigma_obs_vec[1:n,s] = rep_vector(sigma_obs, n);"
}

if(family == 'student'){
model_file[grep("vector<lower=0>[n_series] nu;",
model_file, fixed = TRUE)] <-
"real<lower=0> nu;"

model_file <- model_file[-grep("flat_nu = rep_each(nu, n)[obs_ind];" ,
model_file, fixed = TRUE)]
model_file <- model_file[-grep("vector[n_nonmissing] flat_nu;" ,
model_file, fixed = TRUE)]
model_file[grep("flat_ys ~ student_t(flat_nu,", model_file, fixed = TRUE)] <-
"flat_ys ~ student_t(nu,"

model_file[grep("nu_vec[1:n,s] = rep_vector(nu[s], n);",
model_file, fixed = TRUE)] <-
"nu_vec[1:n,s] = rep_vector(nu, n);"
}

if(family == 'negative binomial'){
model_file[grep('vector<lower=0>[n_series] phi_inv;',
model_file, fixed = TRUE)] <-
'real<lower=0> phi_inv;'

model_file <- model_file[-grep('flat_phis = to_array_1d(rep_each(phi_inv, n)[obs_ind]);',
model_file, fixed = TRUE)]
model_file <- model_file[-grep("real flat_phis[n_nonmissing];",
model_file, fixed = TRUE)]

model_file[grep("inv(flat_phis));" , model_file,
fixed = TRUE)] <-
'inv(phi_inv));'

model_file[grep("phi = inv(phi_inv);" , model_file,
fixed = TRUE)] <-
"phi = rep_vector(inv(phi_inv), n_series);"
}

if(family == 'beta'){
model_file[grep('vector<lower=0>[n_series] phi;',
model_file, fixed = TRUE)] <-
'real<lower=0> phi;'

model_file <- model_file[-grep('flat_phis = rep_each(phi, n)[obs_ind];',
model_file, fixed = TRUE)]
model_file <- model_file[-grep("vector[n_nonmissing] flat_phis;" ,
model_file, fixed = TRUE)]

model_file[grep("inv_logit(flat_xs * b) .* flat_phis," , model_file,
fixed = TRUE)] <-
"inv_logit(flat_xs * b) .* phi,"
model_file[grep("(1 - inv_logit(flat_xs * b)) .* flat_phis);" , model_file,
fixed = TRUE)] <-
"(1 - inv_logit(flat_xs * b)) .* phi);"

model_file[grep("phi_vec[1:n,s] = rep_vector(phi[s], n);" , model_file,
fixed = TRUE)] <-
"phi_vec[1:n,s] = rep_vector(phi, n);"
}

if(family == 'Gamma'){
model_file[grep("vector<lower=0>[n_series] shape;",
model_file, fixed = TRUE)] <-
"real<lower=0> shape;"

model_file <- model_file[-grep("flat_shapes = rep_each(shape, n)[obs_ind];",
model_file, fixed = TRUE)]
model_file <- model_file[-grep("vector[n_nonmissing] flat_shapes;",
model_file, fixed = TRUE)]

model_file[grep("flat_shapes, flat_shapes ./ exp(flat_xs * b));" , model_file,
fixed = TRUE)] <-
"shape, shape ./ exp(flat_xs * b));"

model_file[grep("shape_vec[1:n,s] = rep_vector(shape[s], n);", model_file,
fixed = TRUE)] <-
"shape_vec[1:n,s] = rep_vector(shape, n);"
}

return(model_file)
}
35 changes: 28 additions & 7 deletions R/stan_utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,27 @@ remove_likelihood = function(model_file){
stan_file, fixed = TRUE)] <-
'array[2, p] matrix[m, m] phiGamma;'
}

if(any(grepl("real partial_log_lik(int[] seq, int start, int end,",
stan_file, fixed = TRUE))){
stan_file[grepl("real partial_log_lik(int[] seq, int start, int end,",
stan_file, fixed = TRUE)] <-
"real partial_log_lik(array[] int seq, int start, int end,"
}

if(any(grepl("data vector Y, vector mu, real[] shape) {",
stan_file, fixed = TRUE))){
stan_file[grepl("data vector Y, vector mu, real[] shape) {" ,
stan_file, fixed = TRUE)] <-
"data vector Y, vector mu, array[] real shape) {"
}

if(any(grepl("int<lower=1> seq[n_nonmissing]; // an integer sequence for reduce_sum slicing",
stan_file, fixed = TRUE))){
stan_file[grepl("int<lower=1> seq[n_nonmissing]; // an integer sequence for reduce_sum slicing",
stan_file, fixed = TRUE)] <-
"array[n_nonmissing] int<lower=1> seq; // an integer sequence for reduce_sum slicing"
}
}

stan_file <- cmdstanr::write_stan_file(stan_file)
Expand Down Expand Up @@ -1525,8 +1546,8 @@ vectorise_stan_lik = function(model_file, model_data, family = 'poisson',
ifelse(offset, 'data vector Y, matrix X, vector b, vector sigma_obs, vector alpha) {\n',
'data vector Y, matrix X, vector b, vector sigma_obs, real alpha) {\n'),
'real ptarget = 0;\n',
ifelse(offset,'ptarget += normal_id_glm_lpdf(Y[start:end] | X[start:end], alpha[start:end], sigma_obs[start:end]);\n',
'ptarget += normal_id_glm_lpdf(Y[start:end] | X[start:end], alpha, sigma_obs[start:end]);\n'),
ifelse(offset,'ptarget += normal_id_glm_lpdf(Y[start:end] | X[start:end], alpha[start:end], b, sigma_obs[start:end]);\n',
'ptarget += normal_id_glm_lpdf(Y[start:end] | X[start:end], alpha, b, sigma_obs[start:end]);\n'),
'return ptarget;\n',
'}\n')

Expand All @@ -1538,8 +1559,8 @@ vectorise_stan_lik = function(model_file, model_data, family = 'poisson',
ifelse(offset, 'data vector Y, matrix X, vector b, vector sigma_obs, vector alpha) {\n',
'data vector Y, matrix X, vector b, vector sigma_obs, real alpha) {\n'),
'real ptarget = 0;\n',
ifelse(offset,'ptarget += normal_id_glm_lpdf(Y[start:end] | X[start:end], alpha[start:end], sigma_obs[start:end]);\n',
'ptarget += normal_id_glm_lpdf(Y[start:end] | X[start:end], alpha, sigma_obs[start:end]);\n'),
ifelse(offset,'ptarget += normal_id_glm_lpdf(Y[start:end] | X[start:end], alpha[start:end], b, sigma_obs[start:end]);\n',
'ptarget += normal_id_glm_lpdf(Y[start:end] | X[start:end], alpha, b, sigma_obs[start:end]);\n'),
'return ptarget;\n',
'}\n}\n')
}
Expand Down Expand Up @@ -1677,7 +1698,7 @@ vectorise_stan_lik = function(model_file, model_data, family = 'poisson',
model_file[grep('functions {', model_file, fixed = TRUE)] <-
paste0('functions {\n',
'real partial_log_lik(int[] seq, int start, int end,\n',
'data int[] Y, vector mu, real[] shape) {\n',
'data vector Y, vector mu, real[] shape) {\n',
'real ptarget = 0;\n',
'ptarget += gamma_lpdf(Y[start:end] | shape[start:end], shape[start:end] ./ mu[start:end]);\n',
'return ptarget;\n',
Expand All @@ -1687,7 +1708,7 @@ vectorise_stan_lik = function(model_file, model_data, family = 'poisson',
paste0('// Stan model code generated by package mvgam\n',
'functions {\n',
'real partial_log_lik(int[] seq, int start, int end,\n',
'data int[] Y, vector mu, real[] shape) {\n',
'data vector Y, vector mu, real[] shape) {\n',
'real ptarget = 0;\n',
'ptarget += gamma_lpdf(Y[start:end] | shape[start:end], shape[start:end] ./ mu[start:end]);\n',
'return ptarget;\n',
Expand All @@ -1714,7 +1735,7 @@ vectorise_stan_lik = function(model_file, model_data, family = 'poisson',
'flat_ys,\n',
'append_col(flat_xs, flat_trends),\n',
'append_row(b, 1.0),\n',
'flat_sigma_obs',
'flat_sigma_obs,\n',
ifelse(offset, 'offset[obs_ind],\n);\n}\n',
'0.0);\n}\n}\n'))
} else {
Expand Down
Loading

0 comments on commit ea715c9

Please sign in to comment.