From 5317545fcad28b947830affc594adb9e047ffe85 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Mon, 29 Jan 2024 03:44:57 +0100 Subject: [PATCH] test restore --- .../trainer/connectors/checkpoint_connector.py | 4 +++- .../connectors/test_checkpoint_connector.py | 16 +++++++++++++--- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/src/lightning/pytorch/trainer/connectors/checkpoint_connector.py b/src/lightning/pytorch/trainer/connectors/checkpoint_connector.py index 959fd493a1f20..52aaa93fea104 100644 --- a/src/lightning/pytorch/trainer/connectors/checkpoint_connector.py +++ b/src/lightning/pytorch/trainer/connectors/checkpoint_connector.py @@ -398,7 +398,9 @@ def _restore_train_dataloaders(self) -> None: if not state_dicts: return - for train_dataloader, state_dict in zip(self.trainer.train_dataloader, state_dicts): + combined_loader = self.trainer.fit_loop._combined_loader + iterables = combined_loader.flattened if combined_loader is not None else [] + for train_dataloader, state_dict in zip(iterables, state_dicts): if isinstance(train_dataloader, _Stateful): train_dataloader.load_state_dict(state_dict) diff --git a/tests/tests_pytorch/trainer/connectors/test_checkpoint_connector.py b/tests/tests_pytorch/trainer/connectors/test_checkpoint_connector.py index f3b17ac7cbc22..dc8353e57937b 100644 --- a/tests/tests_pytorch/trainer/connectors/test_checkpoint_connector.py +++ b/tests/tests_pytorch/trainer/connectors/test_checkpoint_connector.py @@ -254,7 +254,7 @@ def __init__(self, *args, **kwargs): (CombinedLoader([NotStatefulDataLoader(3), StatefulDataLoader(1), NotStatefulDataLoader(2)]), [{"label": 1}]), ]) def test_train_dataloaders_restore(train_dataloaders, expected_states, tmp_path): - """Test that the CheckpointConnector saves the state of stateful dataloaders and can reloead them.""" + """Test that the CheckpointConnector saves the state of stateful dataloaders and can reload them.""" class DataLoaderModel(BoringModel): def training_step(self, batch, batch_idx): if isinstance(batch, list): @@ -264,8 +264,7 @@ def training_step(self, batch, batch_idx): def train_dataloader(self): return train_dataloaders - model = DataLoaderModel() - trainer = Trainer( + trainer_kwargs = dict( default_root_dir=tmp_path, accelerator="cpu", max_steps=1, @@ -275,6 +274,10 @@ def train_dataloader(self): logger=False, num_sanity_val_steps=0, ) + + model = DataLoaderModel() + trainer = Trainer(**trainer_kwargs) + # Fit to init the state of CheckpointConnector trainer.fit(model) checkpoint = trainer._checkpoint_connector.dump_checkpoint() @@ -283,3 +286,10 @@ def train_dataloader(self): assert "train_dataloaders" not in checkpoint else: assert checkpoint["train_dataloaders"] == expected_states + + torch.save(checkpoint, tmp_path / "checkpoint.ckpt") + + model = DataLoaderModel() + trainer = Trainer(**trainer_kwargs) + trainer.fit(model, ckpt_path=(tmp_path / "checkpoint.ckpt")) + # TODO: Test here