Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions docs/source-pytorch/data/access.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,47 @@ If you are using a :class:`~lightning.pytorch.utilities.CombinedLoader`. A flatt
updated.append(new_dl)
# it also allows you to easily replace the dataloaders
combined_loader.flattened = updated


Reloading DataLoaders During Training
-------------------------------------

Lightning provides two mechanisms for reloading dataloaders during training:

**Automatic reload with** ``reload_dataloaders_every_n_epochs``

Set ``reload_dataloaders_every_n_epochs`` in the Trainer to automatically reload dataloaders at regular intervals:

.. code-block:: python

trainer = Trainer(reload_dataloaders_every_n_epochs=5)

This is useful when your dataset changes periodically, such as in online learning scenarios.

**Manual reload with** ``trainer.reload_dataloaders()``

For dynamic scenarios like curriculum learning or adaptive training strategies, use
:meth:`~lightning.pytorch.trainer.trainer.Trainer.reload_dataloaders` to trigger a reload
based on training metrics or other conditions:

.. code-block:: python

class CurriculumCallback(Callback):
def on_train_epoch_end(self, trainer, pl_module):
if trainer.callback_metrics.get("train_loss", 1.0) < 0.5:
# Update datamodule parameters
trainer.datamodule.difficulty_level += 1
# Trigger reload for next epoch
trainer.reload_dataloaders(train=True, val=True)

Or directly from your LightningModule:

.. code-block:: python

class MyModel(LightningModule):
def on_train_batch_end(self, outputs, batch, batch_idx):
if self.trainer.callback_metrics.get("train_loss", 1.0) < 0.5:
self.trainer.datamodule.sequence_length += 10
self.trainer.reload_dataloaders()

The reload happens at the start of the next epoch, ensuring training state consistency.
41 changes: 41 additions & 0 deletions src/lightning/pytorch/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,6 +560,41 @@ def parse_arguments(self, parser: LightningArgumentParser, args: ArgsType) -> No
else:
self.config = parser.parse_args(args)

def adapt_checkpoint_hparams(self, subcommand: str, checkpoint_hparams: dict[str, Any]) -> dict[str, Any]:
"""Adapt checkpoint hyperparameters before instantiating the model class.

This method allows for customization of hyperparameters loaded from a checkpoint when
using a different model class than the one used for training. For example, when loading
a checkpoint from a TrainingModule to use with an InferenceModule that has different
``__init__`` parameters, you can remove or modify incompatible hyperparameters.

Args:
subcommand: The subcommand being executed (e.g., 'fit', 'validate', 'test', 'predict').
This allows you to apply different hyperparameter adaptations depending on the context.
checkpoint_hparams: Dictionary of hyperparameters loaded from the checkpoint.

Returns:
Dictionary of adapted hyperparameters to be used for model instantiation.

Example::

class MyCLI(LightningCLI):
def adapt_checkpoint_hparams(
self, subcommand: str, checkpoint_hparams: dict[str, Any]
) -> dict[str, Any]:
# Only remove training-specific hyperparameters for non-fit subcommands
if subcommand != "fit":
checkpoint_hparams.pop("lr", None)
checkpoint_hparams.pop("weight_decay", None)
return checkpoint_hparams

Note:
If subclass module mode is enabled and ``_class_path`` is present in the checkpoint
hyperparameters, you may need to modify it as well to point to your new module class.

"""
return checkpoint_hparams
Comment on lines +563 to +596
Copy link

Copilot AI Jan 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR introduces the adapt_checkpoint_hparams hook in the CLI, which appears to be an unrelated feature to manual dataloader reloading. Consider splitting this into a separate PR to keep changes focused and make review/maintenance easier. The PR title and description only mention dataloader reloading but don't mention the CLI checkpoint hyperparameters adaptation feature.

Copilot uses AI. Check for mistakes.

def _parse_ckpt_path(self) -> None:
"""If a checkpoint path is given, parse the hyperparameters from the checkpoint and update the config."""
if not self.config.get("subcommand"):
Expand All @@ -571,6 +606,12 @@ def _parse_ckpt_path(self) -> None:
hparams.pop("_instantiator", None)
if not hparams:
return

# Allow customization of checkpoint hyperparameters via adapt_checkpoint_hparams hook
hparams = self.adapt_checkpoint_hparams(self.config.subcommand, hparams)
if not hparams:
return

