Skip to content

Commit

Permalink
Upgrade to lightning 2.0
Browse files Browse the repository at this point in the history
  • Loading branch information
tshu-w committed Mar 17, 2023
1 parent 713b15d commit e7711a1
Show file tree
Hide file tree
Showing 12 changed files with 74 additions and 82 deletions.
8 changes: 4 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,22 @@ repos:
- id: check-executables-have-shebangs

- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: v0.0.230
rev: v0.0.256
hooks:
- id: ruff

- repo: https://github.com/psf/black
rev: 22.12.0
rev: 23.1.0
hooks:
- id: black

- repo: https://github.com/kynan/nbstripout.git
rev: 0.6.0
rev: 0.6.1
hooks:
- id: nbstripout

- repo: https://github.com/pre-commit/mirrors-prettier
rev: v2.7.1
rev: v3.0.0-alpha.6
hooks:
- id: prettier
types: [yaml]
2 changes: 1 addition & 1 deletion configs/presets/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ trainer:
callbacks:
class_path: Metric
logger:
class_path: TensorBoardLogger
class_path: CSVLogger
init_args:
save_dir: results
accelerator: auto
Expand Down
11 changes: 6 additions & 5 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
--find-links https://download.pytorch.org/whl/cu113/torch_stable.html
torch >= 1.12.0
torchvision >= 0.13.0
pytorch-lightning >= 1.9.0
torch >= 2.0.0
lightning >= 2.0.0
torchvision
jsonargparse[signatures] # for CLI

transformers >= 4.25.0
datasets >= 2.8.0
transformers
datasets
evaluate
scikit-learn

# dev tools
Expand Down
3 changes: 1 addition & 2 deletions src/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from utils import loggers

from . import callbacks, datamodules, models
from .utils import loggers

__all__ = ["callbacks", "datamodules", "models", "loggers"]
32 changes: 15 additions & 17 deletions src/callbacks/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,20 @@
from pathlib import Path
from typing import Optional

import pytorch_lightning as pl
from pytorch_lightning import Callback
from pytorch_lightning.callbacks import BatchSizeFinder
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities.metrics import metrics_to_scalars
import lightning.pytorch as pl
from lightning.fabric.utilities.apply_func import convert_tensors_to_scalars
from lightning.pytorch.trainer.states import TrainerFn


class Metric(Callback):
class Metric(pl.Callback):
r"""
Save logged metrics to ``Trainer.log_dir``.
"""

