Skip to content

Commit

Permalink
Enable QuantileOutput for TiDE model (#3189)
Browse files Browse the repository at this point in the history
*Description of changes:*
- Make the `TiDE` model compatible with arbitrary `Output` objects (not
only `DistributionOutput`).


By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice.


**Please tag this pr with at least one of these labels to make our
release process faster:** BREAKING, new feature, bug fix, other change,
dev setup
  • Loading branch information
shchur authored May 31, 2024
1 parent 5e30960 commit e3562e4
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 19 deletions.
8 changes: 0 additions & 8 deletions src/gluonts/torch/distributions/distribution_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,14 +90,6 @@ def loss(
nll = nll * (variance.detach() ** self.beta)
return nll

@property
def event_shape(self) -> Tuple:
r"""
Shape of each individual event contemplated by the distributions that
this object constructs.
"""
raise NotImplementedError()

@property
def event_dim(self) -> int:
r"""
Expand Down
7 changes: 7 additions & 0 deletions src/gluonts/torch/distributions/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,13 @@ def loss(
"""
raise NotImplementedError()

@property
def event_shape(self) -> Tuple:
r"""
Shape of each individual event compatible with the output object.
"""
raise NotImplementedError()

@property
def forecast_generator(self) -> ForecastGenerator:
raise NotImplementedError()
Expand Down
4 changes: 4 additions & 0 deletions src/gluonts/torch/distributions/quantile_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ def __init__(self, quantiles: List[float]) -> None:
def forecast_generator(self) -> ForecastGenerator:
return QuantileForecastGenerator(quantiles=self.quantiles)

@property
def event_shape(self) -> Tuple:
return ()

def domain_map(self, *args: torch.Tensor) -> Tuple[torch.Tensor, ...]:
return args

Expand Down
12 changes: 3 additions & 9 deletions src/gluonts/torch/model/tide/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from gluonts.dataset.field_names import FieldName
from gluonts.dataset.loader import as_stacked_batches
from gluonts.itertools import Cyclic
from gluonts.model.forecast_generator import DistributionForecastGenerator
from gluonts.time_feature import (
minute_of_hour,
hour_of_day,
Expand Down Expand Up @@ -49,10 +48,7 @@

from gluonts.torch.model.estimator import PyTorchLightningEstimator
from gluonts.torch.model.predictor import PyTorchPredictor
from gluonts.torch.distributions import (
DistributionOutput,
StudentTOutput,
)
from gluonts.torch.distributions import Output, StudentTOutput

from .lightning_module import TiDELightningModule

Expand Down Expand Up @@ -174,7 +170,7 @@ def __init__(
weight_decay: float = 1e-8,
patience: int = 10,
scaling: Optional[str] = "mean",
distr_output: DistributionOutput = StudentTOutput(),
distr_output: Output = StudentTOutput(),
batch_size: int = 32,
num_batches_per_epoch: int = 50,
trainer_kwargs: Optional[Dict[str, Any]] = None,
Expand Down Expand Up @@ -403,9 +399,7 @@ def create_predictor(
input_transform=transformation + prediction_splitter,
input_names=PREDICTION_INPUT_NAMES,
prediction_net=module,
forecast_generator=DistributionForecastGenerator(
self.distr_output
),
forecast_generator=self.distr_output.forecast_generator,
batch_size=self.batch_size,
prediction_length=self.prediction_length,
device="auto",
Expand Down
4 changes: 2 additions & 2 deletions src/gluonts/torch/model/tide/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from gluonts.core.component import validated
from gluonts.torch.modules.feature import FeatureEmbedder
from gluonts.model import Input, InputSpec
from gluonts.torch.distributions import DistributionOutput
from gluonts.torch.distributions import Output
from gluonts.torch.scaler import StdScaler, MeanScaler, NOPScaler
from gluonts.torch.model.simple_feedforward import make_linear_layer
from gluonts.torch.util import weighted_average
Expand Down Expand Up @@ -242,7 +242,7 @@ def __init__(
num_layers_encoder: int,
num_layers_decoder: int,
layer_norm: bool,
distr_output: DistributionOutput,
distr_output: Output,
scaling: str,
) -> None:
super().__init__()
Expand Down
8 changes: 8 additions & 0 deletions test/torch/model/test_estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,14 @@
num_batches_per_epoch=3,
trainer_kwargs=dict(max_epochs=2),
),
lambda dataset: TiDEEstimator(
freq=dataset.metadata.freq,
prediction_length=dataset.metadata.prediction_length,
distr_output=QuantileOutput(quantiles=[0.1, 0.6, 0.85]),
batch_size=4,
num_batches_per_epoch=3,
trainer_kwargs=dict(max_epochs=2),
),
lambda dataset: WaveNetEstimator(
freq=dataset.metadata.freq,
prediction_length=dataset.metadata.prediction_length,
Expand Down

0 comments on commit e3562e4

Please sign in to comment.