diff --git a/src/cloudcasting/dataset.py b/src/cloudcasting/dataset.py index ad6fa68..552ff87 100644 --- a/src/cloudcasting/dataset.py +++ b/src/cloudcasting/dataset.py @@ -102,6 +102,7 @@ def __init__( variables: list[str] | str | None = None, preshuffle: bool = False, nan_to_num: bool = False, + return_time_features: bool = False, ): """A torch Dataset for loading past and future satellite data @@ -115,6 +116,10 @@ def __init__( variables: The variables to load from the satellite data (defaults to all) preshuffle: Whether to shuffle the data - useful for validation nan_to_num: Whether to convert NaNs to -1. + return_time_features: If True, will calculate features relating to time of + day, time of year etc, and return those as a third + return value of the dataset (X, y, time_features) as + opposed to (X, y). """ # Load the sat zarr file or list of files and slice the data to the given period @@ -144,6 +149,7 @@ def __init__( self.forecast_mins = forecast_mins self.sample_freq_mins = sample_freq_mins self.nan_to_num = nan_to_num + self.return_time_features = return_time_features @staticmethod def _find_t0_times( @@ -154,14 +160,39 @@ def _find_t0_times( def __len__(self) -> int: return len(self.t0_times) - def _get_datetime(self, t0: datetime) -> tuple[NDArray[np.float32], NDArray[np.float32]]: - ds_sel = self.ds.sel( - time=slice( - t0 - timedelta(minutes=self.history_mins), - t0 + timedelta(minutes=self.forecast_mins), - ) + def _get_time_features(self, dt_range: tuple[pd.Timestamp, pd.Timestamp]) -> NDArray[np.float32]: + # get dates in the requested range [from history_mins to forecast_mins centered on t0] + dates = pd.date_range(*dt_range, freq=timedelta(minutes=self.sample_freq_mins)) + + # calculate numerical values for the times and days, normalized to [0, 1] + hours = np.array( + [dt.hour + dt.minute/60 + dt.second/3600 for dt in dates] + ) / 24 + days = np.array( + [dt.day_of_year - 1 + h for dt, h in zip(dates, hours)] # -1 since it's 1-indexed + ) / 366 # leap years + + # use sin/cos features for the hours and days -- trying to capture how far along + # the cycle of time we are. should correlate with the darkening of visible channels, + # and capture seasonal variation. + time_features = np.stack( + ( + np.cos(2 * np.pi * hours), + np.sin(2 * np.pi * hours), + np.cos(2 * np.pi * days), + np.sin(2 * np.pi * days), + ), + axis = -1 ) + return time_features + + def _get_datetime(self, t0: datetime) -> tuple[NDArray[np.float32], NDArray[np.float32]] | tuple[NDArray[np.float32], NDArray[np.float32], NDArray[np.float32]]: + + t_range = (t0 - timedelta(minutes=self.history_mins), t0 + timedelta(minutes=self.forecast_mins)) + + ds_sel = self.ds.sel(time=slice(*t_range)) + # Load the data eagerly so that the same chunks aren't loaded multiple times after we split # further ds_sel = ds_sel.compute(scheduler="single-threaded") @@ -180,9 +211,13 @@ def _get_datetime(self, t0: datetime) -> tuple[NDArray[np.float32], NDArray[np.f X = np.nan_to_num(X, nan=-1) y = np.nan_to_num(y, nan=-1) + if self.return_time_features: + time_features = self._get_time_features((t0 - timedelta(minutes=self.history_mins), t0)).astype(np.float32) + return X.astype(np.float32), y.astype(np.float32), time_features + return X.astype(np.float32), y.astype(np.float32) - def __getitem__(self, key: DataIndex) -> tuple[NDArray[np.float32], NDArray[np.float32]]: + def __getitem__(self, key: DataIndex) -> tuple[NDArray[np.float32], NDArray[np.float32]] | tuple[NDArray[np.float32], NDArray[np.float32], NDArray[np.float32]]: if isinstance(key, int): t0 = self.t0_times[key] @@ -299,6 +334,7 @@ def __init__( nan_to_num: bool = False, pin_memory: bool = False, persistent_workers: bool = False, + return_time_features: bool = False, ): """A lightning DataModule for loading past and future satellite data @@ -341,6 +377,7 @@ def __init__( self.history_mins = history_mins self.forecast_mins = forecast_mins self.sample_freq_mins = sample_freq_mins + self.return_time_features = return_time_features self._common_dataloader_kwargs = DataloaderArgs( batch_size=batch_size, @@ -371,6 +408,7 @@ def _make_dataset( preshuffle=preshuffle, nan_to_num=self.nan_to_num, variables=self.variables, + return_time_features=self.return_time_features, ) def train_dataloader(self) -> DataLoader[tuple[NDArray[np.float32], NDArray[np.float32]]]: diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 6190f5d..82ddda9 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -196,3 +196,53 @@ def test_validation_dataset_raises_error(sat_zarr_path): forecast_mins=FORECAST_HORIZON_MINUTES, sample_freq_mins=DATA_INTERVAL_SPACING_MINUTES, ) + +def test_time_feature_validity(sat_zarr_path): + dataset = SatelliteDataset( + zarr_path=sat_zarr_path, + start_time=None, + end_time=None, + history_mins=0, + forecast_mins=15, + sample_freq_mins=5, + nan_to_num=True, + return_time_features=True, + ) + + # Test access + data = dataset["2023-01-01 00:00:00"] + assert len(data) == 3 + + # Test midnight (hour = 0) + midnight = ( + pd.Timestamp("2023-01-01 00:00:00"), + pd.Timestamp("2023-01-01 00:15:00") + ) + features = dataset._get_time_features(midnight) + assert features.shape[1] == 4 # [cos_hour, sin_hour, cos_day, sin_day] + + # At midnight: cos(hour)=1, sin(hour)=0 + np.testing.assert_allclose(features[0, 0], 1, atol=1e-6) # cos(hour) + np.testing.assert_allclose(features[0, 1], 0, atol=1e-6) # sin(hour) + + # Test noon (hour = 12) + noon = ( + pd.Timestamp("2023-01-01 12:00:00"), + pd.Timestamp("2023-01-01 12:15:00") + ) + features = dataset._get_time_features(noon) + + # At noon: cos(hour)=-1, sin(hour)≈0 + np.testing.assert_allclose(features[0, 0], -1, atol=1e-6) # cos(hour) + np.testing.assert_allclose(features[0, 1], 0, atol=1e-6) # sin(hour) + + # Test start of year + new_year = ( + pd.Timestamp("2023-01-01 00:00:00"), + pd.Timestamp("2023-01-01 00:15:00") + ) + features = dataset._get_time_features(new_year) + + # At start of year: cos(day)=1, sin(day)=0 + np.testing.assert_allclose(features[0, 2], 1, atol=1e-6) # cos(day) + np.testing.assert_allclose(features[0, 3], 0, atol=1e-6) # sin(day) \ No newline at end of file