From 01b2c57b202902b2cee4e2844ce14dd426150d28 Mon Sep 17 00:00:00 2001 From: Damon Bayer Date: Tue, 6 Jan 2026 16:16:57 -0600 Subject: [PATCH 01/15] timeseries forecasting working --- hewr/R/timeseries_utils.R | 3 +- pipelines/README.md | 2 - pipelines/batch/setup_job.py | 4 - pipelines/cli_utils.py | 13 -- pipelines/common_utils.py | 129 +------------- pipelines/forecast_pyrenew.py | 79 +++------ pipelines/forecast_timeseries.py | 45 +---- pipelines/forecast_timeseries_ensemble.R | 2 +- pipelines/generate_epiweekly_data.R | 6 +- pipelines/generate_test_data_lib.py | 5 +- pipelines/prep_data.py | 215 ++++++++++------------- pipelines/prep_eval_data.py | 73 -------- pipelines/tests/test_common_utils.py | 19 -- pipelines/tests/test_end_to_end.sh | 3 +- pipelines/tests/test_pyrenew_fit.sh | 2 - pipelines/tests/test_ts_fit.sh | 2 - 16 files changed, 140 insertions(+), 462 deletions(-) delete mode 100644 pipelines/prep_eval_data.py diff --git a/hewr/R/timeseries_utils.R b/hewr/R/timeseries_utils.R index a1b08bdf..6564343e 100644 --- a/hewr/R/timeseries_utils.R +++ b/hewr/R/timeseries_utils.R @@ -22,7 +22,7 @@ #' @export load_training_data <- function( model_dir, - base_data_name = "combined_training_data", + base_data_name = "combined_data", epiweekly = FALSE ) { resolution <- dplyr::if_else(epiweekly, "epiweekly", "daily") @@ -45,6 +45,7 @@ load_training_data <- function( .value = readr::col_double() ) ) |> + dplyr::filter(.data$data_type == "train") |> dplyr::select(-"lab_site_index") |> dplyr::filter(stringr::str_ends(.data$.variable, "ed_visits")) |> tidyr::pivot_wider(names_from = ".variable", values_from = ".value") diff --git a/pipelines/README.md b/pipelines/README.md index 3f3fd84b..7a2eb23b 100644 --- a/pipelines/README.md +++ b/pipelines/README.md @@ -2,11 +2,9 @@ Functions to run Pyrenew-HEW forecasting pipeline. Uses the following data directories: - facility-level-nssp-data-dir `nssp-etl/gold` - - state-level-nssp-data-dir `nssp-archival-vintages/gold` - param-data-dir `prod_param_estimates` - nwss-data-dir `nwss_vintages` And the following files: - - eval-data-path `nssp-etl/latest_comprehensive.parquet` - priors-path `pipelines/priors/prod_priors.py` diff --git a/pipelines/batch/setup_job.py b/pipelines/batch/setup_job.py index 9329f8f8..dc78fb6c 100644 --- a/pipelines/batch/setup_job.py +++ b/pipelines/batch/setup_job.py @@ -247,16 +247,12 @@ def main( "--loc {loc} " f"--n-training-days {n_training_days} " "--facility-level-nssp-data-dir nssp-etl/gold " - "--state-level-nssp-data-dir " - "nssp-archival-vintages/gold " "--param-data-dir params " "--output-dir {output_dir} " "--credentials-path config/creds.toml " "--report-date {report_date} " f"--exclude-last-n-days {exclude_last_n_days} " f"--model-letters {model_letters} " - "--eval-data-path " - "nssp-etl/latest_comprehensive.parquet " f"{additional_args}" "'" ) diff --git a/pipelines/cli_utils.py b/pipelines/cli_utils.py index 265dd878..75b5386e 100644 --- a/pipelines/cli_utils.py +++ b/pipelines/cli_utils.py @@ -38,13 +38,6 @@ def add_common_forecast_arguments(parser: argparse.ArgumentParser) -> None: help="Directory in which to look for facility-level NSSP ED visit data.", ) - parser.add_argument( - "--state-level-nssp-data-dir", - type=Path, - default=Path("private_data", "nssp_state_level_gold"), - help="Directory in which to look for state-level NSSP ED visit data.", - ) - parser.add_argument( "--param-data-dir", type=Path, @@ -92,12 +85,6 @@ def add_common_forecast_arguments(parser: argparse.ArgumentParser) -> None: help="Path to a TOML file containing credentials such as API keys.", ) - parser.add_argument( - "--eval-data-path", - type=Path, - help="Path to a parquet file containing comprehensive truth data.", - ) - parser.add_argument( "--nhsn-data-path", type=Path, diff --git a/pipelines/common_utils.py b/pipelines/common_utils.py index c8b4e7b4..fe7cf29a 100644 --- a/pipelines/common_utils.py +++ b/pipelines/common_utils.py @@ -44,68 +44,6 @@ def get_available_reports( ] -def parse_and_validate_report_date( - report_date: str, - available_facility_level_reports: list[dt.date], - available_loc_level_reports: list[dt.date], - logger: logging.Logger, -) -> tuple[dt.date, dt.date | None]: - """ - Parse and validate report date, determine location-level report date to use. - - Parameters - ---------- - report_date : str - Report date as string ("latest" or "YYYY-MM-DD" format). - available_facility_level_reports : list[dt.date] - List of available facility-level report dates. - available_loc_level_reports : list[dt.date] - List of available location-level report dates. - logger : logging.Logger - Process logger. - - Returns - ------- - tuple[dt.date, dt.date | None] - Tuple of (report_date, loc_report_date). - - Raises - ------ - ValueError - If report date is invalid or data is missing. - """ - first_available_loc_report = min(available_loc_level_reports) - last_available_loc_report = max(available_loc_level_reports) - - if report_date == "latest": - report_date = max(available_facility_level_reports) - else: - report_date = dt.datetime.strptime(report_date, "%Y-%m-%d").date() - - if report_date in available_loc_level_reports: - loc_report_date = report_date - elif report_date > last_available_loc_report: - loc_report_date = last_available_loc_report - elif report_date > first_available_loc_report: - raise ValueError( - "Dataset appear to be missing some state-level " - f"reports. First entry is {first_available_loc_report}, " - f"last is {last_available_loc_report}, but no entry " - f"for {report_date}" - ) - else: - raise ValueError( - "Requested report date is earlier than the first " - "state-level vintage. This is not currently supported" - ) - - logger.info(f"Report date: {report_date}") - if loc_report_date is not None: - logger.info(f"Using location-level data as of: {loc_report_date}") - - return report_date, loc_report_date - - def calculate_training_dates( report_date: dt.date, n_training_days: int, @@ -155,67 +93,6 @@ def calculate_training_dates( return first_training_date, last_training_date -def load_nssp_data( - report_date: dt.date, - loc_report_date: dt.date | None, - available_facility_level_reports: list[dt.date], - available_loc_level_reports: list[dt.date], - facility_level_nssp_data_dir: Path, - state_level_nssp_data_dir: Path, - logger: logging.Logger, -) -> tuple[pl.LazyFrame | None, pl.LazyFrame | None]: - """ - Load facility-level and location-level NSSP data. - - Parameters - ---------- - report_date : dt.date - The report date. - loc_report_date : dt.date | None - The location-level report date to use. - available_facility_level_reports : list[dt.date] - List of available facility-level report dates. - available_loc_level_reports : list[dt.date] - List of available location-level report dates. - facility_level_nssp_data_dir : Path - Directory containing facility-level NSSP data. - state_level_nssp_data_dir : Path - Directory containing state-level NSSP data. - logger : logging.Logger - Logger for informational messages. - - Returns - ------- - tuple[pl.LazyFrame | None, pl.LazyFrame | None] - Tuple of (facility_level_nssp_data, loc_level_nssp_data). - - Raises - ------ - ValueError - If no data is available for the requested report date. - """ - facility_level_nssp_data, loc_level_nssp_data = None, None - - if report_date in available_facility_level_reports: - logger.info("Facility level data available for the given report date") - facility_datafile = f"{report_date}.parquet" - facility_level_nssp_data = pl.scan_parquet( - Path(facility_level_nssp_data_dir, facility_datafile) - ) - if loc_report_date in available_loc_level_reports: - logger.info("location-level data available for the given report date.") - loc_datafile = f"{loc_report_date}.parquet" - loc_level_nssp_data = pl.scan_parquet( - Path(state_level_nssp_data_dir, loc_datafile) - ) - if facility_level_nssp_data is None and loc_level_nssp_data is None: - raise ValueError( - f"No data available for the requested report date {report_date}" - ) - - return facility_level_nssp_data, loc_level_nssp_data - - def run_r_script( script_name: str, args: list[str], @@ -386,6 +263,12 @@ def plot_and_save_loc_forecast( return None +def py_scalar_to_r_scalar(py_scalar): + if py_scalar is None: + return "NULL" + return f"'{str(py_scalar)}'" + + def create_hubverse_table(model_fit_path: Path) -> None: """Create hubverse table from model fit using R script. diff --git a/pipelines/forecast_pyrenew.py b/pipelines/forecast_pyrenew.py index 922be4b1..69e7c6e9 100644 --- a/pipelines/forecast_pyrenew.py +++ b/pipelines/forecast_pyrenew.py @@ -16,8 +16,6 @@ create_hubverse_table, get_available_reports, load_credentials, - load_nssp_data, - parse_and_validate_report_date, plot_and_save_loc_forecast, run_r_script, ) @@ -26,7 +24,6 @@ generate_and_save_predictions, ) from pipelines.prep_data import process_and_save_loc_data, process_and_save_loc_param -from pipelines.prep_eval_data import save_eval_data from pipelines.prep_ww_data import clean_nwss_data, preprocess_ww_data from pyrenew_hew.utils import ( flags_from_hew_letters, @@ -97,23 +94,20 @@ def generate_epiweekly_data(data_dir: Path) -> None: def main( disease: str, - report_date: str, loc: str, - facility_level_nssp_data_dir: Path | str, - state_level_nssp_data_dir: Path | str, - nwss_data_dir: Path | str, - param_data_dir: Path | str, - priors_path: Path | str, - output_dir: Path | str, + facility_level_nssp_data_dir: Path, + nwss_data_dir: Path, + param_data_dir: Path, + priors_path: Path, + output_dir: Path, n_training_days: int, n_forecast_days: int, n_chains: int, n_warmup: int, n_samples: int, - nhsn_data_path: Path | str = None, + nhsn_data_path: Path | None = None, exclude_last_n_days: int = 0, - eval_data_path: Path = None, - credentials_path: Path = None, + credentials_path: Path | None = None, fit_ed_visits: bool = False, fit_hospital_admissions: bool = False, fit_wastewater: bool = False, @@ -134,7 +128,7 @@ def main( logger.info( "Starting single-location forecasting pipeline for " f"model {pyrenew_model_name}, location {loc}, " - f"and report date {report_date}" + f"and latest NSSP report date." ) signals = ["ed_visits", "hospital_admissions", "wastewater"] @@ -159,14 +153,8 @@ def main( facility_level_nssp_data_dir ) - available_loc_level_reports = get_available_reports(state_level_nssp_data_dir) - - report_date, loc_report_date = parse_and_validate_report_date( - report_date, - available_facility_level_reports, - available_loc_level_reports, - logger, - ) + report_date = max(available_facility_level_reports) + facility_datafile = f"{report_date}.parquet" first_training_date, last_training_date = calculate_training_dates( report_date, @@ -175,14 +163,8 @@ def main( logger, ) - facility_level_nssp_data, loc_level_nssp_data = load_nssp_data( - report_date, - loc_report_date, - available_facility_level_reports, - available_loc_level_reports, - facility_level_nssp_data_dir, - state_level_nssp_data_dir, - logger, + facility_level_nssp_data = pl.scan_parquet( + Path(facility_level_nssp_data_dir, facility_datafile) ) nwss_data_disease_map = { @@ -235,17 +217,18 @@ def get_available_nwss_reports( data_dir = Path(model_dir, "data") os.makedirs(data_dir, exist_ok=True) - timeseries_model_name = "ts_ensemble_e" if fit_ed_visits else None + if fit_ed_visits: + timeseries_model_name = "ts_ensemble_e" - if fit_ed_visits and not os.path.exists(Path(model_run_dir, timeseries_model_name)): - raise ValueError( - f"{timeseries_model_name} model run not found. " - "Please ensure that the timeseries forecasts " - "for the ED visits (E) signal are generated " - "before fitting Pyrenew models with the E signal. " - "If running a batch job, set the flag --model-family " - "'timeseries' to fit timeseries model." - ) + if not os.path.exists(Path(model_run_dir, timeseries_model_name)): + raise ValueError( + f"{timeseries_model_name} model run not found. " + "Please ensure that the timeseries forecasts " + "for the ED visits (E) signal are generated " + "before fitting Pyrenew models with the E signal. " + "If running a batch job, set the flag --model-family " + "'timeseries' to fit timeseries model." + ) logger.info("Recording git info...") record_git_info(model_dir) @@ -258,7 +241,6 @@ def get_available_nwss_reports( loc_abb=loc, disease=disease, facility_level_nssp_data=facility_level_nssp_data, - loc_level_nssp_data=loc_level_nssp_data, loc_level_nwss_data=loc_level_nwss_data, report_date=report_date, first_training_date=first_training_date, @@ -277,21 +259,6 @@ def get_available_nwss_reports( fit_ed_visits=fit_ed_visits, save_dir=data_dir, ) - logger.info("Getting eval data...") - if eval_data_path is None: - raise ValueError("No path to an evaluation dataset provided.") - save_eval_data( - loc=loc, - disease=disease, - first_training_date=first_training_date, - last_training_date=last_training_date, - latest_comprehensive_path=eval_data_path, - output_data_dir=data_dir, - last_eval_date=report_date + dt.timedelta(days=n_forecast_days), - credentials_dict=credentials_dict, - nhsn_data_path=nhsn_data_path, - ) - logger.info("Done getting eval data.") logger.info("Generating epiweekly datasets from daily datasets...") generate_epiweekly_data(data_dir) diff --git a/pipelines/forecast_timeseries.py b/pipelines/forecast_timeseries.py index cd964d8a..8f9a40a9 100644 --- a/pipelines/forecast_timeseries.py +++ b/pipelines/forecast_timeseries.py @@ -4,14 +4,14 @@ import os from pathlib import Path +import polars as pl + from pipelines.cli_utils import add_common_forecast_arguments from pipelines.common_utils import ( calculate_training_dates, create_hubverse_table, get_available_reports, load_credentials, - load_nssp_data, - parse_and_validate_report_date, plot_and_save_loc_forecast, run_r_script, ) @@ -19,7 +19,6 @@ generate_epiweekly_data, ) from pipelines.prep_data import process_and_save_loc_data -from pipelines.prep_eval_data import save_eval_data def timeseries_ensemble_forecasts( @@ -45,17 +44,15 @@ def timeseries_ensemble_forecasts( def main( disease: str, - report_date: str, + report_date: dt.date, + param_data_dir: Path, loc: str, facility_level_nssp_data_dir: Path | str, - state_level_nssp_data_dir: Path | str, - param_data_dir: Path | str, output_dir: Path | str, n_training_days: int, n_forecast_days: int, n_samples: int, model_letters: str, - eval_data_path: Path, exclude_last_n_days: int = 0, credentials_path: Path | None = None, nhsn_data_path: Path | None = None, @@ -80,14 +77,9 @@ def main( available_facility_level_reports = get_available_reports( facility_level_nssp_data_dir ) - available_loc_level_reports = get_available_reports(state_level_nssp_data_dir) - report_date, loc_report_date = parse_and_validate_report_date( - report_date, - available_facility_level_reports, - available_loc_level_reports, - logger, - ) + report_date = max(available_facility_level_reports) + facility_datafile = f"{report_date}.parquet" first_training_date, last_training_date = calculate_training_dates( report_date, @@ -96,14 +88,8 @@ def main( logger, ) - facility_level_nssp_data, loc_level_nssp_data = load_nssp_data( - report_date, - loc_report_date, - available_facility_level_reports, - available_loc_level_reports, - facility_level_nssp_data_dir, - state_level_nssp_data_dir, - logger, + facility_level_nssp_data = pl.scan_parquet( + Path(facility_level_nssp_data_dir, facility_datafile) ) model_batch_dir_name = ( @@ -123,7 +109,6 @@ def main( loc_abb=loc, disease=disease, facility_level_nssp_data=facility_level_nssp_data, - loc_level_nssp_data=loc_level_nssp_data, loc_level_nwss_data=None, report_date=report_date, first_training_date=first_training_date, @@ -134,20 +119,6 @@ def main( nhsn_data_path=nhsn_data_path, ) - logger.info("Getting eval data...") - save_eval_data( - loc=loc, - disease=disease, - first_training_date=first_training_date, - last_training_date=last_training_date, - latest_comprehensive_path=eval_data_path, - output_data_dir=Path(ensemble_model_output_dir, "data"), - last_eval_date=report_date + dt.timedelta(days=n_forecast_days), - credentials_dict=credentials_dict, - nhsn_data_path=nhsn_data_path, - ) - logger.info("Done getting eval data.") - logger.info("Generating epiweekly datasets from daily datasets...") generate_epiweekly_data(Path(ensemble_model_output_dir, "data")) diff --git a/pipelines/forecast_timeseries_ensemble.R b/pipelines/forecast_timeseries_ensemble.R index a06f9d92..bad6833a 100644 --- a/pipelines/forecast_timeseries_ensemble.R +++ b/pipelines/forecast_timeseries_ensemble.R @@ -112,7 +112,7 @@ main <- function( ) { training_data <- hewr::load_training_data( model_dir, - "combined_training_data", + "combined_data", epiweekly ) target_and_other_data <- training_data$data diff --git a/pipelines/generate_epiweekly_data.R b/pipelines/generate_epiweekly_data.R index aaa4fd6a..72d64018 100644 --- a/pipelines/generate_epiweekly_data.R +++ b/pipelines/generate_epiweekly_data.R @@ -83,11 +83,7 @@ convert_daily_to_epiweekly <- function( main <- function(data_dir) { convert_daily_to_epiweekly( data_dir, - data_name = "combined_training_data.tsv" - ) - convert_daily_to_epiweekly( - data_dir, - data_name = "combined_eval_data.tsv" + data_name = "combined_data.tsv" ) } diff --git a/pipelines/generate_test_data_lib.py b/pipelines/generate_test_data_lib.py index adbfe62d..55d5a555 100644 --- a/pipelines/generate_test_data_lib.py +++ b/pipelines/generate_test_data_lib.py @@ -389,7 +389,6 @@ def simulate_data_from_bootstrap( loc_abb=bootstrap_loc, disease=bootstrap_disease, facility_level_nssp_data=bootstrap_facility_level_nssp_data.lazy(), - loc_level_nssp_data=bootstrap_loc_level_nssp_data.lazy(), loc_level_nwss_data=bootstrap_loc_level_nwss_data, report_date=max_train_date, first_training_date=first_training_date, @@ -465,7 +464,7 @@ def simulate_data_from_bootstrap( bootstrap_disease=bootstrap_disease, ) # Update the TSV file with realistic prior predictive values - tsv_file_path = Path(model_run_dir) / "data" / "combined_training_data.tsv" + tsv_file_path = Path(model_run_dir) / "data" / "combined_data.tsv" update_tsv_with_prior_predictive( tsv_file_path=tsv_file_path, idata=idata, @@ -612,7 +611,7 @@ def update_tsv_with_prior_predictive( """Update TSV file with realistic values from prior predictive sampling. Args: - tsv_file_path: Path to the combined_training_data.tsv file to update + tsv_file_path: Path to the combined_data.tsv file to update idata: ArviZ InferenceData containing prior predictive samples state_disease_key: DataFrame mapping draws to state/disease combinations bootstrap_loc: State abbreviation used for bootstrap diff --git a/pipelines/prep_data.py b/pipelines/prep_data.py index a8922b48..a9cdd003 100644 --- a/pipelines/prep_data.py +++ b/pipelines/prep_data.py @@ -3,8 +3,6 @@ import logging import os import tempfile -from datetime import date, datetime -from logging import Logger from pathlib import Path import forecasttools @@ -12,7 +10,7 @@ import polars as pl import polars.selectors as cs -from pipelines.common_utils import run_r_code +from pipelines.common_utils import py_scalar_to_r_scalar, run_r_code from pyrenew_hew.utils import approx_lognorm _disease_map = { @@ -25,8 +23,7 @@ def clean_nssp_data( data: pl.DataFrame, disease: str, - data_type: str, - last_data_date: date | None = None, + last_training_date: dt.date | None = None, ) -> pl.DataFrame: """ Filter, reformat, and annotate a raw `pl.DataFrame` of NSSP data, @@ -41,16 +38,12 @@ def clean_nssp_data( disease Name of the disease for which to prep data. - data_type - Value for the data_type annotation column in the - output dataframe. - - last_data_date - If provided, filter the dataset to include only dates - prior to this date. Default `None` (no filter). + last_training_date + Last date to include in the training data. """ - if last_data_date is not None: - data = data.filter(pl.col("date") <= last_data_date) + + if last_training_date is None: + last_training_date = data.get_column("date").max() return ( data.filter(pl.col("disease").is_in([disease, "Total"])) @@ -61,21 +54,23 @@ def clean_nssp_data( .rename({disease: "observed_ed_visits"}) .with_columns( other_ed_visits=pl.col("Total") - pl.col("observed_ed_visits"), - data_type=pl.lit(data_type), + data_type=pl.when(pl.col("date") <= last_training_date) + .then(pl.lit("train")) + .otherwise(pl.lit("eval")), ) - .drop(pl.col("Total")) + .drop("Total") .sort("date") ) def get_nhsn( - start_date: datetime.date, - end_date: datetime.date, + start_date: dt.date | None, + end_date: dt.date | None, disease: str, loc_abb: str, - temp_dir: Path = None, - credentials_dict: dict = None, - local_data_file: Path = None, + temp_dir: Path | None = None, + credentials_dict: dict | None = None, + local_data_file: Path | None = None, ) -> pl.DataFrame: if local_data_file is None: if temp_dir is None: @@ -83,11 +78,6 @@ def get_nhsn( if credentials_dict is None: credentials_dict = dict() - def py_scalar_to_r_scalar(py_scalar): - if py_scalar is None: - return "NULL" - return f"'{str(py_scalar)}'" - disease_nhsn_key = { "COVID-19": "totalconfc19newadm", "Influenza": "totalconfflunewadm", @@ -133,33 +123,41 @@ def py_scalar_to_r_scalar(py_scalar): def combine_surveillance_data( - nssp_data: pl.DataFrame, - nhsn_data: pl.DataFrame, disease: str, - nwss_data: pl.DataFrame = None, + nssp_data: pl.DataFrame | None = None, + nhsn_data: pl.DataFrame | None = None, + nwss_data: pl.DataFrame | None = None, ): - nssp_data_long = nssp_data.unpivot( - on=["observed_ed_visits", "other_ed_visits"], - variable_name=".variable", - index=cs.exclude(["observed_ed_visits", "other_ed_visits"]), - value_name=".value", - ).with_columns(pl.lit(None).alias("lab_site_index")) - - nhsn_data_long = ( - nhsn_data.rename( - { - "weekendingdate": "date", - "jurisdiction": "geo_value", - "hospital_admissions": "observed_hospital_admissions", - } - ) - .unpivot( - on="observed_hospital_admissions", - index=cs.exclude("observed_hospital_admissions"), + nssp_data_long = ( + nssp_data.unpivot( + on=["observed_ed_visits", "other_ed_visits"], variable_name=".variable", + index=cs.exclude(["observed_ed_visits", "other_ed_visits"]), value_name=".value", + ).with_columns(pl.lit(None).alias("lab_site_index")) + if nssp_data is not None + else pl.DataFrame() + ) + + nhsn_data_long = ( + ( + nhsn_data.rename( + { + "weekendingdate": "date", + "jurisdiction": "geo_value", + "hospital_admissions": "observed_hospital_admissions", + } + ) + .unpivot( + on="observed_hospital_admissions", + index=cs.exclude("observed_hospital_admissions"), + variable_name=".variable", + value_name=".value", + ) + .with_columns(pl.lit(None).alias("lab_site_index")) ) - .with_columns(pl.lit(None).alias("lab_site_index")) + if nhsn_data is not None + else pl.DataFrame() ) nwss_data_long = ( @@ -169,7 +167,6 @@ def combine_surveillance_data( "location": "geo_value", } ) - .with_columns(pl.lit("train").alias("data_type")) .select( cs.exclude( [ @@ -205,10 +202,10 @@ def combine_surveillance_data( "date", "geo_value", "disease", - "data_type", ".variable", ".value", "lab_site_index", + "data_type", ] ) ) @@ -216,10 +213,10 @@ def combine_surveillance_data( return combined_dat -def aggregate_to_national( +def aggregate_nssp_to_national( data: pl.LazyFrame, - geo_values_to_include: list[str], - first_date_to_include: datetime.date, + geo_values_to_include: pl.Series | list[str], + first_date_to_include: dt.date, national_geo_value="US", ): assert national_geo_value not in geo_values_to_include @@ -233,11 +230,12 @@ def aggregate_to_national( ) -def process_loc_level_data( +# not currently used, but could be used for processing latest_comprehensive +def process_loc_level_nssp_data( loc_level_nssp_data: pl.LazyFrame, loc_abb: str, disease: str, - first_training_date: datetime.date, + first_training_date: dt.date, loc_pop_df: pl.DataFrame, ) -> pl.DataFrame: logging.basicConfig(level=logging.INFO) @@ -263,7 +261,7 @@ def process_loc_level_data( .to_list() ) logger.info("Aggregating state-level data to national") - loc_level_nssp_data = aggregate_to_national( + loc_level_nssp_data = aggregate_nssp_to_national( loc_level_nssp_data, locations_to_aggregate, first_training_date, @@ -290,7 +288,7 @@ def process_loc_level_data( disease=pl.col("disease").cast(pl.Utf8).replace(_inverse_disease_map), ) .sort(["date", "disease"]) - .collect(engine="streaming") + .collect() ) @@ -298,7 +296,7 @@ def aggregate_facility_level_nssp_to_loc( facility_level_nssp_data: pl.LazyFrame, loc_abb: str, disease: str, - first_training_date: str, + first_training_date: dt.date, loc_pop_df: pl.DataFrame, ) -> pl.DataFrame: logging.basicConfig(level=logging.INFO) @@ -321,7 +319,7 @@ def aggregate_facility_level_nssp_to_loc( locations_to_aggregate = ( loc_pop_df.filter(pl.col("abb") != "US").get_column("abb").unique() ) - facility_level_nssp_data = aggregate_to_national( + facility_level_nssp_data = aggregate_nssp_to_national( facility_level_nssp_data, locations_to_aggregate, first_training_date, @@ -357,10 +355,10 @@ def get_loc_pop_df(): def _validate_and_extract( - df: pl.DataFrame, + df_lazy: pl.LazyFrame, parameter_name: str, ) -> list: - df = df.filter(pl.col("parameter") == parameter_name).collect() + df = df_lazy.filter(pl.col("parameter") == parameter_name).collect() if df.height != 1: error_msg = f"Expected exactly one {parameter_name} parameter row, but found {df.height}" logging.error(error_msg) @@ -374,8 +372,7 @@ def get_pmfs( param_estimates: pl.LazyFrame, loc_abb: str, disease: str, - as_of: dt.date = None, - reference_date: dt.date = None, + as_of: dt.date | None = None, right_truncation_required: bool = True, ) -> dict[str, list]: """ @@ -406,11 +403,6 @@ def get_pmfs( (start_date <= as_of <= end_date). Defaults to the most recent estimates. - reference_date : datetime.date, optional - The reference date for right truncation estimates. - Defaults to as_of value. Selects the most recent estimate - with reference_date <= this value. - right_truncation_required : bool, optional If False, allows extraction of other pmfs if right_truncation estimate is missing @@ -438,7 +430,6 @@ def get_pmfs( min_as_of = dt.date(1000, 1, 1) max_as_of = dt.date(3000, 1, 1) as_of = as_of or max_as_of - reference_date = reference_date or as_of filtered_estimates = ( param_estimates.with_columns( @@ -486,27 +477,21 @@ def get_pmfs( def process_and_save_loc_data( loc_abb: str, disease: str, - report_date: datetime.date, - first_training_date: datetime.date, - last_training_date: datetime.date, + report_date: dt.date, + first_training_date: dt.date, + last_training_date: dt.date, + facility_level_nssp_data: pl.LazyFrame, save_dir: Path, - logger: Logger = None, - facility_level_nssp_data: pl.LazyFrame = None, - loc_level_nssp_data: pl.LazyFrame = None, - loc_level_nwss_data: pl.LazyFrame = None, - credentials_dict: dict = None, - nhsn_data_path: Path | str = None, + logger: logging.Logger | None = None, + loc_level_nwss_data: pl.DataFrame | None = None, + credentials_dict: dict | None = None, + nhsn_data_path: Path | str | None = None, ) -> None: logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) os.makedirs(save_dir, exist_ok=True) - if facility_level_nssp_data is None and loc_level_nssp_data is None: - raise ValueError( - "Must provide at least one of facility-level and state-levelNSSP data" - ) - loc_pop_df = get_loc_pop_df() loc_pop = loc_pop_df.filter(pl.col("abb") == loc_abb).item(0, "population") @@ -523,57 +508,49 @@ def process_and_save_loc_data( loc_pop_df=loc_pop_df, ) - loc_level_data = process_loc_level_data( - loc_level_nssp_data=loc_level_nssp_data, - loc_abb=loc_abb, + nssp_full_data = clean_nssp_data( + data=aggregated_facility_data, disease=disease, - first_training_date=first_training_date, - loc_pop_df=loc_pop_df, + last_training_date=last_training_date, ) - if aggregated_facility_data.height > 0: - first_facility_level_data_date = aggregated_facility_data.get_column( - "date" - ).min() - loc_level_data = loc_level_data.filter( - pl.col("date") < first_facility_level_data_date - ) - - nssp_training_data = clean_nssp_data( - data=pl.concat([loc_level_data, aggregated_facility_data]), - disease=disease, - data_type="train", - last_data_date=last_training_date, - ) + nssp_training_data = nssp_full_data.filter(pl.col("data_type") == "train") - nhsn_training_data = ( + nhsn_full_data = ( get_nhsn( start_date=first_training_date, - end_date=last_training_date, + end_date=None, disease=disease, loc_abb=loc_abb, credentials_dict=credentials_dict, local_data_file=nhsn_data_path, ) .filter( - (pl.col("weekendingdate") <= last_training_date) - & (pl.col("weekendingdate") >= first_training_date) - ) # in testing mode, this isn't guaranteed - .with_columns(pl.lit("train").alias("data_type")) + pl.col("weekendingdate") >= first_training_date + ) # in testing mode, this isn't guaranteed' + .with_columns( + data_type=pl.when(pl.col("weekendingdate") <= last_training_date) + .then(pl.lit("train")) + .otherwise(pl.lit("eval")), + ) ) + nhsn_training_data = nhsn_full_data.filter(pl.col("data_type") == "train") nhsn_step_size = 7 - nwss_training_data = ( - loc_level_nwss_data.to_dict(as_series=False) - if loc_level_nwss_data is not None - else None - ) + if loc_level_nwss_data is not None: + nwss_training_data = loc_level_nwss_data.filter( + pl.col("date") <= last_training_date + ) + nwss_training_data_dict = nwss_training_data.to_dict(as_series=False) + else: + nwss_training_data = None + nwss_training_data_dict = None data_for_model_fit = { "loc_pop": loc_pop, "right_truncation_offset": right_truncation_offset, - "nwss_training_data": nwss_training_data, + "nwss_training_data": nwss_training_data_dict, "nssp_training_data": nssp_training_data.to_dict(as_series=False), "nhsn_training_data": nhsn_training_data.to_dict(as_series=False), "nhsn_step_size": nhsn_step_size, @@ -584,9 +561,9 @@ def process_and_save_loc_data( with open(Path(save_dir, "data_for_model_fit.json"), "w") as json_file: json.dump(data_for_model_fit, json_file, default=str) - combined_training_dat = combine_surveillance_data( - nssp_data=nssp_training_data, - nhsn_data=nhsn_training_data, + combined_data = combine_surveillance_data( + nssp_data=nssp_full_data, + nhsn_data=nhsn_full_data, nwss_data=loc_level_nwss_data, disease=disease, ) @@ -594,9 +571,7 @@ def process_and_save_loc_data( if logger is not None: logger.info(f"Saving {loc_abb} to {save_dir}") - combined_training_dat.write_csv( - Path(save_dir, "combined_training_data.tsv"), separator="\t" - ) + combined_data.write_csv(Path(save_dir, "combined_data.tsv"), separator="\t") return None diff --git a/pipelines/prep_eval_data.py b/pipelines/prep_eval_data.py deleted file mode 100644 index 8c24b73f..00000000 --- a/pipelines/prep_eval_data.py +++ /dev/null @@ -1,73 +0,0 @@ -import datetime as dt -import logging -from pathlib import Path - -import polars as pl - -from pipelines.prep_data import ( - clean_nssp_data, - combine_surveillance_data, - get_loc_pop_df, - get_nhsn, - process_loc_level_data, -) - - -def save_eval_data( - loc: str, - disease: str, - first_training_date, - last_training_date, - latest_comprehensive_path: Path | str, - output_data_dir: Path | str, - last_eval_date: dt.date = None, - output_file_name: str = "eval_data.tsv", - credentials_dict: dict = None, - nhsn_data_path: Path | str = None, -): - logging.basicConfig(level=logging.INFO) - logger = logging.getLogger(__name__) - - logger.info("Reading in truth data...") - loc_level_nssp_data = pl.scan_parquet(latest_comprehensive_path) - - raw_nssp_data = process_loc_level_data( - loc_level_nssp_data=loc_level_nssp_data, - loc_abb=loc, - disease=disease, - first_training_date=first_training_date, - loc_pop_df=get_loc_pop_df(), - ) - - nssp_data = clean_nssp_data( - data=raw_nssp_data, - disease=disease, - data_type="eval", - last_data_date=last_eval_date, - ) - - nhsn_data = ( - get_nhsn( - start_date=first_training_date, - end_date=None, - disease=disease, - loc_abb=loc, - credentials_dict=credentials_dict, - local_data_file=nhsn_data_path, - ) - .filter( - pl.col("weekendingdate") >= first_training_date - ) # in testing mode, this isn't guaranteed - .with_columns(data_type=pl.lit("eval")) - ) - - combined_eval_dat = combine_surveillance_data( - nssp_data=nssp_data, - nhsn_data=nhsn_data, - disease=disease, - ) - - combined_eval_dat.write_csv( - Path(output_data_dir, "combined_" + output_file_name), separator="\t" - ) - return None diff --git a/pipelines/tests/test_common_utils.py b/pipelines/tests/test_common_utils.py index 7810e748..9be284d6 100644 --- a/pipelines/tests/test_common_utils.py +++ b/pipelines/tests/test_common_utils.py @@ -16,7 +16,6 @@ get_available_reports, load_credentials, load_nssp_data, - parse_and_validate_report_date, ) @@ -59,24 +58,6 @@ def test_load_credentials_with_invalid_extension_raises_error(self, tmp_path): ), ], ) - def test_parse_and_validate_report_date( - self, - input_date, - available_facility, - available_loc, - expected_report, - expected_loc, - ): - """Test parsing report dates with various inputs.""" - logger = logging.getLogger(__name__) - - report_date, loc_report_date = parse_and_validate_report_date( - input_date, available_facility, available_loc, logger - ) - - assert report_date == expected_report - assert loc_report_date == expected_loc - @pytest.mark.parametrize( "n_training_days,exclude_last_n_days,expected_first,expected_last", [ diff --git a/pipelines/tests/test_end_to_end.sh b/pipelines/tests/test_end_to_end.sh index b31d4e00..665f4e58 100755 --- a/pipelines/tests/test_end_to_end.sh +++ b/pipelines/tests/test_end_to_end.sh @@ -3,7 +3,8 @@ BASE_DIR=pipelines/tests/end_to_end_test_output LOCATIONS=(US CA MT DC) DISEASES=(Influenza COVID-19 RSV) - +LOCATIONS=(CA) +DISEASES=(COVID-19) echo "TEST-MODE: Running forecast_pyrenew.py in test mode with base directory $BASE_DIR" if [ -d "$BASE_DIR" ]; then diff --git a/pipelines/tests/test_pyrenew_fit.sh b/pipelines/tests/test_pyrenew_fit.sh index 8deaeabf..7f166d68 100755 --- a/pipelines/tests/test_pyrenew_fit.sh +++ b/pipelines/tests/test_pyrenew_fit.sh @@ -15,7 +15,6 @@ python pipelines/forecast_pyrenew.py \ --disease "$disease" \ --loc "$location" \ --facility-level-nssp-data-dir "$BASE_DIR/private_data/nssp_etl_gold" \ - --state-level-nssp-data-dir "$BASE_DIR/private_data/nssp_state_level_gold" \ --priors-path pipelines/priors/prod_priors.py \ --param-data-dir "$BASE_DIR/private_data/prod_param_estimates" \ --nwss-data-dir "$BASE_DIR/private_data/nwss_vintages" \ @@ -27,7 +26,6 @@ python pipelines/forecast_pyrenew.py \ --rng-key 12345 \ --model-letters "$model_letters" \ --additional-forecast-letters "$model_letters" \ - --eval-data-path "$BASE_DIR/private_data/nssp-etl" \ --nhsn-data-path "$BASE_DIR/private_data/nhsn_test_data/${disease}_${location}.parquet" if [ "$?" -ne 0 ]; then echo "TEST-MODE FAIL: Forecasting/postprocessing pipeline failed" diff --git a/pipelines/tests/test_ts_fit.sh b/pipelines/tests/test_ts_fit.sh index 0993b3dd..1c3ea4cd 100644 --- a/pipelines/tests/test_ts_fit.sh +++ b/pipelines/tests/test_ts_fit.sh @@ -15,13 +15,11 @@ python pipelines/forecast_timeseries.py \ --disease "$disease" \ --loc "$location" \ --facility-level-nssp-data-dir "$BASE_DIR/private_data/nssp_etl_gold" \ - --state-level-nssp-data-dir "$BASE_DIR/private_data/nssp_state_level_gold" \ --param-data-dir "$BASE_DIR/private_data/prod_param_estimates" \ --output-dir "$BASE_DIR/2024-12-21_forecasts" \ --n-training-days 90 \ --n-samples 500 \ --model-letters "$model_letters" \ - --eval-data-path "$BASE_DIR/private_data/nssp-etl" \ --nhsn-data-path "$BASE_DIR/private_data/nhsn_test_data/${disease}_${location}.parquet" if [ "$?" -ne 0 ]; then echo "TEST-MODE FAIL: Forecasting/postprocessing pipeline failed" From 16f377a8d7f28ddadc0651746a502d7af8f222a5 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 6 Jan 2026 22:18:39 +0000 Subject: [PATCH 02/15] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pipelines/common_utils.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/pipelines/common_utils.py b/pipelines/common_utils.py index fe7cf29a..90f5f1bd 100644 --- a/pipelines/common_utils.py +++ b/pipelines/common_utils.py @@ -7,8 +7,6 @@ from pathlib import Path from typing import Any -import polars as pl - from pipelines.cli_utils import run_command From ec3188556c301ae3d1a510e1f463742127caf6d9 Mon Sep 17 00:00:00 2001 From: Damon Bayer Date: Tue, 6 Jan 2026 16:48:32 -0600 Subject: [PATCH 03/15] everything running with some known bugs --- hewr/R/process_loc_forecast.R | 17 ++++++++++------- pipelines/batch/setup_job.py | 2 -- pipelines/cli_utils.py | 7 ------- pipelines/forecast_pyrenew.py | 21 ++++++++++----------- pipelines/forecast_timeseries.py | 3 +-- 5 files changed, 21 insertions(+), 29 deletions(-) diff --git a/hewr/R/process_loc_forecast.R b/hewr/R/process_loc_forecast.R index b40e993a..951520c9 100644 --- a/hewr/R/process_loc_forecast.R +++ b/hewr/R/process_loc_forecast.R @@ -164,7 +164,7 @@ read_and_combine_data <- function(model_dir) { dat <- tidyr::expand_grid( epiweekly = c(FALSE, TRUE), - root = c("combined_training_data", "combined_eval_data") + root = c("combined_data") ) |> dplyr::mutate( prefix = ifelse(.data$epiweekly, "epiweekly_", ""), @@ -543,22 +543,24 @@ process_loc_forecast <- function( fs::path( model_dir, "data", - "combined_training_data", + "combined_data", ext = "tsv" ), col_types = data_col_types - ) + ) |> + dplyr::filter(.data$data_type == "train") # Used for augmenting denominator forecasts with training period denominator epiweekly_training_dat <- readr::read_tsv( fs::path( model_dir, "data", - "epiweekly_combined_training_data", + "epiweekly_combined_data", ext = "tsv" ), col_types = data_col_types - ) + ) |> + dplyr::filter(.data$data_type == "train") required_columns_e <- c( ".chain", @@ -678,11 +680,12 @@ process_forecast <- function( fs::path( model_dir, "data", - "epiweekly_combined_training_data", + "epiweekly_combined_data", ext = "tsv" ), col_types = data_col_types - ) + ) |> + dplyr::filter(.data$data_type == "train") required_columns_e <- c( ".chain", diff --git a/pipelines/batch/setup_job.py b/pipelines/batch/setup_job.py index dc78fb6c..65dd7077 100644 --- a/pipelines/batch/setup_job.py +++ b/pipelines/batch/setup_job.py @@ -250,7 +250,6 @@ def main( "--param-data-dir params " "--output-dir {output_dir} " "--credentials-path config/creds.toml " - "--report-date {report_date} " f"--exclude-last-n-days {exclude_last_n_days} " f"--model-letters {model_letters} " f"{additional_args}" @@ -340,7 +339,6 @@ def style_bool(val): base_call=base_call.format( loc=loc, disease=disease, - report_date="latest", output_dir=str(Path("output", output_subdir)), ), container_settings=container_settings, diff --git a/pipelines/cli_utils.py b/pipelines/cli_utils.py index 75b5386e..476efe31 100644 --- a/pipelines/cli_utils.py +++ b/pipelines/cli_utils.py @@ -24,13 +24,6 @@ def add_common_forecast_arguments(parser: argparse.ArgumentParser) -> None: ), ) - parser.add_argument( - "--report-date", - type=str, - default="latest", - help="Report date in YYYY-MM-DD format or latest (default: latest).", - ) - parser.add_argument( "--facility-level-nssp-data-dir", type=Path, diff --git a/pipelines/forecast_pyrenew.py b/pipelines/forecast_pyrenew.py index 69e7c6e9..0054a675 100644 --- a/pipelines/forecast_pyrenew.py +++ b/pipelines/forecast_pyrenew.py @@ -217,18 +217,17 @@ def get_available_nwss_reports( data_dir = Path(model_dir, "data") os.makedirs(data_dir, exist_ok=True) - if fit_ed_visits: - timeseries_model_name = "ts_ensemble_e" + timeseries_model_name = "ts_ensemble_e" if fit_ed_visits else None - if not os.path.exists(Path(model_run_dir, timeseries_model_name)): - raise ValueError( - f"{timeseries_model_name} model run not found. " - "Please ensure that the timeseries forecasts " - "for the ED visits (E) signal are generated " - "before fitting Pyrenew models with the E signal. " - "If running a batch job, set the flag --model-family " - "'timeseries' to fit timeseries model." - ) + if fit_ed_visits and not os.path.exists(Path(model_run_dir, timeseries_model_name)): + raise ValueError( + f"{timeseries_model_name} model run not found. " + "Please ensure that the timeseries forecasts " + "for the ED visits (E) signal are generated " + "before fitting Pyrenew models with the E signal. " + "If running a batch job, set the flag --model-family " + "'timeseries' to fit timeseries model." + ) logger.info("Recording git info...") record_git_info(model_dir) diff --git a/pipelines/forecast_timeseries.py b/pipelines/forecast_timeseries.py index 8f9a40a9..43f69796 100644 --- a/pipelines/forecast_timeseries.py +++ b/pipelines/forecast_timeseries.py @@ -44,7 +44,6 @@ def timeseries_ensemble_forecasts( def main( disease: str, - report_date: dt.date, param_data_dir: Path, loc: str, facility_level_nssp_data_dir: Path | str, @@ -69,7 +68,7 @@ def main( logger.info( "Starting single-location timeseries forecasting pipeline for " - f"location {loc}, and report date {report_date}" + f"location {loc}, and latest report date." ) credentials_dict = load_credentials(credentials_path, logger) From 6210446e52c3ab177c0d38c5053dd28e555ba387 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 6 Jan 2026 22:48:52 +0000 Subject: [PATCH 04/15] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pipelines/forecast_timeseries.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pipelines/forecast_timeseries.py b/pipelines/forecast_timeseries.py index 43f69796..86ec5c27 100644 --- a/pipelines/forecast_timeseries.py +++ b/pipelines/forecast_timeseries.py @@ -1,5 +1,4 @@ import argparse -import datetime as dt import logging import os from pathlib import Path From e564a268cdd6d18df062f568bfe7588372f6599d Mon Sep 17 00:00:00 2001 From: Damon Bayer Date: Tue, 6 Jan 2026 17:00:37 -0600 Subject: [PATCH 05/15] fix python tests --- pipelines/tests/test_common_utils.py | 54 ---------------------------- 1 file changed, 54 deletions(-) diff --git a/pipelines/tests/test_common_utils.py b/pipelines/tests/test_common_utils.py index 9be284d6..5af897bf 100644 --- a/pipelines/tests/test_common_utils.py +++ b/pipelines/tests/test_common_utils.py @@ -15,7 +15,6 @@ calculate_training_dates, get_available_reports, load_credentials, - load_nssp_data, ) @@ -39,25 +38,6 @@ def test_load_credentials_with_invalid_extension_raises_error(self, tmp_path): with pytest.raises(ValueError, match="must have the extension '.toml'"): load_credentials(invalid_file, logger) - @pytest.mark.parametrize( - "input_date,available_facility,available_loc,expected_report,expected_loc", - [ - ( - "latest", - [dt.date(2024, 12, 18), dt.date(2024, 12, 19), dt.date(2024, 12, 20)], - [dt.date(2024, 12, 18), dt.date(2024, 12, 19)], - dt.date(2024, 12, 20), - dt.date(2024, 12, 19), - ), - ( - "2024-12-20", - [dt.date(2024, 12, 15), dt.date(2024, 12, 20)], - [dt.date(2024, 12, 15), dt.date(2024, 12, 20)], - dt.date(2024, 12, 20), - dt.date(2024, 12, 20), - ), - ], - ) @pytest.mark.parametrize( "n_training_days,exclude_last_n_days,expected_first,expected_last", [ @@ -97,40 +77,6 @@ def test_get_available_reports_with_parquet_files(self, tmp_path): assert dt.date(2024, 12, 15) in result assert dt.date(2024, 12, 20) in result - def test_load_nssp_data_both_available(self, tmp_path): - """Test loading NSSP data when both facility and location data available.""" - facility_dir = tmp_path / "facility" - loc_dir = tmp_path / "location" - facility_dir.mkdir() - loc_dir.mkdir() - - # Create minimal valid parquet files - facility_file = facility_dir / "2024-12-20.parquet" - loc_file = loc_dir / "2024-12-20.parquet" - - df = pl.DataFrame({"col1": [1, 2, 3]}) - df.write_parquet(facility_file) - df.write_parquet(loc_file) - - logger = logging.getLogger(__name__) - report_date = dt.date(2024, 12, 20) - available = [report_date] - - facility_data, loc_data = load_nssp_data( - report_date, - report_date, - available, - available, - facility_dir, - loc_dir, - logger, - ) - - assert facility_data is not None - assert loc_data is not None - assert isinstance(facility_data, pl.LazyFrame) - assert isinstance(loc_data, pl.LazyFrame) - class TestCLIUtils: """Tests for CLI argument parsing utilities.""" From 37b8d81e8f6d9631f9eed2110ebd249bd250df77 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 6 Jan 2026 23:01:17 +0000 Subject: [PATCH 06/15] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pipelines/tests/test_common_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pipelines/tests/test_common_utils.py b/pipelines/tests/test_common_utils.py index 5af897bf..0e2e9d6b 100644 --- a/pipelines/tests/test_common_utils.py +++ b/pipelines/tests/test_common_utils.py @@ -4,7 +4,6 @@ import datetime as dt import logging -import polars as pl import pytest from pipelines.cli_utils import ( From d0c06f5905e03898e3a3604a5ed6eddc0ddd97a8 Mon Sep 17 00:00:00 2001 From: Damon Bayer Date: Tue, 6 Jan 2026 17:26:38 -0600 Subject: [PATCH 07/15] tests passing --- pipelines/tests/test_prep_data.py | 30 +++++++++++++++++++++--------- 1 file changed, 21 insertions(+), 9 deletions(-) diff --git a/pipelines/tests/test_prep_data.py b/pipelines/tests/test_prep_data.py index 654d60bb..d2fa20c3 100644 --- a/pipelines/tests/test_prep_data.py +++ b/pipelines/tests/test_prep_data.py @@ -36,12 +36,11 @@ def test_get_loc_pop_df(): ], ) @pytest.mark.parametrize("disease", valid_diseases + ["Iffluenza", "COVID_19"]) -@pytest.mark.parametrize("data_type", ["train", "eval", "other"]) @pytest.mark.parametrize( "last_data_date", [date(2025, 12, 12), date(2024, 12, 1), date(2025, 1, 2), date(2024, 12, 29)], ) -def test_clean_nssp_data(pivoted_raw_data, disease, data_type, last_data_date): +def test_clean_nssp_data(pivoted_raw_data, disease, last_data_date): """ Confirm that clean_nssp_data works as expected. """ @@ -49,17 +48,15 @@ def test_clean_nssp_data(pivoted_raw_data, disease, data_type, last_data_date): index="date", variable_name="disease", value_name="ed_visits" ) invalid_disease = disease not in valid_diseases - no_data_after_last_requested = ( - last_data_date is not None and last_data_date < pivoted_raw_data["date"].min() - ) - expect_empty_df = invalid_disease or no_data_after_last_requested + + expect_empty_df = invalid_disease if expect_empty_df: context = pytest.raises(pl.exceptions.ColumnNotFoundError, match=disease) else: context = nullcontext() with context: - result = prep_data.clean_nssp_data(raw_data, disease, data_type, last_data_date) + result = prep_data.clean_nssp_data(raw_data, disease, last_data_date) if not expect_empty_df: expected = ( pivoted_raw_data.select( @@ -69,11 +66,26 @@ def test_clean_nssp_data(pivoted_raw_data, disease, data_type, last_data_date): ) .with_columns( other_ed_visits=pl.col("Total") - pl.col("observed_ed_visits"), - data_type=pl.lit(data_type), + data_type=pl.when(pl.col("date") <= last_data_date) + .then(pl.lit("train")) + .otherwise(pl.lit("eval")), ) .drop("Total") .sort("date") - ).filter(pl.col("date") <= last_data_date) + ) assert result.select( ["date", "observed_ed_visits", "other_ed_visits", "data_type"] ).equals(expected) + + +pivoted_raw_data = pl.DataFrame( + { + "COVID-19": [10, 15, 20], + "Influenza": [12, 16, 22], + "RSV": [0, 2, 0], + "Total": [497, 502, 499], + "date": [date(2024, 12, 29), date(2025, 1, 1), date(2025, 1, 3)], + } +) +last_data_date = date(2024, 12, 1) +disease = "COVID-19" From a9a18c8de9406ac85f894fde2d8cc1bbd085b073 Mon Sep 17 00:00:00 2001 From: Damon Bayer Date: Tue, 6 Jan 2026 17:30:34 -0600 Subject: [PATCH 08/15] missed a test --- pipelines/tests/test_common_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pipelines/tests/test_common_utils.py b/pipelines/tests/test_common_utils.py index 0e2e9d6b..bde9b153 100644 --- a/pipelines/tests/test_common_utils.py +++ b/pipelines/tests/test_common_utils.py @@ -98,7 +98,6 @@ def test_add_common_forecast_arguments_smoke_test(self): assert args.disease == "COVID-19" assert args.loc == "CA" - assert args.report_date == "latest" # default value assert args.n_training_days == 180 # default value assert args.n_forecast_days == 28 # default value assert args.exclude_last_n_days == 0 # default value From 330dbcfd57e635320bded75b778dc27ea4e61f0d Mon Sep 17 00:00:00 2001 From: Damon Bayer Date: Tue, 6 Jan 2026 17:33:35 -0600 Subject: [PATCH 09/15] update hewr docs --- hewr/man/load_training_data.Rd | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hewr/man/load_training_data.Rd b/hewr/man/load_training_data.Rd index f9acd86a..fc105c7e 100644 --- a/hewr/man/load_training_data.Rd +++ b/hewr/man/load_training_data.Rd @@ -8,7 +8,7 @@ Load Training Data for Timeseries Forecasting} \usage{ load_training_data( model_dir, - base_data_name = "combined_training_data", + base_data_name = "combined_data", epiweekly = FALSE ) } From 23a14fcd6b251579a099e387cdf90a5a144bde1e Mon Sep 17 00:00:00 2001 From: Damon Bayer Date: Thu, 8 Jan 2026 15:23:29 -0600 Subject: [PATCH 10/15] correct treatment of nwss train/eval data --- pipelines/prep_data.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/pipelines/prep_data.py b/pipelines/prep_data.py index a9cdd003..05891616 100644 --- a/pipelines/prep_data.py +++ b/pipelines/prep_data.py @@ -539,12 +539,17 @@ def process_and_save_loc_data( nhsn_step_size = 7 if loc_level_nwss_data is not None: - nwss_training_data = loc_level_nwss_data.filter( - pl.col("date") <= last_training_date + nwss_full_data = loc_level_nwss_data.with_columns( + data_type=pl.when(pl.col("date") <= last_training_date) + .then(pl.lit("train")) + .otherwise(pl.lit("eval")), ) - nwss_training_data_dict = nwss_training_data.to_dict(as_series=False) + nwss_training_data_dict = nwss_full_data.filter( + pl.col("date") <= last_training_date + ).to_dict(as_series=False) + else: - nwss_training_data = None + nwss_full_data = None nwss_training_data_dict = None data_for_model_fit = { @@ -564,7 +569,7 @@ def process_and_save_loc_data( combined_data = combine_surveillance_data( nssp_data=nssp_full_data, nhsn_data=nhsn_full_data, - nwss_data=loc_level_nwss_data, + nwss_data=nwss_full_data, disease=disease, ) From 26e5b602df84e5b18fd2cf26418269c6a5e642ef Mon Sep 17 00:00:00 2001 From: Damon Bayer Date: Thu, 8 Jan 2026 15:26:07 -0600 Subject: [PATCH 11/15] Discard changes to pipelines/tests/test_end_to_end.sh --- pipelines/tests/test_end_to_end.sh | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pipelines/tests/test_end_to_end.sh b/pipelines/tests/test_end_to_end.sh index 665f4e58..b31d4e00 100755 --- a/pipelines/tests/test_end_to_end.sh +++ b/pipelines/tests/test_end_to_end.sh @@ -3,8 +3,7 @@ BASE_DIR=pipelines/tests/end_to_end_test_output LOCATIONS=(US CA MT DC) DISEASES=(Influenza COVID-19 RSV) -LOCATIONS=(CA) -DISEASES=(COVID-19) + echo "TEST-MODE: Running forecast_pyrenew.py in test mode with base directory $BASE_DIR" if [ -d "$BASE_DIR" ]; then From e8e9e317ed8933f171338fe1ee990311739d631e Mon Sep 17 00:00:00 2001 From: Damon Bayer Date: Thu, 8 Jan 2026 16:42:10 -0600 Subject: [PATCH 12/15] fix issue where training data spans partial epiweeks --- pipelines/generate_epiweekly_data.R | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/pipelines/generate_epiweekly_data.R b/pipelines/generate_epiweekly_data.R index 72d64018..a67eea6b 100644 --- a/pipelines/generate_epiweekly_data.R +++ b/pipelines/generate_epiweekly_data.R @@ -26,16 +26,13 @@ purrr::walk(script_packages, \(pkg) { #' #' @param strict A logical value indicating whether to enforce strict inclusion #' of only full epiweeks. Default is TRUE. -#' @param day_of_week An integer specifying the day of the week to use for the -#' epiweek date. Default is 1 (Monday). #' #' @return None. The function writes the epiweekly data to a CSV file in the #' specified directory. convert_daily_to_epiweekly <- function( data_dir, data_name, - strict = TRUE, - day_of_week = 7 + strict = TRUE ) { data_path <- path(data_dir, data_name) @@ -57,15 +54,22 @@ convert_daily_to_epiweekly <- function( epiweekly_hosp_data <- daily_data |> filter(.variable == "observed_hospital_admissions") + grouping_cols <- c("geo_value", "disease", "data_type", ".variable") + epiweekly_ed_data <- daily_ed_data |> + group_by( + dplyr::across(dplyr::all_of(grouping_cols)), + epiyear = epiyear(date), + epiweek = epiweek(date) + ) |> + mutate(data_type = if_else(all(data_type == "train"), "train", "eval")) |> forecasttools::daily_to_epiweekly( value_col = ".value", weekly_value_name = ".value", - id_cols = c("geo_value", "disease", "data_type", ".variable"), - strict = strict - ) |> - mutate( - date = epiweek_to_date(epiweek, epiyear, day_of_week = day_of_week) + id_cols = c(grouping_cols, "data_type"), + strict = strict, + with_epiweek_end_date = TRUE, + epiweek_end_date_name = "date" ) |> select(date, geo_value, disease, data_type, .variable, .value) From 76708781aa73fea9ab1908eaaded1bae98019cfb Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 8 Jan 2026 22:58:41 +0000 Subject: [PATCH 13/15] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pipelines/common_utils.py | 2 ++ pipelines/tests/test_common_utils.py | 1 - 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/pipelines/common_utils.py b/pipelines/common_utils.py index b9cf1222..94eb4e1d 100644 --- a/pipelines/common_utils.py +++ b/pipelines/common_utils.py @@ -41,6 +41,7 @@ def get_available_reports( for f in Path(data_dir).glob(glob_pattern) ] + def _parse_single_date(date_str: str) -> tuple[dt.date, dt.date]: """ Parse a single date string into a date range tuple. @@ -143,6 +144,7 @@ def parse_exclude_date_ranges( return parsed_ranges + def calculate_training_dates( report_date: dt.date, n_training_days: int, diff --git a/pipelines/tests/test_common_utils.py b/pipelines/tests/test_common_utils.py index 14000646..e10715f6 100644 --- a/pipelines/tests/test_common_utils.py +++ b/pipelines/tests/test_common_utils.py @@ -4,7 +4,6 @@ import datetime as dt import logging -import polars as pl import pytest from pipelines.cli_utils import ( From f343848cbc95ea5defa68e635b759cdebdca3ddd Mon Sep 17 00:00:00 2001 From: Damon Bayer Date: Thu, 8 Jan 2026 17:16:30 -0600 Subject: [PATCH 14/15] small simplification --- hewr/R/process_loc_forecast.R | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/hewr/R/process_loc_forecast.R b/hewr/R/process_loc_forecast.R index 951520c9..9a684e1e 100644 --- a/hewr/R/process_loc_forecast.R +++ b/hewr/R/process_loc_forecast.R @@ -162,9 +162,9 @@ read_and_combine_data <- function(model_dir) { ) dat <- - tidyr::expand_grid( + tibble::tibble( epiweekly = c(FALSE, TRUE), - root = c("combined_data") + root = "combined_data", ) |> dplyr::mutate( prefix = ifelse(.data$epiweekly, "epiweekly_", ""), From 569ebffa7c6af7ca319f2c474c08ad638b53a727 Mon Sep 17 00:00:00 2001 From: Damon Bayer Date: Wed, 14 Jan 2026 15:03:06 -0600 Subject: [PATCH 15/15] remove stray code --- pipelines/tests/test_prep_data.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/pipelines/tests/test_prep_data.py b/pipelines/tests/test_prep_data.py index d2fa20c3..37b25353 100644 --- a/pipelines/tests/test_prep_data.py +++ b/pipelines/tests/test_prep_data.py @@ -76,16 +76,3 @@ def test_clean_nssp_data(pivoted_raw_data, disease, last_data_date): assert result.select( ["date", "observed_ed_visits", "other_ed_visits", "data_type"] ).equals(expected) - - -pivoted_raw_data = pl.DataFrame( - { - "COVID-19": [10, 15, 20], - "Influenza": [12, 16, 22], - "RSV": [0, 2, 0], - "Total": [497, 502, 499], - "date": [date(2024, 12, 29), date(2025, 1, 1), date(2025, 1, 3)], - } -) -last_data_date = date(2024, 12, 1) -disease = "COVID-19"