Skip to content

Commit 8da7b49

Browse files
authoredFeb 2, 2025··
Merge pull request #339 from kozistr/feature/exadam-optimizer
[Feature] Implement `EXAdam` optimizer
2 parents aca76b6 + 5e62c4c commit 8da7b49

15 files changed

+164
-12
lines changed
 

‎README.md

+2-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
## The reasons why you use `pytorch-optimizer`.
1212

13-
* Wide range of supported optimizers. Currently, **94 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported!
13+
* Wide range of supported optimizers. Currently, **95 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported!
1414
* Including many variants such as `ADOPT`, `Cautious`, `AdamD`, `StableAdamW`, and `Gradient Centrailiaztion`
1515
* Easy to use, clean, and tested codes
1616
* Active maintenance
@@ -202,6 +202,7 @@ get_supported_optimizers(['adam*', 'ranger*'])
202202
| TAM | *Torque-Aware Momentum* | | <https://arxiv.org/abs/2412.18790> | [cite](https://ui.adsabs.harvard.edu/abs/2024arXiv241218790M/exportcitation) |
203203
| FOCUS | *First Order Concentrated Updating Scheme* | [github](https://github.com/liuyz0/FOCUS) | <https://arxiv.org/abs/2501.12243> | [cite](https://ui.adsabs.harvard.edu/abs/2025arXiv250112243M/exportcitation) |
204204
| PSGD | *Preconditioned Stochastic Gradient Descent* | [github](https://github.com/lixilinx/psgd_torch) | <https://arxiv.org/abs/1512.04202> | [cite](https://github.com/lixilinx/psgd_torch?tab=readme-ov-file#resources) |
205+
| EXAdam | *The Power of Adaptive Cross-Moments* | [github](https://github.com/AhmedMostafa16/EXAdam) | <https://arxiv.org/abs/2412.20302> | [cite](https://github.com/AhmedMostafa16/EXAdam?tab=readme-ov-file#citation) |
205206

206207
## Supported LR Scheduler
207208

‎docs/changelogs/v3.4.0.md

+3-1
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@
44

55
* Implement `FOCUS` optimizer. (#330, #331)
66
* [First Order Concentrated Updating Scheme](https://arxiv.org/abs/2501.12243)
7-
* Implement `PSGD Kron`. (#337)
7+
* Implement `PSGD Kron` optimizer. (#336, #337)
88
* [preconditioned stochastic gradient descent w/ Kron pre-conditioner](https://arxiv.org/abs/1512.04202)
9+
* Implement `EXAdam` optimizer. (#338, #339)
10+
* [The Power of Adaptive Cross-Moments](https://arxiv.org/abs/2412.20302)
911

1012
### Update
1113

‎docs/index.md

+2-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
## The reasons why you use `pytorch-optimizer`.
1212

13-
* Wide range of supported optimizers. Currently, **94 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported!
13+
* Wide range of supported optimizers. Currently, **95 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported!
1414
* Including many variants such as `ADOPT`, `Cautious`, `AdamD`, `StableAdamW`, and `Gradient Centrailiaztion`
1515
* Easy to use, clean, and tested codes
1616
* Active maintenance
@@ -202,6 +202,7 @@ get_supported_optimizers(['adam*', 'ranger*'])
202202
| TAM | *Torque-Aware Momentum* | | <https://arxiv.org/abs/2412.18790> | [cite](https://ui.adsabs.harvard.edu/abs/2024arXiv241218790M/exportcitation) |
203203
| FOCUS | *First Order Concentrated Updating Scheme* | [github](https://github.com/liuyz0/FOCUS) | <https://arxiv.org/abs/2501.12243> | [cite](https://ui.adsabs.harvard.edu/abs/2025arXiv250112243M/exportcitation) |
204204
| PSGD | *Preconditioned Stochastic Gradient Descent* | [github](https://github.com/lixilinx/psgd_torch) | <https://arxiv.org/abs/1512.04202> | [cite](https://github.com/lixilinx/psgd_torch?tab=readme-ov-file#resources) |
205+
| EXAdam | *The Power of Adaptive Cross-Moments* | [github](https://github.com/AhmedMostafa16/EXAdam) | <https://arxiv.org/abs/2412.20302> | [cite](https://github.com/AhmedMostafa16/EXAdam?tab=readme-ov-file#citation) |
205206

206207
## Supported LR Scheduler
207208

‎docs/optimizer.md

+4
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,10 @@
164164
:docstring:
165165
:members:
166166

167+
::: pytorch_optimizer.EXAdam
168+
:docstring:
169+
:members:
170+
167171
::: pytorch_optimizer.DynamicLossScaler
168172
:docstring:
169173
:members:

‎docs/visualization.md

+16
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,10 @@
150150

151151
![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rastrigin_DiffGrad.png)
152152

153+
### EXAdam
154+
155+
![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rastrigin_EXAdam.png)
156+
153157
### FAdam
154158

155159
![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rastrigin_FAdam.png)
@@ -186,6 +190,10 @@
186190

187191
![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rastrigin_Kate.png)
188192

193+
### Kron
194+
195+
![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rastrigin_Kron.png)
196+
189197
### Lamb
190198

191199
![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rastrigin_Lamb.png)
@@ -496,6 +504,10 @@
496504

497505
![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rosenbrock_DiffGrad.png)
498506

507+
### EXAdam
508+
509+
![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rosenbrock_EXAdam.png)
510+
499511
### FAdam
500512

501513
![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rosenbrock_FAdam.png)
@@ -532,6 +544,10 @@
532544

533545
![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rosenbrock_Kate.png)
534546

547+
### Kron
548+
549+
![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rosenbrock_Kron.png)
550+
535551
### Lamb
536552

537553
![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rosenbrock_Lamb.png)
184 KB
Loading
424 KB
Loading
342 KB
Loading
502 KB
Loading

‎pyproject.toml

+8-8
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "pytorch_optimizer"
3-
version = "3.3.4"
3+
version = "3.4.0"
44
description = "optimizer & lr scheduler & objective function collections in PyTorch"
55
license = "Apache-2.0"
66
authors = ["kozistr <kozistr@gmail.com>"]
@@ -14,13 +14,13 @@ keywords = [
1414
"AdaDelta", "AdaFactor", "AdaMax", "AdamG", "AdaMod", "AdaNorm", "AdaPNM", "AdaSmooth", "AdEMAMix", "ADOPT",
1515
"AdaHessian", "Adai", "Adalite", "AdaLomo", "AdamMini", "AdamP", "AdamS", "Adan", "AggMo", "Aida", "AliG", "Amos",
1616
"Apollo", "APOLLO", "AvaGrad", "bSAM", "CAME", "DAdaptAdaGrad", "DAdaptAdam", "DAdaptAdan", "DAdaptSGD",
17-
"DAdaptLion", "DeMo", "DiffGrad", "FAdam", "FOCUS", "Fromage", "FTRL", "GaLore", "Grams", "Gravity", "GrokFast",
18-
"GSAM", "Kate", "Lamb", "LaProp", "LARS", "Lion", "LOMO", "Lookahead", "MADGRAD", "MARS", "MSVAG", "Muno", "Nero",
19-
"NovoGrad", "OrthoGrad", "PAdam", "PCGrad", "PID", "PNM", "Prodigy", "PSGD", "QHAdam", "QHM", "RAdam", "Ranger",
20-
"Ranger21", "RotoGrad", "SAM", "ScheduleFreeSGD", "ScheduleFreeAdamW", "ScheduleFreeRAdam", "SGDP", "Shampoo",
21-
"ScalableShampoo", "SGDW", "SignSGD", "SM3", "SOAP", "SopihaH", "SPAM", "SRMM", "StableAdamW", "SWATS", "TAM",
22-
"Tiger", "TRAC", "WSAM", "Yogi", "BCE", "BCEFocal", "Focal", "FocalCosine", "SoftF1", "Dice", "LDAM", "Jaccard",
23-
"Bi-Tempered", "Tversky", "FocalTversky", "LovaszHinge", "bitsandbytes", "WSD", "QGaLore",
17+
"DAdaptLion", "DeMo", "DiffGrad", "EXAdam", "FAdam", "FOCUS", "Fromage", "FTRL", "GaLore", "Grams", "Gravity",
18+
"GrokFast", "GSAM", "Kate", "Lamb", "LaProp", "LARS", "Lion", "LOMO", "Lookahead", "MADGRAD", "MARS", "MSVAG",
19+
"Muno", "Nero", "NovoGrad", "OrthoGrad", "PAdam", "PCGrad", "PID", "PNM", "Prodigy", "PSGD", "QHAdam", "QHM",
20+
"RAdam", "Ranger", "Ranger21", "RotoGrad", "SAM", "ScheduleFreeSGD", "ScheduleFreeAdamW", "ScheduleFreeRAdam",
21+
"SGDP", "Shampoo", "ScalableShampoo", "SGDW", "SignSGD", "SM3", "SOAP", "SopihaH", "SPAM", "SRMM", "StableAdamW",
22+
"SWATS", "TAM", "Tiger", "TRAC", "WSAM", "Yogi", "BCE", "BCEFocal", "Focal", "FocalCosine", "SoftF1", "Dice",
23+
"LDAM", "Jaccard", "Bi-Tempered", "Tversky", "FocalTversky", "LovaszHinge", "bitsandbytes", "WSD", "QGaLore",
2424
]
2525
classifiers = [
2626
"License :: OSI Approved :: Apache Software License",

‎pytorch_optimizer/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@
106106
DeMo,
107107
DiffGrad,
108108
DynamicLossScaler,
109+
EXAdam,
109110
FAdam,
110111
Fromage,
111112
GaLore,

‎pytorch_optimizer/optimizer/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
from pytorch_optimizer.optimizer.dadapt import DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptLion, DAdaptSGD
4141
from pytorch_optimizer.optimizer.demo import DeMo
4242
from pytorch_optimizer.optimizer.diffgrad import DiffGrad
43+
from pytorch_optimizer.optimizer.exadam import EXAdam
4344
from pytorch_optimizer.optimizer.experimental.ranger25 import Ranger25
4445
from pytorch_optimizer.optimizer.fadam import FAdam
4546
from pytorch_optimizer.optimizer.focus import FOCUS
@@ -295,6 +296,7 @@ def load_optimizer(optimizer: str) -> OPTIMIZER:
295296
Grams,
296297
SPAM,
297298
Kron,
299+
EXAdam,
298300
Ranger25,
299301
]
300302
OPTIMIZERS: Dict[str, OPTIMIZER] = {str(optimizer.__name__).lower(): optimizer for optimizer in OPTIMIZER_LIST}

‎pytorch_optimizer/optimizer/exadam.py

+123
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
import numpy as np
2+
import torch
3+
4+
from pytorch_optimizer.base.exception import NoSparseGradientError
5+
from pytorch_optimizer.base.optimizer import BaseOptimizer
6+
from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS
7+
8+
9+
class EXAdam(BaseOptimizer):
10+
r"""The Power of Adaptive Cross-Moments.
11+
12+
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
13+
:param lr: float. learning rate.
14+
:param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
15+
:param weight_decay: float. weight decay (L2 penalty).
16+
:param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
17+
:param fixed_decay: bool. fix weight decay.
18+
:param eps: float. term added to the denominator to improve numerical stability.
19+
"""
20+
21+
def __init__(
22+
self,
23+
params: PARAMETERS,
24+
lr: float = 1e-3,
25+
betas: BETAS = (0.9, 0.999),
26+
weight_decay: float = 0.0,
27+
weight_decouple: bool = True,
28+
fixed_decay: bool = False,
29+
eps: float = 1e-8,
30+
**kwargs,
31+
):
32+
self.validate_learning_rate(lr)
33+
self.validate_betas(betas)
34+
self.validate_non_negative(weight_decay, 'weight_decay')
35+
self.validate_non_negative(eps, 'eps')
36+
37+
self.sq2: float = np.sqrt(2)
38+
39+
defaults: DEFAULTS = {
40+
'lr': lr,
41+
'betas': betas,
42+
'weight_decay': weight_decay,
43+
'weight_decouple': weight_decouple,
44+
'fixed_decay': fixed_decay,
45+
'eps': eps,
46+
}
47+
48+
super().__init__(params, defaults)
49+
50+
def __str__(self) -> str:
51+
return 'EXAdam'
52+
53+
@torch.no_grad()
54+
def reset(self):
55+
for group in self.param_groups:
56+
group['step'] = 0
57+
for p in group['params']:
58+
state = self.state[p]
59+
60+
state['exp_avg'] = torch.zeros_like(p)
61+
state['exp_avg_sq'] = torch.zeros_like(p)
62+
63+
@torch.no_grad()
64+
def step(self, closure: CLOSURE = None) -> LOSS:
65+
loss: LOSS = None
66+
if closure is not None:
67+
with torch.enable_grad():
68+
loss = closure()
69+
70+
for group in self.param_groups:
71+
if 'step' in group:
72+
group['step'] += 1
73+
else:
74+
group['step'] = 1
75+
76+
beta1, beta2 = group['betas']
77+
78+
bias_correction1: float = self.debias(beta1, group['step'])
79+
bias_correction2: float = self.debias(beta2, group['step'])
80+
81+
step_size: float = group['lr'] * np.log(np.sqrt(group['step'] + 1) * self.sq2)
82+
83+
for p in group['params']:
84+
if p.grad is None:
85+
continue
86+
87+
grad = p.grad
88+
if grad.is_sparse:
89+
raise NoSparseGradientError(str(self))
90+
91+
state = self.state[p]
92+
if len(state) == 0:
93+
state['exp_avg'] = torch.zeros_like(p)
94+
state['exp_avg_sq'] = torch.zeros_like(p)
95+
96+
self.apply_weight_decay(
97+
p=p,
98+
grad=grad,
99+
lr=group['lr'],
100+
weight_decay=group['weight_decay'],
101+
weight_decouple=group['weight_decouple'],
102+
fixed_decay=group['fixed_decay'],
103+
)
104+
105+
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
106+
exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)
107+
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
108+
109+
d1 = 1.0 + exp_avg_sq.div(exp_avg_sq.add(group['eps'])) * (1.0 - bias_correction2)
110+
111+
exp_avg_p2 = exp_avg.pow(2)
112+
d2 = 1.0 + exp_avg_p2.div(exp_avg_p2.add(group['eps'])) * (1.0 - bias_correction1)
113+
114+
m_tilde = exp_avg.div(bias_correction1) * d1
115+
v_tilde = exp_avg_sq.div(bias_correction2) * d2
116+
117+
g_tilde = grad.div(bias_correction1) * d1
118+
119+
update = (m_tilde + g_tilde) / v_tilde.sqrt().add_(group['eps'])
120+
121+
p.add_(update, alpha=-step_size)
122+
123+
return loss

‎tests/constants.py

+2
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
DAdaptLion,
5757
DAdaptSGD,
5858
DiffGrad,
59+
EXAdam,
5960
FAdam,
6061
Fromage,
6162
GaLore,
@@ -559,6 +560,7 @@
559560
(AdaTAM, {'lr': 1e-1, 'weight_decay': 1e-3}, 5),
560561
(FOCUS, {'lr': 1e-1, 'weight_decay': 1e-3}, 5),
561562
(Kron, {'lr': 1e0, 'weight_decay': 1e-3}, 3),
563+
(EXAdam, {'lr': 1e-1, 'weight_decay': 1e-3}, 5),
562564
(Ranger25, {'lr': 5e0}, 2),
563565
(Ranger25, {'lr': 5e0, 't_alpha_beta3': 5}, 2),
564566
(Ranger25, {'lr': 2e-1, 'stable_adamw': False, 'orthograd': False, 'eps': None}, 3),

‎tests/test_load_modules.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def test_load_lr_scheduler_invalid(invalid_lr_scheduler_names):
3434

3535

3636
def test_get_supported_optimizers():
37-
assert len(get_supported_optimizers()) == 92
37+
assert len(get_supported_optimizers()) == 93
3838
assert len(get_supported_optimizers('adam*')) == 7
3939
assert len(get_supported_optimizers(['adam*', 'ranger*'])) == 10
4040

0 commit comments

Comments
 (0)
Please sign in to comment.