Skip to content

Commit 8f538d4

Browse files
authored
Merge pull request #316 from kozistr/fix/cautious
[Feature] Implement `SGDSaI` optimizer
2 parents d16a368 + a5e0894 commit 8f538d4

16 files changed

+176
-19
lines changed

README.md

+2-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
## The reasons why you use `pytorch-optimizer`.
1212

13-
* Wide range of supported optimizers. Currently, **85 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported!
13+
* Wide range of supported optimizers. Currently, **86 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported!
1414
* Including many variants such as `Cautious`, `AdamD`, `Gradient Centrailiaztion`
1515
* Easy to use, clean, and tested codes
1616
* Active maintenance
@@ -194,6 +194,7 @@ get_supported_optimizers(['adam*', 'ranger*'])
194194
| LaProp | *Separating Momentum and Adaptivity in Adam* | [github](https://github.com/Z-T-WANG/LaProp-Optimizer) | <https://arxiv.org/abs/2002.04839> | [cite](https://github.com/Z-T-WANG/LaProp-Optimizer?tab=readme-ov-file#citation) |
195195
| APOLLO | *SGD-like Memory, AdamW-level Performance* | [github](https://github.com/zhuhanqing/APOLLO) | <https://arxiv.org/abs/2412.05270> | [cite](https://github.com/zhuhanqing/APOLLO?tab=readme-ov-file#-citation) |
196196
| MARS | *Unleashing the Power of Variance Reduction for Training Large Models* | [github](https://github.com/AGI-Arena/MARS) | <https://arxiv.org/abs/2411.10438> | [cite](https://github.com/AGI-Arena/MARS/tree/main?tab=readme-ov-file#citation) |
197+
| SGDSaI | *No More Adam: Learning Rate Scaling at Initialization is All You Need* | [github](https://github.com/AnonymousAlethiometer/SGD_SaI) | <https://arxiv.org/abs/2411.10438> | [cite](https://github.com/AnonymousAlethiometer/SGD_SaI?tab=readme-ov-file#citation) |
197198

198199
## Supported LR Scheduler
199200

docs/changelogs/v3.3.2.md

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
### Change Log
2+
3+
### Feature
4+
5+
* Implement `SGDSaI` optimizer. (#315, #316)
6+
* [No More Adam: Learning Rate Scaling at Initialization is All You Need](https://arxiv.org/abs/2412.11768)
7+
8+
### Bug
9+
10+
* Clone `exp_avg` before calling `apply_cautious` not to mask `exp_avg`. (#316)

docs/index.md

+2-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
## The reasons why you use `pytorch-optimizer`.
1212

13-
* Wide range of supported optimizers. Currently, **85 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported!
13+
* Wide range of supported optimizers. Currently, **86 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported!
1414
* Including many variants such as `Cautious`, `AdamD`, `Gradient Centrailiaztion`
1515
* Easy to use, clean, and tested codes
1616
* Active maintenance
@@ -194,6 +194,7 @@ get_supported_optimizers(['adam*', 'ranger*'])
194194
| LaProp | *Separating Momentum and Adaptivity in Adam* | [github](https://github.com/Z-T-WANG/LaProp-Optimizer) | <https://arxiv.org/abs/2002.04839> | [cite](https://github.com/Z-T-WANG/LaProp-Optimizer?tab=readme-ov-file#citation) |
195195
| APOLLO | *SGD-like Memory, AdamW-level Performance* | [github](https://github.com/zhuhanqing/APOLLO) | <https://arxiv.org/abs/2412.05270> | [cite](https://github.com/zhuhanqing/APOLLO?tab=readme-ov-file#-citation) |
196196
| MARS | *Unleashing the Power of Variance Reduction for Training Large Models* | [github](https://github.com/AGI-Arena/MARS) | <https://arxiv.org/abs/2411.10438> | [cite](https://github.com/AGI-Arena/MARS/tree/main?tab=readme-ov-file#citation) |
197+
| SGDSaI | *No More Adam: Learning Rate Scaling at Initialization is All You Need* | [github](https://github.com/AnonymousAlethiometer/SGD_SaI) | <https://arxiv.org/abs/2411.10438> | [cite](https://github.com/AnonymousAlethiometer/SGD_SaI?tab=readme-ov-file#citation) |
197198

198199
## Supported LR Scheduler
199200

docs/optimizer.md

+4
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,10 @@
332332
:docstring:
333333
:members:
334334

335+
::: pytorch_optimizer.SGDSaI
336+
:docstring:
337+
:members:
338+
335339
::: pytorch_optimizer.SGDP
336340
:docstring:
337341
:members:

docs/visualization.md

+8
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,10 @@
274274

275275
![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rastrigin_SGDP.png)
276276

277+
### SGDSaI
278+
279+
![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rastrigin_SGDSaI.png)
280+
277281
### SGDW
278282

279283
![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rastrigin_SGDW.png)
@@ -592,6 +596,10 @@
592596

593597
![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rosenbrock_SGDP.png)
594598

599+
### SGDSaI
600+
601+
![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rosenbrock_SGDSaI.png)
602+
595603
### SGDW
596604

597605
![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rosenbrock_SGDW.png)
720 KB
Loading
353 KB
Loading

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "pytorch_optimizer"
3-
version = "3.3.1"
3+
version = "3.3.2"
44
description = "optimizer & lr scheduler & objective function collections in PyTorch"
55
license = "Apache-2.0"
66
authors = ["kozistr <[email protected]>"]

pytorch_optimizer/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@
128128
ScheduleFreeAdamW,
129129
ScheduleFreeRAdam,
130130
ScheduleFreeSGD,
131+
SGDSaI,
131132
Shampoo,
132133
SignSGD,
133134
SophiaH,

pytorch_optimizer/optimizer/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@
7474
from pytorch_optimizer.optimizer.rotograd import RotoGrad
7575
from pytorch_optimizer.optimizer.sam import BSAM, GSAM, SAM, WSAM
7676
from pytorch_optimizer.optimizer.schedulefree import ScheduleFreeAdamW, ScheduleFreeRAdam, ScheduleFreeSGD
77-
from pytorch_optimizer.optimizer.sgd import ASGD, SGDW, AccSGD, SignSGD
77+
from pytorch_optimizer.optimizer.sgd import ASGD, SGDW, AccSGD, SGDSaI, SignSGD
7878
from pytorch_optimizer.optimizer.sgdp import SGDP
7979
from pytorch_optimizer.optimizer.shampoo import ScalableShampoo, Shampoo
8080
from pytorch_optimizer.optimizer.sm3 import SM3
@@ -281,6 +281,7 @@ def load_optimizer(optimizer: str) -> OPTIMIZER:
281281
ScheduleFreeRAdam,
282282
LaProp,
283283
MARS,
284+
SGDSaI,
284285
]
285286
OPTIMIZERS: Dict[str, OPTIMIZER] = {str(optimizer.__name__).lower(): optimizer for optimizer in OPTIMIZER_LIST}
286287

pytorch_optimizer/optimizer/adashift.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -110,10 +110,11 @@ def step(self, closure: CLOSURE = None) -> LOSS:
110110
exp_avg_sq.mul_(beta2).add_(reduced_grad_sq, alpha=1.0 - beta2)
111111

112112
update = exp_avg.clone()
113-
update.div_(exp_avg_sq.div(bias_correction).sqrt_().add_(group['eps']))
114113
if self.cautious:
115114
self.apply_cautious(update, grad)
116115

116+
update.div_(exp_avg_sq.div(bias_correction).sqrt_().add_(group['eps']))
117+
117118
p.add_(update, alpha=-group['lr'])
118119

119120
return loss

pytorch_optimizer/optimizer/ademamix.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -146,10 +146,11 @@ def step(self, closure: CLOSURE = None) -> LOSS:
146146

147147
de_nom = exp_avg_sq.sqrt().div_(bias_correction2_sq).add_(group['eps'])
148148

149+
update = exp_avg.clone()
149150
if self.cautious:
150-
self.apply_cautious(exp_avg, grad)
151+
self.apply_cautious(update, grad)
151152

152-
update = (exp_avg + alpha_t * exp_avg_slow).div_(de_nom)
153+
update.add_(exp_avg_slow, alpha=alpha_t).div_(de_nom)
153154

154155
p.add_(update, alpha=-step_size)
155156

pytorch_optimizer/optimizer/mars.py

+14-11
Original file line numberDiff line numberDiff line change
@@ -121,26 +121,27 @@ def optimize_mixed(
121121

122122
exp_avg.mul_(beta1).add_(c_t, alpha=1.0 - beta1)
123123

124+
update = exp_avg.clone()
124125
if cautious:
125-
self.apply_cautious(exp_avg, grad)
126+
self.apply_cautious(update, grad)
126127

127128
if mars_type == 'adamw' or (mars_type == 'shampoo' and not is_grad_2d):
128129
exp_avg_sq.mul_(beta2).addcmul_(c_t, c_t, value=1.0 - beta2)
129130

130131
bias_correction1: float = self.debias(beta1, step)
131132
bias_correction2_sq: float = math.sqrt(self.debias(beta2, step))
132133

133-
update = self.apply_ams_bound(ams_bound, exp_avg_sq, max_exp_avg_sq, eps)
134-
update.div_(bias_correction2_sq).mul_(bias_correction1)
134+
de_nom = self.apply_ams_bound(ams_bound, exp_avg_sq, max_exp_avg_sq, eps)
135+
de_nom.div_(bias_correction2_sq).mul_(bias_correction1)
135136

136-
return exp_avg.div(update)
137+
return update.div_(de_nom)
137138

138139
if mars_type == 'lion':
139-
return exp_avg.sign()
140+
return update.sign_()
140141

141-
factor: float = max(1.0, grad.size(0) / grad.size(1)) ** 0.5
142+
factor: float = math.sqrt(max(1.0, grad.size(0) / grad.size(1)))
142143

143-
return zero_power_via_newton_schulz_5(exp_avg.mul(1.0 / (1.0 - beta1)), eps=eps).mul_(factor)
144+
return zero_power_via_newton_schulz_5(update.mul_(1.0 / (1.0 - beta1)), eps=eps).mul_(factor)
144145

145146
def optimize_1d(
146147
self,
@@ -162,13 +163,15 @@ def optimize_1d(
162163
exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)
163164
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
164165

165-
update = self.apply_ams_bound(ams_bound, exp_avg_sq, max_exp_avg_sq, eps)
166-
update.div_(bias_correction2_sq).mul_(bias_correction1)
166+
update = exp_avg.clone()
167167

168168
if cautious:
169-
self.apply_cautious(exp_avg, grad)
169+
self.apply_cautious(update, grad)
170170

171-
return exp_avg.div(update)
171+
de_nom = self.apply_ams_bound(ams_bound, exp_avg_sq, max_exp_avg_sq, eps)
172+
de_nom.div_(bias_correction2_sq).mul_(bias_correction1)
173+
174+
return update.div_(de_nom)
172175

173176
@torch.no_grad()
174177
def step(self, closure: CLOSURE = None) -> LOSS:

pytorch_optimizer/optimizer/sgd.py

+123
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,9 @@ def __init__(
356356
}
357357
super().__init__(params, defaults)
358358

359+
def __str__(self) -> str:
360+
return 'SignSGD'
361+
359362
@torch.no_grad()
360363
def reset(self):
361364
for group in self.param_groups:
@@ -396,3 +399,123 @@ def step(self, closure: CLOSURE = None) -> LOSS:
396399
p.add_(torch.sign(buf), alpha=-group['lr'])
397400

398401
return loss
402+
403+
404+
class SGDSaI(BaseOptimizer):
405+
r"""No More Adam: Learning Rate Scaling at Initialization is All You Need.
406+
407+
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
408+
:param lr: float. learning rate.
409+
:param momentum: float. coefficients used for computing running averages of gradient.
410+
:param weight_decay: float. weight decay (L2 penalty).
411+
:param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
412+
:param eps: float. term added to the denominator to improve numerical stability.
413+
"""
414+
415+
def __init__(
416+
self,
417+
params: PARAMETERS,
418+
lr: float = 1e-2,
419+
momentum: float = 0.9,
420+
weight_decay: float = 1e-2,
421+
weight_decouple: bool = True,
422+
eps: float = 1e-8,
423+
**kwargs,
424+
):
425+
self.validate_learning_rate(lr)
426+
self.validate_range(momentum, 'beta', 0.0, 1.0)
427+
self.validate_non_negative(weight_decay, 'weight_decay')
428+
self.validate_non_negative(eps, 'eps')
429+
430+
self.has_warmup: bool = False
431+
432+
defaults: DEFAULTS = {
433+
'lr': lr,
434+
'momentum': momentum,
435+
'weight_decay': weight_decay,
436+
'weight_decouple': weight_decouple,
437+
'eps': eps,
438+
}
439+
super().__init__(params, defaults)
440+
441+
def __str__(self) -> str:
442+
return 'SGDSaI'
443+
444+
@torch.no_grad()
445+
def reset(self):
446+
for group in self.param_groups:
447+
group['step'] = 0
448+
for p in group['params']:
449+
state = self.state[p]
450+
451+
if group['momentum'] > 0.0:
452+
state['momentum_buffer'] = torch.zeros_like(p)
453+
454+
@torch.no_grad()
455+
def warmup_step(self, closure: CLOSURE = None) -> LOSS:
456+
loss: LOSS = None
457+
if closure is not None:
458+
with torch.enable_grad():
459+
loss = closure()
460+
461+
for group in self.param_groups:
462+
for p in group['params']:
463+
if p.grad is None:
464+
continue
465+
466+
grad = p.grad
467+
if grad.is_sparse:
468+
raise NoSparseGradientError(str(self))
469+
470+
sigma = grad.std().nan_to_num_()
471+
grad_norm = grad.norm()
472+
473+
g_snr = grad_norm.div_(sigma.add_(group['eps'])) if sigma != 0.0 else grad_norm
474+
475+
self.state[p]['gsnr'] = g_snr
476+
477+
self.has_warmup = True
478+
479+
return loss
480+
481+
@torch.no_grad()
482+
def step(self, closure: CLOSURE = None) -> LOSS:
483+
if not self.has_warmup:
484+
self.warmup_step(closure)
485+
486+
loss: LOSS = None
487+
if closure is not None:
488+
with torch.enable_grad():
489+
loss = closure()
490+
491+
for group in self.param_groups:
492+
momentum: float = group['momentum']
493+
for p in group['params']:
494+
if p.grad is None:
495+
continue
496+
497+
grad = p.grad
498+
499+
state = self.state[p]
500+
501+
if momentum > 0.0:
502+
if 'momentum_buffer' not in state:
503+
state['momentum_buffer'] = grad.clone()
504+
505+
buf = state['momentum_buffer']
506+
buf.mul_(momentum).add_(grad, alpha=1.0 - momentum)
507+
else:
508+
buf = grad
509+
510+
self.apply_weight_decay(
511+
p,
512+
grad,
513+
group['lr'],
514+
group['weight_decay'],
515+
group['weight_decouple'],
516+
False,
517+
)
518+
519+
p.add_(buf, alpha=-group['lr'] * state['gsnr'])
520+
521+
return loss

tests/constants.py

+3
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@
7474
ScheduleFreeAdamW,
7575
ScheduleFreeRAdam,
7676
ScheduleFreeSGD,
77+
SGDSaI,
7778
Shampoo,
7879
SignSGD,
7980
SophiaH,
@@ -538,6 +539,8 @@
538539
(MARS, {'lr': 1e-1, 'weight_decay': 1e-3, 'mars_type': 'lion', 'optimize_1d': True}, 5),
539540
(MARS, {'lr': 5e-1, 'lr_1d': 5e-1, 'weight_decay': 1e-3, 'mars_type': 'shampoo'}, 5),
540541
(MARS, {'lr': 5e-1, 'lr_1d': 5e-1, 'weight_decay': 1e-3, 'mars_type': 'adamw', 'ams_bound': True}, 5),
542+
(SGDSaI, {'lr': 1e0}, 15),
543+
(SGDSaI, {'lr': 1e0, 'momentum': 0.0}, 15),
541544
]
542545
ADANORM_SUPPORTED_OPTIMIZERS: List[Tuple[Any, Dict[str, Union[float, bool, int]], int]] = [
543546
(AdaBelief, {'lr': 5e-1, 'weight_decay': 1e-3, 'adanorm': True}, 10),

tests/test_load_modules.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def test_load_lr_scheduler_invalid(invalid_lr_scheduler_names):
3434

3535

3636
def test_get_supported_optimizers():
37-
assert len(get_supported_optimizers()) == 84
37+
assert len(get_supported_optimizers()) == 85
3838
assert len(get_supported_optimizers('adam*')) == 7
3939
assert len(get_supported_optimizers(['adam*', 'ranger*'])) == 9
4040

0 commit comments

Comments
 (0)