Skip to content

Commit 55c3553

Browse files
authored
Merge pull request #325 from kozistr/update/codes
[Feature] Implement `TAM`, `AdaTAM` optimizers
2 parents a9fb8a2 + 59e8736 commit 55c3553

25 files changed

+309
-39
lines changed

README.md

+2-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
## The reasons why you use `pytorch-optimizer`.
1212

13-
* Wide range of supported optimizers. Currently, **90 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported!
13+
* Wide range of supported optimizers. Currently, **92 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported!
1414
* Including many variants such as `ADOPT`, `Cautious`, `AdamD`, `StableAdamW`, and `Gradient Centrailiaztion`
1515
* Easy to use, clean, and tested codes
1616
* Active maintenance
@@ -199,6 +199,7 @@ get_supported_optimizers(['adam*', 'ranger*'])
199199
| OrthoGrad | *Grokking at the Edge of Numerical Stability* | [github](https://github.com/LucasPrietoAl/grokking-at-the-edge-of-numerical-stability) | <https://arxiv.org/abs/2501.04697> | [cite](https://github.com/LucasPrietoAl/grokking-at-the-edge-of-numerical-stability?tab=readme-ov-file#citation) |
200200
| Adam-ATAN2 | *Scaling Exponents Across Parameterizations and Optimizers* | | <https://arxiv.org/abs/2407.05872> | [cite](https://ui.adsabs.harvard.edu/abs/2024arXiv240705872E/exportcitation) |
201201
| SPAM | *Spike-Aware Adam with Momentum Reset for Stable LLM Training* | [github](https://github.com/TianjinYellow/SPAM-Optimizer) | <https://arxiv.org/abs/2501.06842> | [cite](https://ui.adsabs.harvard.edu/abs/2025arXiv250106842H/exportcitation) |
202+
| TAM | *Torque-Aware Momentum* | | <https://arxiv.org/abs/2412.18790> | [cite](https://ui.adsabs.harvard.edu/abs/2024arXiv241218790M/exportcitation) |
202203

203204
## Supported LR Scheduler
204205

docs/changelogs/v3.3.4.md

+2
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,5 @@
1010
* `Lookahead(AdamW, k=5, alpha=0.5, params=model.parameters())`
1111
* Implement `SPAM` optimizer. (#324)
1212
* [Spike-Aware Adam with Momentum Reset for Stable LLM Training](https://arxiv.org/abs/2501.06842)
13+
* Implement `TAM`, and `AdaTAM` optimizers. (#325)
14+
* [Torque-Aware Momentum](https://arxiv.org/abs/2412.18790)

docs/index.md

+2-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
## The reasons why you use `pytorch-optimizer`.
1212

13-
* Wide range of supported optimizers. Currently, **90 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported!
13+
* Wide range of supported optimizers. Currently, **92 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported!
1414
* Including many variants such as `ADOPT`, `Cautious`, `AdamD`, `StableAdamW`, and `Gradient Centrailiaztion`
1515
* Easy to use, clean, and tested codes
1616
* Active maintenance
@@ -199,6 +199,7 @@ get_supported_optimizers(['adam*', 'ranger*'])
199199
| OrthoGrad | *Grokking at the Edge of Numerical Stability* | [github](https://github.com/LucasPrietoAl/grokking-at-the-edge-of-numerical-stability) | <https://arxiv.org/abs/2501.04697> | [cite](https://github.com/LucasPrietoAl/grokking-at-the-edge-of-numerical-stability?tab=readme-ov-file#citation) |
200200
| Adam-ATAN2 | *Scaling Exponents Across Parameterizations and Optimizers* | | <https://arxiv.org/abs/2407.05872> | [cite](https://ui.adsabs.harvard.edu/abs/2024arXiv240705872E/exportcitation) |
201201
| SPAM | *Spike-Aware Adam with Momentum Reset for Stable LLM Training* | [github](https://github.com/TianjinYellow/SPAM-Optimizer) | <https://arxiv.org/abs/2501.06842> | [cite](https://ui.adsabs.harvard.edu/abs/2025arXiv250106842H/exportcitation) |
202+
| TAM | *Torque-Aware Momentum* | | <https://arxiv.org/abs/2412.18790> | [cite](https://ui.adsabs.harvard.edu/abs/2024arXiv241218790M/exportcitation) |
202203

203204
## Supported LR Scheduler
204205

docs/optimizer.md

+8
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,14 @@
380380
:docstring:
381381
:members:
382382

383+
::: pytorch_optimizer.TAM
384+
:docstring:
385+
:members:
386+
387+
::: pytorch_optimizer.AdaTAM
388+
:docstring:
389+
:members:
390+
383391
::: pytorch_optimizer.Tiger
384392
:docstring:
385393
:members:

docs/qa.md

+4
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,7 @@
77
## Q2) Memory leak happens when using SophiaH, AdaHessian optimizers.
88

99
`torch.autograd.grad` with complex gradient flows sometimes leads memory leak issues, and you might encounter OOM issue. [related issue](https://github.com/kozistr/pytorch_optimizer/issues/278)
10+
11+
## Q3) How to run visualizations?
12+
13+
Run `python3 -m examples.visualize_optimizers` on the project root.

docs/visualization.md

+32
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,10 @@
8282

8383
![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rastrigin_AdaSmooth.png)
8484

85+
### AdaTAM
86+
87+
![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rastrigin_AdaTAM.png)
88+
8589
### AdEMAMix
8690

8791
![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rastrigin_AdEMAMix.png)
@@ -254,6 +258,10 @@
254258

255259
![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rastrigin_Ranger21.png)
256260

261+
### Ranger25
262+
263+
![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rastrigin_Ranger25.png)
264+
257265
### ScalableShampoo
258266

259267
![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rastrigin_ScalableShampoo.png)
@@ -306,6 +314,10 @@
306314

307315
![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rastrigin_SophiaH.png)
308316

317+
### SPAM
318+
319+
![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rastrigin_SPAM.png)
320+
309321
### SRMM
310322

311323
![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rastrigin_SRMM.png)
@@ -318,6 +330,10 @@
318330

319331
![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rastrigin_SWATS.png)
320332

333+
### TAM
334+
335+
![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rastrigin_TAM.png)
336+
321337
### Tiger
322338

323339
![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rastrigin_Tiger.png)
@@ -408,6 +424,10 @@
408424

409425
![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rosenbrock_AdaSmooth.png)
410426

427+
### AdaTAM
428+
429+
![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rosenbrock_AdaTAM.png)
430+
411431
### AdEMAMix
412432

413433
![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rosenbrock_AdEMAMix.png)
@@ -580,6 +600,10 @@
580600

581601
![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rosenbrock_Ranger21.png)
582602

603+
### Ranger25
604+
605+
![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rosenbrock_Ranger25.png)
606+
583607
### ScalableShampoo
584608

585609
![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rosenbrock_ScalableShampoo.png)
@@ -632,6 +656,10 @@
632656

633657
![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rosenbrock_SophiaH.png)
634658

659+
### SPAM
660+
661+
![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rosenbrock_SPAM.png)
662+
635663
### SRMM
636664

637665
![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rosenbrock_SRMM.png)
@@ -644,6 +672,10 @@
644672

645673
![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rosenbrock_SWATS.png)
646674

675+
### TAM
676+
677+
![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rosenbrock_TAM.png)
678+
647679
### Tiger
648680

649681
![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rosenbrock_Tiger.png)
720 KB
Loading
717 KB
Loading
271 KB
Loading

docs/visualizations/rastrigin_TAM.png

722 KB
Loading
451 KB
Loading
450 KB
Loading
452 KB
Loading
463 KB
Loading

pyproject.toml

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "pytorch_optimizer"
3-
version = "3.3.3"
3+
version = "3.3.4"
44
description = "optimizer & lr scheduler & objective function collections in PyTorch"
55
license = "Apache-2.0"
66
authors = ["kozistr <[email protected]>"]
@@ -18,8 +18,8 @@ keywords = [
1818
"Kate", "Lamb", "LaProp", "LARS", "Lion", "LOMO", "Lookahead", "MADGRAD", "MARS", "MSVAG", "Muno", "Nero",
1919
"NovoGrad", "OrthoGrad", "PAdam", "PCGrad", "PID", "PNM", "Prodigy", "QHAdam", "QHM", "RAdam", "Ranger",
2020
"Ranger21", "RotoGrad", "SAM", "ScheduleFreeSGD", "ScheduleFreeAdamW", "ScheduleFreeRAdam", "SGDP", "Shampoo",
21-
"ScalableShampoo", "SGDW", "SignSGD", "SM3", "SOAP", "SopihaH", "SPAM", "SRMM", "StableAdamW", "SWATS", "Tiger",
22-
"TRAC", "WSAM", "Yogi", "BCE", "BCEFocal", "Focal", "FocalCosine", "SoftF1", "Dice", "LDAM", "Jaccard",
21+
"ScalableShampoo", "SGDW", "SignSGD", "SM3", "SOAP", "SopihaH", "SPAM", "SRMM", "StableAdamW", "SWATS", "TAM",
22+
"Tiger", "TRAC", "WSAM", "Yogi", "BCE", "BCEFocal", "Focal", "FocalCosine", "SoftF1", "Dice", "LDAM", "Jaccard",
2323
"Bi-Tempered", "Tversky", "FocalTversky", "LovaszHinge", "bitsandbytes", "WSD", "QGaLore",
2424
]
2525
classifiers = [

pytorch_optimizer/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
SPAM,
6565
SRMM,
6666
SWATS,
67+
TAM,
6768
TRAC,
6869
WSAM,
6970
A2Grad,
@@ -88,6 +89,7 @@
8889
AdaPNM,
8990
AdaShift,
9091
AdaSmooth,
92+
AdaTAM,
9193
AdEMAMix,
9294
AggMo,
9395
Aida,

pytorch_optimizer/optimizer/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@
8686
from pytorch_optimizer.optimizer.spam import SPAM
8787
from pytorch_optimizer.optimizer.srmm import SRMM
8888
from pytorch_optimizer.optimizer.swats import SWATS
89+
from pytorch_optimizer.optimizer.tam import TAM, AdaTAM
8990
from pytorch_optimizer.optimizer.tiger import Tiger
9091
from pytorch_optimizer.optimizer.trac import TRAC
9192
from pytorch_optimizer.optimizer.yogi import Yogi
@@ -252,6 +253,8 @@ def load_optimizer(optimizer: str) -> OPTIMIZER:
252253
SRMM,
253254
AvaGrad,
254255
AdaShift,
256+
TAM,
257+
AdaTAM,
255258
AdaDelta,
256259
Amos,
257260
AdaHessian,

pytorch_optimizer/optimizer/adafactor.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@ class AdaFactor(BaseOptimizer):
1313
1414
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
1515
:param lr: float. learning rate.
16-
:param betas: BETAS. coefficients used for computing running averages of gradient and the squared
17-
hessian trace. if beta1 is None, first momentum will be skipped.
16+
:param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
17+
if beta1 is None, first momentum will be skipped.
1818
:param decay_rate: float. coefficient used to compute running averages of square gradient.
1919
:param weight_decay: float. weight decay (L2 penalty).
2020
:param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.

pytorch_optimizer/optimizer/adamp.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,10 @@ def step(self, closure: CLOSURE = None) -> LOSS:
141141
inv_de_nom = exp_avg_sq.rsqrt().add_(group['eps']).mul_(bias_correction2_sq)
142142

143143
perturb = exp_avg.clone()
144+
145+
if self.cautious:
146+
self.apply_cautious(perturb, grad)
147+
144148
if group['nesterov']:
145149
perturb.mul_(beta1).addcmul_(grad, inv_de_nom, value=1.0 - beta1)
146150
else:
@@ -173,9 +177,6 @@ def step(self, closure: CLOSURE = None) -> LOSS:
173177
bias_correction1=bias_correction1,
174178
)
175179

176-
if self.cautious:
177-
self.apply_cautious(perturb, grad)
178-
179180
p.add_(perturb, alpha=-step_size)
180181

181182
return loss

pytorch_optimizer/optimizer/adamw.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -107,8 +107,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
107107
exp_avg.lerp_(grad, weight=beta1_comp)
108108
exp_avg_sq.mul_(beta2_hat).addcmul_(grad, grad, value=1.0 - beta2_hat)
109109

110-
rms = self.get_stable_adamw_rms(grad, exp_avg_sq, eps=eps_p2)
111-
lr = group['lr'] / rms
110+
lr: float = group['lr'] / self.get_stable_adamw_rms(grad, exp_avg_sq, eps=eps_p2)
112111

113112
self.apply_weight_decay(
114113
p,

pytorch_optimizer/optimizer/sgd.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -406,7 +406,7 @@ class SGDSaI(BaseOptimizer):
406406
407407
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
408408
:param lr: float. learning rate.
409-
:param momentum: float. coefficients used for computing running averages of gradient.
409+
:param momentum: float. coefficients used for computing running averages of gradient.
410410
:param weight_decay: float. weight decay (L2 penalty).
411411
:param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
412412
:param eps: float. term added to the denominator to improve numerical stability.
@@ -423,7 +423,7 @@ def __init__(
423423
**kwargs,
424424
):
425425
self.validate_learning_rate(lr)
426-
self.validate_range(momentum, 'beta', 0.0, 1.0)
426+
self.validate_range(momentum, 'momentum', 0.0, 1.0)
427427
self.validate_non_negative(weight_decay, 'weight_decay')
428428
self.validate_non_negative(eps, 'eps')
429429

pytorch_optimizer/optimizer/spam.py

+15-24
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def __init__(
8181
self.validate_non_negative(density, 'density')
8282
self.validate_non_negative(threshold, 'threshold')
8383
self.validate_non_negative(grad_accu_steps, 'grad_accu_steps')
84-
self.validate_non_negative(update_proj_gap, 'update_proj_gap')
84+
self.validate_positive(update_proj_gap, 'update_proj_gap')
8585
self.validate_non_negative(eps, 'eps')
8686

8787
self.density = density
@@ -91,41 +91,32 @@ def __init__(
9191
self.update_proj_gap = update_proj_gap
9292
self.warmup = CosineDecay(0.99, warmup_epoch)
9393

94-
defaults: DEFAULTS = {
95-
'lr': lr,
96-
'betas': betas,
97-
'weight_decay': weight_decay,
98-
'eps': eps,
99-
**kwargs,
100-
}
94+
defaults: DEFAULTS = {'lr': lr, 'betas': betas, 'weight_decay': weight_decay, 'eps': eps, **kwargs}
10195
super().__init__(params, defaults)
10296

10397
self.init_masks()
10498

10599
self.state['total_step'] = 0
106-
self.state['current_step'] = warmup_epoch + 1
100+
self.state['current_step'] = self.warmup_epoch + 1
107101

108102
@staticmethod
109-
def initialize_random_rank_boolean_tensor(m: int, n: int, density: float) -> torch.Tensor:
103+
def initialize_random_rank_boolean_tensor(m: int, n: int, density: float, device: torch.device) -> torch.Tensor:
110104
r"""Create an (m x n) boolean tensor with `density` fraction of True entries.
111105
112106
:param m: int. number of rows.
113107
:param n: int. number of columns.
114108
:param density: float. fraction of True entries. 1.0 means all True.
109+
:param device: torch.device. device.
115110
"""
116111
total_elements: int = m * n
117112
non_zero_count: int = int(density * total_elements)
118113

119-
tensor = torch.zeros((m, n), dtype=torch.bool)
114+
tensor = torch.zeros(total_elements, dtype=torch.bool, device=device)
120115

121-
if non_zero_count == 0:
122-
return tensor
116+
if non_zero_count > 0:
117+
tensor[torch.randperm(total_elements, device=device)[:non_zero_count]] = True
123118

124-
indices = torch.randperm(total_elements)[:non_zero_count]
125-
rows, cols = indices // n, indices % n
126-
tensor[rows, cols] = True
127-
128-
return tensor
119+
return tensor.view(m, n)
129120

130121
def update_mask_random(self, density: float, p: torch.Tensor, old_mask: torch.Tensor) -> torch.Tensor:
131122
r"""Update a random mask.
@@ -164,9 +155,8 @@ def update_masks(self) -> None:
164155
for p in group['params']:
165156
state = self.state[p]
166157
if 'mask' in state:
167-
new_mask = self.update_mask_random(self.density, p, state['mask'])
168-
state['mask'] = new_mask
169-
p.mask = new_mask
158+
state['mask'] = self.update_mask_random(self.density, p, state['mask'])
159+
p.mask = state['mask']
170160

171161
def init_masks(self) -> None:
172162
r"""Initialize random masks for each parameter group that has 'density'."""
@@ -175,10 +165,11 @@ def init_masks(self) -> None:
175165
state = self.state[p]
176166
if p.dim() == 2 and 'mask' not in state:
177167
state['mask'] = self.initialize_random_rank_boolean_tensor(
178-
p.shape[0],
179-
p.shape[1],
168+
m=p.shape[0],
169+
n=p.shape[1],
180170
density=self.density,
181-
).to(p.device)
171+
device=p.device,
172+
)
182173

183174
def __str__(self) -> str:
184175
return 'SPAM'

0 commit comments

Comments
 (0)