Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 45 additions & 7 deletions src/cloudcasting/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def __init__(
variables: list[str] | str | None = None,
preshuffle: bool = False,
nan_to_num: bool = False,
return_time_features: bool = False,
):
"""A torch Dataset for loading past and future satellite data

Expand All @@ -115,6 +116,10 @@ def __init__(
variables: The variables to load from the satellite data (defaults to all)
preshuffle: Whether to shuffle the data - useful for validation
nan_to_num: Whether to convert NaNs to -1.
return_time_features: If True, will calculate features relating to time of
day, time of year etc, and return those as a third
return value of the dataset (X, y, time_features) as
opposed to (X, y).
"""

# Load the sat zarr file or list of files and slice the data to the given period
Expand Down Expand Up @@ -144,6 +149,7 @@ def __init__(
self.forecast_mins = forecast_mins
self.sample_freq_mins = sample_freq_mins
self.nan_to_num = nan_to_num
self.return_time_features = return_time_features

@staticmethod
def _find_t0_times(
Expand All @@ -154,14 +160,39 @@ def _find_t0_times(
def __len__(self) -> int:
return len(self.t0_times)

def _get_datetime(self, t0: datetime) -> tuple[NDArray[np.float32], NDArray[np.float32]]:
ds_sel = self.ds.sel(
time=slice(
t0 - timedelta(minutes=self.history_mins),
t0 + timedelta(minutes=self.forecast_mins),
)
def _get_time_features(self, dt_range: tuple[pd.Timestamp, pd.Timestamp]) -> NDArray[np.float32]:
# get dates in the requested range [from history_mins to forecast_mins centered on t0]
dates = pd.date_range(*dt_range, freq=timedelta(minutes=self.sample_freq_mins))

# calculate numerical values for the times and days, normalized to [0, 1]
hours = np.array(
[dt.hour + dt.minute/60 + dt.second/3600 for dt in dates]
) / 24
days = np.array(
[dt.day_of_year - 1 + h for dt, h in zip(dates, hours)] # -1 since it's 1-indexed
) / 366 # leap years

# use sin/cos features for the hours and days -- trying to capture how far along
# the cycle of time we are. should correlate with the darkening of visible channels,
# and capture seasonal variation.
time_features = np.stack(
(
np.cos(2 * np.pi * hours),
np.sin(2 * np.pi * hours),
np.cos(2 * np.pi * days),
np.sin(2 * np.pi * days),
),
axis = -1
)

return time_features

def _get_datetime(self, t0: datetime) -> tuple[NDArray[np.float32], NDArray[np.float32]] | tuple[NDArray[np.float32], NDArray[np.float32], NDArray[np.float32]]:

t_range = (t0 - timedelta(minutes=self.history_mins), t0 + timedelta(minutes=self.forecast_mins))

ds_sel = self.ds.sel(time=slice(*t_range))

# Load the data eagerly so that the same chunks aren't loaded multiple times after we split
# further
ds_sel = ds_sel.compute(scheduler="single-threaded")
Expand All @@ -180,9 +211,13 @@ def _get_datetime(self, t0: datetime) -> tuple[NDArray[np.float32], NDArray[np.f
X = np.nan_to_num(X, nan=-1)
y = np.nan_to_num(y, nan=-1)

if self.return_time_features:
time_features = self._get_time_features((t0 - timedelta(minutes=self.history_mins), t0)).astype(np.float32)
return X.astype(np.float32), y.astype(np.float32), time_features

return X.astype(np.float32), y.astype(np.float32)

def __getitem__(self, key: DataIndex) -> tuple[NDArray[np.float32], NDArray[np.float32]]:
def __getitem__(self, key: DataIndex) -> tuple[NDArray[np.float32], NDArray[np.float32]] | tuple[NDArray[np.float32], NDArray[np.float32], NDArray[np.float32]]:
if isinstance(key, int):
t0 = self.t0_times[key]

Expand Down Expand Up @@ -299,6 +334,7 @@ def __init__(
nan_to_num: bool = False,
pin_memory: bool = False,
persistent_workers: bool = False,
return_time_features: bool = False,
):
"""A lightning DataModule for loading past and future satellite data

Expand Down Expand Up @@ -341,6 +377,7 @@ def __init__(
self.history_mins = history_mins
self.forecast_mins = forecast_mins
self.sample_freq_mins = sample_freq_mins
self.return_time_features = return_time_features

self._common_dataloader_kwargs = DataloaderArgs(
batch_size=batch_size,
Expand Down Expand Up @@ -371,6 +408,7 @@ def _make_dataset(
preshuffle=preshuffle,
nan_to_num=self.nan_to_num,
variables=self.variables,
return_time_features=self.return_time_features,
)

def train_dataloader(self) -> DataLoader[tuple[NDArray[np.float32], NDArray[np.float32]]]:
Expand Down
50 changes: 50 additions & 0 deletions tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,3 +196,53 @@ def test_validation_dataset_raises_error(sat_zarr_path):
forecast_mins=FORECAST_HORIZON_MINUTES,
sample_freq_mins=DATA_INTERVAL_SPACING_MINUTES,
)

def test_time_feature_validity(sat_zarr_path):
dataset = SatelliteDataset(
zarr_path=sat_zarr_path,
start_time=None,
end_time=None,
history_mins=0,
forecast_mins=15,
sample_freq_mins=5,
nan_to_num=True,
return_time_features=True,
)

# Test access
data = dataset["2023-01-01 00:00:00"]
assert len(data) == 3

# Test midnight (hour = 0)
midnight = (
pd.Timestamp("2023-01-01 00:00:00"),
pd.Timestamp("2023-01-01 00:15:00")
)
features = dataset._get_time_features(midnight)
assert features.shape[1] == 4 # [cos_hour, sin_hour, cos_day, sin_day]

# At midnight: cos(hour)=1, sin(hour)=0
np.testing.assert_allclose(features[0, 0], 1, atol=1e-6) # cos(hour)
np.testing.assert_allclose(features[0, 1], 0, atol=1e-6) # sin(hour)

# Test noon (hour = 12)
noon = (
pd.Timestamp("2023-01-01 12:00:00"),
pd.Timestamp("2023-01-01 12:15:00")
)
features = dataset._get_time_features(noon)

# At noon: cos(hour)=-1, sin(hour)≈0
np.testing.assert_allclose(features[0, 0], -1, atol=1e-6) # cos(hour)
np.testing.assert_allclose(features[0, 1], 0, atol=1e-6) # sin(hour)

# Test start of year
new_year = (
pd.Timestamp("2023-01-01 00:00:00"),
pd.Timestamp("2023-01-01 00:15:00")
)
features = dataset._get_time_features(new_year)

# At start of year: cos(day)=1, sin(day)=0
np.testing.assert_allclose(features[0, 2], 1, atol=1e-6) # cos(day)
np.testing.assert_allclose(features[0, 3], 0, atol=1e-6) # sin(day)
Loading