feat: muon optimizer support#1270
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces support for the Muon optimizer across the FSDP, Megatron, and Archon engines, including a new _CombinedOptimizer wrapper to manage parameter splitting between Muon (for weights with 2 or more dimensions) and AdamW (for 1D parameters). While the integration is comprehensive, several critical issues were identified: the current parameter splitting logic is mathematically incorrect when used with FSDP sharding as Muon requires full weight matrices, and applying Muon to large embedding or head layers may cause significant performance degradation. Additionally, the _CombinedOptimizer implementation uses a non-standard state_dict format and lacks proper base class initialization, which could break compatibility with PyTorch checkpointing and other utilities. Minor typos regarding the required PyTorch version for Muon were also noted.
| def state_dict(self) -> dict[str, Any]: | ||
| return { | ||
| "muon": self._muon.state_dict(), | ||
| "backend": self._backend.state_dict(), | ||
| } |
There was a problem hiding this comment.
| if self.optimizer_config.type == "muon": | ||
| if not hasattr(torch.optim, "Muon"): | ||
| raise RuntimeError( | ||
| "torch.optim.Muon is not available in the current PyTorch version. " | ||
| "Please upgrade to torch>=2.9.0." | ||
| ) |
There was a problem hiding this comment.
The error message mentions torch>=2.9.0, which is likely a typo. Muon is expected to be available in much earlier versions (e.g., PyTorch 2.6.0).
| if self.optimizer_config.type == "muon": | |
| if not hasattr(torch.optim, "Muon"): | |
| raise RuntimeError( | |
| "torch.optim.Muon is not available in the current PyTorch version. " | |
| "Please upgrade to torch>=2.9.0." | |
| ) | |
| if self.optimizer_config.type == "muon": | |
| if not hasattr(torch.optim, "Muon"): | |
| raise RuntimeError( | |
| "torch.optim.Muon is not available in the current PyTorch version. " | |
| "Please upgrade to torch>=2.6.0." | |
| ) |
| def __init__( | ||
| self, | ||
| muon_optimizer: torch.optim.Optimizer, | ||
| backend_optimizer: torch.optim.Optimizer, | ||
| ) -> None: | ||
| # We do NOT call super().__init__() because we manage state ourselves. | ||
| self._muon = muon_optimizer | ||
| self._backend = backend_optimizer | ||
|
|
||
| # Expose a unified param_groups list so LR schedulers work. | ||
| self.param_groups: list[dict[str, Any]] = ( | ||
| self._muon.param_groups + self._backend.param_groups | ||
| ) |
There was a problem hiding this comment.
It is recommended to call super().__init__ with the merged param_groups to ensure the base Optimizer class is correctly initialized. This maintains compatibility with PyTorch utilities and internal checks that might rely on attributes like defaults or the internal state of the Optimizer object.
| def __init__( | |
| self, | |
| muon_optimizer: torch.optim.Optimizer, | |
| backend_optimizer: torch.optim.Optimizer, | |
| ) -> None: | |
| # We do NOT call super().__init__() because we manage state ourselves. | |
| self._muon = muon_optimizer | |
| self._backend = backend_optimizer | |
| # Expose a unified param_groups list so LR schedulers work. | |
| self.param_groups: list[dict[str, Any]] = ( | |
| self._muon.param_groups + self._backend.param_groups | |
| ) | |
| def __init__( | |
| self, | |
| muon_optimizer: torch.optim.Optimizer, | |
| backend_optimizer: torch.optim.Optimizer, | |
| ) -> None: | |
| # We call super().__init__ with merged param groups for compatibility. | |
| param_groups = muon_optimizer.param_groups + backend_optimizer.param_groups | |
| super().__init__(param_groups, {}) | |
| self._muon = muon_optimizer | |
| self._backend = backend_optimizer |
| def state(self) -> dict: # type: ignore[override] | ||
| """Merged view used by checkpoint utilities.""" | ||
| merged: dict = {} | ||
| merged.update(self._muon.state) | ||
| merged.update(self._backend.state) | ||
| return merged |
There was a problem hiding this comment.
The state property returns a new merged dictionary on every access. This makes the state effectively read-only for any external utility that expects to modify optimizer states in-place (e.g., for state initialization or normalization). While Muon is currently excluded from per_layer_optim_step, this behavior could lead to subtle bugs if the optimizer is used with other utilities that expect standard Optimizer.state behavior.
| elif optimizer_config.type == "muon": | ||
| if not hasattr(torch.optim, "Muon"): | ||
| raise RuntimeError( | ||
| "torch.optim.Muon is not available in the current PyTorch version. " | ||
| "Please upgrade to torch>=2.9.0." | ||
| ) |
There was a problem hiding this comment.
Typo in torch version requirement (2.9.0 -> 2.6.0).
| elif optimizer_config.type == "muon": | |
| if not hasattr(torch.optim, "Muon"): | |
| raise RuntimeError( | |
| "torch.optim.Muon is not available in the current PyTorch version. " | |
| "Please upgrade to torch>=2.9.0." | |
| ) | |
| elif optimizer_config.type == "muon": | |
| if not hasattr(torch.optim, "Muon"): | |
| raise RuntimeError( | |
| "torch.optim.Muon is not available in the current PyTorch version. " | |
| "Please upgrade to torch>=2.6.0." | |
| ) |
416ba57 to
4e4f135
Compare
bf4b583 to
f320451
Compare
Add Muon optimizer (Newton-Schulz orthogonalization) with distributed
FSDP support, ported from samsja/muon_fsdp_2 v0.3.0.
Core changes:
- areal/utils/optimizer.py: Full Muon implementation with Work pipeline
for async NCCL overlap (Fsdp1dWork, SingleDeviceWork), Newton-Schulz
iteration, Moonlight RMS scaling option, and AdamW fallback for
non-2D parameters (embeddings, norms, biases)
- areal/api/cli_args.py: Add Muon-specific config fields (muon_momentum,
muon_nesterov, muon_ns_steps, muon_backend_steps, muon_rms_scale)
- areal/engine/fsdp_engine.py: Integrate Muon into FSDP optimizer creation
- areal/experimental/engine/archon_engine.py + archon_utils.py: Integrate
Muon into Archon engine optimizer creation
- pyproject.toml / pyproject.vllm.toml: Add muon_fsdp_2 dependency
- tests/test_muon_optimizer.py: Unit tests for Newton-Schulz, scaling,
optimizer step, and config validation
feat(megatron): enable Muon optimizer via Megatron-Core native dispatch
Megatron-Core natively supports Muon via _get_megatron_emerging_optimizer
when optimizer type is not in ('adam', 'sgd'). It handles TP-aware
Newton-Schulz, QKV splitting, and ChainedOptimizer (Muon for 2D weights,
AdamW for norms/biases/embeddings) out of the box.
- Allow 'muon' in _create_optimizer assertion
- Forward muon_momentum/muon_nesterov/muon_num_ns_steps from OptimizerConfig
to MCoreOptimizerConfig (with hasattr guard for older Megatron-Core)
- Requires the 'emerging-optimizers' package to be installed at runtime
75441d8 to
0deea42
Compare
0deea42 to
5c545e6
Compare


