diff --git a/.gitignore b/.gitignore index 20c452a..0b6bf46 100644 --- a/.gitignore +++ b/.gitignore @@ -127,4 +127,4 @@ venv.bak/ dmypy.json # Pyre type checker -.pyre/ \ No newline at end of file +.pyre/ diff --git a/LICENSE b/LICENSE index 5716cb7..adfae61 100644 --- a/LICENSE +++ b/LICENSE @@ -22,4 +22,4 @@ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/benchmark/README.md b/benchmark/README.md index 162c22f..79feb43 100644 --- a/benchmark/README.md +++ b/benchmark/README.md @@ -10,12 +10,12 @@ Each test is tagged with categories describing specific optimization challenges: - Magnitude: Exploding/vanishing gradients - Quality: Weak/noisy/sparse/delayed signals - Direction: Adversarial/misleading patterns - + - **Landscape Navigation** - Topology: Valleys, plateaus, saddle points - Complexity: Multimodality, nonconvexity - Dynamics: Moving targets, shifting optima - + - **Numerical Challenges** - Conditioning: Ill-conditioned problems - Scaling: Parameter scale variations @@ -40,7 +40,6 @@ Each test is tagged with categories describing specific optimization challenges: | Gradient Delay | Tests async update handling | Delayed Gradients & Async Updates | ✓ | | Gradient Noise Scale | Tests noise level adaptation | Variable Noise & Scaling | ✓ | | Grokking | Tests sudden learning after memorization | Phase Transitions & Memory | | -| Ill-Conditioned | Tests poor conditioning optimization | Conditioning & Convergence | ✓ | | Layer-wise Scale | Tests multi-layer gradient scaling | Layer Variation & Balance | ✓ | | Loss Contour | Tests complex landscape navigation | Surface Complexity & Visualization | | | Momentum Utilization | Tests momentum in oscillating landscapes | Oscillations & Momentum | ✓ | diff --git a/benchmark/adversarial_gradient.py b/benchmark/adversarial_gradient.py index 0029585..558c21f 100644 --- a/benchmark/adversarial_gradient.py +++ b/benchmark/adversarial_gradient.py @@ -5,7 +5,7 @@ import typer from torch import nn -from benchmark.utils import trial, param_norm_win_condition +from benchmark.utils import param_norm_win_condition, trial from heavyball.utils import set_torch app = typer.Typer(pretty_exceptions_enable=False) @@ -16,7 +16,7 @@ class Model(nn.Module): def __init__(self, size=1024): super().__init__() self.param = nn.Parameter(torch.randn(size)) - self.register_buffer('step', torch.zeros(1)) + self.register_buffer("step", torch.zeros(1)) def forward(self): """Test optimizer's robustness to adversarial gradient patterns.""" @@ -28,10 +28,15 @@ def forward(self): @app.command() -def main(method: List[str] = typer.Option(['qr'], help='Eigenvector method to use (for SOAP)'), - dtype: List[str] = typer.Option(['float32'], help='Data type to use'), steps: int = 100, - weight_decay: float = 0, opt: List[str] = typer.Option(['ForeachSOAP'], help='Optimizers to use'), - trials: int = 100, win_condition_multiplier: float = 1.0, ): +def main( + method: List[str] = typer.Option(["qr"], help="Eigenvector method to use (for SOAP)"), + dtype: List[str] = typer.Option(["float32"], help="Data type to use"), + steps: int = 100, + weight_decay: float = 0, + opt: List[str] = typer.Option(["ForeachSOAP"], help="Optimizers to use"), + trials: int = 100, + win_condition_multiplier: float = 1.0, +): dtype = [getattr(torch, d) for d in dtype] model = Model().cuda().double() @@ -39,9 +44,25 @@ def data(): return None, None # More lenient condition due to adversarial component - trial(model, data, None, param_norm_win_condition(win_condition_multiplier * 1e-3, 0), steps, opt[0], dtype[0], 1, 1, - weight_decay, method[0], 1, 1, failure_threshold=7, base_lr=1e-3, trials=trials) # More attempts for adversarial case + trial( + model, + data, + None, + param_norm_win_condition(win_condition_multiplier * 1e-3, 0), + steps, + opt[0], + dtype[0], + 1, + 1, + weight_decay, + method[0], + 1, + 1, + failure_threshold=7, + base_lr=1e-3, + trials=trials, + ) # More attempts for adversarial case -if __name__ == '__main__': +if __name__ == "__main__": app() diff --git a/benchmark/batch_size_scaling.py b/benchmark/batch_size_scaling.py index bb47f8a..22185f7 100644 --- a/benchmark/batch_size_scaling.py +++ b/benchmark/batch_size_scaling.py @@ -5,10 +5,11 @@ import torch import torch.backends.opt_einsum import typer -from benchmark.utils import trial, param_norm_win_condition, Validator -from heavyball.utils import set_torch from torch import nn +from benchmark.utils import param_norm_win_condition, trial +from heavyball.utils import set_torch + app = typer.Typer(pretty_exceptions_enable=False) set_torch() @@ -17,13 +18,13 @@ class Model(nn.Module): def __init__(self, size=1024): super().__init__() self.param = nn.Parameter(torch.randn(size)) - self.register_buffer('batch_sizes', torch.tensor([1, 4, 16, 64])) + self.register_buffer("batch_sizes", torch.tensor([1, 4, 16, 64])) self.rng = random.Random(0x1238192) def forward(self): """Test optimizer's ability to handle different batch sizes and noise scales.""" batch_size = self.rng.choice(self.batch_sizes) - generator = torch.Generator(device=self.param.device).manual_seed(self.rng.randint(0, 2 ** 31)) + generator = torch.Generator(device=self.param.device).manual_seed(self.rng.randint(0, 2**31)) noise = torch.randn(self.param.shape, generator=generator, device=self.param.device) scale = self.param.norm() / (noise.norm() + 1e-6) noise *= scale.detach() / math.sqrt(batch_size) @@ -31,10 +32,15 @@ def forward(self): @app.command() -def main(method: List[str] = typer.Option(['qr'], help='Eigenvector method to use (for SOAP)'), - dtype: List[str] = typer.Option(['float32'], help='Data type to use'), steps: int = 100, - weight_decay: float = 0, opt: List[str] = typer.Option(['ForeachSOAP'], help='Optimizers to use'), - trials: int = 100, win_condition_multiplier: float = 1.0, ): +def main( + method: List[str] = typer.Option(["qr"], help="Eigenvector method to use (for SOAP)"), + dtype: List[str] = typer.Option(["float32"], help="Data type to use"), + steps: int = 100, + weight_decay: float = 0, + opt: List[str] = typer.Option(["ForeachSOAP"], help="Optimizers to use"), + trials: int = 100, + win_condition_multiplier: float = 1.0, +): dtype = [getattr(torch, d) for d in dtype] model = Model().cuda().double() @@ -42,9 +48,25 @@ def data(): return None, None # Use a more lenient win condition since we have inherent noise - trial(model, data, None, param_norm_win_condition(win_condition_multiplier * 1e-8, 0), steps, opt[0], dtype[0], 1, - 1, weight_decay, method[0], 1, 1, failure_threshold=5, base_lr=1e-3, trials=trials) + trial( + model, + data, + None, + param_norm_win_condition(win_condition_multiplier * 1e-8, 0), + steps, + opt[0], + dtype[0], + 1, + 1, + weight_decay, + method[0], + 1, + 1, + failure_threshold=5, + base_lr=1e-3, + trials=trials, + ) -if __name__ == '__main__': +if __name__ == "__main__": app() diff --git a/benchmark/beale.py b/benchmark/beale.py index c1cf6c7..51a81ee 100644 --- a/benchmark/beale.py +++ b/benchmark/beale.py @@ -1,19 +1,15 @@ -import copy import pathlib import random -import time from typing import List import matplotlib.colors -import matplotlib.pyplot as plt import torch import torch.backends.opt_einsum import typer from hyperopt import early_stop -from benchmark.utils import Plotter from torch import nn -from benchmark.utils import trial, loss_win_condition +from benchmark.utils import Plotter, loss_win_condition, trial from heavyball.utils import set_torch early_stop.no_progress_loss() @@ -24,7 +20,7 @@ def objective(x, y): x = x + 3 y = y + 0.5 - return (1.5 - x + x * y) ** 2 + (2.25 - x + x * y ** 2) ** 2 + (2.625 - x + x * y ** 3) ** 2 + return (1.5 - x + x * y) ** 2 + (2.25 - x + x * y**2) ** 2 + (2.625 - x + x * y**3) ** 2 class Model(nn.Module): @@ -37,15 +33,21 @@ def forward(self): @app.command() -def main(method: List[str] = typer.Option(['qr'], help='Eigenvector method to use (for SOAP)'), - dtype: List[str] = typer.Option(['float32'], help='Data type to use'), steps: int = 100, - weight_decay: float = 0, opt: List[str] = typer.Option(['ForeachSOAP'], help='Optimizers to use'), - show_image: bool = False, trials: int = 100, win_condition_multiplier: float = 1.0, ): +def main( + method: List[str] = typer.Option(["qr"], help="Eigenvector method to use (for SOAP)"), + dtype: List[str] = typer.Option(["float32"], help="Data type to use"), + steps: int = 100, + weight_decay: float = 0, + opt: List[str] = typer.Option(["ForeachSOAP"], help="Optimizers to use"), + show_image: bool = False, + trials: int = 100, + win_condition_multiplier: float = 1.0, +): dtype = [getattr(torch, d) for d in dtype] coords = (-7, -4) # Clean up old plots - for path in pathlib.Path('.').glob('beale.png'): + for path in pathlib.Path(".").glob("beale.png"): path.unlink() colors = list(matplotlib.colors.TABLEAU_COLORS.values()) @@ -61,16 +63,30 @@ def main(method: List[str] = typer.Option(['qr'], help='Eigenvector method to us def data(): return None, None - model = trial(model, data, None, loss_win_condition(win_condition_multiplier * 1e-8 * (not show_image)), steps, - opt[0], dtype[0], 1, 1, weight_decay, method[0], 1, 1, base_lr=1e-4, trials=trials, - return_best=show_image) + model = trial( + model, + data, + None, + loss_win_condition(win_condition_multiplier * 1e-8 * (not show_image)), + steps, + opt[0], + dtype[0], + 1, + 1, + weight_decay, + method[0], + 1, + 1, + base_lr=1e-4, + trials=trials, + return_best=show_image, + ) if not show_image: return - model.plot(save_path='beale.png') + model.plot(save_path="beale.png") - -if __name__ == '__main__': +if __name__ == "__main__": app() diff --git a/benchmark/char_rnn.py b/benchmark/char_rnn.py index b208720..1608801 100644 --- a/benchmark/char_rnn.py +++ b/benchmark/char_rnn.py @@ -1,5 +1,3 @@ -import datetime -import os from pathlib import Path from typing import List @@ -9,9 +7,8 @@ import typer from torch.nn import functional as F -import heavyball -from heavyball.utils import set_torch from benchmark.utils import loss_win_condition, trial +from heavyball.utils import set_torch app = typer.Typer(pretty_exceptions_enable=False) set_torch() @@ -30,7 +27,7 @@ def __init__(self, features: int, sequence: int): nn.Embedding(256, features), nn.LSTM(features, features, 1, batch_first=True), # Removed dropout since num_layers=1 Take0(), - nn.Linear(features, 256) + nn.Linear(features, 256), ) def forward(self, inp): @@ -39,14 +36,14 @@ def forward(self, inp): @app.command() def main( - method: List[str] = typer.Option(['qr'], help='Eigenvector method to use (for SOAP)'), - dtype: List[str] = typer.Option(['float32'], help='Data type to use'), + method: List[str] = typer.Option(["qr"], help="Eigenvector method to use (for SOAP)"), + dtype: List[str] = typer.Option(["float32"], help="Data type to use"), features: int = 512, sequence: int = 256, batch: int = 16, steps: int = 100, weight_decay: float = 0, - opt: List[str] = typer.Option(['ForeachSOAP'], help='Optimizers to use'), + opt: List[str] = typer.Option(["ForeachSOAP"], help="Optimizers to use"), win_condition_multiplier: float = 1.0, trials: int = 10, ): @@ -55,17 +52,16 @@ def main( # Load text data benchmark_dir = Path(__file__).parent - with open(benchmark_dir / 'shakespeare.txt', 'rb') as f: + with open(benchmark_dir / "shakespeare.txt", "rb") as f: text = f.read() chars = torch.frombuffer(text, dtype=torch.uint8).cuda().long() # Create holdout set - holdout = chars[:(sequence + 1) * batch].view(batch, sequence + 1) - chars = chars[(sequence + 1) * batch:] - offsets = torch.arange(0, sequence + 1, device='cuda').repeat(batch, 1) + chars = chars[(sequence + 1) * batch :] + offsets = torch.arange(0, sequence + 1, device="cuda").repeat(batch, 1) def data(): - batch_offsets = torch.randint(0, len(chars) - sequence - 1, (batch,), device='cuda') + batch_offsets = torch.randint(0, len(chars) - sequence - 1, (batch,), device="cuda") batch_offsets = batch_offsets[:, None] + offsets batch_chars = chars[batch_offsets] batch_chars = batch_chars.view(batch, sequence + 1) @@ -73,9 +69,25 @@ def data(): tgt = batch_chars[:, 1:] return src, tgt - trial(model, data, F.cross_entropy, loss_win_condition(win_condition_multiplier * 2.0), steps, opt[0], dtype[0], features, batch, weight_decay, method[0], sequence, 1, - failure_threshold=10, base_lr=1e-3, trials=trials) - - -if __name__ == '__main__': + trial( + model, + data, + F.cross_entropy, + loss_win_condition(win_condition_multiplier * 2.0), + steps, + opt[0], + dtype[0], + features, + batch, + weight_decay, + method[0], + sequence, + 1, + failure_threshold=10, + base_lr=1e-3, + trials=trials, + ) + + +if __name__ == "__main__": app() diff --git a/benchmark/discontinuous_gradient.py b/benchmark/discontinuous_gradient.py index 4195d85..24235d8 100644 --- a/benchmark/discontinuous_gradient.py +++ b/benchmark/discontinuous_gradient.py @@ -1,15 +1,11 @@ -import pathlib -import random from typing import List -import matplotlib.colors import torch import torch.backends.opt_einsum import typer -from utils import Plotter from torch import nn -from benchmark.utils import trial, param_norm_win_condition +from benchmark.utils import param_norm_win_condition, trial from heavyball.utils import set_torch app = typer.Typer(pretty_exceptions_enable=False) @@ -18,7 +14,7 @@ def objective(x): """Tests optimizer robustness to non-smooth landscapes with discontinuous gradients.""" - return torch.where(x < 0, x**2, 2*x).mean() # Discontinuous gradient at x=0 + return torch.where(x < 0, x**2, 2 * x).mean() # Discontinuous gradient at x=0 class Model(nn.Module): @@ -31,19 +27,40 @@ def forward(self): @app.command() -def main(method: List[str] = typer.Option(['qr'], help='Eigenvector method to use (for SOAP)'), - dtype: List[str] = typer.Option(['float32'], help='Data type to use'), steps: int = 100, - weight_decay: float = 0, opt: List[str] = typer.Option(['ForeachSOAP'], help='Optimizers to use'), - trials: int = 100, win_condition_multiplier: float = 1.0, ): +def main( + method: List[str] = typer.Option(["qr"], help="Eigenvector method to use (for SOAP)"), + dtype: List[str] = typer.Option(["float32"], help="Data type to use"), + steps: int = 100, + weight_decay: float = 0, + opt: List[str] = typer.Option(["ForeachSOAP"], help="Optimizers to use"), + trials: int = 100, + win_condition_multiplier: float = 1.0, +): dtype = [getattr(torch, d) for d in dtype] model = Model().cuda().double() def data(): return None, None - trial(model, data, None, param_norm_win_condition(win_condition_multiplier * 1e-4, 0), steps, opt[0], dtype[0], 1, 1, - weight_decay, method[0], 1, 1, failure_threshold=3, base_lr=1e-3, trials=trials) + trial( + model, + data, + None, + param_norm_win_condition(win_condition_multiplier * 1e-4, 0), + steps, + opt[0], + dtype[0], + 1, + 1, + weight_decay, + method[0], + 1, + 1, + failure_threshold=3, + base_lr=1e-3, + trials=trials, + ) -if __name__ == '__main__': +if __name__ == "__main__": app() diff --git a/benchmark/dynamic_landscape.py b/benchmark/dynamic_landscape.py index 822541e..22c6bbf 100644 --- a/benchmark/dynamic_landscape.py +++ b/benchmark/dynamic_landscape.py @@ -15,6 +15,7 @@ import torch import torch.nn as nn import typer + from benchmark.utils import loss_win_condition, trial from heavyball.utils import set_torch @@ -37,10 +38,16 @@ def forward(self): @app.command() -def main(method: List[str] = typer.Option(['qr'], help='Eigenvector method to use (for SOAP)'), - dtype: List[str] = typer.Option(["float32"], help='Data type to use'), dim: int = 16384, steps: int = 500, - weight_decay: float = 0, opt: List[str] = typer.Option(['adamw'], help='Optimizers to use'), - win_condition_multiplier: float = 1.0, trials: int = 3): +def main( + method: List[str] = typer.Option(["qr"], help="Eigenvector method to use (for SOAP)"), + dtype: List[str] = typer.Option(["float32"], help="Data type to use"), + dim: int = 16384, + steps: int = 500, + weight_decay: float = 0, + opt: List[str] = typer.Option(["adamw"], help="Optimizers to use"), + win_condition_multiplier: float = 1.0, + trials: int = 3, +): """Run dynamic landscape benchmark with specified parameters.""" dtype = [getattr(torch, d) for d in dtype] @@ -53,9 +60,24 @@ def data(): return None, None # Win condition: average squared error should be small (parameters close to target) - trial(model, data, None, loss_win_condition(0.01 * win_condition_multiplier, 0), steps, [o], [d], 1, 1, - wd, m, 1, 1, base_lr=0.1, trials=trials) + trial( + model, + data, + None, + loss_win_condition(0.01 * win_condition_multiplier, 0), + steps, + [o], + [d], + 1, + 1, + wd, + m, + 1, + 1, + base_lr=0.1, + trials=trials, + ) -if __name__ == '__main__': +if __name__ == "__main__": app() diff --git a/benchmark/exploding_gradient.py b/benchmark/exploding_gradient.py index 3fab9af..174137c 100644 --- a/benchmark/exploding_gradient.py +++ b/benchmark/exploding_gradient.py @@ -15,8 +15,8 @@ import torch.nn as nn import typer -from heavyball.utils import set_torch from benchmark.utils import param_norm_win_condition, trial +from heavyball.utils import set_torch app = typer.Typer(pretty_exceptions_enable=False) set_torch() @@ -27,7 +27,7 @@ def __init__(self, dim): super().__init__() self.param = nn.Parameter(torch.randn(dim)) self.scale = 5.0 # Controls how quickly gradients grow - + def forward(self): # Creates exponentially growing gradients # Gradient will be scale * exp(|param|) * sign(param) @@ -35,20 +35,22 @@ def forward(self): @app.command() -def main(method: List[str] = typer.Option(['qr'], help='Eigenvector method to use (for SOAP)'), - dtype: List[str] = typer.Option(["float32"], help='Data type to use'), - dim: int = 512, - steps: int = 500, - weight_decay: float = 0, - opt: List[str] = typer.Option(['adamw'], help='Optimizers to use'), - win_condition_multiplier: float = 1.0, - trials: int = 3): +def main( + method: List[str] = typer.Option(["qr"], help="Eigenvector method to use (for SOAP)"), + dtype: List[str] = typer.Option(["float32"], help="Data type to use"), + dim: int = 512, + steps: int = 500, + weight_decay: float = 0, + opt: List[str] = typer.Option(["adamw"], help="Optimizers to use"), + win_condition_multiplier: float = 1.0, + trials: int = 3, +): """Run exploding gradient benchmark with specified parameters.""" dtype = [getattr(torch, d) for d in dtype] for args in itertools.product(method, dtype, [dim], opt, [weight_decay]): m, d, dim, o, wd = args - + model = ExplodingGradient(dim) def data(): @@ -56,13 +58,24 @@ def data(): # Win condition: loss should be close to 1.0 (exp(0) = 1) # Using 1.1 as threshold since perfect convergence is hard - trial(model, data, None, - param_norm_win_condition(0.01 * win_condition_multiplier, 0), - steps, [o], [d], 1, 1, - wd, m, 1, 1, - base_lr=0.001, # Lower learning rate due to large gradients - trials=trials) + trial( + model, + data, + None, + param_norm_win_condition(0.01 * win_condition_multiplier, 0), + steps, + [o], + [d], + 1, + 1, + wd, + m, + 1, + 1, + base_lr=0.001, # Lower learning rate due to large gradients + trials=trials, + ) -if __name__ == '__main__': +if __name__ == "__main__": app() diff --git a/benchmark/format_results.py b/benchmark/format_results.py index 74fd2e9..b72abee 100644 --- a/benchmark/format_results.py +++ b/benchmark/format_results.py @@ -1,55 +1,57 @@ -import colorsys import re -from datetime import datetime import matplotlib.pyplot as plt import numpy as np import pandas as pd -import seaborn as sns import typer -from matplotlib.colors import LinearSegmentedColormap, to_rgba -from matplotlib.patches import Rectangle, FancyBboxPatch +from matplotlib.patches import Rectangle def parse_loss(loss_str): - if loss_str.strip() == 'inf': - return float('inf') + if loss_str.strip() == "inf": + return float("inf") try: - if 'e' in loss_str: - base, exp = loss_str.split('e') + if "e" in loss_str: + base, exp = loss_str.split("e") return float(base) * 10 ** float(exp) return float(loss_str) except (ValueError, IndexError): - return float('nan') + return float("nan") def read_benchmark_results(file_path): - with open(file_path, 'r') as f: + with open(file_path, "r") as f: content = f.read() - details_section = re.search(r'## Details\n\n(.*?)(?=\n\n|$)', content, re.DOTALL) + details_section = re.search(r"## Details\n\n(.*?)(?=\n\n|$)", content, re.DOTALL) if not details_section: raise ValueError("Could not find Details section") - lines = details_section.group(1).strip().split('\n')[2:] + lines = details_section.group(1).strip().split("\n")[2:] data = [] for line in lines: if not line.strip(): continue - parts = [p.strip() for p in line.split('|')[1:-1]] + parts = [p.strip() for p in line.split("|")[1:-1]] if len(parts) < 8: continue - data.append( - {'benchmark': parts[0], 'optimizer': parts[1], 'cautious': parts[2] == 'Yes', 'mars': parts[3] == 'Yes', - 'success': parts[4] == '✓', 'runtime': float(parts[5].replace('s', '')), 'loss': parse_loss(parts[6]), - 'attempts': int(parts[7])}) + data.append({ + "benchmark": parts[0], + "optimizer": parts[1], + "cautious": parts[2] == "Yes", + "mars": parts[3] == "Yes", + "success": parts[4] == "✓", + "runtime": float(parts[5].replace("s", "")), + "loss": parse_loss(parts[6]), + "attempts": int(parts[7]), + }) return pd.DataFrame(data) def create_result_matrix(df): - benchmarks = sorted(df['benchmark'].unique()) - optimizers = sorted(df['optimizer'].unique()) + benchmarks = sorted(df["benchmark"].unique()) + optimizers = sorted(df["optimizer"].unique()) success_matrix = pd.DataFrame(index=benchmarks, columns=optimizers) attempts_matrix = pd.DataFrame(index=benchmarks, columns=optimizers) @@ -57,10 +59,10 @@ def create_result_matrix(df): loss_matrix = pd.DataFrame(index=benchmarks, columns=optimizers) for _, row in df.iterrows(): - success_matrix.loc[row['benchmark'], row['optimizer']] = row['success'] - attempts_matrix.loc[row['benchmark'], row['optimizer']] = row['attempts'] - runtime_matrix.loc[row['benchmark'], row['optimizer']] = row['runtime'] - loss_matrix.loc[row['benchmark'], row['optimizer']] = row['loss'] + success_matrix.loc[row["benchmark"], row["optimizer"]] = row["success"] + attempts_matrix.loc[row["benchmark"], row["optimizer"]] = row["attempts"] + runtime_matrix.loc[row["benchmark"], row["optimizer"]] = row["runtime"] + loss_matrix.loc[row["benchmark"], row["optimizer"]] = row["loss"] return success_matrix, attempts_matrix, runtime_matrix, loss_matrix @@ -68,7 +70,7 @@ def create_result_matrix(df): def normalize_row_attempts(row_attempts, row_success): """Normalize attempts within a row, considering only successful runs""" # Convert to boolean and handle NaN - success_mask = row_success.fillna(False).astype(bool) + success_mask = to_bool(row_success) successful_attempts = row_attempts[success_mask] if len(successful_attempts) == 0: @@ -91,7 +93,7 @@ def normalize_row_attempts(row_attempts, row_success): def get_color_for_cell(normalized_value, success, best_in_row=False): """Generate color for a cell based on normalized value and success""" if pd.isna(normalized_value) or not success: - return '#FF3B30' # Failure color (red) + return "#FF3B30" # Failure color (red) # Create a gradient from light green to dark blue light_green = np.array([0.7, 1.0, 0.7]) # Light green @@ -108,12 +110,24 @@ def get_color_for_cell(normalized_value, success, best_in_row=False): return tuple(color) +def to_bool(x): + return x.fillna(False).astype(bool) + + def create_visual_matrix(success_matrix, attempts_matrix, runtime_matrix, loss_matrix): - plt.style.use('default') - fig = plt.figure(figsize=(20, 15), facecolor='white') + plt.style.use("default") + fig = plt.figure(figsize=(20, 15), facecolor="white") # Create grid for multiple panels with adjusted width ratios - gs = plt.GridSpec(2, 4, figure=fig, width_ratios=[4, 0.1, 1.5, 1.5], height_ratios=[1, 1], wspace=0.3, hspace=0.3) + gs = plt.GridSpec( + 2, + 4, + figure=fig, + width_ratios=[4, 0.1, 1.5, 1.5], + height_ratios=[1, 1], + wspace=0.3, + hspace=0.3, + ) # Main heatmap main_ax = fig.add_subplot(gs[:, 0]) @@ -139,7 +153,7 @@ def create_visual_matrix(success_matrix, attempts_matrix, runtime_matrix, loss_m row_runtime = runtime_matrix.loc[idx] # Find best performer (successful with minimum attempts, then minimum runtime) - successful_mask = row_success == True + successful_mask = to_bool(row_success) if successful_mask.any(): min_attempts = row_attempts[successful_mask].min() min_attempts_mask = (row_attempts == min_attempts) & successful_mask @@ -159,72 +173,90 @@ def create_visual_matrix(success_matrix, attempts_matrix, runtime_matrix, loss_m attempts = attempts_matrix.iloc[i, j] runtime = runtime_matrix.iloc[i, j] normalized = normalized_attempts.iloc[i, j] - is_best = best_performers.iloc[i, j] == True + is_best = to_bool(best_performers.iloc[i, j]) # Get cell color color = get_color_for_cell(normalized, success, is_best) # Create cell rectangle - rect = Rectangle((j - 0.5, i - 0.5), 1, 1, facecolor=color, alpha=1.0, edgecolor='white', linewidth=1) + rect = Rectangle( + (j - 0.5, i - 0.5), + 1, + 1, + facecolor=color, + alpha=1.0, + edgecolor="white", + linewidth=1, + ) main_ax.add_patch(rect) if pd.notna(attempts): # Add attempt count and runtime - attempts_text = f'{int(attempts)}' - runtime_text = f'{runtime:.1f}s' + attempts_text = f"{int(attempts)}" + runtime_text = f"{runtime:.1f}s" # Determine text color based on background brightness if isinstance(color, tuple): brightness = 0.299 * color[0] + 0.587 * color[1] + 0.114 * color[2] - text_color = 'white' if brightness < 0.65 else 'black' + text_color = "white" if brightness < 0.65 else "black" else: - text_color = 'white' if color == '#FF3B30' else 'black' + text_color = "white" if color == "#FF3B30" else "black" # Add text with better formatting - main_ax.text(j, i - 0.15, attempts_text, ha='center', va='center', color=text_color, fontsize=9, - fontweight='bold') - main_ax.text(j, i + 0.15, runtime_text, ha='center', va='center', color=text_color, fontsize=8) + main_ax.text( + j, + i - 0.15, + attempts_text, + ha="center", + va="center", + color=text_color, + fontsize=9, + fontweight="bold", + ) + main_ax.text( + j, + i + 0.15, + runtime_text, + ha="center", + va="center", + color=text_color, + fontsize=8, + ) # Add star for best performer if is_best: - main_ax.text(j - 0.4, i - 0.4, '★', ha='center', va='center', color='#FFD700', fontsize=14, - fontweight='bold') + main_ax.text( + j - 0.4, + i - 0.4, + "★", + ha="center", + va="center", + color="#FFD700", + fontsize=14, + fontweight="bold", + ) # Add grid lines for i in range(success_matrix.shape[0] + 1): - main_ax.axhline(y=i - 0.5, color='#DDD', linewidth=0.5, alpha=0.5) + main_ax.axhline(y=i - 0.5, color="#DDD", linewidth=0.5, alpha=0.5) for j in range(success_matrix.shape[1] + 1): - main_ax.axvline(x=j - 0.5, color='#DDD', linewidth=0.5, alpha=0.5) + main_ax.axvline(x=j - 0.5, color="#DDD", linewidth=0.5, alpha=0.5) # Format axis labels main_ax.set_xticks(range(len(success_matrix.columns))) main_ax.set_yticks(range(len(success_matrix.index))) - main_ax.set_xticklabels(success_matrix.columns, rotation=45, ha='right', fontsize=10, fontweight='bold') - main_ax.set_yticklabels(success_matrix.index, fontsize=10, fontweight='bold') + main_ax.set_xticklabels(success_matrix.columns, rotation=45, ha="right", fontsize=10, fontweight="bold") + main_ax.set_yticklabels(success_matrix.index, fontsize=10, fontweight="bold") # Create statistics panels def create_stats_panel(ax, title, data, is_percentage=False, cmap=plt.cm.RdYlGn): ax.clear() - ax.set_title(title, fontsize=10, fontweight='bold', pad=10) - - # Normalize data for color mapping (if not percentage) - if not is_percentage: - data_min = data.min() - data_max = data.max() - if data_max == data_min: # Handle case where all values are the same - normalized_data = pd.Series(0.5, index=data.index) - else: - normalized_data = (data - data_min) / (data_max - data_min) - else: - normalized_data = data # Already normalized for percentages - - # Plot bars - bars = ax.barh(range(len(data)), data.values, color=[cmap(x) for x in normalized_data]) + ax.set_title(title, fontsize=10, fontweight="bold", pad=10) # Add value labels for i, v in enumerate(data.values): - text = f'{v:.1%}' if is_percentage else f'{v:.1f}' - ax.text(v + max(data.values) * 0.02, i, text, va='center', fontsize=8) + text = f"{v:.1%}" if is_percentage else f"{v:.1f}" + ax.text(v + max(data.values) * 0.02, i, text, va="center", fontsize=8) # Format axis ax.set_yticks(range(len(data))) @@ -232,49 +264,69 @@ def create_stats_panel(ax, title, data, is_percentage=False, cmap=plt.cm.RdYlGn) ax.set_xlim(0, max(data.values) * 1.15) # Remove frame - ax.spines['top'].set_visible(False) - ax.spines['right'].set_visible(False) + ax.spines["top"].set_visible(False) + ax.spines["right"].set_visible(False) return ax # Calculate and plot optimizer success rates success_rates = success_matrix.mean() - create_stats_panel(stats_ax1, 'Optimizer Success Rates (↑)', success_rates.sort_values(ascending=True), - is_percentage=True, cmap=plt.cm.Greens) + create_stats_panel( + stats_ax1, + "Optimizer Success Rates (↑)", + success_rates.sort_values(ascending=True), + is_percentage=True, + cmap=plt.cm.Greens, + ) # Calculate and plot average attempts for successful runs avg_attempts = pd.Series(index=success_matrix.columns, dtype=float) for col in success_matrix.columns: - success_mask = success_matrix[col].fillna(False).astype(bool) + success_mask = to_bool(success_matrix[col]) successful_attempts = attempts_matrix[success_mask][col] avg_attempts[col] = successful_attempts.mean() if len(successful_attempts) > 0 else np.nan - create_stats_panel(stats_ax2, 'Avg Attempts Needed (↓)', avg_attempts.sort_values(ascending=False), - cmap=plt.cm.GnBu) + create_stats_panel( + stats_ax2, + "Avg Attempts Needed (↓)", + avg_attempts.sort_values(ascending=False), + cmap=plt.cm.GnBu, + ) # Calculate and plot average runtime for successful runs avg_runtime = pd.Series(index=success_matrix.columns, dtype=float) for col in success_matrix.columns: - success_mask = success_matrix[col].fillna(False).astype(bool) + success_mask = to_bool(success_matrix[col]) successful_runtime = runtime_matrix[success_mask][col] avg_runtime[col] = successful_runtime.mean() if len(successful_runtime) > 0 else np.nan - create_stats_panel(stats_ax3, 'Avg Runtime Needed (↓)', avg_runtime.sort_values(ascending=False), - cmap=plt.cm.YlOrBr) + create_stats_panel( + stats_ax3, + "Avg Runtime Needed (↓)", + avg_runtime.sort_values(ascending=False), + cmap=plt.cm.YlOrBr, + ) # Calculate and plot average loss for successful runs avg_best = pd.Series(index=runtime_matrix.columns, dtype=float) avg_best[:] = 0 - for (_, ru), (_, su), (_, at) in zip(runtime_matrix.iterrows(), success_matrix.iterrows(), attempts_matrix.iterrows()): + for (_, ru), (_, su), (_, at) in zip( + runtime_matrix.iterrows(), success_matrix.iterrows(), attempts_matrix.iterrows() + ): score = ru + at * 1000 # minimize attempt count. if tie, use runtime - score = score - su * 1e12 # only count successful runs + score = score - su * 1e12 # only count successful runs avg_best[runtime_matrix.columns[score.argmin()]] += 1 avg_best /= avg_best.sum() / 100 - create_stats_panel(stats_ax4, 'Best Optimizer% (↑)', avg_best.sort_values(ascending=True), cmap=plt.cm.YlGn) + create_stats_panel(stats_ax4, "Best Optimizer% (↑)", avg_best.sort_values(ascending=True), cmap=plt.cm.YlGn) # Add title and subtitle - plt.suptitle('Optimizer Performance Matrix', y=0.98, fontsize=16, fontweight='bold') - fig.text(0.25, 0.94, - 'Color intensity shows relative number of attempts per benchmark (row-normalized)\n' + '★ indicates best performer per benchmark', - ha='center', fontsize=11) + plt.suptitle("Optimizer Performance Matrix", y=0.98, fontsize=16, fontweight="bold") + fig.text( + 0.25, + 0.94, + "Color intensity shows relative number of attempts per benchmark (row-normalized)\n" + + "★ indicates best performer per benchmark", + ha="center", + fontsize=11, + ) # Adjust layout plt.tight_layout(rect=[0, 0.02, 1, 0.92]) @@ -282,15 +334,15 @@ def create_stats_panel(ax, title, data, is_percentage=False, cmap=plt.cm.RdYlGn) return fig -def main(file: str = 'benchmark_results.md'): +def main(file: str = "benchmark_results.md"): df = read_benchmark_results(file) success_matrix, attempts_matrix, runtime_matrix, loss_matrix = create_result_matrix(df) # Create the enhanced visual matrix - fig = create_visual_matrix(success_matrix, attempts_matrix, runtime_matrix, loss_matrix) + _fig = create_visual_matrix(success_matrix, attempts_matrix, runtime_matrix, loss_matrix) # Save with high quality - plt.savefig('benchmark_matrix.png', dpi=300, bbox_inches='tight', facecolor='white', pad_inches=0.5) + plt.savefig("benchmark_matrix.png", dpi=300, bbox_inches="tight", facecolor="white", pad_inches=0.5) plt.close() # Print text summary diff --git a/benchmark/gradient_delay.py b/benchmark/gradient_delay.py index fb8966a..8516beb 100644 --- a/benchmark/gradient_delay.py +++ b/benchmark/gradient_delay.py @@ -6,7 +6,7 @@ import typer from torch import nn -from benchmark.utils import trial, loss_win_condition +from benchmark.utils import loss_win_condition, trial from heavyball.utils import set_torch app = typer.Typer(pretty_exceptions_enable=False) @@ -16,38 +16,41 @@ class Model(nn.Module): def __init__(self, num_params=16, param_size=256): super().__init__() - self.params = nn.ParameterList([ - nn.Parameter(torch.randn(param_size)) for _ in range(num_params) - ]) + self.params = nn.ParameterList([nn.Parameter(torch.randn(param_size)) for _ in range(num_params)]) # Different update frequencies for each parameter self.delays = [i for i in range(num_params)] self.step = 0 - self.grad_queues = [deque(maxlen=i+1) for i in self.delays] + self.grad_queues = [deque(maxlen=i + 1) for i in self.delays] def forward(self): """Test optimizer's ability to handle delayed gradients.""" total_loss = 0 self.step += 1 - + for param, delay, queue in zip(self.params, self.delays, self.grad_queues): # Current loss for this parameter loss = param.square().mean() - + # Store the gradient in the queue queue.append(loss) - + # Only add to total loss when we have enough history if len(queue) == queue.maxlen and self.step % (delay + 1) == 0: total_loss = total_loss + queue[0] # Use oldest gradient - + return total_loss / len(self.params) @app.command() -def main(method: List[str] = typer.Option(['qr'], help='Eigenvector method to use (for SOAP)'), - dtype: List[str] = typer.Option(['float32'], help='Data type to use'), steps: int = 100, - weight_decay: float = 0, opt: List[str] = typer.Option(['ForeachSOAP'], help='Optimizers to use'), - trials: int = 100, win_condition_multiplier: float = 1.0, ): +def main( + method: List[str] = typer.Option(["qr"], help="Eigenvector method to use (for SOAP)"), + dtype: List[str] = typer.Option(["float32"], help="Data type to use"), + steps: int = 100, + weight_decay: float = 0, + opt: List[str] = typer.Option(["ForeachSOAP"], help="Optimizers to use"), + trials: int = 100, + win_condition_multiplier: float = 1.0, +): dtype = [getattr(torch, d) for d in dtype] model = Model().cuda().double() @@ -55,9 +58,25 @@ def data(): return None, None # More lenient win condition and more steps due to delayed updates - trial(model, data, None, loss_win_condition(win_condition_multiplier * 1e-4), steps * 2, opt[0], dtype[0], 1, 1, - weight_decay, method[0], 1, 1, failure_threshold=5, base_lr=1e-3, trials=trials) # Double steps, more attempts + trial( + model, + data, + None, + loss_win_condition(win_condition_multiplier * 1e-4), + steps * 2, + opt[0], + dtype[0], + 1, + 1, + weight_decay, + method[0], + 1, + 1, + failure_threshold=5, + base_lr=1e-3, + trials=trials, + ) # Double steps, more attempts -if __name__ == '__main__': +if __name__ == "__main__": app() diff --git a/benchmark/gradient_noise_scale.py b/benchmark/gradient_noise_scale.py index ec544a2..194755c 100644 --- a/benchmark/gradient_noise_scale.py +++ b/benchmark/gradient_noise_scale.py @@ -5,7 +5,7 @@ import typer from torch import nn -from benchmark.utils import trial, param_norm_win_condition +from benchmark.utils import param_norm_win_condition, trial from heavyball.utils import set_torch app = typer.Typer(pretty_exceptions_enable=False) @@ -16,7 +16,7 @@ class Model(nn.Module): def __init__(self, size=4096): super().__init__() self.param = nn.Parameter(torch.randn(size)) - self.register_buffer('step', torch.zeros(1)) + self.register_buffer("step", torch.zeros(1)) def forward(self): """Test optimizer's ability to handle changing noise levels during training.""" @@ -28,10 +28,15 @@ def forward(self): @app.command() -def main(method: List[str] = typer.Option(['qr'], help='Eigenvector method to use (for SOAP)'), - dtype: List[str] = typer.Option(['float32'], help='Data type to use'), steps: int = 100, - weight_decay: float = 0, opt: List[str] = typer.Option(['ForeachSOAP'], help='Optimizers to use'), - trials: int = 100, win_condition_multiplier: float = 1.0, ): +def main( + method: List[str] = typer.Option(["qr"], help="Eigenvector method to use (for SOAP)"), + dtype: List[str] = typer.Option(["float32"], help="Data type to use"), + steps: int = 100, + weight_decay: float = 0, + opt: List[str] = typer.Option(["ForeachSOAP"], help="Optimizers to use"), + trials: int = 100, + win_condition_multiplier: float = 1.0, +): dtype = [getattr(torch, d) for d in dtype] model = Model().cuda().double() @@ -39,9 +44,25 @@ def data(): return None, None # Lenient initial condition due to high initial noise - trial(model, data, None, param_norm_win_condition(win_condition_multiplier * 1e-3, 0), steps, opt[0], dtype[0], 1, 1, - weight_decay, method[0], 1, 1, failure_threshold=5, base_lr=1e-3, trials=trials) + trial( + model, + data, + None, + param_norm_win_condition(win_condition_multiplier * 1e-3, 0), + steps, + opt[0], + dtype[0], + 1, + 1, + weight_decay, + method[0], + 1, + 1, + failure_threshold=5, + base_lr=1e-3, + trials=trials, + ) -if __name__ == '__main__': +if __name__ == "__main__": app() diff --git a/benchmark/grokking.py b/benchmark/grokking.py index 32ffbb8..b7d5766 100644 --- a/benchmark/grokking.py +++ b/benchmark/grokking.py @@ -1,20 +1,21 @@ -import random -import heavyball +import copy import itertools -from typing import List +import random from collections import defaultdict -import copy +from pathlib import Path +from typing import List + +import matplotlib.pyplot as plt +import numpy as np import torch import torch.backends.opt_einsum import torch.nn as nn import typer -from torch.utils.data import TensorDataset, DataLoader -import matplotlib.pyplot as plt -import numpy as np -from pathlib import Path +from torch.utils.data import DataLoader +import heavyball +from benchmark.utils import get_optim from heavyball.utils import set_torch -from benchmark.utils import trial, get_optim app = typer.Typer(pretty_exceptions_enable=False) set_torch() @@ -32,12 +33,13 @@ def __init__(self, numbers, p, hidden_dim): nn.Linear(hidden_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.LeakyReLU(), - nn.Linear(hidden_dim, p) + nn.Linear(hidden_dim, p), ) - + def forward(self, x): return self.net(x) + class ModuloDataset(torch.utils.data.Dataset): def __init__(self, p, numbers, min_idx, length, batch_size): length = length // batch_size @@ -75,28 +77,27 @@ def evaluate(model, loader, device): def plot_results(train_losses, test_accs, steps_to_grok=None, save_path=None): """Plot training curves""" - fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 8), sharex=True) - + _fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 8), sharex=True) + # Plot training loss - ax1.plot(train_losses, label='Training Loss') - ax1.set_yscale('log') - ax1.set_ylabel('Loss') - ax1.set_title('Training Loss Over Time') + ax1.plot(train_losses, label="Training Loss") + ax1.set_yscale("log") + ax1.set_ylabel("Loss") + ax1.set_title("Training Loss Over Time") ax1.grid(True) - + # Plot test accuracy eval_steps = np.arange(0, len(train_losses), len(train_losses) // len(test_accs)) - ax2.plot(eval_steps, test_accs, label='Test Accuracy', color='orange') - ax2.axhline(y=0.9, color='r', linestyle='--', label='Grokking Threshold') - ax2.set_ylabel('Accuracy') - ax2.set_xlabel('Steps') - ax2.set_title('Test Accuracy Over Time') + ax2.plot(eval_steps, test_accs, label="Test Accuracy", color="orange") + ax2.axhline(y=0.9, color="r", linestyle="--", label="Grokking Threshold") + ax2.set_ylabel("Accuracy") + ax2.set_xlabel("Steps") + ax2.set_title("Test Accuracy Over Time") ax2.grid(True) - + if steps_to_grok is not None: - ax2.axvline(x=steps_to_grok, color='g', linestyle='--', - label=f'Grokking Step ({steps_to_grok})') - + ax2.axvline(x=steps_to_grok, color="g", linestyle="--", label=f"Grokking Step ({steps_to_grok})") + ax1.legend() ax2.legend() plt.tight_layout() @@ -106,65 +107,67 @@ def plot_results(train_losses, test_accs, steps_to_grok=None, save_path=None): @app.command() -def main(method: List[str] = typer.Option(['qr'], help='Eigenvector method to use (for SOAP)'), - dtype: List[str] = typer.Option(["float32"], help='Data type to use'), - opt: List[str] = typer.Option(['ForeachSOAP', 'PaLMForeachSOAP', 'PrecondScheduleForeachSOAP'], help='Optimizers to use'), - steps: int = 100, - batch_size: int = 32, - hidden_dim: int = 32, - p: int = 257, - numbers: int = 4, - weight_decay: float = 0, - lr: float = 1e-4, - train_percent: float = 0.1, - eval_samples: int = 1024, - printervall: int = 1000,): - +def main( + method: List[str] = typer.Option(["qr"], help="Eigenvector method to use (for SOAP)"), + dtype: List[str] = typer.Option(["float32"], help="Data type to use"), + opt: List[str] = typer.Option( + ["ForeachSOAP", "PaLMForeachSOAP", "PrecondScheduleForeachSOAP"], help="Optimizers to use" + ), + steps: int = 100, + batch_size: int = 32, + hidden_dim: int = 32, + p: int = 257, + numbers: int = 4, + weight_decay: float = 0, + lr: float = 1e-4, + train_percent: float = 0.1, + eval_samples: int = 1024, + printervall: int = 1000, +): dtype = [getattr(torch, d) for d in dtype] - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + # Clean up old plots - plot_dir = Path('.') - for path in plot_dir.glob('grokking_*.png'): + plot_dir = Path(".") + for path in plot_dir.glob("grokking_*.png"): path.unlink() - + # Pre-generate datasets - unique_samples = p ** numbers + unique_samples = p**numbers train_data = ModuloDataset(p, numbers, 0, int(unique_samples * train_percent), batch_size) test_data = ModuloDataset(p, numbers, train_data.max_idx, eval_samples, eval_samples) - + print(f"Training on {train_data.n_samples * batch_size:,} samples - {train_percent * 100}%") print(f"Testing on {eval_samples:,} samples") - train_loader = DataLoader( - train_data, + train_data, collate_fn=lambda x: x[0], - batch_size=1, + batch_size=1, shuffle=False, - pin_memory=True, + pin_memory=True, num_workers=4, drop_last=True, prefetch_factor=16, - persistent_workers=True + persistent_workers=True, ) - + test_loader = DataLoader( - test_data, + test_data, collate_fn=lambda x: x[0], batch_size=1, shuffle=False, pin_memory=True, num_workers=4, drop_last=True, - prefetch_factor=32 + prefetch_factor=32, ) test_loader = list(test_loader) test_loader = [[x.pin_memory() for x in i] for i in test_loader] - + train_iter = iter(train_loader) history = defaultdict(list) - + def data(): """Get next batch from the dataloader""" nonlocal train_iter @@ -174,71 +177,71 @@ def data(): train_iter = iter(train_loader) x, y = next(train_iter) return x.to(device), y.to(device) - + criterion = nn.CrossEntropyLoss() - + def win_condition(model, loss_hist): """Check if model has achieved grokking""" if not isinstance(loss_hist, float): loss = loss_hist else: loss = loss_hist - - history['loss'].append(loss) - + + history["loss"].append(loss) + if loss > 0.1: # Not converged yet return False, {} - + # If loss is low, check test accuracy acc = evaluate(model, test_loader, device) - history['test_acc'].append(acc) - return acc > 0.9, {'test_acc': acc} - + history["test_acc"].append(acc) + return acc > 0.9, {"test_acc": acc} + global_model = ModularMLP(numbers, p, hidden_dim).to(device) - global_model = torch.compile(global_model, mode='max-autotune-no-cudagraphs') + global_model = torch.compile(global_model, mode="max-autotune-no-cudagraphs") for d, o in itertools.product(dtype, opt): print(f"\nRunning {o} with {d}") model = copy.deepcopy(global_model) model.to(dtype=d) - + history.clear() - + # Get optimizer class optimizer_class = getattr(heavyball, o) optimizer = get_optim(optimizer_class, model.parameters(), lr=lr, weight_decay=weight_decay) - + loss_hist = torch.empty(steps) # Training loop for step in range(steps): model.train() x, y = data() - + optimizer.zero_grad() out = model(x) loss = criterion(out, y) loss.backward() optimizer.step() - + with torch.no_grad(): loss_hist[step] = loss.detach() - + if step % printervall == 0: lh = loss_hist[:step][-printervall:].mean().item() acc = evaluate(model, test_loader, device).item() - history['test_acc'].append(acc) + history["test_acc"].append(acc) print(f"Step {step}: Loss = {lh:.4f}, Test Acc = {acc:.4f}") - + # Plot results plot_name = plot_dir / f"grokking_{o}_{d}_lr{lr}_h{hidden_dim}_p{p}.png" plot_results( loss_hist.cpu().numpy(), - history['test_acc'], - next((i for i, acc in enumerate(history['test_acc']) if acc > 0.9), None), - plot_name + history["test_acc"], + next((i for i, acc in enumerate(history["test_acc"]) if acc > 0.9), None), + plot_name, ) print(f"Training curves saved to {plot_name}") -if __name__ == '__main__': +if __name__ == "__main__": app() diff --git a/benchmark/layer_wise_scale.py b/benchmark/layer_wise_scale.py index 07a183f..09c895c 100644 --- a/benchmark/layer_wise_scale.py +++ b/benchmark/layer_wise_scale.py @@ -5,7 +5,7 @@ import typer from torch import nn -from benchmark.utils import trial, loss_win_condition +from benchmark.utils import loss_win_condition, trial from heavyball.utils import set_torch app = typer.Typer(pretty_exceptions_enable=False) @@ -17,22 +17,27 @@ def __init__(self, size=1024): super().__init__() # Simulate different layer scales in deep networks self.layer1 = nn.Parameter(torch.randn(size)) # Small gradients - self.layer2 = nn.Parameter(torch.randn(size)) # Medium gradients - self.layer3 = nn.Parameter(torch.randn(size)) # Large gradients + self.layer2 = nn.Parameter(torch.randn(size)) # Medium gradients + self.layer3 = nn.Parameter(torch.randn(size)) # Large gradients def forward(self): """Test optimizer's ability to handle different gradient scales across layers.""" # Each layer contributes equally to the loss but has very different scales - return (self.layer1.square().mean() * 1e-3 + - self.layer2.square().mean() + - self.layer3.square().mean() * 1e3) / 3 + return ( + self.layer1.square().mean() * 1e-3 + self.layer2.square().mean() + self.layer3.square().mean() * 1e3 + ) / 3 @app.command() -def main(method: List[str] = typer.Option(['qr'], help='Eigenvector method to use (for SOAP)'), - dtype: List[str] = typer.Option(['float32'], help='Data type to use'), steps: int = 100, - weight_decay: float = 0, opt: List[str] = typer.Option(['ForeachSOAP'], help='Optimizers to use'), - trials: int = 100, win_condition_multiplier: float = 1.0, ): +def main( + method: List[str] = typer.Option(["qr"], help="Eigenvector method to use (for SOAP)"), + dtype: List[str] = typer.Option(["float32"], help="Data type to use"), + steps: int = 100, + weight_decay: float = 0, + opt: List[str] = typer.Option(["ForeachSOAP"], help="Optimizers to use"), + trials: int = 100, + win_condition_multiplier: float = 1.0, +): dtype = [getattr(torch, d) for d in dtype] model = Model().cuda().double() @@ -40,9 +45,25 @@ def data(): return None, None # More lenient win condition due to vastly different scales - trial(model, data, None, loss_win_condition(win_condition_multiplier * 1e-4), steps, opt[0], dtype[0], 1, 1, - weight_decay, method[0], 1, 1, failure_threshold=5, base_lr=1e-4, trials=trials) # Lower learning rate and more attempts + trial( + model, + data, + None, + loss_win_condition(win_condition_multiplier * 1e-4), + steps, + opt[0], + dtype[0], + 1, + 1, + weight_decay, + method[0], + 1, + 1, + failure_threshold=5, + base_lr=1e-4, + trials=trials, + ) # Lower learning rate and more attempts -if __name__ == '__main__': +if __name__ == "__main__": app() diff --git a/benchmark/loss_contour.py b/benchmark/loss_contour.py index bf27e99..d3f975b 100644 --- a/benchmark/loss_contour.py +++ b/benchmark/loss_contour.py @@ -9,7 +9,7 @@ import heavyball -device = 'cuda' +device = "cuda" heavyball.utils.compile_mode = None heavyball.utils.dynamic = True heavyball.utils.set_torch() @@ -33,20 +33,20 @@ class DatasetNorm(nn.Module): def __init__(self, features: int, momentum: float = 0.99): super().__init__() self.weight = nn.Parameter(torch.stack([torch.ones(features), torch.zeros(features)], 1)) - self.register_buffer('stats', torch.zeros(features * 2)) - self.register_buffer('step', torch.tensor(0)) + self.register_buffer("stats", torch.zeros(features * 2)) + self.register_buffer("step", torch.tensor(0)) self.momentum = momentum def forward(self, x): if True: with torch.no_grad(): - mean, sq_mean = x.mean(dim=0), (x ** 2).mean(dim=0) + mean, sq_mean = x.mean(dim=0), (x**2).mean(dim=0) stats = torch.cat([mean, sq_mean]) self.step.add_(1) self.stats.lerp_(stats, 1 - heavyball.utils.beta_debias(self.momentum, self.step)) # self.stats.lerp_(stats, self.step == 1) mean, sq_mean = self.stats.chunk(2) - std = (sq_mean - mean ** 2).clamp_min_(1e-6).sqrt() + std = (sq_mean - mean**2).clamp_min_(1e-6).sqrt() else: std, mean = 1, 0 weight, bias = self.weight.unbind(1) @@ -60,10 +60,16 @@ def __init__(self, in_shape, out_shape, width, depth, act=Sine(), expanded: int layers.append(nn.Linear(in_shape, width)) for _ in range(depth - 1): - layers.append(Residual(nn.Sequential(nn.Linear(width, expanded), # - act, # - DatasetNorm(expanded), # - nn.Linear(expanded, width)))) + layers.append( + Residual( + nn.Sequential( + nn.Linear(width, expanded), # + act, # + DatasetNorm(expanded), # + nn.Linear(expanded, width), + ) + ) + ) layers.append(DatasetNorm(width)) layers.append(nn.Linear(width, out_shape)) self.model = nn.Sequential(*layers) @@ -96,18 +102,33 @@ def generate_two_moons_torch(n_samples=1000, noise=0.1, random_state=None): return X, y -def train_and_generate_frames(model, X_train, y_train, domain, epochs, lr, filename="training_video", - resolution: int = 128, subsample: int = 1, train_samples: int = 1024): +def train_and_generate_frames( + model, + X_train, + y_train, + domain, + epochs, + lr, + filename="training_video", + resolution: int = 128, + subsample: int = 1, + train_samples: int = 1024, +): X_train = X_train.to(device).float() y_train = y_train.view(-1, 1).to(device).float() - optimizers = {'ForeachSOAP': heavyball.ForeachSOAP(model.parameters(), lr=lr), - 'PaLMForeachSOAP': heavyball.PaLMForeachSOAP(model.parameters(), lr=lr), - 'PrecondScheduleForeachSOAP': heavyball.PrecondScheduleForeachSOAP(model.parameters(), lr=lr)} + optimizers = { + "ForeachSOAP": heavyball.ForeachSOAP(model.parameters(), lr=lr), + "PaLMForeachSOAP": heavyball.PaLMForeachSOAP(model.parameters(), lr=lr), + "PrecondScheduleForeachSOAP": heavyball.PrecondScheduleForeachSOAP(model.parameters(), lr=lr), + } criterion = nn.BCEWithLogitsLoss() - xx, yy = torch.meshgrid(torch.linspace(domain[0][0], domain[1][0], resolution, device=device), - torch.linspace(domain[0][1], domain[1][1], resolution, device=device), indexing="xy") + xx, yy = torch.meshgrid( + torch.linspace(domain[0][0], domain[1][0], resolution, device=device), + torch.linspace(domain[0][1], domain[1][1], resolution, device=device), + indexing="xy", + ) grid_points = torch.stack((xx.ravel(), yy.ravel()), dim=1).float() base_model = copy.deepcopy(model) @@ -117,7 +138,7 @@ def train_and_generate_frames(model, X_train, y_train, domain, epochs, lr, filen print(f"\nTraining with {optimizer_name}") model.train() - os.makedirs('frames', exist_ok=True) + os.makedirs("frames", exist_ok=True) for epoch in tqdm.tqdm(range(epochs)): outputs = model(X_train) @@ -133,10 +154,10 @@ def train_and_generate_frames(model, X_train, y_train, domain, epochs, lr, filen Z = model(grid_points).reshape(resolution, resolution) plt.figure(figsize=(10, 8)) plt.contourf(xx.cpu(), yy.cpu(), Z.cpu(), levels=20) - plt.colorbar(label='Model Output') - plt.scatter(X_train[:, 0].cpu(), X_train[:, 1].cpu(), c=y_train.cpu(), cmap='coolwarm') - plt.title(f'{optimizer_name} - Epoch {epoch}, Loss: {loss.item():.4f}') - plt.savefig(f'frames/{optimizer_name}_epoch_{epoch:05d}.png') + plt.colorbar(label="Model Output") + plt.scatter(X_train[:, 0].cpu(), X_train[:, 1].cpu(), c=y_train.cpu(), cmap="coolwarm") + plt.title(f"{optimizer_name} - Epoch {epoch}, Loss: {loss.item():.4f}") + plt.savefig(f"frames/{optimizer_name}_epoch_{epoch:05d}.png") plt.close() model.train() @@ -144,10 +165,12 @@ def train_and_generate_frames(model, X_train, y_train, domain, epochs, lr, filen if __name__ == "__main__": X, y = generate_two_moons_torch(n_samples=1024, noise=0.05, random_state=42) - domain = np.array( - [[X[:, 0].min().item() - 1, X[:, 1].min().item() - 1], [X[:, 0].max().item() + 1, X[:, 1].max().item() + 1]]) + domain = np.array([ + [X[:, 0].min().item() - 1, X[:, 1].min().item() - 1], + [X[:, 0].max().item() + 1, X[:, 1].max().item() + 1], + ]) - model = torch.compile(MLP(in_shape=2, out_shape=1, width=2, depth=32), mode='max-autotune-no-cudagraphs').to(device) + model = torch.compile(MLP(in_shape=2, out_shape=1, width=2, depth=32), mode="max-autotune-no-cudagraphs").to(device) epochs = 100 lr = 1e-4 diff --git a/benchmark/minimax.py b/benchmark/minimax.py index 208e712..3866a88 100644 --- a/benchmark/minimax.py +++ b/benchmark/minimax.py @@ -3,11 +3,12 @@ import torch import torch.backends.opt_einsum import typer -from benchmark.utils import trial, param_norm_win_condition -from heavyball.utils import set_torch from torch import nn from torch.nn import functional as F +from benchmark.utils import param_norm_win_condition, trial +from heavyball.utils import set_torch + app = typer.Typer(pretty_exceptions_enable=False) set_torch() @@ -23,22 +24,44 @@ def forward(self, inp): @app.command() -def main(method: List[str] = typer.Option(['qr'], help='Eigenvector method to use (for SOAP)'), - dtype: List[str] = typer.Option(['float32'], help='Data type to use'), size: int = 1024, depth: int = 4, - batch: int = 16, steps: int = 10, weight_decay: float = 0, - opt: List[str] = typer.Option(['ForeachSOAP'], help='Optimizers to use'), win_condition_multiplier: float = 1.0, - trials: int = 10, ): +def main( + method: List[str] = typer.Option(["qr"], help="Eigenvector method to use (for SOAP)"), + dtype: List[str] = typer.Option(["float32"], help="Data type to use"), + size: int = 1024, + depth: int = 4, + batch: int = 16, + steps: int = 10, + weight_decay: float = 0, + opt: List[str] = typer.Option(["ForeachSOAP"], help="Optimizers to use"), + win_condition_multiplier: float = 1.0, + trials: int = 10, +): dtype = [getattr(torch, d) for d in dtype] model = Model(size).cuda() def data(): - inp = torch.randn((batch, size), device='cuda', dtype=dtype[0]) + inp = torch.randn((batch, size), device="cuda", dtype=dtype[0]) return inp, inp.cumsum(1) - trial(model, data, F.mse_loss, param_norm_win_condition(1e-7 * win_condition_multiplier, model.target), steps, - opt[0], dtype[0], size, batch, weight_decay, method[0], 1, depth, failure_threshold=depth * 2, base_lr=1e-3, - trials=trials) + trial( + model, + data, + F.mse_loss, + param_norm_win_condition(1e-7 * win_condition_multiplier, model.target), + steps, + opt[0], + dtype[0], + size, + batch, + weight_decay, + method[0], + 1, + depth, + failure_threshold=depth * 2, + base_lr=1e-3, + trials=trials, + ) -if __name__ == '__main__': +if __name__ == "__main__": app() diff --git a/benchmark/momentum_utilization.py b/benchmark/momentum_utilization.py index a38e75d..e476156 100644 --- a/benchmark/momentum_utilization.py +++ b/benchmark/momentum_utilization.py @@ -1,5 +1,3 @@ -import pathlib -import random from typing import List import torch @@ -7,7 +5,7 @@ import typer from torch import nn -from benchmark.utils import trial, loss_win_condition +from benchmark.utils import loss_win_condition, trial from heavyball.utils import set_torch app = typer.Typer(pretty_exceptions_enable=False) @@ -18,29 +16,50 @@ class Model(nn.Module): def __init__(self, size=1024): super().__init__() self.param = nn.Parameter(torch.randn(size)) - self.register_buffer('t', torch.zeros(1)) + self.register_buffer("t", torch.zeros(1)) def forward(self): """Tests effective use of momentum for oscillating landscapes.""" self.t += 0.1 x = self.param - return (x.square() + 0.1*torch.sin(10*x)*torch.cos(self.t)).mean() + return (x.square() + 0.1 * torch.sin(10 * x) * torch.cos(self.t)).mean() @app.command() -def main(method: List[str] = typer.Option(['qr'], help='Eigenvector method to use (for SOAP)'), - dtype: List[str] = typer.Option(['float32'], help='Data type to use'), steps: int = 100, - weight_decay: float = 0, opt: List[str] = typer.Option(['ForeachSOAP'], help='Optimizers to use'), - trials: int = 100, win_condition_multiplier: float = 1.0, ): +def main( + method: List[str] = typer.Option(["qr"], help="Eigenvector method to use (for SOAP)"), + dtype: List[str] = typer.Option(["float32"], help="Data type to use"), + steps: int = 100, + weight_decay: float = 0, + opt: List[str] = typer.Option(["ForeachSOAP"], help="Optimizers to use"), + trials: int = 100, + win_condition_multiplier: float = 1.0, +): dtype = [getattr(torch, d) for d in dtype] model = Model().cuda().double() def data(): return None, None - trial(model, data, None, loss_win_condition(win_condition_multiplier * 1e-6), steps, opt[0], dtype[0], 1, 1, - weight_decay, method[0], 1, 1, failure_threshold=3, base_lr=1e-3, trials=trials) + trial( + model, + data, + None, + loss_win_condition(win_condition_multiplier * 1e-6), + steps, + opt[0], + dtype[0], + 1, + 1, + weight_decay, + method[0], + 1, + 1, + failure_threshold=3, + base_lr=1e-3, + trials=trials, + ) -if __name__ == '__main__': +if __name__ == "__main__": app() diff --git a/benchmark/noisy_matmul.py b/benchmark/noisy_matmul.py index a7ba408..7a94123 100644 --- a/benchmark/noisy_matmul.py +++ b/benchmark/noisy_matmul.py @@ -1,4 +1,3 @@ -import itertools from typing import List import torch @@ -7,9 +6,8 @@ from torch import nn from torch.nn import functional as F -import heavyball +from benchmark.utils import param_norm_win_condition, trial from heavyball.utils import set_torch -from benchmark.utils import trial, param_norm_win_condition app = typer.Typer(pretty_exceptions_enable=False) set_torch() @@ -25,20 +23,21 @@ def forward(self, inp): y = None y0 = self.param.view(1, -1).expand(inp.size(0), -1) + self.offset # offset, so weight decay doesnt help for i in inp.unbind(1): - y = torch.einsum('bi,bik->bk', y0, i) + y = torch.einsum("bi,bik->bk", y0, i) y0 = F.leaky_relu(y, 0.1) return y + @app.command() def main( - method: List[str] = typer.Option(['qr'], help='Eigenvector method to use (for SOAP)'), - dtype: List[str] = typer.Option(['float32'], help='Data type to use'), + method: List[str] = typer.Option(["qr"], help="Eigenvector method to use (for SOAP)"), + dtype: List[str] = typer.Option(["float32"], help="Data type to use"), size: int = 64, depth: int = 4, batch: int = 128, steps: int = 10, weight_decay: float = 0, - opt: List[str] = typer.Option(['ForeachSOAP'], help='Optimizers to use'), + opt: List[str] = typer.Option(["ForeachSOAP"], help="Optimizers to use"), win_condition_multiplier: float = 1.0, trials: int = 10, ): @@ -46,11 +45,28 @@ def main( model = Model(size).cuda() def data(): - inp = torch.randn((batch, depth, size, size), device='cuda', dtype=dtype[0]) / size ** 0.5 - return inp, torch.zeros((batch, size), device='cuda', dtype=dtype[0]) + inp = torch.randn((batch, depth, size, size), device="cuda", dtype=dtype[0]) / size**0.5 + return inp, torch.zeros((batch, size), device="cuda", dtype=dtype[0]) + + trial( + model, + data, + F.mse_loss, + param_norm_win_condition(1e-7 * win_condition_multiplier, model.offset), + steps, + opt[0], + dtype[0], + size, + batch, + weight_decay, + method[0], + 1, + depth, + failure_threshold=depth * 2, + base_lr=1e-3, + trials=trials, + ) - trial(model, data, F.mse_loss, param_norm_win_condition(1e-7 * win_condition_multiplier, model.offset), steps, opt[0], dtype[0], size, batch, weight_decay, method[0], 1, depth, - failure_threshold=depth * 2, base_lr=1e-3, trials=trials) -if __name__ == '__main__': +if __name__ == "__main__": app() diff --git a/benchmark/parameter_scale.py b/benchmark/parameter_scale.py index 2158e31..1ef2113 100644 --- a/benchmark/parameter_scale.py +++ b/benchmark/parameter_scale.py @@ -5,7 +5,7 @@ import typer from torch import nn -from benchmark.utils import trial, loss_win_condition +from benchmark.utils import loss_win_condition, trial from heavyball.utils import set_torch app = typer.Typer(pretty_exceptions_enable=False) @@ -17,22 +17,25 @@ def __init__(self, size=1024): super().__init__() # Simulate different layer scales in deep networks self.layer1 = nn.Parameter(torch.randn(size) * 1e-3) # Small gradients - self.layer2 = nn.Parameter(torch.randn(size)) # Medium gradients - self.layer3 = nn.Parameter(torch.randn(size) * 1e3) # Large gradients + self.layer2 = nn.Parameter(torch.randn(size)) # Medium gradients + self.layer3 = nn.Parameter(torch.randn(size) * 1e3) # Large gradients def forward(self): """Test optimizer's ability to handle different gradient scales across layers.""" # Each layer contributes equally to the loss but has very different scales - return (self.layer1.square().mean() + - self.layer2.square().mean() + - self.layer3.square().mean()) / 3 + return (self.layer1.square().mean() + self.layer2.square().mean() + self.layer3.square().mean()) / 3 @app.command() -def main(method: List[str] = typer.Option(['qr'], help='Eigenvector method to use (for SOAP)'), - dtype: List[str] = typer.Option(['float32'], help='Data type to use'), steps: int = 100, - weight_decay: float = 0, opt: List[str] = typer.Option(['ForeachSOAP'], help='Optimizers to use'), - trials: int = 100, win_condition_multiplier: float = 1.0, ): +def main( + method: List[str] = typer.Option(["qr"], help="Eigenvector method to use (for SOAP)"), + dtype: List[str] = typer.Option(["float32"], help="Data type to use"), + steps: int = 100, + weight_decay: float = 0, + opt: List[str] = typer.Option(["ForeachSOAP"], help="Optimizers to use"), + trials: int = 100, + win_condition_multiplier: float = 1.0, +): dtype = [getattr(torch, d) for d in dtype] model = Model().cuda().double() @@ -40,9 +43,25 @@ def data(): return None, None # More lenient win condition due to vastly different scales - trial(model, data, None, loss_win_condition(win_condition_multiplier * 1e-4), steps, opt[0], dtype[0], 1, 1, - weight_decay, method[0], 1, 1, failure_threshold=5, base_lr=1e-4, trials=trials) # Lower learning rate and more attempts + trial( + model, + data, + None, + loss_win_condition(win_condition_multiplier * 1e-4), + steps, + opt[0], + dtype[0], + 1, + 1, + weight_decay, + method[0], + 1, + 1, + failure_threshold=5, + base_lr=1e-4, + trials=trials, + ) # Lower learning rate and more attempts -if __name__ == '__main__': +if __name__ == "__main__": app() diff --git a/benchmark/plateau_navigation.py b/benchmark/plateau_navigation.py index f4f650f..a3ea142 100644 --- a/benchmark/plateau_navigation.py +++ b/benchmark/plateau_navigation.py @@ -1,3 +1,4 @@ +import math import pathlib import random from typing import List @@ -6,11 +7,10 @@ import torch import torch.backends.opt_einsum import typer -from utils import Plotter from torch import nn -import math +from utils import Plotter -from benchmark.utils import trial, loss_win_condition +from benchmark.utils import loss_win_condition, trial from heavyball.utils import set_torch app = typer.Typer(pretty_exceptions_enable=False) @@ -19,7 +19,7 @@ def objective(x, y, scale: float = 10): """Tests optimizer's ability to handle regions with very small gradients and sharp plateaus.""" - output = 1/(1 + torch.exp((x**2 + y**2 - 1) * -scale)) + output = 1 / (1 + torch.exp((x**2 + y**2 - 1) * -scale)) minimum = 1 / (1 + math.exp(scale)) return output - minimum # ensure the minimum is at 0 @@ -34,26 +34,36 @@ def forward(self): @app.command() -def main(method: List[str] = typer.Option(['qr'], help='Eigenvector method to use (for SOAP)'), - dtype: List[str] = typer.Option(['float32'], help='Data type to use'), steps: int = 100, - weight_decay: float = 0, opt: List[str] = typer.Option(['ForeachSOAP'], help='Optimizers to use'), - show_image: bool = False, trials: int = 100, win_condition_multiplier: float = 1.0, ): +def main( + method: List[str] = typer.Option(["qr"], help="Eigenvector method to use (for SOAP)"), + dtype: List[str] = typer.Option(["float32"], help="Data type to use"), + steps: int = 100, + weight_decay: float = 0, + opt: List[str] = typer.Option(["ForeachSOAP"], help="Optimizers to use"), + show_image: bool = False, + trials: int = 100, + win_condition_multiplier: float = 1.0, +): dtype = [getattr(torch, d) for d in dtype] coords = (1.5, 1.5) # Start outside the plateau # Clean up old plots - for path in pathlib.Path('.').glob('plateau_navigation.png'): + for path in pathlib.Path(".").glob("plateau_navigation.png"): path.unlink() - img = None colors = list(matplotlib.colors.TABLEAU_COLORS.values()) - stride = max(1, steps // 20) rng = random.Random(0x1239121) rng.shuffle(colors) if show_image: - model = Plotter(lambda *x: objective(*x).log(), coords=coords, xlim=(-2, 2), ylim=(-2, 2), normalize=8, - after_step=torch.exp) + model = Plotter( + lambda *x: objective(*x).log(), + coords=coords, + xlim=(-2, 2), + ylim=(-2, 2), + normalize=8, + after_step=torch.exp, + ) else: model = Model(coords) model.double() @@ -61,9 +71,25 @@ def main(method: List[str] = typer.Option(['qr'], help='Eigenvector method to us def data(): return None, None - trial(model, data, None, loss_win_condition(win_condition_multiplier * 1e-4), steps, opt[0], dtype[0], 1, 1, - weight_decay, method[0], 1, 1, failure_threshold=3, base_lr=1e-3, trials=trials) - - -if __name__ == '__main__': + trial( + model, + data, + None, + loss_win_condition(win_condition_multiplier * 1e-4), + steps, + opt[0], + dtype[0], + 1, + 1, + weight_decay, + method[0], + 1, + 1, + failure_threshold=3, + base_lr=1e-3, + trials=trials, + ) + + +if __name__ == "__main__": app() diff --git a/benchmark/powers.py b/benchmark/powers.py index b48f67a..b330632 100644 --- a/benchmark/powers.py +++ b/benchmark/powers.py @@ -1,15 +1,12 @@ -import itertools from typing import List import torch import torch.backends.opt_einsum import torch.nn as nn import typer -from torch.nn import functional as F -import heavyball -from heavyball.utils import set_torch from benchmark.utils import loss_win_condition, trial +from heavyball.utils import set_torch app = typer.Typer(pretty_exceptions_enable=False) set_torch() @@ -20,7 +17,7 @@ def __init__(self, size, powers, target): super().__init__() self.target = target self.param = nn.Parameter(torch.rand(powers, size) * 2) - self.register_buffer('scale', torch.arange(powers).float().add(1)) + self.register_buffer("scale", torch.arange(powers).float().add(1)) def forward(self): x = self.param - self.target @@ -30,14 +27,14 @@ def forward(self): @app.command() def main( - method: List[str] = typer.Option(['qr'], help='Eigenvector method to use (for SOAP)'), - dtype: List[str] = typer.Option(['float32'], help='Data type to use'), + method: List[str] = typer.Option(["qr"], help="Eigenvector method to use (for SOAP)"), + dtype: List[str] = typer.Option(["float32"], help="Data type to use"), size: int = 64, powers: int = 8, steps: int = 10, target: float = 1.0, weight_decay: float = 0, - opt: List[str] = typer.Option(['ForeachSOAP'], help='Optimizers to use'), + opt: List[str] = typer.Option(["ForeachSOAP"], help="Optimizers to use"), win_condition_multiplier: float = 1.0, trials: int = 10, ): @@ -46,9 +43,26 @@ def main( def data(): return None, None - trial(model, data, None, loss_win_condition(win_condition_multiplier * 1e-8), steps, opt[0], dtype[0], 1, 1, weight_decay, method[0], 1, 1, - failure_threshold=3, base_lr=1e-3, trials=trials) + + trial( + model, + data, + None, + loss_win_condition(win_condition_multiplier * 1e-8), + steps, + opt[0], + dtype[0], + 1, + 1, + weight_decay, + method[0], + 1, + 1, + failure_threshold=3, + base_lr=1e-3, + trials=trials, + ) -if __name__ == '__main__': +if __name__ == "__main__": app() diff --git a/benchmark/powers_varying_target.py b/benchmark/powers_varying_target.py index 9efb22c..03c2777 100644 --- a/benchmark/powers_varying_target.py +++ b/benchmark/powers_varying_target.py @@ -1,15 +1,12 @@ -import itertools from typing import List import torch import torch.backends.opt_einsum import torch.nn as nn import typer -from torch.nn import functional as F -import heavyball -from heavyball.utils import set_torch from benchmark.utils import loss_win_condition, trial +from heavyball.utils import set_torch app = typer.Typer(pretty_exceptions_enable=False) set_torch() @@ -18,9 +15,11 @@ class Model(nn.Module): def __init__(self, size, powers, target_mult): super().__init__() - self.target = nn.Buffer(torch.arange(powers * size).view(size, powers).transpose(0, 1).float() * target_mult / powers / size) + self.target = nn.Buffer( + torch.arange(powers * size).view(size, powers).transpose(0, 1).float() * target_mult / powers / size + ) self.param = nn.Parameter(torch.rand(powers, size) * 2) - self.register_buffer('scale', torch.arange(powers).float().add(1)) + self.register_buffer("scale", torch.arange(powers).float().add(1)) def forward(self): x = self.param - self.target @@ -30,14 +29,14 @@ def forward(self): @app.command() def main( - method: List[str] = typer.Option(['qr'], help='Eigenvector method to use (for SOAP)'), - dtype: List[str] = typer.Option(['float32'], help='Data type to use'), + method: List[str] = typer.Option(["qr"], help="Eigenvector method to use (for SOAP)"), + dtype: List[str] = typer.Option(["float32"], help="Data type to use"), size: int = 64, powers: int = 8, steps: int = 10, target_mult: float = 1.0, weight_decay: float = 0, - opt: List[str] = typer.Option(['ForeachSOAP'], help='Optimizers to use'), + opt: List[str] = typer.Option(["ForeachSOAP"], help="Optimizers to use"), win_condition_multiplier: float = 1.0, trials: int = 10, ): @@ -46,9 +45,26 @@ def main( def data(): return None, None - trial(model, data, None, loss_win_condition(win_condition_multiplier * 1e-6), steps, opt[0], dtype[0], 1, 1, weight_decay, method[0], 1, 1, - failure_threshold=3, base_lr=1e-3, trials=trials) + + trial( + model, + data, + None, + loss_win_condition(win_condition_multiplier * 1e-6), + steps, + opt[0], + dtype[0], + 1, + 1, + weight_decay, + method[0], + 1, + 1, + failure_threshold=3, + base_lr=1e-3, + trials=trials, + ) -if __name__ == '__main__': +if __name__ == "__main__": app() diff --git a/benchmark/quadratic_varying_scale.py b/benchmark/quadratic_varying_scale.py index bc2da4d..d451b4b 100644 --- a/benchmark/quadratic_varying_scale.py +++ b/benchmark/quadratic_varying_scale.py @@ -1,11 +1,10 @@ -import itertools from typing import List -import heavyball import torch import torch.nn as nn import torch.nn.functional as F import typer + from benchmark.utils import param_norm_win_condition, trial from heavyball.utils import set_torch @@ -17,19 +16,23 @@ class Model(nn.Module): def __init__(self, size): super().__init__() self.param = nn.Parameter(torch.randn(size)) - self.register_buffer('scale', F.normalize(torch.arange(1, 1 + size).float(), dim=0, p=1)) + self.register_buffer("scale", F.normalize(torch.arange(1, 1 + size).float(), dim=0, p=1)) def forward(self): return self.param.square() @ self.scale @app.command() -def main(method: List[str] = typer.Option(['qr'], help='Eigenvector method to use (for SOAP)'), - dtype: List[str] = typer.Option(['float32'], help='Data type to use'), size: int = 1024, batch: int = 256, - steps: int = 100, weight_decay: float = 0, - opt: List[str] = typer.Option(['ForeachSOAP'], help='Optimizers to use'), trials: int = 10, - win_condition_multiplier: float = 1.0, - +def main( + method: List[str] = typer.Option(["qr"], help="Eigenvector method to use (for SOAP)"), + dtype: List[str] = typer.Option(["float32"], help="Data type to use"), + size: int = 1024, + batch: int = 256, + steps: int = 100, + weight_decay: float = 0, + opt: List[str] = typer.Option(["ForeachSOAP"], help="Optimizers to use"), + trials: int = 10, + win_condition_multiplier: float = 1.0, ): dtype = [getattr(torch, d) for d in dtype] model = Model(size).cuda() @@ -37,9 +40,25 @@ def main(method: List[str] = typer.Option(['qr'], help='Eigenvector method to us def data(): return None, None - trial(model, data, None, param_norm_win_condition(win_condition_multiplier * 1e-7, 0), steps, opt[0], dtype[0], - size, batch, weight_decay, method[0], 1, 1, failure_threshold=2, base_lr=1e-3, trials=trials) + trial( + model, + data, + None, + param_norm_win_condition(win_condition_multiplier * 1e-7, 0), + steps, + opt[0], + dtype[0], + size, + batch, + weight_decay, + method[0], + 1, + 1, + failure_threshold=2, + base_lr=1e-3, + trials=trials, + ) -if __name__ == '__main__': +if __name__ == "__main__": app() diff --git a/benchmark/quadratic_varying_target.py b/benchmark/quadratic_varying_target.py index fc6873a..96458ec 100644 --- a/benchmark/quadratic_varying_target.py +++ b/benchmark/quadratic_varying_target.py @@ -1,4 +1,3 @@ -import itertools from typing import List import torch @@ -6,31 +5,32 @@ import torch.nn.functional as F import typer -import heavyball -from heavyball.utils import set_torch from benchmark.utils import param_norm_win_condition, trial +from heavyball.utils import set_torch app = typer.Typer(pretty_exceptions_enable=False) set_torch() + class Model(nn.Module): def __init__(self, size): super().__init__() self.param = nn.Parameter(torch.randn(size)) - self.register_buffer('target', F.normalize(torch.arange(size).float().add(1).square(), dim=0)) + self.register_buffer("target", F.normalize(torch.arange(size).float().add(1).square(), dim=0)) def forward(self): return (self.param - self.target).square().mean() + @app.command() def main( - method: List[str] = typer.Option(['qr'], help='Eigenvector method to use (for SOAP)'), - dtype: List[str] = typer.Option(['float32'], help='Data type to use'), + method: List[str] = typer.Option(["qr"], help="Eigenvector method to use (for SOAP)"), + dtype: List[str] = typer.Option(["float32"], help="Data type to use"), size: int = 1024, batch: int = 256, steps: int = 100, weight_decay: float = 0, - opt: List[str] = typer.Option(['ForeachSOAP'], help='Optimizers to use'), + opt: List[str] = typer.Option(["ForeachSOAP"], help="Optimizers to use"), trials: int = 10, win_condition_multiplier: float = 1.0, ): @@ -40,9 +40,25 @@ def main( def data(): return None, None + trial( + model, + data, + None, + param_norm_win_condition(win_condition_multiplier * 1e-8, -model.target), + steps, + opt[0], + dtype[0], + size, + batch, + weight_decay, + method[0], + 1, + 1, + failure_threshold=2, + base_lr=1e-3, + trials=trials, + ) - trial(model, data, None, param_norm_win_condition(win_condition_multiplier * 1e-8, -model.target), steps, opt[0], dtype[0], size, batch, weight_decay, method[0], 1, 1, - failure_threshold=2, base_lr=1e-3, trials=trials) -if __name__ == '__main__': +if __name__ == "__main__": app() diff --git a/benchmark/rastrigin.py b/benchmark/rastrigin.py index 623a550..a66aff9 100644 --- a/benchmark/rastrigin.py +++ b/benchmark/rastrigin.py @@ -1,9 +1,7 @@ -import copy +import math import pathlib import random -import time from typing import List -import math import matplotlib.colors import matplotlib.pyplot as plt @@ -11,10 +9,10 @@ import torch.backends.opt_einsum import typer from hyperopt import early_stop -from utils import Plotter from torch import nn +from utils import Plotter -from benchmark.utils import trial, loss_win_condition +from benchmark.utils import loss_win_condition, trial from heavyball.utils import set_torch early_stop.no_progress_loss() @@ -23,7 +21,8 @@ def _formula(x, A): - return x ** 2 + A * (1 - torch.cos(2 * math.pi * x)) + return x**2 + A * (1 - torch.cos(2 * math.pi * x)) + def objective(*args, A=10): if len(args) == 1: @@ -31,6 +30,7 @@ def objective(*args, A=10): return sum(_formula(x, A) for x in args) / len(args) + class Model(nn.Module): def __init__(self, x): super().__init__() @@ -41,17 +41,24 @@ def forward(self): @app.command() -def main(method: List[str] = typer.Option(['qr'], help='Eigenvector method to use (for SOAP)'), - dtype: List[str] = typer.Option(['float32'], help='Data type to use'), steps: int = 100, - weight_decay: float = 0, opt: List[str] = typer.Option(['ForeachSOAP'], help='Optimizers to use'), - show_image: bool = False, trials: int = 100, win_condition_multiplier: float = 1.0, size: int = 2): +def main( + method: List[str] = typer.Option(["qr"], help="Eigenvector method to use (for SOAP)"), + dtype: List[str] = typer.Option(["float32"], help="Data type to use"), + steps: int = 100, + weight_decay: float = 0, + opt: List[str] = typer.Option(["ForeachSOAP"], help="Optimizers to use"), + show_image: bool = False, + trials: int = 100, + win_condition_multiplier: float = 1.0, + size: int = 2, +): if show_image: assert size == 2, "Image can only be displayed for 2D functions" dtype = [getattr(torch, d) for d in dtype] coords = (-2.2,) * size # Clean up old plots - for path in pathlib.Path('.').glob('rastrigin.png'): + for path in pathlib.Path(".").glob("rastrigin.png"): path.unlink() colors = list(matplotlib.colors.TABLEAU_COLORS.values()) @@ -60,8 +67,14 @@ def main(method: List[str] = typer.Option(['qr'], help='Eigenvector method to us rng.shuffle(colors) if show_image: - model = Plotter(lambda *x: objective(*x).log(), coords=coords, xlim=(-8, 2), ylim=(-8, 2), normalize=8, - after_step=torch.exp) + model = Plotter( + lambda *x: objective(*x).log(), + coords=coords, + xlim=(-8, 2), + ylim=(-8, 2), + normalize=8, + after_step=torch.exp, + ) else: model = Model(coords) model.double() @@ -69,9 +82,24 @@ def main(method: List[str] = typer.Option(['qr'], help='Eigenvector method to us def data(): return None, None - model = trial(model, data, None, loss_win_condition(win_condition_multiplier * 1e-2 * (not show_image)), steps, - opt[0], dtype[0], 1, 1, weight_decay, method[0], 1, 1, base_lr=1e-4, trials=trials, - return_best=show_image) + model = trial( + model, + data, + None, + loss_win_condition(win_condition_multiplier * 1e-2 * (not show_image)), + steps, + opt[0], + dtype[0], + 1, + 1, + weight_decay, + method[0], + 1, + 1, + base_lr=1e-4, + trials=trials, + return_best=show_image, + ) if not show_image: return @@ -80,14 +108,20 @@ def data(): ax.set_frame_on(False) c = colors[0] - ax.plot(*list(zip(*model.coords_history)), linewidth=1, color=c, zorder=2, label=f'{method[0]} {opt[0]}') - ax.scatter(*list(zip(*model.coords_history[::stride])), s=8, zorder=1, alpha=0.75, marker='x', color=c) - ax.scatter(*model.coords_history[-1], s=64, zorder=3, marker='x', color=c) + ax.plot( + *list(zip(*model.coords_history)), + linewidth=1, + color=c, + zorder=2, + label=f"{method[0]} {opt[0]}", + ) + ax.scatter(*list(zip(*model.coords_history[::stride])), s=8, zorder=1, alpha=0.75, marker="x", color=c) + ax.scatter(*model.coords_history[-1], s=64, zorder=3, marker="x", color=c) fig.legend() - fig.savefig('rastrigin.png', dpi=1000) + fig.savefig("rastrigin.png", dpi=1000) plt.close(fig) -if __name__ == '__main__': +if __name__ == "__main__": app() diff --git a/benchmark/relu_boundaries.py b/benchmark/relu_boundaries.py index 6f14e03..f742752 100644 --- a/benchmark/relu_boundaries.py +++ b/benchmark/relu_boundaries.py @@ -1,107 +1,111 @@ +import subprocess +from abc import ABC, abstractmethod +from pathlib import Path + +import imageio +import matplotlib.pyplot as plt +import numpy as np import torch import torch.nn as nn -import numpy as np -import matplotlib.pyplot as plt -from matplotlib.animation import FuncAnimation -import os -import imageio -import heavyball import typer -from typing import List, Optional, Literal -from pathlib import Path -import subprocess +from torch.utils.data import DataLoader + +import heavyball from benchmark.utils import get_optim -from torch.utils.data import TensorDataset, DataLoader from heavyball.utils import set_torch -from abc import ABC, abstractmethod app = typer.Typer(pretty_exceptions_enable=False) set_torch() + class BaseDataset(torch.utils.data.Dataset, ABC): classes = 2 + def __init__(self, n_samples: int, batch_size: int, seed: int = 42): self.n_samples = n_samples self.batch_size = batch_size self.seed = seed self.X, self.y = self._generate_data() self.indices = torch.arange(len(self.X)) - + @abstractmethod def _generate_data(self): pass def __len__(self): return self.n_samples // self.batch_size - + def __getitem__(self, idx): - batch_idx = torch.randperm(len(self.indices))[:self.batch_size] + batch_idx = torch.randperm(len(self.indices))[: self.batch_size] return self.X[batch_idx], self.y[batch_idx] - + def get_full_data(self): return self.X, self.y + class CircleDataset(BaseDataset): def _generate_data(self): generator = torch.Generator() generator.manual_seed(self.seed) - - r1 = torch.normal(mean=2.0, std=0.2, size=(self.n_samples//2,), generator=generator) - theta1 = torch.rand(self.n_samples//2, generator=generator) * 2 * np.pi - - r2 = torch.normal(mean=4.0, std=0.2, size=(self.n_samples//2,), generator=generator) - theta2 = torch.rand(self.n_samples//2, generator=generator) * 2 * np.pi - + + r1 = torch.normal(mean=2.0, std=0.2, size=(self.n_samples // 2,), generator=generator) + theta1 = torch.rand(self.n_samples // 2, generator=generator) * 2 * np.pi + + r2 = torch.normal(mean=4.0, std=0.2, size=(self.n_samples // 2,), generator=generator) + theta2 = torch.rand(self.n_samples // 2, generator=generator) * 2 * np.pi + x1 = torch.stack([r1 * torch.cos(theta1), r1 * torch.sin(theta1)], dim=1) x2 = torch.stack([r2 * torch.cos(theta2), r2 * torch.sin(theta2)], dim=1) - + X = torch.cat([x1, x2], dim=0).float() - y = torch.cat([torch.zeros(self.n_samples//2), - torch.ones(self.n_samples//2)]).reshape(-1).long() - + y = torch.cat([torch.zeros(self.n_samples // 2), torch.ones(self.n_samples // 2)]).reshape(-1).long() + return X, y + class ModularAdditionDataset(BaseDataset): def __init__(self, n_samples: int, batch_size: int, modulo: int = 11, seed: int = 42): self.modulo = modulo super().__init__(n_samples, batch_size, seed) self.classes = modulo - + def _generate_data(self): generator = torch.Generator() generator.manual_seed(self.seed) - + x1 = torch.randint(0, self.modulo, (self.n_samples,), generator=generator).float() x2 = torch.randint(0, self.modulo, (self.n_samples,), generator=generator).float() - + X = torch.stack([x1, x2], dim=1) / (self.modulo - 1) - + y = ((x1 + x2) % self.modulo).long().reshape(-1) - + return X, y + class XORDataset(BaseDataset): def _generate_data(self): generator = torch.Generator() generator.manual_seed(self.seed) - + x1 = torch.randint(0, 2, (self.n_samples,), generator=generator).float() x2 = torch.randint(0, 2, (self.n_samples,), generator=generator).float() - + noise = torch.normal(0, 0.1, (self.n_samples, 2), generator=generator) X = torch.stack([x1, x2], dim=1) + noise - + y = (x1 != x2).reshape(-1).long() - + return X, y + class SimpleMLP(nn.Module): def __init__(self, hidden_size=32, classes=2): super().__init__() self.fc1 = nn.Linear(2, hidden_size) self.fc2 = nn.Linear(hidden_size, hidden_size) self.fc3 = nn.Linear(hidden_size, classes) - + def forward(self, x): x1 = self.fc1(x) r1 = torch.relu(x1) @@ -109,102 +113,122 @@ def forward(self, x): r2 = torch.relu(x2) out = self.fc3(r2) return out - + def get_boundaries(self, x): x1 = self.fc1(x) # First layer pre-activations r1 = torch.relu(x1) x2 = self.fc2(r1) # Second layer pre-activations return x1, x2 -def plot_decision_boundary(model, loader, ax, resolution, device='cuda'): + +def plot_decision_boundary(model, loader, ax, resolution, device="cuda"): model.eval() - + X, y = loader.dataset.get_full_data() - + margin = 1.0 x_min, x_max = X[:, 0].min() - margin, X[:, 0].max() + margin y_min, y_max = X[:, 1].min() - margin, X[:, 1].max() + margin - - xx, yy = np.meshgrid(np.linspace(x_min, x_max, resolution), - np.linspace(y_min, y_max, resolution)) - + + xx, yy = np.meshgrid(np.linspace(x_min, x_max, resolution), np.linspace(y_min, y_max, resolution)) + grid = torch.FloatTensor(np.c_[xx.ravel(), yy.ravel()]).to(device) - - batch_size = 2 ** 14 + + batch_size = 2**14 boundaries1 = [] boundaries2 = [] - + with torch.no_grad(): for i in range(0, len(grid), batch_size): - batch = grid[i:i+batch_size] + batch = grid[i : i + batch_size] b1, b2 = model.get_boundaries(batch) boundaries1.append(b1.cpu()) boundaries2.append(b2.cpu()) - + boundaries1 = torch.cat(boundaries1, dim=0).numpy() boundaries2 = torch.cat(boundaries2, dim=0).numpy() - - scatter = ax.scatter(X[:, 0], X[:, 1], c=y.squeeze(), - cmap=plt.cm.RdYlBu, alpha=0.6, label='Training Data') - + + ax.scatter(X[:, 0], X[:, 1], c=y.squeeze(), cmap=plt.cm.RdYlBu, alpha=0.6, label="Training Data") + # First layer (blue) for i in range(boundaries1.shape[1]): values = boundaries1[:, i].reshape(xx.shape) - ax.contour(xx, yy, values, levels=[0], colors=['#0066CC'], - alpha=0.4, linewidths=1.0, - label='Layer 1 ReLU' if i == 0 else None) - + ax.contour( + xx, + yy, + values, + levels=[0], + colors=["#0066CC"], + alpha=0.4, + linewidths=1.0, + label="Layer 1 ReLU" if i == 0 else None, + ) + # Second layer (red) for i in range(boundaries2.shape[1]): values = boundaries2[:, i].reshape(xx.shape) - ax.contour(xx, yy, values, levels=[0], colors=['#CC0000'], - alpha=0.4, linewidths=1.0, - label='Layer 2 ReLU' if i == 0 else None) - + ax.contour( + xx, + yy, + values, + levels=[0], + colors=["#CC0000"], + alpha=0.4, + linewidths=1.0, + label="Layer 2 ReLU" if i == 0 else None, + ) + ax.set_xlim([x_min, x_max]) ax.set_ylim([y_min, y_max]) - ax.legend(loc='upper right') - - ax.text(0.02, 0.98, 'Blue: Layer 1 ReLU boundaries\nRed: Layer 2 ReLU boundaries', - transform=ax.transAxes, fontsize=8, verticalalignment='top', - bbox=dict(boxstyle='round', facecolor='white', alpha=0.8)) + ax.legend(loc="upper right") + + ax.text( + 0.02, + 0.98, + "Blue: Layer 1 ReLU boundaries\nRed: Layer 2 ReLU boundaries", + transform=ax.transAxes, + fontsize=8, + verticalalignment="top", + bbox=dict(boxstyle="round", facecolor="white", alpha=0.8), + ) + @app.command() def main( - hidden_size: int = typer.Option(32, help='Number of neurons in hidden layers'), - n_samples: int = typer.Option(2 ** 14, help='Number of training samples'), - learning_rate: float = typer.Option(0.001, help='Learning rate for the optimizer'), - n_epochs: int = typer.Option(10, help='Number of training epochs'), - batch_size: int = typer.Option(128, help='Batch size for training'), - n_frames: int = typer.Option(10, help='Number of frames to generate'), - output_file: str = typer.Option('relu_boundaries.mp4', help='Output filename'), - output_format: str = typer.Option('mp4', help='Output format: mp4 or gif'), - fps: int = typer.Option(10, help='Frames per second in output video'), - seed: int = typer.Option(42, help='Random seed for reproducibility'), - optimizer: str = typer.Option('PrecondScheduleForeachSOAP', help=f'Optimizer to use'), - weight_decay: float = typer.Option(0.0, help='Weight decay for the optimizer'), - beta1: float = typer.Option(0.9, help='Beta1 parameter for Adam-like optimizers'), - beta2: float = typer.Option(0.999, help='Beta2 parameter for Adam-like optimizers'), - resolution: int = typer.Option(32, help='Resolution of the decision boundary plot'), - dataset: str = typer.Option('circle', help='Dataset to use: circle, modular, xor'), - modulo: int = typer.Option(11, help='Modulo for modular addition dataset') + hidden_size: int = typer.Option(32, help="Number of neurons in hidden layers"), + n_samples: int = typer.Option(2**14, help="Number of training samples"), + learning_rate: float = typer.Option(0.001, help="Learning rate for the optimizer"), + n_epochs: int = typer.Option(10, help="Number of training epochs"), + batch_size: int = typer.Option(128, help="Batch size for training"), + n_frames: int = typer.Option(10, help="Number of frames to generate"), + output_file: str = typer.Option("relu_boundaries.mp4", help="Output filename"), + output_format: str = typer.Option("mp4", help="Output format: mp4 or gif"), + fps: int = typer.Option(10, help="Frames per second in output video"), + seed: int = typer.Option(42, help="Random seed for reproducibility"), + optimizer: str = typer.Option("PrecondScheduleForeachSOAP", help="Optimizer to use"), + weight_decay: float = typer.Option(0.0, help="Weight decay for the optimizer"), + beta1: float = typer.Option(0.9, help="Beta1 parameter for Adam-like optimizers"), + beta2: float = typer.Option(0.999, help="Beta2 parameter for Adam-like optimizers"), + resolution: int = typer.Option(32, help="Resolution of the decision boundary plot"), + dataset: str = typer.Option("circle", help="Dataset to use: circle, modular, xor"), + modulo: int = typer.Option(11, help="Modulo for modular addition dataset"), ): torch.manual_seed(seed) np.random.seed(seed) torch.backends.cudnn.benchmark = True - - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") - + dataset_classes = { - 'circle': lambda: CircleDataset(n_samples, batch_size, seed), - 'modular': lambda: ModularAdditionDataset(n_samples, batch_size, modulo, seed), - 'xor': lambda: XORDataset(n_samples, batch_size, seed) + "circle": lambda: CircleDataset(n_samples, batch_size, seed), + "modular": lambda: ModularAdditionDataset(n_samples, batch_size, modulo, seed), + "xor": lambda: XORDataset(n_samples, batch_size, seed), } - + if dataset not in dataset_classes: raise ValueError(f"Unknown dataset: {dataset}. Choose from {list(dataset_classes.keys())}") - + train_data = dataset_classes[dataset]() train_loader = DataLoader( train_data, @@ -213,91 +237,99 @@ def main( pin_memory=True, num_workers=4, prefetch_factor=2, - persistent_workers=True + persistent_workers=True, ) - + model = SimpleMLP(hidden_size=hidden_size, classes=train_data.classes).to(device) - model = torch.compile(model, mode='max-autotune-no-cudagraphs') - + model = torch.compile(model, mode="max-autotune-no-cudagraphs") + optimizer_class = getattr(heavyball, optimizer) optimizer = get_optim( optimizer_class, model.parameters(), lr=learning_rate, weight_decay=weight_decay, - betas=(beta1, beta2) + betas=(beta1, beta2), ) criterion = nn.CrossEntropyLoss() - + plt.ioff() - fig, ax = plt.subplots(figsize=(10, 10)) + _fig, ax = plt.subplots(figsize=(10, 10)) frames = [] frame_files = [] frame_count = 0 - + log_space = np.logspace(0, np.log10(n_epochs + 1), n_frames) - 1 frame_epochs = np.unique(log_space.astype(int)) print(f"Will capture frames at epochs: {frame_epochs.tolist()}") - - output_dir = Path('frames') + + output_dir = Path("frames") output_dir.mkdir(exist_ok=True) - for f in output_dir.glob('*.png'): + for f in output_dir.glob("*.png"): f.unlink() - + train_iter = iter(train_loader) for epoch in range(n_epochs): model.train() - + try: x, y = next(train_iter) except StopIteration: train_iter = iter(train_loader) x, y = next(train_iter) - + x, y = x.squeeze(0).to(device), y.squeeze(0).to(device) - + optimizer.zero_grad() outputs = model(x) loss = criterion(outputs, y) loss.backward() optimizer.step() - + if epoch in frame_epochs: ax.clear() plot_decision_boundary(model, train_loader, ax, resolution, device=device) - ax.set_title(f'Epoch {epoch}\nLoss: {loss.item():.4f}') + ax.set_title(f"Epoch {epoch}\nLoss: {loss.item():.4f}") plt.tight_layout() - - frame_file = output_dir / f'frame_{frame_count:05d}.png' + + frame_file = output_dir / f"frame_{frame_count:05d}.png" frame_files.append(frame_file) - plt.savefig(frame_file, dpi=100, bbox_inches='tight') + plt.savefig(frame_file, dpi=100, bbox_inches="tight") frame_count += 1 - - if output_format == 'gif': + + if output_format == "gif": frames.append(imageio.imread(frame_file)) - + torch.cuda.synchronize() # Ensure GPU operations are complete - + if epoch % (n_epochs // 10) == 0: - print(f'Epoch {epoch}, Loss: {loss.item():.4f}') - + print(f"Epoch {epoch}, Loss: {loss.item():.4f}") + print(f"Generated {frame_count} frames") - + output_path = Path(output_file) if output_format != output_path.suffix[1:]: - output_path = output_path.with_suffix(f'.{output_format}') - - if output_format == 'mp4': + output_path = output_path.with_suffix(f".{output_format}") + + if output_format == "mp4": cmd = [ - 'ffmpeg', '-y', - '-framerate', str(fps), - '-i', str(output_dir / 'frame_%05d.png'), - '-c:v', 'libx264', - '-preset', 'medium', - '-crf', '23', - '-pix_fmt', 'yuv420p', - '-vf', 'pad=ceil(iw/2)*2:ceil(ih/2)*2', - str(output_path) + "ffmpeg", + "-y", + "-framerate", + str(fps), + "-i", + str(output_dir / "frame_%05d.png"), + "-c:v", + "libx264", + "-preset", + "medium", + "-crf", + "23", + "-pix_fmt", + "yuv420p", + "-vf", + "pad=ceil(iw/2)*2:ceil(ih/2)*2", + str(output_path), ] print(f"Running FFmpeg command: {' '.join(cmd)}") result = subprocess.run(cmd, capture_output=True, text=True) @@ -307,9 +339,9 @@ def main( raise RuntimeError("FFmpeg failed to create video") else: imageio.mimsave(output_path, frames, fps=fps) - + # Clean up - plt.close('all') + plt.close("all") for frame_file in frame_files: try: frame_file.unlink() @@ -319,8 +351,9 @@ def main( output_dir.rmdir() except OSError: pass - + print(f"Animation saved as {output_path}") + if __name__ == "__main__": app() diff --git a/benchmark/rosenbrock.py b/benchmark/rosenbrock.py index 296b12c..bf7ce45 100644 --- a/benchmark/rosenbrock.py +++ b/benchmark/rosenbrock.py @@ -9,8 +9,7 @@ from hyperopt import early_stop from torch import nn -from benchmark.utils import Plotter -from benchmark.utils import trial, loss_win_condition +from benchmark.utils import Plotter, loss_win_condition, trial from heavyball.utils import set_torch early_stop.no_progress_loss() @@ -19,7 +18,7 @@ def objective(x, y): - return (1 - x) ** 2 + 1 * (y - x ** 2) ** 2 + return (1 - x) ** 2 + 1 * (y - x**2) ** 2 class Model(nn.Module): @@ -32,19 +31,24 @@ def forward(self): @app.command() -def main(method: List[str] = typer.Option(['qr'], help='Eigenvector method to use (for SOAP)'), - dtype: List[str] = typer.Option(['float32'], help='Data type to use'), steps: int = 100, - weight_decay: float = 0, opt: List[str] = typer.Option(['ForeachSOAP'], help='Optimizers to use'), - show_image: bool = False, trials: int = 100, win_condition_multiplier: float = 1.0, ): +def main( + method: List[str] = typer.Option(["qr"], help="Eigenvector method to use (for SOAP)"), + dtype: List[str] = typer.Option(["float32"], help="Data type to use"), + steps: int = 100, + weight_decay: float = 0, + opt: List[str] = typer.Option(["ForeachSOAP"], help="Optimizers to use"), + show_image: bool = False, + trials: int = 100, + win_condition_multiplier: float = 1.0, +): dtype = [getattr(torch, d) for d in dtype] coords = (-7, -4) # Clean up old plots - for path in pathlib.Path('.').glob('rosenbrock.png'): + for path in pathlib.Path(".").glob("rosenbrock.png"): path.unlink() colors = list(matplotlib.colors.TABLEAU_COLORS.values()) - stride = max(1, steps // 20) rng = random.Random(0x1239121) rng.shuffle(colors) @@ -57,15 +61,30 @@ def main(method: List[str] = typer.Option(['qr'], help='Eigenvector method to us def data(): return None, None - model = trial(model, data, None, loss_win_condition(win_condition_multiplier * 1e-9 * (not show_image)), steps, - opt[0], dtype[0], 1, 1, weight_decay, method[0], 1, 1, base_lr=1e-4, trials=trials, - return_best=show_image) + model = trial( + model, + data, + None, + loss_win_condition(win_condition_multiplier * 1e-9 * (not show_image)), + steps, + opt[0], + dtype[0], + 1, + 1, + weight_decay, + method[0], + 1, + 1, + base_lr=1e-4, + trials=trials, + return_best=show_image, + ) if not show_image: return - model.plot(save_path='rosenbrock.png') + model.plot(save_path="rosenbrock.png") -if __name__ == '__main__': +if __name__ == "__main__": app() diff --git a/benchmark/run_all_benchmarks.py b/benchmark/run_all_benchmarks.py index ae503ac..a7ffdc7 100644 --- a/benchmark/run_all_benchmarks.py +++ b/benchmark/run_all_benchmarks.py @@ -19,39 +19,39 @@ def last_match(pattern, text): def run_benchmark(script, opt, steps, dtype, trials): - base = {'name': script.replace('.py', ''), 'opt': opt} + base = {"name": script.replace(".py", ""), "opt": opt} - import sys import io - import time import pathlib + import sys + import time sys.path.append(str(pathlib.Path(__file__).parent.resolve())) stdout = sys.stdout sys.stdout = io.StringIO() start_time = time.time() try: - module_name = script.replace('.py', '') + module_name = script.replace(".py", "") module = __import__(module_name) # Build arguments arguments = { - 'method': ['qr'], - 'dtype': [dtype], - 'steps': steps, - 'weight_decay': 0, - 'opt': [opt], - 'trials': trials, - 'win_condition_multiplier': 1.0, + "method": ["qr"], + "dtype": [dtype], + "steps": steps, + "weight_decay": 0, + "opt": [opt], + "trials": trials, + "win_condition_multiplier": 1.0, } # Run the main function module.main(**arguments) - except Exception as e: + except Exception: output = sys.stdout.getvalue() error = traceback.format_exc() else: output = sys.stdout.getvalue() - error = '' + error = "" finally: sys.stdout = stdout @@ -60,46 +60,49 @@ def run_benchmark(script, opt, steps, dtype, trials): runtime = last_match(r"Took: ([0-9.]+)", output) loss = last_match(r"Best Loss: ([0-9.e\-+]+)", output) - attempts = int(last_match(r'Attempt: ([0-9]+)', output) or trials) + attempts = int(last_match(r"Attempt: ([0-9]+)", output) or trials) total_runtime = time.time() - start_time - return {**base, 'success': success, - 'runtime': float(runtime or total_runtime), - 'loss': float(loss) if loss else float('inf'), - 'attempts': attempts, - 'error': error if error else ''} + return { + **base, + "success": success, + "runtime": float(runtime or total_runtime), + "loss": float(loss) if loss else float("inf"), + "attempts": attempts, + "error": error if error else "", + } def opt_to_config(opt): caution = "No" mars = "No" - if opt.startswith('cautious-'): - opt = opt[len('cautious-'):] + if opt.startswith("cautious-"): + opt = opt[len("cautious-") :] caution = "Yes" - if opt.startswith('unscaled_cautious-'): - opt = opt[len('unscaled_cautious-'):] + if opt.startswith("unscaled_cautious-"): + opt = opt[len("unscaled_cautious-") :] caution = "Unscaled" - if opt.startswith('mars-'): - opt = opt[len('mars-'):] - mars = 'Yes' + if opt.startswith("mars-"): + opt = opt[len("mars-") :] + mars = "Yes" return opt, caution, mars def write_progress(results, opt, output): - with open(output, 'w') as f: + with open(output, "w") as f: f.write(f"# Benchmark Results\nGenerated: {datetime.now()}\nLast updated: {datetime.now()}\n\n") f.write("## Summary (In Progress)\n\n") f.write("| Optimizer | Caution | Mars | Success | Runtime | Average Attempts |\n") f.write("|-----------|---|---|---------|----------|------|\n") for o in opt: - opt_results = [r for r in results if r['opt'] == o] + opt_results = [r for r in results if r["opt"] == o] if not opt_results: continue - success = sum(r['success'] for r in opt_results) - runtime = np.mean([r['runtime'] for r in opt_results if r['success']]) if success else 0 - attempts = np.mean([r['attempts'] for r in opt_results if r['success']]) if success else 0 + success = sum(r["success"] for r in opt_results) + runtime = np.mean([r["runtime"] for r in opt_results if r["success"]]) if success else 0 + attempts = np.mean([r["attempts"] for r in opt_results if r["success"]]) if success else 0 o, caution, mars = opt_to_config(o) f.write(f"| {o} | {caution} | {mars} | {success}/{len(opt_results)} | {runtime:.2f}s | {attempts:.1f} |\n") @@ -108,61 +111,70 @@ def write_progress(results, opt, output): f.write("| Benchmark | Optimizer | Cautious | Mars | Success | Runtime | Loss | Attempts | \n") f.write("|-----------|-----------|---------|---|---|----------|------|---|\n") - for r in sorted(results, key=lambda x: (x['name'], x['opt'])): - mark = "✓" if r['success'] else "✗" + for r in sorted(results, key=lambda x: (x["name"], x["opt"])): + mark = "✓" if r["success"] else "✗" runtime = f"{r['runtime']:.2f}s" loss = f"{r['loss']:.2e}" attempts = f"{r['attempts']:d}" - opt, caution, mars = opt_to_config(r['opt']) + opt, caution, mars = opt_to_config(r["opt"]) f.write(f"| {r['name']} | {opt} | {caution} | {mars} | {mark} | {runtime} | {loss} | {attempts} | \n") - if any(not r['success'] for r in results): + if any(not r["success"] for r in results): f.write("\n## Errors\n\n") - for r in sorted(results, key=lambda x: (x['name'], x['opt'])): - if not r['success'] and r['error']: + for r in sorted(results, key=lambda x: (x["name"], x["opt"])): + if not r["success"] and r["error"]: f.write(f"\n### {r['name']} - {r['opt']}\n```\n{r['error']}\n```\n") @app.command() -def main(opt: list[str] = typer.Option([], help='Optimizers'), steps: int = 100_000, timeout: int = 3600 * 4, - output: str = 'benchmark_results.md', trials: int = 1000, dtype: str = 'float32', parallelism: int = 16, - caution: bool = False, mars: bool = False, unscaled_caution: bool = False): +def main( + opt: list[str] = typer.Option([], help="Optimizers"), + steps: int = 100_000, + timeout: int = 3600 * 4, + output: str = "benchmark_results.md", + trials: int = 1000, + dtype: str = "float32", + parallelism: int = 16, + caution: bool = False, + mars: bool = False, + unscaled_caution: bool = False, +): benchmarks = [ - 'beale.py', - 'rosenbrock.py', - 'rastrigin.py', - 'quadratic_varying_scale.py', - 'quadratic_varying_target.py', - 'noisy_matmul.py', - 'xor_sequence.py', - 'xor_digit.py', - 'xor_spot.py', - 'saddle_point.py', - 'saddle_point_0init.py', - 'discontinuous_gradient.py', - 'wide_linear.py', - 'minimax.py', - 'plateau_navigation.py', - 'scale_invariant.py', - 'momentum_utilization.py', - 'batch_size_scaling.py', - 'sparse_gradient.py', - 'layer_wise_scale.py', - 'parameter_scale.py', - 'gradient_delay.py', - 'gradient_noise_scale.py', - 'adversarial_gradient.py', - 'dynamic_landscape.py', - 'exploding_gradient.py' + "beale.py", + "rosenbrock.py", + "rastrigin.py", + "quadratic_varying_scale.py", + "quadratic_varying_target.py", + "noisy_matmul.py", + "xor_sequence.py", + "xor_digit.py", + "xor_spot.py", + "saddle_point.py", + "saddle_point_0init.py", + "discontinuous_gradient.py", + "wide_linear.py", + "minimax.py", + "plateau_navigation.py", + "scale_invariant.py", + "momentum_utilization.py", + "batch_size_scaling.py", + "sparse_gradient.py", + "layer_wise_scale.py", + "parameter_scale.py", + "gradient_delay.py", + "gradient_noise_scale.py", + "adversarial_gradient.py", + "dynamic_landscape.py", + "exploding_gradient.py", ] if mars: - opt = ['mars-' + o for o in opt] + opt = ["mars-" + o for o in opt] if caution: - opt = ['cautious-' + o for o in opt] + opt = ["cautious-" + o for o in opt] if unscaled_caution: - opt = ['unscaled_cautious-' + o for o in opt] + opt = ["unscaled_cautious-" + o for o in opt] # Create task queue and result queue task_queue = multiprocessing.Queue() @@ -182,16 +194,16 @@ def worker(task_queue, result_queue): result = run_benchmark(script, o, steps, dtype, trials) except Exception as exc: result = { - 'name': script.replace('.py', ''), - 'opt': o, - 'success': False, - 'runtime': None, - 'attempts': 0, - 'loss': float('inf'), - 'error': str(exc), + "name": script.replace(".py", ""), + "opt": o, + "success": False, + "runtime": None, + "attempts": 0, + "loss": float("inf"), + "error": str(exc), } result_queue.put(result) - except: + except Exception: break # Start worker processes @@ -211,7 +223,8 @@ def worker(task_queue, result_queue): results.append(result) completed += 1 print( - f"Progress: [{completed}/{total_tasks}] {result['name']}.py - {result['opt']}: {'✓' if result['success'] else '✗'}") + f"Progress: [{completed}/{total_tasks}] {result['name']}.py - {result['opt']}: {'✓' if result['success'] else '✗'}" + ) write_progress(results, opt, output) except KeyboardInterrupt: @@ -226,5 +239,5 @@ def worker(task_queue, result_queue): write_progress(results, opt, output) -if __name__ == '__main__': +if __name__ == "__main__": app() diff --git a/benchmark/saddle_point.py b/benchmark/saddle_point.py index 488de6d..a99e4a5 100644 --- a/benchmark/saddle_point.py +++ b/benchmark/saddle_point.py @@ -6,10 +6,10 @@ import torch import torch.backends.opt_einsum import typer -from utils import Plotter from torch import nn +from utils import Plotter -from benchmark.utils import trial, loss_win_condition +from benchmark.utils import loss_win_condition, trial from heavyball.utils import set_torch app = typer.Typer(pretty_exceptions_enable=False) @@ -32,28 +32,38 @@ def forward(self): @app.command() -def main(method: List[str] = typer.Option(['qr'], help='Eigenvector method to use (for SOAP)'), - dtype: List[str] = typer.Option(['float32'], help='Data type to use'), steps: int = 100, - weight_decay: float = 0, opt: List[str] = typer.Option(['ForeachSOAP'], help='Optimizers to use'), - show_image: bool = False, trials: int = 100, win_condition_multiplier: float = 1.0, ): +def main( + method: List[str] = typer.Option(["qr"], help="Eigenvector method to use (for SOAP)"), + dtype: List[str] = typer.Option(["float32"], help="Data type to use"), + steps: int = 100, + weight_decay: float = 0, + opt: List[str] = typer.Option(["ForeachSOAP"], help="Optimizers to use"), + show_image: bool = False, + trials: int = 100, + win_condition_multiplier: float = 1.0, +): dtype = [getattr(torch, d) for d in dtype] coords = (1e-6, 1e-6) # Start near but not at saddle point # Clean up old plots - for path in pathlib.Path('.').glob('saddle_point.png'): + for path in pathlib.Path(".").glob("saddle_point.png"): path.unlink() - img = None colors = list(matplotlib.colors.TABLEAU_COLORS.values()) - stride = max(1, steps // 20) rng = random.Random(0x1239121) rng.shuffle(colors) offset = win_condition_multiplier * 10 if show_image: - model = Plotter(lambda *x: objective(*x).add(offset).log(), coords=coords, xlim=(-2, 2), ylim=(-2, 2), normalize=8, - after_step=torch.exp) + model = Plotter( + lambda *x: objective(*x).add(offset).log(), + coords=coords, + xlim=(-2, 2), + ylim=(-2, 2), + normalize=8, + after_step=torch.exp, + ) else: model = Model(coords, offset) model.double() @@ -61,9 +71,25 @@ def main(method: List[str] = typer.Option(['qr'], help='Eigenvector method to us def data(): return None, None - trial(model, data, None, loss_win_condition(0.1), steps, opt[0], dtype[0], 1, 1, - weight_decay, method[0], 1, 1, failure_threshold=3, base_lr=1e-3, trials=trials) - - -if __name__ == '__main__': + trial( + model, + data, + None, + loss_win_condition(0.1), + steps, + opt[0], + dtype[0], + 1, + 1, + weight_decay, + method[0], + 1, + 1, + failure_threshold=3, + base_lr=1e-3, + trials=trials, + ) + + +if __name__ == "__main__": app() diff --git a/benchmark/saddle_point_0init.py b/benchmark/saddle_point_0init.py index 810439d..999b5e3 100644 --- a/benchmark/saddle_point_0init.py +++ b/benchmark/saddle_point_0init.py @@ -6,18 +6,19 @@ import torch import torch.backends.opt_einsum import typer -from benchmark.utils import trial, param0_win_condition -from heavyball.utils import set_torch from torch import nn from utils import Plotter +from benchmark.utils import param0_win_condition, trial +from heavyball.utils import set_torch + app = typer.Typer(pretty_exceptions_enable=False) set_torch() def objective(x, y): """Classic saddle point objective - tests ability to escape saddle points.""" - return x ** 2 - y ** 2 # Saddle point at (0,0) + return x**2 - y**2 # Saddle point at (0,0) class Model(nn.Module): @@ -31,28 +32,38 @@ def forward(self): @app.command() -def main(method: List[str] = typer.Option(['qr'], help='Eigenvector method to use (for SOAP)'), - dtype: List[str] = typer.Option(['float32'], help='Data type to use'), steps: int = 100, - weight_decay: float = 0, opt: List[str] = typer.Option(['ForeachSOAP'], help='Optimizers to use'), - show_image: bool = False, trials: int = 100, win_condition_multiplier: float = 1.0, ): +def main( + method: List[str] = typer.Option(["qr"], help="Eigenvector method to use (for SOAP)"), + dtype: List[str] = typer.Option(["float32"], help="Data type to use"), + steps: int = 100, + weight_decay: float = 0, + opt: List[str] = typer.Option(["ForeachSOAP"], help="Optimizers to use"), + show_image: bool = False, + trials: int = 100, + win_condition_multiplier: float = 1.0, +): dtype = [getattr(torch, d) for d in dtype] coords = (0, 1e-6) # One dimension starts on the saddle point # Clean up old plots - for path in pathlib.Path('.').glob('saddle_point.png'): + for path in pathlib.Path(".").glob("saddle_point.png"): path.unlink() - img = None colors = list(matplotlib.colors.TABLEAU_COLORS.values()) - stride = max(1, steps // 20) rng = random.Random(0x1239121) rng.shuffle(colors) offset = win_condition_multiplier * 10 if show_image: - model = Plotter(lambda *x: objective(*x).add(offset).log(), coords=coords, xlim=(-2, 2), ylim=(-2, 2), - normalize=8, after_step=torch.exp) + model = Plotter( + lambda *x: objective(*x).add(offset).log(), + coords=coords, + xlim=(-2, 2), + ylim=(-2, 2), + normalize=8, + after_step=torch.exp, + ) else: model = Model(coords, offset) model.double() @@ -60,9 +71,25 @@ def main(method: List[str] = typer.Option(['qr'], help='Eigenvector method to us def data(): return None, None - trial(model, data, None, param0_win_condition(-10), steps, opt[0], dtype[0], 1, 1, weight_decay, method[0], 1, 1, - failure_threshold=3, base_lr=1e-3, trials=trials) - - -if __name__ == '__main__': + trial( + model, + data, + None, + param0_win_condition(-10), + steps, + opt[0], + dtype[0], + 1, + 1, + weight_decay, + method[0], + 1, + 1, + failure_threshold=3, + base_lr=1e-3, + trials=trials, + ) + + +if __name__ == "__main__": app() diff --git a/benchmark/scale_invariant.py b/benchmark/scale_invariant.py index bcdd0ac..532c04e 100644 --- a/benchmark/scale_invariant.py +++ b/benchmark/scale_invariant.py @@ -1,5 +1,3 @@ -import pathlib -import random from typing import List import torch @@ -7,7 +5,7 @@ import typer from torch import nn -from benchmark.utils import trial, loss_win_condition +from benchmark.utils import loss_win_condition, trial from heavyball.utils import set_torch app = typer.Typer(pretty_exceptions_enable=False) @@ -31,19 +29,40 @@ def forward(self): @app.command() -def main(method: List[str] = typer.Option(['qr'], help='Eigenvector method to use (for SOAP)'), - dtype: List[str] = typer.Option(['float32'], help='Data type to use'), steps: int = 100, - weight_decay: float = 0, opt: List[str] = typer.Option(['ForeachSOAP'], help='Optimizers to use'), - trials: int = 100, win_condition_multiplier: float = 1.0, ): +def main( + method: List[str] = typer.Option(["qr"], help="Eigenvector method to use (for SOAP)"), + dtype: List[str] = typer.Option(["float32"], help="Data type to use"), + steps: int = 100, + weight_decay: float = 0, + opt: List[str] = typer.Option(["ForeachSOAP"], help="Optimizers to use"), + trials: int = 100, + win_condition_multiplier: float = 1.0, +): dtype = [getattr(torch, d) for d in dtype] model = Model().cuda().double() def data(): return None, None - trial(model, data, None, loss_win_condition(win_condition_multiplier * 1e-3), steps, opt[0], dtype[0], 1, 1, - weight_decay, method[0], 1, 1, failure_threshold=3, base_lr=1e-3, trials=trials) + trial( + model, + data, + None, + loss_win_condition(win_condition_multiplier * 1e-3), + steps, + opt[0], + dtype[0], + 1, + 1, + weight_decay, + method[0], + 1, + 1, + failure_threshold=3, + base_lr=1e-3, + trials=trials, + ) -if __name__ == '__main__': +if __name__ == "__main__": app() diff --git a/benchmark/shakespeare.txt b/benchmark/shakespeare.txt index e0aff78..3777fb1 100644 --- a/benchmark/shakespeare.txt +++ b/benchmark/shakespeare.txt @@ -39997,4 +39997,4 @@ And yet so fast asleep. ANTONIO: Noble Sebastian, Thou let'st thy fortune sleep--die, rather; wink'st -Whiles thou art waking. \ No newline at end of file +Whiles thou art waking. diff --git a/benchmark/sparse_gradient.py b/benchmark/sparse_gradient.py index f72e6a3..1006626 100644 --- a/benchmark/sparse_gradient.py +++ b/benchmark/sparse_gradient.py @@ -5,7 +5,7 @@ import typer from torch import nn -from benchmark.utils import trial, loss_win_condition +from benchmark.utils import loss_win_condition, trial from heavyball.utils import set_torch app = typer.Typer(pretty_exceptions_enable=False) @@ -13,11 +13,11 @@ class Model(nn.Module): - def __init__(self, size=2 ** 12, sparsity=2 ** -6): + def __init__(self, size=2**12, sparsity=2**-6): super().__init__() self.param = nn.Parameter(torch.randn(size)) self.sparsity = sparsity - self.register_buffer('prev_mask', torch.zeros_like(self.param)) + self.register_buffer("prev_mask", torch.zeros_like(self.param)) def forward(self): """Test optimizer's ability to handle sparse gradients.""" @@ -25,15 +25,20 @@ def forward(self): new_mask = (torch.rand_like(self.param) < self.sparsity).float() mask = (new_mask + self.prev_mask) > 0 # Union of current and previous mask self.prev_mask.copy_(new_mask) - + return (self.param * mask.float()).square().mean() @app.command() -def main(method: List[str] = typer.Option(['qr'], help='Eigenvector method to use (for SOAP)'), - dtype: List[str] = typer.Option(['float32'], help='Data type to use'), steps: int = 100, - weight_decay: float = 0, opt: List[str] = typer.Option(['ForeachSOAP'], help='Optimizers to use'), - trials: int = 100, win_condition_multiplier: float = 1.0, ): +def main( + method: List[str] = typer.Option(["qr"], help="Eigenvector method to use (for SOAP)"), + dtype: List[str] = typer.Option(["float32"], help="Data type to use"), + steps: int = 100, + weight_decay: float = 0, + opt: List[str] = typer.Option(["ForeachSOAP"], help="Optimizers to use"), + trials: int = 100, + win_condition_multiplier: float = 1.0, +): dtype = [getattr(torch, d) for d in dtype] model = Model().cuda().double() @@ -41,9 +46,25 @@ def data(): return None, None # Win condition accounts for sparsity - harder to reach very low loss - trial(model, data, None, loss_win_condition(win_condition_multiplier * 1e-4), steps, opt[0], dtype[0], 1, 1, - weight_decay, method[0], 1, 1, failure_threshold=5, base_lr=1e-3, trials=trials) # More failure attempts allowed + trial( + model, + data, + None, + loss_win_condition(win_condition_multiplier * 1e-4), + steps, + opt[0], + dtype[0], + 1, + 1, + weight_decay, + method[0], + 1, + 1, + failure_threshold=5, + base_lr=1e-3, + trials=trials, + ) # More failure attempts allowed -if __name__ == '__main__': +if __name__ == "__main__": app() diff --git a/benchmark/utils.py b/benchmark/utils.py index 954f4cc..6d74394 100644 --- a/benchmark/utils.py +++ b/benchmark/utils.py @@ -6,24 +6,32 @@ import sys import time import warnings -from datetime import datetime -from multiprocessing import Value from typing import Union -import heavyball.utils import hyperopt import numpy as np import torch from torch import nn from torch._dynamo import config -config.cache_size_limit = 2 ** 16 +import heavyball.utils + +config.cache_size_limit = 2**16 np.warnings = warnings -base_args = {'betas': (0.9, 0.999), 'precondition_frequency': 1, 'merge_dims': False, 'warmup_steps': 100, - 'max_precond_dim': 2 ** 16, 'beta': 0.9, 'max_size_triangular': 2 ** 16, 'split': False, 'eps': 1e-8, - 'weight_decay': 0} +base_args = { + "betas": (0.9, 0.999), + "precondition_frequency": 1, + "merge_dims": False, + "warmup_steps": 100, + "max_precond_dim": 2**16, + "beta": 0.9, + "max_size_triangular": 2**16, + "split": False, + "eps": 1e-8, + "weight_decay": 0, +} def get_optim(optim, params, **kwargs): @@ -38,10 +46,11 @@ def __init__(self, mapping, broadcast: int = 1): self.mapping = mapping self.broadcast = broadcast max_consecutive_failures, minimal_improvement = zip(*mapping.items()) - self.max_consecutive_failures = torch.tensor(max_consecutive_failures, dtype=torch.float64, device='cuda') - self.minimal_improvement = torch.tensor(minimal_improvement, dtype=torch.float64, device='cuda') - self.consecutive_failures = torch.zeros(len(minimal_improvement), dtype=torch.int64, device='cuda').repeat( - broadcast) + self.max_consecutive_failures = torch.tensor(max_consecutive_failures, dtype=torch.float64, device="cuda") + self.minimal_improvement = torch.tensor(minimal_improvement, dtype=torch.float64, device="cuda") + self.consecutive_failures = torch.zeros(len(minimal_improvement), dtype=torch.int64, device="cuda").repeat( + broadcast + ) def compare(self, inp, other): old_state = inp.reshape(1, -1, 1) # vertical @@ -57,7 +66,8 @@ def __call__(self, comparison, failure_scale: float = 1): failed = torch.any(comparison, axis=tuple(range(1, comparison.ndim))) self.consecutive_failures.copy_(torch.where(failed, self.consecutive_failures + 1, 0)) return torch.any( - self.consecutive_failures >= (self.max_consecutive_failures.view(-1, 1) * failure_scale).flatten()) + self.consecutive_failures >= (self.max_consecutive_failures.view(-1, 1) * failure_scale).flatten() + ) class Validator: @@ -70,13 +80,13 @@ def __init__(self, ema_mapping, global_min_mapping, global_avg_mapping, steps, e self.step = 0 self.emas = emas - self.ema_states = torch.zeros((self.emas,), dtype=torch.float64, device='cuda') + self.ema_states = torch.zeros((self.emas,), dtype=torch.float64, device="cuda") es = self.ema_start + 1 - self.update_factor = 2.0 ** (-torch.arange(es, 20 + es, dtype=torch.float64, device='cuda')) - self.ema_failures = FailureCounter(ema_mapping); + self.update_factor = 2.0 ** (-torch.arange(es, 20 + es, dtype=torch.float64, device="cuda")) + self.ema_failures = FailureCounter(ema_mapping) self.triu_indices = torch.triu_indices(self.emas, self.emas, offset=1) - self.global_min_loss = torch.tensor((float('inf'),) * steps, dtype=torch.float64, device='cuda') + self.global_min_loss = torch.tensor((float("inf"),) * steps, dtype=torch.float64, device="cuda") self.global_min_failures = FailureCounter(global_min_mapping, steps) self.global_avg_loss = torch.zeros_like(self.global_min_loss) @@ -103,11 +113,12 @@ def _update_ema(self, loss): def _global_min(self): loss = self.ema_states[self.ema_index] comparison = self.global_min_failures.compare(loss, self.global_min_loss).view(-1, 1) - global_failed = self.global_min_failures(comparison, - torch.arange(1, 1 + self.global_min_loss.size(0), device='cuda').view( - 1, -1).clamp(min=self.global_warmup)) + global_failed = self.global_min_failures( + comparison, + torch.arange(1, 1 + self.global_min_loss.size(0), device="cuda").view(1, -1).clamp(min=self.global_warmup), + ) - loss_slice = self.global_min_loss[self.step - 1:] + loss_slice = self.global_min_loss[self.step - 1 :] loss_slice.copy_(torch.where(torch.logical_and(loss < loss_slice, torch.isfinite(loss)), loss, loss_slice)) return global_failed @@ -118,11 +129,11 @@ def _global_avg(self): self.global_avg_loss[self.step - 1].lerp_(loss, 1 / self.global_avg_step[self.step - 1]) comparison = self.global_avg_failures.compare(loss, self.global_avg_loss).view(-1, 1) - comparison[self.seen_until - 1:].fill_(False) - return self.global_avg_failures(comparison, - torch.arange(1, 1 + self.global_avg_loss.size(0), device='cuda').view(1, - -1).clamp( - min=self.global_warmup)) + comparison[self.seen_until - 1 :].fill_(False) + return self.global_avg_failures( + comparison, + torch.arange(1, 1 + self.global_avg_loss.size(0), device="cuda").view(1, -1).clamp(min=self.global_warmup), + ) def _local_convergence(self): comparison = self.ema_failures.compare(self.ema_states, self.ema_states) @@ -141,8 +152,16 @@ class Stop(Exception): class Plotter(nn.Module): - def __init__(self, objective_fn, x_limits=(-5, 5), y_limits=(-5, 5), resolution=300, transform=None, - inverse_transform=None, should_normalize: bool = True): + def __init__( + self, + objective_fn, + x_limits=(-5, 5), + y_limits=(-5, 5), + resolution=300, + transform=None, + inverse_transform=None, + should_normalize: bool = True, + ): super().__init__() self.should_normalize = should_normalize self.objective = objective_fn @@ -158,11 +177,11 @@ def __init__(self, objective_fn, x_limits=(-5, 5), y_limits=(-5, 5), resolution= with torch.no_grad(): x = torch.linspace(x_limits[0], x_limits[1], resolution) y = torch.linspace(y_limits[0], y_limits[1], resolution) - self.X, self.Y = torch.meshgrid(x, y, indexing='ij') + self.X, self.Y = torch.meshgrid(x, y, indexing="ij") Z = torch.zeros_like(self.X) for i in range(resolution): for j in range(resolution): - objective_fn.param.data[:] = torch.tensor([self.X[i, j].item(), self.Y[i, j].item()], device='cuda') + objective_fn.param.data[:] = torch.tensor([self.X[i, j].item(), self.Y[i, j].item()], device="cuda") Z[i, j] = self.transform(objective_fn()) objective_fn.param.data[:] = self.initial self.Z = Z @@ -177,7 +196,7 @@ def forward(self, *args): def plot(self, title=None, save_path=None): """Create contour plot with optimization trajectory. - + Args: title: Optional title for the plot save_path: Optional path to save the plot @@ -194,13 +213,13 @@ def plot(self, title=None, save_path=None): # Plot trajectory trajectory = np.array(self.trajectory) - plt.plot(trajectory[:, 0], trajectory[:, 1], 'r.-', label='Optimization path') - plt.plot(trajectory[0, 0], trajectory[0, 1], 'go', label='Start') - plt.plot(trajectory[-1, 0], trajectory[-1, 1], 'ro', label='End') + plt.plot(trajectory[:, 0], trajectory[:, 1], "r.-", label="Optimization path") + plt.plot(trajectory[0, 0], trajectory[0, 1], "go", label="Start") + plt.plot(trajectory[-1, 0], trajectory[-1, 1], "ro", label="End") - plt.colorbar(label=f'Log({"Normalized" * self.should_normalize}ObjectiveValue)') - plt.xlabel('x') - plt.ylabel('y') + plt.colorbar(label=f"Log({'Normalized' * self.should_normalize}ObjectiveValue)") + plt.xlabel("x") + plt.ylabel("y") if title: plt.title(title) plt.legend() @@ -212,8 +231,20 @@ def plot(self, title=None, save_path=None): class Objective: - def __init__(self, failure_threshold, model, opt, steps, group, data, loss_fn, win_condition, weight_decay, - ema_index: int = 0, **kwargs): + def __init__( + self, + failure_threshold, + model, + opt, + steps, + group, + data, + loss_fn, + win_condition, + weight_decay, + ema_index: int = 0, + **kwargs, + ): self.failure_threshold = failure_threshold self.model = model.cuda() for mod in self.model.modules(): @@ -230,9 +261,27 @@ def __init__(self, failure_threshold, model, opt, steps, group, data, loss_fn, w self.ema_index = ema_index self.validator = Validator( - {32768: 1e-7, 16384: 1e-6, 8192: 1e-5, 4096: 1e-4, 1024: 1e-3, 512: 1e-2, 256: 0, 128: -1e-4, 64: -1e-3, - 32: -0.01, 16: -0.1, 8: -0.33, 4: -0.5, 2: -0.75, 1: -0.99}, {2: 0}, {6: 0}, - steps) # same loss as best after 3x as many steps; 6x higher loss at same step - for every per-step minimum + { + 32768: 1e-7, + 16384: 1e-6, + 8192: 1e-5, + 4096: 1e-4, + 1024: 1e-3, + 512: 1e-2, + 256: 0, + 128: -1e-4, + 64: -1e-3, + 32: -0.01, + 16: -0.1, + 8: -0.33, + 4: -0.5, + 2: -0.75, + 1: -0.99, + }, + {2: 0}, + {6: 0}, + steps, + ) # same loss as best after 3x as many steps; 6x higher loss at same step - for every per-step minimum self.m = None self.attempt = 0 self.best_loss = None @@ -241,16 +290,20 @@ def __init__(self, failure_threshold, model, opt, steps, group, data, loss_fn, w self.use_cudnn = True def _inner(self, params): - params = {'lr': params[0], 'betas': (1 - params[1], 1 - params[2]), 'shampoo_beta': 1 - params[3], 'eps': 1e-8, - 'precond_lr': params[3] # we never have both precond_lr and shampoo_beta - } + params = { + "lr": params[0], + "betas": (1 - params[1], 1 - params[2]), + "shampoo_beta": 1 - params[3], + "eps": 1e-8, + "precond_lr": params[3], # we never have both precond_lr and shampoo_beta + } self.m = copy.deepcopy(self.model) o = get_optim(self.opt, self.m.parameters(), **params, weight_decay=self.weight_decay, **self.kwargs) - torch_hist = torch.empty(self.group, dtype=torch.float64, device='cuda') + torch_hist = torch.empty(self.group, dtype=torch.float64, device="cuda") validator = self.validator.new() for i in range(self.steps // self.group): - if hasattr(o, 'train'): + if hasattr(o, "train"): o.train() for j in range(self.group): @@ -276,7 +329,7 @@ def _closure(): with torch.no_grad(): torch_hist[j] = loss.detach() - if hasattr(o, 'eval'): + if hasattr(o, "eval"): o.eval() with torch.no_grad(): for loss in torch_hist: @@ -306,14 +359,14 @@ def get_best(self): def loss_win_condition(target): def win(_model, loss: Union[float, hyperopt.Trials]): if not isinstance(loss, (float, torch.Tensor)): - loss = loss.results[-1]['loss'] + loss = loss.results[-1]["loss"] return loss <= target, {} return win def param_norm_win_condition(target, offset): - target = torch.full((), target, device='cuda') + target = torch.full((), target, device="cuda") def win(model, loss): with torch.no_grad(): @@ -322,8 +375,9 @@ def win(model, loss): return win + def param0_win_condition(target): - target = torch.full((), target, device='cuda') + target = torch.full((), target, device="cuda") def win(model, loss): with torch.no_grad(): @@ -331,26 +385,45 @@ def win(model, loss): return win -def trial(model, data, loss_fn, win_condition, steps, opt, dtype, size, batch, weight_decay, method, length, depth, - trials=10, failure_threshold=3, group=64, base_lr: float = 1e-3, return_best: bool = False): + +def trial( + model, + data, + loss_fn, + win_condition, + steps, + opt, + dtype, + size, + batch, + weight_decay, + method, + length, + depth, + trials=10, + failure_threshold=3, + group=64, + base_lr: float = 1e-3, + return_best: bool = False, +): heavyball.utils.set_torch() if isinstance(opt, list): opt = opt[0] - kwargs = {'caution': False, 'mars': False} - if opt.startswith('cautious-'): - opt = opt[len('cautious-'):] - kwargs['caution'] = True - if opt.startswith('unscaled_cautious-'): - opt = opt[len('unscaled_cautious-'):] + kwargs = {"caution": False, "mars": False} + if opt.startswith("cautious-"): + opt = opt[len("cautious-") :] + kwargs["caution"] = True + if opt.startswith("unscaled_cautious-"): + opt = opt[len("unscaled_cautious-") :] heavyball.utils.disable_caution_scaling() - kwargs['caution'] = True - if opt.startswith('mars-'): - opt = opt[len('mars-'):] - kwargs['mars'] = True + kwargs["caution"] = True + if opt.startswith("mars-"): + opt = opt[len("mars-") :] + kwargs["mars"] = True opt = getattr(heavyball, opt) - if "soap" not in opt.__name__.lower() and method != 'qr': + if "soap" not in opt.__name__.lower() and method != "qr": return heavyball.utils.zeroth_power_mode = method @@ -370,19 +443,38 @@ def _win_condition(*args): did_win |= win_state return did_win, out - obj = Objective(failure_threshold, model, opt, steps, group, data, loss_fn, _win_condition, weight_decay, **kwargs) + obj = Objective( + failure_threshold, + model, + opt, + steps, + group, + data, + loss_fn, + _win_condition, + weight_decay, + **kwargs, + ) start_time = time.time() stdout, sys.stdout = sys.stdout, sys.stderr try: # LR=1000 seems way too high, but some problems get solved in one step with it, so it'd be unfair to exclude it - out = hyperopt.fmin(obj.objective, (hyperopt.hp.loguniform('lr', np.log(1e-7), np.log(1000)), # - hyperopt.hp.loguniform('1mbeta1', np.log(1e-3), np.log(1)), # - hyperopt.hp.loguniform('1mbeta2', np.log(1e-5), np.log(1)), # - hyperopt.hp.loguniform('1mshampoo_beta', np.log(1e-4), np.log(1))), # - max_evals=trials, algo=hyperopt.atpe.suggest, - early_stop_fn=lambda x: _win_condition(obj.m, x), return_argmin=True, show_progressbar=True) + out = hyperopt.fmin( + obj.objective, + ( + hyperopt.hp.loguniform("lr", np.log(1e-7), np.log(1000)), # + hyperopt.hp.loguniform("1mbeta1", np.log(1e-3), np.log(1)), # + hyperopt.hp.loguniform("1mbeta2", np.log(1e-5), np.log(1)), # + hyperopt.hp.loguniform("1mshampoo_beta", np.log(1e-4), np.log(1)), + ), # + max_evals=trials, + algo=hyperopt.atpe.suggest, + early_stop_fn=lambda x: _win_condition(obj.m, x), + return_argmin=True, + show_progressbar=True, + ) except Stop: - out = {'lr': 0, '1mbeta1': 0, '1mbeta2': 0, '1mshampoo_beta': 0} + out = {"lr": 0, "1mbeta1": 0, "1mbeta2": 0, "1mshampoo_beta": 0} finally: sys.stdout = stdout torch.cuda.synchronize() @@ -390,8 +482,10 @@ def _win_condition(*args): end_time = time.time() if did_win: print("Successfully found the minimum.") - print(f"Took: {end_time - start_time} | Attempt: {obj.attempt} | " # - f"{opt.__name__}(lr={out['lr']:.5f}, betas=({1 - out['1mbeta1']:.3f}, {1 - out['1mbeta2']:.4f}), " # - f"shampoo_beta={1 - out['1mshampoo_beta']:.3f}) | Best Loss: {obj.best_loss}") + print( + f"Took: {end_time - start_time} | Attempt: {obj.attempt} | " # + f"{opt.__name__}(lr={out['lr']:.5f}, betas=({1 - out['1mbeta1']:.3f}, {1 - out['1mbeta2']:.4f}), " # + f"shampoo_beta={1 - out['1mshampoo_beta']:.3f}) | Best Loss: {obj.best_loss}" + ) if return_best: return obj.get_best() diff --git a/benchmark/wide_linear.py b/benchmark/wide_linear.py index 023d37f..e9e1158 100644 --- a/benchmark/wide_linear.py +++ b/benchmark/wide_linear.py @@ -3,11 +3,12 @@ import torch import torch.backends.opt_einsum import typer -from benchmark.utils import trial, param_norm_win_condition -from heavyball.utils import set_torch from torch import nn from torch.nn import functional as F +from benchmark.utils import param_norm_win_condition, trial +from heavyball.utils import set_torch + app = typer.Typer(pretty_exceptions_enable=False) set_torch() @@ -23,22 +24,44 @@ def forward(self, inp): @app.command() -def main(method: List[str] = typer.Option(['qr'], help='Eigenvector method to use (for SOAP)'), - dtype: List[str] = typer.Option(['float32'], help='Data type to use'), size: int = 1024, depth: int = 4, - batch: int = 16, steps: int = 10, weight_decay: float = 0, - opt: List[str] = typer.Option(['ForeachSOAP'], help='Optimizers to use'), win_condition_multiplier: float = 1.0, - trials: int = 10, ): +def main( + method: List[str] = typer.Option(["qr"], help="Eigenvector method to use (for SOAP)"), + dtype: List[str] = typer.Option(["float32"], help="Data type to use"), + size: int = 1024, + depth: int = 4, + batch: int = 16, + steps: int = 10, + weight_decay: float = 0, + opt: List[str] = typer.Option(["ForeachSOAP"], help="Optimizers to use"), + win_condition_multiplier: float = 1.0, + trials: int = 10, +): dtype = [getattr(torch, d) for d in dtype] model = Model(size).cuda() def data(): - inp = torch.randn((batch, size), device='cuda', dtype=dtype[0]) + inp = torch.randn((batch, size), device="cuda", dtype=dtype[0]) return inp, inp.cumsum(1) - trial(model, data, F.mse_loss, param_norm_win_condition(1e-7 * win_condition_multiplier, model.target), steps, - opt[0], dtype[0], size, batch, weight_decay, method[0], 1, depth, failure_threshold=depth * 2, base_lr=1e-3, - trials=trials) + trial( + model, + data, + F.mse_loss, + param_norm_win_condition(1e-7 * win_condition_multiplier, model.target), + steps, + opt[0], + dtype[0], + size, + batch, + weight_decay, + method[0], + 1, + depth, + failure_threshold=depth * 2, + base_lr=1e-3, + trials=trials, + ) -if __name__ == '__main__': +if __name__ == "__main__": app() diff --git a/benchmark/xor_digit.py b/benchmark/xor_digit.py index 96da1e8..7a04e0b 100644 --- a/benchmark/xor_digit.py +++ b/benchmark/xor_digit.py @@ -1,5 +1,3 @@ -import os -import itertools from typing import List import torch @@ -8,9 +6,8 @@ import typer from torch.nn import functional as F -import heavyball -from heavyball.utils import set_torch from benchmark.utils import loss_win_condition, trial +from heavyball.utils import set_torch app = typer.Typer(pretty_exceptions_enable=False) set_torch() @@ -22,8 +19,10 @@ def __init__(self, size, depth): self.embed = nn.Embedding(2, size) self.enc = nn.LSTM(size, size, depth, batch_first=False) self.enc.flatten_parameters() - self.proj = nn.Sequential(nn.LayerNorm(size), # - nn.Linear(size, 1)) + self.proj = nn.Sequential( + nn.LayerNorm(size), # + nn.Linear(size, 1), + ) def forward(self, inp): inp = inp.transpose(0, 1) @@ -34,30 +33,46 @@ def forward(self, inp): @app.command() def main( - method: List[str] = typer.Option(['qr'], help='Eigenvector method to use (for SOAP)'), - dtype: List[str] = typer.Option(['float32'], help='Data type to use'), + method: List[str] = typer.Option(["qr"], help="Eigenvector method to use (for SOAP)"), + dtype: List[str] = typer.Option(["float32"], help="Data type to use"), length: int = 64, size: int = 64, depth: int = 1, batch: int = 256, steps: int = 10, weight_decay: float = 0, - opt: List[str] = typer.Option(['ForeachSOAP'], help='Optimizers to use'), + opt: List[str] = typer.Option(["ForeachSOAP"], help="Optimizers to use"), trials: int = 10, - win_condition_multiplier: float = 1.0 + win_condition_multiplier: float = 1.0, ): dtype = [getattr(torch, d) for d in dtype] torch.manual_seed(0x1239121) model = Model(size, depth).cuda() def data(): - inp = torch.randn((batch, length, 1), device='cuda', dtype=dtype[0]) + inp = torch.randn((batch, length, 1), device="cuda", dtype=dtype[0]) inp = inp > 0 return inp.to(dtype[0]), (inp.sum(1) % 2).to(dtype[0]) - trial(model, data, F.binary_cross_entropy_with_logits, loss_win_condition(win_condition_multiplier * 1e-3), steps, opt[0], dtype[0], size, batch, weight_decay, method[0], length, depth, - failure_threshold=10, base_lr=1e-6, trials=trials) + trial( + model, + data, + F.binary_cross_entropy_with_logits, + loss_win_condition(win_condition_multiplier * 1e-3), + steps, + opt[0], + dtype[0], + size, + batch, + weight_decay, + method[0], + length, + depth, + failure_threshold=10, + base_lr=1e-6, + trials=trials, + ) -if __name__ == '__main__': +if __name__ == "__main__": app() diff --git a/benchmark/xor_sequence.py b/benchmark/xor_sequence.py index a2564c6..66fd2b3 100644 --- a/benchmark/xor_sequence.py +++ b/benchmark/xor_sequence.py @@ -1,14 +1,13 @@ -import itertools from typing import List -import heavyball import torch import torch.backends.opt_einsum import torch.nn as nn import typer +from torch.nn import functional as F + from benchmark.utils import loss_win_condition, trial from heavyball.utils import set_torch -from torch.nn import functional as F app = typer.Typer(pretty_exceptions_enable=False) set_torch() @@ -23,8 +22,10 @@ def __init__(self, size, depth): self.dec = nn.LSTM(size, size, depth, batch_first=False) self.enc.flatten_parameters() self.dec.flatten_parameters() - self.proj = nn.Sequential(nn.LayerNorm(size), # - nn.Linear(size, 1)) + self.proj = nn.Sequential( + nn.LayerNorm(size), # + nn.Linear(size, 1), + ) def forward(self, inp): i0, i1 = inp.chunk(2, 1) @@ -38,26 +39,49 @@ def forward(self, inp): @app.command() -def main(method: List[str] = typer.Option(['qr'], help='Eigenvector method to use (for SOAP)'), - dtype: List[str] = typer.Option(['float32'], help='Data type to use'), length: int = 14, size: int = 16, - depth: int = 1, batch: int = 256, steps: int = 100, weight_decay: float = 0, - opt: List[str] = typer.Option(['ForeachSOAP'], help='Optimizers to use'), win_condition_multiplier: float = 1, - trials: int = 10): +def main( + method: List[str] = typer.Option(["qr"], help="Eigenvector method to use (for SOAP)"), + dtype: List[str] = typer.Option(["float32"], help="Data type to use"), + length: int = 14, + size: int = 16, + depth: int = 1, + batch: int = 256, + steps: int = 100, + weight_decay: float = 0, + opt: List[str] = typer.Option(["ForeachSOAP"], help="Optimizers to use"), + win_condition_multiplier: float = 1, + trials: int = 10, +): dtype = [getattr(torch, d) for d in dtype] torch.manual_seed(0x1239121) model = Model(size, depth).cuda() def data(): - inp = torch.randn((batch, length, 1), device='cuda', dtype=dtype[0]) + inp = torch.randn((batch, length, 1), device="cuda", dtype=dtype[0]) inp = inp > 0 i0, i1 = inp.chunk(2, 1) xored = torch.logical_xor(i0, i1) return inp.long().squeeze(-1), xored.to(dtype[0]) - trial(model, data, F.binary_cross_entropy_with_logits, loss_win_condition(win_condition_multiplier * 1e-2), steps, - opt[0], dtype[0], size, batch, weight_decay, method[0], length, depth, failure_threshold=10, base_lr=0.001, - trials=trials) + trial( + model, + data, + F.binary_cross_entropy_with_logits, + loss_win_condition(win_condition_multiplier * 1e-2), + steps, + opt[0], + dtype[0], + size, + batch, + weight_decay, + method[0], + length, + depth, + failure_threshold=10, + base_lr=0.001, + trials=trials, + ) -if __name__ == '__main__': +if __name__ == "__main__": app() diff --git a/benchmark/xor_spot.py b/benchmark/xor_spot.py index 7ac2b6b..b22ca82 100644 --- a/benchmark/xor_spot.py +++ b/benchmark/xor_spot.py @@ -6,6 +6,7 @@ 3) Train a model to predict the xor of the two spots This does NOT elicit memory in the RNN, but it does force it to learn a pointwise forget mechanism. """ + import itertools from typing import List @@ -14,8 +15,8 @@ import torch.nn as nn import typer -from heavyball.utils import set_torch from benchmark.utils import loss_win_condition, trial +from heavyball.utils import set_torch app = typer.Typer(pretty_exceptions_enable=False) set_torch() @@ -27,8 +28,10 @@ def __init__(self, size, depth): self.embed = nn.Embedding(4, size) self.enc = nn.LSTM(size, size, depth, batch_first=False) self.enc.flatten_parameters() - self.proj = nn.Sequential(nn.LayerNorm(size), # - nn.Linear(size, 1)) + self.proj = nn.Sequential( + nn.LayerNorm(size), # + nn.Linear(size, 1), + ) def forward(self, inp): inp = self.embed(inp.squeeze(-1).long()) @@ -38,12 +41,21 @@ def forward(self, inp): @app.command() -def main(method: List[str] = typer.Option(['qr'], help='Eigenvector method to use (for SOAP)'), - dtype: List[str] = typer.Option(["float32"], help='Data type to use'), length: int = 64, size: int = 64, - depth: int = 1, batch: int = 256, steps: int = 10, weight_decay: float = 0, - opt: List[str] = typer.Option(['ForeachSOAP', 'PaLMForeachSOAP', 'PrecondScheduleForeachSOAP'], - help='Optimizers to use'), win_condition_multiplier: float = 1.0, - trials: int = 10): +def main( + method: List[str] = typer.Option(["qr"], help="Eigenvector method to use (for SOAP)"), + dtype: List[str] = typer.Option(["float32"], help="Data type to use"), + length: int = 64, + size: int = 64, + depth: int = 1, + batch: int = 256, + steps: int = 10, + weight_decay: float = 0, + opt: List[str] = typer.Option( + ["ForeachSOAP", "PaLMForeachSOAP", "PrecondScheduleForeachSOAP"], help="Optimizers to use" + ), + win_condition_multiplier: float = 1.0, + trials: int = 10, +): dtype = [getattr(torch, d) for d in dtype] for args in itertools.product(method, dtype, [(length, size, depth, batch)], opt, [weight_decay]): @@ -52,18 +64,32 @@ def main(method: List[str] = typer.Option(['qr'], help='Eigenvector method to us model = Model(s, dp).cuda() def data(): - inp = torch.randn((b, l, 1), device='cuda', dtype=d) + inp = torch.randn((b, l, 1), device="cuda", dtype=d) inp = inp > 0 zeros = torch.zeros_like(inp) - zeros[:, torch.randint(0, l, (b,), device='cuda')] = 1 - zeros[:, torch.randint(0, l, (b,), device='cuda')] = 1 + zeros[:, torch.randint(0, l, (b,), device="cuda")] = 1 + zeros[:, torch.randint(0, l, (b,), device="cuda")] = 1 target = (inp * zeros).sum(1) % 2 return torch.stack((inp, zeros + 2), 0).to(d), target.to(d) - trial(model, data, torch.nn.functional.binary_cross_entropy_with_logits, - loss_win_condition(win_condition_multiplier * 1e-2), steps, o, d, s, b, wd, m, l, dp, - failure_threshold=10, trials=trials) + trial( + model, + data, + torch.nn.functional.binary_cross_entropy_with_logits, + loss_win_condition(win_condition_multiplier * 1e-2), + steps, + o, + d, + s, + b, + wd, + m, + l, + dp, + failure_threshold=10, + trials=trials, + ) -if __name__ == '__main__': +if __name__ == "__main__": app() diff --git a/docs/psgd_efficiency.md b/docs/psgd_efficiency.md index a443932..5902b64 100644 --- a/docs/psgd_efficiency.md +++ b/docs/psgd_efficiency.md @@ -29,4 +29,4 @@ the memory overhead. If the doubled memory cost of `CachedPSGDKron` is too high, it's possible to use `CachedPSGDKron` with `triu_as_line=True`, which reduces the total memory cost from 2x `Q` to 1.5x `Q`. -![psgd_efficiency_cache_triu_as_line.png](assets/psgd_efficiency_cache_triu_as_line.png) \ No newline at end of file +![psgd_efficiency_cache_triu_as_line.png](assets/psgd_efficiency_cache_triu_as_line.png) diff --git a/examples/soap.py b/examples/autoencoder.py similarity index 72% rename from examples/soap.py rename to examples/autoencoder.py index aed1f7d..8dc5614 100644 --- a/examples/soap.py +++ b/examples/autoencoder.py @@ -1,7 +1,6 @@ import os from datetime import datetime -import heavyball import matplotlib.pyplot as plt import torch import torch.nn as nn @@ -13,9 +12,12 @@ from torchvision.transforms import v2 from torchvision.utils import make_grid -heavyball.utils.compile_mode = 'default' +import heavyball + +heavyball.utils.compile_mode = "default" heavyball.utils.set_torch() + class Residual(nn.Sequential): def forward(self, input): out = super().forward(input) @@ -23,16 +25,28 @@ def forward(self, input): class Block(nn.Sequential): - def __init__(self, in_features: int, intermediate: int, out_features: int, kernel: int, stride: int, up: bool, - depth: int): + def __init__( + self, + in_features: int, + intermediate: int, + out_features: int, + kernel: int, + stride: int, + up: bool, + depth: int, + ): padding = kernel // 2 layers = [nn.Conv2d(in_features, intermediate, kernel_size=kernel, padding=padding)] for _ in range(depth): - layers.append(Residual(nn.Upsample(scale_factor=stride) if up else nn.MaxPool2d(stride), - nn.BatchNorm2d(intermediate), - nn.ReLU(), - nn.Conv2d(intermediate, intermediate, kernel_size=kernel, padding=padding))) + layers.append( + Residual( + nn.Upsample(scale_factor=stride) if up else nn.MaxPool2d(stride), + nn.BatchNorm2d(intermediate), + nn.ReLU(), + nn.Conv2d(intermediate, intermediate, kernel_size=kernel, padding=padding), + ) + ) layers.append(nn.ReLU()) layers.append(nn.Conv2d(intermediate, out_features, kernel_size=kernel, padding=padding)) @@ -41,7 +55,6 @@ def __init__(self, in_features: int, intermediate: int, out_features: int, kerne class Autoencoder(nn.Module): - def __init__(self, kernel: int = 5, stride: int = 2, hidden: int = 8, intermediate: int = 256): super(Autoencoder, self).__init__() self.enc = Block(1, intermediate, hidden, kernel, stride, False, 5) @@ -57,7 +70,7 @@ def forward(self, x): return out -def plot_samples(model, data, epoch, save_dir='samples'): +def plot_samples(model, data, epoch, save_dir="samples"): os.makedirs(save_dir, exist_ok=True) model.eval() with torch.no_grad(): @@ -67,8 +80,8 @@ def plot_samples(model, data, epoch, save_dir='samples'): grid = make_grid(comparison, nrow=8, normalize=True, padding=2) plt.figure(figsize=(10, 5)) plt.imshow(grid.permute(1, 2, 0)) - plt.axis('off') - plt.savefig(os.path.join(save_dir, f'epoch_{epoch}.png')) + plt.axis("off") + plt.savefig(os.path.join(save_dir, f"epoch_{epoch}.png")) plt.close() model.train() @@ -81,24 +94,28 @@ def __init__(self, amount: int): def forward(self, inp): x = torch.randint(0, self.amount, (inp.size(0),)) y = torch.randint(0, self.amount, (inp.size(0),)) - new = torch.zeros([inp.shape[0], inp.shape[1] + self.amount, inp.shape[2] + self.amount], device=inp.device, dtype=inp.dtype) - new[:, x:x + inp.size(1), y:y + inp.size(2)] = inp + new = torch.zeros( + [inp.shape[0], inp.shape[1] + self.amount, inp.shape[2] + self.amount], + device=inp.device, + dtype=inp.dtype, + ) + new[:, x : x + inp.size(1), y : y + inp.size(2)] = inp return new def main(epochs: int, batch: int): # Setup tensorboard logging - log_dir = os.path.join('runs', f'soap_{datetime.now().strftime("%Y%m%d_%H%M%S")}') + log_dir = os.path.join("runs", f"soap_{datetime.now().strftime('%Y%m%d_%H%M%S')}") writer = SummaryWriter(log_dir) - model = torch.compile(Autoencoder().cuda(), mode='default') + model = torch.compile(Autoencoder().cuda(), mode="default") optimizer = heavyball.SOAP(model.parameters(), lr=1e-3, precondition_frequency=1) # optimizer = heavyball.PSGDKron(optimizer, lr=1e-3, mars=True) # optimizer = heavyball.AdamW(model.parameters(), lr=1e-3, mars=True) transform = v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32), RandomPad(4)]) - trainset = MNIST(root='./data', train=True, download=True, transform=transform) - testset = MNIST(root='./data', train=False, download=True, transform=transform) + trainset = MNIST(root="./data", train=True, download=True, transform=transform) + testset = MNIST(root="./data", train=False, download=True, transform=transform) dataloader = DataLoader(trainset, batch_size=batch, shuffle=True, num_workers=8, drop_last=True, pin_memory=True) testloader = DataLoader(testset, batch_size=batch * 8, shuffle=False, num_workers=1, pin_memory=True) @@ -121,8 +138,8 @@ def _closure(): total_loss = total_loss + loss.detach() avg_loss = (total_loss / len(dataloader)).item() - print(f'epoch [{epoch}/{epochs}], loss:{avg_loss:.4f}') - writer.add_scalar('Loss/train', avg_loss, epoch) + print(f"epoch [{epoch}/{epochs}], loss:{avg_loss:.4f}") + writer.add_scalar("Loss/train", avg_loss, epoch) # Plot samples every 2 epochs if epoch % 2 == 0: @@ -135,10 +152,10 @@ def _closure(): samples = model(eval_batch.cuda()) comparison = torch.cat([eval_batch, samples.cpu()], dim=0) grid = make_grid(comparison, nrow=8, normalize=True, padding=2) - writer.add_image('reconstructions', grid, epoch) + writer.add_image("reconstructions", grid, epoch) model.train() writer.flush() -if __name__ == '__main__': - main(epochs=10, batch=1024 ) +if __name__ == "__main__": + main(epochs=10, batch=1024) diff --git a/heavyball/__init__.py b/heavyball/__init__.py index e47085b..9324458 100644 --- a/heavyball/__init__.py +++ b/heavyball/__init__.py @@ -6,10 +6,24 @@ class ForeachAdamW(C.BaseOpt): - def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0, - foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False, - mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default, - update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8): + def __init__( + self, + params, + lr=0.0025, + betas=(0.9, 0.99), + eps=1e-8, + weight_decay=0, + warmup_steps=0, + foreach: bool = True, + storage_dtype: str = "float32", + mars: bool = False, + caution: bool = False, + mars_gamma: float = 0.0025, + gradient_clipping: C.str_or_fn = C.use_default, + update_clipping: C.str_or_fn = C.use_default, + palm: bool = C.use_default, + beta2_scale: float = 0.8, + ): defaults = locals() defaults.pop("self") params = defaults.pop("params") @@ -21,26 +35,74 @@ class ForeachRMSprop(C.BaseOpt): Debiased RMSprop (not torch.optim.RMSprop) """ - def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-6, weight_decay=0, warmup_steps=0, r=0.0, - weight_lr_power=2.0, foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, - caution: bool = False, mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default, - update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8): + def __init__( + self, + params, + lr=0.0025, + betas=(0.9, 0.99), + eps=1e-6, + weight_decay=0, + warmup_steps=0, + r=0.0, + weight_lr_power=2.0, + foreach: bool = True, + storage_dtype: str = "float32", + mars: bool = False, + caution: bool = False, + mars_gamma: float = 0.0025, + gradient_clipping: C.str_or_fn = C.use_default, + update_clipping: C.str_or_fn = C.use_default, + palm: bool = C.use_default, + beta2_scale: float = 0.8, + ): defaults = locals() defaults.pop("self") params = defaults.pop("params") - super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.scale_by_exp_avg_sq) + super().__init__( + params, + defaults, + foreach, + gradient_clipping, + update_clipping, + palm, + C.scale_by_exp_avg_sq, + ) class ForeachSFAdamW(C.ScheduleFree): - def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-6, weight_decay=0, warmup_steps=0, r=0.0, - weight_lr_power=2.0, foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, - caution: bool = False, mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default, - update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8): + def __init__( + self, + params, + lr=0.0025, + betas=(0.9, 0.99), + eps=1e-6, + weight_decay=0, + warmup_steps=0, + r=0.0, + weight_lr_power=2.0, + foreach: bool = True, + storage_dtype: str = "float32", + mars: bool = False, + caution: bool = False, + mars_gamma: float = 0.0025, + gradient_clipping: C.str_or_fn = C.use_default, + update_clipping: C.str_or_fn = C.use_default, + palm: bool = C.use_default, + beta2_scale: float = 0.8, + ): defaults = locals() defaults.pop("self") params = defaults.pop("params") - super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.scale_by_exp_avg_sq, - C.update_by_schedule_free) + super().__init__( + params, + defaults, + foreach, + gradient_clipping, + update_clipping, + palm, + C.scale_by_exp_avg_sq, + C.update_by_schedule_free, + ) class PaLMForeachSFAdamW(ForeachSFAdamW): @@ -48,10 +110,24 @@ class PaLMForeachSFAdamW(ForeachSFAdamW): class ForeachADOPT(C.BaseOpt): - def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0, - foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False, - mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default, - update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8): + def __init__( + self, + params, + lr=0.0025, + betas=(0.9, 0.99), + eps=1e-8, + weight_decay=0, + warmup_steps=0, + foreach: bool = True, + storage_dtype: str = "float32", + mars: bool = False, + caution: bool = False, + mars_gamma: float = 0.0025, + gradient_clipping: C.str_or_fn = C.use_default, + update_clipping: C.str_or_fn = C.use_default, + palm: bool = C.use_default, + beta2_scale: float = 0.8, + ): defaults = locals() defaults.pop("self") params = defaults.pop("params") @@ -59,23 +135,59 @@ def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay= class ForeachMuon(C.BaseOpt): - def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0, - foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False, - mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default, - update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8, - nesterov: bool = True): + def __init__( + self, + params, + lr=0.0025, + betas=(0.9, 0.99), + eps=1e-8, + weight_decay=0, + warmup_steps=0, + foreach: bool = True, + storage_dtype: str = "float32", + mars: bool = False, + caution: bool = False, + mars_gamma: float = 0.0025, + gradient_clipping: C.str_or_fn = C.use_default, + update_clipping: C.str_or_fn = C.use_default, + palm: bool = C.use_default, + beta2_scale: float = 0.8, + nesterov: bool = True, + ): defaults = locals() defaults.pop("self") params = defaults.pop("params") - super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, - C.nesterov_momentum if nesterov else C.heavyball_momentum, C.orthogonalize_update) + super().__init__( + params, + defaults, + foreach, + gradient_clipping, + update_clipping, + palm, + C.nesterov_momentum if nesterov else C.heavyball_momentum, + C.orthogonalize_update, + ) class ForeachLaProp(C.BaseOpt): - def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0, - foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False, - mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default, - update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8): + def __init__( + self, + params, + lr=0.0025, + betas=(0.9, 0.99), + eps=1e-8, + weight_decay=0, + warmup_steps=0, + foreach: bool = True, + storage_dtype: str = "float32", + mars: bool = False, + caution: bool = False, + mars_gamma: float = 0.0025, + gradient_clipping: C.str_or_fn = C.use_default, + update_clipping: C.str_or_fn = C.use_default, + palm: bool = C.use_default, + beta2_scale: float = 0.8, + ): defaults = locals() defaults.pop("self") params = defaults.pop("params") @@ -83,15 +195,37 @@ def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay= class MuonLaProp(C.BaseOpt): - def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0, - foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False, - mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default, - update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8): + def __init__( + self, + params, + lr=0.0025, + betas=(0.9, 0.99), + eps=1e-8, + weight_decay=0, + warmup_steps=0, + foreach: bool = True, + storage_dtype: str = "float32", + mars: bool = False, + caution: bool = False, + mars_gamma: float = 0.0025, + gradient_clipping: C.str_or_fn = C.use_default, + update_clipping: C.str_or_fn = C.use_default, + palm: bool = C.use_default, + beta2_scale: float = 0.8, + ): defaults = locals() defaults.pop("self") params = defaults.pop("params") - super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.scale_by_laprop, - C.orthogonalize_update) + super().__init__( + params, + defaults, + foreach, + gradient_clipping, + update_clipping, + palm, + C.scale_by_laprop, + C.orthogonalize_update, + ) class ForeachSOAP(C.BaseOpt): @@ -105,16 +239,38 @@ class ForeachSOAP(C.BaseOpt): https://arxiv.org/abs/2409.11321 https://github.com/nikhilvyas/SOAP """ + use_precond_schedule: bool = False - def __init__(self, params, lr: float = 3e-3, betas=(0.9, 0.95), shampoo_beta: float = 0.95, eps: float = 1e-8, - weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, # - merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False, - correct_bias: bool = True, warmup_steps: int = 0, split: bool = False, foreach: bool = True, - mars: bool = False, caution: bool = False, mars_gamma: float = 0.0025, palm: bool = C.use_default, - precond_scheduler=(1 / 3, 9), beta2_scale: float = 0.8, use_precond_schedule: bool = C.use_default, - gradient_clipping: C.str_or_fn = C.use_default, update_clipping: C.str_or_fn = C.use_default, - storage_dtype: str = 'float32', stochastic_schedule: bool = False): + def __init__( + self, + params, + lr: float = 3e-3, + betas=(0.9, 0.95), + shampoo_beta: float = 0.95, + eps: float = 1e-8, + weight_decay: float = 0.01, + precondition_frequency: int = 2, + max_precond_dim: int = 2048, # + merge_dims: bool = True, + precondition_1d: bool = False, + normalize_grads: bool = False, + correct_bias: bool = True, + warmup_steps: int = 0, + split: bool = False, + foreach: bool = True, + mars: bool = False, + caution: bool = False, + mars_gamma: float = 0.0025, + palm: bool = C.use_default, + precond_scheduler=(1 / 3, 9), + beta2_scale: float = 0.8, + use_precond_schedule: bool = C.use_default, + gradient_clipping: C.str_or_fn = C.use_default, + update_clipping: C.str_or_fn = C.use_default, + storage_dtype: str = "float32", + stochastic_schedule: bool = False, + ): use_precond_schedule = C.default(use_precond_schedule, self.use_precond_schedule) defaults = locals() @@ -122,24 +278,54 @@ def __init__(self, params, lr: float = 3e-3, betas=(0.9, 0.95), shampoo_beta: fl params = defaults.pop("params") if use_precond_schedule: - del defaults['precondition_frequency'] + del defaults["precondition_frequency"] self.precond_schedule = utils.get_soap_precond_schedule(defaults.pop("precond_scheduler")) else: - del defaults['precond_scheduler'] + del defaults["precond_scheduler"] self.precond_schedule = 1 / defaults.pop("precondition_frequency") - super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, # - C.scale_by_soap) + super().__init__( + params, + defaults, + foreach, + gradient_clipping, + update_clipping, + palm, # + C.scale_by_soap, + ) class ForeachSignLaProp(C.BaseOpt): - def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0, - foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False, - mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default, - update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8): + def __init__( + self, + params, + lr=0.0025, + betas=(0.9, 0.99), + eps=1e-8, + weight_decay=0, + warmup_steps=0, + foreach: bool = True, + storage_dtype: str = "float32", + mars: bool = False, + caution: bool = False, + mars_gamma: float = 0.0025, + gradient_clipping: C.str_or_fn = C.use_default, + update_clipping: C.str_or_fn = C.use_default, + palm: bool = C.use_default, + beta2_scale: float = 0.8, + ): defaults = locals() defaults.pop("self") params = defaults.pop("params") - super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.scale_by_laprop, C.sign) + super().__init__( + params, + defaults, + foreach, + gradient_clipping, + update_clipping, + palm, + C.scale_by_laprop, + C.sign, + ) class ForeachSOLP(C.BaseOpt): @@ -153,16 +339,38 @@ class ForeachSOLP(C.BaseOpt): https://arxiv.org/abs/2409.11321 https://github.com/nikhilvyas/SOAP """ + use_precond_schedule: bool = False - def __init__(self, params, lr: float = 3e-3, betas=(0.9, 0.95), shampoo_beta: float = 0.95, eps: float = 1e-8, - weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, # - merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False, - correct_bias: bool = True, warmup_steps: int = 0, split: bool = False, foreach: bool = True, - mars: bool = False, caution: bool = False, mars_gamma: float = 0.0025, palm: bool = C.use_default, - precond_scheduler=(1 / 3, 9), beta2_scale: float = 0.8, use_precond_schedule: bool = C.use_default, - gradient_clipping: C.str_or_fn = C.use_default, update_clipping: C.str_or_fn = C.use_default, - storage_dtype: str = 'float32', stochastic_schedule: bool = False): + def __init__( + self, + params, + lr: float = 3e-3, + betas=(0.9, 0.95), + shampoo_beta: float = 0.95, + eps: float = 1e-8, + weight_decay: float = 0.01, + precondition_frequency: int = 2, + max_precond_dim: int = 2048, # + merge_dims: bool = True, + precondition_1d: bool = False, + normalize_grads: bool = False, + correct_bias: bool = True, + warmup_steps: int = 0, + split: bool = False, + foreach: bool = True, + mars: bool = False, + caution: bool = False, + mars_gamma: float = 0.0025, + palm: bool = C.use_default, + precond_scheduler=(1 / 3, 9), + beta2_scale: float = 0.8, + use_precond_schedule: bool = C.use_default, + gradient_clipping: C.str_or_fn = C.use_default, + update_clipping: C.str_or_fn = C.use_default, + storage_dtype: str = "float32", + stochastic_schedule: bool = False, + ): use_precond_schedule = C.default(use_precond_schedule, self.use_precond_schedule) defaults = locals() @@ -170,13 +378,20 @@ def __init__(self, params, lr: float = 3e-3, betas=(0.9, 0.95), shampoo_beta: fl params = defaults.pop("params") if use_precond_schedule: - del defaults['precondition_frequency'] + del defaults["precondition_frequency"] self.precond_schedule = utils.get_soap_precond_schedule(defaults.pop("precond_scheduler")) else: - del defaults['precond_scheduler'] + del defaults["precond_scheduler"] self.precond_schedule = 1 / defaults.pop("precondition_frequency") - super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, # - functools.partial(C.scale_by_soap, inner='laprop')) + super().__init__( + params, + defaults, + foreach, + gradient_clipping, + update_clipping, + palm, # + functools.partial(C.scale_by_soap, inner="laprop"), + ) class PaLMForeachSOAP(ForeachSOAP): @@ -194,27 +409,71 @@ class PrecondSchedulePaLMForeachSOAP(ForeachSOAP): class OrthoLaProp(C.BaseOpt): - def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0, - foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False, - mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default, - update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8): + def __init__( + self, + params, + lr=0.0025, + betas=(0.9, 0.99), + eps=1e-8, + weight_decay=0, + warmup_steps=0, + foreach: bool = True, + storage_dtype: str = "float32", + mars: bool = False, + caution: bool = False, + mars_gamma: float = 0.0025, + gradient_clipping: C.str_or_fn = C.use_default, + update_clipping: C.str_or_fn = C.use_default, + palm: bool = C.use_default, + beta2_scale: float = 0.8, + ): defaults = locals() defaults.pop("self") params = defaults.pop("params") - super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, - C.orthogonalize_grad_to_param, C.scale_by_laprop) + super().__init__( + params, + defaults, + foreach, + gradient_clipping, + update_clipping, + palm, + C.orthogonalize_grad_to_param, + C.scale_by_laprop, + ) class LaPropOrtho(C.BaseOpt): - def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0, - foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False, - mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default, - update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8): + def __init__( + self, + params, + lr=0.0025, + betas=(0.9, 0.99), + eps=1e-8, + weight_decay=0, + warmup_steps=0, + foreach: bool = True, + storage_dtype: str = "float32", + mars: bool = False, + caution: bool = False, + mars_gamma: float = 0.0025, + gradient_clipping: C.str_or_fn = C.use_default, + update_clipping: C.str_or_fn = C.use_default, + palm: bool = C.use_default, + beta2_scale: float = 0.8, + ): defaults = locals() defaults.pop("self") params = defaults.pop("params") - super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.scale_by_laprop, - C.orthogonalize_grad_to_param) + super().__init__( + params, + defaults, + foreach, + gradient_clipping, + update_clipping, + palm, + C.scale_by_laprop, + C.orthogonalize_grad_to_param, + ) class ForeachPSGDKron(C.BaseOpt): @@ -228,20 +487,42 @@ class ForeachPSGDKron(C.BaseOpt): cached: bool = False exp_avg_input: bool = True - def __init__(self, params, lr=0.001, beta=0.9, weight_decay=0.0, preconditioner_update_probability=None, - max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None, - momentum_into_precond_update=True, warmup_steps: int = 0, merge_dims: bool = False, - split: bool = False, store_triu_as_line: bool = True, foreach: bool = True, q_dtype='float32', - stochastic_schedule: bool = False, storage_dtype: str = 'float32', mars: bool = False, - caution: bool = False, mars_gamma: float = 0.0025, delayed: Optional[bool] = C.use_default, - cached: Optional[bool] = C.use_default, exp_avg_input: Optional[bool] = C.use_default, - gradient_clipping: C.str_or_fn = C.use_default, update_clipping: C.str_or_fn = C.use_default, # - # expert parameters - precond_init_scale=1.0, precond_lr=0.1): + def __init__( + self, + params, + lr=0.001, + beta=0.9, + weight_decay=0.0, + preconditioner_update_probability=None, + max_size_triangular=2048, + min_ndim_triangular=2, + memory_save_mode=None, + momentum_into_precond_update=True, + warmup_steps: int = 0, + merge_dims: bool = False, + split: bool = False, + store_triu_as_line: bool = True, + foreach: bool = True, + q_dtype="float32", + stochastic_schedule: bool = False, + storage_dtype: str = "float32", + mars: bool = False, + caution: bool = False, + mars_gamma: float = 0.0025, + delayed: Optional[bool] = C.use_default, + cached: Optional[bool] = C.use_default, + exp_avg_input: Optional[bool] = C.use_default, + gradient_clipping: C.str_or_fn = C.use_default, + update_clipping: C.str_or_fn = C.use_default, # + # expert parameters + precond_init_scale=1.0, + precond_lr=0.1, + ): defaults = locals() defaults.pop("self") - self.precond_schedule = defaults.pop( - "preconditioner_update_probability") or utils.precond_update_prob_schedule() + self.precond_schedule = ( + defaults.pop("preconditioner_update_probability") or utils.precond_update_prob_schedule() + ) params = defaults.pop("params") delayed = C.default(delayed, self.delayed) @@ -249,9 +530,16 @@ def __init__(self, params, lr=0.001, beta=0.9, weight_decay=0.0, preconditioner_ exp_avg_input = C.default(exp_avg_input, self.exp_avg_input) update_clipping = C.default(update_clipping, utils.trust_region_clip_) - super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, False, # - *(C.exp_avg,) * exp_avg_input, # - functools.partial(C.scale_by_delayed_psgd if delayed else C.scale_by_psgd, cached=cached)) + super().__init__( + params, + defaults, + foreach, + gradient_clipping, + update_clipping, + False, # + *(C.exp_avg,) * exp_avg_input, # + functools.partial(C.scale_by_delayed_psgd if delayed else C.scale_by_psgd, cached=cached), + ) class ForeachPurePSGD(ForeachPSGDKron): @@ -294,10 +582,39 @@ class ForeachCachedNewtonPSGD(ForeachCachedPSGDKron): Muon = ForeachMuon SignLaProp = ForeachSignLaProp -__all__ = ["Muon", "RMSprop", "PrecondSchedulePaLMSOAP", "PSGDKron", "PurePSGD", "DelayedPSGD", "CachedPSGDKron", - "CachedDelayedPSGDKron", "PalmForEachSoap", "PaLMSOAP", "PaLMSFAdamW", "LaProp", "ADOPT", - "PrecondScheduleSOAP", "PrecondSchedulePaLMSOAP", 'RMSprop', 'MuonLaProp', 'ForeachSignLaProp' # - "ForeachAdamW", "ForeachSFAdamW", - "ForeachLaProp", "ForeachADOPT", "ForeachSOAP", "ForeachPSGDKron", "ForeachPurePSGD", "ForeachDelayedPSGD", - "ForeachCachedPSGDKron", "ForeachCachedDelayedPSGDKron", "ForeachRMSprop", "ForeachMuon", - 'ForeachCachedNewtonPSGD', 'OrthoLaProp', 'LaPropOrtho', 'SignLaProp'] +__all__ = [ + "Muon", + "RMSprop", + "PrecondSchedulePaLMSOAP", + "PSGDKron", + "PurePSGD", + "DelayedPSGD", + "CachedPSGDKron", + "CachedDelayedPSGDKron", + "PalmForEachSoap", + "PaLMSOAP", + "PaLMSFAdamW", + "LaProp", + "ADOPT", + "PrecondScheduleSOAP", + "PrecondSchedulePaLMSOAP", + "RMSprop", + "MuonLaProp", + "ForeachSignLaProp", # + "ForeachAdamW", + "ForeachSFAdamW", + "ForeachLaProp", + "ForeachADOPT", + "ForeachSOAP", + "ForeachPSGDKron", + "ForeachPurePSGD", + "ForeachDelayedPSGD", + "ForeachCachedPSGDKron", + "ForeachCachedDelayedPSGDKron", + "ForeachRMSprop", + "ForeachMuon", + "ForeachCachedNewtonPSGD", + "OrthoLaProp", + "LaPropOrtho", + "SignLaProp", +] diff --git a/heavyball/chainable.py b/heavyball/chainable.py index d16dcd7..892ec7b 100644 --- a/heavyball/chainable.py +++ b/heavyball/chainable.py @@ -1,6 +1,6 @@ import functools import random -from typing import Optional, Union, Literal, List +from typing import List, Literal, Optional, Union import torch @@ -42,7 +42,7 @@ def __call__(self, state, group, update, grad, param, *args, **kwargs): raise NotImplementedError def get_fn(self): - if hasattr(self.fn, 'get_fn'): + if hasattr(self.fn, "get_fn"): return self.fn.get_fn() return self.fn @@ -55,7 +55,7 @@ def _zero_guard(state, key, ref, dtype): def _storage_dtype(group): - dtype = group.get('storage_dtype', "float32") + dtype = group.get("storage_dtype", "float32") return getattr(torch, dtype) @@ -65,8 +65,10 @@ def __init__(self, fn, names): self.names = names def __call__(self, state, group, update, grad, param, *args, **kwargs): - vars = [[_zero_guard(state(p), self.val_name(name), p, _storage_dtype(group)) for p in param] # - for name in self.names] + vars = [ + [_zero_guard(state(p), self.val_name(name), p, _storage_dtype(group)) for p in param] # + for name in self.names + ] return self.fn(state, group, update, grad, param, *args, *vars, **kwargs) @@ -78,8 +80,10 @@ def __init__(self, fn, index, names): def __call__(self, state, group, update, grad, param, *args, **kwargs): val = [update, grad, param, *args][self.index] - vars = [[_guard_in_state(state(p), self.val_name(name), lambda: torch.clone(v)) for p, v in zip(param, val)] # - for name in self.names] + vars = [ + [_guard_in_state(state(p), self.val_name(name), lambda: torch.clone(v)) for p, v in zip(param, val)] # + for name in self.names + ] return self.fn(state, group, update, grad, param, *args, *vars, **kwargs) @@ -152,145 +156,227 @@ def exp_avg(group, update, grad, param, exp_avg): return utils.scale_by_exp_avg_(exp_avg, update, utils.beta_debias(utils.get_beta1(group), group["step"])) -@zero_guard('exp_avg') +@zero_guard("exp_avg") @no_state def weight_decay_to_ema(group, update, grad, param, exp_avg): - utils.weight_decay_to_ema_(exp_avg, update, utils.beta_debias(group['ema_beta'], group['step']), - group['weight_decay_to_ema'] * group['lr']) + utils.weight_decay_to_ema_( + exp_avg, + update, + utils.beta_debias(group["ema_beta"], group["step"]), + group["weight_decay_to_ema"] * group["lr"], + ) return update -@zero_guard('exp_avg') +@zero_guard("exp_avg") @no_state def l1_weight_decay_to_ema(group, update, grad, param, exp_avg): - utils.l1_weight_decay_to_ema_(exp_avg, update, utils.beta_debias(group['ema_beta'], group['step']), - group['weight_decay_to_ema'] * group['lr']) + utils.l1_weight_decay_to_ema_( + exp_avg, + update, + utils.beta_debias(group["ema_beta"], group["step"]), + group["weight_decay_to_ema"] * group["lr"], + ) return update @zero_guard("exp_avg_sq") @no_state def scale_by_exp_avg_sq(group, update, grad, param, exp_avg_sq): - return utils.scale_by_exp_avg_sq_(exp_avg_sq, update, utils.beta_debias(utils.get_beta2(group), group["step"]), - group['eps']) + return utils.scale_by_exp_avg_sq_( + exp_avg_sq, update, utils.beta_debias(utils.get_beta2(group), group["step"]), group["eps"] + ) @zero_guard("exp_avg", "exp_avg_sq") @no_state def scale_by_adam(group, update, grad, param, exp_avg, exp_avg_sq): - return utils.adam_(exp_avg, exp_avg_sq, update, utils.get_beta1(group), utils.get_beta2(group), group['step'], # - group['eps']) + return utils.adam_( + exp_avg, + exp_avg_sq, + update, + utils.get_beta1(group), + utils.get_beta2(group), + group["step"], # + group["eps"], + ) @zero_guard("exp_avg", "exp_avg_sq") @no_state def update_by_adam(group, update, grad, param, exp_avg, exp_avg_sq): - utils.fused_adam_(param, exp_avg, exp_avg_sq, update, grad, utils.get_beta1(group), utils.get_beta2(group), - group['step'], group['lr'], group['eps'], group['weight_decay'], group['caution']) + utils.fused_adam_( + param, + exp_avg, + exp_avg_sq, + update, + grad, + utils.get_beta1(group), + utils.get_beta2(group), + group["step"], + group["lr"], + group["eps"], + group["weight_decay"], + group["caution"], + ) raise SkipUpdate @zero_guard("exp_avg", "exp_avg_sq") @no_state def scale_by_laprop(group, update, grad, param, exp_avg, exp_avg_sq): - return utils.laprop_(exp_avg, exp_avg_sq, update, utils.get_beta1(group), utils.get_beta2(group), group['step']) + return utils.laprop_(exp_avg, exp_avg_sq, update, utils.get_beta1(group), utils.get_beta2(group), group["step"]) @zero_guard("exp_avg", "exp_avg_sq") @no_state def update_by_laprop(group, update, grad, param, exp_avg, exp_avg_sq): - utils.fused_laprop_(param, exp_avg, exp_avg_sq, update, grad, utils.get_beta1(group), utils.get_beta2(group), - group['step'], group['lr'], group['weight_decay'], group['caution']) + utils.fused_laprop_( + param, + exp_avg, + exp_avg_sq, + update, + grad, + utils.get_beta1(group), + utils.get_beta2(group), + group["step"], + group["lr"], + group["weight_decay"], + group["caution"], + ) raise SkipUpdate @no_state def orthogonalize_grad_to_param(group, update, grad, param): - return utils.orthogonalize_grad_to_param(param, update, group['eps']) + return utils.orthogonalize_grad_to_param(param, update, group["eps"]) @copy_guard(2, "z") @no_state def update_by_schedule_free(group, update, grad, param, z): - group['weight_sum'] = utils.schedule_free_(group['lr'], group['weight_lr_power'], group.get('weight_sum', 0), - utils.get_beta1(group), param, z, update, grad, group['caution'], - group['r'], group['step'], group['weight_decay']) + group["weight_sum"] = utils.schedule_free_( + group["lr"], + group["weight_lr_power"], + group.get("weight_sum", 0), + utils.get_beta1(group), + param, + z, + update, + grad, + group["caution"], + group["r"], + group["step"], + group["weight_decay"], + ) raise SkipUpdate @zero_guard("exp_avg", "exp_avg_sq") @no_state def update_by_adopt(group, update, grad, param, exp_avg, exp_avg_sq): - if group['step'] == 1: - utils.scale_by_exp_avg_sq_(exp_avg_sq, update, 0, group['eps']) + if group["step"] == 1: + utils.scale_by_exp_avg_sq_(exp_avg_sq, update, 0, group["eps"]) raise SkipUpdate - if group['step'] == 2: + if group["step"] == 2: update = utils.promote(update) easq = utils.promote(exp_avg_sq) - [utils.set_(ea, u / easq_.sqrt().clamp_(min=group['eps'])) for ea, u, easq_ in zip(exp_avg, update, easq)] - utils.scale_by_exp_avg_sq_(exp_avg_sq, update, utils.beta_debias(utils.get_beta2(group), group['step']), - group['eps']) + [utils.set_(ea, u / easq_.sqrt().clamp_(min=group["eps"])) for ea, u, easq_ in zip(exp_avg, update, easq)] + utils.scale_by_exp_avg_sq_( + exp_avg_sq, + update, + utils.beta_debias(utils.get_beta2(group), group["step"]), + group["eps"], + ) raise SkipUpdate - utils.fused_adopt_(param, update, grad, exp_avg_sq, exp_avg, utils.get_beta1(group), utils.get_beta2(group), - group['step'] - 2, group['lr'], group['eps'], group['weight_decay'], group['caution']) + utils.fused_adopt_( + param, + update, + grad, + exp_avg_sq, + exp_avg, + utils.get_beta1(group), + utils.get_beta2(group), + group["step"] - 2, + group["lr"], + group["eps"], + group["weight_decay"], + group["caution"], + ) raise SkipUpdate @zero_guard("exp_avg", "exp_avg_sq") @no_state def scale_by_adopt(group, update, grad, param, exp_avg, exp_avg_sq): - if group['step'] == 1: - utils.scale_by_exp_avg_sq_(exp_avg_sq, update, 0, group['eps']) + if group["step"] == 1: + utils.scale_by_exp_avg_sq_(exp_avg_sq, update, 0, group["eps"]) raise SkipUpdate - if group['step'] == 2: + if group["step"] == 2: update = utils.promote(update) easq = utils.promote(exp_avg_sq) - [utils.set_(ea, u / easq_.sqrt().clamp_(min=group['eps'])) for ea, u, easq_ in zip(exp_avg, update, easq)] - utils.scale_by_exp_avg_sq_(exp_avg_sq, update, utils.beta_debias(utils.get_beta2(group), group['step']), - group['eps']) + [utils.set_(ea, u / easq_.sqrt().clamp_(min=group["eps"])) for ea, u, easq_ in zip(exp_avg, update, easq)] + utils.scale_by_exp_avg_sq_( + exp_avg_sq, + update, + utils.beta_debias(utils.get_beta2(group), group["step"]), + group["eps"], + ) raise SkipUpdate - return utils.adopt(update, exp_avg_sq, exp_avg, utils.get_beta1(group), utils.get_beta2(group), group['step'] - 2) + return utils.adopt( + update, + exp_avg_sq, + exp_avg, + utils.get_beta1(group), + utils.get_beta2(group), + group["step"] - 2, + ) -def _init_soap(state, group, update, grad, param, inner: str = ''): - utils.init_preconditioner(grad, state, group['max_precond_dim'], group['precondition_1d']) +def _init_soap(state, group, update, grad, param, inner: str = ""): + utils.init_preconditioner(grad, state, group["max_precond_dim"], group["precondition_1d"]) def _init_psgd(state, group, update, grad, param, cached: bool = False, prob: Optional[callable] = None): - Q, state["exprs"] = utils.init_Q_exprs(grad, group['precond_init_scale'], group['max_size_triangular'], - group['min_ndim_triangular'], group['memory_save_mode'], - dtype=getattr(torch, group['q_dtype'])) - state["Q"] = utils.triu_to_line(Q) if group['store_triu_as_line'] else Q + Q, state["exprs"] = utils.init_Q_exprs( + grad, + group["precond_init_scale"], + group["max_size_triangular"], + group["min_ndim_triangular"], + group["memory_save_mode"], + dtype=getattr(torch, group["q_dtype"]), + ) + state["Q"] = utils.triu_to_line(Q) if group["store_triu_as_line"] else Q if not cached: return - state['Q_cache'] = [torch.empty_like(q) for q in Q] + state["Q_cache"] = [torch.empty_like(q) for q in Q] - expr = [f'{c.upper()}{c}' if q_.ndim == 2 else c for c, q_ in zip(utils.einsum_base, Q)] - expr = ','.join(expr) - grad_expr = ''.join(c for c, _ in zip(utils.einsum_base, grad.shape)) - out_expr = ''.join(c.upper() if c.upper() in expr else c for c in grad_expr) - expr = f'{expr},{grad_expr}->{out_expr}' + expr = [f"{c.upper()}{c}" if q_.ndim == 2 else c for c, q_ in zip(utils.einsum_base, Q)] + expr = ",".join(expr) + grad_expr = "".join(c for c, _ in zip(utils.einsum_base, grad.shape)) + out_expr = "".join(c.upper() if c.upper() in expr else c for c in grad_expr) + expr = f"{expr},{grad_expr}->{out_expr}" - state['cache_expr'] = expr + state["cache_expr"] = expr -def precond_schedule(group, prob: Union[callable, float, None] = None, name: str = 'cumulative_prob'): - step = group['step'] - if 'precondition_frequency' in group: - return step > 0 and step % group['precondition_frequency'] == 0 +def precond_schedule(group, prob: Union[callable, float, None] = None, name: str = "cumulative_prob"): + step = group["step"] + if "precondition_frequency" in group: + return step > 0 and step % group["precondition_frequency"] == 0 if isinstance(step, torch.Tensor): utils.warn_once("Preconditioner schedule is not supported with torch.Tensor step.") rng = random.Random(0x172381) else: rng = random.Random(0x172381 ^ step) - if 'precond_scheduler' in group: - return utils.precond_schedule(step, group['precond_scheduler'], rng) + if "precond_scheduler" in group: + return utils.precond_schedule(step, group["precond_scheduler"], rng) if prob is not None: return utils.psgd_should_update(group, prob, rng, name=name) raise ValueError("No preconditioner update schedule specified.") @@ -313,14 +399,14 @@ def nesterov_momentum(group, updates, grads, params, momentum): return utils.nesterov_momentum(momentum, updates, utils.get_beta1(group)) -@zero_guard('momentum') +@zero_guard("momentum") @no_state def nesterov_ema(group, updates, grads, params, momentum): # equivalent to Grokfast return utils.nesterov_ema(momentum, updates, utils.get_beta1(group)) def _store_std(state, group, update, grad, param): - state['init_std'] = torch.std(grad, dim=0) + state["init_std"] = torch.std(grad, dim=0) @general_guard("init_std", init_fn=_store_std) @@ -338,25 +424,39 @@ def heavyball_momentum(group, updates, grads, params, momentum): return utils.heavyball_momentum(momentum, updates, utils.get_beta1(group)) -_optim_fns = {'adam': utils.adam_, 'laprop': utils.laprop_} +_optim_fns = {"adam": utils.adam_, "laprop": utils.laprop_} @zero_guard("exp_avg", "exp_avg_sq") @general_guard("Q", "GG", init_fn=_init_soap) @no_state -def scale_by_soap(group, update, grad, param, exp_avg, exp_avg_sq, Q, GG, inner: str = 'adam'): +def scale_by_soap(group, update, grad, param, exp_avg, exp_avg_sq, Q, GG, inner: str = "adam"): update = utils.promote(update) # Promote to highest precision if needed grad_projected = [utils.project(u, q, False) for u, q in zip(update, Q)] fn = _optim_fns[inner] - precond = fn(exp_avg, exp_avg_sq, grad_projected, utils.get_beta1(group), utils.get_beta2(group), group['step'] - 1, - group['eps']) + precond = fn( + exp_avg, + exp_avg_sq, + grad_projected, + utils.get_beta1(group), + utils.get_beta2(group), + group["step"] - 1, + group["eps"], + ) precond = [utils.project(p, q, True) for p, q in zip(precond, Q)] for u, q, gg, ea in zip(update, Q, GG, exp_avg): - utils.update_preconditioner(u, q, gg, ea, group['max_precond_dim'], group['precondition_1d'], - utils.beta_debias(group['shampoo_beta'], group['step']), - group['is_preconditioning']) + utils.update_preconditioner( + u, + q, + gg, + ea, + group["max_precond_dim"], + group["precondition_1d"], + utils.beta_debias(group["shampoo_beta"], group["step"]), + group["is_preconditioning"], + ) return precond @@ -364,17 +464,24 @@ def _update_psgd_precond(cached, Q_cache, group, param, grad, Q_mat, Q, exprs, p if prob is None: prob = utils.precond_update_prob_schedule() - if not group['is_preconditioning']: + if not group["is_preconditioning"]: return Q_mat - utils.psgd_update_precond(Q_mat, exprs, getattr(param, 'hessian_vector', grad), group['precond_lr'], Q, - group['store_triu_as_line'], getattr(param, 'vector', None)) - if hasattr(param, 'vector'): + utils.psgd_update_precond( + Q_mat, + exprs, + getattr(param, "hessian_vector", grad), + group["precond_lr"], + Q, + group["store_triu_as_line"], + getattr(param, "vector", None), + ) + if hasattr(param, "vector"): del param.vector del param.hessian_vector if grad.dim() > 1 and precond_schedule(group, balance_probability, f"balance_prob_{id(Q)}"): - if group['store_triu_as_line']: + if group["store_triu_as_line"]: utils.psgd_balance_Q([q_ for _, q_ in Q]) else: utils.psgd_balance_Q(Q) @@ -382,8 +489,8 @@ def _update_psgd_precond(cached, Q_cache, group, param, grad, Q_mat, Q, exprs, p if isinstance(prob, float): float_prob = prob else: - float_prob = prob(group.get(f'cumulative_prob_{id(Q)}_prob_step', 1)) - group['is_cached'] = should_use_cache = cached and float_prob < 0.5 + float_prob = prob(group.get(f"cumulative_prob_{id(Q)}_prob_step", 1)) + group["is_cached"] = should_use_cache = cached and float_prob < 0.5 if should_use_cache: # caching adds extra ops and is not worth the overhead when we precondition at every step return _update_psgd_cache(cached, Q_cache, Q_mat) @@ -403,51 +510,124 @@ def _update_psgd_cache(cached, Q_cache, q): def _cached_psgd_precond_grad(group, cache_expr, exprs, update, Q_mat, Q_cache, grad): - if group.get('is_cached', False): - out = utils.precond_grad_cached_(cache_expr, update, *Q_cache, caution=group['caution'], grad=grad) - out = utils.psgd_precond_grad(exprs[-1], update, *Q_mat, caution=group['caution'], grad=grad) - group['caution'] = False # we already cautioned here - shouldn't do it again + if group.get("is_cached", False): + out = utils.precond_grad_cached_(cache_expr, update, *Q_cache, caution=group["caution"], grad=grad) + out = utils.psgd_precond_grad(exprs[-1], update, *Q_mat, caution=group["caution"], grad=grad) + group["caution"] = False # we already cautioned here - shouldn't do it again return out def _fused_cached_psgd_precond_grad(group, grad, param, cache_expr, exprs, update, Q_mat, Q_cache): - if group.get('is_cached', False): - utils.fused_precond_grad_cached_(cache_expr, update, param, group['lr'], grad, group['weight_decay'], - group['caution'], *Q_cache) + if group.get("is_cached", False): + utils.fused_precond_grad_cached_( + cache_expr, + update, + param, + group["lr"], + grad, + group["weight_decay"], + group["caution"], + *Q_cache, + ) else: - utils.fused_psgd_precond_grad(exprs[-1], update, param, group['lr'], grad, group['weight_decay'], - group['caution'], *Q_mat) + utils.fused_psgd_precond_grad( + exprs[-1], + update, + param, + group["lr"], + grad, + group["weight_decay"], + group["caution"], + *Q_mat, + ) @general_guard("Q", "exprs", ("Q_cache", None), ("cache_expr", None), init_fn=_init_psgd, skip_first=False) @no_state_no_foreach -def scale_by_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_expr: str, cached: bool = False, - prob: Optional[callable] = None): +def scale_by_psgd( + group, + update, + grad, + param, + Q, + exprs, + Q_cache, + cache_expr: str, + cached: bool = False, + prob: Optional[callable] = None, +): update = update.to(memory_format=torch.contiguous_format) - Q_mat = utils.line_to_triu(Q) if group['store_triu_as_line'] else Q - Q_mat = _update_psgd_precond(cached, Q_cache, group, param, - update if group['momentum_into_precond_update'] else grad, Q_mat, Q, exprs, prob) + Q_mat = utils.line_to_triu(Q) if group["store_triu_as_line"] else Q + Q_mat = _update_psgd_precond( + cached, + Q_cache, + group, + param, + update if group["momentum_into_precond_update"] else grad, + Q_mat, + Q, + exprs, + prob, + ) return _cached_psgd_precond_grad(group, cache_expr, exprs, update, Q_mat, Q_cache, grad) @general_guard("Q", "exprs", ("Q_cache", None), ("cache_expr", None), init_fn=_init_psgd, skip_first=False) @no_state_no_foreach -def scale_by_delayed_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_expr: str, cached: bool = False, - prob: Optional[callable] = None): - Q_mat = utils.line_to_triu(Q) if group['store_triu_as_line'] else Q +def scale_by_delayed_psgd( + group, + update, + grad, + param, + Q, + exprs, + Q_cache, + cache_expr: str, + cached: bool = False, + prob: Optional[callable] = None, +): + Q_mat = utils.line_to_triu(Q) if group["store_triu_as_line"] else Q precond = _cached_psgd_precond_grad(group, cache_expr, exprs, update, Q_mat, Q_cache, grad) - _ = _update_psgd_precond(cached, Q_cache, group, param, update if group['momentum_into_precond_update'] else grad, - Q_mat, Q, exprs, prob) + _ = _update_psgd_precond( + cached, + Q_cache, + group, + param, + update if group["momentum_into_precond_update"] else grad, + Q_mat, + Q, + exprs, + prob, + ) return precond @general_guard("Q", "exprs", ("Q_cache", None), ("cache_expr", None), init_fn=_init_psgd, skip_first=False) @no_state_no_foreach -def update_by_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_expr: str, cached: bool = False, - prob: Optional[callable] = None): - Q_mat = utils.line_to_triu(Q) if group['store_triu_as_line'] else Q - Q_mat = _update_psgd_precond(cached, Q_cache, group, param, - update if group['momentum_into_precond_update'] else grad, Q_mat, Q, exprs, prob) +def update_by_psgd( + group, + update, + grad, + param, + Q, + exprs, + Q_cache, + cache_expr: str, + cached: bool = False, + prob: Optional[callable] = None, +): + Q_mat = utils.line_to_triu(Q) if group["store_triu_as_line"] else Q + Q_mat = _update_psgd_precond( + cached, + Q_cache, + group, + param, + update if group["momentum_into_precond_update"] else grad, + Q_mat, + Q, + exprs, + prob, + ) _fused_cached_psgd_precond_grad(group, update, param, cache_expr, exprs, update, Q_mat, Q_cache) raise SkipUpdate @@ -459,18 +639,37 @@ def sign(group, update, grad, param, graft: bool = True): @general_guard("Q", "exprs", ("Q_cache", None), ("cache_expr", None), init_fn=_init_psgd, skip_first=False) @no_state_no_foreach -def update_by_delayed_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_expr: str, cached: bool = False, - prob: Optional[callable] = None): - Q_mat = utils.line_to_triu(Q) if group['store_triu_as_line'] else Q +def update_by_delayed_psgd( + group, + update, + grad, + param, + Q, + exprs, + Q_cache, + cache_expr: str, + cached: bool = False, + prob: Optional[callable] = None, +): + Q_mat = utils.line_to_triu(Q) if group["store_triu_as_line"] else Q _fused_cached_psgd_precond_grad(group, update, param, cache_expr, exprs, update, Q_mat, Q_cache) - _ = _update_psgd_precond(cached, Q_cache, group, param, update if group['momentum_into_precond_update'] else grad, - Q_mat, Q, exprs, prob) + _ = _update_psgd_precond( + cached, + Q_cache, + group, + param, + update if group["momentum_into_precond_update"] else grad, + Q_mat, + Q, + exprs, + prob, + ) raise SkipUpdate def palm_beta2(state, group, update, grad, param): - beta2 = 1 - group['step'] ** -group['beta2_scale'] - group['betas'] = (utils.get_beta1(group), beta2) + beta2 = 1 - group["step"] ** -group["beta2_scale"] + group["betas"] = (utils.get_beta1(group), beta2) return update @@ -499,7 +698,7 @@ def chain(state: Union[callable, dict], group, grad, param, *fns): update = [torch.clone(g, memory_format=torch.preserve_format) for g in grad] update, skip_update = _inner_chain(state, group, update, grad, param, *fns) if not skip_update and update is not None: - utils.update_param_(param, update, group['lr'], group['weight_decay'], caution=group['caution'], grad=grad) + utils.update_param_(param, update, group["lr"], group["weight_decay"], caution=group["caution"], grad=grad) def create_branch(branches: List[List[callable]], merge_fn: callable): @@ -524,14 +723,16 @@ def __init__(self, params, defaults, foreach: bool, *fns): self.fns = tuple(fns) def _step(self, group): - if 'base_lr' not in group: - group['base_lr'] = group['lr'] - if 'prev_lr' in group and group['prev_lr'] != group['lr']: - utils.warn_once(f'Learning rate changed between steps. This is an experimental feature and ' - f'only supported with foreach=True (currently foreach={group["foreach"]}).') - group['base_lr'] = group['lr'] + if "base_lr" not in group: + group["base_lr"] = group["lr"] + if "prev_lr" in group and group["prev_lr"] != group["lr"]: + utils.warn_once( + f"Learning rate changed between steps. This is an experimental feature and " + f"only supported with foreach=True (currently foreach={group['foreach']})." + ) + group["base_lr"] = group["lr"] - caution = group['caution'] + caution = group["caution"] vals = list(self.split_p_and_g_in_group(group, should_promote=self.promote, beta1=utils.get_beta1(group))) @@ -541,26 +742,26 @@ def _step(self, group): for param in p: state = self.state_(param) - if 'step' in state: - step = state['step'] + if "step" in state: + step = state["step"] elif self.compile_step: step = utils.scalar_guard(0, param) else: step = 0 break - group['step'] = state['step'] = step = step + 1 - group['prev_lr'] = group['lr'] = group['base_lr'] * step / max(step, group['warmup_steps'] + 1) + group["step"] = state["step"] = step = step + 1 + group["prev_lr"] = group["lr"] = group["base_lr"] * step / max(step, group["warmup_steps"] + 1) - if not group['foreach'] or len(p) == 1: + if not group["foreach"] or len(p) == 1: for param, grad in zip(p, g): chain(self.state_, group, [grad], [param], *self.fns) else: chain(self.state_, group, g, p, *self.fns) - group['caution'] = caution - group['lr'] = group['prev_lr'] - group['step'] = None + group["caution"] = caution + group["lr"] = group["prev_lr"] + group["step"] = None use_default = object() @@ -571,7 +772,13 @@ def _get_clip_fn(name: str_or_fn, default_val: str_or_fn): name = default(name, default_val) if callable(name): return name - elif name not in ('l2_clip_', 'rmsnorm_clip_', 'trust_region_clip_', 'a_law_compress', 'mu_law_compress'): + elif name not in ( + "l2_clip_", + "rmsnorm_clip_", + "trust_region_clip_", + "a_law_compress", + "mu_law_compress", + ): raise ValueError(f"Clipping function {name} not found") return getattr(utils, name) @@ -581,16 +788,20 @@ def default(a, b): # not supported: update_by_schedule_free, scale_by_soap, scale_by_exp_avg_sq -_scale_to_update_map = {scale_by_delayed_psgd.get_fn(): update_by_delayed_psgd, # - scale_by_psgd.get_fn(): update_by_psgd, # - scale_by_adam.get_fn(): update_by_adam, # - scale_by_laprop.get_fn(): update_by_laprop, # - scale_by_adopt.get_fn(): update_by_adopt} -_scale_to_update_map_inv = {update_by_delayed_psgd.get_fn(): scale_by_delayed_psgd, # - update_by_psgd.get_fn(): scale_by_psgd, # - update_by_adam.get_fn(): scale_by_adam, # - update_by_laprop.get_fn(): scale_by_laprop, # - update_by_adopt.get_fn(): scale_by_adopt} +_scale_to_update_map = { + scale_by_delayed_psgd.get_fn(): update_by_delayed_psgd, # + scale_by_psgd.get_fn(): update_by_psgd, # + scale_by_adam.get_fn(): update_by_adam, # + scale_by_laprop.get_fn(): update_by_laprop, # + scale_by_adopt.get_fn(): update_by_adopt, +} +_scale_to_update_map_inv = { + update_by_delayed_psgd.get_fn(): scale_by_delayed_psgd, # + update_by_psgd.get_fn(): scale_by_psgd, # + update_by_adam.get_fn(): scale_by_adam, # + update_by_laprop.get_fn(): scale_by_laprop, # + update_by_adopt.get_fn(): scale_by_adopt, +} class BaseOpt(ChainOpt): @@ -622,8 +833,18 @@ class BaseOpt(ChainOpt): palm: bool = False auto_fuse: bool = True - def __init__(self, params, defaults, foreach: bool, gradient_clipping: str_or_fn, update_clipping: str_or_fn, - palm: bool = use_default, *fns, compile_step: bool = use_default, promote: bool = use_default): + def __init__( + self, + params, + defaults, + foreach: bool, + gradient_clipping: str_or_fn, + update_clipping: str_or_fn, + palm: bool = use_default, + *fns, + compile_step: bool = use_default, + promote: bool = use_default, + ): if not fns: raise ValueError("No functions provided. If that's on purpose (SGD-like), use `identity`") @@ -643,8 +864,10 @@ def __init__(self, params, defaults, foreach: bool, gradient_clipping: str_or_fn fns = tuple(fns)[:-1] + (fn,) elif fn in _scale_to_update_map_inv: if not self.auto_fuse: - raise ValueError("update_clipping is currently not compatible with update_by_* functions. " - "Manually select scale_by_* functions or set auto_fuse=True.") + raise ValueError( + "update_clipping is currently not compatible with update_by_* functions. " + "Manually select scale_by_* functions or set auto_fuse=True." + ) fn = _scale_to_update_map_inv[fn] if args is not None: fn = functools.partial(fn, *args, **kwargs) @@ -665,27 +888,27 @@ def __init__(self, params, defaults, foreach: bool, gradient_clipping: str_or_fn class ScheduleFree(BaseOpt): def eval(self): for group in self.param_groups: - group['train_mode'] = train_mode = not group.get('train_mode') + group["train_mode"] = train_mode = not group.get("train_mode") beta1 = utils.get_beta1(group) if beta1 > 0 and not train_mode: - for p in group['params']: + for p in group["params"]: state = self.state_(p) - if 'z' in state: + if "z" in state: # Set p.data to x - z = utils.promote(state['z']) + z = utils.promote(state["z"]) p32 = utils.promote(p.data) p32.lerp_(end=z, weight=1 - 1 / beta1) utils.copy_stochastic_(p.data, p32) def train(self): for group in self.param_groups: - group['train_mode'] = train_mode = not group.get('train_mode') + group["train_mode"] = train_mode = not group.get("train_mode") beta1 = utils.get_beta1(group) if beta1 > 0 and train_mode: - for p in group['params']: + for p in group["params"]: state = self.state_(p) - if 'z' in state: - z = utils.promote(state['z']) + if "z" in state: + z = utils.promote(state["z"]) p32 = utils.promote(p.data) p32.lerp_(end=z, weight=1 - beta1) utils.copy_stochastic_(p.data, p32) diff --git a/heavyball/utils.py b/heavyball/utils.py index ab3eb00..3c76bc9 100644 --- a/heavyball/utils.py +++ b/heavyball/utils.py @@ -4,7 +4,7 @@ import random import string import warnings -from typing import List, Optional, Tuple, Callable, Union +from typing import Callable, List, Optional, Tuple, Union from unittest.mock import patch import numpy as np @@ -15,27 +15,20 @@ from torch.backends import cudnn, opt_einsum from torch.utils._pytree import tree_map -config.cache_size_limit = 2 ** 16 - -np.warnings = warnings +config.cache_size_limit = 2**16 compile_mode = "max-autotune-no-cudagraphs" dynamic = False compile_mode_recommended_to_none = None -zeroth_power_mode = 'qr' # 'qr' is baseline, 'newtonschulz' converges better and faster +zeroth_power_mode = "qr" # 'qr' is baseline, 'newtonschulz' converges better and faster tiny_bf16 = torch.finfo(torch.bfloat16).tiny -base_args = {'betas': (0.9, 0.999), 'precondition_frequency': 1, 'merge_dims': False, 'warmup_steps': 100, - 'max_precond_dim': 2 ** 16, 'beta': 0.9, 'max_size_triangular': 2 ** 16, 'split': False, 'eps': 1e-8, - 'weight_decay': 1e-4} - def decorator(func): compiled = None @functools.wraps(func) def _fn(*args, **kwargs): - disable = compile_mode_recommended_to_none is None if is_compiling() or compile_mode_recommended_to_none is None: return func(*args, **kwargs) nonlocal compiled @@ -65,8 +58,17 @@ def _fn(*args, **kwargs): @decorator_knowngood -def _compilable_schedule_free_(p: List[Tensor], z: List[Tensor], ckp1: Tensor, update: List[Tensor], lr: Tensor, - beta1: Tensor, decay: float, grad: List[Tensor], caution): +def _compilable_schedule_free_( + p: List[Tensor], + z: List[Tensor], + ckp1: Tensor, + update: List[Tensor], + lr: Tensor, + beta1: Tensor, + decay: float, + grad: List[Tensor], + caution, +): for op, oz, u_, g_ in zip(p, z, update, grad): u_ = u_.view_as(op) p_, z_, u_ = map(promote, (op, oz, u_)) @@ -81,9 +83,20 @@ def _compilable_schedule_free_(p: List[Tensor], z: List[Tensor], ckp1: Tensor, u copy_stochastic_(oz, z_) -def schedule_free_(lr: float, weight_lr_power: float, weight_sum: float, beta1: float, parameters: List[Tensor], - z: List[Tensor], update: List[Tensor], grad: List[Tensor], caution: bool = False, r: float = 0.0, - step: int = 0, decay: float = 0.0): +def schedule_free_( + lr: float, + weight_lr_power: float, + weight_sum: float, + beta1: float, + parameters: List[Tensor], + z: List[Tensor], + update: List[Tensor], + grad: List[Tensor], + caution: bool = False, + r: float = 0.0, + step: int = 0, + decay: float = 0.0, +): weight = abs(lr) ** weight_lr_power * max(step, 1) ** r weight_sum = weight_sum + weight @@ -156,7 +169,7 @@ def dim_merger(grad, max_precond_dim, split: bool = False): def beta_debias(beta, step): - return 1 - (1 - beta) / (1 - beta ** step) + return 1 - (1 - beta) / (1 - beta**step) def eps_sqrt(item, eps): @@ -164,8 +177,9 @@ def eps_sqrt(item, eps): @decorator_knowngood -def _compilable_exp_avg_sq_(state: List[Tensor], grad: List[Tensor], beta2: Tensor, eps: Tensor, - out: List[Optional[Tensor]]): +def _compilable_exp_avg_sq_( + state: List[Tensor], grad: List[Tensor], beta2: Tensor, eps: Tensor, out: List[Optional[Tensor]] +): g32 = promote(grad) s32 = _lerp(state, torch._foreach_mul(g32, g32), beta2) @@ -226,8 +240,9 @@ def _compilable_agc_(parameters: List[Tensor], gradients: List[Tensor], clip_val copy_stochastic_list_(gradients, g32) -def adaptive_gradient_clipping_(parameters: List[Tensor], gradients: List[Tensor], clip_val: float, - minimum: float = 1e-3, eps: float = 1e-8): +def adaptive_gradient_clipping_( + parameters: List[Tensor], gradients: List[Tensor], clip_val: float, minimum: float = 1e-3, eps: float = 1e-8 +): if clip_val <= 0: return gradients parameters, gradients = list_guard(parameters, gradients) @@ -253,23 +268,24 @@ def clean(): def _ignore_warning(msg): - warnings.filterwarnings('ignore', f'.*{msg}.*') + warnings.filterwarnings("ignore", f".*{msg}.*") -def set_torch(benchmark_limit: int = 32): +def set_torch(benchmark_limit: int = 32, einsum_strategy: str = "auto"): cudnn.benchmark = True cudnn.deterministic = False cudnn.benchmark_limit = benchmark_limit torch.use_deterministic_algorithms(False) torch.set_float32_matmul_precision("high") # highest: FP32, high: TF32, medium: bf16 - opt_einsum.enabled = False - opt_einsum.strategy = "auto" + opt_einsum.set_flags(True, einsum_strategy) # Torch calls these for 2nd-order optimization in HeavyBall, but they are explicitly handled. _ignore_warning( - 'Using backward() with create_graph=True will create a reference cycle between the parameter and its gradient which can cause a memory leak') + "Using backward() with create_graph=True will create a reference cycle between the parameter and its gradient which can cause a memory leak" + ) _ignore_warning( - 'We recommend using autograd.grad when creating the graph to avoid this. If you have to use this function, make sure to reset the .grad fields of your parameters to None after use to break the cycle and avoid the leak') + "We recommend using autograd.grad when creating the graph to avoid this. If you have to use this function, make sure to reset the .grad fields of your parameters to None after use to break the cycle and avoid the leak" + ) @decorator @@ -277,7 +293,7 @@ def zeropower_via_newtonschulz5(G, steps=5, eps=1e-7): assert len(G.shape) == 2 a, b, c = (3.4445, -4.7750, 2.0315) X = G.to(torch.bfloat16 if G.dtype != torch.float64 else G.dtype) # Preserve float64 if present - X /= (X.norm() + eps) # ensure top singular value <= 1 + X /= X.norm() + eps # ensure top singular value <= 1 if G.size(0) > G.size(1): X = X.T for _ in range(steps): @@ -290,10 +306,10 @@ def zeropower_via_newtonschulz5(G, steps=5, eps=1e-7): def ortho(x): - if zeroth_power_mode == 'qr': + if zeroth_power_mode == "qr": return torch.linalg.qr(x).Q - if zeroth_power_mode == 'svd': - u, s, v = torch.linalg.svd(x) + if zeroth_power_mode == "svd": + u, _s, v = torch.linalg.svd(x) return u @ v.T raise NotImplementedError(f"Unknown zeroth_power_mode: {zeroth_power_mode}") @@ -351,12 +367,12 @@ def _compilable_grafting(magnitude, direction): @decorator_knowngood def inplace_orthogonal_(x: Tensor, mode: str, out: Tensor, scale_mode: str): - if mode == 'newtonschulz' or x.shape[0] != x.shape[1]: + if mode == "newtonschulz" or x.shape[0] != x.shape[1]: y = zeropower_via_newtonschulz5(x, 5) - elif mode == 'qr': + elif mode == "qr": y = torch.linalg.qr(promote(x)).Q - elif mode == 'svd': - u, s, v = torch.linalg.svd(promote(x)) + elif mode == "svd": + u, _s, v = torch.linalg.svd(promote(x)) y = u @ v.T else: raise NotImplementedError(f"Unknown zeroth_power_mode: {mode}") @@ -403,7 +419,7 @@ def get_orthogonal_matrix_QR(GG: List[Tensor], Q: List[Tensor], exp_avg: Optiona q_old = promote(q.data) tmp = m @ q_old - est_eig = torch.einsum('ij,ij->j', q_old, tmp) + est_eig = torch.einsum("ij,ij->j", q_old, tmp) sort_idx = torch.argsort(est_eig, descending=True) tmp[:, sort_idx], _ = torch.linalg.qr(tmp[:, sort_idx]) @@ -415,19 +431,20 @@ def get_orthogonal_matrix_QR(GG: List[Tensor], Q: List[Tensor], exp_avg: Optiona return assert exp_avg.ndim < 13, "exp_avg.ndim must be less than 13" - in_str = einsum_base[:exp_avg.dim()] - out_str = einsum_base[exp_avg.dim():2 * exp_avg.dim()] + in_str = einsum_base[: exp_avg.dim()] + out_str = einsum_base[exp_avg.dim() : 2 * exp_avg.dim()] from_shampoo = ",".join([o + i for m, i, o in zip(Q, in_str, in_str.upper()) if m is not None]) if not from_shampoo: return - to_shampoo = ','.join([i + o for m, i, o in zip(new_qs, in_str.upper(), out_str) if m is not None]) - out_str = ''.join([o if o in to_shampoo else i for i, o in zip(in_str, out_str)]) + to_shampoo = ",".join([i + o for m, i, o in zip(new_qs, in_str.upper(), out_str) if m is not None]) + out_str = "".join([o if o in to_shampoo else i for i, o in zip(in_str, out_str)]) - subscripts = f'{in_str},{from_shampoo},{to_shampoo}->{out_str}' - exp_avg_new = torch.einsum(subscripts, exp_avg, *[q for q in Q if q is not None], - *[q for q in new_qs if q is not None]) + subscripts = f"{in_str},{from_shampoo},{to_shampoo}->{out_str}" + exp_avg_new = torch.einsum( + subscripts, exp_avg, *[q for q in Q if q is not None], *[q for q in new_qs if q is not None] + ) copy_stochastic_(exp_avg, exp_avg_new) for q, q_new in zip(Q, new_qs): @@ -453,11 +470,11 @@ def get_orthogonal_matrix(mat, max_eps: float = 1e-3, min_eps: float = 1e-30): while True: try: eye = torch.eye(m.shape[0], device=m.device, dtype=m.dtype) - eigval, eigvec = torch.linalg.eigh(m + eps * eye) + _eigval, eigvec = torch.linalg.eigh(m + eps * eye) eigvec = eigvec.to(device=device, dtype=dtype) break except torch.OutOfMemoryError: - if m.device.type == 'cpu': + if m.device.type == "cpu": raise else: m = m.cpu() @@ -489,21 +506,21 @@ def _compilable_stochastic_lerp_(x: List[Tensor], y: List[Tensor], a: Union[floa def get_beta1(group): beta = None - if 'beta' in group: - beta = group['beta'] - if beta is None and 'betas' in group: - beta = group['betas'][0] + if "beta" in group: + beta = group["beta"] + if beta is None and "betas" in group: + beta = group["betas"][0] if beta is None: raise ValueError("Beta not found in group.") return beta def get_beta2(group): - if 'palm' in group and group['palm'] is True and 'beta2_scale' in group: + if "palm" in group and group["palm"] is True and "beta2_scale" in group: step = max(group.get("step", 1), 1) - return 1 - step ** -group['beta2_scale'] - if 'betas' in group: - return group['betas'][1] + return 1 - step ** -group["beta2_scale"] + if "betas" in group: + return group["betas"][1] raise ValueError("Beta2 not found in group.") @@ -580,9 +597,9 @@ def update_ggt(grad, GG, max_precond_dim, precondition_1d, beta): if not isinstance(m, Tensor): continue b = einsum_base[idx] - g0 = einsum_base[:grad.dim()] + g0 = einsum_base[: grad.dim()] g1 = g0.replace(b, b.upper()) - outer_product = torch.einsum(f'{g0},{g1}->{b + b.upper()}', grad, grad) + outer_product = torch.einsum(f"{g0},{g1}->{b + b.upper()}", grad, grad) stochastic_lerp_(m, outer_product, 1 - beta) @@ -623,19 +640,19 @@ def init_preconditioner(grad, state, max_precond_dim, precondition_1d): """ Initializes the preconditioner matrices (L and R in the paper). """ - state['GG'] = [] # Will hold all the preconditioner matrices (L and R in the paper). + state["GG"] = [] # Will hold all the preconditioner matrices (L and R in the paper). if grad.numel() > 1 and (grad.ndim > 1 or precondition_1d): for sh in grad.shape: if sh > max_precond_dim or sh == 1: # via @francois-rozet: https://github.com/HomebrewML/HeavyBall/commit/8b86be04967e2d095136d5603724f488f2d46592#diff-a430393dd0a6ee393944a9ed16416115c175de2414cf4a96e647197697f265e9R621 - state['GG'].append(None) + state["GG"].append(None) else: - state['GG'].append(torch.zeros(sh, sh, device=grad.device, dtype=grad.dtype)) + state["GG"].append(torch.zeros(sh, sh, device=grad.device, dtype=grad.dtype)) else: - state['GG'].append(None) + state["GG"].append(None) - update_ggt(grad, state['GG'], max_precond_dim, precondition_1d, 0) - state['Q'] = get_orthogonal_matrix(state['GG']) + update_ggt(grad, state["GG"], max_precond_dim, precondition_1d, 0) + state["Q"] = get_orthogonal_matrix(state["GG"]) @decorator @@ -646,11 +663,11 @@ def project(grad, Q, back: bool): :param back: whether to project to Shampoo eigenbases or back to original space :return: """ - param = einsum_base[:grad.dim()] - preconditioners = ",".join([(g + g.upper())[::-1 if back else 1] for m, g in zip(Q, param) if m is not None]) + param = einsum_base[: grad.dim()] + preconditioners = ",".join([(g + g.upper())[:: -1 if back else 1] for m, g in zip(Q, param) if m is not None]) if preconditioners: - out = ''.join([c.upper() if c.upper() in preconditioners else c for c in param]) - out = torch.einsum(f'{param},{preconditioners}->{out}', promote(grad), *[q for q in Q if q is not None]) + out = "".join([c.upper() if c.upper() in preconditioners else c for c in param]) + out = torch.einsum(f"{param},{preconditioners}->{out}", promote(grad), *[q for q in Q if q is not None]) grad = out.to(grad.dtype) return grad @@ -667,12 +684,12 @@ def modify_closure(closure): """ def patched_backward(self, *args, **kwargs): - kwargs['create_graph'] = True + kwargs["create_graph"] = True return original_backward(self, *args, **kwargs) original_backward = torch.Tensor.backward - with patch.object(torch.Tensor, 'backward', patched_backward): + with patch.object(torch.Tensor, "backward", patched_backward): return closure() @@ -683,6 +700,7 @@ class StatefulOptimizer(torch.optim.Optimizer): The previous (heavyball<=1.5.3) default was `True`, which is incompatible with some benchmarks but works better with RevNet Further notice that both methods have different numerics outputs """ + ema_decay: float = 0.001 compile_step: bool = False hessian_approx: bool = False @@ -691,10 +709,10 @@ class StatefulOptimizer(torch.optim.Optimizer): finite_differences: bool = False def __init__(self, params, defaults, foreach: bool = True, use_ema: bool = False): - super().__init__(params, {**defaults, 'foreach': foreach}) + super().__init__(params, {**defaults, "foreach": foreach}) self.use_ema = use_ema self.mapping = {} - self._inner_group = {'stochastic_schedule': self.stochastic_schedule} + self._inner_group = {"stochastic_schedule": self.stochastic_schedule} self._precond_rng = random.Random(0x12312) self._is_preconditioning = None @@ -710,24 +728,25 @@ def state_(self, arg: Tensor): def mars_correct_list(self, group, p_list, g_list, mars_gamma, beta): for p, g in zip(p_list, g_list): state = self.state_(p) - if 'mars_old_grad' not in state: - state['mars_old_grad'] = torch.zeros_like(g) - old_gs = [self.state_(p)['mars_old_grad'] for p in p_list] + if "mars_old_grad" not in state: + state["mars_old_grad"] = torch.zeros_like(g) + old_gs = [self.state_(p)["mars_old_grad"] for p in p_list] mars_correction(g_list, old_gs, mars_gamma, beta) - def split_p_and_g_in_group(self, group: dict, skip_none: bool = True, should_promote: bool = True, - beta1: float = -1.0): + def split_p_and_g_in_group( + self, group: dict, skip_none: bool = True, should_promote: bool = True, beta1: float = -1.0 + ): for p in group["params"]: if p in self.mapping: p_views = self.mapping[p] else: self.mapping[p] = p_views = merge_group(group, p) - grad = getattr(p, 'grad', None) + grad = getattr(p, "grad", None) p.grad = None if grad is None: - grad = [getattr(pv, 'grad', None) for pv in p_views] + grad = [getattr(pv, "grad", None) for pv in p_views] else: grad = merge_group(group, grad) @@ -736,8 +755,8 @@ def split_p_and_g_in_group(self, group: dict, skip_none: bool = True, should_pro continue if should_promote: g = promote(g) - if beta1 >= 0 and group.get('mars', False): - self.mars_correct_list(group, [pv], [g], group['mars_gamma'], beta1) + if beta1 >= 0 and group.get("mars", False): + self.mars_correct_list(group, [pv], [g], group["mars_gamma"], beta1) yield pv, g def state_size(self) -> int: @@ -759,46 +778,46 @@ def _step(self, group): def ema_update(self): with torch.no_grad(): for group in self.param_groups: - active_p = [p for p in group['params']] + active_p = [p for p in group["params"]] if not active_p: return - k = group['ema_step'] = group.get('ema_step', -1) + 1 + k = group["ema_step"] = group.get("ema_step", -1) + 1 for p in active_p: - if 'param_ema' not in self.state_(p): - self.state_(p)['param_ema'] = torch.zeros_like(p.data, memory_format=torch.preserve_format) + if "param_ema" not in self.state_(p): + self.state_(p)["param_ema"] = torch.zeros_like(p.data, memory_format=torch.preserve_format) - y, param_ema = zip(*[(p.data, self.state_(p)['param_ema']) for p in active_p]) + y, param_ema = zip(*[(p.data, self.state_(p)["param_ema"]) for p in active_p]) torch._foreach_lerp_(param_ema, y, weight=beta_debias(1 - self.ema_decay, k + 1)) def copy_emas_to_params(self): with torch.no_grad(): for group in self.param_groups: - active_p = [p for p in group['params']] + active_p = [p for p in group["params"]] if not active_p: return for p in active_p: - if 'param_ema' in self.state_(p): + if "param_ema" in self.state_(p): p_clone = p.data.clone() - set_(p.data, self.state_(p)['param_ema']) - set_(self.state_(p)['param_ema'], p_clone) + set_(p.data, self.state_(p)["param_ema"]) + set_(self.state_(p)["param_ema"], p_clone) def copy_params_to_emas(self): with torch.no_grad(): for group in self.param_groups: - active_p = [p for p in group['params']] + active_p = [p for p in group["params"]] if not active_p: return for p in active_p: - if 'param_ema' in self.state_(p): - ema_clone = self.state_(p)['param_ema'].data.clone() - set_(self.state_(p)['param_ema'], p.data) + if "param_ema" in self.state_(p): + ema_clone = self.state_(p)["param_ema"].data.clone() + set_(self.state_(p)["param_ema"], p.data) set_(p.data, ema_clone) def _handle_closure(self, closure): @@ -844,8 +863,11 @@ def _handle_closure(self, closure): for group in self.param_groups: for p, g in self.split_p_and_g_in_group(group, skip_none=True, should_promote=False): p.grad = g - params, grads = zip(*[x for group in self.param_groups for x in - self.split_p_and_g_in_group(group, skip_none=True, should_promote=False)]) + params, grads = zip(*[ + x + for group in self.param_groups + for x in self.split_p_and_g_in_group(group, skip_none=True, should_promote=False) + ]) vs = [torch.randn_like(p) for p in params] with torch.enable_grad(): hvs = torch.autograd.grad(grads, params, vs) @@ -867,7 +889,7 @@ def step(self, closure: Optional[Callable] = None): # we assume that parameters are constant and that there are no excessive recompiles with torch.no_grad(), torch._dynamo.utils.disable_cache_limit(): for group in self.param_groups: - group['is_preconditioning'] = self._is_preconditioning + group["is_preconditioning"] = self._is_preconditioning self._step(group) if self.use_ema: self.ema_update() @@ -891,8 +913,15 @@ def _lerp(state: List[Tensor], grad: List[Tensor], beta): @decorator_knowngood -def _compilable_adam_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: Tensor, beta2: Tensor, - step: Tensor, eps: Tensor): +def _compilable_adam_( + exp_avg: List[Tensor], + exp_avg_sq: List[Tensor], + grad: List[Tensor], + beta1: Tensor, + beta2: Tensor, + step: Tensor, + eps: Tensor, +): beta1 = beta_debias(beta1, step) beta2 = beta_debias(beta2, step) @@ -903,8 +932,15 @@ def _compilable_adam_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: Lis copy_stochastic_list_(grad, u32) -def adam_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: float, beta2: float, step: int, - eps: float = 1e-8): +def adam_( + exp_avg: List[Tensor], + exp_avg_sq: List[Tensor], + grad: List[Tensor], + beta1: float, + beta2: float, + step: int, + eps: float = 1e-8, +): exp_avg, exp_avg_sq, grad = map(list_guard, (exp_avg, exp_avg_sq, grad)) beta1, beta2, step, eps = scalar_guard(beta1, beta2, step, eps, exp_avg[0]) _compilable_adam_(exp_avg, exp_avg_sq, grad, beta1, beta2, step, eps) @@ -912,9 +948,20 @@ def adam_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], b @decorator_knowngood -def _fused_compilable_adam_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor], update: List[Tensor], - grad: List[Tensor], beta1: Tensor, beta2: Tensor, step: Tensor, decay: Tensor, lr: Tensor, - eps: Tensor, caution: bool): +def _fused_compilable_adam_( + y: List[Tensor], + exp_avg: List[Tensor], + exp_avg_sq: List[Tensor], + update: List[Tensor], + grad: List[Tensor], + beta1: Tensor, + beta2: Tensor, + step: Tensor, + decay: Tensor, + lr: Tensor, + eps: Tensor, + caution: bool, +): beta1 = beta_debias(beta1, step) beta2 = beta_debias(beta2, step) @@ -925,17 +972,35 @@ def _fused_compilable_adam_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: _compilable_update_(y, u32, decay, lr, caution, g32) -def fused_adam_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor], update: List[Tensor], - grad: List[Tensor], beta1: float, beta2: float, step: int, lr: float, eps: float, decay: float, - caution: bool): +def fused_adam_( + y: List[Tensor], + exp_avg: List[Tensor], + exp_avg_sq: List[Tensor], + update: List[Tensor], + grad: List[Tensor], + beta1: float, + beta2: float, + step: int, + lr: float, + eps: float, + decay: float, + caution: bool, +): y, exp_avg, exp_avg_sq, grad = list_guard(y, exp_avg, exp_avg_sq, grad) beta1, beta2, step, lr = scalar_guard(beta1, beta2, step, lr, y[0]) _fused_compilable_adam_(y, exp_avg, exp_avg_sq, update, grad, beta1, beta2, step, decay, lr, eps, caution) @decorator_knowngood -def _compilable_laprop_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: Tensor, - beta2: Tensor, step: Tensor, eps: Tensor): +def _compilable_laprop_( + exp_avg: List[Tensor], + exp_avg_sq: List[Tensor], + grad: List[Tensor], + beta1: Tensor, + beta2: Tensor, + step: Tensor, + eps: Tensor, +): beta1 = beta_debias(beta1, step) beta2 = beta_debias(beta2, step) @@ -946,8 +1011,15 @@ def _compilable_laprop_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: L copy_stochastic_list_(grad, gp32) -def laprop_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: float, beta2: float, step: int, - eps: float = 1e-8): +def laprop_( + exp_avg: List[Tensor], + exp_avg_sq: List[Tensor], + grad: List[Tensor], + beta1: float, + beta2: float, + step: int, + eps: float = 1e-8, +): exp_avg, exp_avg_sq, grad = list_guard(exp_avg, exp_avg_sq, grad) beta1, beta2, step, eps = scalar_guard(beta1, beta2, step, eps, exp_avg[0]) _compilable_laprop_(exp_avg, exp_avg_sq, grad, beta1, beta2, step, eps) @@ -955,9 +1027,20 @@ def laprop_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], @decorator_knowngood -def _fused_compilable_laprop_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor], update: List[Tensor], - grad: List[Tensor], beta1: Tensor, beta2: Tensor, step: Tensor, lr: Tensor, decay: Tensor, - caution: bool, eps: Tensor): +def _fused_compilable_laprop_( + y: List[Tensor], + exp_avg: List[Tensor], + exp_avg_sq: List[Tensor], + update: List[Tensor], + grad: List[Tensor], + beta1: Tensor, + beta2: Tensor, + step: Tensor, + lr: Tensor, + decay: Tensor, + caution: bool, + eps: Tensor, +): beta1 = beta_debias(beta1, step) beta2 = beta_debias(beta2, step) @@ -968,9 +1051,20 @@ def _fused_compilable_laprop_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq _compilable_update_(y, u32, decay, lr, caution, gp32) -def fused_laprop_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor], update: List[Tensor], - grad: List[Tensor], beta1: float, beta2: float, step: int, lr: float, decay: float, caution: bool, - eps: float = 1e-8): +def fused_laprop_( + y: List[Tensor], + exp_avg: List[Tensor], + exp_avg_sq: List[Tensor], + update: List[Tensor], + grad: List[Tensor], + beta1: float, + beta2: float, + step: int, + lr: float, + decay: float, + caution: bool, + eps: float = 1e-8, +): exp_avg, exp_avg_sq, grad, y = list_guard(exp_avg, exp_avg_sq, grad, y) beta1, beta2, step, lr, eps = scalar_guard(beta1, beta2, step, lr, eps, exp_avg[0]) _fused_compilable_laprop_(y, exp_avg, exp_avg_sq, update, grad, beta1, beta2, step, lr, decay, caution, eps) @@ -978,7 +1072,7 @@ def fused_laprop_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tenso @decorator_knowngood def _fused_compilable_adopt_(y, update, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay, caution): - u32, g32, exp_avg_sq32, exp_avg32 = [list(map(promote, x)) for x in [update, grad, exp_avg_sq, exp_avg]] + u32, g32, exp_avg_sq32 = [list(map(promote, x)) for x in [update, grad, exp_avg_sq]] _compilable_update_(y, u32, decay, lr, caution, g32) beta1 = beta_debias(beta1, step) @@ -997,7 +1091,7 @@ def fused_adopt_(y, update, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, e @decorator_knowngood def _compilable_adopt_(grad, exp_avg_sq, exp_avg, beta1, beta2, step, eps): - g32, exp_avg32, exp_avg_sq32 = [list(map(promote, x)) for x in [grad, exp_avg, exp_avg_sq]] + g32, exp_avg_sq32 = [list(map(promote, x)) for x in [grad, exp_avg_sq]] update = [e.clone() for e in exp_avg] beta1 = beta_debias(beta1, step) @@ -1044,8 +1138,9 @@ def copy_stochastic_(target: Tensor, source: Tensor): @decorator_knowngood -def _compilable_update_(p: List[Tensor], u: List[Tensor], decay: Tensor, lr: Tensor, caution: bool, - g: List[Optional[Tensor]]): +def _compilable_update_( + p: List[Tensor], u: List[Tensor], decay: Tensor, lr: Tensor, caution: bool, g: List[Optional[Tensor]] +): for u_, g_, p_ in zip(u, g, p): # lr is data-dependent -> can't compile a foreach u_ = promote(u_.view_as(p_)) p32_ = promote(p_) @@ -1055,8 +1150,9 @@ def _compilable_update_(p: List[Tensor], u: List[Tensor], decay: Tensor, lr: Ten copy_stochastic_(p_, p32_) -def update_param_(param: List[Tensor], update: List[Tensor], lr: float, decay: float, caution: bool = False, - grad: List[Tensor] = None): +def update_param_( + param: List[Tensor], update: List[Tensor], lr: float, decay: float, caution: bool = False, grad: List[Tensor] = None +): param, update, grad = list_guard(param, update, grad) lr = scalar_guard(lr, param[0]) if not caution: @@ -1119,8 +1215,10 @@ def init_Q_exprs(t, scale, max_size, min_ndim_triangular, memory_save_mode, dtyp elif memory_save_mode == "all_diag": dim_diag = [True for _ in shape] else: - raise ValueError(f"Invalid memory_save_mode: {memory_save_mode}, must be one of " - "[None, 'one_diag', 'all_diag', 'smart_one_diag']") + raise ValueError( + f"Invalid memory_save_mode: {memory_save_mode}, must be one of " + "[None, 'one_diag', 'all_diag', 'smart_one_diag']" + ) Q = [] piece1A, piece2A, piece3A = ([], "", "") @@ -1149,7 +1247,7 @@ def init_Q_exprs(t, scale, max_size, min_ndim_triangular, memory_save_mode, dtyp piece3A = piece3A + letters[i] piece1 = "".join([(letters[i + 13] if j == i else letters[j]) for j in range(len(shape))]) piece2 = "".join([(letters[i + 26] if j == i else letters[j]) for j in range(len(shape))]) - subscripts = (piece1 + "," + piece2 + "->" + letters[i + 13] + letters[i + 26]) + subscripts = piece1 + "," + piece2 + "->" + letters[i + 13] + letters[i + 26] exprGs.append(subscripts) a, b, c = (letters[i], letters[i + 13], letters[i + 26]) piece1P.append(a + b) @@ -1158,7 +1256,7 @@ def init_Q_exprs(t, scale, max_size, min_ndim_triangular, memory_save_mode, dtyp piece4P = piece4P + b exprA = ",".join(piece1A) + "," + piece2A + "->" + piece3A - exprP = (",".join(piece1P) + "," + ",".join(piece2P) + "," + piece3P + "->" + piece4P) + exprP = ",".join(piece1P) + "," + ",".join(piece2P) + "," + piece3P + "->" + piece4P return [Q, (exprA, tuple(exprGs), exprP)] @@ -1187,7 +1285,8 @@ def psgd_calc_A_and_conjB(exprA, G, Q, V=None): conjB /= q else: conjB = torch.linalg.solve_triangular(q, conjB.reshape(-1, q.size(0)), upper=True, left=False).reshape_as( - conjB) + conjB + ) if i < order - 1: conjB = torch.transpose(conjB, i, order - 1) return A, conjB @@ -1195,12 +1294,12 @@ def psgd_calc_A_and_conjB(exprA, G, Q, V=None): def psgd_lb(A, max_abs): A /= max_abs - a0 = torch.einsum('ij,ij->j', A, A) + a0 = torch.einsum("ij,ij->j", A, A) i = torch.argmax(a0) x = torch.index_select(A, 1, i).flatten().contiguous() - x = torch.einsum('i,ij->j', x, A) + x = torch.einsum("i,ij->j", x, A) x /= x.norm() - x = torch.einsum('j,kj->k', x, A) + x = torch.einsum("j,kj->k", x, A) x = x.norm() x *= max_abs return x @@ -1217,7 +1316,7 @@ def psgd_update_precond(Q, exprs, G, precond_lr, oq, store_triu_as_line, V): term2 = promote(torch.einsum(exprG, conjB, conjB)) term1, term2 = term1 - term2, term1 + term2 term1 *= precond_lr - norm = term2.norm(float('inf')) + norm = term2.norm(float("inf")) if q.dim() < 2: term1 *= q.to(term1.dtype) / norm.clamp_(min=tiny_bf16) else: @@ -1245,7 +1344,7 @@ def l2_normalization_(x, clip_at: float = 1e-8): return _compilable_l2_clip_(x, clip_at) -def l2_clip_(x, clip_at: float = 1.): +def l2_clip_(x, clip_at: float = 1.0): x = list_guard(x) return _compilable_l2_clip_(x, clip_at) @@ -1437,12 +1536,13 @@ def warn_once(msg): _warned.add(msg) -def psgd_should_update(group, prob: Union[float, callable], rng: Optional[random.Random] = None, - name: str = 'cumulative_prob'): - group[f'{name}_prob_step'] = group.get(f'{name}_prob_step', 0) + 1 +def psgd_should_update( + group, prob: Union[float, callable], rng: Optional[random.Random] = None, name: str = "cumulative_prob" +): + group[f"{name}_prob_step"] = group.get(f"{name}_prob_step", 0) + 1 if not isinstance(prob, float): - prob = prob(group[f'{name}_prob_step']) - if group['stochastic_schedule']: + prob = prob(group[f"{name}_prob_step"]) + if group["stochastic_schedule"]: return rng.random() < prob cumulative_prob = group.get(name, 0) group[name] = cumulative_prob + prob @@ -1450,8 +1550,9 @@ def psgd_should_update(group, prob: Union[float, callable], rng: Optional[random @decorator_knowngood -def precond_grad_cached_(expr: str, ea: Tensor, *cached_q: Tensor, caution: bool = False, grad: Optional[Tensor] = None, - cast: bool = True): +def precond_grad_cached_( + expr: str, ea: Tensor, *cached_q: Tensor, caution: bool = False, grad: Optional[Tensor] = None, cast: bool = True +): if caution: ea = _compilable_cautioning(grad, ea) md = min_dtype(list(cached_q) + [ea]) @@ -1564,15 +1665,21 @@ def _schedule(n): def merge_group(group, *tensors): - if not group.get('merge_dims', False): + if not group.get("merge_dims", False): return tensors if isinstance(tensors[0], list): return [merge_group(group, *t) for t in tensors] out = [] for t in tensors: - append_or_extend(out, dim_merger(t, group['max_size_triangular'] if 'max_size_triangular' in group else group[ - 'max_precond_dim'], group.get('split', False))) + append_or_extend( + out, + dim_merger( + t, + group["max_size_triangular"] if "max_size_triangular" in group else group["max_precond_dim"], + group.get("split", False), + ), + ) return out @@ -1598,8 +1705,9 @@ def fused_hook(parameters, optimizer, *args, **kwargs): o = optimizer(parameters, *args, **kwargs) step_fn = o.step - o.step = functools.partial(warn_once, - msg="You're trying to call `step` on a fused optimizer. This will not do anything.") + o.step = functools.partial( + warn_once, msg="You're trying to call `step` on a fused optimizer. This will not do anything." + ) def _step(p: Tensor): seen_params.add(p) diff --git a/landscape.md b/landscape.md new file mode 100644 index 0000000..edb2f6b --- /dev/null +++ b/landscape.md @@ -0,0 +1,47 @@ +# Benchmark Results +Generated: 2025-02-22 10:51:19.457377 +Last updated: 2025-02-22 10:51:19.457388 + +## Summary (In Progress) + +| Optimizer | Caution | Mars | Success | Runtime | Average Attempts | +|-----------|---|---|---------|----------|------| +| ForeachSOAP | No | No | 1/1 | 19.37s | 5.0 | +| LaProp | No | No | 1/1 | 20.04s | 5.0 | +| AdamW | No | No | 1/1 | 17.83s | 5.0 | +| Muon | No | No | 0/1 | 0.00s | 0.0 | +| ForeachCachedNewtonPSGD | No | No | 1/1 | 51.90s | 18.0 | +| RMSprop | No | No | 1/1 | 31.00s | 10.0 | +| OrthoLaProp | No | No | 0/1 | 0.00s | 0.0 | +| ForeachSFAdamW | No | No | 1/1 | 23.61s | 5.0 | +| ForeachADOPT | No | No | 0/1 | 0.00s | 0.0 | +| LaPropOrtho | No | No | 0/1 | 0.00s | 0.0 | +| CachedPSGDKron | No | No | 1/1 | 22.17s | 4.0 | +| SignLaProp | No | No | 1/1 | 68.20s | 36.0 | +| ForeachSOLP | No | No | 1/1 | 18.54s | 5.0 | +| AdamW | Yes | No | 1/1 | 20.71s | 5.0 | +| AdamW | Unscaled | No | 1/1 | 23.29s | 5.0 | +| AdamW | No | Yes | 1/1 | 19.08s | 5.0 | + +## Details + +| Benchmark | Optimizer | Cautious | Mars | Success | Runtime | Loss | Attempts | +|-----------|-----------|---------|---|---|----------|------|---| +| dynamic_landscape | AdamW | No | No | ✓ | 17.83s | 8.95e-03 | 5 | +| dynamic_landscape | CachedPSGDKron | No | No | ✓ | 22.17s | 9.35e-03 | 4 | +| dynamic_landscape | ForeachADOPT | No | No | ✗ | 645.10s | 4.90e-01 | 1000 | +| dynamic_landscape | ForeachCachedNewtonPSGD | No | No | ✓ | 51.90s | 9.94e-03 | 18 | +| dynamic_landscape | ForeachSFAdamW | No | No | ✓ | 23.61s | 9.37e-03 | 5 | +| dynamic_landscape | ForeachSOAP | No | No | ✓ | 19.37s | 9.74e-03 | 5 | +| dynamic_landscape | ForeachSOLP | No | No | ✓ | 18.54s | 8.25e-03 | 5 | +| dynamic_landscape | LaProp | No | No | ✓ | 20.04s | 9.22e-03 | 5 | +| dynamic_landscape | LaPropOrtho | No | No | ✗ | 290.69s | 9.62e-01 | 453 | +| dynamic_landscape | Muon | No | No | ✗ | 287.13s | 3.79e-01 | 245 | +| dynamic_landscape | OrthoLaProp | No | No | ✗ | 238.36s | 9.30e-01 | 349 | +| dynamic_landscape | RMSprop | No | No | ✓ | 31.00s | 9.80e-03 | 10 | +| dynamic_landscape | SignLaProp | No | No | ✓ | 68.20s | 9.78e-03 | 36 | +| dynamic_landscape | AdamW | Yes | No | ✓ | 20.71s | 8.92e-03 | 5 | +| dynamic_landscape | AdamW | No | Yes | ✓ | 19.08s | 8.86e-03 | 5 | +| dynamic_landscape | AdamW | Unscaled | No | ✓ | 23.29s | 9.75e-03 | 5 | + +## Errors diff --git a/test/readme.md b/test/readme.md index b65249c..04973f5 100644 --- a/test/readme.md +++ b/test/readme.md @@ -2,4 +2,4 @@ - [ ] regression test against SOAP (due to implementation challenges) - [ ] peak memory test -- [ ] compute (runtime) test \ No newline at end of file +- [ ] compute (runtime) test diff --git a/test/test_bf16_params.py b/test/test_bf16_params.py index 39422b1..49fddc0 100644 --- a/test/test_bf16_params.py +++ b/test/test_bf16_params.py @@ -1,20 +1,21 @@ import copy import os -import heavyball -import heavyball.utils import pytest import torch -from benchmark.utils import get_optim -from heavyball.utils import clean, set_torch from torch import nn from torch._dynamo import config -import torch._inductor.config as ind_cfg -os.environ['TORCH_LOGS'] = '+recompiles' +import heavyball +import heavyball.utils +from benchmark.utils import get_optim +from heavyball.utils import clean, set_torch + +os.environ["TORCH_LOGS"] = "+recompiles" config.cache_size_limit = 128 + @pytest.mark.parametrize("opt", heavyball.__all__) @pytest.mark.parametrize("size,depth", [(256, 1)]) def test_foreach(opt, size, depth: int, iterations: int = 512, outer_iterations: int = 1): @@ -37,7 +38,7 @@ def test_foreach(opt, size, depth: int, iterations: int = 512, outer_iterations: o = get_optim(opt, mdl.parameters(), lr=1e-4, update_clipping=None, warmup_steps=128) print(f"\n\n\n{dtype} {opt} {size} {depth}\n\n\n") for _ in range(iterations): - loss = mdl(torch.randn((1024, size), device='cuda', dtype=dtype)).double().abs().mean() + loss = mdl(torch.randn((1024, size), device="cuda", dtype=dtype)).double().abs().mean() loss.backward() print(mdl[0].weight.double().norm().item()) o.step() diff --git a/test/test_bf16_q.py b/test/test_bf16_q.py index 67b7102..7370fbd 100644 --- a/test/test_bf16_q.py +++ b/test/test_bf16_q.py @@ -6,7 +6,7 @@ import heavyball import heavyball.utils from benchmark.utils import get_optim -from heavyball.utils import clean, set_torch, PSGDBase +from heavyball.utils import PSGDBase, clean, set_torch config.cache_size_limit = 128 @@ -18,12 +18,12 @@ def test_foreach(opt, size, depth: int, iterations: int = 128, outer_iterations: opt = getattr(heavyball, opt) if not issubclass(opt, PSGDBase): - raise pytest.skip('Only PSGD is supported') + raise pytest.skip("Only PSGD is supported") peaks = [] losses = [] - for q_dtype in ['float32', 'bfloat16']: + for q_dtype in ["float32", "bfloat16"]: torch.manual_seed(0x2131290) peaks.append([]) losses.append([]) @@ -33,7 +33,7 @@ def test_foreach(opt, size, depth: int, iterations: int = 128, outer_iterations: o = get_optim(opt, model.parameters(), lr=1e-3, q_dtype=q_dtype) for _ in range(iterations): - loss = model(torch.randn((1024, size), device='cuda')).square().mean() + loss = model(torch.randn((1024, size), device="cuda")).square().mean() loss.backward() o.step() o.zero_grad() diff --git a/test/test_bf16_storage.py b/test/test_bf16_storage.py index dc524f5..4fb4958 100644 --- a/test/test_bf16_storage.py +++ b/test/test_bf16_storage.py @@ -6,24 +6,23 @@ import heavyball import heavyball.utils from benchmark.utils import get_optim -from heavyball.utils import clean, set_torch, PSGDBase +from heavyball.utils import PSGDBase, clean, set_torch config.cache_size_limit = 128 - @pytest.mark.parametrize("opt", heavyball.__all__) @pytest.mark.parametrize("size,depth", [(256, 2)]) def test_foreach(opt, size, depth: int, iterations: int = 128, outer_iterations: int = 3): set_torch() - if 'soap' in opt.lower(): - raise pytest.skip('soap is not supported') + if "soap" in opt.lower(): + raise pytest.skip("soap is not supported") opt = getattr(heavyball, opt) if PSGDBase in opt.__mro__: - raise pytest.skip('PSGD is not supported') + raise pytest.skip("PSGD is not supported") peaks = [] losses = [] @@ -40,7 +39,7 @@ def test_foreach(opt, size, depth: int, iterations: int = 128, outer_iterations: o = get_optim(opt, model.parameters(), lr=1e-3, storage_dtype=dtype_name) for _ in range(iterations): - loss = model(torch.randn((1024, size), device='cuda', dtype=dtype)).square().mean() + loss = model(torch.randn((1024, size), device="cuda", dtype=dtype)).square().mean() loss.backward() o.step() o.zero_grad() diff --git a/test/test_caution.py b/test/test_caution.py index b34736e..77ba67b 100644 --- a/test/test_caution.py +++ b/test/test_caution.py @@ -1,16 +1,17 @@ import os -os.environ['TORCH_LOGS'] = '+recompiles' +os.environ["TORCH_LOGS"] = "+recompiles" -import heavyball -import heavyball.utils import pytest import torch -from benchmark.utils import get_optim -from heavyball.utils import clean, set_torch from torch import nn from torch._dynamo import config +import heavyball +import heavyball.utils +from benchmark.utils import get_optim +from heavyball.utils import clean, set_torch + config.cache_size_limit = 128 @@ -32,7 +33,7 @@ def test_caution(opt, size, depth: int, iterations: int = 16, outer_iterations: o = get_optim(opt, model.parameters(), lr=1e-5, caution=caution) for _ in range(iterations): - loss = model(torch.randn((1024, size), device='cuda')).square().mean() + loss = model(torch.randn((1024, size), device="cuda")).square().mean() loss.backward() o.step() o.zero_grad() diff --git a/test/test_channels_last.py b/test/test_channels_last.py index 24f24e9..d7e5c7d 100644 --- a/test/test_channels_last.py +++ b/test/test_channels_last.py @@ -2,17 +2,18 @@ os.environ["TORCH_LOGS"] = "+recompiles" -import heavyball -import heavyball.utils import pytest import torch -from benchmark.utils import get_optim -from heavyball.utils import clean, set_torch from torch import nn from torch._dynamo import config -heavyball.utils.zeroth_power_mode = 'newtonschulz' -heavyball.utils.compile_mode = 'default' +import heavyball +import heavyball.utils +from benchmark.utils import get_optim +from heavyball.utils import clean, set_torch + +heavyball.utils.zeroth_power_mode = "newtonschulz" +heavyball.utils.compile_mode = "default" config.cache_size_limit = 128 @@ -38,7 +39,7 @@ def test_foreach(opt, size, depth: int, iterations: int = 1024, outer_iterations o = get_optim(opt, model.parameters(), lr=1e-3, weight_decay=1e-4, warmup_steps=16) for _ in range(iterations): - loss = model(torch.randn((1024, size, 4, 4), device='cuda')).square().mean() + loss = model(torch.randn((1024, size, 4, 4), device="cuda")).square().mean() loss.backward() o.step() o.zero_grad() diff --git a/test/test_closure.py b/test/test_closure.py index 8ab1446..57a52e1 100644 --- a/test/test_closure.py +++ b/test/test_closure.py @@ -1,13 +1,14 @@ from typing import List -import heavyball -import heavyball.utils import pytest import torch -from benchmark.utils import get_optim -from heavyball.utils import set_torch, clean from torch import nn +import heavyball +import heavyball.utils +from benchmark.utils import get_optim +from heavyball.utils import clean, set_torch + class Param(nn.Module): def __init__(self, size): @@ -19,7 +20,12 @@ def forward(self, inp): @pytest.mark.parametrize("opt", heavyball.__all__) -@pytest.mark.parametrize("size", [(4, 4, 4, 4), ]) +@pytest.mark.parametrize( + "size", + [ + (4, 4, 4, 4), + ], +) def test_closure(opt, size: List[int], depth: int = 2, iterations: int = 5, outer_iterations: int = 3): clean() set_torch() @@ -32,7 +38,7 @@ def test_closure(opt, size: List[int], depth: int = 2, iterations: int = 5, oute o = get_optim(opt, model.parameters(), lr=1e-3) def _closure(): - loss = model(torch.randn((1, size[0]), device='cuda')).sum() + loss = model(torch.randn((1, size[0]), device="cuda")).sum() loss.backward() return loss @@ -40,5 +46,3 @@ def _closure(): o.step(_closure) o.zero_grad() print(o.state_size()) - - del model, o diff --git a/test/test_ema.py b/test/test_ema.py index 584115d..dab9425 100644 --- a/test/test_ema.py +++ b/test/test_ema.py @@ -38,7 +38,7 @@ def test_foreach(opt, size, depth: int, iterations: int = 128, outer_iterations: o = get_optim(opt, model.parameters(), lr=1e-3) for _ in range(iterations): - loss = model(torch.randn((1024, size), device='cuda')).square().mean() + loss = model(torch.randn((1024, size), device="cuda")).square().mean() loss.backward() o.step() o.zero_grad() @@ -50,7 +50,7 @@ def test_foreach(opt, size, depth: int, iterations: int = 128, outer_iterations: if do_ema: o.copy_emas_to_params() - loss = model(torch.randn((1024, size), device='cuda')).square().mean() + loss = model(torch.randn((1024, size), device="cuda")).square().mean() losses[-1].append(loss.detach()) del model, o diff --git a/test/test_foreach.py b/test/test_foreach.py index cbeb59e..779176f 100644 --- a/test/test_foreach.py +++ b/test/test_foreach.py @@ -1,11 +1,12 @@ -import heavyball -import heavyball.utils import pytest import torch -from benchmark.utils import get_optim -from heavyball.utils import clean, set_torch, PSGDBase from torch import nn +import heavyball +import heavyball.utils +from benchmark.utils import get_optim +from heavyball.utils import PSGDBase, clean, set_torch + def get_memory(): clean() @@ -45,7 +46,7 @@ def test_foreach(opt, size, depth: int, iterations: int = 4096, outer_iterations clean() for _ in range(iterations): - loss = model(torch.randn((1, size), device='cuda')).sum() + loss = model(torch.randn((1, size), device="cuda")).sum() loss.backward() o.step() o.zero_grad() @@ -54,7 +55,7 @@ def test_foreach(opt, size, depth: int, iterations: int = 4096, outer_iterations del model, o clean() - peak = torch.cuda.memory_stats()['allocated_bytes.all.peak'] + peak = torch.cuda.memory_stats()["allocated_bytes.all.peak"] if i > 0: peaks[-1].append(peak) diff --git a/test/test_hook.py b/test/test_hook.py index 16c23fb..d8319bb 100644 --- a/test/test_hook.py +++ b/test/test_hook.py @@ -2,16 +2,17 @@ os.environ["TORCH_LOGS"] = "+recompiles" -import heavyball -import heavyball.utils import pytest import torch -from benchmark.utils import get_optim -from heavyball.utils import clean, set_torch, hook_optimizer_into_model from torch import nn from torch._dynamo import config -heavyball.utils.compile_mode = 'default' +import heavyball +import heavyball.utils +from benchmark.utils import get_optim +from heavyball.utils import clean, hook_optimizer_into_model, set_torch + +heavyball.utils.compile_mode = "default" config.cache_size_limit = 128 @@ -37,7 +38,7 @@ def test_foreach(opt, size, depth: int, iterations: int = 128, outer_iterations: else: o = get_optim(opt, model.parameters(), lr=1e-3, weight_decay=1e-4, warmup_steps=16) for _ in range(iterations): - loss = model(torch.randn((1024, size), device='cuda')).square().mean() + loss = model(torch.randn((1024, size), device="cuda")).square().mean() loss.backward() if not use_hook: o.step() diff --git a/test/test_mars.py b/test/test_mars.py index 5b173dc..be8fef3 100644 --- a/test/test_mars.py +++ b/test/test_mars.py @@ -1,12 +1,13 @@ -import heavyball -import heavyball.utils import pytest import torch -from benchmark.utils import get_optim -from heavyball.utils import clean, set_torch, ScheduleFree from torch import nn from torch._dynamo import config +import heavyball +import heavyball.utils +from benchmark.utils import get_optim +from heavyball.utils import ScheduleFree, clean, set_torch + config.cache_size_limit = 128 @@ -31,7 +32,7 @@ def test_mars(opt, size, depth: int, iterations: int = 16384, outer_iterations: o = get_optim(opt, model.parameters(), lr=1e-5, mars=mars) for _ in range(iterations): - loss = model(torch.randn((1024, size), device='cuda')).square().mean() + loss = model(torch.randn((1024, size), device="cuda")).square().mean() loss.backward() o.step() o.zero_grad() diff --git a/test/test_memory.py b/test/test_memory.py index 72433f9..c3ff488 100644 --- a/test/test_memory.py +++ b/test/test_memory.py @@ -16,23 +16,27 @@ def get_memory(): return torch.cuda.memory_allocated() -expected_memory = {'adamw': {'after': 4, 'peak': 5.1}, 'soap': {'after': 7, 'peak': 14}, - 'psgd': {'after': 4, 'peak': 11.5}, 'padam': {'after': 5, 'peak': 11.4}} +expected_memory = { + "adamw": {"after": 4, "peak": 5.1}, + "soap": {"after": 7, "peak": 14}, + "psgd": {"after": 4, "peak": 11.5}, + "padam": {"after": 5, "peak": 11.4}, +} -@pytest.mark.parametrize("opt", ['ForeachPSGDKron']) -@pytest.mark.parametrize("method", ['qr', 'newtonschulz2', 'svd', 'eigh']) +@pytest.mark.parametrize("opt", ["ForeachPSGDKron"]) +@pytest.mark.parametrize("method", ["qr", "newtonschulz2", "svd", "eigh"]) @pytest.mark.parametrize("size,depth", [(8192, 1), (2048, 16)]) def test_memory(opt, method, size, depth: int, iterations: int = 5, outer_iterations: int = 3): - if 'soap' not in opt.lower() and method != 'qr': - raise pytest.skip('Only SOAP supports `method` argument') + if "soap" not in opt.lower() and method != "qr": + raise pytest.skip("Only SOAP supports `method` argument") set_torch() for k, v in expected_memory.items(): if k in opt.lower(): break else: - raise pytest.skip(f'Opt {opt} not supported') + raise pytest.skip(f"Opt {opt} not supported") opt = getattr(heavyball, opt) heavyball.utils.zeroth_power_mode = method @@ -48,16 +52,16 @@ def test_memory(opt, method, size, depth: int, iterations: int = 5, outer_iterat model_allocated = get_memory() o = get_optim(opt, model.parameters(), lr=1e-3) for _ in range(iterations): - model(torch.randn((1, size), device='cuda')).sum().backward() + model(torch.randn((1, size), device="cuda")).sum().backward() o.step() opt_allocated = get_memory() o.zero_grad() del model, o - peak = torch.cuda.memory_stats()['allocated_bytes.all.peak'] + peak = torch.cuda.memory_stats()["allocated_bytes.all.peak"] - print(f'Peak: {peak / model_allocated:.2f}x | Opt: {opt_allocated / model_allocated:.2f}x') + print(f"Peak: {peak / model_allocated:.2f}x | Opt: {opt_allocated / model_allocated:.2f}x") if i > 0: - assert peak / model_allocated < v['peak'] - assert opt_allocated / model_allocated < v['after'] + assert peak / model_allocated < v["peak"] + assert opt_allocated / model_allocated < v["after"] diff --git a/test/test_merge.py b/test/test_merge.py index bbf95ca..05d6644 100644 --- a/test/test_merge.py +++ b/test/test_merge.py @@ -7,7 +7,7 @@ import heavyball import heavyball.utils from benchmark.utils import get_optim -from heavyball.utils import set_torch, clean +from heavyball.utils import clean, set_torch class Param(nn.Module): @@ -19,14 +19,22 @@ def forward(self, inp): return self.weight.mean() * inp -@pytest.mark.parametrize("opt", ['ForeachPSGDKron', 'ForeachPaLMPAdam']) -@pytest.mark.parametrize("method", ['qr', 'newtonschulz2', 'svd', 'eigh']) +@pytest.mark.parametrize("opt", ["ForeachPSGDKron", "ForeachPaLMPAdam"]) +@pytest.mark.parametrize("method", ["qr", "newtonschulz2", "svd", "eigh"]) @pytest.mark.parametrize("size", [(16, 16, 16, 16), (4, 4, 4, 4), (512, 1, 128), (32128, 768)]) @pytest.mark.parametrize("merge,split", [(False, False), (True, False), (True, True)]) -def test_merge(opt, method, size: List[int], merge, split, depth: int = 2, iterations: int = 5, - outer_iterations: int = 3): - if 'soap' not in opt.lower() and method != 'qr': - raise pytest.skip('Only SOAP supports `method` argument') +def test_merge( + opt, + method, + size: List[int], + merge, + split, + depth: int = 2, + iterations: int = 5, + outer_iterations: int = 3, +): + if "soap" not in opt.lower() and method != "qr": + raise pytest.skip("Only SOAP supports `method` argument") clean() set_torch() @@ -37,11 +45,18 @@ def test_merge(opt, method, size: List[int], merge, split, depth: int = 2, itera clean() model = nn.Sequential(*[Param(size) for _ in range(depth)]).cuda() # We don't know if merging will use more or less memory, but we do know that it shouldn't crash. This test is to check if it crashes - o = get_optim(opt, model.parameters(), lr=1e-3, merge_dims=merge, split=split, max_precond_dim=256, - max_size_triangular=256) + o = get_optim( + opt, + model.parameters(), + lr=1e-3, + merge_dims=merge, + split=split, + max_precond_dim=256, + max_size_triangular=256, + ) for i in range(iterations): - model(torch.randn((1, size[0]), device='cuda')).sum().backward() + model(torch.randn((1, size[0]), device="cuda")).sum().backward() o.step() o.zero_grad() print(o.state_size()) diff --git a/test/test_no_grad.py b/test/test_no_grad.py index 4315d88..c83269d 100644 --- a/test/test_no_grad.py +++ b/test/test_no_grad.py @@ -1,13 +1,14 @@ from typing import List -import heavyball -import heavyball.utils import pytest import torch -from benchmark.utils import get_optim -from heavyball.utils import set_torch, clean from torch import nn +import heavyball +import heavyball.utils +from benchmark.utils import get_optim +from heavyball.utils import clean, set_torch + class Param(nn.Module): def __init__(self, size): @@ -19,7 +20,12 @@ def forward(self, inp): @pytest.mark.parametrize("opt", heavyball.__all__) -@pytest.mark.parametrize("size", [(4, 4, 4, 4), ]) +@pytest.mark.parametrize( + "size", + [ + (4, 4, 4, 4), + ], +) def test_closre(opt, size: List[int], depth: int = 2, iterations: int = 5, outer_iterations: int = 3): clean() set_torch() diff --git a/test/test_psgd.py b/test/test_psgd.py deleted file mode 100644 index e501474..0000000 --- a/test/test_psgd.py +++ /dev/null @@ -1,66 +0,0 @@ -import heavyball -import heavyball.utils -import pytest -import torch -from benchmark.utils import get_optim -from heavyball.utils import clean, set_torch -from torch import nn - - -def get_memory(): - clean() - torch.cuda.synchronize() - clean() - torch.cuda.synchronize() - return torch.cuda.memory_allocated() - - -@pytest.mark.parametrize("opt", ['ForeachPSGDKron', 'ForeachPaLMPAdam', 'ForeachPurePSGD', 'ForeachDelayedPSGD']) -@pytest.mark.parametrize("method", - ['norm_clip_', 'mu_law_compress', 'a_law_compress', 'trust_region_clip_', 'identity', - 'normalize_grads']) -@pytest.mark.parametrize("size,depth", [(128, 1), (16, 4)]) -def test_clip(opt, method, size, depth: int, iterations: int = 100, outer_iterations: int = 3): - set_torch() - - opt = getattr(heavyball, opt) - - for i in range(outer_iterations): - model = nn.Sequential(*[nn.Linear(size, size) for _ in range(depth)]).cuda() - - torch.cuda.reset_peak_memory_stats() - torch.cuda.reset_max_memory_allocated() - torch.cuda.reset_max_memory_cached() - torch.cuda.reset_accumulated_memory_stats() - - model_allocated = get_memory() - o = get_optim( - opt, model.parameters(), lr=1e-3, - clip_fn=getattr(heavyball.utils, method) if method != 'normalize_grads' else 'identity', - normalize_grads=method == 'normalize_grads' - ) - losses = torch.zeros((iterations,), device='cuda') - for itr in range(iterations): - src = torch.randn((4, size), device='cuda') - tgt = src - loss = (model(src) - tgt).square().mean() - loss.backward() - o.step() - - opt_allocated = get_memory() - o.zero_grad() - losses[itr] = loss - - del model, o - - arange = torch.arange(iterations, device='cuda', dtype=torch.float32) - lwma_bwd = (losses @ torch.flip(arange, [0])).item() - lwma_fwd = (losses @ arange).item() - assert lwma_bwd > lwma_fwd - - peak = torch.cuda.memory_stats()['allocated_bytes.all.peak'] - - print(f'Peak: {peak / model_allocated:.2f}x | Opt: {opt_allocated / model_allocated:.2f}x') - if i > 0: - assert peak / model_allocated < v['peak'] - assert opt_allocated / model_allocated < v['after'] diff --git a/test/test_soap.py b/test/test_soap.py index 35cdb9b..27574d8 100644 --- a/test/test_soap.py +++ b/test/test_soap.py @@ -7,30 +7,37 @@ from heavyball.utils import dim_merger, promote -def init_preconditioner(grad, state, precondition_frequency=10, shampoo_beta=0.95, max_precond_dim=10000, - precondition_1d=False, merge_dims=False): +def init_preconditioner( + grad, + state, + precondition_frequency=10, + shampoo_beta=0.95, + max_precond_dim=10000, + precondition_1d=False, + merge_dims=False, +): """ Initializes the preconditioner matrices (L and R in the paper). """ - state['GG'] = [] # Will hold all the preconditioner matrices (L and R in the paper). + state["GG"] = [] # Will hold all the preconditioner matrices (L and R in the paper). if grad.dim() == 1: if not precondition_1d or grad.shape[0] > max_precond_dim: - state['GG'].append([]) + state["GG"].append([]) else: - state['GG'].append(torch.zeros(grad.shape[0], grad.shape[0], device=grad.device, dtype=grad.dtype)) + state["GG"].append(torch.zeros(grad.shape[0], grad.shape[0], device=grad.device, dtype=grad.dtype)) else: if merge_dims: grad = dim_merger(grad, max_precond_dim) for sh in grad.shape: if sh > max_precond_dim: - state['GG'].append([]) + state["GG"].append([]) else: - state['GG'].append(torch.zeros(sh, sh, device=grad.device, dtype=grad.dtype)) + state["GG"].append(torch.zeros(sh, sh, device=grad.device, dtype=grad.dtype)) - state['Q'] = None # Will hold all the eigenbases of the preconditioner. - state['precondition_frequency'] = precondition_frequency - state['shampoo_beta'] = shampoo_beta + state["Q"] = None # Will hold all the eigenbases of the preconditioner. + state["precondition_frequency"] = precondition_frequency + state["shampoo_beta"] = shampoo_beta def project(grad, state, merge_dims=False, max_precond_dim=10000): @@ -41,9 +48,13 @@ def project(grad, state, merge_dims=False, max_precond_dim=10000): if merge_dims: grad = dim_merger(grad, max_precond_dim) - for mat in state['Q']: + for mat in state["Q"]: if len(mat) > 0: - grad = torch.tensordot(grad, mat, dims=[[0], [0]], ) + grad = torch.tensordot( + grad, + mat, + dims=[[0], [0]], + ) else: permute_order = list(range(1, len(grad.shape))) + [0] grad = grad.permute(permute_order) @@ -59,26 +70,32 @@ def update_preconditioner(grad, state, max_precond_dim=10000, merge_dims=False, """ if grad.dim() == 1: if precondition_1d and grad.shape[0] <= max_precond_dim: - state['GG'][0].lerp_(grad.unsqueeze(1) @ grad.unsqueeze(0), 1 - state['shampoo_beta']) + state["GG"][0].lerp_(grad.unsqueeze(1) @ grad.unsqueeze(0), 1 - state["shampoo_beta"]) else: if merge_dims: new_grad = dim_merger(grad, max_precond_dim) for idx, sh in enumerate(new_grad.shape): if sh <= max_precond_dim: - outer_product = torch.tensordot(new_grad, new_grad, dims=[[*chain(range(idx), range(idx + 1, - len(new_grad.shape)))]] * 2, ) - state['GG'][idx].lerp_(outer_product, 1 - state['shampoo_beta']) + outer_product = torch.tensordot( + new_grad, + new_grad, + dims=[[*chain(range(idx), range(idx + 1, len(new_grad.shape)))]] * 2, + ) + state["GG"][idx].lerp_(outer_product, 1 - state["shampoo_beta"]) else: for idx, sh in enumerate(grad.shape): if sh <= max_precond_dim: - outer_product = torch.tensordot(grad, grad, # Contracts across all dimensions except for k. - dims=[[*chain(range(idx), range(idx + 1, len(grad.shape)))]] * 2, ) - state['GG'][idx].lerp_(outer_product, 1 - state['shampoo_beta']) + outer_product = torch.tensordot( + grad, + grad, # Contracts across all dimensions except for k. + dims=[[*chain(range(idx), range(idx + 1, len(grad.shape)))]] * 2, + ) + state["GG"][idx].lerp_(outer_product, 1 - state["shampoo_beta"]) - if state['Q'] is None: - state['Q'] = get_orthogonal_matrix(state['GG']) - if state['step'] > 0 and state['step'] % state['precondition_frequency'] == 0: - state['Q'] = get_orthogonal_matrix_QR(state, max_precond_dim, merge_dims) + if state["Q"] is None: + state["Q"] = get_orthogonal_matrix(state["GG"]) + if state["step"] > 0 and state["step"] % state["precondition_frequency"] == 0: + state["Q"] = get_orthogonal_matrix_QR(state, max_precond_dim, merge_dims) def project_back(grad, state, merge_dims=False, max_precond_dim=10000): @@ -88,9 +105,13 @@ def project_back(grad, state, merge_dims=False, max_precond_dim=10000): original_shape = grad.shape if merge_dims: grad = dim_merger(grad, max_precond_dim) - for mat in state['Q']: + for mat in state["Q"]: if len(mat) > 0: - grad = torch.tensordot(grad, mat, dims=[[0], [1]], ) + grad = torch.tensordot( + grad, + mat, + dims=[[0], [1]], + ) else: permute_order = list(range(1, len(grad.shape))) + [0] grad = grad.permute(permute_order) @@ -125,7 +146,7 @@ def get_orthogonal_matrix(mat): continue try: _, Q = torch.linalg.eigh(m + 1e-30 * torch.eye(m.shape[0], device=m.device)) - except: + except Exception: _, Q = torch.linalg.eigh(m.to(torch.float64) + 1e-30 * torch.eye(m.shape[0], device=m.device)) Q = Q.to(m.dtype) Q = torch.flip(Q, [1]) @@ -141,8 +162,8 @@ def get_orthogonal_matrix_QR(state, max_precond_dim=10000, merge_dims=False): Computes the eigenbases of the preconditioner using one round of power iteration followed by torch.linalg.qr decomposition. """ - precond_list = state['GG'] - orth_list = state['Q'] + precond_list = state["GG"] + orth_list = state["Q"] matrix = [] orth_matrix = [] @@ -162,11 +183,11 @@ def get_orthogonal_matrix_QR(state, max_precond_dim=10000, merge_dims=False): matrix.append(promote(m.data)) orth_matrix.append(promote(o.data)) - orig_shape = state['exp_avg_sq'].shape + orig_shape = state["exp_avg_sq"].shape if merge_dims: - exp_avg_sq = dim_merger(state['exp_avg_sq'], max_precond_dim) + exp_avg_sq = dim_merger(state["exp_avg_sq"], max_precond_dim) else: - exp_avg_sq = state['exp_avg_sq'] + exp_avg_sq = state["exp_avg_sq"] final = [] for ind, (m, o) in enumerate(zip(matrix, orth_matrix)): @@ -187,30 +208,56 @@ def get_orthogonal_matrix_QR(state, max_precond_dim=10000, merge_dims=False): if merge_dims: exp_avg_sq = exp_avg_sq.reshape(orig_shape) - state['exp_avg_sq'] = exp_avg_sq + state["exp_avg_sq"] = exp_avg_sq return final def _init(size, max_precond, merge_dims, precondition_1d, beta, precondition_frequency=1): grad = torch.randn(size, dtype=torch.double) - ref_state = {'step': 1, 'exp_avg': torch.randn_like(grad), 'exp_avg_sq': torch.randn_like(grad)} - new_state = {'step': 1, **{k: v.clone() for k, v in ref_state.items() if isinstance(v, torch.Tensor)}} - init_preconditioner(grad.clone(), ref_state, precondition_frequency=precondition_frequency, shampoo_beta=beta, - max_precond_dim=max_precond, precondition_1d=precondition_1d, merge_dims=merge_dims) - utils.init_preconditioner(grad.clone(), new_state, max_precond_dim=max_precond, precondition_1d=precondition_1d, - merge_dims=merge_dims) + ref_state = { + "step": 1, + "exp_avg": torch.randn_like(grad), + "exp_avg_sq": torch.randn_like(grad), + } + new_state = { + "step": 1, + **{k: v.clone() for k, v in ref_state.items() if isinstance(v, torch.Tensor)}, + } + init_preconditioner( + grad.clone(), + ref_state, + precondition_frequency=precondition_frequency, + shampoo_beta=beta, + max_precond_dim=max_precond, + precondition_1d=precondition_1d, + merge_dims=merge_dims, + ) + utils.init_preconditioner( + grad.clone(), + new_state, + max_precond_dim=max_precond, + precondition_1d=precondition_1d, + merge_dims=merge_dims, + ) return grad, ref_state, new_state def _updated(size, max_precond, merge_dims, precondition_1d, beta, iterations, precondition_frequency=1): grad, ref_state, new_state = _init(size, max_precond, merge_dims, precondition_1d, beta, precondition_frequency) for _ in range(iterations): - ref_state['step'] += 1 - new_state['step'] += 1 + ref_state["step"] += 1 + new_state["step"] += 1 grad = torch.randn_like(grad) update_preconditioner(grad.clone(), ref_state, max_precond, merge_dims, precondition_1d) - utils.update_preconditioner(grad.clone(), new_state, max_precond, merge_dims, precondition_1d, beta, - precondition_frequency == 1) + utils.update_preconditioner( + grad.clone(), + new_state, + max_precond, + merge_dims, + precondition_1d, + beta, + precondition_frequency == 1, + ) yield grad, ref_state, new_state @@ -222,58 +269,65 @@ def _check(ref, new): if isinstance(rr, list): for r, n in zip(rr, nn): if isinstance(r, torch.Tensor): - assert ref['step'] and k and torch.allclose(r, n) + assert ref["step"] and k and torch.allclose(r, n) elif isinstance(rr, torch.Tensor): - assert ref['step'] and k and torch.allclose(rr, nn) + assert ref["step"] and k and torch.allclose(rr, nn) _size = 16 -@pytest.mark.parametrize('size', [(_size,), (_size,) * 2, (_size,) * 3]) -@pytest.mark.parametrize('max_precond', [_size ** 2 * 2, _size * 2, _size // 2]) -@pytest.mark.parametrize('merge_dims', [True, False]) -@pytest.mark.parametrize('precondition_1d', [True, False]) -@pytest.mark.parametrize('beta', [0.5, 0.9, 0.99]) +@pytest.mark.parametrize("size", [(_size,), (_size,) * 2, (_size,) * 3]) +@pytest.mark.parametrize("max_precond", [_size**2 * 2, _size * 2, _size // 2]) +@pytest.mark.parametrize("merge_dims", [True, False]) +@pytest.mark.parametrize("precondition_1d", [True, False]) +@pytest.mark.parametrize("beta", [0.5, 0.9, 0.99]) @torch.no_grad() def test_init(size, max_precond, merge_dims, precondition_1d, beta): - grad, ref_state, new_state = _init(size, max_precond, merge_dims, precondition_1d, beta) + _grad, ref_state, new_state = _init(size, max_precond, merge_dims, precondition_1d, beta) _check(ref_state, new_state) -@pytest.mark.parametrize('size', [(_size,), (_size,) * 2, (_size,) * 3]) -@pytest.mark.parametrize('max_precond', [_size ** 2 * 2, _size * 2, _size // 2]) -@pytest.mark.parametrize('merge_dims', [True, False]) -@pytest.mark.parametrize('precondition_1d', [True, False]) -@pytest.mark.parametrize('beta', [0.5, 0.9, 0.99]) +@pytest.mark.parametrize("size", [(_size,), (_size,) * 2, (_size,) * 3]) +@pytest.mark.parametrize("max_precond", [_size**2 * 2, _size * 2, _size // 2]) +@pytest.mark.parametrize("merge_dims", [True, False]) +@pytest.mark.parametrize("precondition_1d", [True, False]) +@pytest.mark.parametrize("beta", [0.5, 0.9, 0.99]) @torch.no_grad() def test_ggt(size, max_precond, merge_dims, precondition_1d, beta, iterations: int = 5): - for grad, ref_state, new_state in _updated(size, max_precond, merge_dims, precondition_1d, beta, iterations, - precondition_frequency=10**12): + for grad, ref_state, new_state in _updated( + size, + max_precond, + merge_dims, + precondition_1d, + beta, + iterations, + precondition_frequency=10**12, + ): _check(ref_state, new_state) -@pytest.mark.parametrize('size', [(_size,), (_size,) * 2, (_size,) * 3]) -@pytest.mark.parametrize('max_precond', [_size ** 2 * 2, _size * 2, _size // 2]) -@pytest.mark.parametrize('merge_dims', [True, False]) -@pytest.mark.parametrize('precondition_1d', [True, False]) -@pytest.mark.parametrize('beta', [0.5, 0.9, 0.99]) +@pytest.mark.parametrize("size", [(_size,), (_size,) * 2, (_size,) * 3]) +@pytest.mark.parametrize("max_precond", [_size**2 * 2, _size * 2, _size // 2]) +@pytest.mark.parametrize("merge_dims", [True, False]) +@pytest.mark.parametrize("precondition_1d", [True, False]) +@pytest.mark.parametrize("beta", [0.5, 0.9, 0.99]) @torch.no_grad() def test_update(size, max_precond, merge_dims, precondition_1d, beta, iterations: int = 5): for grad, ref_state, new_state in _updated(size, max_precond, merge_dims, precondition_1d, beta, iterations): _check(ref_state, new_state) -@pytest.mark.parametrize('size', [(_size,), (_size,) * 2, (_size,) * 3]) -@pytest.mark.parametrize('max_precond', [_size ** 2 * 2, _size * 2, _size // 2]) -@pytest.mark.parametrize('merge_dims', [True, False]) -@pytest.mark.parametrize('precondition_1d', [True, False]) -@pytest.mark.parametrize('beta', [0.5, 0.9, 0.99]) -@pytest.mark.parametrize('back', [True, False]) +@pytest.mark.parametrize("size", [(_size,), (_size,) * 2, (_size,) * 3]) +@pytest.mark.parametrize("max_precond", [_size**2 * 2, _size * 2, _size // 2]) +@pytest.mark.parametrize("merge_dims", [True, False]) +@pytest.mark.parametrize("precondition_1d", [True, False]) +@pytest.mark.parametrize("beta", [0.5, 0.9, 0.99]) +@pytest.mark.parametrize("back", [True, False]) @torch.no_grad() def test_project(size, max_precond, merge_dims, precondition_1d, beta, back, iterations: int = 5): for grad, ref_state, new_state in _updated(size, max_precond, merge_dims, precondition_1d, beta, iterations): proj_ref = (project_back if back else project)(grad.clone(), ref_state, merge_dims, max_precond) - proj_new = utils.project(grad.clone(), ref_state['Q'], merge_dims, max_precond, back) + proj_new = utils.project(grad.clone(), ref_state["Q"], merge_dims, max_precond, back) - assert ref_state['step'] and torch.allclose(proj_ref.contiguous(), proj_new.contiguous()) \ No newline at end of file + assert ref_state["step"] and torch.allclose(proj_ref.contiguous(), proj_new.contiguous()) diff --git a/test/test_stochastic_updates.py b/test/test_stochastic_updates.py index 1c9f22a..0a7828a 100644 --- a/test/test_stochastic_updates.py +++ b/test/test_stochastic_updates.py @@ -1,11 +1,12 @@ -import heavyball -import heavyball.utils import pytest import torch -from benchmark.utils import get_optim -from heavyball.utils import clean, set_torch, PSGDBase from torch import nn +import heavyball +import heavyball.utils +from benchmark.utils import get_optim +from heavyball.utils import PSGDBase, clean, set_torch + def get_memory(): clean() @@ -22,13 +23,13 @@ def test_foreach(opt, size, depth: int, iterations: int = 8192, outer_iterations opt = getattr(heavyball, opt) if not issubclass(opt, PSGDBase): - raise pytest.skip('Only PSGD is supported') + raise pytest.skip("Only PSGD is supported") peaks = [] losses = [] for stochastic in [False, True]: - print('stochastic', stochastic) + print("stochastic", stochastic) torch.manual_seed(0x2131290) peaks.append([]) losses.append([]) @@ -38,7 +39,7 @@ def test_foreach(opt, size, depth: int, iterations: int = 8192, outer_iterations o = get_optim(opt, model.parameters(), lr=1e-3, stochastic_schedule=stochastic) for _ in range(iterations): - loss = model(torch.randn((128, size), device-'cuda')).square().mean() + loss = model(torch.randn((128, size), device="cuda")).square().mean() loss.backward() o.step() o.zero_grad()