Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -127,4 +127,4 @@ venv.bak/
dmypy.json

# Pyre type checker
.pyre/
.pyre/
2 changes: 1 addition & 1 deletion LICENSE
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,4 @@ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
5 changes: 2 additions & 3 deletions benchmark/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@ Each test is tagged with categories describing specific optimization challenges:
- Magnitude: Exploding/vanishing gradients
- Quality: Weak/noisy/sparse/delayed signals
- Direction: Adversarial/misleading patterns

- **Landscape Navigation**
- Topology: Valleys, plateaus, saddle points
- Complexity: Multimodality, nonconvexity
- Dynamics: Moving targets, shifting optima

- **Numerical Challenges**
- Conditioning: Ill-conditioned problems
- Scaling: Parameter scale variations
Expand All @@ -40,7 +40,6 @@ Each test is tagged with categories describing specific optimization challenges:
| Gradient Delay | Tests async update handling | Delayed Gradients & Async Updates | ✓ |
| Gradient Noise Scale | Tests noise level adaptation | Variable Noise & Scaling | ✓ |
| Grokking | Tests sudden learning after memorization | Phase Transitions & Memory | |
| Ill-Conditioned | Tests poor conditioning optimization | Conditioning & Convergence | ✓ |
| Layer-wise Scale | Tests multi-layer gradient scaling | Layer Variation & Balance | ✓ |
| Loss Contour | Tests complex landscape navigation | Surface Complexity & Visualization | |
| Momentum Utilization | Tests momentum in oscillating landscapes | Oscillations & Momentum | ✓ |
Expand Down
39 changes: 30 additions & 9 deletions benchmark/adversarial_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import typer
from torch import nn

from benchmark.utils import trial, param_norm_win_condition
from benchmark.utils import param_norm_win_condition, trial
from heavyball.utils import set_torch

app = typer.Typer(pretty_exceptions_enable=False)
Expand All @@ -16,7 +16,7 @@ class Model(nn.Module):
def __init__(self, size=1024):
super().__init__()
self.param = nn.Parameter(torch.randn(size))
self.register_buffer('step', torch.zeros(1))
self.register_buffer("step", torch.zeros(1))

def forward(self):
"""Test optimizer's robustness to adversarial gradient patterns."""
Expand All @@ -28,20 +28,41 @@ def forward(self):


@app.command()
def main(method: List[str] = typer.Option(['qr'], help='Eigenvector method to use (for SOAP)'),
dtype: List[str] = typer.Option(['float32'], help='Data type to use'), steps: int = 100,
weight_decay: float = 0, opt: List[str] = typer.Option(['ForeachSOAP'], help='Optimizers to use'),
trials: int = 100, win_condition_multiplier: float = 1.0, ):
def main(
method: List[str] = typer.Option(["qr"], help="Eigenvector method to use (for SOAP)"),
dtype: List[str] = typer.Option(["float32"], help="Data type to use"),
steps: int = 100,
weight_decay: float = 0,
opt: List[str] = typer.Option(["ForeachSOAP"], help="Optimizers to use"),
trials: int = 100,
win_condition_multiplier: float = 1.0,
):
dtype = [getattr(torch, d) for d in dtype]
model = Model().cuda().double()

def data():
return None, None

# More lenient condition due to adversarial component
trial(model, data, None, param_norm_win_condition(win_condition_multiplier * 1e-3, 0), steps, opt[0], dtype[0], 1, 1,
weight_decay, method[0], 1, 1, failure_threshold=7, base_lr=1e-3, trials=trials) # More attempts for adversarial case
trial(
model,
data,
None,
param_norm_win_condition(win_condition_multiplier * 1e-3, 0),
steps,
opt[0],
dtype[0],
1,
1,
weight_decay,
method[0],
1,
1,
failure_threshold=7,
base_lr=1e-3,
trials=trials,
) # More attempts for adversarial case


if __name__ == '__main__':
if __name__ == "__main__":
app()
44 changes: 33 additions & 11 deletions benchmark/batch_size_scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
import torch
import torch.backends.opt_einsum
import typer
from benchmark.utils import trial, param_norm_win_condition, Validator
from heavyball.utils import set_torch
from torch import nn

from benchmark.utils import param_norm_win_condition, trial
from heavyball.utils import set_torch

app = typer.Typer(pretty_exceptions_enable=False)
set_torch()

