diff --git a/pyrenew/time.py b/pyrenew/time.py index 6799a2fd..095edb75 100644 --- a/pyrenew/time.py +++ b/pyrenew/time.py @@ -5,7 +5,12 @@ ISO standards, so 0 is Monday at 6 is Sunday. """ +import datetime as dt +from typing import Tuple, Union + import jax.numpy as jnp +import numpy as np +import polars as pl from jax.typing import ArrayLike @@ -49,6 +54,68 @@ def validate_dow(day_of_week: int, variable_name: str) -> None: return None +def convert_date(date: Union[dt.datetime, dt.date, np.datetime64]) -> dt.date: + """Normalize a date-like object to a python ``datetime.date``. + + The function accepts any of the common representations used in this + codebase and returns a ``datetime.date`` (i.e. without time component). + + Supported input types: + - ``numpy.datetime64``: converted to date (day precision) + - ``datetime.datetime``: converted via ``.date()`` + - ``datetime.date``: returned unchanged + + Parameters + ---------- + date + A date-like object to normalize. + + Returns + ------- + datetime.date + The corresponding date (with no time information). + + Notes + ----- + - ``numpy.datetime64`` objects are first normalized to day precision + (``datetime64[D]``) and then converted by computing the integer + number of days since the UNIX epoch and constructing a ``datetime.date``. + This is robust across NumPy versions where direct conversion to Python + datetimes can behave differently. + + - Fails fast for unsupported input types by raising a ``TypeError`` + """ + if isinstance(date, np.datetime64): + days_since_epoch = int(date.astype("datetime64[D]").astype("int")) + return dt.date(1970, 1, 1) + dt.timedelta(days=days_since_epoch) + if isinstance(date, dt.datetime): + return date.date() + if isinstance(date, dt.date): + return date + raise TypeError( + "convert_date expects a numpy.datetime64, datetime.datetime, or " + "datetime.date; got {t}".format(t=type(date)) + ) + + +def validate_mmwr_dates(dates: ArrayLike) -> None: + """ + Validate that dates are Saturdays (MMWR week endings). + + :param dates: Array of dates to validate + :raises ValueError: If any date is not a Saturday + """ + for date in dates: + if date is None: # Skip None values + continue + date = convert_date(date) + if date.weekday() != 5: # Saturday + raise ValueError( + f"MMWR dates must be Saturdays (weekday=5). " + f"Got {date.strftime('%A')} ({date.weekday()}) for {date}" + ) + + def daily_to_weekly( daily_values: ArrayLike, input_data_first_dow: int = 0, @@ -255,3 +322,383 @@ def mmwr_epiweekly_to_daily( output_data_first_dow=output_data_first_dow, week_start_dow=6, ) + + +def date_to_model_t( + date: Union[dt.datetime, np.datetime64], + start_date: Union[dt.datetime, np.datetime64], +) -> int: + """ + Convert calendar date to model time index. + + Parameters + ---------- + date + Target date + start_date + Date corresponding to model time t=0 + + Returns + ------- + int + Model time index (days since start_date) + """ + date = convert_date(date) + start_date = convert_date(start_date) + return (date - start_date).days + + +def model_t_to_date( + model_t: int, start_date: Union[dt.datetime, np.datetime64] +) -> dt.datetime: + """ + Convert model time index to calendar date. + + Parameters + ---------- + model_t + Model time index + start_date + Date corresponding to model time t=0 + + Returns + ------- + dt.datetime + Calendar date + """ + # Convert start_date to date, then make a datetime at midnight + start_date_date = convert_date(start_date) + start_date_dt = dt.datetime.combine(start_date_date, dt.time()) + return start_date_dt + dt.timedelta(days=model_t) + + +def get_observation_indices( + observed_dates: ArrayLike, + data_start_date: Union[dt.datetime, np.datetime64], + freq: str = "mmwr_weekly", +) -> jnp.ndarray: + """ + Get indices for observed data in aggregated time series. + + Parameters + ---------- + observed_dates + Dates of observations + data_start_date + Start date of the data series + freq + Frequency of aggregated data ("mmwr_weekly" or "weekly") + + Returns + ------- + jnp.ndarray + Indices for observed data points in aggregated series + + Raises + ------ + NotImplementedError + For unsupported frequencies + """ + data_start_date = convert_date(data_start_date) + + if freq == "mmwr_weekly": + # Calculate weeks since first Saturday (MMWR week end) + days_to_first_saturday = (5 - data_start_date.weekday()) % 7 + first_saturday = data_start_date + dt.timedelta(days=days_to_first_saturday) + + indices = [] + for obs_date in observed_dates: + obs_date = convert_date(obs_date) + weeks_diff = (obs_date - first_saturday).days // 7 + indices.append(weeks_diff) + return jnp.array(indices) + + elif freq == "weekly": + # Calculate weeks since first Monday (ISO week start) + days_to_first_monday = (7 - data_start_date.weekday()) % 7 + first_monday = data_start_date + dt.timedelta(days=days_to_first_monday) + + indices = [] + for obs_date in observed_dates: + obs_date = convert_date(obs_date) + weeks_diff = (obs_date - first_monday).days // 7 + indices.append(weeks_diff) + return jnp.array(indices) + + else: + raise NotImplementedError(f"Frequency '{freq}' not implemented") + + +def get_date_range_length(date_array: ArrayLike, timestep_days: int = 1) -> int: + """ + Calculate number of time steps in a date range. + + Parameters + ---------- + date_array + Array of observation dates + timestep_days + Days between consecutive points + + Returns + ------- + int + Number of time steps in the date range + """ + return ( + (max(date_array) - min(date_array)) // np.timedelta64(timestep_days, "D") + 1 + ).item() + + +def aggregate_with_dates( + daily_data: ArrayLike, + start_date: Union[dt.datetime, np.datetime64], + target_freq: str = "mmwr_weekly", +) -> Tuple[jnp.ndarray, dt.datetime]: + """ + Aggregate daily data with automatic date handling. + + Parameters + ---------- + daily_data + Daily time series + start_date + Date of first data point + target_freq + Target frequency ("mmwr_weekly" or "weekly") + + Returns + ------- + Tuple[jnp.ndarray, dt.datetime] + Tuple containing (aggregated_data, first_aggregated_date) + + Raises + ------ + ValueError + For unsupported frequencies + + Notes + ----- + Python's datetime.weekday uses 0=Monday..6=Sunday + which matches PyRenew's day-of-week indexing. + """ + start_date = convert_date(start_date) + + if target_freq == "mmwr_weekly": + first_dow = start_date.weekday() + + weekly_data = daily_to_mmwr_epiweekly(daily_data, first_dow) + + # Calculate first Saturday (MMWR week end) + days_to_saturday = (5 - start_date.weekday()) % 7 + first_weekly_date = start_date + dt.timedelta(days=days_to_saturday) + + elif target_freq == "weekly": + first_dow = start_date.weekday() + + weekly_data = daily_to_weekly(daily_data, first_dow, week_start_dow=0) + + # Calculate first Monday (ISO week start) + days_to_monday = (7 - start_date.weekday()) % 7 + first_weekly_date = start_date + dt.timedelta(days=days_to_monday) + + else: + raise ValueError( + f"Unsupported target frequency: {target_freq}" + ) # pragma: no cover + + return weekly_data, first_weekly_date + + +def create_date_time_spine( + start_date: Union[dt.datetime, np.datetime64], + end_date: Union[dt.datetime, np.datetime64], + freq: str = "1d", +) -> pl.DataFrame: + """ + Create a DataFrame mapping calendar dates to model time indices. + + Parameters + ---------- + start_date + First date (becomes t=0) + end_date + Last date + freq + Frequency string for polars date_range + + Returns + ------- + pl.DataFrame + DataFrame with 'date' and 't' columns + """ + # Normalize inputs to datetime.date for polars compatibility + start_date = convert_date(start_date) + end_date = convert_date(end_date) + + return ( + pl.DataFrame( + { + "date": pl.date_range( + start=start_date, + end=end_date, + interval=freq, + eager=True, + ) + } + ) + .with_row_index("t") + .with_columns(pl.col("t").cast(pl.Int64)) + ) + + +def get_end_date( + start_date: Union[dt.datetime, np.datetime64], n_points: int, timestep_days: int = 1 +) -> Union[np.datetime64, None]: + """ + Calculate end date from start date and number of data points. + + Parameters + ---------- + start_date + First date in the series + n_points + Number of data points + timestep_days + Days between consecutive points + + Returns + ------- + Union[np.datetime64, None] + Date of the last data point + + Raises + ------ + ValueError + If n_points is non-positive + """ + if start_date is None: + if n_points > 0: + raise ValueError( + f"Must provide start_date if n_points > 0. " + f"Got n_points={n_points} with start_date=None" + ) + return None + + if n_points < 0: + raise ValueError(f"n_points must be positive, got {n_points}") + + # Normalize to a datetime.date and then to numpy.datetime64 (day precision) + sd = convert_date(start_date) + start_date = np.datetime64(sd) + + return start_date + np.timedelta64((n_points - 1) * timestep_days, "D") + + +def get_n_data_days( + n_points: int = None, date_array: ArrayLike = None, timestep_days: int = 1 +) -> int: + """ + Determine data length from either point count or date array. + + Parameters + ---------- + n_points + Explicit number of data points + date_array + Array of observation dates + timestep_days + Days between consecutive points + + Returns + ------- + int + Number of data points. Returns 0 if both n_points and date_array are None. + + Raises + ------ + ValueError + If both n_points and date_array are provided. + """ + if n_points is None and date_array is None: + return 0 + elif date_array is not None and n_points is not None: + raise ValueError("Must provide at most one of n_points and date_array") + elif date_array is not None: + return get_date_range_length(date_array, timestep_days) + else: + return n_points + + +def align_observation_times( + observation_dates: ArrayLike, + model_start_date: Union[dt.datetime, np.datetime64], + aggregation_freq: str = "daily", +) -> jnp.ndarray: + """ + Convert observation dates to model time indices with temporal aggregation. + + Parameters + ---------- + observation_dates + Dates when observations occurred + model_start_date + Date corresponding to model time t=0 + aggregation_freq + Temporal aggregation ("daily", "weekly", "mmwr_weekly") + + Returns + ------- + jnp.ndarray + Model time indices for observations + + Raises + ------ + NotImplementedError + For unsupported frequencies + """ + if aggregation_freq == "daily": + return jnp.array( + [date_to_model_t(date, model_start_date) for date in observation_dates] + ) + elif aggregation_freq in ["weekly", "mmwr_weekly"]: + return get_observation_indices( + observation_dates, model_start_date, aggregation_freq + ) + else: + raise NotImplementedError(f"Frequency '{aggregation_freq}' not supported") + + +def get_first_week_on_or_after_t0( + model_t_first_weekly_value: int, week_interval_days: int = 7 +) -> int: + """ + Find the first weekly index where the week ends on or after model t=0. + + Parameters + ---------- + model_t_first_weekly_value + Model time of the first weekly value + (often negative during initialization period). Represents week-ending date. + week_interval_days + Days between consecutive weekly values. Default 7. + + Returns + ------- + int + Index of first week ending on or after model t=0. + + Notes + ----- + Weekly values are indexed 0, 1, 2, ... and occur at model times: + - Week 0: model_t_first_weekly_value + - Week k: model_t_first_weekly_value + k * week_interval_days + + We find min k such that: model_t_first_weekly_value + k * week_interval_days >= 0 + Equivalently: k >= ceil(-model_t_first_weekly_value / week_interval_days) + Using ceiling division identity: ceil(-x / d) = (-x - 1) // d + 1 + """ + if model_t_first_weekly_value >= 0: + return 0 + + return (-model_t_first_weekly_value - 1) // week_interval_days + 1 diff --git a/test/test_time.py b/test/test_time.py index b3b44b80..8ca09525 100644 --- a/test/test_time.py +++ b/test/test_time.py @@ -2,15 +2,148 @@ Tests for the pyrenew.time module. """ +import datetime as dt import itertools import jax.numpy as jnp +import numpy as np +import polars as pl import pytest from numpy.testing import assert_array_equal import pyrenew.time as ptime +def test_convert_date_with_datetime(): + """Test convert_date with datetime.datetime input.""" + dt_in = dt.datetime(2025, 1, 18, 15, 30) + out = ptime.convert_date(dt_in) + assert isinstance(out, dt.date) + assert out == dt.date(2025, 1, 18) + + +def test_convert_date_with_date(): + """Test convert_date with datetime.date input.""" + d_in = dt.date(2025, 1, 18) + out = ptime.convert_date(d_in) + assert isinstance(out, dt.date) + assert out == d_in + + +def test_convert_date_with_numpy_datetime64(): + """Test convert_date with numpy.datetime64 input.""" + np_in = np.datetime64("2025-01-18") + out = ptime.convert_date(np_in) + assert isinstance(out, dt.date) + assert out == dt.date(2025, 1, 18) + + +@pytest.mark.parametrize("bad", [None, "2025-01-18", 123, 12.34]) +def test_convert_date_unsupported_types_raise(bad): + """Test convert_date raises TypeError for unsupported types.""" + with pytest.raises(TypeError): + ptime.convert_date(bad) + + +def test_get_observation_indices_mmwr_and_weekly(): + """Test get_observation_indices with MMWR and weekly frequencies.""" + start = dt.datetime(2025, 1, 1) # Wednesday + observed = [dt.datetime(2025, 1, 4), np.datetime64("2025-01-11")] # two Saturdays + mmwr_idx = ptime.get_observation_indices(observed, start, freq="mmwr_weekly") + weekly_idx = ptime.get_observation_indices(observed, start, freq="weekly") + assert isinstance(mmwr_idx, jnp.ndarray) + assert isinstance(weekly_idx, jnp.ndarray) + + +def test_get_observation_indices_bad_freq_raises(): + """Test get_observation_indices raises for unsupported frequency.""" + with pytest.raises(NotImplementedError): + ptime.get_observation_indices( + [dt.datetime(2025, 1, 4)], dt.datetime(2025, 1, 1), freq="monthly" + ) + + +def test_get_date_range_length_and_get_n_data_days(): + """Test get_date_range_length and get_n_data_days.""" + arr = [np.datetime64("2025-01-01"), np.datetime64("2025-01-05")] + assert ptime.get_date_range_length(arr, timestep_days=1) == 5 + assert ptime.get_n_data_days(date_array=arr, timestep_days=1) == 5 + + +def test_aggregate_with_dates_variants(): + """Test aggregate_with_dates with different frequencies.""" + daily = jnp.arange(1, 15) + weekly_mmwr, first_mmwr = ptime.aggregate_with_dates( + daily, dt.datetime(2025, 1, 1), target_freq="mmwr_weekly" + ) + weekly_iso, first_iso = ptime.aggregate_with_dates( + daily, np.datetime64("2025-01-01"), target_freq="weekly" + ) + assert weekly_mmwr.shape[0] >= 1 + assert isinstance(first_mmwr, dt.date) + assert weekly_iso.shape[0] >= 1 + assert isinstance(first_iso, dt.date) + + +def test_create_date_time_spine_with_various_inputs(): + """Test create_date_time_spine with various input types.""" + df1 = ptime.create_date_time_spine(dt.datetime(2025, 1, 1), dt.datetime(2025, 1, 3)) + df2 = ptime.create_date_time_spine( + np.datetime64("2025-01-01"), np.datetime64("2025-01-03") + ) + assert set(df1.columns) == {"date", "t"} + assert set(df2.columns) == {"date", "t"} + + +def test_get_end_date_and_errors(): + """Test get_end_date with various inputs and error conditions.""" + # None with 0 points returns None + assert ptime.get_end_date(None, 0) is None + # None with >0 raises + with pytest.raises(ValueError): + ptime.get_end_date(None, 1) + # negative n_points raises + with pytest.raises(ValueError): + ptime.get_end_date(dt.datetime(2025, 1, 1), -1) + # normal usages + res_dt = ptime.get_end_date(dt.datetime(2025, 1, 1), 10) + res_np = ptime.get_end_date(np.datetime64("2025-01-01"), 10) + assert isinstance(res_dt, np.datetime64) + assert isinstance(res_np, np.datetime64) + + def test_aggregate_with_dates_unsupported_freq_raises(): + """Test aggregate_with_dates raises for unsupported frequency.""" + with pytest.raises(ValueError): + ptime.aggregate_with_dates( + jnp.arange(1, 10), dt.datetime(2025, 1, 1), target_freq="monthly" + ) + + +def test_align_observation_times_and_first_week(): + """Test align_observation_times and get_first_week_on_or_after_t0.""" + obs = [dt.datetime(2025, 1, 2), np.datetime64("2025-01-03")] + # daily + daily_idx = ptime.align_observation_times( + obs, dt.datetime(2025, 1, 1), aggregation_freq="daily" + ) + assert isinstance(daily_idx, jnp.ndarray) + # weekly aggregator + weekly_idx = ptime.align_observation_times( + obs, dt.datetime(2025, 1, 1), aggregation_freq="weekly" + ) + assert isinstance(weekly_idx, jnp.ndarray) + # bad aggregator raises + with pytest.raises(NotImplementedError): + ptime.align_observation_times( + obs, dt.datetime(2025, 1, 1), aggregation_freq="monthly" + ) + + # first week calculations + assert ptime.get_first_week_on_or_after_t0(0) == 0 + assert ptime.get_first_week_on_or_after_t0(5) == 0 + assert ptime.get_first_week_on_or_after_t0(-1) >= 0 + + def test_validate_dow() -> None: """ Test that validate_dow raises appropriate @@ -247,3 +380,353 @@ def test_daily_to_weekly_2d_values_with_offset(): assert jnp.array_equal(result, expected) assert jnp.array_equal(result_leading_zero_offeset, expected) assert jnp.array_equal(result_leading_zero_no_offeset, expected_leading_zero) + + +# validate_mmwr_dates tests +def test_validate_mmwr_dates_valid_saturdays(): + """Valid Saturday dates should not raise.""" + saturdays = [ + dt.datetime(2025, 1, 4), # Saturday + dt.datetime(2025, 1, 11), # Saturday + np.datetime64("2025-01-18"), # Saturday + ] + ptime.validate_mmwr_dates(saturdays) # Should not raise + + +def test_validate_mmwr_dates_invalid_weekday(): + """Non-Saturday dates should raise ValueError.""" + with pytest.raises(ValueError, match="MMWR dates must be Saturdays"): + ptime.validate_mmwr_dates([dt.datetime(2025, 1, 6)]) # Monday + + +def test_validate_mmwr_dates_with_none(): + """None values should be skipped.""" + dates_with_none = [ + None, + dt.datetime(2025, 1, 4), # Saturday + None, + np.datetime64("2025-01-11"), # Saturday + ] + ptime.validate_mmwr_dates(dates_with_none) # Should not raise + + +def test_validate_mmwr_dates_mixed_types(): + """Mix of datetime and np.datetime64 should work.""" + mixed_dates = [ + dt.datetime(2025, 1, 4), # Saturday + np.datetime64("2025-01-11"), # Saturday + ] + ptime.validate_mmwr_dates(mixed_dates) # Should not raise + + +def test_validate_mmwr_dates_empty_array(): + """Empty array should not raise.""" + ptime.validate_mmwr_dates([]) # Should not raise + + +# date_to_model_t tests +def test_date_to_model_t_same_date(): + """Same date as start_date should return 0.""" + start = dt.datetime(2025, 1, 1) + assert ptime.date_to_model_t(start, start) == 0 + + +def test_date_to_model_t_future_date(): + """Future dates should return positive integers.""" + start = dt.datetime(2025, 1, 1) + future = dt.datetime(2025, 1, 15) + assert ptime.date_to_model_t(future, start) == 14 + + +def test_date_to_model_t_past_date(): + """Past dates should return negative integers.""" + start = dt.datetime(2025, 1, 15) + past = dt.datetime(2025, 1, 1) + assert ptime.date_to_model_t(past, start) == -14 + + +def test_date_to_model_t_datetime_types(): + """Test all combinations of datetime and np.datetime64.""" + start_dt = dt.datetime(2025, 1, 1) + start_np = np.datetime64("2025-01-01") + date_dt = dt.datetime(2025, 1, 15) + date_np = np.datetime64("2025-01-15") + + assert ptime.date_to_model_t(date_dt, start_dt) == 14 + assert ptime.date_to_model_t(date_dt, start_np) == 14 + assert ptime.date_to_model_t(date_np, start_dt) == 14 + assert ptime.date_to_model_t(date_np, start_np) == 14 + + +def test_date_to_model_t_leap_year(): + """Verify correct calculation across Feb 29.""" + start = dt.datetime(2024, 2, 28) + after_leap = dt.datetime(2024, 3, 1) + assert ptime.date_to_model_t(after_leap, start) == 2 + + +def test_date_to_model_t_year_boundary(): + """Test calculation across year boundary.""" + start = dt.datetime(2024, 12, 30) + new_year = dt.datetime(2025, 1, 2) + assert ptime.date_to_model_t(new_year, start) == 3 + + +# model_t_to_date tests +def test_model_t_to_date_t_zero(): + """t=0 should return start_date.""" + start = dt.datetime(2025, 1, 1) + result = ptime.model_t_to_date(0, start) + assert result == start + + +def test_model_t_to_date_positive_t(): + """Positive t should return future dates.""" + start = dt.datetime(2025, 1, 1) + result = ptime.model_t_to_date(14, start) + assert result == dt.datetime(2025, 1, 15) + + +def test_model_t_to_date_negative_t(): + """Negative t should return past dates.""" + start = dt.datetime(2025, 1, 15) + result = ptime.model_t_to_date(-14, start) + assert result == dt.datetime(2025, 1, 1) + + +def test_model_t_to_date_roundtrip(): + """Verify roundtrip consistency with date_to_model_t.""" + start = dt.datetime(2025, 1, 1) + for t in [-10, 0, 7, 30, 365]: + date = ptime.model_t_to_date(t, start) + assert ptime.date_to_model_t(date, start) == t + + +def test_model_t_to_date_input_types(): + """Test datetime vs np.datetime64 start_date.""" + start_dt = dt.datetime(2025, 1, 1) + start_np = np.datetime64("2025-01-01") + + result_dt = ptime.model_t_to_date(14, start_dt) + result_np = ptime.model_t_to_date(14, start_np) + + assert result_dt == dt.datetime(2025, 1, 15) + assert result_np == dt.datetime(2025, 1, 15) + + +# get_date_range_length tests +def test_get_date_range_length_default_timestep(): + """Test get_date_range_length with default timestep_days=1.""" + dates = np.array( + [ + np.datetime64("2025-01-01"), + np.datetime64("2025-01-15"), + ] + ) + assert ptime.get_date_range_length(dates) == 15 + + +def test_get_date_range_length_weekly_timestep(): + """Test get_date_range_length with timestep_days=7.""" + dates = np.array( + [ + np.datetime64("2025-01-01"), + np.datetime64("2025-01-29"), + ] + ) + assert ptime.get_date_range_length(dates, timestep_days=7) == 5 + + +def test_get_date_range_length_single_date(): + """Test get_date_range_length with single date.""" + dates = np.array([np.datetime64("2025-01-01")]) + assert ptime.get_date_range_length(dates) == 1 + + +def test_get_date_range_length_multiple_dates(): + """Test get_date_range_length with multiple dates in array.""" + dates = np.array( + [ + np.datetime64("2025-01-01"), + np.datetime64("2025-01-08"), + np.datetime64("2025-01-15"), + np.datetime64("2025-01-31"), + ] + ) + # Should use min to max + assert ptime.get_date_range_length(dates) == 31 + + +# get_end_date tests +def test_get_end_date_basic(): + """Test get_end_date with various n_points.""" + start = dt.datetime(2025, 1, 1) + assert ptime.get_end_date(start, 1) == np.datetime64("2025-01-01") + assert ptime.get_end_date(start, 7) == np.datetime64("2025-01-07") + assert ptime.get_end_date(start, 31) == np.datetime64("2025-01-31") + + +def test_get_end_date_n_points_one(): + """Test get_end_date with n_points=1.""" + start = dt.datetime(2025, 1, 15) + result = ptime.get_end_date(start, 1) + assert result == np.datetime64("2025-01-15") + + +def test_get_end_date_negative_n_points(): + """Test get_end_date raises for negative n_points.""" + start = dt.datetime(2025, 1, 1) + with pytest.raises(ValueError, match="n_points must be positive"): + ptime.get_end_date(start, -5) + + +def test_get_end_date_none_start_with_zero_points(): + """Test get_end_date with None start_date and n_points=0.""" + result = ptime.get_end_date(None, 0) + assert result is None + + +def test_get_end_date_none_start_with_positive_points(): + """Test get_end_date raises for None start_date with positive n_points.""" + with pytest.raises(ValueError, match="Must provide start_date"): + ptime.get_end_date(None, 5) + + +def test_get_end_date_weekly_timestep(): + """Test get_end_date with timestep_days=7.""" + start = dt.datetime(2025, 1, 1) + result = ptime.get_end_date(start, 4, timestep_days=7) + assert result == np.datetime64("2025-01-22") + + +def test_get_end_date_input_types(): + """Test get_end_date with different input types.""" + start_dt = dt.datetime(2025, 1, 1) + start_np = np.datetime64("2025-01-01") + + result_dt = ptime.get_end_date(start_dt, 10) + result_np = ptime.get_end_date(start_np, 10) + + assert result_dt == np.datetime64("2025-01-10") + assert result_np == np.datetime64("2025-01-10") + + +# get_n_data_days tests +def test_get_n_data_days_with_n_points(): + """Test get_n_data_days with n_points specified.""" + assert ptime.get_n_data_days(n_points=15) == 15 + assert ptime.get_n_data_days(n_points=0) == 0 + + +def test_get_n_data_days_with_date_array(): + """Test get_n_data_days with date_array specified.""" + dates = np.array( + [ + np.datetime64("2025-01-01"), + np.datetime64("2025-01-15"), + ] + ) + assert ptime.get_n_data_days(date_array=dates) == 15 + + +def test_get_n_data_days_neither(): + """Test get_n_data_days with both parameters None.""" + assert ptime.get_n_data_days() == 0 + + +def test_get_n_data_days_both(): + """Test get_n_data_days raises when both parameters specified.""" + dates = np.array([np.datetime64("2025-01-01")]) + with pytest.raises(ValueError, match="Must provide at most one"): + ptime.get_n_data_days(n_points=10, date_array=dates) + + +def test_get_n_data_days_weekly_timestep(): + """Test get_n_data_days with timestep_days=7.""" + dates = np.array( + [ + np.datetime64("2025-01-01"), + np.datetime64("2025-01-29"), + ] + ) + assert ptime.get_n_data_days(date_array=dates, timestep_days=7) == 5 + + +def test_create_date_time_spine_daily(): + """Test create_date_time_spine with daily frequency.""" + start = dt.datetime(2025, 1, 1) + end = dt.datetime(2025, 1, 5) + result = ptime.create_date_time_spine(start, end, freq="1d") + + assert isinstance(result, pl.DataFrame) + assert result.shape[0] == 5 + assert "date" in result.columns + assert "t" in result.columns + assert len(result.columns) == 2 + + +def test_create_date_time_spine_columns(): + """Test create_date_time_spine output columns.""" + start = dt.datetime(2025, 1, 1) + end = dt.datetime(2025, 1, 3) + result = ptime.create_date_time_spine(start, end) + + assert "date" in result.columns + assert "t" in result.columns + assert result.schema["t"] == pl.Int64 + + +def test_create_date_time_spine_t_starts_at_zero(): + """Test create_date_time_spine starts at t=0.""" + start = dt.datetime(2025, 1, 1) + end = dt.datetime(2025, 1, 5) + result = ptime.create_date_time_spine(start, end) + + assert result["t"][0] == 0 + + +def test_create_date_time_spine_t_increments(): + """Test create_date_time_spine t increments correctly.""" + start = dt.datetime(2025, 1, 1) + end = dt.datetime(2025, 1, 5) + result = ptime.create_date_time_spine(start, end) + + assert list(result["t"]) == [0, 1, 2, 3, 4] + + +def test_create_date_time_spine_single_day(): + """Test create_date_time_spine with single day.""" + start = dt.datetime(2025, 1, 15) + end = dt.datetime(2025, 1, 15) + result = ptime.create_date_time_spine(start, end) + + assert result.shape[0] == 1 + assert result["t"][0] == 0 + + +def test_create_date_time_spine_input_types(): + """Test create_date_time_spine with different input types.""" + start_dt = dt.datetime(2025, 1, 1) + end_dt = dt.datetime(2025, 1, 3) + start_np = np.datetime64("2025-01-01") + end_np = np.datetime64("2025-01-03") + + result_dt = ptime.create_date_time_spine(start_dt, end_dt) + result_np = ptime.create_date_time_spine(start_np, end_np) + result_mixed = ptime.create_date_time_spine(start_dt, end_np) + + assert result_dt.shape[0] == 3 + assert result_np.shape[0] == 3 + assert result_mixed.shape[0] == 3 + + +def test_create_date_time_spine_date_values(): + """Test create_date_time_spine date values.""" + start = dt.datetime(2025, 1, 1) + end = dt.datetime(2025, 1, 3) + result = ptime.create_date_time_spine(start, end) + + dates = result["date"].to_list() + assert dates[0] == dt.date(2025, 1, 1) + assert dates[1] == dt.date(2025, 1, 2) + assert dates[2] == dt.date(2025, 1, 3)