Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 71 additions & 9 deletions areal/api/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,30 +338,48 @@ class OptimizerConfig:
type: str = field(
default="adam",
metadata={
"help": "Optimizer type. For FSDP Engine, adam_bf16 enables memory-efficient BF16 optimizer states. "
"For Megatron Engine, adam_bf16 requires dtype=bfloat16 and is automatically converted to adam "
"with precision-aware optimizer enabled.",
"choices": ["adam", "sgd", "adam_bf16"],
"help": "Optimizer type. 'adam': AdamW (default). 'adam_bf16': memory-efficient BF16 AdamW "
"(FSDP: uses AnyPrecisionAdamW; Megatron: requires dtype=bfloat16, auto-converted to adam "
"with precision-aware optimizer). 'sgd': plain SGD. 'muon': Muon optimizer for >=2D params "
"with AdamW backend for <2D params (biases, norms, embeddings).",
"choices": ["adam", "sgd", "adam_bf16", "muon"],
},
)
lr: float = field(
default=1e-3,
metadata={
"help": "Learning rate. When type='muon', this is shared by both the Muon sub-optimizer "
"(>=2D params) and the AdamW backend (<2D params). Pair "
"muon_scale_mode='spectral' with muon_extra_scale_factor=0.2 (Moonlight-style) to "
"make Muon's update RMS match AdamW so a single lr works for both."
},
)
weight_decay: float = field(
default=0.01,
metadata={
"help": "Weight decay. Applied to all optimizer types including Muon (>=2D params) "
"and AdamW backend (<2D params)."
},
)
lr: float = field(default=1e-3, metadata={"help": "Learning rate"})
weight_decay: float = field(default=0.01, metadata={"help": "Weight decay"})
beta1: float = field(
default=0.9,
metadata={
"help": "Adam beta1 parameter. Only effective when optimizer_type is adam/adam_bf16"
"help": "Adam beta1 parameter. Used by adam/adam_bf16, and by the AdamW backend "
"when type='muon'. Not used by the Muon sub-optimizer itself."
},
)
beta2: float = field(
default=0.999,
metadata={
"help": "Adam beta2 parameter. Only effective when optimizer_type is adam/adam_bf16"
"help": "Adam beta2 parameter. Used by adam/adam_bf16, and by the AdamW backend "
"when type='muon'. Not used by the Muon sub-optimizer itself."
},
)
eps: float = field(
default=1e-8,
metadata={
"help": "Adam epsilon parameter. Only effective when optimizer_type is adam/adam_bf16"
"help": "Adam epsilon for numerical stability. Used by adam/adam_bf16, and by the "
"AdamW backend when type='muon'. Not used by the Muon sub-optimizer itself."
},
)
min_lr_ratio: float = field(
Expand Down Expand Up @@ -398,6 +416,50 @@ class OptimizerConfig:
gradient_clipping: float = field(
default=1.0, metadata={"help": "Gradient clipping threshold"}
)
muon_momentum: float = field(
default=0.95,
metadata={
"help": "Muon momentum parameter. Only effective when optimizer_type is muon."
},
)
muon_use_nesterov: bool = field(
default=True,
metadata={
"help": "Whether to use Nesterov momentum in Muon. Only effective when type='muon'. "
"Mirrors Megatron-Core OptimizerConfig.muon_use_nesterov."
},
)
muon_num_ns_steps: int = field(
default=5,
metadata={
"help": "Number of Newton-Schulz iteration steps in Muon. Only effective when type='muon'. "
"Mirrors Megatron-Core OptimizerConfig.muon_num_ns_steps."
},
)
muon_scale_mode: str = field(
default="spectral",
metadata={
"help": "Muon update scaling mode (final scale = mode_factor * muon_extra_scale_factor):"
"Only used when type='muon'. Mirrors Megatron-Core OptimizerConfig.muon_scale_mode.",
"choices": ["spectral", "unit_rms_norm", "shape_scaling"],
},
)
muon_extra_scale_factor: float = field(
default=1.0,
metadata={
"help": "Extra multiplier on top of muon_scale_mode. Use 0.2 with "
"scale_mode='spectral' for Moonlight-style RMS-matched scaling. "
"Only used when type='muon'. Mirrors Megatron-Core OptimizerConfig.muon_extra_scale_factor."
},
)

def __post_init__(self):
"""Validate optimizer configuration."""
valid_muon_scale_modes = {"spectral", "unit_rms_norm", "shape_scaling"}
if self.muon_scale_mode not in valid_muon_scale_modes:
raise ValueError(
f"muon_scale_mode must be one of {valid_muon_scale_modes}, got {self.muon_scale_mode!r}. "
)


@dataclass
Expand Down
50 changes: 46 additions & 4 deletions areal/engine/fsdp_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,11 @@
)
from areal.engine.fsdp_utils.checkpoint import DCPState
from areal.engine.fsdp_utils.grad import fsdp2_clip_grad_norm
from areal.engine.fsdp_utils.optimizer import AnyPrecisionAdamW, PerLayerOptimWrapper
from areal.engine.fsdp_utils.muon import Muon as MuonOptimizer
from areal.engine.fsdp_utils.optimizer import (
AnyPrecisionAdamW,
PerLayerOptimWrapper,
)
from areal.engine.fsdp_utils.parallel import ParallelHelper, parallelize_model
from areal.infra.dist_rollout import DistRolloutCoordinator
from areal.infra.platforms import current_platform
Expand Down Expand Up @@ -470,7 +474,7 @@ def initialize(self, addr: str | None, ft_spec: FinetuneSpec, *args, **kwargs):
self._create_optimizer(ft_spec)

