diff --git a/hewr/DESCRIPTION b/hewr/DESCRIPTION index 8d15ab91..4a451dd7 100644 --- a/hewr/DESCRIPTION +++ b/hewr/DESCRIPTION @@ -7,7 +7,7 @@ Description: What the package does (one paragraph). License: Apache License (>= 2) Encoding: UTF-8 Roxygen: list(markdown = TRUE) -RoxygenNote: 7.3.3 +RoxygenNote: 7.3.2 Imports: argparser, cowplot, diff --git a/hewr/NAMESPACE b/hewr/NAMESPACE index 83d59c73..316f2573 100644 --- a/hewr/NAMESPACE +++ b/hewr/NAMESPACE @@ -1,7 +1,5 @@ # Generated by roxygen2: do not edit by hand -S3method(process_model_samples,pyrenew) -S3method(process_model_samples,timeseries) export(epiweekly_samples_from_daily) export(format_timeseries_output) export(generate_exp_growth_pois) @@ -16,9 +14,7 @@ export(parse_model_run_dir_path) export(parse_pyrenew_model_name) export(parse_variable_name) export(path_up_to) -export(process_forecast) export(process_loc_forecast) -export(process_model_samples) export(prop_from_timeseries) export(read_and_combine_data) export(to_tidy_draws_timeseries) diff --git a/hewr/R/process_loc_forecast.R b/hewr/R/process_loc_forecast.R index 73e4a108..4c2368a9 100644 --- a/hewr/R/process_loc_forecast.R +++ b/hewr/R/process_loc_forecast.R @@ -278,101 +278,6 @@ to_tidy_draws_timeseries <- function( } -#' Process model samples based on model type -#' -#' S3 generic function that dispatches to appropriate method based on -#' model type. This allows extensibility for new model types without -#' changing existing code. -#' -#' @param model_type Character string indicating model type -#' ("pyrenew", "timeseries", "epiautogp") -#' @param ... Additional arguments passed to methods. See specific methods -#' for details: [process_model_samples.pyrenew()], -#' [process_model_samples.timeseries()] -#' @return Tibble of model samples -#' @export -process_model_samples <- function(model_type, ...) { - UseMethod("process_model_samples", structure(list(), class = model_type)) -} - -#' Process PyRenew model samples -#' -#' @param model_type Character string indicating model type -#' @param model_run_dir Model run directory -#' @param model_name Name of directory containing model outputs -#' @param ts_samples Timeseries samples (if available) -#' @param required_columns_e Required columns for output -#' @param n_forecast_days Number of forecast days -#' @param ... Additional arguments (unused) -#' @return Tibble of PyRenew model samples -#' @exportS3Method -process_model_samples.pyrenew <- function( - model_type, - model_run_dir, - model_name, - ts_samples = NULL, - required_columns_e, - n_forecast_days, - ... -) { - process_pyrenew_model( - model_run_dir = model_run_dir, - pyrenew_model_name = model_name, - ts_samples = ts_samples, - required_columns_e = required_columns_e, - n_forecast_days = n_forecast_days - ) -} - -#' Process timeseries model samples -#' -#' @param model_type Character string indicating model type -#' @param model_run_dir Model run directory -#' @param model_name Name of directory containing model outputs -#' @param ts_samples Timeseries samples (required for this method) -#' @param required_columns_e Required columns for output -#' @param n_forecast_days Number of forecast days -#' @param ... Additional arguments (unused) -#' @return Tibble of timeseries model samples -#' @exportS3Method -process_model_samples.timeseries <- function( - model_type, - model_run_dir, - model_name, - ts_samples = NULL, - required_columns_e, - n_forecast_days, - ... -) { - # For timeseries, ts_samples should already be loaded - # This is essentially a pass-through - if (is.null(ts_samples)) { - stop("ts_samples must be provided for timeseries model type") - } - ts_samples -} - -#' Detect model type from model name -#' -#' Internal helper function to infer model type from naming conventions -#' -#' @param model_name Character string with model name -#' @return Character string with model type ("pyrenew", "timeseries", etc.) -#' @keywords internal -detect_model_type <- function(model_name) { - if ( - grepl("^ts_", model_name) || - grepl("timeseries", model_name, ignore.case = TRUE) - ) { - return("timeseries") - } else if (grepl("epiautogp", model_name, ignore.case = TRUE)) { - return("epiautogp") - } else { - # Default to pyrenew for backward compatibility - return("pyrenew") - } -} - process_pyrenew_model <- function( model_run_dir, pyrenew_model_name, @@ -477,52 +382,38 @@ process_pyrenew_model <- function( #' Process loc forecast #' #' @param model_run_dir Model run directory -#' @param n_forecast_days An integer specifying the number of days to forecast. -#' @param model_name Name of directory containing model outputs. -#' If provided, uses new S3 dispatch interface (overrides -#' pyrenew_model_name/timeseries_model_name). #' @param pyrenew_model_name Name of directory containing pyrenew -#' model outputs (legacy interface) +#' model outputs #' @param timeseries_model_name Name of directory containing timeseries -#' model outputs (legacy interface) -#' @param model_type Optional character string specifying model type -#' explicitly. Only used with model_name parameter. -#' If NULL (default), will auto-detect from model_name. -#' Options: "pyrenew", "timeseries", "epiautogp" +#' model outputs +#' @param n_forecast_days An integer specifying the number of days to forecast. #' @param ci_widths Vector of probabilities indicating one or more -#' central credible intervals to compute. Passed as the `.width` -#' argument to [ggdist::median_qi()]. Default `c(0.5, 0.8, 0.95)`. +#' central credible intervals to compute. Passed as the `.width` +#' argument to [ggdist::median_qi()]. Default `c(0.5, 0.8, 0.95)`. #' @param save Boolean indicating whether or not to save the output -#' to parquet files. Default `TRUE`. -#' @return a list of 2 tibbles: `samples` and `ci` +#' to parquet files. Default `TRUE`. +#' @return a list of 8 tibbles: +#' `daily_combined_training_eval_data`, +#' `epiweekly_combined_training_eval_data`, +#' `daily_samples`, +#' `epiweekly_samples`, +#' `epiweekly_with_epiweekly_other_samples`, +#' `daily_ci`, +#' `epiweekly_ci`, +#' `epiweekly_with_epiweekly_other_ci` #' @export process_loc_forecast <- function( model_run_dir, n_forecast_days, - model_name = NA, pyrenew_model_name = NA, timeseries_model_name = NA, - model_type = NULL, ci_widths = c(0.5, 0.8, 0.95), save = TRUE ) { - # New interface: delegate to process_forecast - if (!is.na(model_name)) { - return(process_forecast( - model_run_dir = model_run_dir, - model_name = model_name, - n_forecast_days = n_forecast_days, - model_type = model_type, - ci_widths = ci_widths, - save = save - )) - } - - # Legacy interface validation if (is.na(pyrenew_model_name) && is.na(timeseries_model_name)) { stop( - "Either `model_name` or `pyrenew_model_name`/", - "`timeseries_model_name` must be provided." + "Either `pyrenew_model_name` or `timeseries_model_name`", + "must be provided." ) } model_name <- dplyr::if_else( @@ -627,131 +518,3 @@ process_loc_forecast <- function( return(result) } - -#' Process location forecast with generic model name -#' -#' Simplified version of process_loc_forecast that accepts a single model_name -#' parameter and auto-detects the model type. Uses S3 dispatch internally for -#' extensibility. -#' -#' @param model_run_dir Model run directory -#' @param model_name Name of directory containing model outputs -#' @param n_forecast_days An integer specifying the number of days to forecast. -#' @param model_type Optional character string specifying model type explicitly. -#' If NULL (default), will auto-detect from model_name. -#' Options: "pyrenew", "timeseries", "epiautogp" -#' @param ci_widths Vector of probabilities indicating one or more -#' central credible intervals to compute. Passed as the `.width` -#' argument to [ggdist::median_qi()]. Default `c(0.5, 0.8, 0.95)`. -#' @param save Boolean indicating whether or not to save the output -#' to parquet files. Default `TRUE`. -#' @return a list of 2 tibbles: `samples` and `ci` -#' @export -process_forecast <- function( - model_run_dir, - model_name, - n_forecast_days, - model_type = NULL, - ci_widths = c(0.5, 0.8, 0.95), - save = TRUE -) { - # Auto-detect model type if not specified - if (is.null(model_type)) { - model_type <- detect_model_type(model_name) - } - - model_dir <- fs::path(model_run_dir, model_name) - data_col_types <- readr::cols( - date = readr::col_date(), - geo_value = readr::col_character(), - disease = readr::col_character(), - data_type = readr::col_character(), - .variable = readr::col_character(), - .value = readr::col_double() - ) - - # Load training data - daily_training_dat <- readr::read_tsv( - fs::path(model_dir, "data", "combined_training_data", ext = "tsv"), - col_types = data_col_types - ) - - epiweekly_training_dat <- readr::read_tsv( - fs::path( - model_dir, - "data", - "epiweekly_combined_training_data", - ext = "tsv" - ), - col_types = data_col_types - ) - - required_columns_e <- c( - ".chain", - ".iteration", - ".draw", - "date", - "geo_value", - "disease", - ".variable", - ".value", - "resolution", - "aggregated_numerator", - "aggregated_denominator" - ) - - # Load timeseries samples if needed (for any model that might use them) - ts_samples <- NULL - if (model_type %in% c("timeseries", "pyrenew")) { - # Check if timeseries model exists - ts_model_dirs <- fs::dir_ls( - model_run_dir, - regexp = "(^ts_|timeseries)", - type = "directory" - ) - - if (length(ts_model_dirs) > 0) { - ts_model_name <- fs::path_file(ts_model_dirs[1]) - ts_samples <- load_and_aggregate_ts( - model_run_dir, - ts_model_name, - daily_training_dat, - epiweekly_training_dat, - required_columns = required_columns_e - ) - } - } - - # Dispatch to appropriate S3 method - model_samples_tidy <- process_model_samples( - model_type = model_type, - model_run_dir = model_run_dir, - model_name = model_name, - ts_samples = ts_samples, - required_columns_e = required_columns_e, - n_forecast_days = n_forecast_days - ) - - # Calculate credible intervals - ci <- model_samples_tidy |> - dplyr::select(-tidyselect::any_of(c(".chain", ".iteration", ".draw"))) |> - dplyr::group_by(dplyr::across(-".value")) |> - ggdist::median_qi(.width = ci_widths) - - result <- list( - "samples" = model_samples_tidy, - "ci" = ci - ) - - if (save) { - save_dir <- fs::path(model_run_dir, model_name) - purrr::iwalk(result, \(tab, name) { - forecasttools::write_tabular( - tab, - fs::path(save_dir, name, ext = "parquet") - ) - }) - } - - return(result) -} diff --git a/hewr/man/detect_model_type.Rd b/hewr/man/detect_model_type.Rd deleted file mode 100644 index bb3cb446..00000000 --- a/hewr/man/detect_model_type.Rd +++ /dev/null @@ -1,18 +0,0 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/process_loc_forecast.R -\name{detect_model_type} -\alias{detect_model_type} -\title{Detect model type from model name} -\usage{ -detect_model_type(model_name) -} -\arguments{ -\item{model_name}{Character string with model name} -} -\value{ -Character string with model type ("pyrenew", "timeseries", etc.) -} -\description{ -Internal helper function to infer model type from naming conventions -} -\keyword{internal} diff --git a/hewr/man/process_forecast.Rd b/hewr/man/process_forecast.Rd deleted file mode 100644 index 58209728..00000000 --- a/hewr/man/process_forecast.Rd +++ /dev/null @@ -1,41 +0,0 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/process_loc_forecast.R -\name{process_forecast} -\alias{process_forecast} -\title{Process location forecast with generic model name} -\usage{ -process_forecast( - model_run_dir, - model_name, - n_forecast_days, - model_type = NULL, - ci_widths = c(0.5, 0.8, 0.95), - save = TRUE -) -} -\arguments{ -\item{model_run_dir}{Model run directory} - -\item{model_name}{Name of directory containing model outputs} - -\item{n_forecast_days}{An integer specifying the number of days to forecast.} - -\item{model_type}{Optional character string specifying model type explicitly. -If NULL (default), will auto-detect from model_name. -Options: "pyrenew", "timeseries", "epiautogp"} - -\item{ci_widths}{Vector of probabilities indicating one or more -central credible intervals to compute. Passed as the \code{.width} -argument to \code{\link[ggdist:point_interval]{ggdist::median_qi()}}. Default \code{c(0.5, 0.8, 0.95)}.} - -\item{save}{Boolean indicating whether or not to save the output -to parquet files. Default \code{TRUE}.} -} -\value{ -a list of 2 tibbles: \code{samples} and \code{ci} -} -\description{ -Simplified version of process_loc_forecast that accepts a single model_name -parameter and auto-detects the model type. Uses S3 dispatch internally for -extensibility. -} diff --git a/hewr/man/process_loc_forecast.Rd b/hewr/man/process_loc_forecast.Rd index 260b690e..6787fc3a 100644 --- a/hewr/man/process_loc_forecast.Rd +++ b/hewr/man/process_loc_forecast.Rd @@ -7,10 +7,8 @@ process_loc_forecast( model_run_dir, n_forecast_days, - model_name = NA, pyrenew_model_name = NA, timeseries_model_name = NA, - model_type = NULL, ci_widths = c(0.5, 0.8, 0.95), save = TRUE ) @@ -20,20 +18,11 @@ process_loc_forecast( \item{n_forecast_days}{An integer specifying the number of days to forecast.} -\item{model_name}{Name of directory containing model outputs. -If provided, uses new S3 dispatch interface (overrides -pyrenew_model_name/timeseries_model_name).} - \item{pyrenew_model_name}{Name of directory containing pyrenew -model outputs (legacy interface)} +model outputs} \item{timeseries_model_name}{Name of directory containing timeseries -model outputs (legacy interface)} - -\item{model_type}{Optional character string specifying model type -explicitly. Only used with model_name parameter. -If NULL (default), will auto-detect from model_name. -Options: "pyrenew", "timeseries", "epiautogp"} +model outputs} \item{ci_widths}{Vector of probabilities indicating one or more central credible intervals to compute. Passed as the \code{.width} @@ -43,7 +32,15 @@ argument to \code{\link[ggdist:point_interval]{ggdist::median_qi()}}. Default \c to parquet files. Default \code{TRUE}.} } \value{ -a list of 2 tibbles: \code{samples} and \code{ci} +a list of 8 tibbles: +\code{daily_combined_training_eval_data}, +\code{epiweekly_combined_training_eval_data}, +\code{daily_samples}, +\code{epiweekly_samples}, +\code{epiweekly_with_epiweekly_other_samples}, +\code{daily_ci}, +\code{epiweekly_ci}, +\code{epiweekly_with_epiweekly_other_ci} } \description{ Process loc forecast diff --git a/hewr/man/process_model_samples.Rd b/hewr/man/process_model_samples.Rd deleted file mode 100644 index 9941a8a0..00000000 --- a/hewr/man/process_model_samples.Rd +++ /dev/null @@ -1,24 +0,0 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/process_loc_forecast.R -\name{process_model_samples} -\alias{process_model_samples} -\title{Process model samples based on model type} -\usage{ -process_model_samples(model_type, ...) -} -\arguments{ -\item{model_type}{Character string indicating model type -("pyrenew", "timeseries", "epiautogp")} - -\item{...}{Additional arguments passed to methods. See specific methods -for details: \code{\link[=process_model_samples.pyrenew]{process_model_samples.pyrenew()}}, -\code{\link[=process_model_samples.timeseries]{process_model_samples.timeseries()}}} -} -\value{ -Tibble of model samples -} -\description{ -S3 generic function that dispatches to appropriate method based on -model type. This allows extensibility for new model types without -changing existing code. -} diff --git a/hewr/man/process_model_samples.pyrenew.Rd b/hewr/man/process_model_samples.pyrenew.Rd deleted file mode 100644 index f92c8577..00000000 --- a/hewr/man/process_model_samples.pyrenew.Rd +++ /dev/null @@ -1,37 +0,0 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/process_loc_forecast.R -\name{process_model_samples.pyrenew} -\alias{process_model_samples.pyrenew} -\title{Process PyRenew model samples} -\usage{ -\method{process_model_samples}{pyrenew}( - model_type, - model_run_dir, - model_name, - ts_samples = NULL, - required_columns_e, - n_forecast_days, - ... -) -} -\arguments{ -\item{model_type}{Character string indicating model type} - -\item{model_run_dir}{Model run directory} - -\item{model_name}{Name of directory containing model outputs} - -\item{ts_samples}{Timeseries samples (if available)} - -\item{required_columns_e}{Required columns for output} - -\item{n_forecast_days}{Number of forecast days} - -\item{...}{Additional arguments (unused)} -} -\value{ -Tibble of PyRenew model samples -} -\description{ -Process PyRenew model samples -} diff --git a/hewr/man/process_model_samples.timeseries.Rd b/hewr/man/process_model_samples.timeseries.Rd deleted file mode 100644 index 40edc46b..00000000 --- a/hewr/man/process_model_samples.timeseries.Rd +++ /dev/null @@ -1,37 +0,0 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/process_loc_forecast.R -\name{process_model_samples.timeseries} -\alias{process_model_samples.timeseries} -\title{Process timeseries model samples} -\usage{ -\method{process_model_samples}{timeseries}( - model_type, - model_run_dir, - model_name, - ts_samples = NULL, - required_columns_e, - n_forecast_days, - ... -) -} -\arguments{ -\item{model_type}{Character string indicating model type} - -\item{model_run_dir}{Model run directory} - -\item{model_name}{Name of directory containing model outputs} - -\item{ts_samples}{Timeseries samples (required for this method)} - -\item{required_columns_e}{Required columns for output} - -\item{n_forecast_days}{Number of forecast days} - -\item{...}{Additional arguments (unused)} -} -\value{ -Tibble of timeseries model samples -} -\description{ -Process timeseries model samples -} diff --git a/hewr/tests/testthat/test_process_loc_forecast.R b/hewr/tests/testthat/test_process_loc_forecast.R index 434c4c11..48660ee9 100644 --- a/hewr/tests/testthat/test_process_loc_forecast.R +++ b/hewr/tests/testthat/test_process_loc_forecast.R @@ -1,20 +1,3 @@ -# Common test fixtures -fake_dir <- "/fake/dir" -minimal_required_columns <- c("date", ".value") -full_required_columns <- c( - ".chain", - ".iteration", - ".draw", - "date", - "geo_value", - "disease", - ".variable", - ".value", - "resolution", - "aggregated_numerator", - "aggregated_denominator" -) - example_train_dat <- tibble::tibble( geo_value = "CA", disease = "COVID-19", @@ -98,154 +81,3 @@ test_that("to_tidy_draws_timeseries() works as expected", { expect_equal(result, expected) }) - -test_that("detect_model_type() correctly identifies model types", { - # Test timeseries detection with ts_ prefix - expect_equal(detect_model_type("ts_model_v1"), "timeseries") - expect_equal(detect_model_type("ts_ensemble"), "timeseries") - - # Test timeseries detection with name containing "timeseries" - expect_equal(detect_model_type("timeseries_model"), "timeseries") - expect_equal(detect_model_type("my_TimeSeries_model"), "timeseries") - - # Test epiautogp detection - expect_equal(detect_model_type("epiautogp_model"), "epiautogp") - expect_equal(detect_model_type("EpiAutoGP_v2"), "epiautogp") - - # Test default to pyrenew - expect_equal(detect_model_type("pyrenew_hew"), "pyrenew") - expect_equal(detect_model_type("pyrenew_e"), "pyrenew") - expect_equal(detect_model_type("some_other_model"), "pyrenew") -}) - -test_that("process_model_samples S3 dispatch works correctly", { - # Test that the generic function exists and has the right class - expect_true(is.function(process_model_samples)) - - # Test that methods exist for expected classes - expect_true( - "process_model_samples.pyrenew" %in% - methods("process_model_samples") - ) - expect_true( - "process_model_samples.timeseries" %in% - methods("process_model_samples") - ) -}) - -test_that("process_model_samples.timeseries validates ts_samples", { - # Should error when ts_samples is NULL - expect_error( - process_model_samples.timeseries( - model_type = "timeseries", - model_run_dir = fake_dir, - model_name = "ts_model", - ts_samples = NULL, - required_columns_e = minimal_required_columns, - n_forecast_days = 7 - ), - "ts_samples must be provided for timeseries model type" - ) -}) - -test_that("process_model_samples.timeseries returns ts_samples", { - # Create mock ts_samples - mock_ts_samples <- tibble::tibble( - .chain = 1, - .iteration = 1, - .draw = 1, - date = as.Date("2024-01-01"), - geo_value = "US", - disease = "COVID-19", - .variable = "other_ed_visits", - .value = 100, - resolution = "daily", - aggregated_numerator = FALSE, - aggregated_denominator = NA - ) - - result <- process_model_samples.timeseries( - model_type = "timeseries", - model_run_dir = fake_dir, - model_name = "ts_model", - ts_samples = mock_ts_samples, - required_columns_e = minimal_required_columns, - n_forecast_days = 7 - ) - - # Should return the ts_samples unchanged - expect_equal(result, mock_ts_samples) -}) - -test_that("process_model_samples.pyrenew dispatches correctly", { - # This test just verifies the S3 method exists and dispatches - # We expect it to error since we're not providing real data/files - # The key is that it calls the method (for code coverage) - - expect_error( - process_model_samples.pyrenew( - model_type = "pyrenew", - model_run_dir = "any_path", - model_name = "pyrenew_h", - ts_samples = NULL, - required_columns_e = minimal_required_columns, - n_forecast_days = 7 - ) - ) - - # Verify the method exists - expect_true( - "process_model_samples.pyrenew" %in% methods("process_model_samples") - ) -}) - -test_that("process_loc_forecast delegates correctly", { - # Test that process_loc_forecast calls process_forecast when - # model_name is provided by checking that it doesn't use the - # legacy code path - - # Create a simple test: when model_name is provided, function - # should attempt to call process_forecast, which will try to - # read files. When model_name is NA, it uses the legacy path - # with different error message - - # Test with model_name provided (new interface) - expect_error( - process_loc_forecast( - model_run_dir = fake_dir, - n_forecast_days = 7, - model_name = "test_model", - save = FALSE - ), - # This error comes from process_forecast trying to read - # training data - "does not exist" - ) - - # Test with legacy interface - should give different error - expect_error( - process_loc_forecast( - model_run_dir = fake_dir, - n_forecast_days = 7, - model_name = NA, - pyrenew_model_name = NA, - timeseries_model_name = NA - ), - "Either `model_name` or `pyrenew_model_name`" - ) -}) - -test_that("process_loc_forecast validates legacy interface parameters", { - # Should error when neither model_name nor pyrenew/timeseries - # names provided - expect_error( - process_loc_forecast( - model_run_dir = fake_dir, - n_forecast_days = 7, - model_name = NA, - pyrenew_model_name = NA, - timeseries_model_name = NA - ), - "Either `model_name` or `pyrenew_model_name`" - ) -}) diff --git a/hewr/tests/testthat/test_timeseries_utils.R b/hewr/tests/testthat/test_timeseries_utils.R deleted file mode 100644 index e4f54218..00000000 --- a/hewr/tests/testthat/test_timeseries_utils.R +++ /dev/null @@ -1,183 +0,0 @@ -# Common test data fixtures -common_required_columns <- c( - ".draw", - "date", - "geo_value", - "disease", - ".variable", - ".value", - "resolution", - "aggregated_numerator", - "aggregated_denominator" -) - -base_date <- as.Date("2024-01-01") - -test_that("format_timeseries_output formats forecast data correctly", { - # Create minimal test data - forecast_data <- tibble::tibble( - date = base_date + 0:2, - .draw = c(1, 1, 1), - observed_ed_visits = c(10, 15, 20), - other_ed_visits = c(5, 7, 9) - ) - - result <- format_timeseries_output( - forecast_data = forecast_data, - geo_value = "US", - disease = "COVID-19", - resolution = "daily", - output_type_id = ".draw" - ) - - # Check that output has expected structure - expect_s3_class(result, "data.frame") - expect_true(all( - c( - "date", - "geo_value", - "disease", - "resolution", - "aggregated_numerator", - "aggregated_denominator", - ".variable", - ".draw", - ".value" - ) %in% - colnames(result) - )) - - # Check that geo_value and disease are set correctly - expect_true(all(result$geo_value == "US")) - expect_true(all(result$disease == "COVID-19")) - expect_true(all(result$resolution == "daily")) - - # Check aggregation flags - expect_true(all(result$aggregated_numerator == FALSE)) - - # Check that data was pivoted (should have 2 variables x 3 dates = 6 rows) - expect_equal(nrow(result), 6) -}) - -test_that("format_timeseries_output handles proportion variables", { - forecast_data <- tibble::tibble( - date = base_date, - .draw = 1, - prop_disease_ed_visits = 0.5 - ) - - result <- format_timeseries_output( - forecast_data = forecast_data, - geo_value = "CA", - disease = "Influenza", - resolution = "epiweekly", - output_type_id = ".draw" - ) - - # Proportion variables should have aggregated_denominator = FALSE - prop_row <- result[result$.variable == "prop_disease_ed_visits", ] - expect_equal(prop_row$aggregated_denominator, FALSE) -}) - -test_that("prop_from_timeseries calculates proportions correctly", { - e_denominator_samples <- tibble::tibble( - resolution = "daily", - .draw = 1:3, - date = base_date, - geo_value = "US", - disease = "COVID-19", - other_ed_visits = c(5, 10, 15) - ) - - e_numerator_samples <- tibble::tibble( - resolution = "daily", - .draw = 1:3, - date = base_date, - geo_value = "US", - disease = "COVID-19", - observed_ed_visits = c(10, 20, 30), - aggregated_numerator = FALSE - ) - - result <- prop_from_timeseries( - e_denominator_samples, - e_numerator_samples, - common_required_columns - ) - - # Check that proportions are calculated correctly - # prop = {observed \over observed + other} - # For draw 1: 10 / (10 + 5) = 0.6666... - # For draw 2: 20 / (20 + 10) = 0.6666... - # For draw 3: 30 / (30 + 15) = 0.6666... - expect_equal(nrow(result), 3) - expect_true(all(result$.variable == "prop_disease_ed_visits")) - expect_equal(result$.value, rep(2 / 3, 3), tolerance = 1e-10) -}) - -test_that("epiweekly_samples_from_daily aggregates correctly", { - # Create test data - simple smoke test - # Use dates that align with epiweeks (Sunday to Saturday) - daily_samples <- tibble::tibble( - .draw = rep(1, 7), - date = as.Date("2024-01-07") + 0:6, # One complete epiweek - geo_value = "US", - disease = "COVID-19", - .variable = "observed_ed_visits", - .value = rep(10, 7), - resolution = "daily", - aggregated_numerator = FALSE, - aggregated_denominator = NA - ) - - result <- epiweekly_samples_from_daily( - daily_samples = daily_samples, - variables_to_aggregate = "observed_ed_visits", - required_columns = common_required_columns - ) - - # Basic smoke test - should aggregate successfully - expect_s3_class(result, "data.frame") - expect_true(nrow(result) >= 1) - expect_true(all(result$resolution == "epiweekly")) - expect_true(all(result$aggregated_numerator == TRUE)) - expect_true(all(result$.variable == "observed_ed_visits")) -}) - -test_that("to_tidy_draws_timeseries combines forecast and observed", { - # Create minimal forecast data - 3 forecast dates x 2 draws = 6 rows - tidy_forecast <- tibble::tibble( - date = rep(as.Date("2024-01-08") + 0:2, each = 2), - .draw = rep(1:2, 3), - .variable = rep("observed_ed_visits", 6), - .value = c(15, 16, 17, 18, 19, 20), - resolution = "daily" - ) - - observed <- tibble::tibble( - date = base_date + 0:6, - .variable = "observed_ed_visits", - .value = 10:16 - ) - - result <- to_tidy_draws_timeseries( - tidy_forecast = tidy_forecast, - observed = observed - ) - - # Should combine observed (7 days * 2 draws = 14) + forecast (6 rows) - expect_equal(nrow(result), 20) - - # Should have .draw column first - expect_equal(colnames(result)[1], ".draw") - - # All rows should have .draw values - expect_true(all(!is.na(result$.draw))) - - # Check observed dates all have same value across draws - obs_dates <- observed$date - for (d in obs_dates) { - date_rows <- result[result$date == d, ] - expect_equal(length(unique(date_rows$.value)), 1) - } -})