Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def prefixed_name(ctx: Context, name: str):

def assign(ctx: Context, name: str, inp: jax.Array):
name = prefixed_name(ctx, name)
ctx.parameters[name] = inp
ctx.parameters[name] = inp.astype(ctx.parameters[name].dtype)


def normal(ctx: Context, shape: Sequence[int]):
Expand Down
3 changes: 3 additions & 0 deletions src/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,9 @@ class Optimizer(DataClass):
weight_decay: float = 0.01
warmup_end: int = 16384
exponential_decay: float = 3e-6
svd_components: int = 8
fisher_decay: float = 0.99
log_matrix_power: int = 5 # 2^x+1 is actual power


class Normalization(DataClass):
Expand Down
82 changes: 52 additions & 30 deletions src/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import jax
from jax import lax, numpy as jnp

from src.backend import add_sq, assign, default, get_param, is_stacked, stable_rsqrt, with_context
from src.backend import assign, default, get_param, is_stacked, with_context, stable_rsqrt
from src.constants import MomentumType
from src.context import Context

Expand Down Expand Up @@ -63,19 +63,6 @@ def graft(param_name: str, magnitude: jax.Array, direction: jax.Array) -> jax.Ar
return direction * jnp.sqrt(norm(param_name, magnitude) / jnp.maximum(norm(param_name, direction), 1e-16))


def tg_adam(ctx: Context, param_name: str, grad: jax.Array, tg_grad: jax.Array, step: jax.Array) -> jax.Array:
ema_g = ema(ctx, grad, step, 1 - ctx.optimizer.adam_beta1)
ema_gsq = ema(ctx, grad ** 2, step, 1 - ctx.optimizer.adam_beta2)
ema_tgsq = ema(ctx, tg_grad, step, 1 - ctx.optimizer.adam_beta3)

if ctx.is_initializing:
return grad

adam_update = ema_g * stable_rsqrt(ema_gsq, ctx.optimizer.epsilon)
tg_update = ema_g * stable_rsqrt(ema_tgsq, ctx.optimizer.epsilon)
return graft(param_name, adam_update, tg_update)


def get_current_lr(ctx: Context, step: jax.Array) -> jax.Array:
opt = ctx.optimizer
learning_rate = opt.learning_rate
Expand All @@ -85,26 +72,61 @@ def get_current_lr(ctx: Context, step: jax.Array) -> jax.Array:
return learning_rate.astype(ctx.model.storage_dtype)


def normalize(x: jax.Array) -> jax.Array:
return x * lax.rsqrt(lax.square(x).sum(1))


def svd_fisher(ctx: Context, grad: jax.Array):
key = jax.random.PRNGKey(ctx.seed)
vectors = normalize(jax.random.normal(key, (ctx.optimizer.svd_components, grad.shape[0])).astype(jnp.float64))
u = get_param(ctx, "u", vectors.shape[::-1], dtype=ctx.optimizer.momentum_dtype, tied=True,
init_val=jnp.zeros_like(vectors)).astype(jnp.float64)
v = get_param(ctx, "v", vectors.shape, dtype=ctx.optimizer.momentum_dtype, tied=True,
init_val=jnp.zeros_like(vectors)).astype(jnp.float64)

mid = jnp.eye(ctx.optimizer.svd_components * 2 + 1)
mid = mid.at[:ctx.optimizer.svd_components, :ctx.optimizer.svd_components].set(jnp.transpose(u, (1, 0)) @ u)
grad = grad * (1 - ctx.optimizer.fisher_decay)
x0 = jnp.concatenate([u * ctx.optimizer.fisher_decay, grad], 1)
x0t = jnp.concatenate([v * ctx.optimizer.fisher_decay, grad], 0)
grad = grad - ((grad @ x0) @ jnp.linalg.inv(jnp.eye(ctx.optimizer.svd_components + 1) + x0t @ x0)) @ x0t

for i, v in enumerate(vectors, 1):
local_mid = mid[:ctx.optimizer.svd_components + i, :ctx.optimizer.svd_components + i]
b0 = normalize(x0 @ local_mid)
b1 = normalize(x0t)
inner = b1 @ b0
for _ in range(ctx.optimizer.log_matrix_power):
inner = inner @ inner
v = b0 @ (inner @ (b1 @ v)) # brackets for speed (V=[N,1], b1=[N,K], inner=[K,K], b0=[K,N)
u = x0 @ (local_mid @ (x0t @ v))
x0 = jnp.concatenate([x0, u.reshape(-1, 1)], 1)
x0t = jnp.concatenate([x0t, v.reshape(-1, 1)], 0)
assign(ctx, "u", x0[:, -vectors:])
assign(ctx, "v", x0t[-vectors:, :])
return grad


def update(ctx: Context, grads: Dict[str, jax.Array], step: jax.Array):
outer_ctx = ctx.add_to_prefix("optimizer")
ctx = ctx.add_to_prefix("optimizer")
lr = -get_current_lr(ctx, step)
keys = [k for k in grads.keys() if "optimizer" not in k and not k.endswith('_sq') and not k.endswith('_sq_stacked')]
grads = jnp.concatenate([adaptive_gradient_clipping(ctx, k, grads[k].reshape(-1), False) for k in keys], 0)

for param_name, grad in grads.items():
if "optimizer" in param_name or param_name.endswith('_sq') or param_name.endswith('_sq_stacked'):
continue
ctx = outer_ctx.add_to_prefix(param_name, count=False)
ctx.name_cache = {}
dtype = ctx.parameters[param_name].dtype
parameter_lr = lr * ctx.parameter_variance.get(param_name, 1)

grad = adaptive_gradient_clipping(ctx, param_name, grad, False)
grad_sq = adaptive_gradient_clipping(ctx, param_name, grads[add_sq(param_name)], True)
weight_update = tg_adam(ctx, param_name, grad, grad_sq, step) * parameter_lr
ctx.name_cache = {}
ema_gsq = ema(ctx, lax.square(grads), step, 1 - ctx.optimizer.adam_beta2)
adam = ema(ctx, grads / stable_rsqrt(ema_gsq, ctx.optimizer.epsilon), step, 1 - ctx.optimizer.adam_beta1)
prec = svd_fisher(ctx, grads)

if ctx.is_initializing:
continue
if ctx.is_initializing:
return

param = ctx.parameters[param_name].astype(jnp.float64)
offset = 0
for param_name in keys:
param = ctx.parameters[param_name]
dtype = ctx.parameters[param_name].dtype
parameter_lr = lr * ctx.parameter_variance.get(param_name, 1)
grad = graft(param_name, adam[offset:offset + param.size], prec[offset:offset + param.size]) * parameter_lr
if not small_parameter(param_name, grad):
param *= 1 + ctx.optimizer.weight_decay * parameter_lr
ctx.parameters[param_name] = (param + weight_update).astype(dtype)
ctx.parameters[param_name] = (param + grad).astype(dtype)