From 041cf3d41cd57a87dcb167cafd1353eecd8f1fd4 Mon Sep 17 00:00:00 2001 From: Yuanqing Wang Date: Tue, 25 Jul 2023 15:23:40 -0400 Subject: [PATCH] SWA implementation --- pyro/contrib/optim/__init__.py | 0 pyro/contrib/optim/swa.py | 39 ++++++++++++++++++++++++++++++++++ 2 files changed, 39 insertions(+) create mode 100644 pyro/contrib/optim/__init__.py create mode 100644 pyro/contrib/optim/swa.py diff --git a/pyro/contrib/optim/__init__.py b/pyro/contrib/optim/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/pyro/contrib/optim/swa.py b/pyro/contrib/optim/swa.py new file mode 100644 index 0000000000..c7efe4d636 --- /dev/null +++ b/pyro/contrib/optim/swa.py @@ -0,0 +1,39 @@ +import torch +from pyro.optim import PyroOptim +from torchcontrib.optim import SWA as _SWA + +def _swa_constructor( + param: torch.nn.Parameter, + base: type, + base_args: dict, + swa_args: dict, +) -> torch.optim.Optimizer: + base = base(param, **base_args) + optimizer = _SWA(base, **swa_args) + return optimizer + +def SWA(args: dict) -> PyroOptim: + """ + Stochastic Weight Averaging (SWA) optimizer. [1] + + References: + [1] 'Averaging Weights Leads to Wider Optima and Better Generalization', + Pavel Izmailov, Dmitry Podoprikhin, Timur Garipov, Dmitry Vetrov, + Andrew Gordon Wilson + Uncertainty in Artificial Intelligence (UAI), 2018 + + Arguments: + :param args: arguments for SWA optimizer + + """ + return PyroOptim(_swa_constructor, args) + +def swap_swa_sgd(optimizer: PyroOptim) -> None: + """ + Swap the SWA optimized parameters with samples. + + Arguments: + :param optimizer: SWA optimizer + """ + for key, value in optimizer.optim_objs.items(): + value.swap_swa_sgd()