Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
2972159
Merge branch 'main' of github-bf06:CDCgov/pyrenew-hew
cdc-mitzimorris Sep 25, 2025
692ede3
Merge branch 'main' of github-bf06:CDCgov/pyrenew-hew
cdc-mitzimorris Sep 26, 2025
c341d63
simplify pyrenew_hew_data, update unit tests
cdc-mitzimorris Sep 29, 2025
0994a08
checkpointing
cdc-mitzimorris Sep 29, 2025
a04df6f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 30, 2025
ae65cbf
Merge branch 'mem_refactor_datetime_indexing' of github-bf06:CDCgov/p…
cdc-mitzimorris Sep 30, 2025
adebd0c
Refactor time handling to use PyRenew utilities
cdc-mitzimorris Oct 1, 2025
8cf26f1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 1, 2025
4cd149e
Apply comprehensive naming changes from naming guide
cdc-mitzimorris Oct 2, 2025
d8076f1
merge fix
cdc-mitzimorris Oct 2, 2025
6054010
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 2, 2025
8ac35ae
Revert pyrenew_hew_data.py and pyrenew_hew_model.py to commit 8cf26f1
cdc-mitzimorris Oct 3, 2025
d6f2883
merge fix
cdc-mitzimorris Oct 3, 2025
4362190
Revert test_pyrenew_hew_data.py to commit 8cf26f1
cdc-mitzimorris Oct 3, 2025
8218ded
Merge branch 'main' of github-bf06:CDCgov/pyrenew-hew into mem_refact…
cdc-mitzimorris Oct 3, 2025
ae2d05f
not changing names (yet)
cdc-mitzimorris Oct 3, 2025
bd0f9dd
use convert_date everywhere, more time tests
cdc-mitzimorris Oct 8, 2025
12027c0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 8, 2025
2919c5f
Merge branch 'main' of github-bf06:CDCgov/pyrenew-hew into mem_refact…
cdc-mitzimorris Oct 8, 2025
36d29a2
Merge branch 'mem_refactor_datetime_indexing' of github-bf06:CDCgov/p…
cdc-mitzimorris Oct 8, 2025
3b79b3b
update uv.lock
cdc-mitzimorris Oct 8, 2025
63a2b3d
Merge branch 'main' of github-bf06:CDCgov/pyrenew-hew into mem_refact…
cdc-mitzimorris Oct 8, 2025
ed259b0
changes per copilot code review
cdc-mitzimorris Oct 8, 2025
f6a992a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 8, 2025
6dbb3db
Revert "changes per copilot code review"
cdc-mitzimorris Nov 17, 2025
cd55ac2
logic fix for MMWR data handling, test
cdc-mitzimorris Nov 17, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 53 additions & 102 deletions pyrenew_hew/pyrenew_hew_data.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import datetime as dt
import json
from pathlib import Path
from typing import Self
Expand All @@ -7,6 +6,14 @@
import numpy as np
import polars as pl
from jax.typing import ArrayLike
from pyrenew.time import (
align_observation_times,
convert_date,
create_date_time_spine,
get_end_date,
get_n_data_days,
validate_mmwr_dates,
)


