diff --git a/hewr/DESCRIPTION b/hewr/DESCRIPTION index 4a451dd7..8d15ab91 100644 --- a/hewr/DESCRIPTION +++ b/hewr/DESCRIPTION @@ -7,7 +7,7 @@ Description: What the package does (one paragraph). License: Apache License (>= 2) Encoding: UTF-8 Roxygen: list(markdown = TRUE) -RoxygenNote: 7.3.2 +RoxygenNote: 7.3.3 Imports: argparser, cowplot, diff --git a/hewr/NAMESPACE b/hewr/NAMESPACE index 316f2573..83d59c73 100644 --- a/hewr/NAMESPACE +++ b/hewr/NAMESPACE @@ -1,5 +1,7 @@ # Generated by roxygen2: do not edit by hand +S3method(process_model_samples,pyrenew) +S3method(process_model_samples,timeseries) export(epiweekly_samples_from_daily) export(format_timeseries_output) export(generate_exp_growth_pois) @@ -14,7 +16,9 @@ export(parse_model_run_dir_path) export(parse_pyrenew_model_name) export(parse_variable_name) export(path_up_to) +export(process_forecast) export(process_loc_forecast) +export(process_model_samples) export(prop_from_timeseries) export(read_and_combine_data) export(to_tidy_draws_timeseries) diff --git a/hewr/R/process_loc_forecast.R b/hewr/R/process_loc_forecast.R index 4c2368a9..73e4a108 100644 --- a/hewr/R/process_loc_forecast.R +++ b/hewr/R/process_loc_forecast.R @@ -278,6 +278,101 @@ to_tidy_draws_timeseries <- function( } +#' Process model samples based on model type +#' +#' S3 generic function that dispatches to appropriate method based on +#' model type. This allows extensibility for new model types without +#' changing existing code. +#' +#' @param model_type Character string indicating model type +#' ("pyrenew", "timeseries", "epiautogp") +#' @param ... Additional arguments passed to methods. See specific methods +#' for details: [process_model_samples.pyrenew()], +#' [process_model_samples.timeseries()] +#' @return Tibble of model samples +#' @export +process_model_samples <- function(model_type, ...) { + UseMethod("process_model_samples", structure(list(), class = model_type)) +} + +#' Process PyRenew model samples +#' +#' @param model_type Character string indicating model type +#' @param model_run_dir Model run directory +#' @param model_name Name of directory containing model outputs +#' @param ts_samples Timeseries samples (if available) +#' @param required_columns_e Required columns for output +#' @param n_forecast_days Number of forecast days +#' @param ... Additional arguments (unused) +#' @return Tibble of PyRenew model samples +#' @exportS3Method +process_model_samples.pyrenew <- function( + model_type, + model_run_dir, + model_name, + ts_samples = NULL, + required_columns_e, + n_forecast_days, + ... +) { + process_pyrenew_model( + model_run_dir = model_run_dir, + pyrenew_model_name = model_name, + ts_samples = ts_samples, + required_columns_e = required_columns_e, + n_forecast_days = n_forecast_days + ) +} + +#' Process timeseries model samples +#' +#' @param model_type Character string indicating model type +#' @param model_run_dir Model run directory +#' @param model_name Name of directory containing model outputs +#' @param ts_samples Timeseries samples (required for this method) +#' @param required_columns_e Required columns for output +#' @param n_forecast_days Number of forecast days +#' @param ... Additional arguments (unused) +#' @return Tibble of timeseries model samples +#' @exportS3Method +process_model_samples.timeseries <- function( + model_type, + model_run_dir, + model_name, + ts_samples = NULL, + required_columns_e, + n_forecast_days, + ... +) { + # For timeseries, ts_samples should already be loaded + # This is essentially a pass-through + if (is.null(ts_samples)) { + stop("ts_samples must be provided for timeseries model type") + } + ts_samples +} + +#' Detect model type from model name +#' +#' Internal helper function to infer model type from naming conventions +#' +#' @param model_name Character string with model name +#' @return Character string with model type ("pyrenew", "timeseries", etc.) +#' @keywords internal +detect_model_type <- function(model_name) { + if ( + grepl("^ts_", model_name) || + grepl("timeseries", model_name, ignore.case = TRUE) + ) { + return("timeseries") + } else if (grepl("epiautogp", model_name, ignore.case = TRUE)) { + return("epiautogp") + } else { + # Default to pyrenew for backward compatibility + return("pyrenew") + } +} + process_pyrenew_model <- function( model_run_dir, pyrenew_model_name, @@ -382,38 +477,52 @@ process_pyrenew_model <- function( #' Process loc forecast #' #' @param model_run_dir Model run directory +#' @param n_forecast_days An integer specifying the number of days to forecast. +#' @param model_name Name of directory containing model outputs. +#' If provided, uses new S3 dispatch interface (overrides +#' pyrenew_model_name/timeseries_model_name). #' @param pyrenew_model_name Name of directory containing pyrenew -#' model outputs +#' model outputs (legacy interface) #' @param timeseries_model_name Name of directory containing timeseries -#' model outputs -#' @param n_forecast_days An integer specifying the number of days to forecast. +#' model outputs (legacy interface) +#' @param model_type Optional character string specifying model type +#' explicitly. Only used with model_name parameter. +#' If NULL (default), will auto-detect from model_name. +#' Options: "pyrenew", "timeseries", "epiautogp" #' @param ci_widths Vector of probabilities indicating one or more -#' central credible intervals to compute. Passed as the `.width` -#' argument to [ggdist::median_qi()]. Default `c(0.5, 0.8, 0.95)`. +#' central credible intervals to compute. Passed as the `.width` +#' argument to [ggdist::median_qi()]. Default `c(0.5, 0.8, 0.95)`. #' @param save Boolean indicating whether or not to save the output -#' to parquet files. Default `TRUE`. -#' @return a list of 8 tibbles: -#' `daily_combined_training_eval_data`, -#' `epiweekly_combined_training_eval_data`, -#' `daily_samples`, -#' `epiweekly_samples`, -#' `epiweekly_with_epiweekly_other_samples`, -#' `daily_ci`, -#' `epiweekly_ci`, -#' `epiweekly_with_epiweekly_other_ci` +#' to parquet files. Default `TRUE`. +#' @return a list of 2 tibbles: `samples` and `ci` #' @export process_loc_forecast <- function( model_run_dir, n_forecast_days, + model_name = NA, pyrenew_model_name = NA, timeseries_model_name = NA, + model_type = NULL, ci_widths = c(0.5, 0.8, 0.95), save = TRUE ) { + # New interface: delegate to process_forecast + if (!is.na(model_name)) { + return(process_forecast( + model_run_dir = model_run_dir, + model_name = model_name, + n_forecast_days = n_forecast_days, + model_type = model_type, + ci_widths = ci_widths, + save = save + )) + } + + # Legacy interface validation if (is.na(pyrenew_model_name) && is.na(timeseries_model_name)) { stop( - "Either `pyrenew_model_name` or `timeseries_model_name`", - "must be provided." + "Either `model_name` or `pyrenew_model_name`/", + "`timeseries_model_name` must be provided." ) } model_name <- dplyr::if_else( @@ -518,3 +627,131 @@ process_loc_forecast <- function( return(result) } + +#' Process location forecast with generic model name +#' +#' Simplified version of process_loc_forecast that accepts a single model_name +#' parameter and auto-detects the model type. Uses S3 dispatch internally for +#' extensibility. +#' +#' @param model_run_dir Model run directory +#' @param model_name Name of directory containing model outputs +#' @param n_forecast_days An integer specifying the number of days to forecast. +#' @param model_type Optional character string specifying model type explicitly. +#' If NULL (default), will auto-detect from model_name. +#' Options: "pyrenew", "timeseries", "epiautogp" +#' @param ci_widths Vector of probabilities indicating one or more +#' central credible intervals to compute. Passed as the `.width` +#' argument to [ggdist::median_qi()]. Default `c(0.5, 0.8, 0.95)`. +#' @param save Boolean indicating whether or not to save the output +#' to parquet files. Default `TRUE`. +#' @return a list of 2 tibbles: `samples` and `ci` +#' @export +process_forecast <- function( + model_run_dir, + model_name, + n_forecast_days, + model_type = NULL, + ci_widths = c(0.5, 0.8, 0.95), + save = TRUE +) { + # Auto-detect model type if not specified + if (is.null(model_type)) { + model_type <- detect_model_type(model_name) + } + + model_dir <- fs::path(model_run_dir, model_name) + data_col_types <- readr::cols( + date = readr::col_date(), + geo_value = readr::col_character(), + disease = readr::col_character(), + data_type = readr::col_character(), + .variable = readr::col_character(), + .value = readr::col_double() + ) + + # Load training data + daily_training_dat <- readr::read_tsv( + fs::path(model_dir, "data", "combined_training_data", ext = "tsv"), + col_types = data_col_types + ) + + epiweekly_training_dat <- readr::read_tsv( + fs::path( + model_dir, + "data", + "epiweekly_combined_training_data", + ext = "tsv" + ), + col_types = data_col_types + ) + + required_columns_e <- c( + ".chain", + ".iteration", + ".draw", + "date", + "geo_value", + "disease", + ".variable", + ".value", + "resolution", + "aggregated_numerator", + "aggregated_denominator" + ) + + # Load timeseries samples if needed (for any model that might use them) + ts_samples <- NULL + if (model_type %in% c("timeseries", "pyrenew")) { + # Check if timeseries model exists + ts_model_dirs <- fs::dir_ls( + model_run_dir, + regexp = "(^ts_|timeseries)", + type = "directory" + ) + + if (length(ts_model_dirs) > 0) { + ts_model_name <- fs::path_file(ts_model_dirs[1]) + ts_samples <- load_and_aggregate_ts( + model_run_dir, + ts_model_name, + daily_training_dat, + epiweekly_training_dat, + required_columns = required_columns_e + ) + } + } + + # Dispatch to appropriate S3 method + model_samples_tidy <- process_model_samples( + model_type = model_type, + model_run_dir = model_run_dir, + model_name = model_name, + ts_samples = ts_samples, + required_columns_e = required_columns_e, + n_forecast_days = n_forecast_days + ) + + # Calculate credible intervals + ci <- model_samples_tidy |> + dplyr::select(-tidyselect::any_of(c(".chain", ".iteration", ".draw"))) |> + dplyr::group_by(dplyr::across(-".value")) |> + ggdist::median_qi(.width = ci_widths) + + result <- list( + "samples" = model_samples_tidy, + "ci" = ci + ) + + if (save) { + save_dir <- fs::path(model_run_dir, model_name) + purrr::iwalk(result, \(tab, name) { + forecasttools::write_tabular( + tab, + fs::path(save_dir, name, ext = "parquet") + ) + }) + } + + return(result) +} diff --git a/hewr/man/detect_model_type.Rd b/hewr/man/detect_model_type.Rd new file mode 100644 index 00000000..bb3cb446 --- /dev/null +++ b/hewr/man/detect_model_type.Rd @@ -0,0 +1,18 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/process_loc_forecast.R +\name{detect_model_type} +\alias{detect_model_type} +\title{Detect model type from model name} +\usage{ +detect_model_type(model_name) +} +\arguments{ +\item{model_name}{Character string with model name} +} +\value{ +Character string with model type ("pyrenew", "timeseries", etc.) +} +\description{ +Internal helper function to infer model type from naming conventions +} +\keyword{internal} diff --git a/hewr/man/process_forecast.Rd b/hewr/man/process_forecast.Rd new file mode 100644 index 00000000..58209728 --- /dev/null +++ b/hewr/man/process_forecast.Rd @@ -0,0 +1,41 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/process_loc_forecast.R +\name{process_forecast} +\alias{process_forecast} +\title{Process location forecast with generic model name} +\usage{ +process_forecast( + model_run_dir, + model_name, + n_forecast_days, + model_type = NULL, + ci_widths = c(0.5, 0.8, 0.95), + save = TRUE +) +} +\arguments{ +\item{model_run_dir}{Model run directory} + +\item{model_name}{Name of directory containing model outputs} + +\item{n_forecast_days}{An integer specifying the number of days to forecast.} + +\item{model_type}{Optional character string specifying model type explicitly. +If NULL (default), will auto-detect from model_name. +Options: "pyrenew", "timeseries", "epiautogp"} + +\item{ci_widths}{Vector of probabilities indicating one or more +central credible intervals to compute. Passed as the \code{.width} +argument to \code{\link[ggdist:point_interval]{ggdist::median_qi()}}. Default \code{c(0.5, 0.8, 0.95)}.} + +\item{save}{Boolean indicating whether or not to save the output +to parquet files. Default \code{TRUE}.} +} +\value{ +a list of 2 tibbles: \code{samples} and \code{ci} +} +\description{ +Simplified version of process_loc_forecast that accepts a single model_name +parameter and auto-detects the model type. Uses S3 dispatch internally for +extensibility. +} diff --git a/hewr/man/process_loc_forecast.Rd b/hewr/man/process_loc_forecast.Rd index 6787fc3a..260b690e 100644 --- a/hewr/man/process_loc_forecast.Rd +++ b/hewr/man/process_loc_forecast.Rd @@ -7,8 +7,10 @@ process_loc_forecast( model_run_dir, n_forecast_days, + model_name = NA, pyrenew_model_name = NA, timeseries_model_name = NA, + model_type = NULL, ci_widths = c(0.5, 0.8, 0.95), save = TRUE ) @@ -18,11 +20,20 @@ process_loc_forecast( \item{n_forecast_days}{An integer specifying the number of days to forecast.} +\item{model_name}{Name of directory containing model outputs. +If provided, uses new S3 dispatch interface (overrides +pyrenew_model_name/timeseries_model_name).} + \item{pyrenew_model_name}{Name of directory containing pyrenew -model outputs} +model outputs (legacy interface)} \item{timeseries_model_name}{Name of directory containing timeseries -model outputs} +model outputs (legacy interface)} + +\item{model_type}{Optional character string specifying model type +explicitly. Only used with model_name parameter. +If NULL (default), will auto-detect from model_name. +Options: "pyrenew", "timeseries", "epiautogp"} \item{ci_widths}{Vector of probabilities indicating one or more central credible intervals to compute. Passed as the \code{.width} @@ -32,15 +43,7 @@ argument to \code{\link[ggdist:point_interval]{ggdist::median_qi()}}. Default \c to parquet files. Default \code{TRUE}.} } \value{ -a list of 8 tibbles: -\code{daily_combined_training_eval_data}, -\code{epiweekly_combined_training_eval_data}, -\code{daily_samples}, -\code{epiweekly_samples}, -\code{epiweekly_with_epiweekly_other_samples}, -\code{daily_ci}, -\code{epiweekly_ci}, -\code{epiweekly_with_epiweekly_other_ci} +a list of 2 tibbles: \code{samples} and \code{ci} } \description{ Process loc forecast diff --git a/hewr/man/process_model_samples.Rd b/hewr/man/process_model_samples.Rd new file mode 100644 index 00000000..9941a8a0 --- /dev/null +++ b/hewr/man/process_model_samples.Rd @@ -0,0 +1,24 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/process_loc_forecast.R +\name{process_model_samples} +\alias{process_model_samples} +\title{Process model samples based on model type} +\usage{ +process_model_samples(model_type, ...) +} +\arguments{ +\item{model_type}{Character string indicating model type +("pyrenew", "timeseries", "epiautogp")} + +\item{...}{Additional arguments passed to methods. See specific methods +for details: \code{\link[=process_model_samples.pyrenew]{process_model_samples.pyrenew()}}, +\code{\link[=process_model_samples.timeseries]{process_model_samples.timeseries()}}} +} +\value{ +Tibble of model samples +} +\description{ +S3 generic function that dispatches to appropriate method based on +model type. This allows extensibility for new model types without +changing existing code. +} diff --git a/hewr/man/process_model_samples.pyrenew.Rd b/hewr/man/process_model_samples.pyrenew.Rd new file mode 100644 index 00000000..f92c8577 --- /dev/null +++ b/hewr/man/process_model_samples.pyrenew.Rd @@ -0,0 +1,37 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/process_loc_forecast.R +\name{process_model_samples.pyrenew} +\alias{process_model_samples.pyrenew} +\title{Process PyRenew model samples} +\usage{ +\method{process_model_samples}{pyrenew}( + model_type, + model_run_dir, + model_name, + ts_samples = NULL, + required_columns_e, + n_forecast_days, + ... +) +} +\arguments{ +\item{model_type}{Character string indicating model type} + +\item{model_run_dir}{Model run directory} + +\item{model_name}{Name of directory containing model outputs} + +\item{ts_samples}{Timeseries samples (if available)} + +\item{required_columns_e}{Required columns for output} + +\item{n_forecast_days}{Number of forecast days} + +\item{...}{Additional arguments (unused)} +} +\value{ +Tibble of PyRenew model samples +} +\description{ +Process PyRenew model samples +} diff --git a/hewr/man/process_model_samples.timeseries.Rd b/hewr/man/process_model_samples.timeseries.Rd new file mode 100644 index 00000000..40edc46b --- /dev/null +++ b/hewr/man/process_model_samples.timeseries.Rd @@ -0,0 +1,37 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/process_loc_forecast.R +\name{process_model_samples.timeseries} +\alias{process_model_samples.timeseries} +\title{Process timeseries model samples} +\usage{ +\method{process_model_samples}{timeseries}( + model_type, + model_run_dir, + model_name, + ts_samples = NULL, + required_columns_e, + n_forecast_days, + ... +) +} +\arguments{ +\item{model_type}{Character string indicating model type} + +\item{model_run_dir}{Model run directory} + +\item{model_name}{Name of directory containing model outputs} + +\item{ts_samples}{Timeseries samples (required for this method)} + +\item{required_columns_e}{Required columns for output} + +\item{n_forecast_days}{Number of forecast days} + +\item{...}{Additional arguments (unused)} +} +\value{ +Tibble of timeseries model samples +} +\description{ +Process timeseries model samples +} diff --git a/hewr/tests/testthat/test_process_loc_forecast.R b/hewr/tests/testthat/test_process_loc_forecast.R index 48660ee9..434c4c11 100644 --- a/hewr/tests/testthat/test_process_loc_forecast.R +++ b/hewr/tests/testthat/test_process_loc_forecast.R @@ -1,3 +1,20 @@ +# Common test fixtures +fake_dir <- "/fake/dir" +minimal_required_columns <- c("date", ".value") +full_required_columns <- c( + ".chain", + ".iteration", + ".draw", + "date", + "geo_value", + "disease", + ".variable", + ".value", + "resolution", + "aggregated_numerator", + "aggregated_denominator" +) + example_train_dat <- tibble::tibble( geo_value = "CA", disease = "COVID-19", @@ -81,3 +98,154 @@ test_that("to_tidy_draws_timeseries() works as expected", { expect_equal(result, expected) }) + +test_that("detect_model_type() correctly identifies model types", { + # Test timeseries detection with ts_ prefix + expect_equal(detect_model_type("ts_model_v1"), "timeseries") + expect_equal(detect_model_type("ts_ensemble"), "timeseries") + + # Test timeseries detection with name containing "timeseries" + expect_equal(detect_model_type("timeseries_model"), "timeseries") + expect_equal(detect_model_type("my_TimeSeries_model"), "timeseries") + + # Test epiautogp detection + expect_equal(detect_model_type("epiautogp_model"), "epiautogp") + expect_equal(detect_model_type("EpiAutoGP_v2"), "epiautogp") + + # Test default to pyrenew + expect_equal(detect_model_type("pyrenew_hew"), "pyrenew") + expect_equal(detect_model_type("pyrenew_e"), "pyrenew") + expect_equal(detect_model_type("some_other_model"), "pyrenew") +}) + +test_that("process_model_samples S3 dispatch works correctly", { + # Test that the generic function exists and has the right class + expect_true(is.function(process_model_samples)) + + # Test that methods exist for expected classes + expect_true( + "process_model_samples.pyrenew" %in% + methods("process_model_samples") + ) + expect_true( + "process_model_samples.timeseries" %in% + methods("process_model_samples") + ) +}) + +test_that("process_model_samples.timeseries validates ts_samples", { + # Should error when ts_samples is NULL + expect_error( + process_model_samples.timeseries( + model_type = "timeseries", + model_run_dir = fake_dir, + model_name = "ts_model", + ts_samples = NULL, + required_columns_e = minimal_required_columns, + n_forecast_days = 7 + ), + "ts_samples must be provided for timeseries model type" + ) +}) + +test_that("process_model_samples.timeseries returns ts_samples", { + # Create mock ts_samples + mock_ts_samples <- tibble::tibble( + .chain = 1, + .iteration = 1, + .draw = 1, + date = as.Date("2024-01-01"), + geo_value = "US", + disease = "COVID-19", + .variable = "other_ed_visits", + .value = 100, + resolution = "daily", + aggregated_numerator = FALSE, + aggregated_denominator = NA + ) + + result <- process_model_samples.timeseries( + model_type = "timeseries", + model_run_dir = fake_dir, + model_name = "ts_model", + ts_samples = mock_ts_samples, + required_columns_e = minimal_required_columns, + n_forecast_days = 7 + ) + + # Should return the ts_samples unchanged + expect_equal(result, mock_ts_samples) +}) + +test_that("process_model_samples.pyrenew dispatches correctly", { + # This test just verifies the S3 method exists and dispatches + # We expect it to error since we're not providing real data/files + # The key is that it calls the method (for code coverage) + + expect_error( + process_model_samples.pyrenew( + model_type = "pyrenew", + model_run_dir = "any_path", + model_name = "pyrenew_h", + ts_samples = NULL, + required_columns_e = minimal_required_columns, + n_forecast_days = 7 + ) + ) + + # Verify the method exists + expect_true( + "process_model_samples.pyrenew" %in% methods("process_model_samples") + ) +}) + +test_that("process_loc_forecast delegates correctly", { + # Test that process_loc_forecast calls process_forecast when + # model_name is provided by checking that it doesn't use the + # legacy code path + + # Create a simple test: when model_name is provided, function + # should attempt to call process_forecast, which will try to + # read files. When model_name is NA, it uses the legacy path + # with different error message + + # Test with model_name provided (new interface) + expect_error( + process_loc_forecast( + model_run_dir = fake_dir, + n_forecast_days = 7, + model_name = "test_model", + save = FALSE + ), + # This error comes from process_forecast trying to read + # training data + "does not exist" + ) + + # Test with legacy interface - should give different error + expect_error( + process_loc_forecast( + model_run_dir = fake_dir, + n_forecast_days = 7, + model_name = NA, + pyrenew_model_name = NA, + timeseries_model_name = NA + ), + "Either `model_name` or `pyrenew_model_name`" + ) +}) + +test_that("process_loc_forecast validates legacy interface parameters", { + # Should error when neither model_name nor pyrenew/timeseries + # names provided + expect_error( + process_loc_forecast( + model_run_dir = fake_dir, + n_forecast_days = 7, + model_name = NA, + pyrenew_model_name = NA, + timeseries_model_name = NA + ), + "Either `model_name` or `pyrenew_model_name`" + ) +}) diff --git a/hewr/tests/testthat/test_timeseries_utils.R b/hewr/tests/testthat/test_timeseries_utils.R new file mode 100644 index 00000000..e4f54218 --- /dev/null +++ b/hewr/tests/testthat/test_timeseries_utils.R @@ -0,0 +1,183 @@ +# Common test data fixtures +common_required_columns <- c( + ".draw", + "date", + "geo_value", + "disease", + ".variable", + ".value", + "resolution", + "aggregated_numerator", + "aggregated_denominator" +) + +base_date <- as.Date("2024-01-01") + +test_that("format_timeseries_output formats forecast data correctly", { + # Create minimal test data + forecast_data <- tibble::tibble( + date = base_date + 0:2, + .draw = c(1, 1, 1), + observed_ed_visits = c(10, 15, 20), + other_ed_visits = c(5, 7, 9) + ) + + result <- format_timeseries_output( + forecast_data = forecast_data, + geo_value = "US", + disease = "COVID-19", + resolution = "daily", + output_type_id = ".draw" + ) + + # Check that output has expected structure + expect_s3_class(result, "data.frame") + expect_true(all( + c( + "date", + "geo_value", + "disease", + "resolution", + "aggregated_numerator", + "aggregated_denominator", + ".variable", + ".draw", + ".value" + ) %in% + colnames(result) + )) + + # Check that geo_value and disease are set correctly + expect_true(all(result$geo_value == "US")) + expect_true(all(result$disease == "COVID-19")) + expect_true(all(result$resolution == "daily")) + + # Check aggregation flags + expect_true(all(result$aggregated_numerator == FALSE)) + + # Check that data was pivoted (should have 2 variables x 3 dates = 6 rows) + expect_equal(nrow(result), 6) +}) + +test_that("format_timeseries_output handles proportion variables", { + forecast_data <- tibble::tibble( + date = base_date, + .draw = 1, + prop_disease_ed_visits = 0.5 + ) + + result <- format_timeseries_output( + forecast_data = forecast_data, + geo_value = "CA", + disease = "Influenza", + resolution = "epiweekly", + output_type_id = ".draw" + ) + + # Proportion variables should have aggregated_denominator = FALSE + prop_row <- result[result$.variable == "prop_disease_ed_visits", ] + expect_equal(prop_row$aggregated_denominator, FALSE) +}) + +test_that("prop_from_timeseries calculates proportions correctly", { + e_denominator_samples <- tibble::tibble( + resolution = "daily", + .draw = 1:3, + date = base_date, + geo_value = "US", + disease = "COVID-19", + other_ed_visits = c(5, 10, 15) + ) + + e_numerator_samples <- tibble::tibble( + resolution = "daily", + .draw = 1:3, + date = base_date, + geo_value = "US", + disease = "COVID-19", + observed_ed_visits = c(10, 20, 30), + aggregated_numerator = FALSE + ) + + result <- prop_from_timeseries( + e_denominator_samples, + e_numerator_samples, + common_required_columns + ) + + # Check that proportions are calculated correctly + # prop = {observed \over observed + other} + # For draw 1: 10 / (10 + 5) = 0.6666... + # For draw 2: 20 / (20 + 10) = 0.6666... + # For draw 3: 30 / (30 + 15) = 0.6666... + expect_equal(nrow(result), 3) + expect_true(all(result$.variable == "prop_disease_ed_visits")) + expect_equal(result$.value, rep(2 / 3, 3), tolerance = 1e-10) +}) + +test_that("epiweekly_samples_from_daily aggregates correctly", { + # Create test data - simple smoke test + # Use dates that align with epiweeks (Sunday to Saturday) + daily_samples <- tibble::tibble( + .draw = rep(1, 7), + date = as.Date("2024-01-07") + 0:6, # One complete epiweek + geo_value = "US", + disease = "COVID-19", + .variable = "observed_ed_visits", + .value = rep(10, 7), + resolution = "daily", + aggregated_numerator = FALSE, + aggregated_denominator = NA + ) + + result <- epiweekly_samples_from_daily( + daily_samples = daily_samples, + variables_to_aggregate = "observed_ed_visits", + required_columns = common_required_columns + ) + + # Basic smoke test - should aggregate successfully + expect_s3_class(result, "data.frame") + expect_true(nrow(result) >= 1) + expect_true(all(result$resolution == "epiweekly")) + expect_true(all(result$aggregated_numerator == TRUE)) + expect_true(all(result$.variable == "observed_ed_visits")) +}) + +test_that("to_tidy_draws_timeseries combines forecast and observed", { + # Create minimal forecast data - 3 forecast dates x 2 draws = 6 rows + tidy_forecast <- tibble::tibble( + date = rep(as.Date("2024-01-08") + 0:2, each = 2), + .draw = rep(1:2, 3), + .variable = rep("observed_ed_visits", 6), + .value = c(15, 16, 17, 18, 19, 20), + resolution = "daily" + ) + + observed <- tibble::tibble( + date = base_date + 0:6, + .variable = "observed_ed_visits", + .value = 10:16 + ) + + result <- to_tidy_draws_timeseries( + tidy_forecast = tidy_forecast, + observed = observed + ) + + # Should combine observed (7 days * 2 draws = 14) + forecast (6 rows) + expect_equal(nrow(result), 20) + + # Should have .draw column first + expect_equal(colnames(result)[1], ".draw") + + # All rows should have .draw values + expect_true(all(!is.na(result$.draw))) + + # Check observed dates all have same value across draws + obs_dates <- observed$date + for (d in obs_dates) { + date_rows <- result[result$date == d, ] + expect_equal(length(unique(date_rows$.value)), 1) + } +}) diff --git a/pipelines/plot_and_save_loc_forecast.R b/pipelines/plot_and_save_loc_forecast.R index ed363f47..7536409b 100644 --- a/pipelines/plot_and_save_loc_forecast.R +++ b/pipelines/plot_and_save_loc_forecast.R @@ -65,10 +65,10 @@ save_forecast_figures <- function( parsed_model_run_dir <- parse_model_run_dir_path(model_run_dir) processed_forecast <- process_loc_forecast( - model_run_dir, - n_forecast_days, - pyrenew_model_name, - timeseries_model_name, + model_run_dir = model_run_dir, + n_forecast_days = n_forecast_days, + pyrenew_model_name = pyrenew_model_name, + timeseries_model_name = timeseries_model_name, save = TRUE )