Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion hewr/DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ Description: What the package does (one paragraph).
License: Apache License (>= 2)
Encoding: UTF-8
Roxygen: list(markdown = TRUE)
RoxygenNote: 7.3.2
RoxygenNote: 7.3.3
Imports:
argparser,
cowplot,
Expand Down
4 changes: 4 additions & 0 deletions hewr/NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Generated by roxygen2: do not edit by hand

S3method(process_model_samples,pyrenew)
S3method(process_model_samples,timeseries)
export(epiweekly_samples_from_daily)
export(format_timeseries_output)
export(generate_exp_growth_pois)
Expand All @@ -14,7 +16,9 @@ export(parse_model_run_dir_path)
export(parse_pyrenew_model_name)
export(parse_variable_name)
export(path_up_to)
export(process_forecast)
export(process_loc_forecast)
export(process_model_samples)
export(prop_from_timeseries)
export(read_and_combine_data)
export(to_tidy_draws_timeseries)
Expand Down
271 changes: 254 additions & 17 deletions hewr/R/process_loc_forecast.R
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,101 @@ to_tidy_draws_timeseries <- function(
}


#' Process model samples based on model type
#'
#' S3 generic function that dispatches to appropriate method based on
#' model type. This allows extensibility for new model types without
#' changing existing code.
#'
#' @param model_type Character string indicating model type
#' ("pyrenew", "timeseries", "epiautogp")
#' @param ... Additional arguments passed to methods. See specific methods
#' for details: [process_model_samples.pyrenew()],
#' [process_model_samples.timeseries()]
#' @return Tibble of model samples
#' @export
process_model_samples <- function(model_type, ...) {
UseMethod("process_model_samples", structure(list(), class = model_type))
}

#' Process PyRenew model samples
#'
#' @param model_type Character string indicating model type
#' @param model_run_dir Model run directory
#' @param model_name Name of directory containing model outputs
#' @param ts_samples Timeseries samples (if available)
#' @param required_columns_e Required columns for output
#' @param n_forecast_days Number of forecast days
#' @param ... Additional arguments (unused)
#' @return Tibble of PyRenew model samples
#' @exportS3Method
process_model_samples.pyrenew <- function(
model_type,
model_run_dir,
model_name,
ts_samples = NULL,
required_columns_e,
n_forecast_days,
...
) {
process_pyrenew_model(
model_run_dir = model_run_dir,
pyrenew_model_name = model_name,
ts_samples = ts_samples,
required_columns_e = required_columns_e,
n_forecast_days = n_forecast_days
)
}

#' Process timeseries model samples
#'
#' @param model_type Character string indicating model type
#' @param model_run_dir Model run directory
#' @param model_name Name of directory containing model outputs
#' @param ts_samples Timeseries samples (required for this method)
#' @param required_columns_e Required columns for output
#' @param n_forecast_days Number of forecast days
#' @param ... Additional arguments (unused)
#' @return Tibble of timeseries model samples
#' @exportS3Method
process_model_samples.timeseries <- function(
model_type,
model_run_dir,
model_name,
ts_samples = NULL,
required_columns_e,
n_forecast_days,
...
) {
# For timeseries, ts_samples should already be loaded
# This is essentially a pass-through
if (is.null(ts_samples)) {
stop("ts_samples must be provided for timeseries model type")
}
ts_samples
}

#' Detect model type from model name
#'
#' Internal helper function to infer model type from naming conventions
#'
#' @param model_name Character string with model name
#' @return Character string with model type ("pyrenew", "timeseries", etc.)
#' @keywords internal
detect_model_type <- function(model_name) {
if (
grepl("^ts_", model_name) ||
grepl("timeseries", model_name, ignore.case = TRUE)
) {
return("timeseries")
} else if (grepl("epiautogp", model_name, ignore.case = TRUE)) {
return("epiautogp")
} else {
# Default to pyrenew for backward compatibility
return("pyrenew")
}
}

