From f596cb11adf231e6300856d27d0f4e6f25fd7958 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 29 Jan 2024 02:46:25 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../connectors/checkpoint_connector.py | 5 +-- .../connectors/test_checkpoint_connector.py | 37 ++++++++++--------- 2 files changed, 22 insertions(+), 20 deletions(-) diff --git a/src/lightning/pytorch/trainer/connectors/checkpoint_connector.py b/src/lightning/pytorch/trainer/connectors/checkpoint_connector.py index 52aaa93fea104..623739cf728cb 100644 --- a/src/lightning/pytorch/trainer/connectors/checkpoint_connector.py +++ b/src/lightning/pytorch/trainer/connectors/checkpoint_connector.py @@ -14,7 +14,7 @@ import logging import os import re -from typing import Any, Dict, Optional, List +from typing import Any, Dict, List, Optional import torch from fsspec.core import url_to_fs @@ -525,8 +525,7 @@ def _get_dataloader_state_dicts(self) -> List[Dict[str, Any]]: combined_loader = self.trainer.fit_loop._combined_loader iterables = combined_loader.flattened if combined_loader is not None else [] return [ - train_dataloader.state_dict() for train_dataloader in iterables - if isinstance(train_dataloader, _Stateful) + train_dataloader.state_dict() for train_dataloader in iterables if isinstance(train_dataloader, _Stateful) ] @staticmethod diff --git a/tests/tests_pytorch/trainer/connectors/test_checkpoint_connector.py b/tests/tests_pytorch/trainer/connectors/test_checkpoint_connector.py index dc8353e57937b..89ddf3c4d0c15 100644 --- a/tests/tests_pytorch/trainer/connectors/test_checkpoint_connector.py +++ b/tests/tests_pytorch/trainer/connectors/test_checkpoint_connector.py @@ -17,14 +17,13 @@ import pytest import torch -from torch.utils.data import DataLoader - from lightning.pytorch import Trainer from lightning.pytorch.callbacks import ModelCheckpoint from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset from lightning.pytorch.trainer.states import TrainerFn -from lightning.pytorch.utilities.migration.utils import _set_version from lightning.pytorch.utilities import CombinedLoader +from lightning.pytorch.utilities.migration.utils import _set_version +from torch.utils.data import DataLoader def test_preloaded_checkpoint_lifecycle(tmpdir): @@ -239,22 +238,26 @@ def __init__(self, *args, **kwargs): super().__init__(RandomDataset(32, 64), *args, **kwargs) -@pytest.mark.parametrize(("train_dataloaders", "expected_states"), [ - # No dataloader - ([], None), - # Single stateful DataLoader - (StatefulDataLoader(), [{"label": 0}]), - # Single, not stateful DataLoader - (CombinedLoader(NotStatefulDataLoader()), None), - # Single stateful DataLoader - (CombinedLoader(StatefulDataLoader()), [{"label": 0}]), - # Multiple stateful DataLoaders - (CombinedLoader([StatefulDataLoader(3), StatefulDataLoader(1)]), [{"label": 3}, {"label": 1}]), - # Mix of stateful and not stateful DataLoaders - (CombinedLoader([NotStatefulDataLoader(3), StatefulDataLoader(1), NotStatefulDataLoader(2)]), [{"label": 1}]), -]) +@pytest.mark.parametrize( + ("train_dataloaders", "expected_states"), + [ + # No dataloader + ([], None), + # Single stateful DataLoader + (StatefulDataLoader(), [{"label": 0}]), + # Single, not stateful DataLoader + (CombinedLoader(NotStatefulDataLoader()), None), + # Single stateful DataLoader + (CombinedLoader(StatefulDataLoader()), [{"label": 0}]), + # Multiple stateful DataLoaders + (CombinedLoader([StatefulDataLoader(3), StatefulDataLoader(1)]), [{"label": 3}, {"label": 1}]), + # Mix of stateful and not stateful DataLoaders + (CombinedLoader([NotStatefulDataLoader(3), StatefulDataLoader(1), NotStatefulDataLoader(2)]), [{"label": 1}]), + ], +) def test_train_dataloaders_restore(train_dataloaders, expected_states, tmp_path): """Test that the CheckpointConnector saves the state of stateful dataloaders and can reload them.""" + class DataLoaderModel(BoringModel): def training_step(self, batch, batch_idx): if isinstance(batch, list):