Skip to content
Merged
2 changes: 1 addition & 1 deletion hewr/DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ Description: What the package does (one paragraph).
License: Apache License (>= 2)
Encoding: UTF-8
Roxygen: list(markdown = TRUE)
RoxygenNote: 7.3.2
RoxygenNote: 7.3.3
Imports:
argparser,
cowplot,
Expand Down
4 changes: 4 additions & 0 deletions hewr/NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Generated by roxygen2: do not edit by hand

S3method(process_model_samples,pyrenew)
S3method(process_model_samples,timeseries)
export(epiweekly_samples_from_daily)
export(format_timeseries_output)
export(generate_exp_growth_pois)
Expand All @@ -14,7 +16,9 @@ export(parse_model_run_dir_path)
export(parse_pyrenew_model_name)
export(parse_variable_name)
export(path_up_to)
export(process_forecast)
export(process_loc_forecast)
export(process_model_samples)
export(prop_from_timeseries)
export(read_and_combine_data)
export(to_tidy_draws_timeseries)
Expand Down
271 changes: 254 additions & 17 deletions hewr/R/process_loc_forecast.R
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,101 @@ to_tidy_draws_timeseries <- function(
}


#' Process model samples based on model type
#'
#' S3 generic function that dispatches to appropriate method based on
#' model type. This allows extensibility for new model types without
#' changing existing code.
#'
#' @param model_type Character string indicating model type
#' ("pyrenew", "timeseries", "epiautogp")
#' @param ... Additional arguments passed to methods. See specific methods
#' for details: [process_model_samples.pyrenew()],
#' [process_model_samples.timeseries()]
#' @return Tibble of model samples
#' @export
process_model_samples <- function(model_type, ...) {
UseMethod("process_model_samples", structure(list(), class = model_type))
}

#' Process PyRenew model samples
#'
#' @param model_type Character string indicating model type
#' @param model_run_dir Model run directory
#' @param model_name Name of directory containing model outputs
#' @param ts_samples Timeseries samples (if available)
#' @param required_columns_e Required columns for output
#' @param n_forecast_days Number of forecast days
#' @param ... Additional arguments (unused)
#' @return Tibble of PyRenew model samples
#' @exportS3Method
process_model_samples.pyrenew <- function(
model_type,
model_run_dir,
model_name,
ts_samples = NULL,
required_columns_e,
n_forecast_days,
...
) {
process_pyrenew_model(
model_run_dir = model_run_dir,
pyrenew_model_name = model_name,
ts_samples = ts_samples,
required_columns_e = required_columns_e,
n_forecast_days = n_forecast_days
)
}

#' Process timeseries model samples
#'
#' @param model_type Character string indicating model type
#' @param model_run_dir Model run directory
#' @param model_name Name of directory containing model outputs
#' @param ts_samples Timeseries samples (required for this method)
#' @param required_columns_e Required columns for output
#' @param n_forecast_days Number of forecast days
#' @param ... Additional arguments (unused)
#' @return Tibble of timeseries model samples
#' @exportS3Method
process_model_samples.timeseries <- function(
model_type,
model_run_dir,
model_name,
ts_samples = NULL,
required_columns_e,
n_forecast_days,
...
) {
# For timeseries, ts_samples should already be loaded
# This is essentially a pass-through
if (is.null(ts_samples)) {
stop("ts_samples must be provided for timeseries model type")
}
ts_samples
}

#' Detect model type from model name
#'
#' Internal helper function to infer model type from naming conventions
#'
#' @param model_name Character string with model name
#' @return Character string with model type ("pyrenew", "timeseries", etc.)
#' @keywords internal
detect_model_type <- function(model_name) {
if (
grepl("^ts_", model_name) ||
grepl("timeseries", model_name, ignore.case = TRUE)
) {
return("timeseries")
} else if (grepl("epiautogp", model_name, ignore.case = TRUE)) {
return("epiautogp")
} else {
# Default to pyrenew for backward compatibility
return("pyrenew")
}
}

