@@ -406,7 +406,7 @@ class SGDSaI(BaseOptimizer):
406
406
407
407
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
408
408
:param lr: float. learning rate.
409
- :param momentum: float. momentum factor (0.0 = SignSGD, >0 = Signum) .
409
+ :param momentum: float. coefficients used for computing running averages of gradient .
410
410
:param weight_decay: float. weight decay (L2 penalty).
411
411
:param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
412
412
:param eps: float. term added to the denominator to improve numerical stability.
@@ -415,7 +415,7 @@ class SGDSaI(BaseOptimizer):
415
415
def __init__ (
416
416
self ,
417
417
params : PARAMETERS ,
418
- lr : float = 1e-3 ,
418
+ lr : float = 1e-2 ,
419
419
momentum : float = 0.9 ,
420
420
weight_decay : float = 1e-2 ,
421
421
weight_decouple : bool = True ,
@@ -468,10 +468,11 @@ def warmup_step(self, closure: CLOSURE = None) -> LOSS:
468
468
raise NoSparseGradientError (str (self ))
469
469
470
470
sigma = grad .std ().nan_to_num_ ()
471
- grad_norm_snr = grad .norm ()
472
- grad_norm_snr .div_ (sigma .add_ (group ['eps' ]))
471
+ grad_norm = grad .norm ()
473
472
474
- self .state [p ]['gsnr' ] = grad_norm_snr
473
+ g_snr = grad_norm .div_ (sigma .add_ (group ['eps' ])) if sigma != 0.0 else grad_norm
474
+
475
+ self .state [p ]['gsnr' ] = g_snr
475
476
476
477
self .has_warmup = True
477
478
@@ -488,7 +489,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
488
489
loss = closure ()
489
490
490
491
for group in self .param_groups :
491
- momentum = group ['momentum' ]
492
+ momentum : float = group ['momentum' ]
492
493
for p in group ['params' ]:
493
494
if p .grad is None :
494
495
continue
@@ -506,8 +507,6 @@ def step(self, closure: CLOSURE = None) -> LOSS:
506
507
else :
507
508
buf = grad
508
509
509
- step_size = group ['lr' ] * state ['gsnr' ]
510
-
511
510
self .apply_weight_decay (
512
511
p ,
513
512
grad ,
@@ -517,6 +516,6 @@ def step(self, closure: CLOSURE = None) -> LOSS:
517
516
False ,
518
517
)
519
518
520
- p .add_ (buf , alpha = - step_size )
519
+ p .add_ (buf , alpha = - group [ 'lr' ] * state [ 'gsnr' ] )
521
520
522
521
return loss
0 commit comments