Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,6 @@ testpaths = ["tests", "pipelines/tests"]

[tool.uv.sources]
polarbayes = { git = "https://github.com/CDCgov/polarbayes" }
pyrenew = { git = "https://github.com/CDCgov/PyRenew", rev = "v0.1.5" }
pyrenew = { git = "https://github.com/CDCgov/PyRenew", rev = "v0.1.6" }
forecasttools = { git = "https://github.com/cdcgov/forecasttools-py" }
azuretools = { git = "https://github.com/cdcgov/cfa-azuretools" }
98 changes: 18 additions & 80 deletions pyrenew_hew/pyrenew_hew_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@
import numpy as np
import polars as pl
from jax.typing import ArrayLike
from pyrenew.time import (
create_date_time_spine,
get_end_date,
get_n_data_days,
validate_mmwr_dates,
)


class PyrenewHEWData:
Expand Down Expand Up @@ -48,15 +54,7 @@ def __init__(
self.right_truncation_offset = right_truncation_offset
self.population_size = population_size

if (
first_hospital_admissions_date is not None
and not first_hospital_admissions_date.astype(dt.datetime).weekday() == 5
):
raise ValueError(
"Dates for hospital admissions timeseries must "
"be Saturdays (MMWR epiweek end "
"days)."
)
validate_mmwr_dates([first_hospital_admissions_date])

self.first_ed_visits_date_ = first_ed_visits_date
self.first_hospital_admissions_date_ = first_hospital_admissions_date
Expand Down Expand Up @@ -146,23 +144,23 @@ def from_json(

@property
def n_ed_visits_data_days(self):
return self.get_n_data_days(
n_datapoints=self.n_ed_visits_data_days_,
return get_n_data_days(
n_points=self.n_ed_visits_data_days_,
date_array=self.dates_observed_ed_visits,
)

@property
def n_hospital_admissions_data_days(self):
return self.get_n_data_days(
n_datapoints=self.n_hospital_admissions_data_days_,
return get_n_data_days(
n_points=self.n_hospital_admissions_data_days_,
date_array=self.dates_observed_hospital_admissions,
timestep_days=7,
)

@property
def n_wastewater_data_days(self):
return self.get_n_data_days(
n_datapoints=self.n_wastewater_data_days_,
return get_n_data_days(
n_points=self.n_wastewater_data_days_,
date_array=self.dates_observed_disease_wastewater,
)

Expand Down Expand Up @@ -203,23 +201,23 @@ def first_hospital_admissions_date(self):

@property
def last_wastewater_date(self):
return self.get_end_date(
return get_end_date(
self.first_wastewater_date,
self.n_wastewater_data_days,
timestep_days=1,
)

@property
def last_ed_visits_date(self):
return self.get_end_date(
return get_end_date(
self.first_ed_visits_date,
self.n_ed_visits_data_days,
timestep_days=1,
)

@property
def last_hospital_admissions_date(self):
return self.get_end_date(
return get_end_date(
self.first_hospital_admissions_date,
self.n_hospital_admissions_data_days,
timestep_days=7,
Expand Down Expand Up @@ -314,21 +312,9 @@ def site_subpop_spine(self):

@property
def date_time_spine(self):
date_time_spine = (
pl.DataFrame(
{
"date": pl.date_range(
start=self.first_data_date_overall,
end=self.last_data_date_overall,
interval="1d",
eager=True,
)
}
)
.with_row_index("t")
.with_columns(pl.col("t").cast(pl.Int64))
return create_date_time_spine(
self.first_data_date_overall, self.last_data_date_overall
)
return date_time_spine

@property
def wastewater_data_extended(self):
Expand Down Expand Up @@ -439,54 +425,6 @@ def lab_site_to_subpop_map(self):
)
return self.lab_site_to_subpop_map_

def get_end_date(
self,
first_date: np.datetime64,
n_datapoints: int,
timestep_days: int = 1,
) -> np.datetime64:
"""
Get end date from a first date and a number of datapoints,
with handling of None values and non-daily timeseries
"""
if first_date is None:
if n_datapoints != 0:
raise ValueError(
"Must provide an initial date if "
"n_datapoints is non-zero. "
f"Got n_datapoints = {n_datapoints} "
"but first_date was `None`"
)
result = None
else:
result = first_date + np.timedelta64(
(n_datapoints - 1) * timestep_days, "D"
)
return result

def get_n_data_days(
self,
n_datapoints: int = None,
date_array: ArrayLike = None,
timestep_days: int = 1,
) -> int:
if n_datapoints is None and date_array is None:
return 0
elif date_array is not None and n_datapoints is not None:
raise ValueError(
"Must provide at most one out of a "
"number of datapoints to simulate and "
"an array of dates data is observed."
)
elif date_array is not None:
return (
(max(date_array) - min(date_array))
// np.timedelta64((timestep_days), "D")
+ 1
).item()
else:
return n_datapoints

def to_forecast_data(self, n_forecast_points: int) -> Self:
"""
Create a new PyrenewHEWData instance for forecasting.
Expand Down
2 changes: 1 addition & 1 deletion tests/test_pyrenew_hew_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ def test_hospital_admissions_must_be_saturday():
# 2023-01-02 is a Monday (not a Saturday)
monday_date = np.datetime64("2023-01-02")

with pytest.raises(ValueError, match="Saturdays.*MMWR epiweek"):
with pytest.raises(ValueError, match="MMWR dates must be Saturdays"):
PyrenewHEWData(
n_ed_visits_data_days=10,
n_hospital_admissions_data_days=1,
Expand Down
6 changes: 3 additions & 3 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.