diff --git a/pipelines/azure_command_center.py b/pipelines/azure_command_center.py index 0479a013..ebc86d7f 100644 --- a/pipelines/azure_command_center.py +++ b/pipelines/azure_command_center.py @@ -17,6 +17,8 @@ from pipelines.postprocess_forecast_batches import main as postprocess DEFAULT_RNG_KEY = 12345 +DEFAULT_TRAINING_DAYS = 150 +DEFAULT_EXCLUDE_LAST_N_DAYS = 1 load_dotenv() console = Console() @@ -54,8 +56,8 @@ def setup_job_append_id( additional_forecast_letters: str = "", container_image_name: str = "pyrenew-hew", container_image_version: str = "latest", - n_training_days: int = 150, - exclude_last_n_days: int = 1, + n_training_days: int = DEFAULT_TRAINING_DAYS, + exclude_last_n_days: int = DEFAULT_EXCLUDE_LAST_N_DAYS, rng_key: int = DEFAULT_RNG_KEY, locations_include: list[str] | None = None, locations_exclude: list[str] | None = None, @@ -166,33 +168,52 @@ def ask_about_reruns(): "How many days to exclude for H signal?", default=1 ) rng_key = IntPrompt.ask("RNG seed for reproducibility?", default=DEFAULT_RNG_KEY) + n_training_days = IntPrompt.ask( + "Number of training days?", default=DEFAULT_TRAINING_DAYS + ) return { "locations_include": locations_include, "e_exclude_last_n_days": e_exclude_last_n_days, "h_exclude_last_n_days": h_exclude_last_n_days, "rng_key": rng_key, + "n_training_days": n_training_days, } -def compute_skips(e_exclude_last_n_days: int, h_exclude_last_n_days: int, rng_key: int): - skip_e = e_exclude_last_n_days == 1 and rng_key == DEFAULT_RNG_KEY - skip_h = h_exclude_last_n_days == 1 and rng_key == DEFAULT_RNG_KEY - skip_he = ( - max(e_exclude_last_n_days, h_exclude_last_n_days) == 1 - and rng_key == DEFAULT_RNG_KEY +def compute_skips( + e_exclude_last_n_days: int, + h_exclude_last_n_days: int, + rng_key: int, + n_training_days: int, +): + run_due_to_param_change = ( + n_training_days != DEFAULT_TRAINING_DAYS or rng_key != DEFAULT_RNG_KEY ) + if run_due_to_param_change: + skip_e = False + skip_h = False + skip_he = False + else: + skip_e = e_exclude_last_n_days == DEFAULT_EXCLUDE_LAST_N_DAYS + skip_h = h_exclude_last_n_days == DEFAULT_EXCLUDE_LAST_N_DAYS + skip_he = skip_e and skip_h return {"skip_e": skip_e, "skip_h": skip_h, "skip_he": skip_he} def do_timeseries_reruns( locations_include: list[str] | None = None, - e_exclude_last_n_days: int = 1, - h_exclude_last_n_days: int = 1, + e_exclude_last_n_days: int = DEFAULT_EXCLUDE_LAST_N_DAYS, + h_exclude_last_n_days: int = DEFAULT_EXCLUDE_LAST_N_DAYS, rng_key: int = DEFAULT_RNG_KEY, # not used, but kept for interface consistency append_id: str = "", + n_training_days: int = DEFAULT_TRAINING_DAYS, ): - skips = compute_skips(e_exclude_last_n_days, h_exclude_last_n_days, rng_key) + skips = compute_skips( + e_exclude_last_n_days, h_exclude_last_n_days, rng_key, n_training_days + ) + he_exclude_last_n_days = max(e_exclude_last_n_days, h_exclude_last_n_days) + he_exclude_covered_by_e = he_exclude_last_n_days == e_exclude_last_n_days if skips["skip_e"]: print("Skipping Timeseries-E re-fitting due to E") @@ -201,26 +222,31 @@ def do_timeseries_reruns( append_id=append_id, locations_include=locations_include, exclude_last_n_days=e_exclude_last_n_days, + n_training_days=n_training_days, ) - if skips["skip_h"]: - print("Skipping Timeseries-E re-fitting due to H") + if skips["skip_he"] or he_exclude_covered_by_e: + print("Skipping Timeseries-E re-fitting due to HE*") else: fit_timeseries_e( append_id=append_id, locations_include=locations_include, - exclude_last_n_days=h_exclude_last_n_days, + exclude_last_n_days=he_exclude_last_n_days, + n_training_days=n_training_days, ) def do_pyrenew_reruns( locations_include: list[str] | None = None, - e_exclude_last_n_days: int = 1, - h_exclude_last_n_days: int = 1, + e_exclude_last_n_days: int = DEFAULT_EXCLUDE_LAST_N_DAYS, + h_exclude_last_n_days: int = DEFAULT_EXCLUDE_LAST_N_DAYS, rng_key: int = DEFAULT_RNG_KEY, append_id: str = "", + n_training_days: int = DEFAULT_TRAINING_DAYS, ): he_exclude_last_n_days = max(e_exclude_last_n_days, h_exclude_last_n_days) - skips = compute_skips(e_exclude_last_n_days, h_exclude_last_n_days, rng_key) + skips = compute_skips( + e_exclude_last_n_days, h_exclude_last_n_days, rng_key, n_training_days + ) if skips["skip_e"]: print("Skipping PyRenew-E re-fitting") @@ -230,6 +256,7 @@ def do_pyrenew_reruns( locations_include=locations_include, exclude_last_n_days=e_exclude_last_n_days, rng_key=rng_key, + n_training_days=n_training_days, ) if skips["skip_h"]: @@ -240,6 +267,7 @@ def do_pyrenew_reruns( locations_include=locations_include, exclude_last_n_days=h_exclude_last_n_days, rng_key=rng_key, + n_training_days=n_training_days, ) # fit_pyrenew_hw( # append_id=append_id, @@ -256,6 +284,7 @@ def do_pyrenew_reruns( locations_include=locations_include, exclude_last_n_days=he_exclude_last_n_days, rng_key=rng_key, + n_training_days=n_training_days, ) # fit_pyrenew_hew( @@ -429,10 +458,16 @@ def ask_integer_choice(choices): # fit_pyrenew_hew(append_id=current_time) elif selected_choice == "Rerun Timeseries Models": ask_about_reruns_input = ask_about_reruns() - do_timeseries_reruns(append_id=current_time, **ask_about_reruns_input) + do_timeseries_reruns( + append_id=current_time, + **ask_about_reruns_input, + ) elif selected_choice == "Rerun PyRenew Models": ask_about_reruns_input = ask_about_reruns() - do_pyrenew_reruns(append_id=current_time, **ask_about_reruns_input) + do_pyrenew_reruns( + append_id=current_time, + **ask_about_reruns_input, + ) elif selected_choice == "Postprocess Forecast Batches": skip_existing = Confirm.ask( "Skip processing for model batch directories that already have figures?", diff --git a/pipelines/postprocess_forecast_batches.py b/pipelines/postprocess_forecast_batches.py index 6bcad792..fa56a6f2 100644 --- a/pipelines/postprocess_forecast_batches.py +++ b/pipelines/postprocess_forecast_batches.py @@ -53,7 +53,7 @@ def process_model_batch_dir(model_batch_dir_path: Path, plot_ext: str = "pdf") - def main( base_forecast_dir: Path | str, - diseases: list[str] = ["COVID-19", "Influenza", "RSV"], + diseases: list[str] | set[str] = ["COVID-19", "Influenza", "RSV"], skip_existing: bool = True, ) -> None: logging.basicConfig(level=logging.INFO)