-
Notifications
You must be signed in to change notification settings - Fork 3.6k
feat: Add manual dataloader reloading feature (Closes #21448) #21473
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
feat: Add manual dataloader reloading feature (Closes #21448) #21473
Conversation
…ameter loading Fixes Lightning-AI#21255 This commit adds the adapt_checkpoint_hparams() public method to LightningCLI, allowing users to customize hyperparameters loaded from checkpoints before they are used to instantiate model classes. This is particularly useful when using checkpoints from a TrainingModule with a different InferenceModule class that has different __init__ parameters. Problem: When loading a checkpoint trained with TrainingModule(lr=1e-3) into an InferenceModule() that doesn't accept 'lr' as a parameter, the CLI would fail during instantiation because it tries to pass all checkpoint hyperparameters to the new module class. Solution: Added adapt_checkpoint_hparams() hook that is called in _parse_ckpt_path() after loading checkpoint hyperparameters but before applying them. Users can override this method to: - Remove training-specific hyperparameters (e.g., lr, weight_decay) - Modify _class_path for subclass mode - Transform hyperparameter names/values - Completely disable checkpoint hyperparameters by returning {} Example usage: class MyCLI(LightningCLI): def adapt_checkpoint_hparams(self, checkpoint_hparams): checkpoint_hparams.pop('lr', None) checkpoint_hparams.pop('weight_decay', None) return checkpoint_hparams This approach is preferable to: - Disabling checkpoint loading entirely (loses valuable hyperparameter info) - Adding CLI arguments (deviates from Trainer parameter pattern) - Modifying private methods (breaks encapsulation) The hook provides maximum flexibility while maintaining backward compatibility (default implementation returns hyperparameters unchanged).
for more information, see https://pre-commit.ci
…ook and add tests - Update adapt_checkpoint_hparams signature to include subcommand parameter allowing context-aware customization of checkpoint hyperparameters - Change type annotations to use lowercase dict (Python 3.9+ style) - Update docstring with subcommand parameter documentation - Add example showing conditional logic based on subcommand - Add comprehensive unit tests: - test_adapt_checkpoint_hparams_hook: Tests that hook is called and modifications applied - test_adapt_checkpoint_hparams_hook_empty_dict: Tests disabling checkpoint hparams loading - Tests cover both regular and subclass modes
for more information, see https://pre-commit.ci
- Split method signature across multiple lines to stay within 120 char limit - Improves code readability in documentation example
… size mismatch in tests
for more information, see https://pre-commit.ci
Removed redundant method implementations since BoringModel provides them.
The test was asserting hidden_dim==3 but only passing out_dim=3. Since hidden_dim defaults to 16 and there's no argument linking, the assertion failed. Now we explicitly pass --model.hidden_dim=6.
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR implements manual dataloader reloading functionality for PyTorch Lightning, addressing issue #21448. The feature allows dynamic DataLoader reconfiguration during training without exiting the fit loop, enabling use cases like curriculum learning and adaptive training strategies.
Key changes:
- Added
trainer.reload_dataloaders()method for manual dataloader reloading during training - Added comprehensive test suite covering various reload scenarios
- Added documentation and examples for the new feature
Note: The PR also includes an unrelated CLI feature (adapt_checkpoint_hparams hook) which appears to be a separate enhancement.
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 2 comments.
Show a summary per file
| File | Description |
|---|---|
| src/lightning/pytorch/trainer/trainer.py | Implements the reload_dataloaders() method that allows manual triggering of dataloader reload during training by setting internal epoch tracking variables |
| tests/tests_pytorch/trainer/test_reload_dataloaders.py | Comprehensive test suite for the manual reload feature covering various scenarios including single/multiple reloads, parameter updates, and interaction with existing auto-reload functionality |
| src/lightning/pytorch/cli.py | Adds adapt_checkpoint_hparams() hook for customizing checkpoint hyperparameters (unrelated to dataloader reloading) |
| tests/tests_pytorch/test_cli.py | Test coverage for the new CLI hook and adds AdaptHparamsModel test fixture |
| docs/source-pytorch/data/access.rst | Documentation for both automatic and manual dataloader reloading mechanisms with practical examples |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| def reload_dataloaders(self, train: bool = True, val: bool = False) -> None: | ||
| """Manually trigger a reload of dataloaders during training. | ||
| This method allows dynamic reconfiguration of DataLoaders without exiting the ``fit()`` loop. | ||
| It's useful for curriculum learning, adaptive training strategies, or any scenario where | ||
| DataLoader parameters need to change based on training metrics or progress. | ||
| The reload will occur at the start of the next epoch during training. | ||
| Args: | ||
| train: If ``True``, reload the train dataloader. Default: ``True``. | ||
| val: If ``True``, reload the validation dataloader. Default: ``False``. | ||
| Example:: | ||
| # In a callback | ||
| def on_train_epoch_end(self, trainer, pl_module): | ||
| if trainer.current_epoch == 5: | ||
| # Update datamodule parameters | ||
| trainer.datamodule.sequence_length += 10 | ||
| # Trigger reload for next epoch | ||
| trainer.reload_dataloaders(train=True, val=True) | ||
| # In a LightningModule | ||
| def on_train_batch_end(self, outputs, batch, batch_idx): | ||
| if self.trainer.callback_metrics.get('train_loss', 1.0) < 0.5: | ||
| self.trainer.datamodule.unroll_steps += 1 | ||
| self.trainer.reload_dataloaders() | ||
| Raises: | ||
| RuntimeError: If called outside of a ``fit()`` call. | ||
| .. note:: | ||
| The actual reload happens at the beginning of the next training epoch, | ||
| not immediately when this method is called. This ensures training state | ||
| consistency and proper synchronization in distributed settings. | ||
| """ | ||
| if not self.training: | ||
| raise RuntimeError( | ||
| "`trainer.reload_dataloaders()` can only be called during training (inside `trainer.fit()`)." | ||
| ) | ||
|
|
||
| if train: | ||
| # Setting to -inf ensures _should_reload_train_dl returns True | ||
| self.fit_loop._last_train_dl_reload_epoch = float("-inf") | ||
| rank_zero_info("Train dataloader will be reloaded at the start of the next epoch.") | ||
|
|
||
| if val: | ||
| # Setting to -inf ensures _should_reload_val_dl returns True | ||
| self.fit_loop.epoch_loop.val_loop._last_val_dl_reload_epoch = float("-inf") | ||
| rank_zero_info("Validation dataloader will be reloaded at the next validation check.") | ||
|
|
Copilot
AI
Jan 6, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The method should validate that at least one of train or val is True. Currently, if both are False, the method will execute without performing any action or providing feedback to the user, which could be confusing.
| def adapt_checkpoint_hparams(self, subcommand: str, checkpoint_hparams: dict[str, Any]) -> dict[str, Any]: | ||
| """Adapt checkpoint hyperparameters before instantiating the model class. | ||
| This method allows for customization of hyperparameters loaded from a checkpoint when | ||
| using a different model class than the one used for training. For example, when loading | ||
| a checkpoint from a TrainingModule to use with an InferenceModule that has different | ||
| ``__init__`` parameters, you can remove or modify incompatible hyperparameters. | ||
| Args: | ||
| subcommand: The subcommand being executed (e.g., 'fit', 'validate', 'test', 'predict'). | ||
| This allows you to apply different hyperparameter adaptations depending on the context. | ||
| checkpoint_hparams: Dictionary of hyperparameters loaded from the checkpoint. | ||
| Returns: | ||
| Dictionary of adapted hyperparameters to be used for model instantiation. | ||
| Example:: | ||
| class MyCLI(LightningCLI): | ||
| def adapt_checkpoint_hparams( | ||
| self, subcommand: str, checkpoint_hparams: dict[str, Any] | ||
| ) -> dict[str, Any]: | ||
| # Only remove training-specific hyperparameters for non-fit subcommands | ||
| if subcommand != "fit": | ||
| checkpoint_hparams.pop("lr", None) | ||
| checkpoint_hparams.pop("weight_decay", None) | ||
| return checkpoint_hparams | ||
| Note: | ||
| If subclass module mode is enabled and ``_class_path`` is present in the checkpoint | ||
| hyperparameters, you may need to modify it as well to point to your new module class. | ||
| """ | ||
| return checkpoint_hparams |
Copilot
AI
Jan 6, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This PR introduces the adapt_checkpoint_hparams hook in the CLI, which appears to be an unrelated feature to manual dataloader reloading. Consider splitting this into a separate PR to keep changes focused and make review/maintenance easier. The PR title and description only mention dataloader reloading but don't mention the CLI checkpoint hyperparameters adaptation feature.
This PR implements manual dataloader reloading for dynamic DataLoader reconfiguration during training without exiting the fit loop. Addresses issue #21448.
📚 Documentation preview 📚: https://pytorch-lightning--21473.org.readthedocs.build/en/21473/