diff --git a/pyrenew_hew/pyrenew_hew_data.py b/pyrenew_hew/pyrenew_hew_data.py index b3f30f0c..729521df 100644 --- a/pyrenew_hew/pyrenew_hew_data.py +++ b/pyrenew_hew/pyrenew_hew_data.py @@ -1,4 +1,3 @@ -import datetime as dt import json from pathlib import Path from typing import Self @@ -7,6 +6,14 @@ import numpy as np import polars as pl from jax.typing import ArrayLike +from pyrenew.time import ( + align_observation_times, + convert_date, + create_date_time_spine, + get_end_date, + get_n_data_days, + validate_mmwr_dates, +) class PyrenewHEWData: @@ -48,15 +55,7 @@ def __init__( self.right_truncation_offset = right_truncation_offset self.population_size = population_size - if ( - first_hospital_admissions_date is not None - and not first_hospital_admissions_date.astype(dt.datetime).weekday() == 5 - ): - raise ValueError( - "Dates for hospital admissions timeseries must " - "be Saturdays (MMWR epiweek end " - "days)." - ) + validate_mmwr_dates([first_hospital_admissions_date]) self.first_ed_visits_date_ = first_ed_visits_date self.first_hospital_admissions_date_ = first_hospital_admissions_date @@ -146,23 +145,23 @@ def from_json( @property def n_ed_visits_data_days(self): - return self.get_n_data_days( - n_datapoints=self.n_ed_visits_data_days_, + return get_n_data_days( + n_points=self.n_ed_visits_data_days_, date_array=self.dates_observed_ed_visits, ) @property def n_hospital_admissions_data_days(self): - return self.get_n_data_days( - n_datapoints=self.n_hospital_admissions_data_days_, + return get_n_data_days( + n_points=self.n_hospital_admissions_data_days_, date_array=self.dates_observed_hospital_admissions, timestep_days=7, ) @property def n_wastewater_data_days(self): - return self.get_n_data_days( - n_datapoints=self.n_wastewater_data_days_, + return get_n_data_days( + n_points=self.n_wastewater_data_days_, date_array=self.dates_observed_disease_wastewater, ) @@ -197,13 +196,13 @@ def first_ed_visits_date(self): @property def first_hospital_admissions_date(self): - if self.data_observed_disease_hospital_admissions is not None: + if self.dates_observed_hospital_admissions is not None: return self.dates_observed_hospital_admissions.min() return self.first_hospital_admissions_date_ @property def last_wastewater_date(self): - return self.get_end_date( + return get_end_date( self.first_wastewater_date, self.n_wastewater_data_days, timestep_days=1, @@ -211,7 +210,7 @@ def last_wastewater_date(self): @property def last_ed_visits_date(self): - return self.get_end_date( + return get_end_date( self.first_ed_visits_date, self.n_ed_visits_data_days, timestep_days=1, @@ -219,7 +218,7 @@ def last_ed_visits_date(self): @property def last_hospital_admissions_date(self): - return self.get_end_date( + return get_end_date( self.first_hospital_admissions_date, self.n_hospital_admissions_data_days, timestep_days=7, @@ -314,21 +313,9 @@ def site_subpop_spine(self): @property def date_time_spine(self): - date_time_spine = ( - pl.DataFrame( - { - "date": pl.date_range( - start=self.first_data_date_overall, - end=self.last_data_date_overall, - interval="1d", - eager=True, - ) - } - ) - .with_row_index("t") - .with_columns(pl.col("t").cast(pl.Int64)) + return create_date_time_spine( + self.first_data_date_overall, self.last_data_date_overall ) - return date_time_spine @property def wastewater_data_extended(self): @@ -374,33 +361,37 @@ def ww_uncensored(self): @property def model_t_obs_wastewater(self): if self.nwss_training_data is not None: - return self.wastewater_data_extended.get_column("t").to_numpy() + observed_dates = self.nwss_training_data.get_column("date").to_numpy() + return align_observation_times( + observed_dates, + self.first_data_date_overall, + aggregation_freq="daily", + ) + return None @property def model_t_obs_ed_visits(self): if self.nssp_training_data is not None: - return ( - self.nssp_training_data.join( - self.date_time_spine, on="date", how="left" - ) - .get_column("t") - .unique() - .to_numpy() + observed_dates = ( + self.nssp_training_data.get_column("date").unique().to_numpy() + ) + return align_observation_times( + observed_dates, + self.first_data_date_overall, + aggregation_freq="daily", ) return None @property def model_t_obs_hospital_admissions(self): if self.nhsn_training_data is not None: - return ( - self.nhsn_training_data.join( - self.date_time_spine, - left_on="weekendingdate", - right_on="date", - how="left", - ) - .get_column("t") - .to_numpy() + observed_dates = ( + self.nhsn_training_data.get_column("weekendingdate").unique().to_numpy() + ) + return align_observation_times( + observed_dates, + self.first_data_date_overall, + aggregation_freq="daily", ) return None @@ -439,61 +430,21 @@ def lab_site_to_subpop_map(self): ) return self.lab_site_to_subpop_map_ - def get_end_date( - self, - first_date: np.datetime64, - n_datapoints: int, - timestep_days: int = 1, - ) -> np.datetime64: - """ - Get end date from a first date and a number of datapoints, - with handling of None values and non-daily timeseries - """ - if first_date is None: - if n_datapoints != 0: - raise ValueError( - "Must provide an initial date if " - "n_datapoints is non-zero. " - f"Got n_datapoints = {n_datapoints} " - "but first_date was `None`" - ) - result = None - else: - result = first_date + np.timedelta64( - (n_datapoints - 1) * timestep_days, "D" - ) - return result - - def get_n_data_days( - self, - n_datapoints: int = None, - date_array: ArrayLike = None, - timestep_days: int = 1, - ) -> int: - if n_datapoints is None and date_array is None: - return 0 - elif date_array is not None and n_datapoints is not None: - raise ValueError( - "Must provide at most one out of a " - "number of datapoints to simulate and " - "an array of dates data is observed." - ) - elif date_array is not None: - return ( - (max(date_array) - min(date_array)) - // np.timedelta64((timestep_days), "D") - + 1 - ).item() - else: - return n_datapoints - def to_forecast_data(self, n_forecast_points: int) -> Self: n_days = self.n_days_post_init + n_forecast_points n_weeks = n_days // 7 - first_dow = self.first_data_date_overall.astype(dt.datetime).weekday() - to_first_sat = (5 - first_dow) % 7 + start_date = convert_date(self.first_data_date_overall) + first_dow = start_date.weekday() + + # Calculate offset to next Sunday (or 0 if already Sunday) + # MMWR weeks start on Sunday (weekday=6) and end on Saturday (weekday=5) + offset_to_sunday = (6 - first_dow) % 7 + + # First complete MMWR week ends 6 days after that Sunday + days_to_first_saturday = offset_to_sunday + 6 + first_mmwr_ending_date = self.first_data_date_overall + np.timedelta64( - to_first_sat, "D" + days_to_first_saturday, "D" ) return PyrenewHEWData( n_ed_visits_data_days=n_days, diff --git a/pyrenew_hew/pyrenew_hew_model.py b/pyrenew_hew/pyrenew_hew_model.py index a0f47532..d356e90d 100644 --- a/pyrenew_hew/pyrenew_hew_model.py +++ b/pyrenew_hew/pyrenew_hew_model.py @@ -24,7 +24,7 @@ from pyrenew.observation import NegativeBinomialObservation from pyrenew.process import ARProcess, DifferencedProcess from pyrenew.randomvariable import DistributionalVariable, TransformedVariable -from pyrenew.time import daily_to_mmwr_epiweekly +from pyrenew.time import daily_to_mmwr_epiweekly, get_first_week_on_or_after_t0 from pyrenew_hew.pyrenew_hew_data import PyrenewHEWData @@ -533,12 +533,10 @@ def calculate_weekly_hosp_indices( ) // 7 else: which_obs_weekly_hosp_admissions = jnp.arange(n_datapoints) - if model_t_first_pred_admissions < 0: - which_obs_weekly_hosp_admissions = which_obs_weekly_hosp_admissions[ - (-model_t_first_pred_admissions - 1) // 7 + 1 : - ] - # Truncate to include only the epiweek ending after - # model t0 for posterior prediction + skip_weeks = get_first_week_on_or_after_t0(model_t_first_pred_admissions) + which_obs_weekly_hosp_admissions = which_obs_weekly_hosp_admissions[ + skip_weeks: + ] return which_obs_weekly_hosp_admissions diff --git a/tests/test_pyrenew_hew_data.py b/tests/test_pyrenew_hew_data.py index be103594..b7c42208 100644 --- a/tests/test_pyrenew_hew_data.py +++ b/tests/test_pyrenew_hew_data.py @@ -114,24 +114,59 @@ def test_to_forecast_data( assert forecast_data.right_truncation_offset is None assert forecast_data.first_ed_visits_date == data.first_data_date_overall - ## hosp admit date should be the first Saturday + ## hosp admit date should be the first Saturday ending a complete MMWR week assert forecast_data.first_hospital_admissions_date >= data.first_data_date_overall assert ( forecast_data.first_hospital_admissions_date.astype(dt.datetime).weekday() == 5 ) - assert ( + # First complete MMWR week (Sunday-Saturday) ends 6-12 days from start + # 6 days if starting on Sunday, 12 days if starting on Monday + days_diff = ( (forecast_data.first_hospital_admissions_date - data.first_data_date_overall) / np.timedelta64(1, "D") - ).item() <= 6 + ).item() + assert 6 <= days_diff <= 12 assert forecast_data.first_wastewater_date == data.first_data_date_overall assert forecast_data.data_observed_disease_wastewater_conc is None +def test_to_forecast_data_saturday_edge_case(): + """ + Test that to_forecast_data correctly handles the edge case where + first_data_date_overall is already a Saturday. + + Since MMWR weeks run Sunday-Saturday, starting on Saturday means we skip + 1 day to Sunday, then the first complete week ends 6 days later on Saturday + (total: 7 days from the starting Saturday). + """ + # 2025-03-08 is a Saturday + first_date_saturday = np.datetime64("2025-03-08") + + data = PyrenewHEWData( + n_ed_visits_data_days=10, + n_hospital_admissions_data_days=2, + first_ed_visits_date=first_date_saturday, + first_hospital_admissions_date=first_date_saturday, + right_truncation_offset=0, + ) + + forecast_data = data.to_forecast_data(n_forecast_points=14) + + # Verify that first_hospital_admissions_date is 7 days after the start Saturday + expected_hosp_date = first_date_saturday + np.timedelta64(7, "D") + assert forecast_data.first_hospital_admissions_date == expected_hosp_date + + # Verify it's still a Saturday + assert ( + forecast_data.first_hospital_admissions_date.astype(dt.datetime).weekday() == 5 + ) + + def test_pyrenew_wastewater_data(): - first_training_date = np.datetime64("2023-01-01") - last_training_date = np.datetime64("2023-07-23") + first_training_date = dt.date(2023, 1, 1) + last_training_date = dt.date(2023, 7, 23) dates = pl.date_range( first_training_date, last_training_date, @@ -282,3 +317,546 @@ def test_build_pyrenew_hew_data_from_json(mock_data_dir): assert data.data_observed_disease_ed_visits is not None assert data.data_observed_disease_hospital_admissions is not None assert data.data_observed_disease_wastewater_conc is not None + + +# ============================================================================ +# NEW CRITICAL TESTS +# ============================================================================ + + +def test_model_t_conversions(): + """ + Test model_t_obs_* properties calculate correct indices. + + Verifies that observation dates are correctly converted to model time + indices relative to first_data_date_overall. + """ + ed_dates = ["2023-01-05", "2023-01-08"] + hosp_dates = ["2023-01-07", "2023-01-14"] + + nssp_data = pl.DataFrame( + { + "date": ed_dates, + "geo_value": ["CA", "CA"], + "observed_ed_visits": [10, 20], + "other_ed_visits": [100, 200], + "data_type": ["train", "train"], + }, + schema={ + "date": pl.Date, + "geo_value": pl.String, + "observed_ed_visits": pl.Int64, + "other_ed_visits": pl.Int64, + "data_type": pl.String, + }, + ) + + nhsn_data = pl.DataFrame( + { + "weekendingdate": hosp_dates, + "jurisdiction": ["CA", "CA"], + "hospital_admissions": [5, 10], + "data_type": ["train", "train"], + }, + schema={ + "weekendingdate": pl.Date, + "jurisdiction": pl.String, + "hospital_admissions": pl.Int64, + "data_type": pl.String, + }, + ) + + data = PyrenewHEWData( + nssp_training_data=nssp_data, + nhsn_training_data=nhsn_data, + first_ed_visits_date=np.datetime64(ed_dates[0]), + first_hospital_admissions_date=np.datetime64(hosp_dates[0]), + ) + + # first_data_date_overall should be min of all dates + assert data.first_data_date_overall == np.datetime64(ed_dates[0]) + + # model_t should be 0 for first date, 3 for second ED visit + assert data.model_t_obs_ed_visits[0] == 0 + assert data.model_t_obs_ed_visits[1] == 3 + + # Hospital admissions at days 2 and 9 + assert data.model_t_obs_hospital_admissions[0] == 2 + assert data.model_t_obs_hospital_admissions[1] == 9 + + +def test_properties_with_no_data(): + """ + Test behavior when no training data provided. + + Ensures properties return None appropriately when no DataFrames + are provided to the constructor. + """ + data = PyrenewHEWData( + n_ed_visits_data_days=10, first_ed_visits_date=np.datetime64("2023-01-01") + ) + + assert data.data_observed_disease_ed_visits is None + assert data.data_observed_disease_hospital_admissions is None + assert data.data_observed_disease_wastewater_conc is None + assert data.model_t_obs_ed_visits is None + assert data.model_t_obs_hospital_admissions is None + assert data.model_t_obs_wastewater is None + + +def test_mixed_data_sources(): + """ + Test with only some data sources present. + + Verifies correct behavior when only ED data is provided + but hospital and wastewater data are absent. + """ + ed_data = pl.DataFrame( + { + "date": ["2023-01-01"], + "geo_value": ["CA"], + "observed_ed_visits": [10], + "other_ed_visits": [100], + "data_type": ["train"], + }, + schema={ + "date": pl.Date, + "geo_value": pl.String, + "observed_ed_visits": pl.Int64, + "other_ed_visits": pl.Int64, + "data_type": pl.String, + }, + ) + + data = PyrenewHEWData( + nssp_training_data=ed_data, first_ed_visits_date=np.datetime64("2023-01-01") + ) + + # ED data should work + assert data.data_observed_disease_ed_visits is not None + assert len(data.data_observed_disease_ed_visits) == 1 + + # Others should be None + assert data.first_hospital_admissions_date is None + assert data.first_wastewater_date is None + assert data.data_observed_disease_hospital_admissions is None + assert data.data_observed_disease_wastewater_conc is None + + +def test_site_subpop_spine_with_auxiliary(): + """ + Test subpopulation creation when WW sites don't cover full population. + + When wastewater sampling sites don't cover the entire population, + an auxiliary subpopulation should be created for the remainder. + """ + ww_data = pl.DataFrame( + { + "date": ["2023-01-01", "2023-01-01"], + "site": ["site1", "site2"], + "site_index": [0, 1], + "site_pop": [200_000, 300_000], + "lab_site_index": [0, 1], + "log_genome_copies_per_ml": [1.0, 2.0], + "log_lod": [0.5, 0.5], + "below_lod": [0, 0], + }, + schema={ + "date": pl.Date, + "site": pl.String, + "site_index": pl.Int64, + "site_pop": pl.Int64, + "lab_site_index": pl.Int64, + "log_genome_copies_per_ml": pl.Float64, + "log_lod": pl.Float64, + "below_lod": pl.Int64, + }, + ) + + data = PyrenewHEWData( + nwss_training_data=ww_data, + population_size=1_000_000, # 500k not covered by WW + ) + + spine = data.site_subpop_spine + # Should have 3 subpops: 2 sites + 1 auxiliary + assert len(spine) == 3 + assert spine.filter(pl.col("site").is_null()).height == 1 + # Auxiliary subpop should have remaining population + assert spine.filter(pl.col("site").is_null())["subpop_pop"][0] == 500_000 + + +def test_site_subpop_spine_no_auxiliary(): + """ + Test when WW sites cover entire population. + + When sampling sites cover the full population, no auxiliary + subpopulation should be created. + """ + ww_data = pl.DataFrame( + { + "date": ["2023-01-01", "2023-01-01"], + "site": ["site1", "site2"], + "site_index": [0, 1], + "site_pop": [400_000, 600_000], + "lab_site_index": [0, 1], + "log_genome_copies_per_ml": [1.0, 2.0], + "log_lod": [0.5, 0.5], + "below_lod": [0, 0], + }, + schema={ + "date": pl.Date, + "site": pl.String, + "site_index": pl.Int64, + "site_pop": pl.Int64, + "lab_site_index": pl.Int64, + "log_genome_copies_per_ml": pl.Float64, + "log_lod": pl.Float64, + "below_lod": pl.Int64, + }, + ) + + data = PyrenewHEWData(nwss_training_data=ww_data, population_size=1_000_000) + + spine = data.site_subpop_spine + # Should have only 2 subpops + assert len(spine) == 2 + assert spine.filter(pl.col("site").is_null()).height == 0 + + +def test_censored_uncensored_split(): + """ + Test correct identification of censored vs uncensored observations. + + Wastewater observations below the limit of detection are censored. + This test verifies correct indexing of censored/uncensored data. + """ + ww_data = pl.DataFrame( + { + "date": ["2023-01-01"] * 4, + "site": ["site1"] * 4, + "site_index": [0] * 4, + "site_pop": [500_000] * 4, + "lab_site_index": [0] * 4, + "log_genome_copies_per_ml": [0.5, 1.5, 0.3, 2.0], + "log_lod": [1.0, 1.0, 1.0, 1.0], + "below_lod": [1, 0, 1, 0], # 2 censored, 2 uncensored + }, + schema={ + "date": pl.Date, + "site": pl.String, + "site_index": pl.Int64, + "site_pop": pl.Int64, + "lab_site_index": pl.Int64, + "log_genome_copies_per_ml": pl.Float64, + "log_lod": pl.Float64, + "below_lod": pl.Int64, + }, + ) + + data = PyrenewHEWData(nwss_training_data=ww_data, population_size=1_000_000) + + assert len(data.ww_censored) == 2 + assert len(data.ww_uncensored) == 2 + # Censored indices should be 0 and 2 + assert np.array_equal(data.ww_censored, [0, 2]) + assert np.array_equal(data.ww_uncensored, [1, 3]) + + +def test_n_days_post_init_single_source(): + """ + Test n_days_post_init with only one data source. + + With a single data source, n_days_post_init should equal + the number of days in that source. + """ + data = PyrenewHEWData( + n_ed_visits_data_days=30, first_ed_visits_date=np.datetime64("2023-01-01") + ) + + # Should be 30 days (Jan 1 to Jan 30 inclusive) + assert data.n_days_post_init == 30 + + +def test_n_days_post_init_multiple_sources(): + """ + Test with multiple overlapping/non-overlapping data sources. + + With multiple data sources, n_days_post_init should span from + the earliest first date to the latest last date. + """ + data = PyrenewHEWData( + n_ed_visits_data_days=20, + n_hospital_admissions_data_days=3, # 3 weeks = 21 days + first_ed_visits_date=np.datetime64("2023-01-01"), + first_hospital_admissions_date=np.datetime64("2023-01-07"), + ) + + # ED: Jan 1-20 (20 days) + # Hosp: Jan 7, 14, 21 (3 weeks ending Jan 21, so 21 days from start) + # Overall: Jan 1 - Jan 21 = 21 days + assert data.n_days_post_init == 21 + + +def test_lab_site_to_subpop_map(): + """ + Test correct mapping from lab sites to subpopulations. + + Multiple labs can sample from the same site. The mapping + should correctly associate each lab with its subpopulation. + """ + ww_data = pl.DataFrame( + { + "date": ["2023-01-01"] * 4, + "site": ["site1", "site1", "site2", "site2"], + "site_index": [0, 0, 1, 1], + "site_pop": [400_000] * 2 + [200_000] * 2, + "lab_site_index": [0, 1, 2, 3], # 4 labs, 2 sites + "log_genome_copies_per_ml": [1.0, 1.5, 2.0, 2.5], + "log_lod": [0.5] * 4, + "below_lod": [0] * 4, + }, + schema={ + "date": pl.Date, + "site": pl.String, + "site_index": pl.Int64, + "site_pop": pl.Int64, + "lab_site_index": pl.Int64, + "log_genome_copies_per_ml": pl.Float64, + "log_lod": pl.Float64, + "below_lod": pl.Int64, + }, + ) + + data = PyrenewHEWData( + nwss_training_data=ww_data, + population_size=1_000_000, # Creates auxiliary subpop 0 + ) + + # First 2 labs map to subpop 1 (site_index 0 + 1 for auxiliary) + # Next 2 labs map to subpop 2 (site_index 1 + 1 for auxiliary) + mapping = data.lab_site_to_subpop_map + assert len(mapping) == 4 + assert mapping[0] == 1 # lab 0 -> subpop 1 + assert mapping[1] == 1 # lab 1 -> subpop 1 + assert mapping[2] == 2 # lab 2 -> subpop 2 + assert mapping[3] == 2 # lab 3 -> subpop 2 + + +def test_date_time_spine(): + """ + Test creation of date-time spine for temporal indexing. + + The date_time_spine should map each date to its model time index. + """ + data = PyrenewHEWData( + n_ed_visits_data_days=10, first_ed_visits_date=np.datetime64("2023-01-01") + ) + + spine = data.date_time_spine + assert len(spine) == 10 + assert spine["t"][0] == 0 + assert spine["t"][9] == 9 + assert spine["date"][0] == dt.date(2023, 1, 1) + assert spine["date"][9] == dt.date(2023, 1, 10) + + +# ============================================================================ +# ADDITIONAL TESTS FOR mem_refactor_datetime_indexing BRANCH +# ============================================================================ + + +@pytest.mark.parametrize( + ["first_date", "expected_dow", "description"], + [ + (np.datetime64("2023-01-01"), 6, "Sunday"), # Sunday + (np.datetime64("2023-01-02"), 0, "Monday"), # Monday + (np.datetime64("2023-01-03"), 1, "Tuesday"), # Tuesday + (np.datetime64("2023-01-04"), 2, "Wednesday"), # Wednesday + (np.datetime64("2023-01-05"), 3, "Thursday"), # Thursday + (np.datetime64("2023-01-06"), 4, "Friday"), # Friday + (np.datetime64("2023-01-07"), 5, "Saturday"), # Saturday + ], +) +def test_to_forecast_data_different_start_days(first_date, expected_dow, description): + """ + Test to_forecast_data() with first_data_date_overall on different days of week. + + This ensures that the calculation of first_hospital_admissions_date + (which must be a Saturday) works correctly regardless of what day + of the week the data starts on. + """ + # Use first_date as Saturday for hospital admissions + # Calculate the nearest Saturday on or after first_date + days_to_saturday = (5 - expected_dow) % 7 + if days_to_saturday == 0 and expected_dow == 5: + # Already a Saturday, use it + first_hosp_date = first_date + else: + # Find next Saturday + first_hosp_date = first_date + np.timedelta64( + days_to_saturday if days_to_saturday > 0 else 7, "D" + ) + + data = PyrenewHEWData( + n_ed_visits_data_days=14, + n_hospital_admissions_data_days=2, + first_ed_visits_date=first_date, + first_hospital_admissions_date=first_hosp_date, + ) + + forecast_data = data.to_forecast_data(n_forecast_points=7) + + # Verify the forecast hospital admissions date is a Saturday + forecast_hosp_date = forecast_data.first_hospital_admissions_date + assert forecast_hosp_date.astype(dt.datetime).weekday() == 5, ( + f"Expected Saturday (5), got {forecast_hosp_date.astype(dt.datetime).weekday()} " + f"for start day {description}" + ) + + # Verify it's after the first data date + assert forecast_hosp_date >= data.first_data_date_overall + + # Verify it ends the first complete MMWR week (6-12 days from start) + # 6 days if starting on Sunday, 12 days if starting on Monday + days_diff = (forecast_hosp_date - data.first_data_date_overall) / np.timedelta64( + 1, "D" + ) + assert 6 <= days_diff <= 12 + + +@pytest.mark.parametrize( + ["n_days", "expected_weeks", "has_partial_week"], + [ + (7, 1, False), # Exactly 1 week + (10, 1, True), # 1 week + 3 days + (13, 1, True), # 1 week + 6 days + (14, 2, False), # Exactly 2 weeks + (20, 2, True), # 2 weeks + 6 days + (21, 3, False), # Exactly 3 weeks + (1, 0, True), # Less than 1 week + (6, 0, True), # Less than 1 week + ], +) +def test_to_forecast_data_partial_weeks(n_days, expected_weeks, has_partial_week): + """ + Test to_forecast_data() with data spanning partial weeks. + + Hospital admissions are weekly, so we need to ensure correct + handling when the total number of days doesn't evenly divide by 7. + """ + first_date = np.datetime64("2023-01-01") # Sunday + first_saturday = np.datetime64("2023-01-07") # First Saturday + + data = PyrenewHEWData( + n_ed_visits_data_days=n_days, + n_hospital_admissions_data_days=0, + first_ed_visits_date=first_date, + first_hospital_admissions_date=first_saturday, + ) + + forecast_data = data.to_forecast_data(n_forecast_points=0) + + # Check that the number of weeks is calculated correctly + assert forecast_data.n_hospital_admissions_data_days == expected_weeks, ( + f"For {n_days} days, expected {expected_weeks} weeks, " + f"got {forecast_data.n_hospital_admissions_data_days}" + ) + + +def test_to_forecast_data_negative_n_forecast_points_error(): + """ + Test that to_forecast_data() handles edge case of 0 forecast points. + + While negative forecast points don't make logical sense, + we test the current behavior for 0 points. + """ + data = PyrenewHEWData( + n_ed_visits_data_days=10, + n_hospital_admissions_data_days=1, + first_ed_visits_date=np.datetime64("2023-01-01"), + first_hospital_admissions_date=np.datetime64("2023-01-07"), + ) + + # 0 forecast points should work + forecast_data = data.to_forecast_data(n_forecast_points=0) + assert forecast_data.n_ed_visits_data_days == data.n_ed_visits_data_days + # n_days_post_init // 7 gives the number of weeks + assert forecast_data.n_hospital_admissions_data_days == data.n_days_post_init // 7 + + +def test_date_handling_leap_year(): + """ + Test date handling across leap year boundary (Feb 29). + + Ensures that date arithmetic works correctly with leap years. + """ + # Leap year: 2024 + leap_year_start = np.datetime64("2024-02-28") + leap_year_saturday = np.datetime64("2024-03-02") # Saturday after Feb 29 + + data = PyrenewHEWData( + n_ed_visits_data_days=10, # Spans Feb 28 - Mar 8, includes Feb 29 + n_hospital_admissions_data_days=2, + first_ed_visits_date=leap_year_start, + first_hospital_admissions_date=leap_year_saturday, + ) + + # Verify dates span the leap day + assert data.first_ed_visits_date == leap_year_start + assert data.last_ed_visits_date == leap_year_start + np.timedelta64(9, "D") + + # Feb 29 should exist in 2024 + feb_29_2024 = np.datetime64("2024-02-29") + assert data.first_ed_visits_date < feb_29_2024 < data.last_ed_visits_date + + # Verify n_days_post_init calculates correctly across leap day + # Feb 28 to Mar 2 is 4 days, so n_days_post_init should be the maximum + # span from first_data_date_overall (Feb 28) to last_data_date_overall + # ED: Feb 28 + 9 days = Mar 8 + # Hosp: Mar 2 + (2 weeks * 7 days) - 1 = Mar 2 + 13 = Mar 15 + # So n_days_post_init = Mar 15 - Feb 28 + 1 = 17 + expected_n_days = ( + data.last_data_date_overall - data.first_data_date_overall + ) // np.timedelta64(1, "D") + 1 + assert data.n_days_post_init == expected_n_days + + # Test forecast across leap year + forecast_data = data.to_forecast_data(n_forecast_points=30) + assert forecast_data.n_ed_visits_data_days == data.n_days_post_init + 30 + + +def test_date_handling_year_boundary(): + """ + Test date handling across year boundary (Dec 31 -> Jan 1). + + Ensures that date arithmetic works correctly when spanning + two different years. + """ + # Start late in December + year_end = np.datetime64("2023-12-28") # Thursday + year_end_saturday = np.datetime64("2023-12-30") # Saturday + + data = PyrenewHEWData( + n_ed_visits_data_days=14, # Spans Dec 28, 2023 - Jan 10, 2024 + n_hospital_admissions_data_days=2, + first_ed_visits_date=year_end, + first_hospital_admissions_date=year_end_saturday, + ) + + # Verify dates span the year boundary + assert data.first_ed_visits_date.astype("datetime64[Y]").astype(int) + 1970 == 2023 + assert data.last_ed_visits_date.astype("datetime64[Y]").astype(int) + 1970 == 2024 + + # Verify n_days_post_init is correct + assert data.n_days_post_init == 14 + + # Test that date_time_spine works across year boundary + spine = data.date_time_spine + assert len(spine) == 14 + assert spine["date"][0].year == 2023 + assert spine["date"][13].year == 2024 + + # Test forecast into new year + forecast_data = data.to_forecast_data(n_forecast_points=20) + assert forecast_data.n_ed_visits_data_days == 34 diff --git a/uv.lock b/uv.lock index 323337d6..77fb26a0 100644 --- a/uv.lock +++ b/uv.lock @@ -2075,8 +2075,8 @@ wheels = [ [[package]] name = "pyrenew" -version = "0.1.4" -source = { git = "https://github.com/cdcgov/PyRenew/#468ad5823cdf6f715f91240ddc3a01ec756c9b61" } +version = "0.1.5" +source = { git = "https://github.com/cdcgov/PyRenew/#e5ed80475817c11f1c5b232e34c977c3822a4766" } dependencies = [ { name = "jax" }, { name = "numpy" },