Skip to content

Commit 3b2e11d

Browse files
committed
fix: g_snr
1 parent bc9ab50 commit 3b2e11d

File tree

1 file changed

+8
-9
lines changed
  • pytorch_optimizer/optimizer

1 file changed

+8
-9
lines changed

pytorch_optimizer/optimizer/sgd.py

+8-9
Original file line numberDiff line numberDiff line change
@@ -406,7 +406,7 @@ class SGDSaI(BaseOptimizer):
406406
407407
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
408408
:param lr: float. learning rate.
409-
:param momentum: float. momentum factor (0.0 = SignSGD, >0 = Signum).
409+
:param momentum: float. coefficients used for computing running averages of gradient.
410410
:param weight_decay: float. weight decay (L2 penalty).
411411
:param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
412412
:param eps: float. term added to the denominator to improve numerical stability.
@@ -415,7 +415,7 @@ class SGDSaI(BaseOptimizer):
415415
def __init__(
416416
self,
417417
params: PARAMETERS,
418-
lr: float = 1e-3,
418+
lr: float = 1e-2,
419419
momentum: float = 0.9,
420420
weight_decay: float = 1e-2,
421421
weight_decouple: bool = True,
@@ -468,10 +468,11 @@ def warmup_step(self, closure: CLOSURE = None) -> LOSS:
468468
raise NoSparseGradientError(str(self))
469469

470470
sigma = grad.std().nan_to_num_()
471-
grad_norm_snr = grad.norm()
472-
grad_norm_snr.div_(sigma.add_(group['eps']))
471+
grad_norm = grad.norm()
473472

474-
self.state[p]['gsnr'] = grad_norm_snr
473+
g_snr = grad_norm.div_(sigma.add_(group['eps'])) if sigma != 0.0 else grad_norm
474+
475+
self.state[p]['gsnr'] = g_snr
475476

476477
self.has_warmup = True
477478

@@ -488,7 +489,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
488489
loss = closure()
489490

490491
for group in self.param_groups:
491-
momentum = group['momentum']
492+
momentum: float = group['momentum']
492493
for p in group['params']:
493494
if p.grad is None:
494495
continue
@@ -506,8 +507,6 @@ def step(self, closure: CLOSURE = None) -> LOSS:
506507
else:
507508
buf = grad
508509

509-
step_size = group['lr'] * state['gsnr']
510-
511510
self.apply_weight_decay(
512511
p,
513512
grad,
@@ -517,6 +516,6 @@ def step(self, closure: CLOSURE = None) -> LOSS:
517516
False,
518517
)
519518

520-
p.add_(buf, alpha=-step_size)
519+
p.add_(buf, alpha=-group['lr'] * state['gsnr'])
521520

522521
return loss

0 commit comments

Comments
 (0)