Important
It's recommended to use heavyball.utils.set_torch()
for faster training and less memory usage.
A simple package of efficient optimizers
The goal is not to thrive for completeness, full maintenance or abstraction, but instead to provide a simple
largely static alternative to torch.optim
with more and better optimizers.
Currently (2024-12-07, 1.0.0), the recommended stable optimizer is PrecondSchedulePaLMSOAP
(see below). The
recommended experimental optimizer is DelayedPSGDKron
(tuning guide).
- Optax-like API:
C = heavyball.chainable; grokfast = C.ChainOpt(p, lr, C.exp_avg, C.scale_by_adam)
- Stochastic Rounding: FP32 convergence with BF16 parameters
- Inplace EMA: Same math, but less memory, less compute and higher stability
- Foreach: Fast multi-tensor application (turn it off to save memory via
foreach=False
) - PaLM Beta2: Fast initial convergence, stable late convergence
- ScheduleFree: No learning rate schedule, but better convergence
- Preconditioner Schedule: Improved loss-per-step in early convergence, better step-per-second in late convergence (explained below)
- Memory-efficient storage PSGD supports
store_triu_as_line
(default:True
) andq_dtype
to trade off memory usage for memory bandwidth; Other optimizers havestorage_dtype
, supporting lower-precision EMAs at no(?) performance drop via stochastic rounding
pip install heavyball
import torch
import heavyball
# Create a model
model = torch.nn.Linear(16, 1)
# Create an optimizer
optimizer = heavyball.PrecondSchedulePaLMSOAP(model.parameters(), lr=1e-3)
x = torch.randn(128, 16)
y = torch.randn(128, 1)
for _ in range(1000):
optimizer.zero_grad()
loss = torch.nn.functional.mse_loss(model(x), y)
loss.backward()
optimizer.step()
Name | Description | Advantages / Disadvantages |
---|---|---|
AdamW | More efficient (speed, memory) AdamW | + Faster than AdamW + Possibly more (numerically) stable |
LaProp | More efficient (speed, memory) LaProp | + Same cost as AdamW + Marginally better converence (better proofs) + Higher hyperparameter stability - Not a guaranteed win (can be neutral) - No "Slingshot" |
ADOPT | More efficient (speed, memory) ADOPT | + Same cost as AdamW + Rigorous mathematical convergence proofs, even for challenging models (GANs) - Empirically underperforms LaProp - no bf16 |
SFAdamW | More efficient (speed, memory) ScheduleFree AdamW | + Same cost as AdamW, but better eval perf + Full control over hyperparameters |
PaLMSFAdamW | ForeachSFAdamW with PaLM's beta2 schedule | + Same cost as AdamW, but better eval perf + Less control, but faster early and more stable late convergence + ScheduleFree - slow early convergence |
SOAP | More efficient (speed, memory) SOAP | + Faster convergence (loss-at-step) + Full control over hyperparameters - more memory usage - more hyperparameters - higher overhead than AdamW (can be ammortized; better loss-at-second) |
PaLMSOAP | ForeachSOAP with PaLM's beta2 schedule | + Faster convergence (loss-at-step) + Less control, but faster early and more stable late convergence - more memory usage - more hyperparameters - higher overhead than AdamW (can be ammortized; better loss-at-second) |
SFPaLMSOAP | ScheduleFree PaLMForeachSOAP | + Fast convergence (loss-at-step) + less memory usage than PaLMForeachSOAP (more tham AdamW) - slower initial convergence than PaLMForeachSOAP (but allows higher LRs) - higher overhead than AdamW (can be ammortized) |
PrecondScheduleSFPaLMSOAP | SFPaLMForeachSOAP with preconditioner schedule, matching the error of PrecondEvery=2 with the cost of PrecondEvery=512 | + Better initial convergence than SFPaLMForeachSOAP + Significantly faster (sec/it) later + less memory usage than PaLMForeachSOAP (more tham AdamW) - slower initial convergence than PaLMForeachSOAP (but allows higher LRs) - higher overhead than AdamW (can be ammortized), goes to 0 with increasing number of step |
PrecondSchedulePaLMSOAP | PrecondScheduleSFPaLMForeachSOAP without schedule-free | + Best initial convergence + Significantly faster (sec/it) later + high stability - more memory usage than PrecondScheduleSFPaLMForeachSOAP - higher overhead than AdamW (can be ammortized), goes to 0 with increasing number of steps |
PrecondScheduleSOAP | PrecondScheduleSFPaLMForeachSOAP without PaLM's beta2 schedule | + Better initial convergence + Significantly faster (sec/it) later - more memory usage than PrecondScheduleSFPaLMForeachSOAP - higher overhead than AdamW (can be ammortized), goes to 0 with increasing number of steps |
The default preconditioner schedule (f
) would yield the following update intervals:
Steps | Interval, f |
Total (schedule) | Total (constant, every 2) | Total (constant, every 16) |
---|---|---|---|---|
10 | 1.00005 | 10 | 5 (0.5x) | 0 (0.0x) |
100 | 1.026 | 99 | 50 (0.5x) | 6 (0.1x) |
1,000 | 2.0 | 738 | 500 (0.7x) | 62 (0.1x) |
10,000 | 14.3 | 2,168 | 5,000 (2.3x) | 625 (0.3x) |
100,000 | 100.2 | 4,049 | 50,000 (12.3x) | 6,250 (1.5x) |
1,000,000 | 513 | 7,245 | 500,000 (69.0x) | 62,500 (8.6x) |
Second order optimizers make it difficult to estimate memory usage, as it depends on shapes and hyperparameters. To
estimate your memory usage, you may use test/test_memory.py
which attempts to ensure there are no regressions.
Furthermore, you can find real-world memory usage of a 300M parameters video diffusion model below:
HeavyBall offers various configurations of PSGD:
- "PSGDKron" is the baseline, equivalent to kron_torch, but with lower compute and memory overhead.
- "PurePSGD" has no momentum, further reducing memory and compute
- "DelayedPSGD" implements SOAP/ADOPT-style off-by-one momentum, which has worse initial convergence but higher stability
To access heavyball.utils
, you need to explicitly import heavyball.utils
.
It has several handy functions:
set_torch()
sets pytorch optimization settings (TF32, opt_einsum, benchmark, ...)compile_mode
, a string passed as-is totorch.compile(mode=compile_mode)
in all compiled heavyball calls;compile_mode=None
disables torch_compilezeroth_power_mode
, a string determining whether to use QR, newtonschulz{iterations}, or svd or eigh to approximate the eigenvectors. Eigh has the highest precision and cost