diff --git a/areal/api/cli_args.py b/areal/api/cli_args.py index ea55f557a8..5a08a42727 100644 --- a/areal/api/cli_args.py +++ b/areal/api/cli_args.py @@ -338,30 +338,48 @@ class OptimizerConfig: type: str = field( default="adam", metadata={ - "help": "Optimizer type. For FSDP Engine, adam_bf16 enables memory-efficient BF16 optimizer states. " - "For Megatron Engine, adam_bf16 requires dtype=bfloat16 and is automatically converted to adam " - "with precision-aware optimizer enabled.", - "choices": ["adam", "sgd", "adam_bf16"], + "help": "Optimizer type. 'adam': AdamW (default). 'adam_bf16': memory-efficient BF16 AdamW " + "(FSDP: uses AnyPrecisionAdamW; Megatron: requires dtype=bfloat16, auto-converted to adam " + "with precision-aware optimizer). 'sgd': plain SGD. 'muon': Muon optimizer for >=2D params " + "with AdamW backend for <2D params (biases, norms, embeddings).", + "choices": ["adam", "sgd", "adam_bf16", "muon"], + }, + ) + lr: float = field( + default=1e-3, + metadata={ + "help": "Learning rate. When type='muon', this is shared by both the Muon sub-optimizer " + "(>=2D params) and the AdamW backend (<2D params). Pair " + "muon_scale_mode='spectral' with muon_extra_scale_factor=0.2 (Moonlight-style) to " + "make Muon's update RMS match AdamW so a single lr works for both." + }, + ) + weight_decay: float = field( + default=0.01, + metadata={ + "help": "Weight decay. Applied to all optimizer types including Muon (>=2D params) " + "and AdamW backend (<2D params)." }, ) - lr: float = field(default=1e-3, metadata={"help": "Learning rate"}) - weight_decay: float = field(default=0.01, metadata={"help": "Weight decay"}) beta1: float = field( default=0.9, metadata={ - "help": "Adam beta1 parameter. Only effective when optimizer_type is adam/adam_bf16" + "help": "Adam beta1 parameter. Used by adam/adam_bf16, and by the AdamW backend " + "when type='muon'. Not used by the Muon sub-optimizer itself." }, ) beta2: float = field( default=0.999, metadata={ - "help": "Adam beta2 parameter. Only effective when optimizer_type is adam/adam_bf16" + "help": "Adam beta2 parameter. Used by adam/adam_bf16, and by the AdamW backend " + "when type='muon'. Not used by the Muon sub-optimizer itself." }, ) eps: float = field( default=1e-8, metadata={ - "help": "Adam epsilon parameter. Only effective when optimizer_type is adam/adam_bf16" + "help": "Adam epsilon for numerical stability. Used by adam/adam_bf16, and by the " + "AdamW backend when type='muon'. Not used by the Muon sub-optimizer itself." }, ) min_lr_ratio: float = field( @@ -398,6 +416,50 @@ class OptimizerConfig: gradient_clipping: float = field( default=1.0, metadata={"help": "Gradient clipping threshold"} ) + muon_momentum: float = field( + default=0.95, + metadata={ + "help": "Muon momentum parameter. Only effective when optimizer_type is muon." + }, + ) + muon_use_nesterov: bool = field( + default=True, + metadata={ + "help": "Whether to use Nesterov momentum in Muon. Only effective when type='muon'. " + "Mirrors Megatron-Core OptimizerConfig.muon_use_nesterov." + }, + ) + muon_num_ns_steps: int = field( + default=5, + metadata={ + "help": "Number of Newton-Schulz iteration steps in Muon. Only effective when type='muon'. " + "Mirrors Megatron-Core OptimizerConfig.muon_num_ns_steps." + }, + ) + muon_scale_mode: str = field( + default="spectral", + metadata={ + "help": "Muon update scaling mode (final scale = mode_factor * muon_extra_scale_factor):" + "Only used when type='muon'. Mirrors Megatron-Core OptimizerConfig.muon_scale_mode.", + "choices": ["spectral", "unit_rms_norm", "shape_scaling"], + }, + ) + muon_extra_scale_factor: float = field( + default=1.0, + metadata={ + "help": "Extra multiplier on top of muon_scale_mode. Use 0.2 with " + "scale_mode='spectral' for Moonlight-style RMS-matched scaling. " + "Only used when type='muon'. Mirrors Megatron-Core OptimizerConfig.muon_extra_scale_factor." + }, + ) + + def __post_init__(self): + """Validate optimizer configuration.""" + valid_muon_scale_modes = {"spectral", "unit_rms_norm", "shape_scaling"} + if self.muon_scale_mode not in valid_muon_scale_modes: + raise ValueError( + f"muon_scale_mode must be one of {valid_muon_scale_modes}, got {self.muon_scale_mode!r}. " + ) @dataclass diff --git a/areal/engine/fsdp_engine.py b/areal/engine/fsdp_engine.py index 0e9682bc0d..22e4debdff 100644 --- a/areal/engine/fsdp_engine.py +++ b/areal/engine/fsdp_engine.py @@ -83,7 +83,11 @@ ) from areal.engine.fsdp_utils.checkpoint import DCPState from areal.engine.fsdp_utils.grad import fsdp2_clip_grad_norm -from areal.engine.fsdp_utils.optimizer import AnyPrecisionAdamW, PerLayerOptimWrapper +from areal.engine.fsdp_utils.muon import Muon as MuonOptimizer +from areal.engine.fsdp_utils.optimizer import ( + AnyPrecisionAdamW, + PerLayerOptimWrapper, +) from areal.engine.fsdp_utils.parallel import ParallelHelper, parallelize_model from areal.infra.dist_rollout import DistRolloutCoordinator from areal.infra.platforms import current_platform @@ -470,7 +474,7 @@ def initialize(self, addr: str | None, ft_spec: FinetuneSpec, *args, **kwargs): self._create_optimizer(ft_spec) if self.config.fsdp.per_layer_optim_step: - if self.optimizer_config.type != "adam": + if self.optimizer_config.type not in ("adam",): raise ValueError( f"per_layer_optim_step only supports 'adam' optimizer, got '{self.optimizer_config.type}'." ) @@ -1111,7 +1115,8 @@ def _create_optimizer(self, ft_spec: FinetuneSpec) -> None: "adam", "adam_bf16", "sgd", - ], "Only adam/adam_bf16/sgd optimizer is supported in this engine." + "muon", + ], "Only adam/adam_bf16/sgd/muon optimizer is supported in this engine." if self.optimizer_config.type in ["sgd", "adam_bf16"]: self.logger.warning( f"Using the '{self.optimizer_config.type}' optimizer with FSDP may be less stable. Consider using the 'adam' (AdamW) optimizer for improved stability and performance." @@ -1121,7 +1126,44 @@ def _create_optimizer(self, ft_spec: FinetuneSpec) -> None: beta1 = self.optimizer_config.beta1 beta2 = self.optimizer_config.beta2 eps = self.optimizer_config.eps - if self.optimizer_config.type == "adam": + if self.optimizer_config.type == "muon": + muon_params: list[torch.nn.Parameter] = [] + backend_params: list[torch.nn.Parameter] = [] + for p in self.model.parameters(): + if not p.requires_grad: + continue + if p.ndim >= 2: + muon_params.append(p) + else: + backend_params.append(p) + self.optimizer = MuonOptimizer( + [ + dict( + params=muon_params, + lr=lr, + momentum=self.optimizer_config.muon_momentum, + weight_decay=weight_decay, + scale_mode=self.optimizer_config.muon_scale_mode, + extra_scale_factor=self.optimizer_config.muon_extra_scale_factor, + nesterov=self.optimizer_config.muon_use_nesterov, + ns_steps=self.optimizer_config.muon_num_ns_steps, + use_muon=True, + ), + dict( + params=backend_params, + lr=lr, + betas=(beta1, beta2), + eps=eps, + weight_decay=weight_decay, + use_muon=False, + ), + ] + ) + self.logger.info( + f"Muon optimizer: {len(muon_params)} params (>=2D), " + f"AdamW backend: {len(backend_params)} params (<2D)" + ) + elif self.optimizer_config.type == "adam": self.optimizer = torch.optim.AdamW( self.model.parameters(), lr=lr, diff --git a/areal/engine/fsdp_utils/__init__.py b/areal/engine/fsdp_utils/__init__.py index 8620b34fda..a988a7fd3b 100644 --- a/areal/engine/fsdp_utils/__init__.py +++ b/areal/engine/fsdp_utils/__init__.py @@ -13,6 +13,7 @@ ) from transformers import PreTrainedModel +from areal.engine.fsdp_utils.muon import Muon from areal.engine.fsdp_utils.optimizer import ( AdamKernel, OptimKernel, @@ -33,6 +34,7 @@ "apply_fsdp2", "fsdp2_load_full_state_dict", "get_cosine_schedule_with_warmup", + "Muon", "PerLayerOptimWrapper", "OptimKernel", "AdamKernel", diff --git a/areal/engine/fsdp_utils/muon.py b/areal/engine/fsdp_utils/muon.py new file mode 100644 index 0000000000..35107e3768 --- /dev/null +++ b/areal/engine/fsdp_utils/muon.py @@ -0,0 +1,491 @@ +# SPDX-License-Identifier: Apache-2.0 +# +# Credits: +# - Keller Jordan's Muon optimizer: https://github.com/KellerJordan/Muon +# - Newton-Schulz replication strategy: https://gist.github.com/main-horse/7314170780e36f7443d1926418d75823 +# - Moonlight RMS scaling: https://arxiv.org/abs/2502.16982 + +import math +from collections import deque +from typing import Protocol + +import torch +from torch import Tensor +from torch.distributed import gather, scatter +from torch.distributed.tensor import DTensor, distribute_tensor + +__all__ = ["Muon"] + + +# --------------------------------------------------------------------------- +# Newton-Schulz iteration (bf16-accelerated, batched) +# --------------------------------------------------------------------------- + + +@torch.compile(fullgraph=True) +def _nsloop_torch(X: Tensor, steps: int, *, a=3.4445, b=-4.7750, c=2.0315): + """Compiled Newton-Schulz inner loop. + + When compiled, inductor fuses this into efficient matmul + triton kernels. + """ + for _ in range(steps): + A = X @ X.mT + B = b * A + c * A @ A + X = a * X + B @ X + return X + + +def zeropower_via_newtonschulz5(G: Tensor, steps: int) -> Tensor: + """Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. + + We opt to use a quintic iteration whose coefficients are selected to maximize + the slope at zero. For the purpose of minimizing steps, it turns out to be + empirically effective to keep increasing the slope at zero even beyond the point + where the iteration no longer converges all the way to one everywhere on the + interval. This iteration therefore does not produce UV^T but rather something + like US'V^T where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns + out not to hurt model performance at all relative to UV^T, where USV^T = G is + the SVD. + + Credits: @scottjmaddox (batched impl), @YouJiacheng (record practice). + """ + assert G.ndim >= 2 # batched Muon support + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + # Perform the NS iterations + X = _nsloop_torch(X, steps) + + if G.size(-2) > G.size(-1): + X = X.mT + return X + + +# --------------------------------------------------------------------------- +# Muon sub-operations +# --------------------------------------------------------------------------- + + +def apply_momentum( + grad: Tensor, momentum_buf: Tensor, beta: float, nesterov: bool +) -> Tensor: + """Apply momentum with lerp_ formulation and optional Nesterov.""" + momentum_buf.lerp_(grad, 1 - beta) + update = grad.lerp_(momentum_buf, beta) if nesterov else momentum_buf + if update.ndim == 4: # conv filters: flatten to 2D + update = update.view(len(update), -1) + return update + + +def apply_scaling( + grad: Tensor, + mode: str = "spectral", + extra_scale_factor: float = 1.0, +) -> Tensor: + """Post-Newton-Schulz update scaling. + + Naming aligned with Megatron-Core / emerging_optimizers (NVIDIA-NeMo). + + Final scale = scale_factor(mode) * extra_scale_factor, where: + - 'spectral' : sqrt(max(m, n)) + Kimi/Moonlight (arXiv:2502.16982); emerging_optimizers default. + - 'unit_rms_norm' : sqrt(m / n) + Scion (arXiv:2502.07529) / Bernstein + (https://jeremybernste.in/writing/deriving-muon). + - 'shape_scaling' : max(1, m / n)**0.5 + Keller Jordan original (https://kellerjordan.github.io/posts/muon). + + Set extra_scale_factor=0.2 with mode='spectral' to reproduce the legacy + Moonlight `https://github.com/MoonshotAI/Moonlight/blob/5afcb6911077e7f182d05865fe90d9f39abcbcbd/examples/toy_train.py#L146` + setting (= 0.2 * sqrt(max(m, n))), which + approximately matches AdamW's update RMS norm so a single lr works for + both Muon and the AdamW backend. + """ + m = grad.size(-2) + n = grad.size(-1) + if mode == "spectral": + scale = math.sqrt(max(m, n)) + elif mode == "unit_rms_norm": + scale = math.sqrt(m / n) + elif mode == "shape_scaling": + scale = max(1, m / n) ** 0.5 + else: + raise ValueError( + f"Invalid muon_scale_mode {mode!r}. Valid: " + "{'spectral', 'unit_rms_norm', 'shape_scaling'}." + ) + grad *= scale * extra_scale_factor + return grad + + +def adam_update( + grad: Tensor, + buf1: Tensor, + buf2: Tensor, + step: int, + betas: tuple[float, float], + eps: float, +) -> Tensor: + """Standard Adam update (bias-corrected).""" + buf1.lerp_(grad, 1 - betas[0]) + buf2.lerp_(grad.square(), 1 - betas[1]) + buf1c = buf1 / (1 - betas[0] ** step) + buf2c = buf2 / (1 - betas[1] ** step) + return buf1c / (buf2c.sqrt() + eps) + + +# --------------------------------------------------------------------------- +# Work protocol & implementations for distributed NS +# --------------------------------------------------------------------------- + + +class Work(Protocol): + """Protocol for distributed Muon work items (gather → NS → scatter).""" + + def __init__(self, param, state, group, index: int): ... + def start(self): ... + def finish(self): ... + + +class Fsdp1dWork: + """Muon work for FSDP2 1D mesh: gather to one rank, NS, scatter back.""" + + def __init__(self, param, state, group, index: int): + self.param = param + self.state = state + self.group = group + self.index = index + self._intermediate_state = None + + def start(self): + self.param.grad = apply_momentum( + self.param.grad, + self.state["momentum_buffer"], + self.group["momentum"], + self.group["nesterov"], + ) + + grad = self.param.grad + assert isinstance(grad, DTensor), "only supports DTensor parameters" + assert grad.device_mesh.ndim == 1, "only supports 1D mesh" + + rank = grad.device_mesh.get_rank() + world_size = grad.device_mesh.size() + pg = grad.device_mesh.get_group() + + dest_rank = self.index % world_size + + if rank == dest_rank: + gather_lists = [ + torch.zeros_like(input=grad.to_local()) for _ in range(world_size) + ] + gather_handle = gather( + grad.to_local(), + gather_lists, + group_dst=dest_rank, + group=pg, + async_op=True, + ) + else: + gather_lists = None + gather_handle = gather( + grad.to_local(), None, group_dst=dest_rank, group=pg, async_op=True + ) + + self._intermediate_state = [dest_rank, gather_handle, gather_lists] + + def finish(self): + assert self._intermediate_state is not None, "start() must be called first" + + grad = self.param.grad + rank = grad.device_mesh.get_rank() + world_size = grad.device_mesh.size() + pg = grad.device_mesh.get_group() + + dest_rank, gather_handle, gather_lists = self._intermediate_state + gather_handle.wait() + if rank == dest_rank: + g_full_block = torch.cat(gather_lists, dim=0) + g_full_block.copy_( + zeropower_via_newtonschulz5(g_full_block, self.group["ns_steps"]) + ) + g_full_block = g_full_block.type_as(grad) + chunks = list(g_full_block.chunk(chunks=world_size, dim=0)) + scatter( + grad.to_local(), + scatter_list=chunks, + src=dest_rank, + group=pg, + async_op=False, + ) + else: + scatter(grad.to_local(), None, src=dest_rank, group=pg, async_op=False) + + update = apply_scaling( + grad, + self.group["scale_mode"], + self.group["extra_scale_factor"], + ) + + self.param.mul_(1 - self.group["lr"] * self.group["weight_decay"]) + self.param.add_(update.reshape(self.param.shape), alpha=-self.group["lr"]) + + +class TpFsdp2dWork: + """Muon work for TP + FSDP 2D mesh (strategy A: DTensor redistribute). + + Layout assumption (matches ``areal.engine.fsdp_utils.parallel.parallelize_model``): + - Mesh dims = (dp_sp, tp) i.e. TP is applied first, FSDP2 wraps on dp_sp outside. + - dp_sp placement is always ``Shard(0)`` (FSDP2 invariant). + - tp placement is one of ``Shard(0)`` (Colwise), ``Shard(1)`` (Rowwise), + or ``Replicate()`` (e.g. ReplicateParallel / SequenceParallel 1D params + that won't hit Muon anyway). + + Strategy A (correctness-first, redundant NS on every rank): + 1. local momentum on the sharded grad (element-wise, safe under Shard/Replicate). + 2. ``grad.full_tensor()`` does an all-gather along every mesh dim -> full matrix + replicated on every rank. + 3. Every rank runs Newton-Schulz on the full matrix. Since NS is deterministic + and inputs are identical across ranks, results are bit-wise identical. + 4. ``distribute_tensor`` re-shards the result back using the original + ``device_mesh`` + ``placements``, letting DTensor handle the Shard(0) / + Shard(1) / Replicate slicing automatically. + + NOTE: This is a simple but correct baseline. It trades communication volume + (everyone all-gathers) and compute (every rank repeats NS) for code simplicity + and DTensor-handled correctness across mixed Col/Row-wise plans. A future + ``Fsdp2dWork`` could instead gather to a single rank (as in ``Fsdp1dWork``) + to save redundant NS compute and allow pipelined prefetch. + """ + + def __init__(self, param, state, group, index: int): + self.param = param + self.state = state + self.group = group + self.index = index + self._intermediate_state = None + + def start(self): + self.param.grad = apply_momentum( + self.param.grad, + self.state["momentum_buffer"], + self.group["momentum"], + self.group["nesterov"], + ) + + grad = self.param.grad + assert isinstance(grad, DTensor), "only supports DTensor parameters" + assert grad.device_mesh.ndim == 2, "TpFsdp2dWork expects a 2D mesh (dp_sp, tp)" + + # Strategy A performs NS independently on every rank after a full gather, + # so there is no cross-rank async work to overlap here. We just cache + # the original metadata for ``finish`` to re-shard the result. + self._intermediate_state = (grad.device_mesh, grad.placements) + + def finish(self): + assert self._intermediate_state is not None, "start() must be called first" + mesh, placements = self._intermediate_state + + grad = self.param.grad + + # 1) All-gather along every mesh dim -> replicated full matrix on each rank. + g_full = grad.full_tensor() + + # 2) Newton-Schulz on the full matrix (deterministic, identical across ranks). + g_full = zeropower_via_newtonschulz5(g_full, self.group["ns_steps"]) + g_full = g_full.type_as(grad) + + # 3) Re-shard back to the original (dp_sp, tp) placements and write the + # result into grad's local storage. ``distribute_tensor`` correctly + # handles Shard(0) / Shard(1) / Replicate on each mesh dim; the final + # ``copy_`` makes the mutation to ``grad`` explicit (no reliance on + # ``to_local()`` returning a view). + new_local = distribute_tensor(g_full, mesh, placements).to_local() + grad.to_local().copy_(new_local) + + update = apply_scaling( + grad, + self.group["scale_mode"], + self.group["extra_scale_factor"], + ) + + self.param.mul_(1 - self.group["lr"] * self.group["weight_decay"]) + self.param.add_(update.reshape(self.param.shape), alpha=-self.group["lr"]) + + +class EpFsdp2dWork: + """Muon work for EP + FSDP 2D mesh (not yet implemented).""" + + def __init__(self, param, state, group, index: int): + raise NotImplementedError("EP + FSDP 2D mesh Muon is not yet implemented") + + +class TpEpFsdp3dWork: + """Muon work for TP + EP + FSDP 3D mesh (not yet implemented).""" + + def __init__(self, param, state, group, index: int): + raise NotImplementedError("TP + EP + FSDP 3D mesh Muon is not yet implemented") + + +class SingleDeviceWork: + """Muon work for single device (no distributed communication).""" + + def __init__(self, param, state, group, index: int): + self.param = param + self.state = state + self.group = group + + def start(self): + update = apply_momentum( + self.param.grad, + self.state["momentum_buffer"], + self.group["momentum"], + self.group["nesterov"], + ) + update = zeropower_via_newtonschulz5(update, self.group["ns_steps"]) + update = update.to(self.param.grad.dtype) + update = apply_scaling( + update, + self.group["scale_mode"], + self.group["extra_scale_factor"], + ) + self.param.mul_(1 - self.group["lr"] * self.group["weight_decay"]) + self.param.add_(update.reshape(self.param.shape), alpha=-self.group["lr"]) + + def finish(self): + pass + + +# --------------------------------------------------------------------------- +# Muon optimizer (unified: Muon for >=2D, Adam for <2D) +# --------------------------------------------------------------------------- + + +class Muon(torch.optim.Optimizer): + """DTensor-aware Muon optimizer with built-in Adam backend. + + Original code: https://github.com/KellerJordan/Muon/blob/f90a42b28e00b8d9d2d05865fe90d9f39abcbcbd/muon.py + Also supports single device variant. + + Notable changes: + - DTensor/FSDP2 native: uses gather/scatter for distributed NS instead of DDP. + - ``scale_mode`` / ``extra_scale_factor`` arguments aligned with Megatron-Core / + emerging_optimizers (NVIDIA-NeMo). See :func:`apply_scaling` for details. + + Example:: + + optimizer = Muon([ + dict(params=model.square_params(), lr=1e-3, use_muon=True), + dict(params=model.non_square_params(), lr=1e-3, use_muon=False), + ]) + + Param group args (``use_muon=True``): + lr, momentum, weight_decay, scale_mode, extra_scale_factor, nesterov, ns_steps + + Param group args (``use_muon=False``): + lr, betas, eps, weight_decay + """ + + def __init__(self, param_groups): + for group in param_groups: + assert "use_muon" in group + if group["use_muon"]: + group.setdefault("lr", 0.02) + group.setdefault("momentum", 0.95) + group.setdefault("weight_decay", 0) + group.setdefault("scale_mode", "spectral") + group.setdefault("extra_scale_factor", 1.0) + group.setdefault("nesterov", True) + group.setdefault("ns_steps", 5) + assert set(group.keys()) == { + "params", + "lr", + "momentum", + "weight_decay", + "use_muon", + "scale_mode", + "extra_scale_factor", + "nesterov", + "ns_steps", + } + else: + group.setdefault("lr", 3e-4) + group.setdefault("betas", (0.9, 0.95)) + group.setdefault("eps", 1e-10) + group.setdefault("weight_decay", 0) + assert set(group.keys()) == { + "params", + "lr", + "betas", + "eps", + "weight_decay", + "use_muon", + } + super().__init__(param_groups, dict()) + + def _get_work_class(self, p: Tensor) -> tuple[type[Work], int]: + """Dispatch the work class based on mesh dimensionality.""" + if isinstance(p, DTensor): + if p.device_mesh.ndim == 1: + return Fsdp1dWork, 8 + elif p.device_mesh.ndim == 2: + return TpFsdp2dWork, 8 + else: + raise ValueError(f"Unsupported mesh dimension: {p.device_mesh.ndim}") + else: + return SingleDeviceWork, 1 + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + dq: deque[Work] = deque() + + for group in self.param_groups: + if group["use_muon"]: + for i, p in enumerate(group["params"]): + if p.grad is None: + p.grad = torch.zeros_like(p) # Force synchronization + state = self.state[p] + if len(state) == 0: + state["momentum_buffer"] = torch.zeros_like(p) + + class_work, prefetch_factor = self._get_work_class(p) + + work = class_work(p, state, group, i) + work.start() + dq.append(work) + + if len(dq) > prefetch_factor: + dq.popleft().finish() + else: + for p in group["params"]: + if p.grad is None: + p.grad = torch.zeros_like(p) # Force synchronization + state = self.state[p] + if len(state) == 0: + state["exp_avg"] = torch.zeros_like(p) + state["exp_avg_sq"] = torch.zeros_like(p) + state["step"] = 0 + state["step"] += 1 + update = adam_update( + p.grad, + state["exp_avg"], + state["exp_avg_sq"], + state["step"], + group["betas"], + group["eps"], + ) + p.mul_(1 - group["lr"] * group["weight_decay"]) + p.add_(update, alpha=-group["lr"]) + + for work in dq: + work.finish() + + return loss diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index a512469bc0..ed6f08c94f 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -25,6 +25,7 @@ from megatron.core.distributed import finalize_model_grads from megatron.core.optimizer import OptimizerConfig as MCoreOptimizerConfig from megatron.core.optimizer import get_megatron_optimizer +from megatron.core.optimizer.muon import get_megatron_muon_optimizer from megatron.core.optimizer_param_scheduler import OptimizerParamScheduler from megatron.core.pipeline_parallel import get_forward_backward_func from megatron.core.transformer import TransformerConfig @@ -1282,12 +1283,30 @@ def _create_optimizer(self, ft_spec: FinetuneSpec) -> None: assert self.optimizer_config.type in [ "adam", "sgd", - ], "Only AdamW/sgd optimizer is supported in this engine." + "muon", + ], "MegatronEngine supports 'adam'/'sgd'/'muon' optimizer." if self.optimizer_config.type == "sgd": self.logger.warning( "Using the 'sgd' optimizer with Megatron may be less stable. Consider using the 'adam' (AdamW) optimizer for improved stability." ) + if self.optimizer_config.type == "muon": + # Native Megatron Muon has two hard constraints (see + # megatron/core/optimizer/muon.py::get_megatron_muon_optimizer): + # - use_distributed_optimizer must be False (grad buffer coupling) + # - fp16 is not supported (bf16 or fp32 only) + if use_distributed_optimizer: + self.logger.warning( + "Muon is incompatible with Megatron distributed optimizer; " + "forcing use_distributed_optimizer=False for this run." + ) + use_distributed_optimizer = False + if self.dtype is torch.float16: + raise ValueError( + "Muon optimizer does not support fp16 in Megatron-Core; " + "use bf16 or fp32." + ) + # Make megatron optimizer config mcore_opt_config = MCoreOptimizerConfig( optimizer=self.optimizer_config.type, @@ -1304,6 +1323,32 @@ def _create_optimizer(self, ft_spec: FinetuneSpec) -> None: clip_grad=self.optimizer_config.gradient_clipping, fp8_recipe=(self.fp8_config.recipe if self.enable_fp8 else None), ) + + # Forward Muon-specific hyperparameters onto Megatron-Core's OptimizerConfig. + # AReaL's muon_* fields are 1:1 aligned with Megatron-Core >= 0.17, so no + # translation is required. Fields not exposed by AReaL (muon_coefficient_type, + # muon_split_qkv, muon_tp_mode, muon_fp32_matmul_prec) keep their Megatron + # defaults. + if self.optimizer_config.type == "muon": + muon_passthrough_fields = ( + "muon_momentum", + "muon_use_nesterov", + "muon_num_ns_steps", + "muon_scale_mode", + "muon_extra_scale_factor", + ) + for attr in muon_passthrough_fields: + if hasattr(mcore_opt_config, attr): + setattr( + mcore_opt_config, attr, getattr(self.optimizer_config, attr) + ) + else: + self.logger.warning( + f"Megatron-Core OptimizerConfig has no attribute '{attr}'; " + "your Megatron-Core may be too old to fully support Muon." + ) + # AdamW backend for embeddings/biases/norms (AReaL policy). + mcore_opt_config.muon_scalar_optimizer = "adam" mcore_opt_config.overlap_param_gather_with_optimizer_step = ( self.mcore_config.overlap_param_gather_with_optimizer_step ) @@ -1321,7 +1366,26 @@ def _create_optimizer(self, ft_spec: FinetuneSpec) -> None: torch, self.mcore_config.exp_avg_sq_dtype ) - self.optimizer = get_megatron_optimizer(mcore_opt_config, self.model) + # Muon is incompatible with Megatron's precision-aware optimizer path + # (OptimizerConfig.__post_init__ asserts ``optimizer == 'adam'``). + if ( + self.optimizer_config.type == "muon" + and mcore_opt_config.use_precision_aware_optimizer + ): + self.logger.warning( + "Muon optimizer is incompatible with " + "use_precision_aware_optimizer=True; disabling it for this run." + ) + mcore_opt_config.use_precision_aware_optimizer = False + + if self.optimizer_config.type == "muon": + # Megatron-LM's native Muon path: builds TensorParallelMuon for 2D + # linear weights and chains an Adam optimizer for embeddings, + # biases and norms. Returns a ChainedOptimizer that is transparently + # compatible with OptimizerParamScheduler and MegatronCheckpointManager. + self.optimizer = get_megatron_muon_optimizer(mcore_opt_config, self.model) + else: + self.optimizer = get_megatron_optimizer(mcore_opt_config, self.model) warmup_steps_proportion = self.optimizer_config.warmup_steps_proportion warmup_steps = int(warmup_steps_proportion * ft_spec.total_train_steps) @@ -1349,8 +1413,16 @@ def _create_optimizer(self, ft_spec: FinetuneSpec) -> None: ) self.lr_scheduler = lr_scheduler - # MegatronCheckpointManager now only support distributed optimizer which lora does not support - if not self.config.use_lora: + # MegatronCheckpointManager currently requires Megatron's distributed + # optimizer for checkpoint format. Two cases fall outside this: + # - LoRA does not use distributed optimizer. + # - Muon is incompatible with distributed optimizer (enforced by + # Megatron's own get_megatron_muon_optimizer) and therefore runs + # with use_distributed_optimizer=False. + # In both cases we skip the checkpoint manager; attempting save/load + # via self.checkpointer will then raise AttributeError, which is the + # correct signal that checkpointing is not wired for this backend combo. + if not self.config.use_lora and use_distributed_optimizer: self.checkpointer = MegatronCheckpointManager( model=self.model, optimizer=self.optimizer, @@ -1359,6 +1431,14 @@ def _create_optimizer(self, ft_spec: FinetuneSpec) -> None: use_checkpoint_opt_param_scheduler=self.mcore_config.use_checkpoint_opt_param_scheduler, async_save=self.mcore_config.async_save, ) + elif self.optimizer_config.type == "muon": + self.logger.warning( + "MegatronCheckpointManager is not constructed because Muon " + "forces use_distributed_optimizer=False, and the checkpoint " + "manager currently only supports the distributed-optimizer " + "format. save_model()/load_model() will be unavailable for " + "this run." + ) def _check_rollout_engine_connected(self) -> None: """Validate that rollout engine has been connected via connect_engine().""" diff --git a/docs/en/_toc.yml b/docs/en/_toc.yml index 845032dd7d..3792a28cbd 100644 --- a/docs/en/_toc.yml +++ b/docs/en/_toc.yml @@ -40,6 +40,7 @@ parts: - file: algorithms/prox_approx - caption: Reference chapters: + - file: reference/optimizer - file: reference/checkpointing - file: reference/metrics_tracking - file: reference/alloc_mode diff --git a/docs/en/reference/optimizer.md b/docs/en/reference/optimizer.md new file mode 100644 index 0000000000..675d30734d --- /dev/null +++ b/docs/en/reference/optimizer.md @@ -0,0 +1,96 @@ +(section-optimizer-guide)= + +# Optimizer Configuration Guide + +AReaL supports multiple optimizer types, configurable via the `optimizer.type` field. +This document covers the support matrix across training backends and the implementation +differences of the Muon optimizer. + +## Supported Optimizer Types + +| Type | Description | +| ----------- | -------------------------------------------------------------------------------------------------- | +| `adam` | AdamW optimizer (default) | +| `adam_bf16` | BF16-precision AdamW, reduces optimizer state memory | +| `sgd` | Standard SGD | +| `muon` | Muon optimizer: Newton-Schulz orthogonalized updates for ≥2D params, AdamW backend for \<2D params | + +## Engine Support Matrix + +| Optimizer | FSDP Engine | Megatron Engine | Archon Engine | +| ----------- | :--------------------: | :----------------------------: | :------------------: | +| `adam` | ✅ | ✅ | ✅ | +| `adam_bf16` | ✅ (AnyPrecisionAdamW) | ✅ (precision-aware optimizer) | ❌ | +| `sgd` | ✅ | ✅ | ✅ | +| `muon` | ✅ | ✅ (Megatron-Core ≥ 0.17) | ❌ (not implemented) | + +### Notes + +- **FSDP Engine**: `adam_bf16` uses `AnyPrecisionAdamW`, storing momentum and variance + in BF16. +- **Megatron Engine**: `adam_bf16` requires model dtype to be bfloat16; it is + auto-converted to adam with precision-aware optimizer enabled. +- **Archon Engine**: Currently only supports `adam` and `sgd`. Muon support is under + development. + +## Muon Optimizer + +### Overview + +Muon (MomentUm Orthogonalized by Newton-schulz) is an optimizer that applies approximate +orthogonalization to gradient momentum via Newton-Schulz iteration. The core idea is to +impose an orthogonal constraint on weight matrix gradients, making update directions +more "uniform" in parameter space and accelerating convergence. + +### Reference Implementations and Papers + +| Resource | Link | +| ---------------------------------------- | -------------------------------------------------- | +| Original implementation (Keller Jordan) | https://github.com/KellerJordan/Muon | +| Moonlight paper (RMS scaling) | https://arxiv.org/abs/2502.16982 | +| AReaL FSDP implementation | `areal/engine/fsdp_utils/muon.py` | +| Emerging-Optimizers (Megatron-Core Muon) | https://github.com/NVIDIA-NeMo/Emerging-Optimizers | + +### FSDP vs Megatron Implementation Differences + +The FSDP Engine and Megatron Engine differ significantly in how they partition +parameters for Muon: + +| Dimension | FSDP Engine | Megatron Engine | +| --------------------------------- | ----------------------------------------------------------------- | ------------------------------------------------------------------------------------- | +| **Muon parameter scope** | **All** ≥2D parameters (including embedding weight matrices) | **Linear layer weights** | +| **AdamW backend parameters** | All \<2D parameters (bias, LayerNorm weight/bias) | Embeddings, biases, norms, and non-Linear 2D parameters | +| **Distributed NS implementation** | DTensor gather/scatter (FSDP2 native) | TP-aware `TensorParallelMuon` (distributed Newton-Schulz over TP communication group) | +| **TP + EP support** | TP + FSDP 2D mesh ✅; TP + EP + FSDP 3D mesh ❌ (not implemented) | Full TP / EP / PP support | + +### Configuration Example + +```yaml +optimizer: + type: muon + lr: 2e-3 # Shared lr (Muon and AdamW backend) + muon_momentum: 0.95 + muon_use_nesterov: true + muon_num_ns_steps: 5 + muon_scale_mode: spectral # spectral / unit_rms_norm / shape_scaling + muon_extra_scale_factor: 0.2 # 0.2 + spectral = Moonlight-style RMS-matched scaling + weight_decay: 0.05 + beta1: 0.9 # AdamW backend params + beta2: 0.95 + eps: 1e-5 + lr_scheduler_type: cosine + warmup_steps_proportion: 0.03 +``` + +### Configuration Parameters + +| Parameter | Type | Default | Description | +| ------------------------- | ----- | ------------------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | +| `lr` | float | 0.001 | Shared learning rate for both Muon (≥2D params) and AdamW backend (\<2D params). A single lr works well when pairing `muon_scale_mode=spectral` with `muon_extra_scale_factor=0.2` (Moonlight-style) | +| `muon_momentum` | float | 0.95 | Muon momentum coefficient | +| `muon_use_nesterov` | bool | true | Whether to use Nesterov momentum | +| `muon_num_ns_steps` | int | 5 | Number of Newton-Schulz iteration steps | +| `muon_scale_mode` | str | "spectral" | Update scaling mode. `spectral`: `sqrt(max(m, n))` (Kimi/Moonlight, emerging_optimizers default). `unit_rms_norm`: `sqrt(m / n)` (Scion / Bernstein). `shape_scaling`: `max(1, m/n)**0.5` (Keller Jordan original) | +| `muon_extra_scale_factor` | float | 1.0 | Extra multiplicative scale; final scale = `scale_factor(mode) * muon_extra_scale_factor`. Use `0.2` with `spectral` to reproduce Moonlight-style RMS-matched scaling | +| `weight_decay` | float | 0.01 | Weight decay, applied to both Muon and AdamW backend | +| `beta1` / `beta2` / `eps` | float | 0.9 / 0.999 / 1e-8 | AdamW backend hyperparameters | diff --git a/docs/zh/_toc.yml b/docs/zh/_toc.yml index 16f9e7b713..661bef6a2f 100644 --- a/docs/zh/_toc.yml +++ b/docs/zh/_toc.yml @@ -40,6 +40,7 @@ parts: - file: algorithms/prox_approx - caption: 参考 chapters: + - file: reference/optimizer - file: reference/checkpointing - file: reference/metrics_tracking - file: reference/alloc_mode diff --git a/docs/zh/reference/optimizer.md b/docs/zh/reference/optimizer.md new file mode 100644 index 0000000000..ca00b2de55 --- /dev/null +++ b/docs/zh/reference/optimizer.md @@ -0,0 +1,89 @@ +(section-optimizer-guide)= + +# 优化器配置指南 + +AReaL 支持多种优化器类型,可通过 `optimizer.type` 字段进行配置。本文档介绍各优化器在不同训练后端的支持情况,以及 Muon 优化器的实现差异。 + +## 支持的优化器类型 + +| 类型 | 说明 | +| ----------- | ------------------------------------------------------------------------------- | +| `adam` | AdamW 优化器(默认) | +| `adam_bf16` | BF16 精度的 AdamW,降低优化器状态显存占用 | +| `sgd` | 标准 SGD | +| `muon` | Muon 优化器,对 ≥2D 参数使用 Newton-Schulz 正交化更新,\<2D 参数使用 AdamW 后端 | + +## 各引擎支持矩阵 + +| 优化器 | FSDP Engine | Megatron Engine | Archon Engine | +| ----------- | :--------------------: | :----------------------------: | :-----------: | +| `adam` | ✅ | ✅ | ✅ | +| `adam_bf16` | ✅ (AnyPrecisionAdamW) | ✅ (precision-aware optimizer) | ❌ | +| `sgd` | ✅ | ✅ | ✅ | +| `muon` | ✅ | ✅ (Megatron-Core ≥ 0.17) | ❌ (未实现) | + +### 备注 + +- **FSDP Engine** 中 `adam_bf16` 使用 `AnyPrecisionAdamW`,将 momentum 和 variance 存储为 BF16。 +- **Megatron Engine** 中 `adam_bf16` 要求模型 dtype 为 bfloat16,会自动转换为 adam 并启用 + precision-aware optimizer。 +- **Archon Engine** 目前仅支持 `adam` 和 `sgd`,Muon 支持尚在开发中。 + +## Muon 优化器 + +### 简介 + +Muon (MomentUm Orthogonalized by Newton-schulz) 是一种利用 Newton-Schulz +迭代对梯度动量进行近似正交化的优化器。其核心思想是:对权重矩阵的梯度施加正交约束,使更新方向在参数空间中更加"均匀",从而加速收敛。 + +### 参考实现与论文 + +| 资源 | 链接 | +| ---------------------------------------- | -------------------------------------------------- | +| 原始实现 (Keller Jordan) | https://github.com/KellerJordan/Muon | +| Moonlight 论文 (RMS scaling) | https://arxiv.org/abs/2502.16982 | +| AReaL FSDP 实现 | `areal/engine/fsdp_utils/muon.py` | +| Emerging-Optimizers (Megatron-Core Muon) | https://github.com/NVIDIA-NeMo/Emerging-Optimizers | + +### FSDP 与 Megatron 实现差异 + +FSDP Engine 和 Megatron Engine 对 Muon 的参数分组策略存在显著差异: + +| 维度 | FSDP Engine | Megatron Engine | +| ------------------ | -------------------------------------------------------- | ------------------------------------------------------------------------ | +| **Muon 参数范围** | **所有** ≥2D 参数(包括 embedding 权重矩阵) | **Linear 层的 weight** | +| **AdamW 后端参数** | 所有 \<2D 参数(bias、LayerNorm weight/bias) | embedding、bias、norm 以及非 Linear 的 2D 参数 | +| **分布式 NS 实现** | DTensor gather/scatter(FSDP2 原生) | TP-aware 的 `TensorParallelMuon`(利用 TP 通信组做分布式 Newton-Schulz) | +| **TP + EP 支持** | TP + FSDP 2D mesh ✅;TP + EP + FSDP 3D mesh ❌ (未实现) | 完整支持 TP / EP / PP | + +### 配置示例 + +```yaml +optimizer: + type: muon + lr: 2e-3 # 统一 lr(Muon 和 AdamW 后端共用) + muon_momentum: 0.95 + muon_use_nesterov: true + muon_num_ns_steps: 5 + muon_scale_mode: spectral # spectral / unit_rms_norm / shape_scaling + muon_extra_scale_factor: 0.2 # 0.2 + spectral 等价于 Moonlight 风格的 RMS 对齐 + weight_decay: 0.05 + beta1: 0.9 # AdamW 后端参数 + beta2: 0.95 + eps: 1e-5 + lr_scheduler_type: cosine + warmup_steps_proportion: 0.03 +``` + +### 配置参数说明 + +| 参数 | 类型 | 默认值 | 说明 | +| ------------------------- | ----- | ------------------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `lr` | float | 0.001 | 统一学习率,Muon(≥2D 参数)和 AdamW 后端(\<2D 参数)共用。配合 `muon_scale_mode=spectral` + `muon_extra_scale_factor=0.2`(Moonlight 风格)时单一 lr 即可 | +| `muon_momentum` | float | 0.95 | Muon 动量系数 | +| `muon_use_nesterov` | bool | true | 是否使用 Nesterov 动量 | +| `muon_num_ns_steps` | int | 5 | Newton-Schulz 迭代步数 | +| `muon_scale_mode` | str | "spectral" | 更新缩放模式。`spectral`:`sqrt(max(m, n))`(Kimi/Moonlight、emerging_optimizers 默认);`unit_rms_norm`:`sqrt(m / n)`(Scion / Bernstein);`shape_scaling`:`max(1, m/n)**0.5`(Keller Jordan 原版) | +| `muon_extra_scale_factor` | float | 1.0 | 额外乘性缩放,最终 scale = `scale_factor(mode) * muon_extra_scale_factor`。配合 `spectral` 使用 `0.2` 可复刻 Moonlight 风格的 RMS 对齐缩放 | +| `weight_decay` | float | 0.01 | 权重衰减,同时作用于 Muon 和 AdamW 后端 | +| `beta1` / `beta2` / `eps` | float | 0.9 / 0.999 / 1e-8 | AdamW 后端的超参数 | diff --git a/pyproject.toml b/pyproject.toml index 01b79b870c..bc38050093 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -162,6 +162,7 @@ megatron = [ "megatron-core==0.17.0; python_version >= '3.12' and sys_platform == 'linux' and platform_machine == 'x86_64'", "mbridge @ git+https://github.com/ISEEKYAN/mbridge.git@310e8fb; sys_platform == 'linux' and platform_machine == 'x86_64'", "megatron-bridge==0.4.0; python_version >= '3.12' and sys_platform == 'linux' and platform_machine == 'x86_64'", + "emerging-optimizers @ git+https://github.com/NVIDIA-NeMo/Emerging-Optimizers.git@v0.2.0; python_version >= '3.12' and sys_platform == 'linux' and platform_machine == 'x86_64'", ] # Convenience extra for CUDA training packages (no inference backend) cuda-train = [ diff --git a/pyproject.vllm.toml b/pyproject.vllm.toml index 3e998c4ffd..334bc49ff1 100644 --- a/pyproject.vllm.toml +++ b/pyproject.vllm.toml @@ -174,6 +174,7 @@ megatron = [ "megatron-core==0.17.0; python_version >= '3.12' and sys_platform == 'linux' and platform_machine == 'x86_64'", "mbridge @ git+https://github.com/ISEEKYAN/mbridge.git@310e8fb; sys_platform == 'linux' and platform_machine == 'x86_64'", "megatron-bridge==0.4.0; python_version >= '3.12' and sys_platform == 'linux' and platform_machine == 'x86_64'", + "emerging-optimizers @ git+https://github.com/NVIDIA-NeMo/Emerging-Optimizers.git@v0.2.0; python_version >= '3.12' and sys_platform == 'linux' and platform_machine == 'x86_64'", ] cuda-train = [ "areal[tms]", diff --git a/tests/test_muon_optimizer.py b/tests/test_muon_optimizer.py new file mode 100644 index 0000000000..6a5d624f72 --- /dev/null +++ b/tests/test_muon_optimizer.py @@ -0,0 +1,177 @@ +# SPDX-License-Identifier: Apache-2.0 + +"""Unit tests for the unified Muon optimizer.""" + +from __future__ import annotations + +import torch +import torch.nn as nn + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_simple_model() -> nn.Module: + """A small model with both >=2D and <2D params.""" + model = nn.Sequential( + nn.Linear(16, 32), + nn.LayerNorm(32), + nn.Linear(32, 8), + ) + return model + + +def _make_param_groups(model: nn.Module) -> list[dict]: + """Split parameters into muon (>=2D) and backend (<2D) groups.""" + muon_params = [p for p in model.parameters() if p.requires_grad and p.ndim >= 2] + backend_params = [p for p in model.parameters() if p.requires_grad and p.ndim < 2] + return [ + dict(params=muon_params, lr=1e-2, use_muon=True), + dict(params=backend_params, lr=1e-3, use_muon=False), + ] + + +# --------------------------------------------------------------------------- +# Tests for unified Muon optimizer +# --------------------------------------------------------------------------- + + +class TestMuonOptimizer: + """Tests for the unified Muon optimizer with built-in Adam backend.""" + + def test_step_and_zero_grad(self): + from areal.engine.fsdp_utils.muon import Muon + + model = _make_simple_model() + opt = Muon(_make_param_groups(model)) + + x = torch.randn(4, 16) + loss = model(x).sum() + loss.backward() + + opt.step() + opt.zero_grad() + + for p in model.parameters(): + assert p.grad is None or (p.grad == 0).all() + + def test_param_groups_structure(self): + from areal.engine.fsdp_utils.muon import Muon + + model = _make_simple_model() + groups = _make_param_groups(model) + opt = Muon(groups) + + assert len(opt.param_groups) == 2 + assert opt.param_groups[0]["use_muon"] is True + assert opt.param_groups[1]["use_muon"] is False + + def test_state_dict_roundtrip(self): + from areal.engine.fsdp_utils.muon import Muon + + model = _make_simple_model() + opt = Muon(_make_param_groups(model)) + + # Do a step to populate state + x = torch.randn(4, 16) + loss = model(x).sum() + loss.backward() + opt.step() + + sd = opt.state_dict() + assert "state" in sd + assert "param_groups" in sd + + # Reload into fresh optimizer + opt2 = Muon(_make_param_groups(model)) + opt2.load_state_dict(sd) + + sd2 = opt2.state_dict() + assert len(sd2["state"]) == len(sd["state"]) + + def test_lr_scheduler_compat(self): + from areal.engine.fsdp_utils.muon import Muon + + model = _make_simple_model() + opt = Muon(_make_param_groups(model)) + + scheduler = torch.optim.lr_scheduler.StepLR(opt, step_size=1, gamma=0.5) + + x = torch.randn(4, 16) + loss = model(x).sum() + loss.backward() + opt.step() + scheduler.step() + + # LR should be halved for all groups + assert abs(opt.param_groups[0]["lr"] - 5e-3) < 1e-9 + assert abs(opt.param_groups[1]["lr"] - 5e-4) < 1e-9 + + def test_all_params_updated(self): + from areal.engine.fsdp_utils.muon import Muon + + model = _make_simple_model() + opt = Muon(_make_param_groups(model)) + + params_before = { + name: p.clone() for name, p in model.named_parameters() if p.requires_grad + } + + x = torch.randn(4, 16) + loss = model(x).sum() + loss.backward() + opt.step() + + for name, p in model.named_parameters(): + if p.requires_grad: + assert not torch.equal(p.data, params_before[name]), ( + f"Parameter {name} was not updated" + ) + + def test_convergence(self): + """Test that Muon+Adam can minimize a simple quadratic.""" + from areal.engine.fsdp_utils.muon import Muon + + torch.manual_seed(42) + model = nn.Linear(8, 1, bias=True) + target_w = torch.randn(1, 8) + target_b = torch.randn(1) + + opt = Muon( + [ + dict(params=[model.weight], lr=0.02, use_muon=True), + dict(params=[model.bias], lr=0.02, use_muon=False), + ] + ) + + for _ in range(200): + x = torch.randn(32, 8) + pred = model(x) + target = x @ target_w.T + target_b + loss = ((pred - target) ** 2).mean() + + opt.zero_grad() + loss.backward() + opt.step() + + final_loss = loss.item() + assert final_loss < 0.1, f"Muon did not converge, final loss={final_loss}" + + def test_multi_step_finite(self): + """Multiple steps should keep producing finite, reasonable params.""" + from areal.engine.fsdp_utils.muon import Muon + + torch.manual_seed(42) + model = _make_simple_model() + opt = Muon(_make_param_groups(model)) + + for step_idx in range(10): + x = torch.randn(4, 16) + loss = model(x).sum() + opt.zero_grad() + loss.backward() + opt.step() + + for p in model.parameters(): + assert torch.isfinite(p.data).all()