class PyrenewHEWData:
Expand Down Expand Up @@ -48,15 +55,7 @@ def __init__(
self.right_truncation_offset = right_truncation_offset
self.population_size = population_size

if (
first_hospital_admissions_date is not None
and not first_hospital_admissions_date.astype(dt.datetime).weekday() == 5
):
raise ValueError(
"Dates for hospital admissions timeseries must "
"be Saturdays (MMWR epiweek end "
"days)."
)
validate_mmwr_dates([first_hospital_admissions_date])

self.first_ed_visits_date_ = first_ed_visits_date
self.first_hospital_admissions_date_ = first_hospital_admissions_date
Expand Down Expand Up @@ -146,23 +145,23 @@ def from_json(

@property
def n_ed_visits_data_days(self):
return self.get_n_data_days(
n_datapoints=self.n_ed_visits_data_days_,
return get_n_data_days(
n_points=self.n_ed_visits_data_days_,
date_array=self.dates_observed_ed_visits,
)

@property
def n_hospital_admissions_data_days(self):
return self.get_n_data_days(
n_datapoints=self.n_hospital_admissions_data_days_,
return get_n_data_days(
n_points=self.n_hospital_admissions_data_days_,
date_array=self.dates_observed_hospital_admissions,
timestep_days=7,
)

@property
def n_wastewater_data_days(self):
return self.get_n_data_days(
n_datapoints=self.n_wastewater_data_days_,
return get_n_data_days(
n_points=self.n_wastewater_data_days_,
date_array=self.dates_observed_disease_wastewater,
)

Expand Down Expand Up @@ -197,29 +196,29 @@ def first_ed_visits_date(self):

@property
def first_hospital_admissions_date(self):
if self.data_observed_disease_hospital_admissions is not None:
if self.dates_observed_hospital_admissions is not None:
return self.dates_observed_hospital_admissions.min()
return self.first_hospital_admissions_date_

@property
def last_wastewater_date(self):
return self.get_end_date(
return get_end_date(
self.first_wastewater_date,
self.n_wastewater_data_days,
timestep_days=1,
)

@property
def last_ed_visits_date(self):
return self.get_end_date(
return get_end_date(
self.first_ed_visits_date,
self.n_ed_visits_data_days,
timestep_days=1,
)

@property
def last_hospital_admissions_date(self):
return self.get_end_date(
return get_end_date(
self.first_hospital_admissions_date,
self.n_hospital_admissions_data_days,
timestep_days=7,
Expand Down Expand Up @@ -314,21 +313,9 @@ def site_subpop_spine(self):

@property
def date_time_spine(self):
date_time_spine = (
pl.DataFrame(
{
"date": pl.date_range(
start=self.first_data_date_overall,
end=self.last_data_date_overall,
interval="1d",
eager=True,
)
}
)
.with_row_index("t")
.with_columns(pl.col("t").cast(pl.Int64))
return create_date_time_spine(
self.first_data_date_overall, self.last_data_date_overall
)
return date_time_spine

@property
def wastewater_data_extended(self):
Expand Down Expand Up @@ -374,33 +361,37 @@ def ww_uncensored(self):
@property
def model_t_obs_wastewater(self):
if self.nwss_training_data is not None:
return self.wastewater_data_extended.get_column("t").to_numpy()
observed_dates = self.nwss_training_data.get_column("date").to_numpy()
return align_observation_times(
observed_dates,
self.first_data_date_overall,
aggregation_freq="daily",
)
return None

@property
def model_t_obs_ed_visits(self):
if self.nssp_training_data is not None:
return (
self.nssp_training_data.join(
self.date_time_spine, on="date", how="left"
)
.get_column("t")
.unique()
.to_numpy()
observed_dates = (
self.nssp_training_data.get_column("date").unique().to_numpy()
)
return align_observation_times(
observed_dates,
self.first_data_date_overall,
aggregation_freq="daily",
)
return None

@property
def model_t_obs_hospital_admissions(self):
if self.nhsn_training_data is not None:
return (
self.nhsn_training_data.join(
self.date_time_spine,
left_on="weekendingdate",
right_on="date",
how="left",
)
.get_column("t")
.to_numpy()
observed_dates = (
self.nhsn_training_data.get_column("weekendingdate").unique().to_numpy()
)
return align_observation_times(
observed_dates,
self.first_data_date_overall,
aggregation_freq="daily",
)
return None

Expand Down Expand Up @@ -439,61 +430,21 @@ def lab_site_to_subpop_map(self):
)
return self.lab_site_to_subpop_map_

def get_end_date(
self,
first_date: np.datetime64,
n_datapoints: int,
timestep_days: int = 1,
) -> np.datetime64:
"""
Get end date from a first date and a number of datapoints,
with handling of None values and non-daily timeseries
"""
if first_date is None:
if n_datapoints != 0:
raise ValueError(
"Must provide an initial date if "
"n_datapoints is non-zero. "
f"Got n_datapoints = {n_datapoints} "
"but first_date was `None`"
)
result = None
else:
result = first_date + np.timedelta64(
(n_datapoints - 1) * timestep_days, "D"
)
return result

def get_n_data_days(
self,
n_datapoints: int = None,
date_array: ArrayLike = None,
timestep_days: int = 1,
) -> int:
if n_datapoints is None and date_array is None:
return 0
elif date_array is not None and n_datapoints is not None:
raise ValueError(
"Must provide at most one out of a "
"number of datapoints to simulate and "
"an array of dates data is observed."
)
elif date_array is not None:
return (
(max(date_array) - min(date_array))
// np.timedelta64((timestep_days), "D")
+ 1
).item()
else:
return n_datapoints

def to_forecast_data(self, n_forecast_points: int) -> Self:
n_days = self.n_days_post_init + n_forecast_points
n_weeks = n_days // 7
first_dow = self.first_data_date_overall.astype(dt.datetime).weekday()
to_first_sat = (5 - first_dow) % 7
start_date = convert_date(self.first_data_date_overall)
first_dow = start_date.weekday()

# Calculate offset to next Sunday (or 0 if already Sunday)
# MMWR weeks start on Sunday (weekday=6) and end on Saturday (weekday=5)
offset_to_sunday = (6 - first_dow) % 7

# First complete MMWR week ends 6 days after that Sunday
days_to_first_saturday = offset_to_sunday + 6

first_mmwr_ending_date = self.first_data_date_overall + np.timedelta64(
to_first_sat, "D"
days_to_first_saturday, "D"
)
return PyrenewHEWData(
n_ed_visits_data_days=n_days,
Expand Down
12 changes: 5 additions & 7 deletions pyrenew_hew/pyrenew_hew_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from pyrenew.observation import NegativeBinomialObservation
from pyrenew.process import ARProcess, DifferencedProcess
from pyrenew.randomvariable import DistributionalVariable, TransformedVariable
from pyrenew.time import daily_to_mmwr_epiweekly
from pyrenew.time import daily_to_mmwr_epiweekly, get_first_week_on_or_after_t0

from pyrenew_hew.pyrenew_hew_data import PyrenewHEWData

Expand Down Expand Up @@ -533,12 +533,10 @@ def calculate_weekly_hosp_indices(
) // 7
else:
which_obs_weekly_hosp_admissions = jnp.arange(n_datapoints)
if model_t_first_pred_admissions < 0:
which_obs_weekly_hosp_admissions = which_obs_weekly_hosp_admissions[
(-model_t_first_pred_admissions - 1) // 7 + 1 :
]
# Truncate to include only the epiweek ending after
# model t0 for posterior prediction
skip_weeks = get_first_week_on_or_after_t0(model_t_first_pred_admissions)
which_obs_weekly_hosp_admissions = which_obs_weekly_hosp_admissions[
skip_weeks:
]

return which_obs_weekly_hosp_admissions

Expand Down
Loading