Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 51 additions & 20 deletions neuralforecast/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ def __init__(
models: List[Any],
freq: Union[str, int],
local_scaler_type: Optional[str] = None,
local_static_scaler_type: Optional[str] = None,
):
"""The `core.StatsForecast` class allows you to efficiently fit multiple `NeuralForecast` models
for large sets of time series. It operates with pandas DataFrame `df` that identifies series
Expand All @@ -232,7 +233,9 @@ def __init__(
models (List[typing.Any]): Instantiated `neuralforecast.models`
see [collection here](./models).
freq (str or int): Frequency of the data. Must be a valid pandas or polars offset alias, or an integer.
local_scaler_type (str, optional): Scaler to apply per-serie to all features before fitting, which is inverted after predicting.
local_scaler_type (str, optional): Scaler to apply per-serie to temporal features before fitting, which is inverted after predicting.
Can be 'standard', 'robust', 'robust-iqr', 'minmax' or 'boxcox'. Defaults to None.
local_static_scaler_type (str, optional): Scaler to apply to static exogenous features before fitting.
Can be 'standard', 'robust', 'robust-iqr', 'minmax' or 'boxcox'. Defaults to None.

Returns:
Expand All @@ -247,34 +250,48 @@ def __init__(
self.freq = freq
if local_scaler_type is not None and local_scaler_type not in _type2scaler:
raise ValueError(f"scaler_type must be one of {_type2scaler.keys()}")
if local_static_scaler_type is not None and local_static_scaler_type not in _type2scaler:
raise ValueError(f"static_scaler_type must be one of {_type2scaler.keys()}")
self.local_scaler_type = local_scaler_type
self.local_static_scaler_type = local_static_scaler_type
self.scalers_: Dict
self.static_scalers_: Dict

# Flags and attributes
self._fitted = False
self._reset_models()
self._add_level = False

def _scalers_fit_transform(self, dataset: TimeSeriesDataset) -> None:
self.scalers_ = {}
if self.local_scaler_type is None:
return None
for i, col in enumerate(dataset.temporal_cols):
if col == "available_mask":
continue
ga = GroupedArray(dataset.temporal[:, i].numpy(), dataset.indptr)
self.scalers_[col] = _type2scaler[self.local_scaler_type]().fit(ga)
dataset.temporal[:, i] = torch.from_numpy(self.scalers_[col].transform(ga))
self.scalers_, self.static_scalers_ = {}, {}
if self.local_scaler_type is not None:
for i, col in enumerate(dataset.temporal_cols):
if col == "available_mask":
continue
ga = GroupedArray(dataset.temporal[:, i].numpy(), dataset.indptr)
self.scalers_[col] = _type2scaler[self.local_scaler_type]().fit(ga)
dataset.temporal[:, i] = torch.from_numpy(self.scalers_[col].transform(ga))
if self.local_static_scaler_type is not None and dataset.static is not None:
for i, col in enumerate(dataset.static_cols):
ga = GroupedArray(dataset.static[:, i].numpy(), np.array([0, dataset.static.shape[0]]))
self.static_scalers_[col] = _type2scaler[self.local_static_scaler_type]().fit(ga)
dataset.static[:, i] = torch.from_numpy(self.static_scalers_[col].transform(ga))

def _scalers_transform(self, dataset: TimeSeriesDataset) -> None:
if not self.scalers_:
return None
for i, col in enumerate(dataset.temporal_cols):
scaler = self.scalers_.get(col, None)
if scaler is None:
continue
ga = GroupedArray(dataset.temporal[:, i].numpy(), dataset.indptr)
dataset.temporal[:, i] = torch.from_numpy(scaler.transform(ga))
if self.scalers_:
for i, col in enumerate(dataset.temporal_cols):
scaler = self.scalers_.get(col, None)
if scaler is None:
continue
ga = GroupedArray(dataset.temporal[:, i].numpy(), dataset.indptr)
dataset.temporal[:, i] = torch.from_numpy(scaler.transform(ga))
if self.static_scalers_ and dataset.static is not None:
for i, col in enumerate(dataset.static_cols):
scaler = self.static_scalers_.get(col, None)
if scaler is None:
continue
ga = GroupedArray(dataset.static[:, i].numpy(), np.array([0, dataset.static.shape[0]]))
dataset.static[:, i] = torch.from_numpy(scaler.transform(ga))

def _scalers_target_inverse_transform(
self, data: np.ndarray, indptr: np.ndarray
Expand Down Expand Up @@ -348,6 +365,11 @@ def _prepare_fit_distributed(
"Historic scaling isn't supported in distributed. "
"Please open an issue if this would be valuable to you."
)
if self.local_static_scaler_type is not None:
raise ValueError(
"Static scaling isn't supported in distributed. "
"Please open an issue if this would be valuable to you."
)
temporal_cols = [c for c in df.columns if c not in (id_col, time_col)]
if static_df is not None:
static_cols = [c for c in static_df.columns if c != id_col]
Expand All @@ -357,7 +379,7 @@ def _prepare_fit_distributed(
self.id_col = id_col
self.time_col = time_col
self.target_col = target_col
self.scalers_ = {}
self.scalers_, self.static_scalers_ = {}, {}
num_partitions = distributed_config.num_nodes * distributed_config.devices
df = df.repartitionByRange(num_partitions, id_col)
df.write.parquet(path=distributed_config.partitions_path, mode="overwrite")
Expand Down Expand Up @@ -393,11 +415,16 @@ def _prepare_fit_for_local_files(
"Historic scaling isn't supported when the dataset is split between files. "
"Please open an issue if this would be valuable to you."
)
if self.local_static_scaler_type is not None:
raise ValueError(
"Static scaling isn't supported when the dataset is split between files. "
"Please open an issue if this would be valuable to you."
)

self.id_col = id_col
self.time_col = time_col
self.target_col = target_col
self.scalers_ = {}
self.scalers_, self.static_scalers_ = {}, {}

exogs = self._get_needed_exog()
return LocalFilesTimeSeriesDataset.from_data_directories(
Expand Down Expand Up @@ -1752,7 +1779,9 @@ def save(
"freq": self.freq,
"_fitted": self._fitted,
"local_scaler_type": self.local_scaler_type,
"local_static_scaler_type": self.local_static_scaler_type,
"scalers_": self.scalers_,
"static_scalers_": self.static_scalers_,
"id_col": self.id_col,
"time_col": self.time_col,
"target_col": self.target_col,
Expand Down Expand Up @@ -1855,6 +1884,7 @@ def load(path, verbose=False, **kwargs):
models=models,
freq=config_dict["freq"],
local_scaler_type=config_dict.get("local_scaler_type", default_scalar_type),
local_static_scaler_type=config_dict.get("local_static_scaler_type", None)
)

attr_to_default = {"id_col": "unique_id", "time_col": "ds", "target_col": "y"}
Expand All @@ -1879,6 +1909,7 @@ def load(path, verbose=False, **kwargs):
neuralforecast._fitted = config_dict["_fitted"]

neuralforecast.scalers_ = config_dict.get("scalers_", default_scalars_)
neuralforecast.static_scalers_ = config_dict.get("static_scalers_", {})

return neuralforecast

Expand Down
120 changes: 119 additions & 1 deletion tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
PredictionIntervals,
TimesNet,
_insample_times,
_type2scaler,
)
from neuralforecast.losses.pytorch import (
GMM,
Expand All @@ -57,6 +58,7 @@
DistributionLoss,
MQLoss,
)
from neuralforecast.tsdataset import TimeSeriesDataset
from neuralforecast.utils import (
AirPassengersPanel,
AirPassengersStatic,
Expand Down Expand Up @@ -341,6 +343,118 @@ def test_neural_forecast_boxcox_scaling(setup_airplane_data):
insample_res["y_expected"].values,
rtol=0.7,
)


# Test static exogenous feature scaling

@pytest.fixture
def config() -> dict:
return {
'h': 12,
'input_size': 24,
'max_steps': 10,
}


@pytest.mark.parametrize("scaler", _type2scaler.keys())
def test_neural_forecast_static_scaling(config, scaler):
"""Test static scaling functionality for NeuralForecast models."""
stat_exog = ['airline1', 'airline2']
nf = NeuralForecast(models=[NHITS(**config, stat_exog_list=stat_exog)], freq="D")
nf.fit(AirPassengersPanel, AirPassengersStatic)
without_scaler = nf.predict()

nf = NeuralForecast(models=[NHITS(**config, stat_exog_list=stat_exog)], freq="D", local_static_scaler_type=scaler)
nf.fit(AirPassengersPanel, AirPassengersStatic)
with_scaler = nf.predict()

np.testing.assert_allclose(
without_scaler["NHITS"].values,
with_scaler["NHITS"].values,
rtol=0.2,
)


@pytest.fixture
def data(size=300, n_series=3) -> pd.DataFrame:
return pd.DataFrame({
"unique_id": (['store_1'] * (size // n_series) + ['store_2'] * (size // n_series) + ['store_3'] * (size // n_series)),
"ds": np.tile(pd.date_range(start="2020-01-01", periods=size // n_series, freq="D").to_numpy(), n_series),
"y": np.random.rand(size),
})


@pytest.fixture
def static_data(n_series=3) -> pd.DataFrame:
return pd.DataFrame({
"unique_id": [f'store_{i+1}' for i in range(n_series)],
"size": np.random.randint(50, 500, size=n_series),
"num_days_open": np.random.randint(100, 1000, size=n_series),
})


@pytest.mark.parametrize("scaler", _type2scaler.keys())
def test_normalization_of_static_exog(data, static_data, config, scaler):
models = [TFT(**config, stat_exog_list=["size", "num_days_open"])]
nf = NeuralForecast(models=models, freq="D", local_static_scaler_type=scaler)

dataset = TimeSeriesDataset.from_df(data, static_data)[0]
nf._scalers_fit_transform(dataset)
fit_dataset = dataset.static

dataset = TimeSeriesDataset.from_df(data, static_data)[0]
nf._scalers_transform(dataset)
predict_dataset = dataset.static

assert (fit_dataset == predict_dataset).all()


def test_standard_normalization_of_static_exog(data, static_data, config):
models = [TFT(**config, stat_exog_list=["size", "num_days_open"])]
nf = NeuralForecast(models=models, freq="D", local_static_scaler_type="standard")

dataset = TimeSeriesDataset.from_df(data, static_data)[0]
nf._scalers_fit_transform(dataset)
normalized_static = dataset.static.numpy()

for i, col in enumerate(dataset.static_cols):
col_values = static_data[col].values.reshape(-1, 1).astype(np.float32)
mean = col_values.mean()
std = col_values.std()
expected_normalized = (col_values - mean) / std

np.testing.assert_allclose(nf.static_scalers_[col].stats_[:, 0], mean, rtol=1e-5)
np.testing.assert_allclose(nf.static_scalers_[col].stats_[:, 1], std, rtol=1e-5)
np.testing.assert_allclose(
normalized_static[:, i],
expected_normalized.flatten(),
rtol=1e-5,
)


def test_minmax_normalization_of_static_exog(data, static_data, config):
models = [TFT(**config, stat_exog_list=["size", "num_days_open"])]
nf = NeuralForecast(models=models, freq="D", local_static_scaler_type="minmax")

dataset = TimeSeriesDataset.from_df(data, static_data)[0]
nf._scalers_fit_transform(dataset)
normalized_static = dataset.static.numpy()

for i, col in enumerate(dataset.static_cols):
col_values = static_data[col].values.reshape(-1, 1).astype(np.float32)
min_val = col_values.min()
range_val = col_values.max() - min_val
expected_normalized = (col_values - min_val) / range_val

np.testing.assert_allclose(nf.static_scalers_[col].stats_[:, 0], min_val, rtol=1e-5)
np.testing.assert_allclose(nf.static_scalers_[col].stats_[:, 1], range_val, rtol=1e-5)
np.testing.assert_allclose(
normalized_static[:, i],
expected_normalized.flatten(),
rtol=1e-5,
)


# test futr_df contents
def test_future_df_contents(setup_airplane_data):
AirPassengersPanel_train, AirPassengersPanel_test = setup_airplane_data
Expand Down Expand Up @@ -907,7 +1021,11 @@ def test_save_load(setup_airplane_data):
def test_save_load_no_dataset(setup_airplane_data):
AirPassengersPanel_train, AirPassengersPanel_test = setup_airplane_data

shutil.rmtree("examples/debug_run")
try:
shutil.rmtree("examples/debug_run")
except:
print("Directory does not exist")

fcst = NeuralForecast(
models=[DilatedRNN(h=12, input_size=-1, encoder_hidden_size=5, max_steps=1)],
freq="M",
Expand Down