if "_class_path" in hparams:
hparams = {
"class_path": hparams.pop("_class_path"),
Expand Down
54 changes: 54 additions & 0 deletions src/lightning/pytorch/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1206,6 +1206,60 @@ def print(self, *args: Any, **kwargs: Any) -> None:
if self.local_rank == 0:
print(*args, **kwargs)

def reload_dataloaders(self, train: bool = True, val: bool = False) -> None:
"""Manually trigger a reload of dataloaders during training.

This method allows dynamic reconfiguration of DataLoaders without exiting the ``fit()`` loop.
It's useful for curriculum learning, adaptive training strategies, or any scenario where
DataLoader parameters need to change based on training metrics or progress.

The reload will occur at the start of the next epoch during training.

Args:
train: If ``True``, reload the train dataloader. Default: ``True``.
val: If ``True``, reload the validation dataloader. Default: ``False``.

Example::

# In a callback
def on_train_epoch_end(self, trainer, pl_module):
if trainer.current_epoch == 5:
# Update datamodule parameters
trainer.datamodule.sequence_length += 10
# Trigger reload for next epoch
trainer.reload_dataloaders(train=True, val=True)

# In a LightningModule
def on_train_batch_end(self, outputs, batch, batch_idx):
if self.trainer.callback_metrics.get('train_loss', 1.0) < 0.5:
self.trainer.datamodule.unroll_steps += 1
self.trainer.reload_dataloaders()

Raises:
RuntimeError: If called outside of a ``fit()`` call.

.. note::

The actual reload happens at the beginning of the next training epoch,
not immediately when this method is called. This ensures training state
consistency and proper synchronization in distributed settings.

"""
if not self.training:
raise RuntimeError(
"`trainer.reload_dataloaders()` can only be called during training (inside `trainer.fit()`)."
)

if train:
# Setting to -inf ensures _should_reload_train_dl returns True
self.fit_loop._last_train_dl_reload_epoch = float("-inf")
rank_zero_info("Train dataloader will be reloaded at the start of the next epoch.")

if val:
# Setting to -inf ensures _should_reload_val_dl returns True
self.fit_loop.epoch_loop.val_loop._last_val_dl_reload_epoch = float("-inf")
rank_zero_info("Validation dataloader will be reloaded at the next validation check.")

Comment on lines +1209 to +1262
Copy link

Copilot AI Jan 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The method should validate that at least one of train or val is True. Currently, if both are False, the method will execute without performing any action or providing feedback to the user, which could be confusing.

Copilot uses AI. Check for mistakes.
"""
Accelerator properties
"""
Expand Down
71 changes: 71 additions & 0 deletions tests/tests_pytorch/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,21 @@ def __init__(self, out_dim: int = 2, hidden_dim: int = 2) -> None:
self.layer = torch.nn.Linear(32, out_dim)


class AdaptHparamsModel(BoringModel):
"""Simple model for testing adapt_checkpoint_hparams hook without dynamic neural network layers.

This model stores hyperparameters as attributes without creating layers that would cause size mismatches when
hyperparameters are changed between fit and predict phases.

"""

def __init__(self, out_dim: int = 8, hidden_dim: int = 16) -> None:
super().__init__()
self.save_hyperparameters()
self.out_dim = out_dim
self.hidden_dim = hidden_dim


def test_lightning_cli_ckpt_path_argument_hparams(cleandir):
class CkptPathCLI(LightningCLI):
def add_arguments_to_parser(self, parser):
Expand Down Expand Up @@ -562,6 +577,62 @@ def add_arguments_to_parser(self, parser):
assert cli.model.layer.out_features == 4


def test_adapt_checkpoint_hparams_hook_pop_keys(cleandir):
"""Test that the adapt_checkpoint_hparams hook is called and modifications are applied."""

class AdaptHparamsCLI(LightningCLI):
def adapt_checkpoint_hparams(self, subcommand: str, checkpoint_hparams: dict) -> dict:
"""Remove out_dim and hidden_dim for non-fit subcommands."""
if subcommand != "fit":
checkpoint_hparams.pop("out_dim", None)
checkpoint_hparams.pop("hidden_dim", None)
return checkpoint_hparams

# First, create a checkpoint by running fit
cli_args = ["fit", "--model.out_dim=3", "--model.hidden_dim=6", "--trainer.max_epochs=1"]
with mock.patch("sys.argv", ["any.py"] + cli_args):
cli = AdaptHparamsCLI(AdaptHparamsModel)

assert cli.config.fit.model.out_dim == 3
assert cli.config.fit.model.hidden_dim == 6

checkpoint_path = next(Path(cli.trainer.log_dir, "checkpoints").glob("*.ckpt"))

# Test that predict uses adapted hparams (without out_dim and hidden_dim)
cli_args = ["predict", f"--ckpt_path={checkpoint_path}", "--model.out_dim=5", "--model.hidden_dim=10"]
with mock.patch("sys.argv", ["any.py"] + cli_args):
cli = AdaptHparamsCLI(AdaptHparamsModel)

# Since we removed out_dim and hidden_dim for predict, the CLI values should be used
assert cli.config.predict.model.out_dim == 5
assert cli.config.predict.model.hidden_dim == 10


def test_adapt_checkpoint_hparams_hook_empty_dict(cleandir):
"""Test that returning empty dict from adapt_checkpoint_hparams disables checkpoint hyperparameter loading."""

class AdaptHparamsEmptyCLI(LightningCLI):
def adapt_checkpoint_hparams(self, subcommand: str, checkpoint_hparams: dict) -> dict:
"""Disable checkpoint hyperparameter loading."""
return {}

# First, create a checkpoint
cli_args = ["fit", "--model.out_dim=3", "--trainer.max_epochs=1"]
with mock.patch("sys.argv", ["any.py"] + cli_args):
cli = AdaptHparamsEmptyCLI(AdaptHparamsModel)

checkpoint_path = next(Path(cli.trainer.log_dir, "checkpoints").glob("*.ckpt"))

# Test that predict uses default values when hook returns empty dict
cli_args = ["predict", f"--ckpt_path={checkpoint_path}"]
with mock.patch("sys.argv", ["any.py"] + cli_args):
cli = AdaptHparamsEmptyCLI(AdaptHparamsModel)

# Model should use default values (out_dim=8, hidden_dim=16)
assert cli.config_init.predict.model.out_dim == 8
assert cli.config_init.predict.model.hidden_dim == 16


def test_lightning_cli_submodules(cleandir):
class MainModule(BoringModel):
def __init__(self, submodule1: LightningModule, submodule2: LightningModule, main_param: int = 1):
Expand Down
Loading
Loading