def teardown(
self,
trainer: "pl.Trainer",
pl_module: "pl.LightningModule",
trainer: pl.Trainer,
pl_module: pl.LightningModule,
stage: Optional[str] = None,
) -> None:
metrics = {}
Expand All @@ -29,10 +27,10 @@ def teardown(
):
ckpt_path = trainer.checkpoint_callback.best_model_path
# inhibit disturbing logging
logging.getLogger("pytorch_lightning.utilities.distributed").setLevel(
logging.getLogger("lightning.pytorch.utilities.distributed").setLevel(
logging.WARNING
)
logging.getLogger("pytorch_lightning.accelerators.gpu").setLevel(
logging.getLogger("lightning.pytorch.accelerators.gpu").setLevel(
logging.WARNING
)

Expand All @@ -43,20 +41,20 @@ def teardown(
}

val_metrics = {}
if trainer._data_connector._val_dataloader_source.is_defined():
trainer.callbacks = [BatchSizeFinder()]
if trainer.validate_loop._data_source.is_defined():
trainer.callbacks = []
trainer.validate(**fn_kwargs)
val_metrics = metrics_to_scalars(trainer.logged_metrics)
val_metrics = convert_tensors_to_scalars(trainer.logged_metrics)

test_metrics = {}
if trainer._data_connector._test_dataloader_source.is_defined():
trainer.callbacks = [BatchSizeFinder()]
if trainer.test_loop._data_source.is_defined():
trainer.callbacks = []
trainer.test(**fn_kwargs)
test_metrics = metrics_to_scalars(trainer.logged_metrics)
test_metrics = convert_tensors_to_scalars(trainer.logged_metrics)

metrics = {**val_metrics, **test_metrics}
else:
metrics = metrics_to_scalars(trainer.logged_metrics)
metrics = convert_tensors_to_scalars(trainer.logged_metrics)

if metrics:
metrics_str = json.dumps(metrics, ensure_ascii=False, indent=2)
Expand Down
6 changes: 3 additions & 3 deletions src/datamodules/glue_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
from functools import partial
from typing import Literal, Optional

import lightning.pytorch as pl
from datasets import load_dataset
from pytorch_lightning import LightningDataModule
from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS
from lightning.pytorch.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS
from torch.utils.data import DataLoader

warnings.filterwarnings(
Expand All @@ -25,7 +25,7 @@
]


class GLUEDataModule(LightningDataModule):
class GLUEDataModule(pl.LightningDataModule):
task_text_field_map = {
"cola": ["sentence"],
"sst2": ["sentence"],
Expand Down
6 changes: 3 additions & 3 deletions src/datamodules/mnist_datamodule.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from typing import Optional

from pytorch_lightning import LightningDataModule
from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS
import lightning.pytorch as pl
from lightning.pytorch.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
from torchvision.transforms import transforms


class MNISTDataModule(LightningDataModule):
class MNISTDataModule(pl.LightningDataModule):
def __init__(
self,
data_dir: str = "data/",
Expand Down
37 changes: 23 additions & 14 deletions src/models/glue_transformer.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from functools import partial
from typing import Any, Optional, Union

import datasets
import evaluate
import lightning.pytorch as pl
import torch
from pytorch_lightning import LightningModule
from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT
from lightning.pytorch.utilities.types import STEP_OUTPUT
from transformers import (
AutoModelForSequenceClassification,
AutoTokenizer,
Expand All @@ -13,7 +13,7 @@
)


class GLUETransformer(LightningModule):
class GLUETransformer(pl.LightningModule):
def __init__(
self,
task_name: str,
Expand All @@ -36,7 +36,10 @@ def __init__(
self.model = AutoModelForSequenceClassification.from_pretrained(
model_name_or_path
)
self.metric = datasets.load_metric("glue", task_name)
self.metric = evaluate.load("glue", task_name)

self.validation_step_outputs = []
self.test_step_outputs = []

def forward(self, batch):
return self.model.forward(**batch)
Expand All @@ -61,14 +64,18 @@ def training_step(
def validation_step(
self, batch, batch_idx: int, dataloader_idx: Optional[int] = None
) -> Optional[STEP_OUTPUT]:
return self.shared_step(batch)
output = self.shared_step(batch)
self.validation_step_outputs.append(output)
return output

def test_step(
self, batch, batch_idx: int, dataloader_idx: Optional[int] = None
) -> Optional[STEP_OUTPUT]:
return self.shared_step(batch)
output = self.shared_step(batch)
self.test_step_outputs.append(output)
return output

def shared_epoch_end(self, outputs: EPOCH_OUTPUT, step: str) -> None:
def shared_epoch_end(self, outputs, step: str) -> None:
if hasattr(self.trainer.datamodule, f"{step}_splits"):
splits = getattr(self.trainer.datamodule, f"{step}_splits")
if len(splits) > 1:
Expand Down Expand Up @@ -104,14 +111,16 @@ def shared_epoch_end(self, outputs: EPOCH_OUTPUT, step: str) -> None:
self.log(f"{step}/loss", loss)
self.log_dict(metrics, prog_bar=True)

def training_epoch_end(self, outputs: EPOCH_OUTPUT) -> None:
return self.shared_epoch_end(outputs, "train")
def on_training_epoch_end(self) -> None:
self.shared_epoch_end(self.training_step_outputs, "train")

def validation_epoch_end(self, outputs: EPOCH_OUTPUT) -> None:
return self.shared_epoch_end(outputs, "val")
def on_validation_epoch_end(self) -> None:
self.shared_epoch_end(self.validation_step_outputs, "val")
self.validation_step_outputs.clear()

def test_epoch_end(self, outputs: EPOCH_OUTPUT) -> None:
return self.shared_epoch_end(outputs, "test")
def on_test_epoch_end(self) -> None:
self.shared_epoch_end(self.test_step_outputs, "test")
self.test_step_outputs.clear()

def configure_optimizers(self):
no_decay = ["bias", "LayerNorm.weight"]
Expand Down
8 changes: 4 additions & 4 deletions src/models/mnist_model.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from typing import Optional

import lightning.pytorch as pl
import torch
import torch.nn.functional as F
from pytorch_lightning import LightningModule
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.utilities.types import STEP_OUTPUT
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
from lightning.pytorch.utilities.types import STEP_OUTPUT
from torchmetrics import Accuracy, MetricCollection


class MNISTModel(LightningModule):
class MNISTModel(pl.LightningModule):
def __init__(
self,
input_size: int = 28 * 28,
Expand Down
14 changes: 9 additions & 5 deletions src/utils/lit_cli.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import os
from typing import Iterable

from pytorch_lightning.cli import LightningArgumentParser, LightningCLI
import torch
from lightning.pytorch.cli import LightningArgumentParser, LightningCLI


class LitCLI(LightningCLI):
Expand All @@ -16,10 +17,6 @@ def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None:
def before_instantiate_classes(self) -> None:
config = self.config[self.subcommand]

# HACK: https://github.com/Lightning-AI/lightning/issues/15233
if config.trainer.fast_dev_run:
config.trainer.logger = None

logger = config.trainer.logger
if logger and logger is not True:
loggers = logger if isinstance(logger, Iterable) else [logger]
Expand All @@ -35,6 +32,13 @@ def before_instantiate_classes(self) -> None:
if hasattr(logger.init_args, "name"):
logger.init_args.name = exp_name

def before_run(self) -> None:
if hasattr(torch, "compile"):
# https://pytorch.org/get-started/pytorch-2.0/#user-experience
torch.compile(self.model)

before_fit = before_validate = before_test = before_run


def lit_cli():
LitCLI(
Expand Down
28 changes: 5 additions & 23 deletions src/utils/loggers.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import os
from datetime import datetime

from lightning_fabric.utilities.types import _PATH
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
import lightning.pytorch as pl
from lightning.fabric.utilities.types import _PATH
from lightning.pytorch.callbacks import ModelCheckpoint


@property
Expand All @@ -18,10 +16,10 @@ def log_dir(self) -> str:
return dirpath


Trainer.log_dir = log_dir
pl.Trainer.log_dir = log_dir


def __resolve_ckpt_dir(self, trainer: Trainer) -> _PATH:
def __resolve_ckpt_dir(self, trainer: pl.Trainer) -> _PATH:
"""Determines model checkpoint save directory at runtime. References attributes from the trainer's logger
to determine where to save checkpoints. The base path for saving weights is set in this priority:
1. Checkpoint callback's path (if passed in)
Expand All @@ -39,19 +37,3 @@ def __resolve_ckpt_dir(self, trainer: Trainer) -> _PATH:


ModelCheckpoint._ModelCheckpoint__resolve_ckpt_dir = __resolve_ckpt_dir


@property
def TensorBoardLogger_version(self) -> str:
"""Get the experiment version.
Returns:
The experiment version if specified else current timestamp.
"""
if self._version is None:
self._version = datetime.now().strftime("%m-%dT%H%M%S")

return self._version


TensorBoardLogger.version = TensorBoardLogger_version
1 change: 0 additions & 1 deletion src/utils/tweak_shtab.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# HACK for shtab
# https://github.com/omni-us/jsonargparse/issues/127
# https://github.com/iterative/shtab/issues/65

import shtab
Expand Down

0 comments on commit e7711a1

Please sign in to comment.