Skip to content

Conversation

@arrdel
Copy link
Contributor

@arrdel arrdel commented Jan 6, 2026

This PR implements manual dataloader reloading for dynamic DataLoader reconfiguration during training without exiting the fit loop. Addresses issue #21448.


📚 Documentation preview 📚: https://pytorch-lightning--21473.org.readthedocs.build/en/21473/

arrdel and others added 10 commits December 5, 2025 21:58
…ameter loading

Fixes Lightning-AI#21255

This commit adds the adapt_checkpoint_hparams() public method to LightningCLI,
allowing users to customize hyperparameters loaded from checkpoints before they
are used to instantiate model classes. This is particularly useful when using
checkpoints from a TrainingModule with a different InferenceModule class that
has different __init__ parameters.

Problem:
When loading a checkpoint trained with TrainingModule(lr=1e-3) into an
InferenceModule() that doesn't accept 'lr' as a parameter, the CLI would fail
during instantiation because it tries to pass all checkpoint hyperparameters
to the new module class.

Solution:
Added adapt_checkpoint_hparams() hook that is called in _parse_ckpt_path()
after loading checkpoint hyperparameters but before applying them. Users can
override this method to:
- Remove training-specific hyperparameters (e.g., lr, weight_decay)
- Modify _class_path for subclass mode
- Transform hyperparameter names/values
- Completely disable checkpoint hyperparameters by returning {}

Example usage:
    class MyCLI(LightningCLI):
        def adapt_checkpoint_hparams(self, checkpoint_hparams):
            checkpoint_hparams.pop('lr', None)
            checkpoint_hparams.pop('weight_decay', None)
            return checkpoint_hparams

This approach is preferable to:
- Disabling checkpoint loading entirely (loses valuable hyperparameter info)
- Adding CLI arguments (deviates from Trainer parameter pattern)
- Modifying private methods (breaks encapsulation)

The hook provides maximum flexibility while maintaining backward compatibility
(default implementation returns hyperparameters unchanged).
…ook and add tests

- Update adapt_checkpoint_hparams signature to include subcommand parameter
  allowing context-aware customization of checkpoint hyperparameters
- Change type annotations to use lowercase dict (Python 3.9+ style)
- Update docstring with subcommand parameter documentation
- Add example showing conditional logic based on subcommand
- Add comprehensive unit tests:
  - test_adapt_checkpoint_hparams_hook: Tests that hook is called and modifications applied
  - test_adapt_checkpoint_hparams_hook_empty_dict: Tests disabling checkpoint hparams loading
- Tests cover both regular and subclass modes
- Split method signature across multiple lines to stay within 120 char limit
- Improves code readability in documentation example
Removed redundant method implementations since BoringModel provides them.
The test was asserting hidden_dim==3 but only passing out_dim=3.
Since hidden_dim defaults to 16 and there's no argument linking,
the assertion failed. Now we explicitly pass --model.hidden_dim=6.
Copilot AI review requested due to automatic review settings January 6, 2026 01:17
@github-actions github-actions bot added docs Documentation related pl Generic label for PyTorch Lightning package labels Jan 6, 2026
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR implements manual dataloader reloading functionality for PyTorch Lightning, addressing issue #21448. The feature allows dynamic DataLoader reconfiguration during training without exiting the fit loop, enabling use cases like curriculum learning and adaptive training strategies.

Key changes:

  • Added trainer.reload_dataloaders() method for manual dataloader reloading during training
  • Added comprehensive test suite covering various reload scenarios
  • Added documentation and examples for the new feature

Note: The PR also includes an unrelated CLI feature (adapt_checkpoint_hparams hook) which appears to be a separate enhancement.

Reviewed changes

Copilot reviewed 5 out of 5 changed files in this pull request and generated 2 comments.