Description
Add Muon optimizer support to AReaL, with native distributed implementations on
both FSDP2 and Megatron backends.
Muon (MomentUm Orthogonalized by Newton-Schulz) applies an orthogonalization
step (Newton-Schulz iteration) to the momentum buffer before each update,
yielding more "uniform" updates in parameter space. Empirically this matches or
beats AdamW with comparable cost on LLM pre-/post-training workloads
(Kimi/Moonlight, arXiv:2502.16982).
This PR wires Muon end-to-end:
OptimizerConfig.type='muon') inareal/api/cli_args.py, with hyperparameters mirroring Megatron-Core'sOptimizerConfig:muon_momentum,muon_use_nesterov,muon_num_ns_stepsmuon_scale_mode∈ {spectral,unit_rms_norm,shape_scaling}muon_extra_scale_factorareal/engine/fsdp_engine.py+areal/engine/fsdp_utils/muon.py):torch.compile-fused Newton-Schulz innerloop (bf16-accelerated, batched).
Fsdp1dWork)full_tensor()all-gather → redundant NSon every rank →
distribute_tensorre-shard (TpFsdp2dWork)SingleDeviceWork)NotImplementedErrorfor now.embeddings if 1D) use the built-in AdamW backend, sharing a single
lrwhen paired with Moonlight-style RMS scaling
(
scale_mode='spectral'+extra_scale_factor=0.2).areal/engine/megatron_engine.py): delegates toMegatron-Core's
TensorParallelMuon(requires Megatron-Core ≥ 0.17),groups Linear weights to Muon and the rest (embeddings, biases, norms,
non-Linear 2D params) to AdamW.
docs/{en,zh}/reference/optimizer.md)with engine support matrix, FSDP-vs-Megatron grouping differences, full
config example, and parameter table.
examples/math/gsm8k_sft_fsdp_muon.yaml,examples/math/gsm8k_sft_megatron.yaml.Related Issue
Fixes #(issue)
Type of Change
Checklist
pre-commit run --all-files)./docs/build_all.sh)main/review-prcommand/create-prBreaking Change Details (if applicable):
Additional Context
Design notes
muon_scale_modeandmuon_extra_scale_factordeliberately matchMegatron-CoreOptimizerConfigso a single YAML works across both backends.and redundant compute for code simplicity and correctness across mixed
Col/Row-wise/Replicate plans. A future
Fsdp2dWorkcould gather to a singlerank to save redundant NS — left as TODO.
lrergonomics: withscale_mode='spectral'+extra_scale_factor=0.2, Muon's update RMS approximately matches AdamW's,so users don't need separate LRs for Muon vs. the AdamW backend.
Known limitations
NotImplementedError.NotImplementedError.Need help? Check the Contributing Guide or ask in
GitHub Discussions!