Skip to content

Commit d18bb4b

Browse files
authored
Merge pull request #352 from kozistr/feature/scion-optimizer
[Feature] Implement `SCION` optimizer
2 parents 464f4e4 + 0f54f47 commit d18bb4b

12 files changed

+156
-7
lines changed

README.md

+2-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
## The reasons why you use `pytorch-optimizer`.
1212

13-
* Wide range of supported optimizers. Currently, **98 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported!
13+
* Wide range of supported optimizers. Currently, **99 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported!
1414
* Including many variants such as `ADOPT`, `Cautious`, `AdamD`, `StableAdamW`, and `Gradient Centrailiaztion`
1515
* Easy to use, clean, and tested codes
1616
* Active maintenance
@@ -206,6 +206,7 @@ get_supported_optimizers(['adam*', 'ranger*'])
206206
| EXAdam | *The Power of Adaptive Cross-Moments* | [github](https://github.com/AhmedMostafa16/EXAdam) | <https://arxiv.org/abs/2412.20302> | [cite](https://github.com/AhmedMostafa16/EXAdam?tab=readme-ov-file#citation) |
207207
| GCSAM | *Gradient Centralized Sharpness Aware Minimization* | [github](https://github.com/mhassann22/GCSAM) | <https://arxiv.org/abs/2501.11584> | [cite](https://github.com/mhassann22/GCSAM?tab=readme-ov-file#citation) |
208208
| LookSAM | *Towards Efficient and Scalable Sharpness-Aware Minimization* | [github](https://github.com/rollovd/LookSAM) | <https://arxiv.org/abs/2203.02714> | [cite](https://ui.adsabs.harvard.edu/abs/2022arXiv220302714L/exportcitation) |
209+
| SCION | *Training Deep Learning Models with Norm-Constrained LMOs* | | <https://arxiv.org/abs/2502.07529> | [cite](https://ui.adsabs.harvard.edu/abs/2025arXiv250207529P/exportcitation) |
209210

210211
## Supported LR Scheduler
211212

docs/changelogs/v3.4.2.md

+14
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
### Change Log
2+
3+
### Feature
4+
5+
* Implement `SCION` optimizer. (#348, #352)
6+
* [Training Deep Learning Models with Norm-Constrained LMOs](https://arxiv.org/abs/2502.07529)
7+
8+
### Docs
9+
10+
* Fix `AliG` optimizer visualization. (#350)
11+
12+
### Contributions
13+
14+
thanks to @AidinHamedi

docs/index.md

+2-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
## The reasons why you use `pytorch-optimizer`.
1212

13-
* Wide range of supported optimizers. Currently, **98 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported!
13+
* Wide range of supported optimizers. Currently, **99 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported!
1414
* Including many variants such as `ADOPT`, `Cautious`, `AdamD`, `StableAdamW`, and `Gradient Centrailiaztion`
1515
* Easy to use, clean, and tested codes
1616
* Active maintenance
@@ -206,6 +206,7 @@ get_supported_optimizers(['adam*', 'ranger*'])
206206
| EXAdam | *The Power of Adaptive Cross-Moments* | [github](https://github.com/AhmedMostafa16/EXAdam) | <https://arxiv.org/abs/2412.20302> | [cite](https://github.com/AhmedMostafa16/EXAdam?tab=readme-ov-file#citation) |
207207
| GCSAM | *Gradient Centralized Sharpness Aware Minimization* | [github](https://github.com/mhassann22/GCSAM) | <https://arxiv.org/abs/2501.11584> | [cite](https://github.com/mhassann22/GCSAM?tab=readme-ov-file#citation) |
208208
| LookSAM | *Towards Efficient and Scalable Sharpness-Aware Minimization* | [github](https://github.com/rollovd/LookSAM) | <https://arxiv.org/abs/2203.02714> | [cite](https://ui.adsabs.harvard.edu/abs/2022arXiv220302714L/exportcitation) |
209+
| SCION | *Training Deep Learning Models with Norm-Constrained LMOs* | | <https://arxiv.org/abs/2502.07529> | [cite](https://ui.adsabs.harvard.edu/abs/2025arXiv250207529P/exportcitation) |
209210

210211
## Supported LR Scheduler
211212

docs/optimizer.md

+4
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,10 @@
336336
:docstring:
337337
:members:
338338

339+
::: pytorch_optimizer.SCION
340+
:docstring:
341+
:members:
342+
339343
::: pytorch_optimizer.StableAdamW
340344
:docstring:
341345
:members:

pyproject.toml

+4-4
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,10 @@ keywords = [
1818
"GrokFast", "GSAM", "Kate", "Lamb", "LaProp", "LARS", "Lion", "LOMO", "Lookahead", "MADGRAD", "MARS", "MSVAG",
1919
"Muno", "Nero", "NovoGrad", "OrthoGrad", "PAdam", "PCGrad", "PID", "PNM", "Prodigy", "PSGD", "QHAdam", "QHM",
2020
"RAdam", "Ranger", "Ranger21", "RotoGrad", "SAM", "GCSAM", "LookSAM", "ScheduleFreeSGD", "ScheduleFreeAdamW",
21-
"ScheduleFreeRAdam", "SGDP", "Shampoo", "ScalableShampoo", "SGDW", "SignSGD", "SM3", "SOAP", "SopihaH", "SPAM",
22-
"SRMM", "StableAdamW", "SWATS", "TAM", "Tiger", "TRAC", "WSAM", "Yogi", "BCE", "BCEFocal", "Focal", "FocalCosine",
23-
"SoftF1", "Dice", "LDAM", "Jaccard", "Bi-Tempered", "Tversky", "FocalTversky", "LovaszHinge", "bitsandbytes", "WSD",
24-
"QGaLore",
21+
"ScheduleFreeRAdam", "SCION", "SGDP", "Shampoo", "ScalableShampoo", "SGDW", "SignSGD", "SM3", "SOAP", "SopihaH",
22+
"SPAM", "SRMM", "StableAdamW", "SWATS", "TAM", "Tiger", "TRAC", "WSAM", "Yogi", "BCE", "BCEFocal", "Focal",
23+
"FocalCosine", "SoftF1", "Dice", "LDAM", "Jaccard", "Bi-Tempered", "Tversky", "FocalTversky", "LovaszHinge",
24+
"bitsandbytes", "WSD", "QGaLore",
2525
]
2626
classifiers = [
2727
"License :: OSI Approved :: Apache Software License",

pytorch_optimizer/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
PNM,
5959
QHM,
6060
SAM,
61+
SCION,
6162
SGDP,
6263
SGDW,
6364
SM3,

pytorch_optimizer/optimizer/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@
8181
from pytorch_optimizer.optimizer.rotograd import RotoGrad
8282
from pytorch_optimizer.optimizer.sam import BSAM, GSAM, SAM, WSAM, LookSAM
8383
from pytorch_optimizer.optimizer.schedulefree import ScheduleFreeAdamW, ScheduleFreeRAdam, ScheduleFreeSGD
84+
from pytorch_optimizer.optimizer.scion import SCION
8485
from pytorch_optimizer.optimizer.sgd import ASGD, SGDW, AccSGD, SGDSaI, SignSGD
8586
from pytorch_optimizer.optimizer.sgdp import SGDP
8687
from pytorch_optimizer.optimizer.shampoo import ScalableShampoo, Shampoo
@@ -300,6 +301,7 @@ def load_optimizer(optimizer: str) -> OPTIMIZER:
300301
SPAM,
301302
Kron,
302303
EXAdam,
304+
SCION,
303305
Ranger25,
304306
]
305307
OPTIMIZERS: Dict[str, OPTIMIZER] = {str(optimizer.__name__).lower(): optimizer for optimizer in OPTIMIZER_LIST}

pytorch_optimizer/optimizer/scion.py

+113
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
from typing import Literal
2+
3+
import torch
4+
5+
from pytorch_optimizer.base.exception import NoSparseGradientError
6+
from pytorch_optimizer.base.optimizer import BaseOptimizer
7+
from pytorch_optimizer.base.type import CLOSURE, DEFAULTS, LOSS, PARAMETERS
8+
from pytorch_optimizer.optimizer.shampoo_utils import zero_power_via_newton_schulz_5
9+
10+
LMO_TYPE = Literal['spectral', 'sign', 'col_norm', 'row_norm']
11+
12+
13+
class SCION(BaseOptimizer):
14+
r"""Training Deep Learning Models with Norm-Constrained LMOs.
15+
16+
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
17+
:param lr: float. learning rate.
18+
:param momentum: float. momentum factor.
19+
:param constraint: bool. whether to use a constraint SCG or not.
20+
:param lmo_type: LMO_TYPE. supported LMO types.
21+
:param weight_decay: float. weight decay (L2 penalty).
22+
:param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
23+
"""
24+
25+
def __init__(
26+
self,
27+
params: PARAMETERS,
28+
lr: float = 1e-4,
29+
momentum: float = 0.1,
30+
constraint: bool = False,
31+
lmo_type: LMO_TYPE = 'spectral',
32+
weight_decay: float = 0.0,
33+
weight_decouple: bool = True,
34+
**kwargs,
35+
):
36+
self.validate_learning_rate(lr)
37+
self.validate_range(momentum, 'momentum', 0.0, 1.0, '(]')
38+
self.validate_options(lmo_type, 'lmo_type', ['spectral', 'sign', 'col_norm', 'row_norm'])
39+
40+
defaults: DEFAULTS = {
41+
'lr': lr,
42+
'momentum': momentum,
43+
'constraint': constraint,
44+
'lmo_type': lmo_type,
45+
'weight_decay': weight_decay,
46+
'weight_decouple': weight_decouple,
47+
}
48+
super().__init__(params, defaults)
49+
50+
def __str__(self) -> str:
51+
return 'SCION'
52+
53+
@torch.no_grad()
54+
def reset(self):
55+
for group in self.param_groups:
56+
for p in group['params']:
57+
state = self.state[p]
58+
state['d'] = torch.zeros_like(p)
59+
60+
@staticmethod
61+
def get_lmo_direction(grad: torch.Tensor, lmo_type: str) -> torch.Tensor:
62+
r"""Get LMO direction."""
63+
if lmo_type == 'spectral' and grad.ndim == 2:
64+
return zero_power_via_newton_schulz_5(grad)
65+
if lmo_type == 'sign':
66+
return torch.sign(grad)
67+
if lmo_type == 'col_norm':
68+
return grad / torch.norm(grad, dim=0, keepdim=True).add_(1e-6)
69+
if lmo_type == 'row_norm' and grad.ndim == 2:
70+
return grad / torch.norm(grad, dim=1, keepdim=True).add_(1e-6)
71+
return torch.sign(grad)
72+
73+
@torch.no_grad()
74+
def step(self, closure: CLOSURE = None) -> LOSS:
75+
loss: LOSS = None
76+
if closure is not None:
77+
with torch.enable_grad():
78+
loss = closure()
79+
80+
for group in self.param_groups:
81+
step_size: float = -group['lr']
82+
for p in group['params']:
83+
if p.grad is None:
84+
continue
85+
86+
grad = p.grad
87+
if grad.is_sparse:
88+
raise NoSparseGradientError(str(self))
89+
90+
state = self.state[p]
91+
if 'd' not in state:
92+
state['d'] = torch.zeros_like(p)
93+
94+
d = state['d']
95+
d.mul_(1.0 - group['momentum']).add_(grad, alpha=group['momentum'])
96+
97+
update = self.get_lmo_direction(d, group['lmo_type'])
98+
99+
if not group['constraint']:
100+
self.apply_weight_decay(
101+
p,
102+
grad,
103+
lr=group['lr'],
104+
weight_decay=group['weight_decay'],
105+
weight_decouple=group['weight_decouple'],
106+
fixed_decay=False,
107+
)
108+
109+
p.add_(update, alpha=step_size)
110+
else:
111+
p.mul_(1.0 - step_size).add_(update, alpha=step_size)
112+
113+
return loss

tests/constants.py

+3
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
PID,
1616
PNM,
1717
QHM,
18+
SCION,
1819
SGDP,
1920
SGDW,
2021
SM3,
@@ -563,6 +564,8 @@
563564
(FOCUS, {'lr': 1e-1, 'weight_decay': 1e-3}, 5),
564565
(Kron, {'lr': 1e0, 'weight_decay': 1e-3}, 3),
565566
(EXAdam, {'lr': 1e-1, 'weight_decay': 1e-3}, 5),
567+
(SCION, {'lr': 5e-1, 'constraint': False, 'weight_decay': 1e-3}, 10),
568+
(SCION, {'lr': 1e-1, 'constraint': True}, 10),
566569
(Ranger25, {'lr': 1e-1}, 3),
567570
(Ranger25, {'lr': 1e-1, 't_alpha_beta3': 5}, 3),
568571
(Ranger25, {'lr': 5e-2, 'stable_adamw': False, 'orthograd': False, 'eps': None, 'lookahead_merge_time': 2}, 3),

tests/test_general_optimizer_parameters.py

+2
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ def test_epsilon(optimizer_name):
5757
'focus',
5858
'kron',
5959
'sgd',
60+
'scion',
6061
):
6162
pytest.skip(f'skip {optimizer_name} optimizer')
6263

@@ -86,6 +87,7 @@ def test_weight_decay(optimizer_name):
8687
'lomo',
8788
'ftrl',
8889
'muon',
90+
'scion',
8991
):
9092
pytest.skip(f'skip {optimizer_name} optimizer')
9193

tests/test_load_modules.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def test_load_lr_scheduler_invalid(invalid_lr_scheduler_names):
3434

3535

3636
def test_get_supported_optimizers():
37-
assert len(get_supported_optimizers()) == 95
37+
assert len(get_supported_optimizers()) == 96
3838
assert len(get_supported_optimizers('adam*')) == 8
3939
assert len(get_supported_optimizers(['adam*', 'ranger*'])) == 11
4040

tests/test_optimizer_parameters.py

+8
Original file line numberDiff line numberDiff line change
@@ -303,3 +303,11 @@ def test_load_wrapper_optimizer(optimizer_instance):
303303

304304
state = optimizer.state_dict()
305305
optimizer.load_state_dict(state)
306+
307+
308+
def test_scion_lmo_direction():
309+
x = torch.zeros((1, 1), dtype=torch.float32)
310+
311+
optimizer_instance = load_optimizer('SCION')
312+
for lmo_direction in ('spectral', 'sign', 'col_norm', 'row_norm'):
313+
optimizer_instance.get_lmo_direction(x, lmo_direction)

0 commit comments

Comments
 (0)