Skip to content

Commit

Permalink
Fix MQF tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Abdul Fatir Ansari committed May 27, 2023
1 parent 8f83698 commit 15204d5
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/gluonts/torch/model/mqf2/lightning_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def _compute_loss(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
future_time_feat = batch["future_time_feat"]
future_target = batch["future_target"]
past_observed_values = batch["past_observed_values"]
future_observed_values = batch["future_observed_values"]

picnn = self.model.picnn

Expand All @@ -107,6 +108,7 @@ def _compute_loss(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
past_observed_values,
future_time_feat,
future_target,
future_observed_values,
)

hidden_state = hidden_state[:, : self.model.context_length]
Expand Down
1 change: 1 addition & 0 deletions test/torch/model/test_mqf2_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def test_mqf2_modules(
past_observed_values,
future_time_feat,
future_target,
future_observed_values,
)

hidden_state = hidden_state[:, :context_length]
Expand Down

0 comments on commit 15204d5

Please sign in to comment.