Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions hewr/R/process_loc_forecast.R
Original file line number Diff line number Diff line change
Expand Up @@ -479,8 +479,8 @@ process_pyrenew_model <- function(
#' @param model_run_dir Model run directory
#' @param n_forecast_days An integer specifying the number of days to forecast.
#' @param model_name Name of directory containing model outputs.
#' If provided, uses new S3 dispatch interface (overrides
#' pyrenew_model_name/timeseries_model_name).
#' If provided, uses new S3 dispatch interface with auto-detection
#' (overrides pyrenew_model_name/timeseries_model_name).
#' @param pyrenew_model_name Name of directory containing pyrenew
#' model outputs (legacy interface)
#' @param timeseries_model_name Name of directory containing timeseries
Expand All @@ -506,7 +506,7 @@ process_loc_forecast <- function(
ci_widths = c(0.5, 0.8, 0.95),
save = TRUE
) {
# New interface: delegate to process_forecast
# New interface: delegate to process_forecast with auto-detection
if (!is.na(model_name)) {
return(process_forecast(
model_run_dir = model_run_dir,
Expand All @@ -521,8 +521,8 @@ process_loc_forecast <- function(
# Legacy interface validation
if (is.na(pyrenew_model_name) && is.na(timeseries_model_name)) {
stop(
"Either `model_name` or `pyrenew_model_name`/",
"`timeseries_model_name` must be provided."
"At least one of `model_name`, `pyrenew_model_name`, ",
"or `timeseries_model_name` must be provided."
)
}
model_name <- dplyr::if_else(
Expand Down
4 changes: 2 additions & 2 deletions hewr/tests/testthat/test_process_loc_forecast.R
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ test_that("process_loc_forecast delegates correctly", {
pyrenew_model_name = NA,
timeseries_model_name = NA
),
"Either `model_name` or `pyrenew_model_name`"
"At least one of"
)
})

Expand All @@ -246,6 +246,6 @@ test_that("process_loc_forecast validates legacy interface parameters", {
pyrenew_model_name = NA,
timeseries_model_name = NA
),
"Either `model_name` or `pyrenew_model_name`"
"At least one of"
)
})
15 changes: 13 additions & 2 deletions pipelines/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,7 @@ def plot_and_save_loc_forecast(
n_forecast_days: int,
pyrenew_model_name: str = None,
timeseries_model_name: str = None,
model_name: str = None,
) -> None:
"""Plot and save location forecast using R script.

Expand All @@ -339,9 +340,12 @@ def plot_and_save_loc_forecast(
n_forecast_days : int
Number of days to forecast.
pyrenew_model_name : str, optional
Name of the PyRenew model.
Name of the PyRenew model (legacy).
timeseries_model_name : str, optional
Name of the timeseries model.
Name of the timeseries model (legacy).
model_name : str, optional
Generic model name. When provided, auto-detects model type
and dispatches to appropriate processing method.

Returns
-------
Expand All @@ -366,6 +370,13 @@ def plot_and_save_loc_forecast(
f"{timeseries_model_name}",
]
)
if model_name is not None:
args.extend(
[
"--model-name",
f"{model_name}",
]
)

run_r_script(
"pipelines/plot_and_save_loc_forecast.R",
Expand Down
32 changes: 21 additions & 11 deletions pipelines/plot_and_save_loc_forecast.R
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,13 @@ save_forecast_figures <- function(
model_run_dir,
n_forecast_days,
pyrenew_model_name = NA,
timeseries_model_name = NA
timeseries_model_name = NA,
model_name = NA
) {
if (is.na(pyrenew_model_name) && is.na(timeseries_model_name)) {
if (all(is.na(c(pyrenew_model_name, timeseries_model_name, model_name)))) {
stop(
"Either `pyrenew_model_name` or `timeseries_model_name`",
"must be provided."
"At least one of `pyrenew_model_name`, `timeseries_model_name`, ",
"or `model_name` must be provided."
)
}

Expand All @@ -52,13 +53,15 @@ save_forecast_figures <- function(
str_replace_all("_+", "_")
}

model_name <- dplyr::if_else(
is.na(pyrenew_model_name),
timeseries_model_name,
pyrenew_model_name
# Determine which model name to use (prioritize in order)
final_model_name <- dplyr::case_when(
!is.na(model_name) ~ model_name,
!is.na(pyrenew_model_name) ~ pyrenew_model_name,
!is.na(timeseries_model_name) ~ timeseries_model_name,
TRUE ~ NA_character_
)

model_dir <- fs::path(model_run_dir, model_name)
model_dir <- fs::path(model_run_dir, final_model_name)
figure_dir <- fs::path(model_dir, "figures")
data_dir <- fs::path(model_dir, "data")
dir_create(figure_dir)
Expand All @@ -69,6 +72,7 @@ save_forecast_figures <- function(
n_forecast_days = n_forecast_days,
pyrenew_model_name = pyrenew_model_name,
timeseries_model_name = timeseries_model_name,
model_name = model_name,
save = TRUE
)

Expand Down Expand Up @@ -128,7 +132,7 @@ save_forecast_figures <- function(
) |>
mutate(
file_name = create_file_name(
model_name,
final_model_name,
.data$.variable,
.data$resolution,
.data$aggregated_numerator,
Expand Down Expand Up @@ -168,6 +172,10 @@ p <- arg_parser("Generate forecast figures") |>
add_argument(
"--n-forecast-days",
help = "Number of days to forecast"
) |>
add_argument(
"--model-name",
help = "Name of directory with model outputs (auto-detects type)"
)

argv <- parse_args(p)
Expand All @@ -176,10 +184,12 @@ model_run_dir <- path(argv$model_run_dir)
n_forecast_days <- as.numeric(argv$n_forecast_days)
pyrenew_model_name <- argv$pyrenew_model_name
timeseries_model_name <- argv$timeseries_model_name
model_name <- argv$model_name

save_forecast_figures(
model_run_dir,
n_forecast_days,
pyrenew_model_name,
timeseries_model_name
timeseries_model_name,
model_name
)