if self.config.fsdp.per_layer_optim_step:
if self.optimizer_config.type != "adam":
if self.optimizer_config.type not in ("adam",):
raise ValueError(
f"per_layer_optim_step only supports 'adam' optimizer, got '{self.optimizer_config.type}'."
)
Expand Down Expand Up @@ -1111,7 +1115,8 @@ def _create_optimizer(self, ft_spec: FinetuneSpec) -> None:
"adam",
"adam_bf16",
"sgd",
], "Only adam/adam_bf16/sgd optimizer is supported in this engine."
"muon",
], "Only adam/adam_bf16/sgd/muon optimizer is supported in this engine."
if self.optimizer_config.type in ["sgd", "adam_bf16"]:
self.logger.warning(
f"Using the '{self.optimizer_config.type}' optimizer with FSDP may be less stable. Consider using the 'adam' (AdamW) optimizer for improved stability and performance."
Expand All @@ -1121,7 +1126,44 @@ def _create_optimizer(self, ft_spec: FinetuneSpec) -> None:
beta1 = self.optimizer_config.beta1
beta2 = self.optimizer_config.beta2
eps = self.optimizer_config.eps
if self.optimizer_config.type == "adam":
if self.optimizer_config.type == "muon":
muon_params: list[torch.nn.Parameter] = []
backend_params: list[torch.nn.Parameter] = []
for p in self.model.parameters():
if not p.requires_grad:
continue
if p.ndim >= 2:
muon_params.append(p)
else:
backend_params.append(p)
Comment thread
HT-Yuan marked this conversation as resolved.
self.optimizer = MuonOptimizer(
[
dict(
params=muon_params,
lr=lr,
momentum=self.optimizer_config.muon_momentum,
weight_decay=weight_decay,
scale_mode=self.optimizer_config.muon_scale_mode,
extra_scale_factor=self.optimizer_config.muon_extra_scale_factor,
nesterov=self.optimizer_config.muon_use_nesterov,
ns_steps=self.optimizer_config.muon_num_ns_steps,
use_muon=True,
),
dict(
params=backend_params,
lr=lr,
betas=(beta1, beta2),
eps=eps,
weight_decay=weight_decay,
use_muon=False,
),
]
)
self.logger.info(
f"Muon optimizer: {len(muon_params)} params (>=2D), "
f"AdamW backend: {len(backend_params)} params (<2D)"
)
elif self.optimizer_config.type == "adam":
self.optimizer = torch.optim.AdamW(
self.model.parameters(),
lr=lr,
Expand Down
2 changes: 2 additions & 0 deletions areal/engine/fsdp_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
)
from transformers import PreTrainedModel

from areal.engine.fsdp_utils.muon import Muon
from areal.engine.fsdp_utils.optimizer import (
AdamKernel,
OptimKernel,
Expand All @@ -33,6 +34,7 @@
"apply_fsdp2",
"fsdp2_load_full_state_dict",
"get_cosine_schedule_with_warmup",
"Muon",
"PerLayerOptimWrapper",
"OptimKernel",
"AdamKernel",
Expand Down
Loading
Loading