Show a summary per file
File Description
src/lightning/pytorch/trainer/trainer.py Implements the reload_dataloaders() method that allows manual triggering of dataloader reload during training by setting internal epoch tracking variables
tests/tests_pytorch/trainer/test_reload_dataloaders.py Comprehensive test suite for the manual reload feature covering various scenarios including single/multiple reloads, parameter updates, and interaction with existing auto-reload functionality
src/lightning/pytorch/cli.py Adds adapt_checkpoint_hparams() hook for customizing checkpoint hyperparameters (unrelated to dataloader reloading)
tests/tests_pytorch/test_cli.py Test coverage for the new CLI hook and adds AdaptHparamsModel test fixture
docs/source-pytorch/data/access.rst Documentation for both automatic and manual dataloader reloading mechanisms with practical examples

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +1209 to +1262
def reload_dataloaders(self, train: bool = True, val: bool = False) -> None:
"""Manually trigger a reload of dataloaders during training.
This method allows dynamic reconfiguration of DataLoaders without exiting the ``fit()`` loop.
It's useful for curriculum learning, adaptive training strategies, or any scenario where
DataLoader parameters need to change based on training metrics or progress.
The reload will occur at the start of the next epoch during training.
Args:
train: If ``True``, reload the train dataloader. Default: ``True``.
val: If ``True``, reload the validation dataloader. Default: ``False``.
Example::
# In a callback
def on_train_epoch_end(self, trainer, pl_module):
if trainer.current_epoch == 5:
# Update datamodule parameters
trainer.datamodule.sequence_length += 10
# Trigger reload for next epoch
trainer.reload_dataloaders(train=True, val=True)
# In a LightningModule
def on_train_batch_end(self, outputs, batch, batch_idx):
if self.trainer.callback_metrics.get('train_loss', 1.0) < 0.5:
self.trainer.datamodule.unroll_steps += 1
self.trainer.reload_dataloaders()
Raises:
RuntimeError: If called outside of a ``fit()`` call.
.. note::
The actual reload happens at the beginning of the next training epoch,
not immediately when this method is called. This ensures training state
consistency and proper synchronization in distributed settings.
"""
if not self.training:
raise RuntimeError(
"`trainer.reload_dataloaders()` can only be called during training (inside `trainer.fit()`)."
)

if train:
# Setting to -inf ensures _should_reload_train_dl returns True
self.fit_loop._last_train_dl_reload_epoch = float("-inf")
rank_zero_info("Train dataloader will be reloaded at the start of the next epoch.")

if val:
# Setting to -inf ensures _should_reload_val_dl returns True
self.fit_loop.epoch_loop.val_loop._last_val_dl_reload_epoch = float("-inf")
rank_zero_info("Validation dataloader will be reloaded at the next validation check.")

Copy link

Copilot AI Jan 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The method should validate that at least one of train or val is True. Currently, if both are False, the method will execute without performing any action or providing feedback to the user, which could be confusing.

Copilot uses AI. Check for mistakes.
Comment on lines +563 to +596
def adapt_checkpoint_hparams(self, subcommand: str, checkpoint_hparams: dict[str, Any]) -> dict[str, Any]:
"""Adapt checkpoint hyperparameters before instantiating the model class.
This method allows for customization of hyperparameters loaded from a checkpoint when
using a different model class than the one used for training. For example, when loading
a checkpoint from a TrainingModule to use with an InferenceModule that has different
``__init__`` parameters, you can remove or modify incompatible hyperparameters.
Args:
subcommand: The subcommand being executed (e.g., 'fit', 'validate', 'test', 'predict').
This allows you to apply different hyperparameter adaptations depending on the context.
checkpoint_hparams: Dictionary of hyperparameters loaded from the checkpoint.
Returns:
Dictionary of adapted hyperparameters to be used for model instantiation.
Example::
class MyCLI(LightningCLI):
def adapt_checkpoint_hparams(
self, subcommand: str, checkpoint_hparams: dict[str, Any]
) -> dict[str, Any]:
# Only remove training-specific hyperparameters for non-fit subcommands
if subcommand != "fit":
checkpoint_hparams.pop("lr", None)
checkpoint_hparams.pop("weight_decay", None)
return checkpoint_hparams
Note:
If subclass module mode is enabled and ``_class_path`` is present in the checkpoint
hyperparameters, you may need to modify it as well to point to your new module class.
"""
return checkpoint_hparams
Copy link

Copilot AI Jan 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR introduces the adapt_checkpoint_hparams hook in the CLI, which appears to be an unrelated feature to manual dataloader reloading. Consider splitting this into a separate PR to keep changes focused and make review/maintenance easier. The PR title and description only mention dataloader reloading but don't mention the CLI checkpoint hyperparameters adaptation feature.

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

docs Documentation related pl Generic label for PyTorch Lightning package

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant