Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 54 additions & 19 deletions pipelines/azure_command_center.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from pipelines.postprocess_forecast_batches import main as postprocess

DEFAULT_RNG_KEY = 12345
DEFAULT_TRAINING_DAYS = 150
DEFAULT_EXCLUDE_LAST_N_DAYS = 1
load_dotenv()
console = Console()

Expand Down Expand Up @@ -54,8 +56,8 @@ def setup_job_append_id(
additional_forecast_letters: str = "",
container_image_name: str = "pyrenew-hew",
container_image_version: str = "latest",
n_training_days: int = 150,
exclude_last_n_days: int = 1,
n_training_days: int = DEFAULT_TRAINING_DAYS,
exclude_last_n_days: int = DEFAULT_EXCLUDE_LAST_N_DAYS,
rng_key: int = DEFAULT_RNG_KEY,
locations_include: list[str] | None = None,
locations_exclude: list[str] | None = None,
Expand Down Expand Up @@ -166,33 +168,52 @@ def ask_about_reruns():
"How many days to exclude for H signal?", default=1
)
rng_key = IntPrompt.ask("RNG seed for reproducibility?", default=DEFAULT_RNG_KEY)
n_training_days = IntPrompt.ask(
"Number of training days?", default=DEFAULT_TRAINING_DAYS
)

return {
"locations_include": locations_include,
"e_exclude_last_n_days": e_exclude_last_n_days,
"h_exclude_last_n_days": h_exclude_last_n_days,
"rng_key": rng_key,
"n_training_days": n_training_days,
}


def compute_skips(e_exclude_last_n_days: int, h_exclude_last_n_days: int, rng_key: int):
skip_e = e_exclude_last_n_days == 1 and rng_key == DEFAULT_RNG_KEY
skip_h = h_exclude_last_n_days == 1 and rng_key == DEFAULT_RNG_KEY
skip_he = (
max(e_exclude_last_n_days, h_exclude_last_n_days) == 1
and rng_key == DEFAULT_RNG_KEY
def compute_skips(
e_exclude_last_n_days: int,
h_exclude_last_n_days: int,
rng_key: int,
n_training_days: int,
):
run_due_to_param_change = (
n_training_days != DEFAULT_TRAINING_DAYS or rng_key != DEFAULT_RNG_KEY
)
if run_due_to_param_change:
skip_e = False
skip_h = False
skip_he = False
else:
skip_e = e_exclude_last_n_days == DEFAULT_EXCLUDE_LAST_N_DAYS
skip_h = h_exclude_last_n_days == DEFAULT_EXCLUDE_LAST_N_DAYS
skip_he = skip_e and skip_h
return {"skip_e": skip_e, "skip_h": skip_h, "skip_he": skip_he}


def do_timeseries_reruns(
locations_include: list[str] | None = None,
e_exclude_last_n_days: int = 1,
h_exclude_last_n_days: int = 1,
e_exclude_last_n_days: int = DEFAULT_EXCLUDE_LAST_N_DAYS,
h_exclude_last_n_days: int = DEFAULT_EXCLUDE_LAST_N_DAYS,
rng_key: int = DEFAULT_RNG_KEY, # not used, but kept for interface consistency
append_id: str = "",
n_training_days: int = DEFAULT_TRAINING_DAYS,
):
skips = compute_skips(e_exclude_last_n_days, h_exclude_last_n_days, rng_key)
skips = compute_skips(
e_exclude_last_n_days, h_exclude_last_n_days, rng_key, n_training_days
)
he_exclude_last_n_days = max(e_exclude_last_n_days, h_exclude_last_n_days)
he_exclude_covered_by_e = he_exclude_last_n_days == e_exclude_last_n_days

if skips["skip_e"]:
print("Skipping Timeseries-E re-fitting due to E")
Expand All @@ -201,26 +222,31 @@ def do_timeseries_reruns(
append_id=append_id,
locations_include=locations_include,
exclude_last_n_days=e_exclude_last_n_days,
n_training_days=n_training_days,
)
if skips["skip_h"]:
print("Skipping Timeseries-E re-fitting due to H")
if skips["skip_he"] or he_exclude_covered_by_e:
print("Skipping Timeseries-E re-fitting due to HE*")
else:
fit_timeseries_e(
append_id=append_id,
locations_include=locations_include,
exclude_last_n_days=h_exclude_last_n_days,
exclude_last_n_days=he_exclude_last_n_days,
n_training_days=n_training_days,
)


def do_pyrenew_reruns(
locations_include: list[str] | None = None,
e_exclude_last_n_days: int = 1,
h_exclude_last_n_days: int = 1,
e_exclude_last_n_days: int = DEFAULT_EXCLUDE_LAST_N_DAYS,
h_exclude_last_n_days: int = DEFAULT_EXCLUDE_LAST_N_DAYS,
rng_key: int = DEFAULT_RNG_KEY,
append_id: str = "",
n_training_days: int = DEFAULT_TRAINING_DAYS,
):
he_exclude_last_n_days = max(e_exclude_last_n_days, h_exclude_last_n_days)
skips = compute_skips(e_exclude_last_n_days, h_exclude_last_n_days, rng_key)
skips = compute_skips(
e_exclude_last_n_days, h_exclude_last_n_days, rng_key, n_training_days
)

if skips["skip_e"]:
print("Skipping PyRenew-E re-fitting")
Expand All @@ -230,6 +256,7 @@ def do_pyrenew_reruns(
locations_include=locations_include,
exclude_last_n_days=e_exclude_last_n_days,
rng_key=rng_key,
n_training_days=n_training_days,
)

if skips["skip_h"]:
Expand All @@ -240,6 +267,7 @@ def do_pyrenew_reruns(
locations_include=locations_include,
exclude_last_n_days=h_exclude_last_n_days,
rng_key=rng_key,
n_training_days=n_training_days,
)
# fit_pyrenew_hw(
# append_id=append_id,
Expand All @@ -256,6 +284,7 @@ def do_pyrenew_reruns(
locations_include=locations_include,
exclude_last_n_days=he_exclude_last_n_days,
rng_key=rng_key,
n_training_days=n_training_days,
)

# fit_pyrenew_hew(
Expand Down Expand Up @@ -429,10 +458,16 @@ def ask_integer_choice(choices):
# fit_pyrenew_hew(append_id=current_time)
elif selected_choice == "Rerun Timeseries Models":
ask_about_reruns_input = ask_about_reruns()
do_timeseries_reruns(append_id=current_time, **ask_about_reruns_input)
do_timeseries_reruns(
append_id=current_time,
**ask_about_reruns_input,
)
elif selected_choice == "Rerun PyRenew Models":
ask_about_reruns_input = ask_about_reruns()
do_pyrenew_reruns(append_id=current_time, **ask_about_reruns_input)
do_pyrenew_reruns(
append_id=current_time,
**ask_about_reruns_input,
)
elif selected_choice == "Postprocess Forecast Batches":
skip_existing = Confirm.ask(
"Skip processing for model batch directories that already have figures?",
Expand Down
2 changes: 1 addition & 1 deletion pipelines/postprocess_forecast_batches.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def process_model_batch_dir(model_batch_dir_path: Path, plot_ext: str = "pdf") -

def main(
base_forecast_dir: Path | str,
diseases: list[str] = ["COVID-19", "Influenza", "RSV"],
diseases: list[str] | set[str] = ["COVID-19", "Influenza", "RSV"],
skip_existing: bool = True,
) -> None:
logging.basicConfig(level=logging.INFO)
Expand Down