From 9e7be9191fa7a4b81fffdc0c6235868a13bad34f Mon Sep 17 00:00:00 2001 From: Samuel Brand <48288458+SamuelBrand1@users.noreply.github.com> Date: Fri, 19 Dec 2025 17:15:36 +0000 Subject: [PATCH 1/3] Add `model_name` parameter --- hewr/R/process_loc_forecast.R | 10 ++--- .../testthat/test_process_loc_forecast.R | 4 +- pipelines/common_utils.py | 15 ++++++- pipelines/plot_and_save_loc_forecast.R | 41 ++++++++++++++----- 4 files changed, 50 insertions(+), 20 deletions(-) diff --git a/hewr/R/process_loc_forecast.R b/hewr/R/process_loc_forecast.R index 73e4a108..f936ff1c 100644 --- a/hewr/R/process_loc_forecast.R +++ b/hewr/R/process_loc_forecast.R @@ -479,8 +479,8 @@ process_pyrenew_model <- function( #' @param model_run_dir Model run directory #' @param n_forecast_days An integer specifying the number of days to forecast. #' @param model_name Name of directory containing model outputs. -#' If provided, uses new S3 dispatch interface (overrides -#' pyrenew_model_name/timeseries_model_name). +#' If provided, uses new S3 dispatch interface with auto-detection +#' (overrides pyrenew_model_name/timeseries_model_name). #' @param pyrenew_model_name Name of directory containing pyrenew #' model outputs (legacy interface) #' @param timeseries_model_name Name of directory containing timeseries @@ -506,7 +506,7 @@ process_loc_forecast <- function( ci_widths = c(0.5, 0.8, 0.95), save = TRUE ) { - # New interface: delegate to process_forecast + # New interface: delegate to process_forecast with auto-detection if (!is.na(model_name)) { return(process_forecast( model_run_dir = model_run_dir, @@ -521,8 +521,8 @@ process_loc_forecast <- function( # Legacy interface validation if (is.na(pyrenew_model_name) && is.na(timeseries_model_name)) { stop( - "Either `model_name` or `pyrenew_model_name`/", - "`timeseries_model_name` must be provided." + "At least one of `model_name`, `pyrenew_model_name`, ", + "or `timeseries_model_name` must be provided." ) } model_name <- dplyr::if_else( diff --git a/hewr/tests/testthat/test_process_loc_forecast.R b/hewr/tests/testthat/test_process_loc_forecast.R index 434c4c11..0acc27ba 100644 --- a/hewr/tests/testthat/test_process_loc_forecast.R +++ b/hewr/tests/testthat/test_process_loc_forecast.R @@ -231,7 +231,7 @@ test_that("process_loc_forecast delegates correctly", { pyrenew_model_name = NA, timeseries_model_name = NA ), - "Either `model_name` or `pyrenew_model_name`" + "At least one of" ) }) @@ -246,6 +246,6 @@ test_that("process_loc_forecast validates legacy interface parameters", { pyrenew_model_name = NA, timeseries_model_name = NA ), - "Either `model_name` or `pyrenew_model_name`" + "At least one of" ) }) diff --git a/pipelines/common_utils.py b/pipelines/common_utils.py index a7634526..c8b4e7b4 100644 --- a/pipelines/common_utils.py +++ b/pipelines/common_utils.py @@ -329,6 +329,7 @@ def plot_and_save_loc_forecast( n_forecast_days: int, pyrenew_model_name: str = None, timeseries_model_name: str = None, + model_name: str = None, ) -> None: """Plot and save location forecast using R script. @@ -339,9 +340,12 @@ def plot_and_save_loc_forecast( n_forecast_days : int Number of days to forecast. pyrenew_model_name : str, optional - Name of the PyRenew model. + Name of the PyRenew model (legacy). timeseries_model_name : str, optional - Name of the timeseries model. + Name of the timeseries model (legacy). + model_name : str, optional + Generic model name. When provided, auto-detects model type + and dispatches to appropriate processing method. Returns ------- @@ -366,6 +370,13 @@ def plot_and_save_loc_forecast( f"{timeseries_model_name}", ] ) + if model_name is not None: + args.extend( + [ + "--model-name", + f"{model_name}", + ] + ) run_r_script( "pipelines/plot_and_save_loc_forecast.R", diff --git a/pipelines/plot_and_save_loc_forecast.R b/pipelines/plot_and_save_loc_forecast.R index 7536409b..d5873fc4 100644 --- a/pipelines/plot_and_save_loc_forecast.R +++ b/pipelines/plot_and_save_loc_forecast.R @@ -23,12 +23,22 @@ save_forecast_figures <- function( model_run_dir, n_forecast_days, pyrenew_model_name = NA, - timeseries_model_name = NA + timeseries_model_name = NA, + model_name = NA ) { - if (is.na(pyrenew_model_name) && is.na(timeseries_model_name)) { + # Count how many model names are provided + model_names_provided <- sum( + !is.na(c( + pyrenew_model_name, + timeseries_model_name, + model_name + )) + ) + + if (model_names_provided == 0) { stop( - "Either `pyrenew_model_name` or `timeseries_model_name`", - "must be provided." + "At least one of `pyrenew_model_name`, `timeseries_model_name`, ", + "or `model_name` must be provided." ) } @@ -52,13 +62,15 @@ save_forecast_figures <- function( str_replace_all("_+", "_") } - model_name <- dplyr::if_else( - is.na(pyrenew_model_name), - timeseries_model_name, - pyrenew_model_name + # Determine which model name to use (prioritize in order) + final_model_name <- dplyr::case_when( + !is.na(model_name) ~ model_name, + !is.na(pyrenew_model_name) ~ pyrenew_model_name, + !is.na(timeseries_model_name) ~ timeseries_model_name, + TRUE ~ NA_character_ ) - model_dir <- fs::path(model_run_dir, model_name) + model_dir <- fs::path(model_run_dir, final_model_name) figure_dir <- fs::path(model_dir, "figures") data_dir <- fs::path(model_dir, "data") dir_create(figure_dir) @@ -69,6 +81,7 @@ save_forecast_figures <- function( n_forecast_days = n_forecast_days, pyrenew_model_name = pyrenew_model_name, timeseries_model_name = timeseries_model_name, + model_name = model_name, save = TRUE ) @@ -128,7 +141,7 @@ save_forecast_figures <- function( ) |> mutate( file_name = create_file_name( - model_name, + final_model_name, .data$.variable, .data$resolution, .data$aggregated_numerator, @@ -168,6 +181,10 @@ p <- arg_parser("Generate forecast figures") |> add_argument( "--n-forecast-days", help = "Number of days to forecast" + ) |> + add_argument( + "--model-name", + help = "Name of directory with model outputs (auto-detects type)" ) argv <- parse_args(p) @@ -176,10 +193,12 @@ model_run_dir <- path(argv$model_run_dir) n_forecast_days <- as.numeric(argv$n_forecast_days) pyrenew_model_name <- argv$pyrenew_model_name timeseries_model_name <- argv$timeseries_model_name +model_name <- argv$model_name save_forecast_figures( model_run_dir, n_forecast_days, pyrenew_model_name, - timeseries_model_name + timeseries_model_name, + model_name ) From 39cee8531fc9cd48236affe4cd08a00aad18b02b Mon Sep 17 00:00:00 2001 From: Samuel Brand <48288458+SamuelBrand1@users.noreply.github.com> Date: Fri, 19 Dec 2025 22:12:11 +0000 Subject: [PATCH 2/3] use all rather than sum --- pipelines/plot_and_save_loc_forecast.R | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/pipelines/plot_and_save_loc_forecast.R b/pipelines/plot_and_save_loc_forecast.R index d5873fc4..64f0cace 100644 --- a/pipelines/plot_and_save_loc_forecast.R +++ b/pipelines/plot_and_save_loc_forecast.R @@ -26,16 +26,9 @@ save_forecast_figures <- function( timeseries_model_name = NA, model_name = NA ) { - # Count how many model names are provided - model_names_provided <- sum( - !is.na(c( - pyrenew_model_name, - timeseries_model_name, - model_name - )) - ) - if (model_names_provided == 0) { + + if (all(is.na(c(pyrenew_model_name, timeseries_model_name, model_name)))) { stop( "At least one of `pyrenew_model_name`, `timeseries_model_name`, ", "or `model_name` must be provided." From 34fab85e3eb262d6176ed2bc5f62983b502ad86e Mon Sep 17 00:00:00 2001 From: Samuel Brand <48288458+SamuelBrand1@users.noreply.github.com> Date: Fri, 19 Dec 2025 22:13:03 +0000 Subject: [PATCH 3/3] reformat --- pipelines/plot_and_save_loc_forecast.R | 2 -- 1 file changed, 2 deletions(-) diff --git a/pipelines/plot_and_save_loc_forecast.R b/pipelines/plot_and_save_loc_forecast.R index 64f0cace..6b178831 100644 --- a/pipelines/plot_and_save_loc_forecast.R +++ b/pipelines/plot_and_save_loc_forecast.R @@ -26,8 +26,6 @@ save_forecast_figures <- function( timeseries_model_name = NA, model_name = NA ) { - - if (all(is.na(c(pyrenew_model_name, timeseries_model_name, model_name)))) { stop( "At least one of `pyrenew_model_name`, `timeseries_model_name`, ",