diff --git a/pyrenew_hew/pyrenew_hew_data.py b/pyrenew_hew/pyrenew_hew_data.py index b3f30f0c..6fe2b47c 100644 --- a/pyrenew_hew/pyrenew_hew_data.py +++ b/pyrenew_hew/pyrenew_hew_data.py @@ -488,20 +488,48 @@ def get_n_data_days( return n_datapoints def to_forecast_data(self, n_forecast_points: int) -> Self: + """ + Create a new PyrenewHEWData instance for forecasting. + + This method extends the current data object to include forecast points, + converting from observed data to a structure suitable for forecasting. + + Parameters + ---------- + n_forecast_points : int + Number of additional days to forecast beyond the current data. + + Returns + ------- + PyrenewHEWData + A new instance configured for forecasting with extended time range. + + Notes + ----- + The method handles different temporal resolutions for data streams: + + - ED visits and wastewater data are daily, so they extend by + n_forecast_points days. + - Hospital admissions are weekly (MMWR epiweeks), so the number of + weeks is calculated as total days divided by 7 (integer division). + """ + # Calculate total forecast period n_days = self.n_days_post_init + n_forecast_points n_weeks = n_days // 7 + + # Find the first Saturday on or after first_data_date_overall first_dow = self.first_data_date_overall.astype(dt.datetime).weekday() - to_first_sat = (5 - first_dow) % 7 + to_first_sat = (5 - first_dow) % 7 # Saturday is weekday 5 first_mmwr_ending_date = self.first_data_date_overall + np.timedelta64( to_first_sat, "D" ) + return PyrenewHEWData( n_ed_visits_data_days=n_days, n_hospital_admissions_data_days=n_weeks, n_wastewater_data_days=n_days, first_ed_visits_date=self.first_data_date_overall, first_hospital_admissions_date=first_mmwr_ending_date, - # admissions are MMWR epiweekly first_wastewater_date=self.first_data_date_overall, right_truncation_offset=None, # by default, want forecasts of complete reports n_ww_lab_sites=self.n_ww_lab_sites, diff --git a/tests/test_pyrenew_hew_data.py b/tests/test_pyrenew_hew_data.py index be103594..54da12e7 100644 --- a/tests/test_pyrenew_hew_data.py +++ b/tests/test_pyrenew_hew_data.py @@ -282,3 +282,505 @@ def test_build_pyrenew_hew_data_from_json(mock_data_dir): assert data.data_observed_disease_ed_visits is not None assert data.data_observed_disease_hospital_admissions is not None assert data.data_observed_disease_wastewater_conc is not None + + +def test_hospital_admissions_must_be_saturday(): + """ + Test that hospital admissions dates must be Saturdays (MMWR epiweek ends). + + The MMWR (Morbidity and Mortality Weekly Report) standard requires that + weekly hospital admissions data use Saturdays as the week-ending date. + This test verifies that non-Saturday dates are rejected. + """ + # 2023-01-02 is a Monday (not a Saturday) + monday_date = np.datetime64("2023-01-02") + + with pytest.raises(ValueError, match="Saturdays.*MMWR epiweek"): + PyrenewHEWData( + n_ed_visits_data_days=10, + n_hospital_admissions_data_days=1, + first_ed_visits_date=np.datetime64("2023-01-01"), + first_hospital_admissions_date=monday_date, + ) + + # Verify that a Saturday DOES work (no exception) + saturday_date = np.datetime64("2023-01-07") # Saturday + data = PyrenewHEWData( + n_ed_visits_data_days=10, + n_hospital_admissions_data_days=1, + first_ed_visits_date=np.datetime64("2023-01-01"), + first_hospital_admissions_date=saturday_date, + ) + assert data.first_hospital_admissions_date == saturday_date + + +def test_model_t_conversions(): + """ + Test model_t_obs_* properties calculate correct indices. + + Verifies that observation dates are correctly converted to model time + indices relative to first_data_date_overall. + """ + ed_dates = ["2023-01-05", "2023-01-08"] + hosp_dates = ["2023-01-07", "2023-01-14"] + + nssp_data = pl.DataFrame( + { + "date": ed_dates, + "geo_value": ["CA", "CA"], + "observed_ed_visits": [10, 20], + "other_ed_visits": [100, 200], + "data_type": ["train", "train"], + }, + schema={ + "date": pl.Date, + "geo_value": pl.String, + "observed_ed_visits": pl.Int64, + "other_ed_visits": pl.Int64, + "data_type": pl.String, + }, + ) + + nhsn_data = pl.DataFrame( + { + "weekendingdate": hosp_dates, + "jurisdiction": ["CA", "CA"], + "hospital_admissions": [5, 10], + "data_type": ["train", "train"], + }, + schema={ + "weekendingdate": pl.Date, + "jurisdiction": pl.String, + "hospital_admissions": pl.Int64, + "data_type": pl.String, + }, + ) + + data = PyrenewHEWData( + nssp_training_data=nssp_data, + nhsn_training_data=nhsn_data, + first_ed_visits_date=np.datetime64(ed_dates[0]), + first_hospital_admissions_date=np.datetime64(hosp_dates[0]), + ) + + # first_data_date_overall should be min of all dates + assert data.first_data_date_overall == np.datetime64(ed_dates[0]) + + # model_t should be 0 for first date, 3 for second ED visit + assert data.model_t_obs_ed_visits[0] == 0 + assert data.model_t_obs_ed_visits[1] == 3 + + # Hospital admissions at days 2 and 9 + assert data.model_t_obs_hospital_admissions[0] == 2 + assert data.model_t_obs_hospital_admissions[1] == 9 + + +def test_mixed_data_sources(): + """ + Test with only some data sources present. + + Verifies correct behavior when only ED data is provided + but hospital and wastewater data are absent. + """ + ed_data = pl.DataFrame( + { + "date": ["2023-01-01"], + "geo_value": ["CA"], + "observed_ed_visits": [10], + "other_ed_visits": [100], + "data_type": ["train"], + }, + schema={ + "date": pl.Date, + "geo_value": pl.String, + "observed_ed_visits": pl.Int64, + "other_ed_visits": pl.Int64, + "data_type": pl.String, + }, + ) + + data = PyrenewHEWData( + nssp_training_data=ed_data, first_ed_visits_date=np.datetime64("2023-01-01") + ) + + # ED data should work + assert data.data_observed_disease_ed_visits is not None + assert len(data.data_observed_disease_ed_visits) == 1 + + # Others should be None + assert data.first_hospital_admissions_date is None + assert data.first_wastewater_date is None + assert data.data_observed_disease_hospital_admissions is None + assert data.data_observed_disease_wastewater_conc is None + + +def test_site_subpop_spine_with_auxiliary(): + """ + Test subpopulation creation when WW sites don't cover full population. + + When wastewater sampling sites don't cover the entire population, + an auxiliary subpopulation should be created for the remainder. + """ + ww_data = pl.DataFrame( + { + "date": ["2023-01-01", "2023-01-01"], + "site": ["site1", "site2"], + "site_index": [0, 1], + "site_pop": [200_000, 300_000], + "lab_site_index": [0, 1], + "log_genome_copies_per_ml": [1.0, 2.0], + "log_lod": [0.5, 0.5], + "below_lod": [0, 0], + }, + schema={ + "date": pl.Date, + "site": pl.String, + "site_index": pl.Int64, + "site_pop": pl.Int64, + "lab_site_index": pl.Int64, + "log_genome_copies_per_ml": pl.Float64, + "log_lod": pl.Float64, + "below_lod": pl.Int64, + }, + ) + + data = PyrenewHEWData( + nwss_training_data=ww_data, + population_size=1_000_000, # 500k not covered by WW + ) + + spine = data.site_subpop_spine + # Should have 3 subpops: 2 sites + 1 auxiliary + assert len(spine) == 3 + assert spine.filter(pl.col("site").is_null()).height == 1 + # Auxiliary subpop should have remaining population + assert spine.filter(pl.col("site").is_null())["subpop_pop"][0] == 500_000 + + +def test_site_subpop_spine_no_auxiliary(): + """ + Test when WW sites cover entire population. + + When sampling sites cover the full population, no auxiliary + subpopulation should be created. + """ + ww_data = pl.DataFrame( + { + "date": ["2023-01-01", "2023-01-01"], + "site": ["site1", "site2"], + "site_index": [0, 1], + "site_pop": [400_000, 600_000], + "lab_site_index": [0, 1], + "log_genome_copies_per_ml": [1.0, 2.0], + "log_lod": [0.5, 0.5], + "below_lod": [0, 0], + }, + schema={ + "date": pl.Date, + "site": pl.String, + "site_index": pl.Int64, + "site_pop": pl.Int64, + "lab_site_index": pl.Int64, + "log_genome_copies_per_ml": pl.Float64, + "log_lod": pl.Float64, + "below_lod": pl.Int64, + }, + ) + + data = PyrenewHEWData(nwss_training_data=ww_data, population_size=1_000_000) + + spine = data.site_subpop_spine + # Should have only 2 subpops + assert len(spine) == 2 + assert spine.filter(pl.col("site").is_null()).height == 0 + + +def test_censored_uncensored_split(): + """ + Test correct identification of censored vs uncensored observations. + + Wastewater observations below the limit of detection are censored. + This test verifies correct indexing of censored/uncensored data. + """ + ww_data = pl.DataFrame( + { + "date": ["2023-01-01"] * 4, + "site": ["site1"] * 4, + "site_index": [0] * 4, + "site_pop": [500_000] * 4, + "lab_site_index": [0] * 4, + "log_genome_copies_per_ml": [0.5, 1.5, 0.3, 2.0], + "log_lod": [1.0, 1.0, 1.0, 1.0], + "below_lod": [1, 0, 1, 0], # 2 censored, 2 uncensored + }, + schema={ + "date": pl.Date, + "site": pl.String, + "site_index": pl.Int64, + "site_pop": pl.Int64, + "lab_site_index": pl.Int64, + "log_genome_copies_per_ml": pl.Float64, + "log_lod": pl.Float64, + "below_lod": pl.Int64, + }, + ) + + data = PyrenewHEWData(nwss_training_data=ww_data, population_size=1_000_000) + + assert len(data.ww_censored) == 2 + assert len(data.ww_uncensored) == 2 + # Censored indices should be 0 and 2 + assert np.array_equal(data.ww_censored, [0, 2]) + assert np.array_equal(data.ww_uncensored, [1, 3]) + + +def test_n_days_post_init_single_source(): + """ + Test n_days_post_init with only one data source. + + With a single data source, n_days_post_init should equal + the number of days in that source. + """ + data = PyrenewHEWData( + n_ed_visits_data_days=30, first_ed_visits_date=np.datetime64("2023-01-01") + ) + + # Should be 30 days (Jan 1 to Jan 30 inclusive) + assert data.n_days_post_init == 30 + + +def test_n_days_post_init_multiple_sources(): + """ + Test with multiple overlapping/non-overlapping data sources. + + With multiple data sources, n_days_post_init should span from + the earliest first date to the latest last date. + """ + data = PyrenewHEWData( + n_ed_visits_data_days=20, + n_hospital_admissions_data_days=3, # 3 weeks = 21 days + first_ed_visits_date=np.datetime64("2023-01-01"), + first_hospital_admissions_date=np.datetime64("2023-01-07"), + ) + + # ED: Jan 1-20 (20 days) + # Hosp: Jan 7, 14, 21 (3 weeks ending Jan 21, so 21 days from start) + # Overall: Jan 1 - Jan 21 = 21 days + assert data.n_days_post_init == 21 + + +def test_lab_site_to_subpop_map(): + """ + Test correct mapping from lab sites to subpopulations. + + Multiple labs can sample from the same site. The mapping + should correctly associate each lab with its subpopulation. + """ + ww_data = pl.DataFrame( + { + "date": ["2023-01-01"] * 4, + "site": ["site1", "site1", "site2", "site2"], + "site_index": [0, 0, 1, 1], + "site_pop": [400_000] * 2 + [200_000] * 2, + "lab_site_index": [0, 1, 2, 3], # 4 labs, 2 sites + "log_genome_copies_per_ml": [1.0, 1.5, 2.0, 2.5], + "log_lod": [0.5] * 4, + "below_lod": [0] * 4, + }, + schema={ + "date": pl.Date, + "site": pl.String, + "site_index": pl.Int64, + "site_pop": pl.Int64, + "lab_site_index": pl.Int64, + "log_genome_copies_per_ml": pl.Float64, + "log_lod": pl.Float64, + "below_lod": pl.Int64, + }, + ) + + data = PyrenewHEWData( + nwss_training_data=ww_data, + population_size=1_000_000, # Creates auxiliary subpop 0 + ) + + # First 2 labs map to subpop 1 (site_index 0 + 1 for auxiliary) + # Next 2 labs map to subpop 2 (site_index 1 + 1 for auxiliary) + mapping = data.lab_site_to_subpop_map + assert len(mapping) == 4 + assert mapping[0] == 1 # lab 0 -> subpop 1 + assert mapping[1] == 1 # lab 1 -> subpop 1 + assert mapping[2] == 2 # lab 2 -> subpop 2 + assert mapping[3] == 2 # lab 3 -> subpop 2 + + +def test_date_time_spine(): + """ + Test creation of date-time spine for temporal indexing. + + The date_time_spine should map each date to its model time index. + """ + data = PyrenewHEWData( + n_ed_visits_data_days=10, first_ed_visits_date=np.datetime64("2023-01-01") + ) + + spine = data.date_time_spine + assert len(spine) == 10 + assert spine["t"][0] == 0 + assert spine["t"][9] == 9 + assert spine["date"][0] == dt.date(2023, 1, 1) + assert spine["date"][9] == dt.date(2023, 1, 10) + + +@pytest.mark.parametrize( + ["n_days", "expected_weeks", "has_partial_week"], + [ + (7, 1, False), # Exactly 1 week + (10, 1, True), # 1 week + 3 days + (13, 1, True), # 1 week + 6 days + (14, 2, False), # Exactly 2 weeks + (20, 2, True), # 2 weeks + 6 days + (21, 3, False), # Exactly 3 weeks + (1, 0, True), # Less than 1 week + (6, 0, True), # Less than 1 week + ], +) +def test_to_forecast_data_partial_weeks(n_days, expected_weeks, has_partial_week): + """ + Test to_forecast_data() with data spanning partial weeks. + + Hospital admissions are weekly, so we need to ensure correct + handling when the total number of days doesn't evenly divide by 7. + """ + first_date = np.datetime64("2023-01-01") # Sunday + first_saturday = np.datetime64("2023-01-07") # First Saturday + + data = PyrenewHEWData( + n_ed_visits_data_days=n_days, + n_hospital_admissions_data_days=0, + first_ed_visits_date=first_date, + first_hospital_admissions_date=first_saturday, + ) + + forecast_data = data.to_forecast_data(n_forecast_points=0) + + # Check that the number of weeks is calculated correctly + assert forecast_data.n_hospital_admissions_data_days == expected_weeks, ( + f"For {n_days} days, expected {expected_weeks} weeks, " + f"got {forecast_data.n_hospital_admissions_data_days}" + ) + + +def test_to_forecast_data_less_than_one_week(): + """ + Test to_forecast_data with less than a week of observation data. + + When observation data spans fewer than 7 days, hospital admissions + weeks should be 0 (since n_days // 7 = 0 for n_days < 7). + """ + # Start on Monday with only 5 days of data + data = PyrenewHEWData( + n_ed_visits_data_days=5, + first_ed_visits_date=np.datetime64("2023-01-02"), # Monday + ) + + # Create forecast data with 10 additional days (total 15 days) + forecast_data = data.to_forecast_data(n_forecast_points=10) + + # Verify forecast data structure + assert forecast_data.n_ed_visits_data_days == 15 + assert forecast_data.n_hospital_admissions_data_days == 15 // 7 # 2 weeks + assert forecast_data.first_ed_visits_date == np.datetime64("2023-01-02") + + # First Saturday on or after Jan 2 (Monday) is Jan 7 + assert forecast_data.first_hospital_admissions_date == np.datetime64("2023-01-07") + + # Test with even fewer days - only 3 days of observation + data_short = PyrenewHEWData( + n_ed_visits_data_days=3, + first_ed_visits_date=np.datetime64("2023-01-02"), # Monday + ) + + # No forecast points, just convert existing data + forecast_short = data_short.to_forecast_data(n_forecast_points=0) + + # With only 3 days total, should have 0 weeks of hospital admissions + assert forecast_short.n_ed_visits_data_days == 3 + assert forecast_short.n_hospital_admissions_data_days == 0 + + +def test_date_handling_leap_year(): + """ + Test date handling across leap year boundary (Feb 29). + + Ensures that date arithmetic works correctly with leap years. + """ + # Leap year: 2024 + leap_year_start = np.datetime64("2024-02-28") + leap_year_saturday = np.datetime64("2024-03-02") # Saturday after Feb 29 + + data = PyrenewHEWData( + n_ed_visits_data_days=10, # Spans Feb 28 - Mar 8, includes Feb 29 + n_hospital_admissions_data_days=2, + first_ed_visits_date=leap_year_start, + first_hospital_admissions_date=leap_year_saturday, + ) + + # Verify dates span the leap day + assert data.first_ed_visits_date == leap_year_start + assert data.last_ed_visits_date == leap_year_start + np.timedelta64(9, "D") + + # Feb 29 should exist in 2024 + feb_29_2024 = np.datetime64("2024-02-29") + assert data.first_ed_visits_date < feb_29_2024 < data.last_ed_visits_date + + # Verify n_days_post_init calculates correctly across leap day + # Feb 28 to Mar 2 is 4 days, so n_days_post_init should be the maximum + # span from first_data_date_overall (Feb 28) to last_data_date_overall + # ED: Feb 28 + 9 days = Mar 8 + # Hosp: Mar 2 + (2 weeks * 7 days) - 1 = Mar 2 + 13 = Mar 15 + # So n_days_post_init = Mar 15 - Feb 28 + 1 = 17 + expected_n_days = ( + data.last_data_date_overall - data.first_data_date_overall + ) // np.timedelta64(1, "D") + 1 + assert data.n_days_post_init == expected_n_days + + # Test forecast across leap year + forecast_data = data.to_forecast_data(n_forecast_points=30) + assert forecast_data.n_ed_visits_data_days == data.n_days_post_init + 30 + + +def test_date_handling_year_boundary(): + """ + Test date handling across year boundary (Dec 31 -> Jan 1). + + Ensures that date arithmetic works correctly when spanning + two different years. + """ + # Start late in December + year_end = np.datetime64("2023-12-28") # Thursday + year_end_saturday = np.datetime64("2023-12-30") # Saturday + + data = PyrenewHEWData( + n_ed_visits_data_days=14, # Spans Dec 28, 2023 - Jan 10, 2024 + n_hospital_admissions_data_days=2, + first_ed_visits_date=year_end, + first_hospital_admissions_date=year_end_saturday, + ) + + # Verify dates span the year boundary + assert data.first_ed_visits_date.astype("datetime64[Y]").astype(int) + 1970 == 2023 + assert data.last_ed_visits_date.astype("datetime64[Y]").astype(int) + 1970 == 2024 + + # Verify n_days_post_init is correct + assert data.n_days_post_init == 14 + + # Test that date_time_spine works across year boundary + spine = data.date_time_spine + assert len(spine) == 14 + assert spine["date"][0].year == 2023 + assert spine["date"][13].year == 2024 + + # Test forecast into new year + forecast_data = data.to_forecast_data(n_forecast_points=20) + assert forecast_data.n_ed_visits_data_days == 34