Skip to content

Commit f7545df

Browse files
committed
implement stateful dataloader serialization
1 parent 9d35c61 commit f7545df

File tree

1 file changed

+26
-3
lines changed

1 file changed

+26
-3
lines changed

Diff for: src/lightning/pytorch/trainer/connectors/checkpoint_connector.py

+26-3
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import logging
1515
import os
1616
import re
17-
from typing import Any, Dict, Optional
17+
from typing import Any, Dict, Optional, List
1818

1919
import torch
2020
from fsspec.core import url_to_fs
@@ -24,7 +24,7 @@
2424
import lightning.pytorch as pl
2525
from lightning.fabric.plugins.environments.slurm import SLURMEnvironment
2626
from lightning.fabric.utilities.cloud_io import _is_dir, get_filesystem
27-
from lightning.fabric.utilities.types import _PATH
27+
from lightning.fabric.utilities.types import _PATH, _Stateful
2828
from lightning.pytorch.callbacks import ModelCheckpoint
2929
from lightning.pytorch.plugins.precision import MixedPrecision
3030
from lightning.pytorch.trainer import call
@@ -292,8 +292,8 @@ def restore_training_state(self) -> None:
292292

293293
assert self.trainer.state.fn is not None
294294
if self.trainer.state.fn == TrainerFn.FITTING:
295-
# restore optimizers and schedulers state
296295
self.restore_optimizers_and_schedulers()
296+
self._restore_train_dataloaders()
297297

298298
def restore_precision_plugin_state(self) -> None:
299299
"""Restore the precision plugin state from the pre-loaded checkpoint."""
@@ -390,6 +390,18 @@ def restore_lr_schedulers(self) -> None:
390390
for config, lrs_state in zip(self.trainer.lr_scheduler_configs, lr_schedulers):
391391
config.scheduler.load_state_dict(lrs_state)
392392

393+
def _restore_train_dataloaders(self) -> None:
394+
if not self._loaded_checkpoint:
395+
return
396+
397+
state_dicts = self._loaded_checkpoint.get("train_dataloaders")
398+
if not state_dicts:
399+
return
400+
401+
for train_dataloader, state_dict in zip(self.trainer.train_dataloader, state_dicts):
402+
if isinstance(train_dataloader, _Stateful):
403+
train_dataloader.load_state_dict(state_dict)
404+
393405
def _restore_modules_and_callbacks(self, checkpoint_path: Optional[_PATH] = None) -> None:
394406
# restore modules after setup
395407
self.resume_start(checkpoint_path)
@@ -413,6 +425,7 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict:
413425
'optimizer_states': "PT optim's state_dict"[] # if not weights_only
414426
'lr_schedulers': "PT sched's state_dict"[] # if not weights_only
415427
'state_dict': Model's state_dict (e.g. network weights)
428+
'train_dataloaders': List of states of the training dataloader(s), if any of them are stateful
416429
precision_plugin.__class__.__qualname__: precision plugin state_dict # if not weights_only
417430
CHECKPOINT_HYPER_PARAMS_NAME:
418431
CHECKPOINT_HYPER_PARAMS_KEY:
@@ -460,6 +473,10 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict:
460473
checkpoint[prec_plugin.__class__.__qualname__] = prec_plugin_state_dict
461474
prec_plugin.on_save_checkpoint(checkpoint)
462475

476+
# training dataloader(s)
477+
if train_dataloaders := self._get_dataloader_state_dicts():
478+
checkpoint["train_dataloaders"] = train_dataloaders
479+
463480
if _OMEGACONF_AVAILABLE:
464481
from omegaconf import Container
465482

@@ -502,6 +519,12 @@ def _get_loops_state_dict(self) -> Dict[str, Any]:
502519
"predict_loop": self.trainer.predict_loop.state_dict(),
503520
}
504521

522+
def _get_dataloader_state_dicts(self) -> List[Dict[str, Any]]:
523+
return [
524+
train_dataloader.state_dict() for train_dataloader in (self.trainer.train_dataloader or [])
525+
if isinstance(train_dataloader, _Stateful)
526+
]
527+
505528
@staticmethod
506529
def __max_ckpt_version_in_folder(dir_path: _PATH, name_key: str = "ckpt_") -> Optional[int]:
507530
"""List up files in `dir_path` with `name_key`, then yield maximum suffix number.

0 commit comments

Comments
 (0)