Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 52 additions & 11 deletions pipelines/prep_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import logging
import os
import tempfile
from datetime import datetime
from datetime import date, datetime
from logging import Logger
from pathlib import Path

Expand All @@ -22,6 +22,52 @@
_inverse_disease_map = {v: k for k, v in _disease_map.items()}


def clean_nssp_data(
data: pl.DataFrame,
disease: str,
data_type: str,
last_data_date: date | None = None,
) -> pl.DataFrame:
"""
Filter, reformat, and annotate a raw `pl.DataFrame` of NSSP data,
yielding a `pl.DataFrame` in the format expected by
`combine_surveillance_data`.

Parameters
----------
data
Data to clean

disease
Name of the disease for which to prep data.

data_type
Value for the data_type annotation column in the
output dataframe.

last_data_date
If provided, filter the dataset to include only dates
prior to this date. Default `None` (no filter).
"""
if last_data_date is not None:
data = data.filter(pl.col("date") <= last_data_date)

return (
data.filter(pl.col("disease").is_in([disease, "Total"]))
.pivot(
on="disease",
values="ed_visits",
)
.rename({disease: "observed_ed_visits"})
.with_columns(
other_ed_visits=pl.col("Total") - pl.col("observed_ed_visits"),
data_type=pl.lit(data_type),
)
.drop(pl.col("Total"))
.sort("date")
)


def get_nhsn(
start_date: datetime.date,
end_date: datetime.date,
Expand Down Expand Up @@ -493,16 +539,11 @@ def process_and_save_loc_data(
pl.col("date") < first_facility_level_data_date
)

nssp_training_data = (
pl.concat([loc_level_data, aggregated_facility_data])
.filter(pl.col("date") <= last_training_date)
.with_columns(pl.lit("train").alias("data_type"))
.pivot(
on="disease",
values="ed_visits",
)
.rename({disease: "observed_ed_visits", "Total": "other_ed_visits"})
.sort("date")
nssp_training_data = clean_nssp_data(
data=pl.concat([loc_level_data, aggregated_facility_data]),
disease=disease,
data_type="train",
last_data_date=last_training_date,
)

nhsn_training_data = (
Expand Down
32 changes: 13 additions & 19 deletions pipelines/prep_eval_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import polars as pl

from pipelines.prep_data import (
clean_nssp_data,
combine_surveillance_data,
get_loc_pop_df,
get_nhsn,
Expand All @@ -30,26 +31,19 @@ def save_eval_data(
logger.info("Reading in truth data...")
loc_level_nssp_data = pl.scan_parquet(latest_comprehensive_path)

if last_eval_date is not None:
loc_level_nssp_data = loc_level_nssp_data.filter(
pl.col("reference_date") <= last_eval_date
)
raw_nssp_data = process_loc_level_data(
loc_level_nssp_data=loc_level_nssp_data,
loc_abb=loc,
disease=disease,
first_training_date=first_training_date,
loc_pop_df=get_loc_pop_df(),
)

nssp_data = (
process_loc_level_data(
loc_level_nssp_data=loc_level_nssp_data,
loc_abb=loc,
disease=disease,
first_training_date=first_training_date,
loc_pop_df=get_loc_pop_df(),
)
.with_columns(data_type=pl.lit("eval"))
.pivot(
on="disease",
values="ed_visits",
)
.rename({disease: "observed_ed_visits", "Total": "other_ed_visits"})
.sort("date")
nssp_data = clean_nssp_data(
data=raw_nssp_data,
disease=disease,
data_type="eval",
last_data_date=last_eval_date,
)

nhsn_data = (
Expand Down
66 changes: 66 additions & 0 deletions pipelines/tests/test_prep_data.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
from contextlib import nullcontext
from datetime import date

import polars as pl
import pytest

from pipelines import prep_data

valid_diseases = ["COVID-19", "Influenza", "RSV"]


def test_get_loc_pop_df():
"""
Expand All @@ -11,3 +19,61 @@ def test_get_loc_pop_df():
df = prep_data.get_loc_pop_df()
assert df.height == 58 # 50 US states, 7 other jursidictions, US national
assert set(df.columns) == set(["name", "abb", "population"])


@pytest.mark.parametrize(
"pivoted_raw_data",
[
pl.DataFrame(
{
"COVID-19": [10, 15, 20],
"Influenza": [12, 16, 22],
"RSV": [0, 2, 0],
"Total": [497, 502, 499],
"date": [date(2024, 12, 29), date(2025, 1, 1), date(2025, 1, 3)],
}
)
],
)
@pytest.mark.parametrize("disease", valid_diseases + ["Iffluenza", "COVID_19"])
@pytest.mark.parametrize("data_type", ["train", "eval", "other"])
@pytest.mark.parametrize(
"last_data_date",
[date(2025, 12, 12), date(2024, 12, 1), date(2025, 1, 2), date(2024, 12, 29)],
)
def test_clean_nssp_data(pivoted_raw_data, disease, data_type, last_data_date):
"""
Confirm that clean_nssp_data works as expected.
"""
raw_data = pivoted_raw_data.unpivot(
index="date", variable_name="disease", value_name="ed_visits"
)
invalid_disease = disease not in valid_diseases
no_data_after_last_requested = (
last_data_date is not None and last_data_date < pivoted_raw_data["date"].min()
)
expect_empty_df = invalid_disease or no_data_after_last_requested

if expect_empty_df:
context = pytest.raises(pl.exceptions.ColumnNotFoundError, match=disease)
else:
context = nullcontext()
with context:
result = prep_data.clean_nssp_data(raw_data, disease, data_type, last_data_date)
if not expect_empty_df:
expected = (
pivoted_raw_data.select(
pl.col("date"),
pl.col(disease).alias("observed_ed_visits"),
pl.col("Total"),
)
.with_columns(
other_ed_visits=pl.col("Total") - pl.col("observed_ed_visits"),
data_type=pl.lit(data_type),
)
.drop("Total")
.sort("date")
).filter(pl.col("date") <= last_data_date)
assert result.select(
["date", "observed_ed_visits", "other_ed_visits", "data_type"]
).equals(expected)
Loading