diff --git a/src/gluonts/torch/model/deepar/module.py b/src/gluonts/torch/model/deepar/module.py index 5f073bc914..3e27c8be02 100644 --- a/src/gluonts/torch/model/deepar/module.py +++ b/src/gluonts/torch/model/deepar/module.py @@ -146,7 +146,9 @@ def __init__( ) else: self.scaler = NOPScaler(dim=-1, keepdim=True) - self.rnn_input_size = len(self.lags_seq) + self._number_of_features + self.rnn_input_size = ( + 2 * len(self.lags_seq) + ) + self._number_of_features self.rnn = nn.LSTM( input_size=self.rnn_input_size, hidden_size=hidden_size, @@ -216,23 +218,41 @@ def prepare_rnn_input( past_observed_values: torch.Tensor, future_time_feat: torch.Tensor, future_target: Optional[torch.Tensor] = None, + future_observed_values: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor,]: context = past_target[..., -self.context_length :] observed_context = past_observed_values[..., -self.context_length :] input, _, scale = self.scaler(context, observed_context) + observed_input = observed_context future_length = future_time_feat.shape[-2] if future_length > 1: - assert future_target is not None + assert ( + future_target is not None + and future_observed_values is not None + ) input = torch.cat( (input, future_target[..., : future_length - 1] / scale), dim=-1, ) + observed_input = torch.cat( + ( + observed_input, + future_observed_values[..., : future_length - 1], + ), + dim=-1, + ) prior_input = past_target[..., : -self.context_length] / scale + observed_prior_input = past_observed_values[ + ..., : -self.context_length + ] lags = lagged_sequence_values( self.lags_seq, prior_input, input, dim=-1 ) + observed_lags = lagged_sequence_values( + self.lags_seq, observed_prior_input, observed_input, dim=-1 + ) time_feat = torch.cat( ( @@ -252,8 +272,8 @@ def prepare_rnn_input( ) features = torch.cat((expanded_static_feat, time_feat), dim=-1) - - return torch.cat((lags, features), dim=-1), scale, static_feat + rnn_input = torch.cat((lags, observed_lags, features), dim=-1) + return (rnn_input, scale, static_feat) def unroll_lagged_rnn( self, @@ -264,6 +284,7 @@ def unroll_lagged_rnn( past_observed_values: torch.Tensor, future_time_feat: torch.Tensor, future_target: Optional[torch.Tensor] = None, + future_observed_values: Optional[torch.Tensor] = None, ) -> Tuple[ Tuple[torch.Tensor, ...], torch.Tensor, @@ -297,6 +318,9 @@ def unroll_lagged_rnn( future_target (Optional) tensor of future target values, shape: ``(batch_size, prediction_length)``. + future_observed_values + (Optional) tensor of future observed values indicators, + shape: ``(batch_size, prediction_length)``. Returns ------- @@ -316,6 +340,7 @@ def unroll_lagged_rnn( past_observed_values, future_time_feat, future_target, + future_observed_values, ) output, new_state = self.rnn(rnn_input) @@ -409,6 +434,9 @@ def forward( past_target.repeat_interleave(repeats=num_parallel_samples, dim=0) / repeated_scale ) + repeated_past_observed_values = past_observed_values.repeat_interleave( + repeats=num_parallel_samples, dim=0 + ) repeated_time_feat = future_time_feat.repeat_interleave( repeats=num_parallel_samples, dim=0 ) @@ -436,13 +464,28 @@ def forward( next_lags = lagged_sequence_values( self.lags_seq, repeated_past_target, scaled_next_sample, dim=-1 ) - rnn_input = torch.cat((next_lags, next_features), dim=-1) + next_observed_lags = lagged_sequence_values( + self.lags_seq, + repeated_past_observed_values, + torch.ones_like(scaled_next_sample), + dim=-1, + ) + rnn_input = torch.cat( + (next_lags, next_observed_lags, next_features), dim=-1 + ) output, repeated_state = self.rnn(rnn_input, repeated_state) repeated_past_target = torch.cat( (repeated_past_target, scaled_next_sample), dim=1 ) + repeated_past_observed_values = torch.cat( + ( + repeated_past_observed_values, + torch.ones_like(scaled_next_sample), + ), + dim=1, + ) params = self.param_proj(output) distr = self.output_distribution(params, scale=repeated_scale) @@ -524,6 +567,7 @@ def loss( past_observed_values, future_time_feat, future_target_reshaped, + future_observed_reshaped, ) if future_only: diff --git a/src/gluonts/torch/model/mqf2/lightning_module.py b/src/gluonts/torch/model/mqf2/lightning_module.py index 6dc824beb4..470eee8d58 100644 --- a/src/gluonts/torch/model/mqf2/lightning_module.py +++ b/src/gluonts/torch/model/mqf2/lightning_module.py @@ -96,6 +96,7 @@ def _compute_loss(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: future_time_feat = batch["future_time_feat"] future_target = batch["future_target"] past_observed_values = batch["past_observed_values"] + future_observed_values = batch["future_observed_values"] picnn = self.model.picnn @@ -107,6 +108,7 @@ def _compute_loss(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: past_observed_values, future_time_feat, future_target, + future_observed_values, ) hidden_state = hidden_state[:, : self.model.context_length] diff --git a/test/torch/model/test_deepar_modules.py b/test/torch/model/test_deepar_modules.py index fa488d4264..6a4de8b339 100644 --- a/test/torch/model/test_deepar_modules.py +++ b/test/torch/model/test_deepar_modules.py @@ -79,6 +79,7 @@ def test_deepar_modules( past_observed_values, future_time_feat, future_target, + future_observed_values, ) assert scale.shape == (batch_size, 1) @@ -231,6 +232,11 @@ def test_rnn_input( dtype=torch.float32, ).view(1, prediction_length) + batch["future_observed_values"] = torch.ones( + (1, prediction_length), + dtype=torch.float32, + ) + rnn_input, scale, _ = model.prepare_rnn_input(**batch) assert (scale == 1.0).all() diff --git a/test/torch/model/test_mqf2_modules.py b/test/torch/model/test_mqf2_modules.py index 85fa21337f..451a16d890 100644 --- a/test/torch/model/test_mqf2_modules.py +++ b/test/torch/model/test_mqf2_modules.py @@ -78,6 +78,7 @@ def test_mqf2_modules( past_observed_values, future_time_feat, future_target, + future_observed_values, ) hidden_state = hidden_state[:, :context_length]