Skip to content

Commit

Permalink
Enable QuantileOutput for TiDE model
Browse files Browse the repository at this point in the history
  • Loading branch information
shchur committed May 31, 2024
1 parent 5e30960 commit 62a9fa8
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 18 deletions.
8 changes: 0 additions & 8 deletions src/gluonts/torch/distributions/distribution_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,14 +90,6 @@ def loss(
nll = nll * (variance.detach() ** self.beta)
return nll

@property
def event_shape(self) -> Tuple:
r"""
Shape of each individual event contemplated by the distributions that
this object constructs.
"""
raise NotImplementedError()

@property
def event_dim(self) -> int:
r"""
Expand Down
7 changes: 7 additions & 0 deletions src/gluonts/torch/distributions/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,13 @@ def loss(
"""
raise NotImplementedError()

@property
def event_shape(self) -> Tuple:
r"""
Shape of each individual event compatible with the output object.
"""
raise NotImplementedError()

@property
def forecast_generator(self) -> ForecastGenerator:
raise NotImplementedError()
Expand Down
4 changes: 4 additions & 0 deletions src/gluonts/torch/distributions/quantile_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ def __init__(self, quantiles: List[float]) -> None:
def forecast_generator(self) -> ForecastGenerator:
return QuantileForecastGenerator(quantiles=self.quantiles)

@property
def event_shape(self) -> Tuple:
return ()

def domain_map(self, *args: torch.Tensor) -> Tuple[torch.Tensor, ...]:
return args

Expand Down
11 changes: 3 additions & 8 deletions src/gluonts/torch/model/tide/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,7 @@

from gluonts.torch.model.estimator import PyTorchLightningEstimator
from gluonts.torch.model.predictor import PyTorchPredictor
from gluonts.torch.distributions import (
DistributionOutput,
StudentTOutput,
)
from gluonts.torch.distributions import Output, StudentTOutput

from .lightning_module import TiDELightningModule

Expand Down Expand Up @@ -174,7 +171,7 @@ def __init__(
weight_decay: float = 1e-8,
patience: int = 10,
scaling: Optional[str] = "mean",
distr_output: DistributionOutput = StudentTOutput(),
distr_output: Output = StudentTOutput(),
batch_size: int = 32,
num_batches_per_epoch: int = 50,
trainer_kwargs: Optional[Dict[str, Any]] = None,
Expand Down Expand Up @@ -403,9 +400,7 @@ def create_predictor(
input_transform=transformation + prediction_splitter,
input_names=PREDICTION_INPUT_NAMES,
prediction_net=module,
forecast_generator=DistributionForecastGenerator(
self.distr_output
),
forecast_generator=self.distr_output.forecast_generator,
batch_size=self.batch_size,
prediction_length=self.prediction_length,
device="auto",
Expand Down
4 changes: 2 additions & 2 deletions src/gluonts/torch/model/tide/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from gluonts.core.component import validated
from gluonts.torch.modules.feature import FeatureEmbedder
from gluonts.model import Input, InputSpec
from gluonts.torch.distributions import DistributionOutput
from gluonts.torch.distributions import Output
from gluonts.torch.scaler import StdScaler, MeanScaler, NOPScaler
from gluonts.torch.model.simple_feedforward import make_linear_layer
from gluonts.torch.util import weighted_average
Expand Down Expand Up @@ -242,7 +242,7 @@ def __init__(
num_layers_encoder: int,
num_layers_decoder: int,
layer_norm: bool,
distr_output: DistributionOutput,
distr_output: Output,
scaling: str,
) -> None:
super().__init__()
Expand Down
8 changes: 8 additions & 0 deletions test/torch/model/test_estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,14 @@
num_batches_per_epoch=3,
trainer_kwargs=dict(max_epochs=2),
),
lambda dataset: TiDEEstimator(
freq=dataset.metadata.freq,
prediction_length=dataset.metadata.prediction_length,
distr_output=QuantileOutput(quantiles=[0.1, 0.6, 0.85]),
batch_size=4,
num_batches_per_epoch=3,
trainer_kwargs=dict(max_epochs=2),
),
lambda dataset: WaveNetEstimator(
freq=dataset.metadata.freq,
prediction_length=dataset.metadata.prediction_length,
Expand Down

0 comments on commit 62a9fa8

Please sign in to comment.