|
12 | 12 | # See the License for the specific language governing permissions and
|
13 | 13 | # limitations under the License.
|
14 | 14 | import os
|
| 15 | +from copy import deepcopy |
15 | 16 | from unittest import mock
|
16 | 17 | from unittest.mock import Mock
|
17 | 18 |
|
18 | 19 | import pytest
|
19 | 20 | import torch
|
| 21 | +from torch.utils.data import DataLoader |
| 22 | + |
20 | 23 | from lightning.pytorch import Trainer
|
21 | 24 | from lightning.pytorch.callbacks import ModelCheckpoint
|
22 |
| -from lightning.pytorch.demos.boring_classes import BoringModel |
| 25 | +from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset |
23 | 26 | from lightning.pytorch.trainer.states import TrainerFn
|
24 | 27 | from lightning.pytorch.utilities.migration.utils import _set_version
|
| 28 | +from lightning.pytorch.trainer.connectors.checkpoint_connector import _CheckpointConnector |
25 | 29 |
|
26 | 30 |
|
27 | 31 | def test_preloaded_checkpoint_lifecycle(tmpdir):
|
@@ -217,3 +221,44 @@ def test_stateful_trainer_ckpt_path_support(tmp_path):
|
217 | 221 | assert not trainer._checkpoint_connector._user_managed
|
218 | 222 | trainer.test()
|
219 | 223 | assert trainer.ckpt_path == best_path
|
| 224 | + |
| 225 | + |
| 226 | +class StatefulDataLoader(DataLoader): |
| 227 | + def __init__(self, *args, **kwargs): |
| 228 | + super().__init__(*args, **kwargs) |
| 229 | + self._counter = 0 |
| 230 | + |
| 231 | + def state_dict(self): |
| 232 | + return {"counter": self._counter} |
| 233 | + |
| 234 | + def load_state_dict(self, state_dict): |
| 235 | + self._counter = state_dict["counter"] |
| 236 | + |
| 237 | + |
| 238 | +@pytest.mark.parametrize(("train_dataloaders", "expected_states"), [ |
| 239 | + ([], None), |
| 240 | + (StatefulDataLoader(RandomDataset(32, 64)), [{"counter": 0}]), |
| 241 | +]) |
| 242 | +def test_train_dataloaders_restore(train_dataloaders, expected_states, tmp_path): |
| 243 | + |
| 244 | + class TestModel(BoringModel): |
| 245 | + def train_dataloader(self): |
| 246 | + return train_dataloaders |
| 247 | + |
| 248 | + model = TestModel() |
| 249 | + trainer = Trainer( |
| 250 | + default_root_dir=tmp_path, |
| 251 | + accelerator="cpu", |
| 252 | + max_steps=1, |
| 253 | + enable_checkpointing=False, |
| 254 | + enable_model_summary=False, |
| 255 | + enable_progress_bar=False, |
| 256 | + logger=False, |
| 257 | + num_sanity_val_steps=0, |
| 258 | + ) |
| 259 | + trainer.fit(model) |
| 260 | + checkpoint = trainer._checkpoint_connector.dump_checkpoint() |
| 261 | + if expected_states is None: |
| 262 | + assert "train_dataloaders" not in checkpoint |
| 263 | + else: |
| 264 | + assert checkpoint["train_dataloaders"] == expected_states |
0 commit comments