process_pyrenew_model <- function(
model_run_dir,
pyrenew_model_name,
Expand Down Expand Up @@ -382,38 +477,52 @@ process_pyrenew_model <- function(
#' Process loc forecast
#'
#' @param model_run_dir Model run directory
#' @param n_forecast_days An integer specifying the number of days to forecast.
#' @param model_name Name of directory containing model outputs.
#' If provided, uses new S3 dispatch interface (overrides
#' pyrenew_model_name/timeseries_model_name).
#' @param pyrenew_model_name Name of directory containing pyrenew
#' model outputs
#' model outputs (legacy interface)
#' @param timeseries_model_name Name of directory containing timeseries
#' model outputs
#' @param n_forecast_days An integer specifying the number of days to forecast.
#' model outputs (legacy interface)
#' @param model_type Optional character string specifying model type
#' explicitly. Only used with model_name parameter.
#' If NULL (default), will auto-detect from model_name.
#' Options: "pyrenew", "timeseries", "epiautogp"
#' @param ci_widths Vector of probabilities indicating one or more
#' central credible intervals to compute. Passed as the `.width`
#' argument to [ggdist::median_qi()]. Default `c(0.5, 0.8, 0.95)`.
#' central credible intervals to compute. Passed as the `.width`
#' argument to [ggdist::median_qi()]. Default `c(0.5, 0.8, 0.95)`.
#' @param save Boolean indicating whether or not to save the output
#' to parquet files. Default `TRUE`.
#' @return a list of 8 tibbles:
#' `daily_combined_training_eval_data`,
#' `epiweekly_combined_training_eval_data`,
#' `daily_samples`,
#' `epiweekly_samples`,
#' `epiweekly_with_epiweekly_other_samples`,
#' `daily_ci`,
#' `epiweekly_ci`,
#' `epiweekly_with_epiweekly_other_ci`
#' to parquet files. Default `TRUE`.
#' @return a list of 2 tibbles: `samples` and `ci`
#' @export
process_loc_forecast <- function(
model_run_dir,
n_forecast_days,
model_name = NA,
pyrenew_model_name = NA,
timeseries_model_name = NA,
model_type = NULL,
ci_widths = c(0.5, 0.8, 0.95),
save = TRUE
) {
# New interface: delegate to process_forecast
if (!is.na(model_name)) {
return(process_forecast(
model_run_dir = model_run_dir,
model_name = model_name,
n_forecast_days = n_forecast_days,
model_type = model_type,
ci_widths = ci_widths,
save = save
))
}

# Legacy interface validation
if (is.na(pyrenew_model_name) && is.na(timeseries_model_name)) {
stop(
"Either `pyrenew_model_name` or `timeseries_model_name`",
"must be provided."
"Either `model_name` or `pyrenew_model_name`/",
"`timeseries_model_name` must be provided."
)
}
model_name <- dplyr::if_else(
Expand Down Expand Up @@ -518,3 +627,131 @@ process_loc_forecast <- function(

return(result)
}

#' Process location forecast with generic model name
#'
#' Simplified version of process_loc_forecast that accepts a single model_name
#' parameter and auto-detects the model type. Uses S3 dispatch internally for
#' extensibility.
#'
#' @param model_run_dir Model run directory
#' @param model_name Name of directory containing model outputs
#' @param n_forecast_days An integer specifying the number of days to forecast.
#' @param model_type Optional character string specifying model type explicitly.
#' If NULL (default), will auto-detect from model_name.
#' Options: "pyrenew", "timeseries", "epiautogp"
#' @param ci_widths Vector of probabilities indicating one or more
#' central credible intervals to compute. Passed as the `.width`
#' argument to [ggdist::median_qi()]. Default `c(0.5, 0.8, 0.95)`.
#' @param save Boolean indicating whether or not to save the output
#' to parquet files. Default `TRUE`.
#' @return a list of 2 tibbles: `samples` and `ci`
#' @export
process_forecast <- function(
model_run_dir,
model_name,
n_forecast_days,
model_type = NULL,
ci_widths = c(0.5, 0.8, 0.95),
save = TRUE
) {
# Auto-detect model type if not specified
if (is.null(model_type)) {
model_type <- detect_model_type(model_name)
}

model_dir <- fs::path(model_run_dir, model_name)
data_col_types <- readr::cols(
date = readr::col_date(),
geo_value = readr::col_character(),
disease = readr::col_character(),
data_type = readr::col_character(),
.variable = readr::col_character(),
.value = readr::col_double()
)

# Load training data
daily_training_dat <- readr::read_tsv(
fs::path(model_dir, "data", "combined_training_data", ext = "tsv"),
col_types = data_col_types
)

epiweekly_training_dat <- readr::read_tsv(
fs::path(
model_dir,
"data",
"epiweekly_combined_training_data",
ext = "tsv"
),
col_types = data_col_types
)

required_columns_e <- c(
".chain",
".iteration",
".draw",
"date",
"geo_value",
"disease",
".variable",
".value",
"resolution",
"aggregated_numerator",
"aggregated_denominator"
)

# Load timeseries samples if needed (for any model that might use them)
ts_samples <- NULL
if (model_type %in% c("timeseries", "pyrenew")) {
# Check if timeseries model exists
ts_model_dirs <- fs::dir_ls(
model_run_dir,
regexp = "(^ts_|timeseries)",
type = "directory"
)

if (length(ts_model_dirs) > 0) {
ts_model_name <- fs::path_file(ts_model_dirs[1])
ts_samples <- load_and_aggregate_ts(
model_run_dir,
ts_model_name,
daily_training_dat,
epiweekly_training_dat,
required_columns = required_columns_e
)
}
}

# Dispatch to appropriate S3 method
model_samples_tidy <- process_model_samples(
model_type = model_type,
model_run_dir = model_run_dir,
model_name = model_name,
ts_samples = ts_samples,
required_columns_e = required_columns_e,
n_forecast_days = n_forecast_days
)

# Calculate credible intervals
ci <- model_samples_tidy |>
dplyr::select(-tidyselect::any_of(c(".chain", ".iteration", ".draw"))) |>
dplyr::group_by(dplyr::across(-".value")) |>
ggdist::median_qi(.width = ci_widths)

result <- list(
"samples" = model_samples_tidy,
"ci" = ci
)

if (save) {
save_dir <- fs::path(model_run_dir, model_name)
purrr::iwalk(result, \(tab, name) {
forecasttools::write_tabular(
tab,
fs::path(save_dir, name, ext = "parquet")
)
})
}

return(result)
}
18 changes: 18 additions & 0 deletions hewr/man/detect_model_type.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

41 changes: 41 additions & 0 deletions hewr/man/process_forecast.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading