Skip to content

Commit

Permalink
Replace lightning.pytorch as pl with lightning as L
Browse files Browse the repository at this point in the history
  • Loading branch information
tshu-w committed Sep 6, 2024
1 parent b0d16f2 commit dcc53fd
Show file tree
Hide file tree
Showing 6 changed files with 15 additions and 15 deletions.
8 changes: 4 additions & 4 deletions src/callbacks/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,20 @@
import logging
from pathlib import Path

import lightning.pytorch as pl
import lightning as L
from lightning.fabric.utilities.apply_func import convert_tensors_to_scalars
from lightning.pytorch.trainer.states import TrainerFn


class Metric(pl.Callback):
class Metric(L.Callback):
r"""
Save logged metrics to ``Trainer.log_dir``.
"""

def teardown(
self,
trainer: pl.Trainer,
pl_module: pl.LightningModule,
trainer: L.Trainer,
pl_module: L.LightningModule,
stage: str | None = None,
) -> None:
metrics = {}
Expand Down
4 changes: 2 additions & 2 deletions src/datamodules/glue_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from functools import partial
from typing import Literal

import lightning.pytorch as pl
import lightning as L
from datasets import load_dataset
from lightning.pytorch.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS
from torch.utils.data import DataLoader
Expand All @@ -25,7 +25,7 @@
]


class GLUEDataModule(pl.LightningDataModule):
class GLUEDataModule(L.LightningDataModule):
task_text_field_map = {
"cola": ["sentence"],
"sst2": ["sentence"],
Expand Down
4 changes: 2 additions & 2 deletions src/datamodules/mnist_datamodule.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import lightning.pytorch as pl
import lightning as L
from lightning.pytorch.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
from torchvision.transforms import transforms


class MNISTDataModule(pl.LightningDataModule):
class MNISTDataModule(L.LightningDataModule):
def __init__(
self,
data_dir: str = "data/",
Expand Down
4 changes: 2 additions & 2 deletions src/models/glue_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Any

import evaluate
import lightning.pytorch as pl
import lightning as L
import torch
from lightning.pytorch.utilities.types import STEP_OUTPUT
from transformers import (
Expand All @@ -13,7 +13,7 @@
)


class GLUETransformer(pl.LightningModule):
class GLUETransformer(L.LightningModule):
def __init__(
self,
task_name: str,
Expand Down
4 changes: 2 additions & 2 deletions src/models/mnist_model.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import lightning.pytorch as pl
import lightning as L
import torch
import torch.nn.functional as F
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
from lightning.pytorch.utilities.types import STEP_OUTPUT
from torchmetrics import Accuracy, MetricCollection


class MNISTModel(pl.LightningModule):
class MNISTModel(L.LightningModule):
def __init__(
self,
input_size: int = 28 * 28,
Expand Down
6 changes: 3 additions & 3 deletions src/utils/loggers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os

import lightning.pytorch as pl
import lightning as L
from lightning.fabric.utilities.types import _PATH
from lightning.pytorch.callbacks import ModelCheckpoint

Expand All @@ -19,10 +19,10 @@ def log_dir(self) -> str:
return dirpath


pl.Trainer.log_dir = log_dir
L.Trainer.log_dir = log_dir


def __resolve_ckpt_dir(self, trainer: pl.Trainer) -> _PATH:
def __resolve_ckpt_dir(self, trainer: L.Trainer) -> _PATH:
"""Determines model checkpoint save directory at runtime. References attributes from the trainer's logger
to determine where to save checkpoints. The base path for saving weights is set in this priority:
1. Checkpoint callback's path (if passed in)
Expand Down

0 comments on commit dcc53fd

Please sign in to comment.