Skip to content

Commit 46aca6b

Browse files
committed
add test
1 parent f7545df commit 46aca6b

File tree

1 file changed

+46
-1
lines changed

1 file changed

+46
-1
lines changed

tests/tests_pytorch/trainer/connectors/test_checkpoint_connector.py

+46-1
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,20 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import os
15+
from copy import deepcopy
1516
from unittest import mock
1617
from unittest.mock import Mock
1718

1819
import pytest
1920
import torch
21+
from torch.utils.data import DataLoader
22+
2023
from lightning.pytorch import Trainer
2124
from lightning.pytorch.callbacks import ModelCheckpoint
22-
from lightning.pytorch.demos.boring_classes import BoringModel
25+
from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset
2326
from lightning.pytorch.trainer.states import TrainerFn
2427
from lightning.pytorch.utilities.migration.utils import _set_version
28+
from lightning.pytorch.trainer.connectors.checkpoint_connector import _CheckpointConnector
2529

2630

2731
def test_preloaded_checkpoint_lifecycle(tmpdir):
@@ -217,3 +221,44 @@ def test_stateful_trainer_ckpt_path_support(tmp_path):
217221
assert not trainer._checkpoint_connector._user_managed
218222
trainer.test()
219223
assert trainer.ckpt_path == best_path
224+
225+
226+
class StatefulDataLoader(DataLoader):
227+
def __init__(self, *args, **kwargs):
228+
super().__init__(*args, **kwargs)
229+
self._counter = 0
230+
231+
def state_dict(self):
232+
return {"counter": self._counter}
233+
234+
def load_state_dict(self, state_dict):
235+
self._counter = state_dict["counter"]
236+
237+
238+
@pytest.mark.parametrize(("train_dataloaders", "expected_states"), [
239+
([], None),
240+
(StatefulDataLoader(RandomDataset(32, 64)), [{"counter": 0}]),
241+
])
242+
def test_train_dataloaders_restore(train_dataloaders, expected_states, tmp_path):
243+
244+
class TestModel(BoringModel):
245+
def train_dataloader(self):
246+
return train_dataloaders
247+
248+
model = TestModel()
249+
trainer = Trainer(
250+
default_root_dir=tmp_path,
251+
accelerator="cpu",
252+
max_steps=1,
253+
enable_checkpointing=False,
254+
enable_model_summary=False,
255+
enable_progress_bar=False,
256+
logger=False,
257+
num_sanity_val_steps=0,
258+
)
259+
trainer.fit(model)
260+
checkpoint = trainer._checkpoint_connector.dump_checkpoint()
261+
if expected_states is None:
262+
assert "train_dataloaders" not in checkpoint
263+
else:
264+
assert checkpoint["train_dataloaders"] == expected_states

0 commit comments

Comments
 (0)