diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index fb5af3a11ac3f..22cf899496ed6 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -11,7 +11,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added -- +- Added method chaining support to `LightningModule.freeze()` and `LightningModule.unfreeze()` by returning `self` ([#21469](https://github.com/Lightning-AI/pytorch-lightning/pull/21469)) ### Deprecated diff --git a/src/lightning/pytorch/core/module.py b/src/lightning/pytorch/core/module.py index bae7f876c8211..e5279e51887bf 100644 --- a/src/lightning/pytorch/core/module.py +++ b/src/lightning/pytorch/core/module.py @@ -1390,21 +1390,24 @@ def optimizer_zero_grad(self, epoch, batch_idx, optimizer): """ optimizer.zero_grad() - def freeze(self) -> None: + def freeze(self) -> Self: r"""Freeze all params for inference. - Example:: + .. code-block:: python model = MyLightningModule(...) model.freeze() + Returns: + :class:`LightningModule` with all parameters frozen. + """ for param in self.parameters(): param.requires_grad = False - self.eval() + return self.eval() - def unfreeze(self) -> None: + def unfreeze(self) -> Self: """Unfreeze all parameters for training. .. code-block:: python @@ -1412,11 +1415,14 @@ def unfreeze(self) -> None: model = MyLightningModule(...) model.unfreeze() + Returns: + :class:`LightningModule` self with all parameters unfrozen. + """ for param in self.parameters(): param.requires_grad = True - self.train() + return self.train() def _verify_is_manual_optimization(self, fn_name: str) -> None: if self.automatic_optimization: diff --git a/tests/tests_pytorch/trainer/test_trainer.py b/tests/tests_pytorch/trainer/test_trainer.py index 76860fd82733f..e1143d6a492af 100644 --- a/tests/tests_pytorch/trainer/test_trainer.py +++ b/tests/tests_pytorch/trainer/test_trainer.py @@ -386,12 +386,14 @@ def test_model_checkpoint_only_weights(tmp_path): def test_model_freeze_unfreeze(): model = BoringModel() - model.freeze() + freeze_ret = model.freeze() + assert freeze_ret is model assert not model.training for param in model.parameters(): assert not param.requires_grad - model.unfreeze() + unfreeze_ret = model.unfreeze() + assert unfreeze_ret is model assert model.training for param in model.parameters(): assert param.requires_grad