Skip to content

feat: muon optimizer support#1270

Open
HT-Yuan wants to merge 3 commits into
areal-project:mainfrom
HT-Yuan:feature/muon-optimizer-support
Open

feat: muon optimizer support#1270
HT-Yuan wants to merge 3 commits into
areal-project:mainfrom
HT-Yuan:feature/muon-optimizer-support

Conversation

@HT-Yuan
Copy link
Copy Markdown
Contributor

@HT-Yuan HT-Yuan commented Apr 27, 2026

Description

Add Muon optimizer support to AReaL, with native distributed implementations on
both FSDP2 and Megatron backends.

Muon (MomentUm Orthogonalized by Newton-Schulz) applies an orthogonalization
step (Newton-Schulz iteration) to the momentum buffer before each update,
yielding more "uniform" updates in parameter space. Empirically this matches or
beats AdamW with comparable cost on LLM pre-/post-training workloads
(Kimi/Moonlight, arXiv:2502.16982).

This PR wires Muon end-to-end:

  • New optimizer config (OptimizerConfig.type='muon') in
    areal/api/cli_args.py, with hyperparameters mirroring Megatron-Core's
    OptimizerConfig:
    • muon_momentum, muon_use_nesterov, muon_num_ns_steps
    • muon_scale_mode ∈ {spectral, unit_rms_norm, shape_scaling}
    • muon_extra_scale_factor
  • FSDP backend (areal/engine/fsdp_engine.py +
    areal/engine/fsdp_utils/muon.py):
    • Pure-Python Muon optimizer with torch.compile-fused Newton-Schulz inner
      loop (bf16-accelerated, batched).
    • DTensor-aware distributed dispatch:
      • 1D mesh (FSDP only): gather → NS on owner rank → scatter (Fsdp1dWork)
      • 2D mesh (TP + FSDP, strategy A): full_tensor() all-gather → redundant NS
        on every rank → distribute_tensor re-shard (TpFsdp2dWork)
      • Single device fallback (SingleDeviceWork)
      • 3D mesh (TP + EP + FSDP) raises NotImplementedError for now.
    • Hybrid grouping: ≥2D params use Muon, <2D params (biases, LayerNorms,
      embeddings if 1D) use the built-in AdamW backend, sharing a single lr
      when paired with Moonlight-style RMS scaling
      (scale_mode='spectral' + extra_scale_factor=0.2).
  • Megatron backend (areal/engine/megatron_engine.py): delegates to
    Megatron-Core's TensorParallelMuon (requires Megatron-Core ≥ 0.17),
    groups Linear weights to Muon and the rest (embeddings, biases, norms,
    non-Linear 2D params) to AdamW.
  • Docs: new optimizer reference page (docs/{en,zh}/reference/optimizer.md)
    with engine support matrix, FSDP-vs-Megatron grouping differences, full
    config example, and parameter table.
  • Examples: examples/math/gsm8k_sft_fsdp_muon.yaml,
    examples/math/gsm8k_sft_megatron.yaml.
  • Tests: see Test Plan below.

Related Issue

Fixes #(issue)

Type of Change

  • 🐛 Bug fix
  • ✨ New feature
  • 💥 Breaking change
  • 📝 Documentation update
  • ♻️ Refactoring
  • ⚡ Performance improvement
  • ✅ Test coverage improvement

Checklist

  • I have read the Contributing Guide
  • Pre-commit hooks pass (pre-commit run --all-files)
  • Relevant tests pass; new tests added for new functionality
  • Documentation updated (if applicable; built with ./docs/build_all.sh)
  • Branch is up to date with main
  • Self-reviewed via /review-pr command
  • This PR was created by a coding agent via /create-pr
  • This PR is a breaking change

Breaking Change Details (if applicable):

Additional Context

Design notes

  • Naming alignment with Megatron-Core / NVIDIA-NeMo emerging_optimizers:
    muon_scale_mode and muon_extra_scale_factor deliberately match
    Megatron-Core OptimizerConfig so a single YAML works across both backends.
  • Why TP+FSDP 2D uses "Strategy A" (redundant NS): it trades comm volume
    and redundant compute for code simplicity and correctness across mixed
    Col/Row-wise/Replicate plans. A future Fsdp2dWork could gather to a single
    rank to save redundant NS — left as TODO.
  • Single-lr ergonomics: with scale_mode='spectral' +
    extra_scale_factor=0.2, Muon's update RMS approximately matches AdamW's,
    so users don't need separate LRs for Muon vs. the AdamW backend.

Known limitations

  • TP + EP + FSDP (3D mesh) on FSDP backend: NotImplementedError.
  • EP + FSDP (2D mesh) on FSDP backend: NotImplementedError.
  • Archon engine: not yet wired up.

