diff --git a/pyproject.toml b/pyproject.toml index 91c7763c..353c5389 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,6 +58,6 @@ testpaths = ["tests", "pipelines/tests"] [tool.uv.sources] polarbayes = { git = "https://github.com/CDCgov/polarbayes" } -pyrenew = { git = "https://github.com/CDCgov/PyRenew", rev = "v0.1.5" } +pyrenew = { git = "https://github.com/CDCgov/PyRenew", rev = "v0.1.6" } forecasttools = { git = "https://github.com/cdcgov/forecasttools-py" } azuretools = { git = "https://github.com/cdcgov/cfa-azuretools" } diff --git a/pyrenew_hew/pyrenew_hew_data.py b/pyrenew_hew/pyrenew_hew_data.py index 6fe2b47c..cc2fc5a7 100644 --- a/pyrenew_hew/pyrenew_hew_data.py +++ b/pyrenew_hew/pyrenew_hew_data.py @@ -7,6 +7,12 @@ import numpy as np import polars as pl from jax.typing import ArrayLike +from pyrenew.time import ( + create_date_time_spine, + get_end_date, + get_n_data_days, + validate_mmwr_dates, +) class PyrenewHEWData: @@ -48,15 +54,7 @@ def __init__( self.right_truncation_offset = right_truncation_offset self.population_size = population_size - if ( - first_hospital_admissions_date is not None - and not first_hospital_admissions_date.astype(dt.datetime).weekday() == 5 - ): - raise ValueError( - "Dates for hospital admissions timeseries must " - "be Saturdays (MMWR epiweek end " - "days)." - ) + validate_mmwr_dates([first_hospital_admissions_date]) self.first_ed_visits_date_ = first_ed_visits_date self.first_hospital_admissions_date_ = first_hospital_admissions_date @@ -146,23 +144,23 @@ def from_json( @property def n_ed_visits_data_days(self): - return self.get_n_data_days( - n_datapoints=self.n_ed_visits_data_days_, + return get_n_data_days( + n_points=self.n_ed_visits_data_days_, date_array=self.dates_observed_ed_visits, ) @property def n_hospital_admissions_data_days(self): - return self.get_n_data_days( - n_datapoints=self.n_hospital_admissions_data_days_, + return get_n_data_days( + n_points=self.n_hospital_admissions_data_days_, date_array=self.dates_observed_hospital_admissions, timestep_days=7, ) @property def n_wastewater_data_days(self): - return self.get_n_data_days( - n_datapoints=self.n_wastewater_data_days_, + return get_n_data_days( + n_points=self.n_wastewater_data_days_, date_array=self.dates_observed_disease_wastewater, ) @@ -203,7 +201,7 @@ def first_hospital_admissions_date(self): @property def last_wastewater_date(self): - return self.get_end_date( + return get_end_date( self.first_wastewater_date, self.n_wastewater_data_days, timestep_days=1, @@ -211,7 +209,7 @@ def last_wastewater_date(self): @property def last_ed_visits_date(self): - return self.get_end_date( + return get_end_date( self.first_ed_visits_date, self.n_ed_visits_data_days, timestep_days=1, @@ -219,7 +217,7 @@ def last_ed_visits_date(self): @property def last_hospital_admissions_date(self): - return self.get_end_date( + return get_end_date( self.first_hospital_admissions_date, self.n_hospital_admissions_data_days, timestep_days=7, @@ -314,21 +312,9 @@ def site_subpop_spine(self): @property def date_time_spine(self): - date_time_spine = ( - pl.DataFrame( - { - "date": pl.date_range( - start=self.first_data_date_overall, - end=self.last_data_date_overall, - interval="1d", - eager=True, - ) - } - ) - .with_row_index("t") - .with_columns(pl.col("t").cast(pl.Int64)) + return create_date_time_spine( + self.first_data_date_overall, self.last_data_date_overall ) - return date_time_spine @property def wastewater_data_extended(self): @@ -439,54 +425,6 @@ def lab_site_to_subpop_map(self): ) return self.lab_site_to_subpop_map_ - def get_end_date( - self, - first_date: np.datetime64, - n_datapoints: int, - timestep_days: int = 1, - ) -> np.datetime64: - """ - Get end date from a first date and a number of datapoints, - with handling of None values and non-daily timeseries - """ - if first_date is None: - if n_datapoints != 0: - raise ValueError( - "Must provide an initial date if " - "n_datapoints is non-zero. " - f"Got n_datapoints = {n_datapoints} " - "but first_date was `None`" - ) - result = None - else: - result = first_date + np.timedelta64( - (n_datapoints - 1) * timestep_days, "D" - ) - return result - - def get_n_data_days( - self, - n_datapoints: int = None, - date_array: ArrayLike = None, - timestep_days: int = 1, - ) -> int: - if n_datapoints is None and date_array is None: - return 0 - elif date_array is not None and n_datapoints is not None: - raise ValueError( - "Must provide at most one out of a " - "number of datapoints to simulate and " - "an array of dates data is observed." - ) - elif date_array is not None: - return ( - (max(date_array) - min(date_array)) - // np.timedelta64((timestep_days), "D") - + 1 - ).item() - else: - return n_datapoints - def to_forecast_data(self, n_forecast_points: int) -> Self: """ Create a new PyrenewHEWData instance for forecasting. diff --git a/tests/test_pyrenew_hew_data.py b/tests/test_pyrenew_hew_data.py index 54da12e7..69aed488 100644 --- a/tests/test_pyrenew_hew_data.py +++ b/tests/test_pyrenew_hew_data.py @@ -295,7 +295,7 @@ def test_hospital_admissions_must_be_saturday(): # 2023-01-02 is a Monday (not a Saturday) monday_date = np.datetime64("2023-01-02") - with pytest.raises(ValueError, match="Saturdays.*MMWR epiweek"): + with pytest.raises(ValueError, match="MMWR dates must be Saturdays"): PyrenewHEWData( n_ed_visits_data_days=10, n_hospital_admissions_data_days=1, diff --git a/uv.lock b/uv.lock index f7a63531..257cc5dd 100644 --- a/uv.lock +++ b/uv.lock @@ -2206,8 +2206,8 @@ wheels = [ [[package]] name = "pyrenew" -version = "0.1.5" -source = { git = "https://github.com/CDCgov/PyRenew?rev=v0.1.5#7f7a4f8791fa860a488a4598ec144bc7a2aee6ef" } +version = "0.1.6" +source = { git = "https://github.com/CDCgov/PyRenew?rev=v0.1.6#5c502bdc4227def615ec33a25e809ddaa43581ff" } dependencies = [ { name = "jax" }, { name = "numpy" }, @@ -2262,7 +2262,7 @@ requires-dist = [ { name = "pyarrow", specifier = ">=18.0.0" }, { name = "pygit2", specifier = ">=1.17.0" }, { name = "pypdf", specifier = ">=5.1.0" }, - { name = "pyrenew", git = "https://github.com/CDCgov/PyRenew?rev=v0.1.5" }, + { name = "pyrenew", git = "https://github.com/CDCgov/PyRenew?rev=v0.1.6" }, { name = "pyyaml", specifier = ">=6.0.2" }, { name = "rich", specifier = ">=14.0.0" }, { name = "tomli-w", specifier = ">=1.1.0" },