Expand All @@ -17,34 +18,55 @@ class Model(nn.Module):
def __init__(self, size=1024):
super().__init__()
self.param = nn.Parameter(torch.randn(size))
self.register_buffer('batch_sizes', torch.tensor([1, 4, 16, 64]))
self.register_buffer("batch_sizes", torch.tensor([1, 4, 16, 64]))
self.rng = random.Random(0x1238192)

def forward(self):
"""Test optimizer's ability to handle different batch sizes and noise scales."""
batch_size = self.rng.choice(self.batch_sizes)
generator = torch.Generator(device=self.param.device).manual_seed(self.rng.randint(0, 2 ** 31))
generator = torch.Generator(device=self.param.device).manual_seed(self.rng.randint(0, 2**31))
noise = torch.randn(self.param.shape, generator=generator, device=self.param.device)
scale = self.param.norm() / (noise.norm() + 1e-6)
noise *= scale.detach() / math.sqrt(batch_size)
return (self.param + noise).square().mean()


@app.command()
def main(method: List[str] = typer.Option(['qr'], help='Eigenvector method to use (for SOAP)'),
dtype: List[str] = typer.Option(['float32'], help='Data type to use'), steps: int = 100,
weight_decay: float = 0, opt: List[str] = typer.Option(['ForeachSOAP'], help='Optimizers to use'),
trials: int = 100, win_condition_multiplier: float = 1.0, ):
def main(
method: List[str] = typer.Option(["qr"], help="Eigenvector method to use (for SOAP)"),
dtype: List[str] = typer.Option(["float32"], help="Data type to use"),
steps: int = 100,
weight_decay: float = 0,
opt: List[str] = typer.Option(["ForeachSOAP"], help="Optimizers to use"),
trials: int = 100,
win_condition_multiplier: float = 1.0,
):
dtype = [getattr(torch, d) for d in dtype]
model = Model().cuda().double()

def data():
return None, None

# Use a more lenient win condition since we have inherent noise
trial(model, data, None, param_norm_win_condition(win_condition_multiplier * 1e-8, 0), steps, opt[0], dtype[0], 1,
1, weight_decay, method[0], 1, 1, failure_threshold=5, base_lr=1e-3, trials=trials)
trial(
model,
data,
None,
param_norm_win_condition(win_condition_multiplier * 1e-8, 0),
steps,
opt[0],
dtype[0],
1,
1,
weight_decay,
method[0],
1,
1,
failure_threshold=5,
base_lr=1e-3,
trials=trials,
)


if __name__ == '__main__':
if __name__ == "__main__":
app()
50 changes: 33 additions & 17 deletions benchmark/beale.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,15 @@
import copy
import pathlib
import random
import time
from typing import List

import matplotlib.colors
import matplotlib.pyplot as plt
import torch
import torch.backends.opt_einsum
import typer
from hyperopt import early_stop
from benchmark.utils import Plotter
from torch import nn

from benchmark.utils import trial, loss_win_condition
from benchmark.utils import Plotter, loss_win_condition, trial
from heavyball.utils import set_torch

early_stop.no_progress_loss()
Expand All @@ -24,7 +20,7 @@
def objective(x, y):
x = x + 3
y = y + 0.5
return (1.5 - x + x * y) ** 2 + (2.25 - x + x * y ** 2) ** 2 + (2.625 - x + x * y ** 3) ** 2
return (1.5 - x + x * y) ** 2 + (2.25 - x + x * y**2) ** 2 + (2.625 - x + x * y**3) ** 2


class Model(nn.Module):
Expand All @@ -37,15 +33,21 @@ def forward(self):


@app.command()
def main(method: List[str] = typer.Option(['qr'], help='Eigenvector method to use (for SOAP)'),
dtype: List[str] = typer.Option(['float32'], help='Data type to use'), steps: int = 100,
weight_decay: float = 0, opt: List[str] = typer.Option(['ForeachSOAP'], help='Optimizers to use'),
show_image: bool = False, trials: int = 100, win_condition_multiplier: float = 1.0, ):
def main(
method: List[str] = typer.Option(["qr"], help="Eigenvector method to use (for SOAP)"),
dtype: List[str] = typer.Option(["float32"], help="Data type to use"),
steps: int = 100,
weight_decay: float = 0,
opt: List[str] = typer.Option(["ForeachSOAP"], help="Optimizers to use"),
show_image: bool = False,
trials: int = 100,
win_condition_multiplier: float = 1.0,
):
dtype = [getattr(torch, d) for d in dtype]
coords = (-7, -4)

# Clean up old plots
for path in pathlib.Path('.').glob('beale.png'):
for path in pathlib.Path(".").glob("beale.png"):
path.unlink()

