1
- import math
2
1
from typing import List
3
2
4
3
import torch
@@ -15,8 +14,6 @@ class ScheduleFreeSGD(BaseOptimizer):
15
14
:param lr: float. learning rate.
16
15
:param momentum: float. momentum factor, must be between 0 and 1 exclusive.
17
16
:param weight_decay: float. weight decay (L2 penalty).
18
- :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
19
- :param fixed_decay: bool. fix weight decay.
20
17
:param r: float. use polynomial weighting in the average with power r.
21
18
:param weight_lr_power: float. during warmup, the weights in the average will be equal to lr raised to this power.
22
19
set to 0 for no weighting.
@@ -30,8 +27,6 @@ def __init__(
30
27
lr : float = 1.0 ,
31
28
momentum : float = 0.9 ,
32
29
weight_decay : float = 0.0 ,
33
- weight_decouple : bool = True ,
34
- fixed_decay : bool = False ,
35
30
r : float = 0.0 ,
36
31
weight_lr_power : float = 2.0 ,
37
32
warmup_steps : int = 0 ,
@@ -47,8 +42,6 @@ def __init__(
47
42
'lr' : lr ,
48
43
'momentum' : momentum ,
49
44
'weight_decay' : weight_decay ,
50
- 'weight_decouple' : weight_decouple ,
51
- 'fixed_decay' : fixed_decay ,
52
45
'r' : r ,
53
46
'weight_lr_power' : weight_lr_power ,
54
47
'warmup_steps' : warmup_steps ,
@@ -114,7 +107,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
114
107
lr : float = group ['lr' ] * schedule
115
108
lr_max = group ['lr_max' ] = max (lr , group ['lr_max' ])
116
109
117
- weight = (group ['step' ] ** group ['r' ]) * (lr_max ** group ['weight_lr_power' ])
110
+ weight : float = (group ['step' ] ** group ['r' ]) * (lr_max ** group ['weight_lr_power' ])
118
111
weight_sum = group ['weight_sum' ] = group ['weight_sum' ] + weight
119
112
120
113
checkpoint : float = weight / weight_sum if weight_sum != 0.0 else 0.0
@@ -137,8 +130,8 @@ def step(self, closure: CLOSURE = None) -> LOSS:
137
130
grad = grad ,
138
131
lr = lr ,
139
132
weight_decay = group ['weight_decay' ],
140
- weight_decouple = group [ 'weight_decouple' ] ,
141
- fixed_decay = group [ 'fixed_decay' ] ,
133
+ weight_decouple = False ,
134
+ fixed_decay = False ,
142
135
)
143
136
144
137
z = state ['z' ]
@@ -158,8 +151,6 @@ class ScheduleFreeAdamW(BaseOptimizer):
158
151
:param lr: float. learning rate.
159
152
:param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
160
153
:param weight_decay: float. weight decay (L2 penalty).
161
- :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
162
- :param fixed_decay: bool. fix weight decay.
163
154
:param r: float. use polynomial weighting in the average with power r.
164
155
:param weight_lr_power: float. during warmup, the weights in the average will be equal to lr raised to this power.
165
156
set to 0 for no weighting.
@@ -174,8 +165,6 @@ def __init__(
174
165
lr : float = 2.5e-3 ,
175
166
betas : BETAS = (0.9 , 0.999 ),
176
167
weight_decay : float = 0.0 ,
177
- weight_decouple : bool = True ,
178
- fixed_decay : bool = False ,
179
168
r : float = 0.0 ,
180
169
weight_lr_power : float = 2.0 ,
181
170
warmup_steps : int = 0 ,
@@ -192,8 +181,6 @@ def __init__(
192
181
'lr' : lr ,
193
182
'betas' : betas ,
194
183
'weight_decay' : weight_decay ,
195
- 'weight_decouple' : weight_decouple ,
196
- 'fixed_decay' : fixed_decay ,
197
184
'r' : r ,
198
185
'weight_lr_power' : weight_lr_power ,
199
186
'warmup_steps' : warmup_steps ,
@@ -259,22 +246,16 @@ def step(self, closure: CLOSURE = None) -> LOSS:
259
246
260
247
beta1 , beta2 = group ['betas' ]
261
248
262
- bias_correction2_sq : float = math . sqrt ( 1.0 - beta2 ** group ['step' ])
249
+ bias_correction2 : float = self . debias ( beta2 , group ['step' ])
263
250
264
- lr : float = group ['lr' ] * schedule * bias_correction2_sq
251
+ lr : float = group ['lr' ] * schedule
265
252
lr_max = group ['lr_max' ] = max (lr , group ['lr_max' ])
266
253
267
- weight = (group ['step' ] ** group ['r' ]) * (lr_max ** group ['weight_lr_power' ])
254
+ weight : float = (group ['step' ] ** group ['r' ]) * (lr_max ** group ['weight_lr_power' ])
268
255
weight_sum = group ['weight_sum' ] = group ['weight_sum' ] + weight
269
256
270
257
checkpoint : float = weight / weight_sum if weight_sum != 0.0 else 0.0
271
258
272
- if group ['use_palm' ]:
273
- beta2 : float = 1.0 - group ['step' ] ** - 0.8
274
- debias : float = (1.0 - beta2 ) / (1.0 - beta2 ** group ['step' ])
275
- else :
276
- debias : float = beta2
277
-
278
259
for p in group ['params' ]:
279
260
if p .grad is None :
280
261
continue
@@ -289,27 +270,27 @@ def step(self, closure: CLOSURE = None) -> LOSS:
289
270
state ['z' ] = p .clone ()
290
271
state ['exp_avg_sq' ] = torch .zeros_like (p )
291
272
292
- self .apply_weight_decay (
293
- p = p ,
294
- grad = grad ,
295
- lr = lr ,
296
- weight_decay = group ['weight_decay' ],
297
- weight_decouple = group ['weight_decouple' ],
298
- fixed_decay = group ['fixed_decay' ],
299
- )
300
-
301
273
z , exp_avg_sq = state ['z' ], state ['exp_avg_sq' ]
302
- exp_avg_sq .mul_ (debias ).addcmul_ (grad , grad , value = 1.0 - debias )
274
+ exp_avg_sq .mul_ (beta2 ).addcmul_ (grad , grad , value = 1.0 - beta2 )
303
275
304
276
de_nom = self .apply_ams_bound (
305
277
ams_bound = group ['ams_bound' ],
306
- exp_avg_sq = exp_avg_sq ,
278
+ exp_avg_sq = exp_avg_sq . div ( bias_correction2 ) ,
307
279
max_exp_avg_sq = state .get ('max_exp_avg_sq' , None ),
308
280
eps = group ['eps' ],
309
281
)
310
282
311
283
grad .div_ (de_nom )
312
284
285
+ self .apply_weight_decay (
286
+ p = p ,
287
+ grad = grad ,
288
+ lr = lr ,
289
+ weight_decay = group ['weight_decay' ],
290
+ weight_decouple = False ,
291
+ fixed_decay = False ,
292
+ )
293
+
313
294
p .lerp_ (z , weight = checkpoint )
314
295
p .add_ (grad , alpha = lr * (beta1 * (1.0 - checkpoint ) - 1 ))
315
296
@@ -325,12 +306,13 @@ class ScheduleFreeRAdam(BaseOptimizer):
325
306
:param lr: float. learning rate.
326
307
:param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
327
308
:param weight_decay: float. weight decay (L2 penalty).
328
- :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
329
- :param fixed_decay: bool. fix weight decay.
330
- :param degenerated_to_sgd: float. degenerated to SGD.
331
309
:param r: float. use polynomial weighting in the average with power r.
332
310
:param weight_lr_power: float. during warmup, the weights in the average will be equal to lr raised to this power.
333
311
set to 0 for no weighting.
312
+ :param silent_sgd_phase: bool. the optimizer will not use the first SGD phase of RAdam. This means that the
313
+ optimizer will not update model parameters during the early training steps (e.g., < 5 when β_2 = 0.999), but
314
+ just update the momentum values of the optimizer. This helps stabilize training by ensuring smoother warmup
315
+ behavior and more reliable calculation of the moving average coefficient (`ckp1`). Recommended to set to True.
334
316
:param eps: float. term added to the denominator to improve numerical stability.
335
317
"""
336
318
@@ -340,11 +322,9 @@ def __init__(
340
322
lr : float = 2.5e-3 ,
341
323
betas : BETAS = (0.9 , 0.999 ),
342
324
weight_decay : float = 0.0 ,
343
- weight_decouple : bool = True ,
344
- fixed_decay : bool = False ,
345
- degenerated_to_sgd : bool = False ,
346
325
r : float = 0.0 ,
347
326
weight_lr_power : float = 2.0 ,
327
+ silent_sgd_phase : bool = True ,
348
328
eps : float = 1e-8 ,
349
329
** kwargs ,
350
330
):
@@ -357,9 +337,7 @@ def __init__(
357
337
'lr' : lr ,
358
338
'betas' : betas ,
359
339
'weight_decay' : weight_decay ,
360
- 'weight_decouple' : weight_decouple ,
361
- 'fixed_decay' : fixed_decay ,
362
- 'degenerated_to_sgd' : degenerated_to_sgd ,
340
+ 'silent_sgd_phase' : silent_sgd_phase ,
363
341
'r' : r ,
364
342
'weight_lr_power' : weight_lr_power ,
365
343
'eps' : eps ,
@@ -418,32 +396,28 @@ def step(self, closure: CLOSURE = None) -> LOSS:
418
396
419
397
beta1 , beta2 = group ['betas' ]
420
398
421
- bias_correction2_sq : float = math . sqrt ( 1.0 - beta2 ** group ['step' ])
399
+ bias_correction2 : float = self . debias_beta ( beta2 , group ['step' ])
422
400
423
401
lr , n_sma = self .get_rectify_step_size (
424
402
is_rectify = True ,
425
403
step = group ['step' ],
426
404
lr = group ['lr' ],
427
405
beta2 = beta2 ,
428
406
n_sma_threshold = 4 ,
429
- degenerated_to_sgd = group [ 'degenerated_to_sgd' ] ,
407
+ degenerated_to_sgd = False ,
430
408
)
409
+ if lr < 0.0 :
410
+ lr = float (not group ['silent_sgd_phase' ])
431
411
432
412
lr_max = group ['lr_max' ] = max (lr , group ['lr_max' ])
433
413
434
- weight = (group ['step' ] ** group ['r' ]) * (lr_max ** group ['weight_lr_power' ])
414
+ weight : float = (group ['step' ] ** group ['r' ]) * (lr_max ** group ['weight_lr_power' ])
435
415
weight_sum = group ['weight_sum' ] = group ['weight_sum' ] + weight
436
416
437
417
checkpoint : float = weight / weight_sum if weight_sum != 0.0 else 0.0
438
418
439
419
adaptive_y_lr : float = lr * (beta1 * (1.0 - checkpoint ) - 1.0 )
440
420
441
- if group ['use_palm' ]:
442
- beta2 : float = 1.0 - group ['step' ] ** - 0.8
443
- debias : float = (1.0 - beta2 ) / (1.0 - beta2 ** group ['step' ])
444
- else :
445
- debias : float = beta2
446
-
447
421
for p in group ['params' ]:
448
422
if p .grad is None :
449
423
continue
@@ -459,19 +433,19 @@ def step(self, closure: CLOSURE = None) -> LOSS:
459
433
state ['exp_avg_sq' ] = torch .zeros_like (p )
460
434
461
435
z , exp_avg_sq = state ['z' ], state ['exp_avg_sq' ]
462
- exp_avg_sq .mul_ (debias ).addcmul_ (grad , grad , value = 1.0 - debias )
436
+ exp_avg_sq .mul_ (beta2 ).addcmul_ (grad , grad , value = 1.0 - beta2 )
463
437
464
438
if n_sma > 4.0 :
465
- de_nom = exp_avg_sq .sqrt ().div_ (bias_correction2_sq ).add_ (group ['eps' ])
439
+ de_nom = exp_avg_sq .sqrt ().div_ (bias_correction2 ).add_ (group ['eps' ])
466
440
grad .div_ (de_nom )
467
441
468
442
self .apply_weight_decay (
469
443
p = p ,
470
444
grad = grad ,
471
445
lr = lr ,
472
446
weight_decay = group ['weight_decay' ],
473
- weight_decouple = group [ 'weight_decouple' ] ,
474
- fixed_decay = group [ 'fixed_decay' ] ,
447
+ weight_decouple = False ,
448
+ fixed_decay = False ,
475
449
)
476
450
477
451
p .lerp_ (z , weight = checkpoint )
0 commit comments