diff --git a/docs/tutorials/advanced_topics/howto_pytorch_lightning.md.template b/docs/tutorials/advanced_topics/howto_pytorch_lightning.md.template index 01e170951e..4d54d5636a 100644 --- a/docs/tutorials/advanced_topics/howto_pytorch_lightning.md.template +++ b/docs/tutorials/advanced_topics/howto_pytorch_lightning.md.template @@ -134,7 +134,7 @@ To train the model using PyTorch Lightning, we only need to extend the class wit ```python -import pytorch_lightning as pl +import lightning.pytorch as pl ``` diff --git a/requirements/requirements-pytorch.txt b/requirements/requirements-pytorch.txt index 16a40f64a3..03f4e997ab 100644 --- a/requirements/requirements-pytorch.txt +++ b/requirements/requirements-pytorch.txt @@ -1,5 +1,7 @@ torch>=1.9,<3 -pytorch-lightning>=1.5,<3 +lightning>=1.8,<2.2 +# Capping `lightning` does not cap `pytorch_lightning`, so we cap manually +pytorch_lightning>=1.8,<2.2 # Need to pin protobuf (for now) # See: https://github.com/PyTorchLightning/pytorch-lightning/issues/13159 protobuf~=3.19.0 diff --git a/src/gluonts/torch/model/d_linear/estimator.py b/src/gluonts/torch/model/d_linear/estimator.py index 4f58caf4a0..e3f428db86 100644 --- a/src/gluonts/torch/model/d_linear/estimator.py +++ b/src/gluonts/torch/model/d_linear/estimator.py @@ -14,7 +14,7 @@ from typing import Optional, Iterable, Dict, Any import torch -import pytorch_lightning as pl +import lightning.pytorch as pl from gluonts.core.component import validated from gluonts.dataset.common import Dataset diff --git a/src/gluonts/torch/model/d_linear/lightning_module.py b/src/gluonts/torch/model/d_linear/lightning_module.py index 28dccf1b97..bd081b45dd 100644 --- a/src/gluonts/torch/model/d_linear/lightning_module.py +++ b/src/gluonts/torch/model/d_linear/lightning_module.py @@ -11,7 +11,7 @@ # express or implied. See the License for the specific language governing # permissions and limitations under the License. -import pytorch_lightning as pl +import lightning.pytorch as pl import torch from gluonts.core.component import validated diff --git a/src/gluonts/torch/model/deepar/lightning_module.py b/src/gluonts/torch/model/deepar/lightning_module.py index fc676dfab3..8d190e2329 100644 --- a/src/gluonts/torch/model/deepar/lightning_module.py +++ b/src/gluonts/torch/model/deepar/lightning_module.py @@ -11,7 +11,7 @@ # express or implied. See the License for the specific language governing # permissions and limitations under the License. -import pytorch_lightning as pl +import lightning.pytorch as pl import torch from torch.optim.lr_scheduler import ReduceLROnPlateau diff --git a/src/gluonts/torch/model/estimator.py b/src/gluonts/torch/model/estimator.py index 9311b474d0..9f41282ebb 100644 --- a/src/gluonts/torch/model/estimator.py +++ b/src/gluonts/torch/model/estimator.py @@ -15,7 +15,7 @@ import logging import numpy as np -import pytorch_lightning as pl +import lightning.pytorch as pl import torch.nn as nn from gluonts.core.component import validated @@ -209,10 +209,15 @@ def train_model( ckpt_path=ckpt_path, ) - logger.info(f"Loading best model from {checkpoint.best_model_path}") - best_model = training_network.load_from_checkpoint( - checkpoint.best_model_path - ) + if checkpoint.best_model_path != "": + logger.info( + f"Loading best model from {checkpoint.best_model_path}" + ) + best_model = training_network.__class__.load_from_checkpoint( + checkpoint.best_model_path + ) + else: + best_model = training_network return TrainOutput( transformation=transformation, diff --git a/src/gluonts/torch/model/lag_tst/estimator.py b/src/gluonts/torch/model/lag_tst/estimator.py index 27bfd253b2..c3ae48237a 100644 --- a/src/gluonts/torch/model/lag_tst/estimator.py +++ b/src/gluonts/torch/model/lag_tst/estimator.py @@ -14,7 +14,7 @@ from typing import Optional, Iterable, Dict, Any, List import torch -import pytorch_lightning as pl +import lightning.pytorch as pl from gluonts.core.component import validated from gluonts.dataset.common import Dataset diff --git a/src/gluonts/torch/model/lag_tst/lightning_module.py b/src/gluonts/torch/model/lag_tst/lightning_module.py index 2510944cfa..5c9e70e9e4 100644 --- a/src/gluonts/torch/model/lag_tst/lightning_module.py +++ b/src/gluonts/torch/model/lag_tst/lightning_module.py @@ -11,7 +11,7 @@ # express or implied. See the License for the specific language governing # permissions and limitations under the License. -import pytorch_lightning as pl +import lightning.pytorch as pl import torch from gluonts.core.component import validated diff --git a/src/gluonts/torch/model/lightning_util.py b/src/gluonts/torch/model/lightning_util.py index 6742c8c7cf..73e2396140 100644 --- a/src/gluonts/torch/model/lightning_util.py +++ b/src/gluonts/torch/model/lightning_util.py @@ -13,7 +13,7 @@ from packaging import version -import pytorch_lightning as pl +import lightning.pytorch as pl def has_validation_loop(trainer: pl.Trainer): diff --git a/src/gluonts/torch/model/mqf2/lightning_module.py b/src/gluonts/torch/model/mqf2/lightning_module.py index 6dc824beb4..16916c3c41 100644 --- a/src/gluonts/torch/model/mqf2/lightning_module.py +++ b/src/gluonts/torch/model/mqf2/lightning_module.py @@ -13,7 +13,7 @@ from typing import Dict -import pytorch_lightning as pl +import lightning.pytorch as pl import torch from torch.optim.lr_scheduler import ReduceLROnPlateau diff --git a/src/gluonts/torch/model/patch_tst/estimator.py b/src/gluonts/torch/model/patch_tst/estimator.py index c84576916f..860c18e193 100644 --- a/src/gluonts/torch/model/patch_tst/estimator.py +++ b/src/gluonts/torch/model/patch_tst/estimator.py @@ -14,7 +14,7 @@ from typing import Optional, Iterable, Dict, Any import torch -import pytorch_lightning as pl +import lightning.pytorch as pl from gluonts.core.component import validated from gluonts.dataset.common import Dataset diff --git a/src/gluonts/torch/model/patch_tst/lightning_module.py b/src/gluonts/torch/model/patch_tst/lightning_module.py index f5e95158b2..d80137ae05 100644 --- a/src/gluonts/torch/model/patch_tst/lightning_module.py +++ b/src/gluonts/torch/model/patch_tst/lightning_module.py @@ -11,7 +11,7 @@ # express or implied. See the License for the specific language governing # permissions and limitations under the License. -import pytorch_lightning as pl +import lightning.pytorch as pl import torch from gluonts.core.component import validated diff --git a/src/gluonts/torch/model/simple_feedforward/estimator.py b/src/gluonts/torch/model/simple_feedforward/estimator.py index 5f82640e8b..05f34a9fa2 100644 --- a/src/gluonts/torch/model/simple_feedforward/estimator.py +++ b/src/gluonts/torch/model/simple_feedforward/estimator.py @@ -14,7 +14,7 @@ from typing import List, Optional, Iterable, Dict, Any import torch -import pytorch_lightning as pl +import lightning.pytorch as pl from gluonts.core.component import validated from gluonts.dataset.common import Dataset diff --git a/src/gluonts/torch/model/simple_feedforward/lightning_module.py b/src/gluonts/torch/model/simple_feedforward/lightning_module.py index b7cf9a529a..f03473e78d 100644 --- a/src/gluonts/torch/model/simple_feedforward/lightning_module.py +++ b/src/gluonts/torch/model/simple_feedforward/lightning_module.py @@ -11,7 +11,7 @@ # express or implied. See the License for the specific language governing # permissions and limitations under the License. -import pytorch_lightning as pl +import lightning.pytorch as pl import torch from gluonts.core.component import validated diff --git a/src/gluonts/torch/model/tft/lightning_module.py b/src/gluonts/torch/model/tft/lightning_module.py index f6f7daa335..4647d740fd 100644 --- a/src/gluonts/torch/model/tft/lightning_module.py +++ b/src/gluonts/torch/model/tft/lightning_module.py @@ -11,7 +11,7 @@ # express or implied. See the License for the specific language governing # permissions and limitations under the License. -import pytorch_lightning as pl +import lightning.pytorch as pl import torch from gluonts.core.component import validated from gluonts.itertools import select