colors = list(matplotlib.colors.TABLEAU_COLORS.values())
Expand All @@ -61,16 +63,30 @@ def main(method: List[str] = typer.Option(['qr'], help='Eigenvector method to us
def data():
return None, None

model = trial(model, data, None, loss_win_condition(win_condition_multiplier * 1e-8 * (not show_image)), steps,
opt[0], dtype[0], 1, 1, weight_decay, method[0], 1, 1, base_lr=1e-4, trials=trials,
return_best=show_image)
model = trial(
model,
data,
None,
loss_win_condition(win_condition_multiplier * 1e-8 * (not show_image)),
steps,
opt[0],
dtype[0],
1,
1,
weight_decay,
method[0],
1,
1,
base_lr=1e-4,
trials=trials,
return_best=show_image,
)

if not show_image:
return

model.plot(save_path='beale.png')
model.plot(save_path="beale.png")



if __name__ == '__main__':
if __name__ == "__main__":
app()
48 changes: 30 additions & 18 deletions benchmark/char_rnn.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import datetime
import os
from pathlib import Path
from typing import List

Expand All @@ -9,9 +7,8 @@
import typer
from torch.nn import functional as F

import heavyball
from heavyball.utils import set_torch
from benchmark.utils import loss_win_condition, trial
from heavyball.utils import set_torch

app = typer.Typer(pretty_exceptions_enable=False)
set_torch()
Expand All @@ -30,7 +27,7 @@ def __init__(self, features: int, sequence: int):
nn.Embedding(256, features),
nn.LSTM(features, features, 1, batch_first=True), # Removed dropout since num_layers=1
Take0(),
nn.Linear(features, 256)
nn.Linear(features, 256),
)

def forward(self, inp):
Expand All @@ -39,14 +36,14 @@ def forward(self, inp):

@app.command()
def main(
method: List[str] = typer.Option(['qr'], help='Eigenvector method to use (for SOAP)'),
dtype: List[str] = typer.Option(['float32'], help='Data type to use'),
method: List[str] = typer.Option(["qr"], help="Eigenvector method to use (for SOAP)"),
dtype: List[str] = typer.Option(["float32"], help="Data type to use"),
features: int = 512,
sequence: int = 256,
batch: int = 16,
steps: int = 100,
weight_decay: float = 0,
opt: List[str] = typer.Option(['ForeachSOAP'], help='Optimizers to use'),
opt: List[str] = typer.Option(["ForeachSOAP"], help="Optimizers to use"),
win_condition_multiplier: float = 1.0,
trials: int = 10,
):
Expand All @@ -55,27 +52,42 @@ def main(

# Load text data
benchmark_dir = Path(__file__).parent
with open(benchmark_dir / 'shakespeare.txt', 'rb') as f:
with open(benchmark_dir / "shakespeare.txt", "rb") as f:
text = f.read()
chars = torch.frombuffer(text, dtype=torch.uint8).cuda().long()

# Create holdout set
holdout = chars[:(sequence + 1) * batch].view(batch, sequence + 1)
chars = chars[(sequence + 1) * batch:]
offsets = torch.arange(0, sequence + 1, device='cuda').repeat(batch, 1)
chars = chars[(sequence + 1) * batch :]
offsets = torch.arange(0, sequence + 1, device="cuda").repeat(batch, 1)

def data():
batch_offsets = torch.randint(0, len(chars) - sequence - 1, (batch,), device='cuda')
batch_offsets = torch.randint(0, len(chars) - sequence - 1, (batch,), device="cuda")
batch_offsets = batch_offsets[:, None] + offsets
batch_chars = chars[batch_offsets]
batch_chars = batch_chars.view(batch, sequence + 1)
src = batch_chars[:, :-1]
tgt = batch_chars[:, 1:]
return src, tgt

trial(model, data, F.cross_entropy, loss_win_condition(win_condition_multiplier * 2.0), steps, opt[0], dtype[0], features, batch, weight_decay, method[0], sequence, 1,
failure_threshold=10, base_lr=1e-3, trials=trials)


if __name__ == '__main__':
trial(
model,
data,
F.cross_entropy,
loss_win_condition(win_condition_multiplier * 2.0),
steps,
opt[0],
dtype[0],
features,
batch,
weight_decay,
method[0],
sequence,
1,
failure_threshold=10,
base_lr=1e-3,
trials=trials,
)


if __name__ == "__main__":
app()
Loading