diff --git a/pipelines/prep_data.py b/pipelines/prep_data.py index 46e64f66..a8922b48 100644 --- a/pipelines/prep_data.py +++ b/pipelines/prep_data.py @@ -3,7 +3,7 @@ import logging import os import tempfile -from datetime import datetime +from datetime import date, datetime from logging import Logger from pathlib import Path @@ -22,6 +22,52 @@ _inverse_disease_map = {v: k for k, v in _disease_map.items()} +def clean_nssp_data( + data: pl.DataFrame, + disease: str, + data_type: str, + last_data_date: date | None = None, +) -> pl.DataFrame: + """ + Filter, reformat, and annotate a raw `pl.DataFrame` of NSSP data, + yielding a `pl.DataFrame` in the format expected by + `combine_surveillance_data`. + + Parameters + ---------- + data + Data to clean + + disease + Name of the disease for which to prep data. + + data_type + Value for the data_type annotation column in the + output dataframe. + + last_data_date + If provided, filter the dataset to include only dates + prior to this date. Default `None` (no filter). + """ + if last_data_date is not None: + data = data.filter(pl.col("date") <= last_data_date) + + return ( + data.filter(pl.col("disease").is_in([disease, "Total"])) + .pivot( + on="disease", + values="ed_visits", + ) + .rename({disease: "observed_ed_visits"}) + .with_columns( + other_ed_visits=pl.col("Total") - pl.col("observed_ed_visits"), + data_type=pl.lit(data_type), + ) + .drop(pl.col("Total")) + .sort("date") + ) + + def get_nhsn( start_date: datetime.date, end_date: datetime.date, @@ -493,16 +539,11 @@ def process_and_save_loc_data( pl.col("date") < first_facility_level_data_date ) - nssp_training_data = ( - pl.concat([loc_level_data, aggregated_facility_data]) - .filter(pl.col("date") <= last_training_date) - .with_columns(pl.lit("train").alias("data_type")) - .pivot( - on="disease", - values="ed_visits", - ) - .rename({disease: "observed_ed_visits", "Total": "other_ed_visits"}) - .sort("date") + nssp_training_data = clean_nssp_data( + data=pl.concat([loc_level_data, aggregated_facility_data]), + disease=disease, + data_type="train", + last_data_date=last_training_date, ) nhsn_training_data = ( diff --git a/pipelines/prep_eval_data.py b/pipelines/prep_eval_data.py index 697b37d5..8c24b73f 100644 --- a/pipelines/prep_eval_data.py +++ b/pipelines/prep_eval_data.py @@ -5,6 +5,7 @@ import polars as pl from pipelines.prep_data import ( + clean_nssp_data, combine_surveillance_data, get_loc_pop_df, get_nhsn, @@ -30,26 +31,19 @@ def save_eval_data( logger.info("Reading in truth data...") loc_level_nssp_data = pl.scan_parquet(latest_comprehensive_path) - if last_eval_date is not None: - loc_level_nssp_data = loc_level_nssp_data.filter( - pl.col("reference_date") <= last_eval_date - ) + raw_nssp_data = process_loc_level_data( + loc_level_nssp_data=loc_level_nssp_data, + loc_abb=loc, + disease=disease, + first_training_date=first_training_date, + loc_pop_df=get_loc_pop_df(), + ) - nssp_data = ( - process_loc_level_data( - loc_level_nssp_data=loc_level_nssp_data, - loc_abb=loc, - disease=disease, - first_training_date=first_training_date, - loc_pop_df=get_loc_pop_df(), - ) - .with_columns(data_type=pl.lit("eval")) - .pivot( - on="disease", - values="ed_visits", - ) - .rename({disease: "observed_ed_visits", "Total": "other_ed_visits"}) - .sort("date") + nssp_data = clean_nssp_data( + data=raw_nssp_data, + disease=disease, + data_type="eval", + last_data_date=last_eval_date, ) nhsn_data = ( diff --git a/pipelines/tests/test_prep_data.py b/pipelines/tests/test_prep_data.py index 550ac2cb..654d60bb 100644 --- a/pipelines/tests/test_prep_data.py +++ b/pipelines/tests/test_prep_data.py @@ -1,5 +1,13 @@ +from contextlib import nullcontext +from datetime import date + +import polars as pl +import pytest + from pipelines import prep_data +valid_diseases = ["COVID-19", "Influenza", "RSV"] + def test_get_loc_pop_df(): """ @@ -9,5 +17,63 @@ def test_get_loc_pop_df(): and expected column names """ df = prep_data.get_loc_pop_df() - assert df.height == 58 # 50 US states, 7 other jursidictions, US national + assert df.height == 58 # 50 US states, 7 other jurisdictions, US national assert set(df.columns) == set(["name", "abb", "population"]) + + +@pytest.mark.parametrize( + "pivoted_raw_data", + [ + pl.DataFrame( + { + "COVID-19": [10, 15, 20], + "Influenza": [12, 16, 22], + "RSV": [0, 2, 0], + "Total": [497, 502, 499], + "date": [date(2024, 12, 29), date(2025, 1, 1), date(2025, 1, 3)], + } + ) + ], +) +@pytest.mark.parametrize("disease", valid_diseases + ["Iffluenza", "COVID_19"]) +@pytest.mark.parametrize("data_type", ["train", "eval", "other"]) +@pytest.mark.parametrize( + "last_data_date", + [date(2025, 12, 12), date(2024, 12, 1), date(2025, 1, 2), date(2024, 12, 29)], +) +def test_clean_nssp_data(pivoted_raw_data, disease, data_type, last_data_date): + """ + Confirm that clean_nssp_data works as expected. + """ + raw_data = pivoted_raw_data.unpivot( + index="date", variable_name="disease", value_name="ed_visits" + ) + invalid_disease = disease not in valid_diseases + no_data_after_last_requested = ( + last_data_date is not None and last_data_date < pivoted_raw_data["date"].min() + ) + expect_empty_df = invalid_disease or no_data_after_last_requested + + if expect_empty_df: + context = pytest.raises(pl.exceptions.ColumnNotFoundError, match=disease) + else: + context = nullcontext() + with context: + result = prep_data.clean_nssp_data(raw_data, disease, data_type, last_data_date) + if not expect_empty_df: + expected = ( + pivoted_raw_data.select( + pl.col("date"), + pl.col(disease).alias("observed_ed_visits"), + pl.col("Total"), + ) + .with_columns( + other_ed_visits=pl.col("Total") - pl.col("observed_ed_visits"), + data_type=pl.lit(data_type), + ) + .drop("Total") + .sort("date") + ).filter(pl.col("date") <= last_data_date) + assert result.select( + ["date", "observed_ed_visits", "other_ed_visits", "data_type"] + ).equals(expected)