process_pyrenew_model <- function(
model_run_dir,
pyrenew_model_name,
Expand Down Expand Up @@ -382,38 +477,52 @@ process_pyrenew_model <- function(
#' Process loc forecast
#'
#' @param model_run_dir Model run directory
#' @param n_forecast_days An integer specifying the number of days to forecast.
#' @param model_name Name of directory containing model outputs.
#' If provided, uses new S3 dispatch interface (overrides
#' pyrenew_model_name/timeseries_model_name).
#' @param pyrenew_model_name Name of directory containing pyrenew
#' model outputs
#' model outputs (legacy interface)
#' @param timeseries_model_name Name of directory containing timeseries
#' model outputs
#' @param n_forecast_days An integer specifying the number of days to forecast.
#' model outputs (legacy interface)
#' @param model_type Optional character string specifying model type
#' explicitly. Only used with model_name parameter.
#' If NULL (default), will auto-detect from model_name.
#' Options: "pyrenew", "timeseries", "epiautogp"
#' @param ci_widths Vector of probabilities indicating one or more
#' central credible intervals to compute. Passed as the `.width`
#' argument to [ggdist::median_qi()]. Default `c(0.5, 0.8, 0.95)`.
#' central credible intervals to compute. Passed as the `.width`
#' argument to [ggdist::median_qi()]. Default `c(0.5, 0.8, 0.95)`.
#' @param save Boolean indicating whether or not to save the output
#' to parquet files. Default `TRUE`.
#' @return a list of 8 tibbles:
#' `daily_combined_training_eval_data`,
#' `epiweekly_combined_training_eval_data`,
#' `daily_samples`,
#' `epiweekly_samples`,
#' `epiweekly_with_epiweekly_other_samples`,
#' `daily_ci`,
#' `epiweekly_ci`,
#' `epiweekly_with_epiweekly_other_ci`
#' to parquet files. Default `TRUE`.
#' @return a list of 2 tibbles: `samples` and `ci`
#' @export
process_loc_forecast <- function(
model_run_dir,
n_forecast_days,
model_name = NA,
pyrenew_model_name = NA,
timeseries_model_name = NA,
model_type = NULL,
ci_widths = c(0.5, 0.8, 0.95),
save = TRUE
) {
# New interface: delegate to process_forecast
if (!is.na(model_name)) {
return(process_forecast(
model_run_dir = model_run_dir,
model_name = model_name,
n_forecast_days = n_forecast_days,
model_type = model_type,
ci_widths = ci_widths,
save = save
))
}

# Legacy interface validation
if (is.na(pyrenew_model_name) && is.na(timeseries_model_name)) {
stop(
"Either `pyrenew_model_name` or `timeseries_model_name`",
"must be provided."
"Either `model_name` or `pyrenew_model_name`/",
"`timeseries_model_name` must be provided."
)
}
model_name <- dplyr::if_else(
Expand Down Expand Up @@ -518,3 +627,131 @@ process_loc_forecast <- function(

return(result)
}

#' Process location forecast with generic model name
#'
#' Simplified version of process_loc_forecast that accepts a single model_name
#' parameter and auto-detects the model type. Uses S3 dispatch internally for
#' extensibility.
#'
#' @param model_run_dir Model run directory
#' @param model_name Name of directory containing model outputs
#' @param n_forecast_days An integer specifying the number of days to forecast.
#' @param model_type Optional character string specifying model type explicitly.
#' If NULL (default), will auto-detect from model_name.
#' Options: "pyrenew", "timeseries", "epiautogp"
#' @param ci_widths Vector of probabilities indicating one or more
#' central credible intervals to compute. Passed as the `.width`
#' argument to [ggdist::median_qi()]. Default `c(0.5, 0.8, 0.95)`.
#' @param save Boolean indicating whether or not to save the output
#' to parquet files. Default `TRUE`.
#' @return a list of 2 tibbles: `samples` and `ci`
#' @export
process_forecast <- function(
model_run_dir,
model_name,
n_forecast_days,
model_type = NULL,
ci_widths = c(0.5, 0.8, 0.95),
save = TRUE
) {
# Auto-detect model type if not specified
if (is.null(model_type)) {
model_type <- detect_model_type(model_name)
}

model_dir <- fs::path(model_run_dir, model_name)
data_col_types <- readr::cols(
date = readr::col_date(),
geo_value = readr::col_character(),
disease = readr::col_character(),
data_type = readr::col_character(),
.variable = readr::col_character(),
.value = readr::col_double()
)

# Load training data
daily_training_dat <- readr::read_tsv(
fs::path(model_dir, "data", "combined_training_data", ext = "tsv"),
col_types = data_col_types
)

epiweekly_training_dat <- readr::read_tsv(
fs::path(
model_dir,
"data",
"epiweekly_combined_training_data",
ext = "tsv"
),
col_types = data_col_types
)

required_columns_e <- c(
".chain",
".iteration",
".draw",
"date",
"geo_value",
"disease",
".variable",
".value",
"resolution",
"aggregated_numerator",
"aggregated_denominator"
)

# Load timeseries samples if needed (for any model that might use them)
ts_samples <- NULL
if (model_type %in% c("timeseries", "pyrenew")) {
# Check if timeseries model exists
ts_model_dirs <- fs::dir_ls(
model_run_dir,
regexp = "(^ts_|timeseries)",
type = "directory"
)

if (length(ts_model_dirs) > 0) {
ts_model_name <- fs::path_file(ts_model_dirs[1])
ts_samples <- load_and_aggregate_ts(
model_run_dir,
ts_model_name,
daily_training_dat,
epiweekly_training_dat,
required_columns = required_columns_e
)
}
}

# Dispatch to appropriate S3 method
model_samples_tidy <- process_model_samples(
model_type = model_type,
model_run_dir = model_run_dir,
model_name = model_name,
ts_samples = ts_samples,
required_columns_e = required_columns_e,
n_forecast_days = n_forecast_days
)

# Calculate credible intervals
ci <- model_samples_tidy |>
dplyr::select(-tidyselect::any_of(c(".chain", ".iteration", ".draw"))) |>
dplyr::group_by(dplyr::across(-".value")) |>
ggdist::median_qi(.width = ci_widths)

result <- list(
"samples" = model_samples_tidy,
"ci" = ci
)

if (save) {
save_dir <- fs::path(model_run_dir, model_name)
purrr::iwalk(result, \(tab, name) {
forecasttools::write_tabular(
tab,
fs::path(save_dir, name, ext = "parquet")
)
})
}

return(result)
}
18 changes: 18 additions & 0 deletions hewr/man/detect_model_type.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

41 changes: 41 additions & 0 deletions hewr/man/process_forecast.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading