Skip to content

Commit aee5fc4

Browse files
authored
Merge pull request #304 from kozistr/feature/optimizers
[Feature] Implement `ScheduleFreeRAdam`, `LaProp` optimizers and lots of things
2 parents a980dc0 + 5326483 commit aee5fc4

15 files changed

+438
-35
lines changed

README.md

+8-3
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,13 @@
88
| Status | [![PyPi download](https://static.pepy.tech/badge/pytorch-optimizer)](https://pepy.tech/project/pytorch-optimizer) [![PyPi month download](https://static.pepy.tech/badge/pytorch-optimizer/month)](https://pepy.tech/project/pytorch-optimizer) |
99
| License | [![apache](https://img.shields.io/badge/License-Apache_2.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) |
1010

11-
**pytorch-optimizer** is optimizer & lr scheduler collections in PyTorch.
12-
I just re-implemented (speed & memory tweaks, plug-ins) the algorithm while based on the original paper. Also, It includes useful and practical optimization ideas.
13-
Currently, **81 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported!
11+
## The reasons why you use `pytorch-optimizer`.
12+
13+
1. Wide range of supported optimizers. Currently, **83 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported!
14+
2. Including many variants such as `Cautious`, `AdamD`, `Gradient Centrailiaztion`
15+
3. Easy to use, clean, and tested codes
16+
4. Active maintenance
17+
5. Somewhat a bit more optimized compared to the original implementation
1418

1519
Highly inspired by [pytorch-optimizer](https://github.com/jettify/pytorch-optimizer).
1620

@@ -187,6 +191,7 @@ get_supported_optimizers(['adam*', 'ranger*'])
187191
| DeMo | *Decoupled Momentum Optimization* | [github](https://github.com/bloc97/DeMo) | <https://arxiv.org/abs/2411.19870> | [cite](https://ui.adsabs.harvard.edu/abs/2024arXiv241119870P/exportcitation) |
188192
| MicroAdam | *Accurate Adaptive Optimization with Low Space Overhead and Provable Convergence* | [github](https://github.com/IST-DASLab/MicroAdam) | <https://arxiv.org/abs/2405.15593> | [cite](https://github.com/IST-DASLab/MicroAdam?tab=readme-ov-file#citing) |
189193
| Muon | *MomentUm Orthogonalized by Newton-schulz* | [github](https://github.com/KellerJordan/Muon) | <https://x.com/kellerjordan0/status/1842300916864844014> | [cite](https://github.com/KellerJordan/Muon) |
194+
| LaProp | *Separating Momentum and Adaptivity in Adam* | [github](https://github.com/Z-T-WANG/LaProp-Optimizer) | <https://arxiv.org/abs/2002.04839> | [cite](https://github.com/Z-T-WANG/LaProp-Optimizer?tab=readme-ov-file#citation) |
190195

191196
## Supported LR Scheduler
192197

docs/changelogs/v3.3.1.md

+4
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,7 @@
66
* [Decoupled Momentum Optimization](https://arxiv.org/abs/2411.19870)
77
* Implement `Muon` optimizer. (#302)
88
* [MomentUm Orthogonalized by Newton-schulz](https://github.com/KellerJordan/Muon)
9+
* Implement `ScheduleFreeRAdam` optimizer. (#304)
10+
* Implement `LaProp` optimizer. (#304)
11+
* [Separating Momentum and Adaptivity in Adam](https://arxiv.org/abs/2002.04839)
12+
* Support `Cautious` variant to `LaProp`, `AdamP`, `Adopt` optimizers. (#304).

docs/index.md

+8-3
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,13 @@
88
| Status | [![PyPi download](https://static.pepy.tech/badge/pytorch-optimizer)](https://pepy.tech/project/pytorch-optimizer) [![PyPi month download](https://static.pepy.tech/badge/pytorch-optimizer/month)](https://pepy.tech/project/pytorch-optimizer) |
99
| License | [![apache](https://img.shields.io/badge/License-Apache_2.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) |
1010

11-
**pytorch-optimizer** is optimizer & lr scheduler collections in PyTorch.
12-
I just re-implemented (speed & memory tweaks, plug-ins) the algorithm while based on the original paper. Also, It includes useful and practical optimization ideas.
13-
Currently, **81 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported!
11+
## The reasons why you use `pytorch-optimizer`.
12+
13+
1. Wide range of supported optimizers. Currently, **83 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported!
14+
2. Including many variants such as `Cautious`, `AdamD`, `Gradient Centrailiaztion`
15+
3. Easy to use, clean, and tested codes
16+
4. Active maintenance
17+
5. Somewhat a bit more optimized compared to the original implementation
1418

1519
Highly inspired by [pytorch-optimizer](https://github.com/jettify/pytorch-optimizer).
1620

@@ -187,6 +191,7 @@ get_supported_optimizers(['adam*', 'ranger*'])
187191
| DeMo | *Decoupled Momentum Optimization* | [github](https://github.com/bloc97/DeMo) | <https://arxiv.org/abs/2411.19870> | [cite](https://ui.adsabs.harvard.edu/abs/2024arXiv241119870P/exportcitation) |
188192
| MicroAdam | *Accurate Adaptive Optimization with Low Space Overhead and Provable Convergence* | [github](https://github.com/IST-DASLab/MicroAdam) | <https://arxiv.org/abs/2405.15593> | [cite](https://github.com/IST-DASLab/MicroAdam?tab=readme-ov-file#citing) |
189193
| Muon | *MomentUm Orthogonalized by Newton-schulz* | [github](https://github.com/KellerJordan/Muon) | <https://x.com/kellerjordan0/status/1842300916864844014> | [cite](https://github.com/KellerJordan/Muon) |
194+
| LaProp | *Separating Momentum and Adaptivity in Adam* | [github](https://github.com/Z-T-WANG/LaProp-Optimizer) | <https://arxiv.org/abs/2002.04839> | [cite](https://github.com/Z-T-WANG/LaProp-Optimizer?tab=readme-ov-file#citation) |
190195

191196
## Supported LR Scheduler
192197

docs/optimizer.md

+8
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,10 @@
204204
:docstring:
205205
:members:
206206

207+
::: pytorch_optimizer.LaProp
208+
:docstring:
209+
:members:
210+
207211
::: pytorch_optimizer.LARS
208212
:docstring:
209213
:members:
@@ -296,6 +300,10 @@
296300
:docstring:
297301
:members:
298302

303+
::: pytorch_optimizer.ScheduleFreeRAdam
304+
:docstring:
305+
:members:
306+
299307
::: pytorch_optimizer.StableAdamW
300308
:docstring:
301309
:members:

pyproject.toml

+7-7
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,13 @@ keywords = [
1414
"AdaDelta", "AdaFactor", "AdaMax", "AdamG", "AdaMod", "AdaNorm", "AdaPNM", "AdaSmooth", "AdEMAMix", "ADOPT",
1515
"AdaHessian", "Adai", "Adalite", "AdaLomo", "AdamMini", "AdamP", "AdamS", "Adan", "AggMo", "Aida", "AliG", "Amos",
1616
"Apollo", "AvaGrad", "bSAM", "CAME", "DAdaptAdaGrad", "DAdaptAdam", "DAdaptAdan", "DAdaptSGD", "DAdaptLion",
17-
"DeMo", "DiffGrad", "FAdam", "Fromage", "FTRL", "GaLore", "Gravity", "GrokFast", "GSAM", "Kate", "Lamb", "LARS",
18-
"Lion", "LOMO", "Lookahead", "MADGRAD", "MSVAG", "Muno", "Nero", "NovoGrad", "PAdam", "PCGrad", "PID", "PNM",
19-
"Prodigy", "QHAdam", "QHM", "RAdam", "Ranger", "Ranger21", "RotoGrad", "SAM", "ScheduleFreeSGD",
20-
"ScheduleFreeAdamW", "SGDP", "Shampoo", "ScalableShampoo", "SGDW", "SignSGD", "SM3", "SOAP", "SopihaH", "SRMM",
21-
"StableAdamW", "SWATS", "Tiger", "TRAC", "WSAM", "Yogi", "BCE", "BCEFocal", "Focal", "FocalCosine", "SoftF1",
22-
"Dice", "LDAM", "Jaccard", "Bi-Tempered", "Tversky", "FocalTversky", "LovaszHinge", "bitsandbytes", "WSD",
23-
"QGaLore",
17+
"DeMo", "DiffGrad", "FAdam", "Fromage", "FTRL", "GaLore", "Gravity", "GrokFast", "GSAM", "Kate", "Lamb", "LaProp",
18+
"LARS", "Lion", "LOMO", "Lookahead", "MADGRAD", "MSVAG", "Muno", "Nero", "NovoGrad", "PAdam", "PCGrad", "PID",
19+
"PNM", "Prodigy", "QHAdam", "QHM", "RAdam", "Ranger", "Ranger21", "RotoGrad", "SAM", "ScheduleFreeSGD",
20+
"ScheduleFreeAdamW", "ScheduleFreeRAdam", "SGDP", "Shampoo", "ScalableShampoo", "SGDW", "SignSGD", "SM3", "SOAP",
21+
"SopihaH", "SRMM", "StableAdamW", "SWATS", "Tiger", "TRAC", "WSAM", "Yogi", "BCE", "BCEFocal", "Focal",
22+
"FocalCosine", "SoftF1", "Dice", "LDAM", "Jaccard", "Bi-Tempered", "Tversky", "FocalTversky", "LovaszHinge",
23+
"bitsandbytes", "WSD", "QGaLore",
2424
]
2525
classifiers = [
2626
"License :: OSI Approved :: Apache Software License",

pytorch_optimizer/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@
107107
GrokFastAdamW,
108108
Kate,
109109
Lamb,
110+
LaProp,
110111
Lion,
111112
Lookahead,
112113
Muon,
@@ -123,6 +124,7 @@
123124
SafeFP16Optimizer,
124125
ScalableShampoo,
125126
ScheduleFreeAdamW,
127+
ScheduleFreeRAdam,
126128
ScheduleFreeSGD,
127129
Shampoo,
128130
SignSGD,

pytorch_optimizer/optimizer/__init__.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
from pytorch_optimizer.optimizer.grokfast import GrokFastAdamW
5151
from pytorch_optimizer.optimizer.kate import Kate
5252
from pytorch_optimizer.optimizer.lamb import Lamb
53+
from pytorch_optimizer.optimizer.laprop import LaProp
5354
from pytorch_optimizer.optimizer.lars import LARS
5455
from pytorch_optimizer.optimizer.lion import Lion
5556
from pytorch_optimizer.optimizer.lomo import LOMO, AdaLOMO
@@ -71,7 +72,7 @@
7172
from pytorch_optimizer.optimizer.ranger21 import Ranger21
7273
from pytorch_optimizer.optimizer.rotograd import RotoGrad
7374
from pytorch_optimizer.optimizer.sam import BSAM, GSAM, SAM, WSAM
74-
from pytorch_optimizer.optimizer.schedulefree import ScheduleFreeAdamW, ScheduleFreeSGD
75+
from pytorch_optimizer.optimizer.schedulefree import ScheduleFreeAdamW, ScheduleFreeRAdam, ScheduleFreeSGD
7576
from pytorch_optimizer.optimizer.sgd import ASGD, SGDW, AccSGD, SignSGD
7677
from pytorch_optimizer.optimizer.sgdp import SGDP
7778
from pytorch_optimizer.optimizer.shampoo import ScalableShampoo, Shampoo
@@ -275,6 +276,8 @@ def load_optimizer(optimizer: str) -> OPTIMIZER:
275276
FTRL,
276277
DeMo,
277278
Muon,
279+
ScheduleFreeRAdam,
280+
LaProp,
278281
]
279282
OPTIMIZERS: Dict[str, OPTIMIZER] = {str(optimizer.__name__).lower(): optimizer for optimizer in OPTIMIZER_LIST}
280283

pytorch_optimizer/optimizer/adalite.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def __init__(
4141
self.validate_betas(betas)
4242
self.validate_non_negative(weight_decay, 'weight_decay')
4343
self.validate_non_negative(eps1, 'eps1')
44-
self.validate_non_negative(eps2, 'eps1')
44+
self.validate_non_negative(eps2, 'eps2')
4545

4646
defaults: DEFAULTS = {
4747
'lr': lr,

pytorch_optimizer/optimizer/adamp.py

+6
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ class AdamP(BaseOptimizer):
2222
:param wd_ratio: float. relative weight decay applied on scale-invariant parameters compared to that applied
2323
on scale-variant parameters.
2424
:param use_gc: bool. use gradient centralization.
25+
:param cautious: bool. whether to use the Cautious variant.
2526
:param nesterov: bool. enables Nesterov momentum.
2627
:param r: float. EMA factor. between 0.9 ~ 0.99 is preferred.
2728
:param adanorm: bool. whether to use the AdaNorm variant.
@@ -40,6 +41,7 @@ def __init__(
4041
delta: float = 0.1,
4142
wd_ratio: float = 0.1,
4243
use_gc: bool = False,
44+
cautious: bool = False,
4345
nesterov: bool = False,
4446
r: float = 0.95,
4547
adanorm: bool = False,
@@ -54,6 +56,7 @@ def __init__(
5456
self.validate_non_negative(eps, 'eps')
5557

5658
self.use_gc = use_gc
59+
self.cautious = cautious
5760

5861
defaults: DEFAULTS = {
5962
'lr': lr,
@@ -170,6 +173,9 @@ def step(self, closure: CLOSURE = None) -> LOSS:
170173
bias_correction1=bias_correction1,
171174
)
172175

176+
if self.cautious:
177+
self.apply_cautious(perturb, grad)
178+
173179
p.add_(perturb, alpha=-step_size)
174180

175181
return loss

pytorch_optimizer/optimizer/adopt.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ class ADOPT(BaseOptimizer):
1717
:param weight_decay: float. weight decay (L2 penalty).
1818
:param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
1919
:param fixed_decay: bool. fix weight decay.
20+
:param cautious: bool. whether to use the Cautious variant.
2021
:param eps: float. term added to the denominator to improve numerical stability.
2122
"""
2223

@@ -29,6 +30,7 @@ def __init__(
2930
weight_decay: float = 0.0,
3031
weight_decouple: bool = False,
3132
fixed_decay: bool = False,
33+
cautious: bool = False,
3234
eps: float = 1e-6,
3335
**kwargs,
3436
):
@@ -38,6 +40,7 @@ def __init__(
3840
self.validate_non_negative(eps, 'eps')
3941

4042
self.clip_lambda = clip_lambda
43+
self.cautious = cautious
4144

4245
defaults: DEFAULTS = {
4346
'lr': lr,
@@ -118,6 +121,12 @@ def step(self, closure: CLOSURE = None) -> LOSS:
118121

119122
exp_avg.lerp_(normed_grad, weight=1.0 - beta1)
120123

121-
p.add_(exp_avg, alpha=-group['lr'])
124+
if self.cautious:
125+
update = exp_avg.clone()
126+
self.apply_cautious(update, normed_grad)
127+
else:
128+
update = exp_avg
129+
130+
p.add_(update, alpha=-group['lr'])
122131

123132
return loss

0 commit comments

Comments
 (0)