diff --git a/hewr/R/process_loc_forecast.R b/hewr/R/process_loc_forecast.R index 73e4a108..f936ff1c 100644 --- a/hewr/R/process_loc_forecast.R +++ b/hewr/R/process_loc_forecast.R @@ -479,8 +479,8 @@ process_pyrenew_model <- function( #' @param model_run_dir Model run directory #' @param n_forecast_days An integer specifying the number of days to forecast. #' @param model_name Name of directory containing model outputs. -#' If provided, uses new S3 dispatch interface (overrides -#' pyrenew_model_name/timeseries_model_name). +#' If provided, uses new S3 dispatch interface with auto-detection +#' (overrides pyrenew_model_name/timeseries_model_name). #' @param pyrenew_model_name Name of directory containing pyrenew #' model outputs (legacy interface) #' @param timeseries_model_name Name of directory containing timeseries @@ -506,7 +506,7 @@ process_loc_forecast <- function( ci_widths = c(0.5, 0.8, 0.95), save = TRUE ) { - # New interface: delegate to process_forecast + # New interface: delegate to process_forecast with auto-detection if (!is.na(model_name)) { return(process_forecast( model_run_dir = model_run_dir, @@ -521,8 +521,8 @@ process_loc_forecast <- function( # Legacy interface validation if (is.na(pyrenew_model_name) && is.na(timeseries_model_name)) { stop( - "Either `model_name` or `pyrenew_model_name`/", - "`timeseries_model_name` must be provided." + "At least one of `model_name`, `pyrenew_model_name`, ", + "or `timeseries_model_name` must be provided." ) } model_name <- dplyr::if_else( diff --git a/hewr/tests/testthat/test_process_loc_forecast.R b/hewr/tests/testthat/test_process_loc_forecast.R index 434c4c11..0acc27ba 100644 --- a/hewr/tests/testthat/test_process_loc_forecast.R +++ b/hewr/tests/testthat/test_process_loc_forecast.R @@ -231,7 +231,7 @@ test_that("process_loc_forecast delegates correctly", { pyrenew_model_name = NA, timeseries_model_name = NA ), - "Either `model_name` or `pyrenew_model_name`" + "At least one of" ) }) @@ -246,6 +246,6 @@ test_that("process_loc_forecast validates legacy interface parameters", { pyrenew_model_name = NA, timeseries_model_name = NA ), - "Either `model_name` or `pyrenew_model_name`" + "At least one of" ) }) diff --git a/pipelines/common_utils.py b/pipelines/common_utils.py index a7634526..c8b4e7b4 100644 --- a/pipelines/common_utils.py +++ b/pipelines/common_utils.py @@ -329,6 +329,7 @@ def plot_and_save_loc_forecast( n_forecast_days: int, pyrenew_model_name: str = None, timeseries_model_name: str = None, + model_name: str = None, ) -> None: """Plot and save location forecast using R script. @@ -339,9 +340,12 @@ def plot_and_save_loc_forecast( n_forecast_days : int Number of days to forecast. pyrenew_model_name : str, optional - Name of the PyRenew model. + Name of the PyRenew model (legacy). timeseries_model_name : str, optional - Name of the timeseries model. + Name of the timeseries model (legacy). + model_name : str, optional + Generic model name. When provided, auto-detects model type + and dispatches to appropriate processing method. Returns ------- @@ -366,6 +370,13 @@ def plot_and_save_loc_forecast( f"{timeseries_model_name}", ] ) + if model_name is not None: + args.extend( + [ + "--model-name", + f"{model_name}", + ] + ) run_r_script( "pipelines/plot_and_save_loc_forecast.R", diff --git a/pipelines/plot_and_save_loc_forecast.R b/pipelines/plot_and_save_loc_forecast.R index 7536409b..6b178831 100644 --- a/pipelines/plot_and_save_loc_forecast.R +++ b/pipelines/plot_and_save_loc_forecast.R @@ -23,12 +23,13 @@ save_forecast_figures <- function( model_run_dir, n_forecast_days, pyrenew_model_name = NA, - timeseries_model_name = NA + timeseries_model_name = NA, + model_name = NA ) { - if (is.na(pyrenew_model_name) && is.na(timeseries_model_name)) { + if (all(is.na(c(pyrenew_model_name, timeseries_model_name, model_name)))) { stop( - "Either `pyrenew_model_name` or `timeseries_model_name`", - "must be provided." + "At least one of `pyrenew_model_name`, `timeseries_model_name`, ", + "or `model_name` must be provided." ) } @@ -52,13 +53,15 @@ save_forecast_figures <- function( str_replace_all("_+", "_") } - model_name <- dplyr::if_else( - is.na(pyrenew_model_name), - timeseries_model_name, - pyrenew_model_name + # Determine which model name to use (prioritize in order) + final_model_name <- dplyr::case_when( + !is.na(model_name) ~ model_name, + !is.na(pyrenew_model_name) ~ pyrenew_model_name, + !is.na(timeseries_model_name) ~ timeseries_model_name, + TRUE ~ NA_character_ ) - model_dir <- fs::path(model_run_dir, model_name) + model_dir <- fs::path(model_run_dir, final_model_name) figure_dir <- fs::path(model_dir, "figures") data_dir <- fs::path(model_dir, "data") dir_create(figure_dir) @@ -69,6 +72,7 @@ save_forecast_figures <- function( n_forecast_days = n_forecast_days, pyrenew_model_name = pyrenew_model_name, timeseries_model_name = timeseries_model_name, + model_name = model_name, save = TRUE ) @@ -128,7 +132,7 @@ save_forecast_figures <- function( ) |> mutate( file_name = create_file_name( - model_name, + final_model_name, .data$.variable, .data$resolution, .data$aggregated_numerator, @@ -168,6 +172,10 @@ p <- arg_parser("Generate forecast figures") |> add_argument( "--n-forecast-days", help = "Number of days to forecast" + ) |> + add_argument( + "--model-name", + help = "Name of directory with model outputs (auto-detects type)" ) argv <- parse_args(p) @@ -176,10 +184,12 @@ model_run_dir <- path(argv$model_run_dir) n_forecast_days <- as.numeric(argv$n_forecast_days) pyrenew_model_name <- argv$pyrenew_model_name timeseries_model_name <- argv$timeseries_model_name +model_name <- argv$model_name save_forecast_figures( model_run_dir, n_forecast_days, pyrenew_model_name, - timeseries_model_name + timeseries_model_name, + model_name )