Need help? Check the Contributing Guide or ask in
GitHub Discussions!

@HT-Yuan HT-Yuan marked this pull request as draft April 27, 2026 04:14
@HT-Yuan HT-Yuan changed the title Feature/muon optimizer support [wip] Feature/muon optimizer support Apr 27, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for the Muon optimizer across the FSDP, Megatron, and Archon engines, including a new _CombinedOptimizer wrapper to manage parameter splitting between Muon (for weights with 2 or more dimensions) and AdamW (for 1D parameters). While the integration is comprehensive, several critical issues were identified: the current parameter splitting logic is mathematically incorrect when used with FSDP sharding as Muon requires full weight matrices, and applying Muon to large embedding or head layers may cause significant performance degradation. Additionally, the _CombinedOptimizer implementation uses a non-standard state_dict format and lacks proper base class initialization, which could break compatibility with PyTorch checkpointing and other utilities. Minor typos regarding the required PyTorch version for Muon were also noted.

Comment thread areal/engine/fsdp_engine.py
Comment thread areal/engine/fsdp_utils/optimizer.py Outdated
Comment on lines +689 to +693
def state_dict(self) -> dict[str, Any]:
return {
"muon": self._muon.state_dict(),
"backend": self._backend.state_dict(),
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This state_dict format is non-standard as it lacks the top-level state and param_groups keys. This will likely break compatibility with standard PyTorch checkpointing utilities and might cause issues with sharded saving/loading in FSDP (DCP) if it expects the standard structure.

Comment thread areal/engine/fsdp_engine.py Outdated
Comment on lines +1014 to +1019
if self.optimizer_config.type == "muon":
if not hasattr(torch.optim, "Muon"):
raise RuntimeError(
"torch.optim.Muon is not available in the current PyTorch version. "
"Please upgrade to torch>=2.9.0."
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The error message mentions torch>=2.9.0, which is likely a typo. Muon is expected to be available in much earlier versions (e.g., PyTorch 2.6.0).

Suggested change
if self.optimizer_config.type == "muon":
if not hasattr(torch.optim, "Muon"):
raise RuntimeError(
"torch.optim.Muon is not available in the current PyTorch version. "
"Please upgrade to torch>=2.9.0."
)
if self.optimizer_config.type == "muon":
if not hasattr(torch.optim, "Muon"):
raise RuntimeError(
"torch.optim.Muon is not available in the current PyTorch version. "
"Please upgrade to torch>=2.6.0."
)

Comment thread areal/engine/fsdp_utils/optimizer.py Outdated
Comment on lines +649 to +661
def __init__(
self,
muon_optimizer: torch.optim.Optimizer,
backend_optimizer: torch.optim.Optimizer,
) -> None:
# We do NOT call super().__init__() because we manage state ourselves.
self._muon = muon_optimizer
self._backend = backend_optimizer

# Expose a unified param_groups list so LR schedulers work.
self.param_groups: list[dict[str, Any]] = (
self._muon.param_groups + self._backend.param_groups
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

It is recommended to call super().__init__ with the merged param_groups to ensure the base Optimizer class is correctly initialized. This maintains compatibility with PyTorch utilities and internal checks that might rely on attributes like defaults or the internal state of the Optimizer object.

Suggested change
def __init__(
self,
muon_optimizer: torch.optim.Optimizer,
backend_optimizer: torch.optim.Optimizer,
) -> None:
# We do NOT call super().__init__() because we manage state ourselves.
self._muon = muon_optimizer
self._backend = backend_optimizer
# Expose a unified param_groups list so LR schedulers work.
self.param_groups: list[dict[str, Any]] = (
self._muon.param_groups + self._backend.param_groups
)
def __init__(
self,
muon_optimizer: torch.optim.Optimizer,
backend_optimizer: torch.optim.Optimizer,
) -> None:
# We call super().__init__ with merged param groups for compatibility.
param_groups = muon_optimizer.param_groups + backend_optimizer.param_groups
super().__init__(param_groups, {})
self._muon = muon_optimizer
self._backend = backend_optimizer

Comment thread areal/engine/fsdp_utils/optimizer.py Outdated
Comment on lines +682 to +687
def state(self) -> dict: # type: ignore[override]
"""Merged view used by checkpoint utilities."""
merged: dict = {}
merged.update(self._muon.state)
merged.update(self._backend.state)
return merged
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The state property returns a new merged dictionary on every access. This makes the state effectively read-only for any external utility that expects to modify optimizer states in-place (e.g., for state initialization or normalization). While Muon is currently excluded from per_layer_optim_step, this behavior could lead to subtle bugs if the optimizer is used with other utilities that expect standard Optimizer.state behavior.

Comment on lines +63 to +68
elif optimizer_config.type == "muon":
if not hasattr(torch.optim, "Muon"):
raise RuntimeError(
"torch.optim.Muon is not available in the current PyTorch version. "
"Please upgrade to torch>=2.9.0."
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Typo in torch version requirement (2.9.0 -> 2.6.0).

Suggested change
elif optimizer_config.type == "muon":
if not hasattr(torch.optim, "Muon"):
raise RuntimeError(
"torch.optim.Muon is not available in the current PyTorch version. "
"Please upgrade to torch>=2.9.0."
)
elif optimizer_config.type == "muon":
if not hasattr(torch.optim, "Muon"):
raise RuntimeError(
"torch.optim.Muon is not available in the current PyTorch version. "
"Please upgrade to torch>=2.6.0."
)

@HT-Yuan HT-Yuan force-pushed the feature/muon-optimizer-support branch from 416ba57 to 4e4f135 Compare April 29, 2026 12:15
@HT-Yuan HT-Yuan force-pushed the feature/muon-optimizer-support branch 3 times, most recently from bf4b583 to f320451 Compare May 7, 2026 17:08
Add Muon optimizer (Newton-Schulz orthogonalization) with distributed
FSDP support, ported from samsja/muon_fsdp_2 v0.3.0.

Core changes:
- areal/utils/optimizer.py: Full Muon implementation with Work pipeline
  for async NCCL overlap (Fsdp1dWork, SingleDeviceWork), Newton-Schulz
  iteration, Moonlight RMS scaling option, and AdamW fallback for
  non-2D parameters (embeddings, norms, biases)
- areal/api/cli_args.py: Add Muon-specific config fields (muon_momentum,
  muon_nesterov, muon_ns_steps, muon_backend_steps, muon_rms_scale)
- areal/engine/fsdp_engine.py: Integrate Muon into FSDP optimizer creation
- areal/experimental/engine/archon_engine.py + archon_utils.py: Integrate
  Muon into Archon engine optimizer creation
- pyproject.toml / pyproject.vllm.toml: Add muon_fsdp_2 dependency
- tests/test_muon_optimizer.py: Unit tests for Newton-Schulz, scaling,
  optimizer step, and config validation

feat(megatron): enable Muon optimizer via Megatron-Core native dispatch

Megatron-Core natively supports Muon via _get_megatron_emerging_optimizer
when optimizer type is not in ('adam', 'sgd'). It handles TP-aware
Newton-Schulz, QKV splitting, and ChainedOptimizer (Muon for 2D weights,
AdamW for norms/biases/embeddings) out of the box.

- Allow 'muon' in _create_optimizer assertion
- Forward muon_momentum/muon_nesterov/muon_num_ns_steps from OptimizerConfig
  to MCoreOptimizerConfig (with hasattr guard for older Megatron-Core)
- Requires the 'emerging-optimizers' package to be installed at runtime
@HT-Yuan HT-Yuan force-pushed the feature/muon-optimizer-support branch from 75441d8 to 0deea42 Compare May 8, 2026 03:48
@HT-Yuan HT-Yuan changed the title [wip] Feature/muon optimizer support feat: muon optimizer support May 8, 2026
@HT-Yuan HT-Yuan force-pushed the feature/muon-optimizer-support branch from 0deea42 to 5c545e6 Compare May 8, 2026 06:01
@HT-Yuan
Copy link
Copy Markdown
Contributor Author

HT-Yuan commented May 8, 2026

image -Qwen3-1.7B -gsm8k [FSDP](fsdp:d4p1t1): optimizer: type: muon lr: 2e-3 muon_momentum: 0.95 muon_use_nesterov: true muon_num_ns_steps: 5 muon_scale_mode: spectral muon_extra_scale_factor: 0.2 weight_decay: 0.05 beta1: 0.9 beta2: 0.95 eps: 1e-5 lr_scheduler_type: cosine gradient_clipping: 1.0 warmup_steps_proportion: 0.03

@HT-Yuan
Copy link
Copy Markdown
Contributor Author

HT-Yuan commented May 8, 2026

image -Qwen3-1.7B -gsm8k [megatron](megatron:d4p1t1): optimizer: type: muon lr: 2e-3 muon_momentum: 0.95 muon_use_nesterov: true muon_num_ns_steps: 5 muon_scale_mode: spectral muon_extra_scale_factor: 0.2 weight_decay: 0.05 beta1: 0.9 beta2: 0.95 eps: 1e-5 lr_scheduler_type: cosine gradient_clipping: 1.0 warmup_steps_proportion: 0.03

@HT-Yuan HT-Yuan marked this pull request as ready for review May 8, 2026 06:06
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant