Skip to content

Commit c09d18b

Browse files
authored
Merge pull request #353 from kozistr/update/schedulefree-optimizers
[Update] ScheduleFree optimizers
2 parents d18bb4b + dc08ba7 commit c09d18b

18 files changed

+74
-84
lines changed

docs/changelogs/v3.4.2.md

+11-1
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,20 @@
55
* Implement `SCION` optimizer. (#348, #352)
66
* [Training Deep Learning Models with Norm-Constrained LMOs](https://arxiv.org/abs/2502.07529)
77

8+
### Update
9+
10+
* Update ScheduleFreeSGD, AdamW, RAdam optimizers with the latest. (#351, #353)
11+
* Remove `use_palm` variant in ScheduleFree optimizer due to instability. (#353)
12+
* Ranger25 optimizer. (#353)
13+
14+
### Fix
15+
16+
* Remove `weight decouple` parameter in ScheduleFree optimizers. (#351, #353)
17+
818
### Docs
919

1020
* Fix `AliG` optimizer visualization. (#350)
1121

1222
### Contributions
1323

14-
thanks to @AidinHamedi
24+
thanks to @AidinHamedi, @hatonosuke

docs/visualization.md

+8
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,10 @@
302302

303303
![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rastrigin_ScheduleFreeSGD.png)
304304

305+
### SCION
306+
307+
![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rastrigin_SCION.png)
308+
305309
### SGD
306310

307311
![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rastrigin_SGD.png)
@@ -668,6 +672,10 @@
668672

669673
![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rosenbrock_ScheduleFreeSGD.png)
670674

675+
### SCION
676+
677+
![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rosenbrock_SCION.png)
678+
671679
### SGD
672680

673681
![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rosenbrock_SGD.png)
8.83 KB
Loading
632 KB
Loading
Loading
277 Bytes
Loading
-977 Bytes
Loading
9.31 KB
Loading
144 KB
Loading
Loading
Loading
-707 Bytes
Loading

poetry.lock

+19-19
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "pytorch_optimizer"
3-
version = "3.4.1"
3+
version = "3.4.2"
44
description = "optimizer & lr scheduler & objective function collections in PyTorch"
55
license = "Apache-2.0"
66
authors = ["kozistr <[email protected]>"]

pytorch_optimizer/optimizer/experimental/ranger25.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def __init__(
4040
self,
4141
params: PARAMETERS,
4242
lr: float = 1e-3,
43-
betas: BETAS = (0.95, 0.98, 0.9999),
43+
betas: BETAS = (0.9, 0.98, 0.9999),
4444
weight_decay: float = 1e-3,
4545
weight_decouple: bool = True,
4646
fixed_decay: bool = False,

pytorch_optimizer/optimizer/schedulefree.py

+32-58
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import math
21
from typing import List
32

43
import torch
@@ -15,8 +14,6 @@ class ScheduleFreeSGD(BaseOptimizer):
1514
:param lr: float. learning rate.
1615
:param momentum: float. momentum factor, must be between 0 and 1 exclusive.
1716
:param weight_decay: float. weight decay (L2 penalty).
18-
:param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
19-
:param fixed_decay: bool. fix weight decay.
2017
:param r: float. use polynomial weighting in the average with power r.
2118
:param weight_lr_power: float. during warmup, the weights in the average will be equal to lr raised to this power.
2219
set to 0 for no weighting.
@@ -30,8 +27,6 @@ def __init__(
3027
lr: float = 1.0,
3128
momentum: float = 0.9,
3229
weight_decay: float = 0.0,
33-
weight_decouple: bool = True,
34-
fixed_decay: bool = False,
3530
r: float = 0.0,
3631
weight_lr_power: float = 2.0,
3732
warmup_steps: int = 0,
@@ -47,8 +42,6 @@ def __init__(
4742
'lr': lr,
4843
'momentum': momentum,
4944
'weight_decay': weight_decay,
50-
'weight_decouple': weight_decouple,
51-
'fixed_decay': fixed_decay,
5245
'r': r,
5346
'weight_lr_power': weight_lr_power,
5447
'warmup_steps': warmup_steps,
@@ -114,7 +107,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
114107
lr: float = group['lr'] * schedule
115108
lr_max = group['lr_max'] = max(lr, group['lr_max'])
116109

117-
weight = (group['step'] ** group['r']) * (lr_max ** group['weight_lr_power'])
110+
weight: float = (group['step'] ** group['r']) * (lr_max ** group['weight_lr_power'])
118111
weight_sum = group['weight_sum'] = group['weight_sum'] + weight
119112

120113
checkpoint: float = weight / weight_sum if weight_sum != 0.0 else 0.0
@@ -137,8 +130,8 @@ def step(self, closure: CLOSURE = None) -> LOSS:
137130
grad=grad,
138131
lr=lr,
139132
weight_decay=group['weight_decay'],
140-
weight_decouple=group['weight_decouple'],
141-
fixed_decay=group['fixed_decay'],
133+
weight_decouple=False,
134+
fixed_decay=False,
142135
)
143136

144137
z = state['z']
@@ -158,8 +151,6 @@ class ScheduleFreeAdamW(BaseOptimizer):
158151
:param lr: float. learning rate.
159152
:param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
160153
:param weight_decay: float. weight decay (L2 penalty).
161-
:param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
162-
:param fixed_decay: bool. fix weight decay.
163154
:param r: float. use polynomial weighting in the average with power r.
164155
:param weight_lr_power: float. during warmup, the weights in the average will be equal to lr raised to this power.
165156
set to 0 for no weighting.
@@ -174,8 +165,6 @@ def __init__(
174165
lr: float = 2.5e-3,
175166
betas: BETAS = (0.9, 0.999),
176167
weight_decay: float = 0.0,
177-
weight_decouple: bool = True,
178-
fixed_decay: bool = False,
179168
r: float = 0.0,
180169
weight_lr_power: float = 2.0,
181170
warmup_steps: int = 0,
@@ -192,8 +181,6 @@ def __init__(
192181
'lr': lr,
193182
'betas': betas,
194183
'weight_decay': weight_decay,
195-
'weight_decouple': weight_decouple,
196-
'fixed_decay': fixed_decay,
197184
'r': r,
198185
'weight_lr_power': weight_lr_power,
199186
'warmup_steps': warmup_steps,
@@ -259,22 +246,16 @@ def step(self, closure: CLOSURE = None) -> LOSS:
259246

260247
beta1, beta2 = group['betas']
261248

262-
bias_correction2_sq: float = math.sqrt(1.0 - beta2 ** group['step'])
249+
bias_correction2: float = self.debias(beta2, group['step'])
263250

264-
lr: float = group['lr'] * schedule * bias_correction2_sq
251+
lr: float = group['lr'] * schedule
265252
lr_max = group['lr_max'] = max(lr, group['lr_max'])
266253

267-
weight = (group['step'] ** group['r']) * (lr_max ** group['weight_lr_power'])
254+
weight: float = (group['step'] ** group['r']) * (lr_max ** group['weight_lr_power'])
268255
weight_sum = group['weight_sum'] = group['weight_sum'] + weight
269256

270257
checkpoint: float = weight / weight_sum if weight_sum != 0.0 else 0.0
271258

272-
if group['use_palm']:
273-
beta2: float = 1.0 - group['step'] ** -0.8
274-
debias: float = (1.0 - beta2) / (1.0 - beta2 ** group['step'])
275-
else:
276-
debias: float = beta2
277-
278259
for p in group['params']:
279260
if p.grad is None:
280261
continue
@@ -289,27 +270,27 @@ def step(self, closure: CLOSURE = None) -> LOSS:
289270
state['z'] = p.clone()
290271
state['exp_avg_sq'] = torch.zeros_like(p)
291272

292-
self.apply_weight_decay(
293-
p=p,
294-
grad=grad,
295-
lr=lr,
296-
weight_decay=group['weight_decay'],
297-
weight_decouple=group['weight_decouple'],
298-
fixed_decay=group['fixed_decay'],
299-
)
300-
301273
z, exp_avg_sq = state['z'], state['exp_avg_sq']
302-
exp_avg_sq.mul_(debias).addcmul_(grad, grad, value=1.0 - debias)
274+
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
303275

304276
de_nom = self.apply_ams_bound(
305277
ams_bound=group['ams_bound'],
306-
exp_avg_sq=exp_avg_sq,
278+
exp_avg_sq=exp_avg_sq.div(bias_correction2),
307279
max_exp_avg_sq=state.get('max_exp_avg_sq', None),
308280
eps=group['eps'],
309281
)
310282

311283
grad.div_(de_nom)
312284

285+
self.apply_weight_decay(
286+
p=p,
287+
grad=grad,
288+
lr=lr,
289+
weight_decay=group['weight_decay'],
290+
weight_decouple=False,
291+
fixed_decay=False,
292+
)
293+
313294
p.lerp_(z, weight=checkpoint)
314295
p.add_(grad, alpha=lr * (beta1 * (1.0 - checkpoint) - 1))
315296

@@ -325,12 +306,13 @@ class ScheduleFreeRAdam(BaseOptimizer):
325306
:param lr: float. learning rate.
326307
:param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
327308
:param weight_decay: float. weight decay (L2 penalty).
328-
:param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
329-
:param fixed_decay: bool. fix weight decay.
330-
:param degenerated_to_sgd: float. degenerated to SGD.
331309
:param r: float. use polynomial weighting in the average with power r.
332310
:param weight_lr_power: float. during warmup, the weights in the average will be equal to lr raised to this power.
333311
set to 0 for no weighting.
312+
:param silent_sgd_phase: bool. the optimizer will not use the first SGD phase of RAdam. This means that the
313+
optimizer will not update model parameters during the early training steps (e.g., < 5 when β_2 = 0.999), but
314+
just update the momentum values of the optimizer. This helps stabilize training by ensuring smoother warmup
315+
behavior and more reliable calculation of the moving average coefficient (`ckp1`). Recommended to set to True.
334316
:param eps: float. term added to the denominator to improve numerical stability.
335317
"""
336318

@@ -340,11 +322,9 @@ def __init__(
340322
lr: float = 2.5e-3,
341323
betas: BETAS = (0.9, 0.999),
342324
weight_decay: float = 0.0,
343-
weight_decouple: bool = True,
344-
fixed_decay: bool = False,
345-
degenerated_to_sgd: bool = False,
346325
r: float = 0.0,
347326
weight_lr_power: float = 2.0,
327+
silent_sgd_phase: bool = True,
348328
eps: float = 1e-8,
349329
**kwargs,
350330
):
@@ -357,9 +337,7 @@ def __init__(
357337
'lr': lr,
358338
'betas': betas,
359339
'weight_decay': weight_decay,
360-
'weight_decouple': weight_decouple,
361-
'fixed_decay': fixed_decay,
362-
'degenerated_to_sgd': degenerated_to_sgd,
340+
'silent_sgd_phase': silent_sgd_phase,
363341
'r': r,
364342
'weight_lr_power': weight_lr_power,
365343
'eps': eps,
@@ -418,32 +396,28 @@ def step(self, closure: CLOSURE = None) -> LOSS:
418396

419397
beta1, beta2 = group['betas']
420398

421-
bias_correction2_sq: float = math.sqrt(1.0 - beta2 ** group['step'])
399+
bias_correction2: float = self.debias_beta(beta2, group['step'])
422400

423401
lr, n_sma = self.get_rectify_step_size(
424402
is_rectify=True,
425403
step=group['step'],
426404
lr=group['lr'],
427405
beta2=beta2,
428406
n_sma_threshold=4,
429-
degenerated_to_sgd=group['degenerated_to_sgd'],
407+
degenerated_to_sgd=False,
430408
)
409+
if lr < 0.0:
410+
lr = float(not group['silent_sgd_phase'])
431411

432412
lr_max = group['lr_max'] = max(lr, group['lr_max'])
433413

434-
weight = (group['step'] ** group['r']) * (lr_max ** group['weight_lr_power'])
414+
weight: float = (group['step'] ** group['r']) * (lr_max ** group['weight_lr_power'])
435415
weight_sum = group['weight_sum'] = group['weight_sum'] + weight
436416

437417
checkpoint: float = weight / weight_sum if weight_sum != 0.0 else 0.0
438418

439419
adaptive_y_lr: float = lr * (beta1 * (1.0 - checkpoint) - 1.0)
440420

441-
if group['use_palm']:
442-
beta2: float = 1.0 - group['step'] ** -0.8
443-
debias: float = (1.0 - beta2) / (1.0 - beta2 ** group['step'])
444-
else:
445-
debias: float = beta2
446-
447421
for p in group['params']:
448422
if p.grad is None:
449423
continue
@@ -459,19 +433,19 @@ def step(self, closure: CLOSURE = None) -> LOSS:
459433
state['exp_avg_sq'] = torch.zeros_like(p)
460434

461435
z, exp_avg_sq = state['z'], state['exp_avg_sq']
462-
exp_avg_sq.mul_(debias).addcmul_(grad, grad, value=1.0 - debias)
436+
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
463437

464438
if n_sma > 4.0:
465-
de_nom = exp_avg_sq.sqrt().div_(bias_correction2_sq).add_(group['eps'])
439+
de_nom = exp_avg_sq.sqrt().div_(bias_correction2).add_(group['eps'])
466440
grad.div_(de_nom)
467441

468442
self.apply_weight_decay(
469443
p=p,
470444
grad=grad,
471445
lr=lr,
472446
weight_decay=group['weight_decay'],
473-
weight_decouple=group['weight_decouple'],
474-
fixed_decay=group['fixed_decay'],
447+
weight_decouple=False,
448+
fixed_decay=False,
475449
)
476450

477451
p.lerp_(z, weight=checkpoint)

requirements-dev.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ platformdirs==4.3.6 ; python_version >= "3.8"
2727
pluggy==1.5.0 ; python_version >= "3.8"
2828
pytest-cov==5.0.0 ; python_version >= "3.8"
2929
pytest==8.3.4 ; python_version >= "3.8"
30-
ruff==0.9.6 ; python_version >= "3.8"
30+
ruff==0.9.7 ; python_version >= "3.8"
3131
setuptools==75.8.0 ; python_version >= "3.12"
3232
sympy==1.13.1 ; python_version >= "3.9"
3333
sympy==1.13.3 ; python_version < "3.9" and python_version >= "3.8"

tests/constants.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -503,9 +503,7 @@
503503
(Adalite, {'lr': 1e0, 'weight_decay': 1e-3}, 5),
504504
(ScheduleFreeSGD, {'lr': 1e0, 'weight_decay': 1e-3}, 5),
505505
(ScheduleFreeAdamW, {'lr': 1e0, 'weight_decay': 1e-3}, 5),
506-
(ScheduleFreeAdamW, {'lr': 1e-2, 'weight_decay': 1e-3, 'use_palm': True}, 5),
507-
(ScheduleFreeRAdam, {'lr': 1e0, 'weight_decay': 1e-3, 'degenerated_to_sgd': True}, 5),
508-
(ScheduleFreeRAdam, {'lr': 1e0, 'weight_decay': 1e-3, 'use_palm': True, 'degenerated_to_sgd': True}, 5),
506+
(ScheduleFreeRAdam, {'lr': 1e0}, 20),
509507
(FAdam, {'lr': 1e0, 'weight_decay': 1e-3}, 5),
510508
(GrokFastAdamW, {'lr': 5e0, 'weight_decay': 1e-3, 'grokfast_after_step': 1}, 5),
511509
(Kate, {'lr': 5e-2}, 10),

0 commit comments

Comments
 (0)