Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jan 29, 2024
1 parent 5317545 commit f596cb1
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import logging
import os
import re
from typing import Any, Dict, Optional, List
from typing import Any, Dict, List, Optional

import torch
from fsspec.core import url_to_fs
Expand Down Expand Up @@ -525,8 +525,7 @@ def _get_dataloader_state_dicts(self) -> List[Dict[str, Any]]:
combined_loader = self.trainer.fit_loop._combined_loader
iterables = combined_loader.flattened if combined_loader is not None else []
return [
train_dataloader.state_dict() for train_dataloader in iterables
if isinstance(train_dataloader, _Stateful)
train_dataloader.state_dict() for train_dataloader in iterables if isinstance(train_dataloader, _Stateful)
]

@staticmethod
Expand Down
37 changes: 20 additions & 17 deletions tests/tests_pytorch/trainer/connectors/test_checkpoint_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,13 @@

import pytest
import torch
from torch.utils.data import DataLoader

from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset
from lightning.pytorch.trainer.states import TrainerFn
from lightning.pytorch.utilities.migration.utils import _set_version
from lightning.pytorch.utilities import CombinedLoader
from lightning.pytorch.utilities.migration.utils import _set_version
from torch.utils.data import DataLoader


def test_preloaded_checkpoint_lifecycle(tmpdir):
Expand Down Expand Up @@ -239,22 +238,26 @@ def __init__(self, *args, **kwargs):
super().__init__(RandomDataset(32, 64), *args, **kwargs)


@pytest.mark.parametrize(("train_dataloaders", "expected_states"), [
# No dataloader
([], None),
# Single stateful DataLoader
(StatefulDataLoader(), [{"label": 0}]),
# Single, not stateful DataLoader
(CombinedLoader(NotStatefulDataLoader()), None),
# Single stateful DataLoader
(CombinedLoader(StatefulDataLoader()), [{"label": 0}]),
# Multiple stateful DataLoaders
(CombinedLoader([StatefulDataLoader(3), StatefulDataLoader(1)]), [{"label": 3}, {"label": 1}]),
# Mix of stateful and not stateful DataLoaders
(CombinedLoader([NotStatefulDataLoader(3), StatefulDataLoader(1), NotStatefulDataLoader(2)]), [{"label": 1}]),
])
@pytest.mark.parametrize(
("train_dataloaders", "expected_states"),
[
# No dataloader
([], None),
# Single stateful DataLoader
(StatefulDataLoader(), [{"label": 0}]),
# Single, not stateful DataLoader
(CombinedLoader(NotStatefulDataLoader()), None),
# Single stateful DataLoader
(CombinedLoader(StatefulDataLoader()), [{"label": 0}]),
# Multiple stateful DataLoaders
(CombinedLoader([StatefulDataLoader(3), StatefulDataLoader(1)]), [{"label": 3}, {"label": 1}]),
# Mix of stateful and not stateful DataLoaders
(CombinedLoader([NotStatefulDataLoader(3), StatefulDataLoader(1), NotStatefulDataLoader(2)]), [{"label": 1}]),
],
)
def test_train_dataloaders_restore(train_dataloaders, expected_states, tmp_path):
"""Test that the CheckpointConnector saves the state of stateful dataloaders and can reload them."""

class DataLoaderModel(BoringModel):
def training_step(self, batch, batch_idx):
if isinstance(batch, list):
Expand Down

0 comments on commit f596cb1

Please sign in to comment.