From 666a30036abed78946188b06b9213d98b86846f5 Mon Sep 17 00:00:00 2001 From: Damon Bayer Date: Thu, 8 Jan 2026 19:25:25 +0000 Subject: [PATCH 1/6] create local copy --- pipelines/azure_command_center.py | 2 +- pipelines/postprocess_forecast_batches.py | 68 +++++++++++++++++++---- 2 files changed, 58 insertions(+), 12 deletions(-) diff --git a/pipelines/azure_command_center.py b/pipelines/azure_command_center.py index 0479a013..b965d8aa 100644 --- a/pipelines/azure_command_center.py +++ b/pipelines/azure_command_center.py @@ -440,7 +440,7 @@ def ask_integer_choice(choices): ) postprocess( base_forecast_dir=pyrenew_hew_prod_output_path / output_subdir, - diseases=ALL_DISEASES, + diseases=list(ALL_DISEASES), skip_existing=skip_existing, ) diff --git a/pipelines/postprocess_forecast_batches.py b/pipelines/postprocess_forecast_batches.py index 6bcad792..c89b5025 100644 --- a/pipelines/postprocess_forecast_batches.py +++ b/pipelines/postprocess_forecast_batches.py @@ -9,6 +9,7 @@ import argparse import datetime as dt import logging +import shutil from pathlib import Path import collate_plots as cp @@ -16,6 +17,8 @@ from pipelines.common_utils import run_r_script from pipelines.utils import get_all_forecast_dirs, parse_model_batch_dir_name +local_dir = Path.home() / "stf_forecast_fig_share" + def _hubverse_table_filename(report_date: str | dt.date, disease: str) -> None: return f"{report_date}-{disease.lower()}-hubverse-table.parquet" @@ -51,26 +54,64 @@ def process_model_batch_dir(model_batch_dir_path: Path, plot_ext: str = "pdf") - combine_hubverse_tables(model_batch_dir_path) +def model_batch_dir_to_target_path( + model_batch_dir: str, + max_last_training_date: dt.date, + pre_path=local_dir, +) -> Path: + parts = parse_model_batch_dir_name(model_batch_dir) + lookback = (parts["last_training_date"] - parts["first_training_date"]).days + 1 + omit = (max_last_training_date - parts["last_training_date"]).days + 1 + target_path = Path( + pre_path, + f"lookback-{lookback}-omit-{omit}", + parts["disease"], + ) + return target_path + + def main( base_forecast_dir: Path | str, diseases: list[str] = ["COVID-19", "Influenza", "RSV"], skip_existing: bool = True, + create_local_copy: bool = True, ) -> None: logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) - to_process = get_all_forecast_dirs(base_forecast_dir, list(diseases)) + to_process = get_all_forecast_dirs(base_forecast_dir, diseases) + # compute max last training date across all model batch dirs and assume this corresponds to omitting 1 day. + max_last_training_date = max( + [ + parse_model_batch_dir_name(model_batch_dir)["last_training_date"] + for model_batch_dir in to_process + ] + ) + if skip_existing: + to_process = [ + batch_dir + for batch_dir in to_process + if not bool( + list( + Path(base_forecast_dir, batch_dir).glob("*-hubverse-table.parquet") + ) + ) + ] + for batch_dir in to_process: model_batch_dir_path = Path(base_forecast_dir, batch_dir) - hubverse_tbl_exists = bool( - list(model_batch_dir_path.glob("*-hubverse-table.parquet")) - ) - if hubverse_tbl_exists and skip_existing: - logger.info(f"Skipping {batch_dir}, hubverse table already exists.") - else: - logger.info(f"Processing {batch_dir}...") - process_model_batch_dir(model_batch_dir_path) - logger.info(f"Finished processing {batch_dir}") - logger.info(f"Finished processing {base_forecast_dir}.") + logger.info(f"Processing {batch_dir}...") + process_model_batch_dir(model_batch_dir_path) + logger.info(f"Finished processing {batch_dir}") + if create_local_copy: + source_dir = Path(base_forecast_dir, batch_dir, "figures") + target_dir = model_batch_dir_to_target_path( + batch_dir, max_last_training_date, local_dir + ) + logger.info( + f"Copying from {source_dir.relative_to(base_forecast_dir)} to {target_dir.relative_to(local_dir)}..." + ) + shutil.copytree(source_dir, target_dir, dirs_exist_ok=True) + logger.info(f"Finished processing {base_forecast_dir}.") if __name__ == "__main__": @@ -98,6 +139,11 @@ def main( action="store_true", help="Skip processing for model batch directories that already have been processed.", ) + parser.add_argument( + "--local-copy", + action="store_true", + help="Create a local copy of the processed files.", + ) args = parser.parse_args() args.diseases = args.diseases.split() From 1c06c0fac29bbe310a362e79e5977895806e92e5 Mon Sep 17 00:00:00 2001 From: Damon Bayer Date: Mon, 12 Jan 2026 17:16:18 -0600 Subject: [PATCH 2/6] Revert "Merge branch 'main' into dmb_fig_share" This reverts commit 5cd8f733d6391290e41f2d7bb7967394cbc830c7, reversing changes made to 666a30036abed78946188b06b9213d98b86846f5. --- .github/workflows/containers.yaml | 2 +- .pre-commit-config.yaml | 4 +- pipelines/azure_command_center.py | 73 ++++++----------------- pipelines/postprocess_forecast_batches.py | 2 +- 4 files changed, 23 insertions(+), 58 deletions(-) diff --git a/.github/workflows/containers.yaml b/.github/workflows/containers.yaml index 9a2cb0ee..0c66aabe 100644 --- a/.github/workflows/containers.yaml +++ b/.github/workflows/containers.yaml @@ -52,7 +52,7 @@ jobs: uses: docker/setup-buildx-action@v3 - name: Docker build and push - uses: docker/build-push-action@v6 + uses: docker/build-push-action@v5 with: context: . file: Containerfile diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a0e8c618..956d5fb0 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -13,14 +13,14 @@ repos: ##### # Python - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.14.11 + rev: v0.14.10 hooks: # Sort imports - id: ruff-check # Run the formatter - id: ruff-format - repo: https://github.com/astral-sh/uv-pre-commit - rev: 0.9.24 + rev: 0.9.21 hooks: - id: uv-lock ##### diff --git a/pipelines/azure_command_center.py b/pipelines/azure_command_center.py index 225ca8ba..b965d8aa 100644 --- a/pipelines/azure_command_center.py +++ b/pipelines/azure_command_center.py @@ -17,8 +17,6 @@ from pipelines.postprocess_forecast_batches import main as postprocess DEFAULT_RNG_KEY = 12345 -DEFAULT_TRAINING_DAYS = 150 -DEFAULT_EXCLUDE_LAST_N_DAYS = 1 load_dotenv() console = Console() @@ -56,8 +54,8 @@ def setup_job_append_id( additional_forecast_letters: str = "", container_image_name: str = "pyrenew-hew", container_image_version: str = "latest", - n_training_days: int = DEFAULT_TRAINING_DAYS, - exclude_last_n_days: int = DEFAULT_EXCLUDE_LAST_N_DAYS, + n_training_days: int = 150, + exclude_last_n_days: int = 1, rng_key: int = DEFAULT_RNG_KEY, locations_include: list[str] | None = None, locations_exclude: list[str] | None = None, @@ -168,52 +166,33 @@ def ask_about_reruns(): "How many days to exclude for H signal?", default=1 ) rng_key = IntPrompt.ask("RNG seed for reproducibility?", default=DEFAULT_RNG_KEY) - n_training_days = IntPrompt.ask( - "Number of training days?", default=DEFAULT_TRAINING_DAYS - ) return { "locations_include": locations_include, "e_exclude_last_n_days": e_exclude_last_n_days, "h_exclude_last_n_days": h_exclude_last_n_days, "rng_key": rng_key, - "n_training_days": n_training_days, } -def compute_skips( - e_exclude_last_n_days: int, - h_exclude_last_n_days: int, - rng_key: int, - n_training_days: int, -): - run_due_to_param_change = ( - n_training_days != DEFAULT_TRAINING_DAYS or rng_key != DEFAULT_RNG_KEY +def compute_skips(e_exclude_last_n_days: int, h_exclude_last_n_days: int, rng_key: int): + skip_e = e_exclude_last_n_days == 1 and rng_key == DEFAULT_RNG_KEY + skip_h = h_exclude_last_n_days == 1 and rng_key == DEFAULT_RNG_KEY + skip_he = ( + max(e_exclude_last_n_days, h_exclude_last_n_days) == 1 + and rng_key == DEFAULT_RNG_KEY ) - if run_due_to_param_change: - skip_e = False - skip_h = False - skip_he = False - else: - skip_e = e_exclude_last_n_days == DEFAULT_EXCLUDE_LAST_N_DAYS - skip_h = h_exclude_last_n_days == DEFAULT_EXCLUDE_LAST_N_DAYS - skip_he = skip_e and skip_h return {"skip_e": skip_e, "skip_h": skip_h, "skip_he": skip_he} def do_timeseries_reruns( locations_include: list[str] | None = None, - e_exclude_last_n_days: int = DEFAULT_EXCLUDE_LAST_N_DAYS, - h_exclude_last_n_days: int = DEFAULT_EXCLUDE_LAST_N_DAYS, + e_exclude_last_n_days: int = 1, + h_exclude_last_n_days: int = 1, rng_key: int = DEFAULT_RNG_KEY, # not used, but kept for interface consistency append_id: str = "", - n_training_days: int = DEFAULT_TRAINING_DAYS, ): - skips = compute_skips( - e_exclude_last_n_days, h_exclude_last_n_days, rng_key, n_training_days - ) - he_exclude_last_n_days = max(e_exclude_last_n_days, h_exclude_last_n_days) - he_exclude_covered_by_e = he_exclude_last_n_days == e_exclude_last_n_days + skips = compute_skips(e_exclude_last_n_days, h_exclude_last_n_days, rng_key) if skips["skip_e"]: print("Skipping Timeseries-E re-fitting due to E") @@ -222,31 +201,26 @@ def do_timeseries_reruns( append_id=append_id, locations_include=locations_include, exclude_last_n_days=e_exclude_last_n_days, - n_training_days=n_training_days, ) - if skips["skip_he"] or he_exclude_covered_by_e: - print("Skipping Timeseries-E re-fitting due to HE*") + if skips["skip_h"]: + print("Skipping Timeseries-E re-fitting due to H") else: fit_timeseries_e( append_id=append_id, locations_include=locations_include, - exclude_last_n_days=he_exclude_last_n_days, - n_training_days=n_training_days, + exclude_last_n_days=h_exclude_last_n_days, ) def do_pyrenew_reruns( locations_include: list[str] | None = None, - e_exclude_last_n_days: int = DEFAULT_EXCLUDE_LAST_N_DAYS, - h_exclude_last_n_days: int = DEFAULT_EXCLUDE_LAST_N_DAYS, + e_exclude_last_n_days: int = 1, + h_exclude_last_n_days: int = 1, rng_key: int = DEFAULT_RNG_KEY, append_id: str = "", - n_training_days: int = DEFAULT_TRAINING_DAYS, ): he_exclude_last_n_days = max(e_exclude_last_n_days, h_exclude_last_n_days) - skips = compute_skips( - e_exclude_last_n_days, h_exclude_last_n_days, rng_key, n_training_days - ) + skips = compute_skips(e_exclude_last_n_days, h_exclude_last_n_days, rng_key) if skips["skip_e"]: print("Skipping PyRenew-E re-fitting") @@ -256,7 +230,6 @@ def do_pyrenew_reruns( locations_include=locations_include, exclude_last_n_days=e_exclude_last_n_days, rng_key=rng_key, - n_training_days=n_training_days, ) if skips["skip_h"]: @@ -267,7 +240,6 @@ def do_pyrenew_reruns( locations_include=locations_include, exclude_last_n_days=h_exclude_last_n_days, rng_key=rng_key, - n_training_days=n_training_days, ) # fit_pyrenew_hw( # append_id=append_id, @@ -284,7 +256,6 @@ def do_pyrenew_reruns( locations_include=locations_include, exclude_last_n_days=he_exclude_last_n_days, rng_key=rng_key, - n_training_days=n_training_days, ) # fit_pyrenew_hew( @@ -458,16 +429,10 @@ def ask_integer_choice(choices): # fit_pyrenew_hew(append_id=current_time) elif selected_choice == "Rerun Timeseries Models": ask_about_reruns_input = ask_about_reruns() - do_timeseries_reruns( - append_id=current_time, - **ask_about_reruns_input, - ) + do_timeseries_reruns(append_id=current_time, **ask_about_reruns_input) elif selected_choice == "Rerun PyRenew Models": ask_about_reruns_input = ask_about_reruns() - do_pyrenew_reruns( - append_id=current_time, - **ask_about_reruns_input, - ) + do_pyrenew_reruns(append_id=current_time, **ask_about_reruns_input) elif selected_choice == "Postprocess Forecast Batches": skip_existing = Confirm.ask( "Skip processing for model batch directories that already have figures?", diff --git a/pipelines/postprocess_forecast_batches.py b/pipelines/postprocess_forecast_batches.py index 52f0671d..c89b5025 100644 --- a/pipelines/postprocess_forecast_batches.py +++ b/pipelines/postprocess_forecast_batches.py @@ -72,7 +72,7 @@ def model_batch_dir_to_target_path( def main( base_forecast_dir: Path | str, - diseases: list[str] | set[str] = ["COVID-19", "Influenza", "RSV"], + diseases: list[str] = ["COVID-19", "Influenza", "RSV"], skip_existing: bool = True, create_local_copy: bool = True, ) -> None: From 017d9ed945da64dc6c3fd9e8d60a29f5133a2b15 Mon Sep 17 00:00:00 2001 From: Damon Bayer Date: Mon, 12 Jan 2026 17:18:15 -0600 Subject: [PATCH 3/6] Reapply "Merge branch 'main' into dmb_fig_share" This reverts commit 1c06c0fac29bbe310a362e79e5977895806e92e5. --- .github/workflows/containers.yaml | 2 +- .pre-commit-config.yaml | 4 +- pipelines/azure_command_center.py | 73 +++++++++++++++++------ pipelines/postprocess_forecast_batches.py | 2 +- 4 files changed, 58 insertions(+), 23 deletions(-) diff --git a/.github/workflows/containers.yaml b/.github/workflows/containers.yaml index 0c66aabe..9a2cb0ee 100644 --- a/.github/workflows/containers.yaml +++ b/.github/workflows/containers.yaml @@ -52,7 +52,7 @@ jobs: uses: docker/setup-buildx-action@v3 - name: Docker build and push - uses: docker/build-push-action@v5 + uses: docker/build-push-action@v6 with: context: . file: Containerfile diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 956d5fb0..a0e8c618 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -13,14 +13,14 @@ repos: ##### # Python - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.14.10 + rev: v0.14.11 hooks: # Sort imports - id: ruff-check # Run the formatter - id: ruff-format - repo: https://github.com/astral-sh/uv-pre-commit - rev: 0.9.21 + rev: 0.9.24 hooks: - id: uv-lock ##### diff --git a/pipelines/azure_command_center.py b/pipelines/azure_command_center.py index b965d8aa..225ca8ba 100644 --- a/pipelines/azure_command_center.py +++ b/pipelines/azure_command_center.py @@ -17,6 +17,8 @@ from pipelines.postprocess_forecast_batches import main as postprocess DEFAULT_RNG_KEY = 12345 +DEFAULT_TRAINING_DAYS = 150 +DEFAULT_EXCLUDE_LAST_N_DAYS = 1 load_dotenv() console = Console() @@ -54,8 +56,8 @@ def setup_job_append_id( additional_forecast_letters: str = "", container_image_name: str = "pyrenew-hew", container_image_version: str = "latest", - n_training_days: int = 150, - exclude_last_n_days: int = 1, + n_training_days: int = DEFAULT_TRAINING_DAYS, + exclude_last_n_days: int = DEFAULT_EXCLUDE_LAST_N_DAYS, rng_key: int = DEFAULT_RNG_KEY, locations_include: list[str] | None = None, locations_exclude: list[str] | None = None, @@ -166,33 +168,52 @@ def ask_about_reruns(): "How many days to exclude for H signal?", default=1 ) rng_key = IntPrompt.ask("RNG seed for reproducibility?", default=DEFAULT_RNG_KEY) + n_training_days = IntPrompt.ask( + "Number of training days?", default=DEFAULT_TRAINING_DAYS + ) return { "locations_include": locations_include, "e_exclude_last_n_days": e_exclude_last_n_days, "h_exclude_last_n_days": h_exclude_last_n_days, "rng_key": rng_key, + "n_training_days": n_training_days, } -def compute_skips(e_exclude_last_n_days: int, h_exclude_last_n_days: int, rng_key: int): - skip_e = e_exclude_last_n_days == 1 and rng_key == DEFAULT_RNG_KEY - skip_h = h_exclude_last_n_days == 1 and rng_key == DEFAULT_RNG_KEY - skip_he = ( - max(e_exclude_last_n_days, h_exclude_last_n_days) == 1 - and rng_key == DEFAULT_RNG_KEY +def compute_skips( + e_exclude_last_n_days: int, + h_exclude_last_n_days: int, + rng_key: int, + n_training_days: int, +): + run_due_to_param_change = ( + n_training_days != DEFAULT_TRAINING_DAYS or rng_key != DEFAULT_RNG_KEY ) + if run_due_to_param_change: + skip_e = False + skip_h = False + skip_he = False + else: + skip_e = e_exclude_last_n_days == DEFAULT_EXCLUDE_LAST_N_DAYS + skip_h = h_exclude_last_n_days == DEFAULT_EXCLUDE_LAST_N_DAYS + skip_he = skip_e and skip_h return {"skip_e": skip_e, "skip_h": skip_h, "skip_he": skip_he} def do_timeseries_reruns( locations_include: list[str] | None = None, - e_exclude_last_n_days: int = 1, - h_exclude_last_n_days: int = 1, + e_exclude_last_n_days: int = DEFAULT_EXCLUDE_LAST_N_DAYS, + h_exclude_last_n_days: int = DEFAULT_EXCLUDE_LAST_N_DAYS, rng_key: int = DEFAULT_RNG_KEY, # not used, but kept for interface consistency append_id: str = "", + n_training_days: int = DEFAULT_TRAINING_DAYS, ): - skips = compute_skips(e_exclude_last_n_days, h_exclude_last_n_days, rng_key) + skips = compute_skips( + e_exclude_last_n_days, h_exclude_last_n_days, rng_key, n_training_days + ) + he_exclude_last_n_days = max(e_exclude_last_n_days, h_exclude_last_n_days) + he_exclude_covered_by_e = he_exclude_last_n_days == e_exclude_last_n_days if skips["skip_e"]: print("Skipping Timeseries-E re-fitting due to E") @@ -201,26 +222,31 @@ def do_timeseries_reruns( append_id=append_id, locations_include=locations_include, exclude_last_n_days=e_exclude_last_n_days, + n_training_days=n_training_days, ) - if skips["skip_h"]: - print("Skipping Timeseries-E re-fitting due to H") + if skips["skip_he"] or he_exclude_covered_by_e: + print("Skipping Timeseries-E re-fitting due to HE*") else: fit_timeseries_e( append_id=append_id, locations_include=locations_include, - exclude_last_n_days=h_exclude_last_n_days, + exclude_last_n_days=he_exclude_last_n_days, + n_training_days=n_training_days, ) def do_pyrenew_reruns( locations_include: list[str] | None = None, - e_exclude_last_n_days: int = 1, - h_exclude_last_n_days: int = 1, + e_exclude_last_n_days: int = DEFAULT_EXCLUDE_LAST_N_DAYS, + h_exclude_last_n_days: int = DEFAULT_EXCLUDE_LAST_N_DAYS, rng_key: int = DEFAULT_RNG_KEY, append_id: str = "", + n_training_days: int = DEFAULT_TRAINING_DAYS, ): he_exclude_last_n_days = max(e_exclude_last_n_days, h_exclude_last_n_days) - skips = compute_skips(e_exclude_last_n_days, h_exclude_last_n_days, rng_key) + skips = compute_skips( + e_exclude_last_n_days, h_exclude_last_n_days, rng_key, n_training_days + ) if skips["skip_e"]: print("Skipping PyRenew-E re-fitting") @@ -230,6 +256,7 @@ def do_pyrenew_reruns( locations_include=locations_include, exclude_last_n_days=e_exclude_last_n_days, rng_key=rng_key, + n_training_days=n_training_days, ) if skips["skip_h"]: @@ -240,6 +267,7 @@ def do_pyrenew_reruns( locations_include=locations_include, exclude_last_n_days=h_exclude_last_n_days, rng_key=rng_key, + n_training_days=n_training_days, ) # fit_pyrenew_hw( # append_id=append_id, @@ -256,6 +284,7 @@ def do_pyrenew_reruns( locations_include=locations_include, exclude_last_n_days=he_exclude_last_n_days, rng_key=rng_key, + n_training_days=n_training_days, ) # fit_pyrenew_hew( @@ -429,10 +458,16 @@ def ask_integer_choice(choices): # fit_pyrenew_hew(append_id=current_time) elif selected_choice == "Rerun Timeseries Models": ask_about_reruns_input = ask_about_reruns() - do_timeseries_reruns(append_id=current_time, **ask_about_reruns_input) + do_timeseries_reruns( + append_id=current_time, + **ask_about_reruns_input, + ) elif selected_choice == "Rerun PyRenew Models": ask_about_reruns_input = ask_about_reruns() - do_pyrenew_reruns(append_id=current_time, **ask_about_reruns_input) + do_pyrenew_reruns( + append_id=current_time, + **ask_about_reruns_input, + ) elif selected_choice == "Postprocess Forecast Batches": skip_existing = Confirm.ask( "Skip processing for model batch directories that already have figures?", diff --git a/pipelines/postprocess_forecast_batches.py b/pipelines/postprocess_forecast_batches.py index c89b5025..52f0671d 100644 --- a/pipelines/postprocess_forecast_batches.py +++ b/pipelines/postprocess_forecast_batches.py @@ -72,7 +72,7 @@ def model_batch_dir_to_target_path( def main( base_forecast_dir: Path | str, - diseases: list[str] = ["COVID-19", "Influenza", "RSV"], + diseases: list[str] | set[str] = ["COVID-19", "Influenza", "RSV"], skip_existing: bool = True, create_local_copy: bool = True, ) -> None: From 4c698d608c1203f13b05d8bdf520871835c685f8 Mon Sep 17 00:00:00 2001 From: Damon Bayer Date: Mon, 12 Jan 2026 17:26:54 -0600 Subject: [PATCH 4/6] command line option --- pipelines/azure_command_center.py | 8 ++++++++ pipelines/postprocess_forecast_batches.py | 19 +++++++++---------- 2 files changed, 17 insertions(+), 10 deletions(-) diff --git a/pipelines/azure_command_center.py b/pipelines/azure_command_center.py index 225ca8ba..f17aafaa 100644 --- a/pipelines/azure_command_center.py +++ b/pipelines/azure_command_center.py @@ -16,6 +16,8 @@ from pipelines.batch.setup_job import main as setup_job from pipelines.postprocess_forecast_batches import main as postprocess +LOCAL_COPY_DIR = Path.home() / "stf_forecast_fig_share" + DEFAULT_RNG_KEY = 12345 DEFAULT_TRAINING_DAYS = 150 DEFAULT_EXCLUDE_LAST_N_DAYS = 1 @@ -473,10 +475,16 @@ def ask_integer_choice(choices): "Skip processing for model batch directories that already have figures?", default=True, ) + save_local_copy = Confirm.ask( + f"Save a local copy of figures to {LOCAL_COPY_DIR}?", + default=True, + ) + local_copy_dir = LOCAL_COPY_DIR if save_local_copy else "" postprocess( base_forecast_dir=pyrenew_hew_prod_output_path / output_subdir, diseases=list(ALL_DISEASES), skip_existing=skip_existing, + local_copy_dir=local_copy_dir, ) input("Press enter to continue...") diff --git a/pipelines/postprocess_forecast_batches.py b/pipelines/postprocess_forecast_batches.py index 52f0671d..5e4e9ebc 100644 --- a/pipelines/postprocess_forecast_batches.py +++ b/pipelines/postprocess_forecast_batches.py @@ -17,8 +17,6 @@ from pipelines.common_utils import run_r_script from pipelines.utils import get_all_forecast_dirs, parse_model_batch_dir_name -local_dir = Path.home() / "stf_forecast_fig_share" - def _hubverse_table_filename(report_date: str | dt.date, disease: str) -> None: return f"{report_date}-{disease.lower()}-hubverse-table.parquet" @@ -57,7 +55,7 @@ def process_model_batch_dir(model_batch_dir_path: Path, plot_ext: str = "pdf") - def model_batch_dir_to_target_path( model_batch_dir: str, max_last_training_date: dt.date, - pre_path=local_dir, + pre_path: Path | str, ) -> Path: parts = parse_model_batch_dir_name(model_batch_dir) lookback = (parts["last_training_date"] - parts["first_training_date"]).days + 1 @@ -74,7 +72,7 @@ def main( base_forecast_dir: Path | str, diseases: list[str] | set[str] = ["COVID-19", "Influenza", "RSV"], skip_existing: bool = True, - create_local_copy: bool = True, + local_copy_dir: Path | str = "", ) -> None: logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @@ -102,13 +100,13 @@ def main( logger.info(f"Processing {batch_dir}...") process_model_batch_dir(model_batch_dir_path) logger.info(f"Finished processing {batch_dir}") - if create_local_copy: + if local_copy_dir: source_dir = Path(base_forecast_dir, batch_dir, "figures") target_dir = model_batch_dir_to_target_path( - batch_dir, max_last_training_date, local_dir + batch_dir, max_last_training_date, local_copy_dir ) logger.info( - f"Copying from {source_dir.relative_to(base_forecast_dir)} to {target_dir.relative_to(local_dir)}..." + f"Copying from {source_dir.relative_to(base_forecast_dir)} to {target_dir.relative_to(local_copy_dir)}..." ) shutil.copytree(source_dir, target_dir, dirs_exist_ok=True) logger.info(f"Finished processing {base_forecast_dir}.") @@ -140,9 +138,10 @@ def main( help="Skip processing for model batch directories that already have been processed.", ) parser.add_argument( - "--local-copy", - action="store_true", - help="Create a local copy of the processed files.", + "--local-copy-dir", + type=str, + default="", + help="Save a local copy of the processed files to this directory, if supplied.", ) args = parser.parse_args() From ab871bf6dad14139b92eafa345f147b80abd070d Mon Sep 17 00:00:00 2001 From: Damon Bayer Date: Tue, 13 Jan 2026 14:03:54 -0600 Subject: [PATCH 5/6] Update pipelines/postprocess_forecast_batches.py --- pipelines/postprocess_forecast_batches.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/pipelines/postprocess_forecast_batches.py b/pipelines/postprocess_forecast_batches.py index 5e4e9ebc..e6ee8353 100644 --- a/pipelines/postprocess_forecast_batches.py +++ b/pipelines/postprocess_forecast_batches.py @@ -88,11 +88,7 @@ def main( to_process = [ batch_dir for batch_dir in to_process - if not bool( - list( - Path(base_forecast_dir, batch_dir).glob("*-hubverse-table.parquet") - ) - ) + if not any(Path(base_forecast_dir, batch_dir).glob("*-hubverse-table.parquet")) ] for batch_dir in to_process: From b73ebd268c2347677a2284dd5a0fc16088a3d2d6 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 13 Jan 2026 20:04:03 +0000 Subject: [PATCH 6/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pipelines/postprocess_forecast_batches.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pipelines/postprocess_forecast_batches.py b/pipelines/postprocess_forecast_batches.py index e6ee8353..da5bf563 100644 --- a/pipelines/postprocess_forecast_batches.py +++ b/pipelines/postprocess_forecast_batches.py @@ -88,7 +88,9 @@ def main( to_process = [ batch_dir for batch_dir in to_process - if not any(Path(base_forecast_dir, batch_dir).glob("*-hubverse-table.parquet")) + if not any( + Path(base_forecast_dir, batch_dir).glob("*-hubverse-table.parquet") + ) ] for batch_dir in to_process: