Skip to content

Commit a9fb8a2

Browse files
authored
Merge pull request #324 from kozistr/feature/ranger25
[Refactor] flexible and consistent `optimizer` parameters for `Lookahead`, `TRAC`, and `OrthoGrad` optimizers
2 parents 5baa713 + 87e1a60 commit a9fb8a2

28 files changed

+496
-130
lines changed

README.md

+3-2
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010

1111
## The reasons why you use `pytorch-optimizer`.
1212

13-
* Wide range of supported optimizers. Currently, **89 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported!
14-
* Including many variants such as `Cautious`, `AdamD`, `Gradient Centrailiaztion`
13+
* Wide range of supported optimizers. Currently, **90 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported!
14+
* Including many variants such as `ADOPT`, `Cautious`, `AdamD`, `StableAdamW`, and `Gradient Centrailiaztion`
1515
* Easy to use, clean, and tested codes
1616
* Active maintenance
1717
* Somewhat a bit more optimized compared to the original implementation
@@ -198,6 +198,7 @@ get_supported_optimizers(['adam*', 'ranger*'])
198198
| Grams | *Gradient Descent with Adaptive Momentum Scaling* | | <https://arxiv.org/abs/2412.17107> | [cite](https://ui.adsabs.harvard.edu/abs/2024arXiv241217107C/exportcitation) |
199199
| OrthoGrad | *Grokking at the Edge of Numerical Stability* | [github](https://github.com/LucasPrietoAl/grokking-at-the-edge-of-numerical-stability) | <https://arxiv.org/abs/2501.04697> | [cite](https://github.com/LucasPrietoAl/grokking-at-the-edge-of-numerical-stability?tab=readme-ov-file#citation) |
200200
| Adam-ATAN2 | *Scaling Exponents Across Parameterizations and Optimizers* | | <https://arxiv.org/abs/2407.05872> | [cite](https://ui.adsabs.harvard.edu/abs/2024arXiv240705872E/exportcitation) |
201+
| SPAM | *Spike-Aware Adam with Momentum Reset for Stable LLM Training* | [github](https://github.com/TianjinYellow/SPAM-Optimizer) | <https://arxiv.org/abs/2501.06842> | [cite](https://ui.adsabs.harvard.edu/abs/2025arXiv250106842H/exportcitation) |
201202

202203
## Supported LR Scheduler
203204

docs/changelogs/v3.3.4.md

+12
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
### Change Log
2+
3+
### Feature
4+
5+
* Support `OrthoGrad` feature for `create_optimizer()`. (#324)
6+
* Enhanced flexibility for the `optimizer` parameter in `Lookahead`, `TRAC`, and `OrthoGrad` optimizers. (#324)
7+
* Now supports both torch.optim.Optimizer instances and classes
8+
* You can now use `Lookahead` optimizer in two ways.
9+
* `Lookahead(AdamW(model.parameters(), lr=1e-3), k=5, alpha=0.5)`
10+
* `Lookahead(AdamW, k=5, alpha=0.5, params=model.parameters())`
11+
* Implement `SPAM` optimizer. (#324)
12+
* [Spike-Aware Adam with Momentum Reset for Stable LLM Training](https://arxiv.org/abs/2501.06842)

docs/index.md

+3-2
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010

1111
## The reasons why you use `pytorch-optimizer`.
1212

13-
* Wide range of supported optimizers. Currently, **89 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported!
14-
* Including many variants such as `Cautious`, `AdamD`, `Gradient Centrailiaztion`
13+
* Wide range of supported optimizers. Currently, **90 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported!
14+
* Including many variants such as `ADOPT`, `Cautious`, `AdamD`, `StableAdamW`, and `Gradient Centrailiaztion`
1515
* Easy to use, clean, and tested codes
1616
* Active maintenance
1717
* Somewhat a bit more optimized compared to the original implementation
@@ -198,6 +198,7 @@ get_supported_optimizers(['adam*', 'ranger*'])
198198
| Grams | *Gradient Descent with Adaptive Momentum Scaling* | | <https://arxiv.org/abs/2412.17107> | [cite](https://ui.adsabs.harvard.edu/abs/2024arXiv241217107C/exportcitation) |
199199
| OrthoGrad | *Grokking at the Edge of Numerical Stability* | [github](https://github.com/LucasPrietoAl/grokking-at-the-edge-of-numerical-stability) | <https://arxiv.org/abs/2501.04697> | [cite](https://github.com/LucasPrietoAl/grokking-at-the-edge-of-numerical-stability?tab=readme-ov-file#citation) |
200200
| Adam-ATAN2 | *Scaling Exponents Across Parameterizations and Optimizers* | | <https://arxiv.org/abs/2407.05872> | [cite](https://ui.adsabs.harvard.edu/abs/2024arXiv240705872E/exportcitation) |
201+
| SPAM | *Spike-Aware Adam with Momentum Reset for Stable LLM Training* | [github](https://github.com/TianjinYellow/SPAM-Optimizer) | <https://arxiv.org/abs/2501.06842> | [cite](https://ui.adsabs.harvard.edu/abs/2025arXiv250106842H/exportcitation) |
201202

202203
## Supported LR Scheduler
203204

docs/optimizer.md

+4
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,10 @@
368368
:docstring:
369369
:members:
370370

371+
::: pytorch_optimizer.SPAM
372+
:docstring:
373+
:members:
374+
371375
::: pytorch_optimizer.SRMM
372376
:docstring:
373377
:members:

pyproject.toml

+3-3
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@ keywords = [
1818
"Kate", "Lamb", "LaProp", "LARS", "Lion", "LOMO", "Lookahead", "MADGRAD", "MARS", "MSVAG", "Muno", "Nero",
1919
"NovoGrad", "OrthoGrad", "PAdam", "PCGrad", "PID", "PNM", "Prodigy", "QHAdam", "QHM", "RAdam", "Ranger",
2020
"Ranger21", "RotoGrad", "SAM", "ScheduleFreeSGD", "ScheduleFreeAdamW", "ScheduleFreeRAdam", "SGDP", "Shampoo",
21-
"ScalableShampoo", "SGDW", "SignSGD", "SM3", "SOAP", "SopihaH", "SRMM", "StableAdamW", "SWATS", "Tiger", "TRAC",
22-
"WSAM", "Yogi", "BCE", "BCEFocal", "Focal", "FocalCosine", "SoftF1", "Dice", "LDAM", "Jaccard", "Bi-Tempered",
23-
"Tversky", "FocalTversky", "LovaszHinge", "bitsandbytes", "WSD", "QGaLore",
21+
"ScalableShampoo", "SGDW", "SignSGD", "SM3", "SOAP", "SopihaH", "SPAM", "SRMM", "StableAdamW", "SWATS", "Tiger",
22+
"TRAC", "WSAM", "Yogi", "BCE", "BCEFocal", "Focal", "FocalCosine", "SoftF1", "Dice", "LDAM", "Jaccard",
23+
"Bi-Tempered", "Tversky", "FocalTversky", "LovaszHinge", "bitsandbytes", "WSD", "QGaLore",
2424
]
2525
classifiers = [
2626
"License :: OSI Approved :: Apache Software License",

pytorch_optimizer/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
SGDW,
6262
SM3,
6363
SOAP,
64+
SPAM,
6465
SRMM,
6566
SWATS,
6667
TRAC,

pytorch_optimizer/base/scheduler.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,17 @@
11
from abc import ABC, abstractmethod
22
from typing import List
33

4+
from torch.optim import Optimizer
5+
46
from pytorch_optimizer.base.exception import NegativeLRError, NegativeStepError
5-
from pytorch_optimizer.base.types import OPTIMIZER
67

78

89
class BaseLinearWarmupScheduler(ABC):
910
r"""BaseLinearWarmupScheduler class.
1011
1112
The LR Scheduler class based on this class has linear warmup strategy.
1213
13-
:param optimizer: Optimizer. OPTIMIZER. It will set learning rate to all trainable parameters in optimizer.
14+
:param optimizer: Optimizer. It will set learning rate to all trainable parameters in optimizer.
1415
:param t_max: int. total steps to train.
1516
:param max_lr: float. maximum lr.
1617
:param min_lr: float. minimum lr.
@@ -20,7 +21,7 @@ class BaseLinearWarmupScheduler(ABC):
2021

2122
def __init__(
2223
self,
23-
optimizer: OPTIMIZER,
24+
optimizer: Optimizer,
2425
t_max: int,
2526
max_lr: float,
2627
min_lr: float = 0.0,

pytorch_optimizer/base/types.py

+1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
PARAMETERS = Optional[Union[Iterable[Dict], Iterable[torch.Tensor]]]
1212
STATE = Dict
1313
OPTIMIZER = Type[Optimizer]
14+
OPTIMIZER_INSTANCE_OR_CLASS = Union[OPTIMIZER, Optimizer]
1415
SCHEDULER = Type[LRScheduler]
1516

1617
HUTCHINSON_G = Literal['gaussian', 'rademacher']

pytorch_optimizer/lr_scheduler/cosine_anealing.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
import math
22
from typing import List, Optional
33

4+
from torch.optim import Optimizer
45
from torch.optim.lr_scheduler import LRScheduler
56

6-
from pytorch_optimizer.base.types import OPTIMIZER
7-
87

98
class CosineAnnealingWarmupRestarts(LRScheduler):
109
r"""CosineAnnealingWarmupRestarts.
@@ -21,7 +20,7 @@ class CosineAnnealingWarmupRestarts(LRScheduler):
2120

2221
def __init__(
2322
self,
24-
optimizer: OPTIMIZER,
23+
optimizer: Optimizer,
2524
first_cycle_steps: int,
2625
cycle_mult: float = 1.0,
2726
max_lr: float = 1e-4,
@@ -53,7 +52,7 @@ def __init__(
5352

5453
self.init_lr()
5554

56-
def init_lr(self):
55+
def init_lr(self) -> None:
5756
self.base_lrs = []
5857
for param_group in self.optimizer.param_groups:
5958
param_group['lr'] = self.min_lr

pytorch_optimizer/lr_scheduler/rex.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
from typing import List, Optional
22

3+
from torch.optim import Optimizer
34
from torch.optim.lr_scheduler import LRScheduler
45

5-
from pytorch_optimizer.base.types import OPTIMIZER
6-
76

87
class REXScheduler(LRScheduler):
98
r"""Revisiting Budgeted Training with an Improved Schedule.
@@ -16,7 +15,7 @@ class REXScheduler(LRScheduler):
1615

1716
def __init__(
1817
self,
19-
optimizer: OPTIMIZER,
18+
optimizer: Optimizer,
2019
total_steps: int,
2120
max_lr: float = 1.0,
2221
min_lr: float = 0.0,
@@ -35,7 +34,7 @@ def __init__(
3534

3635
self.init_lr()
3736

38-
def init_lr(self):
37+
def init_lr(self) -> None:
3938
self.base_lrs = []
4039
for param_group in self.optimizer.param_groups:
4140
param_group['lr'] = self.min_lr

pytorch_optimizer/optimizer/__init__.py

+15-8
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import fnmatch
22
from importlib.util import find_spec
3-
from typing import Dict, List, Optional, Sequence, Set, Union
3+
from typing import Dict, List, Optional, Sequence, Set, Type, Union
44

55
import torch
66
from torch import nn
7-
from torch.optim import AdamW
7+
from torch.optim import AdamW, Optimizer
88

99
from pytorch_optimizer.base.types import OPTIMIZER, PARAMETERS
1010
from pytorch_optimizer.optimizer.a2grad import A2Grad
@@ -83,6 +83,7 @@
8383
from pytorch_optimizer.optimizer.sm3 import SM3
8484
from pytorch_optimizer.optimizer.soap import SOAP
8585
from pytorch_optimizer.optimizer.sophia import SophiaH
86+
from pytorch_optimizer.optimizer.spam import SPAM
8687
from pytorch_optimizer.optimizer.srmm import SRMM
8788
from pytorch_optimizer.optimizer.swats import SWATS
8889
from pytorch_optimizer.optimizer.tiger import Tiger
@@ -286,6 +287,7 @@ def load_optimizer(optimizer: str) -> OPTIMIZER:
286287
MARS,
287288
SGDSaI,
288289
Grams,
290+
SPAM,
289291
Ranger25,
290292
]
291293
OPTIMIZERS: Dict[str, OPTIMIZER] = {str(optimizer.__name__).lower(): optimizer for optimizer in OPTIMIZER_LIST}
@@ -298,31 +300,36 @@ def create_optimizer(
298300
weight_decay: float = 0.0,
299301
wd_ban_list: List[str] = ('bias', 'LayerNorm.bias', 'LayerNorm.weight'),
300302
use_lookahead: bool = False,
303+
use_orthograd: bool = False,
301304
**kwargs,
302-
):
305+
) -> Optimizer:
303306
r"""Build optimizer.
304307
305308
:param model: nn.Module. model.
306309
:param optimizer_name: str. name of optimizer.
307310
:param lr: float. learning rate.
308311
:param weight_decay: float. weight decay.
309312
:param wd_ban_list: List[str]. weight decay ban list by layer.
310-
:param use_lookahead: bool. use lookahead.
313+
:param use_lookahead: bool. use Lookahead.
314+
:param use_orthograd: bool. use OrthoGrad.
311315
"""
312316
optimizer_name = optimizer_name.lower()
313317

314318
parameters = (
315319
get_optimizer_parameters(model, weight_decay, wd_ban_list) if weight_decay > 0.0 else model.parameters()
316320
)
317321

318-
optimizer = load_optimizer(optimizer_name)
322+
optimizer_class: OPTIMIZER = load_optimizer(optimizer_name)
319323

320324
if optimizer_name == 'alig':
321-
optimizer = optimizer(parameters, max_lr=lr, **kwargs)
325+
optimizer = optimizer_class(parameters, max_lr=lr, **kwargs)
322326
elif optimizer_name in {'lomo', 'adalomo', 'adammini'}:
323-
optimizer = optimizer(model, lr=lr, **kwargs)
327+
optimizer = optimizer_class(model, lr=lr, **kwargs)
324328
else:
325-
optimizer = optimizer(parameters, lr=lr, **kwargs)
329+
optimizer = optimizer_class(parameters, lr=lr, **kwargs)
330+
331+
if use_orthograd:
332+
optimizer = OrthoGrad(optimizer, **kwargs)
326333

327334
if use_lookahead:
328335
optimizer = Lookahead(

pytorch_optimizer/optimizer/experimental/ranger25.py

+17-26
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
class Ranger25(BaseOptimizer):
1212
r"""Mixin' every fancy optimizer hacks.
1313
14+
ADOPT + AdEMAMix + Cautious + StableAdamW + Adam-Atan2
15+
1416
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
1517
:param lr: float. learning rate.
1618
:param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
@@ -19,10 +21,10 @@ class Ranger25(BaseOptimizer):
1921
:param fixed_decay: bool. fix weight decay.
2022
:param alpha: float. usually between 4 and 10 would work well.
2123
:param t_alpha_beta3: Optional[float]. total number of iterations is preferred when needed.
22-
:param n_sma_threshold: number of SMA threshold (recommended is 5).
2324
:param cautious: bool. whether to use the Cautious variant.
2425
:param stable_adamw: bool. whether to use stable AdamW variant.
25-
:param eps: float. term added to the denominator to improve numerical stability.
26+
:param eps: Optional[float]. term added to the denominator to improve numerical stability. when eps is None and
27+
stable_adamw is False, adam-atan2 feature will be used.
2628
"""
2729

2830
def __init__(
@@ -35,10 +37,9 @@ def __init__(
3537
fixed_decay: bool = False,
3638
alpha: float = 5.0,
3739
t_alpha_beta3: Optional[float] = None,
38-
n_sma_threshold: int = 5,
3940
cautious: bool = True,
4041
stable_adamw: bool = True,
41-
eps: float = 1e-8,
42+
eps: Optional[float] = 1e-8,
4243
**kwargs,
4344
):
4445
self.validate_learning_rate(lr)
@@ -48,9 +49,8 @@ def __init__(
4849
self.validate_non_negative(weight_decay, 'weight_decay')
4950
self.validate_non_negative(eps, 'eps')
5051

51-
self.n_sma_threshold = n_sma_threshold
5252
self.cautious = cautious
53-
self.stable_adamw = stable_adamw
53+
self.stable_adamw: bool = stable_adamw if isinstance(eps, float) else False
5454

5555
defaults: DEFAULTS = {
5656
'lr': lr,
@@ -60,7 +60,7 @@ def __init__(
6060
'fixed_decay': fixed_decay,
6161
'alpha': alpha,
6262
't_alpha_beta3': t_alpha_beta3,
63-
'eps': eps,
63+
'eps': eps if (eps is not None) or (eps is None and not stable_adamw) else 1e-8,
6464
}
6565

6666
super().__init__(params, defaults)
@@ -147,38 +147,29 @@ def step(self, closure: CLOSURE = None) -> LOSS:
147147

148148
exp_avg, exp_avg_sq, exp_avg_slow = state['exp_avg'], state['exp_avg_sq'], state['exp_avg_slow']
149149

150-
de_nom = exp_avg_sq.sqrt().clamp_(min=group['eps'])
151-
152-
normed_grad = grad.div(de_nom).clamp_(-clip, clip)
150+
normed_grad = grad.div(
151+
exp_avg_sq.sqrt().clamp_(min=group['eps'] if group['eps'] is not None else 1e-8)
152+
).clamp_(-clip, clip)
153153

154154
exp_avg.mul_(beta1).add_(normed_grad, alpha=1.0 - beta1)
155155
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
156156
exp_avg_slow.mul_(beta3_t).add_(normed_grad, alpha=1.0 - beta3_t)
157157

158-
de_nom = exp_avg_sq.sqrt().div_(bias_correction2_sq).add_(group['eps'])
159-
160158
update = exp_avg.clone()
161159
if self.cautious:
162160
self.apply_cautious(update, grad)
163161

164162
if self.stable_adamw:
165163
step_size /= self.get_stable_adamw_rms(grad, exp_avg_sq)
166164

167-
step_size, n_sma = self.get_rectify_step_size(
168-
is_rectify=True,
169-
step=group['step'],
170-
lr=step_size,
171-
beta2=beta2,
172-
n_sma_threshold=self.n_sma_threshold,
173-
degenerated_to_sgd=False,
174-
)
165+
update.add_(exp_avg_slow, alpha=alpha_t)
166+
167+
de_nom = exp_avg_sq.sqrt().div_(bias_correction2_sq)
175168

176-
update.add_(exp_avg_slow, alpha=alpha_t).div_(de_nom)
169+
if group['eps'] is not None:
170+
p.addcdiv_(update, de_nom.add_(group['eps']), value=-step_size)
171+
continue
177172

178-
if n_sma >= self.n_sma_threshold:
179-
de_nom = exp_avg_sq.sqrt().add_(group['eps'])
180-
p.addcdiv_(update, de_nom, value=-step_size)
181-
else:
182-
p.add_(update, alpha=-step_size)
173+
p.add_(update.atan2_(de_nom), alpha=-step_size)
183174

184175
return loss

0 commit comments

Comments
 (0)