14
14
import logging
15
15
import os
16
16
import re
17
- from typing import Any , Dict , Optional
17
+ from typing import Any , Dict , Optional , List
18
18
19
19
import torch
20
20
from fsspec .core import url_to_fs
24
24
import lightning .pytorch as pl
25
25
from lightning .fabric .plugins .environments .slurm import SLURMEnvironment
26
26
from lightning .fabric .utilities .cloud_io import _is_dir , get_filesystem
27
- from lightning .fabric .utilities .types import _PATH
27
+ from lightning .fabric .utilities .types import _PATH , _Stateful
28
28
from lightning .pytorch .callbacks import ModelCheckpoint
29
29
from lightning .pytorch .plugins .precision import MixedPrecision
30
30
from lightning .pytorch .trainer import call
@@ -292,8 +292,8 @@ def restore_training_state(self) -> None:
292
292
293
293
assert self .trainer .state .fn is not None
294
294
if self .trainer .state .fn == TrainerFn .FITTING :
295
- # restore optimizers and schedulers state
296
295
self .restore_optimizers_and_schedulers ()
296
+ self ._restore_train_dataloaders ()
297
297
298
298
def restore_precision_plugin_state (self ) -> None :
299
299
"""Restore the precision plugin state from the pre-loaded checkpoint."""
@@ -390,6 +390,18 @@ def restore_lr_schedulers(self) -> None:
390
390
for config , lrs_state in zip (self .trainer .lr_scheduler_configs , lr_schedulers ):
391
391
config .scheduler .load_state_dict (lrs_state )
392
392
393
+ def _restore_train_dataloaders (self ) -> None :
394
+ if not self ._loaded_checkpoint :
395
+ return
396
+
397
+ state_dicts = self ._loaded_checkpoint .get ("train_dataloaders" )
398
+ if not state_dicts :
399
+ return
400
+
401
+ for train_dataloader , state_dict in zip (self .trainer .train_dataloader , state_dicts ):
402
+ if isinstance (train_dataloader , _Stateful ):
403
+ train_dataloader .load_state_dict (state_dict )
404
+
393
405
def _restore_modules_and_callbacks (self , checkpoint_path : Optional [_PATH ] = None ) -> None :
394
406
# restore modules after setup
395
407
self .resume_start (checkpoint_path )
@@ -413,6 +425,7 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict:
413
425
'optimizer_states': "PT optim's state_dict"[] # if not weights_only
414
426
'lr_schedulers': "PT sched's state_dict"[] # if not weights_only
415
427
'state_dict': Model's state_dict (e.g. network weights)
428
+ 'train_dataloaders': List of states of the training dataloader(s), if any of them are stateful
416
429
precision_plugin.__class__.__qualname__: precision plugin state_dict # if not weights_only
417
430
CHECKPOINT_HYPER_PARAMS_NAME:
418
431
CHECKPOINT_HYPER_PARAMS_KEY:
@@ -460,6 +473,10 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict:
460
473
checkpoint [prec_plugin .__class__ .__qualname__ ] = prec_plugin_state_dict
461
474
prec_plugin .on_save_checkpoint (checkpoint )
462
475
476
+ # training dataloader(s)
477
+ if train_dataloaders := self ._get_dataloader_state_dicts ():
478
+ checkpoint ["train_dataloaders" ] = train_dataloaders
479
+
463
480
if _OMEGACONF_AVAILABLE :
464
481
from omegaconf import Container
465
482
@@ -502,6 +519,12 @@ def _get_loops_state_dict(self) -> Dict[str, Any]:
502
519
"predict_loop" : self .trainer .predict_loop .state_dict (),
503
520
}
504
521
522
+ def _get_dataloader_state_dicts (self ) -> List [Dict [str , Any ]]:
523
+ return [
524
+ train_dataloader .state_dict () for train_dataloader in (self .trainer .train_dataloader or [])
525
+ if isinstance (train_dataloader , _Stateful )
526
+ ]
527
+
505
528
@staticmethod
506
529
def __max_ckpt_version_in_folder (dir_path : _PATH , name_key : str = "ckpt_" ) -> Optional [int ]:
507
530
"""List up files in `dir_path` with `name_key`, then yield maximum suffix number.
0 commit comments