diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 9f7317c218c30..78beedb72b5bf 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -4,6 +4,10 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). +## [Unreleased] + +- Added support for MPS mixed-precision autocast ([#20531](https://github.com/Lightning-AI/pytorch-lightning/pull/20531)) + ## [2.5.0] - 2024-12-19 ### Added diff --git a/src/lightning/pytorch/trainer/connectors/accelerator_connector.py b/src/lightning/pytorch/trainer/connectors/accelerator_connector.py index 40ee0eef4de33..079791d2f3fc7 100644 --- a/src/lightning/pytorch/trainer/connectors/accelerator_connector.py +++ b/src/lightning/pytorch/trainer/connectors/accelerator_connector.py @@ -520,7 +520,7 @@ def _check_and_init_precision(self) -> Precision: rank_zero_info( f"Using {'16bit' if self._precision_flag == '16-mixed' else 'bfloat16'} Automatic Mixed Precision (AMP)" ) - device = "cpu" if self._accelerator_flag == "cpu" else "cuda" + device = "cpu" if self._accelerator_flag == "cpu" else "mps" if self._accelerator_flag == "mps" else "cuda" return MixedPrecision(self._precision_flag, device) # type: ignore[arg-type] raise RuntimeError("No precision set")