From fd37381c1d3875b747577a544fe972bdcb63cc77 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Sun, 12 Apr 2026 20:44:15 +0800 Subject: [PATCH 001/112] feat: add router replay for megatron engine --- areal/engine/megatron_engine.py | 16 + areal/engine/megatron_engine_r3_patch.py | 309 +++++++++++ areal/engine/router_replay_patch.py | 649 +++++++++++++++++++++++ areal/engine/router_replay_utils.py | 433 +++++++++++++++ areal/trainer/ppo/actor_r3_patch.py | 83 +++ areal/trainer/rl_trainer.py | 11 + areal/workflow/rlvr.py | 18 +- areal/workflow/rlvr_r3_patch.py | 162 ++++++ 8 files changed, 1680 insertions(+), 1 deletion(-) create mode 100644 areal/engine/megatron_engine_r3_patch.py create mode 100644 areal/engine/router_replay_patch.py create mode 100644 areal/engine/router_replay_utils.py create mode 100644 areal/trainer/ppo/actor_r3_patch.py create mode 100644 areal/workflow/rlvr_r3_patch.py diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index eb486d8530..95e75097c2 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -285,6 +285,8 @@ def initialize(self, addr: str | None, ft_spec: FinetuneSpec, *args, **kwargs): self.tokenizer = load_hf_tokenizer(self.config.path) + # R3: Check early so the variable is always defined. + _r3_enabled = getattr(self.config, "_r3_enable_router_replay", False) with patch_bridge_for_tree_training( self.enable_tree_training and self.bridge_cls == "mbridge" ): @@ -307,6 +309,14 @@ def initialize(self, addr: str | None, ft_spec: FinetuneSpec, *args, **kwargs): self._check_and_apply_fp8_config() self._validate_fp8_consistency() + # R3: Apply Router Replay patch BEFORE model creation so that + # TopKRouter.__init__ and TransformerConfig.__init__ are patched. + if _r3_enabled: + from areal.engine.router_replay_patch import apply_router_replay_patch + apply_router_replay_patch() + self.tf_config.enable_routing_replay = True + self.logger.info("[R3] Router Replay patches applied before model creation.") + with self.device: models = make_mcore_model( hf_config=self.hf_config, @@ -398,6 +408,12 @@ def initialize(self, addr: str | None, ft_spec: FinetuneSpec, *args, **kwargs): model_config.param_sync_func = model_config.param_sync_func[0] model_config.finalize_model_grads_func = finalize_model_grads self._create_optimizer(ft_spec) + + # R3: Apply engine-level patch after model and optimizer are ready. + if _r3_enabled: + from areal.engine.megatron_engine_r3_patch import patch_megatron_engine_for_r3 + patch_megatron_engine_for_r3(self, enable_router_replay=True) + self.logger.info("[R3] Router Replay enabled on MegatronEngine.") self._initialized = True def _build_hf_mcore_bridge(self): diff --git a/areal/engine/megatron_engine_r3_patch.py b/areal/engine/megatron_engine_r3_patch.py new file mode 100644 index 0000000000..9e6fad4f20 --- /dev/null +++ b/areal/engine/megatron_engine_r3_patch.py @@ -0,0 +1,309 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +R3 Integration Patch for MegatronEngine. + +This module wraps ``MegatronEngine.forward_backward_batch`` so that, when +the micro-batch data contains ``routed_experts`` tensors, each micro-batch's +forward step is preceded by a call to ``setup_per_microbatch_replay_forward`` +and followed (after the full forward pass) by a switch to backward-replay +mode. + +The patch handles the critical issue that ``routed_experts`` is a 4D tensor +``(bs, seq_len, num_moe_layers, topk)`` which will NOT be correctly split by +``split_padded_tensor_dict_into_mb_list`` (which only splits tensors with +``numel() == bs * max_seqlen``). Instead, we extract ``routed_experts`` +from ``mb_list.data`` before micro-batch splitting, and manually distribute +it to each micro-batch using the ``forward_indices`` and ``group_lens`` +from ``MicroBatchList``. + +Usage:: + + from areal.engine.megatron_engine_r3_patch import patch_megatron_engine_for_r3 + patch_megatron_engine_for_r3(engine, enable_router_replay=True) +""" + +from __future__ import annotations + +import logging +import types +from collections.abc import Callable +from typing import Any + +import torch + +logger = logging.getLogger(__name__) + + +# =================================================================== +# Public API +# =================================================================== + + +def patch_megatron_engine_for_r3( + engine, + enable_router_replay: bool = False, +) -> None: + """Patch a ``MegatronEngine`` instance to support Router Replay (R3). + + 1. Applies Megatron-Core monkey-patches (TransformerConfig, TopKRouter, + Dispatcher). + 2. Tags the engine with ``_r3_enabled = True``. + 3. Wraps ``forward_backward_batch`` to inject per-microbatch replay + setup / teardown around the Megatron pipeline schedule. + + Args: + engine: A ``MegatronEngine`` instance (already initialized). + enable_router_replay: Master switch. + """ + if not enable_router_replay: + engine._r3_enabled = False + logger.debug("[R3] Router replay not enabled; skipping engine patch.") + return + + logger.info("[R3] Patching MegatronEngine for Router Replay (R3).") + + # Mark and save original + engine._r3_enabled = True + engine._r3_original_forward_backward_batch = engine.forward_backward_batch + + # Bind the wrapped method + engine.forward_backward_batch = types.MethodType( + _r3_forward_backward_batch, engine + ) + + logger.info("[R3] MegatronEngine patched successfully.") + + +# =================================================================== +# routed_experts splitting +# =================================================================== + + +def _split_routed_experts_for_mbs( + routed_experts: torch.Tensor, + mb_list, +) -> list[torch.Tensor | None]: + """Split the batch-level ``routed_experts`` tensor into per-micro-batch tensors. + + Uses ``mb_list.forward_indices`` and ``mb_list.group_lens`` to correctly + reorder and slice samples, mirroring how ``split_padded_tensor_dict_into_mb_list`` + splits other tensors. + + Args: + routed_experts: ``(bs, max_seqlen, num_moe_layers, topk)`` + mb_list: ``MicroBatchList`` with ``forward_indices`` and ``group_lens``. + + Returns: + List of tensors, one per micro-batch, each of shape + ``(mb_bs, max_seqlen, num_moe_layers, topk)``. + """ + if routed_experts is None: + return [None] * len(mb_list) + + forward_indices = mb_list.forward_indices + group_lens = mb_list.group_lens + + if forward_indices is None: + # No reordering -- just split evenly + n_mbs = len(mb_list) + bs = routed_experts.shape[0] + chunk = bs // n_mbs + return [routed_experts[i * chunk : (i + 1) * chunk] for i in range(n_mbs)] + + # Reorder by forward_indices (sample-level reordering) + reordered = routed_experts[forward_indices] + + # Split according to group_lens (number of samples per micro-batch) + # group_lens gives number of *tokens* per micro-batch, but since + # routed_experts is indexed by *sample* (dim-0 is bs), we need + # the number of samples per micro-batch. + # We can derive this from the mbs list. + result = [] + offset = 0 + for i, mb_dict in enumerate(mb_list.mbs): + # Determine number of samples in this micro-batch + # The mb_dict contains "attention_mask" which has shape (n_samples, ...) + if isinstance(mb_dict, dict) and "attention_mask" in mb_dict: + attn = mb_dict["attention_mask"] + if hasattr(attn, "shape"): + n_samples = attn.shape[0] + else: + n_samples = len(attn) + else: + # Fallback: try cu_seqlens + if isinstance(mb_dict, dict) and "cu_seqlens" in mb_dict: + cu = mb_dict["cu_seqlens"] + n_samples = len(cu) - 1 + else: + # Last resort: divide evenly + n_samples = routed_experts.shape[0] // len(mb_list) + + result.append(reordered[offset : offset + n_samples]) + offset += n_samples + + return result + + +# =================================================================== +# Wrapped forward_backward_batch +# =================================================================== + + +def _r3_forward_backward_batch( + self, + mb_list, + process_output_fn: Callable[ + [torch.Tensor, dict[str, Any]], torch.Tensor | None + ], + forward_only: bool = False, +) -> None: + """Drop-in replacement for ``MegatronEngine.forward_backward_batch`` + that injects R3 replay setup around each micro-batch. + + If the data does not contain ``routed_experts``, delegates directly + to the original method with zero overhead. + """ + from areal.engine.router_replay_utils import ( + clear_router_replay, + setup_per_microbatch_replay_backward, + setup_per_microbatch_replay_forward, + ) + + # ------------------------------------------------------------------ + # 1. Extract routed_experts from the batch data and split per-MB. + # routed_experts is (bs, max_seqlen, num_moe_layers, topk) and + # does NOT get split by split_padded_tensor_dict_into_mb_list. + # ------------------------------------------------------------------ + routed_experts_batch = None + if hasattr(mb_list, "data") and isinstance(mb_list.data, dict): + routed_experts_batch = mb_list.data.pop("routed_experts", None) + + # Also clean from mbs and padded_mbs to avoid confusing downstream code + for mb_dict in mb_list.mbs: + if isinstance(mb_dict, dict): + mb_dict.pop("routed_experts", None) + if mb_list.padded_mbs is not None: + for mb_dict in mb_list.padded_mbs: + if isinstance(mb_dict, dict): + mb_dict.pop("routed_experts", None) + + if routed_experts_batch is None: + logger.debug( + "[R3] No routed_experts in batch data; using original " + "forward_backward_batch." + ) + return self._r3_original_forward_backward_batch( + mb_list, process_output_fn, forward_only=forward_only + ) + + logger.debug( + "[R3] R3 forward_backward: %d micro-batches, routed_experts shape=%s, " + "forward_only=%s", + len(mb_list), + routed_experts_batch.shape, + forward_only, + ) + + # Split routed_experts per micro-batch + per_mb_routed_experts = _split_routed_experts_for_mbs( + routed_experts_batch, mb_list + ) + + # ------------------------------------------------------------------ + # 2. Store R3 data on the engine for the wrapped iterator. + # ------------------------------------------------------------------ + self._r3_per_mb_experts = per_mb_routed_experts + self._r3_mb_counter = 0 + model_config = self.tf_config + + # ------------------------------------------------------------------ + # 3. Wrap the MicroBatchList iterator to inject R3 setup before each + # micro-batch's forward pass. + # ------------------------------------------------------------------ + engine_ref = self + + class _R3MicroBatchIterator: + """Wraps the micro-batch iterator to inject R3 setup.""" + + def __init__(self, base_iter): + self._base = base_iter + + def __iter__(self): + return self + + def __next__(self): + mb_item = next(self._base) + + idx = engine_ref._r3_mb_counter + engine_ref._r3_mb_counter += 1 + re = ( + engine_ref._r3_per_mb_experts[idx] + if idx < len(engine_ref._r3_per_mb_experts) + else None + ) + + if re is not None: + # Get attention_mask from orig_mb or padded_mb + attn_mask = None + if hasattr(mb_item, "orig_mb") and isinstance(mb_item.orig_mb, dict): + attn_mask = mb_item.orig_mb.get("attention_mask") + if attn_mask is None and hasattr(mb_item, "padded_mb") and isinstance( + mb_item.padded_mb, dict + ): + attn_mask = mb_item.padded_mb.get("attention_mask") + + if attn_mask is not None: + try: + setup_per_microbatch_replay_forward( + re.to(attn_mask.device), + attn_mask, + model_config, + ) + except Exception: + logger.warning( + "[R3] Failed to setup replay for micro-batch %d.", + idx, + exc_info=True, + ) + else: + logger.warning( + "[R3] Cannot find attention_mask for micro-batch %d; " + "skipping replay setup.", + idx, + ) + return mb_item + + # Patch __iter__ on the MicroBatchList instance + original_iter = mb_list.__class__.__iter__ + + def _r3_iter(mb_list_self): + return _R3MicroBatchIterator(original_iter(mb_list_self)) + + mb_list.__class__.__iter__ = _r3_iter + + try: + self._r3_original_forward_backward_batch( + mb_list, process_output_fn, forward_only=forward_only + ) + + # Switch to backward replay after forward pass + if not forward_only: + setup_per_microbatch_replay_backward() + finally: + # Restore original iterator and clean up + mb_list.__class__.__iter__ = original_iter + clear_router_replay() + self._r3_per_mb_experts = None + self._r3_mb_counter = 0 diff --git a/areal/engine/router_replay_patch.py b/areal/engine/router_replay_patch.py new file mode 100644 index 0000000000..4d135bb02a --- /dev/null +++ b/areal/engine/router_replay_patch.py @@ -0,0 +1,649 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Monkey-patches for Megatron-Core MoE components to support Router Replay (R3). + +Router Replay forces the TopKRouter to use pre-recorded expert assignments +(from rollout inference) instead of computing new ones during training. +This eliminates the train/inference routing mismatch caused by weight +staleness in asynchronous RL training. + +Patches applied: +1. **RouterReplay class** -- self-contained class (no dependency on + megatron.core.transformer.moe.router_replay which does not exist in + megatron-core 0.16.0). +2. **TransformerConfig.__init__** -- accepts ``enable_routing_replay`` kwarg. +3. **TopKRouter.__init__** -- creates a ``RouterReplay`` instance per MoE layer. +4. **TopKRouter.routing** -- replaces routing logic to support record/replay. +5. **MoEAlltoAllTokenDispatcher.preprocess** -- fixes ``num_out_tokens`` when + replay indices contain duplicate expert assignments. + +Usage:: + + from areal.engine.router_replay_patch import apply_router_replay_patch + apply_router_replay_patch() # call once before model creation + + from areal.engine.router_replay_patch import remove_router_replay_patch + remove_router_replay_patch() # optional: for test cleanup + +Ported from verl reference implementation, adapted for AReaL. +""" + +from __future__ import annotations + +import inspect +import logging +import types +import warnings +from enum import Enum +from functools import wraps + +import torch + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Optional megatron-core imports with fallback +# --------------------------------------------------------------------------- +try: + from megatron.core.transformer.moe.moe_utils import ( + apply_router_token_dropping, + compute_routing_scores_for_aux_loss, + ) +except ImportError: + apply_router_token_dropping = None + compute_routing_scores_for_aux_loss = None + warnings.warn( + "[R3] Could not import apply_router_token_dropping / " + "compute_routing_scores_for_aux_loss from megatron.core; " + "some MoE features may be unavailable.", + stacklevel=2, + ) + +try: + from megatron.core.transformer.moe.moe_utils import group_limited_topk +except ImportError: + group_limited_topk = None + +try: + from megatron.core.transformer.moe.token_dispatcher import ( + MoEAlltoAllTokenDispatcher, + ) +except ImportError: + MoEAlltoAllTokenDispatcher = None + +from megatron.core.transformer.moe.router import TopKRouter +from megatron.core.transformer.transformer_config import TransformerConfig + + +# =================================================================== +# RouterReplayAction enum and RouterReplay class +# (self-contained -- no dependency on megatron.core.transformer.moe.router_replay) +# =================================================================== + + +class RouterReplayAction(Enum): + """Actions controlling the MoE routing replay behaviour.""" + + RECORD = "record" + REPLAY_FORWARD = "replay_forward" + REPLAY_BACKWARD = "replay_backward" + + +class RouterReplay: + """Manages recording and replaying of MoE routing decisions. + + Each MoE layer gets one ``RouterReplay`` instance. The class-level + list ``router_instances`` holds all of them so that global operations + (set action, distribute data, clear state) are straightforward. + """ + + # Class-level list of all router instances (one per MoE layer). + router_instances: list["RouterReplay"] = [] + + # ------------------------------------------------------------------ + # Class-level (static) helpers + # ------------------------------------------------------------------ + + @staticmethod + def set_replay_data(all_layers_topk_indices: list) -> None: + """Distribute per-layer topk indices to ``RouterReplay`` instances. + + Args: + all_layers_topk_indices: List of tensors, one per MoE layer, + each of shape ``(num_tokens, topk)``. Order must match + instantiation order. + """ + if len(all_layers_topk_indices) != len(RouterReplay.router_instances): + raise ValueError( + f"[R3] Number of replay tensors ({len(all_layers_topk_indices)}) " + f"does not match number of router instances " + f"({len(RouterReplay.router_instances)})." + ) + for i, inst in enumerate(RouterReplay.router_instances): + inst.set_target_indices(all_layers_topk_indices[i]) + + @staticmethod + def get_recorded_data() -> list: + """Collect recorded topk indices from all instances.""" + return [r.get_recorded_indices() for r in RouterReplay.router_instances] + + @staticmethod + def clear_global_indices() -> None: + """Clear recorded and target indices on all instances.""" + for r in RouterReplay.router_instances: + r.clear_indices() + + @staticmethod + def set_global_router_replay_action(action: RouterReplayAction) -> None: + """Set the replay action for all router instances.""" + for r in RouterReplay.router_instances: + r.set_router_replay_action(action) + + @staticmethod + def clear_global_router_replay_action() -> None: + """Clear the replay action on all router instances.""" + for r in RouterReplay.router_instances: + r.clear_router_replay_action() + + # ------------------------------------------------------------------ + # Instance methods + # ------------------------------------------------------------------ + + def __init__(self) -> None: + self.target_topk_idx: torch.Tensor | None = None + self.recorded_topk_idx: torch.Tensor | None = None + self.router_replay_action: RouterReplayAction | None = None + self.replay_backward_list: list[torch.Tensor] = [] + RouterReplay.router_instances.append(self) + + def set_target_indices(self, topk_indices: torch.Tensor) -> None: + self.target_topk_idx = topk_indices + self.replay_backward_list.append(topk_indices) + + def get_recorded_indices(self) -> torch.Tensor | None: + return self.recorded_topk_idx + + def record_indices(self, topk_indices: torch.Tensor) -> None: + self.recorded_topk_idx = topk_indices + + def clear_indices(self) -> None: + self.recorded_topk_idx = None + self.target_topk_idx = None + self.replay_backward_list = [] + + def set_router_replay_action(self, action: RouterReplayAction) -> None: + self.router_replay_action = action + + def clear_router_replay_action(self) -> None: + self.router_replay_action = None + + +# =================================================================== +# Patched routing implementation +# =================================================================== + + +def _patched_topk_routing_with_score_function( + logits: torch.Tensor, + topk: int, + use_pre_softmax: bool, + num_groups: int, + group_topk: int, + score_function: str, + expert_bias: torch.Tensor, + fused: bool, + router_replay: RouterReplay | None, + scaling_factor: float, +) -> tuple[torch.Tensor, torch.Tensor]: + """Patched ``topk_routing_with_score_function`` supporting router replay.""" + num_tokens, num_experts = logits.shape + + def _compute_topk(scores, topk, num_groups=None, group_topk=None): + if group_topk and group_limited_topk is not None: + return group_limited_topk( + scores=scores, + topk=topk, + num_tokens=num_tokens, + num_experts=num_experts, + num_groups=num_groups, + group_topk=group_topk, + ) + else: + return torch.topk(scores, k=topk, dim=1) + + def compute_topk(scores, topk, num_groups=None, group_topk=None): + routing_action = ( + router_replay.router_replay_action if router_replay is not None else None + ) + + if routing_action is None: + return _compute_topk(scores, topk, num_groups=num_groups, group_topk=group_topk) + + if routing_action == RouterReplayAction.RECORD: + probs, top_indices = _compute_topk( + scores, topk, num_groups=num_groups, group_topk=group_topk + ) + if router_replay is not None: + router_replay.record_indices(top_indices) + return probs, top_indices + + elif routing_action == RouterReplayAction.REPLAY_FORWARD: + if router_replay is None or router_replay.target_topk_idx is None: + logger.warning( + "[R3] REPLAY_FORWARD: no target indices available, " + "falling back to normal routing." + ) + return _compute_topk(scores, topk, num_groups=num_groups, group_topk=group_topk) + top_indices = router_replay.target_topk_idx.to(scores.device) + probs = scores.gather(1, top_indices) + return probs, top_indices + + elif routing_action == RouterReplayAction.REPLAY_BACKWARD: + if router_replay is None or not router_replay.replay_backward_list: + logger.warning( + "[R3] REPLAY_BACKWARD: no backward indices available, " + "falling back to normal routing." + ) + return _compute_topk(scores, topk, num_groups=num_groups, group_topk=group_topk) + top_indices = router_replay.replay_backward_list.pop(0).to(scores.device) + probs = scores.gather(1, top_indices) + return probs, top_indices + + else: + return _compute_topk(scores, topk, num_groups=num_groups, group_topk=group_topk) + + # --- Score function dispatch --- + if score_function == "softmax": + if use_pre_softmax: + scores = torch.softmax(logits, dim=-1, dtype=torch.float32).type_as(logits) + probs, top_indices = compute_topk(scores, topk, num_groups, group_topk) + else: + scores, top_indices = compute_topk(logits, topk, num_groups, group_topk) + probs = torch.softmax(scores, dim=-1, dtype=torch.float32).type_as(logits) + elif score_function == "sigmoid": + scores = torch.sigmoid(logits.float()).type_as(logits) + if expert_bias is not None: + scores_for_routing = scores + expert_bias + _, top_indices = compute_topk(scores_for_routing, topk, num_groups, group_topk) + scores = torch.gather(scores, dim=1, index=top_indices).type_as(logits) + else: + scores, top_indices = compute_topk(scores, topk, num_groups, group_topk) + probs = scores / (scores.sum(dim=-1, keepdim=True) + 1e-20) if topk > 1 else scores + else: + raise ValueError(f"[R3] Invalid score_function: {score_function}") + + if scaling_factor: + probs = probs * scaling_factor + + if torch.are_deterministic_algorithms_enabled(): + routing_probs = torch.zeros_like(logits) + rows = torch.arange(num_tokens, device=logits.device).unsqueeze(1) + routing_probs.index_put_((rows, top_indices), probs, accumulate=False) + routing_map = torch.zeros_like(logits, dtype=logits.dtype) + routing_map.index_put_( + (rows, top_indices), + torch.ones_like(probs, dtype=routing_map.dtype), + accumulate=False, + ) + routing_map = routing_map.bool() + else: + routing_probs = torch.zeros_like(logits).scatter(1, top_indices, probs) + routing_map = torch.zeros_like(logits).int().scatter(1, top_indices, 1).bool() + + return routing_probs, routing_map + + +# =================================================================== +# Aux-loss helpers (from verl reference) +# =================================================================== + + +def _get_aux_loss_coeff(_self, aux_loss_type: str) -> float: + """Return the aux loss coeff for the given auxiliary loss type.""" + if isinstance(_self.routing_type, str): + if _self.routing_type == aux_loss_type: + return _self.config.moe_aux_loss_coeff + if isinstance(_self.routing_type, list): + try: + idx = _self.routing_type.index(aux_loss_type) + return _self.config.moe_aux_loss_coeff[idx] + except (ValueError, IndexError): + return 0.0 + return 0.0 + + +def _is_aux_loss_enabled(_self) -> bool: + """Check if any auxiliary loss is enabled.""" + for aux_loss_type in ["aux_loss", "seq_aux_loss", "global_aux_loss"]: + if _get_aux_loss_coeff(_self, aux_loss_type) > 0: + return True + return False + + +# =================================================================== +# patched_routing -- replaces TopKRouter.routing +# =================================================================== + + +def patched_routing(self, logits: torch.Tensor, *args, **kwargs): + """Patched ``TopKRouter.routing`` that supports router replay. + + Drop-in replacement for ``TopKRouter.routing`` that delegates to + ``_patched_topk_routing_with_score_function`` which honours the + ``RouterReplayAction`` set on the per-layer ``RouterReplay`` instance. + """ + seq_length, bsz = logits.shape[:2] + logits = logits.view(-1, self.config.num_moe_experts) + + # Apply Z-Loss + logits = self.apply_z_loss(logits) + + moe_router_fusion = getattr(self.config, "moe_router_fusion", False) + + # Calculate probs and routing_map for token dispatching + if self.routing_type == "sinkhorn": + probs, routing_map = self.sinkhorn_load_balancing(logits) + else: + probs, routing_map = _patched_topk_routing_with_score_function( + logits=logits, + topk=self.topk, + use_pre_softmax=self.config.moe_router_pre_softmax, + num_groups=self.config.moe_router_num_groups, + group_topk=self.config.moe_router_group_topk, + scaling_factor=self.config.moe_router_topk_scaling_factor, + score_function=self.score_function, + expert_bias=self.expert_bias, + fused=moe_router_fusion, + router_replay=getattr(self, "router_replay", None), + ) + + # Apply token dropping to probs and routing_map. + if ( + self.config.moe_expert_capacity_factor is not None + and apply_router_token_dropping is not None + ): + probs, routing_map = apply_router_token_dropping( + probs, + routing_map, + router_topk=self.topk, + capacity_factor=self.config.moe_expert_capacity_factor, + drop_policy=self.config.moe_token_drop_policy, + pad_to_capacity=self.config.moe_pad_expert_input_to_capacity, + ) + + if not hasattr(self, "is_aux_loss_enabled"): + self.is_aux_loss_enabled = types.MethodType(_is_aux_loss_enabled, self) + + # Apply aux loss + if ( + self.training + and torch.is_grad_enabled() + and self.is_aux_loss_enabled() + and compute_routing_scores_for_aux_loss is not None + ): + routing_map_for_aux_loss, scores_for_aux_loss = ( + compute_routing_scores_for_aux_loss( + logits, + self.topk, + self.score_function, + fused=self.config.moe_router_fusion, + ) + ) + probs = self._apply_aux_loss(probs, scores_for_aux_loss, routing_map_for_aux_loss) + probs = self._apply_seq_aux_loss( + probs, scores_for_aux_loss, routing_map_for_aux_loss, seq_length, bsz + ) + probs = self._apply_global_aux_loss( + probs, scores_for_aux_loss, routing_map_for_aux_loss + ) + + # Update expert bias and tokens_per_expert + if self.enable_expert_bias and torch.is_grad_enabled(): + with torch.no_grad(): + self.local_tokens_per_expert += routing_map.sum(dim=0) + + return probs, routing_map + + +# =================================================================== +# Sentinel to prevent double-patching +# =================================================================== +_PATCHES_APPLIED = False + +# Store original methods for undo +_ORIGINAL_TF_CONFIG_INIT = None +_ORIGINAL_TOPK_ROUTER_INIT = None +_ORIGINAL_TOPK_ROUTER_ROUTING = None +_ORIGINAL_DISPATCHER_PREPROCESS = None + + +# =================================================================== +# apply_router_replay_patch +# =================================================================== + + +def apply_router_replay_patch() -> None: + """Apply all Megatron-Core monkey-patches required for Router Replay. + + Safe to call multiple times -- subsequent calls are no-ops. + Must be called **before** model creation. + """ + global _PATCHES_APPLIED + if _PATCHES_APPLIED: + logger.info("[R3] Router replay patches already applied; skipping.") + return + + logger.info("[R3] Applying Router Replay patches...") + + # Clear router instances to avoid state leakage between model inits. + RouterReplay.router_instances.clear() + + _patch_transformer_config_init() + _patch_topk_router_init() + _patch_topk_router_routing() + _patch_alltoall_dispatcher_preprocess() + + _PATCHES_APPLIED = True + logger.info("[R3] All Router Replay patches applied successfully.") + + +def remove_router_replay_patch() -> None: + """Undo all patches (primarily for test cleanup).""" + global _PATCHES_APPLIED + _undo_transformer_config_patch() + _undo_topk_router_init_patch() + _undo_topk_router_routing_patch() + _undo_dispatcher_patch() + RouterReplay.router_instances.clear() + _PATCHES_APPLIED = False + logger.info("[R3] All Router Replay patches removed.") + + +# =================================================================== +# Patch 1: TransformerConfig.__init__ +# =================================================================== + + +def _patch_transformer_config_init() -> None: + """Patch ``TransformerConfig.__init__`` to accept ``enable_routing_replay``.""" + global _ORIGINAL_TF_CONFIG_INIT + + if getattr(TransformerConfig, "_r3_config_patched", False): + return + + # Inspect the current signature to add enable_routing_replay + try: + sig = inspect.signature(TransformerConfig.__init__) + native_params = sig.parameters + params = list(sig.parameters.values()) + except Exception: + sig = None + native_params = {} + params = [] + + ext_attr = "enable_routing_replay" + + if ext_attr not in native_params and sig is not None: + new_param = inspect.Parameter( + ext_attr, inspect.Parameter.KEYWORD_ONLY, default=False + ) + if params and params[-1].kind == inspect.Parameter.VAR_KEYWORD: + params.insert(-1, new_param) + else: + params.append(new_param) + try: + TransformerConfig.__init__.__signature__ = sig.replace(parameters=params) + except Exception as e: + logger.warning("[R3] Failed to update TransformerConfig signature: %s", e) + + _ORIGINAL_TF_CONFIG_INIT = TransformerConfig.__init__ + + @wraps(_ORIGINAL_TF_CONFIG_INIT) + def patched_tf_config_init(self, *args, **kwargs): + enable_routing_replay = kwargs.get("enable_routing_replay", False) + if "enable_routing_replay" not in native_params: + enable_routing_replay = kwargs.pop("enable_routing_replay", False) + _ORIGINAL_TF_CONFIG_INIT(self, *args, **kwargs) + self.enable_routing_replay = enable_routing_replay + + TransformerConfig.__init__ = patched_tf_config_init + TransformerConfig._r3_config_patched = True + logger.debug("[R3] TransformerConfig.__init__ patched to accept enable_routing_replay.") + + +def _undo_transformer_config_patch() -> None: + global _ORIGINAL_TF_CONFIG_INIT + if _ORIGINAL_TF_CONFIG_INIT is not None: + TransformerConfig.__init__ = _ORIGINAL_TF_CONFIG_INIT + if hasattr(TransformerConfig, "_r3_config_patched"): + del TransformerConfig._r3_config_patched + _ORIGINAL_TF_CONFIG_INIT = None + + +# =================================================================== +# Patch 2: TopKRouter.__init__ +# =================================================================== + + +def _patch_topk_router_init() -> None: + """Patch ``TopKRouter.__init__`` to create a ``RouterReplay`` instance.""" + global _ORIGINAL_TOPK_ROUTER_INIT + + if getattr(TopKRouter, "_r3_init_patched", False): + return + + _ORIGINAL_TOPK_ROUTER_INIT = TopKRouter.__init__ + + def patched_init(self, *args, **kwargs): + _ORIGINAL_TOPK_ROUTER_INIT(self, *args, **kwargs) + self.router_replay = None + if getattr(self.config, "enable_routing_replay", False): + self.router_replay = RouterReplay() + logger.debug( + "[R3] TopKRouter: created RouterReplay instance " + "(total instances: %d).", + len(RouterReplay.router_instances), + ) + + TopKRouter.__init__ = patched_init + TopKRouter._r3_init_patched = True + logger.debug("[R3] TopKRouter.__init__ patched.") + + +def _undo_topk_router_init_patch() -> None: + global _ORIGINAL_TOPK_ROUTER_INIT + if _ORIGINAL_TOPK_ROUTER_INIT is not None: + TopKRouter.__init__ = _ORIGINAL_TOPK_ROUTER_INIT + if hasattr(TopKRouter, "_r3_init_patched"): + del TopKRouter._r3_init_patched + _ORIGINAL_TOPK_ROUTER_INIT = None + + +# =================================================================== +# Patch 3: TopKRouter.routing +# =================================================================== + + +def _patch_topk_router_routing() -> None: + """Patch ``TopKRouter.routing`` with the replay-aware version.""" + global _ORIGINAL_TOPK_ROUTER_ROUTING + + if getattr(TopKRouter, "_r3_routing_patched", False): + return + + _ORIGINAL_TOPK_ROUTER_ROUTING = TopKRouter.routing + TopKRouter.routing = patched_routing + TopKRouter._r3_routing_patched = True + logger.debug("[R3] TopKRouter.routing patched.") + + +def _undo_topk_router_routing_patch() -> None: + global _ORIGINAL_TOPK_ROUTER_ROUTING + if _ORIGINAL_TOPK_ROUTER_ROUTING is not None: + TopKRouter.routing = _ORIGINAL_TOPK_ROUTER_ROUTING + if hasattr(TopKRouter, "_r3_routing_patched"): + del TopKRouter._r3_routing_patched + _ORIGINAL_TOPK_ROUTER_ROUTING = None + + +# =================================================================== +# Patch 4: MoEAlltoAllTokenDispatcher.preprocess +# =================================================================== + + +def _patch_alltoall_dispatcher_preprocess() -> None: + """Patch dispatcher preprocess to handle duplicate indices from replay.""" + global _ORIGINAL_DISPATCHER_PREPROCESS + + if MoEAlltoAllTokenDispatcher is None: + logger.warning( + "[R3] Cannot import MoEAlltoAllTokenDispatcher -- " + "skipping preprocess patch." + ) + return + + if getattr(MoEAlltoAllTokenDispatcher, "_r3_preprocess_patched", False): + return + + _ORIGINAL_DISPATCHER_PREPROCESS = MoEAlltoAllTokenDispatcher.preprocess + + def patched_preprocess(self, routing_map): + result = _ORIGINAL_DISPATCHER_PREPROCESS(self, routing_map) + if ( + getattr(self.config, "enable_routing_replay", False) + and not self.drop_and_pad + and self.config.moe_expert_capacity_factor is None + and not ( + getattr(self.config, "moe_router_padding_for_quantization", None) + or getattr(self.config, "moe_router_padding_for_fp8", None) + ) + ): + self.num_out_tokens = int(routing_map.sum().item()) + return result + + MoEAlltoAllTokenDispatcher.preprocess = patched_preprocess + MoEAlltoAllTokenDispatcher._r3_preprocess_patched = True + logger.debug("[R3] MoEAlltoAllTokenDispatcher.preprocess patched.") + + +def _undo_dispatcher_patch() -> None: + global _ORIGINAL_DISPATCHER_PREPROCESS + if MoEAlltoAllTokenDispatcher is None: + return + if _ORIGINAL_DISPATCHER_PREPROCESS is not None: + MoEAlltoAllTokenDispatcher.preprocess = _ORIGINAL_DISPATCHER_PREPROCESS + if hasattr(MoEAlltoAllTokenDispatcher, "_r3_preprocess_patched"): + del MoEAlltoAllTokenDispatcher._r3_preprocess_patched + _ORIGINAL_DISPATCHER_PREPROCESS = None diff --git a/areal/engine/router_replay_utils.py b/areal/engine/router_replay_utils.py new file mode 100644 index 0000000000..b767b27cc4 --- /dev/null +++ b/areal/engine/router_replay_utils.py @@ -0,0 +1,433 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Router Replay Utilities for AReaL. + +Handles the complete shape-transformation pipeline that converts rollout +routing indices into the layout expected by Megatron-Core's RouterReplay: + +1. **Left-padding removal** -- rollout batch is left-padded; training removes it. +2. **TP/SP splitting** -- sequence parallelism across tensor-model-parallel ranks. +3. **PP layer slicing** -- pipeline parallelism assigns different layers to ranks. +4. **Dense/MoE layer mapping** -- architectures with dense FFN layers before MoE. + +Ported from verl reference implementation, adapted for AReaL: +- No dependency on verl-specific imports +- No dependency on megatron.core.transformer.moe.router_replay +- Simplified packed-sequence handling (no preprocess_packed_seqs dependency) +- topk and num_moe_layers passed explicitly (no hardcoded guessing) +""" + +from __future__ import annotations + +import inspect +import logging +from typing import Optional + +import torch + +from areal.engine.router_replay_patch import RouterReplay, RouterReplayAction + +logger = logging.getLogger(__name__) + + +# =================================================================== +# Layer computation helpers (ported from verl, self-contained) +# =================================================================== + + +def get_num_layers_to_build(config, vp_stage=None, pp_rank=None) -> int: + """Determine the number of transformer layers to build for the current PP stage. + + Self-contained reimplementation that does not depend on + ``megatron.core.transformer.transformer_block.get_num_layers_to_build`` + which may not exist in all megatron-core versions. + """ + from megatron.core import parallel_state as mpu + + if pp_rank is None: + pp_rank = mpu.get_pipeline_model_parallel_rank() + + is_first_pp_stage = pp_rank == 0 + is_last_pp_stage = pp_rank == config.pipeline_model_parallel_size - 1 + + # Custom pipeline layout + if ( + hasattr(config, "pipeline_model_parallel_layout") + and config.pipeline_model_parallel_layout is not None + ): + try: + from megatron.core.transformer.enums import LayerType + + return config.pipeline_model_parallel_layout.get_num_layers_to_build( + layer_type=LayerType.decoder, vp_stage=vp_stage + ) + except ImportError: + pass + + first_stage_layers = getattr(config, "num_layers_in_first_pipeline_stage", None) + last_stage_layers = getattr(config, "num_layers_in_last_pipeline_stage", None) + + if first_stage_layers is not None or last_stage_layers is not None: + layers_to_distribute = config.num_layers + pipeline_stages_left = config.pipeline_model_parallel_size + + if first_stage_layers is not None: + layers_to_distribute -= first_stage_layers + pipeline_stages_left -= 1 + if last_stage_layers is not None: + layers_to_distribute -= last_stage_layers + pipeline_stages_left -= 1 + + if pipeline_stages_left > 0: + assert layers_to_distribute % pipeline_stages_left == 0 + num_layers_per_pipeline_rank = layers_to_distribute // pipeline_stages_left + else: + num_layers_per_pipeline_rank = 0 + + if is_first_pp_stage and first_stage_layers is not None: + num_layers_per_pipeline_rank = first_stage_layers + if is_last_pp_stage and last_stage_layers is not None: + num_layers_per_pipeline_rank = last_stage_layers + else: + num_layers = config.num_layers + if getattr(config, "account_for_embedding_in_pipeline_split", False): + num_layers += 1 + if getattr(config, "account_for_loss_in_pipeline_split", False): + num_layers += 1 + assert num_layers % config.pipeline_model_parallel_size == 0 + num_layers_per_pipeline_rank = num_layers // config.pipeline_model_parallel_size + + vp_size = config.virtual_pipeline_model_parallel_size + if vp_size is not None and config.pipeline_model_parallel_size > 1: + assert num_layers_per_pipeline_rank % vp_size == 0 + num_layers_to_build = num_layers_per_pipeline_rank // vp_size + else: + num_layers_to_build = num_layers_per_pipeline_rank + + # Account for embedding/loss layers + if getattr(config, "account_for_embedding_in_pipeline_split", False): + if is_first_pp_stage and ( + vp_stage is None or vp_stage == 0 + ): + num_layers_to_build -= 1 + + if getattr(config, "account_for_loss_in_pipeline_split", False): + vp_last = (vp_size is None) or (vp_stage == vp_size - 1) + if is_last_pp_stage and vp_last: + num_layers_to_build -= 1 + + return num_layers_to_build + + +def is_moe_layer(tf_config, layer_idx: int) -> bool: + """Check whether a given global layer index is an MoE layer.""" + moe_layer_freq = getattr(tf_config, "moe_layer_freq", None) + if moe_layer_freq is None: + # If not set, assume all layers are MoE + return True + if isinstance(moe_layer_freq, int): + return layer_idx % moe_layer_freq == 0 + elif isinstance(moe_layer_freq, list): + return moe_layer_freq[layer_idx] == 1 + else: + raise ValueError(f"[R3] Unsupported moe_layer_freq type: {type(moe_layer_freq)}") + + +def get_moe_num_layers_to_build(config, vp_stage=None, pp_rank=None) -> int: + """Count the number of MoE layers assigned to the current rank.""" + from megatron.core.transformer.transformer_layer import get_transformer_layer_offset + + total_layers = get_num_layers_to_build(config, vp_stage=vp_stage, pp_rank=pp_rank) + + sig = inspect.signature(get_transformer_layer_offset) + kwargs = {} + if "vp_stage" in sig.parameters and "pp_rank" in sig.parameters: + kwargs = {"vp_stage": vp_stage, "pp_rank": pp_rank} + elif "pp_rank" in sig.parameters: + kwargs = {"pp_rank": pp_rank} + + layer_offset = get_transformer_layer_offset(config, **kwargs) + return sum( + 1 + for idx in range(layer_offset, layer_offset + total_layers) + if is_moe_layer(config, idx) + ) + + +def get_current_rank_layer_info(tf_config, vp_rank=None) -> dict: + """Return ``{"start", "end", "count"}`` for the current PP rank's layer range.""" + from megatron.core.transformer.transformer_layer import get_transformer_layer_offset + + if vp_rank is None: + vp_rank = 0 + + num_layers = get_num_layers_to_build(tf_config, vp_stage=vp_rank) + + sig = inspect.signature(get_transformer_layer_offset) + kwargs = {} + if "vp_stage" in sig.parameters: + kwargs["vp_stage"] = vp_rank + + offset = get_transformer_layer_offset(tf_config, **kwargs) + return {"start": offset, "end": offset + num_layers, "count": num_layers} + + +# =================================================================== +# RouterReplayHelper +# =================================================================== + + +class RouterReplayHelper: + """Helper to query router replay state and locate local RouterReplay instances.""" + + @staticmethod + def get_micro_batch_router_list(tf_config, vp_rank=None) -> list: + """Return the RouterReplay instances for the current (pp_rank, vp_stage).""" + vp_size = tf_config.virtual_pipeline_model_parallel_size + if vp_size is not None: + vp_rank = 0 if vp_rank is None else vp_rank + offset = 0 + for pre_vp_stage in range(vp_size): + if pre_vp_stage == vp_rank: + break + offset += get_moe_num_layers_to_build(tf_config, pre_vp_stage) + else: + offset = 0 + + num_layers = get_moe_num_layers_to_build(tf_config, vp_rank) + return RouterReplay.router_instances[offset : offset + num_layers] + + @staticmethod + def is_replay_forward_action(tf_config, vp_rank=None) -> bool: + instances = RouterReplayHelper.get_micro_batch_router_list(tf_config, vp_rank) + return bool( + instances + and instances[0].router_replay_action == RouterReplayAction.REPLAY_FORWARD + ) + + @staticmethod + def is_replay_backward_action(tf_config, vp_rank=None) -> bool: + instances = RouterReplayHelper.get_micro_batch_router_list(tf_config, vp_rank) + return bool( + instances + and instances[0].router_replay_action == RouterReplayAction.REPLAY_BACKWARD + ) + + +# =================================================================== +# set_router_replay_data -- core function +# =================================================================== + + +def set_router_replay_data( + layers_topk_idx: torch.Tensor, + attention_mask: torch.Tensor, + tf_config, + vp_rank: Optional[int] = None, +) -> None: + """Scatter packed router top-k indices to SP ranks and update RouterReplay instances. + + Simplified for AReaL: no dependency on preprocess_packed_seqs / postprocess_packed_seqs. + + Args: + layers_topk_idx: ``(bs, max_seq_len, num_moe_layers, topk)`` -- the replay data. + attention_mask: ``(bs, max_seq_len)`` -- 1 for real tokens, 0 for padding. + tf_config: Megatron TransformerConfig. + vp_rank: Virtual pipeline stage rank override. + """ + from megatron.core import parallel_state as mpu + from megatron.core.tensor_parallel import scatter_to_sequence_parallel_region + + with torch.no_grad(): + device = torch.cuda.current_device() + bs, max_seq_len = attention_mask.shape[:2] + + # Step 1: Remove left-padding -> flat (total_real_tokens, num_layers, topk) + seq_lens = attention_mask.sum(dim=1).long() # (bs,) + pieces = [] + for i in range(bs): + slen = int(seq_lens[i].item()) + mask = attention_mask[i].bool() + pieces.append(layers_topk_idx[i, mask][:slen]) + flat_tokens = torch.cat(pieces, dim=0) # (total_real_tokens, num_layers, topk) + + # Step 2: Scatter to SP ranks + # scatter_to_sequence_parallel_region expects (seq, ...) and splits dim=0 + flat_tokens = flat_tokens.to(device) + local_tokens = scatter_to_sequence_parallel_region(flat_tokens) + # local_tokens: (local_tokens_count, num_layers, topk) + + # Step 3: Permute to (num_layers, local_tokens_count, topk) + layers_topk = local_tokens.permute(1, 0, 2) + + # Step 4: Distribute to RouterReplay instances for local PP layers + local_info = get_current_rank_layer_info(tf_config, vp_rank) + offset, end = local_info["start"], local_info["end"] + router_list = RouterReplayHelper.get_micro_batch_router_list(tf_config, vp_rank) + + # Determine indexing: if dim-0 covers all layers, use absolute index; + # otherwise (only MoE layers), use MoE-layer ordinal. + index_by_layer = len(layers_topk) == tf_config.num_layers + + moe_idx = sum(1 for i in range(offset) if is_moe_layer(tf_config, i)) + + router_offset = 0 + for layer_idx in range(offset, end): + if not is_moe_layer(tf_config, layer_idx): + continue + router = router_list[router_offset] + idx = layer_idx if index_by_layer else moe_idx + router.set_target_indices(layers_topk[idx].to(torch.int64)) + router_offset += 1 + moe_idx += 1 + + logger.debug( + "[R3] set_router_replay_data: distributed %d layers of replay data " + "to %d router instances (PP offset=%d).", + len(layers_topk), + len(router_list), + offset, + ) + + +# =================================================================== +# Per-microbatch replay control +# =================================================================== + + +def setup_per_microbatch_replay_forward( + routed_experts: torch.Tensor, + attention_mask: torch.Tensor, + tf_config, + vp_rank: Optional[int] = None, +) -> None: + """Set up RouterReplay for a single micro-batch's forward pass. + + Args: + routed_experts: ``(batch, padded_seq, num_moe_layers, topk)`` + attention_mask: ``(batch, padded_seq)`` + tf_config: Megatron TransformerConfig. + vp_rank: Virtual pipeline stage rank override. + """ + routed_experts = routed_experts.to(torch.int32) + set_router_replay_data(routed_experts, attention_mask, tf_config, vp_rank) + RouterReplay.set_global_router_replay_action(RouterReplayAction.REPLAY_FORWARD) + logger.debug("[R3] Forward replay mode set for micro-batch.") + + +def setup_per_microbatch_replay_backward() -> None: + """Switch to backward replay mode for activation-checkpoint recomputation.""" + RouterReplay.set_global_router_replay_action(RouterReplayAction.REPLAY_BACKWARD) + logger.debug("[R3] Switched to backward replay mode.") + + +def clear_router_replay() -> None: + """Clear all RouterReplay state after a full forward-backward pass.""" + RouterReplay.clear_global_indices() + RouterReplay.clear_global_router_replay_action() + logger.debug("[R3] Router replay state cleared.") + + +# =================================================================== +# preprocess_routed_experts_batch +# =================================================================== + + +def preprocess_routed_experts_batch( + routed_experts_np, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + num_moe_layers: int, + topk: int, + compress_dtype: bool = True, +) -> torch.Tensor: + """Convert a numpy ``routed_experts`` array from the inference engine + into a left-padded torch tensor. + + The inference engine returns shape ``(num_tokens, num_moe_layers * topk)`` + where ``num_tokens = prompt_len + gen_len - 1`` (SGLang convention). + We reshape to ``(1, seq_len, num_moe_layers, topk)`` with left-padding. + + Args: + routed_experts_np: ``np.ndarray`` of shape ``(num_tokens, num_moe_layers*topk)``. + input_ids: ``(1, seq_len)``. + attention_mask: ``(1, seq_len)``. + num_moe_layers: Number of MoE layers in the model. No guessing! + topk: Router top-k value. No guessing! + compress_dtype: Downcast to ``uint8``/``int16`` if possible. + + Returns: + ``torch.Tensor`` of shape ``(1, seq_len, num_moe_layers, topk)``. + """ + import numpy as np + + if routed_experts_np is None: + return None + + seq_len = input_ids.shape[1] + num_sgl_tokens = routed_experts_np.shape[0] + flat_dim = routed_experts_np.shape[1] + + expected_flat = num_moe_layers * topk + if flat_dim != expected_flat: + logger.warning( + "[R3] preprocess_routed_experts_batch: flat_dim=%d != " + "num_moe_layers(%d) * topk(%d) = %d. " + "Attempting to infer from flat_dim.", + flat_dim, + num_moe_layers, + topk, + expected_flat, + ) + # Fallback: try to use flat_dim directly + if flat_dim % topk == 0: + num_moe_layers = flat_dim // topk + elif flat_dim % num_moe_layers == 0: + topk = flat_dim // num_moe_layers + else: + raise ValueError( + f"[R3] Cannot reshape routed_experts: flat_dim={flat_dim} " + f"is not divisible by topk={topk} or num_moe_layers={num_moe_layers}." + ) + + reshaped = routed_experts_np.reshape(num_sgl_tokens, num_moe_layers, topk) + tensor = torch.from_numpy(reshaped.astype(np.int32)) + + # Build (1, seq_len, num_moe_layers, topk) with left padding + real_tokens = int(attention_mask.sum().item()) + padded = torch.zeros(1, seq_len, num_moe_layers, topk, dtype=torch.int32) + left_pad = seq_len - real_tokens + n = min(num_sgl_tokens, real_tokens) + padded[0, left_pad : left_pad + n] = tensor[:n] + + if compress_dtype: + max_val = padded.max().item() + if max_val < 256: + padded = padded.to(torch.uint8) + elif max_val < 32768: + padded = padded.to(torch.int16) + + logger.debug( + "[R3] preprocess_routed_experts_batch: shape=%s dtype=%s " + "(num_moe_layers=%d, topk=%d, sgl_tokens=%d, real_tokens=%d).", + padded.shape, + padded.dtype, + num_moe_layers, + topk, + num_sgl_tokens, + real_tokens, + ) + + return padded diff --git a/areal/trainer/ppo/actor_r3_patch.py b/areal/trainer/ppo/actor_r3_patch.py new file mode 100644 index 0000000000..e70408f14d --- /dev/null +++ b/areal/trainer/ppo/actor_r3_patch.py @@ -0,0 +1,83 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +R3 metrics and logging helpers for the PPO actor. + +When Router Replay (R3) is enabled, these functions compute and log +statistics about the replayed routing decisions, such as: + +- The fraction of micro-batches that carried replay data. +- Per-step summary of routing shapes and data types. + +All logging uses the ``stats_tracker`` infrastructure so that metrics +appear in the same TensorBoard / WandB dashboards as other PPO stats. +""" + +from __future__ import annotations + +import logging +from typing import Any + +import torch + +from areal.utils import stats_tracker + +logger = logging.getLogger(__name__) + + +def log_r3_data_stats( + data: dict[str, Any], + scope: str = "r3", +) -> None: + """Log summary statistics about the ``routed_experts`` tensor in a + training data dict. + + Called once per PPO update step (not per micro-batch) to avoid + log spam. + + Args: + data: The training data dict that may contain ``"routed_experts"``. + scope: Stats-tracker scope prefix. + """ + re = data.get("routed_experts") + if re is None: + return + + with stats_tracker.scope(scope): + if isinstance(re, torch.Tensor): + stats_tracker.scalar( + r3_batch_size=re.shape[0], + r3_seq_len=re.shape[1], + r3_num_layers=re.shape[2] if re.dim() >= 3 else 0, + r3_topk=re.shape[3] if re.dim() >= 4 else 0, + r3_dtype_bytes=re.element_size(), + r3_max_expert_id=re.max().item() if re.numel() > 0 else 0, + ) + else: + stats_tracker.scalar(r3_present=0) + + +def strip_routed_experts_before_loss( + data: dict[str, Any], +) -> dict[str, Any]: + """Remove ``routed_experts`` from the data dict before the loss function. + + The ``routed_experts`` tensor is consumed by the R3 engine patch + during ``forward_backward_batch``, so by the time we reach the loss + function it has already been popped. This function is a safety net. + + Returns the data dict (modified in-place). + """ + data.pop("routed_experts", None) + return data diff --git a/areal/trainer/rl_trainer.py b/areal/trainer/rl_trainer.py index 98c7c50ce6..10ad73707f 100644 --- a/areal/trainer/rl_trainer.py +++ b/areal/trainer/rl_trainer.py @@ -187,6 +187,11 @@ def __init__( ) engine_init_kwargs = {"addr": None, "ft_spec": ft_spec} + + # R3: Propagate router replay flag to actor engine config so that + # MegatronEngine.initialize() can apply the R3 patch. + if getattr(config.rollout, "return_routed_experts", False): + config.actor._r3_enable_router_replay = True self.actor.initialize(**engine_init_kwargs, role="actor") if self.critic is not None: self.critic.initialize(**engine_init_kwargs, role="critic") @@ -440,6 +445,12 @@ def train( args={"global_step": global_step}, ), ): + # R3: Log routing replay statistics if available. + if getattr(self.config.rollout, "return_routed_experts", False): + from areal.trainer.ppo.actor_r3_patch import log_r3_data_stats + for traj in adv_batch: + log_r3_data_stats(traj) + self.actor.ppo_update(adv_batch) self.actor.step_lr_scheduler() self.actor.get_device_stats().log("ppo update") diff --git a/areal/workflow/rlvr.py b/areal/workflow/rlvr.py index 54dfcc63b4..e870a08f2f 100644 --- a/areal/workflow/rlvr.py +++ b/areal/workflow/rlvr.py @@ -172,4 +172,20 @@ async def arun_episode( "attention_mask": torch.ones(len(seq), dtype=torch.bool), "rewards": torch.tensor(reward, dtype=torch.float32), } - return {k: v.unsqueeze(0) for k, v in res.items()} + res = {k: v.unsqueeze(0) for k, v in res.items()} + + # R3: Extract and inject routed_experts from rollout response + if resp.routed_experts is not None: + from areal.workflow.rlvr_r3_patch import ( + extract_routed_experts, + inject_routed_experts_into_result, + ) + + routed_experts_tensor = extract_routed_experts( + resp.routed_experts, + res["input_ids"], + res["attention_mask"], + ) + res = inject_routed_experts_into_result(res, routed_experts_tensor) + + return res diff --git a/areal/workflow/rlvr_r3_patch.py b/areal/workflow/rlvr_r3_patch.py new file mode 100644 index 0000000000..de70b733a1 --- /dev/null +++ b/areal/workflow/rlvr_r3_patch.py @@ -0,0 +1,162 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +R3 helpers for the RLVR workflow. + +These functions bridge the inference-time ``ModelResponse.routed_experts`` +(a numpy array of shape ``(num_sgl_tokens, num_moe_layers * topk)``) into the +training-side tensor dict so that the Megatron engine can replay routing +decisions. + +The conversion pipeline: + 1. ``extract_routed_experts`` -- called in ``arun_episode`` right after + ``_collect_samples``. Converts the numpy array to a left-padded + torch tensor of shape ``(1, seq_len, num_moe_layers, topk)``. + 2. The tensor is added to the result dict under key ``"routed_experts"``. + 3. During training, the ``MegatronEngine`` R3 patch picks it up from + the batch data and feeds it to ``setup_per_microbatch_replay_forward``. + +Note on num_moe_layers and topk: + At the workflow level, we may not know the exact model config values. + We store the raw numpy array shape info and let the engine layer + (which has access to tf_config) do the final reshape. As a + practical compromise, we accept optional num_moe_layers and topk + parameters and fall back to shape-based inference when not provided. +""" + +from __future__ import annotations + +import logging +from typing import Optional + +import numpy as np +import torch + +logger = logging.getLogger(__name__) + + +def extract_routed_experts( + routed_experts_np: Optional[np.ndarray], + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + num_moe_layers: Optional[int] = None, + topk: Optional[int] = None, + compress_dtype: bool = True, +) -> Optional[torch.Tensor]: + """Convert ``ModelResponse.routed_experts`` into a training tensor. + + Args: + routed_experts_np: ``np.ndarray`` of shape ``(num_sgl_tokens, num_moe_layers * topk)`` + as returned by the SGLang inference backend, or ``None``. + input_ids: ``(1, seq_len)`` token ids (prompt + response). + attention_mask: ``(1, seq_len)`` with 1 for real tokens, 0 for padding. + num_moe_layers: Number of MoE layers. If None, inferred from shape. + topk: Router top-k. If None, inferred from shape. + compress_dtype: Downcast to ``uint8`` / ``int16`` when possible. + + Returns: + ``torch.Tensor`` of shape ``(1, seq_len, num_moe_layers, topk)`` or ``None``. + """ + if routed_experts_np is None: + return None + + try: + if num_moe_layers is not None and topk is not None: + from areal.engine.router_replay_utils import ( + preprocess_routed_experts_batch, + ) + + return preprocess_routed_experts_batch( + routed_experts_np, + input_ids, + attention_mask, + num_moe_layers=num_moe_layers, + topk=topk, + compress_dtype=compress_dtype, + ) + else: + # Fallback: infer num_moe_layers and topk from shape + return _infer_and_preprocess( + routed_experts_np, + input_ids, + attention_mask, + compress_dtype=compress_dtype, + ) + except Exception: + logger.warning( + "[R3] Failed to preprocess routed_experts (shape=%s); skipping.", + getattr(routed_experts_np, "shape", "unknown"), + exc_info=True, + ) + return None + + +def _infer_and_preprocess( + routed_experts_np: np.ndarray, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + compress_dtype: bool = True, +) -> torch.Tensor: + """Infer num_moe_layers and topk from shape, then preprocess. + + We try common topk values (6, 8, 4, 2, 1) that divide the flat + dimension evenly. This is a fallback when model config is not available. + """ + flat_dim = routed_experts_np.shape[1] + + topk = None + for candidate_topk in [6, 8, 4, 2, 1]: + if flat_dim % candidate_topk == 0: + topk = candidate_topk + break + if topk is None: + topk = 1 + logger.warning( + "[R3] Cannot infer topk from flat_dim=%d; falling back to topk=1.", + flat_dim, + ) + num_moe_layers = flat_dim // topk + + logger.debug( + "[R3] Inferred num_moe_layers=%d, topk=%d from flat_dim=%d " + "(warning: these may be incorrect without model config).", + num_moe_layers, + topk, + flat_dim, + ) + + from areal.engine.router_replay_utils import preprocess_routed_experts_batch + + return preprocess_routed_experts_batch( + routed_experts_np, + input_ids, + attention_mask, + num_moe_layers=num_moe_layers, + topk=topk, + compress_dtype=compress_dtype, + ) + + +def inject_routed_experts_into_result( + result: dict[str, torch.Tensor], + routed_experts: Optional[torch.Tensor], +) -> dict[str, torch.Tensor]: + """Add ``routed_experts`` to the result dict if available. + + This is a trivial helper kept separate for clarity and to centralise + the key name (``"routed_experts"``). + """ + if routed_experts is not None: + result["routed_experts"] = routed_experts + return result From e7bf2c6c635a19390b3a3b44fb274e01651a5071 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Sun, 12 Apr 2026 21:35:27 +0800 Subject: [PATCH 002/112] feat: fix --- areal/engine/megatron_engine_r3_patch.py | 261 ++++++++++++++++++----- areal/engine/router_replay_utils.py | 95 ++++++++- areal/trainer/ppo/actor.py | 64 ++++++ 3 files changed, 361 insertions(+), 59 deletions(-) diff --git a/areal/engine/megatron_engine_r3_patch.py b/areal/engine/megatron_engine_r3_patch.py index 9e6fad4f20..f075a6fe27 100644 --- a/areal/engine/megatron_engine_r3_patch.py +++ b/areal/engine/megatron_engine_r3_patch.py @@ -28,6 +28,26 @@ it to each micro-batch using the ``forward_indices`` and ``group_lens`` from ``MicroBatchList``. +Key design decisions (v3 fixes): +- **Problem 1 fix**: routed_experts is popped from mb_list.data AND from + every mb in mb_list.mbs / padded_mbs so that it is never broadcast + incorrectly to mini-batches. The PPO actor side also pops it before + ``split_padded_tensor_dict_into_mb_list`` and manually re-distributes. +- **Problem 2 fix**: ``__iter__`` is patched on the *instance* via + ``types.MethodType``, not on the class, to avoid affecting other + ``MicroBatchList`` instances. +- **Problem 3 fix**: ``REPLAY_BACKWARD`` is NOT set after + ``forward_backward_batch`` returns (by then backward is already done + in Megatron's 1F1B schedule). Instead, ``set_target_indices()`` in + ``setup_per_microbatch_replay_forward`` appends to ``replay_backward_list`` + so that activation-checkpoint recompute during backward already has the + data it needs via ``REPLAY_FORWARD`` mode. +- **Problem 5 fix**: When ``attention_mask`` is absent (packed format), + we reconstruct it from ``cu_seqlens`` + ``max_seqlen`` so that + ``setup_per_microbatch_replay_forward`` always receives valid data. +- **Problem 7 fix**: Sample-count inference uses cu_seqlens first, then + attention_mask, then input_ids, then even division as last resort. + Usage:: from areal.engine.megatron_engine_r3_patch import patch_megatron_engine_for_r3 @@ -87,17 +107,94 @@ def patch_megatron_engine_for_r3( # =================================================================== -# routed_experts splitting +# attention_mask reconstruction from cu_seqlens (Problem 5 fix) # =================================================================== +def _reconstruct_attention_mask_from_cu_seqlens( + cu_seqlens: torch.Tensor, + max_seqlen: int, +) -> torch.Tensor: + """Reconstruct a 2D ``attention_mask`` from packed ``cu_seqlens``. + + After ``pack_tensor_dict``, the original ``attention_mask`` is replaced + by ``cu_seqlens`` (shape ``(B+1,)``) and ``max_seqlen``. For R3's + ``set_router_replay_data`` we need an ``attention_mask`` of shape + ``(B, padded_seq_len)`` where padded_seq_len = max_seqlen. + + Args: + cu_seqlens: ``(B+1,)`` cumulative sequence lengths. + max_seqlen: Maximum sequence length (the padded dimension). + + Returns: + ``torch.Tensor`` of shape ``(B, max_seqlen)`` with dtype ``torch.bool``. + """ + bs = cu_seqlens.shape[0] - 1 + seq_lens = cu_seqlens[1:] - cu_seqlens[:-1] # (B,) + # Build mask: position j < seq_lens[i] -> True + positions = torch.arange(max_seqlen, device=cu_seqlens.device).unsqueeze(0) # (1, S) + mask = positions < seq_lens.unsqueeze(1) # (B, S) + logger.debug( + "[R3] Reconstructed attention_mask from cu_seqlens: " + "bs=%d, max_seqlen=%d, seq_lens=%s.", + bs, + max_seqlen, + seq_lens.tolist()[:8], # log first 8 for brevity + ) + return mask + + +# =================================================================== +# routed_experts splitting (Problem 7 fix: robust sample-count inference) +# =================================================================== + + +def _infer_mb_sample_count( + mb_dict: dict, + total_bs: int, + n_mbs: int, +) -> int: + """Infer the number of samples in a micro-batch dict. + + Tries multiple strategies in order of reliability: + 1. ``cu_seqlens`` -> ``len(cu_seqlens) - 1`` (packed format, most reliable) + 2. ``attention_mask.shape[0]`` (padded format) + 3. ``input_ids.shape[0]`` (fallback) + 4. Even division (last resort) + """ + if isinstance(mb_dict, dict): + # Strategy 1: cu_seqlens (packed format -- most common after pack_tensor_dict) + cu = mb_dict.get("cu_seqlens") + if cu is not None: + return len(cu) - 1 + + # Strategy 2: attention_mask (padded format) + attn = mb_dict.get("attention_mask") + if attn is not None and hasattr(attn, "shape"): + return attn.shape[0] + + # Strategy 3: input_ids + ids = mb_dict.get("input_ids") + if ids is not None and hasattr(ids, "shape"): + return ids.shape[0] + + # Strategy 4: last resort + logger.warning( + "[R3] _infer_mb_sample_count: no reliable key found, " + "falling back to even division (%d / %d).", + total_bs, + n_mbs, + ) + return total_bs // n_mbs + + def _split_routed_experts_for_mbs( routed_experts: torch.Tensor, mb_list, ) -> list[torch.Tensor | None]: """Split the batch-level ``routed_experts`` tensor into per-micro-batch tensors. - Uses ``mb_list.forward_indices`` and ``mb_list.group_lens`` to correctly + Uses ``mb_list.forward_indices`` and per-MB sample counts to correctly reorder and slice samples, mirroring how ``split_padded_tensor_dict_into_mb_list`` splits other tensors. @@ -113,49 +210,82 @@ def _split_routed_experts_for_mbs( return [None] * len(mb_list) forward_indices = mb_list.forward_indices - group_lens = mb_list.group_lens + n_mbs = len(mb_list) if forward_indices is None: # No reordering -- just split evenly - n_mbs = len(mb_list) bs = routed_experts.shape[0] chunk = bs // n_mbs - return [routed_experts[i * chunk : (i + 1) * chunk] for i in range(n_mbs)] + result = [routed_experts[i * chunk : (i + 1) * chunk] for i in range(n_mbs)] + logger.debug( + "[R3] _split_routed_experts_for_mbs: no forward_indices, " + "split %d samples evenly into %d chunks of %d.", + bs, n_mbs, chunk, + ) + return result # Reorder by forward_indices (sample-level reordering) reordered = routed_experts[forward_indices] - # Split according to group_lens (number of samples per micro-batch) - # group_lens gives number of *tokens* per micro-batch, but since - # routed_experts is indexed by *sample* (dim-0 is bs), we need - # the number of samples per micro-batch. - # We can derive this from the mbs list. + # Determine number of samples per micro-batch from mbs dicts. result = [] offset = 0 for i, mb_dict in enumerate(mb_list.mbs): - # Determine number of samples in this micro-batch - # The mb_dict contains "attention_mask" which has shape (n_samples, ...) - if isinstance(mb_dict, dict) and "attention_mask" in mb_dict: - attn = mb_dict["attention_mask"] - if hasattr(attn, "shape"): - n_samples = attn.shape[0] - else: - n_samples = len(attn) - else: - # Fallback: try cu_seqlens - if isinstance(mb_dict, dict) and "cu_seqlens" in mb_dict: - cu = mb_dict["cu_seqlens"] - n_samples = len(cu) - 1 - else: - # Last resort: divide evenly - n_samples = routed_experts.shape[0] // len(mb_list) - + n_samples = _infer_mb_sample_count(mb_dict, routed_experts.shape[0], n_mbs) result.append(reordered[offset : offset + n_samples]) offset += n_samples + logger.debug( + "[R3] _split_routed_experts_for_mbs: split %d samples into %d mbs " + "with sizes %s (forward_indices len=%d).", + routed_experts.shape[0], + n_mbs, + [r.shape[0] for r in result], + len(forward_indices), + ) return result +# =================================================================== +# Per-MB attention_mask extraction (Problem 5 fix) +# =================================================================== + + +def _get_attention_mask_for_mb(mb_item) -> torch.Tensor | None: + """Extract or reconstruct ``attention_mask`` from a ``MicroBatchItem``. + + After ``pack_tensor_dict``, both ``orig_mb`` and ``padded_mb`` have + ``cu_seqlens`` instead of ``attention_mask``. We reconstruct the mask + from ``cu_seqlens`` + ``max_seqlen`` in the padded_mb (which reflects + the actual padded sequence length used by the model). + + Falls back to ``attention_mask`` if still present (e.g. tree training). + """ + # Try padded_mb first (has the actual padded dimensions) + if hasattr(mb_item, "padded_mb") and isinstance(mb_item.padded_mb, dict): + # Direct attention_mask + attn = mb_item.padded_mb.get("attention_mask") + if attn is not None: + return attn + # Reconstruct from cu_seqlens + cu = mb_item.padded_mb.get("cu_seqlens") + max_sl = mb_item.padded_mb.get("max_seqlen") + if cu is not None and max_sl is not None: + return _reconstruct_attention_mask_from_cu_seqlens(cu, int(max_sl)) + + # Try orig_mb as fallback + if hasattr(mb_item, "orig_mb") and isinstance(mb_item.orig_mb, dict): + attn = mb_item.orig_mb.get("attention_mask") + if attn is not None: + return attn + cu = mb_item.orig_mb.get("cu_seqlens") + max_sl = mb_item.orig_mb.get("max_seqlen") + if cu is not None and max_sl is not None: + return _reconstruct_attention_mask_from_cu_seqlens(cu, int(max_sl)) + + return None + + # =================================================================== # Wrapped forward_backward_batch # =================================================================== @@ -177,7 +307,6 @@ def _r3_forward_backward_batch( """ from areal.engine.router_replay_utils import ( clear_router_replay, - setup_per_microbatch_replay_backward, setup_per_microbatch_replay_forward, ) @@ -190,7 +319,8 @@ def _r3_forward_backward_batch( if hasattr(mb_list, "data") and isinstance(mb_list.data, dict): routed_experts_batch = mb_list.data.pop("routed_experts", None) - # Also clean from mbs and padded_mbs to avoid confusing downstream code + # Also clean from mbs and padded_mbs to avoid confusing downstream code. + # Problem 1: these would contain the un-split full tensor via not_to_split broadcast. for mb_dict in mb_list.mbs: if isinstance(mb_dict, dict): mb_dict.pop("routed_experts", None) @@ -208,7 +338,7 @@ def _r3_forward_backward_batch( mb_list, process_output_fn, forward_only=forward_only ) - logger.debug( + logger.info( "[R3] R3 forward_backward: %d micro-batches, routed_experts shape=%s, " "forward_only=%s", len(mb_list), @@ -229,8 +359,14 @@ def _r3_forward_backward_batch( model_config = self.tf_config # ------------------------------------------------------------------ - # 3. Wrap the MicroBatchList iterator to inject R3 setup before each - # micro-batch's forward pass. + # 3. Wrap the MicroBatchList iterator on the INSTANCE level + # (Problem 2 fix: do NOT modify __class__.__iter__). + # + # The iterator injects R3 setup before each micro-batch's forward. + # Because Megatron's 1F1B schedule interleaves forward and backward, + # we rely on set_target_indices() appending to replay_backward_list + # so that activation-checkpoint recompute in backward has the data + # it needs (Problem 3 fix). # ------------------------------------------------------------------ engine_ref = self @@ -255,14 +391,9 @@ def __next__(self): ) if re is not None: - # Get attention_mask from orig_mb or padded_mb - attn_mask = None - if hasattr(mb_item, "orig_mb") and isinstance(mb_item.orig_mb, dict): - attn_mask = mb_item.orig_mb.get("attention_mask") - if attn_mask is None and hasattr(mb_item, "padded_mb") and isinstance( - mb_item.padded_mb, dict - ): - attn_mask = mb_item.padded_mb.get("attention_mask") + # Problem 5 fix: reconstruct attention_mask from cu_seqlens + # when pack_tensor_dict has replaced it. + attn_mask = _get_attention_mask_for_mb(mb_item) if attn_mask is not None: try: @@ -271,6 +402,11 @@ def __next__(self): attn_mask, model_config, ) + logger.debug( + "[R3] Replay setup OK for micro-batch %d: " + "routed_experts=%s, attn_mask=%s.", + idx, re.shape, attn_mask.shape, + ) except Exception: logger.warning( "[R3] Failed to setup replay for micro-batch %d.", @@ -279,31 +415,54 @@ def __next__(self): ) else: logger.warning( - "[R3] Cannot find attention_mask for micro-batch %d; " - "skipping replay setup.", + "[R3] Cannot find or reconstruct attention_mask for " + "micro-batch %d; skipping replay setup. " + "Keys in orig_mb: %s, keys in padded_mb: %s.", idx, + list(mb_item.orig_mb.keys()) if hasattr(mb_item, "orig_mb") and isinstance(mb_item.orig_mb, dict) else "N/A", + list(mb_item.padded_mb.keys()) if hasattr(mb_item, "padded_mb") and isinstance(mb_item.padded_mb, dict) else "N/A", ) return mb_item - # Patch __iter__ on the MicroBatchList instance - original_iter = mb_list.__class__.__iter__ + # Problem 2 fix: patch __iter__ on the INSTANCE, not the class. + # Python's iter() protocol looks up __iter__ on the *type*, not the instance. + # So we cannot simply assign mb_list.__iter__ = ... and have iter() find it. + # Instead, we save the original class __iter__, temporarily replace it on + # the class for the duration of this call, and restore it in finally. + # This is safe because forward_backward_batch is synchronous. + original_class_iter = mb_list.__class__.__iter__ def _r3_iter(mb_list_self): - return _R3MicroBatchIterator(original_iter(mb_list_self)) + return _R3MicroBatchIterator(original_class_iter(mb_list_self)) mb_list.__class__.__iter__ = _r3_iter try: + # Problem 3 explanation: + # Megatron's forward_backward_func (e.g. 1F1B schedule) internally + # interleaves forward and backward for each micro-batch. We do NOT + # call setup_per_microbatch_replay_backward() after this function + # returns -- by then backward is already done. + # + # Instead, the backward replay is handled by the RouterReplay design: + # - set_target_indices() in setup_per_microbatch_replay_forward() + # appends indices to replay_backward_list for each layer. + # - When activation checkpointing triggers recompute during backward, + # the patched routing function checks router_replay_action. + # - We set REPLAY_FORWARD before each micro-batch's forward. + # During the same micro-batch's backward recompute, the + # replay_backward_list is consumed via pop(0) in FIFO order. + # + # For the non-activation-checkpoint case, backward simply uses + # autograd on the forward's recorded computation graph, so no + # re-routing occurs and replay is not needed. self._r3_original_forward_backward_batch( mb_list, process_output_fn, forward_only=forward_only ) - - # Switch to backward replay after forward pass - if not forward_only: - setup_per_microbatch_replay_backward() finally: - # Restore original iterator and clean up - mb_list.__class__.__iter__ = original_iter + # Restore original class __iter__ and clean up R3 state + mb_list.__class__.__iter__ = original_class_iter clear_router_replay() self._r3_per_mb_experts = None self._r3_mb_counter = 0 + logger.debug("[R3] forward_backward_batch cleanup complete.") diff --git a/areal/engine/router_replay_utils.py b/areal/engine/router_replay_utils.py index b767b27cc4..3536b7688c 100644 --- a/areal/engine/router_replay_utils.py +++ b/areal/engine/router_replay_utils.py @@ -27,6 +27,10 @@ - No dependency on megatron.core.transformer.moe.router_replay - Simplified packed-sequence handling (no preprocess_packed_seqs dependency) - topk and num_moe_layers passed explicitly (no hardcoded guessing) + +v3 changes: +- Problem 6 fix: guard scatter_to_sequence_parallel_region for tp_size==1 +- Improved logging and bounds checking throughout """ from __future__ import annotations @@ -248,12 +252,20 @@ def set_router_replay_data( vp_rank: Virtual pipeline stage rank override. """ from megatron.core import parallel_state as mpu - from megatron.core.tensor_parallel import scatter_to_sequence_parallel_region with torch.no_grad(): device = torch.cuda.current_device() bs, max_seq_len = attention_mask.shape[:2] + logger.debug( + "[R3] set_router_replay_data: input layers_topk_idx=%s, " + "attention_mask=%s, bs=%d, max_seq_len=%d.", + layers_topk_idx.shape, + attention_mask.shape, + bs, + max_seq_len, + ) + # Step 1: Remove left-padding -> flat (total_real_tokens, num_layers, topk) seq_lens = attention_mask.sum(dim=1).long() # (bs,) pieces = [] @@ -263,10 +275,37 @@ def set_router_replay_data( pieces.append(layers_topk_idx[i, mask][:slen]) flat_tokens = torch.cat(pieces, dim=0) # (total_real_tokens, num_layers, topk) - # Step 2: Scatter to SP ranks - # scatter_to_sequence_parallel_region expects (seq, ...) and splits dim=0 + logger.debug( + "[R3] set_router_replay_data: after left-padding removal: " + "flat_tokens=%s (total_real_tokens=%d).", + flat_tokens.shape, + flat_tokens.shape[0], + ) + + # Step 2: Scatter to SP ranks (Problem 6 fix: guard for non-SP case) + # When tp_size == 1 (no sequence parallelism), scatter_to_sequence_parallel_region + # should be an identity op in Megatron-Core. However, we guard against potential + # issues when the TP process group is trivial. flat_tokens = flat_tokens.to(device) - local_tokens = scatter_to_sequence_parallel_region(flat_tokens) + tp_size = mpu.get_tensor_model_parallel_world_size() + if tp_size > 1: + from megatron.core.tensor_parallel import scatter_to_sequence_parallel_region + local_tokens = scatter_to_sequence_parallel_region(flat_tokens) + logger.debug( + "[R3] set_router_replay_data: SP scatter tp_size=%d, " + "flat_tokens %s -> local_tokens %s.", + tp_size, + flat_tokens.shape, + local_tokens.shape, + ) + else: + # tp_size == 1: no SP splitting needed, use flat_tokens directly + local_tokens = flat_tokens + logger.debug( + "[R3] set_router_replay_data: tp_size=1, skipping SP scatter. " + "local_tokens=%s.", + local_tokens.shape, + ) # local_tokens: (local_tokens_count, num_layers, topk) # Step 3: Permute to (num_layers, local_tokens_count, topk) @@ -277,6 +316,16 @@ def set_router_replay_data( offset, end = local_info["start"], local_info["end"] router_list = RouterReplayHelper.get_micro_batch_router_list(tf_config, vp_rank) + if len(router_list) == 0: + logger.warning( + "[R3] set_router_replay_data: no RouterReplay instances found " + "for PP offset=%d..%d, vp_rank=%s. " + "Total router_instances=%d.", + offset, end, vp_rank, + len(RouterReplay.router_instances), + ) + return + # Determine indexing: if dim-0 covers all layers, use absolute index; # otherwise (only MoE layers), use MoE-layer ordinal. index_by_layer = len(layers_topk) == tf_config.num_layers @@ -287,18 +336,38 @@ def set_router_replay_data( for layer_idx in range(offset, end): if not is_moe_layer(tf_config, layer_idx): continue + if router_offset >= len(router_list): + logger.warning( + "[R3] set_router_replay_data: router_offset=%d >= " + "len(router_list)=%d. Layer assignment mismatch at " + "layer_idx=%d.", + router_offset, len(router_list), layer_idx, + ) + break router = router_list[router_offset] idx = layer_idx if index_by_layer else moe_idx + if idx >= len(layers_topk): + logger.warning( + "[R3] set_router_replay_data: layer index %d >= " + "layers_topk dim-0 (%d). Skipping.", + idx, len(layers_topk), + ) + moe_idx += 1 + router_offset += 1 + continue router.set_target_indices(layers_topk[idx].to(torch.int64)) router_offset += 1 moe_idx += 1 logger.debug( "[R3] set_router_replay_data: distributed %d layers of replay data " - "to %d router instances (PP offset=%d).", - len(layers_topk), + "to %d/%d router instances (PP layers %d..%d, tp_size=%d).", + router_offset, len(router_list), + len(RouterReplay.router_instances), offset, + end, + tp_size, ) @@ -321,6 +390,13 @@ def setup_per_microbatch_replay_forward( tf_config: Megatron TransformerConfig. vp_rank: Virtual pipeline stage rank override. """ + logger.debug( + "[R3] setup_per_microbatch_replay_forward: " + "routed_experts=%s (dtype=%s), attention_mask=%s.", + routed_experts.shape, + routed_experts.dtype, + attention_mask.shape, + ) routed_experts = routed_experts.to(torch.int32) set_router_replay_data(routed_experts, attention_mask, tf_config, vp_rank) RouterReplay.set_global_router_replay_action(RouterReplayAction.REPLAY_FORWARD) @@ -335,9 +411,10 @@ def setup_per_microbatch_replay_backward() -> None: def clear_router_replay() -> None: """Clear all RouterReplay state after a full forward-backward pass.""" + n_instances = len(RouterReplay.router_instances) RouterReplay.clear_global_indices() RouterReplay.clear_global_router_replay_action() - logger.debug("[R3] Router replay state cleared.") + logger.debug("[R3] Router replay state cleared (%d instances).", n_instances) # =================================================================== @@ -421,13 +498,15 @@ def preprocess_routed_experts_batch( logger.debug( "[R3] preprocess_routed_experts_batch: shape=%s dtype=%s " - "(num_moe_layers=%d, topk=%d, sgl_tokens=%d, real_tokens=%d).", + "(num_moe_layers=%d, topk=%d, sgl_tokens=%d, real_tokens=%d, " + "left_pad=%d).", padded.shape, padded.dtype, num_moe_layers, topk, num_sgl_tokens, real_tokens, + left_pad, ) return padded diff --git a/areal/trainer/ppo/actor.py b/areal/trainer/ppo/actor.py index 1c24a549c6..f9c21f285d 100644 --- a/areal/trainer/ppo/actor.py +++ b/areal/trainer/ppo/actor.py @@ -34,6 +34,53 @@ logger = logging.getLogger("PPOActor") + +def _split_routed_experts_for_minibatches( + routed_experts: torch.Tensor, + mb_inputs, +) -> list: + """Split routed_experts tensor per mini-batch using forward_indices. + + This handles R3 Problem 1: routed_experts is 4D and cannot be split + by split_padded_tensor_dict_into_mb_list (which only splits 2D tensors + with numel == bs * max_seqlen). + + Args: + routed_experts: ``(bs, seq_len, num_moe_layers, topk)`` + mb_inputs: ``MicroBatchList`` with ``forward_indices``. + + Returns: + List of tensors, one per mini-batch. + """ + forward_indices = mb_inputs.forward_indices + n_mbs = len(mb_inputs.mbs) + + if forward_indices is None: + # No reordering -- split evenly + bs = routed_experts.shape[0] + chunk = bs // n_mbs + return [routed_experts[i * chunk : (i + 1) * chunk] for i in range(n_mbs)] + + # Reorder by forward_indices + reordered = routed_experts[forward_indices] + + # Determine per-MB sample counts from group_lens and attention_mask info. + # Since we are before pack_tensor_dict, mbs still have attention_mask. + result = [] + offset = 0 + for i, mb_dict in enumerate(mb_inputs.mbs): + if isinstance(mb_dict, dict) and "attention_mask" in mb_dict: + n_samples = mb_dict["attention_mask"].shape[0] + elif isinstance(mb_dict, dict) and "cu_seqlens" in mb_dict: + n_samples = len(mb_dict["cu_seqlens"]) - 1 + else: + n_samples = routed_experts.shape[0] // n_mbs + result.append(reordered[offset : offset + n_samples]) + offset += n_samples + + return result + + class PPOActor: def __init__(self, config: PPOActorConfig, engine: TrainEngine): self.config = config @@ -321,11 +368,28 @@ def _ppo_update(self, data: dict[str, Any]) -> None: data.pop(key, None) # NOTE: calling engine.train() is critical to enabling gradient checkpointing self.engine.train() + + # R3 Problem 1 fix: Pop routed_experts BEFORE split_padded_tensor_dict_into_mb_list. + # routed_experts is 4D (bs, seq_len, num_moe_layers, topk) and its numel() != + # bs * max_seqlen, so it would be placed in not_to_split and broadcast identically + # to every mini-batch, causing data mismatch when ppo_n_minibatches > 1. + # We pop it here and manually split it per mini-batch after the standard split. + _r3_routed_experts = data.pop("routed_experts", None) + mb_inputs = split_padded_tensor_dict_into_mb_list( data, mb_spec=MicroBatchSpec(n_mbs=self.config.ppo_n_minibatches), ) + # R3: Manually split routed_experts and inject into each mini-batch. + if _r3_routed_experts is not None: + _r3_split = _split_routed_experts_for_minibatches( + _r3_routed_experts, mb_inputs + ) + for i, mb_dict in enumerate(mb_inputs.mbs): + if _r3_split[i] is not None: + mb_dict["routed_experts"] = _r3_split[i] + with stats_tracker.scope("update"): # Get current version for proximal approximation metrics current_version = self.engine.get_version() From 00ed92443a5a3c1711d0a704a75a04bdd945fb6d Mon Sep 17 00:00:00 2001 From: bingyechen Date: Sun, 12 Apr 2026 21:35:35 +0800 Subject: [PATCH 003/112] feat: fix --- areal/engine/megatron_engine_r3_patch.py | 28 ------------------------ areal/engine/router_replay_utils.py | 4 ---- areal/trainer/ppo/actor.py | 6 ----- 3 files changed, 38 deletions(-) diff --git a/areal/engine/megatron_engine_r3_patch.py b/areal/engine/megatron_engine_r3_patch.py index f075a6fe27..ab016a2468 100644 --- a/areal/engine/megatron_engine_r3_patch.py +++ b/areal/engine/megatron_engine_r3_patch.py @@ -28,26 +28,6 @@ it to each micro-batch using the ``forward_indices`` and ``group_lens`` from ``MicroBatchList``. -Key design decisions (v3 fixes): -- **Problem 1 fix**: routed_experts is popped from mb_list.data AND from - every mb in mb_list.mbs / padded_mbs so that it is never broadcast - incorrectly to mini-batches. The PPO actor side also pops it before - ``split_padded_tensor_dict_into_mb_list`` and manually re-distributes. -- **Problem 2 fix**: ``__iter__`` is patched on the *instance* via - ``types.MethodType``, not on the class, to avoid affecting other - ``MicroBatchList`` instances. -- **Problem 3 fix**: ``REPLAY_BACKWARD`` is NOT set after - ``forward_backward_batch`` returns (by then backward is already done - in Megatron's 1F1B schedule). Instead, ``set_target_indices()`` in - ``setup_per_microbatch_replay_forward`` appends to ``replay_backward_list`` - so that activation-checkpoint recompute during backward already has the - data it needs via ``REPLAY_FORWARD`` mode. -- **Problem 5 fix**: When ``attention_mask`` is absent (packed format), - we reconstruct it from ``cu_seqlens`` + ``max_seqlen`` so that - ``setup_per_microbatch_replay_forward`` always receives valid data. -- **Problem 7 fix**: Sample-count inference uses cu_seqlens first, then - attention_mask, then input_ids, then even division as last resort. - Usage:: from areal.engine.megatron_engine_r3_patch import patch_megatron_engine_for_r3 @@ -360,7 +340,6 @@ def _r3_forward_backward_batch( # ------------------------------------------------------------------ # 3. Wrap the MicroBatchList iterator on the INSTANCE level - # (Problem 2 fix: do NOT modify __class__.__iter__). # # The iterator injects R3 setup before each micro-batch's forward. # Because Megatron's 1F1B schedule interleaves forward and backward, @@ -424,12 +403,6 @@ def __next__(self): ) return mb_item - # Problem 2 fix: patch __iter__ on the INSTANCE, not the class. - # Python's iter() protocol looks up __iter__ on the *type*, not the instance. - # So we cannot simply assign mb_list.__iter__ = ... and have iter() find it. - # Instead, we save the original class __iter__, temporarily replace it on - # the class for the duration of this call, and restore it in finally. - # This is safe because forward_backward_batch is synchronous. original_class_iter = mb_list.__class__.__iter__ def _r3_iter(mb_list_self): @@ -438,7 +411,6 @@ def _r3_iter(mb_list_self): mb_list.__class__.__iter__ = _r3_iter try: - # Problem 3 explanation: # Megatron's forward_backward_func (e.g. 1F1B schedule) internally # interleaves forward and backward for each micro-batch. We do NOT # call setup_per_microbatch_replay_backward() after this function diff --git a/areal/engine/router_replay_utils.py b/areal/engine/router_replay_utils.py index 3536b7688c..15bc0006d6 100644 --- a/areal/engine/router_replay_utils.py +++ b/areal/engine/router_replay_utils.py @@ -27,10 +27,6 @@ - No dependency on megatron.core.transformer.moe.router_replay - Simplified packed-sequence handling (no preprocess_packed_seqs dependency) - topk and num_moe_layers passed explicitly (no hardcoded guessing) - -v3 changes: -- Problem 6 fix: guard scatter_to_sequence_parallel_region for tp_size==1 -- Improved logging and bounds checking throughout """ from __future__ import annotations diff --git a/areal/trainer/ppo/actor.py b/areal/trainer/ppo/actor.py index f9c21f285d..6443cdf68d 100644 --- a/areal/trainer/ppo/actor.py +++ b/areal/trainer/ppo/actor.py @@ -368,12 +368,6 @@ def _ppo_update(self, data: dict[str, Any]) -> None: data.pop(key, None) # NOTE: calling engine.train() is critical to enabling gradient checkpointing self.engine.train() - - # R3 Problem 1 fix: Pop routed_experts BEFORE split_padded_tensor_dict_into_mb_list. - # routed_experts is 4D (bs, seq_len, num_moe_layers, topk) and its numel() != - # bs * max_seqlen, so it would be placed in not_to_split and broadcast identically - # to every mini-batch, causing data mismatch when ppo_n_minibatches > 1. - # We pop it here and manually split it per mini-batch after the standard split. _r3_routed_experts = data.pop("routed_experts", None) mb_inputs = split_padded_tensor_dict_into_mb_list( From bb650a2513ea265365e0eadf08f79391c175e12b Mon Sep 17 00:00:00 2001 From: bingyechen Date: Sun, 12 Apr 2026 21:41:55 +0800 Subject: [PATCH 004/112] feat: add config for test --- ...moonlight_16b_a3b_gsm8k_grpo_megatron.yaml | 193 ++++++++++++++++++ 1 file changed, 193 insertions(+) create mode 100644 examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron.yaml diff --git a/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron.yaml b/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron.yaml new file mode 100644 index 0000000000..2bb29ba89a --- /dev/null +++ b/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron.yaml @@ -0,0 +1,193 @@ +experiment_name: moonlight-16b-a3b-gsm8k-grpo +trial_name: trial0 + +seed: 1 +enable_offload: false +total_train_epochs: 10 +tokenizer_path: ${actor.path} + +cluster: + n_nodes: 1 + n_gpus_per_node: 8 + fileroot: /tmp/areal/moon_experiments + name_resolve: + type: nfs + nfs_record_root: /tmp/areal/moon_name_resolve + +scheduler: + type: null + +rollout: + backend: "sglang:d1p1t8" + experiment_name: ${experiment_name} + trial_name: ${trial_name} + max_concurrent_rollouts: 64 + queue_size: null + consumer_batch_size: ${train_dataset.batch_size} + max_head_offpolicyness: 2 + enable_rollout_tracing: false + scheduling_spec: ${actor.scheduling_spec} + fileroot: ${cluster.fileroot} + tokenizer_path: ${tokenizer_path} + dump_to_file: true + +gconfig: + n_samples: 4 + min_new_tokens: 0 + max_new_tokens: 768 + greedy: false + temperature: 1.0 + +actor: + backend: "megatron:(attn:d1p2t4|ffn:d1p2t1e4)" # ← PP=2 回退,TP=4/EP=4 + experiment_name: ${experiment_name} + trial_name: ${trial_name} + path: /workspace/models/Moonlight-16B-A3B-Instruct + init_from_scratch: false + disable_dropout: true + gradient_checkpointing: true + dtype: bfloat16 + mb_spec: + max_tokens_per_mb: 1280 # ← 从 2048 降至 512 + optimizer: + type: adam_bf16 + lr: 5e-6 + weight_decay: 0.003 + beta1: 0.9 + beta2: 0.999 + eps: 1e-8 + lr_scheduler_type: constant + gradient_clipping: 1.0 + warmup_steps_proportion: 0.001 + eps_clip: 0.4 + temperature: ${gconfig.temperature} + reward_scaling: 10.0 + reward_bias: -0.5 + kl_ctl: 0.0 + ppo_n_minibatches: 8 # ← 从 1 提高至 4(分批梯度累积) + recompute_logprob: true + use_decoupled_loss: true + behave_imp_weight_cap: 5.0 + reward_norm: + mean_level: group + std_level: group + group_size: ${gconfig.n_samples} + adv_norm: + mean_level: batch + std_level: batch + weight_update_mode: disk + max_new_tokens: ${gconfig.max_new_tokens} + megatron: + use_deterministic_algorithms: false + recompute_granularity: full + recompute_method: uniform + recompute_num_layers: 14 + main_grads_dtype: bfloat16 # 梯度从 FP32 降为 BF16(节省 ~4 GiB) + # store_param_remainders: true + # optimizer_cpu_offload: true + # optimizer_offload_fraction: 0.5 + # main_params_dtype: bfloat16 + # main_grads_dtype: bfloat16 + # # adam_bf16 已自动设置以下两项,但显式声明更安全 + # exp_avg_dtype: bfloat16 + # exp_avg_sq_dtype: bfloat16 + ddp: + grad_reduce_in_fp32: false # ← 保持逐层重计算 + scheduling_spec: + - task_type: worker + port_count: 2 + gpu: 1 + mem: 48 + cmd: python3 -m areal.infra.rpc.rpc_server + env_vars: + PYTORCH_CUDA_ALLOC_CONF: "expandable_segments:True" + +ref: + backend: ${actor.backend} + experiment_name: ${experiment_name} + trial_name: ${trial_name} + path: ${actor.path} + init_from_scratch: false + disable_dropout: true + dtype: ${actor.dtype} + mb_spec: + max_tokens_per_mb: 1280 + optimizer: null + scheduling_strategy: + type: colocation + target: actor + scheduling_spec: ${actor.scheduling_spec} + +sglang: + model_path: ${actor.path} + random_seed: ${seed} + skip_tokenizer_init: true + dtype: bfloat16 + max_running_requests: 8 + context_length: 2048 + mem_fraction_static: 0.2 + attention_backend: triton + +vllm: + model: ${actor.path} + seed: ${seed} + skip_tokenizer_init: false + dtype: bfloat16 + max_model_len: 4096 + gpu_memory_utilization: 0.75 + +train_dataset: + batch_size: 64 + shuffle: true + pin_memory: true + num_workers: 4 + path: openai/gsm8k + type: rl + max_length: 512 + +valid_dataset: + batch_size: 128 + pin_memory: true + num_workers: 4 + path: openai/gsm8k + type: rl + +saver: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: 1 + freq_steps: null + freq_secs: null + +recover: + mode: disabled + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: 1 + freq_steps: null + freq_secs: 3600 + +evaluator: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: 1 + freq_steps: null + freq_secs: null + +stats_logger: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + wandb: + mode: disabled + +perf_tracer: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + enabled: false + session_tracer: + enabled: false \ No newline at end of file From 87dca2ad453d048f302d37806a45f67fa25661c2 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Mon, 13 Apr 2026 00:53:10 +0800 Subject: [PATCH 005/112] feat: fix --- areal/engine/megatron_engine_r3_patch.py | 71 ++++-- areal/engine/router_replay_patch.py | 84 +++++- areal/trainer/ppo/actor.py | 63 +---- areal/trainer/ppo/actor_r3_patch.py | 241 +++++++++++++++++- ...moonlight_16b_a3b_gsm8k_grpo_megatron.yaml | 9 + 5 files changed, 383 insertions(+), 85 deletions(-) diff --git a/areal/engine/megatron_engine_r3_patch.py b/areal/engine/megatron_engine_r3_patch.py index ab016a2468..6c02e96b1b 100644 --- a/areal/engine/megatron_engine_r3_patch.py +++ b/areal/engine/megatron_engine_r3_patch.py @@ -285,6 +285,7 @@ def _r3_forward_backward_batch( If the data does not contain ``routed_experts``, delegates directly to the original method with zero overhead. """ + from areal.engine.router_replay_patch import RouterReplay from areal.engine.router_replay_utils import ( clear_router_replay, setup_per_microbatch_replay_forward, @@ -338,14 +339,39 @@ def _r3_forward_backward_batch( self._r3_mb_counter = 0 model_config = self.tf_config + # ------------------------------------------------------------------ + # 2b. Set PP size on RouterReplay for backward recompute remapping. + # Reset call counters so each forward_backward_batch starts clean. + # ------------------------------------------------------------------ + try: + from megatron.core import parallel_state as mpu + pp_size = mpu.get_pipeline_model_parallel_world_size() + except Exception: + pp_size = getattr(model_config, "pipeline_model_parallel_size", 1) + RouterReplay.pp_size = pp_size + RouterReplay.reset_all_call_counters() + logger.debug( + "[R3] Set RouterReplay.pp_size=%d, reset all call counters " + "(%d instances).", + pp_size, + len(RouterReplay.router_instances), + ) + # ------------------------------------------------------------------ # 3. Wrap the MicroBatchList iterator on the INSTANCE level # # The iterator injects R3 setup before each micro-batch's forward. - # Because Megatron's 1F1B schedule interleaves forward and backward, - # we rely on set_target_indices() appending to replay_backward_list - # so that activation-checkpoint recompute in backward has the data - # it needs (Problem 3 fix). + # The call-counter mechanism in RouterReplay handles backward + # recompute correctly for both PP=1 and PP>1 scenarios: + # + # - PP=1 (forward_backward_no_pipelining): all forwards first, + # then all backwards in REVERSE order. The call counter tracks + # which MB's data to use: forward calls [0..N-1], backward + # recompute calls [N..2N-1] remapped as [N-1..0]. + # + # - PP>1 (1F1B): interleaved forward-backward. Each MB's backward + # follows its forward closely. The call counter handles this + # with same-order remapping. # ------------------------------------------------------------------ engine_ref = self @@ -375,16 +401,23 @@ def __next__(self): attn_mask = _get_attention_mask_for_mb(mb_item) if attn_mask is not None: + # Truncate `re` sequence length to match `attn_mask` + # Since data is right-padded (real tokens are on the left), + # slicing the first max_seqlen elements safely removes extra padding. + # This prevents a shape mismatch when the micro-batch has a smaller max_seqlen + # than the full batch (which occurs when sequences are packed). + re_matched = re[:, :attn_mask.shape[1], ...] + try: setup_per_microbatch_replay_forward( - re.to(attn_mask.device), + re_matched.to(attn_mask.device), attn_mask, model_config, ) logger.debug( "[R3] Replay setup OK for micro-batch %d: " "routed_experts=%s, attn_mask=%s.", - idx, re.shape, attn_mask.shape, + idx, re_matched.shape, attn_mask.shape, ) except Exception: logger.warning( @@ -412,22 +445,20 @@ def _r3_iter(mb_list_self): try: # Megatron's forward_backward_func (e.g. 1F1B schedule) internally - # interleaves forward and backward for each micro-batch. We do NOT - # call setup_per_microbatch_replay_backward() after this function - # returns -- by then backward is already done. + # interleaves forward and backward for each micro-batch. # - # Instead, the backward replay is handled by the RouterReplay design: - # - set_target_indices() in setup_per_microbatch_replay_forward() - # appends indices to replay_backward_list for each layer. - # - When activation checkpointing triggers recompute during backward, - # the patched routing function checks router_replay_action. - # - We set REPLAY_FORWARD before each micro-batch's forward. - # During the same micro-batch's backward recompute, the - # replay_backward_list is consumed via pop(0) in FIFO order. + # The call-counter mechanism in RouterReplay handles backward + # recompute (activation checkpointing) correctly: + # - Each RouterReplay instance tracks how many times its routing + # function is called via routing_call_count. + # - Forward calls (call_idx < N) use routing_data_list[call_idx]. + # - Backward recompute calls (call_idx >= N) are remapped based + # on pp_size: + # * PP=1: reverse order (N-1-recompute_idx) + # * PP>1: same order (recompute_idx) # - # For the non-activation-checkpoint case, backward simply uses - # autograd on the forward's recorded computation graph, so no - # re-routing occurs and replay is not needed. + # This replaces the previous design that relied on target_topk_idx + # (which got overwritten by subsequent MBs in PP=1 multi-MB case). self._r3_original_forward_backward_batch( mb_list, process_output_fn, forward_only=forward_only ) diff --git a/areal/engine/router_replay_patch.py b/areal/engine/router_replay_patch.py index 4d135bb02a..a859fe4853 100644 --- a/areal/engine/router_replay_patch.py +++ b/areal/engine/router_replay_patch.py @@ -112,6 +112,10 @@ class RouterReplay: # Class-level list of all router instances (one per MoE layer). router_instances: list["RouterReplay"] = [] + # Class-level pipeline parallelism size for backward remapping. + # Set by the engine patch before forward_backward_func. + pp_size: int = 1 + # ------------------------------------------------------------------ # Class-level (static) helpers # ------------------------------------------------------------------ @@ -157,6 +161,12 @@ def clear_global_router_replay_action() -> None: for r in RouterReplay.router_instances: r.clear_router_replay_action() + @staticmethod + def reset_all_call_counters() -> None: + """Reset call counters on all instances (call before forward_backward).""" + for r in RouterReplay.router_instances: + r.routing_call_count = 0 + # ------------------------------------------------------------------ # Instance methods # ------------------------------------------------------------------ @@ -165,11 +175,18 @@ def __init__(self) -> None: self.target_topk_idx: torch.Tensor | None = None self.recorded_topk_idx: torch.Tensor | None = None self.router_replay_action: RouterReplayAction | None = None + # Legacy list kept for API compatibility; no longer used for + # backward replay (call-counter approach replaces it). self.replay_backward_list: list[torch.Tensor] = [] + # Call-counter mechanism for multi-MB activation checkpointing. + self.routing_data_list: list[torch.Tensor] = [] + self.routing_call_count: int = 0 RouterReplay.router_instances.append(self) def set_target_indices(self, topk_indices: torch.Tensor) -> None: self.target_topk_idx = topk_indices + self.routing_data_list.append(topk_indices) + # Keep replay_backward_list in sync for backward compatibility. self.replay_backward_list.append(topk_indices) def get_recorded_indices(self) -> torch.Tensor | None: @@ -182,6 +199,8 @@ def clear_indices(self) -> None: self.recorded_topk_idx = None self.target_topk_idx = None self.replay_backward_list = [] + self.routing_data_list = [] + self.routing_call_count = 0 def set_router_replay_action(self, action: RouterReplayAction) -> None: self.router_replay_action = action @@ -189,6 +208,54 @@ def set_router_replay_action(self, action: RouterReplayAction) -> None: def clear_router_replay_action(self) -> None: self.router_replay_action = None + def get_replay_indices_by_call_count(self) -> torch.Tensor | None: + """Return the correct replay indices based on the current call count. + + Uses the call-counter mechanism to handle multi-MB + activation + checkpointing correctly. + + Returns: + The routing indices tensor for the current call, or None if + no data is available. + """ + n_mbs = len(self.routing_data_list) + if n_mbs == 0: + return None + + call_idx = self.routing_call_count + self.routing_call_count += 1 + + if call_idx < n_mbs: + # Forward pass: use indices in order [0, 1, ..., N-1] + data_idx = call_idx + phase = "forward" + else: + # Backward recompute: remap based on pipeline schedule + recompute_idx = call_idx - n_mbs + if RouterReplay.pp_size <= 1: + # PP=1: forward_backward_no_pipelining + # Backward order is REVERSE: MB_{N-1}, MB_{N-2}, ..., MB_0 + data_idx = n_mbs - 1 - recompute_idx + else: + # PP>1: 1F1B schedule + # Backward order is SAME as forward: MB_0, MB_1, ..., MB_{N-1} + data_idx = recompute_idx + phase = "backward_recompute" + + # Clamp to valid range as safety measure + data_idx = max(0, min(data_idx, n_mbs - 1)) + + if logger.isEnabledFor(logging.DEBUG): + logger.debug( + "[R3] get_replay_indices_by_call_count: " + "call_idx=%d, n_mbs=%d, pp_size=%d, phase=%s, " + "data_idx=%d, shape=%s.", + call_idx, n_mbs, RouterReplay.pp_size, phase, + data_idx, self.routing_data_list[data_idx].shape, + ) + + return self.routing_data_list[data_idx] + # =================================================================== # Patched routing implementation @@ -240,17 +307,26 @@ def compute_topk(scores, topk, num_groups=None, group_topk=None): return probs, top_indices elif routing_action == RouterReplayAction.REPLAY_FORWARD: - if router_replay is None or router_replay.target_topk_idx is None: + # Use call-counter approach for multi-MB + activation checkpointing. + # This handles both forward and backward recompute correctly: + # - Forward calls (call_idx < N): use routing_data_list[call_idx] + # - Backward recompute (call_idx >= N): remapped by PP schedule + top_indices = router_replay.get_replay_indices_by_call_count() + if top_indices is None: logger.warning( - "[R3] REPLAY_FORWARD: no target indices available, " - "falling back to normal routing." + "[R3] REPLAY_FORWARD: no replay indices available " + "(routing_data_list empty), falling back to normal routing." ) return _compute_topk(scores, topk, num_groups=num_groups, group_topk=group_topk) - top_indices = router_replay.target_topk_idx.to(scores.device) + top_indices = top_indices.to(scores.device) probs = scores.gather(1, top_indices) return probs, top_indices elif routing_action == RouterReplayAction.REPLAY_BACKWARD: + # Legacy backward mode using replay_backward_list.pop(0). + # Kept for backward compatibility but no longer the primary + # mechanism. The call-counter approach in REPLAY_FORWARD + # handles both forward and backward recompute. if router_replay is None or not router_replay.replay_backward_list: logger.warning( "[R3] REPLAY_BACKWARD: no backward indices available, " diff --git a/areal/trainer/ppo/actor.py b/areal/trainer/ppo/actor.py index 6443cdf68d..515790362a 100644 --- a/areal/trainer/ppo/actor.py +++ b/areal/trainer/ppo/actor.py @@ -34,53 +34,6 @@ logger = logging.getLogger("PPOActor") - -def _split_routed_experts_for_minibatches( - routed_experts: torch.Tensor, - mb_inputs, -) -> list: - """Split routed_experts tensor per mini-batch using forward_indices. - - This handles R3 Problem 1: routed_experts is 4D and cannot be split - by split_padded_tensor_dict_into_mb_list (which only splits 2D tensors - with numel == bs * max_seqlen). - - Args: - routed_experts: ``(bs, seq_len, num_moe_layers, topk)`` - mb_inputs: ``MicroBatchList`` with ``forward_indices``. - - Returns: - List of tensors, one per mini-batch. - """ - forward_indices = mb_inputs.forward_indices - n_mbs = len(mb_inputs.mbs) - - if forward_indices is None: - # No reordering -- split evenly - bs = routed_experts.shape[0] - chunk = bs // n_mbs - return [routed_experts[i * chunk : (i + 1) * chunk] for i in range(n_mbs)] - - # Reorder by forward_indices - reordered = routed_experts[forward_indices] - - # Determine per-MB sample counts from group_lens and attention_mask info. - # Since we are before pack_tensor_dict, mbs still have attention_mask. - result = [] - offset = 0 - for i, mb_dict in enumerate(mb_inputs.mbs): - if isinstance(mb_dict, dict) and "attention_mask" in mb_dict: - n_samples = mb_dict["attention_mask"].shape[0] - elif isinstance(mb_dict, dict) and "cu_seqlens" in mb_dict: - n_samples = len(mb_dict["cu_seqlens"]) - 1 - else: - n_samples = routed_experts.shape[0] // n_mbs - result.append(reordered[offset : offset + n_samples]) - offset += n_samples - - return result - - class PPOActor: def __init__(self, config: PPOActorConfig, engine: TrainEngine): self.config = config @@ -368,22 +321,16 @@ def _ppo_update(self, data: dict[str, Any]) -> None: data.pop(key, None) # NOTE: calling engine.train() is critical to enabling gradient checkpointing self.engine.train() - _r3_routed_experts = data.pop("routed_experts", None) - + # NOTE: routed_experts is intentionally left in data. + # It is a 4D tensor and will be put into not_to_split by + # split_padded_tensor_dict_into_mb_list. The R3 engine patch + # (megatron_engine_r3_patch.py) will extract it from mb_inputs.data + # and handle the per-microbatch splitting. mb_inputs = split_padded_tensor_dict_into_mb_list( data, mb_spec=MicroBatchSpec(n_mbs=self.config.ppo_n_minibatches), ) - # R3: Manually split routed_experts and inject into each mini-batch. - if _r3_routed_experts is not None: - _r3_split = _split_routed_experts_for_minibatches( - _r3_routed_experts, mb_inputs - ) - for i, mb_dict in enumerate(mb_inputs.mbs): - if _r3_split[i] is not None: - mb_dict["routed_experts"] = _r3_split[i] - with stats_tracker.scope("update"): # Get current version for proximal approximation metrics current_version = self.engine.get_version() diff --git a/areal/trainer/ppo/actor_r3_patch.py b/areal/trainer/ppo/actor_r3_patch.py index e70408f14d..bd73c26b00 100644 --- a/areal/trainer/ppo/actor_r3_patch.py +++ b/areal/trainer/ppo/actor_r3_patch.py @@ -15,10 +15,23 @@ R3 metrics and logging helpers for the PPO actor. When Router Replay (R3) is enabled, these functions compute and log -statistics about the replayed routing decisions, such as: +statistics about the replayed routing decisions. +The key effectiveness metrics are: -- The fraction of micro-batches that carried replay data. -- Per-step summary of routing shapes and data types. +1. **Router Agreement Rate** -- fraction of tokens where training routing + matches the replayed (inference-time) routing. Measures how effectively + R3 forces routing alignment. + +2. **Per-Layer Routing Entropy** -- Shannon entropy of the expert probability + distribution per MoE layer. Lower entropy under replay indicates stronger + routing concentration (expected when replay overrides natural routing). + +3. **Expert Utilization Balance** -- standard deviation of per-expert token + counts normalised by the mean. High balance (low std/mean) indicates + evenly distributed expert usage; replay may skew this. + +4. **Routing Data Coverage** -- fraction of micro-batches that carried valid + replay data. Should be 1.0 in a healthy R3 run. All logging uses the ``stats_tracker`` infrastructure so that metrics appear in the same TensorBoard / WandB dashboards as other PPO stats. @@ -64,10 +77,232 @@ def log_r3_data_stats( r3_dtype_bytes=re.element_size(), r3_max_expert_id=re.max().item() if re.numel() > 0 else 0, ) + + # Compute R3 effectiveness metrics + _log_r3_effectiveness_metrics(re) else: stats_tracker.scalar(r3_present=0) +def _log_r3_effectiveness_metrics( + routed_experts: torch.Tensor, +) -> None: + """Compute and log R3 effectiveness metrics following SkyRL's approach. + + These metrics help assess whether Router Replay is working correctly + and how it affects the MoE routing distribution. + + Args: + routed_experts: ``(bs, seq_len, num_moe_layers, topk)`` int tensor + containing the expert indices from inference. + """ + if routed_experts.dim() != 4 or routed_experts.numel() == 0: + return + + bs, seq_len, num_moe_layers, topk = routed_experts.shape + + try: + # --- Metric 1: Per-Layer Routing Entropy --- + # Measures the diversity of expert assignments per layer. + # Lower entropy = more concentrated routing. + # Under R3, this reflects the inference-time routing distribution. + _log_per_layer_routing_entropy(routed_experts, num_moe_layers, topk) + + # --- Metric 2: Expert Utilization Balance --- + # Measures how evenly tokens are distributed across experts. + # Coefficient of variation (std/mean) -- lower = more balanced. + _log_expert_utilization_balance(routed_experts, num_moe_layers) + + # --- Metric 3: Routing Data Coverage --- + # Fraction of (batch, layer) combinations with non-zero routing data. + _log_routing_data_coverage(routed_experts, bs, num_moe_layers) + + # --- Metric 4: Top-1 Expert Concentration --- + # How often the most popular expert is selected (per layer). + _log_top1_expert_concentration(routed_experts, num_moe_layers) + + except Exception: + logger.warning( + "[R3] Failed to compute R3 effectiveness metrics.", + exc_info=True, + ) + + +def _log_per_layer_routing_entropy( + routed_experts: torch.Tensor, + num_moe_layers: int, + topk: int, +) -> None: + """Log per-layer Shannon entropy of expert routing distribution. + + For each MoE layer, computes the probability distribution over experts + (from the replay data) and its Shannon entropy. Reports mean, min, + max across layers. + """ + bs, seq_len = routed_experts.shape[:2] + # Flatten batch and seq dimensions + flat = routed_experts.view(-1, num_moe_layers, topk) # (bs*seq_len, L, K) + num_tokens = flat.shape[0] + + if num_tokens == 0: + return + + # Determine number of experts from max index + num_experts = int(routed_experts.max().item()) + 1 + if num_experts <= 0: + return + + layer_entropies = [] + for layer_idx in range(num_moe_layers): + # Count expert occurrences for this layer across all tokens and topk slots + expert_ids = flat[:, layer_idx, :].reshape(-1).long() + # Filter out padding (expert_id == 0 might be valid, but -1 or very large is not) + valid_mask = (expert_ids >= 0) & (expert_ids < num_experts) + expert_ids = expert_ids[valid_mask] + if expert_ids.numel() == 0: + continue + + counts = torch.bincount(expert_ids, minlength=num_experts).float() + probs = counts / counts.sum() + # Shannon entropy: -sum(p * log(p)), with 0*log(0) = 0 + log_probs = torch.where(probs > 0, torch.log2(probs), torch.zeros_like(probs)) + entropy = -(probs * log_probs).sum().item() + layer_entropies.append(entropy) + + if layer_entropies: + mean_entropy = sum(layer_entropies) / len(layer_entropies) + min_entropy = min(layer_entropies) + max_entropy = max(layer_entropies) + # Maximum possible entropy for reference + max_possible = torch.log2(torch.tensor(float(num_experts))).item() + + stats_tracker.scalar( + r3_routing_entropy_mean=mean_entropy, + r3_routing_entropy_min=min_entropy, + r3_routing_entropy_max=max_entropy, + r3_routing_entropy_normalised=mean_entropy / max_possible if max_possible > 0 else 0, + r3_num_experts=num_experts, + ) + + +def _log_expert_utilization_balance( + routed_experts: torch.Tensor, + num_moe_layers: int, +) -> None: + """Log expert utilization balance (coefficient of variation per layer). + + For each layer, compute the standard deviation of per-expert token + counts divided by the mean. Aggregate across layers. + """ + flat = routed_experts.view(-1, num_moe_layers, routed_experts.shape[-1]) + num_experts = int(routed_experts.max().item()) + 1 + if num_experts <= 1: + return + + layer_cv_values = [] + for layer_idx in range(num_moe_layers): + expert_ids = flat[:, layer_idx, :].reshape(-1).long() + valid_mask = (expert_ids >= 0) & (expert_ids < num_experts) + expert_ids = expert_ids[valid_mask] + if expert_ids.numel() == 0: + continue + + counts = torch.bincount(expert_ids, minlength=num_experts).float() + mean_count = counts.mean() + if mean_count > 0: + cv = counts.std() / mean_count + layer_cv_values.append(cv.item()) + + if layer_cv_values: + stats_tracker.scalar( + r3_expert_util_cv_mean=sum(layer_cv_values) / len(layer_cv_values), + r3_expert_util_cv_max=max(layer_cv_values), + r3_expert_util_cv_min=min(layer_cv_values), + ) + + +def _log_routing_data_coverage( + routed_experts: torch.Tensor, + bs: int, + num_moe_layers: int, +) -> None: + """Log fraction of (sample, layer) with non-zero routing data.""" + # Check each sample x layer has at least one non-zero expert id + # routed_experts: (bs, seq_len, num_moe_layers, topk) + # Sum over seq_len and topk dimensions + has_data = (routed_experts.sum(dim=(1, 3)) > 0).float() # (bs, num_moe_layers) + coverage = has_data.mean().item() + stats_tracker.scalar(r3_routing_data_coverage=coverage) + + +def _log_top1_expert_concentration( + routed_experts: torch.Tensor, + num_moe_layers: int, +) -> None: + """Log how concentrated routing is on the most popular expert per layer. + + For each layer, the concentration ratio = count(most_popular_expert) / total_count. + High concentration suggests the replay data has strong routing preferences. + """ + flat = routed_experts.view(-1, num_moe_layers, routed_experts.shape[-1]) + num_experts = int(routed_experts.max().item()) + 1 + if num_experts <= 0: + return + + layer_concentrations = [] + for layer_idx in range(num_moe_layers): + expert_ids = flat[:, layer_idx, :].reshape(-1).long() + valid_mask = (expert_ids >= 0) & (expert_ids < num_experts) + expert_ids = expert_ids[valid_mask] + if expert_ids.numel() == 0: + continue + + counts = torch.bincount(expert_ids, minlength=num_experts) + max_count = counts.max().item() + total = counts.sum().item() + if total > 0: + layer_concentrations.append(max_count / total) + + if layer_concentrations: + stats_tracker.scalar( + r3_top1_expert_concentration_mean=sum(layer_concentrations) / len(layer_concentrations), + r3_top1_expert_concentration_max=max(layer_concentrations), + ) + + +def compute_router_agreement_rate( + replay_indices: torch.Tensor, + actual_indices: torch.Tensor, +) -> float: + """Compute the fraction of tokens where actual routing matches replay target. + + This is the KEY R3 effectiveness metric: if R3 is working correctly, + agreement should be very close to 1.0 (training router produces the same + assignments as the replayed inference routing). + + Args: + replay_indices: ``(num_tokens, topk)`` target expert indices from replay. + actual_indices: ``(num_tokens, topk)`` actual expert indices from training. + + Returns: + Agreement rate in [0, 1]. Returns -1.0 if inputs are invalid. + """ + if replay_indices is None or actual_indices is None: + return -1.0 + if replay_indices.shape != actual_indices.shape: + logger.warning( + "[R3] Agreement rate: shape mismatch replay=%s vs actual=%s.", + replay_indices.shape, actual_indices.shape, + ) + return -1.0 + + # Sort topk indices per token to handle different ordering + replay_sorted = replay_indices.sort(dim=-1).values + actual_sorted = actual_indices.sort(dim=-1).values + matches = (replay_sorted == actual_sorted).all(dim=-1).float() + return matches.mean().item() + + def strip_routed_experts_before_loss( data: dict[str, Any], ) -> dict[str, Any]: diff --git a/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron.yaml b/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron.yaml index 2bb29ba89a..e6d4f1c0fc 100644 --- a/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron.yaml +++ b/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron.yaml @@ -30,6 +30,11 @@ rollout: fileroot: ${cluster.fileroot} tokenizer_path: ${tokenizer_path} dump_to_file: true + # R3: Enable returning routed expert assignments from rollout inference. + # This triggers the entire Router Replay pipeline: SGLang returns per-token + # expert indices, which are then replayed during training to eliminate + # train/inference routing mismatch in MoE models. + return_routed_experts: true gconfig: n_samples: 4 @@ -127,6 +132,10 @@ sglang: context_length: 2048 mem_fraction_static: 0.2 attention_backend: triton + # R3: Enable SGLang to capture and return per-token routed expert indices + # during inference. This is auto-set by rl_trainer when + # rollout.return_routed_experts=true, but explicitly declared here for clarity. + enable_return_routed_experts: true vllm: model: ${actor.path} From 6198aeabb43e0314753c43672535d4618f7af14a Mon Sep 17 00:00:00 2001 From: bingyechen Date: Mon, 13 Apr 2026 13:49:21 +0800 Subject: [PATCH 006/112] faet: fix --- areal/engine/megatron_engine_r3_patch.py | 153 ++++++++++++++-- areal/trainer/ppo/actor.py | 32 +++- areal/trainer/ppo/actor_r3_patch.py | 214 ++++++++++++++++++++++- areal/trainer/rl_trainer.py | 18 +- 4 files changed, 384 insertions(+), 33 deletions(-) diff --git a/areal/engine/megatron_engine_r3_patch.py b/areal/engine/megatron_engine_r3_patch.py index 6c02e96b1b..db46213d32 100644 --- a/areal/engine/megatron_engine_r3_patch.py +++ b/areal/engine/megatron_engine_r3_patch.py @@ -77,6 +77,7 @@ def patch_megatron_engine_for_r3( # Mark and save original engine._r3_enabled = True engine._r3_original_forward_backward_batch = engine.forward_backward_batch + engine._r3_pending_routed_experts = None # Bind the wrapped method engine.forward_backward_batch = types.MethodType( @@ -124,6 +125,95 @@ def _reconstruct_attention_mask_from_cu_seqlens( return mask +# =================================================================== +# Problem 2 fix: Align routed_experts seq dim to attention_mask +# =================================================================== + + +def _align_routed_experts_to_mask( + routed_experts: torch.Tensor, + attention_mask: torch.Tensor, +) -> torch.Tensor: + """Align ``routed_experts`` seq dimension to match ``attention_mask``. + + **Problem 2 Fix**: After pack_tensor_dict + pad_mb_list, the + cu_seqlens-reconstructed ``attention_mask`` has ``mb_max_seqlen`` + which may be SMALLER than ``routed_experts``' seq dimension + (``batch_max_seqlen``). The rollout-produced ``routed_experts`` is + LEFT-padded (padding on the left, real tokens on the right), while + the post-pack ``attention_mask`` is LEFT-aligned (real tokens first, + no left-padding). + + This function extracts the right-most ``actual_len`` tokens from each + sample's left-padded ``routed_experts`` and places them at the + left-aligned positions expected by ``attention_mask``. + + Args: + routed_experts: ``(bs, batch_max_seqlen, num_moe_layers, topk)`` + Left-padded routing indices from rollout. + attention_mask: ``(bs, mb_max_seqlen)`` + Left-aligned mask (1 for real tokens, 0 for padding). + + Returns: + ``(bs, mb_max_seqlen, num_moe_layers, topk)`` aligned tensor. + """ + bs, re_seqlen = routed_experts.shape[:2] + _, mask_seqlen = attention_mask.shape[:2] + + if re_seqlen == mask_seqlen: + # No alignment needed + return routed_experts + + if re_seqlen < mask_seqlen: + # Unlikely but possible if mask was padded beyond batch_max_seqlen. + # Right-pad routed_experts with zeros. + extra_dims = routed_experts.shape[2:] # (num_moe_layers, topk) + padded = torch.zeros( + bs, mask_seqlen, *extra_dims, + dtype=routed_experts.dtype, + device=routed_experts.device, + ) + padded[:, :re_seqlen] = routed_experts + logger.info( + "[R3] _align_routed_experts_to_mask: re_seqlen(%d) < mask_seqlen(%d), " + "right-padded routed_experts with zeros.", + re_seqlen, mask_seqlen, + ) + return padded + + # re_seqlen > mask_seqlen: the common case. + # routed_experts is LEFT-padded: real tokens are at the RIGHT end. + # attention_mask is LEFT-aligned: real tokens are at the LEFT end. + # For each sample, extract the rightmost `actual_len` tokens from + # routed_experts and place them at positions [0, actual_len) in output. + extra_dims = routed_experts.shape[2:] # (num_moe_layers, topk) + aligned = torch.zeros( + bs, mask_seqlen, *extra_dims, + dtype=routed_experts.dtype, + device=routed_experts.device, + ) + + seq_lens = attention_mask.sum(dim=1).long() # actual lengths per sample + for i in range(bs): + actual_len = int(seq_lens[i].item()) + if actual_len <= 0: + continue + # Take rightmost actual_len tokens from left-padded routed_experts + src_start = re_seqlen - actual_len + n = min(actual_len, mask_seqlen) + aligned[i, :n] = routed_experts[i, src_start : src_start + n] + + logger.info( + "[R3] _align_routed_experts_to_mask: re_seqlen=%d -> mask_seqlen=%d, " + "bs=%d, seq_lens=%s (aligned left-padded RE to left-aligned mask).", + re_seqlen, + mask_seqlen, + bs, + seq_lens.tolist()[:8], + ) + return aligned + + # =================================================================== # routed_experts splitting (Problem 7 fix: robust sample-count inference) # =================================================================== @@ -284,6 +374,14 @@ def _r3_forward_backward_batch( If the data does not contain ``routed_experts``, delegates directly to the original method with zero overhead. + + **Problem 1 Fix**: Retrieves routed_experts from engine side-channel + (``self._r3_pending_routed_experts``) set by actor._ppo_update FIRST, + falling back to ``mb_list.data`` for backward compatibility. + + **Problem 2 Fix**: Before passing per-MB routed_experts to + ``setup_per_microbatch_replay_forward``, aligns the seq dimension + to match the attention_mask's seq dimension. """ from areal.engine.router_replay_patch import RouterReplay from areal.engine.router_replay_utils import ( @@ -292,16 +390,36 @@ def _r3_forward_backward_batch( ) # ------------------------------------------------------------------ - # 1. Extract routed_experts from the batch data and split per-MB. - # routed_experts is (bs, max_seqlen, num_moe_layers, topk) and - # does NOT get split by split_padded_tensor_dict_into_mb_list. + # 1. Retrieve routed_experts. + # Problem 1 Fix: Prefer side-channel from actor._ppo_update, which + # bypasses _prepare_mb_list/pack_tensor_dict entirely. + # Fall back to mb_list.data for backward compatibility. # ------------------------------------------------------------------ routed_experts_batch = None - if hasattr(mb_list, "data") and isinstance(mb_list.data, dict): - routed_experts_batch = mb_list.data.pop("routed_experts", None) + + # Strategy A: Side-channel (Problem 1 fix -- preferred path) + if hasattr(self, '_r3_pending_routed_experts') and self._r3_pending_routed_experts is not None: + routed_experts_batch = self._r3_pending_routed_experts + self._r3_pending_routed_experts = None # Consume it + logger.info( + "[R3] Retrieved routed_experts from engine side-channel: shape=%s.", + routed_experts_batch.shape, + ) + + # Strategy B: Legacy path from mb_list.data (backward compatibility) + if routed_experts_batch is None: + if hasattr(mb_list, "data") and isinstance(mb_list.data, dict): + routed_experts_batch = mb_list.data.pop("routed_experts", None) + if routed_experts_batch is not None: + logger.info( + "[R3] Retrieved routed_experts from mb_list.data (legacy path): " + "shape=%s.", + routed_experts_batch.shape, + ) # Also clean from mbs and padded_mbs to avoid confusing downstream code. - # Problem 1: these would contain the un-split full tensor via not_to_split broadcast. + # Problem 1: these would contain the un-split full tensor via not_to_split broadcast, + # or corrupted 3D tensors from pack_tensor_dict. for mb_dict in mb_list.mbs: if isinstance(mb_dict, dict): mb_dict.pop("routed_experts", None) @@ -312,8 +430,8 @@ def _r3_forward_backward_batch( if routed_experts_batch is None: logger.debug( - "[R3] No routed_experts in batch data; using original " - "forward_backward_batch." + "[R3] No routed_experts found (neither side-channel nor mb_list.data); " + "using original forward_backward_batch." ) return self._r3_original_forward_backward_batch( mb_list, process_output_fn, forward_only=forward_only @@ -401,23 +519,22 @@ def __next__(self): attn_mask = _get_attention_mask_for_mb(mb_item) if attn_mask is not None: - # Truncate `re` sequence length to match `attn_mask` - # Since data is right-padded (real tokens are on the left), - # slicing the first max_seqlen elements safely removes extra padding. - # This prevents a shape mismatch when the micro-batch has a smaller max_seqlen - # than the full batch (which occurs when sequences are packed). - re_matched = re[:, :attn_mask.shape[1], ...] - try: + # Problem 2 fix: Align routed_experts seq dimension + # to match attention_mask's seq dimension. + # routed_experts is left-padded (batch_max_seqlen), + # attn_mask is left-aligned (mb_max_seqlen). + aligned_re = _align_routed_experts_to_mask(re, attn_mask) + setup_per_microbatch_replay_forward( - re_matched.to(attn_mask.device), + aligned_re.to(attn_mask.device), attn_mask, model_config, ) logger.debug( "[R3] Replay setup OK for micro-batch %d: " - "routed_experts=%s, attn_mask=%s.", - idx, re_matched.shape, attn_mask.shape, + "original_re=%s, aligned_re=%s, attn_mask=%s.", + idx, re.shape, aligned_re.shape, attn_mask.shape, ) except Exception: logger.warning( diff --git a/areal/trainer/ppo/actor.py b/areal/trainer/ppo/actor.py index 515790362a..00eb9c673c 100644 --- a/areal/trainer/ppo/actor.py +++ b/areal/trainer/ppo/actor.py @@ -321,21 +321,41 @@ def _ppo_update(self, data: dict[str, Any]) -> None: data.pop(key, None) # NOTE: calling engine.train() is critical to enabling gradient checkpointing self.engine.train() - # NOTE: routed_experts is intentionally left in data. - # It is a 4D tensor and will be put into not_to_split by - # split_padded_tensor_dict_into_mb_list. The R3 engine patch - # (megatron_engine_r3_patch.py) will extract it from mb_inputs.data - # and handle the per-microbatch splitting. + _r3_routed_experts = data.pop("routed_experts", None) + mb_inputs = split_padded_tensor_dict_into_mb_list( data, mb_spec=MicroBatchSpec(n_mbs=self.config.ppo_n_minibatches), ) + # R3: Split routed_experts per mini-batch for side-channel delivery. + _r3_split = None + if _r3_routed_experts is not None: + from areal.trainer.ppo.actor_r3_patch import split_routed_experts_for_minibatches + _r3_split = split_routed_experts_for_minibatches( + _r3_routed_experts, mb_inputs + ) + logger.info( + "[R3] Split routed_experts for %d mini-batches via side-channel " + "(shapes: %s).", + len(mb_inputs.mbs), + [s.shape if s is not None else None for s in _r3_split], + ) + with stats_tracker.scope("update"): # Get current version for proximal approximation metrics current_version = self.engine.get_version() - for mb in mb_inputs.mbs: + for i, mb in enumerate(mb_inputs.mbs): + # deliver routed_experts via engine side-channel + # to bypass pack_tensor_dict corruption and ensure correct per-mini-batch data. + if _r3_split is not None and hasattr(self.engine, '_r3_enabled'): + self.engine._r3_pending_routed_experts = ( + _r3_split[i] + if i < len(_r3_split) and _r3_split[i] is not None + else None + ) + train_stat = self.engine.train_batch( mb, loss_fn=functools.partial( diff --git a/areal/trainer/ppo/actor_r3_patch.py b/areal/trainer/ppo/actor_r3_patch.py index bd73c26b00..376a485c07 100644 --- a/areal/trainer/ppo/actor_r3_patch.py +++ b/areal/trainer/ppo/actor_r3_patch.py @@ -12,11 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. """ -R3 metrics and logging helpers for the PPO actor. +MoE routing metrics and R3 logging helpers for the PPO actor. -When Router Replay (R3) is enabled, these functions compute and log -statistics about the replayed routing decisions. -The key effectiveness metrics are: +Provides two categories of metrics: + +1. **R3 data stats** (``log_r3_data_stats``): Summary of the routed_experts + tensor shape, dtype, and basic coverage info. Logged when R3 is enabled. + +2. **MoE routing effectiveness metrics** (``log_moe_routing_metrics``): + SkyRL-style routing quality indicators that are useful for ANY MoE model, + regardless of whether R3 is enabled. These include: + - Routing entropy (per-layer and aggregated) + - Expert utilization balance (std dev of expert load) + - Data coverage ratio (fraction of samples with valid routing data) + - Top-1 expert concentration (how much traffic goes to most-used expert) + - Expert diversity (number of unique experts used per token) + +The key R3-specific effectiveness metrics are: 1. **Router Agreement Rate** -- fraction of tokens where training routing matches the replayed (inference-time) routing. Measures how effectively @@ -84,6 +96,81 @@ def log_r3_data_stats( stats_tracker.scalar(r3_present=0) +def split_routed_experts_for_minibatches( + routed_experts: torch.Tensor, + mb_list, +) -> list[torch.Tensor | None]: + """Split ``routed_experts`` tensor for actor-level mini-batches. + + This handles the Level-1 split (actor._ppo_update splits into + ppo_n_minibatches). The tensor is reordered by ``forward_indices`` + and then sliced according to each mini-batch's sample count. + + Args: + routed_experts: ``(bs, seq_len, num_moe_layers, topk)`` full batch tensor. + mb_list: ``MicroBatchList`` from ``split_padded_tensor_dict_into_mb_list``. + + Returns: + List of tensors, one per mini-batch, each of shape + ``(mini_bs, seq_len, num_moe_layers, topk)``. + """ + if routed_experts is None: + return [None] * len(mb_list) + + forward_indices = mb_list.forward_indices + n_mbs = len(mb_list) + + if forward_indices is None: + # No reordering -- just split evenly + bs = routed_experts.shape[0] + chunk = bs // n_mbs + result = [routed_experts[i * chunk : (i + 1) * chunk] for i in range(n_mbs)] + logger.debug( + "[R3] split_routed_experts_for_minibatches: no forward_indices, " + "split %d samples evenly into %d chunks of %d.", + bs, n_mbs, chunk, + ) + return result + + # Reorder by forward_indices (sample-level reordering) + reordered = routed_experts[forward_indices] + + # Determine number of samples per mini-batch from mbs dicts + result = [] + offset = 0 + for i, mb_dict in enumerate(mb_list.mbs): + n_samples = _infer_mb_sample_count_from_dict( + mb_dict, routed_experts.shape[0], n_mbs + ) + result.append(reordered[offset : offset + n_samples]) + offset += n_samples + + logger.debug( + "[R3] split_routed_experts_for_minibatches: split %d samples into " + "%d mini-batches with sizes %s.", + routed_experts.shape[0], + n_mbs, + [r.shape[0] for r in result], + ) + return result + + +def _infer_mb_sample_count_from_dict( + mb_dict: dict, + total_bs: int, + n_mbs: int, +) -> int: + """Infer sample count from a mini-batch dict.""" + if isinstance(mb_dict, dict): + attn = mb_dict.get("attention_mask") + if attn is not None and hasattr(attn, "shape"): + return attn.shape[0] + ids = mb_dict.get("input_ids") + if ids is not None and hasattr(ids, "shape"): + return ids.shape[0] + return total_bs // n_mbs + + def _log_r3_effectiveness_metrics( routed_experts: torch.Tensor, ) -> None: @@ -303,6 +390,125 @@ def compute_router_agreement_rate( return matches.mean().item() +def log_moe_routing_metrics( + data: dict[str, Any], + scope: str = "moe_routing", +) -> None: + """Log MoE routing effectiveness metrics for ANY MoE model. + + Computes routing quality indicators from the + ``routed_experts`` tensor. These metrics help diagnose routing + quality issues (expert collapse, load imbalance, etc.) and are + useful even without R3. + + Args: + data: Training data dict containing ``"routed_experts"`` + of shape ``(bs, seq_len, num_moe_layers, topk)``. + scope: Stats-tracker scope prefix. + """ + re = data.get("routed_experts") + if re is None: + return + if not isinstance(re, torch.Tensor) or re.dim() < 4: + return + + bs, seq_len, num_layers, topk = re.shape + attn_mask = data.get("attention_mask") + + with stats_tracker.scope(scope): + # ------------------------------------------------------------------ + # 1. Data coverage: fraction of samples with non-zero routing data + # ------------------------------------------------------------------ + has_routing = (re.sum(dim=(1, 2, 3)) != 0).float() + coverage = has_routing.mean().item() + stats_tracker.scalar(data_coverage=coverage) + + # ------------------------------------------------------------------ + # 2. Expert utilization and load balance (per-layer) + # ------------------------------------------------------------------ + if attn_mask is not None: + real_mask = attn_mask.bool() # (bs, seq_len) + else: + real_mask = torch.ones(bs, seq_len, dtype=torch.bool, device=re.device) + + # Expand mask for layers and topk: (bs, seq_len, 1, 1) + token_mask = real_mask.unsqueeze(-1).unsqueeze(-1).expand_as(re) + max_expert_id = re[token_mask].max().item() if token_mask.any() else 0 + num_experts = int(max_expert_id) + 1 + if num_experts < 2: + stats_tracker.scalar( + num_experts=num_experts, + insufficient_data=1, + ) + return + + entropy_sum = 0.0 + balance_sum = 0.0 + top1_concentration_sum = 0.0 + diversity_sum = 0.0 + valid_layers = 0 + + for layer_idx in range(num_layers): + layer_re = re[:, :, layer_idx, :] + layer_mask = real_mask.unsqueeze(-1).expand_as(layer_re) + valid_experts = layer_re[layer_mask] + + if valid_experts.numel() == 0: + continue + + valid_layers += 1 + + expert_counts = torch.bincount( + valid_experts.long().clamp(0, num_experts - 1), + minlength=num_experts, + ).float() + total_assignments = expert_counts.sum() + + if total_assignments == 0: + continue + + expert_probs = expert_counts / total_assignments + + # Routing entropy (normalized) + log_probs = torch.log(expert_probs + 1e-10) + entropy = -(expert_probs * log_probs).sum().item() + max_entropy = torch.log(torch.tensor(float(num_experts))).item() + normalized_entropy = entropy / max_entropy if max_entropy > 0 else 0.0 + entropy_sum += normalized_entropy + + # Expert load imbalance (CV) + load_std = expert_probs.std().item() + load_mean = expert_probs.mean().item() + balance = load_std / (load_mean + 1e-10) + balance_sum += balance + + # Top-1 expert concentration + top1_ratio = expert_probs.max().item() + top1_concentration_sum += top1_ratio + + # Expert diversity + unique_experts_used = (expert_counts > 0).sum().item() + diversity = unique_experts_used / num_experts + diversity_sum += diversity + + if valid_layers > 0: + stats_tracker.scalar( + num_experts=num_experts, + num_moe_layers=num_layers, + routing_entropy=entropy_sum / valid_layers, + expert_load_imbalance_cv=balance_sum / valid_layers, + top1_expert_concentration=top1_concentration_sum / valid_layers, + expert_diversity=diversity_sum / valid_layers, + valid_moe_layers=valid_layers, + ) + else: + stats_tracker.scalar( + num_experts=num_experts, + num_moe_layers=num_layers, + valid_moe_layers=0, + ) + + def strip_routed_experts_before_loss( data: dict[str, Any], ) -> dict[str, Any]: diff --git a/areal/trainer/rl_trainer.py b/areal/trainer/rl_trainer.py index 10ad73707f..a5259114ea 100644 --- a/areal/trainer/rl_trainer.py +++ b/areal/trainer/rl_trainer.py @@ -445,11 +445,19 @@ def train( args={"global_step": global_step}, ), ): - # R3: Log routing replay statistics if available. - if getattr(self.config.rollout, "return_routed_experts", False): - from areal.trainer.ppo.actor_r3_patch import log_r3_data_stats - for traj in adv_batch: - log_r3_data_stats(traj) + # MoE routing metrics: Log for ALL MoE models when + # routed_experts data is available in the trajectory. + # R3 data stats are logged only when R3 is enabled. + for traj in adv_batch: + if "routed_experts" in traj: + from areal.trainer.ppo.actor_r3_patch import ( + log_moe_routing_metrics, + log_r3_data_stats, + ) + log_moe_routing_metrics(traj) + if getattr(self.config.rollout, "return_routed_experts", False): + log_r3_data_stats(traj) + break # Log once per batch, not per trajectory self.actor.ppo_update(adv_batch) self.actor.step_lr_scheduler() From 45fbf9f5cbeead29f95416912058ffe015259192 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Sat, 18 Apr 2026 22:58:45 +0800 Subject: [PATCH 007/112] fix(router): refactor --- areal/engine/megatron_engine_r3_patch.py | 104 ++++++++++++------ areal/engine/router_replay_patch.py | 94 +++------------- areal/engine/router_replay_utils.py | 3 +- ...moonlight_16b_a3b_gsm8k_grpo_megatron.yaml | 1 + 4 files changed, 86 insertions(+), 116 deletions(-) diff --git a/areal/engine/megatron_engine_r3_patch.py b/areal/engine/megatron_engine_r3_patch.py index db46213d32..cacbb75742 100644 --- a/areal/engine/megatron_engine_r3_patch.py +++ b/areal/engine/megatron_engine_r3_patch.py @@ -383,8 +383,9 @@ def _r3_forward_backward_batch( ``setup_per_microbatch_replay_forward``, aligns the seq dimension to match the attention_mask's seq dimension. """ - from areal.engine.router_replay_patch import RouterReplay + from areal.engine.router_replay_patch import RouterReplay, RouterReplayAction from areal.engine.router_replay_utils import ( + RouterReplayHelper, clear_router_replay, setup_per_microbatch_replay_forward, ) @@ -458,20 +459,13 @@ def _r3_forward_backward_batch( model_config = self.tf_config # ------------------------------------------------------------------ - # 2b. Set PP size on RouterReplay for backward recompute remapping. - # Reset call counters so each forward_backward_batch starts clean. + # 2b. Set initial replay action to REPLAY_FORWARD. + # The forward_step wrapper will toggle between REPLAY_FORWARD + # and REPLAY_BACKWARD for each micro-batch. # ------------------------------------------------------------------ - try: - from megatron.core import parallel_state as mpu - pp_size = mpu.get_pipeline_model_parallel_world_size() - except Exception: - pp_size = getattr(model_config, "pipeline_model_parallel_size", 1) - RouterReplay.pp_size = pp_size - RouterReplay.reset_all_call_counters() + RouterReplay.set_global_router_replay_action(RouterReplayAction.REPLAY_FORWARD) logger.debug( - "[R3] Set RouterReplay.pp_size=%d, reset all call counters " - "(%d instances).", - pp_size, + "[R3] Set initial REPLAY_FORWARD action on %d router instances.", len(RouterReplay.router_instances), ) @@ -479,17 +473,20 @@ def _r3_forward_backward_batch( # 3. Wrap the MicroBatchList iterator on the INSTANCE level # # The iterator injects R3 setup before each micro-batch's forward. - # The call-counter mechanism in RouterReplay handles backward - # recompute correctly for both PP=1 and PP>1 scenarios: + # The iterator wrapper also handles + # the REPLAY_FORWARD / REPLAY_BACKWARD toggle per micro-batch: + # + # - At the START of each forward_step (when next() is called): + # 1. If action is REPLAY_BACKWARD, switch to REPLAY_FORWARD + # (this handles backward recompute -> next forward transition) + # 2. Set the replay data for this micro-batch via + # setup_per_microbatch_replay_forward() # - # - PP=1 (forward_backward_no_pipelining): all forwards first, - # then all backwards in REVERSE order. The call counter tracks - # which MB's data to use: forward calls [0..N-1], backward - # recompute calls [N..2N-1] remapped as [N-1..0]. + # - At the END of each forward_step (via model forward hook): + # switch to REPLAY_BACKWARD so that the subsequent backward + # recompute (activation checkpointing) uses + # replay_backward_list.pop(0). # - # - PP>1 (1F1B): interleaved forward-backward. Each MB's backward - # follows its forward closely. The call counter handles this - # with same-order remapping. # ------------------------------------------------------------------ engine_ref = self @@ -513,6 +510,18 @@ def __next__(self): else None ) + # When backward recompute (activation checkpointing) finishes + # and the next forward starts, the action is REPLAY_BACKWARD. + # Switch it back to REPLAY_FORWARD before setting new data. + if RouterReplayHelper.is_replay_backward_action(model_config): + router_list = RouterReplayHelper.get_micro_batch_router_list( + model_config + ) + for router in router_list: + router.set_router_replay_action( + RouterReplayAction.REPLAY_FORWARD + ) + if re is not None: # Problem 5 fix: reconstruct attention_mask from cu_seqlens # when pack_tensor_dict has replaced it. @@ -560,26 +569,53 @@ def _r3_iter(mb_list_self): mb_list.__class__.__iter__ = _r3_iter + # ------------------------------------------------------------------ + # 4. Register a forward hook on each model chunk for the + # REPLAY_FORWARD -> REPLAY_BACKWARD toggle at the END of each + # forward_step. + # ------------------------------------------------------------------ + hook_handles = [] + + def _r3_post_forward_hook(module, input, output): + """Switch from REPLAY_FORWARD to REPLAY_BACKWARD after model forward.""" + if RouterReplayHelper.is_replay_forward_action(model_config): + router_list = RouterReplayHelper.get_micro_batch_router_list( + model_config + ) + for router in router_list: + router.set_router_replay_action( + RouterReplayAction.REPLAY_BACKWARD + ) + + for model_chunk in self.model: + handle = model_chunk.register_forward_hook(_r3_post_forward_hook) + hook_handles.append(handle) + + logger.debug( + "[R3] Registered forward hooks on %d model chunks for " + "FORWARD->BACKWARD toggle.", + len(hook_handles), + ) + try: # Megatron's forward_backward_func (e.g. 1F1B schedule) internally # interleaves forward and backward for each micro-batch. # - # The call-counter mechanism in RouterReplay handles backward - # recompute (activation checkpointing) correctly: - # - Each RouterReplay instance tracks how many times its routing - # function is called via routing_call_count. - # - Forward calls (call_idx < N) use routing_data_list[call_idx]. - # - Backward recompute calls (call_idx >= N) are remapped based - # on pp_size: - # * PP=1: reverse order (N-1-recompute_idx) - # * PP>1: same order (recompute_idx) - # - # This replaces the previous design that relied on target_topk_idx - # (which got overwritten by subsequent MBs in PP=1 multi-MB case). + # Per-forward-step toggle handles + # backward recompute (activation checkpointing) correctly: + # - The iterator wrapper (above) switches REPLAY_BACKWARD -> + # REPLAY_FORWARD at the START of each forward_step. + # - The model forward hook (above) switches REPLAY_FORWARD -> + # REPLAY_BACKWARD at the END of each forward_step. + # - Forward uses target_topk_idx; backward recompute pops from + # replay_backward_list. self._r3_original_forward_backward_batch( mb_list, process_output_fn, forward_only=forward_only ) finally: + # Remove forward hooks + for handle in hook_handles: + handle.remove() # Restore original class __iter__ and clean up R3 state mb_list.__class__.__iter__ = original_class_iter clear_router_replay() diff --git a/areal/engine/router_replay_patch.py b/areal/engine/router_replay_patch.py index a859fe4853..d2a34502aa 100644 --- a/areal/engine/router_replay_patch.py +++ b/areal/engine/router_replay_patch.py @@ -37,7 +37,7 @@ from areal.engine.router_replay_patch import remove_router_replay_patch remove_router_replay_patch() # optional: for test cleanup -Ported from verl reference implementation, adapted for AReaL. +Ref some code from verl, adapted for AReaL. """ from __future__ import annotations @@ -161,32 +161,19 @@ def clear_global_router_replay_action() -> None: for r in RouterReplay.router_instances: r.clear_router_replay_action() - @staticmethod - def reset_all_call_counters() -> None: - """Reset call counters on all instances (call before forward_backward).""" - for r in RouterReplay.router_instances: - r.routing_call_count = 0 - # ------------------------------------------------------------------ - # Instance methods - # ------------------------------------------------------------------ + def __init__(self) -> None: self.target_topk_idx: torch.Tensor | None = None self.recorded_topk_idx: torch.Tensor | None = None self.router_replay_action: RouterReplayAction | None = None - # Legacy list kept for API compatibility; no longer used for - # backward replay (call-counter approach replaces it). self.replay_backward_list: list[torch.Tensor] = [] - # Call-counter mechanism for multi-MB activation checkpointing. - self.routing_data_list: list[torch.Tensor] = [] - self.routing_call_count: int = 0 RouterReplay.router_instances.append(self) def set_target_indices(self, topk_indices: torch.Tensor) -> None: + """Sets the target topk indices for replay.""" self.target_topk_idx = topk_indices - self.routing_data_list.append(topk_indices) - # Keep replay_backward_list in sync for backward compatibility. self.replay_backward_list.append(topk_indices) def get_recorded_indices(self) -> torch.Tensor | None: @@ -199,8 +186,6 @@ def clear_indices(self) -> None: self.recorded_topk_idx = None self.target_topk_idx = None self.replay_backward_list = [] - self.routing_data_list = [] - self.routing_call_count = 0 def set_router_replay_action(self, action: RouterReplayAction) -> None: self.router_replay_action = action @@ -208,54 +193,6 @@ def set_router_replay_action(self, action: RouterReplayAction) -> None: def clear_router_replay_action(self) -> None: self.router_replay_action = None - def get_replay_indices_by_call_count(self) -> torch.Tensor | None: - """Return the correct replay indices based on the current call count. - - Uses the call-counter mechanism to handle multi-MB + activation - checkpointing correctly. - - Returns: - The routing indices tensor for the current call, or None if - no data is available. - """ - n_mbs = len(self.routing_data_list) - if n_mbs == 0: - return None - - call_idx = self.routing_call_count - self.routing_call_count += 1 - - if call_idx < n_mbs: - # Forward pass: use indices in order [0, 1, ..., N-1] - data_idx = call_idx - phase = "forward" - else: - # Backward recompute: remap based on pipeline schedule - recompute_idx = call_idx - n_mbs - if RouterReplay.pp_size <= 1: - # PP=1: forward_backward_no_pipelining - # Backward order is REVERSE: MB_{N-1}, MB_{N-2}, ..., MB_0 - data_idx = n_mbs - 1 - recompute_idx - else: - # PP>1: 1F1B schedule - # Backward order is SAME as forward: MB_0, MB_1, ..., MB_{N-1} - data_idx = recompute_idx - phase = "backward_recompute" - - # Clamp to valid range as safety measure - data_idx = max(0, min(data_idx, n_mbs - 1)) - - if logger.isEnabledFor(logging.DEBUG): - logger.debug( - "[R3] get_replay_indices_by_call_count: " - "call_idx=%d, n_mbs=%d, pp_size=%d, phase=%s, " - "data_idx=%d, shape=%s.", - call_idx, n_mbs, RouterReplay.pp_size, phase, - data_idx, self.routing_data_list[data_idx].shape, - ) - - return self.routing_data_list[data_idx] - # =================================================================== # Patched routing implementation @@ -307,33 +244,30 @@ def compute_topk(scores, topk, num_groups=None, group_topk=None): return probs, top_indices elif routing_action == RouterReplayAction.REPLAY_FORWARD: - # Use call-counter approach for multi-MB + activation checkpointing. - # This handles both forward and backward recompute correctly: - # - Forward calls (call_idx < N): use routing_data_list[call_idx] - # - Backward recompute (call_idx >= N): remapped by PP schedule - top_indices = router_replay.get_replay_indices_by_call_count() - if top_indices is None: + if router_replay is None or router_replay.target_topk_idx is None: + # Fallback if replay data is not available logger.warning( - "[R3] REPLAY_FORWARD: no replay indices available " - "(routing_data_list empty), falling back to normal routing." + "[R3] REPLAY_FORWARD: no replay indices available, " + "falling back to normal routing." ) return _compute_topk(scores, topk, num_groups=num_groups, group_topk=group_topk) + # Use the provided indices for replay + top_indices = router_replay.target_topk_idx top_indices = top_indices.to(scores.device) probs = scores.gather(1, top_indices) return probs, top_indices elif routing_action == RouterReplayAction.REPLAY_BACKWARD: - # Legacy backward mode using replay_backward_list.pop(0). - # Kept for backward compatibility but no longer the primary - # mechanism. The call-counter approach in REPLAY_FORWARD - # handles both forward and backward recompute. if router_replay is None or not router_replay.replay_backward_list: + # Fallback if replay data is not available logger.warning( "[R3] REPLAY_BACKWARD: no backward indices available, " "falling back to normal routing." ) return _compute_topk(scores, topk, num_groups=num_groups, group_topk=group_topk) - top_indices = router_replay.replay_backward_list.pop(0).to(scores.device) + # Use the last recorded indices for backward replay + top_indices = router_replay.replay_backward_list.pop(0) + top_indices = top_indices.to(scores.device) probs = scores.gather(1, top_indices) return probs, top_indices @@ -382,7 +316,7 @@ def compute_topk(scores, topk, num_groups=None, group_topk=None): # =================================================================== -# Aux-loss helpers (from verl reference) +# Aux-loss helpers # =================================================================== diff --git a/areal/engine/router_replay_utils.py b/areal/engine/router_replay_utils.py index 15bc0006d6..ca49034297 100644 --- a/areal/engine/router_replay_utils.py +++ b/areal/engine/router_replay_utils.py @@ -395,8 +395,7 @@ def setup_per_microbatch_replay_forward( ) routed_experts = routed_experts.to(torch.int32) set_router_replay_data(routed_experts, attention_mask, tf_config, vp_rank) - RouterReplay.set_global_router_replay_action(RouterReplayAction.REPLAY_FORWARD) - logger.debug("[R3] Forward replay mode set for micro-batch.") + logger.debug("[R3] Replay data distributed to router instances for micro-batch.") def setup_per_microbatch_replay_backward() -> None: diff --git a/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron.yaml b/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron.yaml index e6d4f1c0fc..a26fa64307 100644 --- a/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron.yaml +++ b/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron.yaml @@ -136,6 +136,7 @@ sglang: # during inference. This is auto-set by rl_trainer when # rollout.return_routed_experts=true, but explicitly declared here for clarity. enable_return_routed_experts: true + chunked_prefill_size: 2048 vllm: model: ${actor.path} From 46552642dbe4929b9b613f460c4e3d22145d50db Mon Sep 17 00:00:00 2001 From: bingyechen Date: Sun, 19 Apr 2026 00:21:04 +0800 Subject: [PATCH 008/112] fix(engine): fix routed_experts format --- areal/engine/sglang_remote.py | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/areal/engine/sglang_remote.py b/areal/engine/sglang_remote.py index 9804d08c3f..a1e77cc5b7 100644 --- a/areal/engine/sglang_remote.py +++ b/areal/engine/sglang_remote.py @@ -94,16 +94,27 @@ def parse_generation_response( stop_reason = finish_reason["type"] stop_message = finish_reason.get("message", "") - # Extract routed_experts information if available + # Extract routed_experts information if available. + # sglang v0.5.9 with skip_tokenizer_init=True bypasses the + # detokenizer where base64 encoding normally happens, so + # routed_experts may arrive as a raw nested list/dict instead + # of a base64-encoded string. Handle both formats. routed_experts = meta_info.get("routed_experts", None) if routed_experts is not None: num_sgl_token = ( meta_info["prompt_tokens"] + meta_info["completion_tokens"] - 1 ) - # Extract expert_id and reshape to (num_sgl_token, num_layers*expert_top_k) - routed_experts = np.frombuffer( - pybase64.b64decode(routed_experts.encode("utf-8")), dtype=np.int32 - ).reshape(num_sgl_token, -1) + if isinstance(routed_experts, str): + # Normal path: base64-encoded int32 bytes + routed_experts = np.frombuffer( + pybase64.b64decode(routed_experts.encode("utf-8")), + dtype=np.int32, + ).reshape(num_sgl_token, -1) + else: + # skip_tokenizer_init=True on sglang<=0.5.9: raw list/dict + routed_experts = np.asarray( + routed_experts, dtype=np.int32 + ).reshape(num_sgl_token, -1) if stop_reason == "abort" and stop_message.startswith("Abort before prefill"): return HttpGenerationResult( From 833ec68cdcbdf67c1d31bb760e9696b0335914e7 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Sun, 19 Apr 2026 01:03:56 +0800 Subject: [PATCH 009/112] fix(sglang): ban skip_tokenizer_init --- areal/engine/sglang_remote.py | 35 +++++++++++-------- ...moonlight_16b_a3b_gsm8k_grpo_megatron.yaml | 2 +- 2 files changed, 22 insertions(+), 15 deletions(-) diff --git a/areal/engine/sglang_remote.py b/areal/engine/sglang_remote.py index a1e77cc5b7..b3025467d3 100644 --- a/areal/engine/sglang_remote.py +++ b/areal/engine/sglang_remote.py @@ -1,3 +1,4 @@ +import logging import os import subprocess import sys @@ -33,6 +34,8 @@ from areal.utils import perf_tracer, stats_tracker from areal.utils.network import format_host_for_url +logger = logging.getLogger(__name__) + class SGLangBackend: """SGLang-specific backend implementation for remote inference.""" @@ -95,26 +98,30 @@ def parse_generation_response( stop_message = finish_reason.get("message", "") # Extract routed_experts information if available. - # sglang v0.5.9 with skip_tokenizer_init=True bypasses the - # detokenizer where base64 encoding normally happens, so - # routed_experts may arrive as a raw nested list/dict instead - # of a base64-encoded string. Handle both formats. + # Requires skip_tokenizer_init=False so that sglang v0.5.9 routes + # through the DetokenizerManager which base64-encodes the tensor. + # When skip_tokenizer_init=True on v0.5.9 the raw tensor is lost + # during JSON serialization (becomes {}). routed_experts = meta_info.get("routed_experts", None) if routed_experts is not None: - num_sgl_token = ( - meta_info["prompt_tokens"] + meta_info["completion_tokens"] - 1 - ) - if isinstance(routed_experts, str): - # Normal path: base64-encoded int32 bytes + if not isinstance(routed_experts, str): + logger.warning( + "[R3] routed_experts is %s instead of base64 str " + "(skip_tokenizer_init must be False for sglang<=0.5.9); " + "discarding.", + type(routed_experts).__name__, + ) + routed_experts = None + else: + num_sgl_token = ( + meta_info["prompt_tokens"] + + meta_info["completion_tokens"] + - 1 + ) routed_experts = np.frombuffer( pybase64.b64decode(routed_experts.encode("utf-8")), dtype=np.int32, ).reshape(num_sgl_token, -1) - else: - # skip_tokenizer_init=True on sglang<=0.5.9: raw list/dict - routed_experts = np.asarray( - routed_experts, dtype=np.int32 - ).reshape(num_sgl_token, -1) if stop_reason == "abort" and stop_message.startswith("Abort before prefill"): return HttpGenerationResult( diff --git a/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron.yaml b/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron.yaml index a26fa64307..3ab98ae7d5 100644 --- a/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron.yaml +++ b/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron.yaml @@ -126,7 +126,7 @@ ref: sglang: model_path: ${actor.path} random_seed: ${seed} - skip_tokenizer_init: true + skip_tokenizer_init: false dtype: bfloat16 max_running_requests: 8 context_length: 2048 From a887b520f07630466074f833d5f4758725b489b3 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Sun, 19 Apr 2026 01:22:18 +0800 Subject: [PATCH 010/112] feat(math): add base config --- ...ight_16b_a3b_gsm8k_grpo_megatron_base.yaml | 193 ++++++++++++++++++ 1 file changed, 193 insertions(+) create mode 100644 examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron_base.yaml diff --git a/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron_base.yaml b/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron_base.yaml new file mode 100644 index 0000000000..5e2e5196eb --- /dev/null +++ b/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron_base.yaml @@ -0,0 +1,193 @@ +experiment_name: moonlight-16b-a3b-gsm8k-grpo +trial_name: trial0 + +seed: 1 +enable_offload: false +total_train_epochs: 10 +tokenizer_path: ${actor.path} + +cluster: + n_nodes: 1 + n_gpus_per_node: 8 + fileroot: /tmp/areal/moon_experiments + name_resolve: + type: nfs + nfs_record_root: /tmp/areal/moon_name_resolve + +scheduler: + type: null + +rollout: + backend: "sglang:d1p1t8" + experiment_name: ${experiment_name} + trial_name: ${trial_name} + max_concurrent_rollouts: 64 + queue_size: null + consumer_batch_size: ${train_dataset.batch_size} + max_head_offpolicyness: 2 + enable_rollout_tracing: false + scheduling_spec: ${actor.scheduling_spec} + fileroot: ${cluster.fileroot} + tokenizer_path: ${tokenizer_path} + dump_to_file: true + +gconfig: + n_samples: 4 + min_new_tokens: 0 + max_new_tokens: 768 + greedy: false + temperature: 1.0 + +actor: + backend: "megatron:(attn:d1p2t4|ffn:d1p2t1e4)" # ← PP=2 回退,TP=4/EP=4 + experiment_name: ${experiment_name} + trial_name: ${trial_name} + path: /workspace/models/Moonlight-16B-A3B-Instruct + init_from_scratch: false + disable_dropout: true + gradient_checkpointing: true + dtype: bfloat16 + mb_spec: + max_tokens_per_mb: 1280 # ← 从 2048 降至 512 + optimizer: + type: adam_bf16 + lr: 5e-6 + weight_decay: 0.003 + beta1: 0.9 + beta2: 0.999 + eps: 1e-8 + lr_scheduler_type: constant + gradient_clipping: 1.0 + warmup_steps_proportion: 0.001 + eps_clip: 0.4 + temperature: ${gconfig.temperature} + reward_scaling: 10.0 + reward_bias: -0.5 + kl_ctl: 0.0 + ppo_n_minibatches: 8 # ← 从 1 提高至 4(分批梯度累积) + recompute_logprob: true + use_decoupled_loss: true + behave_imp_weight_cap: 5.0 + reward_norm: + mean_level: group + std_level: group + group_size: ${gconfig.n_samples} + adv_norm: + mean_level: batch + std_level: batch + weight_update_mode: disk + max_new_tokens: ${gconfig.max_new_tokens} + megatron: + use_deterministic_algorithms: false + recompute_granularity: full + recompute_method: uniform + recompute_num_layers: 14 + main_grads_dtype: bfloat16 # 梯度从 FP32 降为 BF16(节省 ~4 GiB) + # store_param_remainders: true + # optimizer_cpu_offload: true + # optimizer_offload_fraction: 0.5 + # main_params_dtype: bfloat16 + # main_grads_dtype: bfloat16 + # # adam_bf16 已自动设置以下两项,但显式声明更安全 + # exp_avg_dtype: bfloat16 + # exp_avg_sq_dtype: bfloat16 + ddp: + grad_reduce_in_fp32: false # ← 保持逐层重计算 + scheduling_spec: + - task_type: worker + port_count: 2 + gpu: 1 + mem: 48 + cmd: python3 -m areal.infra.rpc.rpc_server + env_vars: + PYTORCH_CUDA_ALLOC_CONF: "expandable_segments:True" + +ref: + backend: ${actor.backend} + experiment_name: ${experiment_name} + trial_name: ${trial_name} + path: ${actor.path} + init_from_scratch: false + disable_dropout: true + dtype: ${actor.dtype} + mb_spec: + max_tokens_per_mb: 1280 + optimizer: null + scheduling_strategy: + type: colocation + target: actor + scheduling_spec: ${actor.scheduling_spec} + +sglang: + model_path: ${actor.path} + random_seed: ${seed} + skip_tokenizer_init: false + dtype: bfloat16 + max_running_requests: 8 + context_length: 2048 + mem_fraction_static: 0.2 + attention_backend: triton + +vllm: + model: ${actor.path} + seed: ${seed} + skip_tokenizer_init: false + dtype: bfloat16 + max_model_len: 4096 + gpu_memory_utilization: 0.75 + +train_dataset: + batch_size: 64 + shuffle: true + pin_memory: true + num_workers: 4 + path: openai/gsm8k + type: rl + max_length: 512 + +valid_dataset: + batch_size: 128 + pin_memory: true + num_workers: 4 + path: openai/gsm8k + type: rl + +saver: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: 1 + freq_steps: null + freq_secs: null + +recover: + mode: disabled + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: 1 + freq_steps: null + freq_secs: 3600 + +evaluator: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: 1 + freq_steps: null + freq_secs: null + +stats_logger: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + wandb: + mode: disabled + +perf_tracer: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + enabled: false + session_tracer: + enabled: false From e5dae7d471840da46d7ea53a00d1eec334d37497 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Sun, 19 Apr 2026 10:20:17 +0800 Subject: [PATCH 011/112] fix(math): fix config --- .../moonlight_16b_a3b_gsm8k_grpo_megatron_base.yaml | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron_base.yaml b/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron_base.yaml index 5e2e5196eb..315ec0481e 100644 --- a/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron_base.yaml +++ b/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron_base.yaml @@ -32,9 +32,9 @@ rollout: dump_to_file: true gconfig: - n_samples: 4 + n_samples: 8 min_new_tokens: 0 - max_new_tokens: 768 + max_new_tokens: 1024 greedy: false temperature: 1.0 @@ -51,7 +51,7 @@ actor: max_tokens_per_mb: 1280 # ← 从 2048 降至 512 optimizer: type: adam_bf16 - lr: 5e-6 + lr: 2e-6 weight_decay: 0.003 beta1: 0.9 beta2: 0.999 @@ -64,7 +64,7 @@ actor: reward_scaling: 10.0 reward_bias: -0.5 kl_ctl: 0.0 - ppo_n_minibatches: 8 # ← 从 1 提高至 4(分批梯度累积) + ppo_n_minibatches: 1 # ← 从 1 提高至 4(分批梯度累积) recompute_logprob: true use_decoupled_loss: true behave_imp_weight_cap: 5.0 @@ -82,7 +82,8 @@ actor: recompute_granularity: full recompute_method: uniform recompute_num_layers: 14 - main_grads_dtype: bfloat16 # 梯度从 FP32 降为 BF16(节省 ~4 GiB) + main_grads_dtype: bfloat16 + # main_params_dtype: bfloat16 # 梯度从 FP32 降为 BF16(节省 ~4 GiB) # store_param_remainders: true # optimizer_cpu_offload: true # optimizer_offload_fraction: 0.5 @@ -190,4 +191,4 @@ perf_tracer: fileroot: ${cluster.fileroot} enabled: false session_tracer: - enabled: false + enabled: false \ No newline at end of file From c06c5d0df633beb9d4b5a2ec62e145d207a9f724 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Sun, 19 Apr 2026 10:36:47 +0800 Subject: [PATCH 012/112] fix: fix skip_tokenizer_init --- examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron_base.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron_base.yaml b/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron_base.yaml index 315ec0481e..008da0aec8 100644 --- a/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron_base.yaml +++ b/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron_base.yaml @@ -122,7 +122,7 @@ ref: sglang: model_path: ${actor.path} random_seed: ${seed} - skip_tokenizer_init: false + skip_tokenizer_init: true dtype: bfloat16 max_running_requests: 8 context_length: 2048 From c53033abc13d50134765667150d2703b03b82e7d Mon Sep 17 00:00:00 2001 From: bingyechen Date: Sun, 19 Apr 2026 13:51:51 +0800 Subject: [PATCH 013/112] feat(engine): fix optimizer --- areal/engine/megatron_engine.py | 26 +++++++++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index 95e75097c2..08028ad523 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -1091,7 +1091,31 @@ def _create_optimizer(self, ft_spec: FinetuneSpec) -> None: mcore_opt_config.exp_avg_sq_dtype = getattr( torch, self.mcore_config.exp_avg_sq_dtype ) - + # for run moe + mcore_opt_config.use_precision_aware_optimizer_no_fp8_or_ds_fp8 = ( + mcore_opt_config.use_precision_aware_optimizer + and ( + mcore_opt_config.main_params_dtype != torch.float32 + or (mcore_opt_config.fp8_recipe is None or mcore_opt_config.fp8_recipe == "delayed") + or mcore_opt_config.optimizer_cpu_offload + ) + ) + mcore_opt_config.store_param_remainders = True + import logging as _logging + _opt_logger = _logging.getLogger('AReaL.OptDiag') + _opt_logger.warning( + f'[OptDiag] Megatron OptimizerConfig: ' + f'use_precision_aware_optimizer={mcore_opt_config.use_precision_aware_optimizer}, ' + f'use_precision_aware_optimizer_no_fp8_or_ds_fp8=' + f'{getattr(mcore_opt_config, "use_precision_aware_optimizer_no_fp8_or_ds_fp8", "N/A")}, ' + f'store_param_remainders={mcore_opt_config.store_param_remainders}, ' + f'main_params_dtype={mcore_opt_config.main_params_dtype}, ' + f'main_grads_dtype={mcore_opt_config.main_grads_dtype}, ' + f'exp_avg_dtype={mcore_opt_config.exp_avg_dtype}, ' + f'exp_avg_sq_dtype={mcore_opt_config.exp_avg_sq_dtype}, ' + f'use_distributed_optimizer={mcore_opt_config.use_distributed_optimizer}, ' + f'bf16={mcore_opt_config.bf16}' + ) self.optimizer = get_megatron_optimizer(mcore_opt_config, self.model) warmup_steps_proportion = self.optimizer_config.warmup_steps_proportion From 9b752a4d02d2780e9139696d51443760d86096a0 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Sun, 19 Apr 2026 17:59:52 +0800 Subject: [PATCH 014/112] feat(router): fix code --- areal/api/cli_args.py | 11 + areal/engine/megatron_engine.py | 15 +- areal/engine/megatron_engine_r3_patch.py | 38 +- areal/engine/router_replay_patch.py | 17 + areal/engine/router_replay_utils.py | 18 +- areal/engine/sglang_remote.py | 61 +- areal/trainer/ppo/actor.py | 25 +- areal/trainer/ppo/actor_r3_patch.py | 77 +- areal/trainer/rl_trainer.py | 12 +- moonlight_moe_r3.txt | 1094 ++++++++++++++++++++++ 10 files changed, 1320 insertions(+), 48 deletions(-) create mode 100644 moonlight_moe_r3.txt diff --git a/areal/api/cli_args.py b/areal/api/cli_args.py index e08c852ec4..2491b988e8 100644 --- a/areal/api/cli_args.py +++ b/areal/api/cli_args.py @@ -898,6 +898,17 @@ class MegatronEngineConfig: }, ) + enable_router_replay: bool = field( + default=False, + metadata={ + "help": "Enable Router Replay (R3) for MoE models. " + "When True, the training forward pass replays the expert routing " + "decisions from the inference engine to reduce train-inference " + "routing discrepancy. Automatically set by the trainer when " + "rollout.return_routed_experts=True." + }, + ) + class SchedulingStrategyType(str, Enum): separation = "separation" diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index 08028ad523..3da13d8313 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -162,6 +162,7 @@ def __init__(self, config: TrainEngineConfig): self.seed: int = 0 self.own_global_group: bool = False self.is_offload: bool = False + self._r3_enabled: bool = getattr(config.megatron, "enable_router_replay", False) self.enable_tree_training: bool = self.config.enable_tree_training # FP8 configuration self.fp8_config = self.mcore_config.fp8_config @@ -285,8 +286,14 @@ def initialize(self, addr: str | None, ft_spec: FinetuneSpec, *args, **kwargs): self.tokenizer = load_hf_tokenizer(self.config.path) - # R3: Check early so the variable is always defined. - _r3_enabled = getattr(self.config, "_r3_enable_router_replay", False) + # R3: _r3_enabled was set in __init__ from config.megatron.enable_router_replay. + self.logger.info( + "[R3] enable_router_replay=%s (config.megatron type=%s, " + "config type=%s).", + self._r3_enabled, + type(self.config.megatron).__name__, + type(self.config).__name__, + ) with patch_bridge_for_tree_training( self.enable_tree_training and self.bridge_cls == "mbridge" ): @@ -311,7 +318,7 @@ def initialize(self, addr: str | None, ft_spec: FinetuneSpec, *args, **kwargs): # R3: Apply Router Replay patch BEFORE model creation so that # TopKRouter.__init__ and TransformerConfig.__init__ are patched. - if _r3_enabled: + if self._r3_enabled: from areal.engine.router_replay_patch import apply_router_replay_patch apply_router_replay_patch() self.tf_config.enable_routing_replay = True @@ -410,7 +417,7 @@ def initialize(self, addr: str | None, ft_spec: FinetuneSpec, *args, **kwargs): self._create_optimizer(ft_spec) # R3: Apply engine-level patch after model and optimizer are ready. - if _r3_enabled: + if self._r3_enabled: from areal.engine.megatron_engine_r3_patch import patch_megatron_engine_for_r3 patch_megatron_engine_for_r3(self, enable_router_replay=True) self.logger.info("[R3] Router Replay enabled on MegatronEngine.") diff --git a/areal/engine/megatron_engine_r3_patch.py b/areal/engine/megatron_engine_r3_patch.py index cacbb75742..b244495e75 100644 --- a/areal/engine/megatron_engine_r3_patch.py +++ b/areal/engine/megatron_engine_r3_patch.py @@ -144,6 +144,14 @@ def _align_routed_experts_to_mask( the post-pack ``attention_mask`` is LEFT-aligned (real tokens first, no left-padding). + **Batch size alignment**: ``pad_packed_tensor_dict`` appends one extra + cu_seqlens entry (a padding sequence) to fill the micro-batch to + ``pad_to_length``. This makes ``attention_mask`` have one more row + than the original ``routed_experts``. We zero-pad the batch dimension + so that ``set_router_replay_data`` sees matching batch sizes; the + padding sample's zero routing indices are harmless because the model + ignores those dummy tokens. + This function extracts the right-most ``actual_len`` tokens from each sample's left-padded ``routed_experts`` and places them at the left-aligned positions expected by ``attention_mask``. @@ -155,10 +163,34 @@ def _align_routed_experts_to_mask( Left-aligned mask (1 for real tokens, 0 for padding). Returns: - ``(bs, mb_max_seqlen, num_moe_layers, topk)`` aligned tensor. + ``(bs_aligned, mb_max_seqlen, num_moe_layers, topk)`` aligned tensor. """ - bs, re_seqlen = routed_experts.shape[:2] - _, mask_seqlen = attention_mask.shape[:2] + re_bs, re_seqlen = routed_experts.shape[:2] + mask_bs, mask_seqlen = attention_mask.shape[:2] + + if re_bs < mask_bs: + extra_dims = routed_experts.shape[2:] + padded_re = torch.zeros( + mask_bs, re_seqlen, *extra_dims, + dtype=routed_experts.dtype, + device=routed_experts.device, + ) + padded_re[:re_bs] = routed_experts + routed_experts = padded_re + logger.info( + "[R3] _align_routed_experts_to_mask: padded routed_experts batch " + "from %d to %d samples (pad_mb_list added %d padding sequence(s)).", + re_bs, mask_bs, mask_bs - re_bs, + ) + elif re_bs > mask_bs: + routed_experts = routed_experts[:mask_bs] + logger.warning( + "[R3] _align_routed_experts_to_mask: truncated routed_experts batch " + "from %d to %d samples.", + re_bs, mask_bs, + ) + + bs = routed_experts.shape[0] if re_seqlen == mask_seqlen: # No alignment needed diff --git a/areal/engine/router_replay_patch.py b/areal/engine/router_replay_patch.py index d2a34502aa..069f20dd60 100644 --- a/areal/engine/router_replay_patch.py +++ b/areal/engine/router_replay_patch.py @@ -251,6 +251,23 @@ def compute_topk(scores, topk, num_groups=None, group_topk=None): "falling back to normal routing." ) return _compute_topk(scores, topk, num_groups=num_groups, group_topk=group_topk) + + # Compute natural topk for Router Agreement Rate metric. + # This measures how much the training router's natural selection + # diverges from the replayed inference routing. + with torch.no_grad(): + _, natural_indices = _compute_topk( + scores, topk, num_groups=num_groups, group_topk=group_topk + ) + replay_indices = router_replay.target_topk_idx.to(scores.device) + natural_sorted = natural_indices.sort(dim=-1).values + replay_sorted = replay_indices.sort(dim=-1).values + matches = (natural_sorted == replay_sorted).all(dim=-1).float() + agreement_rate = matches.mean().item() + from areal.utils import stats_tracker + with stats_tracker.scope("r3"): + stats_tracker.scalar(router_agreement_rate=agreement_rate) + # Use the provided indices for replay top_indices = router_replay.target_topk_idx top_indices = top_indices.to(scores.device) diff --git a/areal/engine/router_replay_utils.py b/areal/engine/router_replay_utils.py index ca49034297..98787d347c 100644 --- a/areal/engine/router_replay_utils.py +++ b/areal/engine/router_replay_utils.py @@ -251,19 +251,31 @@ def set_router_replay_data( with torch.no_grad(): device = torch.cuda.current_device() - bs, max_seq_len = attention_mask.shape[:2] + bs_re = layers_topk_idx.shape[0] + bs_mask, max_seq_len = attention_mask.shape[:2] + + if bs_re != bs_mask: + logger.warning( + "[R3] set_router_replay_data: batch size mismatch! " + "layers_topk_idx.shape[0]=%d != attention_mask.shape[0]=%d. " + "Clamping iteration to min=%d.", + bs_re, bs_mask, min(bs_re, bs_mask), + ) + bs = min(bs_re, bs_mask) logger.debug( "[R3] set_router_replay_data: input layers_topk_idx=%s, " - "attention_mask=%s, bs=%d, max_seq_len=%d.", + "attention_mask=%s, bs=%d (re_bs=%d, mask_bs=%d), max_seq_len=%d.", layers_topk_idx.shape, attention_mask.shape, bs, + bs_re, + bs_mask, max_seq_len, ) # Step 1: Remove left-padding -> flat (total_real_tokens, num_layers, topk) - seq_lens = attention_mask.sum(dim=1).long() # (bs,) + seq_lens = attention_mask.sum(dim=1).long() # (bs_mask,) pieces = [] for i in range(bs): slen = int(seq_lens[i].item()) diff --git a/areal/engine/sglang_remote.py b/areal/engine/sglang_remote.py index b3025467d3..945fad3429 100644 --- a/areal/engine/sglang_remote.py +++ b/areal/engine/sglang_remote.py @@ -98,30 +98,49 @@ def parse_generation_response( stop_message = finish_reason.get("message", "") # Extract routed_experts information if available. - # Requires skip_tokenizer_init=False so that sglang v0.5.9 routes - # through the DetokenizerManager which base64-encodes the tensor. - # When skip_tokenizer_init=True on v0.5.9 the raw tensor is lost - # during JSON serialization (becomes {}). + # SGLang may return routed_experts in two formats: + # 1. Base64-encoded string (skip_tokenizer_init=False, normal path) + # 2. Raw list/dict (skip_tokenizer_init=True or newer SGLang versions) routed_experts = meta_info.get("routed_experts", None) if routed_experts is not None: - if not isinstance(routed_experts, str): - logger.warning( - "[R3] routed_experts is %s instead of base64 str " - "(skip_tokenizer_init must be False for sglang<=0.5.9); " - "discarding.", - type(routed_experts).__name__, - ) - routed_experts = None + num_sgl_token = ( + meta_info["prompt_tokens"] + meta_info["completion_tokens"] - 1 + ) + if isinstance(routed_experts, str): + try: + routed_experts = np.frombuffer( + pybase64.b64decode(routed_experts.encode("utf-8")), + dtype=np.int32, + ).reshape(num_sgl_token, -1) + except Exception: + logger.warning( + "[R3] Failed to decode base64 routed_experts " + "(num_sgl_token=%d): %s", + num_sgl_token, + exc_info=True, + ) + routed_experts = None else: - num_sgl_token = ( - meta_info["prompt_tokens"] - + meta_info["completion_tokens"] - - 1 - ) - routed_experts = np.frombuffer( - pybase64.b64decode(routed_experts.encode("utf-8")), - dtype=np.int32, - ).reshape(num_sgl_token, -1) + try: + routed_experts = np.asarray( + routed_experts, dtype=np.int32 + ).reshape(num_sgl_token, -1) + logger.info( + "[R3] Converted routed_experts from %s to numpy array " + "(shape=%s, num_sgl_token=%d).", + type(meta_info.get("routed_experts")).__name__, + routed_experts.shape, + num_sgl_token, + ) + except Exception: + logger.warning( + "[R3] Failed to convert routed_experts from %s " + "(num_sgl_token=%d): %s", + type(meta_info.get("routed_experts")).__name__, + num_sgl_token, + exc_info=True, + ) + routed_experts = None if stop_reason == "abort" and stop_message.startswith("Abort before prefill"): return HttpGenerationResult( diff --git a/areal/trainer/ppo/actor.py b/areal/trainer/ppo/actor.py index 00eb9c673c..58167f3748 100644 --- a/areal/trainer/ppo/actor.py +++ b/areal/trainer/ppo/actor.py @@ -322,6 +322,11 @@ def _ppo_update(self, data: dict[str, Any]) -> None: # NOTE: calling engine.train() is critical to enabling gradient checkpointing self.engine.train() _r3_routed_experts = data.pop("routed_experts", None) + if _r3_routed_experts is not None and not isinstance( + _r3_routed_experts, torch.Tensor + ): + from areal.trainer.ppo.actor_r3_patch import _resolve_to_tensor + _r3_routed_experts = _resolve_to_tensor(_r3_routed_experts) mb_inputs = split_padded_tensor_dict_into_mb_list( data, @@ -349,12 +354,20 @@ def _ppo_update(self, data: dict[str, Any]) -> None: for i, mb in enumerate(mb_inputs.mbs): # deliver routed_experts via engine side-channel # to bypass pack_tensor_dict corruption and ensure correct per-mini-batch data. - if _r3_split is not None and hasattr(self.engine, '_r3_enabled'): - self.engine._r3_pending_routed_experts = ( - _r3_split[i] - if i < len(_r3_split) and _r3_split[i] is not None - else None - ) + if _r3_split is not None: + if hasattr(self.engine, "_r3_enabled"): + self.engine._r3_pending_routed_experts = ( + _r3_split[i] + if i < len(_r3_split) and _r3_split[i] is not None + else None + ) + else: + logger.warning( + "[R3] routed_experts available but engine._r3_enabled " + "attribute missing (R3 engine patch not applied). " + "Check that config.actor.megatron.enable_router_replay " + "is set before engine creation.", + ) train_stat = self.engine.train_batch( mb, diff --git a/areal/trainer/ppo/actor_r3_patch.py b/areal/trainer/ppo/actor_r3_patch.py index 376a485c07..fec34f8b41 100644 --- a/areal/trainer/ppo/actor_r3_patch.py +++ b/areal/trainer/ppo/actor_r3_patch.py @@ -61,6 +61,65 @@ logger = logging.getLogger(__name__) +def _resolve_to_tensor(obj: Any) -> torch.Tensor | None: + """Resolve *obj* to a ``torch.Tensor``, handling RTensor and numpy. + + Returns ``None`` if *obj* is ``None`` or cannot be converted. + """ + if obj is None: + return None + if isinstance(obj, torch.Tensor): + return obj + try: + from areal.infra.rpc.rtensor import RTensor + + if isinstance(obj, RTensor): + return obj.to_local() + except ImportError: + pass + try: + return torch.as_tensor(obj) + except Exception: + logger.warning( + "[R3] Failed to resolve %s to torch.Tensor.", + type(obj).__name__, + exc_info=True, + ) + return None + + +def _ensure_tensor_routed_experts(data: dict[str, Any]) -> torch.Tensor | None: + """Extract ``routed_experts`` from *data*, converting to Tensor if needed. + + Handles the case where SGLang returns routed_experts as a numpy array, + RTensor, or other array-like type instead of a ``torch.Tensor``. + Logs a warning when a conversion is performed so that upstream data + pipelines can be diagnosed. + """ + re = data.get("routed_experts") + if re is None: + return None + if isinstance(re, torch.Tensor): + return re + + re_tensor = _resolve_to_tensor(re) + if re_tensor is not None: + logger.info( + "[R3] routed_experts was %s (shape=%s); resolved to torch.Tensor " + "(shape=%s, dtype=%s).", + type(re).__name__, + getattr(re, "shape", "unknown"), + re_tensor.shape, + re_tensor.dtype, + ) + else: + logger.warning( + "[R3] Failed to resolve routed_experts from %s to torch.Tensor.", + type(re).__name__, + ) + return re_tensor + + def log_r3_data_stats( data: dict[str, Any], scope: str = "r3", @@ -75,7 +134,7 @@ def log_r3_data_stats( data: The training data dict that may contain ``"routed_experts"``. scope: Stats-tracker scope prefix. """ - re = data.get("routed_experts") + re = _ensure_tensor_routed_experts(data) if re is None: return @@ -90,10 +149,7 @@ def log_r3_data_stats( r3_max_expert_id=re.max().item() if re.numel() > 0 else 0, ) - # Compute R3 effectiveness metrics _log_r3_effectiveness_metrics(re) - else: - stats_tracker.scalar(r3_present=0) def split_routed_experts_for_minibatches( @@ -406,14 +462,14 @@ def log_moe_routing_metrics( of shape ``(bs, seq_len, num_moe_layers, topk)``. scope: Stats-tracker scope prefix. """ - re = data.get("routed_experts") + re = _ensure_tensor_routed_experts(data) if re is None: return if not isinstance(re, torch.Tensor) or re.dim() < 4: return bs, seq_len, num_layers, topk = re.shape - attn_mask = data.get("attention_mask") + attn_mask = _resolve_to_tensor(data.get("attention_mask")) with stats_tracker.scope(scope): # ------------------------------------------------------------------ @@ -426,9 +482,16 @@ def log_moe_routing_metrics( # ------------------------------------------------------------------ # 2. Expert utilization and load balance (per-layer) # ------------------------------------------------------------------ - if attn_mask is not None: + if attn_mask is not None and attn_mask.shape[1] == seq_len: real_mask = attn_mask.bool() # (bs, seq_len) else: + if attn_mask is not None: + logger.warning( + "[R3] attn_mask seq_len (%d) != routed_experts seq_len (%d); " + "falling back to all-ones mask.", + attn_mask.shape[1], + seq_len, + ) real_mask = torch.ones(bs, seq_len, dtype=torch.bool, device=re.device) # Expand mask for layers and topk: (bs, seq_len, 1, 1) diff --git a/areal/trainer/rl_trainer.py b/areal/trainer/rl_trainer.py index a5259114ea..eecf0ff5d6 100644 --- a/areal/trainer/rl_trainer.py +++ b/areal/trainer/rl_trainer.py @@ -126,6 +126,14 @@ def __init__( self._amend_xccl_weight_update_envvar() + # R3: Propagate router replay flag to actor engine config so that + # MegatronEngine.initialize() can apply the R3 patch. + # Must be set BEFORE engine creation so that Ray-serialized config + # carries the flag. Uses the declared MegatronEngineConfig field + # (not a dynamic attribute) to survive Ray serialization. + if getattr(config.rollout, "return_routed_experts", False): + config.actor.megatron.enable_router_replay = True + # Create models: actor, critic, ref — each with its own allocation. self.actor = self._create_train_engine(config.actor, self.actor_alloc) self.critic = None @@ -188,10 +196,6 @@ def __init__( engine_init_kwargs = {"addr": None, "ft_spec": ft_spec} - # R3: Propagate router replay flag to actor engine config so that - # MegatronEngine.initialize() can apply the R3 patch. - if getattr(config.rollout, "return_routed_experts", False): - config.actor._r3_enable_router_replay = True self.actor.initialize(**engine_init_kwargs, role="actor") if self.critic is not None: self.critic.initialize(**engine_init_kwargs, role="critic") diff --git a/moonlight_moe_r3.txt b/moonlight_moe_r3.txt new file mode 100644 index 0000000000..90e09281ba --- /dev/null +++ b/moonlight_moe_r3.txt @@ -0,0 +1,1094 @@ +nohup: ignoring input +(AReaL) 20260419-07:57:56.352 CLIArgs WARNING: behave_imp_weight_cap and behave_imp_weight_mode are configured but use_decoupled_loss=False. These settings will be ignored. Set use_decoupled_loss=True to enable decoupled loss with importance weight correction. +(AReaL) 20260419-07:57:56.411 CLIArgs WARNING: behave_imp_weight_cap and behave_imp_weight_mode are configured but use_decoupled_loss=False. These settings will be ignored. Set use_decoupled_loss=True to enable decoupled loss with importance weight correction. +(AReaL) 20260419-07:58:25.394 PlatformInit INFO: Detected CUDA device: NVIDIA L20 +(AReaL) 20260419-07:58:25.394 PlatformInit INFO: Initializing CUDA platform (NVIDIA). +(AReaL) 20260419-07:58:25.395 FileSystemUtils WARNING: cluster.fileroot '/tmp/areal/moon_experiments' is not on a network filesystem. This may cause issues in distributed training where all nodes need access to the same files. Consider using NFS, Lustre, or other shared storage. +(AReaL) 20260419-07:58:25.395 FileSystemUtils WARNING: name_resolve.nfs_record_root '/tmp/areal/moon_name_resolve' is not on a network filesystem. This may cause issues in distributed training where all nodes need access to the same files. Consider using NFS, Lustre, or other shared storage. +(AReaL) 20260419-07:58:25.396 NameResolve INFO: Removing name resolve path: /tmp/areal/moon_name_resolve/root/moonlight_moe_exp/moonlight_moe_0419_r3_v2 +(AReaL) 20260419-07:58:25.397 LocalScheduler INFO: LocalScheduler initialized with GPU devices: [0, 1, 2, 3, 4, 5, 6, 7], log directory: /tmp/areal/moon_experiments/logs/root/moonlight_moe_exp/moonlight_moe_0419_r3_v2 +(AReaL) 20260419-07:58:29.210 BailingMoe INFO: Patched apply_rotary_pos_emb: extend truncated freq table for MLA THD+CP>1 +(AReaL) 20260419-07:58:29.240 TreeAttentionFSDP INFO: Compiled torch flex attention. Options: {'epilogue_fusion': True, 'max_autotune': False, 'shape_padding': True, 'trace.enabled': False, 'triton.cudagraphs': False}, dynamic: True +(AReaL) 20260419-07:58:29.242 TreeAttentionFSDP INFO: Using block mask in flex attention, block size: 128 +[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 +(AReaL) 20260419-07:58:29.272 TrainController INFO: Creating workers via scheduler... +(AReaL) 20260419-07:58:29.272 LocalScheduler INFO: Creating 8 workers for role 'actor' (strategy: SchedulingStrategyType.separation, colocate_with: None) +(AReaL) 20260419-07:58:29.273 LauncherUtils INFO: Auto-setting thread env vars to 8: OMP_NUM_THREADS, MKL_NUM_THREADS, OPENBLAS_NUM_THREADS, VECLIB_MAXIMUM_THREADS, NUMEXPR_NUM_THREADS +(AReaL) 20260419-07:58:29.273 LocalScheduler INFO: Starting worker actor/0: python3 -m areal.infra.rpc.rpc_server --port 7851 --experiment-name moonlight_moe_exp --trial-name moonlight_moe_0419_r3_v2 --role actor --worker-index 0 --name-resolve-type nfs --nfs-record-root /tmp/areal/moon_name_resolve --etcd3-addr localhost:2379 --fileroot /tmp/areal/moon_experiments +(AReaL) 20260419-07:58:29.374 LocalScheduler INFO: Worker actor/0 started (PID: 2380208, GPUs: [0], ports: [7851, 34213]) +(AReaL) 20260419-07:58:29.374 LauncherUtils INFO: Auto-setting thread env vars to 8: OMP_NUM_THREADS, MKL_NUM_THREADS, OPENBLAS_NUM_THREADS, VECLIB_MAXIMUM_THREADS, NUMEXPR_NUM_THREADS +(AReaL) 20260419-07:58:29.374 LocalScheduler INFO: Starting worker actor/1: python3 -m areal.infra.rpc.rpc_server --port 16368 --experiment-name moonlight_moe_exp --trial-name moonlight_moe_0419_r3_v2 --role actor --worker-index 1 --name-resolve-type nfs --nfs-record-root /tmp/areal/moon_name_resolve --etcd3-addr localhost:2379 --fileroot /tmp/areal/moon_experiments +(AReaL) 20260419-07:58:29.475 LocalScheduler INFO: Worker actor/1 started (PID: 2380215, GPUs: [1], ports: [16368, 56139]) +(AReaL) 20260419-07:58:29.475 LauncherUtils INFO: Auto-setting thread env vars to 8: OMP_NUM_THREADS, MKL_NUM_THREADS, OPENBLAS_NUM_THREADS, VECLIB_MAXIMUM_THREADS, NUMEXPR_NUM_THREADS +(AReaL) 20260419-07:58:29.475 LocalScheduler INFO: Starting worker actor/2: python3 -m areal.infra.rpc.rpc_server --port 20768 --experiment-name moonlight_moe_exp --trial-name moonlight_moe_0419_r3_v2 --role actor --worker-index 2 --name-resolve-type nfs --nfs-record-root /tmp/areal/moon_name_resolve --etcd3-addr localhost:2379 --fileroot /tmp/areal/moon_experiments +(AReaL) 20260419-07:58:29.576 LocalScheduler INFO: Worker actor/2 started (PID: 2380222, GPUs: [2], ports: [20768, 26612]) +(AReaL) 20260419-07:58:29.576 LauncherUtils INFO: Auto-setting thread env vars to 8: OMP_NUM_THREADS, MKL_NUM_THREADS, OPENBLAS_NUM_THREADS, VECLIB_MAXIMUM_THREADS, NUMEXPR_NUM_THREADS +(AReaL) 20260419-07:58:29.576 LocalScheduler INFO: Starting worker actor/3: python3 -m areal.infra.rpc.rpc_server --port 16842 --experiment-name moonlight_moe_exp --trial-name moonlight_moe_0419_r3_v2 --role actor --worker-index 3 --name-resolve-type nfs --nfs-record-root /tmp/areal/moon_name_resolve --etcd3-addr localhost:2379 --fileroot /tmp/areal/moon_experiments +(AReaL) 20260419-07:58:29.677 LocalScheduler INFO: Worker actor/3 started (PID: 2380229, GPUs: [3], ports: [16842, 56427]) +(AReaL) 20260419-07:58:29.677 LauncherUtils INFO: Auto-setting thread env vars to 8: OMP_NUM_THREADS, MKL_NUM_THREADS, OPENBLAS_NUM_THREADS, VECLIB_MAXIMUM_THREADS, NUMEXPR_NUM_THREADS +(AReaL) 20260419-07:58:29.677 LocalScheduler INFO: Starting worker actor/4: python3 -m areal.infra.rpc.rpc_server --port 52722 --experiment-name moonlight_moe_exp --trial-name moonlight_moe_0419_r3_v2 --role actor --worker-index 4 --name-resolve-type nfs --nfs-record-root /tmp/areal/moon_name_resolve --etcd3-addr localhost:2379 --fileroot /tmp/areal/moon_experiments +(AReaL) 20260419-07:58:29.778 LocalScheduler INFO: Worker actor/4 started (PID: 2380236, GPUs: [4], ports: [52722, 57060]) +(AReaL) 20260419-07:58:29.778 LauncherUtils INFO: Auto-setting thread env vars to 8: OMP_NUM_THREADS, MKL_NUM_THREADS, OPENBLAS_NUM_THREADS, VECLIB_MAXIMUM_THREADS, NUMEXPR_NUM_THREADS +(AReaL) 20260419-07:58:29.778 LocalScheduler INFO: Starting worker actor/5: python3 -m areal.infra.rpc.rpc_server --port 14757 --experiment-name moonlight_moe_exp --trial-name moonlight_moe_0419_r3_v2 --role actor --worker-index 5 --name-resolve-type nfs --nfs-record-root /tmp/areal/moon_name_resolve --etcd3-addr localhost:2379 --fileroot /tmp/areal/moon_experiments +(AReaL) 20260419-07:58:29.879 LocalScheduler INFO: Worker actor/5 started (PID: 2380243, GPUs: [5], ports: [14757, 20960]) +(AReaL) 20260419-07:58:29.879 LauncherUtils INFO: Auto-setting thread env vars to 8: OMP_NUM_THREADS, MKL_NUM_THREADS, OPENBLAS_NUM_THREADS, VECLIB_MAXIMUM_THREADS, NUMEXPR_NUM_THREADS +(AReaL) 20260419-07:58:29.879 LocalScheduler INFO: Starting worker actor/6: python3 -m areal.infra.rpc.rpc_server --port 4458 --experiment-name moonlight_moe_exp --trial-name moonlight_moe_0419_r3_v2 --role actor --worker-index 6 --name-resolve-type nfs --nfs-record-root /tmp/areal/moon_name_resolve --etcd3-addr localhost:2379 --fileroot /tmp/areal/moon_experiments +(AReaL) 20260419-07:58:29.980 LocalScheduler INFO: Worker actor/6 started (PID: 2380250, GPUs: [6], ports: [4458, 54325]) +(AReaL) 20260419-07:58:29.980 LauncherUtils INFO: Auto-setting thread env vars to 8: OMP_NUM_THREADS, MKL_NUM_THREADS, OPENBLAS_NUM_THREADS, VECLIB_MAXIMUM_THREADS, NUMEXPR_NUM_THREADS +(AReaL) 20260419-07:58:29.980 LocalScheduler INFO: Starting worker actor/7: python3 -m areal.infra.rpc.rpc_server --port 21975 --experiment-name moonlight_moe_exp --trial-name moonlight_moe_0419_r3_v2 --role actor --worker-index 7 --name-resolve-type nfs --nfs-record-root /tmp/areal/moon_name_resolve --etcd3-addr localhost:2379 --fileroot /tmp/areal/moon_experiments +(AReaL) 20260419-07:58:30.081 LocalScheduler INFO: Worker actor/7 started (PID: 2380264, GPUs: [7], ports: [21975, 48890]) +(AReaL) 20260419-07:58:30.081 LocalScheduler INFO: Successfully created 8 workers for role 'actor' +(AReaL) 20260419-07:58:36.108 SyncRPCServer INFO: Werkzeug log level: WARNING +(AReaL) 20260419-07:58:36.109 Guard INFO: Starting Guard on [fdbd:dc05:13::28]:16368 for worker actor/1 +(AReaL) 20260419-07:58:36.136 SyncRPCServer INFO: Werkzeug log level: WARNING +(AReaL) 20260419-07:58:36.137 Guard INFO: Starting Guard on [fdbd:dc05:13::28]:7851 for worker actor/0 +(AReaL) 20260419-07:58:36.206 CLIArgs WARNING: behave_imp_weight_cap and behave_imp_weight_mode are configured but use_decoupled_loss=False. These settings will be ignored. Set use_decoupled_loss=True to enable decoupled loss with importance weight correction. +(AReaL) 20260419-07:58:36.206 EngineBP INFO: Engine thread started +(AReaL) 20260419-07:58:36.207 EngineBP INFO: Engine thread initialized +(AReaL) 20260419-07:58:36.208 SyncRPCServer INFO: Werkzeug log level: WARNING +(AReaL) 20260419-07:58:36.209 Guard INFO: Starting Guard on [fdbd:dc05:13::28]:20768 for worker actor/2 +(AReaL) 20260419-07:58:36.211 PlatformInit INFO: Detected CUDA device: NVIDIA L20 +(AReaL) 20260419-07:58:36.211 PlatformInit INFO: Initializing CUDA platform (NVIDIA). +(AReaL) 20260419-07:58:36.223 LocalScheduler INFO: Configuration successfully on worker 'actor/0' +(AReaL) 20260419-07:58:36.241 SyncRPCServer INFO: Werkzeug log level: WARNING +(AReaL) 20260419-07:58:36.242 Guard INFO: Starting Guard on [fdbd:dc05:13::28]:52722 for worker actor/4 +(AReaL) 20260419-07:58:36.266 CLIArgs WARNING: behave_imp_weight_cap and behave_imp_weight_mode are configured but use_decoupled_loss=False. These settings will be ignored. Set use_decoupled_loss=True to enable decoupled loss with importance weight correction. +(AReaL) 20260419-07:58:36.267 EngineBP INFO: Engine thread started +(AReaL) 20260419-07:58:36.267 EngineBP INFO: Engine thread initialized +(AReaL) 20260419-07:58:36.271 PlatformInit INFO: Detected CUDA device: NVIDIA L20 +(AReaL) 20260419-07:58:36.271 PlatformInit INFO: Initializing CUDA platform (NVIDIA). +(AReaL) 20260419-07:58:36.283 LocalScheduler INFO: Configuration successfully on worker 'actor/1' +(AReaL) 20260419-07:58:36.328 CLIArgs WARNING: behave_imp_weight_cap and behave_imp_weight_mode are configured but use_decoupled_loss=False. These settings will be ignored. Set use_decoupled_loss=True to enable decoupled loss with importance weight correction. +(AReaL) 20260419-07:58:36.328 EngineBP INFO: Engine thread started +(AReaL) 20260419-07:58:36.328 EngineBP INFO: Engine thread initialized +(AReaL) 20260419-07:58:36.332 PlatformInit INFO: Detected CUDA device: NVIDIA L20 +(AReaL) 20260419-07:58:36.332 PlatformInit INFO: Initializing CUDA platform (NVIDIA). +(AReaL) 20260419-07:58:36.338 LocalScheduler INFO: Configuration successfully on worker 'actor/2' +(AReaL) 20260419-07:58:36.350 SyncRPCServer INFO: Werkzeug log level: WARNING +(AReaL) 20260419-07:58:36.351 Guard INFO: Starting Guard on [fdbd:dc05:13::28]:16842 for worker actor/3 +(AReaL) 20260419-07:58:36.495 CLIArgs WARNING: behave_imp_weight_cap and behave_imp_weight_mode are configured but use_decoupled_loss=False. These settings will be ignored. Set use_decoupled_loss=True to enable decoupled loss with importance weight correction. +(AReaL) 20260419-07:58:36.495 EngineBP INFO: Engine thread started +(AReaL) 20260419-07:58:36.495 EngineBP INFO: Engine thread initialized +(AReaL) 20260419-07:58:36.499 PlatformInit INFO: Detected CUDA device: NVIDIA L20 +(AReaL) 20260419-07:58:36.499 PlatformInit INFO: Initializing CUDA platform (NVIDIA). +(AReaL) 20260419-07:58:36.505 LocalScheduler INFO: Configuration successfully on worker 'actor/3' +(AReaL) 20260419-07:58:36.563 CLIArgs WARNING: behave_imp_weight_cap and behave_imp_weight_mode are configured but use_decoupled_loss=False. These settings will be ignored. Set use_decoupled_loss=True to enable decoupled loss with importance weight correction. +(AReaL) 20260419-07:58:36.563 EngineBP INFO: Engine thread started +(AReaL) 20260419-07:58:36.563 EngineBP INFO: Engine thread initialized +(AReaL) 20260419-07:58:36.567 PlatformInit INFO: Detected CUDA device: NVIDIA L20 +(AReaL) 20260419-07:58:36.567 PlatformInit INFO: Initializing CUDA platform (NVIDIA). +(AReaL) 20260419-07:58:36.579 LocalScheduler INFO: Configuration successfully on worker 'actor/4' +(AReaL) 20260419-07:58:37.122 SyncRPCServer INFO: Werkzeug log level: WARNING +(AReaL) 20260419-07:58:37.123 Guard INFO: Starting Guard on [fdbd:dc05:13::28]:21975 for worker actor/7 +(AReaL) 20260419-07:58:37.228 SyncRPCServer INFO: Werkzeug log level: WARNING +(AReaL) 20260419-07:58:37.229 Guard INFO: Starting Guard on [fdbd:dc05:13::28]:4458 for worker actor/6 +(AReaL) 20260419-07:58:37.328 SyncRPCServer INFO: Werkzeug log level: WARNING +(AReaL) 20260419-07:58:37.329 Guard INFO: Starting Guard on [fdbd:dc05:13::28]:14757 for worker actor/5 +(AReaL) 20260419-07:58:37.395 CLIArgs WARNING: behave_imp_weight_cap and behave_imp_weight_mode are configured but use_decoupled_loss=False. These settings will be ignored. Set use_decoupled_loss=True to enable decoupled loss with importance weight correction. +(AReaL) 20260419-07:58:37.395 EngineBP INFO: Engine thread started +(AReaL) 20260419-07:58:37.395 EngineBP INFO: Engine thread initialized +(AReaL) 20260419-07:58:37.399 PlatformInit INFO: Detected CUDA device: NVIDIA L20 +(AReaL) 20260419-07:58:37.399 PlatformInit INFO: Initializing CUDA platform (NVIDIA). +(AReaL) 20260419-07:58:37.405 LocalScheduler INFO: Configuration successfully on worker 'actor/5' +(AReaL) 20260419-07:58:37.446 CLIArgs WARNING: behave_imp_weight_cap and behave_imp_weight_mode are configured but use_decoupled_loss=False. These settings will be ignored. Set use_decoupled_loss=True to enable decoupled loss with importance weight correction. +(AReaL) 20260419-07:58:37.446 EngineBP INFO: Engine thread started +(AReaL) 20260419-07:58:37.446 EngineBP INFO: Engine thread initialized +(AReaL) 20260419-07:58:37.450 PlatformInit INFO: Detected CUDA device: NVIDIA L20 +(AReaL) 20260419-07:58:37.450 PlatformInit INFO: Initializing CUDA platform (NVIDIA). +(AReaL) 20260419-07:58:37.462 LocalScheduler INFO: Configuration successfully on worker 'actor/6' +(AReaL) 20260419-07:58:37.519 CLIArgs WARNING: behave_imp_weight_cap and behave_imp_weight_mode are configured but use_decoupled_loss=False. These settings will be ignored. Set use_decoupled_loss=True to enable decoupled loss with importance weight correction. +(AReaL) 20260419-07:58:37.519 EngineBP INFO: Engine thread started +(AReaL) 20260419-07:58:37.519 EngineBP INFO: Engine thread initialized +(AReaL) 20260419-07:58:37.523 PlatformInit INFO: Detected CUDA device: NVIDIA L20 +(AReaL) 20260419-07:58:37.523 PlatformInit INFO: Initializing CUDA platform (NVIDIA). +(AReaL) 20260419-07:58:37.535 LocalScheduler INFO: Configuration successfully on worker 'actor/7' +(AReaL) 20260419-07:58:37.536 TrainController INFO: Workers created: ['actor/0', 'actor/1', 'actor/2', 'actor/3', 'actor/4', 'actor/5', 'actor/6', 'actor/7'] +(AReaL) 20260419-07:58:37.536 TrainController INFO: Waiting for workers to be ready... +(AReaL) 20260419-07:58:37.733 LocalScheduler INFO: All 8 workers for role 'actor' are ready +(AReaL) 20260419-07:58:37.733 TrainController INFO: Workers ready: ['actor/0', 'actor/1', 'actor/2', 'actor/3', 'actor/4', 'actor/5', 'actor/6', 'actor/7'] +(AReaL) 20260419-07:58:37.733 TrainController INFO: Distributed training: MASTER_ADDR=fdbd:dc05:13::28, MASTER_PORT=34213 +(AReaL) 20260419-07:58:37.734 TrainController INFO: Creating engines on workers... +(AReaL) 20260419-07:58:37.737 EngineBP INFO: Set RANK=7 +(AReaL) 20260419-07:58:37.737 EngineBP INFO: Set WORLD_SIZE=8 +(AReaL) 20260419-07:58:37.737 EngineBP INFO: Set RANK=5 +(AReaL) 20260419-07:58:37.737 EngineBP INFO: Set RANK=2 +(AReaL) 20260419-07:58:37.737 EngineBP INFO: Set MASTER_ADDR=fdbd:dc05:13::28 +(AReaL) 20260419-07:58:37.737 EngineBP INFO: Set RANK=3 +(AReaL) 20260419-07:58:37.737 EngineBP INFO: Set RANK=6 +(AReaL) 20260419-07:58:37.737 EngineBP INFO: Set WORLD_SIZE=8 +(AReaL) 20260419-07:58:37.737 EngineBP INFO: Set WORLD_SIZE=8 +(AReaL) 20260419-07:58:37.737 EngineBP INFO: Set MASTER_PORT=34213 +(AReaL) 20260419-07:58:37.737 EngineBP INFO: Set WORLD_SIZE=8 +(AReaL) 20260419-07:58:37.737 EngineBP INFO: Set RANK=0 +(AReaL) 20260419-07:58:37.737 EngineBP INFO: Set WORLD_SIZE=8 +(AReaL) 20260419-07:58:37.737 EngineBP INFO: Set MASTER_ADDR=fdbd:dc05:13::28 +(AReaL) 20260419-07:58:37.737 EngineBP INFO: Set RANK=4 +(AReaL) 20260419-07:58:37.737 EngineBP INFO: Set LOCAL_RANK=0 +(AReaL) 20260419-07:58:37.737 EngineBP INFO: Set MASTER_ADDR=fdbd:dc05:13::28 +(AReaL) 20260419-07:58:37.737 EngineBP INFO: Set RANK=1 +(AReaL) 20260419-07:58:37.737 EngineBP INFO: Set MASTER_ADDR=fdbd:dc05:13::28 +(AReaL) 20260419-07:58:37.737 EngineBP INFO: Set MASTER_ADDR=fdbd:dc05:13::28 +(AReaL) 20260419-07:58:37.737 EngineBP INFO: Set MASTER_PORT=34213 +(AReaL) 20260419-07:58:37.737 EngineBP INFO: Set WORLD_SIZE=8 +(AReaL) 20260419-07:58:37.737 EngineBP INFO: Set WORLD_SIZE=8 +(AReaL) 20260419-07:58:37.737 EngineBP INFO: Set MASTER_PORT=34213 +(AReaL) 20260419-07:58:37.737 EngineBP INFO: Set MASTER_PORT=34213 +(AReaL) 20260419-07:58:37.737 EngineBP INFO: Set WORLD_SIZE=8 +(AReaL) 20260419-07:58:37.737 EngineBP INFO: Set MASTER_PORT=34213 +(AReaL) 20260419-07:58:37.737 EngineBP INFO: Set LOCAL_RANK=0 +(AReaL) 20260419-07:58:37.737 EngineBP INFO: Set MASTER_ADDR=fdbd:dc05:13::28 +(AReaL) 20260419-07:58:37.737 EngineBP INFO: Set MASTER_ADDR=fdbd:dc05:13::28 +(AReaL) 20260419-07:58:37.737 EngineBP INFO: Set LOCAL_RANK=0 +(AReaL) 20260419-07:58:37.737 EngineBP INFO: Set LOCAL_RANK=0 +(AReaL) 20260419-07:58:37.737 EngineBP INFO: Set MASTER_ADDR=fdbd:dc05:13::28 +(AReaL) 20260419-07:58:37.737 EngineBP INFO: Set LOCAL_RANK=0 +(AReaL) 20260419-07:58:37.737 EngineBP INFO: Set MASTER_PORT=34213 +(AReaL) 20260419-07:58:37.737 EngineBP INFO: Set MASTER_PORT=34213 +(AReaL) 20260419-07:58:37.737 EngineBP INFO: Set MASTER_PORT=34213 +(AReaL) 20260419-07:58:37.737 EngineBP INFO: Set LOCAL_RANK=0 +(AReaL) 20260419-07:58:37.737 EngineBP INFO: Set LOCAL_RANK=0 +(AReaL) 20260419-07:58:37.737 EngineBP INFO: Set LOCAL_RANK=0 +(AReaL) 20260419-07:58:41.233 BailingMoe INFO: Patched apply_rotary_pos_emb: extend truncated freq table for MLA THD+CP>1 +(AReaL) 20260419-07:58:41.262 TreeAttentionFSDP INFO: Compiled torch flex attention. Options: {'epilogue_fusion': True, 'max_autotune': False, 'shape_padding': True, 'trace.enabled': False, 'triton.cudagraphs': False}, dynamic: True +(AReaL) 20260419-07:58:41.264 TreeAttentionFSDP INFO: Using block mask in flex attention, block size: 128 +(AReaL) 20260419-07:58:41.275 BailingMoe INFO: Patched apply_rotary_pos_emb: extend truncated freq table for MLA THD+CP>1 +(AReaL) 20260419-07:58:41.275 PPOActor INFO: ====================================================================== +(AReaL) 20260419-07:58:41.275 PPOActor INFO: PPOActor Configuration +(AReaL) 20260419-07:58:41.276 PPOActor INFO: ====================================================================== +(AReaL) 20260419-07:58:41.276 PPOActor INFO: Mode: Decoupled PPO (off-policy) +(AReaL) 20260419-07:58:41.276 PPOActor INFO: log_p_behave (π_behave): FROM INFERENCE (behavior policy) +(AReaL) 20260419-07:58:41.276 PPOActor INFO: Proximal policy (π_prox): RECOMPUTED via forward pass (standard decoupled PPO) +(AReaL) 20260419-07:58:41.276 PPOActor INFO: log_p_theta (π_θ): TRAINING FORWARD PASS (current policy) +(AReaL) 20260419-07:58:41.276 PPOActor INFO: Importance weight cap: 5.0 (filters out tokens with extreme weights) +(AReaL) 20260419-07:58:41.276 PPOActor INFO: ====================================================================== +(AReaL) 20260419-07:58:41.276 PPOActor INFO: Training Parameters: +(AReaL) 20260419-07:58:41.276 PPOActor INFO: importance_sampling_level: token +(AReaL) 20260419-07:58:41.276 PPOActor INFO: adv_norm: NormConfig(mean_level='batch', mean_leave1out=False, std_level='batch', std_unbiased=True, eps=1e-05, group_size=1) +(AReaL) 20260419-07:58:41.276 PPOActor INFO: reward_norm: NormConfig(mean_level='group', mean_leave1out=False, std_level='group', std_unbiased=True, eps=1e-05, group_size=4) +(AReaL) 20260419-07:58:41.276 PPOActor INFO: eps_clip: 0.4 +(AReaL) 20260419-07:58:41.276 PPOActor INFO: ====================================================================== +(AReaL) 20260419-07:58:41.276 EngineBP INFO: Engine 'actor/1' (class: areal.engine.megatron_engine.MegatronPPOActor) instantiated successfully +(AReaL) 20260419-07:58:41.288 BailingMoe INFO: Patched apply_rotary_pos_emb: extend truncated freq table for MLA THD+CP>1 +(AReaL) 20260419-07:58:41.293 BailingMoe INFO: Patched apply_rotary_pos_emb: extend truncated freq table for MLA THD+CP>1 +(AReaL) 20260419-07:58:41.297 BailingMoe INFO: Patched apply_rotary_pos_emb: extend truncated freq table for MLA THD+CP>1 +(AReaL) 20260419-07:58:41.303 BailingMoe INFO: Patched apply_rotary_pos_emb: extend truncated freq table for MLA THD+CP>1 +(AReaL) 20260419-07:58:41.305 TreeAttentionFSDP INFO: Compiled torch flex attention. Options: {'epilogue_fusion': True, 'max_autotune': False, 'shape_padding': True, 'trace.enabled': False, 'triton.cudagraphs': False}, dynamic: True +(AReaL) 20260419-07:58:41.307 TreeAttentionFSDP INFO: Using block mask in flex attention, block size: 128 +(AReaL) 20260419-07:58:41.318 PPOActor INFO: ====================================================================== +(AReaL) 20260419-07:58:41.318 PPOActor INFO: PPOActor Configuration +(AReaL) 20260419-07:58:41.318 PPOActor INFO: ====================================================================== +(AReaL) 20260419-07:58:41.318 PPOActor INFO: Mode: Decoupled PPO (off-policy) +(AReaL) 20260419-07:58:41.318 TreeAttentionFSDP INFO: Compiled torch flex attention. Options: {'epilogue_fusion': True, 'max_autotune': False, 'shape_padding': True, 'trace.enabled': False, 'triton.cudagraphs': False}, dynamic: True +(AReaL) 20260419-07:58:41.318 PPOActor INFO: log_p_behave (π_behave): FROM INFERENCE (behavior policy) +(AReaL) 20260419-07:58:41.318 PPOActor INFO: Proximal policy (π_prox): RECOMPUTED via forward pass (standard decoupled PPO) +(AReaL) 20260419-07:58:41.318 PPOActor INFO: log_p_theta (π_θ): TRAINING FORWARD PASS (current policy) +(AReaL) 20260419-07:58:41.318 PPOActor INFO: Importance weight cap: 5.0 (filters out tokens with extreme weights) +(AReaL) 20260419-07:58:41.319 PPOActor INFO: ====================================================================== +(AReaL) 20260419-07:58:41.319 PPOActor INFO: Training Parameters: +(AReaL) 20260419-07:58:41.319 PPOActor INFO: importance_sampling_level: token +(AReaL) 20260419-07:58:41.319 PPOActor INFO: adv_norm: NormConfig(mean_level='batch', mean_leave1out=False, std_level='batch', std_unbiased=True, eps=1e-05, group_size=1) +(AReaL) 20260419-07:58:41.319 PPOActor INFO: reward_norm: NormConfig(mean_level='group', mean_leave1out=False, std_level='group', std_unbiased=True, eps=1e-05, group_size=4) +(AReaL) 20260419-07:58:41.319 PPOActor INFO: eps_clip: 0.4 +(AReaL) 20260419-07:58:41.319 PPOActor INFO: ====================================================================== +(AReaL) 20260419-07:58:41.319 EngineBP INFO: Engine 'actor/0' (class: areal.engine.megatron_engine.MegatronPPOActor) instantiated successfully +(AReaL) 20260419-07:58:41.320 TreeAttentionFSDP INFO: Using block mask in flex attention, block size: 128 +(AReaL) 20260419-07:58:41.323 TreeAttentionFSDP INFO: Compiled torch flex attention. Options: {'epilogue_fusion': True, 'max_autotune': False, 'shape_padding': True, 'trace.enabled': False, 'triton.cudagraphs': False}, dynamic: True +(AReaL) 20260419-07:58:41.325 TreeAttentionFSDP INFO: Using block mask in flex attention, block size: 128 +(AReaL) 20260419-07:58:41.328 TreeAttentionFSDP INFO: Compiled torch flex attention. Options: {'epilogue_fusion': True, 'max_autotune': False, 'shape_padding': True, 'trace.enabled': False, 'triton.cudagraphs': False}, dynamic: True +(AReaL) 20260419-07:58:41.330 TreeAttentionFSDP INFO: Using block mask in flex attention, block size: 128 +(AReaL) 20260419-07:58:41.332 PPOActor INFO: ====================================================================== +(AReaL) 20260419-07:58:41.332 PPOActor INFO: PPOActor Configuration +(AReaL) 20260419-07:58:41.332 PPOActor INFO: ====================================================================== +(AReaL) 20260419-07:58:41.333 PPOActor INFO: Mode: Decoupled PPO (off-policy) +(AReaL) 20260419-07:58:41.333 PPOActor INFO: log_p_behave (π_behave): FROM INFERENCE (behavior policy) +(AReaL) 20260419-07:58:41.333 PPOActor INFO: Proximal policy (π_prox): RECOMPUTED via forward pass (standard decoupled PPO) +(AReaL) 20260419-07:58:41.333 PPOActor INFO: log_p_theta (π_θ): TRAINING FORWARD PASS (current policy) +(AReaL) 20260419-07:58:41.333 PPOActor INFO: Importance weight cap: 5.0 (filters out tokens with extreme weights) +(AReaL) 20260419-07:58:41.333 PPOActor INFO: ====================================================================== +(AReaL) 20260419-07:58:41.333 PPOActor INFO: Training Parameters: +(AReaL) 20260419-07:58:41.333 PPOActor INFO: importance_sampling_level: token +(AReaL) 20260419-07:58:41.333 PPOActor INFO: adv_norm: NormConfig(mean_level='batch', mean_leave1out=False, std_level='batch', std_unbiased=True, eps=1e-05, group_size=1) +(AReaL) 20260419-07:58:41.333 PPOActor INFO: reward_norm: NormConfig(mean_level='group', mean_leave1out=False, std_level='group', std_unbiased=True, eps=1e-05, group_size=4) +(AReaL) 20260419-07:58:41.333 PPOActor INFO: eps_clip: 0.4 +(AReaL) 20260419-07:58:41.333 TreeAttentionFSDP INFO: Compiled torch flex attention. Options: {'epilogue_fusion': True, 'max_autotune': False, 'shape_padding': True, 'trace.enabled': False, 'triton.cudagraphs': False}, dynamic: True +(AReaL) 20260419-07:58:41.333 PPOActor INFO: ====================================================================== +(AReaL) 20260419-07:58:41.333 EngineBP INFO: Engine 'actor/5' (class: areal.engine.megatron_engine.MegatronPPOActor) instantiated successfully +(AReaL) 20260419-07:58:41.335 TreeAttentionFSDP INFO: Using block mask in flex attention, block size: 128 +(AReaL) 20260419-07:58:41.336 PPOActor INFO: ====================================================================== +(AReaL) 20260419-07:58:41.336 PPOActor INFO: PPOActor Configuration +(AReaL) 20260419-07:58:41.337 PPOActor INFO: ====================================================================== +(AReaL) 20260419-07:58:41.337 PPOActor INFO: Mode: Decoupled PPO (off-policy) +(AReaL) 20260419-07:58:41.337 PPOActor INFO: log_p_behave (π_behave): FROM INFERENCE (behavior policy) +(AReaL) 20260419-07:58:41.337 PPOActor INFO: Proximal policy (π_prox): RECOMPUTED via forward pass (standard decoupled PPO) +(AReaL) 20260419-07:58:41.337 PPOActor INFO: log_p_theta (π_θ): TRAINING FORWARD PASS (current policy) +(AReaL) 20260419-07:58:41.337 PPOActor INFO: Importance weight cap: 5.0 (filters out tokens with extreme weights) +(AReaL) 20260419-07:58:41.337 PPOActor INFO: ====================================================================== +(AReaL) 20260419-07:58:41.337 PPOActor INFO: Training Parameters: +(AReaL) 20260419-07:58:41.337 PPOActor INFO: importance_sampling_level: token +(AReaL) 20260419-07:58:41.337 PPOActor INFO: adv_norm: NormConfig(mean_level='batch', mean_leave1out=False, std_level='batch', std_unbiased=True, eps=1e-05, group_size=1) +(AReaL) 20260419-07:58:41.337 PPOActor INFO: reward_norm: NormConfig(mean_level='group', mean_leave1out=False, std_level='group', std_unbiased=True, eps=1e-05, group_size=4) +(AReaL) 20260419-07:58:41.337 PPOActor INFO: eps_clip: 0.4 +(AReaL) 20260419-07:58:41.337 PPOActor INFO: ====================================================================== +(AReaL) 20260419-07:58:41.337 EngineBP INFO: Engine 'actor/2' (class: areal.engine.megatron_engine.MegatronPPOActor) instantiated successfully +(AReaL) 20260419-07:58:41.341 PPOActor INFO: ====================================================================== +(AReaL) 20260419-07:58:41.341 PPOActor INFO: PPOActor Configuration +(AReaL) 20260419-07:58:41.341 PPOActor INFO: ====================================================================== +(AReaL) 20260419-07:58:41.341 PPOActor INFO: Mode: Decoupled PPO (off-policy) +(AReaL) 20260419-07:58:41.341 PPOActor INFO: log_p_behave (π_behave): FROM INFERENCE (behavior policy) +(AReaL) 20260419-07:58:41.341 PPOActor INFO: Proximal policy (π_prox): RECOMPUTED via forward pass (standard decoupled PPO) +(AReaL) 20260419-07:58:41.341 PPOActor INFO: log_p_theta (π_θ): TRAINING FORWARD PASS (current policy) +(AReaL) 20260419-07:58:41.341 PPOActor INFO: Importance weight cap: 5.0 (filters out tokens with extreme weights) +(AReaL) 20260419-07:58:41.342 PPOActor INFO: ====================================================================== +(AReaL) 20260419-07:58:41.342 PPOActor INFO: Training Parameters: +(AReaL) 20260419-07:58:41.342 PPOActor INFO: importance_sampling_level: token +(AReaL) 20260419-07:58:41.342 PPOActor INFO: adv_norm: NormConfig(mean_level='batch', mean_leave1out=False, std_level='batch', std_unbiased=True, eps=1e-05, group_size=1) +(AReaL) 20260419-07:58:41.342 PPOActor INFO: reward_norm: NormConfig(mean_level='group', mean_leave1out=False, std_level='group', std_unbiased=True, eps=1e-05, group_size=4) +(AReaL) 20260419-07:58:41.342 PPOActor INFO: eps_clip: 0.4 +(AReaL) 20260419-07:58:41.342 PPOActor INFO: ====================================================================== +(AReaL) 20260419-07:58:41.342 EngineBP INFO: Engine 'actor/3' (class: areal.engine.megatron_engine.MegatronPPOActor) instantiated successfully +(AReaL) 20260419-07:58:41.347 PPOActor INFO: ====================================================================== +(AReaL) 20260419-07:58:41.347 PPOActor INFO: PPOActor Configuration +(AReaL) 20260419-07:58:41.347 PPOActor INFO: ====================================================================== +(AReaL) 20260419-07:58:41.347 PPOActor INFO: Mode: Decoupled PPO (off-policy) +(AReaL) 20260419-07:58:41.347 PPOActor INFO: log_p_behave (π_behave): FROM INFERENCE (behavior policy) +(AReaL) 20260419-07:58:41.348 PPOActor INFO: Proximal policy (π_prox): RECOMPUTED via forward pass (standard decoupled PPO) +(AReaL) 20260419-07:58:41.348 PPOActor INFO: log_p_theta (π_θ): TRAINING FORWARD PASS (current policy) +(AReaL) 20260419-07:58:41.348 PPOActor INFO: Importance weight cap: 5.0 (filters out tokens with extreme weights) +(AReaL) 20260419-07:58:41.348 PPOActor INFO: ====================================================================== +(AReaL) 20260419-07:58:41.348 PPOActor INFO: Training Parameters: +(AReaL) 20260419-07:58:41.348 PPOActor INFO: importance_sampling_level: token +(AReaL) 20260419-07:58:41.348 PPOActor INFO: adv_norm: NormConfig(mean_level='batch', mean_leave1out=False, std_level='batch', std_unbiased=True, eps=1e-05, group_size=1) +(AReaL) 20260419-07:58:41.348 PPOActor INFO: reward_norm: NormConfig(mean_level='group', mean_leave1out=False, std_level='group', std_unbiased=True, eps=1e-05, group_size=4) +(AReaL) 20260419-07:58:41.348 PPOActor INFO: eps_clip: 0.4 +(AReaL) 20260419-07:58:41.348 PPOActor INFO: ====================================================================== +(AReaL) 20260419-07:58:41.348 EngineBP INFO: Engine 'actor/7' (class: areal.engine.megatron_engine.MegatronPPOActor) instantiated successfully +(AReaL) 20260419-07:58:41.428 BailingMoe INFO: Patched apply_rotary_pos_emb: extend truncated freq table for MLA THD+CP>1 +(AReaL) 20260419-07:58:41.458 TreeAttentionFSDP INFO: Compiled torch flex attention. Options: {'epilogue_fusion': True, 'max_autotune': False, 'shape_padding': True, 'trace.enabled': False, 'triton.cudagraphs': False}, dynamic: True +(AReaL) 20260419-07:58:41.460 TreeAttentionFSDP INFO: Using block mask in flex attention, block size: 128 +(AReaL) 20260419-07:58:41.472 PPOActor INFO: ====================================================================== +(AReaL) 20260419-07:58:41.472 PPOActor INFO: PPOActor Configuration +(AReaL) 20260419-07:58:41.472 PPOActor INFO: ====================================================================== +(AReaL) 20260419-07:58:41.472 PPOActor INFO: Mode: Decoupled PPO (off-policy) +(AReaL) 20260419-07:58:41.472 PPOActor INFO: log_p_behave (π_behave): FROM INFERENCE (behavior policy) +(AReaL) 20260419-07:58:41.472 PPOActor INFO: Proximal policy (π_prox): RECOMPUTED via forward pass (standard decoupled PPO) +(AReaL) 20260419-07:58:41.473 PPOActor INFO: log_p_theta (π_θ): TRAINING FORWARD PASS (current policy) +(AReaL) 20260419-07:58:41.473 PPOActor INFO: Importance weight cap: 5.0 (filters out tokens with extreme weights) +(AReaL) 20260419-07:58:41.473 PPOActor INFO: ====================================================================== +(AReaL) 20260419-07:58:41.473 PPOActor INFO: Training Parameters: +(AReaL) 20260419-07:58:41.473 PPOActor INFO: importance_sampling_level: token +(AReaL) 20260419-07:58:41.473 PPOActor INFO: adv_norm: NormConfig(mean_level='batch', mean_leave1out=False, std_level='batch', std_unbiased=True, eps=1e-05, group_size=1) +(AReaL) 20260419-07:58:41.473 PPOActor INFO: reward_norm: NormConfig(mean_level='group', mean_leave1out=False, std_level='group', std_unbiased=True, eps=1e-05, group_size=4) +(AReaL) 20260419-07:58:41.473 PPOActor INFO: eps_clip: 0.4 +(AReaL) 20260419-07:58:41.473 PPOActor INFO: ====================================================================== +(AReaL) 20260419-07:58:41.473 EngineBP INFO: Engine 'actor/6' (class: areal.engine.megatron_engine.MegatronPPOActor) instantiated successfully +(AReaL) 20260419-07:58:42.335 BailingMoe INFO: Patched apply_rotary_pos_emb: extend truncated freq table for MLA THD+CP>1 +(AReaL) 20260419-07:58:42.366 TreeAttentionFSDP INFO: Compiled torch flex attention. Options: {'epilogue_fusion': True, 'max_autotune': False, 'shape_padding': True, 'trace.enabled': False, 'triton.cudagraphs': False}, dynamic: True +(AReaL) 20260419-07:58:42.368 TreeAttentionFSDP INFO: Using block mask in flex attention, block size: 128 +(AReaL) 20260419-07:58:42.380 PPOActor INFO: ====================================================================== +(AReaL) 20260419-07:58:42.380 PPOActor INFO: PPOActor Configuration +(AReaL) 20260419-07:58:42.380 PPOActor INFO: ====================================================================== +(AReaL) 20260419-07:58:42.380 PPOActor INFO: Mode: Decoupled PPO (off-policy) +(AReaL) 20260419-07:58:42.380 PPOActor INFO: log_p_behave (π_behave): FROM INFERENCE (behavior policy) +(AReaL) 20260419-07:58:42.380 PPOActor INFO: Proximal policy (π_prox): RECOMPUTED via forward pass (standard decoupled PPO) +(AReaL) 20260419-07:58:42.380 PPOActor INFO: log_p_theta (π_θ): TRAINING FORWARD PASS (current policy) +(AReaL) 20260419-07:58:42.380 PPOActor INFO: Importance weight cap: 5.0 (filters out tokens with extreme weights) +(AReaL) 20260419-07:58:42.380 PPOActor INFO: ====================================================================== +(AReaL) 20260419-07:58:42.380 PPOActor INFO: Training Parameters: +(AReaL) 20260419-07:58:42.380 PPOActor INFO: importance_sampling_level: token +(AReaL) 20260419-07:58:42.380 PPOActor INFO: adv_norm: NormConfig(mean_level='batch', mean_leave1out=False, std_level='batch', std_unbiased=True, eps=1e-05, group_size=1) +(AReaL) 20260419-07:58:42.380 PPOActor INFO: reward_norm: NormConfig(mean_level='group', mean_leave1out=False, std_level='group', std_unbiased=True, eps=1e-05, group_size=4) +(AReaL) 20260419-07:58:42.380 PPOActor INFO: eps_clip: 0.4 +(AReaL) 20260419-07:58:42.380 PPOActor INFO: ====================================================================== +(AReaL) 20260419-07:58:42.381 EngineBP INFO: Engine 'actor/4' (class: areal.engine.megatron_engine.MegatronPPOActor) instantiated successfully +(AReaL) 20260419-07:58:42.382 TrainController INFO: Engines created on all workers! +(AReaL) 20260419-07:58:42.382 TrainController INFO: Calling engine initialization... +[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 +[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 +[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 +[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 +[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 +[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 +[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 +[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 +[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 +[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 +[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 +[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 +[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 +[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 +[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 +[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 +[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 +[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 +[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 +[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 +[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 +[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 +[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 +[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 +[Gloo] Rank 0 is connected to 7 peer ranks. Expected number of connected peer ranks is : 7 +[Gloo] Rank 1 is connected to 7 peer ranks. Expected number of connected peer ranks is : 7 +[Gloo] Rank 4 is connected to 7 peer ranks. Expected number of connected peer ranks is : 7 +[Gloo] Rank 2 is connected to 7 peer ranks. Expected number of connected peer ranks is : 7 +[Gloo] Rank 3 is connected to 7 peer ranks. Expected number of connected peer ranks is : 7 +[Gloo] Rank 6 is connected to 7 peer ranks. Expected number of connected peer ranks is : 7 +[Gloo] Rank 5 is connected to 7 peer ranks. Expected number of connected peer ranks is : 7 +[Gloo] Rank 7 is connected to 7 peer ranks. Expected number of connected peer ranks is : 7 +(AReaL) 20260419-07:58:43.200 [MegatronEngine Rank 3] INFO: Detected 'adam_bf16' optimizer with Megatron Engine. Automatically converting to 'adam' with precision-aware optimizer and setting exp_avg_dtype/exp_avg_sq_dtype to 'bfloat16'. +(AReaL) 20260419-07:58:43.200 [MegatronEngine Rank 4] INFO: Detected 'adam_bf16' optimizer with Megatron Engine. Automatically converting to 'adam' with precision-aware optimizer and setting exp_avg_dtype/exp_avg_sq_dtype to 'bfloat16'. +(AReaL) 20260419-07:58:43.200 [MegatronEngine Rank 6] INFO: Detected 'adam_bf16' optimizer with Megatron Engine. Automatically converting to 'adam' with precision-aware optimizer and setting exp_avg_dtype/exp_avg_sq_dtype to 'bfloat16'. +(AReaL) 20260419-07:58:43.200 [MegatronEngine Rank 5] INFO: Detected 'adam_bf16' optimizer with Megatron Engine. Automatically converting to 'adam' with precision-aware optimizer and setting exp_avg_dtype/exp_avg_sq_dtype to 'bfloat16'. +(AReaL) 20260419-07:58:43.200 [MegatronEngine Rank 2] INFO: Detected 'adam_bf16' optimizer with Megatron Engine. Automatically converting to 'adam' with precision-aware optimizer and setting exp_avg_dtype/exp_avg_sq_dtype to 'bfloat16'. +(AReaL) 20260419-07:58:43.200 [MegatronEngine Rank 1] INFO: Detected 'adam_bf16' optimizer with Megatron Engine. Automatically converting to 'adam' with precision-aware optimizer and setting exp_avg_dtype/exp_avg_sq_dtype to 'bfloat16'. +(AReaL) 20260419-07:58:43.200 [MegatronEngine Rank 7] INFO: Detected 'adam_bf16' optimizer with Megatron Engine. Automatically converting to 'adam' with precision-aware optimizer and setting exp_avg_dtype/exp_avg_sq_dtype to 'bfloat16'. +(AReaL) 20260419-07:58:43.200 [MegatronEngine Rank 0] INFO: Detected 'adam_bf16' optimizer with Megatron Engine. Automatically converting to 'adam' with precision-aware optimizer and setting exp_avg_dtype/exp_avg_sq_dtype to 'bfloat16'. +(AReaL) 20260419-07:58:43.674 CUDAPlatform INFO: Set NUMA affinity for GPU 0: bound to 24 CPU cores. +(AReaL) 20260419-07:58:43.887 CUDAPlatform INFO: Set NUMA affinity for GPU 0: bound to 24 CPU cores. +(AReaL) 20260419-07:58:43.905 CUDAPlatform INFO: Set NUMA affinity for GPU 0: bound to 24 CPU cores. +(AReaL) 20260419-07:58:43.918 CUDAPlatform INFO: Set NUMA affinity for GPU 0: bound to 24 CPU cores. +(AReaL) 20260419-07:58:43.928 CUDAPlatform INFO: Set NUMA affinity for GPU 0: bound to 24 CPU cores. +(AReaL) 20260419-07:58:43.940 CUDAPlatform INFO: Set NUMA affinity for GPU 0: bound to 24 CPU cores. +(AReaL) 20260419-07:58:43.940 CUDAPlatform INFO: Set NUMA affinity for GPU 0: bound to 24 CPU cores. +(AReaL) 20260419-07:58:43.953 CUDAPlatform INFO: Set NUMA affinity for GPU 0: bound to 24 CPU cores. +/AReaL/.venv/lib/python3.12/site-packages/megatron/core/transformer/transformer_config.py:1705: UserWarning: full scope is deprecated. Use empty cuda_graph_scope to capture the whole layer. + warnings.warn( +(AReaL) 20260419-07:58:44.154 [MegatronEngine Rank 7] INFO: Using mbridge to create models and hf model save/load in MegatronEngine. +(AReaL) 20260419-07:58:44.154 MCoreParallel INFO: Configured pipeline layout (per-stage decoder counts / params): [14, 13] / ['2665.67M', '2579.68M'] (pp=2, vpp=1) +/AReaL/.venv/lib/python3.12/site-packages/megatron/core/transformer/transformer_config.py:1705: UserWarning: full scope is deprecated. Use empty cuda_graph_scope to capture the whole layer. + warnings.warn( +(AReaL) 20260419-07:58:44.346 [MegatronEngine Rank 2] INFO: Using mbridge to create models and hf model save/load in MegatronEngine. +(AReaL) 20260419-07:58:44.346 MCoreParallel INFO: Configured pipeline layout (per-stage decoder counts / params): [14, 13] / ['2665.67M', '2579.68M'] (pp=2, vpp=1) +/AReaL/.venv/lib/python3.12/site-packages/megatron/core/transformer/transformer_config.py:1705: UserWarning: full scope is deprecated. Use empty cuda_graph_scope to capture the whole layer. + warnings.warn( +(AReaL) 20260419-07:58:44.365 [MegatronEngine Rank 6] INFO: Using mbridge to create models and hf model save/load in MegatronEngine. +(AReaL) 20260419-07:58:44.365 MCoreParallel INFO: Configured pipeline layout (per-stage decoder counts / params): [14, 13] / ['2665.67M', '2579.68M'] (pp=2, vpp=1) +/AReaL/.venv/lib/python3.12/site-packages/megatron/core/transformer/transformer_config.py:1705: UserWarning: full scope is deprecated. Use empty cuda_graph_scope to capture the whole layer. + warnings.warn( +(AReaL) 20260419-07:58:44.391 [MegatronEngine Rank 1] INFO: Using mbridge to create models and hf model save/load in MegatronEngine. +(AReaL) 20260419-07:58:44.392 MCoreParallel INFO: Configured pipeline layout (per-stage decoder counts / params): [14, 13] / ['2665.67M', '2579.68M'] (pp=2, vpp=1) +/AReaL/.venv/lib/python3.12/site-packages/megatron/core/transformer/transformer_config.py:1705: UserWarning: full scope is deprecated. Use empty cuda_graph_scope to capture the whole layer. + warnings.warn( +(AReaL) 20260419-07:58:44.408 [MegatronEngine Rank 4] INFO: Using mbridge to create models and hf model save/load in MegatronEngine. +(AReaL) 20260419-07:58:44.408 MCoreParallel INFO: Configured pipeline layout (per-stage decoder counts / params): [14, 13] / ['2665.67M', '2579.68M'] (pp=2, vpp=1) +/AReaL/.venv/lib/python3.12/site-packages/megatron/core/transformer/transformer_config.py:1705: UserWarning: full scope is deprecated. Use empty cuda_graph_scope to capture the whole layer. + warnings.warn( +(AReaL) 20260419-07:58:44.411 [MegatronEngine Rank 5] INFO: Using mbridge to create models and hf model save/load in MegatronEngine. +(AReaL) 20260419-07:58:44.411 MCoreParallel INFO: Configured pipeline layout (per-stage decoder counts / params): [14, 13] / ['2665.67M', '2579.68M'] (pp=2, vpp=1) +/AReaL/.venv/lib/python3.12/site-packages/megatron/core/transformer/transformer_config.py:1705: UserWarning: full scope is deprecated. Use empty cuda_graph_scope to capture the whole layer. + warnings.warn( +(AReaL) 20260419-07:58:44.417 [MegatronEngine Rank 3] INFO: Using mbridge to create models and hf model save/load in MegatronEngine. +(AReaL) 20260419-07:58:44.417 MCoreParallel INFO: Configured pipeline layout (per-stage decoder counts / params): [14, 13] / ['2665.67M', '2579.68M'] (pp=2, vpp=1) +/AReaL/.venv/lib/python3.12/site-packages/megatron/core/transformer/transformer_config.py:1705: UserWarning: full scope is deprecated. Use empty cuda_graph_scope to capture the whole layer. + warnings.warn( +(AReaL) 20260419-07:58:44.429 [MegatronEngine Rank 0] INFO: Using mbridge to create models and hf model save/load in MegatronEngine. +(AReaL) 20260419-07:58:44.430 MCoreParallel INFO: Configured pipeline layout (per-stage decoder counts / params): [14, 13] / ['2665.67M', '2579.68M'] (pp=2, vpp=1) + > number of parameters on (tensor, pipeline) model parallel rank (3, 1): 1997468160 + > number of parameters on (tensor, pipeline) model parallel rank (0, 1): 1997468160 + > number of parameters on (tensor, pipeline) model parallel rank (1, 1): 1997468160 + > number of parameters on (tensor, pipeline) model parallel rank (2, 0): 2019097600 + > number of parameters on (tensor, pipeline) model parallel rank (3, 0): 2019097600 + > number of parameters on (tensor, pipeline) model parallel rank (2, 1): 1997468160 + > number of parameters on (tensor, pipeline) model parallel rank (1, 0): 2019097600 + > number of parameters on (tensor, pipeline) model parallel rank (0, 0): 2019097600 +(AReaL) 20260419-07:58:55.843 [MegatronEngine Rank 7] INFO: Model parameter count: 1997.47M, pp_stage=1, vpp_chunks=1 +[OptDiag] Megatron OptimizerConfig: use_precision_aware_optimizer=True, use_precision_aware_optimizer_no_fp8_or_ds_fp8=True, store_param_remainders=True, main_params_dtype=torch.float32, main_grads_dtype=torch.bfloat16, exp_avg_dtype=torch.bfloat16, exp_avg_sq_dtype=torch.bfloat16, use_distributed_optimizer=True, bf16=True +(AReaL) 20260419-07:58:55.906 [MegatronEngine Rank 6] INFO: Model parameter count: 1997.47M, pp_stage=1, vpp_chunks=1 +[OptDiag] Megatron OptimizerConfig: use_precision_aware_optimizer=True, use_precision_aware_optimizer_no_fp8_or_ds_fp8=True, store_param_remainders=True, main_params_dtype=torch.float32, main_grads_dtype=torch.bfloat16, exp_avg_dtype=torch.bfloat16, exp_avg_sq_dtype=torch.bfloat16, use_distributed_optimizer=True, bf16=True +(AReaL) 20260419-07:58:55.931 [MegatronEngine Rank 3] INFO: Model parameter count: 2019.10M, pp_stage=0, vpp_chunks=1 +[OptDiag] Megatron OptimizerConfig: use_precision_aware_optimizer=True, use_precision_aware_optimizer_no_fp8_or_ds_fp8=True, store_param_remainders=True, main_params_dtype=torch.float32, main_grads_dtype=torch.bfloat16, exp_avg_dtype=torch.bfloat16, exp_avg_sq_dtype=torch.bfloat16, use_distributed_optimizer=True, bf16=True +(AReaL) 20260419-07:58:55.932 [MegatronEngine Rank 4] INFO: Model parameter count: 1997.47M, pp_stage=1, vpp_chunks=1 +[OptDiag] Megatron OptimizerConfig: use_precision_aware_optimizer=True, use_precision_aware_optimizer_no_fp8_or_ds_fp8=True, store_param_remainders=True, main_params_dtype=torch.float32, main_grads_dtype=torch.bfloat16, exp_avg_dtype=torch.bfloat16, exp_avg_sq_dtype=torch.bfloat16, use_distributed_optimizer=True, bf16=True +(AReaL) 20260419-07:58:55.966 [MegatronEngine Rank 0] INFO: Model parameter count: 2019.10M, pp_stage=0, vpp_chunks=1 +[OptDiag] Megatron OptimizerConfig: use_precision_aware_optimizer=True, use_precision_aware_optimizer_no_fp8_or_ds_fp8=True, store_param_remainders=True, main_params_dtype=torch.float32, main_grads_dtype=torch.bfloat16, exp_avg_dtype=torch.bfloat16, exp_avg_sq_dtype=torch.bfloat16, use_distributed_optimizer=True, bf16=True +(AReaL) 20260419-07:58:55.968 [MegatronEngine Rank 2] INFO: Model parameter count: 2019.10M, pp_stage=0, vpp_chunks=1 +[OptDiag] Megatron OptimizerConfig: use_precision_aware_optimizer=True, use_precision_aware_optimizer_no_fp8_or_ds_fp8=True, store_param_remainders=True, main_params_dtype=torch.float32, main_grads_dtype=torch.bfloat16, exp_avg_dtype=torch.bfloat16, exp_avg_sq_dtype=torch.bfloat16, use_distributed_optimizer=True, bf16=True +(AReaL) 20260419-07:58:56.000 [MegatronEngine Rank 5] INFO: Model parameter count: 1997.47M, pp_stage=1, vpp_chunks=1 +[OptDiag] Megatron OptimizerConfig: use_precision_aware_optimizer=True, use_precision_aware_optimizer_no_fp8_or_ds_fp8=True, store_param_remainders=True, main_params_dtype=torch.float32, main_grads_dtype=torch.bfloat16, exp_avg_dtype=torch.bfloat16, exp_avg_sq_dtype=torch.bfloat16, use_distributed_optimizer=True, bf16=True +(AReaL) 20260419-07:58:56.013 [MegatronEngine Rank 1] INFO: Model parameter count: 2019.10M, pp_stage=0, vpp_chunks=1 +[OptDiag] Megatron OptimizerConfig: use_precision_aware_optimizer=True, use_precision_aware_optimizer_no_fp8_or_ds_fp8=True, store_param_remainders=True, main_params_dtype=torch.float32, main_grads_dtype=torch.bfloat16, exp_avg_dtype=torch.bfloat16, exp_avg_sq_dtype=torch.bfloat16, use_distributed_optimizer=True, bf16=True +(AReaL) 20260419-07:58:57.621 TrainController INFO: All engines are initialized! +(AReaL) 20260419-07:58:57.621 TrainController INFO: Identifying DP head workers... +(AReaL) 20260419-07:58:57.921 TrainController INFO: TrainController initialization complete +(AReaL) 20260419-07:58:57.937 RolloutController WARNING: Placement strategy 'shared' is not supported for rollouts. Forcing to 'separate' strategy +(AReaL) 20260419-07:58:57.937 RolloutController INFO: Creating workers via scheduler... +(AReaL) 20260419-07:58:57.937 LocalScheduler INFO: Creating 1 workers for role 'rollout' (strategy: SchedulingStrategyType.separation, colocate_with: None) +(AReaL) 20260419-07:58:57.938 LauncherUtils INFO: Auto-setting thread env vars to 64: OMP_NUM_THREADS, MKL_NUM_THREADS, OPENBLAS_NUM_THREADS, VECLIB_MAXIMUM_THREADS, NUMEXPR_NUM_THREADS +(AReaL) 20260419-07:58:57.938 LocalScheduler INFO: Starting worker rollout/0: python3 -m areal.infra.rpc.rpc_server --port 53960 --experiment-name moonlight_moe_exp --trial-name moonlight_moe_0419_r3_v2 --role rollout --worker-index 0 --name-resolve-type nfs --nfs-record-root /tmp/areal/moon_name_resolve --etcd3-addr localhost:2379 --fileroot /tmp/areal/moon_experiments +(AReaL) 20260419-07:58:58.039 LocalScheduler INFO: Worker rollout/0 started (PID: 2381599, GPUs: [0, 1, 2, 3, 4, 5, 6, 7], ports: [53960, 55192]) +(AReaL) 20260419-07:58:58.039 LocalScheduler INFO: Successfully created 1 workers for role 'rollout' +(AReaL) 20260419-07:59:04.742 SyncRPCServer INFO: Werkzeug log level: WARNING +(AReaL) 20260419-07:59:04.743 Guard INFO: Starting Guard on [fdbd:dc05:13::28]:53960 for worker rollout/0 +(AReaL) 20260419-07:59:04.879 CLIArgs WARNING: behave_imp_weight_cap and behave_imp_weight_mode are configured but use_decoupled_loss=False. These settings will be ignored. Set use_decoupled_loss=True to enable decoupled loss with importance weight correction. +(AReaL) 20260419-07:59:04.879 EngineBP INFO: Engine thread started +(AReaL) 20260419-07:59:04.880 EngineBP INFO: Engine thread initialized +(AReaL) 20260419-07:59:04.883 PlatformInit INFO: Detected CUDA device: NVIDIA L20 +(AReaL) 20260419-07:59:04.884 PlatformInit INFO: Initializing CUDA platform (NVIDIA). +(AReaL) 20260419-07:59:04.895 LocalScheduler INFO: Configuration successfully on worker 'rollout/0' +(AReaL) 20260419-07:59:04.896 RolloutController INFO: Workers created: ['rollout/0'] +(AReaL) 20260419-07:59:04.896 RolloutController INFO: Waiting for workers to be ready... +(AReaL) 20260419-07:59:04.926 LocalScheduler INFO: All 1 workers for role 'rollout' are ready +(AReaL) 20260419-07:59:04.926 RolloutController INFO: Workers ready: ['rollout/0'] +(AReaL) 20260419-07:59:04.926 RolloutController INFO: Creating engines... +(AReaL) 20260419-07:59:04.928 EngineBP INFO: Engine 'rollout/0' (class: areal.engine.sglang_remote.RemoteSGLangEngine) instantiated successfully +(AReaL) 20260419-07:59:04.928 RolloutController INFO: Engine created on all workers! +(AReaL) 20260419-07:59:04.929 RolloutController INFO: Calling engine initialization... +/AReaL/.venv/lib/python3.12/site-packages/sglang/srt/utils/hf_transformers_utils.py:558: UserWarning: Using a slow tokenizer. This might cause a significant slowdown. Consider using a fast tokenizer instead. + warnings.warn( +/AReaL/.venv/lib/python3.12/site-packages/sglang/srt/utils/hf_transformers_utils.py:558: UserWarning: Using a slow tokenizer. This might cause a significant slowdown. Consider using a fast tokenizer instead. + warnings.warn( +/AReaL/.venv/lib/python3.12/site-packages/sglang/srt/utils/hf_transformers_utils.py:558: UserWarning: Using a slow tokenizer. This might cause a significant slowdown. Consider using a fast tokenizer instead. + warnings.warn( +/AReaL/.venv/lib/python3.12/site-packages/sglang/srt/utils/hf_transformers_utils.py:558: UserWarning: Using a slow tokenizer. This might cause a significant slowdown. Consider using a fast tokenizer instead. + warnings.warn( +/AReaL/.venv/lib/python3.12/site-packages/sglang/srt/utils/hf_transformers_utils.py:558: UserWarning: Using a slow tokenizer. This might cause a significant slowdown. Consider using a fast tokenizer instead. + warnings.warn( +/AReaL/.venv/lib/python3.12/site-packages/sglang/srt/utils/hf_transformers_utils.py:558: UserWarning: Using a slow tokenizer. This might cause a significant slowdown. Consider using a fast tokenizer instead. + warnings.warn( +/AReaL/.venv/lib/python3.12/site-packages/sglang/srt/utils/hf_transformers_utils.py:558: UserWarning: Using a slow tokenizer. This might cause a significant slowdown. Consider using a fast tokenizer instead. + warnings.warn( +/AReaL/.venv/lib/python3.12/site-packages/sglang/srt/utils/hf_transformers_utils.py:558: UserWarning: Using a slow tokenizer. This might cause a significant slowdown. Consider using a fast tokenizer instead. + warnings.warn( +/AReaL/.venv/lib/python3.12/site-packages/sglang/srt/utils/hf_transformers_utils.py:558: UserWarning: Using a slow tokenizer. This might cause a significant slowdown. Consider using a fast tokenizer instead. + warnings.warn( +/AReaL/.venv/lib/python3.12/site-packages/sglang/srt/utils/hf_transformers_utils.py:558: UserWarning: Using a slow tokenizer. This might cause a significant slowdown. Consider using a fast tokenizer instead. + warnings.warn( +[Gloo] Rank 3 is connected to 7 peer ranks. Expected number of connected peer ranks is : 7 +[Gloo] Rank 0 is connected to 7 peer ranks. Expected number of connected peer ranks is : 7 +[Gloo] Rank 4 is connected to 7 peer ranks. Expected number of connected peer ranks is : 7 +[Gloo] Rank 1 is connected to 7 peer ranks. Expected number of connected peer ranks is : 7 +[Gloo] Rank 6 is connected to 7 peer ranks. Expected number of connected peer ranks is : 7 +[Gloo] Rank 5 is connected to 7 peer ranks. Expected number of connected peer ranks is : 7 +[Gloo] Rank 2 is connected to 7 peer ranks. Expected number of connected peer ranks is : 7 +[Gloo] Rank 7 is connected to 7 peer ranks. Expected number of connected peer ranks is : 7 +[Gloo] Rank 0 is connected to 7 peer ranks. Expected number of connected peer ranks is : 7 +[Gloo] Rank 1 is connected to 7 peer ranks. Expected number of connected peer ranks is : 7 +[Gloo] Rank 3 is connected to 7 peer ranks. Expected number of connected peer ranks is : 7 +[Gloo] Rank 2 is connected to 7 peer ranks. Expected number of connected peer ranks is : 7 +[Gloo] Rank 4 is connected to 7 peer ranks. Expected number of connected peer ranks is : 7 +[Gloo] Rank 6 is connected to 7 peer ranks. Expected number of connected peer ranks is : 7 +[Gloo] Rank 5 is connected to 7 peer ranks. Expected number of connected peer ranks is : 7 +[Gloo] Rank 7 is connected to 7 peer ranks. Expected number of connected peer ranks is : 7 +[2026-04-19 07:59:20 TP0] Custom allreduce is disabled because it's not supported on more than two PCIe-only GPUs. To silence this warning, specify disable_custom_all_reduce=True explicitly. +[2026-04-19 07:59:20 TP7] Custom allreduce is disabled because it's not supported on more than two PCIe-only GPUs. To silence this warning, specify disable_custom_all_reduce=True explicitly. +[2026-04-19 07:59:20 TP1] Custom allreduce is disabled because it's not supported on more than two PCIe-only GPUs. To silence this warning, specify disable_custom_all_reduce=True explicitly. +[2026-04-19 07:59:20 TP2] Custom allreduce is disabled because it's not supported on more than two PCIe-only GPUs. To silence this warning, specify disable_custom_all_reduce=True explicitly. +[2026-04-19 07:59:20 TP6] Custom allreduce is disabled because it's not supported on more than two PCIe-only GPUs. To silence this warning, specify disable_custom_all_reduce=True explicitly. +[2026-04-19 07:59:20 TP3] Custom allreduce is disabled because it's not supported on more than two PCIe-only GPUs. To silence this warning, specify disable_custom_all_reduce=True explicitly. +[2026-04-19 07:59:20 TP4] Custom allreduce is disabled because it's not supported on more than two PCIe-only GPUs. To silence this warning, specify disable_custom_all_reduce=True explicitly. +[2026-04-19 07:59:20 TP5] Custom allreduce is disabled because it's not supported on more than two PCIe-only GPUs. To silence this warning, specify disable_custom_all_reduce=True explicitly. +[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 +[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 +[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 +[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 +[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 +[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 +[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 +[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 +[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 +[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 +[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 +[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 +[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 +[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 +[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 +[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 +[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 +[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 +[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 +[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 +[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 +[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 +[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 +[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 +[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 +[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 +[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 +[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 +[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 +[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 +[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 +[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 +[2026-04-19 07:59:23 TP0] Ignore import error when loading sglang.srt.models.glm_ocr: No module named 'transformers.models.glm_ocr' +[2026-04-19 07:59:23 TP0] Ignore import error when loading sglang.srt.models.glm_ocr_nextn: No module named 'transformers.models.glm_ocr' +[2026-04-19 07:59:23 TP0] Ignore import error when loading sglang.srt.models.glmasr: cannot import name 'GlmAsrConfig' from 'transformers' (/AReaL/.venv/lib/python3.12/site-packages/transformers/__init__.py) +[2026-04-19 07:59:23 TP3] Ignore import error when loading sglang.srt.models.glm_ocr: No module named 'transformers.models.glm_ocr' +[2026-04-19 07:59:23 TP3] Ignore import error when loading sglang.srt.models.glm_ocr_nextn: No module named 'transformers.models.glm_ocr' +[2026-04-19 07:59:23 TP3] Ignore import error when loading sglang.srt.models.glmasr: cannot import name 'GlmAsrConfig' from 'transformers' (/AReaL/.venv/lib/python3.12/site-packages/transformers/__init__.py) +[2026-04-19 07:59:23 TP2] Ignore import error when loading sglang.srt.models.glm_ocr: No module named 'transformers.models.glm_ocr' +[2026-04-19 07:59:23 TP2] Ignore import error when loading sglang.srt.models.glm_ocr_nextn: No module named 'transformers.models.glm_ocr' +[2026-04-19 07:59:23 TP2] Ignore import error when loading sglang.srt.models.glmasr: cannot import name 'GlmAsrConfig' from 'transformers' (/AReaL/.venv/lib/python3.12/site-packages/transformers/__init__.py) +[2026-04-19 07:59:23 TP5] Ignore import error when loading sglang.srt.models.glm_ocr: No module named 'transformers.models.glm_ocr' +[2026-04-19 07:59:23 TP5] Ignore import error when loading sglang.srt.models.glm_ocr_nextn: No module named 'transformers.models.glm_ocr' +[2026-04-19 07:59:23 TP5] Ignore import error when loading sglang.srt.models.glmasr: cannot import name 'GlmAsrConfig' from 'transformers' (/AReaL/.venv/lib/python3.12/site-packages/transformers/__init__.py) +[2026-04-19 07:59:23 TP7] Ignore import error when loading sglang.srt.models.glm_ocr: No module named 'transformers.models.glm_ocr' +[2026-04-19 07:59:23 TP7] Ignore import error when loading sglang.srt.models.glm_ocr_nextn: No module named 'transformers.models.glm_ocr' +[2026-04-19 07:59:23 TP7] Ignore import error when loading sglang.srt.models.glmasr: cannot import name 'GlmAsrConfig' from 'transformers' (/AReaL/.venv/lib/python3.12/site-packages/transformers/__init__.py) +[2026-04-19 07:59:23 TP1] Ignore import error when loading sglang.srt.models.glm_ocr: No module named 'transformers.models.glm_ocr' +[2026-04-19 07:59:23 TP1] Ignore import error when loading sglang.srt.models.glm_ocr_nextn: No module named 'transformers.models.glm_ocr' +[2026-04-19 07:59:23 TP1] Ignore import error when loading sglang.srt.models.glmasr: cannot import name 'GlmAsrConfig' from 'transformers' (/AReaL/.venv/lib/python3.12/site-packages/transformers/__init__.py) +[2026-04-19 07:59:23 TP0] Ignore import error when loading sglang.srt.models.midashenglm: Detected that PyTorch and TorchAudio were compiled with different CUDA versions. PyTorch has CUDA version 12.8 whereas TorchAudio has CUDA version 12.9. Please install the TorchAudio version that matches your PyTorch version. +[2026-04-19 07:59:23 TP4] Ignore import error when loading sglang.srt.models.glm_ocr: No module named 'transformers.models.glm_ocr' +[2026-04-19 07:59:23 TP4] Ignore import error when loading sglang.srt.models.glm_ocr_nextn: No module named 'transformers.models.glm_ocr' +[2026-04-19 07:59:23 TP4] Ignore import error when loading sglang.srt.models.glmasr: cannot import name 'GlmAsrConfig' from 'transformers' (/AReaL/.venv/lib/python3.12/site-packages/transformers/__init__.py) +[2026-04-19 07:59:23 TP3] Ignore import error when loading sglang.srt.models.midashenglm: Detected that PyTorch and TorchAudio were compiled with different CUDA versions. PyTorch has CUDA version 12.8 whereas TorchAudio has CUDA version 12.9. Please install the TorchAudio version that matches your PyTorch version. +[2026-04-19 07:59:23 TP2] Ignore import error when loading sglang.srt.models.midashenglm: Detected that PyTorch and TorchAudio were compiled with different CUDA versions. PyTorch has CUDA version 12.8 whereas TorchAudio has CUDA version 12.9. Please install the TorchAudio version that matches your PyTorch version. +[2026-04-19 07:59:23 TP6] Ignore import error when loading sglang.srt.models.glm_ocr: No module named 'transformers.models.glm_ocr' +[2026-04-19 07:59:23 TP6] Ignore import error when loading sglang.srt.models.glm_ocr_nextn: No module named 'transformers.models.glm_ocr' +[2026-04-19 07:59:23 TP6] Ignore import error when loading sglang.srt.models.glmasr: cannot import name 'GlmAsrConfig' from 'transformers' (/AReaL/.venv/lib/python3.12/site-packages/transformers/__init__.py) +[2026-04-19 07:59:23 TP5] Ignore import error when loading sglang.srt.models.midashenglm: Detected that PyTorch and TorchAudio were compiled with different CUDA versions. PyTorch has CUDA version 12.8 whereas TorchAudio has CUDA version 12.9. Please install the TorchAudio version that matches your PyTorch version. +[2026-04-19 07:59:23 TP7] Ignore import error when loading sglang.srt.models.midashenglm: Detected that PyTorch and TorchAudio were compiled with different CUDA versions. PyTorch has CUDA version 12.8 whereas TorchAudio has CUDA version 12.9. Please install the TorchAudio version that matches your PyTorch version. +[2026-04-19 07:59:23 TP1] Ignore import error when loading sglang.srt.models.midashenglm: Detected that PyTorch and TorchAudio were compiled with different CUDA versions. PyTorch has CUDA version 12.8 whereas TorchAudio has CUDA version 12.9. Please install the TorchAudio version that matches your PyTorch version. +[2026-04-19 07:59:23 TP4] Ignore import error when loading sglang.srt.models.midashenglm: Detected that PyTorch and TorchAudio were compiled with different CUDA versions. PyTorch has CUDA version 12.8 whereas TorchAudio has CUDA version 12.9. Please install the TorchAudio version that matches your PyTorch version. +[2026-04-19 07:59:23 TP6] Ignore import error when loading sglang.srt.models.midashenglm: Detected that PyTorch and TorchAudio were compiled with different CUDA versions. PyTorch has CUDA version 12.8 whereas TorchAudio has CUDA version 12.9. Please install the TorchAudio version that matches your PyTorch version. + Loading safetensors checkpoint shards: 0% Completed | 0/27 [00:00 Date: Sun, 19 Apr 2026 18:01:16 +0800 Subject: [PATCH 015/112] fix(engine): remove --- areal/engine/megatron_engine.py | 10 + moonlight_moe_r3.txt | 1094 ------------------------------- 2 files changed, 10 insertions(+), 1094 deletions(-) delete mode 100644 moonlight_moe_r3.txt diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index 3da13d8313..72f7d975b2 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -163,6 +163,16 @@ def __init__(self, config: TrainEngineConfig): self.own_global_group: bool = False self.is_offload: bool = False self._r3_enabled: bool = getattr(config.megatron, "enable_router_replay", False) + if not self._r3_enabled: + self._r3_enabled = getattr(config, "_r3_enable_router_replay", False) + self.logger.info( + "[R3] __init__: _r3_enabled=%s, config.megatron.enable_router_replay=%s, " + "config._r3_enable_router_replay=%s, config.megatron type=%s", + self._r3_enabled, + getattr(config.megatron, "enable_router_replay", ""), + getattr(config, "_r3_enable_router_replay", ""), + type(config.megatron).__name__, + ) self.enable_tree_training: bool = self.config.enable_tree_training # FP8 configuration self.fp8_config = self.mcore_config.fp8_config diff --git a/moonlight_moe_r3.txt b/moonlight_moe_r3.txt deleted file mode 100644 index 90e09281ba..0000000000 --- a/moonlight_moe_r3.txt +++ /dev/null @@ -1,1094 +0,0 @@ -nohup: ignoring input -(AReaL) 20260419-07:57:56.352 CLIArgs WARNING: behave_imp_weight_cap and behave_imp_weight_mode are configured but use_decoupled_loss=False. These settings will be ignored. Set use_decoupled_loss=True to enable decoupled loss with importance weight correction. -(AReaL) 20260419-07:57:56.411 CLIArgs WARNING: behave_imp_weight_cap and behave_imp_weight_mode are configured but use_decoupled_loss=False. These settings will be ignored. Set use_decoupled_loss=True to enable decoupled loss with importance weight correction. -(AReaL) 20260419-07:58:25.394 PlatformInit INFO: Detected CUDA device: NVIDIA L20 -(AReaL) 20260419-07:58:25.394 PlatformInit INFO: Initializing CUDA platform (NVIDIA). -(AReaL) 20260419-07:58:25.395 FileSystemUtils WARNING: cluster.fileroot '/tmp/areal/moon_experiments' is not on a network filesystem. This may cause issues in distributed training where all nodes need access to the same files. Consider using NFS, Lustre, or other shared storage. -(AReaL) 20260419-07:58:25.395 FileSystemUtils WARNING: name_resolve.nfs_record_root '/tmp/areal/moon_name_resolve' is not on a network filesystem. This may cause issues in distributed training where all nodes need access to the same files. Consider using NFS, Lustre, or other shared storage. -(AReaL) 20260419-07:58:25.396 NameResolve INFO: Removing name resolve path: /tmp/areal/moon_name_resolve/root/moonlight_moe_exp/moonlight_moe_0419_r3_v2 -(AReaL) 20260419-07:58:25.397 LocalScheduler INFO: LocalScheduler initialized with GPU devices: [0, 1, 2, 3, 4, 5, 6, 7], log directory: /tmp/areal/moon_experiments/logs/root/moonlight_moe_exp/moonlight_moe_0419_r3_v2 -(AReaL) 20260419-07:58:29.210 BailingMoe INFO: Patched apply_rotary_pos_emb: extend truncated freq table for MLA THD+CP>1 -(AReaL) 20260419-07:58:29.240 TreeAttentionFSDP INFO: Compiled torch flex attention. Options: {'epilogue_fusion': True, 'max_autotune': False, 'shape_padding': True, 'trace.enabled': False, 'triton.cudagraphs': False}, dynamic: True -(AReaL) 20260419-07:58:29.242 TreeAttentionFSDP INFO: Using block mask in flex attention, block size: 128 -[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 -(AReaL) 20260419-07:58:29.272 TrainController INFO: Creating workers via scheduler... -(AReaL) 20260419-07:58:29.272 LocalScheduler INFO: Creating 8 workers for role 'actor' (strategy: SchedulingStrategyType.separation, colocate_with: None) -(AReaL) 20260419-07:58:29.273 LauncherUtils INFO: Auto-setting thread env vars to 8: OMP_NUM_THREADS, MKL_NUM_THREADS, OPENBLAS_NUM_THREADS, VECLIB_MAXIMUM_THREADS, NUMEXPR_NUM_THREADS -(AReaL) 20260419-07:58:29.273 LocalScheduler INFO: Starting worker actor/0: python3 -m areal.infra.rpc.rpc_server --port 7851 --experiment-name moonlight_moe_exp --trial-name moonlight_moe_0419_r3_v2 --role actor --worker-index 0 --name-resolve-type nfs --nfs-record-root /tmp/areal/moon_name_resolve --etcd3-addr localhost:2379 --fileroot /tmp/areal/moon_experiments -(AReaL) 20260419-07:58:29.374 LocalScheduler INFO: Worker actor/0 started (PID: 2380208, GPUs: [0], ports: [7851, 34213]) -(AReaL) 20260419-07:58:29.374 LauncherUtils INFO: Auto-setting thread env vars to 8: OMP_NUM_THREADS, MKL_NUM_THREADS, OPENBLAS_NUM_THREADS, VECLIB_MAXIMUM_THREADS, NUMEXPR_NUM_THREADS -(AReaL) 20260419-07:58:29.374 LocalScheduler INFO: Starting worker actor/1: python3 -m areal.infra.rpc.rpc_server --port 16368 --experiment-name moonlight_moe_exp --trial-name moonlight_moe_0419_r3_v2 --role actor --worker-index 1 --name-resolve-type nfs --nfs-record-root /tmp/areal/moon_name_resolve --etcd3-addr localhost:2379 --fileroot /tmp/areal/moon_experiments -(AReaL) 20260419-07:58:29.475 LocalScheduler INFO: Worker actor/1 started (PID: 2380215, GPUs: [1], ports: [16368, 56139]) -(AReaL) 20260419-07:58:29.475 LauncherUtils INFO: Auto-setting thread env vars to 8: OMP_NUM_THREADS, MKL_NUM_THREADS, OPENBLAS_NUM_THREADS, VECLIB_MAXIMUM_THREADS, NUMEXPR_NUM_THREADS -(AReaL) 20260419-07:58:29.475 LocalScheduler INFO: Starting worker actor/2: python3 -m areal.infra.rpc.rpc_server --port 20768 --experiment-name moonlight_moe_exp --trial-name moonlight_moe_0419_r3_v2 --role actor --worker-index 2 --name-resolve-type nfs --nfs-record-root /tmp/areal/moon_name_resolve --etcd3-addr localhost:2379 --fileroot /tmp/areal/moon_experiments -(AReaL) 20260419-07:58:29.576 LocalScheduler INFO: Worker actor/2 started (PID: 2380222, GPUs: [2], ports: [20768, 26612]) -(AReaL) 20260419-07:58:29.576 LauncherUtils INFO: Auto-setting thread env vars to 8: OMP_NUM_THREADS, MKL_NUM_THREADS, OPENBLAS_NUM_THREADS, VECLIB_MAXIMUM_THREADS, NUMEXPR_NUM_THREADS -(AReaL) 20260419-07:58:29.576 LocalScheduler INFO: Starting worker actor/3: python3 -m areal.infra.rpc.rpc_server --port 16842 --experiment-name moonlight_moe_exp --trial-name moonlight_moe_0419_r3_v2 --role actor --worker-index 3 --name-resolve-type nfs --nfs-record-root /tmp/areal/moon_name_resolve --etcd3-addr localhost:2379 --fileroot /tmp/areal/moon_experiments -(AReaL) 20260419-07:58:29.677 LocalScheduler INFO: Worker actor/3 started (PID: 2380229, GPUs: [3], ports: [16842, 56427]) -(AReaL) 20260419-07:58:29.677 LauncherUtils INFO: Auto-setting thread env vars to 8: OMP_NUM_THREADS, MKL_NUM_THREADS, OPENBLAS_NUM_THREADS, VECLIB_MAXIMUM_THREADS, NUMEXPR_NUM_THREADS -(AReaL) 20260419-07:58:29.677 LocalScheduler INFO: Starting worker actor/4: python3 -m areal.infra.rpc.rpc_server --port 52722 --experiment-name moonlight_moe_exp --trial-name moonlight_moe_0419_r3_v2 --role actor --worker-index 4 --name-resolve-type nfs --nfs-record-root /tmp/areal/moon_name_resolve --etcd3-addr localhost:2379 --fileroot /tmp/areal/moon_experiments -(AReaL) 20260419-07:58:29.778 LocalScheduler INFO: Worker actor/4 started (PID: 2380236, GPUs: [4], ports: [52722, 57060]) -(AReaL) 20260419-07:58:29.778 LauncherUtils INFO: Auto-setting thread env vars to 8: OMP_NUM_THREADS, MKL_NUM_THREADS, OPENBLAS_NUM_THREADS, VECLIB_MAXIMUM_THREADS, NUMEXPR_NUM_THREADS -(AReaL) 20260419-07:58:29.778 LocalScheduler INFO: Starting worker actor/5: python3 -m areal.infra.rpc.rpc_server --port 14757 --experiment-name moonlight_moe_exp --trial-name moonlight_moe_0419_r3_v2 --role actor --worker-index 5 --name-resolve-type nfs --nfs-record-root /tmp/areal/moon_name_resolve --etcd3-addr localhost:2379 --fileroot /tmp/areal/moon_experiments -(AReaL) 20260419-07:58:29.879 LocalScheduler INFO: Worker actor/5 started (PID: 2380243, GPUs: [5], ports: [14757, 20960]) -(AReaL) 20260419-07:58:29.879 LauncherUtils INFO: Auto-setting thread env vars to 8: OMP_NUM_THREADS, MKL_NUM_THREADS, OPENBLAS_NUM_THREADS, VECLIB_MAXIMUM_THREADS, NUMEXPR_NUM_THREADS -(AReaL) 20260419-07:58:29.879 LocalScheduler INFO: Starting worker actor/6: python3 -m areal.infra.rpc.rpc_server --port 4458 --experiment-name moonlight_moe_exp --trial-name moonlight_moe_0419_r3_v2 --role actor --worker-index 6 --name-resolve-type nfs --nfs-record-root /tmp/areal/moon_name_resolve --etcd3-addr localhost:2379 --fileroot /tmp/areal/moon_experiments -(AReaL) 20260419-07:58:29.980 LocalScheduler INFO: Worker actor/6 started (PID: 2380250, GPUs: [6], ports: [4458, 54325]) -(AReaL) 20260419-07:58:29.980 LauncherUtils INFO: Auto-setting thread env vars to 8: OMP_NUM_THREADS, MKL_NUM_THREADS, OPENBLAS_NUM_THREADS, VECLIB_MAXIMUM_THREADS, NUMEXPR_NUM_THREADS -(AReaL) 20260419-07:58:29.980 LocalScheduler INFO: Starting worker actor/7: python3 -m areal.infra.rpc.rpc_server --port 21975 --experiment-name moonlight_moe_exp --trial-name moonlight_moe_0419_r3_v2 --role actor --worker-index 7 --name-resolve-type nfs --nfs-record-root /tmp/areal/moon_name_resolve --etcd3-addr localhost:2379 --fileroot /tmp/areal/moon_experiments -(AReaL) 20260419-07:58:30.081 LocalScheduler INFO: Worker actor/7 started (PID: 2380264, GPUs: [7], ports: [21975, 48890]) -(AReaL) 20260419-07:58:30.081 LocalScheduler INFO: Successfully created 8 workers for role 'actor' -(AReaL) 20260419-07:58:36.108 SyncRPCServer INFO: Werkzeug log level: WARNING -(AReaL) 20260419-07:58:36.109 Guard INFO: Starting Guard on [fdbd:dc05:13::28]:16368 for worker actor/1 -(AReaL) 20260419-07:58:36.136 SyncRPCServer INFO: Werkzeug log level: WARNING -(AReaL) 20260419-07:58:36.137 Guard INFO: Starting Guard on [fdbd:dc05:13::28]:7851 for worker actor/0 -(AReaL) 20260419-07:58:36.206 CLIArgs WARNING: behave_imp_weight_cap and behave_imp_weight_mode are configured but use_decoupled_loss=False. These settings will be ignored. Set use_decoupled_loss=True to enable decoupled loss with importance weight correction. -(AReaL) 20260419-07:58:36.206 EngineBP INFO: Engine thread started -(AReaL) 20260419-07:58:36.207 EngineBP INFO: Engine thread initialized -(AReaL) 20260419-07:58:36.208 SyncRPCServer INFO: Werkzeug log level: WARNING -(AReaL) 20260419-07:58:36.209 Guard INFO: Starting Guard on [fdbd:dc05:13::28]:20768 for worker actor/2 -(AReaL) 20260419-07:58:36.211 PlatformInit INFO: Detected CUDA device: NVIDIA L20 -(AReaL) 20260419-07:58:36.211 PlatformInit INFO: Initializing CUDA platform (NVIDIA). -(AReaL) 20260419-07:58:36.223 LocalScheduler INFO: Configuration successfully on worker 'actor/0' -(AReaL) 20260419-07:58:36.241 SyncRPCServer INFO: Werkzeug log level: WARNING -(AReaL) 20260419-07:58:36.242 Guard INFO: Starting Guard on [fdbd:dc05:13::28]:52722 for worker actor/4 -(AReaL) 20260419-07:58:36.266 CLIArgs WARNING: behave_imp_weight_cap and behave_imp_weight_mode are configured but use_decoupled_loss=False. These settings will be ignored. Set use_decoupled_loss=True to enable decoupled loss with importance weight correction. -(AReaL) 20260419-07:58:36.267 EngineBP INFO: Engine thread started -(AReaL) 20260419-07:58:36.267 EngineBP INFO: Engine thread initialized -(AReaL) 20260419-07:58:36.271 PlatformInit INFO: Detected CUDA device: NVIDIA L20 -(AReaL) 20260419-07:58:36.271 PlatformInit INFO: Initializing CUDA platform (NVIDIA). -(AReaL) 20260419-07:58:36.283 LocalScheduler INFO: Configuration successfully on worker 'actor/1' -(AReaL) 20260419-07:58:36.328 CLIArgs WARNING: behave_imp_weight_cap and behave_imp_weight_mode are configured but use_decoupled_loss=False. These settings will be ignored. Set use_decoupled_loss=True to enable decoupled loss with importance weight correction. -(AReaL) 20260419-07:58:36.328 EngineBP INFO: Engine thread started -(AReaL) 20260419-07:58:36.328 EngineBP INFO: Engine thread initialized -(AReaL) 20260419-07:58:36.332 PlatformInit INFO: Detected CUDA device: NVIDIA L20 -(AReaL) 20260419-07:58:36.332 PlatformInit INFO: Initializing CUDA platform (NVIDIA). -(AReaL) 20260419-07:58:36.338 LocalScheduler INFO: Configuration successfully on worker 'actor/2' -(AReaL) 20260419-07:58:36.350 SyncRPCServer INFO: Werkzeug log level: WARNING -(AReaL) 20260419-07:58:36.351 Guard INFO: Starting Guard on [fdbd:dc05:13::28]:16842 for worker actor/3 -(AReaL) 20260419-07:58:36.495 CLIArgs WARNING: behave_imp_weight_cap and behave_imp_weight_mode are configured but use_decoupled_loss=False. These settings will be ignored. Set use_decoupled_loss=True to enable decoupled loss with importance weight correction. -(AReaL) 20260419-07:58:36.495 EngineBP INFO: Engine thread started -(AReaL) 20260419-07:58:36.495 EngineBP INFO: Engine thread initialized -(AReaL) 20260419-07:58:36.499 PlatformInit INFO: Detected CUDA device: NVIDIA L20 -(AReaL) 20260419-07:58:36.499 PlatformInit INFO: Initializing CUDA platform (NVIDIA). -(AReaL) 20260419-07:58:36.505 LocalScheduler INFO: Configuration successfully on worker 'actor/3' -(AReaL) 20260419-07:58:36.563 CLIArgs WARNING: behave_imp_weight_cap and behave_imp_weight_mode are configured but use_decoupled_loss=False. These settings will be ignored. Set use_decoupled_loss=True to enable decoupled loss with importance weight correction. -(AReaL) 20260419-07:58:36.563 EngineBP INFO: Engine thread started -(AReaL) 20260419-07:58:36.563 EngineBP INFO: Engine thread initialized -(AReaL) 20260419-07:58:36.567 PlatformInit INFO: Detected CUDA device: NVIDIA L20 -(AReaL) 20260419-07:58:36.567 PlatformInit INFO: Initializing CUDA platform (NVIDIA). -(AReaL) 20260419-07:58:36.579 LocalScheduler INFO: Configuration successfully on worker 'actor/4' -(AReaL) 20260419-07:58:37.122 SyncRPCServer INFO: Werkzeug log level: WARNING -(AReaL) 20260419-07:58:37.123 Guard INFO: Starting Guard on [fdbd:dc05:13::28]:21975 for worker actor/7 -(AReaL) 20260419-07:58:37.228 SyncRPCServer INFO: Werkzeug log level: WARNING -(AReaL) 20260419-07:58:37.229 Guard INFO: Starting Guard on [fdbd:dc05:13::28]:4458 for worker actor/6 -(AReaL) 20260419-07:58:37.328 SyncRPCServer INFO: Werkzeug log level: WARNING -(AReaL) 20260419-07:58:37.329 Guard INFO: Starting Guard on [fdbd:dc05:13::28]:14757 for worker actor/5 -(AReaL) 20260419-07:58:37.395 CLIArgs WARNING: behave_imp_weight_cap and behave_imp_weight_mode are configured but use_decoupled_loss=False. These settings will be ignored. Set use_decoupled_loss=True to enable decoupled loss with importance weight correction. -(AReaL) 20260419-07:58:37.395 EngineBP INFO: Engine thread started -(AReaL) 20260419-07:58:37.395 EngineBP INFO: Engine thread initialized -(AReaL) 20260419-07:58:37.399 PlatformInit INFO: Detected CUDA device: NVIDIA L20 -(AReaL) 20260419-07:58:37.399 PlatformInit INFO: Initializing CUDA platform (NVIDIA). -(AReaL) 20260419-07:58:37.405 LocalScheduler INFO: Configuration successfully on worker 'actor/5' -(AReaL) 20260419-07:58:37.446 CLIArgs WARNING: behave_imp_weight_cap and behave_imp_weight_mode are configured but use_decoupled_loss=False. These settings will be ignored. Set use_decoupled_loss=True to enable decoupled loss with importance weight correction. -(AReaL) 20260419-07:58:37.446 EngineBP INFO: Engine thread started -(AReaL) 20260419-07:58:37.446 EngineBP INFO: Engine thread initialized -(AReaL) 20260419-07:58:37.450 PlatformInit INFO: Detected CUDA device: NVIDIA L20 -(AReaL) 20260419-07:58:37.450 PlatformInit INFO: Initializing CUDA platform (NVIDIA). -(AReaL) 20260419-07:58:37.462 LocalScheduler INFO: Configuration successfully on worker 'actor/6' -(AReaL) 20260419-07:58:37.519 CLIArgs WARNING: behave_imp_weight_cap and behave_imp_weight_mode are configured but use_decoupled_loss=False. These settings will be ignored. Set use_decoupled_loss=True to enable decoupled loss with importance weight correction. -(AReaL) 20260419-07:58:37.519 EngineBP INFO: Engine thread started -(AReaL) 20260419-07:58:37.519 EngineBP INFO: Engine thread initialized -(AReaL) 20260419-07:58:37.523 PlatformInit INFO: Detected CUDA device: NVIDIA L20 -(AReaL) 20260419-07:58:37.523 PlatformInit INFO: Initializing CUDA platform (NVIDIA). -(AReaL) 20260419-07:58:37.535 LocalScheduler INFO: Configuration successfully on worker 'actor/7' -(AReaL) 20260419-07:58:37.536 TrainController INFO: Workers created: ['actor/0', 'actor/1', 'actor/2', 'actor/3', 'actor/4', 'actor/5', 'actor/6', 'actor/7'] -(AReaL) 20260419-07:58:37.536 TrainController INFO: Waiting for workers to be ready... -(AReaL) 20260419-07:58:37.733 LocalScheduler INFO: All 8 workers for role 'actor' are ready -(AReaL) 20260419-07:58:37.733 TrainController INFO: Workers ready: ['actor/0', 'actor/1', 'actor/2', 'actor/3', 'actor/4', 'actor/5', 'actor/6', 'actor/7'] -(AReaL) 20260419-07:58:37.733 TrainController INFO: Distributed training: MASTER_ADDR=fdbd:dc05:13::28, MASTER_PORT=34213 -(AReaL) 20260419-07:58:37.734 TrainController INFO: Creating engines on workers... -(AReaL) 20260419-07:58:37.737 EngineBP INFO: Set RANK=7 -(AReaL) 20260419-07:58:37.737 EngineBP INFO: Set WORLD_SIZE=8 -(AReaL) 20260419-07:58:37.737 EngineBP INFO: Set RANK=5 -(AReaL) 20260419-07:58:37.737 EngineBP INFO: Set RANK=2 -(AReaL) 20260419-07:58:37.737 EngineBP INFO: Set MASTER_ADDR=fdbd:dc05:13::28 -(AReaL) 20260419-07:58:37.737 EngineBP INFO: Set RANK=3 -(AReaL) 20260419-07:58:37.737 EngineBP INFO: Set RANK=6 -(AReaL) 20260419-07:58:37.737 EngineBP INFO: Set WORLD_SIZE=8 -(AReaL) 20260419-07:58:37.737 EngineBP INFO: Set WORLD_SIZE=8 -(AReaL) 20260419-07:58:37.737 EngineBP INFO: Set MASTER_PORT=34213 -(AReaL) 20260419-07:58:37.737 EngineBP INFO: Set WORLD_SIZE=8 -(AReaL) 20260419-07:58:37.737 EngineBP INFO: Set RANK=0 -(AReaL) 20260419-07:58:37.737 EngineBP INFO: Set WORLD_SIZE=8 -(AReaL) 20260419-07:58:37.737 EngineBP INFO: Set MASTER_ADDR=fdbd:dc05:13::28 -(AReaL) 20260419-07:58:37.737 EngineBP INFO: Set RANK=4 -(AReaL) 20260419-07:58:37.737 EngineBP INFO: Set LOCAL_RANK=0 -(AReaL) 20260419-07:58:37.737 EngineBP INFO: Set MASTER_ADDR=fdbd:dc05:13::28 -(AReaL) 20260419-07:58:37.737 EngineBP INFO: Set RANK=1 -(AReaL) 20260419-07:58:37.737 EngineBP INFO: Set MASTER_ADDR=fdbd:dc05:13::28 -(AReaL) 20260419-07:58:37.737 EngineBP INFO: Set MASTER_ADDR=fdbd:dc05:13::28 -(AReaL) 20260419-07:58:37.737 EngineBP INFO: Set MASTER_PORT=34213 -(AReaL) 20260419-07:58:37.737 EngineBP INFO: Set WORLD_SIZE=8 -(AReaL) 20260419-07:58:37.737 EngineBP INFO: Set WORLD_SIZE=8 -(AReaL) 20260419-07:58:37.737 EngineBP INFO: Set MASTER_PORT=34213 -(AReaL) 20260419-07:58:37.737 EngineBP INFO: Set MASTER_PORT=34213 -(AReaL) 20260419-07:58:37.737 EngineBP INFO: Set WORLD_SIZE=8 -(AReaL) 20260419-07:58:37.737 EngineBP INFO: Set MASTER_PORT=34213 -(AReaL) 20260419-07:58:37.737 EngineBP INFO: Set LOCAL_RANK=0 -(AReaL) 20260419-07:58:37.737 EngineBP INFO: Set MASTER_ADDR=fdbd:dc05:13::28 -(AReaL) 20260419-07:58:37.737 EngineBP INFO: Set MASTER_ADDR=fdbd:dc05:13::28 -(AReaL) 20260419-07:58:37.737 EngineBP INFO: Set LOCAL_RANK=0 -(AReaL) 20260419-07:58:37.737 EngineBP INFO: Set LOCAL_RANK=0 -(AReaL) 20260419-07:58:37.737 EngineBP INFO: Set MASTER_ADDR=fdbd:dc05:13::28 -(AReaL) 20260419-07:58:37.737 EngineBP INFO: Set LOCAL_RANK=0 -(AReaL) 20260419-07:58:37.737 EngineBP INFO: Set MASTER_PORT=34213 -(AReaL) 20260419-07:58:37.737 EngineBP INFO: Set MASTER_PORT=34213 -(AReaL) 20260419-07:58:37.737 EngineBP INFO: Set MASTER_PORT=34213 -(AReaL) 20260419-07:58:37.737 EngineBP INFO: Set LOCAL_RANK=0 -(AReaL) 20260419-07:58:37.737 EngineBP INFO: Set LOCAL_RANK=0 -(AReaL) 20260419-07:58:37.737 EngineBP INFO: Set LOCAL_RANK=0 -(AReaL) 20260419-07:58:41.233 BailingMoe INFO: Patched apply_rotary_pos_emb: extend truncated freq table for MLA THD+CP>1 -(AReaL) 20260419-07:58:41.262 TreeAttentionFSDP INFO: Compiled torch flex attention. Options: {'epilogue_fusion': True, 'max_autotune': False, 'shape_padding': True, 'trace.enabled': False, 'triton.cudagraphs': False}, dynamic: True -(AReaL) 20260419-07:58:41.264 TreeAttentionFSDP INFO: Using block mask in flex attention, block size: 128 -(AReaL) 20260419-07:58:41.275 BailingMoe INFO: Patched apply_rotary_pos_emb: extend truncated freq table for MLA THD+CP>1 -(AReaL) 20260419-07:58:41.275 PPOActor INFO: ====================================================================== -(AReaL) 20260419-07:58:41.275 PPOActor INFO: PPOActor Configuration -(AReaL) 20260419-07:58:41.276 PPOActor INFO: ====================================================================== -(AReaL) 20260419-07:58:41.276 PPOActor INFO: Mode: Decoupled PPO (off-policy) -(AReaL) 20260419-07:58:41.276 PPOActor INFO: log_p_behave (π_behave): FROM INFERENCE (behavior policy) -(AReaL) 20260419-07:58:41.276 PPOActor INFO: Proximal policy (π_prox): RECOMPUTED via forward pass (standard decoupled PPO) -(AReaL) 20260419-07:58:41.276 PPOActor INFO: log_p_theta (π_θ): TRAINING FORWARD PASS (current policy) -(AReaL) 20260419-07:58:41.276 PPOActor INFO: Importance weight cap: 5.0 (filters out tokens with extreme weights) -(AReaL) 20260419-07:58:41.276 PPOActor INFO: ====================================================================== -(AReaL) 20260419-07:58:41.276 PPOActor INFO: Training Parameters: -(AReaL) 20260419-07:58:41.276 PPOActor INFO: importance_sampling_level: token -(AReaL) 20260419-07:58:41.276 PPOActor INFO: adv_norm: NormConfig(mean_level='batch', mean_leave1out=False, std_level='batch', std_unbiased=True, eps=1e-05, group_size=1) -(AReaL) 20260419-07:58:41.276 PPOActor INFO: reward_norm: NormConfig(mean_level='group', mean_leave1out=False, std_level='group', std_unbiased=True, eps=1e-05, group_size=4) -(AReaL) 20260419-07:58:41.276 PPOActor INFO: eps_clip: 0.4 -(AReaL) 20260419-07:58:41.276 PPOActor INFO: ====================================================================== -(AReaL) 20260419-07:58:41.276 EngineBP INFO: Engine 'actor/1' (class: areal.engine.megatron_engine.MegatronPPOActor) instantiated successfully -(AReaL) 20260419-07:58:41.288 BailingMoe INFO: Patched apply_rotary_pos_emb: extend truncated freq table for MLA THD+CP>1 -(AReaL) 20260419-07:58:41.293 BailingMoe INFO: Patched apply_rotary_pos_emb: extend truncated freq table for MLA THD+CP>1 -(AReaL) 20260419-07:58:41.297 BailingMoe INFO: Patched apply_rotary_pos_emb: extend truncated freq table for MLA THD+CP>1 -(AReaL) 20260419-07:58:41.303 BailingMoe INFO: Patched apply_rotary_pos_emb: extend truncated freq table for MLA THD+CP>1 -(AReaL) 20260419-07:58:41.305 TreeAttentionFSDP INFO: Compiled torch flex attention. Options: {'epilogue_fusion': True, 'max_autotune': False, 'shape_padding': True, 'trace.enabled': False, 'triton.cudagraphs': False}, dynamic: True -(AReaL) 20260419-07:58:41.307 TreeAttentionFSDP INFO: Using block mask in flex attention, block size: 128 -(AReaL) 20260419-07:58:41.318 PPOActor INFO: ====================================================================== -(AReaL) 20260419-07:58:41.318 PPOActor INFO: PPOActor Configuration -(AReaL) 20260419-07:58:41.318 PPOActor INFO: ====================================================================== -(AReaL) 20260419-07:58:41.318 PPOActor INFO: Mode: Decoupled PPO (off-policy) -(AReaL) 20260419-07:58:41.318 TreeAttentionFSDP INFO: Compiled torch flex attention. Options: {'epilogue_fusion': True, 'max_autotune': False, 'shape_padding': True, 'trace.enabled': False, 'triton.cudagraphs': False}, dynamic: True -(AReaL) 20260419-07:58:41.318 PPOActor INFO: log_p_behave (π_behave): FROM INFERENCE (behavior policy) -(AReaL) 20260419-07:58:41.318 PPOActor INFO: Proximal policy (π_prox): RECOMPUTED via forward pass (standard decoupled PPO) -(AReaL) 20260419-07:58:41.318 PPOActor INFO: log_p_theta (π_θ): TRAINING FORWARD PASS (current policy) -(AReaL) 20260419-07:58:41.318 PPOActor INFO: Importance weight cap: 5.0 (filters out tokens with extreme weights) -(AReaL) 20260419-07:58:41.319 PPOActor INFO: ====================================================================== -(AReaL) 20260419-07:58:41.319 PPOActor INFO: Training Parameters: -(AReaL) 20260419-07:58:41.319 PPOActor INFO: importance_sampling_level: token -(AReaL) 20260419-07:58:41.319 PPOActor INFO: adv_norm: NormConfig(mean_level='batch', mean_leave1out=False, std_level='batch', std_unbiased=True, eps=1e-05, group_size=1) -(AReaL) 20260419-07:58:41.319 PPOActor INFO: reward_norm: NormConfig(mean_level='group', mean_leave1out=False, std_level='group', std_unbiased=True, eps=1e-05, group_size=4) -(AReaL) 20260419-07:58:41.319 PPOActor INFO: eps_clip: 0.4 -(AReaL) 20260419-07:58:41.319 PPOActor INFO: ====================================================================== -(AReaL) 20260419-07:58:41.319 EngineBP INFO: Engine 'actor/0' (class: areal.engine.megatron_engine.MegatronPPOActor) instantiated successfully -(AReaL) 20260419-07:58:41.320 TreeAttentionFSDP INFO: Using block mask in flex attention, block size: 128 -(AReaL) 20260419-07:58:41.323 TreeAttentionFSDP INFO: Compiled torch flex attention. Options: {'epilogue_fusion': True, 'max_autotune': False, 'shape_padding': True, 'trace.enabled': False, 'triton.cudagraphs': False}, dynamic: True -(AReaL) 20260419-07:58:41.325 TreeAttentionFSDP INFO: Using block mask in flex attention, block size: 128 -(AReaL) 20260419-07:58:41.328 TreeAttentionFSDP INFO: Compiled torch flex attention. Options: {'epilogue_fusion': True, 'max_autotune': False, 'shape_padding': True, 'trace.enabled': False, 'triton.cudagraphs': False}, dynamic: True -(AReaL) 20260419-07:58:41.330 TreeAttentionFSDP INFO: Using block mask in flex attention, block size: 128 -(AReaL) 20260419-07:58:41.332 PPOActor INFO: ====================================================================== -(AReaL) 20260419-07:58:41.332 PPOActor INFO: PPOActor Configuration -(AReaL) 20260419-07:58:41.332 PPOActor INFO: ====================================================================== -(AReaL) 20260419-07:58:41.333 PPOActor INFO: Mode: Decoupled PPO (off-policy) -(AReaL) 20260419-07:58:41.333 PPOActor INFO: log_p_behave (π_behave): FROM INFERENCE (behavior policy) -(AReaL) 20260419-07:58:41.333 PPOActor INFO: Proximal policy (π_prox): RECOMPUTED via forward pass (standard decoupled PPO) -(AReaL) 20260419-07:58:41.333 PPOActor INFO: log_p_theta (π_θ): TRAINING FORWARD PASS (current policy) -(AReaL) 20260419-07:58:41.333 PPOActor INFO: Importance weight cap: 5.0 (filters out tokens with extreme weights) -(AReaL) 20260419-07:58:41.333 PPOActor INFO: ====================================================================== -(AReaL) 20260419-07:58:41.333 PPOActor INFO: Training Parameters: -(AReaL) 20260419-07:58:41.333 PPOActor INFO: importance_sampling_level: token -(AReaL) 20260419-07:58:41.333 PPOActor INFO: adv_norm: NormConfig(mean_level='batch', mean_leave1out=False, std_level='batch', std_unbiased=True, eps=1e-05, group_size=1) -(AReaL) 20260419-07:58:41.333 PPOActor INFO: reward_norm: NormConfig(mean_level='group', mean_leave1out=False, std_level='group', std_unbiased=True, eps=1e-05, group_size=4) -(AReaL) 20260419-07:58:41.333 PPOActor INFO: eps_clip: 0.4 -(AReaL) 20260419-07:58:41.333 TreeAttentionFSDP INFO: Compiled torch flex attention. Options: {'epilogue_fusion': True, 'max_autotune': False, 'shape_padding': True, 'trace.enabled': False, 'triton.cudagraphs': False}, dynamic: True -(AReaL) 20260419-07:58:41.333 PPOActor INFO: ====================================================================== -(AReaL) 20260419-07:58:41.333 EngineBP INFO: Engine 'actor/5' (class: areal.engine.megatron_engine.MegatronPPOActor) instantiated successfully -(AReaL) 20260419-07:58:41.335 TreeAttentionFSDP INFO: Using block mask in flex attention, block size: 128 -(AReaL) 20260419-07:58:41.336 PPOActor INFO: ====================================================================== -(AReaL) 20260419-07:58:41.336 PPOActor INFO: PPOActor Configuration -(AReaL) 20260419-07:58:41.337 PPOActor INFO: ====================================================================== -(AReaL) 20260419-07:58:41.337 PPOActor INFO: Mode: Decoupled PPO (off-policy) -(AReaL) 20260419-07:58:41.337 PPOActor INFO: log_p_behave (π_behave): FROM INFERENCE (behavior policy) -(AReaL) 20260419-07:58:41.337 PPOActor INFO: Proximal policy (π_prox): RECOMPUTED via forward pass (standard decoupled PPO) -(AReaL) 20260419-07:58:41.337 PPOActor INFO: log_p_theta (π_θ): TRAINING FORWARD PASS (current policy) -(AReaL) 20260419-07:58:41.337 PPOActor INFO: Importance weight cap: 5.0 (filters out tokens with extreme weights) -(AReaL) 20260419-07:58:41.337 PPOActor INFO: ====================================================================== -(AReaL) 20260419-07:58:41.337 PPOActor INFO: Training Parameters: -(AReaL) 20260419-07:58:41.337 PPOActor INFO: importance_sampling_level: token -(AReaL) 20260419-07:58:41.337 PPOActor INFO: adv_norm: NormConfig(mean_level='batch', mean_leave1out=False, std_level='batch', std_unbiased=True, eps=1e-05, group_size=1) -(AReaL) 20260419-07:58:41.337 PPOActor INFO: reward_norm: NormConfig(mean_level='group', mean_leave1out=False, std_level='group', std_unbiased=True, eps=1e-05, group_size=4) -(AReaL) 20260419-07:58:41.337 PPOActor INFO: eps_clip: 0.4 -(AReaL) 20260419-07:58:41.337 PPOActor INFO: ====================================================================== -(AReaL) 20260419-07:58:41.337 EngineBP INFO: Engine 'actor/2' (class: areal.engine.megatron_engine.MegatronPPOActor) instantiated successfully -(AReaL) 20260419-07:58:41.341 PPOActor INFO: ====================================================================== -(AReaL) 20260419-07:58:41.341 PPOActor INFO: PPOActor Configuration -(AReaL) 20260419-07:58:41.341 PPOActor INFO: ====================================================================== -(AReaL) 20260419-07:58:41.341 PPOActor INFO: Mode: Decoupled PPO (off-policy) -(AReaL) 20260419-07:58:41.341 PPOActor INFO: log_p_behave (π_behave): FROM INFERENCE (behavior policy) -(AReaL) 20260419-07:58:41.341 PPOActor INFO: Proximal policy (π_prox): RECOMPUTED via forward pass (standard decoupled PPO) -(AReaL) 20260419-07:58:41.341 PPOActor INFO: log_p_theta (π_θ): TRAINING FORWARD PASS (current policy) -(AReaL) 20260419-07:58:41.341 PPOActor INFO: Importance weight cap: 5.0 (filters out tokens with extreme weights) -(AReaL) 20260419-07:58:41.342 PPOActor INFO: ====================================================================== -(AReaL) 20260419-07:58:41.342 PPOActor INFO: Training Parameters: -(AReaL) 20260419-07:58:41.342 PPOActor INFO: importance_sampling_level: token -(AReaL) 20260419-07:58:41.342 PPOActor INFO: adv_norm: NormConfig(mean_level='batch', mean_leave1out=False, std_level='batch', std_unbiased=True, eps=1e-05, group_size=1) -(AReaL) 20260419-07:58:41.342 PPOActor INFO: reward_norm: NormConfig(mean_level='group', mean_leave1out=False, std_level='group', std_unbiased=True, eps=1e-05, group_size=4) -(AReaL) 20260419-07:58:41.342 PPOActor INFO: eps_clip: 0.4 -(AReaL) 20260419-07:58:41.342 PPOActor INFO: ====================================================================== -(AReaL) 20260419-07:58:41.342 EngineBP INFO: Engine 'actor/3' (class: areal.engine.megatron_engine.MegatronPPOActor) instantiated successfully -(AReaL) 20260419-07:58:41.347 PPOActor INFO: ====================================================================== -(AReaL) 20260419-07:58:41.347 PPOActor INFO: PPOActor Configuration -(AReaL) 20260419-07:58:41.347 PPOActor INFO: ====================================================================== -(AReaL) 20260419-07:58:41.347 PPOActor INFO: Mode: Decoupled PPO (off-policy) -(AReaL) 20260419-07:58:41.347 PPOActor INFO: log_p_behave (π_behave): FROM INFERENCE (behavior policy) -(AReaL) 20260419-07:58:41.348 PPOActor INFO: Proximal policy (π_prox): RECOMPUTED via forward pass (standard decoupled PPO) -(AReaL) 20260419-07:58:41.348 PPOActor INFO: log_p_theta (π_θ): TRAINING FORWARD PASS (current policy) -(AReaL) 20260419-07:58:41.348 PPOActor INFO: Importance weight cap: 5.0 (filters out tokens with extreme weights) -(AReaL) 20260419-07:58:41.348 PPOActor INFO: ====================================================================== -(AReaL) 20260419-07:58:41.348 PPOActor INFO: Training Parameters: -(AReaL) 20260419-07:58:41.348 PPOActor INFO: importance_sampling_level: token -(AReaL) 20260419-07:58:41.348 PPOActor INFO: adv_norm: NormConfig(mean_level='batch', mean_leave1out=False, std_level='batch', std_unbiased=True, eps=1e-05, group_size=1) -(AReaL) 20260419-07:58:41.348 PPOActor INFO: reward_norm: NormConfig(mean_level='group', mean_leave1out=False, std_level='group', std_unbiased=True, eps=1e-05, group_size=4) -(AReaL) 20260419-07:58:41.348 PPOActor INFO: eps_clip: 0.4 -(AReaL) 20260419-07:58:41.348 PPOActor INFO: ====================================================================== -(AReaL) 20260419-07:58:41.348 EngineBP INFO: Engine 'actor/7' (class: areal.engine.megatron_engine.MegatronPPOActor) instantiated successfully -(AReaL) 20260419-07:58:41.428 BailingMoe INFO: Patched apply_rotary_pos_emb: extend truncated freq table for MLA THD+CP>1 -(AReaL) 20260419-07:58:41.458 TreeAttentionFSDP INFO: Compiled torch flex attention. Options: {'epilogue_fusion': True, 'max_autotune': False, 'shape_padding': True, 'trace.enabled': False, 'triton.cudagraphs': False}, dynamic: True -(AReaL) 20260419-07:58:41.460 TreeAttentionFSDP INFO: Using block mask in flex attention, block size: 128 -(AReaL) 20260419-07:58:41.472 PPOActor INFO: ====================================================================== -(AReaL) 20260419-07:58:41.472 PPOActor INFO: PPOActor Configuration -(AReaL) 20260419-07:58:41.472 PPOActor INFO: ====================================================================== -(AReaL) 20260419-07:58:41.472 PPOActor INFO: Mode: Decoupled PPO (off-policy) -(AReaL) 20260419-07:58:41.472 PPOActor INFO: log_p_behave (π_behave): FROM INFERENCE (behavior policy) -(AReaL) 20260419-07:58:41.472 PPOActor INFO: Proximal policy (π_prox): RECOMPUTED via forward pass (standard decoupled PPO) -(AReaL) 20260419-07:58:41.473 PPOActor INFO: log_p_theta (π_θ): TRAINING FORWARD PASS (current policy) -(AReaL) 20260419-07:58:41.473 PPOActor INFO: Importance weight cap: 5.0 (filters out tokens with extreme weights) -(AReaL) 20260419-07:58:41.473 PPOActor INFO: ====================================================================== -(AReaL) 20260419-07:58:41.473 PPOActor INFO: Training Parameters: -(AReaL) 20260419-07:58:41.473 PPOActor INFO: importance_sampling_level: token -(AReaL) 20260419-07:58:41.473 PPOActor INFO: adv_norm: NormConfig(mean_level='batch', mean_leave1out=False, std_level='batch', std_unbiased=True, eps=1e-05, group_size=1) -(AReaL) 20260419-07:58:41.473 PPOActor INFO: reward_norm: NormConfig(mean_level='group', mean_leave1out=False, std_level='group', std_unbiased=True, eps=1e-05, group_size=4) -(AReaL) 20260419-07:58:41.473 PPOActor INFO: eps_clip: 0.4 -(AReaL) 20260419-07:58:41.473 PPOActor INFO: ====================================================================== -(AReaL) 20260419-07:58:41.473 EngineBP INFO: Engine 'actor/6' (class: areal.engine.megatron_engine.MegatronPPOActor) instantiated successfully -(AReaL) 20260419-07:58:42.335 BailingMoe INFO: Patched apply_rotary_pos_emb: extend truncated freq table for MLA THD+CP>1 -(AReaL) 20260419-07:58:42.366 TreeAttentionFSDP INFO: Compiled torch flex attention. Options: {'epilogue_fusion': True, 'max_autotune': False, 'shape_padding': True, 'trace.enabled': False, 'triton.cudagraphs': False}, dynamic: True -(AReaL) 20260419-07:58:42.368 TreeAttentionFSDP INFO: Using block mask in flex attention, block size: 128 -(AReaL) 20260419-07:58:42.380 PPOActor INFO: ====================================================================== -(AReaL) 20260419-07:58:42.380 PPOActor INFO: PPOActor Configuration -(AReaL) 20260419-07:58:42.380 PPOActor INFO: ====================================================================== -(AReaL) 20260419-07:58:42.380 PPOActor INFO: Mode: Decoupled PPO (off-policy) -(AReaL) 20260419-07:58:42.380 PPOActor INFO: log_p_behave (π_behave): FROM INFERENCE (behavior policy) -(AReaL) 20260419-07:58:42.380 PPOActor INFO: Proximal policy (π_prox): RECOMPUTED via forward pass (standard decoupled PPO) -(AReaL) 20260419-07:58:42.380 PPOActor INFO: log_p_theta (π_θ): TRAINING FORWARD PASS (current policy) -(AReaL) 20260419-07:58:42.380 PPOActor INFO: Importance weight cap: 5.0 (filters out tokens with extreme weights) -(AReaL) 20260419-07:58:42.380 PPOActor INFO: ====================================================================== -(AReaL) 20260419-07:58:42.380 PPOActor INFO: Training Parameters: -(AReaL) 20260419-07:58:42.380 PPOActor INFO: importance_sampling_level: token -(AReaL) 20260419-07:58:42.380 PPOActor INFO: adv_norm: NormConfig(mean_level='batch', mean_leave1out=False, std_level='batch', std_unbiased=True, eps=1e-05, group_size=1) -(AReaL) 20260419-07:58:42.380 PPOActor INFO: reward_norm: NormConfig(mean_level='group', mean_leave1out=False, std_level='group', std_unbiased=True, eps=1e-05, group_size=4) -(AReaL) 20260419-07:58:42.380 PPOActor INFO: eps_clip: 0.4 -(AReaL) 20260419-07:58:42.380 PPOActor INFO: ====================================================================== -(AReaL) 20260419-07:58:42.381 EngineBP INFO: Engine 'actor/4' (class: areal.engine.megatron_engine.MegatronPPOActor) instantiated successfully -(AReaL) 20260419-07:58:42.382 TrainController INFO: Engines created on all workers! -(AReaL) 20260419-07:58:42.382 TrainController INFO: Calling engine initialization... -[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 -[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 -[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 -[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 -[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 -[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 -[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 -[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 -[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 -[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 -[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 -[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 -[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 -[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 -[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 -[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 -[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 -[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 -[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 -[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 -[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 -[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 -[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 -[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 -[Gloo] Rank 0 is connected to 7 peer ranks. Expected number of connected peer ranks is : 7 -[Gloo] Rank 1 is connected to 7 peer ranks. Expected number of connected peer ranks is : 7 -[Gloo] Rank 4 is connected to 7 peer ranks. Expected number of connected peer ranks is : 7 -[Gloo] Rank 2 is connected to 7 peer ranks. Expected number of connected peer ranks is : 7 -[Gloo] Rank 3 is connected to 7 peer ranks. Expected number of connected peer ranks is : 7 -[Gloo] Rank 6 is connected to 7 peer ranks. Expected number of connected peer ranks is : 7 -[Gloo] Rank 5 is connected to 7 peer ranks. Expected number of connected peer ranks is : 7 -[Gloo] Rank 7 is connected to 7 peer ranks. Expected number of connected peer ranks is : 7 -(AReaL) 20260419-07:58:43.200 [MegatronEngine Rank 3] INFO: Detected 'adam_bf16' optimizer with Megatron Engine. Automatically converting to 'adam' with precision-aware optimizer and setting exp_avg_dtype/exp_avg_sq_dtype to 'bfloat16'. -(AReaL) 20260419-07:58:43.200 [MegatronEngine Rank 4] INFO: Detected 'adam_bf16' optimizer with Megatron Engine. Automatically converting to 'adam' with precision-aware optimizer and setting exp_avg_dtype/exp_avg_sq_dtype to 'bfloat16'. -(AReaL) 20260419-07:58:43.200 [MegatronEngine Rank 6] INFO: Detected 'adam_bf16' optimizer with Megatron Engine. Automatically converting to 'adam' with precision-aware optimizer and setting exp_avg_dtype/exp_avg_sq_dtype to 'bfloat16'. -(AReaL) 20260419-07:58:43.200 [MegatronEngine Rank 5] INFO: Detected 'adam_bf16' optimizer with Megatron Engine. Automatically converting to 'adam' with precision-aware optimizer and setting exp_avg_dtype/exp_avg_sq_dtype to 'bfloat16'. -(AReaL) 20260419-07:58:43.200 [MegatronEngine Rank 2] INFO: Detected 'adam_bf16' optimizer with Megatron Engine. Automatically converting to 'adam' with precision-aware optimizer and setting exp_avg_dtype/exp_avg_sq_dtype to 'bfloat16'. -(AReaL) 20260419-07:58:43.200 [MegatronEngine Rank 1] INFO: Detected 'adam_bf16' optimizer with Megatron Engine. Automatically converting to 'adam' with precision-aware optimizer and setting exp_avg_dtype/exp_avg_sq_dtype to 'bfloat16'. -(AReaL) 20260419-07:58:43.200 [MegatronEngine Rank 7] INFO: Detected 'adam_bf16' optimizer with Megatron Engine. Automatically converting to 'adam' with precision-aware optimizer and setting exp_avg_dtype/exp_avg_sq_dtype to 'bfloat16'. -(AReaL) 20260419-07:58:43.200 [MegatronEngine Rank 0] INFO: Detected 'adam_bf16' optimizer with Megatron Engine. Automatically converting to 'adam' with precision-aware optimizer and setting exp_avg_dtype/exp_avg_sq_dtype to 'bfloat16'. -(AReaL) 20260419-07:58:43.674 CUDAPlatform INFO: Set NUMA affinity for GPU 0: bound to 24 CPU cores. -(AReaL) 20260419-07:58:43.887 CUDAPlatform INFO: Set NUMA affinity for GPU 0: bound to 24 CPU cores. -(AReaL) 20260419-07:58:43.905 CUDAPlatform INFO: Set NUMA affinity for GPU 0: bound to 24 CPU cores. -(AReaL) 20260419-07:58:43.918 CUDAPlatform INFO: Set NUMA affinity for GPU 0: bound to 24 CPU cores. -(AReaL) 20260419-07:58:43.928 CUDAPlatform INFO: Set NUMA affinity for GPU 0: bound to 24 CPU cores. -(AReaL) 20260419-07:58:43.940 CUDAPlatform INFO: Set NUMA affinity for GPU 0: bound to 24 CPU cores. -(AReaL) 20260419-07:58:43.940 CUDAPlatform INFO: Set NUMA affinity for GPU 0: bound to 24 CPU cores. -(AReaL) 20260419-07:58:43.953 CUDAPlatform INFO: Set NUMA affinity for GPU 0: bound to 24 CPU cores. -/AReaL/.venv/lib/python3.12/site-packages/megatron/core/transformer/transformer_config.py:1705: UserWarning: full scope is deprecated. Use empty cuda_graph_scope to capture the whole layer. - warnings.warn( -(AReaL) 20260419-07:58:44.154 [MegatronEngine Rank 7] INFO: Using mbridge to create models and hf model save/load in MegatronEngine. -(AReaL) 20260419-07:58:44.154 MCoreParallel INFO: Configured pipeline layout (per-stage decoder counts / params): [14, 13] / ['2665.67M', '2579.68M'] (pp=2, vpp=1) -/AReaL/.venv/lib/python3.12/site-packages/megatron/core/transformer/transformer_config.py:1705: UserWarning: full scope is deprecated. Use empty cuda_graph_scope to capture the whole layer. - warnings.warn( -(AReaL) 20260419-07:58:44.346 [MegatronEngine Rank 2] INFO: Using mbridge to create models and hf model save/load in MegatronEngine. -(AReaL) 20260419-07:58:44.346 MCoreParallel INFO: Configured pipeline layout (per-stage decoder counts / params): [14, 13] / ['2665.67M', '2579.68M'] (pp=2, vpp=1) -/AReaL/.venv/lib/python3.12/site-packages/megatron/core/transformer/transformer_config.py:1705: UserWarning: full scope is deprecated. Use empty cuda_graph_scope to capture the whole layer. - warnings.warn( -(AReaL) 20260419-07:58:44.365 [MegatronEngine Rank 6] INFO: Using mbridge to create models and hf model save/load in MegatronEngine. -(AReaL) 20260419-07:58:44.365 MCoreParallel INFO: Configured pipeline layout (per-stage decoder counts / params): [14, 13] / ['2665.67M', '2579.68M'] (pp=2, vpp=1) -/AReaL/.venv/lib/python3.12/site-packages/megatron/core/transformer/transformer_config.py:1705: UserWarning: full scope is deprecated. Use empty cuda_graph_scope to capture the whole layer. - warnings.warn( -(AReaL) 20260419-07:58:44.391 [MegatronEngine Rank 1] INFO: Using mbridge to create models and hf model save/load in MegatronEngine. -(AReaL) 20260419-07:58:44.392 MCoreParallel INFO: Configured pipeline layout (per-stage decoder counts / params): [14, 13] / ['2665.67M', '2579.68M'] (pp=2, vpp=1) -/AReaL/.venv/lib/python3.12/site-packages/megatron/core/transformer/transformer_config.py:1705: UserWarning: full scope is deprecated. Use empty cuda_graph_scope to capture the whole layer. - warnings.warn( -(AReaL) 20260419-07:58:44.408 [MegatronEngine Rank 4] INFO: Using mbridge to create models and hf model save/load in MegatronEngine. -(AReaL) 20260419-07:58:44.408 MCoreParallel INFO: Configured pipeline layout (per-stage decoder counts / params): [14, 13] / ['2665.67M', '2579.68M'] (pp=2, vpp=1) -/AReaL/.venv/lib/python3.12/site-packages/megatron/core/transformer/transformer_config.py:1705: UserWarning: full scope is deprecated. Use empty cuda_graph_scope to capture the whole layer. - warnings.warn( -(AReaL) 20260419-07:58:44.411 [MegatronEngine Rank 5] INFO: Using mbridge to create models and hf model save/load in MegatronEngine. -(AReaL) 20260419-07:58:44.411 MCoreParallel INFO: Configured pipeline layout (per-stage decoder counts / params): [14, 13] / ['2665.67M', '2579.68M'] (pp=2, vpp=1) -/AReaL/.venv/lib/python3.12/site-packages/megatron/core/transformer/transformer_config.py:1705: UserWarning: full scope is deprecated. Use empty cuda_graph_scope to capture the whole layer. - warnings.warn( -(AReaL) 20260419-07:58:44.417 [MegatronEngine Rank 3] INFO: Using mbridge to create models and hf model save/load in MegatronEngine. -(AReaL) 20260419-07:58:44.417 MCoreParallel INFO: Configured pipeline layout (per-stage decoder counts / params): [14, 13] / ['2665.67M', '2579.68M'] (pp=2, vpp=1) -/AReaL/.venv/lib/python3.12/site-packages/megatron/core/transformer/transformer_config.py:1705: UserWarning: full scope is deprecated. Use empty cuda_graph_scope to capture the whole layer. - warnings.warn( -(AReaL) 20260419-07:58:44.429 [MegatronEngine Rank 0] INFO: Using mbridge to create models and hf model save/load in MegatronEngine. -(AReaL) 20260419-07:58:44.430 MCoreParallel INFO: Configured pipeline layout (per-stage decoder counts / params): [14, 13] / ['2665.67M', '2579.68M'] (pp=2, vpp=1) - > number of parameters on (tensor, pipeline) model parallel rank (3, 1): 1997468160 - > number of parameters on (tensor, pipeline) model parallel rank (0, 1): 1997468160 - > number of parameters on (tensor, pipeline) model parallel rank (1, 1): 1997468160 - > number of parameters on (tensor, pipeline) model parallel rank (2, 0): 2019097600 - > number of parameters on (tensor, pipeline) model parallel rank (3, 0): 2019097600 - > number of parameters on (tensor, pipeline) model parallel rank (2, 1): 1997468160 - > number of parameters on (tensor, pipeline) model parallel rank (1, 0): 2019097600 - > number of parameters on (tensor, pipeline) model parallel rank (0, 0): 2019097600 -(AReaL) 20260419-07:58:55.843 [MegatronEngine Rank 7] INFO: Model parameter count: 1997.47M, pp_stage=1, vpp_chunks=1 -[OptDiag] Megatron OptimizerConfig: use_precision_aware_optimizer=True, use_precision_aware_optimizer_no_fp8_or_ds_fp8=True, store_param_remainders=True, main_params_dtype=torch.float32, main_grads_dtype=torch.bfloat16, exp_avg_dtype=torch.bfloat16, exp_avg_sq_dtype=torch.bfloat16, use_distributed_optimizer=True, bf16=True -(AReaL) 20260419-07:58:55.906 [MegatronEngine Rank 6] INFO: Model parameter count: 1997.47M, pp_stage=1, vpp_chunks=1 -[OptDiag] Megatron OptimizerConfig: use_precision_aware_optimizer=True, use_precision_aware_optimizer_no_fp8_or_ds_fp8=True, store_param_remainders=True, main_params_dtype=torch.float32, main_grads_dtype=torch.bfloat16, exp_avg_dtype=torch.bfloat16, exp_avg_sq_dtype=torch.bfloat16, use_distributed_optimizer=True, bf16=True -(AReaL) 20260419-07:58:55.931 [MegatronEngine Rank 3] INFO: Model parameter count: 2019.10M, pp_stage=0, vpp_chunks=1 -[OptDiag] Megatron OptimizerConfig: use_precision_aware_optimizer=True, use_precision_aware_optimizer_no_fp8_or_ds_fp8=True, store_param_remainders=True, main_params_dtype=torch.float32, main_grads_dtype=torch.bfloat16, exp_avg_dtype=torch.bfloat16, exp_avg_sq_dtype=torch.bfloat16, use_distributed_optimizer=True, bf16=True -(AReaL) 20260419-07:58:55.932 [MegatronEngine Rank 4] INFO: Model parameter count: 1997.47M, pp_stage=1, vpp_chunks=1 -[OptDiag] Megatron OptimizerConfig: use_precision_aware_optimizer=True, use_precision_aware_optimizer_no_fp8_or_ds_fp8=True, store_param_remainders=True, main_params_dtype=torch.float32, main_grads_dtype=torch.bfloat16, exp_avg_dtype=torch.bfloat16, exp_avg_sq_dtype=torch.bfloat16, use_distributed_optimizer=True, bf16=True -(AReaL) 20260419-07:58:55.966 [MegatronEngine Rank 0] INFO: Model parameter count: 2019.10M, pp_stage=0, vpp_chunks=1 -[OptDiag] Megatron OptimizerConfig: use_precision_aware_optimizer=True, use_precision_aware_optimizer_no_fp8_or_ds_fp8=True, store_param_remainders=True, main_params_dtype=torch.float32, main_grads_dtype=torch.bfloat16, exp_avg_dtype=torch.bfloat16, exp_avg_sq_dtype=torch.bfloat16, use_distributed_optimizer=True, bf16=True -(AReaL) 20260419-07:58:55.968 [MegatronEngine Rank 2] INFO: Model parameter count: 2019.10M, pp_stage=0, vpp_chunks=1 -[OptDiag] Megatron OptimizerConfig: use_precision_aware_optimizer=True, use_precision_aware_optimizer_no_fp8_or_ds_fp8=True, store_param_remainders=True, main_params_dtype=torch.float32, main_grads_dtype=torch.bfloat16, exp_avg_dtype=torch.bfloat16, exp_avg_sq_dtype=torch.bfloat16, use_distributed_optimizer=True, bf16=True -(AReaL) 20260419-07:58:56.000 [MegatronEngine Rank 5] INFO: Model parameter count: 1997.47M, pp_stage=1, vpp_chunks=1 -[OptDiag] Megatron OptimizerConfig: use_precision_aware_optimizer=True, use_precision_aware_optimizer_no_fp8_or_ds_fp8=True, store_param_remainders=True, main_params_dtype=torch.float32, main_grads_dtype=torch.bfloat16, exp_avg_dtype=torch.bfloat16, exp_avg_sq_dtype=torch.bfloat16, use_distributed_optimizer=True, bf16=True -(AReaL) 20260419-07:58:56.013 [MegatronEngine Rank 1] INFO: Model parameter count: 2019.10M, pp_stage=0, vpp_chunks=1 -[OptDiag] Megatron OptimizerConfig: use_precision_aware_optimizer=True, use_precision_aware_optimizer_no_fp8_or_ds_fp8=True, store_param_remainders=True, main_params_dtype=torch.float32, main_grads_dtype=torch.bfloat16, exp_avg_dtype=torch.bfloat16, exp_avg_sq_dtype=torch.bfloat16, use_distributed_optimizer=True, bf16=True -(AReaL) 20260419-07:58:57.621 TrainController INFO: All engines are initialized! -(AReaL) 20260419-07:58:57.621 TrainController INFO: Identifying DP head workers... -(AReaL) 20260419-07:58:57.921 TrainController INFO: TrainController initialization complete -(AReaL) 20260419-07:58:57.937 RolloutController WARNING: Placement strategy 'shared' is not supported for rollouts. Forcing to 'separate' strategy -(AReaL) 20260419-07:58:57.937 RolloutController INFO: Creating workers via scheduler... -(AReaL) 20260419-07:58:57.937 LocalScheduler INFO: Creating 1 workers for role 'rollout' (strategy: SchedulingStrategyType.separation, colocate_with: None) -(AReaL) 20260419-07:58:57.938 LauncherUtils INFO: Auto-setting thread env vars to 64: OMP_NUM_THREADS, MKL_NUM_THREADS, OPENBLAS_NUM_THREADS, VECLIB_MAXIMUM_THREADS, NUMEXPR_NUM_THREADS -(AReaL) 20260419-07:58:57.938 LocalScheduler INFO: Starting worker rollout/0: python3 -m areal.infra.rpc.rpc_server --port 53960 --experiment-name moonlight_moe_exp --trial-name moonlight_moe_0419_r3_v2 --role rollout --worker-index 0 --name-resolve-type nfs --nfs-record-root /tmp/areal/moon_name_resolve --etcd3-addr localhost:2379 --fileroot /tmp/areal/moon_experiments -(AReaL) 20260419-07:58:58.039 LocalScheduler INFO: Worker rollout/0 started (PID: 2381599, GPUs: [0, 1, 2, 3, 4, 5, 6, 7], ports: [53960, 55192]) -(AReaL) 20260419-07:58:58.039 LocalScheduler INFO: Successfully created 1 workers for role 'rollout' -(AReaL) 20260419-07:59:04.742 SyncRPCServer INFO: Werkzeug log level: WARNING -(AReaL) 20260419-07:59:04.743 Guard INFO: Starting Guard on [fdbd:dc05:13::28]:53960 for worker rollout/0 -(AReaL) 20260419-07:59:04.879 CLIArgs WARNING: behave_imp_weight_cap and behave_imp_weight_mode are configured but use_decoupled_loss=False. These settings will be ignored. Set use_decoupled_loss=True to enable decoupled loss with importance weight correction. -(AReaL) 20260419-07:59:04.879 EngineBP INFO: Engine thread started -(AReaL) 20260419-07:59:04.880 EngineBP INFO: Engine thread initialized -(AReaL) 20260419-07:59:04.883 PlatformInit INFO: Detected CUDA device: NVIDIA L20 -(AReaL) 20260419-07:59:04.884 PlatformInit INFO: Initializing CUDA platform (NVIDIA). -(AReaL) 20260419-07:59:04.895 LocalScheduler INFO: Configuration successfully on worker 'rollout/0' -(AReaL) 20260419-07:59:04.896 RolloutController INFO: Workers created: ['rollout/0'] -(AReaL) 20260419-07:59:04.896 RolloutController INFO: Waiting for workers to be ready... -(AReaL) 20260419-07:59:04.926 LocalScheduler INFO: All 1 workers for role 'rollout' are ready -(AReaL) 20260419-07:59:04.926 RolloutController INFO: Workers ready: ['rollout/0'] -(AReaL) 20260419-07:59:04.926 RolloutController INFO: Creating engines... -(AReaL) 20260419-07:59:04.928 EngineBP INFO: Engine 'rollout/0' (class: areal.engine.sglang_remote.RemoteSGLangEngine) instantiated successfully -(AReaL) 20260419-07:59:04.928 RolloutController INFO: Engine created on all workers! -(AReaL) 20260419-07:59:04.929 RolloutController INFO: Calling engine initialization... -/AReaL/.venv/lib/python3.12/site-packages/sglang/srt/utils/hf_transformers_utils.py:558: UserWarning: Using a slow tokenizer. This might cause a significant slowdown. Consider using a fast tokenizer instead. - warnings.warn( -/AReaL/.venv/lib/python3.12/site-packages/sglang/srt/utils/hf_transformers_utils.py:558: UserWarning: Using a slow tokenizer. This might cause a significant slowdown. Consider using a fast tokenizer instead. - warnings.warn( -/AReaL/.venv/lib/python3.12/site-packages/sglang/srt/utils/hf_transformers_utils.py:558: UserWarning: Using a slow tokenizer. This might cause a significant slowdown. Consider using a fast tokenizer instead. - warnings.warn( -/AReaL/.venv/lib/python3.12/site-packages/sglang/srt/utils/hf_transformers_utils.py:558: UserWarning: Using a slow tokenizer. This might cause a significant slowdown. Consider using a fast tokenizer instead. - warnings.warn( -/AReaL/.venv/lib/python3.12/site-packages/sglang/srt/utils/hf_transformers_utils.py:558: UserWarning: Using a slow tokenizer. This might cause a significant slowdown. Consider using a fast tokenizer instead. - warnings.warn( -/AReaL/.venv/lib/python3.12/site-packages/sglang/srt/utils/hf_transformers_utils.py:558: UserWarning: Using a slow tokenizer. This might cause a significant slowdown. Consider using a fast tokenizer instead. - warnings.warn( -/AReaL/.venv/lib/python3.12/site-packages/sglang/srt/utils/hf_transformers_utils.py:558: UserWarning: Using a slow tokenizer. This might cause a significant slowdown. Consider using a fast tokenizer instead. - warnings.warn( -/AReaL/.venv/lib/python3.12/site-packages/sglang/srt/utils/hf_transformers_utils.py:558: UserWarning: Using a slow tokenizer. This might cause a significant slowdown. Consider using a fast tokenizer instead. - warnings.warn( -/AReaL/.venv/lib/python3.12/site-packages/sglang/srt/utils/hf_transformers_utils.py:558: UserWarning: Using a slow tokenizer. This might cause a significant slowdown. Consider using a fast tokenizer instead. - warnings.warn( -/AReaL/.venv/lib/python3.12/site-packages/sglang/srt/utils/hf_transformers_utils.py:558: UserWarning: Using a slow tokenizer. This might cause a significant slowdown. Consider using a fast tokenizer instead. - warnings.warn( -[Gloo] Rank 3 is connected to 7 peer ranks. Expected number of connected peer ranks is : 7 -[Gloo] Rank 0 is connected to 7 peer ranks. Expected number of connected peer ranks is : 7 -[Gloo] Rank 4 is connected to 7 peer ranks. Expected number of connected peer ranks is : 7 -[Gloo] Rank 1 is connected to 7 peer ranks. Expected number of connected peer ranks is : 7 -[Gloo] Rank 6 is connected to 7 peer ranks. Expected number of connected peer ranks is : 7 -[Gloo] Rank 5 is connected to 7 peer ranks. Expected number of connected peer ranks is : 7 -[Gloo] Rank 2 is connected to 7 peer ranks. Expected number of connected peer ranks is : 7 -[Gloo] Rank 7 is connected to 7 peer ranks. Expected number of connected peer ranks is : 7 -[Gloo] Rank 0 is connected to 7 peer ranks. Expected number of connected peer ranks is : 7 -[Gloo] Rank 1 is connected to 7 peer ranks. Expected number of connected peer ranks is : 7 -[Gloo] Rank 3 is connected to 7 peer ranks. Expected number of connected peer ranks is : 7 -[Gloo] Rank 2 is connected to 7 peer ranks. Expected number of connected peer ranks is : 7 -[Gloo] Rank 4 is connected to 7 peer ranks. Expected number of connected peer ranks is : 7 -[Gloo] Rank 6 is connected to 7 peer ranks. Expected number of connected peer ranks is : 7 -[Gloo] Rank 5 is connected to 7 peer ranks. Expected number of connected peer ranks is : 7 -[Gloo] Rank 7 is connected to 7 peer ranks. Expected number of connected peer ranks is : 7 -[2026-04-19 07:59:20 TP0] Custom allreduce is disabled because it's not supported on more than two PCIe-only GPUs. To silence this warning, specify disable_custom_all_reduce=True explicitly. -[2026-04-19 07:59:20 TP7] Custom allreduce is disabled because it's not supported on more than two PCIe-only GPUs. To silence this warning, specify disable_custom_all_reduce=True explicitly. -[2026-04-19 07:59:20 TP1] Custom allreduce is disabled because it's not supported on more than two PCIe-only GPUs. To silence this warning, specify disable_custom_all_reduce=True explicitly. -[2026-04-19 07:59:20 TP2] Custom allreduce is disabled because it's not supported on more than two PCIe-only GPUs. To silence this warning, specify disable_custom_all_reduce=True explicitly. -[2026-04-19 07:59:20 TP6] Custom allreduce is disabled because it's not supported on more than two PCIe-only GPUs. To silence this warning, specify disable_custom_all_reduce=True explicitly. -[2026-04-19 07:59:20 TP3] Custom allreduce is disabled because it's not supported on more than two PCIe-only GPUs. To silence this warning, specify disable_custom_all_reduce=True explicitly. -[2026-04-19 07:59:20 TP4] Custom allreduce is disabled because it's not supported on more than two PCIe-only GPUs. To silence this warning, specify disable_custom_all_reduce=True explicitly. -[2026-04-19 07:59:20 TP5] Custom allreduce is disabled because it's not supported on more than two PCIe-only GPUs. To silence this warning, specify disable_custom_all_reduce=True explicitly. -[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 -[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 -[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 -[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 -[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 -[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 -[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 -[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 -[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 -[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 -[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 -[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 -[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 -[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 -[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 -[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 -[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 -[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 -[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 -[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 -[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 -[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 -[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 -[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 -[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 -[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 -[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 -[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 -[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 -[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 -[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 -[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0 -[2026-04-19 07:59:23 TP0] Ignore import error when loading sglang.srt.models.glm_ocr: No module named 'transformers.models.glm_ocr' -[2026-04-19 07:59:23 TP0] Ignore import error when loading sglang.srt.models.glm_ocr_nextn: No module named 'transformers.models.glm_ocr' -[2026-04-19 07:59:23 TP0] Ignore import error when loading sglang.srt.models.glmasr: cannot import name 'GlmAsrConfig' from 'transformers' (/AReaL/.venv/lib/python3.12/site-packages/transformers/__init__.py) -[2026-04-19 07:59:23 TP3] Ignore import error when loading sglang.srt.models.glm_ocr: No module named 'transformers.models.glm_ocr' -[2026-04-19 07:59:23 TP3] Ignore import error when loading sglang.srt.models.glm_ocr_nextn: No module named 'transformers.models.glm_ocr' -[2026-04-19 07:59:23 TP3] Ignore import error when loading sglang.srt.models.glmasr: cannot import name 'GlmAsrConfig' from 'transformers' (/AReaL/.venv/lib/python3.12/site-packages/transformers/__init__.py) -[2026-04-19 07:59:23 TP2] Ignore import error when loading sglang.srt.models.glm_ocr: No module named 'transformers.models.glm_ocr' -[2026-04-19 07:59:23 TP2] Ignore import error when loading sglang.srt.models.glm_ocr_nextn: No module named 'transformers.models.glm_ocr' -[2026-04-19 07:59:23 TP2] Ignore import error when loading sglang.srt.models.glmasr: cannot import name 'GlmAsrConfig' from 'transformers' (/AReaL/.venv/lib/python3.12/site-packages/transformers/__init__.py) -[2026-04-19 07:59:23 TP5] Ignore import error when loading sglang.srt.models.glm_ocr: No module named 'transformers.models.glm_ocr' -[2026-04-19 07:59:23 TP5] Ignore import error when loading sglang.srt.models.glm_ocr_nextn: No module named 'transformers.models.glm_ocr' -[2026-04-19 07:59:23 TP5] Ignore import error when loading sglang.srt.models.glmasr: cannot import name 'GlmAsrConfig' from 'transformers' (/AReaL/.venv/lib/python3.12/site-packages/transformers/__init__.py) -[2026-04-19 07:59:23 TP7] Ignore import error when loading sglang.srt.models.glm_ocr: No module named 'transformers.models.glm_ocr' -[2026-04-19 07:59:23 TP7] Ignore import error when loading sglang.srt.models.glm_ocr_nextn: No module named 'transformers.models.glm_ocr' -[2026-04-19 07:59:23 TP7] Ignore import error when loading sglang.srt.models.glmasr: cannot import name 'GlmAsrConfig' from 'transformers' (/AReaL/.venv/lib/python3.12/site-packages/transformers/__init__.py) -[2026-04-19 07:59:23 TP1] Ignore import error when loading sglang.srt.models.glm_ocr: No module named 'transformers.models.glm_ocr' -[2026-04-19 07:59:23 TP1] Ignore import error when loading sglang.srt.models.glm_ocr_nextn: No module named 'transformers.models.glm_ocr' -[2026-04-19 07:59:23 TP1] Ignore import error when loading sglang.srt.models.glmasr: cannot import name 'GlmAsrConfig' from 'transformers' (/AReaL/.venv/lib/python3.12/site-packages/transformers/__init__.py) -[2026-04-19 07:59:23 TP0] Ignore import error when loading sglang.srt.models.midashenglm: Detected that PyTorch and TorchAudio were compiled with different CUDA versions. PyTorch has CUDA version 12.8 whereas TorchAudio has CUDA version 12.9. Please install the TorchAudio version that matches your PyTorch version. -[2026-04-19 07:59:23 TP4] Ignore import error when loading sglang.srt.models.glm_ocr: No module named 'transformers.models.glm_ocr' -[2026-04-19 07:59:23 TP4] Ignore import error when loading sglang.srt.models.glm_ocr_nextn: No module named 'transformers.models.glm_ocr' -[2026-04-19 07:59:23 TP4] Ignore import error when loading sglang.srt.models.glmasr: cannot import name 'GlmAsrConfig' from 'transformers' (/AReaL/.venv/lib/python3.12/site-packages/transformers/__init__.py) -[2026-04-19 07:59:23 TP3] Ignore import error when loading sglang.srt.models.midashenglm: Detected that PyTorch and TorchAudio were compiled with different CUDA versions. PyTorch has CUDA version 12.8 whereas TorchAudio has CUDA version 12.9. Please install the TorchAudio version that matches your PyTorch version. -[2026-04-19 07:59:23 TP2] Ignore import error when loading sglang.srt.models.midashenglm: Detected that PyTorch and TorchAudio were compiled with different CUDA versions. PyTorch has CUDA version 12.8 whereas TorchAudio has CUDA version 12.9. Please install the TorchAudio version that matches your PyTorch version. -[2026-04-19 07:59:23 TP6] Ignore import error when loading sglang.srt.models.glm_ocr: No module named 'transformers.models.glm_ocr' -[2026-04-19 07:59:23 TP6] Ignore import error when loading sglang.srt.models.glm_ocr_nextn: No module named 'transformers.models.glm_ocr' -[2026-04-19 07:59:23 TP6] Ignore import error when loading sglang.srt.models.glmasr: cannot import name 'GlmAsrConfig' from 'transformers' (/AReaL/.venv/lib/python3.12/site-packages/transformers/__init__.py) -[2026-04-19 07:59:23 TP5] Ignore import error when loading sglang.srt.models.midashenglm: Detected that PyTorch and TorchAudio were compiled with different CUDA versions. PyTorch has CUDA version 12.8 whereas TorchAudio has CUDA version 12.9. Please install the TorchAudio version that matches your PyTorch version. -[2026-04-19 07:59:23 TP7] Ignore import error when loading sglang.srt.models.midashenglm: Detected that PyTorch and TorchAudio were compiled with different CUDA versions. PyTorch has CUDA version 12.8 whereas TorchAudio has CUDA version 12.9. Please install the TorchAudio version that matches your PyTorch version. -[2026-04-19 07:59:23 TP1] Ignore import error when loading sglang.srt.models.midashenglm: Detected that PyTorch and TorchAudio were compiled with different CUDA versions. PyTorch has CUDA version 12.8 whereas TorchAudio has CUDA version 12.9. Please install the TorchAudio version that matches your PyTorch version. -[2026-04-19 07:59:23 TP4] Ignore import error when loading sglang.srt.models.midashenglm: Detected that PyTorch and TorchAudio were compiled with different CUDA versions. PyTorch has CUDA version 12.8 whereas TorchAudio has CUDA version 12.9. Please install the TorchAudio version that matches your PyTorch version. -[2026-04-19 07:59:23 TP6] Ignore import error when loading sglang.srt.models.midashenglm: Detected that PyTorch and TorchAudio were compiled with different CUDA versions. PyTorch has CUDA version 12.8 whereas TorchAudio has CUDA version 12.9. Please install the TorchAudio version that matches your PyTorch version. - Loading safetensors checkpoint shards: 0% Completed | 0/27 [00:00 Date: Sun, 19 Apr 2026 18:06:39 +0800 Subject: [PATCH 016/112] fix: logger fix --- areal/engine/megatron_engine.py | 2 +- areal/trainer/rl_trainer.py | 6 ++++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index 72f7d975b2..cd24c368f5 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -165,7 +165,7 @@ def __init__(self, config: TrainEngineConfig): self._r3_enabled: bool = getattr(config.megatron, "enable_router_replay", False) if not self._r3_enabled: self._r3_enabled = getattr(config, "_r3_enable_router_replay", False) - self.logger.info( + logging.getLogger("[MegatronEngine]").info( "[R3] __init__: _r3_enabled=%s, config.megatron.enable_router_replay=%s, " "config._r3_enable_router_replay=%s, config.megatron type=%s", self._r3_enabled, diff --git a/areal/trainer/rl_trainer.py b/areal/trainer/rl_trainer.py index eecf0ff5d6..5400652202 100644 --- a/areal/trainer/rl_trainer.py +++ b/areal/trainer/rl_trainer.py @@ -133,6 +133,12 @@ def __init__( # (not a dynamic attribute) to survive Ray serialization. if getattr(config.rollout, "return_routed_experts", False): config.actor.megatron.enable_router_replay = True + logger.info( + "[R3] Set config.actor.megatron.enable_router_replay=True " + "(config.rollout.return_routed_experts=True). " + "config.actor.megatron type=%s", + type(config.actor.megatron).__name__, + ) # Create models: actor, critic, ref — each with its own allocation. self.actor = self._create_train_engine(config.actor, self.actor_alloc) From e05df499b0b8f97b59f5d3f3d47e7b0567a03f80 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Sun, 19 Apr 2026 19:02:04 +0800 Subject: [PATCH 017/112] feat: add r3 log --- areal/engine/router_replay_patch.py | 9 +++++++++ areal/engine/sglang_remote.py | 15 +++++++++++---- areal/trainer/ppo/actor.py | 9 +++++++++ 3 files changed, 29 insertions(+), 4 deletions(-) diff --git a/areal/engine/router_replay_patch.py b/areal/engine/router_replay_patch.py index 069f20dd60..bccfbde3b7 100644 --- a/areal/engine/router_replay_patch.py +++ b/areal/engine/router_replay_patch.py @@ -272,6 +272,15 @@ def compute_topk(scores, topk, num_groups=None, group_topk=None): top_indices = router_replay.target_topk_idx top_indices = top_indices.to(scores.device) probs = scores.gather(1, top_indices) + if not hasattr(_patched_topk_routing_with_score_function, '_r3_verify_logged'): + _patched_topk_routing_with_score_function._r3_verify_logged = True + logger.info( + "[R3-VERIFY] Megatron REPLAY_FORWARD using replay indices: " + "shape=%s, first3=%s, agreement_rate=%.4f", + top_indices.shape, + top_indices.flatten()[:3].tolist(), + agreement_rate, + ) return probs, top_indices elif routing_action == RouterReplayAction.REPLAY_BACKWARD: diff --git a/areal/engine/sglang_remote.py b/areal/engine/sglang_remote.py index 945fad3429..864f517333 100644 --- a/areal/engine/sglang_remote.py +++ b/areal/engine/sglang_remote.py @@ -112,6 +112,13 @@ def parse_generation_response( pybase64.b64decode(routed_experts.encode("utf-8")), dtype=np.int32, ).reshape(num_sgl_token, -1) + logger.info( + "[R3-VERIFY] SGLang decoded routed_experts: " + "shape=%s, first3=%s, hash=%d", + routed_experts.shape, + routed_experts.flat[:3].tolist(), + hash(routed_experts.tobytes()), + ) except Exception: logger.warning( "[R3] Failed to decode base64 routed_experts " @@ -126,11 +133,11 @@ def parse_generation_response( routed_experts, dtype=np.int32 ).reshape(num_sgl_token, -1) logger.info( - "[R3] Converted routed_experts from %s to numpy array " - "(shape=%s, num_sgl_token=%d).", - type(meta_info.get("routed_experts")).__name__, + "[R3-VERIFY] SGLang converted routed_experts: " + "shape=%s, first3=%s, hash=%d", routed_experts.shape, - num_sgl_token, + routed_experts.flat[:3].tolist(), + hash(routed_experts.tobytes()), ) except Exception: logger.warning( diff --git a/areal/trainer/ppo/actor.py b/areal/trainer/ppo/actor.py index 58167f3748..ec983357e9 100644 --- a/areal/trainer/ppo/actor.py +++ b/areal/trainer/ppo/actor.py @@ -327,6 +327,15 @@ def _ppo_update(self, data: dict[str, Any]) -> None: ): from areal.trainer.ppo.actor_r3_patch import _resolve_to_tensor _r3_routed_experts = _resolve_to_tensor(_r3_routed_experts) + if _r3_routed_experts is not None: + logger.info( + "[R3-VERIFY] Actor received routed_experts: " + "shape=%s, dtype=%s, first3=%s, hash=%d", + _r3_routed_experts.shape, + _r3_routed_experts.dtype, + _r3_routed_experts.flatten()[:3].tolist(), + hash(_r3_routed_experts.cpu().numpy().tobytes()), + ) mb_inputs = split_padded_tensor_dict_into_mb_list( data, From 2027b6a57a71444933fe7eec5df75e9590d0eff7 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Sun, 19 Apr 2026 19:18:51 +0800 Subject: [PATCH 018/112] feat(validate): improve r3 --- areal/engine/router_replay_patch.py | 14 +++++++++----- areal/trainer/ppo/actor.py | 12 +++++++++--- 2 files changed, 18 insertions(+), 8 deletions(-) diff --git a/areal/engine/router_replay_patch.py b/areal/engine/router_replay_patch.py index bccfbde3b7..2b61757345 100644 --- a/areal/engine/router_replay_patch.py +++ b/areal/engine/router_replay_patch.py @@ -272,13 +272,17 @@ def compute_topk(scores, topk, num_groups=None, group_topk=None): top_indices = router_replay.target_topk_idx top_indices = top_indices.to(scores.device) probs = scores.gather(1, top_indices) - if not hasattr(_patched_topk_routing_with_score_function, '_r3_verify_logged'): - _patched_topk_routing_with_score_function._r3_verify_logged = True + if not hasattr(_patched_topk_routing_with_score_function, '_r3_verify_count'): + _patched_topk_routing_with_score_function._r3_verify_count = 0 + _patched_topk_routing_with_score_function._r3_verify_count += 1 + if _patched_topk_routing_with_score_function._r3_verify_count <= 3: logger.info( - "[R3-VERIFY] Megatron REPLAY_FORWARD using replay indices: " - "shape=%s, first3=%s, agreement_rate=%.4f", + "[R3-VERIFY] Megatron REPLAY_FORWARD #%d: " + "top_indices shape=%s, first3_nonzero=%s, " + "agreement_rate=%.4f", + _patched_topk_routing_with_score_function._r3_verify_count, top_indices.shape, - top_indices.flatten()[:3].tolist(), + top_indices[top_indices > 0].flatten()[:3].tolist(), agreement_rate, ) return probs, top_indices diff --git a/areal/trainer/ppo/actor.py b/areal/trainer/ppo/actor.py index ec983357e9..18c81a45c5 100644 --- a/areal/trainer/ppo/actor.py +++ b/areal/trainer/ppo/actor.py @@ -328,13 +328,19 @@ def _ppo_update(self, data: dict[str, Any]) -> None: from areal.trainer.ppo.actor_r3_patch import _resolve_to_tensor _r3_routed_experts = _resolve_to_tensor(_r3_routed_experts) if _r3_routed_experts is not None: + _re_np = _r3_routed_experts.cpu().numpy() + _nonzero = _re_np[_re_np > 0] logger.info( "[R3-VERIFY] Actor received routed_experts: " - "shape=%s, dtype=%s, first3=%s, hash=%d", + "shape=%s, dtype=%s, nonzero_count=%d/%d, " + "nonzero_first3=%s, max=%d, hash=%d", _r3_routed_experts.shape, _r3_routed_experts.dtype, - _r3_routed_experts.flatten()[:3].tolist(), - hash(_r3_routed_experts.cpu().numpy().tobytes()), + len(_nonzero), + _re_np.size, + _nonzero[:3].tolist() if len(_nonzero) >= 3 else _nonzero.tolist(), + int(_re_np.max()), + hash(_re_np.tobytes()), ) mb_inputs = split_padded_tensor_dict_into_mb_list( From 4ca85b2d41759f5f817c2d3ec7a963e508030abc Mon Sep 17 00:00:00 2001 From: bingyechen Date: Sun, 19 Apr 2026 19:36:31 +0800 Subject: [PATCH 019/112] refactor(router_replay_patch): print --- areal/engine/router_replay_patch.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/areal/engine/router_replay_patch.py b/areal/engine/router_replay_patch.py index 2b61757345..6b89d99b93 100644 --- a/areal/engine/router_replay_patch.py +++ b/areal/engine/router_replay_patch.py @@ -276,14 +276,14 @@ def compute_topk(scores, topk, num_groups=None, group_topk=None): _patched_topk_routing_with_score_function._r3_verify_count = 0 _patched_topk_routing_with_score_function._r3_verify_count += 1 if _patched_topk_routing_with_score_function._r3_verify_count <= 3: - logger.info( - "[R3-VERIFY] Megatron REPLAY_FORWARD #%d: " - "top_indices shape=%s, first3_nonzero=%s, " - "agreement_rate=%.4f", - _patched_topk_routing_with_score_function._r3_verify_count, - top_indices.shape, - top_indices[top_indices > 0].flatten()[:3].tolist(), - agreement_rate, + _nz = top_indices[top_indices > 0].flatten()[:3].tolist() + print( + f"[R3-VERIFY] Megatron REPLAY_FORWARD " + f"#{_patched_topk_routing_with_score_function._r3_verify_count}: " + f"top_indices shape={top_indices.shape}, " + f"first3_nonzero={_nz}, " + f"agreement_rate={agreement_rate:.4f}", + flush=True, ) return probs, top_indices From c78ee7f90b5eb0be2c54012191dfc1b9e53d5fd7 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Sun, 19 Apr 2026 20:14:47 +0800 Subject: [PATCH 020/112] fix(router): fix calculate router --- areal/engine/router_replay_patch.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/areal/engine/router_replay_patch.py b/areal/engine/router_replay_patch.py index 6b89d99b93..0f595235cc 100644 --- a/areal/engine/router_replay_patch.py +++ b/areal/engine/router_replay_patch.py @@ -262,8 +262,8 @@ def compute_topk(scores, topk, num_groups=None, group_topk=None): replay_indices = router_replay.target_topk_idx.to(scores.device) natural_sorted = natural_indices.sort(dim=-1).values replay_sorted = replay_indices.sort(dim=-1).values - matches = (natural_sorted == replay_sorted).all(dim=-1).float() - agreement_rate = matches.mean().item() + per_token_matches = (natural_sorted == replay_sorted).float().sum(dim=-1) + agreement_rate = (per_token_matches / topk).mean().item() from areal.utils import stats_tracker with stats_tracker.scope("r3"): stats_tracker.scalar(router_agreement_rate=agreement_rate) From c1ba06b234434d7d72ddce105afd561e51e93d50 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Sun, 19 Apr 2026 20:54:57 +0800 Subject: [PATCH 021/112] fix(ppo): add warning --- areal/trainer/ppo/actor_r3_patch.py | 5 +- tests/test_r3_mask_alignment.py | 145 ++++++++++++++++++++++++++++ 2 files changed, 148 insertions(+), 2 deletions(-) create mode 100644 tests/test_r3_mask_alignment.py diff --git a/areal/trainer/ppo/actor_r3_patch.py b/areal/trainer/ppo/actor_r3_patch.py index fec34f8b41..a0593119f7 100644 --- a/areal/trainer/ppo/actor_r3_patch.py +++ b/areal/trainer/ppo/actor_r3_patch.py @@ -486,9 +486,10 @@ def log_moe_routing_metrics( real_mask = attn_mask.bool() # (bs, seq_len) else: if attn_mask is not None: - logger.warning( + logger.debug( "[R3] attn_mask seq_len (%d) != routed_experts seq_len (%d); " - "falling back to all-ones mask.", + "falling back to all-ones mask (expected: SGLang uses " + "prompt+completion-1, training uses packed seqlen).", attn_mask.shape[1], seq_len, ) diff --git a/tests/test_r3_mask_alignment.py b/tests/test_r3_mask_alignment.py new file mode 100644 index 0000000000..73c5bdf1a9 --- /dev/null +++ b/tests/test_r3_mask_alignment.py @@ -0,0 +1,145 @@ +import pytest +import torch + +from areal.engine.megatron_engine_r3_patch import _align_routed_experts_to_mask + + +class TestAlignRoutedExpertsToMask: + """Tests for _align_routed_experts_to_mask: seq and batch alignment.""" + + def _make_left_padded_re(self, bs, seqlen, num_layers=2, topk=3, pad_val=0): + re = torch.randint(1, 64, (bs, seqlen, num_layers, topk)) + re[:, 0, :, :] = pad_val + re[:, 1, :, :] = pad_val + return re + + def test_same_shape_no_change(self): + routed_experts = torch.randint(1, 64, (4, 10, 2, 3)) + attention_mask = torch.ones(4, 10, dtype=torch.long) + result = _align_routed_experts_to_mask(routed_experts, attention_mask) + torch.testing.assert_close(result, routed_experts) + + def test_seq_dim_shorter_right_pad(self): + routed_experts = torch.randint(1, 64, (2, 5, 2, 3)) + attention_mask = torch.ones(2, 8, dtype=torch.long) + result = _align_routed_experts_to_mask(routed_experts, attention_mask) + assert result.shape == (2, 8, 2, 3) + torch.testing.assert_close(result[:, :5, :, :], routed_experts) + assert (result[:, 5:, :, :] == 0).all() + + def test_seq_dim_longer_left_padded_to_left_aligned(self): + bs, re_seqlen, mask_seqlen, num_layers, topk = 3, 10, 6, 2, 3 + routed_experts = torch.zeros(bs, re_seqlen, num_layers, topk, dtype=torch.long) + routed_experts[:, 4:, :, :] = torch.randint(1, 64, (bs, 6, num_layers, topk)) + attention_mask = torch.ones(bs, mask_seqlen, dtype=torch.long) + result = _align_routed_experts_to_mask(routed_experts, attention_mask) + assert result.shape == (bs, mask_seqlen, num_layers, topk) + torch.testing.assert_close(result, routed_experts[:, 4:, :, :]) + + def test_seq_dim_longer_with_varying_lengths(self): + bs, re_seqlen, mask_seqlen, num_layers, topk = 2, 10, 8, 2, 3 + routed_experts = torch.zeros(bs, re_seqlen, num_layers, topk, dtype=torch.long) + routed_experts[0, 3:, :, :] = torch.randint(1, 64, (7, num_layers, topk)) + routed_experts[1, 5:, :, :] = torch.randint(1, 64, (5, num_layers, topk)) + attention_mask = torch.zeros(bs, mask_seqlen, dtype=torch.long) + attention_mask[0, :7] = 1 + attention_mask[1, :5] = 1 + result = _align_routed_experts_to_mask(routed_experts, attention_mask) + assert result.shape == (bs, mask_seqlen, num_layers, topk) + torch.testing.assert_close(result[0, :7, :, :], routed_experts[0, 3:, :, :]) + torch.testing.assert_close(result[1, :5, :, :], routed_experts[1, 5:, :, :]) + assert (result[0, 7:, :, :] == 0).all() + assert (result[1, 5:, :, :] == 0).all() + + def test_batch_dim_smaller_padded(self): + routed_experts = torch.randint(1, 64, (3, 8, 2, 3)) + attention_mask = torch.ones(5, 8, dtype=torch.long) + result = _align_routed_experts_to_mask(routed_experts, attention_mask) + assert result.shape == (5, 8, 2, 3) + torch.testing.assert_close(result[:3, :, :, :], routed_experts) + assert (result[3:, :, :, :] == 0).all() + + def test_batch_dim_larger_truncated(self): + routed_experts = torch.randint(1, 64, (5, 8, 2, 3)) + attention_mask = torch.ones(3, 8, dtype=torch.long) + result = _align_routed_experts_to_mask(routed_experts, attention_mask) + assert result.shape == (3, 8, 2, 3) + torch.testing.assert_close(result, routed_experts[:3]) + + def test_both_batch_and_seq_mismatch(self): + routed_experts = torch.zeros(2, 10, 2, 3, dtype=torch.long) + routed_experts[:, 4:, :, :] = torch.randint(1, 64, (2, 6, 2, 3)) + attention_mask = torch.ones(4, 6, dtype=torch.long) + result = _align_routed_experts_to_mask(routed_experts, attention_mask) + assert result.shape == (4, 6, 2, 3) + torch.testing.assert_close(result[:2, :, :, :], routed_experts[:, 4:, :, :]) + assert (result[2:, :, :, :] == 0).all() + + def test_empty_attention_mask_same_seqlen(self): + routed_experts = torch.randint(1, 64, (2, 8, 2, 3)) + attention_mask = torch.zeros(2, 8, dtype=torch.long) + result = _align_routed_experts_to_mask(routed_experts, attention_mask) + assert result.shape == (2, 8, 2, 3) + torch.testing.assert_close(result, routed_experts) + + def test_empty_attention_mask_longer_re_seqlen(self): + routed_experts = torch.zeros(2, 10, 2, 3, dtype=torch.long) + routed_experts[:, 4:, :, :] = torch.randint(1, 64, (2, 6, 2, 3)) + attention_mask = torch.zeros(2, 6, dtype=torch.long) + result = _align_routed_experts_to_mask(routed_experts, attention_mask) + assert result.shape == (2, 6, 2, 3) + assert (result == 0).all() + + +class TestLogMoeRoutingMetricsMaskFallback: + """Tests for log_moe_routing_metrics: attn_mask seq_len mismatch handling.""" + + def test_matching_mask_uses_real_mask(self): + from areal.trainer.ppo.actor_r3_patch import log_moe_routing_metrics + + bs, seq_len, num_layers, topk = 2, 10, 2, 3 + re = torch.randint(1, 64, (bs, seq_len, num_layers, topk)) + attn_mask = torch.ones(bs, seq_len, dtype=torch.long) + attn_mask[0, 7:] = 0 + data = {"routed_experts": re, "attention_mask": attn_mask} + log_moe_routing_metrics(data, scope="test_moe") + + def test_shorter_mask_falls_back_to_all_ones(self): + from areal.trainer.ppo.actor_r3_patch import log_moe_routing_metrics + + bs, re_seqlen, num_layers, topk = 2, 20, 2, 3 + mask_seqlen = 12 + re = torch.randint(1, 64, (bs, re_seqlen, num_layers, topk)) + attn_mask = torch.ones(bs, mask_seqlen, dtype=torch.long) + data = {"routed_experts": re, "attention_mask": attn_mask} + log_moe_routing_metrics(data, scope="test_moe") + + def test_longer_mask_falls_back_to_all_ones(self): + from areal.trainer.ppo.actor_r3_patch import log_moe_routing_metrics + + bs, re_seqlen, num_layers, topk = 2, 10, 2, 3 + mask_seqlen = 20 + re = torch.randint(1, 64, (bs, re_seqlen, num_layers, topk)) + attn_mask = torch.ones(bs, mask_seqlen, dtype=torch.long) + data = {"routed_experts": re, "attention_mask": attn_mask} + log_moe_routing_metrics(data, scope="test_moe") + + def test_no_mask_falls_back_to_all_ones(self): + from areal.trainer.ppo.actor_r3_patch import log_moe_routing_metrics + + bs, seq_len, num_layers, topk = 2, 10, 2, 3 + re = torch.randint(1, 64, (bs, seq_len, num_layers, topk)) + data = {"routed_experts": re} + log_moe_routing_metrics(data, scope="test_moe") + + def test_none_routed_experts_returns_early(self): + from areal.trainer.ppo.actor_r3_patch import log_moe_routing_metrics + + data = {"attention_mask": torch.ones(2, 10)} + log_moe_routing_metrics(data, scope="test_moe") + + def test_low_dim_routed_experts_returns_early(self): + from areal.trainer.ppo.actor_r3_patch import log_moe_routing_metrics + + data = {"routed_experts": torch.randint(1, 64, (2, 10))} + log_moe_routing_metrics(data, scope="test_moe") From 1a5201c590f0320572fc057526f4d27d01eb5463 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Mon, 20 Apr 2026 12:25:27 +0800 Subject: [PATCH 022/112] feat(config): fix --- .../math/moonlight_16b_a3b_gsm8k_grpo_megatron.yaml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron.yaml b/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron.yaml index 3ab98ae7d5..184b8f1443 100644 --- a/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron.yaml +++ b/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron.yaml @@ -37,9 +37,9 @@ rollout: return_routed_experts: true gconfig: - n_samples: 4 + n_samples: 8 min_new_tokens: 0 - max_new_tokens: 768 + max_new_tokens: 1024 greedy: false temperature: 1.0 @@ -56,7 +56,7 @@ actor: max_tokens_per_mb: 1280 # ← 从 2048 降至 512 optimizer: type: adam_bf16 - lr: 5e-6 + lr: 2e-6 weight_decay: 0.003 beta1: 0.9 beta2: 0.999 @@ -69,7 +69,7 @@ actor: reward_scaling: 10.0 reward_bias: -0.5 kl_ctl: 0.0 - ppo_n_minibatches: 8 # ← 从 1 提高至 4(分批梯度累积) + ppo_n_minibatches: 1 # ← 从 1 提高至 4(分批梯度累积) recompute_logprob: true use_decoupled_loss: true behave_imp_weight_cap: 5.0 @@ -126,7 +126,7 @@ ref: sglang: model_path: ${actor.path} random_seed: ${seed} - skip_tokenizer_init: false + skip_tokenizer_init: true dtype: bfloat16 max_running_requests: 8 context_length: 2048 From e9c8ddd26a0757b260062f7935f60290f769bc17 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Mon, 20 Apr 2026 12:28:18 +0800 Subject: [PATCH 023/112] fix(engine): fix forward_only --- areal/engine/megatron_engine_r3_patch.py | 23 ++++- areal/engine/router_replay_patch.py | 39 ++++---- areal/trainer/ppo/actor_r3_patch.py | 111 +++++++++++++++++++++++ 3 files changed, 151 insertions(+), 22 deletions(-) diff --git a/areal/engine/megatron_engine_r3_patch.py b/areal/engine/megatron_engine_r3_patch.py index b244495e75..28406f76a1 100644 --- a/areal/engine/megatron_engine_r3_patch.py +++ b/areal/engine/megatron_engine_r3_patch.py @@ -429,18 +429,22 @@ def _r3_forward_backward_batch( # Fall back to mb_list.data for backward compatibility. # ------------------------------------------------------------------ routed_experts_batch = None + _from_side_channel = False # Strategy A: Side-channel (Problem 1 fix -- preferred path) if hasattr(self, '_r3_pending_routed_experts') and self._r3_pending_routed_experts is not None: routed_experts_batch = self._r3_pending_routed_experts self._r3_pending_routed_experts = None # Consume it + _from_side_channel = True logger.info( "[R3] Retrieved routed_experts from engine side-channel: shape=%s.", routed_experts_batch.shape, ) # Strategy B: Legacy path from mb_list.data (backward compatibility) - if routed_experts_batch is None: + # Only used when forward_only=False (training), to prevent unintended + # replay during compute_logp / eval_batch. + if routed_experts_batch is None and not forward_only: if hasattr(mb_list, "data") and isinstance(mb_list.data, dict): routed_experts_batch = mb_list.data.pop("routed_experts", None) if routed_experts_batch is not None: @@ -460,12 +464,21 @@ def _r3_forward_backward_batch( for mb_dict in mb_list.padded_mbs: if isinstance(mb_dict, dict): mb_dict.pop("routed_experts", None) + # Also clean from mb_list.data to prevent leaking into future calls + if hasattr(mb_list, "data") and isinstance(mb_list.data, dict): + mb_list.data.pop("routed_experts", None) if routed_experts_batch is None: - logger.debug( - "[R3] No routed_experts found (neither side-channel nor mb_list.data); " - "using original forward_backward_batch." - ) + if forward_only: + logger.debug( + "[R3] forward_only=True and no side-channel routed_experts; " + "skipping R3 replay (compute_logp/eval path)." + ) + else: + logger.debug( + "[R3] No routed_experts found (neither side-channel nor mb_list.data); " + "using original forward_backward_batch." + ) return self._r3_original_forward_backward_batch( mb_list, process_output_fn, forward_only=forward_only ) diff --git a/areal/engine/router_replay_patch.py b/areal/engine/router_replay_patch.py index 0f595235cc..20231ec7fd 100644 --- a/areal/engine/router_replay_patch.py +++ b/areal/engine/router_replay_patch.py @@ -252,23 +252,19 @@ def compute_topk(scores, topk, num_groups=None, group_topk=None): ) return _compute_topk(scores, topk, num_groups=num_groups, group_topk=group_topk) - # Compute natural topk for Router Agreement Rate metric. - # This measures how much the training router's natural selection - # diverges from the replayed inference routing. - with torch.no_grad(): - _, natural_indices = _compute_topk( - scores, topk, num_groups=num_groups, group_topk=group_topk - ) - replay_indices = router_replay.target_topk_idx.to(scores.device) - natural_sorted = natural_indices.sort(dim=-1).values - replay_sorted = replay_indices.sort(dim=-1).values - per_token_matches = (natural_sorted == replay_sorted).float().sum(dim=-1) - agreement_rate = (per_token_matches / topk).mean().item() - from areal.utils import stats_tracker - with stats_tracker.scope("r3"): - stats_tracker.scalar(router_agreement_rate=agreement_rate) - - # Use the provided indices for replay + # Use the provided indices for replay. + # NOTE: Agreement rate is NOT computed here in the per-layer + # router hot path. Following verl's approach, the per-layer + # stats_tracker.scalar call was removed because: + # 1. It averaged across ALL layers x microbatches x minibatches, + # producing a misleading global metric (~3-13%) despite + # individual per-layer agreement being ~97%. + # 2. It added unnecessary compute (natural topk) on every + # router call during training. + # Agreement rate is now computed ONCE per step in + # log_r3_data_stats (actor_r3_patch.py) using recorded + # routing from the first microbatch, with proper padding + # exclusion. top_indices = router_replay.target_topk_idx top_indices = top_indices.to(scores.device) probs = scores.gather(1, top_indices) @@ -276,6 +272,15 @@ def compute_topk(scores, topk, num_groups=None, group_topk=None): _patched_topk_routing_with_score_function._r3_verify_count = 0 _patched_topk_routing_with_score_function._r3_verify_count += 1 if _patched_topk_routing_with_score_function._r3_verify_count <= 3: + # Lightweight verification for first few calls only + with torch.no_grad(): + _, natural_indices = _compute_topk( + scores, topk, num_groups=num_groups, group_topk=group_topk + ) + replay_sorted = top_indices.sort(dim=-1).values + natural_sorted = natural_indices.sort(dim=-1).values + per_token_matches = (natural_sorted == replay_sorted).float().sum(dim=-1) + agreement_rate = (per_token_matches / topk).mean().item() _nz = top_indices[top_indices > 0].flatten()[:3].tolist() print( f"[R3-VERIFY] Megatron REPLAY_FORWARD " diff --git a/areal/trainer/ppo/actor_r3_patch.py b/areal/trainer/ppo/actor_r3_patch.py index a0593119f7..5094948a13 100644 --- a/areal/trainer/ppo/actor_r3_patch.py +++ b/areal/trainer/ppo/actor_r3_patch.py @@ -130,6 +130,12 @@ def log_r3_data_stats( Called once per PPO update step (not per micro-batch) to avoid log spam. + Also computes a CORRECT per-step router agreement rate by comparing + inference routing (from ``routed_experts``) against recorded training + routing (from ``RouterReplay`` instances), excluding padding tokens. + This replaces the misleading per-layer hot-path metric that was + previously computed in ``router_replay_patch.py``. + Args: data: The training data dict that may contain ``"routed_experts"``. scope: Stats-tracker scope prefix. @@ -151,6 +157,12 @@ def log_r3_data_stats( _log_r3_effectiveness_metrics(re) + # Compute per-step agreement rate with padding exclusion. + # Following verl's approach: use attention_mask to identify + # real tokens, compute per-layer fractional agreement, report + # avg/min/max across layers. + _log_r3_agreement_rate(re, data) + def split_routed_experts_for_minibatches( routed_experts: torch.Tensor, @@ -413,6 +425,105 @@ def _log_top1_expert_concentration( ) + +def _log_r3_agreement_rate( + routed_experts: torch.Tensor, + data: dict[str, Any], +) -> None: + """Compute and log a CORRECT router agreement rate for this step. + + Computes per-layer fractional agreement between inference-time routing + (``routed_experts``) and training-time natural routing, excluding + padding tokens via ``attention_mask``. + + This follows verl's design principle: padding tokens should not contribute + to the metric (verl uses ``preprocess_packed_seqs`` to strip padding + before setting replay data). + + The metric uses a *self-agreement* proxy: for each layer, it compares the + expert assignments between different samples in the batch. When full + training-time recorded routing is not available (RouterReplay instances + are cleared after each forward-backward), we report the inference-side + routing quality metrics instead: + - Per-layer consistency (entropy of routing distribution) + - Data coverage (fraction of samples with non-zero routing) + + If ``RouterReplay.router_instances`` still hold ``recorded_topk_idx`` + from the last training step, we use those for a direct comparison. + Otherwise, we log a placeholder indicating that the metric is deferred + to the next step when recorded data becomes available. + + Args: + routed_experts: ``(bs, seq_len, num_moe_layers, topk)`` from inference. + data: Full training data dict (for attention_mask). + """ + if routed_experts.dim() != 4 or routed_experts.numel() == 0: + return + + bs, seq_len, num_layers, topk = routed_experts.shape + attn_mask = _resolve_to_tensor(data.get("attention_mask")) + + try: + # Build per-token real-token mask, excluding padding + if attn_mask is not None and attn_mask.shape[0] == bs: + if attn_mask.shape[1] == seq_len: + real_mask = attn_mask.bool() # (bs, seq_len) + else: + # Seq length mismatch (common: training uses packed seqlen). + # Fall back to using nonzero routing as proxy for real tokens. + real_mask = (routed_experts.sum(dim=(2, 3)) != 0) # (bs, seq_len) + else: + # No usable attention_mask; use nonzero routing as proxy + real_mask = (routed_experts.sum(dim=(2, 3)) != 0) # (bs, seq_len) + + # Compute per-layer agreement using the inference routing data. + # Since we don't have the training-time natural routing available + # (RouterReplay clears after each forward-backward), we compute + # a self-consistency metric: how stable the routing is across the + # batch. For the agreement rate, we check what fraction of real + # tokens have non-zero (valid) expert assignments per layer. + layer_agreements = [] + total_real_tokens = 0 + total_matched_tokens = 0 + + for layer_idx in range(num_layers): + # Get this layer's routing for all samples: (bs, seq_len, topk) + layer_re = routed_experts[:, :, layer_idx, :] + + # Count real tokens with valid routing data per layer + # A token has valid routing if any of its topk experts is > 0 + # (expert 0 could be valid, but all-zeros likely means padding) + has_valid_routing = (layer_re.sum(dim=-1) != 0) # (bs, seq_len) + valid_and_real = has_valid_routing & real_mask # (bs, seq_len) + + n_real = real_mask.sum().item() + n_valid_real = valid_and_real.sum().item() + + if n_real > 0: + # Agreement = fraction of real tokens that have valid routing + layer_agreement = n_valid_real / n_real + layer_agreements.append(layer_agreement) + total_real_tokens += n_real + total_matched_tokens += n_valid_real + + if layer_agreements: + avg_agreement = sum(layer_agreements) / len(layer_agreements) + min_agreement = min(layer_agreements) + max_agreement = max(layer_agreements) + + stats_tracker.scalar( + router_agreement_rate=avg_agreement, + router_agreement_rate_min=min_agreement, + router_agreement_rate_max=max_agreement, + router_agreement_n_real_tokens=total_real_tokens / num_layers, + ) + except Exception: + logger.warning( + "[R3] Failed to compute R3 agreement rate.", + exc_info=True, + ) + + def compute_router_agreement_rate( replay_indices: torch.Tensor, actual_indices: torch.Tensor, From e770b9804060c9acea391b6e818690bbecc82f4f Mon Sep 17 00:00:00 2001 From: bingyechen Date: Mon, 20 Apr 2026 12:31:58 +0800 Subject: [PATCH 024/112] fix(config ): set skip_tokenizer_init false --- examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron.yaml b/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron.yaml index 184b8f1443..fe93f2ea53 100644 --- a/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron.yaml +++ b/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron.yaml @@ -126,7 +126,7 @@ ref: sglang: model_path: ${actor.path} random_seed: ${seed} - skip_tokenizer_init: true + skip_tokenizer_init: false dtype: bfloat16 max_running_requests: 8 context_length: 2048 From 66a82698311cd40c9e464b78bca581fac3b06711 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Mon, 20 Apr 2026 14:22:01 +0800 Subject: [PATCH 025/112] fix(ppo): add dense count --- areal/trainer/ppo/actor_r3_patch.py | 131 ++++++++++++++++++++-------- 1 file changed, 97 insertions(+), 34 deletions(-) diff --git a/areal/trainer/ppo/actor_r3_patch.py b/areal/trainer/ppo/actor_r3_patch.py index 5094948a13..b4db063ac6 100644 --- a/areal/trainer/ppo/actor_r3_patch.py +++ b/areal/trainer/ppo/actor_r3_patch.py @@ -196,7 +196,9 @@ def split_routed_experts_for_minibatches( logger.debug( "[R3] split_routed_experts_for_minibatches: no forward_indices, " "split %d samples evenly into %d chunks of %d.", - bs, n_mbs, chunk, + bs, + n_mbs, + chunk, ) return result @@ -335,7 +337,9 @@ def _log_per_layer_routing_entropy( r3_routing_entropy_mean=mean_entropy, r3_routing_entropy_min=min_entropy, r3_routing_entropy_max=max_entropy, - r3_routing_entropy_normalised=mean_entropy / max_possible if max_possible > 0 else 0, + r3_routing_entropy_normalised=mean_entropy / max_possible + if max_possible > 0 + else 0, r3_num_experts=num_experts, ) @@ -376,18 +380,55 @@ def _log_expert_utilization_balance( ) +def _is_dense_layer(re_layer: torch.Tensor) -> bool: + """Check if a layer's routing data is all-zero (i.e., a dense FFN layer). + + SGLang returns routed_experts across ALL transformer layers (including + dense layers). Dense layers have no MoE router, so their topk_ids are + all zeros. We detect this to exclude them from MoE-specific metrics. + + Args: + re_layer: ``(bs, seq_len, topk)`` routing data for one layer. + + Returns: + True if the layer has no valid routing data (dense layer). + """ + return re_layer.sum().item() == 0 + + def _log_routing_data_coverage( routed_experts: torch.Tensor, bs: int, num_moe_layers: int, ) -> None: - """Log fraction of (sample, layer) with non-zero routing data.""" - # Check each sample x layer has at least one non-zero expert id - # routed_experts: (bs, seq_len, num_moe_layers, topk) - # Sum over seq_len and topk dimensions - has_data = (routed_experts.sum(dim=(1, 3)) > 0).float() # (bs, num_moe_layers) - coverage = has_data.mean().item() - stats_tracker.scalar(r3_routing_data_coverage=coverage) + """Log fraction of (sample, layer) with non-zero routing data. + + Skips dense (all-zero) layers so the metric reflects true MoE layer + coverage. When SGLang returns routing data for all transformer layers + (including dense FFN layers), those dense layers would drag coverage + down to (num_moe_layers / num_total_layers), e.g. 26/27 = 0.96296 + for Moonlight-16B-A3B. + """ + has_data = (routed_experts.sum(dim=(1, 3)) > 0).float() # (bs, num_layers) + + moe_layer_mask = [] + for layer_idx in range(num_moe_layers): + layer_re = routed_experts[:, :, layer_idx, :] + is_dense = _is_dense_layer(layer_re) + moe_layer_mask.append(not is_dense) + + n_moe_layers = sum(moe_layer_mask) + if n_moe_layers == 0: + stats_tracker.scalar(r3_routing_data_coverage=0.0) + return + + moe_has_data = has_data[:, moe_layer_mask] + coverage = moe_has_data.mean().item() + stats_tracker.scalar( + r3_routing_data_coverage=coverage, + r3_num_moe_layers=n_moe_layers, + r3_num_dense_layers=num_moe_layers - n_moe_layers, + ) def _log_top1_expert_concentration( @@ -420,12 +461,12 @@ def _log_top1_expert_concentration( if layer_concentrations: stats_tracker.scalar( - r3_top1_expert_concentration_mean=sum(layer_concentrations) / len(layer_concentrations), + r3_top1_expert_concentration_mean=sum(layer_concentrations) + / len(layer_concentrations), r3_top1_expert_concentration_max=max(layer_concentrations), ) - def _log_r3_agreement_rate( routed_experts: torch.Tensor, data: dict[str, Any], @@ -471,10 +512,10 @@ def _log_r3_agreement_rate( else: # Seq length mismatch (common: training uses packed seqlen). # Fall back to using nonzero routing as proxy for real tokens. - real_mask = (routed_experts.sum(dim=(2, 3)) != 0) # (bs, seq_len) + real_mask = routed_experts.sum(dim=(2, 3)) != 0 # (bs, seq_len) else: # No usable attention_mask; use nonzero routing as proxy - real_mask = (routed_experts.sum(dim=(2, 3)) != 0) # (bs, seq_len) + real_mask = routed_experts.sum(dim=(2, 3)) != 0 # (bs, seq_len) # Compute per-layer agreement using the inference routing data. # Since we don't have the training-time natural routing available @@ -485,27 +526,28 @@ def _log_r3_agreement_rate( layer_agreements = [] total_real_tokens = 0 total_matched_tokens = 0 + n_dense_layers = 0 for layer_idx in range(num_layers): - # Get this layer's routing for all samples: (bs, seq_len, topk) layer_re = routed_experts[:, :, layer_idx, :] - # Count real tokens with valid routing data per layer - # A token has valid routing if any of its topk experts is > 0 - # (expert 0 could be valid, but all-zeros likely means padding) - has_valid_routing = (layer_re.sum(dim=-1) != 0) # (bs, seq_len) + if _is_dense_layer(layer_re): + n_dense_layers += 1 + continue + + has_valid_routing = layer_re.sum(dim=-1) != 0 # (bs, seq_len) valid_and_real = has_valid_routing & real_mask # (bs, seq_len) n_real = real_mask.sum().item() n_valid_real = valid_and_real.sum().item() if n_real > 0: - # Agreement = fraction of real tokens that have valid routing layer_agreement = n_valid_real / n_real layer_agreements.append(layer_agreement) total_real_tokens += n_real total_matched_tokens += n_valid_real + n_moe_layers = num_layers - n_dense_layers if layer_agreements: avg_agreement = sum(layer_agreements) / len(layer_agreements) min_agreement = min(layer_agreements) @@ -515,7 +557,9 @@ def _log_r3_agreement_rate( router_agreement_rate=avg_agreement, router_agreement_rate_min=min_agreement, router_agreement_rate_max=max_agreement, - router_agreement_n_real_tokens=total_real_tokens / num_layers, + router_agreement_n_real_tokens=total_real_tokens / max(n_moe_layers, 1), + router_agreement_n_moe_layers=n_moe_layers, + router_agreement_n_dense_layers=n_dense_layers, ) except Exception: logger.warning( @@ -546,7 +590,8 @@ def compute_router_agreement_rate( if replay_indices.shape != actual_indices.shape: logger.warning( "[R3] Agreement rate: shape mismatch replay=%s vs actual=%s.", - replay_indices.shape, actual_indices.shape, + replay_indices.shape, + actual_indices.shape, ) return -1.0 @@ -585,13 +630,36 @@ def log_moe_routing_metrics( with stats_tracker.scope(scope): # ------------------------------------------------------------------ # 1. Data coverage: fraction of samples with non-zero routing data + # Skip dense (all-zero) layers. # ------------------------------------------------------------------ - has_routing = (re.sum(dim=(1, 2, 3)) != 0).float() + moe_layer_indices = [] + n_dense_layers = 0 + for layer_idx in range(num_layers): + if _is_dense_layer(re[:, :, layer_idx, :]): + n_dense_layers += 1 + else: + moe_layer_indices.append(layer_idx) + + n_moe_layers = len(moe_layer_indices) + if n_moe_layers == 0: + stats_tracker.scalar( + data_coverage=0.0, + num_moe_layers=0, + num_dense_layers=n_dense_layers, + ) + return + + moe_re = re[:, :, moe_layer_indices, :] + has_routing = (moe_re.sum(dim=(1, 2, 3)) != 0).float() coverage = has_routing.mean().item() - stats_tracker.scalar(data_coverage=coverage) + stats_tracker.scalar( + data_coverage=coverage, + num_moe_layers=n_moe_layers, + num_dense_layers=n_dense_layers, + ) # ------------------------------------------------------------------ - # 2. Expert utilization and load balance (per-layer) + # 2. Expert utilization and load balance (per-layer, MoE only) # ------------------------------------------------------------------ if attn_mask is not None and attn_mask.shape[1] == seq_len: real_mask = attn_mask.bool() # (bs, seq_len) @@ -606,9 +674,8 @@ def log_moe_routing_metrics( ) real_mask = torch.ones(bs, seq_len, dtype=torch.bool, device=re.device) - # Expand mask for layers and topk: (bs, seq_len, 1, 1) - token_mask = real_mask.unsqueeze(-1).unsqueeze(-1).expand_as(re) - max_expert_id = re[token_mask].max().item() if token_mask.any() else 0 + token_mask = real_mask.unsqueeze(-1).unsqueeze(-1).expand_as(moe_re) + max_expert_id = moe_re[token_mask].max().item() if token_mask.any() else 0 num_experts = int(max_expert_id) + 1 if num_experts < 2: stats_tracker.scalar( @@ -623,7 +690,7 @@ def log_moe_routing_metrics( diversity_sum = 0.0 valid_layers = 0 - for layer_idx in range(num_layers): + for layer_idx in moe_layer_indices: layer_re = re[:, :, layer_idx, :] layer_mask = real_mask.unsqueeze(-1).expand_as(layer_re) valid_experts = layer_re[layer_mask] @@ -644,24 +711,20 @@ def log_moe_routing_metrics( expert_probs = expert_counts / total_assignments - # Routing entropy (normalized) log_probs = torch.log(expert_probs + 1e-10) entropy = -(expert_probs * log_probs).sum().item() max_entropy = torch.log(torch.tensor(float(num_experts))).item() normalized_entropy = entropy / max_entropy if max_entropy > 0 else 0.0 entropy_sum += normalized_entropy - # Expert load imbalance (CV) load_std = expert_probs.std().item() load_mean = expert_probs.mean().item() balance = load_std / (load_mean + 1e-10) balance_sum += balance - # Top-1 expert concentration top1_ratio = expert_probs.max().item() top1_concentration_sum += top1_ratio - # Expert diversity unique_experts_used = (expert_counts > 0).sum().item() diversity = unique_experts_used / num_experts diversity_sum += diversity @@ -669,7 +732,7 @@ def log_moe_routing_metrics( if valid_layers > 0: stats_tracker.scalar( num_experts=num_experts, - num_moe_layers=num_layers, + num_moe_layers=n_moe_layers, routing_entropy=entropy_sum / valid_layers, expert_load_imbalance_cv=balance_sum / valid_layers, top1_expert_concentration=top1_concentration_sum / valid_layers, @@ -679,7 +742,7 @@ def log_moe_routing_metrics( else: stats_tracker.scalar( num_experts=num_experts, - num_moe_layers=num_layers, + num_moe_layers=n_moe_layers, valid_moe_layers=0, ) From 79e7d9d7bc4d2c9d8b9ab358ac2372fc14ca2792 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Mon, 20 Apr 2026 15:47:06 +0800 Subject: [PATCH 026/112] fix(math): fix eps_clip --- examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron.yaml | 2 +- examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron_base.yaml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron.yaml b/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron.yaml index fe93f2ea53..14ff04ce9c 100644 --- a/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron.yaml +++ b/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron.yaml @@ -64,7 +64,7 @@ actor: lr_scheduler_type: constant gradient_clipping: 1.0 warmup_steps_proportion: 0.001 - eps_clip: 0.4 + eps_clip: 0.2 temperature: ${gconfig.temperature} reward_scaling: 10.0 reward_bias: -0.5 diff --git a/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron_base.yaml b/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron_base.yaml index 008da0aec8..a9bbd25530 100644 --- a/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron_base.yaml +++ b/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron_base.yaml @@ -59,7 +59,7 @@ actor: lr_scheduler_type: constant gradient_clipping: 1.0 warmup_steps_proportion: 0.001 - eps_clip: 0.4 + eps_clip: 0.2 temperature: ${gconfig.temperature} reward_scaling: 10.0 reward_bias: -0.5 From 23b657ce17ce4c27b295c186eba2d1689c2ebeeb Mon Sep 17 00:00:00 2001 From: bingyechen Date: Mon, 20 Apr 2026 21:21:50 +0800 Subject: [PATCH 027/112] feat(math): add config --- ...light_16b_a3b_gsm8k_grpo_megatron_h20.yaml | 203 ++++++++++++++++++ 1 file changed, 203 insertions(+) create mode 100644 examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron_h20.yaml diff --git a/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron_h20.yaml b/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron_h20.yaml new file mode 100644 index 0000000000..380353b67e --- /dev/null +++ b/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron_h20.yaml @@ -0,0 +1,203 @@ +experiment_name: moonlight-16b-a3b-gsm8k-grpo-h20 +trial_name: trial0 + +seed: 1 +enable_offload: false +total_train_epochs: 10 +tokenizer_path: ${actor.path} + +cluster: + n_nodes: 1 + n_gpus_per_node: 6 + fileroot: /tmp/areal/moon_experiments + name_resolve: + type: nfs + nfs_record_root: /tmp/areal/moon_name_resolve + +scheduler: + type: null + +rollout: + backend: "sglang:d1p1t2" + experiment_name: ${experiment_name} + trial_name: ${trial_name} + max_concurrent_rollouts: 128 + queue_size: null + consumer_batch_size: ${train_dataset.batch_size} + max_head_offpolicyness: 2 + enable_rollout_tracing: false + scheduling_spec: ${actor.scheduling_spec} + fileroot: ${cluster.fileroot} + tokenizer_path: ${tokenizer_path} + dump_to_file: true + # R3: Enable returning routed expert assignments from rollout inference. + # This triggers the entire Router Replay pipeline: SGLang returns per-token + # expert indices, which are then replayed during training to eliminate + # train/inference routing mismatch in MoE models. + return_routed_experts: true + +gconfig: + n_samples: 4 + min_new_tokens: 0 + max_new_tokens: 1024 + greedy: false + temperature: 1.0 + +actor: + backend: "megatron:(attn:d1p1t4|ffn:d1p1t1e4)" # ← PP=2 回退,TP=4/EP=4 + experiment_name: ${experiment_name} + trial_name: ${trial_name} + path: /workspace/models/Moonlight-16B-A3B-Instruct + init_from_scratch: false + disable_dropout: true + gradient_checkpointing: false + dtype: bfloat16 + mb_spec: + max_tokens_per_mb: 10240 # ← 从 2048 降至 512 + optimizer: + type: adam + lr: 2e-6 + weight_decay: 0.003 + beta1: 0.9 + beta2: 0.999 + eps: 1e-8 + lr_scheduler_type: constant + gradient_clipping: 1.0 + warmup_steps_proportion: 0.001 + eps_clip: 0.2 + temperature: ${gconfig.temperature} + reward_scaling: 10.0 + reward_bias: -0.5 + kl_ctl: 0.0 + ppo_n_minibatches: 1 # ← 从 1 提高至 4(分批梯度累积) + recompute_logprob: true + use_decoupled_loss: true + behave_imp_weight_cap: 5.0 + reward_norm: + mean_level: group + std_level: group + group_size: ${gconfig.n_samples} + adv_norm: + mean_level: batch + std_level: batch + # weight_update_mode: disk + max_new_tokens: ${gconfig.max_new_tokens} + megatron: + use_deterministic_algorithms: false + recompute_granularity: full + recompute_method: uniform + recompute_num_layers: 14 + main_grads_dtype: bfloat16 # 梯度从 FP32 降为 BF16(节省 ~4 GiB) + # store_param_remainders: true + # optimizer_cpu_offload: true + # optimizer_offload_fraction: 0.5 + # main_params_dtype: bfloat16 + # main_grads_dtype: bfloat16 + # # adam_bf16 已自动设置以下两项,但显式声明更安全 + # exp_avg_dtype: bfloat16 + # exp_avg_sq_dtype: bfloat16 + ddp: + grad_reduce_in_fp32: false # ← 保持逐层重计算 + scheduling_spec: + - task_type: worker + port_count: 2 + gpu: 1 + mem: 48 + cmd: python3 -m areal.infra.rpc.rpc_server + env_vars: + PYTORCH_CUDA_ALLOC_CONF: "expandable_segments:True" + +ref: + backend: ${actor.backend} + experiment_name: ${experiment_name} + trial_name: ${trial_name} + path: ${actor.path} + init_from_scratch: false + disable_dropout: true + dtype: ${actor.dtype} + mb_spec: + max_tokens_per_mb: 10240 + optimizer: null + scheduling_strategy: + type: colocation + target: actor + scheduling_spec: ${actor.scheduling_spec} + +sglang: + model_path: ${actor.path} + random_seed: ${seed} + skip_tokenizer_init: false + dtype: bfloat16 + max_running_requests: null + context_length: 32768 + mem_fraction_static: 0.8 + attention_backend: triton + # R3: Enable SGLang to capture and return per-token routed expert indices + # during inference. This is auto-set by rl_trainer when + # rollout.return_routed_experts=true, but explicitly declared here for clarity. + enable_return_routed_experts: true + chunked_prefill_size: 2048 + +vllm: + model: ${actor.path} + seed: ${seed} + skip_tokenizer_init: false + dtype: bfloat16 + max_model_len: 4096 + gpu_memory_utilization: 0.75 + +train_dataset: + batch_size: 256 + shuffle: true + pin_memory: true + num_workers: 4 + path: openai/gsm8k + type: rl + max_length: 1024 + +valid_dataset: + batch_size: 256 + pin_memory: true + num_workers: 4 + path: openai/gsm8k + type: rl + +saver: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: 1 + freq_steps: null + freq_secs: null + +recover: + mode: disabled + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: 1 + freq_steps: null + freq_secs: 3600 + +evaluator: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: 1 + freq_steps: null + freq_secs: null + +stats_logger: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + wandb: + mode: disabled + +perf_tracer: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + enabled: false + session_tracer: + enabled: false \ No newline at end of file From a9b0b46b0523dea26be5da92d9ca5fbfe8bf7826 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Tue, 21 Apr 2026 12:13:06 +0800 Subject: [PATCH 028/112] fix(config): remove Instruct --- examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron.yaml | 2 +- examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron_base.yaml | 2 +- examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron_h20.yaml | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron.yaml b/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron.yaml index 14ff04ce9c..c8e1827849 100644 --- a/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron.yaml +++ b/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron.yaml @@ -47,7 +47,7 @@ actor: backend: "megatron:(attn:d1p2t4|ffn:d1p2t1e4)" # ← PP=2 回退,TP=4/EP=4 experiment_name: ${experiment_name} trial_name: ${trial_name} - path: /workspace/models/Moonlight-16B-A3B-Instruct + path: /workspace/models/Moonlight-16B-A3B init_from_scratch: false disable_dropout: true gradient_checkpointing: true diff --git a/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron_base.yaml b/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron_base.yaml index a9bbd25530..29a83328ee 100644 --- a/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron_base.yaml +++ b/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron_base.yaml @@ -42,7 +42,7 @@ actor: backend: "megatron:(attn:d1p2t4|ffn:d1p2t1e4)" # ← PP=2 回退,TP=4/EP=4 experiment_name: ${experiment_name} trial_name: ${trial_name} - path: /workspace/models/Moonlight-16B-A3B-Instruct + path: /workspace/models/Moonlight-16B-A3B init_from_scratch: false disable_dropout: true gradient_checkpointing: true diff --git a/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron_h20.yaml b/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron_h20.yaml index 380353b67e..370b1bce88 100644 --- a/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron_h20.yaml +++ b/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron_h20.yaml @@ -47,7 +47,7 @@ actor: backend: "megatron:(attn:d1p1t4|ffn:d1p1t1e4)" # ← PP=2 回退,TP=4/EP=4 experiment_name: ${experiment_name} trial_name: ${trial_name} - path: /workspace/models/Moonlight-16B-A3B-Instruct + path: /workspace/models/Moonlight-16B-A3B init_from_scratch: false disable_dropout: true gradient_checkpointing: false From e78147c04a3987703f608d0d8cdce7ba17c522f6 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Tue, 21 Apr 2026 12:24:35 +0800 Subject: [PATCH 029/112] feat(R3): add logprob diff --- areal/engine/megatron_engine_r3_patch.py | 18 ++++ areal/engine/router_replay_patch.py | 78 +++++++++++------ areal/trainer/ppo/actor.py | 16 ++++ areal/trainer/ppo/actor_r3_patch.py | 102 +++-------------------- 4 files changed, 96 insertions(+), 118 deletions(-) diff --git a/areal/engine/megatron_engine_r3_patch.py b/areal/engine/megatron_engine_r3_patch.py index 28406f76a1..8e8f1d4b93 100644 --- a/areal/engine/megatron_engine_r3_patch.py +++ b/areal/engine/megatron_engine_r3_patch.py @@ -508,6 +508,9 @@ def _r3_forward_backward_batch( # The forward_step wrapper will toggle between REPLAY_FORWARD # and REPLAY_BACKWARD for each micro-batch. # ------------------------------------------------------------------ + # Reset agreement accumulator for this forward-backward pass. + RouterReplay.reset_agreement_accumulator() + RouterReplay.set_global_router_replay_action(RouterReplayAction.REPLAY_FORWARD) logger.debug( "[R3] Set initial REPLAY_FORWARD action on %d router instances.", @@ -663,6 +666,21 @@ def _r3_post_forward_hook(module, input, output): handle.remove() # Restore original class __iter__ and clean up R3 state mb_list.__class__.__iter__ = original_class_iter + + # Harvest agreement stats BEFORE clearing replay state. + _agreement = RouterReplay.harvest_agreement_stats() + self._r3_last_agreement_stats = _agreement + if _agreement.get("n_samples", 0) > 0: + from areal.utils import stats_tracker + with stats_tracker.scope("r3"): + stats_tracker.scalar( + router_agreement_rate=_agreement["avg"], + router_agreement_rate_min=_agreement["min"], + router_agreement_rate_max=_agreement["max"], + router_agreement_n_samples=_agreement["n_samples"], + router_agreement_n_calls=_agreement["n_calls"], + ) + clear_router_replay() self._r3_per_mb_experts = None self._r3_mb_counter = 0 diff --git a/areal/engine/router_replay_patch.py b/areal/engine/router_replay_patch.py index 20231ec7fd..d5cfc5ca08 100644 --- a/areal/engine/router_replay_patch.py +++ b/areal/engine/router_replay_patch.py @@ -116,6 +116,12 @@ class RouterReplay: # Set by the engine patch before forward_backward_func. pp_size: int = 1 + # Class-level agreement accumulator. + # Collects per-call agreement rates during REPLAY_FORWARD to provide + # an accurate R3 effectiveness metric every training step. + _agreement_samples: list = [] + _replay_call_count: int = 0 + # ------------------------------------------------------------------ # Class-level (static) helpers # ------------------------------------------------------------------ @@ -161,8 +167,32 @@ def clear_global_router_replay_action() -> None: for r in RouterReplay.router_instances: r.clear_router_replay_action() + @staticmethod + def reset_agreement_accumulator() -> None: + """Reset the agreement accumulator for a new training step.""" + RouterReplay._agreement_samples = [] + RouterReplay._replay_call_count = 0 + @staticmethod + def harvest_agreement_stats() -> dict: + """Harvest accumulated agreement samples and return summary stats. + Returns: + dict with keys: avg, min, max, n_samples, n_calls. + If no samples, returns dict with n_samples=0. + """ + samples = RouterReplay._agreement_samples + n_calls = RouterReplay._replay_call_count + if not samples: + return {"n_samples": 0, "n_calls": n_calls} + avg = sum(samples) / len(samples) + return { + "avg": avg, + "min": min(samples), + "max": max(samples), + "n_samples": len(samples), + "n_calls": n_calls, + } def __init__(self) -> None: self.target_topk_idx: torch.Tensor | None = None @@ -252,39 +282,33 @@ def compute_topk(scores, topk, num_groups=None, group_topk=None): ) return _compute_topk(scores, topk, num_groups=num_groups, group_topk=group_topk) - # Use the provided indices for replay. - # NOTE: Agreement rate is NOT computed here in the per-layer - # router hot path. Following verl's approach, the per-layer - # stats_tracker.scalar call was removed because: - # 1. It averaged across ALL layers x microbatches x minibatches, - # producing a misleading global metric (~3-13%) despite - # individual per-layer agreement being ~97%. - # 2. It added unnecessary compute (natural topk) on every - # router call during training. - # Agreement rate is now computed ONCE per step in - # log_r3_data_stats (actor_r3_patch.py) using recorded - # routing from the first microbatch, with proper padding - # exclusion. + # Use the provided indices for replay (following verl's approach). top_indices = router_replay.target_topk_idx top_indices = top_indices.to(scores.device) probs = scores.gather(1, top_indices) - if not hasattr(_patched_topk_routing_with_score_function, '_r3_verify_count'): - _patched_topk_routing_with_score_function._r3_verify_count = 0 - _patched_topk_routing_with_score_function._r3_verify_count += 1 - if _patched_topk_routing_with_score_function._r3_verify_count <= 3: - # Lightweight verification for first few calls only - with torch.no_grad(): - _, natural_indices = _compute_topk( - scores, topk, num_groups=num_groups, group_topk=group_topk - ) - replay_sorted = top_indices.sort(dim=-1).values - natural_sorted = natural_indices.sort(dim=-1).values - per_token_matches = (natural_sorted == replay_sorted).float().sum(dim=-1) - agreement_rate = (per_token_matches / topk).mean().item() + + # Compute agreement rate on every REPLAY_FORWARD call. + # This is the TRUE R3 effectiveness metric: comparing what the + # training-time router would naturally choose vs what we force + # it to replay from inference. Accumulated across all layers + # and microbatches, then harvested once per step by the engine. + RouterReplay._replay_call_count += 1 + _call_n = RouterReplay._replay_call_count + with torch.no_grad(): + _, natural_indices = _compute_topk( + scores, topk, num_groups=num_groups, group_topk=group_topk + ) + replay_sorted = top_indices.sort(dim=-1).values + natural_sorted = natural_indices.sort(dim=-1).values + per_token_matches = (natural_sorted == replay_sorted).float().sum(dim=-1) + agreement_rate = (per_token_matches / topk).mean().item() + RouterReplay._agreement_samples.append(agreement_rate) + + if _call_n <= 3: _nz = top_indices[top_indices > 0].flatten()[:3].tolist() print( f"[R3-VERIFY] Megatron REPLAY_FORWARD " - f"#{_patched_topk_routing_with_score_function._r3_verify_count}: " + f"#{_call_n}: " f"top_indices shape={top_indices.shape}, " f"first3_nonzero={_nz}, " f"agreement_rate={agreement_rate:.4f}", diff --git a/areal/trainer/ppo/actor.py b/areal/trainer/ppo/actor.py index 18c81a45c5..b84c0d89c0 100644 --- a/areal/trainer/ppo/actor.py +++ b/areal/trainer/ppo/actor.py @@ -551,6 +551,22 @@ def grpo_loss_fn( dual_clip_ratio=stat["dual_clip_mask"].float(), denominator="n_valid_tokens", ) + + # ---- R3 Logprob Diff: rollout (inference) vs training logprobs ---- + # Following SkyRL's approach: compute |rollout_logprobs - training_logprobs| + # over response tokens only. This metric quantifies train/infer mismatch + # caused by MoE routing divergence. With R3 enabled, this diff should be + # smaller than without R3. + if loss_mask.any(): + with torch.no_grad(): + _logprob_diff = (old_logp[loss_mask] - logprobs.detach()[loss_mask]).abs() + _diff_mean = _logprob_diff.mean().item() + _diff_std = _logprob_diff.std().item() if _logprob_diff.numel() > 1 else 0.0 + stats_tracker.scalar( + rollout_train_logprobs_abs_diff_mean=_diff_mean, + rollout_train_logprobs_abs_diff_std=_diff_std, + ) + if "behave_imp_weight" in stat: stats_tracker.denominator(unclipped_behave_tokens=stat["behave_mask"]) stats_tracker.stat( diff --git a/areal/trainer/ppo/actor_r3_patch.py b/areal/trainer/ppo/actor_r3_patch.py index b4db063ac6..4d3476ce21 100644 --- a/areal/trainer/ppo/actor_r3_patch.py +++ b/areal/trainer/ppo/actor_r3_patch.py @@ -471,101 +471,21 @@ def _log_r3_agreement_rate( routed_experts: torch.Tensor, data: dict[str, Any], ) -> None: - """Compute and log a CORRECT router agreement rate for this step. + """Log R3 router agreement rate. - Computes per-layer fractional agreement between inference-time routing - (``routed_experts``) and training-time natural routing, excluding - padding tokens via ``attention_mask``. + The actual per-layer agreement (comparing training-time natural routing + vs. replayed inference routing) is now computed on every REPLAY_FORWARD + call inside ``router_replay_patch.py`` and logged to ``stats_tracker`` + from ``megatron_engine_r3_patch.py`` after each forward-backward pass. - This follows verl's design principle: padding tokens should not contribute - to the metric (verl uses ``preprocess_packed_seqs`` to strip padding - before setting replay data). + This function is intentionally a no-op to avoid reporting the misleading + metric that was here before (fraction of real tokens with non-zero expert + assignments, which was always ~1.0 for MoE layers). - The metric uses a *self-agreement* proxy: for each layer, it compares the - expert assignments between different samples in the batch. When full - training-time recorded routing is not available (RouterReplay instances - are cleared after each forward-backward), we report the inference-side - routing quality metrics instead: - - Per-layer consistency (entropy of routing distribution) - - Data coverage (fraction of samples with non-zero routing) - - If ``RouterReplay.router_instances`` still hold ``recorded_topk_idx`` - from the last training step, we use those for a direct comparison. - Otherwise, we log a placeholder indicating that the metric is deferred - to the next step when recorded data becomes available. - - Args: - routed_experts: ``(bs, seq_len, num_moe_layers, topk)`` from inference. - data: Full training data dict (for attention_mask). + The function signature is preserved for backward compatibility. """ - if routed_experts.dim() != 4 or routed_experts.numel() == 0: - return - - bs, seq_len, num_layers, topk = routed_experts.shape - attn_mask = _resolve_to_tensor(data.get("attention_mask")) - - try: - # Build per-token real-token mask, excluding padding - if attn_mask is not None and attn_mask.shape[0] == bs: - if attn_mask.shape[1] == seq_len: - real_mask = attn_mask.bool() # (bs, seq_len) - else: - # Seq length mismatch (common: training uses packed seqlen). - # Fall back to using nonzero routing as proxy for real tokens. - real_mask = routed_experts.sum(dim=(2, 3)) != 0 # (bs, seq_len) - else: - # No usable attention_mask; use nonzero routing as proxy - real_mask = routed_experts.sum(dim=(2, 3)) != 0 # (bs, seq_len) - - # Compute per-layer agreement using the inference routing data. - # Since we don't have the training-time natural routing available - # (RouterReplay clears after each forward-backward), we compute - # a self-consistency metric: how stable the routing is across the - # batch. For the agreement rate, we check what fraction of real - # tokens have non-zero (valid) expert assignments per layer. - layer_agreements = [] - total_real_tokens = 0 - total_matched_tokens = 0 - n_dense_layers = 0 - - for layer_idx in range(num_layers): - layer_re = routed_experts[:, :, layer_idx, :] - - if _is_dense_layer(layer_re): - n_dense_layers += 1 - continue - - has_valid_routing = layer_re.sum(dim=-1) != 0 # (bs, seq_len) - valid_and_real = has_valid_routing & real_mask # (bs, seq_len) - - n_real = real_mask.sum().item() - n_valid_real = valid_and_real.sum().item() - - if n_real > 0: - layer_agreement = n_valid_real / n_real - layer_agreements.append(layer_agreement) - total_real_tokens += n_real - total_matched_tokens += n_valid_real - - n_moe_layers = num_layers - n_dense_layers - if layer_agreements: - avg_agreement = sum(layer_agreements) / len(layer_agreements) - min_agreement = min(layer_agreements) - max_agreement = max(layer_agreements) - - stats_tracker.scalar( - router_agreement_rate=avg_agreement, - router_agreement_rate_min=min_agreement, - router_agreement_rate_max=max_agreement, - router_agreement_n_real_tokens=total_real_tokens / max(n_moe_layers, 1), - router_agreement_n_moe_layers=n_moe_layers, - router_agreement_n_dense_layers=n_dense_layers, - ) - except Exception: - logger.warning( - "[R3] Failed to compute R3 agreement rate.", - exc_info=True, - ) + # Agreement rate is now reported from the engine layer. + pass def compute_router_agreement_rate( From 685c17ad13e0b66eacf120ec00512812a5ae6cea Mon Sep 17 00:00:00 2001 From: bingyechen Date: Tue, 21 Apr 2026 12:38:45 +0800 Subject: [PATCH 030/112] fix: add Moonlight-16B-A3B-Istruct --- examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron_base.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron_base.yaml b/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron_base.yaml index 29a83328ee..97162a3e62 100644 --- a/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron_base.yaml +++ b/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron_base.yaml @@ -42,7 +42,7 @@ actor: backend: "megatron:(attn:d1p2t4|ffn:d1p2t1e4)" # ← PP=2 回退,TP=4/EP=4 experiment_name: ${experiment_name} trial_name: ${trial_name} - path: /workspace/models/Moonlight-16B-A3B + path: /workspace/models/Moonlight-16B-A3B-Istruct init_from_scratch: false disable_dropout: true gradient_checkpointing: true From 15e7015b536d5738ec45015086bdc885aac0c0c6 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Tue, 21 Apr 2026 13:07:10 +0800 Subject: [PATCH 031/112] fix(math): fix Moonlight-16B-A3B-Istruct --- examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron.yaml b/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron.yaml index c8e1827849..8423ff0072 100644 --- a/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron.yaml +++ b/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron.yaml @@ -47,7 +47,7 @@ actor: backend: "megatron:(attn:d1p2t4|ffn:d1p2t1e4)" # ← PP=2 回退,TP=4/EP=4 experiment_name: ${experiment_name} trial_name: ${trial_name} - path: /workspace/models/Moonlight-16B-A3B + path: /workspace/models/Moonlight-16B-A3B-Istruct init_from_scratch: false disable_dropout: true gradient_checkpointing: true From 302be1e2fbc37c20b2fd58723697fb7b8fee43bc Mon Sep 17 00:00:00 2001 From: bingyechen Date: Tue, 21 Apr 2026 13:13:14 +0800 Subject: [PATCH 032/112] fix(config): spill --- examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron.yaml | 2 +- examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron_base.yaml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron.yaml b/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron.yaml index 8423ff0072..14ff04ce9c 100644 --- a/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron.yaml +++ b/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron.yaml @@ -47,7 +47,7 @@ actor: backend: "megatron:(attn:d1p2t4|ffn:d1p2t1e4)" # ← PP=2 回退,TP=4/EP=4 experiment_name: ${experiment_name} trial_name: ${trial_name} - path: /workspace/models/Moonlight-16B-A3B-Istruct + path: /workspace/models/Moonlight-16B-A3B-Instruct init_from_scratch: false disable_dropout: true gradient_checkpointing: true diff --git a/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron_base.yaml b/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron_base.yaml index 97162a3e62..a9bbd25530 100644 --- a/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron_base.yaml +++ b/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron_base.yaml @@ -42,7 +42,7 @@ actor: backend: "megatron:(attn:d1p2t4|ffn:d1p2t1e4)" # ← PP=2 回退,TP=4/EP=4 experiment_name: ${experiment_name} trial_name: ${trial_name} - path: /workspace/models/Moonlight-16B-A3B-Istruct + path: /workspace/models/Moonlight-16B-A3B-Instruct init_from_scratch: false disable_dropout: true gradient_checkpointing: true From b54d4b786636969a36da93cc425cbeb90b52f0bc Mon Sep 17 00:00:00 2001 From: bingyechen Date: Tue, 21 Apr 2026 17:51:48 +0800 Subject: [PATCH 033/112] fix(router_replay): fix culen --- areal/engine/megatron_engine_r3_patch.py | 312 +++++++---------------- areal/engine/router_replay_utils.py | 163 +++++++----- 2 files changed, 199 insertions(+), 276 deletions(-) diff --git a/areal/engine/megatron_engine_r3_patch.py b/areal/engine/megatron_engine_r3_patch.py index 8e8f1d4b93..0cd4c12e30 100644 --- a/areal/engine/megatron_engine_r3_patch.py +++ b/areal/engine/megatron_engine_r3_patch.py @@ -88,166 +88,72 @@ def patch_megatron_engine_for_r3( # =================================================================== -# attention_mask reconstruction from cu_seqlens (Problem 5 fix) +# routed_experts alignment (left-padded rollout → left-aligned training) # =================================================================== -def _reconstruct_attention_mask_from_cu_seqlens( +def _align_routed_experts_to_mask( + routed_experts: torch.Tensor, cu_seqlens: torch.Tensor, max_seqlen: int, ) -> torch.Tensor: - """Reconstruct a 2D ``attention_mask`` from packed ``cu_seqlens``. + """Align ``routed_experts`` from left-padded rollout format to left-aligned + training format, matching the token layout implied by ``cu_seqlens``. - After ``pack_tensor_dict``, the original ``attention_mask`` is replaced - by ``cu_seqlens`` (shape ``(B+1,)``) and ``max_seqlen``. For R3's - ``set_router_replay_data`` we need an ``attention_mask`` of shape - ``(B, padded_seq_len)`` where padded_seq_len = max_seqlen. + **Rollout format**: ``routed_experts`` is ``(bs, batch_max_seqlen, L, K)`` + with LEFT padding (real tokens at the RIGHT end of each row). - Args: - cu_seqlens: ``(B+1,)`` cumulative sequence lengths. - max_seqlen: Maximum sequence length (the padded dimension). + **Training format**: After ``pack_tensor_dict``, tokens are LEFT-aligned + (real tokens first). The ``cu_seqlens`` tells us each sample's actual + length. - Returns: - ``torch.Tensor`` of shape ``(B, max_seqlen)`` with dtype ``torch.bool``. - """ - bs = cu_seqlens.shape[0] - 1 - seq_lens = cu_seqlens[1:] - cu_seqlens[:-1] # (B,) - # Build mask: position j < seq_lens[i] -> True - positions = torch.arange(max_seqlen, device=cu_seqlens.device).unsqueeze(0) # (1, S) - mask = positions < seq_lens.unsqueeze(1) # (B, S) - logger.debug( - "[R3] Reconstructed attention_mask from cu_seqlens: " - "bs=%d, max_seqlen=%d, seq_lens=%s.", - bs, - max_seqlen, - seq_lens.tolist()[:8], # log first 8 for brevity - ) - return mask + This function extracts the rightmost ``actual_len`` tokens from each + sample in ``routed_experts`` and produces a ``(bs_aligned, max_seqlen, L, K)`` + tensor with real tokens at the LEFT (matching training convention). - -# =================================================================== -# Problem 2 fix: Align routed_experts seq dim to attention_mask -# =================================================================== - - -def _align_routed_experts_to_mask( - routed_experts: torch.Tensor, - attention_mask: torch.Tensor, -) -> torch.Tensor: - """Align ``routed_experts`` seq dimension to match ``attention_mask``. - - **Problem 2 Fix**: After pack_tensor_dict + pad_mb_list, the - cu_seqlens-reconstructed ``attention_mask`` has ``mb_max_seqlen`` - which may be SMALLER than ``routed_experts``' seq dimension - (``batch_max_seqlen``). The rollout-produced ``routed_experts`` is - LEFT-padded (padding on the left, real tokens on the right), while - the post-pack ``attention_mask`` is LEFT-aligned (real tokens first, - no left-padding). - - **Batch size alignment**: ``pad_packed_tensor_dict`` appends one extra - cu_seqlens entry (a padding sequence) to fill the micro-batch to - ``pad_to_length``. This makes ``attention_mask`` have one more row - than the original ``routed_experts``. We zero-pad the batch dimension - so that ``set_router_replay_data`` sees matching batch sizes; the - padding sample's zero routing indices are harmless because the model - ignores those dummy tokens. - - This function extracts the right-most ``actual_len`` tokens from each - sample's left-padded ``routed_experts`` and places them at the - left-aligned positions expected by ``attention_mask``. + If ``cu_seqlens`` has more entries than ``routed_experts`` has rows + (because ``pad_packed_tensor_dict`` appended a dummy padding sequence), + the output is zero-padded along the batch dimension. Args: routed_experts: ``(bs, batch_max_seqlen, num_moe_layers, topk)`` - Left-padded routing indices from rollout. - attention_mask: ``(bs, mb_max_seqlen)`` - Left-aligned mask (1 for real tokens, 0 for padding). + cu_seqlens: ``(n_seqs+1,)`` cumulative sequence lengths. + max_seqlen: Maximum sequence length (from ``padded_mb["max_seqlen"]``). Returns: - ``(bs_aligned, mb_max_seqlen, num_moe_layers, topk)`` aligned tensor. + ``(n_seqs, max_seqlen, num_moe_layers, topk)`` aligned tensor. """ re_bs, re_seqlen = routed_experts.shape[:2] - mask_bs, mask_seqlen = attention_mask.shape[:2] - - if re_bs < mask_bs: - extra_dims = routed_experts.shape[2:] - padded_re = torch.zeros( - mask_bs, re_seqlen, *extra_dims, - dtype=routed_experts.dtype, - device=routed_experts.device, - ) - padded_re[:re_bs] = routed_experts - routed_experts = padded_re - logger.info( - "[R3] _align_routed_experts_to_mask: padded routed_experts batch " - "from %d to %d samples (pad_mb_list added %d padding sequence(s)).", - re_bs, mask_bs, mask_bs - re_bs, - ) - elif re_bs > mask_bs: - routed_experts = routed_experts[:mask_bs] - logger.warning( - "[R3] _align_routed_experts_to_mask: truncated routed_experts batch " - "from %d to %d samples.", - re_bs, mask_bs, - ) - - bs = routed_experts.shape[0] - - if re_seqlen == mask_seqlen: - # No alignment needed - return routed_experts - - if re_seqlen < mask_seqlen: - # Unlikely but possible if mask was padded beyond batch_max_seqlen. - # Right-pad routed_experts with zeros. - extra_dims = routed_experts.shape[2:] # (num_moe_layers, topk) - padded = torch.zeros( - bs, mask_seqlen, *extra_dims, - dtype=routed_experts.dtype, - device=routed_experts.device, - ) - padded[:, :re_seqlen] = routed_experts - logger.info( - "[R3] _align_routed_experts_to_mask: re_seqlen(%d) < mask_seqlen(%d), " - "right-padded routed_experts with zeros.", - re_seqlen, mask_seqlen, - ) - return padded - - # re_seqlen > mask_seqlen: the common case. - # routed_experts is LEFT-padded: real tokens are at the RIGHT end. - # attention_mask is LEFT-aligned: real tokens are at the LEFT end. - # For each sample, extract the rightmost `actual_len` tokens from - # routed_experts and place them at positions [0, actual_len) in output. extra_dims = routed_experts.shape[2:] # (num_moe_layers, topk) + n_seqs = cu_seqlens.shape[0] - 1 + seq_lens = (cu_seqlens[1:] - cu_seqlens[:-1]).cpu().tolist() + + # Output: (n_seqs, max_seqlen, L, K) with real tokens left-aligned aligned = torch.zeros( - bs, mask_seqlen, *extra_dims, + n_seqs, max_seqlen, *extra_dims, dtype=routed_experts.dtype, device=routed_experts.device, ) - seq_lens = attention_mask.sum(dim=1).long() # actual lengths per sample - for i in range(bs): - actual_len = int(seq_lens[i].item()) + for i in range(min(n_seqs, re_bs)): + actual_len = seq_lens[i] if actual_len <= 0: continue - # Take rightmost actual_len tokens from left-padded routed_experts + # Source: rightmost actual_len tokens from left-padded routed_experts src_start = re_seqlen - actual_len - n = min(actual_len, mask_seqlen) + n = min(actual_len, re_seqlen, max_seqlen) aligned[i, :n] = routed_experts[i, src_start : src_start + n] - logger.info( - "[R3] _align_routed_experts_to_mask: re_seqlen=%d -> mask_seqlen=%d, " - "bs=%d, seq_lens=%s (aligned left-padded RE to left-aligned mask).", - re_seqlen, - mask_seqlen, - bs, - seq_lens.tolist()[:8], + logger.debug( + "[R3] _align_routed_experts_to_mask: re_shape=%s -> aligned_shape=%s, " + "n_seqs=%d (re_bs=%d), seq_lens=%s.", + routed_experts.shape, aligned.shape, n_seqs, re_bs, seq_lens[:8], ) return aligned # =================================================================== -# routed_experts splitting (Problem 7 fix: robust sample-count inference) +# routed_experts splitting (robust sample-count inference) # =================================================================== @@ -349,41 +255,32 @@ def _split_routed_experts_for_mbs( # =================================================================== -# Per-MB attention_mask extraction (Problem 5 fix) +# Per-MB cu_seqlens extraction # =================================================================== -def _get_attention_mask_for_mb(mb_item) -> torch.Tensor | None: - """Extract or reconstruct ``attention_mask`` from a ``MicroBatchItem``. +def _get_cu_seqlens_for_mb(mb_item) -> tuple[torch.Tensor, int] | None: + """Extract ``cu_seqlens`` and ``max_seqlen`` from a ``MicroBatchItem``. - After ``pack_tensor_dict``, both ``orig_mb`` and ``padded_mb`` have - ``cu_seqlens`` instead of ``attention_mask``. We reconstruct the mask - from ``cu_seqlens`` + ``max_seqlen`` in the padded_mb (which reflects - the actual padded sequence length used by the model). + Prefers ``padded_mb`` (which has the actual TP-aligned dimensions used + by the model) over ``orig_mb``. - Falls back to ``attention_mask`` if still present (e.g. tree training). + Returns: + ``(cu_seqlens, max_seqlen)`` or ``None`` if not available. """ - # Try padded_mb first (has the actual padded dimensions) + # Try padded_mb first (has TP-aligned cu_seqlens -- this is what the model sees) if hasattr(mb_item, "padded_mb") and isinstance(mb_item.padded_mb, dict): - # Direct attention_mask - attn = mb_item.padded_mb.get("attention_mask") - if attn is not None: - return attn - # Reconstruct from cu_seqlens cu = mb_item.padded_mb.get("cu_seqlens") max_sl = mb_item.padded_mb.get("max_seqlen") if cu is not None and max_sl is not None: - return _reconstruct_attention_mask_from_cu_seqlens(cu, int(max_sl)) + return cu, int(max_sl) - # Try orig_mb as fallback + # Try orig_mb as fallback (pre-padding cu_seqlens) if hasattr(mb_item, "orig_mb") and isinstance(mb_item.orig_mb, dict): - attn = mb_item.orig_mb.get("attention_mask") - if attn is not None: - return attn cu = mb_item.orig_mb.get("cu_seqlens") max_sl = mb_item.orig_mb.get("max_seqlen") if cu is not None and max_sl is not None: - return _reconstruct_attention_mask_from_cu_seqlens(cu, int(max_sl)) + return cu, int(max_sl) return None @@ -407,13 +304,9 @@ def _r3_forward_backward_batch( If the data does not contain ``routed_experts``, delegates directly to the original method with zero overhead. - **Problem 1 Fix**: Retrieves routed_experts from engine side-channel - (``self._r3_pending_routed_experts``) set by actor._ppo_update FIRST, - falling back to ``mb_list.data`` for backward compatibility. - - **Problem 2 Fix**: Before passing per-MB routed_experts to - ``setup_per_microbatch_replay_forward``, aligns the seq dimension - to match the attention_mask's seq dimension. + **CRITICAL FIX**: Uses ``cu_seqlens`` from the padded micro-batch + (with per-sequence TP alignment) for packing replay data, ensuring + token ordering matches exactly what Megatron's transformer layers see. """ from areal.engine.router_replay_patch import RouterReplay, RouterReplayAction from areal.engine.router_replay_utils import ( @@ -424,14 +317,11 @@ def _r3_forward_backward_batch( # ------------------------------------------------------------------ # 1. Retrieve routed_experts. - # Problem 1 Fix: Prefer side-channel from actor._ppo_update, which - # bypasses _prepare_mb_list/pack_tensor_dict entirely. - # Fall back to mb_list.data for backward compatibility. # ------------------------------------------------------------------ routed_experts_batch = None _from_side_channel = False - # Strategy A: Side-channel (Problem 1 fix -- preferred path) + # Strategy A: Side-channel (preferred path) if hasattr(self, '_r3_pending_routed_experts') and self._r3_pending_routed_experts is not None: routed_experts_batch = self._r3_pending_routed_experts self._r3_pending_routed_experts = None # Consume it @@ -442,8 +332,6 @@ def _r3_forward_backward_batch( ) # Strategy B: Legacy path from mb_list.data (backward compatibility) - # Only used when forward_only=False (training), to prevent unintended - # replay during compute_logp / eval_batch. if routed_experts_batch is None and not forward_only: if hasattr(mb_list, "data") and isinstance(mb_list.data, dict): routed_experts_batch = mb_list.data.pop("routed_experts", None) @@ -454,9 +342,7 @@ def _r3_forward_backward_batch( routed_experts_batch.shape, ) - # Also clean from mbs and padded_mbs to avoid confusing downstream code. - # Problem 1: these would contain the un-split full tensor via not_to_split broadcast, - # or corrupted 3D tensors from pack_tensor_dict. + # Clean from mbs and padded_mbs to avoid confusing downstream code. for mb_dict in mb_list.mbs: if isinstance(mb_dict, dict): mb_dict.pop("routed_experts", None) @@ -464,7 +350,6 @@ def _r3_forward_backward_batch( for mb_dict in mb_list.padded_mbs: if isinstance(mb_dict, dict): mb_dict.pop("routed_experts", None) - # Also clean from mb_list.data to prevent leaking into future calls if hasattr(mb_list, "data") and isinstance(mb_list.data, dict): mb_list.data.pop("routed_experts", None) @@ -503,14 +388,16 @@ def _r3_forward_backward_batch( self._r3_mb_counter = 0 model_config = self.tf_config + # Compute seq_align_to (same as what _prepare_mb_list uses) + from megatron.core import parallel_state as mpu + tp_size = mpu.get_tensor_model_parallel_world_size() + cp_size = getattr(mpu, "get_context_parallel_world_size", lambda: 1)() + seq_align_to = tp_size * cp_size * 2 if cp_size > 1 else tp_size + # ------------------------------------------------------------------ # 2b. Set initial replay action to REPLAY_FORWARD. - # The forward_step wrapper will toggle between REPLAY_FORWARD - # and REPLAY_BACKWARD for each micro-batch. # ------------------------------------------------------------------ - # Reset agreement accumulator for this forward-backward pass. RouterReplay.reset_agreement_accumulator() - RouterReplay.set_global_router_replay_action(RouterReplayAction.REPLAY_FORWARD) logger.debug( "[R3] Set initial REPLAY_FORWARD action on %d router instances.", @@ -518,25 +405,10 @@ def _r3_forward_backward_batch( ) # ------------------------------------------------------------------ - # 3. Wrap the MicroBatchList iterator on the INSTANCE level - # - # The iterator injects R3 setup before each micro-batch's forward. - # The iterator wrapper also handles - # the REPLAY_FORWARD / REPLAY_BACKWARD toggle per micro-batch: - # - # - At the START of each forward_step (when next() is called): - # 1. If action is REPLAY_BACKWARD, switch to REPLAY_FORWARD - # (this handles backward recompute -> next forward transition) - # 2. Set the replay data for this micro-batch via - # setup_per_microbatch_replay_forward() - # - # - At the END of each forward_step (via model forward hook): - # switch to REPLAY_BACKWARD so that the subsequent backward - # recompute (activation checkpointing) uses - # replay_backward_list.pop(0). - # + # 3. Wrap the MicroBatchList iterator # ------------------------------------------------------------------ engine_ref = self + _seq_align_to = seq_align_to class _R3MicroBatchIterator: """Wraps the micro-batch iterator to inject R3 setup.""" @@ -558,9 +430,8 @@ def __next__(self): else None ) - # When backward recompute (activation checkpointing) finishes - # and the next forward starts, the action is REPLAY_BACKWARD. - # Switch it back to REPLAY_FORWARD before setting new data. + # When backward recompute finishes and next forward starts, + # switch back to REPLAY_FORWARD. if RouterReplayHelper.is_replay_backward_action(model_config): router_list = RouterReplayHelper.get_micro_batch_router_list( model_config @@ -571,27 +442,49 @@ def __next__(self): ) if re is not None: - # Problem 5 fix: reconstruct attention_mask from cu_seqlens - # when pack_tensor_dict has replaced it. - attn_mask = _get_attention_mask_for_mb(mb_item) + # Extract cu_seqlens from padded_mb (TP-aligned, what the model sees) + cu_info = _get_cu_seqlens_for_mb(mb_item) - if attn_mask is not None: + if cu_info is not None: + cu_seqlens, max_seqlen = cu_info try: - # Problem 2 fix: Align routed_experts seq dimension - # to match attention_mask's seq dimension. - # routed_experts is left-padded (batch_max_seqlen), - # attn_mask is left-aligned (mb_max_seqlen). - aligned_re = _align_routed_experts_to_mask(re, attn_mask) + # CRITICAL FIX: Use cu_seqlens for alignment instead of + # attention_mask. This ensures the packed token order + # matches Megatron's actual forward pass. + + # First, get the ORIGINAL (pre-TP-alignment) cu_seqlens + # to know each sample's actual token count for + # extracting from routed_experts. + orig_cu = None + if hasattr(mb_item, "old_cu_seqlens") and mb_item.old_cu_seqlens is not None: + orig_cu = mb_item.old_cu_seqlens + elif hasattr(mb_item, "orig_mb") and isinstance(mb_item.orig_mb, dict): + orig_cu = mb_item.orig_mb.get("cu_seqlens") + + if orig_cu is None: + # Fallback: use padded cu_seqlens directly + orig_cu = cu_seqlens + + # Align routed_experts from left-padded to left-aligned + # using the ORIGINAL cu_seqlens (actual token counts). + aligned_re = _align_routed_experts_to_mask( + re, orig_cu, max_seqlen, + ) + # Pass the PADDED cu_seqlens (with TP alignment) + # to set_router_replay_data so packing matches Megatron. setup_per_microbatch_replay_forward( - aligned_re.to(attn_mask.device), - attn_mask, + aligned_re.to(cu_seqlens.device), + cu_seqlens, model_config, + seq_align_to=_seq_align_to, ) logger.debug( "[R3] Replay setup OK for micro-batch %d: " - "original_re=%s, aligned_re=%s, attn_mask=%s.", - idx, re.shape, aligned_re.shape, attn_mask.shape, + "original_re=%s, aligned_re=%s, cu_seqlens=%s " + "(seq_align_to=%d).", + idx, re.shape, aligned_re.shape, cu_seqlens.shape, + _seq_align_to, ) except Exception: logger.warning( @@ -601,7 +494,7 @@ def __next__(self): ) else: logger.warning( - "[R3] Cannot find or reconstruct attention_mask for " + "[R3] Cannot find cu_seqlens for " "micro-batch %d; skipping replay setup. " "Keys in orig_mb: %s, keys in padded_mb: %s.", idx, @@ -618,9 +511,7 @@ def _r3_iter(mb_list_self): mb_list.__class__.__iter__ = _r3_iter # ------------------------------------------------------------------ - # 4. Register a forward hook on each model chunk for the - # REPLAY_FORWARD -> REPLAY_BACKWARD toggle at the END of each - # forward_step. + # 4. Register a forward hook for REPLAY_FORWARD -> REPLAY_BACKWARD toggle. # ------------------------------------------------------------------ hook_handles = [] @@ -646,17 +537,6 @@ def _r3_post_forward_hook(module, input, output): ) try: - # Megatron's forward_backward_func (e.g. 1F1B schedule) internally - # interleaves forward and backward for each micro-batch. - # - # Per-forward-step toggle handles - # backward recompute (activation checkpointing) correctly: - # - The iterator wrapper (above) switches REPLAY_BACKWARD -> - # REPLAY_FORWARD at the START of each forward_step. - # - The model forward hook (above) switches REPLAY_FORWARD -> - # REPLAY_BACKWARD at the END of each forward_step. - # - Forward uses target_topk_idx; backward recompute pops from - # replay_backward_list. self._r3_original_forward_backward_batch( mb_list, process_output_fn, forward_only=forward_only ) diff --git a/areal/engine/router_replay_utils.py b/areal/engine/router_replay_utils.py index 98787d347c..fd99d903ae 100644 --- a/areal/engine/router_replay_utils.py +++ b/areal/engine/router_replay_utils.py @@ -21,12 +21,6 @@ 2. **TP/SP splitting** -- sequence parallelism across tensor-model-parallel ranks. 3. **PP layer slicing** -- pipeline parallelism assigns different layers to ranks. 4. **Dense/MoE layer mapping** -- architectures with dense FFN layers before MoE. - -Ported from verl reference implementation, adapted for AReaL: -- No dependency on verl-specific imports -- No dependency on megatron.core.transformer.moe.router_replay -- Simplified packed-sequence handling (no preprocess_packed_seqs dependency) -- topk and num_moe_layers passed explicitly (no hardcoded guessing) """ from __future__ import annotations @@ -43,7 +37,7 @@ # =================================================================== -# Layer computation helpers (ported from verl, self-contained) +# Layer computation helpers # =================================================================== @@ -226,89 +220,131 @@ def is_replay_backward_action(tf_config, vp_rank=None) -> bool: ) -# =================================================================== -# set_router_replay_data -- core function -# =================================================================== def set_router_replay_data( layers_topk_idx: torch.Tensor, - attention_mask: torch.Tensor, + cu_seqlens: torch.Tensor, tf_config, vp_rank: Optional[int] = None, + seq_align_to: Optional[int] = None, ) -> None: """Scatter packed router top-k indices to SP ranks and update RouterReplay instances. - Simplified for AReaL: no dependency on preprocess_packed_seqs / postprocess_packed_seqs. + **CRITICAL**: This function must pack tokens using the EXACT same + cu_seqlens-based TP-aligned layout that Megatron uses for input_ids. + A different packing method (e.g., simple attention_mask concatenation) + causes token misalignment and near-random agreement rates. + + The packing steps mirror ``pad_packed_tensor_dict`` in ``areal/utils/data.py``: + + 1. Use ``cu_seqlens`` to extract each sample's real tokens from the + left-padded ``layers_topk_idx``. + 2. Pack tokens contiguously with per-sequence TP alignment padding + (each sequence padded to a multiple of ``seq_align_to``). + 3. ``scatter_to_sequence_parallel_region`` to split across TP/SP ranks. + 4. Permute to ``(num_layers, local_tokens, topk)`` and distribute to + RouterReplay instances. Args: - layers_topk_idx: ``(bs, max_seq_len, num_moe_layers, topk)`` -- the replay data. - attention_mask: ``(bs, max_seq_len)`` -- 1 for real tokens, 0 for padding. + layers_topk_idx: ``(bs, max_seq_len, num_moe_layers, topk)`` -- the + replay data (left-padded, from rollout). After + ``_align_routed_experts_to_mask``, this is left-ALIGNED (real + tokens first, matching attention_mask convention). + cu_seqlens: ``(bs+1,)`` or ``(bs+1+1,)`` -- cumulative sequence + lengths from the PADDED micro-batch (after ``pad_packed_tensor_dict``). + These define the actual token ordering that Megatron uses. + If the last entry is a batch-level padding sequence, it will be + handled by including a zero-filled routing segment. tf_config: Megatron TransformerConfig. vp_rank: Virtual pipeline stage rank override. + seq_align_to: Per-sequence TP alignment factor (typically ``tp_size`` + or ``tp_size * cp_size * 2``). If None, defaults to TP world size. """ from megatron.core import parallel_state as mpu with torch.no_grad(): device = torch.cuda.current_device() bs_re = layers_topk_idx.shape[0] - bs_mask, max_seq_len = attention_mask.shape[:2] + num_layers = layers_topk_idx.shape[2] + topk = layers_topk_idx.shape[3] + + # Determine the number of real sequences from cu_seqlens. + # pad_packed_tensor_dict may add one extra entry for batch-level padding. + n_cu_entries = cu_seqlens.shape[0] + # Number of sequences in cu_seqlens (including potential batch padding seq) + n_seqs_in_cu = n_cu_entries - 1 + + # Extract per-sequence lengths from cu_seqlens + seq_lens = (cu_seqlens[1:] - cu_seqlens[:-1]).cpu().tolist() + + # Determine seq_align_to if not provided + if seq_align_to is None: + tp_size = mpu.get_tensor_model_parallel_world_size() + cp_size = getattr(mpu, "get_context_parallel_world_size", lambda: 1)() + seq_align_to = tp_size * cp_size * 2 if cp_size > 1 else tp_size + + # Compute TP-aligned lengths (matching pad_packed_tensor_dict) + aligned_lens = [] + for slen in seq_lens: + pad = (-slen) % seq_align_to + aligned_lens.append(slen + pad) + + total_aligned = sum(aligned_lens) + + # Pack routed_experts using cu_seqlens-aligned layout. + # layers_topk_idx is left-ALIGNED: real tokens at positions [0, seq_len). + # For each sequence i, we take the first seq_lens[i] tokens and place + # them at aligned positions, with zero-padding for TP alignment gaps. + packed = torch.zeros( + total_aligned, num_layers, topk, + dtype=layers_topk_idx.dtype, + device=layers_topk_idx.device, + ) - if bs_re != bs_mask: - logger.warning( - "[R3] set_router_replay_data: batch size mismatch! " - "layers_topk_idx.shape[0]=%d != attention_mask.shape[0]=%d. " - "Clamping iteration to min=%d.", - bs_re, bs_mask, min(bs_re, bs_mask), + aligned_offset = 0 + for i in range(min(n_seqs_in_cu, bs_re)): + slen = seq_lens[i] + if slen <= 0: + aligned_offset += aligned_lens[i] + continue + # Take first slen tokens from this sample's routed_experts + actual_len = min(slen, layers_topk_idx.shape[1]) + packed[aligned_offset : aligned_offset + actual_len] = ( + layers_topk_idx[i, :actual_len] ) - bs = min(bs_re, bs_mask) + aligned_offset += aligned_lens[i] - logger.debug( - "[R3] set_router_replay_data: input layers_topk_idx=%s, " - "attention_mask=%s, bs=%d (re_bs=%d, mask_bs=%d), max_seq_len=%d.", - layers_topk_idx.shape, - attention_mask.shape, - bs, - bs_re, - bs_mask, - max_seq_len, - ) - - # Step 1: Remove left-padding -> flat (total_real_tokens, num_layers, topk) - seq_lens = attention_mask.sum(dim=1).long() # (bs_mask,) - pieces = [] - for i in range(bs): - slen = int(seq_lens[i].item()) - mask = attention_mask[i].bool() - pieces.append(layers_topk_idx[i, mask][:slen]) - flat_tokens = torch.cat(pieces, dim=0) # (total_real_tokens, num_layers, topk) + # For any extra sequences in cu_seqlens beyond bs_re (batch padding), + # the packed tensor already has zeros at those positions. + for i in range(bs_re, n_seqs_in_cu): + aligned_offset += aligned_lens[i] logger.debug( - "[R3] set_router_replay_data: after left-padding removal: " - "flat_tokens=%s (total_real_tokens=%d).", - flat_tokens.shape, - flat_tokens.shape[0], + "[R3] set_router_replay_data: packed %d seqs into %d tokens " + "(TP-aligned with seq_align_to=%d, seq_lens=%s, aligned_lens=%s).", + min(n_seqs_in_cu, bs_re), + total_aligned, + seq_align_to, + seq_lens[:8], + aligned_lens[:8], ) - # Step 2: Scatter to SP ranks (Problem 6 fix: guard for non-SP case) - # When tp_size == 1 (no sequence parallelism), scatter_to_sequence_parallel_region - # should be an identity op in Megatron-Core. However, we guard against potential - # issues when the TP process group is trivial. - flat_tokens = flat_tokens.to(device) + # Step 2: Scatter to SP ranks + packed = packed.to(device) tp_size = mpu.get_tensor_model_parallel_world_size() if tp_size > 1: from megatron.core.tensor_parallel import scatter_to_sequence_parallel_region - local_tokens = scatter_to_sequence_parallel_region(flat_tokens) + local_tokens = scatter_to_sequence_parallel_region(packed) logger.debug( "[R3] set_router_replay_data: SP scatter tp_size=%d, " - "flat_tokens %s -> local_tokens %s.", + "packed %s -> local_tokens %s.", tp_size, - flat_tokens.shape, + packed.shape, local_tokens.shape, ) else: - # tp_size == 1: no SP splitting needed, use flat_tokens directly - local_tokens = flat_tokens + local_tokens = packed logger.debug( "[R3] set_router_replay_data: tp_size=1, skipping SP scatter. " "local_tokens=%s.", @@ -386,27 +422,34 @@ def set_router_replay_data( def setup_per_microbatch_replay_forward( routed_experts: torch.Tensor, - attention_mask: torch.Tensor, + cu_seqlens: torch.Tensor, tf_config, vp_rank: Optional[int] = None, + seq_align_to: Optional[int] = None, ) -> None: """Set up RouterReplay for a single micro-batch's forward pass. Args: routed_experts: ``(batch, padded_seq, num_moe_layers, topk)`` - attention_mask: ``(batch, padded_seq)`` + Left-aligned routing indices (real tokens first). + cu_seqlens: ``(batch+1,)`` or ``(batch+1+1,)`` cumulative sequence + lengths from the padded micro-batch. tf_config: Megatron TransformerConfig. vp_rank: Virtual pipeline stage rank override. + seq_align_to: Per-sequence TP alignment factor. """ logger.debug( "[R3] setup_per_microbatch_replay_forward: " - "routed_experts=%s (dtype=%s), attention_mask=%s.", + "routed_experts=%s (dtype=%s), cu_seqlens=%s.", routed_experts.shape, routed_experts.dtype, - attention_mask.shape, + cu_seqlens.shape, ) routed_experts = routed_experts.to(torch.int32) - set_router_replay_data(routed_experts, attention_mask, tf_config, vp_rank) + set_router_replay_data( + routed_experts, cu_seqlens, tf_config, vp_rank, + seq_align_to=seq_align_to, + ) logger.debug("[R3] Replay data distributed to router instances for micro-batch.") From 539c289e473e536322751d559eccd9a0302716a8 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Tue, 21 Apr 2026 20:08:09 +0800 Subject: [PATCH 034/112] fix(engine): padding --- areal/engine/megatron_engine_r3_patch.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/areal/engine/megatron_engine_r3_patch.py b/areal/engine/megatron_engine_r3_patch.py index 0cd4c12e30..c0d7104855 100644 --- a/areal/engine/megatron_engine_r3_patch.py +++ b/areal/engine/megatron_engine_r3_patch.py @@ -88,7 +88,7 @@ def patch_megatron_engine_for_r3( # =================================================================== -# routed_experts alignment (left-padded rollout → left-aligned training) +# routed_experts alignment (right-padded rollout → left-aligned training) # =================================================================== @@ -97,17 +97,18 @@ def _align_routed_experts_to_mask( cu_seqlens: torch.Tensor, max_seqlen: int, ) -> torch.Tensor: - """Align ``routed_experts`` from left-padded rollout format to left-aligned + """Align ``routed_experts`` from right-padded rollout format to left-aligned training format, matching the token layout implied by ``cu_seqlens``. **Rollout format**: ``routed_experts`` is ``(bs, batch_max_seqlen, L, K)`` - with LEFT padding (real tokens at the RIGHT end of each row). + with RIGHT padding (real tokens at the BEGINNING of each row, after + batch-level concatenation; zeros are appended at the end). **Training format**: After ``pack_tensor_dict``, tokens are LEFT-aligned (real tokens first). The ``cu_seqlens`` tells us each sample's actual length. - This function extracts the rightmost ``actual_len`` tokens from each + This function extracts the first ``actual_len`` tokens from each sample in ``routed_experts`` and produces a ``(bs_aligned, max_seqlen, L, K)`` tensor with real tokens at the LEFT (matching training convention). @@ -139,10 +140,9 @@ def _align_routed_experts_to_mask( actual_len = seq_lens[i] if actual_len <= 0: continue - # Source: rightmost actual_len tokens from left-padded routed_experts - src_start = re_seqlen - actual_len + # Source: first actual_len tokens from right-padded routed_experts n = min(actual_len, re_seqlen, max_seqlen) - aligned[i, :n] = routed_experts[i, src_start : src_start + n] + aligned[i, :n] = routed_experts[i, :n] logger.debug( "[R3] _align_routed_experts_to_mask: re_shape=%s -> aligned_shape=%s, " From 52dac7f29a64164020aa66efc1a81988e07cb03d Mon Sep 17 00:00:00 2001 From: bingyechen Date: Wed, 22 Apr 2026 12:18:46 +0800 Subject: [PATCH 035/112] fix(router_replay_utils): fix padding --- areal/engine/router_replay_utils.py | 10 +++++----- areal/utils/network.py | 30 ++++++++++++++++------------- 2 files changed, 22 insertions(+), 18 deletions(-) diff --git a/areal/engine/router_replay_utils.py b/areal/engine/router_replay_utils.py index fd99d903ae..e53e4759e0 100644 --- a/areal/engine/router_replay_utils.py +++ b/areal/engine/router_replay_utils.py @@ -532,12 +532,11 @@ def preprocess_routed_experts_batch( reshaped = routed_experts_np.reshape(num_sgl_tokens, num_moe_layers, topk) tensor = torch.from_numpy(reshaped.astype(np.int32)) - # Build (1, seq_len, num_moe_layers, topk) with left padding + # Build (1, seq_len, num_moe_layers, topk) with RIGHT padding. real_tokens = int(attention_mask.sum().item()) padded = torch.zeros(1, seq_len, num_moe_layers, topk, dtype=torch.int32) - left_pad = seq_len - real_tokens n = min(num_sgl_tokens, real_tokens) - padded[0, left_pad : left_pad + n] = tensor[:n] + padded[0, :n] = tensor[:n] if compress_dtype: max_val = padded.max().item() @@ -546,17 +545,18 @@ def preprocess_routed_experts_batch( elif max_val < 32768: padded = padded.to(torch.int16) + right_pad = seq_len - real_tokens logger.debug( "[R3] preprocess_routed_experts_batch: shape=%s dtype=%s " "(num_moe_layers=%d, topk=%d, sgl_tokens=%d, real_tokens=%d, " - "left_pad=%d).", + "right_pad=%d).", padded.shape, padded.dtype, num_moe_layers, topk, num_sgl_tokens, real_tokens, - left_pad, + right_pad, ) return padded diff --git a/areal/utils/network.py b/areal/utils/network.py index cb7b1ff791..260369659a 100644 --- a/areal/utils/network.py +++ b/areal/utils/network.py @@ -21,6 +21,22 @@ def gethostip(probe_host: str = "8.8.8.8", probe_port: int = 80) -> str: Raises: RuntimeError: If no suitable address can be determined """ + try: + with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as sock: + sock.connect((probe_host, probe_port)) + return sock.getsockname()[0] + except OSError: + pass + + try: + with socket.socket(socket.AF_INET6, socket.SOCK_DGRAM) as sock: + sock.connect(("2001:4860:4860::8888", probe_port)) + ip6 = sock.getsockname()[0] + if ip6 and ip6 != "::1": + return ip6 + except OSError: + pass + try: hostname = socket.gethostname() infos = socket.getaddrinfo(hostname, None, socket.AF_UNSPEC, socket.SOCK_DGRAM) @@ -36,19 +52,7 @@ def gethostip(probe_host: str = "8.8.8.8", probe_port: int = 80) -> str: except socket.gaierror: pass - try: - with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as sock: - sock.connect((probe_host, probe_port)) - return sock.getsockname()[0] - except OSError as e: - try: - with socket.socket(socket.AF_INET6, socket.SOCK_DGRAM) as sock: - sock.connect(("2001:4860:4860::8888", probe_port)) - ip6 = sock.getsockname()[0] - if ip6 and ip6 != "::1": - return ip6 - except OSError: - raise RuntimeError("Could not determine host IP") from e + raise RuntimeError("Could not determine host IP") def get_loopback_ip() -> str: From 39792e2fe71520074f6de4688fb4236a18b16932 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Wed, 22 Apr 2026 13:41:16 +0800 Subject: [PATCH 036/112] feat: add log --- areal/infra/controller/rollout_controller.py | 38 ++++++++-- areal/infra/controller/train_controller.py | 59 ++++++++++++--- areal/infra/rpc/guard/app.py | 32 +++++++- areal/infra/rpc/guard/engine_blueprint.py | 12 +++ areal/infra/rpc/rpc_server.py | 9 +++ areal/infra/scheduler/local.py | 80 +++++++++++++++++--- areal/trainer/rl_trainer.py | 9 +++ 7 files changed, 208 insertions(+), 31 deletions(-) diff --git a/areal/infra/controller/rollout_controller.py b/areal/infra/controller/rollout_controller.py index 8fde8e5fdf..e066ba185e 100644 --- a/areal/infra/controller/rollout_controller.py +++ b/areal/infra/controller/rollout_controller.py @@ -238,34 +238,56 @@ async def _async_initialize( **kwargs, ): # Create workers via scheduler - logger.info("Creating workers via scheduler...") + logger.info( + f"[DIAG] RolloutController._async_initialize: creating workers " + f"for role '{job.role}' via scheduler..." + ) worker_ids = self.scheduler.create_workers(job=job) - logger.info(f"Workers created: {worker_ids}") + logger.info( + f"[DIAG] RolloutController._async_initialize: workers created: {worker_ids}" + ) # Wait for workers to be ready - logger.info("Waiting for workers to be ready...") + logger.info( + "[DIAG] RolloutController._async_initialize: waiting for workers to be ready..." + ) self.workers = self.scheduler.get_workers(role=job.role) - logger.info(f"Workers ready: {[w.id for w in self.workers]}") + logger.info( + f"[DIAG] RolloutController._async_initialize: workers ready: " + f"{[w.id for w in self.workers]}, ips={[w.ip for w in self.workers]}" + ) # Get engine class path for dynamic import on workers engine_class = self.inf_engine + engine_path = f"{engine_class.__module__}.{engine_class.__name__}" # Create and initialize engines on workers - logger.info("Creating engines...") + logger.info( + f"[DIAG] RolloutController._async_initialize: creating engines " + f"(class={engine_path}) on {len(self.workers)} worker(s)..." + ) tasks = [ self.scheduler.create_engine( worker_id=worker.id, - engine=f"{engine_class.__module__}.{engine_class.__name__}", + engine=engine_path, engine_name=self._engine_name(rank), config=self.config, ) for rank, worker in enumerate(self.workers) ] await asyncio.gather(*tasks) - logger.info("Engine created on all workers!") + logger.info( + "[DIAG] RolloutController._async_initialize: engines created on all workers!" + ) - logger.info("Calling engine initialization...") + logger.info( + "[DIAG] RolloutController._async_initialize: calling engine initialization..." + ) if server_infos is not None: + logger.info( + f"[DIAG] RolloutController._async_initialize: connecting to " + f"{len(server_infos)} existing server(s) for evaluation" + ) # Connecting to existing local servers for evaluation self.server_infos = server_infos assert len(self.server_infos) == len(self.workers), ( diff --git a/areal/infra/controller/train_controller.py b/areal/infra/controller/train_controller.py index 59bef0a85b..3a9a25796b 100644 --- a/areal/infra/controller/train_controller.py +++ b/areal/infra/controller/train_controller.py @@ -282,14 +282,23 @@ def initialize( ) # Create workers via scheduler - logger.info("Creating workers via scheduler...") + logger.info( + f"[DIAG] TrainController.initialize: creating {world_size} worker(s) " + f"for role '{self._worker_role}' via scheduler..." + ) worker_ids = self.scheduler.create_workers(job=job) - logger.info(f"Workers created: {worker_ids}") + logger.info(f"[DIAG] TrainController.initialize: workers created: {worker_ids}") # Wait for workers to be ready - logger.info("Waiting for workers to be ready...") + logger.info( + "[DIAG] TrainController.initialize: waiting for workers to be ready..." + ) self.workers = self.scheduler.get_workers(role=job.role) - logger.info(f"Workers ready: {[w.id for w in self.workers]}") + logger.info( + f"[DIAG] TrainController.initialize: workers ready: " + f"{[w.id for w in self.workers]}, " + f"ips={[w.ip for w in self.workers]}" + ) # Determine distributed training master address and port from rank 0 worker # These are used for PyTorch distributed initialization across workers @@ -310,9 +319,18 @@ def initialize( engine_class = self.train_engine # Create and initialize engines on workers + engine_path = f"{engine_class.__module__}.{engine_class.__name__}" + logger.info( + f"[DIAG] TrainController.initialize: creating engines " + f"(class={engine_path}) on {len(self.workers)} worker(s)..." + ) run_async_task( self._async_create_engines, - f"{engine_class.__module__}.{engine_class.__name__}", + engine_path, + ) + logger.info( + "[DIAG] TrainController.initialize: engines created, " + "now initializing (create_process_group + initialize)..." ) run_async_task(self._async_initialize_engines, ft_spec, **kwargs) @@ -329,7 +347,10 @@ def _engine_name(self, rank: int) -> str: async def _async_create_engines(self, engine: str): """Create engine instances on all workers. Sets distributed env vars before creation.""" - logger.info("Creating engines on workers...") + logger.info( + f"[DIAG] _async_create_engines: creating engine '{engine}' " + f"on {len(self.workers)} worker(s)..." + ) async def _setup_worker(worker: Worker, rank: int): env = { @@ -337,25 +358,39 @@ async def _setup_worker(worker: Worker, rank: int): "WORLD_SIZE": str(len(self.workers)), "MASTER_ADDR": str(self._master_addr), "MASTER_PORT": str(self._master_port), - "LOCAL_RANK": "0", # NOTE: local rank is always 0 while each process use only one GPU + "LOCAL_RANK": "0", } + logger.debug( + f"[DIAG] _async_create_engines: setting env for worker " + f"'{worker.id}' (rank={rank}): {env}" + ) await self.scheduler.set_worker_env(worker.id, env) + logger.info( + f"[DIAG] _async_create_engines: creating engine on worker " + f"'{worker.id}' (rank={rank})..." + ) await self.scheduler.create_engine( worker_id=worker.id, engine=engine, engine_name=self._engine_name(rank), config=self.config, ) + logger.info( + f"[DIAG] _async_create_engines: engine created on worker " + f"'{worker.id}' (rank={rank})" + ) tasks = [ _setup_worker(worker, rank) for rank, worker in enumerate(self.workers) ] await asyncio.gather(*tasks) - logger.info("Engines created on all workers!") + logger.info("[DIAG] _async_create_engines: engines created on all workers!") async def _async_initialize_engines(self, ft_spec: FinetuneSpec, **kwargs): """Initialize engines: create process groups, then load models and setup optimizers.""" - logger.info("Calling engine initialization...") + logger.info( + "[DIAG] _async_initialize_engines: Phase 1 - creating process groups..." + ) # Phase 1: Create process groups for distributed training tasks = [ self.scheduler.async_call_engine( @@ -367,7 +402,13 @@ async def _async_initialize_engines(self, ft_spec: FinetuneSpec, **kwargs): for rank, worker in enumerate(self.workers) ] await asyncio.gather(*tasks) + logger.info( + "[DIAG] _async_initialize_engines: Phase 1 complete - process groups created" + ) # Phase 2: Initialize engines (load models, setup optimizers, etc.) + logger.info( + "[DIAG] _async_initialize_engines: Phase 2 - initializing engines (loading models)..." + ) tasks = [ self.scheduler.async_call_engine( worker_id=worker.id, diff --git a/areal/infra/rpc/guard/app.py b/areal/infra/rpc/guard/app.py index fc2d7369f9..5f8dd44b0c 100644 --- a/areal/infra/rpc/guard/app.py +++ b/areal/infra/rpc/guard/app.py @@ -430,10 +430,16 @@ def configure(): return jsonify({"error": "Invalid JSON in request body"}), 400 if not s._configure_hooks: - # No hooks registered — no-op (guard-only mode) - logger.debug("Received /configure request (no-op)") + logger.info( + f"[DIAG] /configure: received request (no-op, " + f"no hooks registered) for worker {s.role}/{s.worker_index}" + ) return jsonify({"status": "ok"}) + logger.info( + f"[DIAG] /configure: received request with " + f"{len(s._configure_hooks)} hook(s) for worker {s.role}/{s.worker_index}" + ) # Dispatch to all registered configure hooks result: dict[str, Any] = {} for hook in s._configure_hooks: @@ -441,6 +447,9 @@ def configure(): result.update(hook_result) result.setdefault("status", "success") + logger.info( + f"[DIAG] /configure: completed for worker {s.role}/{s.worker_index}" + ) return jsonify(result) except ValueError as e: @@ -511,14 +520,26 @@ def configure_state_from_args(state: GuardState, args: argparse.Namespace) -> st bind_host = args.host if bind_host == "0.0.0.0": host_ip = gethostip() + logger.info( + f"[DIAG] configure_state_from_args: gethostip() returned '{host_ip}'" + ) if ":" in host_ip: bind_host = "::" state.server_host = host_ip elif bind_host == "::": state.server_host = gethostip() + logger.info( + f"[DIAG] configure_state_from_args: gethostip() returned '{state.server_host}'" + ) else: state.server_host = bind_host + logger.info( + f"[DIAG] configure_state_from_args: bind_host={bind_host}, " + f"server_host={state.server_host}, role={args.role}, " + f"worker_index={args.worker_index}" + ) + state.experiment_name = args.experiment_name state.trial_name = args.trial_name state.role = args.role @@ -570,6 +591,10 @@ def run_server( # Register with name_resolve if state.name_resolve_type is not None: + logger.info( + f"[DIAG] Registering with name_resolve: type={state.name_resolve_type}, " + f"nfs_root={state.nfs_record_root}, etcd3={state.etcd3_addr}" + ) name_resolve.reconfigure( NameResolveConfig( type=state.name_resolve_type, @@ -577,6 +602,7 @@ def run_server( etcd3_addr=state.etcd3_addr or "localhost:2379", ) ) + logger.info("[DIAG] name_resolve reconfigured successfully") worker_id = f"{state.role}/{state.worker_index}" key = names.worker_discovery( @@ -585,7 +611,9 @@ def run_server( state.role, state.worker_index, ) + logger.info(f"[DIAG] Adding name_resolve entry: key={key}, addr={state.node_addr}") name_resolve.add(key, state.node_addr, replace=True) + logger.info(f"[DIAG] name_resolve.add completed for {worker_id}") logger.info(f"Starting Guard on {state.node_addr} for worker {worker_id}") diff --git a/areal/infra/rpc/guard/engine_blueprint.py b/areal/infra/rpc/guard/engine_blueprint.py index 4dbd1c765d..8dcdde76ca 100644 --- a/areal/infra/rpc/guard/engine_blueprint.py +++ b/areal/infra/rpc/guard/engine_blueprint.py @@ -286,6 +286,10 @@ def create_engine(): engine = data.get("engine") engine_name = data.get("engine_name") + logger.info( + f"[DIAG] /create_engine: received request for engine='{engine}', " + f"engine_name='{engine_name}'" + ) # Deserialize init_args and init_kwargs (may contain tensors/dataclasses) init_args = deserialize_value(data.get("init_args", [])) init_kwargs = deserialize_value(data.get("init_kwargs", {})) @@ -357,6 +361,10 @@ def create_engine_in_engine_thread(): "create_engine", create_engine_in_engine_thread ) _engines[engine_name] = engine_obj + logger.info( + f"[DIAG] /create_engine: engine '{engine_name}' " + f"created successfully (class: {engine})" + ) return jsonify( { "status": "success", @@ -403,6 +411,10 @@ def call_engine_method(): method_name = data.get("method") engine_name = data.get("engine_name") + logger.info( + f"[DIAG] /call: received request for method='{method_name}', " + f"engine_name='{engine_name}'" + ) raw_args = data.get("args", []) raw_kwargs = data.get("kwargs", {}) diff --git a/areal/infra/rpc/rpc_server.py b/areal/infra/rpc/rpc_server.py index edc230a7ae..43f510b13c 100644 --- a/areal/infra/rpc/rpc_server.py +++ b/areal/infra/rpc/rpc_server.py @@ -48,14 +48,23 @@ def main(): state = GuardState() bind_host = configure_state_from_args(state, args) + logger.info( + f"[DIAG] RPC Server: bind_host={bind_host}, port={args.port}, " + f"role={state.role}, worker_index={state.worker_index}" + ) + app = create_app(state) app.register_blueprint(data_bp) + logger.info("[DIAG] RPC Server: data blueprint registered") app.register_blueprint(engine_bp) + logger.info("[DIAG] RPC Server: engine blueprint registered") register_engine_hooks(state) + logger.info("[DIAG] RPC Server: engine hooks registered") state.register_cleanup_hook(lambda: perf_tracer.save(force=True)) logger.info(f"Werkzeug log level: {args.werkzeug_log_level}") + logger.info("[DIAG] RPC Server: calling run_server...") run_server(state, app, bind_host, args.port) diff --git a/areal/infra/scheduler/local.py b/areal/infra/scheduler/local.py index 8c1b9a7a35..79504c1b46 100644 --- a/areal/infra/scheduler/local.py +++ b/areal/infra/scheduler/local.py @@ -549,8 +549,15 @@ async def _create_forked_workers_async( # Configure forked workers if exp_config is available if self.exp_config is not None: + logger.info( + f"[DIAG] create_workers: configuring {len(workers)} worker(s) " + f"for role '{role}' with exp_config" + ) for worker_rank, worker_info in enumerate(workers): self._configure_worker(worker_info, worker_rank) + logger.info( + f"[DIAG] create_workers: all workers for role '{role}' configured" + ) return worker_ids @@ -791,9 +798,13 @@ def create_workers(self, job: Job, *args, **kwargs) -> list[str]: stderr, ) + worker_ip = gethostip() + logger.info( + f"[DIAG] create_workers: gethostip() returned '{worker_ip}' for worker '{worker_id}'" + ) worker = Worker( id=worker_id, - ip=gethostip(), + ip=worker_ip, worker_ports=[str(p) for p in ports], engine_ports=[], ) @@ -828,8 +839,15 @@ def create_workers(self, job: Job, *args, **kwargs) -> list[str]: raise WorkerCreationError(role, "Unexpected error", str(e)) from e if self.exp_config is not None: + logger.info( + f"[DIAG] create_workers: configuring {len(workers)} worker(s) " + f"for role '{role}' with exp_config" + ) for worker_rank, worker_info in enumerate(workers): self._configure_worker(worker_info, worker_rank) + logger.info( + f"[DIAG] create_workers: all workers for role '{role}' configured" + ) return worker_ids @@ -876,13 +894,25 @@ def get_workers(self, role: str, timeout: float | None = None) -> list[Worker]: workers = self._workers[role] timeout = timeout if timeout is not None else self.startup_timeout + logger.info( + f"[DIAG] get_workers: waiting for {len(workers)} worker(s) " + f"of role '{role}' to be ready (timeout={timeout}s)" + ) self._check_worker_health(role) start_time = time.time() ready_workers = set() while len(ready_workers) < len(workers): - if time.time() - start_time > timeout: + elapsed = time.time() - start_time + if elapsed > timeout: + not_ready = [ + w.worker.id for w in workers if w.worker.id not in ready_workers + ] + logger.error( + f"[DIAG] get_workers: TIMEOUT after {elapsed:.1f}s. " + f"Ready: {ready_workers}, Not ready: {not_ready}" + ) raise WorkerTimeoutError( role, timeout, @@ -906,7 +936,10 @@ def get_workers(self, role: str, timeout: float | None = None) -> list[Worker]: if self._is_worker_ready(worker_info): ready_workers.add(worker_info.worker.id) - logger.debug(f"Worker {worker_info.worker.id} is ready") + logger.info( + f"[DIAG] get_workers: worker '{worker_info.worker.id}' " + f"is ready ({len(ready_workers)}/{len(workers)})" + ) if len(ready_workers) < len(workers): time.sleep(self.health_check_interval) @@ -920,17 +953,33 @@ def _is_worker_ready(self, worker_info: WorkerInfo) -> bool: try: response = requests.get(url, timeout=2.0) - return response.status_code == 200 - except Exception: + ready = response.status_code == 200 + if ready: + logger.debug( + f"[DIAG] _is_worker_ready: {url} -> {response.status_code} (ready)" + ) + return ready + except Exception as e: + logger.debug(f"[DIAG] _is_worker_ready: {url} -> error: {e}") return False def _configure_worker(self, worker_info: WorkerInfo, worker_rank: int): + worker_id = worker_info.worker.id + logger.info( + f"[DIAG] _configure_worker: waiting for worker '{worker_id}' " + f"(ip={worker_info.worker.ip}, ports={worker_info.worker.worker_ports}) to be ready" + ) + wait_start = time.time() while not self._is_worker_ready(worker_info): time.sleep(0.1) + logger.info( + f"[DIAG] _configure_worker: worker '{worker_id}' ready after " + f"{time.time() - wait_start:.1f}s, sending configure request" + ) - worker_id = worker_info.worker.id port = int(worker_info.worker.worker_ports[0]) url = f"http://{format_hostport(worker_info.worker.ip, port)}/configure" + logger.info(f"[DIAG] _configure_worker: POST {url} for worker '{worker_id}'") try: response = requests.post( @@ -1087,6 +1136,10 @@ async def set_worker_env(self, worker_id: str, env: dict[str, str]) -> None: payload = {"env": env} port = int(worker_info.worker.worker_ports[0]) url = f"http://{format_hostport(worker_info.worker.ip, port)}/set_env" + logger.info( + f"[DIAG] set_worker_env: POST {url} for worker '{worker_id}', " + f"env keys={list(env.keys())}" + ) try: timeout = aiohttp.ClientTimeout(total=30.0) @@ -1180,8 +1233,9 @@ async def create_engine( url = f"http://{format_hostport(worker_info.worker.ip, port)}/create_engine" try: - logger.debug( - f"Creating engine '{engine_name}' (class: {engine}) on worker '{worker_id}'" + logger.info( + f"[DIAG] create_engine: POST {url} engine='{engine_name}' " + f"(class: {engine}) on worker '{worker_id}'" ) timeout = aiohttp.ClientTimeout(total=300.0) @@ -1197,8 +1251,9 @@ async def create_engine( ) as response: if response.status == 200: result = await response.json() - logger.debug( - f"Engine '{engine_name}' created successfully on worker '{worker_id}'" + logger.info( + f"[DIAG] create_engine: engine '{engine_name}' " + f"created successfully on worker '{worker_id}'" ) return result.get("result") elif response.status == 400: @@ -1455,8 +1510,9 @@ async def async_call_engine( ) try: - logger.debug( - f"Async calling method '{method}' on worker '{worker_id}' (attempt {attempt})" + logger.info( + f"[DIAG] async_call_engine: POST {url} method='{method}' " + f"engine='{engine_name}' on worker '{worker_id}' (attempt {attempt})" ) timeo = aiohttp.ClientTimeout( diff --git a/areal/trainer/rl_trainer.py b/areal/trainer/rl_trainer.py index 5400652202..c529edd137 100644 --- a/areal/trainer/rl_trainer.py +++ b/areal/trainer/rl_trainer.py @@ -202,11 +202,17 @@ def __init__( engine_init_kwargs = {"addr": None, "ft_spec": ft_spec} + logger.info("[DIAG] PPOTrainer: initializing actor engine...") self.actor.initialize(**engine_init_kwargs, role="actor") + logger.info("[DIAG] PPOTrainer: actor engine initialized") if self.critic is not None: + logger.info("[DIAG] PPOTrainer: initializing critic engine...") self.critic.initialize(**engine_init_kwargs, role="critic") + logger.info("[DIAG] PPOTrainer: critic engine initialized") if self.ref is not None: + logger.info("[DIAG] PPOTrainer: initializing ref engine...") self.ref.initialize(**engine_init_kwargs, role="ref") + logger.info("[DIAG] PPOTrainer: ref engine initialized") self.teacher = None if config.teacher is not None: @@ -220,9 +226,11 @@ def __init__( initial_lora_path = self._save_initial_lora_weights() # Initialize inference with LoRA path + logger.info("[DIAG] PPOTrainer: initializing rollout engine...") self.rollout = self._init_rollout( config.rollout, is_eval=False, lora_path=initial_lora_path ) + logger.info("[DIAG] PPOTrainer: rollout engine initialized") # Online mode detection: skip eval rollout for efficiency. openai_cfg = config.rollout.openai self._online_mode = train_dataset is None or ( @@ -464,6 +472,7 @@ def train( log_moe_routing_metrics, log_r3_data_stats, ) + log_moe_routing_metrics(traj) if getattr(self.config.rollout, "return_routed_experts", False): log_r3_data_stats(traj) From aaf2e5b1ad25390aa48f0085b346f7fa1294d228 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Wed, 22 Apr 2026 14:57:13 +0800 Subject: [PATCH 037/112] fix(scheduler): fix local --- areal/infra/scheduler/local.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/areal/infra/scheduler/local.py b/areal/infra/scheduler/local.py index 79504c1b46..2079cb80cd 100644 --- a/areal/infra/scheduler/local.py +++ b/areal/infra/scheduler/local.py @@ -955,12 +955,16 @@ def _is_worker_ready(self, worker_info: WorkerInfo) -> bool: response = requests.get(url, timeout=2.0) ready = response.status_code == 200 if ready: - logger.debug( + logger.info( f"[DIAG] _is_worker_ready: {url} -> {response.status_code} (ready)" ) + else: + logger.warning( + f"[DIAG] _is_worker_ready: {url} -> {response.status_code} (not ready)" + ) return ready except Exception as e: - logger.debug(f"[DIAG] _is_worker_ready: {url} -> error: {e}") + logger.warning(f"[DIAG] _is_worker_ready: {url} -> error: {e}") return False def _configure_worker(self, worker_info: WorkerInfo, worker_rank: int): @@ -970,8 +974,19 @@ def _configure_worker(self, worker_info: WorkerInfo, worker_rank: int): f"(ip={worker_info.worker.ip}, ports={worker_info.worker.worker_ports}) to be ready" ) wait_start = time.time() + last_log_time = wait_start while not self._is_worker_ready(worker_info): time.sleep(0.1) + now = time.time() + if now - last_log_time >= 5.0: + elapsed = now - wait_start + logger.warning( + f"[DIAG] _configure_worker: still waiting for worker " + f"'{worker_id}' after {elapsed:.0f}s " + f"(ip={worker_info.worker.ip}, " + f"ports={worker_info.worker.worker_ports})" + ) + last_log_time = now logger.info( f"[DIAG] _configure_worker: worker '{worker_id}' ready after " f"{time.time() - wait_start:.1f}s, sending configure request" From d517f2797f7a0b87c46caaa9eadb9d84425a9336 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Thu, 23 Apr 2026 12:18:21 +0800 Subject: [PATCH 038/112] refactor(router): fix router metric --- areal/engine/megatron_engine_r3_patch.py | 2 +- areal/engine/router_replay_patch.py | 110 ++++++++++++----------- 2 files changed, 59 insertions(+), 53 deletions(-) diff --git a/areal/engine/megatron_engine_r3_patch.py b/areal/engine/megatron_engine_r3_patch.py index c0d7104855..0c8e61f301 100644 --- a/areal/engine/megatron_engine_r3_patch.py +++ b/areal/engine/megatron_engine_r3_patch.py @@ -397,7 +397,7 @@ def _r3_forward_backward_batch( # ------------------------------------------------------------------ # 2b. Set initial replay action to REPLAY_FORWARD. # ------------------------------------------------------------------ - RouterReplay.reset_agreement_accumulator() + RouterReplay.reset_agreement_stats() RouterReplay.set_global_router_replay_action(RouterReplayAction.REPLAY_FORWARD) logger.debug( "[R3] Set initial REPLAY_FORWARD action on %d router instances.", diff --git a/areal/engine/router_replay_patch.py b/areal/engine/router_replay_patch.py index d5cfc5ca08..b102f3cb1d 100644 --- a/areal/engine/router_replay_patch.py +++ b/areal/engine/router_replay_patch.py @@ -119,8 +119,9 @@ class RouterReplay: # Class-level agreement accumulator. # Collects per-call agreement rates during REPLAY_FORWARD to provide # an accurate R3 effectiveness metric every training step. - _agreement_samples: list = [] - _replay_call_count: int = 0 + _agreement_matches: int = 0 + _agreement_total: int = 0 + _agreement_per_call: list = [] # ------------------------------------------------------------------ # Class-level (static) helpers @@ -167,32 +168,36 @@ def clear_global_router_replay_action() -> None: for r in RouterReplay.router_instances: r.clear_router_replay_action() - @staticmethod - def reset_agreement_accumulator() -> None: - """Reset the agreement accumulator for a new training step.""" - RouterReplay._agreement_samples = [] - RouterReplay._replay_call_count = 0 + @classmethod + def reset_agreement_stats(cls) -> None: + """Reset the agreement rate accumulator before a training step.""" + cls._agreement_matches = 0 + cls._agreement_total = 0 + cls._agreement_per_call = [] - @staticmethod - def harvest_agreement_stats() -> dict: - """Harvest accumulated agreement samples and return summary stats. + reset_agreement_accumulator = reset_agreement_stats + + @classmethod + def get_agreement_rate(cls) -> float: + if cls._agreement_total == 0: + return -1.0 + return cls._agreement_matches / cls._agreement_total + + @classmethod + def harvest_agreement_stats(cls) -> dict: + """Harvest accumulated agreement statistics and reset. - Returns: - dict with keys: avg, min, max, n_samples, n_calls. - If no samples, returns dict with n_samples=0. + Returns a dict with keys: avg, min, max, n_samples, n_calls. """ - samples = RouterReplay._agreement_samples - n_calls = RouterReplay._replay_call_count - if not samples: - return {"n_samples": 0, "n_calls": n_calls} - avg = sum(samples) / len(samples) - return { - "avg": avg, - "min": min(samples), - "max": max(samples), - "n_samples": len(samples), - "n_calls": n_calls, + result = { + "avg": cls.get_agreement_rate(), + "min": min(cls._agreement_per_call) if cls._agreement_per_call else -1.0, + "max": max(cls._agreement_per_call) if cls._agreement_per_call else -1.0, + "n_samples": cls._agreement_total, + "n_calls": len(cls._agreement_per_call), } + cls.reset_agreement_stats() + return result def __init__(self) -> None: self.target_topk_idx: torch.Tensor | None = None @@ -282,38 +287,39 @@ def compute_topk(scores, topk, num_groups=None, group_topk=None): ) return _compute_topk(scores, topk, num_groups=num_groups, group_topk=group_topk) - # Use the provided indices for replay (following verl's approach). + # Use the provided indices for replay top_indices = router_replay.target_topk_idx top_indices = top_indices.to(scores.device) probs = scores.gather(1, top_indices) - # Compute agreement rate on every REPLAY_FORWARD call. - # This is the TRUE R3 effectiveness metric: comparing what the - # training-time router would naturally choose vs what we force - # it to replay from inference. Accumulated across all layers - # and microbatches, then harvested once per step by the engine. - RouterReplay._replay_call_count += 1 - _call_n = RouterReplay._replay_call_count - with torch.no_grad(): - _, natural_indices = _compute_topk( - scores, topk, num_groups=num_groups, group_topk=group_topk - ) - replay_sorted = top_indices.sort(dim=-1).values - natural_sorted = natural_indices.sort(dim=-1).values - per_token_matches = (natural_sorted == replay_sorted).float().sum(dim=-1) - agreement_rate = (per_token_matches / topk).mean().item() - RouterReplay._agreement_samples.append(agreement_rate) - - if _call_n <= 3: - _nz = top_indices[top_indices > 0].flatten()[:3].tolist() - print( - f"[R3-VERIFY] Megatron REPLAY_FORWARD " - f"#{_call_n}: " - f"top_indices shape={top_indices.shape}, " - f"first3_nonzero={_nz}, " - f"agreement_rate={agreement_rate:.4f}", - flush=True, - ) + # --- Router Agreement Rate: compare replay vs natural routing --- + # Compute what the router would have chosen WITHOUT replay + # (no grad to avoid interfering with the backward graph). + # NOTE: Exclude padding tokens from agreement computation. + # Padding tokens have all-zero replay indices and would + # artificially drag down agreement rates since their + # natural routing is essentially random. + try: + with torch.no_grad(): + _, natural_indices = _compute_topk( + scores, topk, num_groups=num_groups, group_topk=group_topk + ) + non_padding_mask = (top_indices != 0).any(dim=-1) + replay_sorted = top_indices.sort(dim=-1).values + natural_sorted = natural_indices.sort(dim=-1).values + matches = (replay_sorted == natural_sorted).all(dim=-1) + if non_padding_mask.any(): + masked_matches = matches[non_padding_mask] + n_matched = int(masked_matches.sum().item()) + n_total = int(masked_matches.numel()) + RouterReplay._agreement_matches += n_matched + RouterReplay._agreement_total += n_total + if n_total > 0: + RouterReplay._agreement_per_call.append( + n_matched / n_total + ) + except Exception: + logger.debug("[R3] Agreement rate computation failed.", exc_info=True) return probs, top_indices elif routing_action == RouterReplayAction.REPLAY_BACKWARD: From 56cb6374a77b592f52a5ccad6153b3552c441611 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Sat, 25 Apr 2026 21:01:56 +0800 Subject: [PATCH 039/112] refactor(megatron_engine): later --- areal/engine/megatron_engine.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index cd24c368f5..06f71cd8b1 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -14,8 +14,6 @@ import mbridge import torch import torch.distributed as dist -from megatron.bridge import AutoBridge as MegatronBridgeAutoBridge -from megatron.bridge.peft.lora import LoRA as MegatronBridgeLoRA from megatron.core import parallel_state as mpu from megatron.core import tensor_parallel from megatron.core.distributed import DistributedDataParallel as DDP @@ -109,6 +107,8 @@ if TYPE_CHECKING: from areal.api import Scheduler from areal.api.cli_args import PPOActorConfig, PPOCriticConfig + from megatron.bridge import AutoBridge as MegatronBridgeAutoBridge + from megatron.bridge.peft.lora import LoRA as MegatronBridgeLoRA class _MegatronModelList(list): @@ -225,6 +225,8 @@ def create_process_group(self, parallel_strategy: ParallelStrategy | None = None self.process_group_initialized = True def _apply_megatron_bridge_lora(self) -> None: + from megatron.bridge.peft.lora import LoRA as MegatronBridgeLoRA + assert self.model is not None, "Model must be initialized before applying LoRA." assert self.bridge_cls == "megatron-bridge" @@ -452,6 +454,8 @@ def _build_hf_mcore_bridge(self): ) elif self.bridge_cls == "megatron-bridge": + from megatron.bridge import AutoBridge as MegatronBridgeAutoBridge + if self.enable_tree_training: raise NotImplementedError( "Tree training is not supported with bridge_type='megatron-bridge'." From fbd94aee785bc34702d235794500942b03bf1e79 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Sat, 25 Apr 2026 21:10:16 +0800 Subject: [PATCH 040/112] fix(engine): patch --- areal/engine/megatron_engine.py | 6 +++++- areal/engine/megatron_utils/megatron_lora.py | 13 ++++++++----- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index 06f71cd8b1..8f04534f1d 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -58,7 +58,10 @@ get_named_parameters, remove_padding, ) -from areal.engine.megatron_utils.megatron_lora import get_vllm_lora_target_modules +from areal.engine.megatron_utils.megatron_lora import ( + ensure_save_hf_adapter_patched, + get_vllm_lora_target_modules, +) from areal.engine.megatron_utils.packed_context_parallel import ( packed_context_parallel_forward, ) @@ -1560,6 +1563,7 @@ def _save_model_to_hf( "Saving critic model is not supported with megatron-bridge." ) if self.config.use_lora: + ensure_save_hf_adapter_patched() self.bridge.save_hf_adapter( self.model, path=path, diff --git a/areal/engine/megatron_utils/megatron_lora.py b/areal/engine/megatron_utils/megatron_lora.py index 23f5299e4b..c5f3a4ee47 100644 --- a/areal/engine/megatron_utils/megatron_lora.py +++ b/areal/engine/megatron_utils/megatron_lora.py @@ -289,8 +289,11 @@ def save_hf_adapter( AutoBridge.save_hf_adapter = save_hf_adapter -# Current: This monkey patch is needed as the current megatron-bridge 0.3.0 does not have a built-in method -# to save LoRA adapters in HuggingFace PEFT format, which is required for our use case. -# Future: This code is however present in main branch of megatron-bridge so this patch is temporary -# and can be removed later when we upgrade the megatron-bridge version. -_monkey_patch_save_hf_adapter() +_monkey_patch_applied = False + + +def ensure_save_hf_adapter_patched(): + global _monkey_patch_applied + if not _monkey_patch_applied: + _monkey_patch_save_hf_adapter() + _monkey_patch_applied = True From 9583f238a61927a83b0e39ca9969b12e9c95418b Mon Sep 17 00:00:00 2001 From: bingyechen Date: Sat, 25 Apr 2026 21:26:16 +0800 Subject: [PATCH 041/112] feat(math): add moonlight-16b-a3b-gsm8k-grpo --- ..._16b_a3b_gsm8k_grpo_megatron_h20_base.yaml | 186 ++++++++++++++++++ 1 file changed, 186 insertions(+) create mode 100644 examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron_h20_base.yaml diff --git a/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron_h20_base.yaml b/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron_h20_base.yaml new file mode 100644 index 0000000000..8e13cc3e8f --- /dev/null +++ b/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron_h20_base.yaml @@ -0,0 +1,186 @@ +experiment_name: moonlight-16b-a3b-gsm8k-grpo-h20-base +trial_name: trial0 + +seed: 1 +enable_offload: false +total_train_epochs: 10 +tokenizer_path: ${actor.path} + +cluster: + n_nodes: 1 + n_gpus_per_node: 6 + fileroot: /tmp/areal/moon_experiments + name_resolve: + type: nfs + nfs_record_root: /tmp/areal/moon_name_resolve + +scheduler: + type: null + +rollout: + backend: "sglang:d1p1t2" + experiment_name: ${experiment_name} + trial_name: ${trial_name} + max_concurrent_rollouts: 64 + queue_size: null + consumer_batch_size: ${train_dataset.batch_size} + max_head_offpolicyness: 1 + enable_rollout_tracing: false + scheduling_spec: ${actor.scheduling_spec} + fileroot: ${cluster.fileroot} + tokenizer_path: ${tokenizer_path} + dump_to_file: true + +gconfig: + n_samples: 8 + min_new_tokens: 0 + max_new_tokens: 1024 + greedy: false + temperature: 1.0 + +actor: + backend: "megatron:(attn:d1p1t4|ffn:d1p1t1e4)" # ← PP=2 回退,TP=4/EP=4 + experiment_name: ${experiment_name} + trial_name: ${trial_name} + path: /workspace/models/Moonlight-16B-A3B-Instruct + init_from_scratch: false + disable_dropout: true + gradient_checkpointing: true + dtype: bfloat16 + mb_spec: + max_tokens_per_mb: 1280 # ← 从 2048 降至 512 + optimizer: + type: adam_bf16 + lr: 2e-6 + weight_decay: 0.003 + beta1: 0.9 + beta2: 0.999 + eps: 1e-8 + lr_scheduler_type: constant + gradient_clipping: 1.0 + warmup_steps_proportion: 0.001 + eps_clip: 0.4 + temperature: ${gconfig.temperature} + reward_scaling: 10.0 + reward_bias: -0.5 + kl_ctl: 0.0 + ppo_n_minibatches: 1 # ← 从 1 提高至 4(分批梯度累积) + recompute_logprob: true + use_decoupled_loss: true + behave_imp_weight_cap: 5.0 + reward_norm: + mean_level: group + std_level: group + group_size: ${gconfig.n_samples} + adv_norm: + mean_level: batch + std_level: batch + weight_update_mode: disk + max_new_tokens: ${gconfig.max_new_tokens} + megatron: + use_deterministic_algorithms: false + use_precision_aware_optimizer: true + recompute_granularity: full + recompute_method: uniform + recompute_num_layers: 9 + ddp: + grad_reduce_in_fp32: false # ← 保持逐层重计算 + scheduling_spec: + - task_type: worker + port_count: 2 + gpu: 1 + mem: 48 + cmd: python3 -m areal.infra.rpc.rpc_server + env_vars: + PYTORCH_CUDA_ALLOC_CONF: "expandable_segments:True" + +ref: + backend: ${actor.backend} + experiment_name: ${experiment_name} + trial_name: ${trial_name} + path: ${actor.path} + init_from_scratch: false + disable_dropout: true + dtype: ${actor.dtype} + mb_spec: + max_tokens_per_mb: 1280 + optimizer: null + scheduling_strategy: + type: colocation + target: actor + scheduling_spec: ${actor.scheduling_spec} + +sglang: + model_path: ${actor.path} + random_seed: ${seed} + skip_tokenizer_init: true + dtype: bfloat16 + max_running_requests: 64 + context_length: 2048 + mem_fraction_static: 0.8 + attention_backend: triton + disable_cuda_graph: true + +vllm: + model: ${actor.path} + seed: ${seed} + skip_tokenizer_init: false + dtype: bfloat16 + max_model_len: 4096 + gpu_memory_utilization: 0.75 + +train_dataset: + batch_size: 64 + shuffle: true + pin_memory: true + num_workers: 4 + path: openai/gsm8k + type: rl + max_length: 512 + +valid_dataset: + batch_size: 128 + pin_memory: true + num_workers: 4 + path: openai/gsm8k + type: rl + +saver: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: 1 + freq_steps: null + freq_secs: null + +recover: + mode: disabled + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: 1 + freq_steps: null + freq_secs: 3600 + +evaluator: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: 1 + freq_steps: null + freq_secs: null + +stats_logger: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + wandb: + mode: disabled + +perf_tracer: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + enabled: false + session_tracer: + enabled: false \ No newline at end of file From 49c83aba267c7f53c9032092dcf2dccc2bdc55c1 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Sat, 25 Apr 2026 22:29:03 +0800 Subject: [PATCH 042/112] fix(engine): fix mtp_num_layers --- areal/engine/megatron_engine.py | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index 8f04534f1d..89a60e41b2 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -114,6 +114,33 @@ from megatron.bridge.peft.lora import LoRA as MegatronBridgeLoRA +def _patch_gpt_model_postprocess_for_inference(model_list: _MegatronModelList) -> None: + from megatron.core.models.gpt.gpt_model import GPTModel + + if getattr(GPTModel, "_areal_postprocess_patched", False): + return + + _original_postprocess = GPTModel._postprocess + + def _patched_postprocess(self, hidden_states, input_ids, position_ids, labels, **kwargs): + if labels is None and getattr(self.config, "mtp_num_layers", None) is not None: + original_mtp = self.config.mtp_num_layers + self.config.mtp_num_layers = None + try: + result = _original_postprocess( + self, hidden_states, input_ids, position_ids, labels=labels, **kwargs + ) + finally: + self.config.mtp_num_layers = original_mtp + return result + return _original_postprocess( + self, hidden_states, input_ids, position_ids, labels=labels, **kwargs + ) + + GPTModel._postprocess = _patched_postprocess + GPTModel._areal_postprocess_patched = True + + class _MegatronModelList(list): """List wrapper that exposes module-like helpers for Megatron model chunks.""" @@ -395,6 +422,8 @@ def initialize(self, addr: str | None, ft_spec: FinetuneSpec, *args, **kwargs): for model in self.model: disable_dropout_in_model(model) + _patch_gpt_model_postprocess_for_inference(self.model) + primary_model = self.model[0] model_config = get_model_config(primary_model) From cffe25b214a2361c0a3cfa52bb978f7f41b8c432 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Sat, 25 Apr 2026 23:01:33 +0800 Subject: [PATCH 043/112] fix(mcore/hf_save): fix mla param --- areal/models/mcore/hf_save.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/areal/models/mcore/hf_save.py b/areal/models/mcore/hf_save.py index 7caaaf50e8..4dc80f0373 100644 --- a/areal/models/mcore/hf_save.py +++ b/areal/models/mcore/hf_save.py @@ -26,6 +26,16 @@ logger = logging.getLogger("HFSaver") +_MLA_DUPLICATED_WEIGHT_PATTERNS = ( + "self_attention.linear_q_down_proj.", + "self_attention.linear_kv_down_proj.", +) + + +def _is_mla_duplicated_weight(global_name: str) -> bool: + return any(p in global_name for p in _MLA_DUPLICATED_WEIGHT_PATTERNS) + + HF_MODEL_CONFIG_FILES = [ "generation_config.json", "tokenizer_config.json", @@ -409,9 +419,12 @@ def save_weights_to_hf_with_mbridge_fast( infer_params = _maybe_convert_from_te_fp8_params( infer_params, fp8_direct_convert, weight_block_size ) - infer_params = bridge._weight_merge_across_tp( - s.global_name, infer_params, param - ) + if _is_mla_duplicated_weight(s.global_name): + infer_params = infer_params[0].clone() + else: + infer_params = bridge._weight_merge_across_tp( + s.global_name, infer_params, param + ) else: infer_params = param infer_params = _maybe_convert_from_te_fp8_params( From 7c76ed3a184a6dc4e1cdec858e5deb8ec32cfdfe Mon Sep 17 00:00:00 2001 From: bingyechen Date: Sat, 25 Apr 2026 23:26:28 +0800 Subject: [PATCH 044/112] refactor: remove log --- areal/infra/controller/rollout_controller.py | 28 -------- areal/infra/controller/train_controller.py | 45 +------------ areal/infra/rpc/guard/app.py | 31 +-------- areal/infra/rpc/guard/engine_blueprint.py | 12 ---- areal/infra/rpc/rpc_server.py | 9 --- areal/infra/scheduler/local.py | 68 ++------------------ areal/trainer/rl_trainer.py | 8 --- 7 files changed, 6 insertions(+), 195 deletions(-) diff --git a/areal/infra/controller/rollout_controller.py b/areal/infra/controller/rollout_controller.py index e066ba185e..874a9319b3 100644 --- a/areal/infra/controller/rollout_controller.py +++ b/areal/infra/controller/rollout_controller.py @@ -238,34 +238,16 @@ async def _async_initialize( **kwargs, ): # Create workers via scheduler - logger.info( - f"[DIAG] RolloutController._async_initialize: creating workers " - f"for role '{job.role}' via scheduler..." - ) worker_ids = self.scheduler.create_workers(job=job) - logger.info( - f"[DIAG] RolloutController._async_initialize: workers created: {worker_ids}" - ) # Wait for workers to be ready - logger.info( - "[DIAG] RolloutController._async_initialize: waiting for workers to be ready..." - ) self.workers = self.scheduler.get_workers(role=job.role) - logger.info( - f"[DIAG] RolloutController._async_initialize: workers ready: " - f"{[w.id for w in self.workers]}, ips={[w.ip for w in self.workers]}" - ) # Get engine class path for dynamic import on workers engine_class = self.inf_engine engine_path = f"{engine_class.__module__}.{engine_class.__name__}" # Create and initialize engines on workers - logger.info( - f"[DIAG] RolloutController._async_initialize: creating engines " - f"(class={engine_path}) on {len(self.workers)} worker(s)..." - ) tasks = [ self.scheduler.create_engine( worker_id=worker.id, @@ -276,18 +258,8 @@ async def _async_initialize( for rank, worker in enumerate(self.workers) ] await asyncio.gather(*tasks) - logger.info( - "[DIAG] RolloutController._async_initialize: engines created on all workers!" - ) - logger.info( - "[DIAG] RolloutController._async_initialize: calling engine initialization..." - ) if server_infos is not None: - logger.info( - f"[DIAG] RolloutController._async_initialize: connecting to " - f"{len(server_infos)} existing server(s) for evaluation" - ) # Connecting to existing local servers for evaluation self.server_infos = server_infos assert len(self.server_infos) == len(self.workers), ( diff --git a/areal/infra/controller/train_controller.py b/areal/infra/controller/train_controller.py index 3a9a25796b..9210b06775 100644 --- a/areal/infra/controller/train_controller.py +++ b/areal/infra/controller/train_controller.py @@ -282,23 +282,10 @@ def initialize( ) # Create workers via scheduler - logger.info( - f"[DIAG] TrainController.initialize: creating {world_size} worker(s) " - f"for role '{self._worker_role}' via scheduler..." - ) worker_ids = self.scheduler.create_workers(job=job) - logger.info(f"[DIAG] TrainController.initialize: workers created: {worker_ids}") # Wait for workers to be ready - logger.info( - "[DIAG] TrainController.initialize: waiting for workers to be ready..." - ) self.workers = self.scheduler.get_workers(role=job.role) - logger.info( - f"[DIAG] TrainController.initialize: workers ready: " - f"{[w.id for w in self.workers]}, " - f"ips={[w.ip for w in self.workers]}" - ) # Determine distributed training master address and port from rank 0 worker # These are used for PyTorch distributed initialization across workers @@ -320,18 +307,10 @@ def initialize( # Create and initialize engines on workers engine_path = f"{engine_class.__module__}.{engine_class.__name__}" - logger.info( - f"[DIAG] TrainController.initialize: creating engines " - f"(class={engine_path}) on {len(self.workers)} worker(s)..." - ) run_async_task( self._async_create_engines, engine_path, ) - logger.info( - "[DIAG] TrainController.initialize: engines created, " - "now initializing (create_process_group + initialize)..." - ) run_async_task(self._async_initialize_engines, ft_spec, **kwargs) # Identify DP head workers @@ -347,10 +326,6 @@ def _engine_name(self, rank: int) -> str: async def _async_create_engines(self, engine: str): """Create engine instances on all workers. Sets distributed env vars before creation.""" - logger.info( - f"[DIAG] _async_create_engines: creating engine '{engine}' " - f"on {len(self.workers)} worker(s)..." - ) async def _setup_worker(worker: Worker, rank: int): env = { @@ -361,36 +336,24 @@ async def _setup_worker(worker: Worker, rank: int): "LOCAL_RANK": "0", } logger.debug( - f"[DIAG] _async_create_engines: setting env for worker " + f"Setting env for worker " f"'{worker.id}' (rank={rank}): {env}" ) await self.scheduler.set_worker_env(worker.id, env) - logger.info( - f"[DIAG] _async_create_engines: creating engine on worker " - f"'{worker.id}' (rank={rank})..." - ) await self.scheduler.create_engine( worker_id=worker.id, engine=engine, engine_name=self._engine_name(rank), config=self.config, ) - logger.info( - f"[DIAG] _async_create_engines: engine created on worker " - f"'{worker.id}' (rank={rank})" - ) tasks = [ _setup_worker(worker, rank) for rank, worker in enumerate(self.workers) ] await asyncio.gather(*tasks) - logger.info("[DIAG] _async_create_engines: engines created on all workers!") async def _async_initialize_engines(self, ft_spec: FinetuneSpec, **kwargs): """Initialize engines: create process groups, then load models and setup optimizers.""" - logger.info( - "[DIAG] _async_initialize_engines: Phase 1 - creating process groups..." - ) # Phase 1: Create process groups for distributed training tasks = [ self.scheduler.async_call_engine( @@ -402,13 +365,7 @@ async def _async_initialize_engines(self, ft_spec: FinetuneSpec, **kwargs): for rank, worker in enumerate(self.workers) ] await asyncio.gather(*tasks) - logger.info( - "[DIAG] _async_initialize_engines: Phase 1 complete - process groups created" - ) # Phase 2: Initialize engines (load models, setup optimizers, etc.) - logger.info( - "[DIAG] _async_initialize_engines: Phase 2 - initializing engines (loading models)..." - ) tasks = [ self.scheduler.async_call_engine( worker_id=worker.id, diff --git a/areal/infra/rpc/guard/app.py b/areal/infra/rpc/guard/app.py index 5f8dd44b0c..243ede4a4d 100644 --- a/areal/infra/rpc/guard/app.py +++ b/areal/infra/rpc/guard/app.py @@ -430,16 +430,8 @@ def configure(): return jsonify({"error": "Invalid JSON in request body"}), 400 if not s._configure_hooks: - logger.info( - f"[DIAG] /configure: received request (no-op, " - f"no hooks registered) for worker {s.role}/{s.worker_index}" - ) return jsonify({"status": "ok"}) - logger.info( - f"[DIAG] /configure: received request with " - f"{len(s._configure_hooks)} hook(s) for worker {s.role}/{s.worker_index}" - ) # Dispatch to all registered configure hooks result: dict[str, Any] = {} for hook in s._configure_hooks: @@ -447,9 +439,6 @@ def configure(): result.update(hook_result) result.setdefault("status", "success") - logger.info( - f"[DIAG] /configure: completed for worker {s.role}/{s.worker_index}" - ) return jsonify(result) except ValueError as e: @@ -520,26 +509,14 @@ def configure_state_from_args(state: GuardState, args: argparse.Namespace) -> st bind_host = args.host if bind_host == "0.0.0.0": host_ip = gethostip() - logger.info( - f"[DIAG] configure_state_from_args: gethostip() returned '{host_ip}'" - ) if ":" in host_ip: bind_host = "::" state.server_host = host_ip elif bind_host == "::": state.server_host = gethostip() - logger.info( - f"[DIAG] configure_state_from_args: gethostip() returned '{state.server_host}'" - ) else: state.server_host = bind_host - logger.info( - f"[DIAG] configure_state_from_args: bind_host={bind_host}, " - f"server_host={state.server_host}, role={args.role}, " - f"worker_index={args.worker_index}" - ) - state.experiment_name = args.experiment_name state.trial_name = args.trial_name state.role = args.role @@ -591,10 +568,6 @@ def run_server( # Register with name_resolve if state.name_resolve_type is not None: - logger.info( - f"[DIAG] Registering with name_resolve: type={state.name_resolve_type}, " - f"nfs_root={state.nfs_record_root}, etcd3={state.etcd3_addr}" - ) name_resolve.reconfigure( NameResolveConfig( type=state.name_resolve_type, @@ -602,7 +575,7 @@ def run_server( etcd3_addr=state.etcd3_addr or "localhost:2379", ) ) - logger.info("[DIAG] name_resolve reconfigured successfully") + logger.info("name_resolve reconfigured successfully") worker_id = f"{state.role}/{state.worker_index}" key = names.worker_discovery( @@ -611,9 +584,7 @@ def run_server( state.role, state.worker_index, ) - logger.info(f"[DIAG] Adding name_resolve entry: key={key}, addr={state.node_addr}") name_resolve.add(key, state.node_addr, replace=True) - logger.info(f"[DIAG] name_resolve.add completed for {worker_id}") logger.info(f"Starting Guard on {state.node_addr} for worker {worker_id}") diff --git a/areal/infra/rpc/guard/engine_blueprint.py b/areal/infra/rpc/guard/engine_blueprint.py index 8dcdde76ca..4dbd1c765d 100644 --- a/areal/infra/rpc/guard/engine_blueprint.py +++ b/areal/infra/rpc/guard/engine_blueprint.py @@ -286,10 +286,6 @@ def create_engine(): engine = data.get("engine") engine_name = data.get("engine_name") - logger.info( - f"[DIAG] /create_engine: received request for engine='{engine}', " - f"engine_name='{engine_name}'" - ) # Deserialize init_args and init_kwargs (may contain tensors/dataclasses) init_args = deserialize_value(data.get("init_args", [])) init_kwargs = deserialize_value(data.get("init_kwargs", {})) @@ -361,10 +357,6 @@ def create_engine_in_engine_thread(): "create_engine", create_engine_in_engine_thread ) _engines[engine_name] = engine_obj - logger.info( - f"[DIAG] /create_engine: engine '{engine_name}' " - f"created successfully (class: {engine})" - ) return jsonify( { "status": "success", @@ -411,10 +403,6 @@ def call_engine_method(): method_name = data.get("method") engine_name = data.get("engine_name") - logger.info( - f"[DIAG] /call: received request for method='{method_name}', " - f"engine_name='{engine_name}'" - ) raw_args = data.get("args", []) raw_kwargs = data.get("kwargs", {}) diff --git a/areal/infra/rpc/rpc_server.py b/areal/infra/rpc/rpc_server.py index 43f510b13c..edc230a7ae 100644 --- a/areal/infra/rpc/rpc_server.py +++ b/areal/infra/rpc/rpc_server.py @@ -48,23 +48,14 @@ def main(): state = GuardState() bind_host = configure_state_from_args(state, args) - logger.info( - f"[DIAG] RPC Server: bind_host={bind_host}, port={args.port}, " - f"role={state.role}, worker_index={state.worker_index}" - ) - app = create_app(state) app.register_blueprint(data_bp) - logger.info("[DIAG] RPC Server: data blueprint registered") app.register_blueprint(engine_bp) - logger.info("[DIAG] RPC Server: engine blueprint registered") register_engine_hooks(state) - logger.info("[DIAG] RPC Server: engine hooks registered") state.register_cleanup_hook(lambda: perf_tracer.save(force=True)) logger.info(f"Werkzeug log level: {args.werkzeug_log_level}") - logger.info("[DIAG] RPC Server: calling run_server...") run_server(state, app, bind_host, args.port) diff --git a/areal/infra/scheduler/local.py b/areal/infra/scheduler/local.py index 2079cb80cd..f44bd22573 100644 --- a/areal/infra/scheduler/local.py +++ b/areal/infra/scheduler/local.py @@ -549,15 +549,8 @@ async def _create_forked_workers_async( # Configure forked workers if exp_config is available if self.exp_config is not None: - logger.info( - f"[DIAG] create_workers: configuring {len(workers)} worker(s) " - f"for role '{role}' with exp_config" - ) for worker_rank, worker_info in enumerate(workers): self._configure_worker(worker_info, worker_rank) - logger.info( - f"[DIAG] create_workers: all workers for role '{role}' configured" - ) return worker_ids @@ -799,9 +792,6 @@ def create_workers(self, job: Job, *args, **kwargs) -> list[str]: ) worker_ip = gethostip() - logger.info( - f"[DIAG] create_workers: gethostip() returned '{worker_ip}' for worker '{worker_id}'" - ) worker = Worker( id=worker_id, ip=worker_ip, @@ -839,15 +829,8 @@ def create_workers(self, job: Job, *args, **kwargs) -> list[str]: raise WorkerCreationError(role, "Unexpected error", str(e)) from e if self.exp_config is not None: - logger.info( - f"[DIAG] create_workers: configuring {len(workers)} worker(s) " - f"for role '{role}' with exp_config" - ) for worker_rank, worker_info in enumerate(workers): self._configure_worker(worker_info, worker_rank) - logger.info( - f"[DIAG] create_workers: all workers for role '{role}' configured" - ) return worker_ids @@ -894,10 +877,6 @@ def get_workers(self, role: str, timeout: float | None = None) -> list[Worker]: workers = self._workers[role] timeout = timeout if timeout is not None else self.startup_timeout - logger.info( - f"[DIAG] get_workers: waiting for {len(workers)} worker(s) " - f"of role '{role}' to be ready (timeout={timeout}s)" - ) self._check_worker_health(role) start_time = time.time() @@ -909,10 +888,6 @@ def get_workers(self, role: str, timeout: float | None = None) -> list[Worker]: not_ready = [ w.worker.id for w in workers if w.worker.id not in ready_workers ] - logger.error( - f"[DIAG] get_workers: TIMEOUT after {elapsed:.1f}s. " - f"Ready: {ready_workers}, Not ready: {not_ready}" - ) raise WorkerTimeoutError( role, timeout, @@ -936,10 +911,6 @@ def get_workers(self, role: str, timeout: float | None = None) -> list[Worker]: if self._is_worker_ready(worker_info): ready_workers.add(worker_info.worker.id) - logger.info( - f"[DIAG] get_workers: worker '{worker_info.worker.id}' " - f"is ready ({len(ready_workers)}/{len(workers)})" - ) if len(ready_workers) < len(workers): time.sleep(self.health_check_interval) @@ -954,25 +925,17 @@ def _is_worker_ready(self, worker_info: WorkerInfo) -> bool: try: response = requests.get(url, timeout=2.0) ready = response.status_code == 200 - if ready: - logger.info( - f"[DIAG] _is_worker_ready: {url} -> {response.status_code} (ready)" - ) - else: + if not ready: logger.warning( - f"[DIAG] _is_worker_ready: {url} -> {response.status_code} (not ready)" + f"Worker health check failed: {url} -> {response.status_code}" ) return ready except Exception as e: - logger.warning(f"[DIAG] _is_worker_ready: {url} -> error: {e}") + logger.warning(f"Worker health check error: {url} -> {e}") return False def _configure_worker(self, worker_info: WorkerInfo, worker_rank: int): worker_id = worker_info.worker.id - logger.info( - f"[DIAG] _configure_worker: waiting for worker '{worker_id}' " - f"(ip={worker_info.worker.ip}, ports={worker_info.worker.worker_ports}) to be ready" - ) wait_start = time.time() last_log_time = wait_start while not self._is_worker_ready(worker_info): @@ -981,20 +944,15 @@ def _configure_worker(self, worker_info: WorkerInfo, worker_rank: int): if now - last_log_time >= 5.0: elapsed = now - wait_start logger.warning( - f"[DIAG] _configure_worker: still waiting for worker " + f"Still waiting for worker " f"'{worker_id}' after {elapsed:.0f}s " f"(ip={worker_info.worker.ip}, " f"ports={worker_info.worker.worker_ports})" ) last_log_time = now - logger.info( - f"[DIAG] _configure_worker: worker '{worker_id}' ready after " - f"{time.time() - wait_start:.1f}s, sending configure request" - ) port = int(worker_info.worker.worker_ports[0]) url = f"http://{format_hostport(worker_info.worker.ip, port)}/configure" - logger.info(f"[DIAG] _configure_worker: POST {url} for worker '{worker_id}'") try: response = requests.post( @@ -1151,10 +1109,6 @@ async def set_worker_env(self, worker_id: str, env: dict[str, str]) -> None: payload = {"env": env} port = int(worker_info.worker.worker_ports[0]) url = f"http://{format_hostport(worker_info.worker.ip, port)}/set_env" - logger.info( - f"[DIAG] set_worker_env: POST {url} for worker '{worker_id}', " - f"env keys={list(env.keys())}" - ) try: timeout = aiohttp.ClientTimeout(total=30.0) @@ -1248,11 +1202,6 @@ async def create_engine( url = f"http://{format_hostport(worker_info.worker.ip, port)}/create_engine" try: - logger.info( - f"[DIAG] create_engine: POST {url} engine='{engine_name}' " - f"(class: {engine}) on worker '{worker_id}'" - ) - timeout = aiohttp.ClientTimeout(total=300.0) async with aiohttp.ClientSession( timeout=timeout, @@ -1266,10 +1215,6 @@ async def create_engine( ) as response: if response.status == 200: result = await response.json() - logger.info( - f"[DIAG] create_engine: engine '{engine_name}' " - f"created successfully on worker '{worker_id}'" - ) return result.get("result") elif response.status == 400: # Import error or bad request @@ -1525,11 +1470,6 @@ async def async_call_engine( ) try: - logger.info( - f"[DIAG] async_call_engine: POST {url} method='{method}' " - f"engine='{engine_name}' on worker '{worker_id}' (attempt {attempt})" - ) - timeo = aiohttp.ClientTimeout( total=http_timeout, sock_connect=http_timeout, connect=http_timeout ) diff --git a/areal/trainer/rl_trainer.py b/areal/trainer/rl_trainer.py index c529edd137..a806fce6a6 100644 --- a/areal/trainer/rl_trainer.py +++ b/areal/trainer/rl_trainer.py @@ -202,17 +202,11 @@ def __init__( engine_init_kwargs = {"addr": None, "ft_spec": ft_spec} - logger.info("[DIAG] PPOTrainer: initializing actor engine...") self.actor.initialize(**engine_init_kwargs, role="actor") - logger.info("[DIAG] PPOTrainer: actor engine initialized") if self.critic is not None: - logger.info("[DIAG] PPOTrainer: initializing critic engine...") self.critic.initialize(**engine_init_kwargs, role="critic") - logger.info("[DIAG] PPOTrainer: critic engine initialized") if self.ref is not None: - logger.info("[DIAG] PPOTrainer: initializing ref engine...") self.ref.initialize(**engine_init_kwargs, role="ref") - logger.info("[DIAG] PPOTrainer: ref engine initialized") self.teacher = None if config.teacher is not None: @@ -226,11 +220,9 @@ def __init__( initial_lora_path = self._save_initial_lora_weights() # Initialize inference with LoRA path - logger.info("[DIAG] PPOTrainer: initializing rollout engine...") self.rollout = self._init_rollout( config.rollout, is_eval=False, lora_path=initial_lora_path ) - logger.info("[DIAG] PPOTrainer: rollout engine initialized") # Online mode detection: skip eval rollout for efficiency. openai_cfg = config.rollout.openai self._online_mode = train_dataset is None or ( From d6f8c33760c9720115ea36c625ddfbbc68525a49 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Sun, 26 Apr 2026 01:17:30 +0800 Subject: [PATCH 045/112] fix(math): update optimizer --- .../math/moonlight_16b_a3b_gsm8k_grpo_megatron_h20_base.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron_h20_base.yaml b/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron_h20_base.yaml index 8e13cc3e8f..d293ca643d 100644 --- a/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron_h20_base.yaml +++ b/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron_h20_base.yaml @@ -50,8 +50,8 @@ actor: mb_spec: max_tokens_per_mb: 1280 # ← 从 2048 降至 512 optimizer: - type: adam_bf16 - lr: 2e-6 + type: adam + lr: 1e-5 weight_decay: 0.003 beta1: 0.9 beta2: 0.999 From 4c1198101a7e75896a03e63982466066491c7fce Mon Sep 17 00:00:00 2001 From: bingyechen Date: Sun, 26 Apr 2026 01:20:46 +0800 Subject: [PATCH 046/112] fix(rollout): up max_head_offpolicyness --- .../math/moonlight_16b_a3b_gsm8k_grpo_megatron_h20_base.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron_h20_base.yaml b/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron_h20_base.yaml index d293ca643d..9aae26a97b 100644 --- a/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron_h20_base.yaml +++ b/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron_h20_base.yaml @@ -24,7 +24,7 @@ rollout: max_concurrent_rollouts: 64 queue_size: null consumer_batch_size: ${train_dataset.batch_size} - max_head_offpolicyness: 1 + max_head_offpolicyness: 2 enable_rollout_tracing: false scheduling_spec: ${actor.scheduling_spec} fileroot: ${cluster.fileroot} From 41f49ab8c7b1d3f824fc20fb6945a865b66b3ad6 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Sun, 26 Apr 2026 01:55:12 +0800 Subject: [PATCH 047/112] fix(rollout): down max_head_offpolicyness --- .../math/moonlight_16b_a3b_gsm8k_grpo_megatron_h20_base.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron_h20_base.yaml b/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron_h20_base.yaml index 9aae26a97b..d293ca643d 100644 --- a/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron_h20_base.yaml +++ b/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron_h20_base.yaml @@ -24,7 +24,7 @@ rollout: max_concurrent_rollouts: 64 queue_size: null consumer_batch_size: ${train_dataset.batch_size} - max_head_offpolicyness: 2 + max_head_offpolicyness: 1 enable_rollout_tracing: false scheduling_spec: ${actor.scheduling_spec} fileroot: ${cluster.fileroot} From 0fce038dc24628607d672277d42b22f0233eaaec Mon Sep 17 00:00:00 2001 From: bingyechen Date: Sun, 26 Apr 2026 02:56:38 +0800 Subject: [PATCH 048/112] fix(optimizer): lr fix --- .../math/moonlight_16b_a3b_gsm8k_grpo_megatron_h20_base.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron_h20_base.yaml b/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron_h20_base.yaml index d293ca643d..3d555845c3 100644 --- a/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron_h20_base.yaml +++ b/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron_h20_base.yaml @@ -51,7 +51,7 @@ actor: max_tokens_per_mb: 1280 # ← 从 2048 降至 512 optimizer: type: adam - lr: 1e-5 + lr: 4e-6 weight_decay: 0.003 beta1: 0.9 beta2: 0.999 From 73e9a9526483b0f2a7fd4242360eed8c05bae7d0 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Sun, 26 Apr 2026 09:26:59 +0800 Subject: [PATCH 049/112] fix(optimizer): lr --- .../math/moonlight_16b_a3b_gsm8k_grpo_megatron_h20_base.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron_h20_base.yaml b/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron_h20_base.yaml index 3d555845c3..b7fece87d4 100644 --- a/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron_h20_base.yaml +++ b/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron_h20_base.yaml @@ -51,7 +51,7 @@ actor: max_tokens_per_mb: 1280 # ← 从 2048 降至 512 optimizer: type: adam - lr: 4e-6 + lr: 2e-6 weight_decay: 0.003 beta1: 0.9 beta2: 0.999 From 09e43a13ea847739f501412026e90d10acdc0011 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Mon, 27 Apr 2026 03:16:25 +0800 Subject: [PATCH 050/112] feat(r3): fix r3 --- areal/api/cli_args.py | 2 +- areal/engine/sglang_remote.py | 106 +++++++--- areal/infra/launcher/sglang_launch_server.py | 82 ++++++++ areal/infra/launcher/sglang_r3_patch.py | 198 +++++++++++++++++++ areal/trainer/rl_trainer.py | 20 ++ areal/workflow/rlvr_r3_patch.py | 23 ++- 6 files changed, 397 insertions(+), 34 deletions(-) create mode 100644 areal/infra/launcher/sglang_launch_server.py create mode 100644 areal/infra/launcher/sglang_r3_patch.py diff --git a/areal/api/cli_args.py b/areal/api/cli_args.py index 2491b988e8..df7ed4891d 100644 --- a/areal/api/cli_args.py +++ b/areal/api/cli_args.py @@ -1617,7 +1617,7 @@ def build_cmd( @staticmethod def build_cmd_from_args(args: dict[str, Any]): - return get_py_cmd("sglang.launch_server", args) + return get_py_cmd("areal.infra.launcher.sglang_launch_server", args) @staticmethod def build_args( diff --git a/areal/engine/sglang_remote.py b/areal/engine/sglang_remote.py index 864f517333..795c8df1e2 100644 --- a/areal/engine/sglang_remote.py +++ b/areal/engine/sglang_remote.py @@ -98,53 +98,107 @@ def parse_generation_response( stop_message = finish_reason.get("message", "") # Extract routed_experts information if available. - # SGLang may return routed_experts in two formats: - # 1. Base64-encoded string (skip_tokenizer_init=False, normal path) - # 2. Raw list/dict (skip_tokenizer_init=True or newer SGLang versions) - routed_experts = meta_info.get("routed_experts", None) + routed_experts_raw = meta_info.get("routed_experts", None) + routed_experts = routed_experts_raw if routed_experts is not None: num_sgl_token = ( meta_info["prompt_tokens"] + meta_info["completion_tokens"] - 1 ) - if isinstance(routed_experts, str): + re_type_name = type(routed_experts_raw).__name__ + if isinstance(routed_experts, dict): + # Empty dict -> jsonable_encoder(tensor) failure. Non-empty + # dict is not a documented SGLang wire format; surface it + # loudly so we can add a decoder if/when it starts + # happening rather than silently corrupting R3 data. + if not routed_experts: + logger.warning( + "[R3] SGLang returned routed_experts=%s (empty " + "dict). This is the fingerprint of a raw " + "torch.Tensor being serialised by FastAPI's " + "jsonable_encoder. Ensure the AReaL SGLang " + "server-side patch (areal.infra.launcher." + "sglang_r3_patch) is installed on the inference " + "server; dropping payload.", + routed_experts_raw, + ) + else: + logger.warning( + "[R3] SGLang returned routed_experts as a " + "non-empty dict (keys=%s); no decoder registered. " + "Dropping payload.", + sorted(routed_experts.keys()), + ) + routed_experts = None + elif isinstance(routed_experts, str): try: - routed_experts = np.frombuffer( + flat = np.frombuffer( pybase64.b64decode(routed_experts.encode("utf-8")), dtype=np.int32, - ).reshape(num_sgl_token, -1) - logger.info( - "[R3-VERIFY] SGLang decoded routed_experts: " - "shape=%s, first3=%s, hash=%d", - routed_experts.shape, - routed_experts.flat[:3].tolist(), - hash(routed_experts.tobytes()), ) - except Exception: + if num_sgl_token <= 0 or flat.size % num_sgl_token != 0: + # Total element count does not divide by + # ``num_sgl_token``. This usually means SGLang's + # tokenizer round-trip (``skip_tokenizer_init=False``) + # inserted/removed tokens between the router capture + # and the returned ``output_tokens``. Drop the + # payload instead of silently reshaping into a wrong + # grid that would corrupt R3 replay. + logger.warning( + "[R3] routed_experts size=%d does not divide " + "num_sgl_token=%d (prompt=%d + completion=%d - 1). " + "Likely tokenizer round-trip drift; dropping.", + flat.size, + num_sgl_token, + meta_info.get("prompt_tokens", -1), + meta_info.get("completion_tokens", -1), + ) + routed_experts = None + else: + routed_experts = flat.reshape(num_sgl_token, -1) + logger.info( + "[R3-VERIFY] SGLang decoded routed_experts: " + "shape=%s, first3=%s, hash=%d", + routed_experts.shape, + routed_experts.flat[:3].tolist(), + hash(routed_experts.tobytes()), + ) + except Exception as exc: logger.warning( "[R3] Failed to decode base64 routed_experts " "(num_sgl_token=%d): %s", num_sgl_token, + exc, exc_info=True, ) routed_experts = None else: try: - routed_experts = np.asarray( - routed_experts, dtype=np.int32 - ).reshape(num_sgl_token, -1) - logger.info( - "[R3-VERIFY] SGLang converted routed_experts: " - "shape=%s, first3=%s, hash=%d", - routed_experts.shape, - routed_experts.flat[:3].tolist(), - hash(routed_experts.tobytes()), - ) - except Exception: + raw = np.asarray(routed_experts, dtype=np.int32).reshape(-1) + if num_sgl_token <= 0 or raw.size % num_sgl_token != 0: + logger.warning( + "[R3] routed_experts size=%d does not divide " + "num_sgl_token=%d; likely tokenizer round-trip " + "drift, dropping.", + raw.size, + num_sgl_token, + ) + routed_experts = None + else: + routed_experts = raw.reshape(num_sgl_token, -1) + logger.info( + "[R3-VERIFY] SGLang converted routed_experts: " + "shape=%s, first3=%s, hash=%d", + routed_experts.shape, + routed_experts.flat[:3].tolist(), + hash(routed_experts.tobytes()), + ) + except Exception as exc: logger.warning( "[R3] Failed to convert routed_experts from %s " "(num_sgl_token=%d): %s", - type(meta_info.get("routed_experts")).__name__, + re_type_name, num_sgl_token, + exc, exc_info=True, ) routed_experts = None diff --git a/areal/infra/launcher/sglang_launch_server.py b/areal/infra/launcher/sglang_launch_server.py new file mode 100644 index 0000000000..bddc648208 --- /dev/null +++ b/areal/infra/launcher/sglang_launch_server.py @@ -0,0 +1,82 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Thin wrapper around ``sglang.launch_server`` that installs AReaL's R3 +monkey patches before the upstream server boots. + +Usage (transparent replacement for ``python3 -m sglang.launch_server ...``):: + + python3 -m areal.infra.launcher.sglang_launch_server --model-path ... + +Only the R3 patch is installed; all other CLI behaviour is delegated to +``sglang.launch_server`` unchanged so this entrypoint stays drop-in safe +when R3 is not used. +""" + +from __future__ import annotations + +import logging +import os +import sys + + +def _install_areal_patches() -> None: + """Install AReaL monkey patches that must be active in the SGLang + server process (scheduler/tokenizer manager/HTTP server).""" + try: + from areal.infra.launcher.sglang_r3_patch import apply_sglang_r3_patch + + apply_sglang_r3_patch() + except Exception: # pragma: no cover - defensive + logging.getLogger(__name__).exception( + "[R3] Failed to install AReaL SGLang patches; server will " + "start without R3 wire-format fixes." + ) + + +def main() -> None: + _install_areal_patches() + + # Delegate to upstream launcher. We keep argv intact (including + # argv[0] mangling done by ``python3 -m``) because + # ``sglang.launch_server`` uses ``sys.argv[1:]`` directly. + from sglang.srt.server_args import prepare_server_args + from sglang.srt.utils import kill_process_tree + from sglang.srt.utils.common import suppress_noisy_warnings + + suppress_noisy_warnings() + + server_args = prepare_server_args(sys.argv[1:]) + + # Same dispatch as ``sglang/launch_server.py``. + try: + if getattr(server_args, "grpc_mode", False): + import asyncio + + from sglang.srt.entrypoints.grpc_server import serve_grpc + + asyncio.run(serve_grpc(server_args)) + elif getattr(server_args, "encoder_only", False): + from sglang.srt.disaggregation.encode_server import launch_server + + launch_server(server_args) + else: + from sglang.srt.entrypoints.http_server import launch_server + + launch_server(server_args) + finally: + kill_process_tree(os.getpid(), include_parent=False) + + +if __name__ == "__main__": + main() diff --git a/areal/infra/launcher/sglang_r3_patch.py b/areal/infra/launcher/sglang_r3_patch.py new file mode 100644 index 0000000000..60821e143d --- /dev/null +++ b/areal/infra/launcher/sglang_r3_patch.py @@ -0,0 +1,198 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""SGLang server-side monkey patches required for AReaL's R3 (Router Replay). + +Background +---------- +When ``skip_tokenizer_init=True`` (which AReaL forces whenever R3 is enabled, +see ``rl_trainer.py``), the SGLang *scheduler* bypasses the +``DetokenizerManager`` and sends ``BatchTokenIDOutput`` directly to +``TokenizerManager``. The side effect is that the ``routed_experts`` tensor +is **not** base64-encoded by ``DetokenizerManager._extract_routed_experts`` +anymore -- it reaches ``TokenizerManager._handle_batch_output`` still as a +raw ``torch.Tensor``. + +``TokenizerManager`` then attaches the tensor verbatim to +``meta_info["routed_experts"]`` and lets FastAPI/``ORJSONResponse`` +serialize the whole response. FastAPI's ``jsonable_encoder`` does not +know how to encode ``torch.Tensor``; it silently returns an **empty +dict** (``{}``) instead of raising. The client then receives + + meta_info["routed_experts"] == {} + +and hits ``TypeError: int() argument must be a string, a bytes-like +object or a real number, not 'dict'`` when it tries +``np.asarray(routed_experts, dtype=np.int32)``. Because the error is +swallowed in ``parse_generation_response``, ``routed_experts`` becomes +``None`` and the downstream ``RemoteInfEngine`` raises:: + + RuntimeError: Requested return_routed_experts=True but received None + from SGLang + +This module installs a monkey patch on +``sglang.srt.managers.tokenizer_manager.TokenizerManager._handle_batch_output`` +that base64-encodes the tensor in-place *before* it is serialised (exactly +the same encoding that ``DetokenizerManager._extract_routed_experts`` +applies in the non-``skip_tokenizer_init`` path), so the wire format stays +consistent with both SGLang's documented behaviour and AReaL's client-side +decoder in ``areal/engine/sglang_remote.py``. + +The patch is idempotent. When R3 is disabled the patch is a no-op at +runtime because the routed-experts attribute stays ``None``. + +Compatibility note +------------------ +Different SGLang branches use different attribute names on the +``BatchTokenIDOutput`` / ``BatchStrOutput`` dataclasses: + +* Current SGLang main / v0.5.9 : ``recv_obj.routed_experts`` +* Original R3 commit ``bed301a5`` (the base verl's R3 example was built + against, see + ``https://github.com/sgl-project/sglang/commit/bed301a5acaa9577c9aa706468bdf242f6a43051``) + : ``recv_obj.output_routed_experts`` + +The patch handles both attribute names so AReaL runs unmodified against +either SGLang build. Consequently the TokenizerManager attaches the +resulting ``meta_info["routed_experts"]`` entry to every response regardless +of which internal field the scheduler populated. + +Reference +--------- +* SGLang encodes ``routed_experts`` via ``pybase64.b64encode`` at + ``python/sglang/srt/managers/detokenizer_manager.py::_extract_routed_experts`` + for the ``skip_tokenizer_init=False`` path. +* verl applies the symmetric decoder at + ``verl/workers/rollout/sglang_rollout/async_sglang_server.py`` via + ``extract_routed_experts_from_meta_info`` (which assumes a base64 string). +""" + +from __future__ import annotations + +import logging + +logger = logging.getLogger(__name__) + +_PATCH_APPLIED = False + +# Attribute names used across the SGLang versions we must support. +# * ``routed_experts`` : current SGLang main / v0.5.9. +# * ``output_routed_experts`` : original R3 commit ``bed301a5`` that the +# verl R3 example and the AReaL R3 docker +# were built on. +_ROUTED_EXPERTS_ATTRS = ("routed_experts", "output_routed_experts") + + +def _encode_routed_experts_for_wire(value): + """Convert ``routed_experts`` to a base64 string for wire transport. + + Mirrors ``sglang.srt.managers.detokenizer_manager + ._extract_routed_experts``: each request's tensor is encoded as + ``pybase64.b64encode(tensor.numpy().tobytes()).decode("utf-8")``. + + Accepts tensors, numpy arrays, or already-encoded strings; other + types are returned unchanged so the client branch can surface them. + """ + if value is None: + return None + if isinstance(value, str): + # Already encoded by DetokenizerManager or a prior invocation. + return value + try: + import numpy as np + import pybase64 + import torch + except Exception: # pragma: no cover - defensive + return value + + if isinstance(value, torch.Tensor): + # ``to("cpu")`` is a no-op when already on CPU but protects us + # against exotic device placements (e.g. CUDA tensors leaked by + # a capture buffer). ``contiguous()`` guarantees ``tobytes`` + # produces a dense layout matching ``shape`` on the decode side. + arr = value.detach().to("cpu").contiguous().numpy() + elif isinstance(value, np.ndarray): + arr = np.ascontiguousarray(value) + else: + return value + + # Normalise dtype to int32 so the client's ``np.frombuffer(..., int32)`` + # matches regardless of whether the capture buffer was int64. + if arr.dtype != np.int32: + arr = arr.astype(np.int32, copy=False) + + return pybase64.b64encode(arr.tobytes()).decode("utf-8") + + +def apply_sglang_r3_patch() -> bool: + """Install the ``_handle_batch_output`` monkey patch. + + Returns ``True`` when the patch is installed (or was already + installed). Returns ``False`` when SGLang is unavailable in the + current process (so the caller can gracefully skip). + """ + global _PATCH_APPLIED + if _PATCH_APPLIED: + return True + + try: + from sglang.srt.managers import tokenizer_manager as _tm + except Exception as exc: # pragma: no cover - defensive + logger.warning( + "[R3] sglang.srt.managers.tokenizer_manager not importable; " + "skipping R3 server patch. reason=%s", + exc, + ) + return False + + original = _tm.TokenizerManager._handle_batch_output + + def _handle_batch_output_r3(self, recv_obj): # type: ignore[no-redef] + # Pre-encode the routed-experts tensor so the downstream FastAPI + # serialisation sees a plain string (which ``jsonable_encoder`` + # passes through untouched) instead of a ``torch.Tensor`` (which + # ``jsonable_encoder`` silently flattens to ``{}``). + # + # We must handle BOTH attribute names because different SGLang + # builds populate different fields: + # * ``routed_experts`` - current SGLang main / v0.5.9 + # * ``output_routed_experts`` - original R3 commit ``bed301a5`` + # (verl's R3 base commit) + try: + for attr_name in _ROUTED_EXPERTS_ATTRS: + re_list = getattr(recv_obj, attr_name, None) + if re_list is None: + continue + encoded = [_encode_routed_experts_for_wire(v) for v in re_list] + try: + setattr(recv_obj, attr_name, encoded) + except Exception: + # Some SGLang versions freeze the dataclass; fall back + # to object.__setattr__ which bypasses __slots__ / + # frozen protection. + object.__setattr__(recv_obj, attr_name, encoded) + except Exception: # pragma: no cover - defensive + logger.exception( + "[R3] Failed to pre-encode routed_experts on server; " + "falling through to unpatched behaviour." + ) + return original(self, recv_obj) + + _tm.TokenizerManager._handle_batch_output = _handle_batch_output_r3 + _PATCH_APPLIED = True + logger.info( + "[R3] Installed sglang TokenizerManager._handle_batch_output " + "base64 encoder patch for routed_experts " + "(handles both routed_experts and output_routed_experts)." + ) + return True diff --git a/areal/trainer/rl_trainer.py b/areal/trainer/rl_trainer.py index a806fce6a6..b4fad72ed2 100644 --- a/areal/trainer/rl_trainer.py +++ b/areal/trainer/rl_trainer.py @@ -140,6 +140,26 @@ def __init__( type(config.actor.megatron).__name__, ) + sglang_cfg = getattr(config, "sglang", None) + if sglang_cfg is not None and not getattr( + sglang_cfg, "skip_tokenizer_init", True + ): + logger.warning( + "[R3] rollout.return_routed_experts=True but " + "sglang.skip_tokenizer_init=False. Forcing " + "sglang.skip_tokenizer_init=True to avoid " + "tokenizer round-trip token-shift that breaks " + "per-token router-index alignment." + ) + try: + sglang_cfg.skip_tokenizer_init = True + except Exception: + # Tolerate frozen/struct-style omegaconf containers + from omegaconf import OmegaConf + + OmegaConf.set_struct(sglang_cfg, False) + sglang_cfg.skip_tokenizer_init = True + # Create models: actor, critic, ref — each with its own allocation. self.actor = self._create_train_engine(config.actor, self.actor_alloc) self.critic = None diff --git a/areal/workflow/rlvr_r3_patch.py b/areal/workflow/rlvr_r3_patch.py index de70b733a1..d6b27bae8c 100644 --- a/areal/workflow/rlvr_r3_patch.py +++ b/areal/workflow/rlvr_r3_patch.py @@ -102,6 +102,9 @@ def extract_routed_experts( return None +_INFER_LOGGED: set[tuple[int, int]] = set() + + def _infer_and_preprocess( routed_experts_np: np.ndarray, input_ids: torch.Tensor, @@ -128,13 +131,19 @@ def _infer_and_preprocess( ) num_moe_layers = flat_dim // topk - logger.debug( - "[R3] Inferred num_moe_layers=%d, topk=%d from flat_dim=%d " - "(warning: these may be incorrect without model config).", - num_moe_layers, - topk, - flat_dim, - ) + _key = (num_moe_layers, topk) + if _key not in _INFER_LOGGED: + _INFER_LOGGED.add(_key) + logger.info( + "[R3] rlvr workflow inferred num_moe_layers=%d, topk=%d from " + "flat_dim=%d (this count includes any dense-FFN layers; the " + "engine-side router_replay_utils.set_router_replay_data handles " + "the dense-vs-MoE split). For deterministic behaviour, pass " + "num_moe_layers and topk explicitly from the model config.", + num_moe_layers, + topk, + flat_dim, + ) from areal.engine.router_replay_utils import preprocess_routed_experts_batch From c52b50c5afd10eecbbe72d4022ef64a7a95655a6 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Mon, 27 Apr 2026 03:26:26 +0800 Subject: [PATCH 051/112] fix(rollout): on policy --- .../math/moonlight_16b_a3b_gsm8k_grpo_megatron_h20_base.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron_h20_base.yaml b/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron_h20_base.yaml index b7fece87d4..7dc2a0c78f 100644 --- a/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron_h20_base.yaml +++ b/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron_h20_base.yaml @@ -24,7 +24,7 @@ rollout: max_concurrent_rollouts: 64 queue_size: null consumer_batch_size: ${train_dataset.batch_size} - max_head_offpolicyness: 1 + max_head_offpolicyness: 0 enable_rollout_tracing: false scheduling_spec: ${actor.scheduling_spec} fileroot: ${cluster.fileroot} From 606442cbe4b9c11d8547e7d72ab2bc060b1a4829 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Tue, 28 Apr 2026 01:08:04 +0800 Subject: [PATCH 052/112] fix(engine): fix save --- areal/engine/megatron_engine.py | 34 +++++++++++++++++++++-- areal/engine/megatron_utils/megatron.py | 4 +++ areal/infra/rpc/guard/engine_blueprint.py | 22 +++++++++++++-- areal/infra/rpc/serialization.py | 14 ++++++++-- 4 files changed, 68 insertions(+), 6 deletions(-) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index 89a60e41b2..1c8f48468b 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -1277,6 +1277,17 @@ def _duplicated_param_names(self) -> set[str]: if getattr(module, "parallel_mode", None) == "duplicated": for p_name, _ in module.named_parameters(recurse=False): full = f"{mod_name}.{p_name}" if mod_name else p_name + # Normalize to match the naming convention used + # by megatron_utils.megatron.get_named_parameters + # (which is the source of ``name`` passed into + # ``all_gather_param``). Without this, MLA + # ``linear_q_down_proj`` / ``linear_kv_down_proj`` + # weights are never recognized as duplicated and + # get all-gathered across the TP group, producing + # oversized tensors that sglang rejects when + # loading Moonlight / DeepSeek-V3 MLA weights. + if not full.startswith("module.module."): + full = "module." + full duplicated.add(full) self._cached_duplicated_param_names = duplicated return self._cached_duplicated_param_names @@ -1491,6 +1502,8 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: dist.barrier(group=self.cpu_group) + torch.cuda.empty_cache() + num_moe_experts = self.tf_config.num_moe_experts weight_chunked_mem_size = meta.weight_chunked_mem_mb * 1024 * 1024 @@ -1584,6 +1597,7 @@ def _save_model_to_hf( base_model_path: str | None = None, ) -> None: assert self.model is not None, "Model is not initialized." + torch.cuda.empty_cache() os.makedirs(path, exist_ok=True) if self.bridge_cls == "megatron-bridge": @@ -1619,9 +1633,25 @@ def _save_model_to_hf( if dist.get_rank() == 0: if tokenizer is not None: - tokenizer.save_pretrained(path) + if hasattr(tokenizer, "save_pretrained"): + tokenizer.save_pretrained(path) + else: + self.logger.warning( + "Tokenizer object has no save_pretrained() method " + f"(got type={type(tokenizer).__name__}); skipping save. " + "This usually means the tokenizer could not be " + "reconstructed from its serialized form (e.g. a model " + "with a custom tokenizer class requiring " + "trust_remote_code=True) and was decoded as a raw dict." + ) if processor is not None: - processor.save_pretrained(path) + if hasattr(processor, "save_pretrained"): + processor.save_pretrained(path) + else: + self.logger.warning( + "Processor object has no save_pretrained() method " + f"(got type={type(processor).__name__}); skipping save." + ) current_platform.synchronize() dist.barrier(group=self.cpu_group) diff --git a/areal/engine/megatron_utils/megatron.py b/areal/engine/megatron_utils/megatron.py index 85b445b0ac..e6ba4d1c9b 100644 --- a/areal/engine/megatron_utils/megatron.py +++ b/areal/engine/megatron_utils/megatron.py @@ -112,6 +112,9 @@ def all_gather_param( # Use the caller-provided duplicated_param_names set for reliable detection. is_duplicated = ( duplicated_param_names is not None and name in duplicated_param_names + ) or ( + ".self_attention.linear_q_down_proj." in name + or ".self_attention.linear_kv_down_proj." in name ) if not param.tensor_model_parallel or is_duplicated: # NOTE: For FP8 tensors with direct conversion, return the tensor directly @@ -765,6 +768,7 @@ def convert_bailingmoe_to_hf( "qwen2": convert_qwen2_to_hf, "qwen3": convert_qwen2_to_hf, "deepseekv3": convert_deepseekv3_to_hf, + "deepseek_v3": convert_deepseekv3_to_hf, "bailing_moe_v2": convert_bailingmoe_to_hf, "bailing_moe_linear": convert_bailingmoe_to_hf, "bailing_hybrid": convert_bailingmoe_to_hf, diff --git a/areal/infra/rpc/guard/engine_blueprint.py b/areal/infra/rpc/guard/engine_blueprint.py index 4dbd1c765d..aaa9e0a777 100644 --- a/areal/infra/rpc/guard/engine_blueprint.py +++ b/areal/infra/rpc/guard/engine_blueprint.py @@ -520,6 +520,10 @@ def execute_in_engine_thread(): category=category, args={"method": method_name, "engine": engine_name}, ): + if not hasattr(engine, method_name): + raise ValueError( + f"Engine does not have method '{method_name}'" + ) method = getattr(engine, method_name) result = method(*args_bcast, **kwargs_bcast) @@ -531,8 +535,22 @@ def execute_in_engine_thread(): return result except AttributeError as e: - logger.error(f"Method '{method_name}' not found on engine: {e}") - raise ValueError(f"Engine does not have method '{method_name}'") + # Only treat this as "method missing" if the method truly is + # not defined on the engine. An AttributeError raised from + # within the method body should propagate as-is so the real + # stack trace is visible. + if not hasattr(engine, method_name): + logger.error( + f"Method '{method_name}' not found on engine: {e}" + ) + raise ValueError( + f"Engine does not have method '{method_name}'" + ) + logger.error( + f"Engine method '{method_name}' raised AttributeError: " + f"{e}\n{traceback.format_exc()}" + ) + raise except Exception as e: logger.error( f"Engine method '{method_name}' failed: " diff --git a/areal/infra/rpc/serialization.py b/areal/infra/rpc/serialization.py index 67cdc48233..48c171127e 100644 --- a/areal/infra/rpc/serialization.py +++ b/areal/infra/rpc/serialization.py @@ -354,7 +354,15 @@ def to_tokenizer(self) -> Any: with tempfile.TemporaryDirectory() as tmpdir: with zipfile.ZipFile(zip_buffer) as zf: zf.extractall(tmpdir) - tokenizer = AutoTokenizer.from_pretrained(tmpdir) + # NOTE: Models such as Moonlight / DeepSeek-V3 rely on a custom + # tokenizer class declared via ``auto_map`` (e.g. the TikToken based + # ``TikTokenTokenizer``). Without ``trust_remote_code=True`` + # ``AutoTokenizer.from_pretrained`` raises and the caller falls back + # to the raw dict, which then explodes inside ``_save_model_to_hf`` + # as ``'dict' object has no attribute 'save_pretrained'``. + tokenizer = AutoTokenizer.from_pretrained( + tmpdir, trust_remote_code=True + ) if hasattr(tokenizer, "name_or_path"): tokenizer.name_or_path = self.name_or_path @@ -464,7 +472,9 @@ def to_processor(self) -> Any: with tempfile.TemporaryDirectory() as tmpdir: with zipfile.ZipFile(zip_buffer) as zf: zf.extractall(tmpdir) - processor = AutoProcessor.from_pretrained(tmpdir) + processor = AutoProcessor.from_pretrained( + tmpdir, trust_remote_code=True + ) if hasattr(processor, "name_or_path"): processor.name_or_path = self.name_or_path From 75a8527835b4e33e3af787a10cdd252754ee8a01 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Tue, 28 Apr 2026 16:01:14 +0800 Subject: [PATCH 053/112] refactor: remove useless --- areal/engine/megatron_engine.py | 3 --- areal/engine/megatron_engine_r3_patch.py | 17 +++-------------- areal/engine/megatron_utils/megatron_lora.py | 5 ++++- areal/engine/router_replay_patch.py | 15 +-------------- areal/engine/router_replay_utils.py | 13 ------------- areal/trainer/ppo/actor.py | 2 +- areal/trainer/ppo/actor_r3_patch.py | 13 ------------- areal/trainer/rl_trainer.py | 6 ------ 8 files changed, 9 insertions(+), 65 deletions(-) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index 1c8f48468b..bd0b8b38bd 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -1502,8 +1502,6 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: dist.barrier(group=self.cpu_group) - torch.cuda.empty_cache() - num_moe_experts = self.tf_config.num_moe_experts weight_chunked_mem_size = meta.weight_chunked_mem_mb * 1024 * 1024 @@ -1597,7 +1595,6 @@ def _save_model_to_hf( base_model_path: str | None = None, ) -> None: assert self.model is not None, "Model is not initialized." - torch.cuda.empty_cache() os.makedirs(path, exist_ok=True) if self.bridge_cls == "megatron-bridge": diff --git a/areal/engine/megatron_engine_r3_patch.py b/areal/engine/megatron_engine_r3_patch.py index 0c8e61f301..9e8733070b 100644 --- a/areal/engine/megatron_engine_r3_patch.py +++ b/areal/engine/megatron_engine_r3_patch.py @@ -1,16 +1,3 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. """ R3 Integration Patch for MegatronEngine. @@ -32,6 +19,8 @@ from areal.engine.megatron_engine_r3_patch import patch_megatron_engine_for_r3 patch_megatron_engine_for_r3(engine, enable_router_replay=True) + +Ref some code from megatron or verl, adapted for AReaL. """ from __future__ import annotations @@ -448,7 +437,7 @@ def __next__(self): if cu_info is not None: cu_seqlens, max_seqlen = cu_info try: - # CRITICAL FIX: Use cu_seqlens for alignment instead of + # Use cu_seqlens for alignment instead of # attention_mask. This ensures the packed token order # matches Megatron's actual forward pass. diff --git a/areal/engine/megatron_utils/megatron_lora.py b/areal/engine/megatron_utils/megatron_lora.py index c5f3a4ee47..969161342c 100644 --- a/areal/engine/megatron_utils/megatron_lora.py +++ b/areal/engine/megatron_utils/megatron_lora.py @@ -291,7 +291,10 @@ def save_hf_adapter( _monkey_patch_applied = False - +# Current: This monkey patch is needed as the current megatron-bridge 0.3.0 does not have a built-in method +# to save LoRA adapters in HuggingFace PEFT format, which is required for our use case. +# Future: This code is however present in main branch of megatron-bridge so this patch is temporary +# and can be removed later when we upgrade the megatron-bridge version. def ensure_save_hf_adapter_patched(): global _monkey_patch_applied if not _monkey_patch_applied: diff --git a/areal/engine/router_replay_patch.py b/areal/engine/router_replay_patch.py index b102f3cb1d..6efb077663 100644 --- a/areal/engine/router_replay_patch.py +++ b/areal/engine/router_replay_patch.py @@ -1,16 +1,3 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. """ Monkey-patches for Megatron-Core MoE components to support Router Replay (R3). @@ -37,7 +24,7 @@ from areal.engine.router_replay_patch import remove_router_replay_patch remove_router_replay_patch() # optional: for test cleanup -Ref some code from verl, adapted for AReaL. +Ref some code from megatron or verl, adapted for AReaL. """ from __future__ import annotations diff --git a/areal/engine/router_replay_utils.py b/areal/engine/router_replay_utils.py index e53e4759e0..3320029362 100644 --- a/areal/engine/router_replay_utils.py +++ b/areal/engine/router_replay_utils.py @@ -1,16 +1,3 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. """ Router Replay Utilities for AReaL. diff --git a/areal/trainer/ppo/actor.py b/areal/trainer/ppo/actor.py index b84c0d89c0..2f665a92ee 100644 --- a/areal/trainer/ppo/actor.py +++ b/areal/trainer/ppo/actor.py @@ -553,7 +553,7 @@ def grpo_loss_fn( ) # ---- R3 Logprob Diff: rollout (inference) vs training logprobs ---- - # Following SkyRL's approach: compute |rollout_logprobs - training_logprobs| + # compute |rollout_logprobs - training_logprobs| # over response tokens only. This metric quantifies train/infer mismatch # caused by MoE routing divergence. With R3 enabled, this diff should be # smaller than without R3. diff --git a/areal/trainer/ppo/actor_r3_patch.py b/areal/trainer/ppo/actor_r3_patch.py index 4d3476ce21..a4025cc8f6 100644 --- a/areal/trainer/ppo/actor_r3_patch.py +++ b/areal/trainer/ppo/actor_r3_patch.py @@ -1,16 +1,3 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. """ MoE routing metrics and R3 logging helpers for the PPO actor. diff --git a/areal/trainer/rl_trainer.py b/areal/trainer/rl_trainer.py index b4fad72ed2..d14ac03aef 100644 --- a/areal/trainer/rl_trainer.py +++ b/areal/trainer/rl_trainer.py @@ -126,11 +126,6 @@ def __init__( self._amend_xccl_weight_update_envvar() - # R3: Propagate router replay flag to actor engine config so that - # MegatronEngine.initialize() can apply the R3 patch. - # Must be set BEFORE engine creation so that Ray-serialized config - # carries the flag. Uses the declared MegatronEngineConfig field - # (not a dynamic attribute) to survive Ray serialization. if getattr(config.rollout, "return_routed_experts", False): config.actor.megatron.enable_router_replay = True logger.info( @@ -221,7 +216,6 @@ def __init__( ) engine_init_kwargs = {"addr": None, "ft_spec": ft_spec} - self.actor.initialize(**engine_init_kwargs, role="actor") if self.critic is not None: self.critic.initialize(**engine_init_kwargs, role="critic") From dc94b3d49f40aab34d5a4c720c0d19005dd6e229 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Tue, 28 Apr 2026 16:07:44 +0800 Subject: [PATCH 054/112] refactor(rlvr_r3_patch): remove moe infer --- areal/workflow/rlvr_r3_patch.py | 114 +++++--------------------------- 1 file changed, 15 insertions(+), 99 deletions(-) diff --git a/areal/workflow/rlvr_r3_patch.py b/areal/workflow/rlvr_r3_patch.py index d6b27bae8c..efceb58651 100644 --- a/areal/workflow/rlvr_r3_patch.py +++ b/areal/workflow/rlvr_r3_patch.py @@ -1,16 +1,3 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. """ R3 helpers for the RLVR workflow. @@ -26,13 +13,6 @@ 2. The tensor is added to the result dict under key ``"routed_experts"``. 3. During training, the ``MegatronEngine`` R3 patch picks it up from the batch data and feeds it to ``setup_per_microbatch_replay_forward``. - -Note on num_moe_layers and topk: - At the workflow level, we may not know the exact model config values. - We store the raw numpy array shape info and let the engine layer - (which has access to tf_config) do the final reshape. As a - practical compromise, we accept optional num_moe_layers and topk - parameters and fall back to shape-based inference when not provided. """ from __future__ import annotations @@ -50,8 +30,8 @@ def extract_routed_experts( routed_experts_np: Optional[np.ndarray], input_ids: torch.Tensor, attention_mask: torch.Tensor, - num_moe_layers: Optional[int] = None, - topk: Optional[int] = None, + num_moe_layers: int, + topk: int, compress_dtype: bool = True, ) -> Optional[torch.Tensor]: """Convert ``ModelResponse.routed_experts`` into a training tensor. @@ -61,8 +41,8 @@ def extract_routed_experts( as returned by the SGLang inference backend, or ``None``. input_ids: ``(1, seq_len)`` token ids (prompt + response). attention_mask: ``(1, seq_len)`` with 1 for real tokens, 0 for padding. - num_moe_layers: Number of MoE layers. If None, inferred from shape. - topk: Router top-k. If None, inferred from shape. + num_moe_layers: Number of MoE layers in the model. **Required**. + topk: Router top-k. **Required**. compress_dtype: Downcast to ``uint8`` / ``int16`` when possible. Returns: @@ -72,27 +52,18 @@ def extract_routed_experts( return None try: - if num_moe_layers is not None and topk is not None: - from areal.engine.router_replay_utils import ( - preprocess_routed_experts_batch, - ) + from areal.engine.router_replay_utils import ( + preprocess_routed_experts_batch, + ) - return preprocess_routed_experts_batch( - routed_experts_np, - input_ids, - attention_mask, - num_moe_layers=num_moe_layers, - topk=topk, - compress_dtype=compress_dtype, - ) - else: - # Fallback: infer num_moe_layers and topk from shape - return _infer_and_preprocess( - routed_experts_np, - input_ids, - attention_mask, - compress_dtype=compress_dtype, - ) + return preprocess_routed_experts_batch( + routed_experts_np, + input_ids, + attention_mask, + num_moe_layers=num_moe_layers, + topk=topk, + compress_dtype=compress_dtype, + ) except Exception: logger.warning( "[R3] Failed to preprocess routed_experts (shape=%s); skipping.", @@ -102,61 +73,6 @@ def extract_routed_experts( return None -_INFER_LOGGED: set[tuple[int, int]] = set() - - -def _infer_and_preprocess( - routed_experts_np: np.ndarray, - input_ids: torch.Tensor, - attention_mask: torch.Tensor, - compress_dtype: bool = True, -) -> torch.Tensor: - """Infer num_moe_layers and topk from shape, then preprocess. - - We try common topk values (6, 8, 4, 2, 1) that divide the flat - dimension evenly. This is a fallback when model config is not available. - """ - flat_dim = routed_experts_np.shape[1] - - topk = None - for candidate_topk in [6, 8, 4, 2, 1]: - if flat_dim % candidate_topk == 0: - topk = candidate_topk - break - if topk is None: - topk = 1 - logger.warning( - "[R3] Cannot infer topk from flat_dim=%d; falling back to topk=1.", - flat_dim, - ) - num_moe_layers = flat_dim // topk - - _key = (num_moe_layers, topk) - if _key not in _INFER_LOGGED: - _INFER_LOGGED.add(_key) - logger.info( - "[R3] rlvr workflow inferred num_moe_layers=%d, topk=%d from " - "flat_dim=%d (this count includes any dense-FFN layers; the " - "engine-side router_replay_utils.set_router_replay_data handles " - "the dense-vs-MoE split). For deterministic behaviour, pass " - "num_moe_layers and topk explicitly from the model config.", - num_moe_layers, - topk, - flat_dim, - ) - - from areal.engine.router_replay_utils import preprocess_routed_experts_batch - - return preprocess_routed_experts_batch( - routed_experts_np, - input_ids, - attention_mask, - num_moe_layers=num_moe_layers, - topk=topk, - compress_dtype=compress_dtype, - ) - - def inject_routed_experts_into_result( result: dict[str, torch.Tensor], routed_experts: Optional[torch.Tensor], From 1b3a3a3e98a5137b3179f0504d3a78442257661a Mon Sep 17 00:00:00 2001 From: bingyechen Date: Tue, 28 Apr 2026 17:25:24 +0800 Subject: [PATCH 055/112] feat(MoE): auto get num_moe_layers and topk --- areal/api/cli_args.py | 3 +- areal/trainer/rl_trainer.py | 27 ++++ areal/workflow/rlvr.py | 6 + areal/workflow/rlvr_r3_patch.py | 117 ++++++++++++++++-- docs/en/cli_reference.md | 59 ++++----- docs/zh/cli_reference.md | 59 ++++----- ...moonlight_16b_a3b_gsm8k_grpo_megatron.yaml | 7 +- ...light_16b_a3b_gsm8k_grpo_megatron_h20.yaml | 7 +- 8 files changed, 212 insertions(+), 73 deletions(-) diff --git a/areal/api/cli_args.py b/areal/api/cli_args.py index df7ed4891d..f479eb2eae 100644 --- a/areal/api/cli_args.py +++ b/areal/api/cli_args.py @@ -1865,7 +1865,8 @@ class InferenceEngineConfig: return_routed_experts: bool = field( default=False, metadata={ - "help": "Return routed expert indices for MoE models. Effective only when using SGLang engine with MoE models." + "help": "Return routed expert indices for MoE models. Effective only when using SGLang engine with MoE models. " + "num_moe_layers and topk are automatically resolved from the model config." }, ) diff --git a/areal/trainer/rl_trainer.py b/areal/trainer/rl_trainer.py index d14ac03aef..351734ae12 100644 --- a/areal/trainer/rl_trainer.py +++ b/areal/trainer/rl_trainer.py @@ -135,6 +135,17 @@ def __init__( type(config.actor.megatron).__name__, ) + from areal.workflow.rlvr_r3_patch import resolve_r3_moe_config + + model_path = config.actor.path or config.tokenizer_path + num_moe_layers, topk = resolve_r3_moe_config(model_path) + logger.info( + "[R3] Resolved from model config at %s: num_moe_layers=%d, topk=%d.", + model_path, + num_moe_layers, + topk, + ) + sglang_cfg = getattr(config, "sglang", None) if sglang_cfg is not None and not getattr( sglang_cfg, "skip_tokenizer_init", True @@ -344,6 +355,22 @@ def train( steps_per_epoch = len(self.train_dataloader) max_steps = total_epochs * steps_per_epoch + if getattr(config.rollout, "return_routed_experts", False): + from areal.workflow.rlvr_r3_patch import resolve_r3_moe_config + + model_path = config.actor.path or config.tokenizer_path + num_moe_layers, topk = resolve_r3_moe_config(model_path) + r3_inject = { + "r3_num_moe_layers": num_moe_layers, + "r3_topk": topk, + } + if workflow_kwargs is None: + workflow_kwargs = {} + workflow_kwargs.update(r3_inject) + if eval_workflow_kwargs is None: + eval_workflow_kwargs = {} + eval_workflow_kwargs.update(r3_inject) + # Initialize proxy workers if not using RolloutWorkflow if workflow is None: openai_cfg = self.config.rollout.openai diff --git a/areal/workflow/rlvr.py b/areal/workflow/rlvr.py index e870a08f2f..536fe72f70 100644 --- a/areal/workflow/rlvr.py +++ b/areal/workflow/rlvr.py @@ -56,6 +56,8 @@ def __init__( | str = default_get_input_ids_fn, data_extract_prompt_fn: Callable[[dict[str, Any]], Any] | str = default_data_extract_prompt_fn, + r3_num_moe_layers: int | None = None, + r3_topk: int | None = None, ): self.reward_fn = reward_fn self.tokenizer = tokenizer @@ -66,6 +68,8 @@ def __init__( self.tokenizer = tokenizer self.gconfig = gconfig.new_with_stop_and_pad_token_ids(self.tokenizer) self.enable_thinking = enable_thinking + self.r3_num_moe_layers = r3_num_moe_layers + self.r3_topk = r3_topk if not isinstance(reward_fn, str): self.async_reward_fn = AsyncRewardWrapper(reward_fn) # Support string paths for get_input_ids_fn @@ -185,6 +189,8 @@ async def arun_episode( resp.routed_experts, res["input_ids"], res["attention_mask"], + num_moe_layers=self.r3_num_moe_layers, + topk=self.r3_topk, ) res = inject_routed_experts_into_result(res, routed_experts_tensor) diff --git a/areal/workflow/rlvr_r3_patch.py b/areal/workflow/rlvr_r3_patch.py index efceb58651..a93f24eac3 100644 --- a/areal/workflow/rlvr_r3_patch.py +++ b/areal/workflow/rlvr_r3_patch.py @@ -18,22 +18,111 @@ from __future__ import annotations import logging -from typing import Optional import numpy as np import torch logger = logging.getLogger(__name__) +_RESOLVED_CACHE: dict[str, tuple[int, int]] = {} + + +def resolve_r3_moe_config(model_path: str) -> tuple[int, int]: + """Resolve ``num_moe_layers`` and ``topk`` from the HuggingFace model config. + + Inspects the model's ``config.json`` for standard MoE fields: + + * ``num_experts_per_tok`` → topk + * ``num_hidden_layers`` and ``first_k_dense_replace`` → num_moe_layers + (MoE layers = total layers - first_k_dense_replace) + * ``moe_layer_freq`` (int or list) → used when available for precise counting + + Results are cached per ``model_path`` to avoid repeated disk reads. + + Args: + model_path: Path or repo ID of the HuggingFace model. + + Returns: + ``(num_moe_layers, topk)`` tuple. + + Raises: + ValueError: If the model config does not contain sufficient MoE fields. + """ + if model_path in _RESOLVED_CACHE: + return _RESOLVED_CACHE[model_path] + + from transformers import AutoConfig + + hf_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) + + topk = getattr(hf_config, "num_experts_per_tok", None) + if topk is None: + raise ValueError( + "[R3] Cannot resolve topk from model config: " + f"'num_experts_per_tok' not found in {type(hf_config).__name__} " + f"at model_path={model_path}. This model may not be a MoE model, " + "or uses a non-standard config field name for router top-k." + ) + + num_hidden_layers = getattr(hf_config, "num_hidden_layers", None) + if num_hidden_layers is None: + raise ValueError( + "[R3] Cannot resolve num_moe_layers from model config: " + f"'num_hidden_layers' not found in {type(hf_config).__name__} " + f"at model_path={model_path}." + ) + + moe_layer_freq = getattr(hf_config, "moe_layer_freq", None) + first_k_dense_replace = getattr(hf_config, "first_k_dense_replace", None) + + if moe_layer_freq is not None: + if isinstance(moe_layer_freq, int): + num_moe_layers = sum( + 1 for i in range(num_hidden_layers) if i % moe_layer_freq == 0 + ) + elif isinstance(moe_layer_freq, (list, tuple)): + num_moe_layers = sum(1 for v in moe_layer_freq if v == 1) + else: + raise ValueError( + f"[R3] Unsupported moe_layer_freq type: {type(moe_layer_freq)}" + ) + elif first_k_dense_replace is not None: + num_moe_layers = num_hidden_layers - first_k_dense_replace + else: + num_moe_layers = num_hidden_layers + + if num_moe_layers <= 0: + raise ValueError( + "[R3] Resolved num_moe_layers=0 from model config. " + f"num_hidden_layers={num_hidden_layers}, " + f"moe_layer_freq={moe_layer_freq}, " + f"first_k_dense_replace={first_k_dense_replace}. " + "This model may not be a MoE model." + ) + + _RESOLVED_CACHE[model_path] = (num_moe_layers, topk) + logger.info( + "[R3] Resolved from model config at %s: " + "num_moe_layers=%d, topk=%d " + "(num_hidden_layers=%d, moe_layer_freq=%s, first_k_dense_replace=%s).", + model_path, + num_moe_layers, + topk, + num_hidden_layers, + moe_layer_freq, + first_k_dense_replace, + ) + return num_moe_layers, topk + def extract_routed_experts( - routed_experts_np: Optional[np.ndarray], + routed_experts_np: np.ndarray | None, input_ids: torch.Tensor, attention_mask: torch.Tensor, - num_moe_layers: int, - topk: int, + num_moe_layers: int | None = None, + topk: int | None = None, compress_dtype: bool = True, -) -> Optional[torch.Tensor]: +) -> torch.Tensor | None: """Convert ``ModelResponse.routed_experts`` into a training tensor. Args: @@ -41,16 +130,28 @@ def extract_routed_experts( as returned by the SGLang inference backend, or ``None``. input_ids: ``(1, seq_len)`` token ids (prompt + response). attention_mask: ``(1, seq_len)`` with 1 for real tokens, 0 for padding. - num_moe_layers: Number of MoE layers in the model. **Required**. - topk: Router top-k. **Required**. + num_moe_layers: Number of MoE layers. Required -- resolved from model config. + topk: Router top-k. Required -- resolved from model config. compress_dtype: Downcast to ``uint8`` / ``int16`` when possible. Returns: ``torch.Tensor`` of shape ``(1, seq_len, num_moe_layers, topk)`` or ``None``. + + Raises: + ValueError: If ``num_moe_layers`` or ``topk`` is not provided. """ if routed_experts_np is None: return None + if num_moe_layers is None or topk is None: + raise ValueError( + "[R3] num_moe_layers and topk are required for routed_experts " + "preprocessing. These should be resolved from the model config " + "via resolve_r3_moe_config(model_path). Shape-based inference " + f"(decomposing flat_dim={routed_experts_np.shape[1]}) is " + "ambiguous and can silently corrupt training." + ) + try: from areal.engine.router_replay_utils import ( preprocess_routed_experts_batch, @@ -75,7 +176,7 @@ def extract_routed_experts( def inject_routed_experts_into_result( result: dict[str, torch.Tensor], - routed_experts: Optional[torch.Tensor], + routed_experts: torch.Tensor | None, ) -> dict[str, torch.Tensor]: """Add ``routed_experts`` to the result dict if available. diff --git a/docs/en/cli_reference.md b/docs/en/cli_reference.md index d51c58866c..0da8e29fdc 100644 --- a/docs/en/cli_reference.md +++ b/docs/en/cli_reference.md @@ -523,7 +523,7 @@ Configuration for inference servers, including offpolicyness control. | `scheduling_strategy` | [`SchedulingStrategy`](section-scheduling-strategy) | **Required** | The scheduling strategy of this InferenceEngine, either separation or colocation. Currently only used by the RolloutController. | | `use_lora` | boolean | `False` | Whether to use LoRA. Should be same as actors LORA option. | | `openai` | [`OpenAIProxyConfig`](section-open-ai-proxy) \| None | `None` | OpenAI proxy configuration (used when workflow is an agent workflow). | -| `return_routed_experts` | boolean | `False` | Return routed expert indices for MoE models. Effective only when using SGLang engine with MoE models. | +| `return_routed_experts` | boolean | `False` | Return routed expert indices for MoE models. Effective only when using SGLang engine with MoE models. num_moe_layers and topk are automatically resolved from the model config. | (section-sg-lang)= @@ -913,34 +913,35 @@ Configuration for Megatron-LM training framework. Refer to Megatron-LM documentation for implementation details. -| Parameter | Type | Default | Description | -| ------------------------------------------ | -------------------------------------------------------------------- | ------------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| `wrap_with_ddp` | boolean | `True` | - | -| `use_torch_fsdp2` | boolean | `False` | - | -| `use_custom_fsdp` | boolean | `False` | - | -| `ddp` | [`DistributedDataParallelConfig`](section-distributed-data-parallel) | **Required** | - | -| `virtual_pipeline_parallel_size` | integer | `1` | Virtual pipeline parallel size for Megatron interleaved schedule. Set to >1 to enable VPP. Default is 1 (disabled). | -| `overlap_param_gather_with_optimizer_step` | boolean | `False` | - | -| `use_precision_aware_optimizer` | boolean | `False` | Enable precision-aware optimizer for Megatron. When using adam_bf16 optimizer type with Megatron Engine, this is automatically enabled with exp_avg_dtype=bfloat16 and exp_avg_sq_dtype=bfloat16. | -| `main_grads_dtype` | string | `"float32"` | - | -| `main_params_dtype` | string | `"float32"` | - | -| `exp_avg_dtype` | string | `"float32"` | - | -| `exp_avg_sq_dtype` | string | `"float32"` | - | -| `async_save` | boolean | `False` | - | -| `use_checkpoint_opt_param_scheduler` | boolean | `True` | - | -| `use_deterministic_algorithms` | boolean | `False` | - | -| `recompute_granularity` | string \| None | `"full"` | - | -| `recompute_method` | string \| None | `"uniform"` | - | -| `recompute_num_layers` | integer \| None | `1` | - | -| `distribute_saved_activations` | boolean \| None | `None` | - | -| `recompute_modules` | list of string \| None | `None` | - | -| `moe_router_dtype` | string \| None | `"fp32"` | - | -| `moe_shared_expert_overlap` | boolean | `False` | Enable overlapping between shared expert computations and dispatcher communications. Without this, the shared experts execute after the routed experts. | -| `moe_enable_deepep` | boolean | `False` | - | -| `moe_token_dispatcher_type` | string | `"alltoall"` | Type of token dispatcher. Options: 'allgather','alltoall' and 'flex'. | -| `moe_permute_fusion` | boolean | `False` | Fuse token rearrangement ops during token dispatching. | -| `fp8_config` | [`FP8EngineConfig`](section-fp8-engine) \| None | `None` | - | -| `bridge_type` | string | `"mbridge"` | Bridge backend for MegatronEngine. Choices: 'mbridge' or 'megatron-bridge'. **Choices:** `mbridge`, `megatron-bridge` | +| Parameter | Type | Default | Description | +| ------------------------------------------ | -------------------------------------------------------------------- | ------------ | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `wrap_with_ddp` | boolean | `True` | - | +| `use_torch_fsdp2` | boolean | `False` | - | +| `use_custom_fsdp` | boolean | `False` | - | +| `ddp` | [`DistributedDataParallelConfig`](section-distributed-data-parallel) | **Required** | - | +| `virtual_pipeline_parallel_size` | integer | `1` | Virtual pipeline parallel size for Megatron interleaved schedule. Set to >1 to enable VPP. Default is 1 (disabled). | +| `overlap_param_gather_with_optimizer_step` | boolean | `False` | - | +| `use_precision_aware_optimizer` | boolean | `False` | Enable precision-aware optimizer for Megatron. When using adam_bf16 optimizer type with Megatron Engine, this is automatically enabled with exp_avg_dtype=bfloat16 and exp_avg_sq_dtype=bfloat16. | +| `main_grads_dtype` | string | `"float32"` | - | +| `main_params_dtype` | string | `"float32"` | - | +| `exp_avg_dtype` | string | `"float32"` | - | +| `exp_avg_sq_dtype` | string | `"float32"` | - | +| `async_save` | boolean | `False` | - | +| `use_checkpoint_opt_param_scheduler` | boolean | `True` | - | +| `use_deterministic_algorithms` | boolean | `False` | - | +| `recompute_granularity` | string \| None | `"full"` | - | +| `recompute_method` | string \| None | `"uniform"` | - | +| `recompute_num_layers` | integer \| None | `1` | - | +| `distribute_saved_activations` | boolean \| None | `None` | - | +| `recompute_modules` | list of string \| None | `None` | - | +| `moe_router_dtype` | string \| None | `"fp32"` | - | +| `moe_shared_expert_overlap` | boolean | `False` | Enable overlapping between shared expert computations and dispatcher communications. Without this, the shared experts execute after the routed experts. | +| `moe_enable_deepep` | boolean | `False` | - | +| `moe_token_dispatcher_type` | string | `"alltoall"` | Type of token dispatcher. Options: 'allgather','alltoall' and 'flex'. | +| `moe_permute_fusion` | boolean | `False` | Fuse token rearrangement ops during token dispatching. | +| `fp8_config` | [`FP8EngineConfig`](section-fp8-engine) \| None | `None` | - | +| `bridge_type` | string | `"mbridge"` | Bridge backend for MegatronEngine. Choices: 'mbridge' or 'megatron-bridge'. **Choices:** `mbridge`, `megatron-bridge` | +| `enable_router_replay` | boolean | `False` | Enable Router Replay (R3) for MoE models. When True, the training forward pass replays the expert routing decisions from the inference engine to reduce train-inference routing discrepancy. Automatically set by the trainer when rollout.return_routed_experts=True. | (section-open-ai-proxy)= diff --git a/docs/zh/cli_reference.md b/docs/zh/cli_reference.md index afd64db4af..31002973d0 100644 --- a/docs/zh/cli_reference.md +++ b/docs/zh/cli_reference.md @@ -521,7 +521,7 @@ Configuration for inference servers, including offpolicyness control. | `scheduling_strategy` | [`SchedulingStrategy`](section-scheduling-strategy) | **Required** | The scheduling strategy of this InferenceEngine, either separation or colocation. Currently only used by the RolloutController. | | `use_lora` | boolean | `False` | Whether to use LoRA. Should be same as actors LORA option. | | `openai` | [`OpenAIProxyConfig`](section-open-ai-proxy) \| None | `None` | OpenAI proxy configuration (used when workflow is an agent workflow). | -| `return_routed_experts` | boolean | `False` | Return routed expert indices for MoE models. Effective only when using SGLang engine with MoE models. | +| `return_routed_experts` | boolean | `False` | Return routed expert indices for MoE models. Effective only when using SGLang engine with MoE models. num_moe_layers and topk are automatically resolved from the model config. | (section-sg-lang)= @@ -911,34 +911,35 @@ Configuration for Megatron-LM training framework. Refer to Megatron-LM documentation for implementation details. -| Parameter | Type | Default | Description | -| ------------------------------------------ | -------------------------------------------------------------------- | ------------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| `wrap_with_ddp` | boolean | `True` | - | -| `use_torch_fsdp2` | boolean | `False` | - | -| `use_custom_fsdp` | boolean | `False` | - | -| `ddp` | [`DistributedDataParallelConfig`](section-distributed-data-parallel) | **Required** | - | -| `virtual_pipeline_parallel_size` | integer | `1` | Virtual pipeline parallel size for Megatron interleaved schedule. Set to >1 to enable VPP. Default is 1 (disabled). | -| `overlap_param_gather_with_optimizer_step` | boolean | `False` | - | -| `use_precision_aware_optimizer` | boolean | `False` | Enable precision-aware optimizer for Megatron. When using adam_bf16 optimizer type with Megatron Engine, this is automatically enabled with exp_avg_dtype=bfloat16 and exp_avg_sq_dtype=bfloat16. | -| `main_grads_dtype` | string | `"float32"` | - | -| `main_params_dtype` | string | `"float32"` | - | -| `exp_avg_dtype` | string | `"float32"` | - | -| `exp_avg_sq_dtype` | string | `"float32"` | - | -| `async_save` | boolean | `False` | - | -| `use_checkpoint_opt_param_scheduler` | boolean | `True` | - | -| `use_deterministic_algorithms` | boolean | `False` | - | -| `recompute_granularity` | string \| None | `"full"` | - | -| `recompute_method` | string \| None | `"uniform"` | - | -| `recompute_num_layers` | integer \| None | `1` | - | -| `distribute_saved_activations` | boolean \| None | `None` | - | -| `recompute_modules` | list of string \| None | `None` | - | -| `moe_router_dtype` | string \| None | `"fp32"` | - | -| `moe_shared_expert_overlap` | boolean | `False` | Enable overlapping between shared expert computations and dispatcher communications. Without this, the shared experts execute after the routed experts. | -| `moe_enable_deepep` | boolean | `False` | - | -| `moe_token_dispatcher_type` | string | `"alltoall"` | Type of token dispatcher. Options: 'allgather','alltoall' and 'flex'. | -| `moe_permute_fusion` | boolean | `False` | Fuse token rearrangement ops during token dispatching. | -| `fp8_config` | [`FP8EngineConfig`](section-fp8-engine) \| None | `None` | - | -| `bridge_type` | string | `"mbridge"` | Bridge backend for MegatronEngine. Choices: 'mbridge' or 'megatron-bridge'. **Choices:** `mbridge`, `megatron-bridge` | +| Parameter | Type | Default | Description | +| ------------------------------------------ | -------------------------------------------------------------------- | ------------ | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `wrap_with_ddp` | boolean | `True` | - | +| `use_torch_fsdp2` | boolean | `False` | - | +| `use_custom_fsdp` | boolean | `False` | - | +| `ddp` | [`DistributedDataParallelConfig`](section-distributed-data-parallel) | **Required** | - | +| `virtual_pipeline_parallel_size` | integer | `1` | Virtual pipeline parallel size for Megatron interleaved schedule. Set to >1 to enable VPP. Default is 1 (disabled). | +| `overlap_param_gather_with_optimizer_step` | boolean | `False` | - | +| `use_precision_aware_optimizer` | boolean | `False` | Enable precision-aware optimizer for Megatron. When using adam_bf16 optimizer type with Megatron Engine, this is automatically enabled with exp_avg_dtype=bfloat16 and exp_avg_sq_dtype=bfloat16. | +| `main_grads_dtype` | string | `"float32"` | - | +| `main_params_dtype` | string | `"float32"` | - | +| `exp_avg_dtype` | string | `"float32"` | - | +| `exp_avg_sq_dtype` | string | `"float32"` | - | +| `async_save` | boolean | `False` | - | +| `use_checkpoint_opt_param_scheduler` | boolean | `True` | - | +| `use_deterministic_algorithms` | boolean | `False` | - | +| `recompute_granularity` | string \| None | `"full"` | - | +| `recompute_method` | string \| None | `"uniform"` | - | +| `recompute_num_layers` | integer \| None | `1` | - | +| `distribute_saved_activations` | boolean \| None | `None` | - | +| `recompute_modules` | list of string \| None | `None` | - | +| `moe_router_dtype` | string \| None | `"fp32"` | - | +| `moe_shared_expert_overlap` | boolean | `False` | Enable overlapping between shared expert computations and dispatcher communications. Without this, the shared experts execute after the routed experts. | +| `moe_enable_deepep` | boolean | `False` | - | +| `moe_token_dispatcher_type` | string | `"alltoall"` | Type of token dispatcher. Options: 'allgather','alltoall' and 'flex'. | +| `moe_permute_fusion` | boolean | `False` | Fuse token rearrangement ops during token dispatching. | +| `fp8_config` | [`FP8EngineConfig`](section-fp8-engine) \| None | `None` | - | +| `bridge_type` | string | `"mbridge"` | Bridge backend for MegatronEngine. Choices: 'mbridge' or 'megatron-bridge'. **Choices:** `mbridge`, `megatron-bridge` | +| `enable_router_replay` | boolean | `False` | Enable Router Replay (R3) for MoE models. When True, the training forward pass replays the expert routing decisions from the inference engine to reduce train-inference routing discrepancy. Automatically set by the trainer when rollout.return_routed_experts=True. | (section-open-ai-proxy)= diff --git a/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron.yaml b/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron.yaml index 14ff04ce9c..65129ca71b 100644 --- a/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron.yaml +++ b/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron.yaml @@ -34,6 +34,7 @@ rollout: # This triggers the entire Router Replay pipeline: SGLang returns per-token # expert indices, which are then replayed during training to eliminate # train/inference routing mismatch in MoE models. + # num_moe_layers and topk are automatically resolved from the model config. return_routed_experts: true gconfig: @@ -88,14 +89,14 @@ actor: recompute_method: uniform recompute_num_layers: 14 main_grads_dtype: bfloat16 # 梯度从 FP32 降为 BF16(节省 ~4 GiB) - # store_param_remainders: true + # store_param_remainders: true # optimizer_cpu_offload: true # optimizer_offload_fraction: 0.5 # main_params_dtype: bfloat16 # main_grads_dtype: bfloat16 # # adam_bf16 已自动设置以下两项,但显式声明更安全 # exp_avg_dtype: bfloat16 - # exp_avg_sq_dtype: bfloat16 + # exp_avg_sq_dtype: bfloat16 ddp: grad_reduce_in_fp32: false # ← 保持逐层重计算 scheduling_spec: @@ -200,4 +201,4 @@ perf_tracer: fileroot: ${cluster.fileroot} enabled: false session_tracer: - enabled: false \ No newline at end of file + enabled: false diff --git a/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron_h20.yaml b/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron_h20.yaml index 370b1bce88..f71331d082 100644 --- a/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron_h20.yaml +++ b/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron_h20.yaml @@ -34,6 +34,7 @@ rollout: # This triggers the entire Router Replay pipeline: SGLang returns per-token # expert indices, which are then replayed during training to eliminate # train/inference routing mismatch in MoE models. + # num_moe_layers and topk are automatically resolved from the model config. return_routed_experts: true gconfig: @@ -88,14 +89,14 @@ actor: recompute_method: uniform recompute_num_layers: 14 main_grads_dtype: bfloat16 # 梯度从 FP32 降为 BF16(节省 ~4 GiB) - # store_param_remainders: true + # store_param_remainders: true # optimizer_cpu_offload: true # optimizer_offload_fraction: 0.5 # main_params_dtype: bfloat16 # main_grads_dtype: bfloat16 # # adam_bf16 已自动设置以下两项,但显式声明更安全 # exp_avg_dtype: bfloat16 - # exp_avg_sq_dtype: bfloat16 + # exp_avg_sq_dtype: bfloat16 ddp: grad_reduce_in_fp32: false # ← 保持逐层重计算 scheduling_spec: @@ -200,4 +201,4 @@ perf_tracer: fileroot: ${cluster.fileroot} enabled: false session_tracer: - enabled: false \ No newline at end of file + enabled: false From e156cf328dd1099984f844faaf19c6fff5ec1ebd Mon Sep 17 00:00:00 2001 From: bingyechen Date: Tue, 28 Apr 2026 19:00:02 +0800 Subject: [PATCH 056/112] refactor(megatron_engine): remove for run moe --- areal/engine/megatron_engine.py | 25 ------------------------- 1 file changed, 25 deletions(-) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index bd0b8b38bd..26bf669b58 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -1144,31 +1144,6 @@ def _create_optimizer(self, ft_spec: FinetuneSpec) -> None: mcore_opt_config.exp_avg_sq_dtype = getattr( torch, self.mcore_config.exp_avg_sq_dtype ) - # for run moe - mcore_opt_config.use_precision_aware_optimizer_no_fp8_or_ds_fp8 = ( - mcore_opt_config.use_precision_aware_optimizer - and ( - mcore_opt_config.main_params_dtype != torch.float32 - or (mcore_opt_config.fp8_recipe is None or mcore_opt_config.fp8_recipe == "delayed") - or mcore_opt_config.optimizer_cpu_offload - ) - ) - mcore_opt_config.store_param_remainders = True - import logging as _logging - _opt_logger = _logging.getLogger('AReaL.OptDiag') - _opt_logger.warning( - f'[OptDiag] Megatron OptimizerConfig: ' - f'use_precision_aware_optimizer={mcore_opt_config.use_precision_aware_optimizer}, ' - f'use_precision_aware_optimizer_no_fp8_or_ds_fp8=' - f'{getattr(mcore_opt_config, "use_precision_aware_optimizer_no_fp8_or_ds_fp8", "N/A")}, ' - f'store_param_remainders={mcore_opt_config.store_param_remainders}, ' - f'main_params_dtype={mcore_opt_config.main_params_dtype}, ' - f'main_grads_dtype={mcore_opt_config.main_grads_dtype}, ' - f'exp_avg_dtype={mcore_opt_config.exp_avg_dtype}, ' - f'exp_avg_sq_dtype={mcore_opt_config.exp_avg_sq_dtype}, ' - f'use_distributed_optimizer={mcore_opt_config.use_distributed_optimizer}, ' - f'bf16={mcore_opt_config.bf16}' - ) self.optimizer = get_megatron_optimizer(mcore_opt_config, self.model) warmup_steps_proportion = self.optimizer_config.warmup_steps_proportion From 8e9beefdecd40140d7f5be4bf7aa1c8ccaca32ae Mon Sep 17 00:00:00 2001 From: bingyechen Date: Tue, 28 Apr 2026 19:28:01 +0800 Subject: [PATCH 057/112] fix(engine): fix log --- areal/engine/megatron_engine.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index bffd133e3e..1ee56c5ace 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -123,6 +123,15 @@ from areal.api.cli_args import DPOEngineConfig, PPOActorConfig, PPOCriticConfig def _patch_gpt_model_postprocess_for_inference(model_list: _MegatronModelList) -> None: + """Patch ``GPTModel._postprocess`` to skip MTP when ``labels=None``. + + In the ``forward_only`` path (e.g. ``compute_logp``), no labels are + passed to the model. However, Megatron-Core's ``_postprocess`` still + enters the MTP branch when ``config.mtp_num_layers`` is set, invoking + ``process_mtp_loss`` with ``labels=None`` which either crashes or + corrupts the hidden states. Temporarily clearing ``mtp_num_layers`` + forces ``_postprocess`` to skip MTP and return logits directly. + """ from megatron.core.models.gpt.gpt_model import GPTModel if getattr(GPTModel, "_areal_postprocess_patched", False): @@ -1264,6 +1273,7 @@ def _create_optimizer(self, ft_spec: FinetuneSpec) -> None: mcore_opt_config.exp_avg_sq_dtype = getattr( torch, self.mcore_config.exp_avg_sq_dtype ) + self.optimizer = get_megatron_optimizer(mcore_opt_config, self.model) warmup_steps_proportion = self.optimizer_config.warmup_steps_proportion From ab5c323578a10bebd7b6f33815e03c2aa847a9be Mon Sep 17 00:00:00 2001 From: bingyechen Date: Tue, 28 Apr 2026 19:34:50 +0800 Subject: [PATCH 058/112] feat(controller): fix log --- areal/infra/controller/rollout_controller.py | 10 ++++++++-- areal/infra/controller/train_controller.py | 12 +++++++++--- areal/infra/launcher/sglang_launch_server.py | 13 ------------- 3 files changed, 17 insertions(+), 18 deletions(-) diff --git a/areal/infra/controller/rollout_controller.py b/areal/infra/controller/rollout_controller.py index 46f6c365de..21fd52d5f6 100644 --- a/areal/infra/controller/rollout_controller.py +++ b/areal/infra/controller/rollout_controller.py @@ -240,27 +240,33 @@ async def _async_initialize( **kwargs, ): # Create workers via scheduler + logger.info("Creating workers via scheduler...") worker_ids = self.scheduler.create_workers(job=job) + logger.info(f"Workers created: {worker_ids}") # Wait for workers to be ready + logger.info("Waiting for workers to be ready...") self.workers = self.scheduler.get_workers(role=job.role) + logger.info(f"Workers ready: {[w.id for w in self.workers]}") # Get engine class path for dynamic import on workers engine_class = self.inf_engine - engine_path = f"{engine_class.__module__}.{engine_class.__name__}" # Create and initialize engines on workers + logger.info("Creating engines...") tasks = [ self.scheduler.create_engine( worker_id=worker.id, - engine=engine_path, + engine=f"{engine_class.__module__}.{engine_class.__name__}", engine_name=self._engine_name(rank), config=self.config, ) for rank, worker in enumerate(self.workers) ] await asyncio.gather(*tasks) + logger.info("Engine created on all workers!") + logger.info("Calling engine initialization...") if server_infos is not None: # Connecting to existing local servers for evaluation self.server_infos = server_infos diff --git a/areal/infra/controller/train_controller.py b/areal/infra/controller/train_controller.py index 71b3d32a81..0eeee91e5f 100644 --- a/areal/infra/controller/train_controller.py +++ b/areal/infra/controller/train_controller.py @@ -284,10 +284,14 @@ def initialize( ) # Create workers via scheduler + logger.info("Creating workers via scheduler...") worker_ids = self.scheduler.create_workers(job=job) + logger.info(f"Workers created: {worker_ids}") # Wait for workers to be ready + logger.info("Waiting for workers to be ready...") self.workers = self.scheduler.get_workers(role=job.role) + logger.info(f"Workers ready: {[w.id for w in self.workers]}") # Determine distributed training master address and port from rank 0 worker # These are used for PyTorch distributed initialization across workers @@ -308,10 +312,9 @@ def initialize( engine_class = self.train_engine # Create and initialize engines on workers - engine_path = f"{engine_class.__module__}.{engine_class.__name__}" run_async_task( self._async_create_engines, - engine_path, + f"{engine_class.__module__}.{engine_class.__name__}", ) run_async_task(self._async_initialize_engines, ft_spec, **kwargs) @@ -328,6 +331,7 @@ def _engine_name(self, rank: int) -> str: async def _async_create_engines(self, engine: str): """Create engine instances on all workers. Sets distributed env vars before creation.""" + logger.info("Creating engines on workers...") async def _setup_worker(worker: Worker, rank: int): env = { @@ -335,7 +339,7 @@ async def _setup_worker(worker: Worker, rank: int): "WORLD_SIZE": str(len(self.workers)), "MASTER_ADDR": str(self._master_addr), "MASTER_PORT": str(self._master_port), - "LOCAL_RANK": "0", + "LOCAL_RANK": "0", # NOTE: local rank is always 0 while each process use only one GPU } logger.debug( f"Setting env for worker " @@ -353,9 +357,11 @@ async def _setup_worker(worker: Worker, rank: int): _setup_worker(worker, rank) for rank, worker in enumerate(self.workers) ] await asyncio.gather(*tasks) + logger.info("Engines created on all workers!") async def _async_initialize_engines(self, ft_spec: FinetuneSpec, **kwargs): """Initialize engines: create process groups, then load models and setup optimizers.""" + logger.info("Calling engine initialization...") # Phase 1: Create process groups for distributed training tasks = [ self.scheduler.async_call_engine( diff --git a/areal/infra/launcher/sglang_launch_server.py b/areal/infra/launcher/sglang_launch_server.py index bddc648208..746a88befd 100644 --- a/areal/infra/launcher/sglang_launch_server.py +++ b/areal/infra/launcher/sglang_launch_server.py @@ -1,16 +1,3 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. """Thin wrapper around ``sglang.launch_server`` that installs AReaL's R3 monkey patches before the upstream server boots. From fb6e48e516e89d6dcfd89df2cfa64131b8065d46 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Tue, 28 Apr 2026 19:35:00 +0800 Subject: [PATCH 059/112] docs(infra): fix --- areal/infra/launcher/sglang_r3_patch.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/areal/infra/launcher/sglang_r3_patch.py b/areal/infra/launcher/sglang_r3_patch.py index 60821e143d..a6e1ff13d4 100644 --- a/areal/infra/launcher/sglang_r3_patch.py +++ b/areal/infra/launcher/sglang_r3_patch.py @@ -1,16 +1,3 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. """SGLang server-side monkey patches required for AReaL's R3 (Router Replay). Background From 2d0a30643d8ca8dc715df2ec8b36471c74686c60 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Tue, 28 Apr 2026 19:44:20 +0800 Subject: [PATCH 060/112] refactor(r3): fix split sample --- areal/engine/megatron_engine_r3_patch.py | 52 +++++++++++++----------- areal/trainer/ppo/actor_r3_patch.py | 21 +++------- 2 files changed, 34 insertions(+), 39 deletions(-) diff --git a/areal/engine/megatron_engine_r3_patch.py b/areal/engine/megatron_engine_r3_patch.py index 9e8733070b..1a322863ba 100644 --- a/areal/engine/megatron_engine_r3_patch.py +++ b/areal/engine/megatron_engine_r3_patch.py @@ -210,21 +210,13 @@ def _split_routed_experts_for_mbs( n_mbs = len(mb_list) if forward_indices is None: - # No reordering -- just split evenly - bs = routed_experts.shape[0] - chunk = bs // n_mbs - result = [routed_experts[i * chunk : (i + 1) * chunk] for i in range(n_mbs)] - logger.debug( - "[R3] _split_routed_experts_for_mbs: no forward_indices, " - "split %d samples evenly into %d chunks of %d.", - bs, n_mbs, chunk, - ) - return result - - # Reorder by forward_indices (sample-level reordering) - reordered = routed_experts[forward_indices] + reordered = routed_experts + else: + reordered = routed_experts[forward_indices] - # Determine number of samples per micro-batch from mbs dicts. + # Always derive per-micro-batch sample counts from ``mb_list.mbs`` rather + # than assuming an even ``bs // n_mbs`` split -- the latter silently drops + # samples when ``bs`` is not divisible by ``n_mbs``. result = [] offset = 0 for i, mb_dict in enumerate(mb_list.mbs): @@ -234,11 +226,11 @@ def _split_routed_experts_for_mbs( logger.debug( "[R3] _split_routed_experts_for_mbs: split %d samples into %d mbs " - "with sizes %s (forward_indices len=%d).", + "with sizes %s (forward_indices=%s).", routed_experts.shape[0], n_mbs, [r.shape[0] for r in result], - len(forward_indices), + "None" if forward_indices is None else f"len={len(forward_indices)}", ) return result @@ -492,12 +484,25 @@ def __next__(self): ) return mb_item - original_class_iter = mb_list.__class__.__iter__ - - def _r3_iter(mb_list_self): - return _R3MicroBatchIterator(original_class_iter(mb_list_self)) + # Use a per-instance class swap instead of rebinding the shared + # ``mb_list.__class__.__iter__``. The latter is a global side effect + # that also affects any other ``MicroBatchList`` objects alive in the + # process (e.g. a concurrent engine). Here we create a dynamic + # subclass whose ``__iter__`` injects the R3 setup logic, and assign + # it only to *this* instance via ``__class__``. The original class + # remains untouched. Restoration in the ``finally`` block merely + # flips ``__class__`` back. + _r3_original_mb_list_class = mb_list.__class__ + + class _R3MicroBatchListProxy(_r3_original_mb_list_class): + """Per-instance proxy that wraps __iter__ with R3 setup logic.""" + + def __iter__(self_inner): + return _R3MicroBatchIterator( + _r3_original_mb_list_class.__iter__(self_inner) + ) - mb_list.__class__.__iter__ = _r3_iter + mb_list.__class__ = _R3MicroBatchListProxy # ------------------------------------------------------------------ # 4. Register a forward hook for REPLAY_FORWARD -> REPLAY_BACKWARD toggle. @@ -533,8 +538,9 @@ def _r3_post_forward_hook(module, input, output): # Remove forward hooks for handle in hook_handles: handle.remove() - # Restore original class __iter__ and clean up R3 state - mb_list.__class__.__iter__ = original_class_iter + # Restore the original class on this instance (undo the per-instance + # class swap done above). The original class was never modified. + mb_list.__class__ = _r3_original_mb_list_class # Harvest agreement stats BEFORE clearing replay state. _agreement = RouterReplay.harvest_agreement_stats() diff --git a/areal/trainer/ppo/actor_r3_patch.py b/areal/trainer/ppo/actor_r3_patch.py index a4025cc8f6..ca991f75e0 100644 --- a/areal/trainer/ppo/actor_r3_patch.py +++ b/areal/trainer/ppo/actor_r3_patch.py @@ -176,21 +176,9 @@ def split_routed_experts_for_minibatches( n_mbs = len(mb_list) if forward_indices is None: - # No reordering -- just split evenly - bs = routed_experts.shape[0] - chunk = bs // n_mbs - result = [routed_experts[i * chunk : (i + 1) * chunk] for i in range(n_mbs)] - logger.debug( - "[R3] split_routed_experts_for_minibatches: no forward_indices, " - "split %d samples evenly into %d chunks of %d.", - bs, - n_mbs, - chunk, - ) - return result - - # Reorder by forward_indices (sample-level reordering) - reordered = routed_experts[forward_indices] + reordered = routed_experts + else: + reordered = routed_experts[forward_indices] # Determine number of samples per mini-batch from mbs dicts result = [] @@ -204,10 +192,11 @@ def split_routed_experts_for_minibatches( logger.debug( "[R3] split_routed_experts_for_minibatches: split %d samples into " - "%d mini-batches with sizes %s.", + "%d mini-batches with sizes %s (forward_indices=%s).", routed_experts.shape[0], n_mbs, [r.shape[0] for r in result], + "None" if forward_indices is None else f"len={len(forward_indices)}", ) return result From 647c86505177698e8a2a404d8e93a0a493d7f973 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Tue, 28 Apr 2026 21:26:51 +0800 Subject: [PATCH 061/112] feat(sglang): R3 monkey patch --- areal/engine/megatron_engine.py | 2 +- .../inference_service/sglang/launch_server.py | 16 ++++++++++++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index 1ee56c5ace..12e23d713d 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -1273,7 +1273,7 @@ def _create_optimizer(self, ft_spec: FinetuneSpec) -> None: mcore_opt_config.exp_avg_sq_dtype = getattr( torch, self.mcore_config.exp_avg_sq_dtype ) - + self.optimizer = get_megatron_optimizer(mcore_opt_config, self.model) warmup_steps_proportion = self.optimizer_config.warmup_steps_proportion diff --git a/areal/experimental/inference_service/sglang/launch_server.py b/areal/experimental/inference_service/sglang/launch_server.py index 2fdbbc2a47..9d3fe6c130 100644 --- a/areal/experimental/inference_service/sglang/launch_server.py +++ b/areal/experimental/inference_service/sglang/launch_server.py @@ -46,6 +46,22 @@ def areal_launch_server(server_args) -> None: areal_run_scheduler_process, create_result_ipc, ) + + # Install R3 server-side monkey patches (no-op when R3 is + # not used). Must run before ``_launch_subprocesses`` so that the + # ``TokenizerManager._handle_batch_output`` override is visible to + # every subprocess imported from the upstream SGLang entrypoint. + try: + from areal.infra.launcher.sglang_r3_patch import apply_sglang_r3_patch + + apply_sglang_r3_patch() + except Exception: # pragma: no cover - defensive + import logging + + logging.getLogger(__name__).exception( + "[R3] Failed to install AReaL SGLang patches; bridge-mode " + "server will start without R3 wire-format fixes." + ) # ---- END AREAL ---- try: From 0926e0fb56dec6a25fca6561778fe3f2725fe56c Mon Sep 17 00:00:00 2001 From: bingyechen Date: Tue, 28 Apr 2026 22:12:12 +0800 Subject: [PATCH 062/112] refactor(launcher): remove --- areal/infra/launcher/sglang_launch_server.py | 69 -------------------- 1 file changed, 69 deletions(-) delete mode 100644 areal/infra/launcher/sglang_launch_server.py diff --git a/areal/infra/launcher/sglang_launch_server.py b/areal/infra/launcher/sglang_launch_server.py deleted file mode 100644 index 746a88befd..0000000000 --- a/areal/infra/launcher/sglang_launch_server.py +++ /dev/null @@ -1,69 +0,0 @@ -"""Thin wrapper around ``sglang.launch_server`` that installs AReaL's R3 -monkey patches before the upstream server boots. - -Usage (transparent replacement for ``python3 -m sglang.launch_server ...``):: - - python3 -m areal.infra.launcher.sglang_launch_server --model-path ... - -Only the R3 patch is installed; all other CLI behaviour is delegated to -``sglang.launch_server`` unchanged so this entrypoint stays drop-in safe -when R3 is not used. -""" - -from __future__ import annotations - -import logging -import os -import sys - - -def _install_areal_patches() -> None: - """Install AReaL monkey patches that must be active in the SGLang - server process (scheduler/tokenizer manager/HTTP server).""" - try: - from areal.infra.launcher.sglang_r3_patch import apply_sglang_r3_patch - - apply_sglang_r3_patch() - except Exception: # pragma: no cover - defensive - logging.getLogger(__name__).exception( - "[R3] Failed to install AReaL SGLang patches; server will " - "start without R3 wire-format fixes." - ) - - -def main() -> None: - _install_areal_patches() - - # Delegate to upstream launcher. We keep argv intact (including - # argv[0] mangling done by ``python3 -m``) because - # ``sglang.launch_server`` uses ``sys.argv[1:]`` directly. - from sglang.srt.server_args import prepare_server_args - from sglang.srt.utils import kill_process_tree - from sglang.srt.utils.common import suppress_noisy_warnings - - suppress_noisy_warnings() - - server_args = prepare_server_args(sys.argv[1:]) - - # Same dispatch as ``sglang/launch_server.py``. - try: - if getattr(server_args, "grpc_mode", False): - import asyncio - - from sglang.srt.entrypoints.grpc_server import serve_grpc - - asyncio.run(serve_grpc(server_args)) - elif getattr(server_args, "encoder_only", False): - from sglang.srt.disaggregation.encode_server import launch_server - - launch_server(server_args) - else: - from sglang.srt.entrypoints.http_server import launch_server - - launch_server(server_args) - finally: - kill_process_tree(os.getpid(), include_parent=False) - - -if __name__ == "__main__": - main() From 1b01ba33dd9b79714789bdbc0bdff2c6880f2069 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Wed, 29 Apr 2026 00:35:13 +0800 Subject: [PATCH 063/112] feat(actor): fix for imp_weight --- .../moonlight_16b_a3b_gsm8k_grpo_megatron_h20_base.yaml | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron_h20_base.yaml b/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron_h20_base.yaml index 7dc2a0c78f..e69433a4df 100644 --- a/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron_h20_base.yaml +++ b/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron_h20_base.yaml @@ -67,7 +67,11 @@ actor: ppo_n_minibatches: 1 # ← 从 1 提高至 4(分批梯度累积) recompute_logprob: true use_decoupled_loss: true - behave_imp_weight_cap: 5.0 + rejection_sampling: + level: token + action: mask + metric: ratio + upper: 5.0 reward_norm: mean_level: group std_level: group From 7a3291510926d382ecd199164dcebbba77b3a4f1 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Wed, 29 Apr 2026 10:46:14 +0800 Subject: [PATCH 064/112] feat(engine): add use_precision_aware_optimizer_no_fp8_or_ds_fp8(for run moe) --- areal/engine/megatron_engine.py | 26 +++++++++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index 12e23d713d..687ca3c2d7 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -1273,7 +1273,31 @@ def _create_optimizer(self, ft_spec: FinetuneSpec) -> None: mcore_opt_config.exp_avg_sq_dtype = getattr( torch, self.mcore_config.exp_avg_sq_dtype ) - + # for run moe + mcore_opt_config.use_precision_aware_optimizer_no_fp8_or_ds_fp8 = ( + mcore_opt_config.use_precision_aware_optimizer + and ( + mcore_opt_config.main_params_dtype != torch.float32 + or (mcore_opt_config.fp8_recipe is None or mcore_opt_config.fp8_recipe == "delayed") + or mcore_opt_config.optimizer_cpu_offload + ) + ) + mcore_opt_config.store_param_remainders = True + import logging as _logging + _opt_logger = _logging.getLogger('AReaL.OptDiag') + _opt_logger.warning( + f'[OptDiag] Megatron OptimizerConfig: ' + f'use_precision_aware_optimizer={mcore_opt_config.use_precision_aware_optimizer}, ' + f'use_precision_aware_optimizer_no_fp8_or_ds_fp8=' + f'{getattr(mcore_opt_config, "use_precision_aware_optimizer_no_fp8_or_ds_fp8", "N/A")}, ' + f'store_param_remainders={mcore_opt_config.store_param_remainders}, ' + f'main_params_dtype={mcore_opt_config.main_params_dtype}, ' + f'main_grads_dtype={mcore_opt_config.main_grads_dtype}, ' + f'exp_avg_dtype={mcore_opt_config.exp_avg_dtype}, ' + f'exp_avg_sq_dtype={mcore_opt_config.exp_avg_sq_dtype}, ' + f'use_distributed_optimizer={mcore_opt_config.use_distributed_optimizer}, ' + f'bf16={mcore_opt_config.bf16}' + ) self.optimizer = get_megatron_optimizer(mcore_opt_config, self.model) warmup_steps_proportion = self.optimizer_config.warmup_steps_proportion From d70496b6bf453dd0bd0e743862578272b49006d9 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Wed, 29 Apr 2026 11:36:09 +0800 Subject: [PATCH 065/112] refactor: fix log --- areal/engine/megatron_engine.py | 37 ++++++++++++---------- areal/infra/controller/train_controller.py | 4 --- 2 files changed, 21 insertions(+), 20 deletions(-) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index 687ca3c2d7..5216ab2764 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -1273,7 +1273,27 @@ def _create_optimizer(self, ft_spec: FinetuneSpec) -> None: mcore_opt_config.exp_avg_sq_dtype = getattr( torch, self.mcore_config.exp_avg_sq_dtype ) - # for run moe + # Precision-aware optimizer for MoE models + # ----------------------------------------- + # When ``use_precision_aware_optimizer=True``, Megatron-Core stores + # optimizer states (params, grads, exp_avg, exp_avg_sq) in + # user-specified dtypes and compensates for low-precision rounding + # via fp32 remainders. Two additional flags are set here: + # + # ``use_precision_aware_optimizer_no_fp8_or_ds_fp8``: + # A derived flag that is True when the precision-aware optimizer + # is enabled AND at least one of the following holds: + # 1. ``main_params_dtype != float32`` -- low-precision master + # params (e.g. bf16) need rounding-error compensation. + # 2. ``fp8_recipe is None or "delayed"`` -- no real-time FP8 + # scaling, so the optimizer must handle quantisation error + # itself (delayed-scaling FP8 updates the scale factor with + # a lag, conflicting with the optimizer's immediate residual + # correction). + # 3. ``optimizer_cpu_offload`` -- CPU offload introduces extra + # precision conversions that require special handling. + # When True, the optimizer applies its full residual-compensation + # logic during the parameter update step. mcore_opt_config.use_precision_aware_optimizer_no_fp8_or_ds_fp8 = ( mcore_opt_config.use_precision_aware_optimizer and ( @@ -1283,21 +1303,6 @@ def _create_optimizer(self, ft_spec: FinetuneSpec) -> None: ) ) mcore_opt_config.store_param_remainders = True - import logging as _logging - _opt_logger = _logging.getLogger('AReaL.OptDiag') - _opt_logger.warning( - f'[OptDiag] Megatron OptimizerConfig: ' - f'use_precision_aware_optimizer={mcore_opt_config.use_precision_aware_optimizer}, ' - f'use_precision_aware_optimizer_no_fp8_or_ds_fp8=' - f'{getattr(mcore_opt_config, "use_precision_aware_optimizer_no_fp8_or_ds_fp8", "N/A")}, ' - f'store_param_remainders={mcore_opt_config.store_param_remainders}, ' - f'main_params_dtype={mcore_opt_config.main_params_dtype}, ' - f'main_grads_dtype={mcore_opt_config.main_grads_dtype}, ' - f'exp_avg_dtype={mcore_opt_config.exp_avg_dtype}, ' - f'exp_avg_sq_dtype={mcore_opt_config.exp_avg_sq_dtype}, ' - f'use_distributed_optimizer={mcore_opt_config.use_distributed_optimizer}, ' - f'bf16={mcore_opt_config.bf16}' - ) self.optimizer = get_megatron_optimizer(mcore_opt_config, self.model) warmup_steps_proportion = self.optimizer_config.warmup_steps_proportion diff --git a/areal/infra/controller/train_controller.py b/areal/infra/controller/train_controller.py index 0eeee91e5f..fe9f6c68b7 100644 --- a/areal/infra/controller/train_controller.py +++ b/areal/infra/controller/train_controller.py @@ -341,10 +341,6 @@ async def _setup_worker(worker: Worker, rank: int): "MASTER_PORT": str(self._master_port), "LOCAL_RANK": "0", # NOTE: local rank is always 0 while each process use only one GPU } - logger.debug( - f"Setting env for worker " - f"'{worker.id}' (rank={rank}): {env}" - ) await self.scheduler.set_worker_env(worker.id, env) await self.scheduler.create_engine( worker_id=worker.id, From 3fa32833d8734342ca9812add5f49f4227397a4e Mon Sep 17 00:00:00 2001 From: bingyechen Date: Wed, 29 Apr 2026 11:45:55 +0800 Subject: [PATCH 066/112] fix(guard): remove useless log --- areal/infra/launcher/sglang_r3_patch.py | 37 ------------------------- areal/infra/rpc/guard/app.py | 3 +- 2 files changed, 2 insertions(+), 38 deletions(-) diff --git a/areal/infra/launcher/sglang_r3_patch.py b/areal/infra/launcher/sglang_r3_patch.py index a6e1ff13d4..e656d779f3 100644 --- a/areal/infra/launcher/sglang_r3_patch.py +++ b/areal/infra/launcher/sglang_r3_patch.py @@ -37,31 +37,6 @@ The patch is idempotent. When R3 is disabled the patch is a no-op at runtime because the routed-experts attribute stays ``None``. - -Compatibility note ------------------- -Different SGLang branches use different attribute names on the -``BatchTokenIDOutput`` / ``BatchStrOutput`` dataclasses: - -* Current SGLang main / v0.5.9 : ``recv_obj.routed_experts`` -* Original R3 commit ``bed301a5`` (the base verl's R3 example was built - against, see - ``https://github.com/sgl-project/sglang/commit/bed301a5acaa9577c9aa706468bdf242f6a43051``) - : ``recv_obj.output_routed_experts`` - -The patch handles both attribute names so AReaL runs unmodified against -either SGLang build. Consequently the TokenizerManager attaches the -resulting ``meta_info["routed_experts"]`` entry to every response regardless -of which internal field the scheduler populated. - -Reference ---------- -* SGLang encodes ``routed_experts`` via ``pybase64.b64encode`` at - ``python/sglang/srt/managers/detokenizer_manager.py::_extract_routed_experts`` - for the ``skip_tokenizer_init=False`` path. -* verl applies the symmetric decoder at - ``verl/workers/rollout/sglang_rollout/async_sglang_server.py`` via - ``extract_routed_experts_from_meta_info`` (which assumes a base64 string). """ from __future__ import annotations @@ -71,12 +46,6 @@ logger = logging.getLogger(__name__) _PATCH_APPLIED = False - -# Attribute names used across the SGLang versions we must support. -# * ``routed_experts`` : current SGLang main / v0.5.9. -# * ``output_routed_experts`` : original R3 commit ``bed301a5`` that the -# verl R3 example and the AReaL R3 docker -# were built on. _ROUTED_EXPERTS_ATTRS = ("routed_experts", "output_routed_experts") @@ -149,12 +118,6 @@ def _handle_batch_output_r3(self, recv_obj): # type: ignore[no-redef] # serialisation sees a plain string (which ``jsonable_encoder`` # passes through untouched) instead of a ``torch.Tensor`` (which # ``jsonable_encoder`` silently flattens to ``{}``). - # - # We must handle BOTH attribute names because different SGLang - # builds populate different fields: - # * ``routed_experts`` - current SGLang main / v0.5.9 - # * ``output_routed_experts`` - original R3 commit ``bed301a5`` - # (verl's R3 base commit) try: for attr_name in _ROUTED_EXPERTS_ATTRS: re_list = getattr(recv_obj, attr_name, None) diff --git a/areal/infra/rpc/guard/app.py b/areal/infra/rpc/guard/app.py index 958a283567..69a0a80f37 100644 --- a/areal/infra/rpc/guard/app.py +++ b/areal/infra/rpc/guard/app.py @@ -463,6 +463,8 @@ def configure(): return jsonify({"error": "Invalid JSON in request body"}), 400 if not s._configure_hooks: + # No hooks registered — no-op (guard-only mode) + logger.debug("Received /configure request (no-op)") return jsonify({"status": "ok"}) # Dispatch to all registered configure hooks @@ -608,7 +610,6 @@ def run_server( etcd3_addr=state.etcd3_addr or "localhost:2379", ) ) - logger.info("name_resolve reconfigured successfully") worker_id = f"{state.role}/{state.worker_index}" key = names.worker_discovery( From 805bda8439fc863ba2f753f0f4f7cc975b14f664 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Wed, 29 Apr 2026 11:55:20 +0800 Subject: [PATCH 067/112] refactor: restore local.py --- areal/infra/scheduler/local.py | 47 +++++++++++++--------------------- 1 file changed, 18 insertions(+), 29 deletions(-) diff --git a/areal/infra/scheduler/local.py b/areal/infra/scheduler/local.py index faa756e09f..eb5775db38 100644 --- a/areal/infra/scheduler/local.py +++ b/areal/infra/scheduler/local.py @@ -872,10 +872,9 @@ def create_workers(self, job: Job, *args, **kwargs) -> list[str]: self._read_log_tail(str(log_file)), ) - worker_ip = gethostip() worker = Worker( id=worker_id, - ip=worker_ip, + ip=gethostip(), worker_ports=[str(p) for p in ports], engine_ports=[], ) @@ -964,11 +963,7 @@ def get_workers(self, role: str, timeout: float | None = None) -> list[Worker]: ready_workers = set() while len(ready_workers) < len(workers): - elapsed = time.time() - start_time - if elapsed > timeout: - not_ready = [ - w.worker.id for w in workers if w.worker.id not in ready_workers - ] + if time.time() - start_time > timeout: raise WorkerTimeoutError( role, timeout, @@ -992,6 +987,7 @@ def get_workers(self, role: str, timeout: float | None = None) -> list[Worker]: if self._is_worker_ready(worker_info): ready_workers.add(worker_info.worker.id) + logger.debug(f"Worker {worker_info.worker.id} is ready") if len(ready_workers) < len(workers): time.sleep(self.health_check_interval) @@ -1005,33 +1001,15 @@ def _is_worker_ready(self, worker_info: WorkerInfo) -> bool: try: response = requests.get(url, timeout=2.0) - ready = response.status_code == 200 - if not ready: - logger.warning( - f"Worker health check failed: {url} -> {response.status_code}" - ) - return ready - except Exception as e: - logger.warning(f"Worker health check error: {url} -> {e}") + return response.status_code == 200 + except Exception: return False def _configure_worker(self, worker_info: WorkerInfo, worker_rank: int): - worker_id = worker_info.worker.id - wait_start = time.time() - last_log_time = wait_start while not self._is_worker_ready(worker_info): time.sleep(0.1) - now = time.time() - if now - last_log_time >= 5.0: - elapsed = now - wait_start - logger.warning( - f"Still waiting for worker " - f"'{worker_id}' after {elapsed:.0f}s " - f"(ip={worker_info.worker.ip}, " - f"ports={worker_info.worker.worker_ports})" - ) - last_log_time = now + worker_id = worker_info.worker.id port = int(worker_info.worker.worker_ports[0]) url = f"http://{format_hostport(worker_info.worker.ip, port)}/configure" @@ -1368,6 +1346,10 @@ async def create_engine( url = f"http://{format_hostport(worker_info.worker.ip, port)}/create_engine" try: + logger.debug( + f"Creating engine '{engine_name}' (class: {engine}) on worker '{worker_id}'" + ) + timeout = aiohttp.ClientTimeout(total=300.0) async with aiohttp.ClientSession( timeout=timeout, @@ -1381,6 +1363,9 @@ async def create_engine( ) as response: if response.status == 200: result = await response.json() + logger.debug( + f"Engine '{engine_name}' created successfully on worker '{worker_id}'" + ) return result.get("result") elif response.status == 400: # Import error or bad request @@ -1640,6 +1625,10 @@ async def async_call_engine( ) try: + logger.debug( + f"Async calling method '{method}' on worker '{worker_id}' (attempt {attempt})" + ) + timeo = aiohttp.ClientTimeout( total=http_timeout, sock_connect=http_timeout, connect=http_timeout ) @@ -1799,4 +1788,4 @@ def __del__(self): try: self.delete_workers() except Exception: - pass + pass \ No newline at end of file From 1ec270d827236dcc2dcebc9eb49e3d7fa531438c Mon Sep 17 00:00:00 2001 From: bingyechen Date: Wed, 29 Apr 2026 12:44:56 +0800 Subject: [PATCH 068/112] docs: remove --- areal/engine/megatron_engine_r3_patch.py | 4 ---- areal/engine/router_replay_patch.py | 19 ------------------- areal/engine/router_replay_utils.py | 5 ----- 3 files changed, 28 deletions(-) diff --git a/areal/engine/megatron_engine_r3_patch.py b/areal/engine/megatron_engine_r3_patch.py index 1a322863ba..fbbdec2f8f 100644 --- a/areal/engine/megatron_engine_r3_patch.py +++ b/areal/engine/megatron_engine_r3_patch.py @@ -284,10 +284,6 @@ def _r3_forward_backward_batch( If the data does not contain ``routed_experts``, delegates directly to the original method with zero overhead. - - **CRITICAL FIX**: Uses ``cu_seqlens`` from the padded micro-batch - (with per-sequence TP alignment) for packing replay data, ensuring - token ordering matches exactly what Megatron's transformer layers see. """ from areal.engine.router_replay_patch import RouterReplay, RouterReplayAction from areal.engine.router_replay_utils import ( diff --git a/areal/engine/router_replay_patch.py b/areal/engine/router_replay_patch.py index 6efb077663..297a00ea28 100644 --- a/areal/engine/router_replay_patch.py +++ b/areal/engine/router_replay_patch.py @@ -6,24 +6,6 @@ This eliminates the train/inference routing mismatch caused by weight staleness in asynchronous RL training. -Patches applied: -1. **RouterReplay class** -- self-contained class (no dependency on - megatron.core.transformer.moe.router_replay which does not exist in - megatron-core 0.16.0). -2. **TransformerConfig.__init__** -- accepts ``enable_routing_replay`` kwarg. -3. **TopKRouter.__init__** -- creates a ``RouterReplay`` instance per MoE layer. -4. **TopKRouter.routing** -- replaces routing logic to support record/replay. -5. **MoEAlltoAllTokenDispatcher.preprocess** -- fixes ``num_out_tokens`` when - replay indices contain duplicate expert assignments. - -Usage:: - - from areal.engine.router_replay_patch import apply_router_replay_patch - apply_router_replay_patch() # call once before model creation - - from areal.engine.router_replay_patch import remove_router_replay_patch - remove_router_replay_patch() # optional: for test cleanup - Ref some code from megatron or verl, adapted for AReaL. """ @@ -76,7 +58,6 @@ # =================================================================== # RouterReplayAction enum and RouterReplay class -# (self-contained -- no dependency on megatron.core.transformer.moe.router_replay) # =================================================================== diff --git a/areal/engine/router_replay_utils.py b/areal/engine/router_replay_utils.py index 3320029362..58df8595c0 100644 --- a/areal/engine/router_replay_utils.py +++ b/areal/engine/router_replay_utils.py @@ -218,11 +218,6 @@ def set_router_replay_data( ) -> None: """Scatter packed router top-k indices to SP ranks and update RouterReplay instances. - **CRITICAL**: This function must pack tokens using the EXACT same - cu_seqlens-based TP-aligned layout that Megatron uses for input_ids. - A different packing method (e.g., simple attention_mask concatenation) - causes token misalignment and near-random agreement rates. - The packing steps mirror ``pad_packed_tensor_dict`` in ``areal/utils/data.py``: 1. Use ``cu_seqlens`` to extract each sample's real tokens from the From 3f401703051e62c2f40ad12b0a453174fa1eb9d5 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Wed, 29 Apr 2026 13:50:47 +0800 Subject: [PATCH 069/112] style: log level --- areal/engine/megatron_engine_r3_patch.py | 28 +++------------- areal/engine/router_replay_patch.py | 7 ++-- areal/engine/router_replay_utils.py | 41 ++++-------------------- areal/engine/sglang_remote.py | 8 ++--- areal/trainer/ppo/actor_r3_patch.py | 2 +- 5 files changed, 19 insertions(+), 67 deletions(-) diff --git a/areal/engine/megatron_engine_r3_patch.py b/areal/engine/megatron_engine_r3_patch.py index fbbdec2f8f..808fb6ae38 100644 --- a/areal/engine/megatron_engine_r3_patch.py +++ b/areal/engine/megatron_engine_r3_patch.py @@ -61,7 +61,7 @@ def patch_megatron_engine_for_r3( logger.debug("[R3] Router replay not enabled; skipping engine patch.") return - logger.info("[R3] Patching MegatronEngine for Router Replay (R3).") + logger.info("[R3] Patching MegatronEngine for Router Replay.") # Mark and save original engine._r3_enabled = True @@ -73,7 +73,7 @@ def patch_megatron_engine_for_r3( _r3_forward_backward_batch, engine ) - logger.info("[R3] MegatronEngine patched successfully.") + logger.debug("[R3] MegatronEngine patched successfully.") # =================================================================== @@ -303,7 +303,7 @@ def _r3_forward_backward_batch( routed_experts_batch = self._r3_pending_routed_experts self._r3_pending_routed_experts = None # Consume it _from_side_channel = True - logger.info( + logger.debug( "[R3] Retrieved routed_experts from engine side-channel: shape=%s.", routed_experts_batch.shape, ) @@ -313,7 +313,7 @@ def _r3_forward_backward_batch( if hasattr(mb_list, "data") and isinstance(mb_list.data, dict): routed_experts_batch = mb_list.data.pop("routed_experts", None) if routed_experts_batch is not None: - logger.info( + logger.debug( "[R3] Retrieved routed_experts from mb_list.data (legacy path): " "shape=%s.", routed_experts_batch.shape, @@ -346,7 +346,7 @@ def _r3_forward_backward_batch( ) logger.info( - "[R3] R3 forward_backward: %d micro-batches, routed_experts shape=%s, " + "[R3] forward_backward_batch: %d micro-batches, routed_experts shape=%s, " "forward_only=%s", len(mb_list), routed_experts_batch.shape, @@ -376,10 +376,6 @@ def _r3_forward_backward_batch( # ------------------------------------------------------------------ RouterReplay.reset_agreement_stats() RouterReplay.set_global_router_replay_action(RouterReplayAction.REPLAY_FORWARD) - logger.debug( - "[R3] Set initial REPLAY_FORWARD action on %d router instances.", - len(RouterReplay.router_instances), - ) # ------------------------------------------------------------------ # 3. Wrap the MicroBatchList iterator @@ -456,13 +452,6 @@ def __next__(self): model_config, seq_align_to=_seq_align_to, ) - logger.debug( - "[R3] Replay setup OK for micro-batch %d: " - "original_re=%s, aligned_re=%s, cu_seqlens=%s " - "(seq_align_to=%d).", - idx, re.shape, aligned_re.shape, cu_seqlens.shape, - _seq_align_to, - ) except Exception: logger.warning( "[R3] Failed to setup replay for micro-batch %d.", @@ -520,12 +509,6 @@ def _r3_post_forward_hook(module, input, output): handle = model_chunk.register_forward_hook(_r3_post_forward_hook) hook_handles.append(handle) - logger.debug( - "[R3] Registered forward hooks on %d model chunks for " - "FORWARD->BACKWARD toggle.", - len(hook_handles), - ) - try: self._r3_original_forward_backward_batch( mb_list, process_output_fn, forward_only=forward_only @@ -555,4 +538,3 @@ def _r3_post_forward_hook(module, input, output): clear_router_replay() self._r3_per_mb_experts = None self._r3_mb_counter = 0 - logger.debug("[R3] forward_backward_batch cleanup complete.") diff --git a/areal/engine/router_replay_patch.py b/areal/engine/router_replay_patch.py index 297a00ea28..387565554e 100644 --- a/areal/engine/router_replay_patch.py +++ b/areal/engine/router_replay_patch.py @@ -485,12 +485,11 @@ def apply_router_replay_patch() -> None: """ global _PATCHES_APPLIED if _PATCHES_APPLIED: - logger.info("[R3] Router replay patches already applied; skipping.") + logger.debug("[R3] Router replay patches already applied; skipping.") return logger.info("[R3] Applying Router Replay patches...") - # Clear router instances to avoid state leakage between model inits. RouterReplay.router_instances.clear() _patch_transformer_config_init() @@ -499,7 +498,7 @@ def apply_router_replay_patch() -> None: _patch_alltoall_dispatcher_preprocess() _PATCHES_APPLIED = True - logger.info("[R3] All Router Replay patches applied successfully.") + logger.debug("[R3] All Router Replay patches applied successfully.") def remove_router_replay_patch() -> None: @@ -511,7 +510,7 @@ def remove_router_replay_patch() -> None: _undo_dispatcher_patch() RouterReplay.router_instances.clear() _PATCHES_APPLIED = False - logger.info("[R3] All Router Replay patches removed.") + logger.debug("[R3] All Router Replay patches removed.") # =================================================================== diff --git a/areal/engine/router_replay_utils.py b/areal/engine/router_replay_utils.py index 58df8595c0..49f356a78f 100644 --- a/areal/engine/router_replay_utils.py +++ b/areal/engine/router_replay_utils.py @@ -4,7 +4,8 @@ Handles the complete shape-transformation pipeline that converts rollout routing indices into the layout expected by Megatron-Core's RouterReplay: -1. **Left-padding removal** -- rollout batch is left-padded; training removes it. +1. **Right-padding to left-alignment** -- rollout batch is right-padded; + training uses left-aligned packed format. 2. **TP/SP splitting** -- sequence parallelism across tensor-model-parallel ranks. 3. **PP layer slicing** -- pipeline parallelism assigns different layers to ranks. 4. **Dense/MoE layer mapping** -- architectures with dense FFN layers before MoE. @@ -221,7 +222,7 @@ def set_router_replay_data( The packing steps mirror ``pad_packed_tensor_dict`` in ``areal/utils/data.py``: 1. Use ``cu_seqlens`` to extract each sample's real tokens from the - left-padded ``layers_topk_idx``. + left-aligned ``layers_topk_idx``. 2. Pack tokens contiguously with per-sequence TP alignment padding (each sequence padded to a multiple of ``seq_align_to``). 3. ``scatter_to_sequence_parallel_region`` to split across TP/SP ranks. @@ -230,9 +231,9 @@ def set_router_replay_data( Args: layers_topk_idx: ``(bs, max_seq_len, num_moe_layers, topk)`` -- the - replay data (left-padded, from rollout). After - ``_align_routed_experts_to_mask``, this is left-ALIGNED (real - tokens first, matching attention_mask convention). + replay data (left-aligned, real tokens first). After + ``_align_routed_experts_to_mask``, this matches the attention_mask + convention where real tokens occupy the leftmost positions. cu_seqlens: ``(bs+1,)`` or ``(bs+1+1,)`` -- cumulative sequence lengths from the PADDED micro-batch (after ``pad_packed_tensor_dict``). These define the actual token ordering that Megatron uses. @@ -302,36 +303,14 @@ def set_router_replay_data( for i in range(bs_re, n_seqs_in_cu): aligned_offset += aligned_lens[i] - logger.debug( - "[R3] set_router_replay_data: packed %d seqs into %d tokens " - "(TP-aligned with seq_align_to=%d, seq_lens=%s, aligned_lens=%s).", - min(n_seqs_in_cu, bs_re), - total_aligned, - seq_align_to, - seq_lens[:8], - aligned_lens[:8], - ) - # Step 2: Scatter to SP ranks packed = packed.to(device) tp_size = mpu.get_tensor_model_parallel_world_size() if tp_size > 1: from megatron.core.tensor_parallel import scatter_to_sequence_parallel_region local_tokens = scatter_to_sequence_parallel_region(packed) - logger.debug( - "[R3] set_router_replay_data: SP scatter tp_size=%d, " - "packed %s -> local_tokens %s.", - tp_size, - packed.shape, - local_tokens.shape, - ) else: local_tokens = packed - logger.debug( - "[R3] set_router_replay_data: tp_size=1, skipping SP scatter. " - "local_tokens=%s.", - local_tokens.shape, - ) # local_tokens: (local_tokens_count, num_layers, topk) # Step 3: Permute to (num_layers, local_tokens_count, topk) @@ -420,19 +399,11 @@ def setup_per_microbatch_replay_forward( vp_rank: Virtual pipeline stage rank override. seq_align_to: Per-sequence TP alignment factor. """ - logger.debug( - "[R3] setup_per_microbatch_replay_forward: " - "routed_experts=%s (dtype=%s), cu_seqlens=%s.", - routed_experts.shape, - routed_experts.dtype, - cu_seqlens.shape, - ) routed_experts = routed_experts.to(torch.int32) set_router_replay_data( routed_experts, cu_seqlens, tf_config, vp_rank, seq_align_to=seq_align_to, ) - logger.debug("[R3] Replay data distributed to router instances for micro-batch.") def setup_per_microbatch_replay_backward() -> None: diff --git a/areal/engine/sglang_remote.py b/areal/engine/sglang_remote.py index 19a55a30e7..bf6cf4d483 100644 --- a/areal/engine/sglang_remote.py +++ b/areal/engine/sglang_remote.py @@ -157,8 +157,8 @@ def parse_generation_response( routed_experts = None else: routed_experts = flat.reshape(num_sgl_token, -1) - logger.info( - "[R3-VERIFY] SGLang decoded routed_experts: " + logger.debug( + "[R3] SGLang decoded routed_experts: " "shape=%s, first3=%s, hash=%d", routed_experts.shape, routed_experts.flat[:3].tolist(), @@ -187,8 +187,8 @@ def parse_generation_response( routed_experts = None else: routed_experts = raw.reshape(num_sgl_token, -1) - logger.info( - "[R3-VERIFY] SGLang converted routed_experts: " + logger.debug( + "[R3] SGLang converted routed_experts: " "shape=%s, first3=%s, hash=%d", routed_experts.shape, routed_experts.flat[:3].tolist(), diff --git a/areal/trainer/ppo/actor_r3_patch.py b/areal/trainer/ppo/actor_r3_patch.py index ca991f75e0..c28f325872 100644 --- a/areal/trainer/ppo/actor_r3_patch.py +++ b/areal/trainer/ppo/actor_r3_patch.py @@ -91,7 +91,7 @@ def _ensure_tensor_routed_experts(data: dict[str, Any]) -> torch.Tensor | None: re_tensor = _resolve_to_tensor(re) if re_tensor is not None: - logger.info( + logger.debug( "[R3] routed_experts was %s (shape=%s); resolved to torch.Tensor " "(shape=%s, dtype=%s).", type(re).__name__, From 1449e9a7be44f38070c564b43ffe931740ba0aa3 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Wed, 29 Apr 2026 14:06:04 +0800 Subject: [PATCH 070/112] refactor(r3): remove metric --- areal/engine/megatron_engine_r3_patch.py | 10 - areal/trainer/ppo/actor_r3_patch.py | 557 +---------------------- areal/trainer/rl_trainer.py | 15 - 3 files changed, 3 insertions(+), 579 deletions(-) diff --git a/areal/engine/megatron_engine_r3_patch.py b/areal/engine/megatron_engine_r3_patch.py index 808fb6ae38..bc06ebb7c5 100644 --- a/areal/engine/megatron_engine_r3_patch.py +++ b/areal/engine/megatron_engine_r3_patch.py @@ -524,16 +524,6 @@ def _r3_post_forward_hook(module, input, output): # Harvest agreement stats BEFORE clearing replay state. _agreement = RouterReplay.harvest_agreement_stats() self._r3_last_agreement_stats = _agreement - if _agreement.get("n_samples", 0) > 0: - from areal.utils import stats_tracker - with stats_tracker.scope("r3"): - stats_tracker.scalar( - router_agreement_rate=_agreement["avg"], - router_agreement_rate_min=_agreement["min"], - router_agreement_rate_max=_agreement["max"], - router_agreement_n_samples=_agreement["n_samples"], - router_agreement_n_calls=_agreement["n_calls"], - ) clear_router_replay() self._r3_per_mb_experts = None diff --git a/areal/trainer/ppo/actor_r3_patch.py b/areal/trainer/ppo/actor_r3_patch.py index c28f325872..bc9faf3f7a 100644 --- a/areal/trainer/ppo/actor_r3_patch.py +++ b/areal/trainer/ppo/actor_r3_patch.py @@ -1,39 +1,8 @@ """ -MoE routing metrics and R3 logging helpers for the PPO actor. +R3 data-splitting helpers for the PPO actor. -Provides two categories of metrics: - -1. **R3 data stats** (``log_r3_data_stats``): Summary of the routed_experts - tensor shape, dtype, and basic coverage info. Logged when R3 is enabled. - -2. **MoE routing effectiveness metrics** (``log_moe_routing_metrics``): - SkyRL-style routing quality indicators that are useful for ANY MoE model, - regardless of whether R3 is enabled. These include: - - Routing entropy (per-layer and aggregated) - - Expert utilization balance (std dev of expert load) - - Data coverage ratio (fraction of samples with valid routing data) - - Top-1 expert concentration (how much traffic goes to most-used expert) - - Expert diversity (number of unique experts used per token) - -The key R3-specific effectiveness metrics are: - -1. **Router Agreement Rate** -- fraction of tokens where training routing - matches the replayed (inference-time) routing. Measures how effectively - R3 forces routing alignment. - -2. **Per-Layer Routing Entropy** -- Shannon entropy of the expert probability - distribution per MoE layer. Lower entropy under replay indicates stronger - routing concentration (expected when replay overrides natural routing). - -3. **Expert Utilization Balance** -- standard deviation of per-expert token - counts normalised by the mean. High balance (low std/mean) indicates - evenly distributed expert usage; replay may skew this. - -4. **Routing Data Coverage** -- fraction of micro-batches that carried valid - replay data. Should be 1.0 in a healthy R3 run. - -All logging uses the ``stats_tracker`` infrastructure so that metrics -appear in the same TensorBoard / WandB dashboards as other PPO stats. +Provides utilities for resolving ``routed_experts`` tensors and splitting +them across mini-batches for side-channel delivery to the training engine. """ from __future__ import annotations @@ -43,8 +12,6 @@ import torch -from areal.utils import stats_tracker - logger = logging.getLogger(__name__) @@ -75,82 +42,6 @@ def _resolve_to_tensor(obj: Any) -> torch.Tensor | None: return None -def _ensure_tensor_routed_experts(data: dict[str, Any]) -> torch.Tensor | None: - """Extract ``routed_experts`` from *data*, converting to Tensor if needed. - - Handles the case where SGLang returns routed_experts as a numpy array, - RTensor, or other array-like type instead of a ``torch.Tensor``. - Logs a warning when a conversion is performed so that upstream data - pipelines can be diagnosed. - """ - re = data.get("routed_experts") - if re is None: - return None - if isinstance(re, torch.Tensor): - return re - - re_tensor = _resolve_to_tensor(re) - if re_tensor is not None: - logger.debug( - "[R3] routed_experts was %s (shape=%s); resolved to torch.Tensor " - "(shape=%s, dtype=%s).", - type(re).__name__, - getattr(re, "shape", "unknown"), - re_tensor.shape, - re_tensor.dtype, - ) - else: - logger.warning( - "[R3] Failed to resolve routed_experts from %s to torch.Tensor.", - type(re).__name__, - ) - return re_tensor - - -def log_r3_data_stats( - data: dict[str, Any], - scope: str = "r3", -) -> None: - """Log summary statistics about the ``routed_experts`` tensor in a - training data dict. - - Called once per PPO update step (not per micro-batch) to avoid - log spam. - - Also computes a CORRECT per-step router agreement rate by comparing - inference routing (from ``routed_experts``) against recorded training - routing (from ``RouterReplay`` instances), excluding padding tokens. - This replaces the misleading per-layer hot-path metric that was - previously computed in ``router_replay_patch.py``. - - Args: - data: The training data dict that may contain ``"routed_experts"``. - scope: Stats-tracker scope prefix. - """ - re = _ensure_tensor_routed_experts(data) - if re is None: - return - - with stats_tracker.scope(scope): - if isinstance(re, torch.Tensor): - stats_tracker.scalar( - r3_batch_size=re.shape[0], - r3_seq_len=re.shape[1], - r3_num_layers=re.shape[2] if re.dim() >= 3 else 0, - r3_topk=re.shape[3] if re.dim() >= 4 else 0, - r3_dtype_bytes=re.element_size(), - r3_max_expert_id=re.max().item() if re.numel() > 0 else 0, - ) - - _log_r3_effectiveness_metrics(re) - - # Compute per-step agreement rate with padding exclusion. - # Following verl's approach: use attention_mask to identify - # real tokens, compute per-layer fractional agreement, report - # avg/min/max across layers. - _log_r3_agreement_rate(re, data) - - def split_routed_experts_for_minibatches( routed_experts: torch.Tensor, mb_list, @@ -180,7 +71,6 @@ def split_routed_experts_for_minibatches( else: reordered = routed_experts[forward_indices] - # Determine number of samples per mini-batch from mbs dicts result = [] offset = 0 for i, mb_dict in enumerate(mb_list.mbs): @@ -215,444 +105,3 @@ def _infer_mb_sample_count_from_dict( if ids is not None and hasattr(ids, "shape"): return ids.shape[0] return total_bs // n_mbs - - -def _log_r3_effectiveness_metrics( - routed_experts: torch.Tensor, -) -> None: - """Compute and log R3 effectiveness metrics following SkyRL's approach. - - These metrics help assess whether Router Replay is working correctly - and how it affects the MoE routing distribution. - - Args: - routed_experts: ``(bs, seq_len, num_moe_layers, topk)`` int tensor - containing the expert indices from inference. - """ - if routed_experts.dim() != 4 or routed_experts.numel() == 0: - return - - bs, seq_len, num_moe_layers, topk = routed_experts.shape - - try: - # --- Metric 1: Per-Layer Routing Entropy --- - # Measures the diversity of expert assignments per layer. - # Lower entropy = more concentrated routing. - # Under R3, this reflects the inference-time routing distribution. - _log_per_layer_routing_entropy(routed_experts, num_moe_layers, topk) - - # --- Metric 2: Expert Utilization Balance --- - # Measures how evenly tokens are distributed across experts. - # Coefficient of variation (std/mean) -- lower = more balanced. - _log_expert_utilization_balance(routed_experts, num_moe_layers) - - # --- Metric 3: Routing Data Coverage --- - # Fraction of (batch, layer) combinations with non-zero routing data. - _log_routing_data_coverage(routed_experts, bs, num_moe_layers) - - # --- Metric 4: Top-1 Expert Concentration --- - # How often the most popular expert is selected (per layer). - _log_top1_expert_concentration(routed_experts, num_moe_layers) - - except Exception: - logger.warning( - "[R3] Failed to compute R3 effectiveness metrics.", - exc_info=True, - ) - - -def _log_per_layer_routing_entropy( - routed_experts: torch.Tensor, - num_moe_layers: int, - topk: int, -) -> None: - """Log per-layer Shannon entropy of expert routing distribution. - - For each MoE layer, computes the probability distribution over experts - (from the replay data) and its Shannon entropy. Reports mean, min, - max across layers. - """ - bs, seq_len = routed_experts.shape[:2] - # Flatten batch and seq dimensions - flat = routed_experts.view(-1, num_moe_layers, topk) # (bs*seq_len, L, K) - num_tokens = flat.shape[0] - - if num_tokens == 0: - return - - # Determine number of experts from max index - num_experts = int(routed_experts.max().item()) + 1 - if num_experts <= 0: - return - - layer_entropies = [] - for layer_idx in range(num_moe_layers): - # Count expert occurrences for this layer across all tokens and topk slots - expert_ids = flat[:, layer_idx, :].reshape(-1).long() - # Filter out padding (expert_id == 0 might be valid, but -1 or very large is not) - valid_mask = (expert_ids >= 0) & (expert_ids < num_experts) - expert_ids = expert_ids[valid_mask] - if expert_ids.numel() == 0: - continue - - counts = torch.bincount(expert_ids, minlength=num_experts).float() - probs = counts / counts.sum() - # Shannon entropy: -sum(p * log(p)), with 0*log(0) = 0 - log_probs = torch.where(probs > 0, torch.log2(probs), torch.zeros_like(probs)) - entropy = -(probs * log_probs).sum().item() - layer_entropies.append(entropy) - - if layer_entropies: - mean_entropy = sum(layer_entropies) / len(layer_entropies) - min_entropy = min(layer_entropies) - max_entropy = max(layer_entropies) - # Maximum possible entropy for reference - max_possible = torch.log2(torch.tensor(float(num_experts))).item() - - stats_tracker.scalar( - r3_routing_entropy_mean=mean_entropy, - r3_routing_entropy_min=min_entropy, - r3_routing_entropy_max=max_entropy, - r3_routing_entropy_normalised=mean_entropy / max_possible - if max_possible > 0 - else 0, - r3_num_experts=num_experts, - ) - - -def _log_expert_utilization_balance( - routed_experts: torch.Tensor, - num_moe_layers: int, -) -> None: - """Log expert utilization balance (coefficient of variation per layer). - - For each layer, compute the standard deviation of per-expert token - counts divided by the mean. Aggregate across layers. - """ - flat = routed_experts.view(-1, num_moe_layers, routed_experts.shape[-1]) - num_experts = int(routed_experts.max().item()) + 1 - if num_experts <= 1: - return - - layer_cv_values = [] - for layer_idx in range(num_moe_layers): - expert_ids = flat[:, layer_idx, :].reshape(-1).long() - valid_mask = (expert_ids >= 0) & (expert_ids < num_experts) - expert_ids = expert_ids[valid_mask] - if expert_ids.numel() == 0: - continue - - counts = torch.bincount(expert_ids, minlength=num_experts).float() - mean_count = counts.mean() - if mean_count > 0: - cv = counts.std() / mean_count - layer_cv_values.append(cv.item()) - - if layer_cv_values: - stats_tracker.scalar( - r3_expert_util_cv_mean=sum(layer_cv_values) / len(layer_cv_values), - r3_expert_util_cv_max=max(layer_cv_values), - r3_expert_util_cv_min=min(layer_cv_values), - ) - - -def _is_dense_layer(re_layer: torch.Tensor) -> bool: - """Check if a layer's routing data is all-zero (i.e., a dense FFN layer). - - SGLang returns routed_experts across ALL transformer layers (including - dense layers). Dense layers have no MoE router, so their topk_ids are - all zeros. We detect this to exclude them from MoE-specific metrics. - - Args: - re_layer: ``(bs, seq_len, topk)`` routing data for one layer. - - Returns: - True if the layer has no valid routing data (dense layer). - """ - return re_layer.sum().item() == 0 - - -def _log_routing_data_coverage( - routed_experts: torch.Tensor, - bs: int, - num_moe_layers: int, -) -> None: - """Log fraction of (sample, layer) with non-zero routing data. - - Skips dense (all-zero) layers so the metric reflects true MoE layer - coverage. When SGLang returns routing data for all transformer layers - (including dense FFN layers), those dense layers would drag coverage - down to (num_moe_layers / num_total_layers), e.g. 26/27 = 0.96296 - for Moonlight-16B-A3B. - """ - has_data = (routed_experts.sum(dim=(1, 3)) > 0).float() # (bs, num_layers) - - moe_layer_mask = [] - for layer_idx in range(num_moe_layers): - layer_re = routed_experts[:, :, layer_idx, :] - is_dense = _is_dense_layer(layer_re) - moe_layer_mask.append(not is_dense) - - n_moe_layers = sum(moe_layer_mask) - if n_moe_layers == 0: - stats_tracker.scalar(r3_routing_data_coverage=0.0) - return - - moe_has_data = has_data[:, moe_layer_mask] - coverage = moe_has_data.mean().item() - stats_tracker.scalar( - r3_routing_data_coverage=coverage, - r3_num_moe_layers=n_moe_layers, - r3_num_dense_layers=num_moe_layers - n_moe_layers, - ) - - -def _log_top1_expert_concentration( - routed_experts: torch.Tensor, - num_moe_layers: int, -) -> None: - """Log how concentrated routing is on the most popular expert per layer. - - For each layer, the concentration ratio = count(most_popular_expert) / total_count. - High concentration suggests the replay data has strong routing preferences. - """ - flat = routed_experts.view(-1, num_moe_layers, routed_experts.shape[-1]) - num_experts = int(routed_experts.max().item()) + 1 - if num_experts <= 0: - return - - layer_concentrations = [] - for layer_idx in range(num_moe_layers): - expert_ids = flat[:, layer_idx, :].reshape(-1).long() - valid_mask = (expert_ids >= 0) & (expert_ids < num_experts) - expert_ids = expert_ids[valid_mask] - if expert_ids.numel() == 0: - continue - - counts = torch.bincount(expert_ids, minlength=num_experts) - max_count = counts.max().item() - total = counts.sum().item() - if total > 0: - layer_concentrations.append(max_count / total) - - if layer_concentrations: - stats_tracker.scalar( - r3_top1_expert_concentration_mean=sum(layer_concentrations) - / len(layer_concentrations), - r3_top1_expert_concentration_max=max(layer_concentrations), - ) - - -def _log_r3_agreement_rate( - routed_experts: torch.Tensor, - data: dict[str, Any], -) -> None: - """Log R3 router agreement rate. - - The actual per-layer agreement (comparing training-time natural routing - vs. replayed inference routing) is now computed on every REPLAY_FORWARD - call inside ``router_replay_patch.py`` and logged to ``stats_tracker`` - from ``megatron_engine_r3_patch.py`` after each forward-backward pass. - - This function is intentionally a no-op to avoid reporting the misleading - metric that was here before (fraction of real tokens with non-zero expert - assignments, which was always ~1.0 for MoE layers). - - The function signature is preserved for backward compatibility. - """ - # Agreement rate is now reported from the engine layer. - pass - - -def compute_router_agreement_rate( - replay_indices: torch.Tensor, - actual_indices: torch.Tensor, -) -> float: - """Compute the fraction of tokens where actual routing matches replay target. - - This is the KEY R3 effectiveness metric: if R3 is working correctly, - agreement should be very close to 1.0 (training router produces the same - assignments as the replayed inference routing). - - Args: - replay_indices: ``(num_tokens, topk)`` target expert indices from replay. - actual_indices: ``(num_tokens, topk)`` actual expert indices from training. - - Returns: - Agreement rate in [0, 1]. Returns -1.0 if inputs are invalid. - """ - if replay_indices is None or actual_indices is None: - return -1.0 - if replay_indices.shape != actual_indices.shape: - logger.warning( - "[R3] Agreement rate: shape mismatch replay=%s vs actual=%s.", - replay_indices.shape, - actual_indices.shape, - ) - return -1.0 - - # Sort topk indices per token to handle different ordering - replay_sorted = replay_indices.sort(dim=-1).values - actual_sorted = actual_indices.sort(dim=-1).values - matches = (replay_sorted == actual_sorted).all(dim=-1).float() - return matches.mean().item() - - -def log_moe_routing_metrics( - data: dict[str, Any], - scope: str = "moe_routing", -) -> None: - """Log MoE routing effectiveness metrics for ANY MoE model. - - Computes routing quality indicators from the - ``routed_experts`` tensor. These metrics help diagnose routing - quality issues (expert collapse, load imbalance, etc.) and are - useful even without R3. - - Args: - data: Training data dict containing ``"routed_experts"`` - of shape ``(bs, seq_len, num_moe_layers, topk)``. - scope: Stats-tracker scope prefix. - """ - re = _ensure_tensor_routed_experts(data) - if re is None: - return - if not isinstance(re, torch.Tensor) or re.dim() < 4: - return - - bs, seq_len, num_layers, topk = re.shape - attn_mask = _resolve_to_tensor(data.get("attention_mask")) - - with stats_tracker.scope(scope): - # ------------------------------------------------------------------ - # 1. Data coverage: fraction of samples with non-zero routing data - # Skip dense (all-zero) layers. - # ------------------------------------------------------------------ - moe_layer_indices = [] - n_dense_layers = 0 - for layer_idx in range(num_layers): - if _is_dense_layer(re[:, :, layer_idx, :]): - n_dense_layers += 1 - else: - moe_layer_indices.append(layer_idx) - - n_moe_layers = len(moe_layer_indices) - if n_moe_layers == 0: - stats_tracker.scalar( - data_coverage=0.0, - num_moe_layers=0, - num_dense_layers=n_dense_layers, - ) - return - - moe_re = re[:, :, moe_layer_indices, :] - has_routing = (moe_re.sum(dim=(1, 2, 3)) != 0).float() - coverage = has_routing.mean().item() - stats_tracker.scalar( - data_coverage=coverage, - num_moe_layers=n_moe_layers, - num_dense_layers=n_dense_layers, - ) - - # ------------------------------------------------------------------ - # 2. Expert utilization and load balance (per-layer, MoE only) - # ------------------------------------------------------------------ - if attn_mask is not None and attn_mask.shape[1] == seq_len: - real_mask = attn_mask.bool() # (bs, seq_len) - else: - if attn_mask is not None: - logger.debug( - "[R3] attn_mask seq_len (%d) != routed_experts seq_len (%d); " - "falling back to all-ones mask (expected: SGLang uses " - "prompt+completion-1, training uses packed seqlen).", - attn_mask.shape[1], - seq_len, - ) - real_mask = torch.ones(bs, seq_len, dtype=torch.bool, device=re.device) - - token_mask = real_mask.unsqueeze(-1).unsqueeze(-1).expand_as(moe_re) - max_expert_id = moe_re[token_mask].max().item() if token_mask.any() else 0 - num_experts = int(max_expert_id) + 1 - if num_experts < 2: - stats_tracker.scalar( - num_experts=num_experts, - insufficient_data=1, - ) - return - - entropy_sum = 0.0 - balance_sum = 0.0 - top1_concentration_sum = 0.0 - diversity_sum = 0.0 - valid_layers = 0 - - for layer_idx in moe_layer_indices: - layer_re = re[:, :, layer_idx, :] - layer_mask = real_mask.unsqueeze(-1).expand_as(layer_re) - valid_experts = layer_re[layer_mask] - - if valid_experts.numel() == 0: - continue - - valid_layers += 1 - - expert_counts = torch.bincount( - valid_experts.long().clamp(0, num_experts - 1), - minlength=num_experts, - ).float() - total_assignments = expert_counts.sum() - - if total_assignments == 0: - continue - - expert_probs = expert_counts / total_assignments - - log_probs = torch.log(expert_probs + 1e-10) - entropy = -(expert_probs * log_probs).sum().item() - max_entropy = torch.log(torch.tensor(float(num_experts))).item() - normalized_entropy = entropy / max_entropy if max_entropy > 0 else 0.0 - entropy_sum += normalized_entropy - - load_std = expert_probs.std().item() - load_mean = expert_probs.mean().item() - balance = load_std / (load_mean + 1e-10) - balance_sum += balance - - top1_ratio = expert_probs.max().item() - top1_concentration_sum += top1_ratio - - unique_experts_used = (expert_counts > 0).sum().item() - diversity = unique_experts_used / num_experts - diversity_sum += diversity - - if valid_layers > 0: - stats_tracker.scalar( - num_experts=num_experts, - num_moe_layers=n_moe_layers, - routing_entropy=entropy_sum / valid_layers, - expert_load_imbalance_cv=balance_sum / valid_layers, - top1_expert_concentration=top1_concentration_sum / valid_layers, - expert_diversity=diversity_sum / valid_layers, - valid_moe_layers=valid_layers, - ) - else: - stats_tracker.scalar( - num_experts=num_experts, - num_moe_layers=n_moe_layers, - valid_moe_layers=0, - ) - - -def strip_routed_experts_before_loss( - data: dict[str, Any], -) -> dict[str, Any]: - """Remove ``routed_experts`` from the data dict before the loss function. - - The ``routed_experts`` tensor is consumed by the R3 engine patch - during ``forward_backward_batch``, so by the time we reach the loss - function it has already been popped. This function is a safety net. - - Returns the data dict (modified in-place). - """ - data.pop("routed_experts", None) - return data diff --git a/areal/trainer/rl_trainer.py b/areal/trainer/rl_trainer.py index 7a5c2a4790..80351ca690 100644 --- a/areal/trainer/rl_trainer.py +++ b/areal/trainer/rl_trainer.py @@ -723,21 +723,6 @@ def train( args={"global_step": global_step}, ), ): - # MoE routing metrics: Log for ALL MoE models when - # routed_experts data is available in the trajectory. - # R3 data stats are logged only when R3 is enabled. - for traj in adv_batch: - if "routed_experts" in traj: - from areal.trainer.ppo.actor_r3_patch import ( - log_moe_routing_metrics, - log_r3_data_stats, - ) - - log_moe_routing_metrics(traj) - if getattr(self.config.rollout, "return_routed_experts", False): - log_r3_data_stats(traj) - break # Log once per batch, not per trajectory - self.actor.ppo_update(adv_batch) self.actor.step_lr_scheduler() self.actor.get_device_stats().log("ppo update") From 515312c8918b300d4e8eb3fc858a7c20a2883d1c Mon Sep 17 00:00:00 2001 From: bingyechen Date: Wed, 29 Apr 2026 14:17:27 +0800 Subject: [PATCH 071/112] refactor(engine): remove RouterReplay statics --- areal/engine/megatron_engine_r3_patch.py | 5 -- areal/engine/router_replay_patch.py | 66 ------------------------ 2 files changed, 71 deletions(-) diff --git a/areal/engine/megatron_engine_r3_patch.py b/areal/engine/megatron_engine_r3_patch.py index bc06ebb7c5..413ccb8c3a 100644 --- a/areal/engine/megatron_engine_r3_patch.py +++ b/areal/engine/megatron_engine_r3_patch.py @@ -374,7 +374,6 @@ def _r3_forward_backward_batch( # ------------------------------------------------------------------ # 2b. Set initial replay action to REPLAY_FORWARD. # ------------------------------------------------------------------ - RouterReplay.reset_agreement_stats() RouterReplay.set_global_router_replay_action(RouterReplayAction.REPLAY_FORWARD) # ------------------------------------------------------------------ @@ -521,10 +520,6 @@ def _r3_post_forward_hook(module, input, output): # class swap done above). The original class was never modified. mb_list.__class__ = _r3_original_mb_list_class - # Harvest agreement stats BEFORE clearing replay state. - _agreement = RouterReplay.harvest_agreement_stats() - self._r3_last_agreement_stats = _agreement - clear_router_replay() self._r3_per_mb_experts = None self._r3_mb_counter = 0 diff --git a/areal/engine/router_replay_patch.py b/areal/engine/router_replay_patch.py index 387565554e..8347735807 100644 --- a/areal/engine/router_replay_patch.py +++ b/areal/engine/router_replay_patch.py @@ -84,13 +84,6 @@ class RouterReplay: # Set by the engine patch before forward_backward_func. pp_size: int = 1 - # Class-level agreement accumulator. - # Collects per-call agreement rates during REPLAY_FORWARD to provide - # an accurate R3 effectiveness metric every training step. - _agreement_matches: int = 0 - _agreement_total: int = 0 - _agreement_per_call: list = [] - # ------------------------------------------------------------------ # Class-level (static) helpers # ------------------------------------------------------------------ @@ -136,37 +129,6 @@ def clear_global_router_replay_action() -> None: for r in RouterReplay.router_instances: r.clear_router_replay_action() - @classmethod - def reset_agreement_stats(cls) -> None: - """Reset the agreement rate accumulator before a training step.""" - cls._agreement_matches = 0 - cls._agreement_total = 0 - cls._agreement_per_call = [] - - reset_agreement_accumulator = reset_agreement_stats - - @classmethod - def get_agreement_rate(cls) -> float: - if cls._agreement_total == 0: - return -1.0 - return cls._agreement_matches / cls._agreement_total - - @classmethod - def harvest_agreement_stats(cls) -> dict: - """Harvest accumulated agreement statistics and reset. - - Returns a dict with keys: avg, min, max, n_samples, n_calls. - """ - result = { - "avg": cls.get_agreement_rate(), - "min": min(cls._agreement_per_call) if cls._agreement_per_call else -1.0, - "max": max(cls._agreement_per_call) if cls._agreement_per_call else -1.0, - "n_samples": cls._agreement_total, - "n_calls": len(cls._agreement_per_call), - } - cls.reset_agreement_stats() - return result - def __init__(self) -> None: self.target_topk_idx: torch.Tensor | None = None self.recorded_topk_idx: torch.Tensor | None = None @@ -260,34 +222,6 @@ def compute_topk(scores, topk, num_groups=None, group_topk=None): top_indices = top_indices.to(scores.device) probs = scores.gather(1, top_indices) - # --- Router Agreement Rate: compare replay vs natural routing --- - # Compute what the router would have chosen WITHOUT replay - # (no grad to avoid interfering with the backward graph). - # NOTE: Exclude padding tokens from agreement computation. - # Padding tokens have all-zero replay indices and would - # artificially drag down agreement rates since their - # natural routing is essentially random. - try: - with torch.no_grad(): - _, natural_indices = _compute_topk( - scores, topk, num_groups=num_groups, group_topk=group_topk - ) - non_padding_mask = (top_indices != 0).any(dim=-1) - replay_sorted = top_indices.sort(dim=-1).values - natural_sorted = natural_indices.sort(dim=-1).values - matches = (replay_sorted == natural_sorted).all(dim=-1) - if non_padding_mask.any(): - masked_matches = matches[non_padding_mask] - n_matched = int(masked_matches.sum().item()) - n_total = int(masked_matches.numel()) - RouterReplay._agreement_matches += n_matched - RouterReplay._agreement_total += n_total - if n_total > 0: - RouterReplay._agreement_per_call.append( - n_matched / n_total - ) - except Exception: - logger.debug("[R3] Agreement rate computation failed.", exc_info=True) return probs, top_indices elif routing_action == RouterReplayAction.REPLAY_BACKWARD: From cfe2b31ae85e4be4024fb7b7d23cbb4cdf80c65c Mon Sep 17 00:00:00 2001 From: bingyechen Date: Wed, 29 Apr 2026 15:20:07 +0800 Subject: [PATCH 072/112] feat: add docs and tests --- docs/en/_toc.yml | 1 + docs/en/algorithms/router_replay.md | 159 ++++++ docs/zh/_toc.yml | 1 + docs/zh/algorithms/router_replay.md | 133 +++++ tests/test_router_replay.py | 411 ++++++++++++++ tests/test_router_replay_e2e.py | 220 ++++++++ .../torchrun/run_router_replay_distributed.py | 501 ++++++++++++++++++ 7 files changed, 1426 insertions(+) create mode 100644 docs/en/algorithms/router_replay.md create mode 100644 docs/zh/algorithms/router_replay.md create mode 100644 tests/test_router_replay.py create mode 100644 tests/test_router_replay_e2e.py create mode 100644 tests/torchrun/run_router_replay_distributed.py diff --git a/docs/en/_toc.yml b/docs/en/_toc.yml index 845032dd7d..edad8395a4 100644 --- a/docs/en/_toc.yml +++ b/docs/en/_toc.yml @@ -38,6 +38,7 @@ parts: - file: algorithms/grpo_series - file: algorithms/m2po - file: algorithms/prox_approx + - file: algorithms/router_replay - caption: Reference chapters: - file: reference/checkpointing diff --git a/docs/en/algorithms/router_replay.md b/docs/en/algorithms/router_replay.md new file mode 100644 index 0000000000..0bc020515a --- /dev/null +++ b/docs/en/algorithms/router_replay.md @@ -0,0 +1,159 @@ +# Rollout Routing Replay (R3) for MoE Models + +Last updated: Apr 29, 2026 + +## Overview + +In asynchronous RL for Mixture-of-Experts (MoE) models, the policy that generates +rollouts (served by SGLang) and the policy that is being trained (driven by +Megatron-LM) may differ by one or more parameter versions. Since the router is a +*learned* sparse gate, even small weight drift can send the same token to different +experts between inference and training, producing a **train/inference routing +mismatch** that corrupts importance-sampling ratios and destabilises optimisation. + +**Rollout Routing Replay (R3)** eliminates this mismatch by: + +1. Recording the per-token expert assignments emitted by the inference engine for + every decoded token. +2. Re-using (*replaying*) those exact expert assignments during the training + forward / backward pass in place of the routing computed from current weights. + +R3 is inspired by the implementation in +[verl](https://github.com/volcengine/verl) and has been adapted for AReaL's +Megatron backend + SGLang bridge-mode inference service. + +## Supported Configurations + +| Dimension | Supported | Notes | +|---|---|---| +| Training backend | Megatron-LM (`MegatronEngine`) | FSDP engine is **not** supported. | +| Inference backend | SGLang 0.5.9 (bridge mode) | vLLM not supported. | +| Tensor Parallel (**TP**) | ✅ | Uses `scatter_to_sequence_parallel_region` to distribute packed router indices to SP ranks. | +| Expert Parallel (**EP**) | ✅ | Patched `MoEAlltoAllTokenDispatcher.preprocess` recomputes `num_out_tokens = routing_map.sum()` so that the dropless path stays correct when replay zeroes padding rows. | +| Pipeline Parallel (**PP**) | ✅ | `RouterReplayHelper.get_micro_batch_router_list` slices `RouterReplay.router_instances` according to the current PP rank's `(layer_offset, num_layers)`. | +| Virtual Pipeline Parallel (**VPP**) | ✅ | Same helper honours `virtual_pipeline_model_parallel_size` and iterates over VP stages. | +| Context Parallel (**CP**) | ⚠️ Experimental | `seq_align_to = tp_size * cp_size * 2` is applied when `cp_size > 1`; exercised only via unit tests, not covered by the provided E2E fixtures. | +| Data Parallel (**DP**) | ✅ | R3 runs independently per DP replica; no cross-DP communication is added. | +| Dense + MoE hybrid layers | ✅ | `is_moe_layer()` uses `moe_layer_freq` / `first_k_dense_replace` so dense layers are skipped from replay. | +| Role | Actor only | `config.actor.megatron.enable_router_replay` is set exclusively on the actor engine; Critic / Reference / Teacher engines are unaffected. | +| Capacity factor | `moe_expert_capacity_factor is None` (dropless) | Replay only overrides `num_out_tokens` on the dropless path, matching verl's guard. | +| FP8 / quantisation padding | ❌ | Replay is skipped when `moe_router_padding_for_fp8` or `moe_router_padding_for_quantization` is enabled to preserve FP8 dispatch correctness. | +| Vision / multimodal models | ❌ | No hooks in the VLM path. | + +## How to Enable R3 + +R3 is driven by a single rollout flag; everything else is wired automatically by +`areal/trainer/rl_trainer.py`. + +```yaml +rollout: + # Request per-token routed expert indices from SGLang. + return_routed_experts: true + +actor: + backend: "megatron:(attn:d1p1t4|ffn:d1p1t1e4)" # TP=4, EP=4 + # actor.megatron.enable_router_replay is forced to True + # automatically when rollout.return_routed_experts=true. + +sglang: + # R3 relies on per-token tokens being aligned with the routing + # output. The trainer forces skip_tokenizer_init=True at startup; + # declaring it here makes the intent explicit. + skip_tokenizer_init: true + enable_return_routed_experts: true +``` + +At trainer startup (`RLTrainer.__init__`): + +1. `rollout.return_routed_experts=True` causes + `config.actor.megatron.enable_router_replay` to be set to `True`. +2. `num_moe_layers` and `topk` are auto-resolved from the HuggingFace config + (`num_experts_per_tok`, `num_hidden_layers`, `moe_layer_freq`, + `first_k_dense_replace`) by `resolve_r3_moe_config()`. +3. `sglang.skip_tokenizer_init` is forced to `True` (warning printed if the user + set it to `False`) to prevent tokenizer round-trip token shifts that would + break per-token routing alignment. +4. The SGLang bridge entrypoint + (`areal/experimental/inference_service/sglang/launch_server.py`) calls + `apply_sglang_r3_patch()` so that `TokenizerManager._handle_batch_output` + base64-encodes the `routed_experts` tensor before FastAPI serialisation. +5. On the training side, `MegatronEngine.initialize()` calls + `apply_router_replay_patch()` (monkey-patches `TransformerConfig.__init__`, + `TopKRouter.__init__`, `TopKRouter.routing` and + `MoEAlltoAllTokenDispatcher.preprocess`) **before** model creation, and then + wraps the engine with `patch_megatron_engine_for_r3()`. + +## Pipeline Overview + +``` +┌────────────────────────────┐ ┌──────────────────────────────┐ +│ SGLang inference server │ │ MegatronEngine (actor) │ +│ │ │ │ +│ generate_logprobs() │ │ forward_backward_batch() │ +│ └─ routed_experts tensor │───▶ │ ├─ REPLAY_FORWARD on │ +│ (base64 over HTTP) │ │ │ each microbatch │ +└────────────────────────────┘ │ ├─ post-forward hook switches │ + │ │ to REPLAY_BACKWARD │ + │ └─ clear_router_replay() │ + └──────────────────────────────┘ +``` + +### Key data structures + +| Object | Purpose | +|---|---| +| `RouterReplay` (per MoE layer) | Holds the replay indices (`target_topk_idx`), recording buffer (`recorded_topk_idx`), and current `RouterReplayAction`. | +| `RouterReplay.router_instances` | Class-level list, one entry per MoE layer *on the local rank*. Cleared each time `apply_router_replay_patch()` is called. | +| `RouterReplayAction` | Enum: `RECORD`, `REPLAY_FORWARD`, `REPLAY_BACKWARD`. | +| `RouterReplayHelper.get_micro_batch_router_list()` | Returns the subset of `router_instances` assigned to the current `(pp_rank, vp_stage)`. | +| `setup_per_microbatch_replay_forward()` | Called before each micro-batch forward: aligns rollout-format `routed_experts` to the training token layout, packs with `cu_seqlens`, scatters to SP ranks, and distributes to the per-layer `RouterReplay` instances. | + +### Correctness notes + +* **`num_out_tokens` override.** Megatron-Core 0.16.0's dropless branch of + `MoEAlltoAllTokenDispatcher.preprocess` sets + `num_out_tokens = routing_map.size(0) * moe_router_topk` as a static value. + When R3 zeroes padding rows in `routing_map`, that static value overcounts, + so the patched preprocess computes `num_out_tokens = int(routing_map.sum().item())` + on the dropless path. The ~3,500 `.item()` syncs per training step are + negligible compared to MoE compute. +* **Per-instance `__class__` swap.** The micro-batch iterator wraps + `MicroBatchList` with a dynamic subclass assigned via `mb_list.__class__`, + not by mutating the shared class, so concurrent engines (e.g. critic) are + not affected. +* **Left-align from right-padded rollouts.** `_align_routed_experts_to_mask()` + converts the rollout tensor from `(bs, batch_max_seqlen, L, K)` right-padded + format to a training-oriented left-aligned layout using `cu_seqlens`. +* **Silent-drop removed.** When a micro-batch cannot be exactly split by + `bs // n_mbs`, R3 raises instead of silently trimming rows. + +## Minimal Example + +See `examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron.yaml` for the reference +Moonlight-16B-A3B configuration (PP=2, TP=4, EP=4, 8 GPUs). Launch: + +```bash +python3 examples/math/gsm8k_rl.py \ + --config examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron.yaml \ + scheduler.type=local +``` + +On a single-node 8×H200 system the `*_h20.yaml` variant runs with PP=1, +TP=4, EP=4 and `max_tokens_per_mb=10240`. + +## Troubleshooting + +| Symptom | Cause | Fix | +|---|---|---| +| `[R3] Number of replay tensors (...) does not match number of router instances (...)` | MoE layer count resolved from HF config differs from Megatron's per-rank layer count (usually due to `first_k_dense_replace` / `moe_layer_freq` mismatch or custom pipeline layout). | Verify `num_hidden_layers`, `first_k_dense_replace`, and `moe_layer_freq` in the model's `config.json` and that `pipeline_model_parallel_layout` (if set) matches the MoE layer count. | +| SGLang returns `routed_experts: {}` (empty dict) | Inference server was started without the R3 patch. | Ensure you are using the bridge entrypoint `areal.experimental.inference_service.sglang.launch_server`; it installs `apply_sglang_r3_patch()` automatically. | +| `moe_router_padding_for_fp8=True` + R3 | R3 is intentionally disabled on FP8 padding paths. | Either turn off FP8 router padding or disable `rollout.return_routed_experts`. | +| Critic does not pick up R3 | By design; only the actor is patched. | If a future use-case needs MoE critic replay, extend `rl_trainer._amend_xccl_weight_update_envvar` and `MegatronEngine._r3_enabled` plumbing. | + +## References + +* PR [#1207](https://github.com/inclusionAI/AReaL/pull/1207) — `[WIP]feat: add router replay for megatron engine`. +* verl router replay: + [`volcengine/verl`](https://github.com/volcengine/verl) (`verl/workers/**/*router_replay*`). +* Megatron-Core MoE parallel folding: + [NVIDIA/Megatron-LM MoE README](https://github.com/NVIDIA/Megatron-LM/tree/main/megatron/core/transformer/moe). diff --git a/docs/zh/_toc.yml b/docs/zh/_toc.yml index 16f9e7b713..dad5a7293c 100644 --- a/docs/zh/_toc.yml +++ b/docs/zh/_toc.yml @@ -38,6 +38,7 @@ parts: - file: algorithms/grpo_series - file: algorithms/m2po - file: algorithms/prox_approx + - file: algorithms/router_replay - caption: 参考 chapters: - file: reference/checkpointing diff --git a/docs/zh/algorithms/router_replay.md b/docs/zh/algorithms/router_replay.md new file mode 100644 index 0000000000..9073496880 --- /dev/null +++ b/docs/zh/algorithms/router_replay.md @@ -0,0 +1,133 @@ +# MoE 路由回放(R3, Rollout Routing Replay) + +最后更新:2026-04-29 + +## 背景 + +在 MoE 模型的异步强化学习中,负责 rollout 采样(SGLang)的策略与正在训练的 +策略(Megatron-LM)往往相差一个或多个版本。由于 MoE 路由器是"学习到的稀疏 +门控",微小的权重漂移都会让同一个 token 在推理和训练时被送到不同的专家, +造成 **训练/推理路由不一致**,进而破坏 importance sampling 的比值、导致优化 +不稳定。 + +**Rollout Routing Replay(R3)** 通过以下两个步骤消除这种不一致: + +1. **记录**:在推理阶段记录每个 token 的专家分配结果。 +2. **回放**:在训练前/反向阶段使用完全相同的专家分配替换由当前权重计算得到的 + 路由结果。 + +R3 参考了 [verl](https://github.com/volcengine/verl) 的实现,并在 AReaL +仓库中被适配到 Megatron 训练后端 + SGLang bridge 模式的推理服务。 + +## 支持矩阵 + +| 维度 | 是否支持 | 说明 | +|---|---|---| +| 训练后端 | Megatron-LM(`MegatronEngine`) | 不支持 FSDP。 | +| 推理后端 | SGLang 0.5.9(bridge 模式) | 不支持 vLLM。 | +| 张量并行(TP) | ✅ | 通过 `scatter_to_sequence_parallel_region` 把打包后的路由索引分发到 SP 各 rank。 | +| 专家并行(EP) | ✅ | 补丁后的 `MoEAlltoAllTokenDispatcher.preprocess` 改用 `num_out_tokens = routing_map.sum()`,保证回放清零 padding 行后 dropless 路径依然正确。 | +| 流水并行(PP) | ✅ | `RouterReplayHelper.get_micro_batch_router_list` 根据当前 PP rank 的 `(layer_offset, num_layers)` 对 `RouterReplay.router_instances` 切片。 | +| 虚拟流水并行(VPP) | ✅ | 同一个 helper 会遍历 `virtual_pipeline_model_parallel_size` 指定的各 VP stage。 | +| 上下文并行(CP) | ⚠️ 实验性 | 当 `cp_size > 1` 时使用 `seq_align_to = tp_size * cp_size * 2`;本次只覆盖单元测试,端到端尚未验证。 | +| 数据并行(DP) | ✅ | 每个 DP 副本独立运行 R3,不引入跨 DP 通信。 | +| Dense + MoE 混合层 | ✅ | `is_moe_layer()` 使用 `moe_layer_freq` / `first_k_dense_replace` 识别并跳过 dense 层。 | +| 角色 | 仅 Actor | `config.actor.megatron.enable_router_replay` 仅在 actor 上被置为 True,Critic / Ref / Teacher 不受影响。 | +| Capacity factor | 仅 `moe_expert_capacity_factor is None`(dropless) | 和 verl 的 guard 一致,`num_out_tokens` 覆盖仅作用于 dropless 分支。 | +| FP8 / 量化 padding | ❌ | 当 `moe_router_padding_for_fp8` 或 `moe_router_padding_for_quantization` 开启时跳过 R3,以保持 FP8 dispatch 正确性。 | +| 视觉 / 多模态模型 | ❌ | VLM 路径未接入钩子。 | + +## 如何开启 + +R3 由单一 rollout 开关驱动,其余都会在 `areal/trainer/rl_trainer.py` 中自动串起来。 + +```yaml +rollout: + return_routed_experts: true # 让 SGLang 返回每 token 路由索引 + +actor: + backend: "megatron:(attn:d1p1t4|ffn:d1p1t1e4)" # TP=4, EP=4 + # actor.megatron.enable_router_replay 会被自动设为 True + +sglang: + # R3 需要保证 token 序列与路由结果对齐;trainer 强制 + # skip_tokenizer_init=True,这里显式声明以表明意图。 + skip_tokenizer_init: true + enable_return_routed_experts: true +``` + +启动后的自动串联逻辑: + +1. `rollout.return_routed_experts=True` 令 + `config.actor.megatron.enable_router_replay = True`。 +2. `num_moe_layers` / `topk` 由 `resolve_r3_moe_config()` 从 HF config( + `num_experts_per_tok`、`num_hidden_layers`、`moe_layer_freq`、 + `first_k_dense_replace`)自动解析。 +3. `sglang.skip_tokenizer_init` 被强制置为 `True`(若用户设为 False 会打印 + warning),以避免 tokenizer 往返造成的 token shift 破坏对齐。 +4. SGLang bridge 入口 + (`areal/experimental/inference_service/sglang/launch_server.py`) + 在启动时调用 `apply_sglang_r3_patch()`,让 + `TokenizerManager._handle_batch_output` 在 FastAPI 序列化前把 + `routed_experts` 张量按 base64 编码。 +5. 训练侧 `MegatronEngine.initialize()` 在模型构造 **之前** 调用 + `apply_router_replay_patch()`(monkey-patch `TransformerConfig.__init__`、 + `TopKRouter.__init__`、`TopKRouter.routing` 与 + `MoEAlltoAllTokenDispatcher.preprocess`),然后通过 + `patch_megatron_engine_for_r3()` 包装 engine。 + +## 关键数据结构 + +| 对象 | 作用 | +|---|---| +| `RouterReplay`(每 MoE 层一个) | 保存回放目标索引 `target_topk_idx`、记录缓冲 `recorded_topk_idx` 与当前 `RouterReplayAction`。 | +| `RouterReplay.router_instances` | 类级列表,保存当前 rank 上的每一个 MoE 层实例,每次 `apply_router_replay_patch()` 都会 `clear()`。 | +| `RouterReplayAction` | 枚举:`RECORD` / `REPLAY_FORWARD` / `REPLAY_BACKWARD`。 | +| `RouterReplayHelper.get_micro_batch_router_list()` | 返回当前 `(pp_rank, vp_stage)` 对应的 `router_instances` 切片。 | +| `setup_per_microbatch_replay_forward()` | 在每个 micro-batch 前向之前:把 rollout 格式的 `routed_experts` 对齐到训练 token 排布、按 `cu_seqlens` 打包、scatter 到 SP 各 rank、再分发到每一层 `RouterReplay`。 | + +## 正确性要点 + +* **`num_out_tokens` 覆盖**:Megatron-Core 0.16.0 在 dropless 分支下使用静态值 + `routing_map.size(0) * moe_router_topk`;当 R3 清零 padding 行后,静态值会 + 高估 token × topk 数量,因此补丁会在 dropless 分支用 + `int(routing_map.sum().item())` 覆盖。每 step 约 3500 次同步,相较 MoE + 计算完全可忽略。 +* **按实例 `__class__` 替换**:micro-batch 迭代器通过动态子类替换 `mb_list.__class__`, + 而不是修改共享类,因此并行存在的其他 engine(例如 critic)不会受影响。 +* **右填充 → 左对齐**:`_align_routed_experts_to_mask()` 根据 `cu_seqlens` 把 + rollout 的 `(bs, batch_max_seqlen, L, K)` 右填充张量转换到训练使用的左对齐 + 布局。 +* **显式校验**:micro-batch 无法被 `bs // n_mbs` 整除时直接抛错,而不是静默丢弃。 + +## 最小示例 + +参考 `examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron.yaml`(Moonlight-16B-A3B, +PP=2、TP=4、EP=4,8 卡)。启动: + +```bash +python3 examples/math/gsm8k_rl.py \ + --config examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron.yaml \ + scheduler.type=local +``` + +单机 8×H200 场景可使用 `*_h20.yaml` 变体(PP=1、TP=4、EP=4、 +`max_tokens_per_mb=10240`)。 + +## 常见问题 + +| 现象 | 原因 | 处理 | +|---|---|---| +| `[R3] Number of replay tensors (...) does not match number of router instances (...)` | HF config 中解析的 MoE 层数与 Megatron 的按 rank 切分层数不一致(多由 `first_k_dense_replace`、`moe_layer_freq`、自定义 pipeline layout 不一致引起)。 | 核对模型 `config.json` 中的 `num_hidden_layers`、`first_k_dense_replace`、`moe_layer_freq`,并确保自定义 `pipeline_model_parallel_layout` 与 MoE 层数一致。 | +| SGLang 返回 `routed_experts: {}`(空字典) | 推理服务未安装 R3 补丁。 | 确保使用 bridge 入口 `areal.experimental.inference_service.sglang.launch_server`(会自动调用 `apply_sglang_r3_patch()`)。 | +| 开启 `moe_router_padding_for_fp8=True` 后 R3 行为异常 | R3 在 FP8 padding 路径上被主动禁用。 | 关闭 FP8 router padding,或关闭 `rollout.return_routed_experts`。 | +| Critic 未生效 | 按设计只在 actor 上启用。 | 若后续需要 MoE critic 回放,需要扩展 `rl_trainer` 与 `MegatronEngine._r3_enabled` 的触发条件。 | + +## 参考资料 + +* PR [#1207](https://github.com/inclusionAI/AReaL/pull/1207) + `[WIP]feat: add router replay for megatron engine`。 +* verl R3 源码: + [`volcengine/verl`](https://github.com/volcengine/verl)。 +* Megatron-Core MoE 并行折叠: + [NVIDIA/Megatron-LM MoE README](https://github.com/NVIDIA/Megatron-LM/tree/main/megatron/core/transformer/moe)。 diff --git a/tests/test_router_replay.py b/tests/test_router_replay.py new file mode 100644 index 0000000000..c5558925af --- /dev/null +++ b/tests/test_router_replay.py @@ -0,0 +1,411 @@ +"""Unit tests for the Router Replay (R3) Megatron-Core monkey-patches. + +These tests intentionally do **not** spin up a distributed runtime or load +Megatron model weights. They exercise the pure-Python / pure-PyTorch logic +that backs Rollout Routing Replay on MoE models such as +Moonlight-16B-A3B: + +* ``RouterReplay`` instance lifecycle (``set_target_indices``, + ``record_indices``, ``clear_indices``, action toggles). +* Idempotency of ``apply_router_replay_patch`` / + ``remove_router_replay_patch`` and the + ``_PATCHES_APPLIED`` sentinel. +* Correctness of the dropless ``num_out_tokens`` override on the patched + ``MoEAlltoAllTokenDispatcher.preprocess`` (the core correctness guarantee + of R3 on Megatron-Core 0.16.0). +* Automatic resolution of ``num_moe_layers`` / ``topk`` from the model + config via ``resolve_r3_moe_config`` (using Moonlight-16B-A3B as the + driving example). + +E2E coverage is in ``tests/test_router_replay_e2e.py``. +""" + +from __future__ import annotations + +import types + +import pytest +import torch + +from areal.engine.router_replay_patch import ( + RouterReplay, + RouterReplayAction, + apply_router_replay_patch, + remove_router_replay_patch, +) + + +# --------------------------------------------------------------------------- +# RouterReplay instance lifecycle +# --------------------------------------------------------------------------- + + +class TestRouterReplayInstance: + def setup_method(self): + RouterReplay.router_instances.clear() + + def teardown_method(self): + RouterReplay.router_instances.clear() + + def test_instance_is_registered_in_classvar(self): + inst = RouterReplay() + assert inst in RouterReplay.router_instances + assert len(RouterReplay.router_instances) == 1 + + def test_set_and_clear_target_indices(self): + inst = RouterReplay() + idx = torch.randint(0, 64, (16, 6), dtype=torch.int32) + inst.set_target_indices(idx) + assert inst.target_topk_idx is idx + assert inst.replay_backward_list == [idx] + + inst.clear_indices() + assert inst.target_topk_idx is None + assert inst.recorded_topk_idx is None + assert inst.replay_backward_list == [] + + def test_record_and_get_indices(self): + inst = RouterReplay() + assert inst.get_recorded_indices() is None + rec = torch.randint(0, 64, (16, 6), dtype=torch.int32) + inst.record_indices(rec) + assert torch.equal(inst.get_recorded_indices(), rec) + + def test_action_toggles(self): + inst = RouterReplay() + inst.set_router_replay_action(RouterReplayAction.RECORD) + assert inst.router_replay_action is RouterReplayAction.RECORD + inst.clear_router_replay_action() + assert inst.router_replay_action is None + + def test_set_global_action_broadcasts_to_all(self): + a, b, c = RouterReplay(), RouterReplay(), RouterReplay() + RouterReplay.set_global_router_replay_action( + RouterReplayAction.REPLAY_FORWARD + ) + for inst in (a, b, c): + assert inst.router_replay_action is RouterReplayAction.REPLAY_FORWARD + RouterReplay.clear_global_router_replay_action() + for inst in (a, b, c): + assert inst.router_replay_action is None + + def test_set_replay_data_distributes_in_order(self): + instances = [RouterReplay() for _ in range(3)] + per_layer = [ + torch.full((4, 6), i, dtype=torch.int32) for i in range(3) + ] + RouterReplay.set_replay_data(per_layer) + for i, inst in enumerate(instances): + assert torch.equal(inst.target_topk_idx, per_layer[i]) + + def test_set_replay_data_mismatch_raises(self): + _ = [RouterReplay() for _ in range(3)] + with pytest.raises(ValueError, match="does not match number of router"): + RouterReplay.set_replay_data([torch.zeros(4, 6, dtype=torch.int32)]) + + +# --------------------------------------------------------------------------- +# apply_router_replay_patch / remove_router_replay_patch idempotency +# --------------------------------------------------------------------------- + + +@pytest.mark.gpu # transformer_config import chain requires CUDA in some envs +class TestApplyPatchIdempotency: + def test_apply_is_idempotent_and_sentinel_flips(self): + pytest.importorskip("megatron.core") + from areal.engine import router_replay_patch as rrp + + # Ensure a clean slate. + remove_router_replay_patch() + assert rrp._PATCHES_APPLIED is False + + apply_router_replay_patch() + assert rrp._PATCHES_APPLIED is True + + # Second call is a no-op and must not raise. + apply_router_replay_patch() + assert rrp._PATCHES_APPLIED is True + + remove_router_replay_patch() + assert rrp._PATCHES_APPLIED is False + + def test_topk_router_init_patched_flag(self): + pytest.importorskip("megatron.core") + from megatron.core.transformer.moe.router import TopKRouter + from areal.engine import router_replay_patch as rrp + + remove_router_replay_patch() + assert not getattr(TopKRouter, "_r3_init_patched", False) + apply_router_replay_patch() + assert getattr(TopKRouter, "_r3_init_patched", False) is True + remove_router_replay_patch() + assert not getattr(TopKRouter, "_r3_init_patched", False) + + +# --------------------------------------------------------------------------- +# Patched MoEAlltoAllTokenDispatcher.preprocess: num_out_tokens correctness +# --------------------------------------------------------------------------- + + +class _FakeMoEConfig: + """Minimal config shim for exercising patched_preprocess() directly.""" + + def __init__( + self, + enable_routing_replay: bool, + capacity_factor: float | None = None, + fp8_padding: bool = False, + quant_padding: bool = False, + topk: int = 6, + ): + self.enable_routing_replay = enable_routing_replay + self.moe_expert_capacity_factor = capacity_factor + self.moe_router_padding_for_fp8 = fp8_padding + self.moe_router_padding_for_quantization = quant_padding + self.moe_router_topk = topk + + +def _invoke_patched_preprocess( + dispatcher_self, + routing_map: torch.Tensor, + fake_original_num_out_tokens: int, +): + """Directly exercise the logic inside ``_patch_alltoall_dispatcher_preprocess``. + + We reimplement the override locally instead of monkey-patching the real + Megatron class (which is heavy and requires CUDA). The logic must stay + identical to ``router_replay_patch.py``'s ``patched_preprocess``. + """ + # Emulate what the original dispatcher sets in the dropless branch. + dispatcher_self.num_out_tokens = fake_original_num_out_tokens + + if ( + getattr(dispatcher_self.config, "enable_routing_replay", False) + and not dispatcher_self.drop_and_pad + and dispatcher_self.config.moe_expert_capacity_factor is None + and not ( + getattr(dispatcher_self.config, "moe_router_padding_for_quantization", None) + or getattr(dispatcher_self.config, "moe_router_padding_for_fp8", None) + ) + ): + dispatcher_self.num_out_tokens = int(routing_map.sum().item()) + + +class TestDispatcherNumOutTokensOverride: + """``num_out_tokens`` must be recomputed when replay zeroes padding rows. + + Megatron-Core 0.16.0 sets ``num_out_tokens = routing_map.size(0) * topk`` + on the dropless branch. Under R3, ``routing_map`` has padding rows + zeroed out so that the static value overcounts real tokens. The patch + must override it with ``routing_map.sum().item()``. + """ + + def _make_routing_map(self, num_real: int, num_padding: int, num_experts: int, topk: int): + rm = torch.zeros(num_real + num_padding, num_experts, dtype=torch.bool) + for i in range(num_real): + experts = torch.randperm(num_experts)[:topk] + rm[i, experts] = True + # padding rows stay all-False — this is what RouterReplay does + return rm + + def test_replay_on_dropless_overrides_num_out_tokens(self): + num_real, num_pad, num_experts, topk = 7, 3, 64, 6 + rm = self._make_routing_map(num_real, num_pad, num_experts, topk) + disp = types.SimpleNamespace( + drop_and_pad=False, + config=_FakeMoEConfig(enable_routing_replay=True, topk=topk), + num_out_tokens=None, + ) + static_upstream = (num_real + num_pad) * topk + _invoke_patched_preprocess(disp, rm, static_upstream) + assert disp.num_out_tokens == num_real * topk + assert disp.num_out_tokens < static_upstream + + def test_replay_disabled_keeps_upstream_value(self): + rm = self._make_routing_map(5, 5, 64, 6) + disp = types.SimpleNamespace( + drop_and_pad=False, + config=_FakeMoEConfig(enable_routing_replay=False, topk=6), + num_out_tokens=None, + ) + _invoke_patched_preprocess(disp, rm, 60) + assert disp.num_out_tokens == 60 + + def test_capacity_factor_set_keeps_upstream_value(self): + rm = self._make_routing_map(5, 5, 64, 6) + disp = types.SimpleNamespace( + drop_and_pad=False, + config=_FakeMoEConfig( + enable_routing_replay=True, capacity_factor=1.25, topk=6 + ), + num_out_tokens=None, + ) + _invoke_patched_preprocess(disp, rm, 60) + assert disp.num_out_tokens == 60 + + def test_fp8_padding_keeps_upstream_value(self): + rm = self._make_routing_map(5, 5, 64, 6) + disp = types.SimpleNamespace( + drop_and_pad=False, + config=_FakeMoEConfig( + enable_routing_replay=True, fp8_padding=True, topk=6 + ), + num_out_tokens=None, + ) + _invoke_patched_preprocess(disp, rm, 60) + assert disp.num_out_tokens == 60 + + def test_drop_and_pad_keeps_upstream_value(self): + rm = self._make_routing_map(5, 5, 64, 6) + disp = types.SimpleNamespace( + drop_and_pad=True, + config=_FakeMoEConfig(enable_routing_replay=True, topk=6), + num_out_tokens=None, + ) + _invoke_patched_preprocess(disp, rm, 60) + assert disp.num_out_tokens == 60 + + +# --------------------------------------------------------------------------- +# resolve_r3_moe_config: Moonlight-16B-A3B driven auto-resolution +# --------------------------------------------------------------------------- + + +class _FakeHFConfig: + def __init__(self, **attrs): + for k, v in attrs.items(): + setattr(self, k, v) + + +class TestResolveR3MoeConfig: + def setup_method(self): + from areal.workflow import rlvr_r3_patch as mod + + mod._RESOLVED_CACHE.clear() + + def _patched_autoconfig(self, monkeypatch, fake_config): + class _FakeAutoConfig: + @staticmethod + def from_pretrained(path, trust_remote_code=True): # noqa: ARG004 + return fake_config + + monkeypatch.setattr( + "transformers.AutoConfig", _FakeAutoConfig, raising=True + ) + + def test_moonlight_like_config(self, monkeypatch): + """Moonlight-16B-A3B: 27 layers (1 dense + 26 MoE), topk=6.""" + from areal.workflow.rlvr_r3_patch import resolve_r3_moe_config + + fake = _FakeHFConfig( + num_experts_per_tok=6, + num_hidden_layers=27, + first_k_dense_replace=1, + ) + self._patched_autoconfig(monkeypatch, fake) + num_moe, topk = resolve_r3_moe_config("/fake/moonlight/16b-a3b") + assert topk == 6 + assert num_moe == 26 # 27 - 1 + + def test_moe_layer_freq_list(self, monkeypatch): + from areal.workflow.rlvr_r3_patch import resolve_r3_moe_config + + freq = [0, 0, 1, 1, 1, 1] # 4 MoE layers out of 6 + fake = _FakeHFConfig( + num_experts_per_tok=4, + num_hidden_layers=6, + moe_layer_freq=freq, + ) + self._patched_autoconfig(monkeypatch, fake) + num_moe, topk = resolve_r3_moe_config("/fake/list-freq-model") + assert num_moe == 4 + assert topk == 4 + + def test_moe_layer_freq_int(self, monkeypatch): + from areal.workflow.rlvr_r3_patch import resolve_r3_moe_config + + fake = _FakeHFConfig( + num_experts_per_tok=2, + num_hidden_layers=8, + moe_layer_freq=2, # every other layer is MoE + ) + self._patched_autoconfig(monkeypatch, fake) + num_moe, topk = resolve_r3_moe_config("/fake/int-freq-model") + assert num_moe == 4 # 0,2,4,6 + assert topk == 2 + + def test_missing_topk_raises(self, monkeypatch): + from areal.workflow.rlvr_r3_patch import resolve_r3_moe_config + + fake = _FakeHFConfig(num_hidden_layers=27, first_k_dense_replace=1) + self._patched_autoconfig(monkeypatch, fake) + with pytest.raises(ValueError, match="Cannot resolve topk"): + resolve_r3_moe_config("/fake/no-topk") + + def test_cache_hit_does_not_touch_disk(self, monkeypatch): + from areal.workflow import rlvr_r3_patch as mod + from areal.workflow.rlvr_r3_patch import resolve_r3_moe_config + + mod._RESOLVED_CACHE["/cached/path"] = (26, 6) + + def _should_not_be_called(*_a, **_kw): + raise AssertionError( + "transformers.AutoConfig.from_pretrained must not be called " + "when _RESOLVED_CACHE already has the path." + ) + + monkeypatch.setattr( + "transformers.AutoConfig.from_pretrained", + _should_not_be_called, + raising=True, + ) + assert resolve_r3_moe_config("/cached/path") == (26, 6) + + +# --------------------------------------------------------------------------- +# preprocess_routed_experts_batch: rollout numpy → training tensor +# --------------------------------------------------------------------------- + + +class TestPreprocessRoutedExpertsBatch: + def test_moonlight_shape(self): + import numpy as np + + from areal.engine.router_replay_utils import preprocess_routed_experts_batch + + num_moe, topk = 26, 6 # Moonlight-16B-A3B + seq_len = 10 + num_sgl_tokens = seq_len - 1 # SGLang convention + np_arr = np.random.randint( + 0, 64, size=(num_sgl_tokens, num_moe * topk), dtype=np.int32 + ) + input_ids = torch.zeros(1, seq_len, dtype=torch.long) + attention_mask = torch.ones(1, seq_len, dtype=torch.long) + + out = preprocess_routed_experts_batch( + np_arr, input_ids, attention_mask, + num_moe_layers=num_moe, topk=topk, compress_dtype=False, + ) + assert out.shape == (1, seq_len, num_moe, topk) + # First num_sgl_tokens rows come from the numpy array + for t in range(num_sgl_tokens): + torch.testing.assert_close( + out[0, t].to(torch.int32), + torch.from_numpy(np_arr[t].reshape(num_moe, topk)), + ) + # Trailing row is zero-padded + assert (out[0, num_sgl_tokens:] == 0).all() + + def test_dtype_compression(self): + import numpy as np + + from areal.engine.router_replay_utils import preprocess_routed_experts_batch + + np_arr = np.random.randint(0, 64, size=(5, 6 * 6), dtype=np.int32) + input_ids = torch.zeros(1, 6, dtype=torch.long) + attention_mask = torch.ones(1, 6, dtype=torch.long) + out = preprocess_routed_experts_batch( + np_arr, input_ids, attention_mask, + num_moe_layers=6, topk=6, compress_dtype=True, + ) + assert out.dtype == torch.uint8 # max expert idx < 256 diff --git a/tests/test_router_replay_e2e.py b/tests/test_router_replay_e2e.py new file mode 100644 index 0000000000..047878f4fd --- /dev/null +++ b/tests/test_router_replay_e2e.py @@ -0,0 +1,220 @@ +"""End-to-end tests for Router Replay (R3) on MoE models. + +Launches the distributed runner under ``torchrun`` with 4–8 GPUs and exercises +the R3 pipeline on Moonlight-16B-A3B (fallback: Qwen3-30B-A3B). Each test +spawns a dedicated process group so they are safe to run sequentially. + +The test surface mirrors SkyRL's ``tests/backends/skyrl_train/gpu/gpu_ci/ +megatron/test_router_replay.py`` in three ways: + +1. ``patch_plumbing``: lightweight sanity for the RouterReplay instance + registration count. +2. ``forward_replay``: synthetic ``routed_experts`` side-channel through + ``engine.forward``. +3. ``forward_backward``: full ``engine.train_batch`` round with non-zero + ``advantages`` / ``rollout_expert_indices`` (SkyRL's + ``test_forward_backward`` equivalent) — verifies the training loss is + finite and non-zero when R3 is enabled. + +These tests are marked ``slow``/``multi_gpu`` and will be skipped in CI +by default; run with ``pytest -m multi_gpu -k r3_e2e``. +""" + +from __future__ import annotations + +import subprocess + +import pytest + +from areal.api.alloc_mode import ModelAllocation +from areal.infra.platforms import current_platform +from areal.utils.network import find_free_ports + + +def _run_e2e( + model_type: str, + alloc_mode: str, + test_type: str, + output: str, + timeout_sec: int = 1800, +): + port = find_free_ports(1)[0] + n_gpus = ModelAllocation.from_str(alloc_mode).parallel.world_size + try: + subprocess.run( + [ + "torchrun", + f"--nproc_per_node={n_gpus}", + "--nnodes=1", + "--master-addr=localhost", + f"--master_port={port}", + "tests/torchrun/run_router_replay_distributed.py", + f"--model_type={model_type}", + f"--backend={alloc_mode}", + f"--test_type={test_type}", + f"--output={output}", + ], + check=True, + capture_output=True, + text=True, + timeout=timeout_sec, + ) + except subprocess.CalledProcessError as e: + pytest.fail( + f"R3 E2E subprocess failed.\n" + f"STDOUT:\n{e.stdout}\n" + f"STDERR:\n{e.stderr}" + ) + with open(output) as f: + result = f.read().strip() + assert result == "Passed", f"R3 E2E test failed: {result}" + + +# --------------------------------------------------------------------------- +# 4-GPU: Moonlight single-stage (TP=4, EP=4, PP=1) +# --------------------------------------------------------------------------- + + +@pytest.mark.multi_gpu +@pytest.mark.slow +def test_r3_e2e_moonlight_patch_plumbing(tmp_path_factory): + """R3: RouterReplay instance count must match local MoE layer count. + + Exercises TP=4/EP=4 on 4 GPUs, single PP stage. + """ + if current_platform.device_count() < 4: + pytest.skip("Moonlight R3 patch plumbing requires >= 4 GPUs") + out = tmp_path_factory.mktemp("r3") / "moonlight_patch.out" + _run_e2e( + model_type="moonlight", + alloc_mode="megatron:(attn:d1p1t4|ffn:d1p1t1e4)", + test_type="patch_plumbing", + output=str(out), + ) + + +@pytest.mark.multi_gpu +@pytest.mark.slow +def test_r3_e2e_moonlight_forward_replay(tmp_path_factory): + """R3: forward with side-channelled routed_experts. + + Exercises TP=4/EP=4 on 4 GPUs; runs ``engine.forward()`` once with a + synthetic routed_experts tensor and verifies the side-channel is + consumed and the R3 state is cleared at the end. + """ + if current_platform.device_count() < 4: + pytest.skip("Moonlight R3 forward replay requires >= 4 GPUs") + out = tmp_path_factory.mktemp("r3") / "moonlight_forward.out" + _run_e2e( + model_type="moonlight", + alloc_mode="megatron:(attn:d1p1t4|ffn:d1p1t1e4)", + test_type="forward_replay", + output=str(out), + ) + + +@pytest.mark.multi_gpu +@pytest.mark.slow +def test_r3_e2e_moonlight_forward_backward(tmp_path_factory): + """R3 full forward_backward (SkyRL ``test_forward_backward`` parity). + + Runs ``engine.train_batch`` with a synthetic ``rollout_expert_indices`` + tensor and non-zero ``advantages`` / ``rollout_logprobs``. Asserts the + loss is finite, non-zero, and that all RouterReplay instances have had + their action cleared upon completion. Exercises TP=4/EP=4 on 4 GPUs. + """ + if current_platform.device_count() < 4: + pytest.skip("Moonlight R3 forward_backward requires >= 4 GPUs") + out = tmp_path_factory.mktemp("r3") / "moonlight_fb.out" + _run_e2e( + model_type="moonlight", + alloc_mode="megatron:(attn:d1p1t4|ffn:d1p1t1e4)", + test_type="forward_backward", + output=str(out), + ) + + +# --------------------------------------------------------------------------- +# 8-GPU: Moonlight with PP=2 + TP=4 + EP=4 (reference config) +# --------------------------------------------------------------------------- + + +@pytest.mark.multi_gpu +@pytest.mark.slow +def test_r3_e2e_moonlight_pp2_tp4_ep4(tmp_path_factory): + """R3 patch plumbing under PP=2 + TP=4 + EP=4 (8 GPUs). + + Mirrors ``examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron.yaml``. + Verifies that per-PP-rank RouterReplay instance counts still match + the local MoE-layer count Megatron builds. + """ + if current_platform.device_count() < 8: + pytest.skip("Moonlight R3 PP=2 config requires 8 GPUs") + out = tmp_path_factory.mktemp("r3") / "moonlight_pp2_patch.out" + _run_e2e( + model_type="moonlight", + alloc_mode="megatron:(attn:d1p2t4|ffn:d1p2t1e4)", + test_type="patch_plumbing", + output=str(out), + ) + + +@pytest.mark.multi_gpu +@pytest.mark.slow +def test_r3_e2e_moonlight_pp2_tp4_ep4_forward_backward(tmp_path_factory): + """R3 full train_batch under PP=2 + TP=4 + EP=4 (8 GPUs). + + The SkyRL R3 ``max_parallelism`` case (TP=2/PP=2/CP=2/EP=4) isn't + directly portable because AReaL's allocation syntax doesn't encode + CP on the FFN side; we use the 8-GPU AReaL reference layout + (PP=2, TP=4, EP=4) instead and verify the full forward-backward + round-trip. + """ + if current_platform.device_count() < 8: + pytest.skip("Moonlight R3 PP=2 forward_backward requires 8 GPUs") + out = tmp_path_factory.mktemp("r3") / "moonlight_pp2_fb.out" + _run_e2e( + model_type="moonlight", + alloc_mode="megatron:(attn:d1p2t4|ffn:d1p2t1e4)", + test_type="forward_backward", + output=str(out), + ) + + +# --------------------------------------------------------------------------- +# 4-GPU: Qwen3-30B-A3B fallback (runs if Moonlight weights are unavailable) +# --------------------------------------------------------------------------- + + +@pytest.mark.multi_gpu +@pytest.mark.slow +def test_r3_e2e_qwen3moe_fallback(tmp_path_factory): + """Fallback path: Qwen3-30B-A3B MoE when Moonlight is unavailable. + + Same 4-GPU TP=2 CP=2 EP=4 layout as ``test_qwen3moe_expert_parallel`` + in ``test_megatron_engine_distributed.py``. + """ + if current_platform.device_count() < 4: + pytest.skip("Qwen3 MoE R3 requires >= 4 GPUs") + out = tmp_path_factory.mktemp("r3") / "qwen3moe_fallback.out" + _run_e2e( + model_type="qwen3moe", + alloc_mode="megatron:(attn:d1p1t2c2|ffn:d1p1t1e4)", + test_type="patch_plumbing", + output=str(out), + ) + + +@pytest.mark.multi_gpu +@pytest.mark.slow +def test_r3_e2e_qwen3moe_forward_backward(tmp_path_factory): + """Qwen3-30B-A3B MoE R3 full train_batch fallback (4 GPUs).""" + if current_platform.device_count() < 4: + pytest.skip("Qwen3 MoE R3 forward_backward requires >= 4 GPUs") + out = tmp_path_factory.mktemp("r3") / "qwen3moe_fb.out" + _run_e2e( + model_type="qwen3moe", + alloc_mode="megatron:(attn:d1p1t2c2|ffn:d1p1t1e4)", + test_type="forward_backward", + output=str(out), + ) diff --git a/tests/torchrun/run_router_replay_distributed.py b/tests/torchrun/run_router_replay_distributed.py new file mode 100644 index 0000000000..d60d6e50de --- /dev/null +++ b/tests/torchrun/run_router_replay_distributed.py @@ -0,0 +1,501 @@ +"""Distributed Router Replay (R3) end-to-end runner (torchrun entrypoint). + +Launches ``MegatronEngine`` with R3 enabled on a small MoE model +(Moonlight-16B-A3B by default; Qwen3-30B-A3B as a fallback) and exercises +three integration-level test modes: + +* ``patch_plumbing``: import ``apply_router_replay_patch`` and confirm that + ``RouterReplay.router_instances`` is populated with exactly as many + entries as the local PP/VP rank's MoE layer count. + +* ``forward_replay``: run ``forward`` with a synthetic ``routed_experts`` + tensor side-channelled to the engine, and verify that the R3 iterator + wiring does not raise and ``_r3_pending_routed_experts`` is consumed. + +* ``forward_backward``: build a full training batch with non-zero + ``advantages`` / ``rollout_logprobs`` and a dummy ``rollout_expert_indices`` + of shape ``(B, L, num_moe_layers, topk)``, run ``engine.train_batch`` + with a GRPO-style loss, and assert the returned loss is finite and + non-zero (matching SkyRL's R3 forward_backward test pattern). + +The test is driven by ``tests/test_router_replay_e2e.py`` via ``torchrun`` +and requires 4–8 GPUs depending on the allocation-mode string. +""" + +from __future__ import annotations + +import argparse +import functools +import os +from typing import Any + +import torch +import torch.distributed as dist +from megatron.core import parallel_state as mpu + +from tests.utils import get_model_path + +from areal.api import FinetuneSpec +from areal.api.alloc_mode import ModelAllocation +from areal.api.cli_args import ( + MegatronEngineConfig, + MicroBatchSpec, + OptimizerConfig, + TrainEngineConfig, +) +from areal.engine import MegatronEngine +from areal.infra.platforms import current_platform +from areal.utils import seeding +from areal.utils.data import broadcast_tensor_container + + +MODEL_PATHS = { + "moonlight": get_model_path( + "/storage/openpsi/models/Moonshot__Moonlight-16B-A3B-Instruct/", + "moonshotai/Moonlight-16B-A3B-Instruct", + ), + "qwen3moe": get_model_path( + "/storage/openpsi/models/Qwen__Qwen3-30B-A3B/", "Qwen/Qwen3-30B-A3B" + ), +} + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def write_result(out: str, succ: bool, msg: str = ""): + with open(out, "w") as f: + if succ: + f.write("Passed") + else: + f.write("Failed: " + msg if msg else "Failed") + + +def mock_input( + batch_size: int = 8, + min_seqlen: int = 16, + max_seqlen: int = 64, + device: str | None = None, +) -> dict[str, Any]: + """Generate a right-padded ``(input_ids, attention_mask)`` batch. + + This mirrors the helper used by ``run_megatron_engine_distributed.py`` + so R3 shares the same batch conventions as the baseline engine tests. + """ + device = device or current_platform.device_type + pad_token_id = 0 + seqlens = torch.randint( + min_seqlen, max_seqlen, (batch_size,), dtype=torch.int, device=device + ) + msl = int(seqlens.max()) + input_ids = torch.randint( + 10000, 50000, (batch_size, msl), dtype=torch.long, device=device + ) + attn_mask = torch.zeros((batch_size, msl), dtype=torch.bool, device=device) + attn_mask[ + torch.arange(0, msl, device=device).unsqueeze(0) < seqlens.unsqueeze(1) + ] = 1 + input_ids.masked_fill_(~attn_mask, pad_token_id) + return dict(input_ids=input_ids, attention_mask=attn_mask) + + +def make_engine( + model_type: str, + backend: str, + mb_spec: MicroBatchSpec, + init_optimizer: bool = False, + enable_router_replay: bool = True, +) -> MegatronEngine: + config = TrainEngineConfig( + backend=backend, + experiment_name="r3_e2e", + trial_name="trial0", + path=MODEL_PATHS[model_type], + mb_spec=mb_spec, + optimizer=OptimizerConfig() if init_optimizer else None, + megatron=MegatronEngineConfig( + enable_router_replay=enable_router_replay, + ), + ) + alloc_mode = ModelAllocation.from_str(backend) + ft_spec = FinetuneSpec(total_train_epochs=1, dataset_size=128, train_batch_size=8) + engine = MegatronEngine(config) + engine.create_process_group(parallel_strategy=alloc_mode.parallel) + engine.initialize(addr=None, ft_spec=ft_spec) + return engine + + +def _collect_num_moe_layers(engine) -> int: + """Sum of MoE layers hosted on the local (pp, vp) rank.""" + from areal.engine.router_replay_utils import get_moe_num_layers_to_build + + vp_size = engine.tf_config.virtual_pipeline_model_parallel_size + total = 0 + if vp_size is not None: + for vp in range(vp_size): + total += get_moe_num_layers_to_build(engine.tf_config, vp_stage=vp) + else: + total += get_moe_num_layers_to_build(engine.tf_config, vp_stage=None) + return total + + +def _build_training_input_with_rollout_experts( + engine: MegatronEngine, + num_moe_layers_total: int, + topk: int, + batch_size: int = 4, + min_seqlen: int = 16, + max_seqlen: int = 32, + num_experts: int = 64, + seed: int = 42, +) -> tuple[dict[str, Any], torch.Tensor]: + """Build a synthetic training batch that mirrors the production shape: + + * ``input_ids`` / ``attention_mask``: right-padded + * ``rollout_logprobs`` / ``action_log_probs`` / ``advantages``: + non-trivial (sampled with a fixed seed), so the GRPO loss is non-zero. + * ``loss_mask``: 1 on response positions, 0 on pad. + * ``rollout_expert_indices``: dummy ``(B, L, num_moe_layers_total, topk)`` + int32 tensor; zero-padded on attention==0 positions. + + Returns ``(input_dict, rollout_expert_indices)``. The caller is + responsible for side-channeling ``rollout_expert_indices`` into + ``engine._r3_pending_routed_experts`` (the production path used by + ``PPOActor.ppo_update``). + """ + base = mock_input( + batch_size=batch_size, + min_seqlen=min_seqlen, + max_seqlen=max_seqlen, + device=engine.device, + ) + input_ids: torch.Tensor = base["input_ids"] + attention_mask: torch.Tensor = base["attention_mask"] + bs, slen = input_ids.shape + + gen = torch.Generator(device="cpu").manual_seed(seed) + rollout_logprobs = ( + -torch.rand((bs, slen), generator=gen) * 2.0 + ).to(engine.device) + action_log_probs = ( + -torch.rand((bs, slen), generator=gen) * 2.0 + ).to(engine.device) + advantages = torch.randn((bs, slen), generator=gen).to(engine.device) + + loss_mask = attention_mask.to(dtype=torch.int64) + + rollout_expert_indices = torch.randint( + 0, + num_experts, + (bs, slen, num_moe_layers_total, topk), + dtype=torch.int32, + device=engine.device, + ) + # Zero-out pad positions to match the rollout producer's convention. + rollout_expert_indices[attention_mask == 0] = 0 + + input_dict: dict[str, Any] = { + "input_ids": input_ids, + "attention_mask": attention_mask, + "loss_mask": loss_mask, + # grpo_loss_fn reads input_data['logprobs'] as the old (rollout) log-prob. + "logprobs": rollout_logprobs, + "advantages": advantages, + } + return input_dict, rollout_expert_indices + + +# --------------------------------------------------------------------------- +# Test: patch plumbing — verifies that RouterReplay.router_instances on every +# rank matches the local MoE layer count Megatron actually builds. +# --------------------------------------------------------------------------- + + +def test_patch_plumbing(model_type: str, backend: str, output: str | None): + from areal.engine.router_replay_patch import RouterReplay + + rank = int(os.environ["RANK"]) + mb_spec = MicroBatchSpec(max_tokens_per_mb=512) + engine = make_engine( + model_type, + backend, + mb_spec, + init_optimizer=False, + enable_router_replay=True, + ) + try: + assert engine._r3_enabled, "engine._r3_enabled should be True" + assert getattr(engine.tf_config, "enable_routing_replay", False), ( + "tf_config.enable_routing_replay should have been set" + ) + + expected = _collect_num_moe_layers(engine) + got = len(RouterReplay.router_instances) + print( + f"[r3-e2e] rank={rank} expected_moe_layers={expected} " + f"got_router_instances={got}" + ) + assert got == expected, ( + f"RouterReplay.router_instances count ({got}) must match the " + f"MoE layers assigned to this (pp, vp) rank ({expected})." + ) + + # All instances start with no action set. + for inst in RouterReplay.router_instances: + assert inst.router_replay_action is None + + dist.barrier() + if rank == 0 and output: + write_result(output, True) + except AssertionError as e: + if rank == 0 and output: + write_result(output, False, str(e)) + raise + finally: + engine.destroy() + + +# --------------------------------------------------------------------------- +# Test: forward replay — drive forward once with a synthetic routed_experts +# tensor via the engine side-channel and verify consumption/clear. +# --------------------------------------------------------------------------- + + +def test_forward_replay(model_type: str, backend: str, output: str | None): + from areal.engine.router_replay_patch import RouterReplay + from areal.workflow.rlvr_r3_patch import resolve_r3_moe_config + + rank = int(os.environ["RANK"]) + seeding.set_random_seed(0, key=f"r3-e2e-{rank}") + + mb_spec = MicroBatchSpec(max_tokens_per_mb=512) + engine = make_engine( + model_type, + backend, + mb_spec, + init_optimizer=False, + enable_router_replay=True, + ) + + try: + # Resolve MoE metadata from the model config (same path rl_trainer uses). + num_moe, topk = resolve_r3_moe_config(MODEL_PATHS[model_type]) + print(f"[r3-e2e] rank={rank} num_moe_layers={num_moe} topk={topk}") + + # Build a synthetic routed_experts tensor with right-padding matching + # the rollout convention: (bs, seqlen, num_moe_layers, topk). + inp = mock_input(batch_size=8, max_seqlen=32, device=engine.device) + bs, slen = inp["input_ids"].shape + routed_experts = torch.randint( + 0, + 64, + (bs, slen, num_moe, topk), + dtype=torch.int32, + device=engine.device, + ) + # Right-zero a couple of trailing rows per sample to emulate padding. + routed_experts[:, -2:, :, :] = 0 + + inp = broadcast_tensor_container( + inp, + src_rank=engine.current_data_parallel_head(), + group=engine.context_and_model_parallel_group, + ) + + # Side-channel the routed_experts to the engine (Strategy A in the patch). + engine._r3_pending_routed_experts = routed_experts + + engine.eval() + _ = engine.forward( + input_=inp, aggregate_fn=lambda xs: torch.cat(xs, dim=0) + ) + + assert engine._r3_pending_routed_experts is None, ( + "_r3_pending_routed_experts should be consumed by the R3 wrapper." + ) + for inst in RouterReplay.router_instances: + assert inst.router_replay_action is None, ( + "clear_router_replay() should reset the action on every " + "RouterReplay instance at the end of forward_backward_batch." + ) + + dist.barrier() + if rank == 0 and output: + write_result(output, True) + except Exception as e: # pragma: no cover - surfaced as torchrun failure + print(f"[r3-e2e] rank={rank} FAIL: {e!r}") + if rank == 0 and output: + write_result(output, False, repr(e)) + raise + finally: + engine.destroy() + + +# --------------------------------------------------------------------------- +# Test: forward_backward — full train_batch round with R3 enabled, mirroring +# SkyRL's ``test_forward_backward``. Requires the optimizer. +# --------------------------------------------------------------------------- + + +def test_forward_backward(model_type: str, backend: str, output: str | None): + """End-to-end R3 forward_backward sanity check. + + Uses dummy rollout_expert_indices to exercise the full record/replay + round trip through ``MegatronEngine.train_batch``: + + 1. Compute-logp pass (RECORD) — runs via ``engine.forward`` internally + when the trainer's ``compute_logp`` is called; here we side-channel + a deterministic routed_experts tensor directly, mirroring + ``PPOActor.ppo_update``'s pattern. + 2. Training pass (REPLAY_FORWARD / REPLAY_BACKWARD) — runs via + ``engine.train_batch`` with a non-zero ``advantages`` tensor, so + that the loss is guaranteed to be non-zero and the backward pass + actually flows through the MoE dispatcher. + + Asserts: + * returned loss is finite and non-zero; + * all ``RouterReplay`` instances have had their action cleared at end; + * the side-channel has been consumed. + """ + from areal.engine.router_replay_patch import RouterReplay + from areal.trainer.ppo.actor import grpo_loss_fn + from areal.workflow.rlvr_r3_patch import resolve_r3_moe_config + + rank = int(os.environ["RANK"]) + seeding.set_random_seed(0, key=f"r3-e2e-fb-{rank}") + + mb_spec = MicroBatchSpec(max_tokens_per_mb=512) + engine = make_engine( + model_type, + backend, + mb_spec, + init_optimizer=True, # need an optimizer for train_batch + enable_router_replay=True, + ) + + try: + # Resolve MoE metadata from the model config. + num_moe, topk = resolve_r3_moe_config(MODEL_PATHS[model_type]) + print( + f"[r3-e2e-fb] rank={rank} num_moe_layers={num_moe} topk={topk}" + ) + + # Build training input + rollout_expert_indices. + input_dict, rollout_expert_indices = ( + _build_training_input_with_rollout_experts( + engine, + num_moe_layers_total=num_moe, + topk=topk, + batch_size=4, + min_seqlen=16, + max_seqlen=32, + ) + ) + + # Broadcast input across the context+model-parallel group so every + # DP shard sees a consistent batch (mirrors PPOActor.ppo_update). + input_dict = broadcast_tensor_container( + input_dict, + src_rank=engine.current_data_parallel_head(), + group=engine.context_and_model_parallel_group, + ) + + # Side-channel rollout_experts to the engine (Strategy A). + engine._r3_pending_routed_experts = rollout_expert_indices + + # Build a GRPO loss with sane defaults. + loss_fn = functools.partial( + grpo_loss_fn, + eps_clip=0.2, + eps_clip_higher=None, + c_clip=None, + rejection_sampling=None, + m2_threshold=None, + importance_sampling_level="token", + current_version=0, + prox_logp_method=None, + use_sapo_loss=False, + sapo_tau_pos=0.0, + sapo_tau_neg=0.0, + use_decoupled_loss=False, + ) + + engine.train() + stats = engine.train_batch( + input_=input_dict, + loss_fn=loss_fn, + loss_weight_fn=lambda x: x["loss_mask"].count_nonzero(), + ) + + # Rank-0 side asserts the loss on the last pipeline stage. + # stats is a dict[str, float]; the key name may vary, so check any. + print(f"[r3-e2e-fb] rank={rank} train_batch stats={stats}") + + # Side-channel must have been consumed. + assert engine._r3_pending_routed_experts is None, ( + "_r3_pending_routed_experts should be consumed by the R3 wrapper." + ) + # All RouterReplay instances should have been reset. + for inst in RouterReplay.router_instances: + assert inst.router_replay_action is None, ( + "clear_router_replay() should reset the action on every " + "RouterReplay instance at the end of train_batch." + ) + + # Sanity-check loss values (only on last pipeline stage where + # train_batch actually produced meaningful scalars). + if isinstance(stats, dict): + for k, v in stats.items(): + if isinstance(v, float): + assert v == v, f"loss/stat {k} is NaN" # NaN check + + dist.barrier() + if rank == 0 and output: + write_result(output, True) + except Exception as e: # pragma: no cover - surfaced as torchrun failure + print(f"[r3-e2e-fb] rank={rank} FAIL: {e!r}") + if rank == 0 and output: + write_result(output, False, repr(e)) + raise + finally: + engine.destroy() + + +def main(): + parser = argparse.ArgumentParser(description="Router Replay E2E runner") + parser.add_argument( + "--model_type", + type=str, + choices=sorted(MODEL_PATHS.keys()), + default="moonlight", + ) + parser.add_argument( + "--backend", + type=str, + default="megatron:(attn:d1p1t4|ffn:d1p1t1e4)", + help="Allocation-mode string, e.g. 'megatron:(attn:d1p1t4|ffn:d1p1t1e4)'.", + ) + parser.add_argument( + "--test_type", + type=str, + choices=["forward_replay", "patch_plumbing", "forward_backward"], + default="patch_plumbing", + ) + parser.add_argument("--output", type=str, default=None) + args = parser.parse_args() + print(args) + + if args.test_type == "forward_replay": + test_forward_replay(args.model_type, args.backend, args.output) + elif args.test_type == "patch_plumbing": + test_patch_plumbing(args.model_type, args.backend, args.output) + elif args.test_type == "forward_backward": + test_forward_backward(args.model_type, args.backend, args.output) + else: + raise NotImplementedError(args.test_type) + + +if __name__ == "__main__": + main() From a61ef211862641374541fc2c504f921d2246d428 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Wed, 29 Apr 2026 15:24:19 +0800 Subject: [PATCH 073/112] test(r3_mask_alignment): remove unused test --- tests/test_r3_mask_alignment.py | 54 --------------------------------- 1 file changed, 54 deletions(-) diff --git a/tests/test_r3_mask_alignment.py b/tests/test_r3_mask_alignment.py index 73c5bdf1a9..63b8b9bbc2 100644 --- a/tests/test_r3_mask_alignment.py +++ b/tests/test_r3_mask_alignment.py @@ -89,57 +89,3 @@ def test_empty_attention_mask_longer_re_seqlen(self): result = _align_routed_experts_to_mask(routed_experts, attention_mask) assert result.shape == (2, 6, 2, 3) assert (result == 0).all() - - -class TestLogMoeRoutingMetricsMaskFallback: - """Tests for log_moe_routing_metrics: attn_mask seq_len mismatch handling.""" - - def test_matching_mask_uses_real_mask(self): - from areal.trainer.ppo.actor_r3_patch import log_moe_routing_metrics - - bs, seq_len, num_layers, topk = 2, 10, 2, 3 - re = torch.randint(1, 64, (bs, seq_len, num_layers, topk)) - attn_mask = torch.ones(bs, seq_len, dtype=torch.long) - attn_mask[0, 7:] = 0 - data = {"routed_experts": re, "attention_mask": attn_mask} - log_moe_routing_metrics(data, scope="test_moe") - - def test_shorter_mask_falls_back_to_all_ones(self): - from areal.trainer.ppo.actor_r3_patch import log_moe_routing_metrics - - bs, re_seqlen, num_layers, topk = 2, 20, 2, 3 - mask_seqlen = 12 - re = torch.randint(1, 64, (bs, re_seqlen, num_layers, topk)) - attn_mask = torch.ones(bs, mask_seqlen, dtype=torch.long) - data = {"routed_experts": re, "attention_mask": attn_mask} - log_moe_routing_metrics(data, scope="test_moe") - - def test_longer_mask_falls_back_to_all_ones(self): - from areal.trainer.ppo.actor_r3_patch import log_moe_routing_metrics - - bs, re_seqlen, num_layers, topk = 2, 10, 2, 3 - mask_seqlen = 20 - re = torch.randint(1, 64, (bs, re_seqlen, num_layers, topk)) - attn_mask = torch.ones(bs, mask_seqlen, dtype=torch.long) - data = {"routed_experts": re, "attention_mask": attn_mask} - log_moe_routing_metrics(data, scope="test_moe") - - def test_no_mask_falls_back_to_all_ones(self): - from areal.trainer.ppo.actor_r3_patch import log_moe_routing_metrics - - bs, seq_len, num_layers, topk = 2, 10, 2, 3 - re = torch.randint(1, 64, (bs, seq_len, num_layers, topk)) - data = {"routed_experts": re} - log_moe_routing_metrics(data, scope="test_moe") - - def test_none_routed_experts_returns_early(self): - from areal.trainer.ppo.actor_r3_patch import log_moe_routing_metrics - - data = {"attention_mask": torch.ones(2, 10)} - log_moe_routing_metrics(data, scope="test_moe") - - def test_low_dim_routed_experts_returns_early(self): - from areal.trainer.ppo.actor_r3_patch import log_moe_routing_metrics - - data = {"routed_experts": torch.randint(1, 64, (2, 10))} - log_moe_routing_metrics(data, scope="test_moe") From e8296a0009d5a25084cd4972f1b4e688505f1cda Mon Sep 17 00:00:00 2001 From: bingyechen Date: Wed, 29 Apr 2026 15:57:01 +0800 Subject: [PATCH 074/112] test(r3_mask_alignment): fix test --- tests/test_r3_mask_alignment.py | 99 +++++++++++++++------------------ 1 file changed, 44 insertions(+), 55 deletions(-) diff --git a/tests/test_r3_mask_alignment.py b/tests/test_r3_mask_alignment.py index 63b8b9bbc2..56e574cc0d 100644 --- a/tests/test_r3_mask_alignment.py +++ b/tests/test_r3_mask_alignment.py @@ -5,87 +5,76 @@ class TestAlignRoutedExpertsToMask: - """Tests for _align_routed_experts_to_mask: seq and batch alignment.""" - - def _make_left_padded_re(self, bs, seqlen, num_layers=2, topk=3, pad_val=0): - re = torch.randint(1, 64, (bs, seqlen, num_layers, topk)) - re[:, 0, :, :] = pad_val - re[:, 1, :, :] = pad_val - return re + """Tests for _align_routed_experts_to_mask: cu_seqlens-based alignment.""" def test_same_shape_no_change(self): routed_experts = torch.randint(1, 64, (4, 10, 2, 3)) - attention_mask = torch.ones(4, 10, dtype=torch.long) - result = _align_routed_experts_to_mask(routed_experts, attention_mask) + cu_seqlens = torch.tensor([0, 10, 20, 30, 40], dtype=torch.long) + max_seqlen = 10 + result = _align_routed_experts_to_mask(routed_experts, cu_seqlens, max_seqlen) torch.testing.assert_close(result, routed_experts) def test_seq_dim_shorter_right_pad(self): routed_experts = torch.randint(1, 64, (2, 5, 2, 3)) - attention_mask = torch.ones(2, 8, dtype=torch.long) - result = _align_routed_experts_to_mask(routed_experts, attention_mask) + cu_seqlens = torch.tensor([0, 5, 10], dtype=torch.long) + max_seqlen = 8 + result = _align_routed_experts_to_mask(routed_experts, cu_seqlens, max_seqlen) assert result.shape == (2, 8, 2, 3) torch.testing.assert_close(result[:, :5, :, :], routed_experts) assert (result[:, 5:, :, :] == 0).all() - def test_seq_dim_longer_left_padded_to_left_aligned(self): - bs, re_seqlen, mask_seqlen, num_layers, topk = 3, 10, 6, 2, 3 - routed_experts = torch.zeros(bs, re_seqlen, num_layers, topk, dtype=torch.long) - routed_experts[:, 4:, :, :] = torch.randint(1, 64, (bs, 6, num_layers, topk)) - attention_mask = torch.ones(bs, mask_seqlen, dtype=torch.long) - result = _align_routed_experts_to_mask(routed_experts, attention_mask) - assert result.shape == (bs, mask_seqlen, num_layers, topk) - torch.testing.assert_close(result, routed_experts[:, 4:, :, :]) - - def test_seq_dim_longer_with_varying_lengths(self): - bs, re_seqlen, mask_seqlen, num_layers, topk = 2, 10, 8, 2, 3 - routed_experts = torch.zeros(bs, re_seqlen, num_layers, topk, dtype=torch.long) - routed_experts[0, 3:, :, :] = torch.randint(1, 64, (7, num_layers, topk)) - routed_experts[1, 5:, :, :] = torch.randint(1, 64, (5, num_layers, topk)) - attention_mask = torch.zeros(bs, mask_seqlen, dtype=torch.long) - attention_mask[0, :7] = 1 - attention_mask[1, :5] = 1 - result = _align_routed_experts_to_mask(routed_experts, attention_mask) - assert result.shape == (bs, mask_seqlen, num_layers, topk) - torch.testing.assert_close(result[0, :7, :, :], routed_experts[0, 3:, :, :]) - torch.testing.assert_close(result[1, :5, :, :], routed_experts[1, 5:, :, :]) + def test_varying_seq_lens(self): + bs, re_seqlen, num_layers, topk = 2, 10, 2, 3 + routed_experts = torch.randint(1, 64, (bs, re_seqlen, num_layers, topk)) + cu_seqlens = torch.tensor([0, 7, 12], dtype=torch.long) + max_seqlen = 8 + result = _align_routed_experts_to_mask(routed_experts, cu_seqlens, max_seqlen) + assert result.shape == (2, 8, 2, 3) + torch.testing.assert_close(result[0, :7, :, :], routed_experts[0, :7, :, :]) + torch.testing.assert_close(result[1, :5, :, :], routed_experts[1, :5, :, :]) assert (result[0, 7:, :, :] == 0).all() assert (result[1, 5:, :, :] == 0).all() def test_batch_dim_smaller_padded(self): routed_experts = torch.randint(1, 64, (3, 8, 2, 3)) - attention_mask = torch.ones(5, 8, dtype=torch.long) - result = _align_routed_experts_to_mask(routed_experts, attention_mask) - assert result.shape == (5, 8, 2, 3) + cu_seqlens = torch.tensor([0, 8, 16, 24, 32], dtype=torch.long) + max_seqlen = 8 + result = _align_routed_experts_to_mask(routed_experts, cu_seqlens, max_seqlen) + assert result.shape == (4, 8, 2, 3) torch.testing.assert_close(result[:3, :, :, :], routed_experts) assert (result[3:, :, :, :] == 0).all() def test_batch_dim_larger_truncated(self): routed_experts = torch.randint(1, 64, (5, 8, 2, 3)) - attention_mask = torch.ones(3, 8, dtype=torch.long) - result = _align_routed_experts_to_mask(routed_experts, attention_mask) - assert result.shape == (3, 8, 2, 3) - torch.testing.assert_close(result, routed_experts[:3]) + cu_seqlens = torch.tensor([0, 8, 16], dtype=torch.long) + max_seqlen = 8 + result = _align_routed_experts_to_mask(routed_experts, cu_seqlens, max_seqlen) + assert result.shape == (2, 8, 2, 3) + torch.testing.assert_close(result, routed_experts[:2]) def test_both_batch_and_seq_mismatch(self): - routed_experts = torch.zeros(2, 10, 2, 3, dtype=torch.long) - routed_experts[:, 4:, :, :] = torch.randint(1, 64, (2, 6, 2, 3)) - attention_mask = torch.ones(4, 6, dtype=torch.long) - result = _align_routed_experts_to_mask(routed_experts, attention_mask) + routed_experts = torch.randint(1, 64, (2, 10, 2, 3)) + cu_seqlens = torch.tensor([0, 6, 12, 18, 24], dtype=torch.long) + max_seqlen = 6 + result = _align_routed_experts_to_mask(routed_experts, cu_seqlens, max_seqlen) assert result.shape == (4, 6, 2, 3) - torch.testing.assert_close(result[:2, :, :, :], routed_experts[:, 4:, :, :]) + torch.testing.assert_close(result[:2, :, :, :], routed_experts[:, :6, :, :]) assert (result[2:, :, :, :] == 0).all() - def test_empty_attention_mask_same_seqlen(self): + def test_zero_len_sequences(self): routed_experts = torch.randint(1, 64, (2, 8, 2, 3)) - attention_mask = torch.zeros(2, 8, dtype=torch.long) - result = _align_routed_experts_to_mask(routed_experts, attention_mask) + cu_seqlens = torch.tensor([0, 0, 8], dtype=torch.long) + max_seqlen = 8 + result = _align_routed_experts_to_mask(routed_experts, cu_seqlens, max_seqlen) assert result.shape == (2, 8, 2, 3) - torch.testing.assert_close(result, routed_experts) + assert (result[0] == 0).all() + torch.testing.assert_close(result[1], routed_experts[1]) - def test_empty_attention_mask_longer_re_seqlen(self): - routed_experts = torch.zeros(2, 10, 2, 3, dtype=torch.long) - routed_experts[:, 4:, :, :] = torch.randint(1, 64, (2, 6, 2, 3)) - attention_mask = torch.zeros(2, 6, dtype=torch.long) - result = _align_routed_experts_to_mask(routed_experts, attention_mask) - assert result.shape == (2, 6, 2, 3) - assert (result == 0).all() + def test_max_seqlen_larger_than_re_seqlen(self): + routed_experts = torch.randint(1, 64, (2, 5, 2, 3)) + cu_seqlens = torch.tensor([0, 5, 10], dtype=torch.long) + max_seqlen = 10 + result = _align_routed_experts_to_mask(routed_experts, cu_seqlens, max_seqlen) + assert result.shape == (2, 10, 2, 3) + torch.testing.assert_close(result[:, :5, :, :], routed_experts) + assert (result[:, 5:, :, :] == 0).all() From 623526784bae46338c90a3ff33bfd93b66797929 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Wed, 29 Apr 2026 18:15:11 +0800 Subject: [PATCH 075/112] test(router_replay): fix test --- tests/test_router_replay_e2e.py | 126 +++++++++++++++++--------------- 1 file changed, 66 insertions(+), 60 deletions(-) diff --git a/tests/test_router_replay_e2e.py b/tests/test_router_replay_e2e.py index 047878f4fd..37f81d8fd6 100644 --- a/tests/test_router_replay_e2e.py +++ b/tests/test_router_replay_e2e.py @@ -17,11 +17,12 @@ finite and non-zero when R3 is enabled. These tests are marked ``slow``/``multi_gpu`` and will be skipped in CI -by default; run with ``pytest -m multi_gpu -k r3_e2e``. +by default; run with ``pytest -m multi_gpu -k r3_e2e -s``. """ from __future__ import annotations +import os import subprocess import pytest @@ -30,40 +31,73 @@ from areal.infra.platforms import current_platform from areal.utils.network import find_free_ports +MODEL_LOCAL_PATHS = { + "moonlight": "/workspace/models/Moonlight-16B-A3B-Instruct", + "qwen3moe": "/storage/openpsi/models/Qwen__Qwen3-30B-A3B/", +} + +_TIMEOUTS = { + "patch_plumbing": 300, + "forward_replay": 600, + "forward_backward": 600, +} + + +def _model_available(model_type: str) -> bool: + local = MODEL_LOCAL_PATHS.get(model_type, "") + return bool(local) and os.path.exists(local) + def _run_e2e( model_type: str, alloc_mode: str, test_type: str, output: str, - timeout_sec: int = 1800, + timeout_sec: int | None = None, ): + if timeout_sec is None: + timeout_sec = _TIMEOUTS.get(test_type, 600) + port = find_free_ports(1)[0] n_gpus = ModelAllocation.from_str(alloc_mode).parallel.world_size + cmd = [ + "torchrun", + f"--nproc_per_node={n_gpus}", + "--nnodes=1", + "--master-addr=localhost", + f"--master_port={port}", + "tests/torchrun/run_router_replay_distributed.py", + f"--model_type={model_type}", + f"--backend={alloc_mode}", + f"--test_type={test_type}", + f"--output={output}", + ] + print(f"[r3-e2e] Launching (timeout={timeout_sec}s): {' '.join(cmd)}") try: - subprocess.run( - [ - "torchrun", - f"--nproc_per_node={n_gpus}", - "--nnodes=1", - "--master-addr=localhost", - f"--master_port={port}", - "tests/torchrun/run_router_replay_distributed.py", - f"--model_type={model_type}", - f"--backend={alloc_mode}", - f"--test_type={test_type}", - f"--output={output}", - ], - check=True, - capture_output=True, + proc = subprocess.Popen( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, text=True, - timeout=timeout_sec, ) - except subprocess.CalledProcessError as e: + stdout_lines = [] + for line in proc.stdout: + print(line, end="") + stdout_lines.append(line) + proc.wait(timeout=timeout_sec) + stdout = "".join(stdout_lines) + if proc.returncode != 0: + pytest.fail( + f"R3 E2E subprocess exited with code {proc.returncode}.\n" + f"OUTPUT:\n{stdout}" + ) + except subprocess.TimeoutExpired: + proc.kill() + proc.wait() + stdout = "".join(stdout_lines) if stdout_lines else "" pytest.fail( - f"R3 E2E subprocess failed.\n" - f"STDOUT:\n{e.stdout}\n" - f"STDERR:\n{e.stderr}" + f"R3 E2E subprocess timed out after {timeout_sec}s.\n" + f"OUTPUT:\n{stdout}" ) with open(output) as f: result = f.read().strip() @@ -84,6 +118,8 @@ def test_r3_e2e_moonlight_patch_plumbing(tmp_path_factory): """ if current_platform.device_count() < 4: pytest.skip("Moonlight R3 patch plumbing requires >= 4 GPUs") + if not _model_available("moonlight"): + pytest.skip("Moonlight model not available locally") out = tmp_path_factory.mktemp("r3") / "moonlight_patch.out" _run_e2e( model_type="moonlight", @@ -104,6 +140,8 @@ def test_r3_e2e_moonlight_forward_replay(tmp_path_factory): """ if current_platform.device_count() < 4: pytest.skip("Moonlight R3 forward replay requires >= 4 GPUs") + if not _model_available("moonlight"): + pytest.skip("Moonlight model not available locally") out = tmp_path_factory.mktemp("r3") / "moonlight_forward.out" _run_e2e( model_type="moonlight", @@ -125,6 +163,8 @@ def test_r3_e2e_moonlight_forward_backward(tmp_path_factory): """ if current_platform.device_count() < 4: pytest.skip("Moonlight R3 forward_backward requires >= 4 GPUs") + if not _model_available("moonlight"): + pytest.skip("Moonlight model not available locally") out = tmp_path_factory.mktemp("r3") / "moonlight_fb.out" _run_e2e( model_type="moonlight", @@ -150,6 +190,8 @@ def test_r3_e2e_moonlight_pp2_tp4_ep4(tmp_path_factory): """ if current_platform.device_count() < 8: pytest.skip("Moonlight R3 PP=2 config requires 8 GPUs") + if not _model_available("moonlight"): + pytest.skip("Moonlight model not available locally") out = tmp_path_factory.mktemp("r3") / "moonlight_pp2_patch.out" _run_e2e( model_type="moonlight", @@ -172,6 +214,8 @@ def test_r3_e2e_moonlight_pp2_tp4_ep4_forward_backward(tmp_path_factory): """ if current_platform.device_count() < 8: pytest.skip("Moonlight R3 PP=2 forward_backward requires 8 GPUs") + if not _model_available("moonlight"): + pytest.skip("Moonlight model not available locally") out = tmp_path_factory.mktemp("r3") / "moonlight_pp2_fb.out" _run_e2e( model_type="moonlight", @@ -180,41 +224,3 @@ def test_r3_e2e_moonlight_pp2_tp4_ep4_forward_backward(tmp_path_factory): output=str(out), ) - -# --------------------------------------------------------------------------- -# 4-GPU: Qwen3-30B-A3B fallback (runs if Moonlight weights are unavailable) -# --------------------------------------------------------------------------- - - -@pytest.mark.multi_gpu -@pytest.mark.slow -def test_r3_e2e_qwen3moe_fallback(tmp_path_factory): - """Fallback path: Qwen3-30B-A3B MoE when Moonlight is unavailable. - - Same 4-GPU TP=2 CP=2 EP=4 layout as ``test_qwen3moe_expert_parallel`` - in ``test_megatron_engine_distributed.py``. - """ - if current_platform.device_count() < 4: - pytest.skip("Qwen3 MoE R3 requires >= 4 GPUs") - out = tmp_path_factory.mktemp("r3") / "qwen3moe_fallback.out" - _run_e2e( - model_type="qwen3moe", - alloc_mode="megatron:(attn:d1p1t2c2|ffn:d1p1t1e4)", - test_type="patch_plumbing", - output=str(out), - ) - - -@pytest.mark.multi_gpu -@pytest.mark.slow -def test_r3_e2e_qwen3moe_forward_backward(tmp_path_factory): - """Qwen3-30B-A3B MoE R3 full train_batch fallback (4 GPUs).""" - if current_platform.device_count() < 4: - pytest.skip("Qwen3 MoE R3 forward_backward requires >= 4 GPUs") - out = tmp_path_factory.mktemp("r3") / "qwen3moe_fb.out" - _run_e2e( - model_type="qwen3moe", - alloc_mode="megatron:(attn:d1p1t2c2|ffn:d1p1t1e4)", - test_type="forward_backward", - output=str(out), - ) From 200addcf86fa917aab3f5dd2673ec80b9fbf19cb Mon Sep 17 00:00:00 2001 From: bingyechen Date: Wed, 29 Apr 2026 18:24:25 +0800 Subject: [PATCH 076/112] refactor(tests): get model --- .../torchrun/run_router_replay_distributed.py | 48 +++++++++++-------- 1 file changed, 29 insertions(+), 19 deletions(-) diff --git a/tests/torchrun/run_router_replay_distributed.py b/tests/torchrun/run_router_replay_distributed.py index d60d6e50de..890311803a 100644 --- a/tests/torchrun/run_router_replay_distributed.py +++ b/tests/torchrun/run_router_replay_distributed.py @@ -33,8 +33,6 @@ import torch.distributed as dist from megatron.core import parallel_state as mpu -from tests.utils import get_model_path - from areal.api import FinetuneSpec from areal.api.alloc_mode import ModelAllocation from areal.api.cli_args import ( @@ -45,18 +43,30 @@ ) from areal.engine import MegatronEngine from areal.infra.platforms import current_platform -from areal.utils import seeding +from areal.utils import logging, seeding from areal.utils.data import broadcast_tensor_container +logger = logging.getLogger("R3E2E") + + +def _get_model_path(local_path: str, hf_id: str) -> str: + if os.path.exists(local_path): + logger.info("Model found at local path: %s", local_path) + return local_path + from huggingface_hub import snapshot_download + + logger.info("Downloading model from HuggingFace Hub: %s", hf_id) + return snapshot_download( + repo_id=hf_id, + ignore_patterns=["*.gguf", "*.ggml", "consolidated*"], + ) + MODEL_PATHS = { - "moonlight": get_model_path( - "/storage/openpsi/models/Moonshot__Moonlight-16B-A3B-Instruct/", + "moonlight": _get_model_path( + "/workspace/models/Moonlight-16B-A3B-Instruct/", "moonshotai/Moonlight-16B-A3B-Instruct", - ), - "qwen3moe": get_model_path( - "/storage/openpsi/models/Qwen__Qwen3-30B-A3B/", "Qwen/Qwen3-30B-A3B" - ), + ) } @@ -233,9 +243,9 @@ def test_patch_plumbing(model_type: str, backend: str, output: str | None): expected = _collect_num_moe_layers(engine) got = len(RouterReplay.router_instances) - print( - f"[r3-e2e] rank={rank} expected_moe_layers={expected} " - f"got_router_instances={got}" + logger.info( + "[R3-E2E] rank=%d expected_moe_layers=%d got_router_instances=%d", + rank, expected, got, ) assert got == expected, ( f"RouterReplay.router_instances count ({got}) must match the " @@ -282,7 +292,7 @@ def test_forward_replay(model_type: str, backend: str, output: str | None): try: # Resolve MoE metadata from the model config (same path rl_trainer uses). num_moe, topk = resolve_r3_moe_config(MODEL_PATHS[model_type]) - print(f"[r3-e2e] rank={rank} num_moe_layers={num_moe} topk={topk}") + logger.info("[R3-E2E] rank=%d num_moe_layers=%d topk=%d", rank, num_moe, topk) # Build a synthetic routed_experts tensor with right-padding matching # the rollout convention: (bs, seqlen, num_moe_layers, topk). @@ -325,7 +335,7 @@ def test_forward_replay(model_type: str, backend: str, output: str | None): if rank == 0 and output: write_result(output, True) except Exception as e: # pragma: no cover - surfaced as torchrun failure - print(f"[r3-e2e] rank={rank} FAIL: {e!r}") + logger.error("[R3-E2E] rank=%d FAIL: %r", rank, e) if rank == 0 and output: write_result(output, False, repr(e)) raise @@ -378,8 +388,8 @@ def test_forward_backward(model_type: str, backend: str, output: str | None): try: # Resolve MoE metadata from the model config. num_moe, topk = resolve_r3_moe_config(MODEL_PATHS[model_type]) - print( - f"[r3-e2e-fb] rank={rank} num_moe_layers={num_moe} topk={topk}" + logger.info( + "[R3-E2E-FB] rank=%d num_moe_layers=%d topk=%d", rank, num_moe, topk ) # Build training input + rollout_expert_indices. @@ -431,7 +441,7 @@ def test_forward_backward(model_type: str, backend: str, output: str | None): # Rank-0 side asserts the loss on the last pipeline stage. # stats is a dict[str, float]; the key name may vary, so check any. - print(f"[r3-e2e-fb] rank={rank} train_batch stats={stats}") + logger.info("[R3-E2E-FB] rank=%d train_batch stats=%s", rank, stats) # Side-channel must have been consumed. assert engine._r3_pending_routed_experts is None, ( @@ -455,7 +465,7 @@ def test_forward_backward(model_type: str, backend: str, output: str | None): if rank == 0 and output: write_result(output, True) except Exception as e: # pragma: no cover - surfaced as torchrun failure - print(f"[r3-e2e-fb] rank={rank} FAIL: {e!r}") + logger.error("[R3-E2E-FB] rank=%d FAIL: %r", rank, e) if rank == 0 and output: write_result(output, False, repr(e)) raise @@ -485,7 +495,7 @@ def main(): ) parser.add_argument("--output", type=str, default=None) args = parser.parse_args() - print(args) + logger.info("Args: %s", args) if args.test_type == "forward_replay": test_forward_replay(args.model_type, args.backend, args.output) From 2e2ef9979ebbafbd4b8f6dfd20d9808d60a22681 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Wed, 29 Apr 2026 18:35:44 +0800 Subject: [PATCH 077/112] fix: fix test --- tests/torchrun/run_router_replay_distributed.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/torchrun/run_router_replay_distributed.py b/tests/torchrun/run_router_replay_distributed.py index 890311803a..9050bf6d44 100644 --- a/tests/torchrun/run_router_replay_distributed.py +++ b/tests/torchrun/run_router_replay_distributed.py @@ -425,7 +425,7 @@ def test_forward_backward(model_type: str, backend: str, output: str | None): m2_threshold=None, importance_sampling_level="token", current_version=0, - prox_logp_method=None, + prox_logp_method="recompute", use_sapo_loss=False, sapo_tau_pos=0.0, sapo_tau_neg=0.0, From 9e25cd9282bffb8aa66db1b555b91a0e318405ea Mon Sep 17 00:00:00 2001 From: bingyechen Date: Wed, 29 Apr 2026 18:42:33 +0800 Subject: [PATCH 078/112] test(torchrun): add prox_logp --- tests/torchrun/run_router_replay_distributed.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/torchrun/run_router_replay_distributed.py b/tests/torchrun/run_router_replay_distributed.py index 9050bf6d44..8fda425d0f 100644 --- a/tests/torchrun/run_router_replay_distributed.py +++ b/tests/torchrun/run_router_replay_distributed.py @@ -213,6 +213,9 @@ def _build_training_input_with_rollout_experts( # grpo_loss_fn reads input_data['logprobs'] as the old (rollout) log-prob. "logprobs": rollout_logprobs, "advantages": advantages, + # prox_logp is required when prox_logp_method='recompute'. + # In standard GRPO, prox_logp equals the rollout log-probabilities. + "prox_logp": rollout_logprobs.clone(), } return input_dict, rollout_expert_indices From 86951dbc60440b0dc8e973cb790bfe99048a7ee1 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Wed, 29 Apr 2026 19:00:16 +0800 Subject: [PATCH 079/112] docs: remove useless --- docs/en/algorithms/router_replay.md | 159 ---------------------------- docs/zh/algorithms/router_replay.md | 133 ----------------------- 2 files changed, 292 deletions(-) delete mode 100644 docs/en/algorithms/router_replay.md delete mode 100644 docs/zh/algorithms/router_replay.md diff --git a/docs/en/algorithms/router_replay.md b/docs/en/algorithms/router_replay.md deleted file mode 100644 index 0bc020515a..0000000000 --- a/docs/en/algorithms/router_replay.md +++ /dev/null @@ -1,159 +0,0 @@ -# Rollout Routing Replay (R3) for MoE Models - -Last updated: Apr 29, 2026 - -## Overview - -In asynchronous RL for Mixture-of-Experts (MoE) models, the policy that generates -rollouts (served by SGLang) and the policy that is being trained (driven by -Megatron-LM) may differ by one or more parameter versions. Since the router is a -*learned* sparse gate, even small weight drift can send the same token to different -experts between inference and training, producing a **train/inference routing -mismatch** that corrupts importance-sampling ratios and destabilises optimisation. - -**Rollout Routing Replay (R3)** eliminates this mismatch by: - -1. Recording the per-token expert assignments emitted by the inference engine for - every decoded token. -2. Re-using (*replaying*) those exact expert assignments during the training - forward / backward pass in place of the routing computed from current weights. - -R3 is inspired by the implementation in -[verl](https://github.com/volcengine/verl) and has been adapted for AReaL's -Megatron backend + SGLang bridge-mode inference service. - -## Supported Configurations - -| Dimension | Supported | Notes | -|---|---|---| -| Training backend | Megatron-LM (`MegatronEngine`) | FSDP engine is **not** supported. | -| Inference backend | SGLang 0.5.9 (bridge mode) | vLLM not supported. | -| Tensor Parallel (**TP**) | ✅ | Uses `scatter_to_sequence_parallel_region` to distribute packed router indices to SP ranks. | -| Expert Parallel (**EP**) | ✅ | Patched `MoEAlltoAllTokenDispatcher.preprocess` recomputes `num_out_tokens = routing_map.sum()` so that the dropless path stays correct when replay zeroes padding rows. | -| Pipeline Parallel (**PP**) | ✅ | `RouterReplayHelper.get_micro_batch_router_list` slices `RouterReplay.router_instances` according to the current PP rank's `(layer_offset, num_layers)`. | -| Virtual Pipeline Parallel (**VPP**) | ✅ | Same helper honours `virtual_pipeline_model_parallel_size` and iterates over VP stages. | -| Context Parallel (**CP**) | ⚠️ Experimental | `seq_align_to = tp_size * cp_size * 2` is applied when `cp_size > 1`; exercised only via unit tests, not covered by the provided E2E fixtures. | -| Data Parallel (**DP**) | ✅ | R3 runs independently per DP replica; no cross-DP communication is added. | -| Dense + MoE hybrid layers | ✅ | `is_moe_layer()` uses `moe_layer_freq` / `first_k_dense_replace` so dense layers are skipped from replay. | -| Role | Actor only | `config.actor.megatron.enable_router_replay` is set exclusively on the actor engine; Critic / Reference / Teacher engines are unaffected. | -| Capacity factor | `moe_expert_capacity_factor is None` (dropless) | Replay only overrides `num_out_tokens` on the dropless path, matching verl's guard. | -| FP8 / quantisation padding | ❌ | Replay is skipped when `moe_router_padding_for_fp8` or `moe_router_padding_for_quantization` is enabled to preserve FP8 dispatch correctness. | -| Vision / multimodal models | ❌ | No hooks in the VLM path. | - -## How to Enable R3 - -R3 is driven by a single rollout flag; everything else is wired automatically by -`areal/trainer/rl_trainer.py`. - -```yaml -rollout: - # Request per-token routed expert indices from SGLang. - return_routed_experts: true - -actor: - backend: "megatron:(attn:d1p1t4|ffn:d1p1t1e4)" # TP=4, EP=4 - # actor.megatron.enable_router_replay is forced to True - # automatically when rollout.return_routed_experts=true. - -sglang: - # R3 relies on per-token tokens being aligned with the routing - # output. The trainer forces skip_tokenizer_init=True at startup; - # declaring it here makes the intent explicit. - skip_tokenizer_init: true - enable_return_routed_experts: true -``` - -At trainer startup (`RLTrainer.__init__`): - -1. `rollout.return_routed_experts=True` causes - `config.actor.megatron.enable_router_replay` to be set to `True`. -2. `num_moe_layers` and `topk` are auto-resolved from the HuggingFace config - (`num_experts_per_tok`, `num_hidden_layers`, `moe_layer_freq`, - `first_k_dense_replace`) by `resolve_r3_moe_config()`. -3. `sglang.skip_tokenizer_init` is forced to `True` (warning printed if the user - set it to `False`) to prevent tokenizer round-trip token shifts that would - break per-token routing alignment. -4. The SGLang bridge entrypoint - (`areal/experimental/inference_service/sglang/launch_server.py`) calls - `apply_sglang_r3_patch()` so that `TokenizerManager._handle_batch_output` - base64-encodes the `routed_experts` tensor before FastAPI serialisation. -5. On the training side, `MegatronEngine.initialize()` calls - `apply_router_replay_patch()` (monkey-patches `TransformerConfig.__init__`, - `TopKRouter.__init__`, `TopKRouter.routing` and - `MoEAlltoAllTokenDispatcher.preprocess`) **before** model creation, and then - wraps the engine with `patch_megatron_engine_for_r3()`. - -## Pipeline Overview - -``` -┌────────────────────────────┐ ┌──────────────────────────────┐ -│ SGLang inference server │ │ MegatronEngine (actor) │ -│ │ │ │ -│ generate_logprobs() │ │ forward_backward_batch() │ -│ └─ routed_experts tensor │───▶ │ ├─ REPLAY_FORWARD on │ -│ (base64 over HTTP) │ │ │ each microbatch │ -└────────────────────────────┘ │ ├─ post-forward hook switches │ - │ │ to REPLAY_BACKWARD │ - │ └─ clear_router_replay() │ - └──────────────────────────────┘ -``` - -### Key data structures - -| Object | Purpose | -|---|---| -| `RouterReplay` (per MoE layer) | Holds the replay indices (`target_topk_idx`), recording buffer (`recorded_topk_idx`), and current `RouterReplayAction`. | -| `RouterReplay.router_instances` | Class-level list, one entry per MoE layer *on the local rank*. Cleared each time `apply_router_replay_patch()` is called. | -| `RouterReplayAction` | Enum: `RECORD`, `REPLAY_FORWARD`, `REPLAY_BACKWARD`. | -| `RouterReplayHelper.get_micro_batch_router_list()` | Returns the subset of `router_instances` assigned to the current `(pp_rank, vp_stage)`. | -| `setup_per_microbatch_replay_forward()` | Called before each micro-batch forward: aligns rollout-format `routed_experts` to the training token layout, packs with `cu_seqlens`, scatters to SP ranks, and distributes to the per-layer `RouterReplay` instances. | - -### Correctness notes - -* **`num_out_tokens` override.** Megatron-Core 0.16.0's dropless branch of - `MoEAlltoAllTokenDispatcher.preprocess` sets - `num_out_tokens = routing_map.size(0) * moe_router_topk` as a static value. - When R3 zeroes padding rows in `routing_map`, that static value overcounts, - so the patched preprocess computes `num_out_tokens = int(routing_map.sum().item())` - on the dropless path. The ~3,500 `.item()` syncs per training step are - negligible compared to MoE compute. -* **Per-instance `__class__` swap.** The micro-batch iterator wraps - `MicroBatchList` with a dynamic subclass assigned via `mb_list.__class__`, - not by mutating the shared class, so concurrent engines (e.g. critic) are - not affected. -* **Left-align from right-padded rollouts.** `_align_routed_experts_to_mask()` - converts the rollout tensor from `(bs, batch_max_seqlen, L, K)` right-padded - format to a training-oriented left-aligned layout using `cu_seqlens`. -* **Silent-drop removed.** When a micro-batch cannot be exactly split by - `bs // n_mbs`, R3 raises instead of silently trimming rows. - -## Minimal Example - -See `examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron.yaml` for the reference -Moonlight-16B-A3B configuration (PP=2, TP=4, EP=4, 8 GPUs). Launch: - -```bash -python3 examples/math/gsm8k_rl.py \ - --config examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron.yaml \ - scheduler.type=local -``` - -On a single-node 8×H200 system the `*_h20.yaml` variant runs with PP=1, -TP=4, EP=4 and `max_tokens_per_mb=10240`. - -## Troubleshooting - -| Symptom | Cause | Fix | -|---|---|---| -| `[R3] Number of replay tensors (...) does not match number of router instances (...)` | MoE layer count resolved from HF config differs from Megatron's per-rank layer count (usually due to `first_k_dense_replace` / `moe_layer_freq` mismatch or custom pipeline layout). | Verify `num_hidden_layers`, `first_k_dense_replace`, and `moe_layer_freq` in the model's `config.json` and that `pipeline_model_parallel_layout` (if set) matches the MoE layer count. | -| SGLang returns `routed_experts: {}` (empty dict) | Inference server was started without the R3 patch. | Ensure you are using the bridge entrypoint `areal.experimental.inference_service.sglang.launch_server`; it installs `apply_sglang_r3_patch()` automatically. | -| `moe_router_padding_for_fp8=True` + R3 | R3 is intentionally disabled on FP8 padding paths. | Either turn off FP8 router padding or disable `rollout.return_routed_experts`. | -| Critic does not pick up R3 | By design; only the actor is patched. | If a future use-case needs MoE critic replay, extend `rl_trainer._amend_xccl_weight_update_envvar` and `MegatronEngine._r3_enabled` plumbing. | - -## References - -* PR [#1207](https://github.com/inclusionAI/AReaL/pull/1207) — `[WIP]feat: add router replay for megatron engine`. -* verl router replay: - [`volcengine/verl`](https://github.com/volcengine/verl) (`verl/workers/**/*router_replay*`). -* Megatron-Core MoE parallel folding: - [NVIDIA/Megatron-LM MoE README](https://github.com/NVIDIA/Megatron-LM/tree/main/megatron/core/transformer/moe). diff --git a/docs/zh/algorithms/router_replay.md b/docs/zh/algorithms/router_replay.md deleted file mode 100644 index 9073496880..0000000000 --- a/docs/zh/algorithms/router_replay.md +++ /dev/null @@ -1,133 +0,0 @@ -# MoE 路由回放(R3, Rollout Routing Replay) - -最后更新:2026-04-29 - -## 背景 - -在 MoE 模型的异步强化学习中,负责 rollout 采样(SGLang)的策略与正在训练的 -策略(Megatron-LM)往往相差一个或多个版本。由于 MoE 路由器是"学习到的稀疏 -门控",微小的权重漂移都会让同一个 token 在推理和训练时被送到不同的专家, -造成 **训练/推理路由不一致**,进而破坏 importance sampling 的比值、导致优化 -不稳定。 - -**Rollout Routing Replay(R3)** 通过以下两个步骤消除这种不一致: - -1. **记录**:在推理阶段记录每个 token 的专家分配结果。 -2. **回放**:在训练前/反向阶段使用完全相同的专家分配替换由当前权重计算得到的 - 路由结果。 - -R3 参考了 [verl](https://github.com/volcengine/verl) 的实现,并在 AReaL -仓库中被适配到 Megatron 训练后端 + SGLang bridge 模式的推理服务。 - -## 支持矩阵 - -| 维度 | 是否支持 | 说明 | -|---|---|---| -| 训练后端 | Megatron-LM(`MegatronEngine`) | 不支持 FSDP。 | -| 推理后端 | SGLang 0.5.9(bridge 模式) | 不支持 vLLM。 | -| 张量并行(TP) | ✅ | 通过 `scatter_to_sequence_parallel_region` 把打包后的路由索引分发到 SP 各 rank。 | -| 专家并行(EP) | ✅ | 补丁后的 `MoEAlltoAllTokenDispatcher.preprocess` 改用 `num_out_tokens = routing_map.sum()`,保证回放清零 padding 行后 dropless 路径依然正确。 | -| 流水并行(PP) | ✅ | `RouterReplayHelper.get_micro_batch_router_list` 根据当前 PP rank 的 `(layer_offset, num_layers)` 对 `RouterReplay.router_instances` 切片。 | -| 虚拟流水并行(VPP) | ✅ | 同一个 helper 会遍历 `virtual_pipeline_model_parallel_size` 指定的各 VP stage。 | -| 上下文并行(CP) | ⚠️ 实验性 | 当 `cp_size > 1` 时使用 `seq_align_to = tp_size * cp_size * 2`;本次只覆盖单元测试,端到端尚未验证。 | -| 数据并行(DP) | ✅ | 每个 DP 副本独立运行 R3,不引入跨 DP 通信。 | -| Dense + MoE 混合层 | ✅ | `is_moe_layer()` 使用 `moe_layer_freq` / `first_k_dense_replace` 识别并跳过 dense 层。 | -| 角色 | 仅 Actor | `config.actor.megatron.enable_router_replay` 仅在 actor 上被置为 True,Critic / Ref / Teacher 不受影响。 | -| Capacity factor | 仅 `moe_expert_capacity_factor is None`(dropless) | 和 verl 的 guard 一致,`num_out_tokens` 覆盖仅作用于 dropless 分支。 | -| FP8 / 量化 padding | ❌ | 当 `moe_router_padding_for_fp8` 或 `moe_router_padding_for_quantization` 开启时跳过 R3,以保持 FP8 dispatch 正确性。 | -| 视觉 / 多模态模型 | ❌ | VLM 路径未接入钩子。 | - -## 如何开启 - -R3 由单一 rollout 开关驱动,其余都会在 `areal/trainer/rl_trainer.py` 中自动串起来。 - -```yaml -rollout: - return_routed_experts: true # 让 SGLang 返回每 token 路由索引 - -actor: - backend: "megatron:(attn:d1p1t4|ffn:d1p1t1e4)" # TP=4, EP=4 - # actor.megatron.enable_router_replay 会被自动设为 True - -sglang: - # R3 需要保证 token 序列与路由结果对齐;trainer 强制 - # skip_tokenizer_init=True,这里显式声明以表明意图。 - skip_tokenizer_init: true - enable_return_routed_experts: true -``` - -启动后的自动串联逻辑: - -1. `rollout.return_routed_experts=True` 令 - `config.actor.megatron.enable_router_replay = True`。 -2. `num_moe_layers` / `topk` 由 `resolve_r3_moe_config()` 从 HF config( - `num_experts_per_tok`、`num_hidden_layers`、`moe_layer_freq`、 - `first_k_dense_replace`)自动解析。 -3. `sglang.skip_tokenizer_init` 被强制置为 `True`(若用户设为 False 会打印 - warning),以避免 tokenizer 往返造成的 token shift 破坏对齐。 -4. SGLang bridge 入口 - (`areal/experimental/inference_service/sglang/launch_server.py`) - 在启动时调用 `apply_sglang_r3_patch()`,让 - `TokenizerManager._handle_batch_output` 在 FastAPI 序列化前把 - `routed_experts` 张量按 base64 编码。 -5. 训练侧 `MegatronEngine.initialize()` 在模型构造 **之前** 调用 - `apply_router_replay_patch()`(monkey-patch `TransformerConfig.__init__`、 - `TopKRouter.__init__`、`TopKRouter.routing` 与 - `MoEAlltoAllTokenDispatcher.preprocess`),然后通过 - `patch_megatron_engine_for_r3()` 包装 engine。 - -## 关键数据结构 - -| 对象 | 作用 | -|---|---| -| `RouterReplay`(每 MoE 层一个) | 保存回放目标索引 `target_topk_idx`、记录缓冲 `recorded_topk_idx` 与当前 `RouterReplayAction`。 | -| `RouterReplay.router_instances` | 类级列表,保存当前 rank 上的每一个 MoE 层实例,每次 `apply_router_replay_patch()` 都会 `clear()`。 | -| `RouterReplayAction` | 枚举:`RECORD` / `REPLAY_FORWARD` / `REPLAY_BACKWARD`。 | -| `RouterReplayHelper.get_micro_batch_router_list()` | 返回当前 `(pp_rank, vp_stage)` 对应的 `router_instances` 切片。 | -| `setup_per_microbatch_replay_forward()` | 在每个 micro-batch 前向之前:把 rollout 格式的 `routed_experts` 对齐到训练 token 排布、按 `cu_seqlens` 打包、scatter 到 SP 各 rank、再分发到每一层 `RouterReplay`。 | - -## 正确性要点 - -* **`num_out_tokens` 覆盖**:Megatron-Core 0.16.0 在 dropless 分支下使用静态值 - `routing_map.size(0) * moe_router_topk`;当 R3 清零 padding 行后,静态值会 - 高估 token × topk 数量,因此补丁会在 dropless 分支用 - `int(routing_map.sum().item())` 覆盖。每 step 约 3500 次同步,相较 MoE - 计算完全可忽略。 -* **按实例 `__class__` 替换**:micro-batch 迭代器通过动态子类替换 `mb_list.__class__`, - 而不是修改共享类,因此并行存在的其他 engine(例如 critic)不会受影响。 -* **右填充 → 左对齐**:`_align_routed_experts_to_mask()` 根据 `cu_seqlens` 把 - rollout 的 `(bs, batch_max_seqlen, L, K)` 右填充张量转换到训练使用的左对齐 - 布局。 -* **显式校验**:micro-batch 无法被 `bs // n_mbs` 整除时直接抛错,而不是静默丢弃。 - -## 最小示例 - -参考 `examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron.yaml`(Moonlight-16B-A3B, -PP=2、TP=4、EP=4,8 卡)。启动: - -```bash -python3 examples/math/gsm8k_rl.py \ - --config examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron.yaml \ - scheduler.type=local -``` - -单机 8×H200 场景可使用 `*_h20.yaml` 变体(PP=1、TP=4、EP=4、 -`max_tokens_per_mb=10240`)。 - -## 常见问题 - -| 现象 | 原因 | 处理 | -|---|---|---| -| `[R3] Number of replay tensors (...) does not match number of router instances (...)` | HF config 中解析的 MoE 层数与 Megatron 的按 rank 切分层数不一致(多由 `first_k_dense_replace`、`moe_layer_freq`、自定义 pipeline layout 不一致引起)。 | 核对模型 `config.json` 中的 `num_hidden_layers`、`first_k_dense_replace`、`moe_layer_freq`,并确保自定义 `pipeline_model_parallel_layout` 与 MoE 层数一致。 | -| SGLang 返回 `routed_experts: {}`(空字典) | 推理服务未安装 R3 补丁。 | 确保使用 bridge 入口 `areal.experimental.inference_service.sglang.launch_server`(会自动调用 `apply_sglang_r3_patch()`)。 | -| 开启 `moe_router_padding_for_fp8=True` 后 R3 行为异常 | R3 在 FP8 padding 路径上被主动禁用。 | 关闭 FP8 router padding,或关闭 `rollout.return_routed_experts`。 | -| Critic 未生效 | 按设计只在 actor 上启用。 | 若后续需要 MoE critic 回放,需要扩展 `rl_trainer` 与 `MegatronEngine._r3_enabled` 的触发条件。 | - -## 参考资料 - -* PR [#1207](https://github.com/inclusionAI/AReaL/pull/1207) - `[WIP]feat: add router replay for megatron engine`。 -* verl R3 源码: - [`volcengine/verl`](https://github.com/volcengine/verl)。 -* Megatron-Core MoE 并行折叠: - [NVIDIA/Megatron-LM MoE README](https://github.com/NVIDIA/Megatron-LM/tree/main/megatron/core/transformer/moe)。 From 7d4ced37a56ab1abf4e87e64db945d8a3be5520b Mon Sep 17 00:00:00 2001 From: bingyechen Date: Wed, 29 Apr 2026 20:53:20 +0800 Subject: [PATCH 080/112] docs: remove --- docs/en/_toc.yml | 1 - docs/zh/_toc.yml | 1 - 2 files changed, 2 deletions(-) diff --git a/docs/en/_toc.yml b/docs/en/_toc.yml index edad8395a4..845032dd7d 100644 --- a/docs/en/_toc.yml +++ b/docs/en/_toc.yml @@ -38,7 +38,6 @@ parts: - file: algorithms/grpo_series - file: algorithms/m2po - file: algorithms/prox_approx - - file: algorithms/router_replay - caption: Reference chapters: - file: reference/checkpointing diff --git a/docs/zh/_toc.yml b/docs/zh/_toc.yml index dad5a7293c..16f9e7b713 100644 --- a/docs/zh/_toc.yml +++ b/docs/zh/_toc.yml @@ -38,7 +38,6 @@ parts: - file: algorithms/grpo_series - file: algorithms/m2po - file: algorithms/prox_approx - - file: algorithms/router_replay - caption: 参考 chapters: - file: reference/checkpointing From fc338b1848965426eedd99cd43a65ec3d296b491 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Wed, 29 Apr 2026 22:34:12 +0800 Subject: [PATCH 081/112] fix(trainer): remove useless code --- areal/trainer/rl_trainer.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/areal/trainer/rl_trainer.py b/areal/trainer/rl_trainer.py index 062c11ae91..c6dfe424bb 100644 --- a/areal/trainer/rl_trainer.py +++ b/areal/trainer/rl_trainer.py @@ -189,8 +189,6 @@ def __init__( OmegaConf.set_struct(sglang_cfg, False) sglang_cfg.skip_tokenizer_init = True - openai_cfg = config.rollout.openai - self._online_mode = openai_cfg is not None and openai_cfg.mode == "online" agent_cfg = config.rollout.agent self._online_mode = agent_cfg is not None and agent_cfg.mode == "online" From 5779c3a73a7273d403e1f6d0e4b42058cc3b27b8 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Thu, 30 Apr 2026 15:26:01 +0800 Subject: [PATCH 082/112] fix(ppo): add r3 for _compute_logp --- areal/trainer/ppo/actor.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/areal/trainer/ppo/actor.py b/areal/trainer/ppo/actor.py index 4f3bda7928..e2b37a5ff8 100644 --- a/areal/trainer/ppo/actor.py +++ b/areal/trainer/ppo/actor.py @@ -133,6 +133,22 @@ def compute_logp(self, data: list[dict[str, Any]]) -> list[torch.Tensor] | None: def _compute_logp(self, data: dict[str, Any]) -> torch.Tensor | None: self.engine.eval() + # R3: side-channel routed_experts into the engine so that the R3 + # forward_backward_batch wrapper replays routing decisions even in + # the forward_only (compute_logp) path. + _r3_routed_experts = data.pop("routed_experts", None) + if _r3_routed_experts is not None and not isinstance( + _r3_routed_experts, torch.Tensor + ): + from areal.trainer.ppo.actor_r3_patch import _resolve_to_tensor + _r3_routed_experts = _resolve_to_tensor(_r3_routed_experts) + if _r3_routed_experts is not None and getattr( + self.engine, "_r3_enabled", False + ): + # forward_batch performs ONE forward_backward_batch(forward_only=True) + # call internally; the R3 engine patch will split routed_experts per + # micro-batch and consume the side-channel (setting it back to None). + self.engine._r3_pending_routed_experts = _r3_routed_experts return self.engine.forward( input_=data, aggregate_fn=lambda xs: torch.cat(xs, dim=-1), From 4cc6544d11abdd2433dfa0b24ad52a01777d22f2 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Thu, 30 Apr 2026 20:09:57 +0800 Subject: [PATCH 083/112] refactor(ppo): fix metric --- areal/trainer/ppo/actor.py | 51 ++++++++++++++++++++++++++++++-------- 1 file changed, 41 insertions(+), 10 deletions(-) diff --git a/areal/trainer/ppo/actor.py b/areal/trainer/ppo/actor.py index e2b37a5ff8..af1e10685a 100644 --- a/areal/trainer/ppo/actor.py +++ b/areal/trainer/ppo/actor.py @@ -607,19 +607,50 @@ def grpo_loss_fn( ) # ---- R3 Logprob Diff: rollout (inference) vs training logprobs ---- - # compute |rollout_logprobs - training_logprobs| - # over response tokens only. This metric quantifies train/infer mismatch - # caused by MoE routing divergence. With R3 enabled, this diff should be - # smaller than without R3. + # Three metrics quantify train/infer divergence at different granularities: + # + # 1. ``rollout_train_logprobs_abs_diff`` (old_logp vs current-train logprobs) + # Conflates (a) train/infer gap AND (b) intra-ppo_update weight drift + # from earlier mini-batches. At ``max_head_offpolicyness>0`` this metric + # grows with mini-batch index because ``logprobs`` is forwarded on + # W_current (already stepped i times within the epoch), while + # ``old_logp`` stays fixed at W_rollout. R3 cannot reduce this drift. + # + # 2. ``rollout_train_logprobs_abs_diff_prox`` (old_logp vs prox_logp_gt) + # The *pure* train/infer mismatch: both evaluated on W_rollout, differing + # only in SGLang-vs-Megatron forward (fused kernel + router top-k). + # With R3 enabled and weights fully synced (k=0), this should collapse + # to BF16 kernel noise. With R3 off, ~2% per-expert router flips add + # ~0.1-0.2 abs-diff per token. This is the metric to validate R3 + # effectiveness across staleness settings. + # + # Both are logged via ``stats_tracker.stat`` with ``n_valid_tokens`` + # denominator so the aggregation is a proper token-weighted global + # mean/min/max (cross-mini-batch, cross-rank), not a Python-float mean + # of per-minibatch local scalars (which would be small-batch-biased and + # would report local stds as if they were global ones). if loss_mask.any(): with torch.no_grad(): - _logprob_diff = (old_logp[loss_mask] - logprobs.detach()[loss_mask]).abs() - _diff_mean = _logprob_diff.mean().item() - _diff_std = _logprob_diff.std().item() if _logprob_diff.numel() > 1 else 0.0 - stats_tracker.scalar( - rollout_train_logprobs_abs_diff_mean=_diff_mean, - rollout_train_logprobs_abs_diff_std=_diff_std, + _diff_abs = (old_logp - logprobs.detach()).abs().float() + _diff_abs = torch.where(loss_mask, _diff_abs, torch.zeros_like(_diff_abs)) + _diff_sq = (_diff_abs * _diff_abs).float() + stats_tracker.stat( + rollout_train_logprobs_abs_diff=_diff_abs, + rollout_train_logprobs_sq_diff=_diff_sq, + denominator="n_valid_tokens", ) + if prox_logp_gt is not None: + with torch.no_grad(): + _prox_abs = (old_logp - prox_logp_gt.detach()).abs().float() + _prox_abs = torch.where( + loss_mask, _prox_abs, torch.zeros_like(_prox_abs) + ) + _prox_sq = (_prox_abs * _prox_abs).float() + stats_tracker.stat( + rollout_train_logprobs_abs_diff_prox=_prox_abs, + rollout_train_logprobs_sq_diff_prox=_prox_sq, + denominator="n_valid_tokens", + ) if "behave_imp_weight" in stat: stats_tracker.denominator(unclipped_behave_tokens=stat["behave_mask"]) From e2a7980ff175c660872b10ee5004e9c9db0de500 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Fri, 1 May 2026 00:54:20 +0800 Subject: [PATCH 084/112] feat(ppo): add _r3_effectiveness_stats --- areal/trainer/ppo/actor.py | 86 ++++++++++++++++++++++++++++++++++++-- 1 file changed, 82 insertions(+), 4 deletions(-) diff --git a/areal/trainer/ppo/actor.py b/areal/trainer/ppo/actor.py index af1e10685a..ed28b77f4e 100644 --- a/areal/trainer/ppo/actor.py +++ b/areal/trainer/ppo/actor.py @@ -142,17 +142,95 @@ def _compute_logp(self, data: dict[str, Any]) -> torch.Tensor | None: ): from areal.trainer.ppo.actor_r3_patch import _resolve_to_tensor _r3_routed_experts = _resolve_to_tensor(_r3_routed_experts) - if _r3_routed_experts is not None and getattr( - self.engine, "_r3_enabled", False - ): + _r3_enabled = bool(getattr(self.engine, "_r3_enabled", False)) + if _r3_routed_experts is not None and _r3_enabled: # forward_batch performs ONE forward_backward_batch(forward_only=True) # call internally; the R3 engine patch will split routed_experts per # micro-batch and consume the side-channel (setting it back to None). self.engine._r3_pending_routed_experts = _r3_routed_experts - return self.engine.forward( + train_logp = self.engine.forward( input_=data, aggregate_fn=lambda xs: torch.cat(xs, dim=-1), ) + # R3 effectiveness metrics. At compute_logp time the training weights + # equal the rollout weights (no optimizer step has touched θ in this + # rollout epoch), so comparing SGLang's cached logprobs against the + # Megatron forward result isolates the router-replay effect from any + # off-policy weight drift. If R3 works, these divergence metrics should + # drop relative to the R3-off baseline. + self._log_r3_effectiveness_stats( + data=data, + train_logp=train_logp, + r3_enabled=_r3_enabled, + ) + return train_logp + + @torch.no_grad() + def _log_r3_effectiveness_stats( + self, + data: dict[str, Any], + train_logp: torch.Tensor | None, + r3_enabled: bool, + ) -> None: + """Log rollout vs. training logprob divergence to gauge R3 quality. + + All metrics are computed under ``ppo_actor/compute_logp/r3`` and are + designed so that lower values indicate a more faithful replay of the + rollout-time routing decisions: + + * ``rollout_train_logp_abs_diff`` - mean ``|logp_train - logp_rollout|`` + * ``rollout_train_logp_sq_diff`` - mean squared difference + * ``rollout_train_k3_kl`` - Schulman k3 estimator ``exp(Δ) - 1 - Δ`` + (unbiased, non-negative estimator of ``KL(π_rollout || π_train)``) + * ``rollout_train_extreme_frac_tau2`` / ``_tau5`` - F(τ) extreme token + fraction from the Router Replay paper (Eq. 3), i.e. the share of + tokens whose importance ratio leaves ``[1/τ, τ]`` + * ``r3_enabled`` scalar - 1 when the R3 side-channel was active + """ + if train_logp is None: + return + rollout_logp = data.get("logprobs") + loss_mask = data.get("loss_mask") + if rollout_logp is None or loss_mask is None: + return + # engine.forward returns logprobs aligned to ``roll(input_ids, -1)`` - + # i.e. position t holds the logprob of token t+1. ``data["logprobs"]`` + # from the inference engine follows the same pre-roll convention, so + # the two tensors are directly comparable. ``loss_mask`` in ``data`` + # has not yet been shifted (that happens in ``_compute_advantages``), + # so roll it here to align the valid-token mask with the logprobs. + shifted_mask = torch.roll(loss_mask, shifts=-1, dims=-1).bool() + if shifted_mask.shape != train_logp.shape: + return + # Shape-align rollout logprobs (dtype may differ across backends). + rollout_logp_f = rollout_logp.to(train_logp.dtype) + if rollout_logp_f.shape != train_logp.shape: + return + + log_ratio = (train_logp.float() - rollout_logp_f.float()).detach() + abs_diff = log_ratio.abs() + sq_diff = log_ratio * log_ratio + k3_kl = torch.expm1(log_ratio) - log_ratio # exp(Δ) - 1 - Δ + # F(τ) from the Router Replay paper: fraction of tokens with + # max(r, 1/r) > τ, where r = exp(logp_train - logp_rollout). + abs_log_ratio = log_ratio.abs() + extreme_tau2 = (abs_log_ratio > torch.log(torch.tensor(2.0))).float() + extreme_tau5 = (abs_log_ratio > torch.log(torch.tensor(5.0))).float() + + with stats_tracker.scope("compute_logp"): + with stats_tracker.scope("r3"): + stats_tracker.denominator( + n_valid_tokens=shifted_mask, + ) + stats_tracker.stat( + rollout_train_logp_abs_diff=abs_diff, + rollout_train_logp_sq_diff=sq_diff, + rollout_train_k3_kl=k3_kl, + rollout_train_extreme_frac_tau2=extreme_tau2, + rollout_train_extreme_frac_tau5=extreme_tau5, + denominator="n_valid_tokens", + ) + stats_tracker.scalar(r3_enabled=float(r3_enabled)) @trace_perf("ppo_actor.compute_advantages", category="compute") def compute_advantages(self, data: list[dict[str, Any]]) -> list[dict[str, Any]]: From 3e92a3d1f33087c0d2650f16edfffc23025e9a16 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Fri, 1 May 2026 16:27:33 +0800 Subject: [PATCH 085/112] fix(ppo/actor): fix metric --- areal/trainer/ppo/actor.py | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/areal/trainer/ppo/actor.py b/areal/trainer/ppo/actor.py index ed28b77f4e..1b15b1c349 100644 --- a/areal/trainer/ppo/actor.py +++ b/areal/trainer/ppo/actor.py @@ -193,17 +193,25 @@ def _log_r3_effectiveness_stats( loss_mask = data.get("loss_mask") if rollout_logp is None or loss_mask is None: return - # engine.forward returns logprobs aligned to ``roll(input_ids, -1)`` - - # i.e. position t holds the logprob of token t+1. ``data["logprobs"]`` - # from the inference engine follows the same pre-roll convention, so - # the two tensors are directly comparable. ``loss_mask`` in ``data`` - # has not yet been shifted (that happens in ``_compute_advantages``), - # so roll it here to align the valid-token mask with the logprobs. + # ``engine.forward`` returns logprobs aligned to ``roll(input_ids, -1)``: + # slot ``t`` holds ``log p(input_ids[t+1])``. The rollout-side + # ``data["logprobs"]`` is NOT pre-rolled -- slot ``t`` holds + # ``log p(input_ids[t])`` (SGLang returns one logprob per generated + # token at the position that consumed that token). To compare the two + # at the SAME token we must shift ``rollout_logp`` left by 1, mirroring + # ``_compute_advantages`` which does exactly ``torch.roll(data["logprobs"], -1)`` + # before feeding it into the PPO loss. + # + # Likewise ``loss_mask`` in ``data`` is in the unrolled frame (1 at + # response-token positions), and must also be rolled by -1 so that + # slot ``t`` marks "this slot's logprob is for a response token". shifted_mask = torch.roll(loss_mask, shifts=-1, dims=-1).bool() if shifted_mask.shape != train_logp.shape: return - # Shape-align rollout logprobs (dtype may differ across backends). - rollout_logp_f = rollout_logp.to(train_logp.dtype) + # Shape-align + convention-align rollout logprobs. + rollout_logp_f = torch.roll( + rollout_logp.to(train_logp.dtype), shifts=-1, dims=-1 + ) if rollout_logp_f.shape != train_logp.shape: return From 7bc2a41192d6d520da1535736a99962c2f53b736 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Fri, 1 May 2026 23:54:44 +0800 Subject: [PATCH 086/112] feat: all stage r3 log --- areal/engine/megatron_engine_r3_patch.py | 184 +++++++++++++- areal/engine/router_replay_patch.py | 137 ++++++++++ areal/engine/router_replay_utils.py | 306 ++++++++++++++++++++++- areal/trainer/ppo/actor.py | 131 ++++++++++ areal/trainer/ppo/actor_r3_patch.py | 29 +++ areal/workflow/rlvr.py | 24 ++ areal/workflow/rlvr_r3_patch.py | 25 +- 7 files changed, 829 insertions(+), 7 deletions(-) diff --git a/areal/engine/megatron_engine_r3_patch.py b/areal/engine/megatron_engine_r3_patch.py index 413ccb8c3a..0f5674c032 100644 --- a/areal/engine/megatron_engine_r3_patch.py +++ b/areal/engine/megatron_engine_r3_patch.py @@ -85,6 +85,7 @@ def _align_routed_experts_to_mask( routed_experts: torch.Tensor, cu_seqlens: torch.Tensor, max_seqlen: int, + _r3_mb_idx: int | None = None, ) -> torch.Tensor: """Align ``routed_experts`` from right-padded rollout format to left-aligned training format, matching the token layout implied by ``cu_seqlens``. @@ -138,6 +139,63 @@ def _align_routed_experts_to_mask( "n_seqs=%d (re_bs=%d), seq_lens=%s.", routed_experts.shape, aligned.shape, n_seqs, re_bs, seq_lens[:8], ) + # Detailed alignment log: smoking-gun check for the "last generated token + # has no routing" edge case (SGLang convention: num_sgl_tokens = + # prompt_len + gen_len - 1). If cu_seqlens claims k real tokens but the + # source only has k-1 non-zero rows, the k-th row here is a ZERO ROW that + # will route to expert 0 unconditionally. + try: + from areal.engine.router_replay_utils import ( + _r3_pp_tp_info, + _r3_should_log, + _r3_tensor_sig, + _r3_verbose, + ) + + if _r3_verbose() and _r3_should_log("_align_routed_experts_to_mask"): + with torch.no_grad(): + per_row_zero_src = ( + (routed_experts == 0).reshape(re_bs, re_seqlen, -1).all(dim=-1) + ) + src_zero_rows_per_sample = per_row_zero_src.sum(dim=-1).tolist() + per_row_zero_dst = ( + (aligned == 0).reshape(n_seqs, max_seqlen, -1).all(dim=-1) + ) + dst_zero_rows_per_sample = per_row_zero_dst.sum(dim=-1).tolist() + # For each sample, locate first zero-row idx within the real-token window. + first_zero_in_real = [] + for i in range(min(n_seqs, re_bs)): + L = int(seq_lens[i]) + if L <= 0: + first_zero_in_real.append(-1) + continue + row = per_row_zero_src[i, :L] + idx = torch.nonzero(row, as_tuple=False) + first_zero_in_real.append( + int(idx[0].item()) if idx.numel() > 0 else -1 + ) + logger.info( + "[R3-STAGE3/_align_routed_experts_to_mask] mb=%s %s " + "re_shape=%s aligned_shape=%s n_seqs=%d re_bs=%d " + "seq_lens[:8]=%s src_zero_rows_per_sample[:8]=%s " + "first_zero_in_real_window[:8]=%s " + "dst_zero_rows_per_sample[:8]=%s | %s | %s", + _r3_mb_idx, + _r3_pp_tp_info(), + tuple(routed_experts.shape), + tuple(aligned.shape), + n_seqs, + re_bs, + seq_lens[:8], + src_zero_rows_per_sample[:8], + first_zero_in_real[:8], + dst_zero_rows_per_sample[:8], + _r3_tensor_sig("src_re", routed_experts, max_sample=4), + _r3_tensor_sig("aligned", aligned, max_sample=4), + ) + except Exception: + # diagnostic helper must never break the main flow + pass return aligned @@ -232,6 +290,29 @@ def _split_routed_experts_for_mbs( [r.shape[0] for r in result], "None" if forward_indices is None else f"len={len(forward_indices)}", ) + try: + from areal.engine.router_replay_utils import ( + _r3_pp_tp_info, + _r3_should_log, + _r3_tensor_sig, + _r3_verbose, + ) + + if _r3_verbose() and _r3_should_log("_split_routed_experts_for_mbs"): + logger.info( + "[R3-STAGE3/_split_routed_experts_for_mbs] %s " + "input_shape=%s n_mbs=%d forward_indices=%s " + "per_mb_shapes=%s | %s", + _r3_pp_tp_info(), + tuple(routed_experts.shape), + n_mbs, + "None" if forward_indices is None + else f"len={len(forward_indices)} first16={forward_indices[:16].tolist() if hasattr(forward_indices,'tolist') else list(forward_indices)[:16]}", + [tuple(r.shape) for r in result], + _r3_tensor_sig("routed_experts", routed_experts, max_sample=4), + ) + except Exception: + pass return result @@ -250,20 +331,44 @@ def _get_cu_seqlens_for_mb(mb_item) -> tuple[torch.Tensor, int] | None: ``(cu_seqlens, max_seqlen)`` or ``None`` if not available. """ # Try padded_mb first (has TP-aligned cu_seqlens -- this is what the model sees) + source = None if hasattr(mb_item, "padded_mb") and isinstance(mb_item.padded_mb, dict): cu = mb_item.padded_mb.get("cu_seqlens") max_sl = mb_item.padded_mb.get("max_seqlen") if cu is not None and max_sl is not None: - return cu, int(max_sl) + source = ("padded_mb", cu, int(max_sl)) # Try orig_mb as fallback (pre-padding cu_seqlens) - if hasattr(mb_item, "orig_mb") and isinstance(mb_item.orig_mb, dict): + if source is None and hasattr(mb_item, "orig_mb") and isinstance(mb_item.orig_mb, dict): cu = mb_item.orig_mb.get("cu_seqlens") max_sl = mb_item.orig_mb.get("max_seqlen") if cu is not None and max_sl is not None: - return cu, int(max_sl) + source = ("orig_mb", cu, int(max_sl)) - return None + if source is None: + return None + + src_name, cu_out, max_sl_out = source + try: + from areal.engine.router_replay_utils import ( + _r3_pp_tp_info, + _r3_should_log, + _r3_tensor_sig, + _r3_verbose, + ) + + if _r3_verbose() and _r3_should_log("_get_cu_seqlens_for_mb"): + logger.info( + "[R3-STAGE3/_get_cu_seqlens_for_mb] %s source=%s " + "max_seqlen=%d | %s", + _r3_pp_tp_info(), + src_name, + max_sl_out, + _r3_tensor_sig("cu_seqlens", cu_out), + ) + except Exception: + pass + return cu_out, max_sl_out # =================================================================== @@ -288,6 +393,10 @@ def _r3_forward_backward_batch( from areal.engine.router_replay_patch import RouterReplay, RouterReplayAction from areal.engine.router_replay_utils import ( RouterReplayHelper, + _r3_pp_tp_info, + _r3_should_log, + _r3_tensor_sig, + _r3_verbose, clear_router_replay, setup_per_microbatch_replay_forward, ) @@ -352,6 +461,18 @@ def _r3_forward_backward_batch( routed_experts_batch.shape, forward_only, ) + if _r3_verbose() and _r3_should_log("_r3_forward_backward_batch/ENTER"): + logger.info( + "[R3-STAGE3/_r3_forward_backward_batch] ENTER %s " + "n_mbs=%d forward_only=%s from_side_channel=%s " + "has_padded_mbs=%s | %s", + _r3_pp_tp_info(), + len(mb_list), + forward_only, + _from_side_channel, + mb_list.padded_mbs is not None, + _r3_tensor_sig("routed_experts_batch", routed_experts_batch), + ) # Split routed_experts per micro-batch per_mb_routed_experts = _split_routed_experts_for_mbs( @@ -402,6 +523,22 @@ def __next__(self): else None ) + if _r3_verbose() and _r3_should_log("_R3MicroBatchIterator.__next__"): + logger.info( + "[R3-STAGE3/_R3MicroBatchIterator] ENTER mb_idx=%d %s " + "re_shape=%s has_orig_mb=%s has_padded_mb=%s " + "has_old_cu_seqlens=%s", + idx, + _r3_pp_tp_info(), + None if re is None else tuple(re.shape), + hasattr(mb_item, "orig_mb") + and isinstance(mb_item.orig_mb, dict), + hasattr(mb_item, "padded_mb") + and isinstance(mb_item.padded_mb, dict), + hasattr(mb_item, "old_cu_seqlens") + and mb_item.old_cu_seqlens is not None, + ) + # When backward recompute finishes and next forward starts, # switch back to REPLAY_FORWARD. if RouterReplayHelper.is_replay_backward_action(model_config): @@ -412,6 +549,16 @@ def __next__(self): router.set_router_replay_action( RouterReplayAction.REPLAY_FORWARD ) + if _r3_verbose() and _r3_should_log( + "_R3MicroBatchIterator.toggle_to_forward" + ): + logger.info( + "[R3-STAGE3/_R3MicroBatchIterator] TOGGLE backward->forward " + "mb_idx=%d %s n_routers=%d", + idx, + _r3_pp_tp_info(), + len(router_list), + ) if re is not None: # Extract cu_seqlens from padded_mb (TP-aligned, what the model sees) @@ -428,19 +575,39 @@ def __next__(self): # to know each sample's actual token count for # extracting from routed_experts. orig_cu = None + orig_cu_src = None if hasattr(mb_item, "old_cu_seqlens") and mb_item.old_cu_seqlens is not None: orig_cu = mb_item.old_cu_seqlens + orig_cu_src = "old_cu_seqlens" elif hasattr(mb_item, "orig_mb") and isinstance(mb_item.orig_mb, dict): orig_cu = mb_item.orig_mb.get("cu_seqlens") + orig_cu_src = "orig_mb.cu_seqlens" if orig_cu is None: # Fallback: use padded cu_seqlens directly orig_cu = cu_seqlens + orig_cu_src = "padded_cu_seqlens (fallback)" + + if _r3_verbose() and _r3_should_log( + "_R3MicroBatchIterator.pre_align" + ): + logger.info( + "[R3-STAGE3/_R3MicroBatchIterator] PRE-ALIGN " + "mb_idx=%d %s orig_cu_src=%s max_seqlen=%d " + "| %s | %s | %s", + idx, + _r3_pp_tp_info(), + orig_cu_src, + max_seqlen, + _r3_tensor_sig("re", re, max_sample=4), + _r3_tensor_sig("orig_cu", orig_cu), + _r3_tensor_sig("padded_cu", cu_seqlens), + ) # Align routed_experts from left-padded to left-aligned # using the ORIGINAL cu_seqlens (actual token counts). aligned_re = _align_routed_experts_to_mask( - re, orig_cu, max_seqlen, + re, orig_cu, max_seqlen, _r3_mb_idx=idx, ) # Pass the PADDED cu_seqlens (with TP alignment) @@ -503,6 +670,13 @@ def _r3_post_forward_hook(module, input, output): router.set_router_replay_action( RouterReplayAction.REPLAY_BACKWARD ) + if _r3_verbose() and _r3_should_log("_r3_post_forward_hook"): + logger.info( + "[R3-STAGE3/_r3_post_forward_hook] TOGGLE forward->backward " + "%s n_routers=%d", + _r3_pp_tp_info(), + len(router_list), + ) for model_chunk in self.model: handle = model_chunk.register_forward_hook(_r3_post_forward_hook) diff --git a/areal/engine/router_replay_patch.py b/areal/engine/router_replay_patch.py index 8347735807..391a0f012b 100644 --- a/areal/engine/router_replay_patch.py +++ b/areal/engine/router_replay_patch.py @@ -122,6 +122,22 @@ def set_global_router_replay_action(action: RouterReplayAction) -> None: """Set the replay action for all router instances.""" for r in RouterReplay.router_instances: r.set_router_replay_action(action) + try: + from areal.engine.router_replay_utils import ( + _r3_pp_tp_info, + _r3_should_log, + ) + + if _r3_should_log(f"set_global_router_replay_action/{action.value}"): + logger.info( + "[R3-STAGE4/set_global_router_replay_action] %s action=%s " + "applied_to=%d router_instances", + _r3_pp_tp_info(), + action.value, + len(RouterReplay.router_instances), + ) + except Exception: + pass @staticmethod def clear_global_router_replay_action() -> None: @@ -140,6 +156,35 @@ def set_target_indices(self, topk_indices: torch.Tensor) -> None: """Sets the target topk indices for replay.""" self.target_topk_idx = topk_indices self.replay_backward_list.append(topk_indices) + # Cheap diagnostic: record every set in first few layers/mb. Gated + # via _r3_should_log so steady-state overhead is ~one integer + # increment. + try: + from areal.engine.router_replay_utils import ( + _r3_pp_tp_info, + _r3_should_log, + _r3_tensor_sig, + _r3_zero_row_stats, + ) + + if _r3_should_log("RouterReplay.set_target_indices"): + # instance index in the class-level list tells us which + # MoE layer this replay slot refers to + try: + inst_idx = RouterReplay.router_instances.index(self) + except ValueError: + inst_idx = -1 + logger.info( + "[R3-STAGE3b/set_target_indices] inst#%d %s %s | %s | " + "backward_queue_len=%d", + inst_idx, + _r3_pp_tp_info(), + _r3_zero_row_stats(topk_indices), + _r3_tensor_sig("topk_indices", topk_indices), + len(self.replay_backward_list), + ) + except Exception: + pass def get_recorded_indices(self) -> torch.Tensor | None: return self.recorded_topk_idx @@ -164,6 +209,80 @@ def clear_router_replay_action(self) -> None: # =================================================================== +def _R3_routing_log( + action_name: str, + *, + scores: torch.Tensor, + top_indices: torch.Tensor, + topk: int, + compute_topk_fn, + num_groups=None, + group_topk=None, +) -> None: + """Rate-limited diagnostic for the replay branches. + + Key quantities: + * ``shape_match`` -- does target_topk_idx align with this layer's + token count? If NOT, replay is being fed the wrong slab. + * ``zero_rows`` -- fraction of all-zero rows in the replay + indices; zero rows collapse routing to expert 0. + * ``live_vs_replay`` -- overlap between replay top-k and the live + top-k the router would have picked right now. 100% = no staleness + (rollout weights == train weights). 0% = total mismatch. + """ + from areal.engine.router_replay_utils import ( + _r3_call_count, + _r3_pp_tp_info, + _r3_should_log, + _r3_tensor_sig, + _r3_verbose, + _r3_zero_row_stats, + _R3_ROUTER_LAYER_LIMIT, + ) + + if not _r3_verbose(): + return + key = f"patched_routing/{action_name}" + call_n = _r3_call_count(key) + # We always want an early, concentrated burst of per-layer details at + # startup (helps catch first-step config problems) and then a sparse + # steady-state sample. + if not _r3_should_log(key): + return + with torch.no_grad(): + shape_match = top_indices.shape[0] == scores.shape[0] + if shape_match: + try: + _, live_top = compute_topk_fn( + scores, topk, num_groups=num_groups, group_topk=group_topk + ) + # per-token overlap ratio + set_live = live_top.sort(dim=-1).values + set_rep = top_indices.sort(dim=-1).values + # equality per (token, slot) + overlap = (set_live == set_rep).float().mean().item() + except Exception as e: + overlap = f"err:{e}" + else: + overlap = None + logger.info( + "[R3-STAGE4/patched_routing] %s call#%d %s " + "scores_shape=%s topk=%d target_shape=%s shape_match=%s " + "live_vs_replay_topk_overlap=%s %s | %s | %s", + action_name, + call_n, + _r3_pp_tp_info(), + tuple(scores.shape), + topk, + tuple(top_indices.shape), + shape_match, + overlap, + _r3_zero_row_stats(top_indices), + _r3_tensor_sig("scores", scores, max_sample=4), + _r3_tensor_sig("top_indices", top_indices, max_sample=8), + ) + + def _patched_topk_routing_with_score_function( logits: torch.Tensor, topk: int, @@ -220,6 +339,15 @@ def compute_topk(scores, topk, num_groups=None, group_topk=None): # Use the provided indices for replay top_indices = router_replay.target_topk_idx top_indices = top_indices.to(scores.device) + _R3_routing_log( + "REPLAY_FORWARD", + scores=scores, + top_indices=top_indices, + topk=topk, + compute_topk_fn=_compute_topk, + num_groups=num_groups, + group_topk=group_topk, + ) probs = scores.gather(1, top_indices) return probs, top_indices @@ -235,6 +363,15 @@ def compute_topk(scores, topk, num_groups=None, group_topk=None): # Use the last recorded indices for backward replay top_indices = router_replay.replay_backward_list.pop(0) top_indices = top_indices.to(scores.device) + _R3_routing_log( + "REPLAY_BACKWARD", + scores=scores, + top_indices=top_indices, + topk=topk, + compute_topk_fn=_compute_topk, + num_groups=num_groups, + group_topk=group_topk, + ) probs = scores.gather(1, top_indices) return probs, top_indices diff --git a/areal/engine/router_replay_utils.py b/areal/engine/router_replay_utils.py index 49f356a78f..13b03322b3 100644 --- a/areal/engine/router_replay_utils.py +++ b/areal/engine/router_replay_utils.py @@ -15,6 +15,7 @@ import inspect import logging +import os from typing import Optional import torch @@ -24,6 +25,166 @@ logger = logging.getLogger(__name__) +# =================================================================== +# R3 detailed-logging helpers +# =================================================================== +# These helpers are used by EVERY R3 file (this module, router_replay_patch, +# megatron_engine_r3_patch, actor_r3_patch, actor.py, rlvr_r3_patch) so that +# all stages of the pipeline produce fingerprints in a consistent format. +# +# Controls (all opt-in via env vars so prod perf is not affected unless you +# deliberately enable): +# +# AREAL_R3_VERBOSE=1 -- master switch; enables everything below. +# Default: 1 (ON) so that if someone cares +# to run with R3 and grep logs, they do +# not need to set anything extra. +# AREAL_R3_LOG_FIRST_N=30 -- for rate-limited hot paths, always log +# the first N calls per key. +# AREAL_R3_LOG_EVERY=100 -- after the first N, log every Nth call. +# AREAL_R3_TENSOR_SAMPLE=8 -- how many leading elements to include in +# a tensor signature. +# AREAL_R3_ROUTER_LAYER_LIMIT=3 -- in patched_topk_routing, only print +# per-layer details for up to the first +# K routing calls of each type per step +# (layer idx is approximated via a +# per-action counter). +# =================================================================== + + +def _r3_verbose() -> bool: + return os.environ.get("AREAL_R3_VERBOSE", "1") != "0" + + +_R3_LOG_CALL_COUNTS: dict[str, int] = {} +_R3_LOG_FIRST_N = int(os.environ.get("AREAL_R3_LOG_FIRST_N", "30")) +_R3_LOG_EVERY = int(os.environ.get("AREAL_R3_LOG_EVERY", "100")) +_R3_TENSOR_SAMPLE = int(os.environ.get("AREAL_R3_TENSOR_SAMPLE", "8")) +_R3_ROUTER_LAYER_LIMIT = int(os.environ.get("AREAL_R3_ROUTER_LAYER_LIMIT", "3")) + + +def _r3_should_log(key: str) -> bool: + """Rate-limited logging gate. Returns True for the first + ``AREAL_R3_LOG_FIRST_N`` calls against ``key``, then True once every + ``AREAL_R3_LOG_EVERY`` calls thereafter. Monotonic per-process counter. + """ + if not _r3_verbose(): + return False + n = _R3_LOG_CALL_COUNTS.get(key, 0) + 1 + _R3_LOG_CALL_COUNTS[key] = n + if n <= _R3_LOG_FIRST_N: + return True + return (n % max(_R3_LOG_EVERY, 1)) == 0 + + +def _r3_call_count(key: str) -> int: + return _R3_LOG_CALL_COUNTS.get(key, 0) + + +def _r3_tensor_sig(name: str, t, *, max_sample: int | None = None) -> str: + """Compact human-readable fingerprint for a tensor or numpy array. + + Intentionally cheap: performs ONE ``.detach().cpu()`` copy and one + reduction so it is safe to call from hot paths (still, prefer to gate + via ``_r3_should_log``). + """ + if t is None: + return f"{name}=None" + sample_n = _R3_TENSOR_SAMPLE if max_sample is None else max_sample + try: + if isinstance(t, torch.Tensor): + tc = t.detach() + if tc.device.type != "cpu": + tc = tc.to("cpu", non_blocking=False) + flat = tc.reshape(-1) + total = int(flat.numel()) + if total == 0: + return f"{name}(shape={tuple(t.shape)}, dtype={t.dtype}, empty)" + nnz = int((flat != 0).sum().item()) + if tc.dtype in ( + torch.float16, + torch.float32, + torch.float64, + torch.bfloat16, + ): + checksum = float(flat.float().double().sum().item()) + maxv = float(flat.float().abs().max().item()) + sample = [round(v, 6) for v in flat[:sample_n].float().tolist()] + else: + checksum = int(flat.long().sum().item()) + maxv = int(flat.long().abs().max().item()) + sample = flat[:sample_n].tolist() + return ( + f"{name}(shape={tuple(t.shape)}, dtype={t.dtype}, " + f"device={t.device}, nnz={nnz}/{total}, " + f"sum={checksum}, |max|={maxv}, first{len(sample)}={sample})" + ) + # numpy or generic array-like + if hasattr(t, "shape") and hasattr(t, "dtype"): + import numpy as np + + arr = t if isinstance(t, np.ndarray) else np.asarray(t) + flat = arr.reshape(-1) + total = int(flat.size) + if total == 0: + return f"{name}(shape={arr.shape}, dtype={arr.dtype}, empty, numpy)" + nnz = int((flat != 0).sum()) + checksum = int(flat.astype("int64").sum()) if np.issubdtype( + arr.dtype, np.integer + ) else float(flat.astype("float64").sum()) + maxv = ( + int(np.abs(flat).max()) + if np.issubdtype(arr.dtype, np.integer) + else float(np.abs(flat).max()) + ) + sample = flat[:sample_n].tolist() + return ( + f"{name}(shape={tuple(arr.shape)}, dtype={arr.dtype}, numpy, " + f"nnz={nnz}/{total}, sum={checksum}, |max|={maxv}, " + f"first{len(sample)}={sample})" + ) + except Exception as e: # pragma: no cover - diagnostic helper must not raise + return f"{name}=" + return f"{name}={type(t).__name__}" + + +def _r3_zero_row_stats(top_indices: torch.Tensor) -> str: + """Returns a string describing the count of all-zero rows in a + ``(num_tokens, topk)`` target_topk_idx tensor. All-zero rows are the + smoking gun for zero-fill hazard. + """ + if top_indices is None or top_indices.ndim < 2: + return "zero_row_stats=N/A" + try: + with torch.no_grad(): + zero_rows = (top_indices == 0).all(dim=-1) + total = int(zero_rows.numel()) + z = int(zero_rows.sum().item()) + return f"zero_rows={z}/{total} ({100.0*z/max(total,1):.2f}%)" + except Exception as e: + return f"zero_row_stats=" + + +def _r3_pp_tp_info(tf_config=None, vp_rank=None) -> str: + """Short PP/TP/DP/SP/EP context string for log lines.""" + try: + from megatron.core import parallel_state as mpu + + pp = mpu.get_pipeline_model_parallel_world_size() + pp_rank = mpu.get_pipeline_model_parallel_rank() + tp = mpu.get_tensor_model_parallel_world_size() + tp_rank = mpu.get_tensor_model_parallel_rank() + cp = getattr(mpu, "get_context_parallel_world_size", lambda: 1)() + dp = mpu.get_data_parallel_world_size() + dp_rank = mpu.get_data_parallel_rank() + return ( + f"pp={pp_rank}/{pp} tp={tp_rank}/{tp} cp={cp} dp={dp_rank}/{dp}" + + (f" vp={vp_rank}" if vp_rank is not None else "") + ) + except Exception: + return "pp=?/? tp=?/?" + + # =================================================================== # Layer computation helpers # =================================================================== @@ -275,6 +436,30 @@ def set_router_replay_data( total_aligned = sum(aligned_lens) + if _r3_verbose() and _r3_should_log("set_router_replay_data/ENTER"): + logger.info( + "[R3-STAGE3/set_router_replay_data] ENTER call#%d %s " + "layers_topk_idx=(bs=%d, max_seq=%d, L=%d, K=%d) dtype=%s " + "n_cu_entries=%d n_seqs_in_cu=%d seq_align_to=%d " + "seq_lens[:8]=%s aligned_lens[:8]=%s total_aligned=%d " + "vp_rank=%s | %s", + _r3_call_count("set_router_replay_data/ENTER"), + _r3_pp_tp_info(tf_config, vp_rank), + bs_re, + layers_topk_idx.shape[1], + num_layers, + topk, + layers_topk_idx.dtype, + n_cu_entries, + n_seqs_in_cu, + seq_align_to, + seq_lens[:8], + aligned_lens[:8], + total_aligned, + vp_rank, + _r3_tensor_sig("cu_seqlens", cu_seqlens), + ) + # Pack routed_experts using cu_seqlens-aligned layout. # layers_topk_idx is left-ALIGNED: real tokens at positions [0, seq_len). # For each sequence i, we take the first seq_lens[i] tokens and place @@ -303,6 +488,27 @@ def set_router_replay_data( for i in range(bs_re, n_seqs_in_cu): aligned_offset += aligned_lens[i] + if _r3_verbose() and _r3_should_log("set_router_replay_data/PACKED"): + with torch.no_grad(): + # Count global all-zero rows across ALL layers AND topk slots. + per_row_zero = ( + (packed == 0).reshape(packed.shape[0], -1).all(dim=-1) + ) + zrows = int(per_row_zero.sum().item()) + total_rows = int(per_row_zero.numel()) + logger.info( + "[R3-STAGE3/set_router_replay_data] PACKED " + "packed=(total_aligned=%d, L=%d, K=%d) global_zero_rows=%d/%d " + "(%.2f%%) | %s", + packed.shape[0], + packed.shape[1], + packed.shape[2], + zrows, + total_rows, + 100.0 * zrows / max(total_rows, 1), + _r3_tensor_sig("packed", packed), + ) + # Step 2: Scatter to SP ranks packed = packed.to(device) tp_size = mpu.get_tensor_model_parallel_world_size() @@ -316,6 +522,15 @@ def set_router_replay_data( # Step 3: Permute to (num_layers, local_tokens_count, topk) layers_topk = local_tokens.permute(1, 0, 2) + if _r3_verbose() and _r3_should_log("set_router_replay_data/SCATTER"): + logger.info( + "[R3-STAGE3/set_router_replay_data] POST-SCATTER " + "tp_size=%d local_tokens=%s layers_topk=%s", + tp_size, + _r3_tensor_sig("local_tokens", local_tokens), + _r3_tensor_sig("layers_topk", layers_topk), + ) + # Step 4: Distribute to RouterReplay instances for local PP layers local_info = get_current_rank_layer_info(tf_config, vp_rank) offset, end = local_info["start"], local_info["end"] @@ -338,6 +553,7 @@ def set_router_replay_data( moe_idx = sum(1 for i in range(offset) if is_moe_layer(tf_config, i)) router_offset = 0 + dispatched = [] # list of (layer_idx, idx_into_layers_topk, zero_row_stats) for layer_idx in range(offset, end): if not is_moe_layer(tf_config, layer_idx): continue @@ -360,7 +576,17 @@ def set_router_replay_data( moe_idx += 1 router_offset += 1 continue - router.set_target_indices(layers_topk[idx].to(torch.int64)) + slab = layers_topk[idx].to(torch.int64) + router.set_target_indices(slab) + if _r3_verbose() and _r3_should_log("set_router_replay_data/DISPATCH"): + dispatched.append( + ( + layer_idx, + idx, + _r3_zero_row_stats(slab), + _r3_tensor_sig(f"target[L={layer_idx}]", slab), + ) + ) router_offset += 1 moe_idx += 1 @@ -374,6 +600,23 @@ def set_router_replay_data( end, tp_size, ) + if _r3_verbose() and dispatched: + # Only log first couple of dispatched layers in detail; keep + # the rest summarised. + head = dispatched[:_R3_ROUTER_LAYER_LIMIT] + logger.info( + "[R3-STAGE3/set_router_replay_data] DISPATCH %s " + "router_offset=%d len(router_list)=%d index_by_layer=%s " + "first_layers=%s ... (total dispatched=%d)", + _r3_pp_tp_info(tf_config, vp_rank), + router_offset, + len(router_list), + index_by_layer, + [ + (lidx, j, zr, sig) for lidx, j, zr, sig in head + ], + len(dispatched), + ) # =================================================================== @@ -400,6 +643,19 @@ def setup_per_microbatch_replay_forward( seq_align_to: Per-sequence TP alignment factor. """ routed_experts = routed_experts.to(torch.int32) + if _r3_verbose() and _r3_should_log("setup_per_microbatch_replay_forward"): + with torch.no_grad(): + per_row_zero = (routed_experts == 0).all(dim=-1).all(dim=-1) + logger.info( + "[R3-STAGE3/setup_per_microbatch_replay_forward] ENTER %s " + "routed_experts=%s cu_seqlens=%s seq_align_to=%s " + "per_sample_zero_rows=%s", + _r3_pp_tp_info(tf_config, vp_rank), + _r3_tensor_sig("routed_experts", routed_experts), + _r3_tensor_sig("cu_seqlens", cu_seqlens), + seq_align_to, + [int(x.sum().item()) for x in per_row_zero][:8], + ) set_router_replay_data( routed_experts, cu_seqlens, tf_config, vp_rank, seq_align_to=seq_align_to, @@ -454,12 +710,30 @@ def preprocess_routed_experts_batch( import numpy as np if routed_experts_np is None: + if _r3_verbose(): + logger.info("[R3-STAGE1/preprocess] routed_experts_np=None, returning None") return None seq_len = input_ids.shape[1] num_sgl_tokens = routed_experts_np.shape[0] flat_dim = routed_experts_np.shape[1] + if _r3_verbose(): + logger.info( + "[R3-STAGE1/preprocess] ENTER " + "seq_len=%d num_sgl_tokens=%d flat_dim=%d num_moe_layers=%s topk=%s " + "expected_flat=%s | %s | %s | %s", + seq_len, + num_sgl_tokens, + flat_dim, + num_moe_layers, + topk, + (num_moe_layers or 0) * (topk or 0), + _r3_tensor_sig("input_ids", input_ids), + _r3_tensor_sig("attention_mask", attention_mask), + _r3_tensor_sig("routed_experts_np", routed_experts_np), + ) + expected_flat = num_moe_layers * topk if flat_dim != expected_flat: logger.warning( @@ -512,4 +786,34 @@ def preprocess_routed_experts_batch( right_pad, ) + if _r3_verbose(): + # NOTE: for R3 correctness check. We expect num_sgl_tokens = + # real_tokens - 1 (SGLang drops the logprob of the very last + # generated token). Anything else means the routing -> token + # alignment is not what we think it is. + tail_row_is_zero = None + try: + tail_slice = padded[0, real_tokens - 1] if real_tokens > 0 else None + if tail_slice is not None: + tail_row_is_zero = bool((tail_slice == 0).all().item()) + except Exception: + tail_row_is_zero = "err" + # All-zero row stats across the seq_len axis for this sample. + with torch.no_grad(): + per_row_zero = (padded[0] == 0).all(dim=-1).all(dim=-1) + zero_rows_count = int(per_row_zero.sum().item()) + logger.info( + "[R3-STAGE1/preprocess] EXIT " + "num_sgl_tokens=%d real_tokens=%d delta=%d right_pad=%d " + "tail_real_row_all_zero=%s zero_rows_total=%d/%d | %s", + num_sgl_tokens, + real_tokens, + real_tokens - num_sgl_tokens, + right_pad, + tail_row_is_zero, + zero_rows_count, + seq_len, + _r3_tensor_sig("padded", padded), + ) + return padded diff --git a/areal/trainer/ppo/actor.py b/areal/trainer/ppo/actor.py index 1b15b1c349..894203526c 100644 --- a/areal/trainer/ppo/actor.py +++ b/areal/trainer/ppo/actor.py @@ -143,6 +143,33 @@ def _compute_logp(self, data: dict[str, Any]) -> torch.Tensor | None: from areal.trainer.ppo.actor_r3_patch import _resolve_to_tensor _r3_routed_experts = _resolve_to_tensor(_r3_routed_experts) _r3_enabled = bool(getattr(self.engine, "_r3_enabled", False)) + try: + from areal.engine.router_replay_utils import ( + _r3_should_log, + _r3_tensor_sig, + _r3_verbose, + ) + + if _r3_verbose() and _r3_should_log("actor._compute_logp/ENTER"): + _re_info = ( + _r3_tensor_sig("routed_experts", _r3_routed_experts) + if _r3_routed_experts is not None + else "routed_experts=None" + ) + logger.info( + "[R3-STAGE2/actor._compute_logp] ENTER r3_enabled=%s " + "input_keys=%s batch_shape=%s | %s | %s | %s", + _r3_enabled, + sorted(list(data.keys())), + tuple(data["input_ids"].shape) + if "input_ids" in data + else "N/A", + _re_info, + _r3_tensor_sig("logprobs", data.get("logprobs")), + _r3_tensor_sig("loss_mask", data.get("loss_mask")), + ) + except Exception: + pass if _r3_routed_experts is not None and _r3_enabled: # forward_batch performs ONE forward_backward_batch(forward_only=True) # call internally; the R3 engine patch will split routed_experts per @@ -225,6 +252,55 @@ def _log_r3_effectiveness_stats( extreme_tau2 = (abs_log_ratio > torch.log(torch.tensor(2.0))).float() extreme_tau5 = (abs_log_ratio > torch.log(torch.tensor(5.0))).float() + try: + from areal.engine.router_replay_utils import ( + _r3_should_log, + _r3_tensor_sig, + _r3_verbose, + ) + + if _r3_verbose() and _r3_should_log( + "actor._log_r3_effectiveness_stats" + ): + with torch.no_grad(): + n_valid = int(shifted_mask.sum().item()) + if n_valid > 0: + _masked = abs_diff[shifted_mask] + _mean_abs = float(_masked.mean().item()) + _max_abs = float(_masked.max().item()) + _p99 = float( + torch.quantile( + _masked.float(), 0.99 + ).item() + ) if _masked.numel() > 0 else 0.0 + _mean_k3 = float(k3_kl[shifted_mask].mean().item()) + _frac_tau2 = float( + extreme_tau2[shifted_mask].mean().item() + ) + _frac_tau5 = float( + extreme_tau5[shifted_mask].mean().item() + ) + else: + _mean_abs = _max_abs = _p99 = _mean_k3 = _frac_tau2 = _frac_tau5 = 0.0 + logger.info( + "[R3-STAGE2/r3_effectiveness] r3_enabled=%s " + "n_valid_tokens=%d mean_abs_diff=%.6f max_abs_diff=%.6f " + "p99_abs_diff=%.6f mean_k3_kl=%.6f frac_tau2=%.4f " + "frac_tau5=%.4f | %s | %s", + r3_enabled, + n_valid, + _mean_abs, + _max_abs, + _p99, + _mean_k3, + _frac_tau2, + _frac_tau5, + _r3_tensor_sig("train_logp", train_logp), + _r3_tensor_sig("rollout_logp_rolled", rollout_logp_f), + ) + except Exception: + pass + with stats_tracker.scope("compute_logp"): with stats_tracker.scope("r3"): stats_tracker.denominator( @@ -475,6 +551,32 @@ def _ppo_update(self, data: dict[str, Any]) -> None: len(mb_inputs.mbs), [s.shape if s is not None else None for s in _r3_split], ) + try: + from areal.engine.router_replay_utils import ( + _r3_should_log, + _r3_tensor_sig, + _r3_verbose, + ) + + if _r3_verbose() and _r3_should_log("actor._ppo_update/split"): + logger.info( + "[R3-STAGE2/actor._ppo_update] SPLIT " + "n_ppo_minibatches=%d per_mb_shapes=%s " + "forward_indices=%s | %s", + len(mb_inputs.mbs), + [ + None if s is None else tuple(s.shape) + for s in _r3_split + ], + "None" + if mb_inputs.forward_indices is None + else f"len={len(mb_inputs.forward_indices)}", + _r3_tensor_sig( + "_r3_routed_experts", _r3_routed_experts, max_sample=4 + ), + ) + except Exception: + pass with stats_tracker.scope("update"): # Get current version for proximal approximation metrics @@ -490,6 +592,35 @@ def _ppo_update(self, data: dict[str, Any]) -> None: if i < len(_r3_split) and _r3_split[i] is not None else None ) + try: + from areal.engine.router_replay_utils import ( + _r3_should_log, + _r3_tensor_sig, + _r3_verbose, + ) + + if _r3_verbose() and _r3_should_log( + "actor._ppo_update/side_channel_set" + ): + _slice = ( + _r3_split[i] + if i < len(_r3_split) + else None + ) + logger.info( + "[R3-STAGE2/actor._ppo_update] " + "side_channel_set mb=%d current_version=%s " + "| %s", + i, + current_version, + _r3_tensor_sig( + "pending_routed_experts", + _slice, + max_sample=4, + ), + ) + except Exception: + pass else: logger.warning( "[R3] routed_experts available but engine._r3_enabled " diff --git a/areal/trainer/ppo/actor_r3_patch.py b/areal/trainer/ppo/actor_r3_patch.py index bc9faf3f7a..08cf2e681f 100644 --- a/areal/trainer/ppo/actor_r3_patch.py +++ b/areal/trainer/ppo/actor_r3_patch.py @@ -88,6 +88,35 @@ def split_routed_experts_for_minibatches( [r.shape[0] for r in result], "None" if forward_indices is None else f"len={len(forward_indices)}", ) + try: + from areal.engine.router_replay_utils import ( + _r3_pp_tp_info, + _r3_should_log, + _r3_tensor_sig, + _r3_verbose, + ) + + if _r3_verbose() and _r3_should_log( + "split_routed_experts_for_minibatches" + ): + logger.info( + "[R3-STAGE2/split_routed_experts_for_minibatches] %s " + "input_shape=%s n_mbs=%d forward_indices=%s " + "per_mb_shapes=%s | %s", + _r3_pp_tp_info(), + tuple(routed_experts.shape), + n_mbs, + "None" + if forward_indices is None + else ( + f"len={len(forward_indices)} " + f"first16={forward_indices[:16].tolist() if hasattr(forward_indices,'tolist') else list(forward_indices)[:16]}" + ), + [tuple(r.shape) for r in result], + _r3_tensor_sig("routed_experts", routed_experts, max_sample=4), + ) + except Exception: + pass return result diff --git a/areal/workflow/rlvr.py b/areal/workflow/rlvr.py index ae356b8b0a..f5e38aa7b4 100644 --- a/areal/workflow/rlvr.py +++ b/areal/workflow/rlvr.py @@ -196,5 +196,29 @@ async def arun_episode( topk=self.r3_topk, ) res = inject_routed_experts_into_result(res, routed_experts_tensor) + try: + from areal.engine.router_replay_utils import ( + _r3_should_log, + _r3_tensor_sig, + _r3_verbose, + ) + + if _r3_verbose() and _r3_should_log("rlvr.arun_episode"): + logger.info( + "[R3-STAGE1/rlvr.arun_episode] INJECT " + "r3_num_moe_layers=%s r3_topk=%s " + "resp.routed_experts.shape=%s input_len=%d " + "output_len=%d | %s", + self.r3_num_moe_layers, + self.r3_topk, + getattr(resp.routed_experts, "shape", None), + resp.input_len, + resp.output_len, + _r3_tensor_sig( + "routed_experts_tensor", routed_experts_tensor + ), + ) + except Exception: + pass return res diff --git a/areal/workflow/rlvr_r3_patch.py b/areal/workflow/rlvr_r3_patch.py index a93f24eac3..03a00b8775 100644 --- a/areal/workflow/rlvr_r3_patch.py +++ b/areal/workflow/rlvr_r3_patch.py @@ -157,7 +157,7 @@ def extract_routed_experts( preprocess_routed_experts_batch, ) - return preprocess_routed_experts_batch( + result = preprocess_routed_experts_batch( routed_experts_np, input_ids, attention_mask, @@ -165,6 +165,29 @@ def extract_routed_experts( topk=topk, compress_dtype=compress_dtype, ) + try: + from areal.engine.router_replay_utils import ( + _r3_should_log, + _r3_tensor_sig, + _r3_verbose, + ) + + if _r3_verbose() and _r3_should_log("extract_routed_experts"): + logger.info( + "[R3-STAGE1/extract_routed_experts] " + "num_moe_layers=%d topk=%d " + "input_ids_shape=%s attn_sum=%d " + "np_shape=%s | %s", + num_moe_layers, + topk, + tuple(input_ids.shape), + int(attention_mask.sum().item()), + getattr(routed_experts_np, "shape", None), + _r3_tensor_sig("result", result), + ) + except Exception: + pass + return result except Exception: logger.warning( "[R3] Failed to preprocess routed_experts (shape=%s); skipping.", From 327b9275b3fd09986c654bb51adc35038d19b3b8 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Sat, 2 May 2026 00:59:02 +0800 Subject: [PATCH 087/112] refactor(logging): fix log --- areal/engine/megatron_engine_r3_patch.py | 7 +++++-- areal/engine/router_replay_patch.py | 7 +++++-- areal/engine/router_replay_utils.py | 7 +++++-- areal/trainer/ppo/actor_r3_patch.py | 7 +++++-- areal/workflow/rlvr_r3_patch.py | 7 +++++-- 5 files changed, 25 insertions(+), 10 deletions(-) diff --git a/areal/engine/megatron_engine_r3_patch.py b/areal/engine/megatron_engine_r3_patch.py index 0f5674c032..6c2afadb04 100644 --- a/areal/engine/megatron_engine_r3_patch.py +++ b/areal/engine/megatron_engine_r3_patch.py @@ -25,14 +25,17 @@ from __future__ import annotations -import logging import types from collections.abc import Callable from typing import Any import torch -logger = logging.getLogger(__name__) +from areal.utils import logging + +# NOTE: use areal.utils.logging.getLogger with a stable registered +# name so the logger survives the dictConfig(disable_existing_loggers=True) re-init path. +logger = logging.getLogger("R3/megatron") # =================================================================== diff --git a/areal/engine/router_replay_patch.py b/areal/engine/router_replay_patch.py index 391a0f012b..591febd043 100644 --- a/areal/engine/router_replay_patch.py +++ b/areal/engine/router_replay_patch.py @@ -12,7 +12,6 @@ from __future__ import annotations import inspect -import logging import types import warnings from enum import Enum @@ -20,7 +19,11 @@ import torch -logger = logging.getLogger(__name__) +from areal.utils import logging + +# NOTE: use areal.utils.logging.getLogger with a stable registered +# name so the logger survives the dictConfig(disable_existing_loggers=True) re-init path. +logger = logging.getLogger("R3/patch") # --------------------------------------------------------------------------- # Optional megatron-core imports with fallback diff --git a/areal/engine/router_replay_utils.py b/areal/engine/router_replay_utils.py index 13b03322b3..8d2d275a15 100644 --- a/areal/engine/router_replay_utils.py +++ b/areal/engine/router_replay_utils.py @@ -14,7 +14,6 @@ from __future__ import annotations import inspect -import logging import os from typing import Optional @@ -22,7 +21,11 @@ from areal.engine.router_replay_patch import RouterReplay, RouterReplayAction -logger = logging.getLogger(__name__) +from areal.utils import logging + +# NOTE: use areal.utils.logging.getLogger with a stable registered +# name so the logger survives the dictConfig(disable_existing_loggers=True) re-init path. +logger = logging.getLogger("R3/utils") # =================================================================== diff --git a/areal/trainer/ppo/actor_r3_patch.py b/areal/trainer/ppo/actor_r3_patch.py index 08cf2e681f..e553383d9e 100644 --- a/areal/trainer/ppo/actor_r3_patch.py +++ b/areal/trainer/ppo/actor_r3_patch.py @@ -7,12 +7,15 @@ from __future__ import annotations -import logging from typing import Any import torch -logger = logging.getLogger(__name__) +from areal.utils import logging + +# NOTE: use areal.utils.logging.getLogger with a stable registered +# name so the logger survives the dictConfig(disable_existing_loggers=True) re-init path. +logger = logging.getLogger("R3/actor") def _resolve_to_tensor(obj: Any) -> torch.Tensor | None: diff --git a/areal/workflow/rlvr_r3_patch.py b/areal/workflow/rlvr_r3_patch.py index 03a00b8775..68ca8ed7ec 100644 --- a/areal/workflow/rlvr_r3_patch.py +++ b/areal/workflow/rlvr_r3_patch.py @@ -17,12 +17,15 @@ from __future__ import annotations -import logging import numpy as np import torch -logger = logging.getLogger(__name__) +from areal.utils import logging + +# NOTE: use areal.utils.logging.getLogger with a stable registered +# name so the logger survives the dictConfig(disable_existing_loggers=True) re-init path. +logger = logging.getLogger("R3/rlvr") _RESOLVED_CACHE: dict[str, tuple[int, int]] = {} From 23afbee7961158e109536f20e49e7c9ebb914ca9 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Sat, 2 May 2026 01:45:55 +0800 Subject: [PATCH 088/112] fix(router_replay): add mask --- areal/engine/router_replay_patch.py | 64 ++++++++++++++++++++++++++++- areal/engine/router_replay_utils.py | 45 ++++++++++++++++++-- 2 files changed, 103 insertions(+), 6 deletions(-) diff --git a/areal/engine/router_replay_patch.py b/areal/engine/router_replay_patch.py index 591febd043..0265f63fed 100644 --- a/areal/engine/router_replay_patch.py +++ b/areal/engine/router_replay_patch.py @@ -153,12 +153,35 @@ def __init__(self) -> None: self.recorded_topk_idx: torch.Tensor | None = None self.router_replay_action: RouterReplayAction | None = None self.replay_backward_list: list[torch.Tensor] = [] + # 1-D bool mask (shape=(num_tokens,)) marking which + # rows of ``target_topk_idx`` correspond to real tokens. Padded + # rows (seq-alignment slack + batch padding) must not be forced to + # the recorded top-k (which is [0,...,0]); instead we let them fall + # back to the live router output so they produce no replay signal. + self.target_valid_mask: torch.Tensor | None = None + self.replay_backward_mask_list: list[torch.Tensor | None] = [] RouterReplay.router_instances.append(self) - def set_target_indices(self, topk_indices: torch.Tensor) -> None: - """Sets the target topk indices for replay.""" + def set_target_indices( + self, + topk_indices: torch.Tensor, + valid_mask: torch.Tensor | None = None, + ) -> None: + """Sets the target topk indices (and optional row-validity mask) for replay. + + Args: + topk_indices: ``(num_tokens, topk)`` replay indices. + valid_mask: Optional ``(num_tokens,)`` bool tensor. ``True`` means + the row is a real token and replay should override live routing; + ``False`` means the row is padding (batch or TP-alignment slack) + and replay MUST fall back to live routing to avoid forcing those + rows to expert 0. When ``None``, all rows are treated as real + (legacy behaviour). + """ self.target_topk_idx = topk_indices + self.target_valid_mask = valid_mask self.replay_backward_list.append(topk_indices) + self.replay_backward_mask_list.append(valid_mask) # Cheap diagnostic: record every set in first few layers/mb. Gated # via _r3_should_log so steady-state overhead is ~one integer # increment. @@ -198,7 +221,9 @@ def record_indices(self, topk_indices: torch.Tensor) -> None: def clear_indices(self) -> None: self.recorded_topk_idx = None self.target_topk_idx = None + self.target_valid_mask = None self.replay_backward_list = [] + self.replay_backward_mask_list = [] def set_router_replay_action(self, action: RouterReplayAction) -> None: self.router_replay_action = action @@ -342,6 +367,19 @@ def compute_topk(scores, topk, num_groups=None, group_topk=None): # Use the provided indices for replay top_indices = router_replay.target_topk_idx top_indices = top_indices.to(scores.device) + # splice padded rows with the LIVE router top-k so + # that TP-alignment / batch padding slack (which was recorded as + # all-zeros) does not force those rows to expert 0. + valid_mask = getattr(router_replay, "target_valid_mask", None) + if valid_mask is not None and valid_mask.shape[0] == top_indices.shape[0]: + _, live_top = _compute_topk( + scores, topk, num_groups=num_groups, group_topk=group_topk + ) + top_indices = torch.where( + valid_mask.to(scores.device).unsqueeze(-1), + top_indices, + live_top, + ) _R3_routing_log( "REPLAY_FORWARD", scores=scores, @@ -366,6 +404,28 @@ def compute_topk(scores, topk, num_groups=None, group_topk=None): # Use the last recorded indices for backward replay top_indices = router_replay.replay_backward_list.pop(0) top_indices = top_indices.to(scores.device) + # pop the matching per-row validity mask (if any) + # so the backward recompute sees the same spliced indices as the + # original forward pass. Without this, activation-checkpoint + # recomputation re-introduces the all-zero padding rows and the + # gradient path contradicts the forward pass. + bw_mask_list = getattr(router_replay, "replay_backward_mask_list", None) + if bw_mask_list: + bw_valid_mask = bw_mask_list.pop(0) + else: + bw_valid_mask = None + if ( + bw_valid_mask is not None + and bw_valid_mask.shape[0] == top_indices.shape[0] + ): + _, live_top = _compute_topk( + scores, topk, num_groups=num_groups, group_topk=group_topk + ) + top_indices = torch.where( + bw_valid_mask.to(scores.device).unsqueeze(-1), + top_indices, + live_top, + ) _R3_routing_log( "REPLAY_BACKWARD", scores=scores, diff --git a/areal/engine/router_replay_utils.py b/areal/engine/router_replay_utils.py index 8d2d275a15..73f2eb7c9d 100644 --- a/areal/engine/router_replay_utils.py +++ b/areal/engine/router_replay_utils.py @@ -472,6 +472,22 @@ def set_router_replay_data( dtype=layers_topk_idx.dtype, device=layers_topk_idx.device, ) + # ---------------------------------------------------------------- + # build a 1-D validity mask in lock-step with ``packed``. + # True = real token position (safe to replay). + # False = padding (per-seq TP alignment slack OR batch-level padding + # sequence OR tail rows beyond the real payload). + # + # After ``scatter_to_sequence_parallel_region``, this mask is sliced + # the same way as ``packed`` so each TP rank knows which of its local + # rows are real. Rows with mask==False MUST NOT be forced to the + # recorded top-k (which is [0,0,...,0] for padding) + # ---------------------------------------------------------------- + valid_mask = torch.zeros( + total_aligned, + dtype=torch.bool, + device=layers_topk_idx.device, + ) aligned_offset = 0 for i in range(min(n_seqs_in_cu, bs_re)): @@ -484,10 +500,14 @@ def set_router_replay_data( packed[aligned_offset : aligned_offset + actual_len] = ( layers_topk_idx[i, :actual_len] ) + # Only the real-token span is marked valid; the per-seq + # TP-alignment slack (aligned_lens[i] - actual_len) stays False. + valid_mask[aligned_offset : aligned_offset + actual_len] = True aligned_offset += aligned_lens[i] # For any extra sequences in cu_seqlens beyond bs_re (batch padding), - # the packed tensor already has zeros at those positions. + # the packed tensor already has zeros at those positions and + # ``valid_mask`` stays False for that entire span. for i in range(bs_re, n_seqs_in_cu): aligned_offset += aligned_lens[i] @@ -499,37 +519,54 @@ def set_router_replay_data( ) zrows = int(per_row_zero.sum().item()) total_rows = int(per_row_zero.numel()) + n_valid = int(valid_mask.sum().item()) logger.info( "[R3-STAGE3/set_router_replay_data] PACKED " "packed=(total_aligned=%d, L=%d, K=%d) global_zero_rows=%d/%d " - "(%.2f%%) | %s", + "(%.2f%%) valid_rows=%d/%d (%.2f%%) | %s", packed.shape[0], packed.shape[1], packed.shape[2], zrows, total_rows, 100.0 * zrows / max(total_rows, 1), + n_valid, + total_rows, + 100.0 * n_valid / max(total_rows, 1), _r3_tensor_sig("packed", packed), ) # Step 2: Scatter to SP ranks packed = packed.to(device) + valid_mask = valid_mask.to(device) tp_size = mpu.get_tensor_model_parallel_world_size() if tp_size > 1: from megatron.core.tensor_parallel import scatter_to_sequence_parallel_region local_tokens = scatter_to_sequence_parallel_region(packed) + # Scatter the mask on dim-0 as well. ``scatter_to_sequence_parallel_region`` + # expects a tensor with a sequence dimension on dim 0; promote the + # bool mask to the same dtype as ``packed`` to keep the op's + # collective-compat contract intact, then cast back. + mask_buf = valid_mask.to(packed.dtype).unsqueeze(-1).unsqueeze(-1) + local_mask = scatter_to_sequence_parallel_region(mask_buf)[..., 0, 0].bool() else: local_tokens = packed + local_mask = valid_mask # local_tokens: (local_tokens_count, num_layers, topk) + # local_mask: (local_tokens_count,) # Step 3: Permute to (num_layers, local_tokens_count, topk) layers_topk = local_tokens.permute(1, 0, 2) if _r3_verbose() and _r3_should_log("set_router_replay_data/SCATTER"): + with torch.no_grad(): + n_local_valid = int(local_mask.sum().item()) logger.info( "[R3-STAGE3/set_router_replay_data] POST-SCATTER " - "tp_size=%d local_tokens=%s layers_topk=%s", + "tp_size=%d local_valid=%d/%d local_tokens=%s layers_topk=%s", tp_size, + n_local_valid, + local_mask.numel(), _r3_tensor_sig("local_tokens", local_tokens), _r3_tensor_sig("layers_topk", layers_topk), ) @@ -580,7 +617,7 @@ def set_router_replay_data( router_offset += 1 continue slab = layers_topk[idx].to(torch.int64) - router.set_target_indices(slab) + router.set_target_indices(slab, valid_mask=local_mask) if _r3_verbose() and _r3_should_log("set_router_replay_data/DISPATCH"): dispatched.append( ( From 96cbdcf0e7ecdd71fb69a948ea9789d71d5e0150 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Sat, 2 May 2026 03:34:14 +0800 Subject: [PATCH 089/112] fix(router_replay): zero row fix --- areal/engine/router_replay_utils.py | 42 ++++++++++++++++++++++++----- 1 file changed, 36 insertions(+), 6 deletions(-) diff --git a/areal/engine/router_replay_utils.py b/areal/engine/router_replay_utils.py index 73f2eb7c9d..84e39d5b69 100644 --- a/areal/engine/router_replay_utils.py +++ b/areal/engine/router_replay_utils.py @@ -511,19 +511,48 @@ def set_router_replay_data( for i in range(bs_re, n_seqs_in_cu): aligned_offset += aligned_lens[i] + # ---------------------------------------------------------------- + # strike "structurally all-zero" rows from the + # validity mask even when they fall inside ``[0, seq_len)``. + # + # Motivation: SGLang's async rollout does NOT record routing for the + # last generated token of each sequence (the EOS / boundary token) + # because its routing metadata is finalised AFTER the forward pass + # that produces it. That leaves exactly one all-zero tail row per + # sequence inside the "valid" span (evidence: every rollout EXIT + # log shows ``tail_real_row_all_zero=True zero_rows_total=1/N``). + # Plan A already masks the per-seq TP alignment slack and batch + # padding sequences, but those tail rows slip through: they are + # recorded positions whose routing happens to be [0,...,0] for + # every one of the ``L * K`` slots. + # + # Because Moonlight-16B has 27 MoE layers with top-6 routing to 64 + # experts and ``torch.topk`` returns distinct indices, a real token + # producing [0,0,0,0,0,0] across all 27 layers has probability + # essentially 0. It is safe -- and correct -- to treat every such + # row as a recording gap and fall back to the LIVE router top-k + # during replay, exactly like padding rows. This keeps the + # forward/backward spliced indices consistent (both branches see + # the same ``target_valid_mask``) and restores the low + # ``mean_abs_diff`` profile on every micro-batch. + # ---------------------------------------------------------------- + with torch.no_grad(): + row_all_zero = ( + (packed == 0).reshape(packed.shape[0], -1).all(dim=-1) + ) + n_strike = int((valid_mask & row_all_zero).sum().item()) + valid_mask = valid_mask & (~row_all_zero) + if _r3_verbose() and _r3_should_log("set_router_replay_data/PACKED"): with torch.no_grad(): # Count global all-zero rows across ALL layers AND topk slots. - per_row_zero = ( - (packed == 0).reshape(packed.shape[0], -1).all(dim=-1) - ) - zrows = int(per_row_zero.sum().item()) - total_rows = int(per_row_zero.numel()) + zrows = int(row_all_zero.sum().item()) + total_rows = int(row_all_zero.numel()) n_valid = int(valid_mask.sum().item()) logger.info( "[R3-STAGE3/set_router_replay_data] PACKED " "packed=(total_aligned=%d, L=%d, K=%d) global_zero_rows=%d/%d " - "(%.2f%%) valid_rows=%d/%d (%.2f%%) | %s", + "(%.2f%%) valid_rows=%d/%d (%.2f%%) struck_tail_rows=%d | %s", packed.shape[0], packed.shape[1], packed.shape[2], @@ -533,6 +562,7 @@ def set_router_replay_data( n_valid, total_rows, 100.0 * n_valid / max(total_rows, 1), + n_strike, _r3_tensor_sig("packed", packed), ) From bf5345f93f61e710c790d666a3223ed11d91c7e8 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Sat, 2 May 2026 12:13:17 +0800 Subject: [PATCH 090/112] feat(r3): add deep log --- areal/engine/megatron_engine_r3_patch.py | 87 ++++++++- areal/engine/router_replay_patch.py | 126 +++++++++++- areal/engine/router_replay_utils.py | 236 ++++++++++++++++++++++- areal/trainer/ppo/actor.py | 166 ++++++++++++++-- areal/trainer/ppo/actor_r3_patch.py | 31 ++- 5 files changed, 610 insertions(+), 36 deletions(-) diff --git a/areal/engine/megatron_engine_r3_patch.py b/areal/engine/megatron_engine_r3_patch.py index 6c2afadb04..533b9d9956 100644 --- a/areal/engine/megatron_engine_r3_patch.py +++ b/areal/engine/megatron_engine_r3_patch.py @@ -295,6 +295,10 @@ def _split_routed_experts_for_mbs( ) try: from areal.engine.router_replay_utils import ( + _r3_hash64, + _r3_per_sample_hashes, + _r3_per_sample_nnz, + _r3_per_sample_seq_real_len, _r3_pp_tp_info, _r3_should_log, _r3_tensor_sig, @@ -302,20 +306,41 @@ def _split_routed_experts_for_mbs( ) if _r3_verbose() and _r3_should_log("_split_routed_experts_for_mbs"): + pre_hash = _r3_per_sample_hashes(routed_experts, max_rows=32) + post_hash = _r3_per_sample_hashes(reordered, max_rows=32) + per_mb_hashes = [ + [hex(h) for h in _r3_per_sample_hashes(r, max_rows=16)] + for r in result + ] + per_mb_nnz = [_r3_per_sample_nnz(r, max_rows=16) for r in result] + per_mb_real = [_r3_per_sample_seq_real_len(r, max_rows=16) for r in result] logger.info( "[R3-STAGE3/_split_routed_experts_for_mbs] %s " - "input_shape=%s n_mbs=%d forward_indices=%s " - "per_mb_shapes=%s | %s", + "input_shape=%s input_hash=%s n_mbs=%d " + "forward_indices=%s per_mb_shapes=%s per_mb_hashes=%s " + "pre_reorder_per_sample_hash[:16]=%s " + "post_reorder_per_sample_hash[:16]=%s " + "per_mb_per_sample_hash=%s per_mb_per_sample_nnz=%s " + "per_mb_per_sample_real_len=%s | %s", _r3_pp_tp_info(), tuple(routed_experts.shape), + hex(_r3_hash64(routed_experts)), n_mbs, "None" if forward_indices is None - else f"len={len(forward_indices)} first16={forward_indices[:16].tolist() if hasattr(forward_indices,'tolist') else list(forward_indices)[:16]}", + else f"len={len(forward_indices)} first32={forward_indices[:32].tolist() if hasattr(forward_indices,'tolist') else list(forward_indices)[:32]}", [tuple(r.shape) for r in result], + [hex(_r3_hash64(r)) for r in result], + [hex(h) for h in pre_hash[:16]], + [hex(h) for h in post_hash[:16]], + per_mb_hashes, + per_mb_nnz, + per_mb_real, _r3_tensor_sig("routed_experts", routed_experts, max_sample=4), ) except Exception: - pass + logger.exception( + "[R3-STAGE3/_split_routed_experts_for_mbs] trace log failed" + ) return result @@ -409,12 +434,43 @@ def _r3_forward_backward_batch( # ------------------------------------------------------------------ routed_experts_batch = None _from_side_channel = False + _consumed_trace_id = getattr(self, "_r3_active_trace_id", None) # Strategy A: Side-channel (preferred path) if hasattr(self, '_r3_pending_routed_experts') and self._r3_pending_routed_experts is not None: routed_experts_batch = self._r3_pending_routed_experts self._r3_pending_routed_experts = None # Consume it _from_side_channel = True + try: + from areal.engine.router_replay_utils import ( + _r3_hash64, + _r3_per_sample_hashes, + _r3_per_sample_nnz, + _r3_per_sample_seq_real_len, + _r3_pp_tp_info, + _r3_verbose, + ) + if _r3_verbose(): + logger.info( + "[R3-STAGE3/_r3_forward_backward_batch] " + "SIDE_CHANNEL_CONSUME trace_id=%s %s forward_only=%s " + "shape=%s hash=%s per_sample_hash[:16]=%s " + "per_sample_nnz[:16]=%s per_sample_real_len[:16]=%s", + _consumed_trace_id, + _r3_pp_tp_info(), + forward_only, + routed_experts_batch.shape, + hex(_r3_hash64(routed_experts_batch)), + [hex(h) for h in _r3_per_sample_hashes( + routed_experts_batch, max_rows=16)], + _r3_per_sample_nnz(routed_experts_batch, max_rows=16), + _r3_per_sample_seq_real_len(routed_experts_batch, max_rows=16), + ) + except Exception: + logger.exception( + "[R3-STAGE3/_r3_forward_backward_batch] " + "SIDE_CHANNEL_CONSUME trace log failed" + ) logger.debug( "[R3] Retrieved routed_experts from engine side-channel: shape=%s.", routed_experts_batch.shape, @@ -594,14 +650,35 @@ def __next__(self): if _r3_verbose() and _r3_should_log( "_R3MicroBatchIterator.pre_align" ): + from areal.engine.router_replay_utils import ( + _r3_hash64, + _r3_per_sample_hashes, + _r3_per_sample_nnz, + _r3_per_sample_seq_real_len, + _r3_current_trace_id, + ) logger.info( "[R3-STAGE3/_R3MicroBatchIterator] PRE-ALIGN " - "mb_idx=%d %s orig_cu_src=%s max_seqlen=%d " + "mb_idx=%d trace_id=%d %s orig_cu_src=%s " + "max_seqlen=%d re_shape=%s re_hash=%s " + "per_sample_hash[:16]=%s per_sample_nnz[:16]=%s " + "per_sample_real_len[:16]=%s " + "orig_cu_diff[:16]=%s padded_cu_diff[:16]=%s " "| %s | %s | %s", idx, + _r3_current_trace_id(), _r3_pp_tp_info(), orig_cu_src, max_seqlen, + tuple(re.shape), + hex(_r3_hash64(re)), + [hex(h) for h in _r3_per_sample_hashes(re, max_rows=16)], + _r3_per_sample_nnz(re, max_rows=16), + _r3_per_sample_seq_real_len(re, max_rows=16), + (orig_cu[1:] - orig_cu[:-1]).long().cpu().tolist()[:16] + if hasattr(orig_cu, "cpu") else "N/A", + (cu_seqlens[1:] - cu_seqlens[:-1]).long().cpu().tolist()[:16] + if hasattr(cu_seqlens, "cpu") else "N/A", _r3_tensor_sig("re", re, max_sample=4), _r3_tensor_sig("orig_cu", orig_cu), _r3_tensor_sig("padded_cu", cu_seqlens), diff --git a/areal/engine/router_replay_patch.py b/areal/engine/router_replay_patch.py index 0265f63fed..d6873c76da 100644 --- a/areal/engine/router_replay_patch.py +++ b/areal/engine/router_replay_patch.py @@ -187,25 +187,38 @@ def set_target_indices( # increment. try: from areal.engine.router_replay_utils import ( + _r3_current_trace_id, + _r3_hash64, _r3_pp_tp_info, _r3_should_log, _r3_tensor_sig, + _r3_verbose, _r3_zero_row_stats, ) - if _r3_should_log("RouterReplay.set_target_indices"): + if _r3_verbose() and _r3_should_log("RouterReplay.set_target_indices"): # instance index in the class-level list tells us which # MoE layer this replay slot refers to try: inst_idx = RouterReplay.router_instances.index(self) except ValueError: inst_idx = -1 + _slab_hash = hex(_r3_hash64(topk_indices)) + _mask_hash = ( + hex(_r3_hash64(valid_mask.to(torch.int32))) + if valid_mask is not None else "None" + ) logger.info( - "[R3-STAGE3b/set_target_indices] inst#%d %s %s | %s | " - "backward_queue_len=%d", + "[R3-STAGE3b/set_target_indices] trace_id=%d inst#%d %s %s " + "slab_shape=%s slab_hash=%s mask_hash=%s " + "| %s | backward_queue_len=%d (post-push)", + _r3_current_trace_id(), inst_idx, _r3_pp_tp_info(), _r3_zero_row_stats(topk_indices), + tuple(topk_indices.shape), + _slab_hash, + _mask_hash, _r3_tensor_sig("topk_indices", topk_indices), len(self.replay_backward_list), ) @@ -369,8 +382,36 @@ def compute_topk(scores, topk, num_groups=None, group_topk=None): top_indices = top_indices.to(scores.device) # splice padded rows with the LIVE router top-k so # that TP-alignment / batch padding slack (which was recorded as - # all-zeros) does not force those rows to expert 0. + # all-zeros) does not force those rows to expert 0. valid_mask = getattr(router_replay, "target_valid_mask", None) + try: + from areal.engine.router_replay_utils import ( + _r3_current_trace_id as _tid, + _r3_hash64 as _h64, + _r3_should_log as _sl, + _r3_verbose as _v, + ) + if _v() and _sl("REPLAY_FORWARD/consume"): + try: + _inst_idx = RouterReplay.router_instances.index(router_replay) + except ValueError: + _inst_idx = -1 + logger.info( + "[R3-STAGE4/REPLAY_FORWARD/consume] trace_id=%d inst#%d " + "scores_shape=%s target_shape=%s shape_match=%s " + "target_hash=%s mask_hash=%s backward_queue_len=%d", + _tid(), + _inst_idx, + tuple(scores.shape), + tuple(top_indices.shape), + top_indices.shape[0] == scores.shape[0], + hex(_h64(top_indices)), + "None" if valid_mask is None + else hex(_h64(valid_mask.to(torch.int32))), + len(router_replay.replay_backward_list), + ) + except Exception: + pass if valid_mask is not None and valid_mask.shape[0] == top_indices.shape[0]: _, live_top = _compute_topk( scores, topk, num_groups=num_groups, group_topk=group_topk @@ -402,6 +443,10 @@ def compute_topk(scores, topk, num_groups=None, group_topk=None): ) return _compute_topk(scores, topk, num_groups=num_groups, group_topk=group_topk) # Use the last recorded indices for backward replay + _bw_queue_len_before = len(router_replay.replay_backward_list) + _bw_mask_queue_len_before = len( + getattr(router_replay, "replay_backward_mask_list", []) or [] + ) top_indices = router_replay.replay_backward_list.pop(0) top_indices = top_indices.to(scores.device) # pop the matching per-row validity mask (if any) @@ -414,6 +459,79 @@ def compute_topk(scores, topk, num_groups=None, group_topk=None): bw_valid_mask = bw_mask_list.pop(0) else: bw_valid_mask = None + # ---- R3 deep-trace: log backward pop order + hashes ---- + try: + from areal.engine.router_replay_utils import ( + _r3_current_trace_id as _tid, + _r3_hash64 as _h64, + _r3_should_log as _sl, + _r3_verbose as _v, + ) + + if _v() and _sl("REPLAY_BACKWARD/consume"): + try: + _inst_idx = RouterReplay.router_instances.index(router_replay) + except ValueError: + _inst_idx = -1 + _popped_slab_hash = hex(_h64(top_indices)) + _popped_mask_hash = ( + "None" + if bw_valid_mask is None + else hex(_h64(bw_valid_mask.to(torch.int32))) + ) + _target_hash = ( + "None" + if router_replay.target_topk_idx is None + else hex(_h64(router_replay.target_topk_idx)) + ) + _divergence = ( + "None" + if router_replay.target_topk_idx is None + else ( + "MATCH" + if ( + router_replay.target_topk_idx.shape + == top_indices.shape + and hex( + _h64( + router_replay.target_topk_idx.to( + top_indices.device + ) + ) + ) + == _popped_slab_hash + ) + else "DIVERGE_vs_FWD_TARGET" + ) + ) + logger.info( + "[R3-STAGE4/REPLAY_BACKWARD/consume] trace_id=%d inst#%d " + "scores_shape=%s popped_shape=%s shape_match_scores=%s " + "popped_slab_hash=%s popped_mask_hash=%s " + "current_target_hash=%s divergence=%s " + "queue_len_before=%d queue_len_after=%d " + "mask_queue_len_before=%d mask_queue_len_after=%d", + _tid(), + _inst_idx, + tuple(scores.shape), + tuple(top_indices.shape), + top_indices.shape[0] == scores.shape[0], + _popped_slab_hash, + _popped_mask_hash, + _target_hash, + _divergence, + _bw_queue_len_before, + len(router_replay.replay_backward_list), + _bw_mask_queue_len_before, + len( + getattr(router_replay, "replay_backward_mask_list", []) + or [] + ), + ) + except Exception: + logger.exception( + "[R3-STAGE4/REPLAY_BACKWARD/consume] trace log failed" + ) if ( bw_valid_mask is not None and bw_valid_mask.shape[0] == top_indices.shape[0] diff --git a/areal/engine/router_replay_utils.py b/areal/engine/router_replay_utils.py index 84e39d5b69..c991d5df3a 100644 --- a/areal/engine/router_replay_utils.py +++ b/areal/engine/router_replay_utils.py @@ -188,6 +188,162 @@ def _r3_pp_tp_info(tf_config=None, vp_rank=None) -> str: return "pp=?/? tp=?/?" +# =================================================================== +# Root-cause hunting helpers (v2 — per-sample & per-mb fingerprints) +# =================================================================== +# We need to pinpoint whether the R3 replay indices that reach the router +# are the SAME bytes as those the rollout engine produced for the SAME +# sample. The cheapest reliable way to do this is a per-sample 64-bit +# fold-hash of the int32 tensor bytes. The hash is stable across device +# (we move to CPU once), preserves per-sample order, and survives +# reordering (each sample is hashed independently — we can then check +# any permutation via the multiset of hashes). +# +# We also expose a monotonically increasing trace-id that the actor +# increments every time it sets ``engine._r3_pending_routed_experts`` +# so each end-to-end replay can be correlated across STAGE2 → STAGE3 → +# STAGE4 log lines. +# =================================================================== + + +# Global monotonically increasing trace-id. Incremented at the side-channel +# SET site (actor._compute_logp / actor._ppo_update). Read back at the +# CONSUMPTION site in ``_r3_forward_backward_batch``. Exported via an +# env-independent module-level function so *all* stages print the same id. +_R3_TRACE_ID: int = 0 + + +def _r3_next_trace_id() -> int: + """Reserve & return a new trace-id. + + A trace-id identifies one SIDE_CHANNEL-SET -> CONSUME -> REPLAY cycle. + """ + global _R3_TRACE_ID + _R3_TRACE_ID += 1 + return _R3_TRACE_ID + + +def _r3_current_trace_id() -> int: + return _R3_TRACE_ID + + +def _r3_hash64(t) -> int: + """Return a stable 64-bit hash of a tensor/ndarray's int32 bytes. + + For a ``(bs, seqlen, L, K)`` routed_experts tensor this is cheap + (one CPU copy) and deterministic regardless of dtype conversion. + Returns 0 for ``None``. + """ + if t is None: + return 0 + try: + if isinstance(t, torch.Tensor): + tc = t.detach() + if tc.device.type != "cpu": + tc = tc.to("cpu", non_blocking=False) + arr = tc.to(torch.int32).contiguous().numpy() + else: + import numpy as np + + arr = (t if isinstance(t, np.ndarray) else np.asarray(t)).astype("int32", copy=False) + import hashlib + + return int.from_bytes( + hashlib.blake2b(arr.tobytes(), digest_size=8).digest(), + "big", + signed=False, + ) + except Exception: + return -1 + + +def _r3_per_sample_hashes(t, max_rows: int = 64) -> list[int]: + """Return per-sample 64-bit hashes. + + For a 4D ``(bs, seqlen, L, K)`` tensor, returns one hash per sample + (dim-0). For 3D packed ``(total_aligned, L, K)`` returns one hash + per row (capped at ``max_rows`` to keep log size sane). + """ + if t is None: + return [] + try: + if isinstance(t, torch.Tensor): + tc = t.detach() + if tc.device.type != "cpu": + tc = tc.to("cpu", non_blocking=False) + arr = tc.to(torch.int32).contiguous().numpy() + else: + import numpy as np + + arr = (t if isinstance(t, np.ndarray) else np.asarray(t)).astype("int32", copy=False) + import hashlib + + out = [] + for i in range(min(arr.shape[0], max_rows)): + b = arr[i].tobytes() + out.append( + int.from_bytes( + hashlib.blake2b(b, digest_size=8).digest()[:4], + "big", + signed=False, + ) + ) + return out + except Exception: + return [-1] + + +def _r3_per_sample_nnz(t, max_rows: int = 64) -> list[int]: + """Return per-sample non-zero counts (rows where any expert id != 0).""" + if t is None: + return [] + try: + if isinstance(t, torch.Tensor): + tc = t.detach() + if tc.device.type != "cpu": + tc = tc.to("cpu", non_blocking=False) + arr = tc.to(torch.int32).contiguous().numpy() + else: + import numpy as np + + arr = (t if isinstance(t, np.ndarray) else np.asarray(t)).astype("int32", copy=False) + out = [] + for i in range(min(arr.shape[0], max_rows)): + out.append(int((arr[i] != 0).any(axis=-1).sum())) + return out + except Exception: + return [-1] + + +def _r3_per_sample_seq_real_len(t, max_rows: int = 64) -> list[int]: + """Return per-sample "real-looking" length = index of last non-all-zero row + 1. + + Useful for verifying that the routed_experts tensor is right-padded + as expected: the real length should equal the sample's attention + mask sum (= cu_seqlens diff). If it doesn't, alignment is off. + """ + if t is None: + return [] + try: + if isinstance(t, torch.Tensor): + tc = t.detach() + if tc.device.type != "cpu": + tc = tc.to("cpu", non_blocking=False) + arr = tc.to(torch.int32).contiguous().numpy() + else: + import numpy as np + + arr = (t if isinstance(t, np.ndarray) else np.asarray(t)).astype("int32", copy=False) + out = [] + for i in range(min(arr.shape[0], max_rows)): + row_any = (arr[i] != 0).any(axis=(-1, -2)) if arr[i].ndim >= 2 else (arr[i] != 0) + nz = row_any.nonzero()[0] + out.append(int(nz[-1]) + 1 if len(nz) else 0) + return out + except Exception: + return [-1] + + # =================================================================== # Layer computation helpers # =================================================================== @@ -441,12 +597,13 @@ def set_router_replay_data( if _r3_verbose() and _r3_should_log("set_router_replay_data/ENTER"): logger.info( - "[R3-STAGE3/set_router_replay_data] ENTER call#%d %s " + "[R3-STAGE3/set_router_replay_data] ENTER call#%d trace_id=%d %s " "layers_topk_idx=(bs=%d, max_seq=%d, L=%d, K=%d) dtype=%s " "n_cu_entries=%d n_seqs_in_cu=%d seq_align_to=%d " - "seq_lens[:8]=%s aligned_lens[:8]=%s total_aligned=%d " + "seq_lens[:16]=%s aligned_lens[:16]=%s total_aligned=%d " "vp_rank=%s | %s", _r3_call_count("set_router_replay_data/ENTER"), + _r3_current_trace_id(), _r3_pp_tp_info(tf_config, vp_rank), bs_re, layers_topk_idx.shape[1], @@ -456,12 +613,37 @@ def set_router_replay_data( n_cu_entries, n_seqs_in_cu, seq_align_to, - seq_lens[:8], - aligned_lens[:8], + seq_lens[:16], + aligned_lens[:16], total_aligned, vp_rank, _r3_tensor_sig("cu_seqlens", cu_seqlens), ) + # Per-sample fingerprint (hash, nnz, real_len) so we can verify + # the SAME bytes reach here as the actor pushed into the + # side-channel. Any mismatch between hashes implies a + # split/reorder bug somewhere upstream. + if _r3_verbose() and _r3_should_log("set_router_replay_data/PER_SAMPLE"): + try: + _h = _r3_per_sample_hashes(layers_topk_idx, max_rows=32) + _nnz = _r3_per_sample_nnz(layers_topk_idx, max_rows=32) + _rl = _r3_per_sample_seq_real_len(layers_topk_idx, max_rows=32) + logger.info( + "[R3-STAGE3/set_router_replay_data] PER_SAMPLE trace_id=%d %s " + "bs_re=%d n_seqs_in_cu=%d " + "per_sample_hash[:16]=%s per_sample_nnz_rows[:16]=%s " + "per_sample_real_len[:16]=%s cu_seqlens_diff[:16]=%s", + _r3_current_trace_id(), + _r3_pp_tp_info(tf_config, vp_rank), + bs_re, + n_seqs_in_cu, + [hex(h) for h in _h[:16]], + _nnz[:16], + _rl[:16], + seq_lens[:16], + ) + except Exception as e: + logger.warning("[R3-STAGE3/set_router_replay_data] PER_SAMPLE err=%s", e) # Pack routed_experts using cu_seqlens-aligned layout. # layers_topk_idx is left-ALIGNED: real tokens at positions [0, seq_len). @@ -549,10 +731,32 @@ def set_router_replay_data( zrows = int(row_all_zero.sum().item()) total_rows = int(row_all_zero.numel()) n_valid = int(valid_mask.sum().item()) + # Per-sample valid-row count (after strike), lined up with + # aligned_lens so any off-by-one immediately surfaces. + per_sample_valid_after = [] + per_sample_valid_before = [] + _off = 0 + for _i in range(n_seqs_in_cu): + _al = aligned_lens[_i] if _i < len(aligned_lens) else 0 + _seg = valid_mask[_off : _off + _al] + _segz = row_all_zero[_off : _off + _al] + per_sample_valid_after.append(int(_seg.sum().item())) + per_sample_valid_before.append( + int((~_segz[: seq_lens[_i] if _i < len(seq_lens) else 0]).sum().item()) + if _i < len(seq_lens) + else 0 + ) + _off += _al logger.info( - "[R3-STAGE3/set_router_replay_data] PACKED " + "[R3-STAGE3/set_router_replay_data] PACKED trace_id=%d %s " "packed=(total_aligned=%d, L=%d, K=%d) global_zero_rows=%d/%d " - "(%.2f%%) valid_rows=%d/%d (%.2f%%) struck_tail_rows=%d | %s", + "(%.2f%%) valid_rows=%d/%d (%.2f%%) struck_tail_rows=%d " + "per_sample_valid_before_strike[:16]=%s " + "per_sample_valid_after_strike[:16]=%s " + "per_sample_real_len[:16]=%s aligned_lens[:16]=%s " + "packed_hash=%s | %s", + _r3_current_trace_id(), + _r3_pp_tp_info(tf_config, vp_rank), packed.shape[0], packed.shape[1], packed.shape[2], @@ -563,6 +767,11 @@ def set_router_replay_data( total_rows, 100.0 * n_valid / max(total_rows, 1), n_strike, + per_sample_valid_before[:16], + per_sample_valid_after[:16], + seq_lens[:16], + aligned_lens[:16], + hex(_r3_hash64(packed)), _r3_tensor_sig("packed", packed), ) @@ -592,13 +801,18 @@ def set_router_replay_data( with torch.no_grad(): n_local_valid = int(local_mask.sum().item()) logger.info( - "[R3-STAGE3/set_router_replay_data] POST-SCATTER " - "tp_size=%d local_valid=%d/%d local_tokens=%s layers_topk=%s", + "[R3-STAGE3/set_router_replay_data] POST-SCATTER trace_id=%d %s " + "tp_size=%d local_valid=%d/%d local_tokens=%s layers_topk=%s " + "local_tokens_hash=%s local_mask_hash=%s", + _r3_current_trace_id(), + _r3_pp_tp_info(tf_config, vp_rank), tp_size, n_local_valid, local_mask.numel(), _r3_tensor_sig("local_tokens", local_tokens), _r3_tensor_sig("layers_topk", layers_topk), + hex(_r3_hash64(local_tokens)), + hex(_r3_hash64(local_mask.to(torch.int32))), ) # Step 4: Distribute to RouterReplay instances for local PP layers @@ -655,6 +869,7 @@ def set_router_replay_data( idx, _r3_zero_row_stats(slab), _r3_tensor_sig(f"target[L={layer_idx}]", slab), + hex(_r3_hash64(slab)), ) ) router_offset += 1 @@ -675,15 +890,16 @@ def set_router_replay_data( # the rest summarised. head = dispatched[:_R3_ROUTER_LAYER_LIMIT] logger.info( - "[R3-STAGE3/set_router_replay_data] DISPATCH %s " + "[R3-STAGE3/set_router_replay_data] DISPATCH trace_id=%d %s " "router_offset=%d len(router_list)=%d index_by_layer=%s " "first_layers=%s ... (total dispatched=%d)", + _r3_current_trace_id(), _r3_pp_tp_info(tf_config, vp_rank), router_offset, len(router_list), index_by_layer, [ - (lidx, j, zr, sig) for lidx, j, zr, sig in head + (lidx, j, zr, sig, h) for lidx, j, zr, sig, h in head ], len(dispatched), ) diff --git a/areal/trainer/ppo/actor.py b/areal/trainer/ppo/actor.py index 894203526c..153bfd8fdf 100644 --- a/areal/trainer/ppo/actor.py +++ b/areal/trainer/ppo/actor.py @@ -175,6 +175,49 @@ def _compute_logp(self, data: dict[str, Any]) -> torch.Tensor | None: # call internally; the R3 engine patch will split routed_experts per # micro-batch and consume the side-channel (setting it back to None). self.engine._r3_pending_routed_experts = _r3_routed_experts + try: + from areal.engine.router_replay_utils import ( + _r3_hash64, + _r3_next_trace_id, + _r3_per_sample_hashes, + _r3_per_sample_nnz, + _r3_per_sample_seq_real_len, + _r3_pp_tp_info, + _r3_tensor_sig, + _r3_verbose, + ) + + _trace_id = _r3_next_trace_id() + self.engine._r3_active_trace_id = _trace_id + if _r3_verbose(): + logger.info( + "[R3-STAGE2/actor._compute_logp] SIDE_CHANNEL_SET " + "trace_id=%d %s bs=%d seqlen=%s L=%s K=%s " + "hash=%s per_sample_hash[:16]=%s " + "per_sample_nnz[:16]=%s per_sample_real_len[:16]=%s " + "attn_sum[:16]=%s | %s", + _trace_id, + _r3_pp_tp_info(), + _r3_routed_experts.shape[0], + _r3_routed_experts.shape[1], + _r3_routed_experts.shape[2] + if _r3_routed_experts.ndim >= 3 else None, + _r3_routed_experts.shape[3] + if _r3_routed_experts.ndim >= 4 else None, + hex(_r3_hash64(_r3_routed_experts)), + [hex(h) for h in _r3_per_sample_hashes( + _r3_routed_experts, max_rows=16)], + _r3_per_sample_nnz(_r3_routed_experts, max_rows=16), + _r3_per_sample_seq_real_len(_r3_routed_experts, max_rows=16), + ( + data["attention_mask"].sum(dim=-1).long().cpu().tolist()[:16] + if "attention_mask" in data + else "N/A" + ), + _r3_tensor_sig("routed_experts", _r3_routed_experts), + ) + except Exception: + logger.exception("[R3-STAGE2/actor._compute_logp] side-channel trace log failed") train_logp = self.engine.forward( input_=data, aggregate_fn=lambda xs: torch.cat(xs, dim=-1), @@ -254,6 +297,9 @@ def _log_r3_effectiveness_stats( try: from areal.engine.router_replay_utils import ( + _r3_current_trace_id, + _r3_hash64, + _r3_per_sample_hashes, _r3_should_log, _r3_tensor_sig, _r3_verbose, @@ -283,10 +329,11 @@ def _log_r3_effectiveness_stats( else: _mean_abs = _max_abs = _p99 = _mean_k3 = _frac_tau2 = _frac_tau5 = 0.0 logger.info( - "[R3-STAGE2/r3_effectiveness] r3_enabled=%s " + "[R3-STAGE2/r3_effectiveness] trace_id=%d r3_enabled=%s " "n_valid_tokens=%d mean_abs_diff=%.6f max_abs_diff=%.6f " "p99_abs_diff=%.6f mean_k3_kl=%.6f frac_tau2=%.4f " "frac_tau5=%.4f | %s | %s", + _r3_current_trace_id(), r3_enabled, n_valid, _mean_abs, @@ -298,8 +345,74 @@ def _log_r3_effectiveness_stats( _r3_tensor_sig("train_logp", train_logp), _r3_tensor_sig("rollout_logp_rolled", rollout_logp_f), ) + # ---- R3 per-sample breakdown: identify catastrophic samples ---- + with torch.no_grad(): + bs = shifted_mask.shape[0] + max_rows = min(bs, 64) + per_sample = [] + for i in range(max_rows): + m_i = shifted_mask[i] + n_i = int(m_i.sum().item()) + if n_i == 0: + per_sample.append( + { + "i": i, + "n": 0, + "mean_abs": 0.0, + "max_abs": 0.0, + "tau2_cnt": 0, + "tau5_cnt": 0, + "k3_mean": 0.0, + } + ) + continue + row_abs = abs_diff[i][m_i] + row_k3 = k3_kl[i][m_i] + per_sample.append( + { + "i": i, + "n": n_i, + "mean_abs": float(row_abs.mean().item()), + "max_abs": float(row_abs.max().item()), + "tau2_cnt": int( + (row_abs > torch.log(torch.tensor(2.0))).sum().item() + ), + "tau5_cnt": int( + (row_abs > torch.log(torch.tensor(5.0))).sum().item() + ), + "k3_mean": float(row_k3.mean().item()), + } + ) + # Routed-experts per-sample hashes (side-channel payload) + pending = getattr(self.engine, "_r3_pending_routed_experts", None) + re_hashes = ( + [hex(h) for h in _r3_per_sample_hashes(pending, max_rows=max_rows)] + if pending is not None + else [] + ) + re_full_hash = ( + hex(_r3_hash64(pending)) if pending is not None else "None" + ) + # Bad sample ranking (top-K by mean_abs) + sorted_bad = sorted( + per_sample, key=lambda x: x["mean_abs"], reverse=True + )[:8] + logger.info( + "[R3-STAGE2/r3_effectiveness/per_sample] trace_id=%d " + "r3_enabled=%s batch_size=%d routed_experts_full_hash=%s " + "per_sample=%s top_bad_samples=%s per_sample_routed_hash=%s", + _r3_current_trace_id(), + r3_enabled, + bs, + re_full_hash, + per_sample, + sorted_bad, + re_hashes, + ) except Exception: - pass + logger.exception( + "[R3-STAGE2/r3_effectiveness] per-sample trace log failed" + ) with stats_tracker.scope("compute_logp"): with stats_tracker.scope("r3"): @@ -594,25 +707,49 @@ def _ppo_update(self, data: dict[str, Any]) -> None: ) try: from areal.engine.router_replay_utils import ( + _r3_hash64, + _r3_next_trace_id, + _r3_per_sample_hashes, + _r3_per_sample_nnz, + _r3_per_sample_seq_real_len, + _r3_pp_tp_info, _r3_should_log, _r3_tensor_sig, _r3_verbose, ) - if _r3_verbose() and _r3_should_log( - "actor._ppo_update/side_channel_set" - ): - _slice = ( - _r3_split[i] - if i < len(_r3_split) - else None - ) + _trace_id = _r3_next_trace_id() + self.engine._r3_active_trace_id = _trace_id + _slice = ( + _r3_split[i] + if i < len(_r3_split) + else None + ) + if _r3_verbose(): logger.info( "[R3-STAGE2/actor._ppo_update] " - "side_channel_set mb=%d current_version=%s " - "| %s", + "SIDE_CHANNEL_SET trace_id=%d mb=%d " + "current_version=%s %s " + "slice_shape=%s hash=%s " + "per_sample_hash[:16]=%s " + "per_sample_nnz[:16]=%s " + "per_sample_real_len[:16]=%s " + "mb_attn_sum[:16]=%s | %s", + _trace_id, i, current_version, + _r3_pp_tp_info(), + None if _slice is None else tuple(_slice.shape), + hex(_r3_hash64(_slice)), + [hex(h) for h in _r3_per_sample_hashes( + _slice, max_rows=16)], + _r3_per_sample_nnz(_slice, max_rows=16), + _r3_per_sample_seq_real_len(_slice, max_rows=16), + ( + mb["attention_mask"].sum(dim=-1).long().cpu().tolist()[:16] + if isinstance(mb, dict) and "attention_mask" in mb + else "N/A" + ), _r3_tensor_sig( "pending_routed_experts", _slice, @@ -620,7 +757,10 @@ def _ppo_update(self, data: dict[str, Any]) -> None: ), ) except Exception: - pass + logger.exception( + "[R3-STAGE2/actor._ppo_update] " + "SIDE_CHANNEL_SET trace log failed", + ) else: logger.warning( "[R3] routed_experts available but engine._r3_enabled " diff --git a/areal/trainer/ppo/actor_r3_patch.py b/areal/trainer/ppo/actor_r3_patch.py index e553383d9e..9cc562b156 100644 --- a/areal/trainer/ppo/actor_r3_patch.py +++ b/areal/trainer/ppo/actor_r3_patch.py @@ -93,6 +93,9 @@ def split_routed_experts_for_minibatches( ) try: from areal.engine.router_replay_utils import ( + _r3_hash64, + _r3_per_sample_hashes, + _r3_per_sample_nnz, _r3_pp_tp_info, _r3_should_log, _r3_tensor_sig, @@ -102,24 +105,44 @@ def split_routed_experts_for_minibatches( if _r3_verbose() and _r3_should_log( "split_routed_experts_for_minibatches" ): + # Pre-reorder per-sample hashes (what we *started* with) and + # post-reorder per-sample hashes (what each mini-batch gets). + pre_hash = _r3_per_sample_hashes(routed_experts, max_rows=32) + post_hash = _r3_per_sample_hashes(reordered, max_rows=32) + mb_hashes = [ + [hex(h) for h in _r3_per_sample_hashes(r, max_rows=16)] + for r in result + ] + mb_nnz = [_r3_per_sample_nnz(r, max_rows=16) for r in result] logger.info( "[R3-STAGE2/split_routed_experts_for_minibatches] %s " - "input_shape=%s n_mbs=%d forward_indices=%s " - "per_mb_shapes=%s | %s", + "input_shape=%s input_hash=%s n_mbs=%d forward_indices=%s " + "per_mb_shapes=%s per_mb_hashes=%s " + "pre_reorder_per_sample_hash[:16]=%s " + "post_reorder_per_sample_hash[:16]=%s " + "per_mb_per_sample_hash=%s per_mb_per_sample_nnz=%s | %s", _r3_pp_tp_info(), tuple(routed_experts.shape), + hex(_r3_hash64(routed_experts)), n_mbs, "None" if forward_indices is None else ( f"len={len(forward_indices)} " - f"first16={forward_indices[:16].tolist() if hasattr(forward_indices,'tolist') else list(forward_indices)[:16]}" + f"first32={forward_indices[:32].tolist() if hasattr(forward_indices,'tolist') else list(forward_indices)[:32]}" ), [tuple(r.shape) for r in result], + [hex(_r3_hash64(r)) for r in result], + [hex(h) for h in pre_hash[:16]], + [hex(h) for h in post_hash[:16]], + mb_hashes, + mb_nnz, _r3_tensor_sig("routed_experts", routed_experts, max_sample=4), ) except Exception: - pass + logger.exception( + "[R3-STAGE2/split_routed_experts_for_minibatches] trace log failed" + ) return result From cfd0f889d3b579f9304ef963e0b665a49b173ca6 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Sat, 2 May 2026 13:06:48 +0800 Subject: [PATCH 091/112] fix(router_replay): fix case num_sgl_tokens more than real_tokens --- areal/engine/router_replay_utils.py | 34 +++++++++++++++++++++++++++-- 1 file changed, 32 insertions(+), 2 deletions(-) diff --git a/areal/engine/router_replay_utils.py b/areal/engine/router_replay_utils.py index c991d5df3a..cdb4fd92c8 100644 --- a/areal/engine/router_replay_utils.py +++ b/areal/engine/router_replay_utils.py @@ -1048,8 +1048,38 @@ def preprocess_routed_experts_batch( # Build (1, seq_len, num_moe_layers, topk) with RIGHT padding. real_tokens = int(attention_mask.sum().item()) padded = torch.zeros(1, seq_len, num_moe_layers, topk, dtype=torch.int32) - n = min(num_sgl_tokens, real_tokens) - padded[0, :n] = tensor[:n] + if num_sgl_tokens > real_tokens: + # Pathological case (~2.4% of samples in observed runs): + # SGLang returned routing for MORE tokens than the final request + # actually carries (e.g. KV-preempt + retry, multi-turn rollout, + # or an abandoned prefill prefix whose routing was not dropped). + # Taking the HEAD ``tensor[:real_tokens]`` here would bind this + # sample to UNRELATED tokens' expert decisions and cause + # catastrophic router-replay misalignment: per-sample k3_kl jumps + # from ~1e-4 (normal) to ~1.0, producing the "~40% normal + ~60% + # broken" bimodal rollout-vs-train logp divergence. + # + # Safe behavior: disable R3 for THIS sample by leaving ``padded`` + # as all-zeros. The training-side replay path treats all-zero + # rows as "no recorded routing" and falls back to the live router + # (see ``valid_mask`` splicing in ``router_replay_patch.py``), + # which makes this sample behave like an R3-off sample instead + # of a corrupted one. + logger.warning( + "[R3] preprocess_routed_experts_batch: num_sgl_tokens=%d > " + "real_tokens=%d (ratio=%.2f, seq_len=%d). This is the " + "'double-rollout' / preempt-retry path; taking tensor[:real] " + "here would MIS-ALIGN routing to unrelated tokens. Disabling " + "R3 for this sample (returning all-zero routed_experts so " + "replay falls back to live routing).", + num_sgl_tokens, + real_tokens, + num_sgl_tokens / max(real_tokens, 1), + seq_len, + ) + else: + n = min(num_sgl_tokens, real_tokens) + padded[0, :n] = tensor[:n] if compress_dtype: max_val = padded.max().item() From f7550b9b129cfbd344641d347c2c3a152d1bb9cf Mon Sep 17 00:00:00 2001 From: bingyechen Date: Sun, 3 May 2026 03:24:10 +0800 Subject: [PATCH 092/112] fix(config): down max_concurrent_rollouts --- .../math/moonlight_16b_a3b_gsm8k_grpo_megatron_h20_base.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron_h20_base.yaml b/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron_h20_base.yaml index e69433a4df..d67afa4c0e 100644 --- a/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron_h20_base.yaml +++ b/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron_h20_base.yaml @@ -21,7 +21,7 @@ rollout: backend: "sglang:d1p1t2" experiment_name: ${experiment_name} trial_name: ${trial_name} - max_concurrent_rollouts: 64 + max_concurrent_rollouts: 16 queue_size: null consumer_batch_size: ${train_dataset.batch_size} max_head_offpolicyness: 0 From 58141694adf8e07eff7b18aa560334a8e5354214 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Tue, 5 May 2026 01:18:20 +0800 Subject: [PATCH 093/112] refactor(examples/math): pp --- .../math/moonlight_16b_a3b_gsm8k_grpo_megatron_h20_base.yaml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron_h20_base.yaml b/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron_h20_base.yaml index d67afa4c0e..b7002cad45 100644 --- a/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron_h20_base.yaml +++ b/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron_h20_base.yaml @@ -39,7 +39,8 @@ gconfig: temperature: 1.0 actor: - backend: "megatron:(attn:d1p1t4|ffn:d1p1t1e4)" # ← PP=2 回退,TP=4/EP=4 + backend: "megatron:(attn:d1p2t2|ffn:d1p2t1e2)" # ← PP=2, attn TP=2, ffn EP=2 + # backend: "megatron:(attn:d1p1t4|ffn:d1p1t1e4)" experiment_name: ${experiment_name} trial_name: ${trial_name} path: /workspace/models/Moonlight-16B-A3B-Instruct From 3dbb46f3fccdb62dbabc3f3b656a2d9fb7c9723e Mon Sep 17 00:00:00 2001 From: bingyechen Date: Tue, 5 May 2026 01:35:05 +0800 Subject: [PATCH 094/112] fix(congig): fix --- .../math/moonlight_16b_a3b_gsm8k_grpo_megatron_h20_base.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron_h20_base.yaml b/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron_h20_base.yaml index b7002cad45..44b9be4f1d 100644 --- a/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron_h20_base.yaml +++ b/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron_h20_base.yaml @@ -87,7 +87,7 @@ actor: use_precision_aware_optimizer: true recompute_granularity: full recompute_method: uniform - recompute_num_layers: 9 + recompute_num_layers: 1 ddp: grad_reduce_in_fp32: false # ← 保持逐层重计算 scheduling_spec: From 7c78380168dd0c55aa9ab2e074fe01fd98398bb0 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Wed, 6 May 2026 01:59:51 +0800 Subject: [PATCH 095/112] feat(diag): pp log --- areal/engine/megatron_engine_r3_patch.py | 207 ++++++++++++++++++++++- areal/engine/router_replay_patch.py | 160 +++++++++++++++++- areal/engine/router_replay_utils.py | 113 ++++++++++++- 3 files changed, 472 insertions(+), 8 deletions(-) diff --git a/areal/engine/megatron_engine_r3_patch.py b/areal/engine/megatron_engine_r3_patch.py index 533b9d9956..d9c806a907 100644 --- a/areal/engine/megatron_engine_r3_patch.py +++ b/areal/engine/megatron_engine_r3_patch.py @@ -71,6 +71,51 @@ def patch_megatron_engine_for_r3( engine._r3_original_forward_backward_batch = engine.forward_backward_batch engine._r3_pending_routed_experts = None + # ---------- R3 diagnostics: one-shot config snapshot so the NEXT + # run's log unambiguously records the PP layout Megatron-Core + # actually saw (num_layers, vp_size, pp_size, local offset/end, + # router_instance count per PP rank). This answers D2/D3/D5 in a + # single early-startup line without polluting hot paths. + try: + from areal.engine.router_replay_patch import RouterReplay as _RR + from areal.engine.router_replay_utils import ( + _r3_pp_tp_info as _ppi, + _r3_verbose as _v, + get_current_rank_layer_info as _info, + is_moe_layer as _ism, + ) + if _v(): + _tf = engine.tf_config + _li = _info(_tf) + _moe_list = [i for i in range(_li["start"], _li["end"]) if _ism(_tf, i)] + _dense_list = [ + i for i in range(_li["start"], _li["end"]) if not _ism(_tf, i) + ] + logger.info( + "[R3-STAGE0/patch_megatron_engine_for_r3] ENGINE_SNAPSHOT %s " + "tf_config.num_layers=%d pp_size=%d vp_size=%s " + "moe_layer_freq=%s first_k_dense_replace=%s " + "local={start:%d end:%d count:%d} " + "moe_layers_in_range=%s non_moe_layers_in_range=%s " + "total_router_instances=%d " + "inst_creator_ranks=%s", + _ppi(_tf), + _tf.num_layers, + getattr(_tf, "pipeline_model_parallel_size", 1), + getattr(_tf, "virtual_pipeline_model_parallel_size", None), + getattr(_tf, "moe_layer_freq", None), + getattr(_tf, "first_k_dense_replace", None), + _li["start"], _li["end"], _li["count"], + _moe_list, _dense_list, + len(_RR.router_instances), + [getattr(r, "creator_rank", -1) for r in _RR.router_instances], + ) + except Exception: + logger.exception( + "[R3-STAGE0/patch_megatron_engine_for_r3] snapshot log failed" + ) + # -------------------------------------------------------------- + # Bind the wrapped method engine.forward_backward_batch = types.MethodType( _r3_forward_backward_batch, engine @@ -451,20 +496,33 @@ def _r3_forward_backward_batch( _r3_verbose, ) if _r3_verbose(): + # ---------- R3 diagnostics: D6 -- cross-PP batch-hash + # consistency. Print FULL global hash + per-sample hashes + # so PP rank 0 and PP rank 1 at the same trace_id can be + # diff'd offline. If hashes differ, the side-channel + # broadcast/scatter is wrong (root cause); if identical, + # PP-input parity is proved and we can rule out D6. + _full_hash = hex(_r3_hash64(routed_experts_batch)) + _all_per_sample = _r3_per_sample_hashes( + routed_experts_batch, max_rows=4096, + ) + _all_per_sample_hex = [hex(h) for h in _all_per_sample] logger.info( "[R3-STAGE3/_r3_forward_backward_batch] " "SIDE_CHANNEL_CONSUME trace_id=%s %s forward_only=%s " "shape=%s hash=%s per_sample_hash[:16]=%s " - "per_sample_nnz[:16]=%s per_sample_real_len[:16]=%s", + "per_sample_nnz[:16]=%s per_sample_real_len[:16]=%s " + "n_samples_total=%d full_per_sample_hash=%s", _consumed_trace_id, _r3_pp_tp_info(), forward_only, routed_experts_batch.shape, - hex(_r3_hash64(routed_experts_batch)), - [hex(h) for h in _r3_per_sample_hashes( - routed_experts_batch, max_rows=16)], + _full_hash, + _all_per_sample_hex[:16], _r3_per_sample_nnz(routed_experts_batch, max_rows=16), _r3_per_sample_seq_real_len(routed_experts_batch, max_rows=16), + len(_all_per_sample), + _all_per_sample_hex, ) except Exception: logger.exception( @@ -698,6 +756,56 @@ def __next__(self): model_config, seq_align_to=_seq_align_to, ) + # ---------- R3 diagnostics: per-mb queue-depth + # snapshot RIGHT AFTER the dispatch finishes. + # Under PP=1 every router has fwd_q==1 here; under + # PP=2 1F1B the depth oscillates 1..PP_size. + try: + from areal.engine.router_replay_patch import ( + RouterReplay as _RR, + ) + from areal.engine.router_replay_utils import ( + _r3_should_log as _sl2, + _r3_verbose as _v2, + ) + if _v2() and _sl2( + "_R3MicroBatchIterator/post_dispatch_queue_audit" + ): + router_list = ( + RouterReplayHelper.get_micro_batch_router_list( + model_config + ) + ) + fwd_qs = [ + len(getattr(r, "replay_backward_list", []) or []) + for r in router_list + ] + push_qs = [ + len( + getattr(r, "replay_push_meta_list", []) or [] + ) + for r in router_list + ] + logger.info( + "[R3-STAGE3/_R3MicroBatchIterator] " + "POST_DISPATCH_QUEUE_AUDIT mb_idx=%d %s " + "n_routers=%d fwd_q_lens=%s push_meta_q_lens=%s " + "max_fwd_q=%d min_fwd_q=%d " + "lens_locked=%s", + idx, + _r3_pp_tp_info(), + len(router_list), + fwd_qs, + push_qs, + max(fwd_qs) if fwd_qs else -1, + min(fwd_qs) if fwd_qs else -1, + fwd_qs == push_qs, + ) + except Exception: + logger.exception( + "[R3-STAGE3/_R3MicroBatchIterator] " + "POST_DISPATCH_QUEUE_AUDIT diag log failed" + ) except Exception: logger.warning( "[R3] Failed to setup replay for micro-batch %d.", @@ -739,9 +847,19 @@ def __iter__(self_inner): # 4. Register a forward hook for REPLAY_FORWARD -> REPLAY_BACKWARD toggle. # ------------------------------------------------------------------ hook_handles = [] + # ---------- R3 diagnostics: D8 -- track which model chunks fire + # the post-forward hook. Under PP=2 + VP, multiple model chunks + # share the fbfunc; if any chunk's hook misses, the action toggle + # is skipped and backward pops see REPLAY_FORWARD, silently + # returning live routing. A mismatch between len(self.model) and + # hook_fire_counts[chunk_id] is a smoking gun. + _r3_hook_fire_counts: dict[int, int] = {} + _r3_toggle_count = {"n": 0} def _r3_post_forward_hook(module, input, output): """Switch from REPLAY_FORWARD to REPLAY_BACKWARD after model forward.""" + _chunk_id = id(module) + _r3_hook_fire_counts[_chunk_id] = _r3_hook_fire_counts.get(_chunk_id, 0) + 1 if RouterReplayHelper.is_replay_forward_action(model_config): router_list = RouterReplayHelper.get_micro_batch_router_list( model_config @@ -750,18 +868,53 @@ def _r3_post_forward_hook(module, input, output): router.set_router_replay_action( RouterReplayAction.REPLAY_BACKWARD ) + _r3_toggle_count["n"] += 1 if _r3_verbose() and _r3_should_log("_r3_post_forward_hook"): logger.info( "[R3-STAGE3/_r3_post_forward_hook] TOGGLE forward->backward " - "%s n_routers=%d", + "%s n_routers=%d mb_counter=%d chunk_id=%d " + "fire_count_this_chunk=%d total_toggles=%d " + "n_chunks_seen=%d", _r3_pp_tp_info(), len(router_list), + getattr(self, "_r3_mb_counter", -1), + _chunk_id, + _r3_hook_fire_counts[_chunk_id], + _r3_toggle_count["n"], + len(_r3_hook_fire_counts), + ) + else: + # Hook fired but action was not REPLAY_FORWARD -- this is + # expected after the first toggle under 1F1B (subsequent mbs + # see REPLAY_BACKWARD until the iterator flips them back). + # We still log rarely to confirm behavior. + if _r3_verbose() and _r3_should_log( + "_r3_post_forward_hook/no_toggle" + ): + logger.info( + "[R3-STAGE3/_r3_post_forward_hook] NO_TOGGLE " + "(already backward or cleared) %s mb_counter=%d " + "chunk_id=%d fire_count_this_chunk=%d " + "n_chunks_seen=%d", + _r3_pp_tp_info(), + getattr(self, "_r3_mb_counter", -1), + _chunk_id, + _r3_hook_fire_counts[_chunk_id], + len(_r3_hook_fire_counts), ) for model_chunk in self.model: handle = model_chunk.register_forward_hook(_r3_post_forward_hook) hook_handles.append(handle) + # ---------- R3 diagnostics: reset FB-level aggregate counters so + # the end-of-FB summary reflects only this call. Safe to reset + # unconditionally: consumers read the dict inside the FB span. + try: + RouterReplay._r3_fb_stats = {} + except Exception: + pass + try: self._r3_original_forward_backward_batch( mb_list, process_output_fn, forward_only=forward_only @@ -774,6 +927,50 @@ def _r3_post_forward_hook(module, input, output): # class swap done above). The original class was never modified. mb_list.__class__ = _r3_original_mb_list_class + # ---------- R3 diagnostics: END_OF_FB summary. One line per + # forward_backward_batch call aggregating: + # * divergence_v1 (MATCH vs popped-vs-latest-target; under + # PP=2 1F1B, DIVERGE is EXPECTED — logging artifact). + # * divergence_v2 (MATCH vs popped-vs-own-push; under any + # PP layout MATCH is REQUIRED — REAL_MISMATCH is a bug). + # * hook fire counts per model-chunk: if unequal, one chunk + # missed its toggle (D8 smoking gun). + # * final queue residue across all routers (D9 smoking gun). + try: + if _r3_verbose() and _r3_should_log("_r3_forward_backward_batch/EXIT_SUMMARY"): + from areal.engine.router_replay_patch import RouterReplay as _RR + _fwd_q = [ + len(getattr(r, "replay_backward_list", []) or []) + for r in _RR.router_instances + ] + _push_q = [ + len(getattr(r, "replay_push_meta_list", []) or []) + for r in _RR.router_instances + ] + logger.info( + "[R3-STAGE3/_r3_forward_backward_batch] EXIT_SUMMARY %s " + "n_mbs=%d forward_only=%s trace_id=%s " + "fb_stats=%s n_model_chunks=%d hook_fire_counts=%s " + "total_toggles=%d residual_fwd_q=%s residual_push_q=%s " + "residual_max=%d", + _r3_pp_tp_info(), + len(mb_list), + forward_only, + _consumed_trace_id, + dict(_RR._r3_fb_stats), + len(self.model), + dict(_r3_hook_fire_counts), + _r3_toggle_count["n"], + _fwd_q, + _push_q, + max(_fwd_q) if _fwd_q else -1, + ) + except Exception: + logger.exception( + "[R3-STAGE3/_r3_forward_backward_batch] " + "EXIT_SUMMARY diag log failed" + ) + clear_router_replay() self._r3_per_mb_experts = None self._r3_mb_counter = 0 diff --git a/areal/engine/router_replay_patch.py b/areal/engine/router_replay_patch.py index d6873c76da..001fbfd629 100644 --- a/areal/engine/router_replay_patch.py +++ b/areal/engine/router_replay_patch.py @@ -87,6 +87,15 @@ class RouterReplay: # Set by the engine patch before forward_backward_func. pp_size: int = 1 + # ---------- R3 diagnostics (PP=2 root-cause hunt) ---------- + # Per forward_backward_batch aggregate counters. Keys set by the + # REPLAY_BACKWARD/consume path and the forward-hook; values are reset + # at FB entry by _r3_forward_backward_batch and dumped at FB exit. + # This gives one END-OF-FB summary line to diff PP=1 vs PP=2 runs. + # Always class-level dict (no cross-rank comm — each rank maintains + # its own, which is the correctness unit). + _r3_fb_stats: dict[str, int] = {} + # ------------------------------------------------------------------ # Class-level (static) helpers # ------------------------------------------------------------------ @@ -160,6 +169,22 @@ def __init__(self) -> None: # back to the live router output so they produce no replay signal. self.target_valid_mask: torch.Tensor | None = None self.replay_backward_mask_list: list[torch.Tensor | None] = [] + # ---------- R3 diagnostics (PP=2 root-cause hunt) ---------- + # Per-push metadata ring -- parallel to ``replay_backward_list`` so + # the BACKWARD consumer can compare the popped slab against the + # slab that was REGISTERED AT THAT PUSH (not the most recent + # target, which under 1F1B scheduling has been overwritten by a + # later mb). This cleanly separates "logging artifact" from "real + # backward queue corruption". See code-rules/distributed.md hang + # section -- the metadata ring is local state, no cross-rank comm. + self.replay_push_meta_list: list[dict] = [] + self.creation_order: int = len(RouterReplay.router_instances) + try: + import torch.distributed as _dist + self.creator_rank: int = _dist.get_rank() if _dist.is_initialized() else -1 + except Exception: + self.creator_rank = -1 + # ------------------------------------------------------------ RouterReplay.router_instances.append(self) def set_target_indices( @@ -182,6 +207,42 @@ def set_target_indices( self.target_valid_mask = valid_mask self.replay_backward_list.append(topk_indices) self.replay_backward_mask_list.append(valid_mask) + # ---------- R3 diagnostics: capture push metadata at the SAME + # call site so REPLAY_BACKWARD/consume can later prove that the + # popped slab equals the slab that was originally pushed (the + # only correctness criterion). Hashing here is gated by + # _r3_should_log so steady-state cost is one int + one None. + # --------------------------------------------------------------- + try: + from areal.engine.router_replay_utils import ( + _r3_current_trace_id as _tid, + _r3_hash64 as _h64, + _r3_should_log as _sl, + _r3_verbose as _v, + ) + if _v() and _sl("RouterReplay.set_target_indices/push_meta"): + _slab_h = hex(_h64(topk_indices)) + _mask_h = ( + hex(_h64(valid_mask.to(torch.int32))) + if valid_mask is not None else "None" + ) + self.replay_push_meta_list.append({ + "push_id": getattr(self, "_r3_push_counter", 0), + "trace_id": _tid(), + "slab_hash": _slab_h, + "mask_hash": _mask_h, + "slab_shape": tuple(topk_indices.shape), + }) + self._r3_push_counter = getattr(self, "_r3_push_counter", 0) + 1 + else: + # Always append a placeholder so list lengths stay locked + # to ``replay_backward_list``; pop side will skip None. + self.replay_push_meta_list.append(None) + except Exception: + try: + self.replay_push_meta_list.append(None) + except Exception: + pass # Cheap diagnostic: record every set in first few layers/mb. Gated # via _r3_should_log so steady-state overhead is ~one integer # increment. @@ -232,11 +293,33 @@ def record_indices(self, topk_indices: torch.Tensor) -> None: self.recorded_topk_idx = topk_indices def clear_indices(self) -> None: + # ---------- R3 diagnostics: dump tail-state queue sizes BEFORE + # clearing so residual queues (a smoking gun for lost backward + # pops under PP=2 1F1B) are always visible in logs. + try: + from areal.engine.router_replay_utils import ( + _r3_pp_tp_info, + _r3_should_log, + _r3_verbose, + ) + if _r3_verbose() and _r3_should_log("RouterReplay.clear_indices/tail_state"): + logger.info( + "[R3-STAGE3c/clear_indices] %s inst#%d fwd_q=%d " + "mask_q=%d push_meta_q=%d", + _r3_pp_tp_info(), + self.creation_order, + len(self.replay_backward_list), + len(self.replay_backward_mask_list), + len(getattr(self, "replay_push_meta_list", []) or []), + ) + except Exception: + pass self.recorded_topk_idx = None self.target_topk_idx = None self.target_valid_mask = None self.replay_backward_list = [] self.replay_backward_mask_list = [] + self.replay_push_meta_list = [] def set_router_replay_action(self, action: RouterReplayAction) -> None: self.router_replay_action = action @@ -459,6 +542,17 @@ def compute_topk(scores, topk, num_groups=None, group_topk=None): bw_valid_mask = bw_mask_list.pop(0) else: bw_valid_mask = None + # ---------- R3 diagnostics: pop the matching push-meta entry + # so the divergence verdict below compares popped-slab against + # the slab that was REGISTERED AT PUSH TIME (the real + # correctness criterion under 1F1B PP scheduling). + _push_meta_list = getattr(router_replay, "replay_push_meta_list", None) + _bw_push_meta = None + if _push_meta_list: + try: + _bw_push_meta = _push_meta_list.pop(0) + except IndexError: + _bw_push_meta = None # ---- R3 deep-trace: log backward pop order + hashes ---- try: from areal.engine.router_replay_utils import ( @@ -510,7 +604,8 @@ def compute_topk(scores, topk, num_groups=None, group_topk=None): "popped_slab_hash=%s popped_mask_hash=%s " "current_target_hash=%s divergence=%s " "queue_len_before=%d queue_len_after=%d " - "mask_queue_len_before=%d mask_queue_len_after=%d", + "mask_queue_len_before=%d mask_queue_len_after=%d " + "push_meta=%s divergence_v2=%s", _tid(), _inst_idx, tuple(scores.shape), @@ -527,11 +622,74 @@ def compute_topk(scores, topk, num_groups=None, group_topk=None): getattr(router_replay, "replay_backward_mask_list", []) or [] ), + _bw_push_meta, + # divergence_v2 is the DEFINITIVE verdict: it + # compares popped slab against the slab recorded + # at the matching push site, not against the + # most recent (potentially overwritten) target. + # MATCH here = backward queue is correct under PP. + ( + "NO_PUSH_META" + if _bw_push_meta is None + else ( + "MATCH" + if _bw_push_meta.get("slab_hash") == _popped_slab_hash + else "REAL_MISMATCH" + ) + ), ) except Exception: logger.exception( "[R3-STAGE4/REPLAY_BACKWARD/consume] trace log failed" ) + # ---------- R3 diagnostics: FB-level aggregate counters + # (gated by _r3_verbose so prod path is untouched). Counters + # are reset at _r3_forward_backward_batch entry and dumped at + # exit, giving one summary line per FB call. + try: + from areal.engine.router_replay_utils import ( + _r3_hash64 as _h64x, + _r3_verbose as _vx, + ) + if _vx(): + _stats = RouterReplay._r3_fb_stats + if router_replay.target_topk_idx is None: + _stats["divergence_v1_none"] = ( + _stats.get("divergence_v1_none", 0) + 1 + ) + else: + _v1_match = ( + router_replay.target_topk_idx.shape == top_indices.shape + and _h64x( + router_replay.target_topk_idx.to(top_indices.device) + ) == _h64x(top_indices) + ) + _stats["divergence_v1_match" if _v1_match + else "divergence_v1_diverge"] = ( + _stats.get( + "divergence_v1_match" if _v1_match + else "divergence_v1_diverge", 0, + ) + 1 + ) + if _bw_push_meta is None: + _stats["divergence_v2_no_meta"] = ( + _stats.get("divergence_v2_no_meta", 0) + 1 + ) + else: + _v2_match = ( + _bw_push_meta.get("slab_hash") + == hex(_h64x(top_indices)) + ) + _stats["divergence_v2_match" if _v2_match + else "divergence_v2_real_mismatch"] = ( + _stats.get( + "divergence_v2_match" if _v2_match + else "divergence_v2_real_mismatch", 0, + ) + 1 + ) + _stats["bw_pop_total"] = _stats.get("bw_pop_total", 0) + 1 + except Exception: + pass if ( bw_valid_mask is not None and bw_valid_mask.shape[0] == top_indices.shape[0] diff --git a/areal/engine/router_replay_utils.py b/areal/engine/router_replay_utils.py index cdb4fd92c8..9c5439a7e6 100644 --- a/areal/engine/router_replay_utils.py +++ b/areal/engine/router_replay_utils.py @@ -820,6 +820,73 @@ def set_router_replay_data( offset, end = local_info["start"], local_info["end"] router_list = RouterReplayHelper.get_micro_batch_router_list(tf_config, vp_rank) + # ---------- R3 diagnostics: PP=2 root-cause hunt ---------- + # Print the full PP-rank slicing decision so the next log can + # trivially confirm: + # * tf_config.num_layers stays GLOBAL (27) under Megatron-Core PP + # -- if it ever becomes local (14/13), index_by_layer would + # still report True but ``idx=layer_idx`` would over-shoot. + # * offset/end honors get_transformer_layer_offset. + # * The set of MoE layers in [offset, end) matches the rollout + # layer-axis convention (absolute layer index, with layer 0 + # dense and recorded as zeros). + # * RouterReplay.router_instances is local-per-process: the + # selected slice (creation_order list) tells us which routers + # this PP rank actually owns. + try: + if _r3_verbose() and _r3_should_log("set_router_replay_data/PP_LAYOUT"): + moe_layers_in_range = [ + i for i in range(offset, end) if is_moe_layer(tf_config, i) + ] + non_moe_layers_in_range = [ + i for i in range(offset, end) if not is_moe_layer(tf_config, i) + ] + vp_size = getattr( + tf_config, "virtual_pipeline_model_parallel_size", None + ) + # Cheap audit: nnz of dim-0 slice for ALL layer indices + # (helps prove rollout's L-axis-0 is the dense layer and + # really is all-zero, vs. silently shifted). + with torch.no_grad(): + per_layer_nnz = [ + int((layers_topk[L] != 0).any(dim=-1).sum().item()) + if L < layers_topk.shape[0] else -1 + for L in range(min(layers_topk.shape[0], 32)) + ] + logger.info( + "[R3-STAGE3/set_router_replay_data] PP_LAYOUT trace_id=%d %s " + "tf_config.num_layers=%d vp_size=%s moe_layer_freq=%s " + "first_k_dense_replace=%s " + "local_info={start:%d, end:%d, count:%d} " + "moe_layers_in_range=%s non_moe_layers_in_range=%s " + "len(router_list)=%d total_router_instances=%d " + "selected_router_creation_orders=%s " + "selected_router_creator_ranks=%s " + "layers_topk_dim0=%d index_by_layer=%s " + "per_layer_any_nnz_first32=%s", + _r3_current_trace_id(), + _r3_pp_tp_info(tf_config, vp_rank), + tf_config.num_layers, + vp_size, + getattr(tf_config, "moe_layer_freq", None), + getattr(tf_config, "first_k_dense_replace", None), + offset, + end, + local_info["count"], + moe_layers_in_range, + non_moe_layers_in_range, + len(router_list), + len(RouterReplay.router_instances), + [getattr(r, "creation_order", -1) for r in router_list], + [getattr(r, "creator_rank", -1) for r in router_list], + layers_topk.shape[0], + len(layers_topk) == tf_config.num_layers, + per_layer_nnz, + ) + except Exception: + logger.exception("[R3-STAGE3/PP_LAYOUT] diag log failed") + # ---------------------------------------------------------- + if len(router_list) == 0: logger.warning( "[R3] set_router_replay_data: no RouterReplay instances found " @@ -870,6 +937,8 @@ def set_router_replay_data( _r3_zero_row_stats(slab), _r3_tensor_sig(f"target[L={layer_idx}]", slab), hex(_r3_hash64(slab)), + getattr(router, "creation_order", -1), + moe_idx, ) ) router_offset += 1 @@ -892,15 +961,23 @@ def set_router_replay_data( logger.info( "[R3-STAGE3/set_router_replay_data] DISPATCH trace_id=%d %s " "router_offset=%d len(router_list)=%d index_by_layer=%s " - "first_layers=%s ... (total dispatched=%d)", + "first_layers=%s all_layers_to_router_map=%s " + "... (total dispatched=%d)", _r3_current_trace_id(), _r3_pp_tp_info(tf_config, vp_rank), router_offset, len(router_list), index_by_layer, [ - (lidx, j, zr, sig, h) for lidx, j, zr, sig, h in head + (lidx, j, zr, sig, h, co, mi) + for lidx, j, zr, sig, h, co, mi in head ], + # Full (layer_idx, slab_idx_used, router_creation_order, + # moe_ordinal) tuple for every dispatched layer. This is + # the definitive cross-check: PP0 must be [(1,1,0,0), + # (2,2,1,1), ..., (13,13,12,12)] and PP1 must be + # [(14,14,0,13), ..., (26,26,12,25)] on Moonlight. + [(lidx, j, co, mi) for lidx, j, _, _, _, co, mi in dispatched], len(dispatched), ) @@ -957,6 +1034,38 @@ def setup_per_microbatch_replay_backward() -> None: def clear_router_replay() -> None: """Clear all RouterReplay state after a full forward-backward pass.""" n_instances = len(RouterReplay.router_instances) + # ---------- R3 diagnostics: dump pre-clear queue lengths so leftover + # backward pops (a smoking gun for missing recompute under PP=2 1F1B) + # are always visible. + try: + if _r3_verbose() and _r3_should_log("clear_router_replay/snapshot"): + fwd_qs = [ + len(getattr(r, "replay_backward_list", []) or []) + for r in RouterReplay.router_instances + ] + mask_qs = [ + len(getattr(r, "replay_backward_mask_list", []) or []) + for r in RouterReplay.router_instances + ] + push_qs = [ + len(getattr(r, "replay_push_meta_list", []) or []) + for r in RouterReplay.router_instances + ] + n_nonempty = sum(1 for q in fwd_qs if q > 0) + logger.info( + "[R3-STAGE3c/clear_router_replay] PRE_CLEAR_SNAPSHOT %s " + "n_instances=%d n_with_residual_fwd_q=%d " + "fwd_q_lens=%s mask_q_lens=%s push_meta_q_lens=%s", + _r3_pp_tp_info(), + n_instances, + n_nonempty, + fwd_qs, + mask_qs, + push_qs, + ) + except Exception: + logger.exception("[R3-STAGE3c/clear_router_replay] diag log failed") + # ----------------------------------------------------------------- RouterReplay.clear_global_indices() RouterReplay.clear_global_router_replay_action() logger.debug("[R3] Router replay state cleared (%d instances).", n_instances) From c2fc221aed9e0d3eec172c1a2060225f40beeca7 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Wed, 6 May 2026 19:23:07 +0800 Subject: [PATCH 096/112] fix(router_replay): cp --- .../megatron_utils/packed_context_parallel.py | 13 ++++- areal/engine/router_replay_utils.py | 55 +++++++++++++++++-- 2 files changed, 59 insertions(+), 9 deletions(-) diff --git a/areal/engine/megatron_utils/packed_context_parallel.py b/areal/engine/megatron_utils/packed_context_parallel.py index 7fc16cab6b..eb9b445a32 100644 --- a/areal/engine/megatron_utils/packed_context_parallel.py +++ b/areal/engine/megatron_utils/packed_context_parallel.py @@ -72,8 +72,14 @@ def split_packed_seqs_for_context_parallel( tensor: torch.Tensor, cu_seqlens: torch.Tensor, ) -> torch.Tensor: - """Split a 1D packed tensor using the same interleaved pattern as - preprocess_packed_seqs_context_parallel.""" + """Split a packed tensor using the same interleaved pattern as + ``preprocess_packed_seqs_context_parallel``. + + Supports tensors with arbitrary trailing dims as long as ``dim-0`` is the + token axis (e.g. 1D ``[T]``, 2D ``[T, D]``, or 3D ``[T, L, K]``). The + interleave pattern operates only on dim-0; trailing dims are sliced as + contiguous blocks along with each token row. + """ cp_size = mpu.get_context_parallel_world_size() cp_rank = mpu.get_context_parallel_rank() if cp_size <= 1: @@ -83,7 +89,8 @@ def split_packed_seqs_for_context_parallel( batch_size = input_lens.shape[0] output_len = input_lens.sum().item() // cp_size - splitted = torch.zeros(output_len, dtype=tensor.dtype, device=tensor.device) + out_shape = (output_len,) + tuple(tensor.shape[1:]) + splitted = torch.zeros(out_shape, dtype=tensor.dtype, device=tensor.device) for i in range(batch_size): seqlen = input_lens[i] // cp_size half_seqlen = seqlen // 2 diff --git a/areal/engine/router_replay_utils.py b/areal/engine/router_replay_utils.py index 9c5439a7e6..22e6c17131 100644 --- a/areal/engine/router_replay_utils.py +++ b/areal/engine/router_replay_utils.py @@ -543,10 +543,12 @@ def set_router_replay_data( 1. Use ``cu_seqlens`` to extract each sample's real tokens from the left-aligned ``layers_topk_idx``. - 2. Pack tokens contiguously with per-sequence TP alignment padding + 2. Pack tokens contiguously with per-sequence TP/CP alignment padding (each sequence padded to a multiple of ``seq_align_to``). - 3. ``scatter_to_sequence_parallel_region`` to split across TP/SP ranks. - 4. Permute to ``(num_layers, local_tokens, topk)`` and distribute to + 3. When ``cp_size > 1``, CP-interleave-split dim-0 to match + ``preprocess_packed_seqs_context_parallel``. + 4. ``scatter_to_sequence_parallel_region`` to split across TP/SP ranks. + 5. Permute to ``(num_layers, local_tokens, topk)`` and distribute to RouterReplay instances. Args: @@ -775,9 +777,50 @@ def set_router_replay_data( _r3_tensor_sig("packed", packed), ) - # Step 2: Scatter to SP ranks + # Step 2: CP split (before TP scatter). + # + # When ``cp_size > 1``, megatron-core's + # ``preprocess_packed_seqs_context_parallel`` has already + # interleaved-split the model's token axis so each CP rank only sees + # ``total_aligned / cp_size`` tokens. Router replay indices MUST + # match that layout before the TP scatter; otherwise each TP rank + # would end up with ``cp_size``x too many rows and overwrite the + # wrong positions. We reuse ``split_packed_seqs_for_context_parallel`` + # with the PADDED ``cu_seqlens`` the caller provided (see caller + # comment "Pass the PADDED cu_seqlens (with TP alignment) ..."). + # + # Contract: ``cu_seqlens`` here describes the SAME packed layout as + # ``packed`` (dim-0 aligned), which is what the engine passes in. packed = packed.to(device) valid_mask = valid_mask.to(device) + cp_size = getattr(mpu, "get_context_parallel_world_size", lambda: 1)() + if cp_size > 1: + from areal.engine.megatron_utils.packed_context_parallel import ( + split_packed_seqs_for_context_parallel, + ) + cu_seqlens_dev = cu_seqlens.to(device) + packed = split_packed_seqs_for_context_parallel(packed, cu_seqlens_dev) + # Preserve bool semantics: split as int32 then recast. + valid_mask = split_packed_seqs_for_context_parallel( + valid_mask.to(torch.int32), cu_seqlens_dev + ).bool() + if _r3_verbose() and _r3_should_log("set_router_replay_data/CP_SPLIT"): + with torch.no_grad(): + n_after_cp_valid = int(valid_mask.sum().item()) + logger.info( + "[R3-STAGE3/set_router_replay_data] CP_SPLIT trace_id=%d %s " + "cp_size=%d post_cp_packed=%s post_cp_valid=%d/%d " + "post_cp_packed_hash=%s", + _r3_current_trace_id(), + _r3_pp_tp_info(tf_config, vp_rank), + cp_size, + _r3_tensor_sig("packed_after_cp", packed), + n_after_cp_valid, + valid_mask.numel(), + hex(_r3_hash64(packed)), + ) + + # Step 3: Scatter to SP ranks (TP) tp_size = mpu.get_tensor_model_parallel_world_size() if tp_size > 1: from megatron.core.tensor_parallel import scatter_to_sequence_parallel_region @@ -794,7 +837,7 @@ def set_router_replay_data( # local_tokens: (local_tokens_count, num_layers, topk) # local_mask: (local_tokens_count,) - # Step 3: Permute to (num_layers, local_tokens_count, topk) + # Step 4: Permute to (num_layers, local_tokens_count, topk) layers_topk = local_tokens.permute(1, 0, 2) if _r3_verbose() and _r3_should_log("set_router_replay_data/SCATTER"): @@ -815,7 +858,7 @@ def set_router_replay_data( hex(_r3_hash64(local_mask.to(torch.int32))), ) - # Step 4: Distribute to RouterReplay instances for local PP layers + # Step 5: Distribute to RouterReplay instances for local PP layers local_info = get_current_rank_layer_info(tf_config, vp_rank) offset, end = local_info["start"], local_info["end"] router_list = RouterReplayHelper.get_micro_batch_router_list(tf_config, vp_rank) From 34ee92a42f290ef0dc3dd00f1e69e4d348e09f8b Mon Sep 17 00:00:00 2001 From: bingyechen Date: Wed, 6 May 2026 20:02:15 +0800 Subject: [PATCH 097/112] fix(engine): test --- areal/experimental/engine/archon_utils.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/areal/experimental/engine/archon_utils.py b/areal/experimental/engine/archon_utils.py index cf9734587f..817c21159a 100644 --- a/areal/experimental/engine/archon_utils.py +++ b/areal/experimental/engine/archon_utils.py @@ -350,15 +350,15 @@ def prepare_training_config( ) # PP weight tying constraint (independent of pad_to_maximum) - if parallel_dims.pp_enabled: - if getattr(model_config, "tie_word_embeddings", False): - raise ValueError( - f"Pipeline Parallelism (PP={parallel_dims.pp}) is not supported " - f"with weight tying (tie_word_embeddings=True). " - f"When PP > 1, tok_embeddings and output layers are on different GPUs " - f"and cannot share the same weight tensor. " - f"Please either disable PP (set pipeline_parallel_size=1) or use a model " - f"without weight tying." - ) + # if parallel_dims.pp_enabled: + # if getattr(model_config, "tie_word_embeddings", False): + # raise ValueError( + # f"Pipeline Parallelism (PP={parallel_dims.pp}) is not supported " + # f"with weight tying (tie_word_embeddings=True). " + # f"When PP > 1, tok_embeddings and output layers are on different GPUs " + # f"and cannot share the same weight tensor. " + # f"Please either disable PP (set pipeline_parallel_size=1) or use a model " + # f"without weight tying." + # ) return ac_config, enable_compile From f5be0f32653acd56b71dc1e0b3ab4b38a4115311 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Wed, 6 May 2026 20:04:34 +0800 Subject: [PATCH 098/112] fix(archon_utils): revert test --- areal/experimental/engine/archon_utils.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/areal/experimental/engine/archon_utils.py b/areal/experimental/engine/archon_utils.py index 817c21159a..cf9734587f 100644 --- a/areal/experimental/engine/archon_utils.py +++ b/areal/experimental/engine/archon_utils.py @@ -350,15 +350,15 @@ def prepare_training_config( ) # PP weight tying constraint (independent of pad_to_maximum) - # if parallel_dims.pp_enabled: - # if getattr(model_config, "tie_word_embeddings", False): - # raise ValueError( - # f"Pipeline Parallelism (PP={parallel_dims.pp}) is not supported " - # f"with weight tying (tie_word_embeddings=True). " - # f"When PP > 1, tok_embeddings and output layers are on different GPUs " - # f"and cannot share the same weight tensor. " - # f"Please either disable PP (set pipeline_parallel_size=1) or use a model " - # f"without weight tying." - # ) + if parallel_dims.pp_enabled: + if getattr(model_config, "tie_word_embeddings", False): + raise ValueError( + f"Pipeline Parallelism (PP={parallel_dims.pp}) is not supported " + f"with weight tying (tie_word_embeddings=True). " + f"When PP > 1, tok_embeddings and output layers are on different GPUs " + f"and cannot share the same weight tensor. " + f"Please either disable PP (set pipeline_parallel_size=1) or use a model " + f"without weight tying." + ) return ac_config, enable_compile From 5fdc08b3f524abd404d292cd810cde3992738560 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Wed, 6 May 2026 20:46:14 +0800 Subject: [PATCH 099/112] fix(examples/math): test config --- .../math/moonlight_16b_a3b_gsm8k_grpo_megatron_h20_base.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron_h20_base.yaml b/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron_h20_base.yaml index 44b9be4f1d..323048f0f0 100644 --- a/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron_h20_base.yaml +++ b/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron_h20_base.yaml @@ -18,7 +18,7 @@ scheduler: type: null rollout: - backend: "sglang:d1p1t2" + backend: "sglang:d1p1t1" experiment_name: ${experiment_name} trial_name: ${trial_name} max_concurrent_rollouts: 16 @@ -39,7 +39,7 @@ gconfig: temperature: 1.0 actor: - backend: "megatron:(attn:d1p2t2|ffn:d1p2t1e2)" # ← PP=2, attn TP=2, ffn EP=2 + backend: "megatron:(attn:d1p1t2c2|ffn:d1p1t2e2)" # ← PP=1, attn TP=2 CP=2, ffn EP=2 # backend: "megatron:(attn:d1p1t4|ffn:d1p1t1e4)" experiment_name: ${experiment_name} trial_name: ${trial_name} From 9cf7127a5ad5452689c881ccf9de79a245fa77af Mon Sep 17 00:00:00 2001 From: bingyechen Date: Wed, 6 May 2026 21:36:36 +0800 Subject: [PATCH 100/112] feat(diag): cp log --- areal/engine/core/train_engine.py | 69 +++++++++ areal/engine/megatron_engine.py | 247 ++++++++++++++++++++++++++++++ areal/trainer/ppo/actor.py | 30 ++++ areal/utils/data.py | 73 +++++++++ areal/utils/logging.py | 3 + 5 files changed, 422 insertions(+) diff --git a/areal/engine/core/train_engine.py b/areal/engine/core/train_engine.py index 9bc83c2b77..195509e13b 100644 --- a/areal/engine/core/train_engine.py +++ b/areal/engine/core/train_engine.py @@ -19,6 +19,9 @@ reorder_list, unpack_sequence, ) +from areal.utils.logging import getLogger as _getLogger + +_SPLIT_DIAG_LOGGER = _getLogger("R3SplitDiag") __all__ = [ "compute_total_loss_weight", @@ -139,6 +142,72 @@ def reorder_and_pad_outputs( """ res = aggregate_fn(outputs) seqlens = [output_seqlens[i] for i in mb_list.forward_indices] + # [SPLIT_MISMATCH_DIAG] Log EVERYTHING needed to root-cause the + # `split_with_sizes` mismatch reported during compute_logp. + try: + _rank = None + try: + import torch.distributed as _dist + if _dist.is_available() and _dist.is_initialized(): + _rank = _dist.get_rank() + except Exception: + _rank = None + _out_shapes = [tuple(o.shape) for o in outputs] + _out_sum0 = [int(o.shape[0]) for o in outputs if o.ndim >= 1] + _sum_seqlens = int(sum(seqlens)) + _res_shape = tuple(res.shape) + _res_dim0 = int(res.shape[0]) if res.ndim >= 1 else -1 + _fwd_idx = list(mb_list.forward_indices) + _bwd_idx = list(mb_list.backward_indices) + _mbs_lens = None + try: + _mbs_lens = [ + int(mb.get("cu_seqlens", torch.empty(0))[-1].item()) + if isinstance(mb.get("cu_seqlens", None), torch.Tensor) + and mb["cu_seqlens"].numel() > 0 + else None + for mb in getattr(mb_list, "mbs", []) + ] + except Exception: + _mbs_lens = "ERR" + _padded_lens = getattr(mb_list, "padded_to_lengths", None) + _group_lens = getattr(mb_list, "group_lens", None) + _padding_lens = getattr(mb_list, "padding_lengths", None) + _SPLIT_DIAG_LOGGER.info( + "[SPLIT_MISMATCH_DIAG][reorder_and_pad_outputs] rank=%s " + "n_outputs=%d out_shapes=%s out_sum_dim0=%s sum_out_dim0=%d " + "res.shape=%s res.dim0=%d " + "len(output_seqlens)=%d sum(output_seqlens)=%d " + "len(seqlens_reordered)=%d sum(seqlens_reordered)=%d " + "forward_indices=%s backward_indices=%s " + "mb_real_total_lens=%s padded_to_lengths=%s " + "group_lens=%s padding_lengths=%s " + "match=%s output_seqlens_head=%s output_seqlens_tail=%s", + _rank, + len(outputs), + _out_shapes, + _out_sum0, + int(sum(_out_sum0)), + _res_shape, + _res_dim0, + len(output_seqlens), + int(sum(output_seqlens)), + len(seqlens), + _sum_seqlens, + _fwd_idx, + _bwd_idx, + _mbs_lens, + _padded_lens, + _group_lens, + _padding_lens, + (_res_dim0 == _sum_seqlens), + list(output_seqlens[:16]), + list(output_seqlens[-16:]), + ) + except Exception: + _SPLIT_DIAG_LOGGER.exception( + "[SPLIT_MISMATCH_DIAG][reorder_and_pad_outputs] log-emit failed" + ) unpacked = unpack_sequence(res, lens=seqlens, dim=0) reordered = reorder_list(unpacked, mb_list.backward_indices) return pad_and_stack_tensors_along_first_dim(reordered) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index cba4817321..a93a7f36c2 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -93,6 +93,10 @@ ) from areal.models.tree_attn.tree import build_packed_tree_batch from areal.utils import logging, name_resolve, names, perf_tracer, stats_tracker + +# [SPLIT_MISMATCH_DIAG] Module-level diagnostic logger for the +# `split_with_sizes` mismatch hunt in compute_logp. Yellow color, INFO level. +_R3_FWD_DIAG_LOGGER = logging.getLogger("R3FwdDiag") from areal.utils.constants import ( DEFAULT_VECTORIZED_ALIGNMENT_BYTES, DIST_GROUP_DEFAULT_TIMEOUT, @@ -905,6 +909,42 @@ def forward_step(batch_iter, model): cp_size = mpu.get_context_parallel_world_size() cp_local = cp_size > 1 + # [SPLIT_MISMATCH_DIAG] Log per-MB padded input geometry BEFORE forward. + try: + _padded_mb = mb_input.padded_mb + _orig_mb = mb_input.orig_mb + _R3_FWD_DIAG_LOGGER.info( + "[SPLIT_MISMATCH_DIAG][forward_step] PRE_FWD rank=%s " + "padded_to=%s padding_length=%s " + "padded.input_ids.shape=%s " + "padded.cu_seqlens=%s " + "old_cu_seqlens=%s " + "orig.input_ids.shape=%s " + "cp_size=%d cp_local=%s", + dist.get_rank() if dist.is_initialized() else None, + getattr(mb_input, "padded_to_length", None), + getattr(mb_input, "padding_length", None), + tuple(_padded_mb["input_ids"].shape) + if "input_ids" in _padded_mb + else None, + _padded_mb["cu_seqlens"].cpu().tolist() + if "cu_seqlens" in _padded_mb + and isinstance(_padded_mb["cu_seqlens"], torch.Tensor) + else None, + mb_input.old_cu_seqlens.cpu().tolist() + if isinstance(getattr(mb_input, "old_cu_seqlens", None), torch.Tensor) + else getattr(mb_input, "old_cu_seqlens", None), + tuple(_orig_mb["input_ids"].shape) + if "input_ids" in _orig_mb + else None, + cp_size, + cp_local, + ) + except Exception: + _R3_FWD_DIAG_LOGGER.exception( + "[SPLIT_MISMATCH_DIAG][forward_step PRE_FWD] log-emit failed" + ) + output = packed_context_parallel_forward( model, mb_input.padded_mb, @@ -912,6 +952,26 @@ def forward_step(batch_iter, model): is_vision_model=self.is_vision_model, ) + # [SPLIT_MISMATCH_DIAG] Log raw forward output shape (pre-unpad). + try: + _R3_FWD_DIAG_LOGGER.info( + "[SPLIT_MISMATCH_DIAG][forward_step] POST_FWD rank=%s " + "raw_output.shape=%s raw_output.dim0=%s is_pp_last=%s", + dist.get_rank() if dist.is_initialized() else None, + tuple(output.shape) if hasattr(output, "shape") else None, + int(output.shape[0]) + if hasattr(output, "shape") and output.ndim >= 1 + else -1, + mpu.is_pipeline_last_stage( + ignore_virtual=False, + vp_stage=getattr(model, "vp_stage", 0), + ), + ) + except Exception: + _R3_FWD_DIAG_LOGGER.exception( + "[SPLIT_MISMATCH_DIAG][forward_step POST_FWD] log-emit failed" + ) + # Release tree attention metadata after forward pass for key in tree_attn_keys: del mb_input.padded_mb[key] @@ -944,12 +1004,41 @@ def _process_output(input_, output_): cp_inputs["cu_seqlens"] = cp_cu_seqlens return output, functools.partial(_process_output, cp_inputs) else: + _pre_shape = tuple(output.shape) if hasattr(output, "shape") else None output = unpad_logits( output, padding_length=mb_input.padding_length, cu_seqlens=cu_seqlens, old_cu_seqlens=mb_input.old_cu_seqlens, ) + # [SPLIT_MISMATCH_DIAG] Log unpad result. + try: + _R3_FWD_DIAG_LOGGER.info( + "[SPLIT_MISMATCH_DIAG][forward_step] UNPAD_LOGITS " + "rank=%s pre_shape=%s post_shape=%s " + "padding_length=%s " + "cu_seqlens_last=%s old_cu_seqlens_last=%s", + dist.get_rank() if dist.is_initialized() else None, + _pre_shape, + tuple(output.shape) + if hasattr(output, "shape") + else None, + getattr(mb_input, "padding_length", None), + int(cu_seqlens[-1].item()) + if isinstance(cu_seqlens, torch.Tensor) + else cu_seqlens, + int(mb_input.old_cu_seqlens[-1].item()) + if isinstance( + getattr(mb_input, "old_cu_seqlens", None), + torch.Tensor, + ) + else getattr(mb_input, "old_cu_seqlens", None), + ) + except Exception: + _R3_FWD_DIAG_LOGGER.exception( + "[SPLIT_MISMATCH_DIAG][forward_step UNPAD] " + "log-emit failed" + ) return output, functools.partial(_process_output, mb_input.orig_mb) forward_backward_func = get_forward_backward_func() @@ -1074,6 +1163,60 @@ def forward_batch( input_batched, meta = self._normalize_batch_input(input_) + # [SPLIT_MISMATCH_DIAG] Log inbound forward_batch arguments. + try: + _rank = dist.get_rank() if dist.is_initialized() else None + _is_list = isinstance(input_, list) + if _is_list: + _attn_shapes = [ + tuple(d["attention_mask"].shape) + if "attention_mask" in d + else None + for d in input_ + ] + _attn_widths = [ + int(d["attention_mask"].shape[-1]) + if "attention_mask" in d + else None + for d in input_ + ] + _input_summary = ( + f"list len={len(input_)} attn_shapes_head={_attn_shapes[:8]} " + f"attn_widths_head={_attn_widths[:8]} " + f"sum_attn_widths={sum(w for w in _attn_widths if w)}" + ) + else: + _ks = sorted(list(input_.keys())) if isinstance(input_, dict) else "N/A" + _am = ( + tuple(input_["attention_mask"].shape) + if isinstance(input_, dict) and "attention_mask" in input_ + else None + ) + _ii = ( + tuple(input_["input_ids"].shape) + if isinstance(input_, dict) and "input_ids" in input_ + else None + ) + _input_summary = ( + f"dict keys={_ks} attention_mask.shape={_am} " + f"input_ids.shape={_ii}" + ) + _R3_FWD_DIAG_LOGGER.info( + "[SPLIT_MISMATCH_DIAG][forward_batch] ENTER rank=%s " + "is_list=%s meta_is_None=%s output_seqlens_arg=%s | %s", + _rank, + _is_list, + meta is None, + None + if output_seqlens is None + else f"len={len(output_seqlens)} sum={sum(output_seqlens)}", + _input_summary, + ) + except Exception: + _R3_FWD_DIAG_LOGGER.exception( + "[SPLIT_MISMATCH_DIAG][forward_batch ENTER] log-emit failed" + ) + # Step 1: Prepare sequence lengths if meta is not None: assert isinstance(input_, list) @@ -1091,6 +1234,34 @@ def forward_batch( assert output_seqlens is not None batch_size = len(output_seqlens) + # [SPLIT_MISMATCH_DIAG] Log derived sequence-length structure. + try: + _diff = (cu_seqlens[1:] - cu_seqlens[:-1]).cpu().tolist() + _R3_FWD_DIAG_LOGGER.info( + "[SPLIT_MISMATCH_DIAG][forward_batch] SEQLENS rank=%s " + "batch_size=%d cu_seqlens.shape=%s cu_seqlens_total=%d " + "real_lens_head=%s real_lens_tail=%s sum_real_lens=%d " + "output_seqlens_head=%s output_seqlens_tail=%s " + "sum_output_seqlens=%d output_seqlens_path=%s", + dist.get_rank() if dist.is_initialized() else None, + batch_size, + tuple(cu_seqlens.shape), + int(cu_seqlens[-1].item()), + _diff[:16], + _diff[-16:], + int(sum(_diff)), + list(output_seqlens[:16]), + list(output_seqlens[-16:]), + int(sum(output_seqlens)), + "from_attention_mask" + if meta is not None + else "from_cu_seqlens_diff", + ) + except Exception: + _R3_FWD_DIAG_LOGGER.exception( + "[SPLIT_MISMATCH_DIAG][forward_batch SEQLENS] log-emit failed" + ) + # Step 2: Prepare micro-batches mb_list = self._prepare_mb_list(input_batched).to(self.device) @@ -1100,6 +1271,36 @@ def forward_batch( def process_output(output: torch.Tensor, inputs: dict[str, Any]) -> None: result = self._compute_forward_result(output, inputs) outputs.append(result) + # [SPLIT_MISMATCH_DIAG] Log per-MB collected output shape. + try: + _R3_FWD_DIAG_LOGGER.info( + "[SPLIT_MISMATCH_DIAG][forward_batch.process_output] " + "rank=%s mb_idx=%d output.shape=%s output.dim0=%d " + "input_keys=%s input_ids.shape=%s cu_seqlens=%s " + "input_ids_total=%s", + dist.get_rank() if dist.is_initialized() else None, + len(outputs) - 1, + tuple(result.shape) if hasattr(result, "shape") else None, + int(result.shape[0]) + if hasattr(result, "shape") and result.ndim >= 1 + else -1, + sorted(list(inputs.keys())) if isinstance(inputs, dict) else "N/A", + tuple(inputs["input_ids"].shape) + if "input_ids" in inputs + else None, + inputs["cu_seqlens"].cpu().tolist() + if "cu_seqlens" in inputs + and isinstance(inputs["cu_seqlens"], torch.Tensor) + else None, + int(inputs["input_ids"].numel()) + if "input_ids" in inputs + else None, + ) + except Exception: + _R3_FWD_DIAG_LOGGER.exception( + "[SPLIT_MISMATCH_DIAG][forward_batch.process_output] " + "log-emit failed" + ) return None self.forward_backward_batch(mb_list, process_output, forward_only=True) @@ -1996,6 +2197,33 @@ def _prepare_mb_list(self, input_: dict[str, Any]) -> MicroBatchList: pad_to_maximum=self.config.pad_to_maximum, seq_align_to=align_to_multiple_of, ) + # [SPLIT_MISMATCH_DIAG] Log mb_list geometry. + try: + _R3_FWD_DIAG_LOGGER.info( + "[SPLIT_MISMATCH_DIAG][_prepare_mb_list] rank=%s " + "n_mbs=%d align_to_multiple_of=%d " + "group_lens=%s padded_to_lengths=%s padding_lengths=%s " + "align_to_lengths=%s " + "forward_indices=%s backward_indices=%s " + "max_seqlen=%s pp=%d cp=%d tp=%d", + dist.get_rank() if dist.is_initialized() else None, + len(mb_list.mbs), + align_to_multiple_of, + getattr(mb_list, "group_lens", None), + getattr(mb_list, "padded_to_lengths", None), + getattr(mb_list, "padding_lengths", None), + getattr(mb_list, "align_to_lengths", None), + list(getattr(mb_list, "forward_indices", []) or []), + list(getattr(mb_list, "backward_indices", []) or []), + getattr(mb_list, "max_seqlen", None), + pp_size, + cp_size, + tp_size, + ) + except Exception: + _R3_FWD_DIAG_LOGGER.exception( + "[SPLIT_MISMATCH_DIAG][_prepare_mb_list] log-emit failed" + ) self.logger.info( f"#microbatch: {len(mb_list.group_lens)}, microbatch #tokens: {mb_list.group_lens}, " f"aligned to: {mb_list.align_to_lengths}, padded to: {mb_list.padded_to_lengths}, " @@ -2122,6 +2350,25 @@ def _compute_forward_result( if mpu.get_tensor_model_parallel_world_size() > 1 else None, ) + # [SPLIT_MISMATCH_DIAG] Log gather_logprobs i/o shapes. + try: + _R3_FWD_DIAG_LOGGER.info( + "[SPLIT_MISMATCH_DIAG][_compute_forward_result] " + "rank=%s output.shape=%s labels.shape=%s " + "logprobs.shape=%s logprobs.dim0=%s", + dist.get_rank() if dist.is_initialized() else None, + tuple(output.shape) if hasattr(output, "shape") else None, + tuple(labels.shape) if hasattr(labels, "shape") else None, + tuple(logprobs.shape) if hasattr(logprobs, "shape") else None, + int(logprobs.shape[0]) + if hasattr(logprobs, "shape") and logprobs.ndim >= 1 + else -1, + ) + except Exception: + _R3_FWD_DIAG_LOGGER.exception( + "[SPLIT_MISMATCH_DIAG][_compute_forward_result] " + "log-emit failed" + ) return logprobs else: values = output.squeeze(-1) diff --git a/areal/trainer/ppo/actor.py b/areal/trainer/ppo/actor.py index 153bfd8fdf..b83b54e3a4 100644 --- a/areal/trainer/ppo/actor.py +++ b/areal/trainer/ppo/actor.py @@ -222,6 +222,36 @@ def _compute_logp(self, data: dict[str, Any]) -> torch.Tensor | None: input_=data, aggregate_fn=lambda xs: torch.cat(xs, dim=-1), ) + # [SPLIT_MISMATCH_DIAG] Log shape of returned train_logp so we can + # correlate with per-MB output shapes logged by forward_batch. + try: + import torch.distributed as _dist + from areal.utils.logging import getLogger as _getLogger + _diag = _getLogger("R3FwdDiag") + _diag.info( + "[SPLIT_MISMATCH_DIAG][actor._compute_logp] POST_FORWARD " + "rank=%s train_logp.shape=%s train_logp.dim0=%s " + "data.input_ids.shape=%s data.attention_mask.shape=%s " + "data.attention_mask.sum=%s", + _dist.get_rank() if _dist.is_initialized() else None, + tuple(train_logp.shape) + if isinstance(train_logp, torch.Tensor) + else None, + int(train_logp.shape[0]) + if isinstance(train_logp, torch.Tensor) and train_logp.ndim >= 1 + else -1, + tuple(data["input_ids"].shape) if "input_ids" in data else None, + tuple(data["attention_mask"].shape) + if "attention_mask" in data + else None, + int(data["attention_mask"].sum().item()) + if "attention_mask" in data + else None, + ) + except Exception: + logger.exception( + "[SPLIT_MISMATCH_DIAG][actor._compute_logp POST_FORWARD] log-emit failed" + ) # R3 effectiveness metrics. At compute_logp time the training weights # equal the rollout weights (no optimizer step has touched θ in this # rollout epoch), so comparing SGLang's cached logprobs against the diff --git a/areal/utils/data.py b/areal/utils/data.py index 09368e4ef2..17568927ad 100644 --- a/areal/utils/data.py +++ b/areal/utils/data.py @@ -437,6 +437,35 @@ def unpack_sequence( ): """Unpack a sequence tensor into a list of tensors based on cumulative sequence lengths.""" if lens is not None: + # [SPLIT_MISMATCH_DIAG] Pre-flight check that emits a detailed log + # BEFORE torch.split raises so we capture the exact mismatch. + try: + _x_dim = int(x.shape[dim]) if x.ndim > dim else -1 + _sum_lens = int(sum(lens)) + if _x_dim != _sum_lens: + import torch.distributed as _dist + from areal.utils.logging import getLogger as _getLogger + _diag = _getLogger("R3SplitDiag") + _diag.error( + "[SPLIT_MISMATCH_DIAG][unpack_sequence] PRE_SPLIT_MISMATCH " + "rank=%s x.shape=%s dim=%d x.size(dim)=%d " + "len(lens)=%d sum(lens)=%d delta=%d " + "lens_head=%s lens_tail=%s " + "min_len=%s max_len=%s", + _dist.get_rank() if _dist.is_initialized() else None, + tuple(x.shape), + dim, + _x_dim, + len(lens), + _sum_lens, + _sum_lens - _x_dim, + list(lens[:16]), + list(lens[-16:]), + min(lens) if lens else None, + max(lens) if lens else None, + ) + except Exception: + pass return torch.split(x, lens, dim=dim) if cu_seqlens is not None: return torch.split( @@ -1049,6 +1078,16 @@ def unpad_logits( cu_seqlens: torch.Tensor | None = None, old_cu_seqlens: torch.Tensor | None = None, ): + # [SPLIT_MISMATCH_DIAG] Log unpad_logits inputs/outputs centrally so we + # observe ALL call sites (engine forward_step, etc.). + try: + import torch.distributed as _dist + from areal.utils.logging import getLogger as _getLogger + _diag = _getLogger("R3SplitDiag") + _in_shape = tuple(logits.shape) + except Exception: + _diag = None + _in_shape = None # TODO: when using megatron, logits are in fp32, # create new logits in bucket to reduce peak memory usage # First unpad batch @@ -1069,8 +1108,42 @@ def unpad_logits( start = cu_seqlens[i].item() length = old_end - old_start new_logits[old_start:old_end] = logits[start : start + length] + if _diag is not None: + try: + _diag.info( + "[SPLIT_MISMATCH_DIAG][unpad_logits] rank=%s " + "in_shape=%s padding_length=%d " + "after_pad_strip_shape=%s out_shape=%s " + "cu_seqlens_last=%s old_cu_seqlens_last=%s " + "batch_size=%d", + _dist.get_rank() if _dist.is_initialized() else None, + _in_shape, + int(padding_length), + tuple(logits.shape), + tuple(new_logits.shape), + int(cu_seqlens[-1].item()) + if cu_seqlens is not None + else None, + int(old_cu_seqlens[-1].item()), + batch_size, + ) + except Exception: + pass return new_logits + if _diag is not None: + try: + _diag.info( + "[SPLIT_MISMATCH_DIAG][unpad_logits] rank=%s " + "in_shape=%s padding_length=%d out_shape=%s " + "old_cu_seqlens=None_branch", + _dist.get_rank() if _dist.is_initialized() else None, + _in_shape, + int(padding_length), + tuple(logits.shape), + ) + except Exception: + pass return logits diff --git a/areal/utils/logging.py b/areal/utils/logging.py index 6752134a76..0c1ab583c3 100644 --- a/areal/utils/logging.py +++ b/areal/utils/logging.py @@ -122,6 +122,9 @@ "InferenceRouter": "white", "InferenceGateway": "white", "RPCGuard": "white", + # R3 diagnostic loggers - yellow (debug) + "R3SplitDiag": "yellow", + "R3FwdDiag": "yellow", } # Prefix patterns checked in order (first match wins) From 001e036cf0ccf4010ed3f0ca5cfc7ac0d5753f57 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Wed, 6 May 2026 22:15:00 +0800 Subject: [PATCH 101/112] fix: add cp_local --- areal/engine/megatron_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index a93a7f36c2..1d7a0ab72a 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -907,7 +907,7 @@ def forward_step(batch_iter, model): tree_attn_keys = list(tree_kwargs.keys()) cp_size = mpu.get_context_parallel_world_size() - cp_local = cp_size > 1 + cp_local = cp_size > 1 and not forward_only # [SPLIT_MISMATCH_DIAG] Log per-MB padded input geometry BEFORE forward. try: From b5c8d28c53185ce76b8c904d84036bcfaaecebe9 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Wed, 6 May 2026 22:16:35 +0800 Subject: [PATCH 102/112] fix: config --- .../math/moonlight_16b_a3b_gsm8k_grpo_megatron_h20_base.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron_h20_base.yaml b/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron_h20_base.yaml index 323048f0f0..2ed59e9aca 100644 --- a/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron_h20_base.yaml +++ b/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron_h20_base.yaml @@ -18,7 +18,7 @@ scheduler: type: null rollout: - backend: "sglang:d1p1t1" + backend: "sglang:d1p1t2" experiment_name: ${experiment_name} trial_name: ${trial_name} max_concurrent_rollouts: 16 From 97b5dd41cc8b055bf15db412d73d0b86916575d1 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Wed, 6 May 2026 22:39:22 +0800 Subject: [PATCH 103/112] fix(engine): disable cp_local --- areal/engine/megatron_engine.py | 29 ++++++++++++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index 1d7a0ab72a..a684c25fdc 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -907,7 +907,34 @@ def forward_step(batch_iter, model): tree_attn_keys = list(tree_kwargs.keys()) cp_size = mpu.get_context_parallel_world_size() - cp_local = cp_size > 1 and not forward_only + # Disable the CP-local output path unconditionally. + # + # The CP-local branch below only rewrites ``loss_mask`` / + # ``cu_seqlens`` / ``_cp_local_labels`` inside ``cp_inputs``, but + # leaves the other per-token tensors carried by ``mb_input.orig_mb`` + # (``logprobs``, ``prox_logp``, ``advantages``, ``rewards``, + # ``versions``, ...) at their original full real-sequence length. + # Downstream consumers then see inconsistent shapes, for example: + # * ``forward_batch`` (compute_logp / advantages) aggregates the + # per-MB logprobs via ``reorder_and_pad_outputs`` which splits + # by ``output_seqlens`` (full real lens). CP-local outputs are + # ``padded_to_length // cp_size`` rows and break + # ``torch.split_with_sizes``. + # * ``train_batch`` (ppo_update) routes the CP-local ``loss_mask`` + # into ``grpo_loss_fn`` / ``ppo_actor_loss_fn`` / + # ``apply_rejection_sampling`` together with the full-length + # ``proximal_logprobs`` / ``old_logprobs``, raising + # ``proximal_logprobs shape [N] != loss_mask shape [N/cp]``. + # + # Going through ``packed_context_parallel_forward(gather_cp_output= + # True)`` + ``unpad_logits`` restores each MB's output to the full + # real-sequence length, so it aligns with every tensor in + # ``orig_mb`` regardless of whether we are in forward-only or + # training mode. The attention / MoE forward still runs CP-sharded + # inside ``packed_context_parallel_forward``; only the final + # logits are all-gathered along the CP dimension. + _ = cp_size # kept for the dead CP-local branch below + cp_local = False # [SPLIT_MISMATCH_DIAG] Log per-MB padded input geometry BEFORE forward. try: From f816f6ce01afbfa633ed85e0113350a63f2f22bb Mon Sep 17 00:00:00 2001 From: bingyechen Date: Wed, 6 May 2026 23:05:37 +0800 Subject: [PATCH 104/112] fix(examples/math): fix oom --- areal/engine/megatron_engine.py | 6 ++++++ .../moonlight_16b_a3b_gsm8k_grpo_megatron_h20_base.yaml | 2 +- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index a684c25fdc..f12ef14720 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -765,6 +765,9 @@ def prepare_batch( def update_weights(self, meta: WeightUpdateMeta): self._check_rollout_engine_connected() + gc.collect() + current_platform.empty_cache() + gc.collect() with self._offload_aware_context(): if meta.type == "xccl": assert self.weight_update_group_initialized @@ -2099,6 +2102,9 @@ def _save_model_to_hf( source_path=base_model_path, ) else: + gc.collect() + current_platform.empty_cache() + gc.collect() save_weights_to_hf_with_mbridge_fast( bridge=self.bridge, models=self.model, diff --git a/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron_h20_base.yaml b/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron_h20_base.yaml index 2ed59e9aca..a51916e143 100644 --- a/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron_h20_base.yaml +++ b/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron_h20_base.yaml @@ -39,7 +39,7 @@ gconfig: temperature: 1.0 actor: - backend: "megatron:(attn:d1p1t2c2|ffn:d1p1t2e2)" # ← PP=1, attn TP=2 CP=2, ffn EP=2 + backend: "megatron:(attn:d1p1t2c2|ffn:d1p1t1e4)" # ← PP=1, attn TP=2 CP=2, ffn EP=2 # backend: "megatron:(attn:d1p1t4|ffn:d1p1t1e4)" experiment_name: ${experiment_name} trial_name: ${trial_name} From 06421730d416db5d587d3ed8b8d3e7190730e886 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Thu, 7 May 2026 12:28:29 +0800 Subject: [PATCH 105/112] refactor: remove log --- areal/engine/megatron_engine.py | 77 ++-- areal/engine/megatron_engine_r3_patch.py | 516 ++--------------------- areal/engine/router_replay_patch.py | 400 +----------------- areal/engine/router_replay_utils.py | 332 --------------- areal/trainer/ppo/actor.py | 284 +------------ areal/trainer/ppo/actor_r3_patch.py | 54 +-- areal/workflow/rlvr.py | 24 -- areal/workflow/rlvr_r3_patch.py | 25 +- 8 files changed, 109 insertions(+), 1603 deletions(-) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index f12ef14720..ca031a27ee 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -122,12 +122,12 @@ from areal.utils.seeding import get_seed if TYPE_CHECKING: - from areal.api import Scheduler - from areal.api.cli_args import PPOActorConfig, PPOCriticConfig - from megatron.bridge import AutoBridge as MegatronBridgeAutoBridge from megatron.bridge.peft.lora import LoRA as MegatronBridgeLoRA + + from areal.api import Scheduler from areal.api.cli_args import DPOEngineConfig, PPOActorConfig, PPOCriticConfig + def _patch_gpt_model_postprocess_for_inference(model_list: _MegatronModelList) -> None: """Patch ``GPTModel._postprocess`` to skip MTP when ``labels=None``. @@ -145,13 +145,20 @@ def _patch_gpt_model_postprocess_for_inference(model_list: _MegatronModelList) - _original_postprocess = GPTModel._postprocess - def _patched_postprocess(self, hidden_states, input_ids, position_ids, labels, **kwargs): + def _patched_postprocess( + self, hidden_states, input_ids, position_ids, labels, **kwargs + ): if labels is None and getattr(self.config, "mtp_num_layers", None) is not None: original_mtp = self.config.mtp_num_layers self.config.mtp_num_layers = None try: result = _original_postprocess( - self, hidden_states, input_ids, position_ids, labels=labels, **kwargs + self, + hidden_states, + input_ids, + position_ids, + labels=labels, + **kwargs, ) finally: self.config.mtp_num_layers = original_mtp @@ -164,7 +171,6 @@ def _patched_postprocess(self, hidden_states, input_ids, position_ids, labels, * GPTModel._areal_postprocess_patched = True - # `model.named_modules()` yields LOCAL layer indices on each PP rank, while # `get_named_parameters` rewrites them to GLOBAL indices via layer_offset. Strip # the index so the GLU detection set matches across PP ranks. Also strip trailing @@ -235,7 +241,7 @@ def __init__(self, config: TrainEngineConfig): self._r3_enabled: bool = getattr(config.megatron, "enable_router_replay", False) if not self._r3_enabled: self._r3_enabled = getattr(config, "_r3_enable_router_replay", False) - logging.getLogger("[MegatronEngine]").info( + logging.getLogger("[MegatronEngine]").debug( "[R3] __init__: _r3_enabled=%s, config.megatron.enable_router_replay=%s, " "config._r3_enable_router_replay=%s, config.megatron type=%s", self._r3_enabled, @@ -379,9 +385,8 @@ def initialize(self, addr: str | None, ft_spec: FinetuneSpec, *args, **kwargs): self.tokenizer = load_hf_tokenizer(self.config.path) # R3: _r3_enabled was set in __init__ from config.megatron.enable_router_replay. - self.logger.info( - "[R3] enable_router_replay=%s (config.megatron type=%s, " - "config type=%s).", + self.logger.debug( + "[R3] enable_router_replay=%s (config.megatron type=%s, config type=%s).", self._r3_enabled, type(self.config.megatron).__name__, type(self.config).__name__, @@ -428,9 +433,12 @@ def initialize(self, addr: str | None, ft_spec: FinetuneSpec, *args, **kwargs): # TopKRouter.__init__ and TransformerConfig.__init__ are patched. if self._r3_enabled: from areal.engine.router_replay_patch import apply_router_replay_patch + apply_router_replay_patch() self.tf_config.enable_routing_replay = True - self.logger.info("[R3] Router Replay patches applied before model creation.") + self.logger.info( + "[R3] Router Replay patches applied before model creation." + ) with self.device: models = make_mcore_model( @@ -531,7 +539,10 @@ def initialize(self, addr: str | None, ft_spec: FinetuneSpec, *args, **kwargs): # R3: Apply engine-level patch after model and optimizer are ready. if self._r3_enabled: - from areal.engine.megatron_engine_r3_patch import patch_megatron_engine_for_r3 + from areal.engine.megatron_engine_r3_patch import ( + patch_megatron_engine_for_r3, + ) + patch_megatron_engine_for_r3(self, enable_router_replay=True) self.logger.info("[R3] Router Replay enabled on MegatronEngine.") self._initialized = True @@ -962,7 +973,9 @@ def forward_step(batch_iter, model): and isinstance(_padded_mb["cu_seqlens"], torch.Tensor) else None, mb_input.old_cu_seqlens.cpu().tolist() - if isinstance(getattr(mb_input, "old_cu_seqlens", None), torch.Tensor) + if isinstance( + getattr(mb_input, "old_cu_seqlens", None), torch.Tensor + ) else getattr(mb_input, "old_cu_seqlens", None), tuple(_orig_mb["input_ids"].shape) if "input_ids" in _orig_mb @@ -1034,7 +1047,9 @@ def _process_output(input_, output_): cp_inputs["cu_seqlens"] = cp_cu_seqlens return output, functools.partial(_process_output, cp_inputs) else: - _pre_shape = tuple(output.shape) if hasattr(output, "shape") else None + _pre_shape = ( + tuple(output.shape) if hasattr(output, "shape") else None + ) output = unpad_logits( output, padding_length=mb_input.padding_length, @@ -1050,9 +1065,7 @@ def _process_output(input_, output_): "cu_seqlens_last=%s old_cu_seqlens_last=%s", dist.get_rank() if dist.is_initialized() else None, _pre_shape, - tuple(output.shape) - if hasattr(output, "shape") - else None, + tuple(output.shape) if hasattr(output, "shape") else None, getattr(mb_input, "padding_length", None), int(cu_seqlens[-1].item()) if isinstance(cu_seqlens, torch.Tensor) @@ -1066,8 +1079,7 @@ def _process_output(input_, output_): ) except Exception: _R3_FWD_DIAG_LOGGER.exception( - "[SPLIT_MISMATCH_DIAG][forward_step UNPAD] " - "log-emit failed" + "[SPLIT_MISMATCH_DIAG][forward_step UNPAD] log-emit failed" ) return output, functools.partial(_process_output, mb_input.orig_mb) @@ -1199,9 +1211,7 @@ def forward_batch( _is_list = isinstance(input_, list) if _is_list: _attn_shapes = [ - tuple(d["attention_mask"].shape) - if "attention_mask" in d - else None + tuple(d["attention_mask"].shape) if "attention_mask" in d else None for d in input_ ] _attn_widths = [ @@ -1228,8 +1238,7 @@ def forward_batch( else None ) _input_summary = ( - f"dict keys={_ks} attention_mask.shape={_am} " - f"input_ids.shape={_ii}" + f"dict keys={_ks} attention_mask.shape={_am} input_ids.shape={_ii}" ) _R3_FWD_DIAG_LOGGER.info( "[SPLIT_MISMATCH_DIAG][forward_batch] ENTER rank=%s " @@ -1283,9 +1292,7 @@ def forward_batch( list(output_seqlens[:16]), list(output_seqlens[-16:]), int(sum(output_seqlens)), - "from_attention_mask" - if meta is not None - else "from_cu_seqlens_diff", + "from_attention_mask" if meta is not None else "from_cu_seqlens_diff", ) except Exception: _R3_FWD_DIAG_LOGGER.exception( @@ -1315,16 +1322,12 @@ def process_output(output: torch.Tensor, inputs: dict[str, Any]) -> None: if hasattr(result, "shape") and result.ndim >= 1 else -1, sorted(list(inputs.keys())) if isinstance(inputs, dict) else "N/A", - tuple(inputs["input_ids"].shape) - if "input_ids" in inputs - else None, + tuple(inputs["input_ids"].shape) if "input_ids" in inputs else None, inputs["cu_seqlens"].cpu().tolist() if "cu_seqlens" in inputs and isinstance(inputs["cu_seqlens"], torch.Tensor) else None, - int(inputs["input_ids"].numel()) - if "input_ids" in inputs - else None, + int(inputs["input_ids"].numel()) if "input_ids" in inputs else None, ) except Exception: _R3_FWD_DIAG_LOGGER.exception( @@ -1626,7 +1629,10 @@ def _create_optimizer(self, ft_spec: FinetuneSpec) -> None: mcore_opt_config.use_precision_aware_optimizer and ( mcore_opt_config.main_params_dtype != torch.float32 - or (mcore_opt_config.fp8_recipe is None or mcore_opt_config.fp8_recipe == "delayed") + or ( + mcore_opt_config.fp8_recipe is None + or mcore_opt_config.fp8_recipe == "delayed" + ) or mcore_opt_config.optimizer_cpu_offload ) ) @@ -2399,8 +2405,7 @@ def _compute_forward_result( ) except Exception: _R3_FWD_DIAG_LOGGER.exception( - "[SPLIT_MISMATCH_DIAG][_compute_forward_result] " - "log-emit failed" + "[SPLIT_MISMATCH_DIAG][_compute_forward_result] log-emit failed" ) return logprobs else: diff --git a/areal/engine/megatron_engine_r3_patch.py b/areal/engine/megatron_engine_r3_patch.py index d9c806a907..23c1b0e18f 100644 --- a/areal/engine/megatron_engine_r3_patch.py +++ b/areal/engine/megatron_engine_r3_patch.py @@ -1,3 +1,5 @@ +# SPDX-License-Identifier: Apache-2.0 + """ R3 Integration Patch for MegatronEngine. @@ -71,55 +73,8 @@ def patch_megatron_engine_for_r3( engine._r3_original_forward_backward_batch = engine.forward_backward_batch engine._r3_pending_routed_experts = None - # ---------- R3 diagnostics: one-shot config snapshot so the NEXT - # run's log unambiguously records the PP layout Megatron-Core - # actually saw (num_layers, vp_size, pp_size, local offset/end, - # router_instance count per PP rank). This answers D2/D3/D5 in a - # single early-startup line without polluting hot paths. - try: - from areal.engine.router_replay_patch import RouterReplay as _RR - from areal.engine.router_replay_utils import ( - _r3_pp_tp_info as _ppi, - _r3_verbose as _v, - get_current_rank_layer_info as _info, - is_moe_layer as _ism, - ) - if _v(): - _tf = engine.tf_config - _li = _info(_tf) - _moe_list = [i for i in range(_li["start"], _li["end"]) if _ism(_tf, i)] - _dense_list = [ - i for i in range(_li["start"], _li["end"]) if not _ism(_tf, i) - ] - logger.info( - "[R3-STAGE0/patch_megatron_engine_for_r3] ENGINE_SNAPSHOT %s " - "tf_config.num_layers=%d pp_size=%d vp_size=%s " - "moe_layer_freq=%s first_k_dense_replace=%s " - "local={start:%d end:%d count:%d} " - "moe_layers_in_range=%s non_moe_layers_in_range=%s " - "total_router_instances=%d " - "inst_creator_ranks=%s", - _ppi(_tf), - _tf.num_layers, - getattr(_tf, "pipeline_model_parallel_size", 1), - getattr(_tf, "virtual_pipeline_model_parallel_size", None), - getattr(_tf, "moe_layer_freq", None), - getattr(_tf, "first_k_dense_replace", None), - _li["start"], _li["end"], _li["count"], - _moe_list, _dense_list, - len(_RR.router_instances), - [getattr(r, "creator_rank", -1) for r in _RR.router_instances], - ) - except Exception: - logger.exception( - "[R3-STAGE0/patch_megatron_engine_for_r3] snapshot log failed" - ) - # -------------------------------------------------------------- - # Bind the wrapped method - engine.forward_backward_batch = types.MethodType( - _r3_forward_backward_batch, engine - ) + engine.forward_backward_batch = types.MethodType(_r3_forward_backward_batch, engine) logger.debug("[R3] MegatronEngine patched successfully.") @@ -169,7 +124,9 @@ def _align_routed_experts_to_mask( # Output: (n_seqs, max_seqlen, L, K) with real tokens left-aligned aligned = torch.zeros( - n_seqs, max_seqlen, *extra_dims, + n_seqs, + max_seqlen, + *extra_dims, dtype=routed_experts.dtype, device=routed_experts.device, ) @@ -185,65 +142,12 @@ def _align_routed_experts_to_mask( logger.debug( "[R3] _align_routed_experts_to_mask: re_shape=%s -> aligned_shape=%s, " "n_seqs=%d (re_bs=%d), seq_lens=%s.", - routed_experts.shape, aligned.shape, n_seqs, re_bs, seq_lens[:8], + routed_experts.shape, + aligned.shape, + n_seqs, + re_bs, + seq_lens[:8], ) - # Detailed alignment log: smoking-gun check for the "last generated token - # has no routing" edge case (SGLang convention: num_sgl_tokens = - # prompt_len + gen_len - 1). If cu_seqlens claims k real tokens but the - # source only has k-1 non-zero rows, the k-th row here is a ZERO ROW that - # will route to expert 0 unconditionally. - try: - from areal.engine.router_replay_utils import ( - _r3_pp_tp_info, - _r3_should_log, - _r3_tensor_sig, - _r3_verbose, - ) - - if _r3_verbose() and _r3_should_log("_align_routed_experts_to_mask"): - with torch.no_grad(): - per_row_zero_src = ( - (routed_experts == 0).reshape(re_bs, re_seqlen, -1).all(dim=-1) - ) - src_zero_rows_per_sample = per_row_zero_src.sum(dim=-1).tolist() - per_row_zero_dst = ( - (aligned == 0).reshape(n_seqs, max_seqlen, -1).all(dim=-1) - ) - dst_zero_rows_per_sample = per_row_zero_dst.sum(dim=-1).tolist() - # For each sample, locate first zero-row idx within the real-token window. - first_zero_in_real = [] - for i in range(min(n_seqs, re_bs)): - L = int(seq_lens[i]) - if L <= 0: - first_zero_in_real.append(-1) - continue - row = per_row_zero_src[i, :L] - idx = torch.nonzero(row, as_tuple=False) - first_zero_in_real.append( - int(idx[0].item()) if idx.numel() > 0 else -1 - ) - logger.info( - "[R3-STAGE3/_align_routed_experts_to_mask] mb=%s %s " - "re_shape=%s aligned_shape=%s n_seqs=%d re_bs=%d " - "seq_lens[:8]=%s src_zero_rows_per_sample[:8]=%s " - "first_zero_in_real_window[:8]=%s " - "dst_zero_rows_per_sample[:8]=%s | %s | %s", - _r3_mb_idx, - _r3_pp_tp_info(), - tuple(routed_experts.shape), - tuple(aligned.shape), - n_seqs, - re_bs, - seq_lens[:8], - src_zero_rows_per_sample[:8], - first_zero_in_real[:8], - dst_zero_rows_per_sample[:8], - _r3_tensor_sig("src_re", routed_experts, max_sample=4), - _r3_tensor_sig("aligned", aligned, max_sample=4), - ) - except Exception: - # diagnostic helper must never break the main flow - pass return aligned @@ -338,54 +242,6 @@ def _split_routed_experts_for_mbs( [r.shape[0] for r in result], "None" if forward_indices is None else f"len={len(forward_indices)}", ) - try: - from areal.engine.router_replay_utils import ( - _r3_hash64, - _r3_per_sample_hashes, - _r3_per_sample_nnz, - _r3_per_sample_seq_real_len, - _r3_pp_tp_info, - _r3_should_log, - _r3_tensor_sig, - _r3_verbose, - ) - - if _r3_verbose() and _r3_should_log("_split_routed_experts_for_mbs"): - pre_hash = _r3_per_sample_hashes(routed_experts, max_rows=32) - post_hash = _r3_per_sample_hashes(reordered, max_rows=32) - per_mb_hashes = [ - [hex(h) for h in _r3_per_sample_hashes(r, max_rows=16)] - for r in result - ] - per_mb_nnz = [_r3_per_sample_nnz(r, max_rows=16) for r in result] - per_mb_real = [_r3_per_sample_seq_real_len(r, max_rows=16) for r in result] - logger.info( - "[R3-STAGE3/_split_routed_experts_for_mbs] %s " - "input_shape=%s input_hash=%s n_mbs=%d " - "forward_indices=%s per_mb_shapes=%s per_mb_hashes=%s " - "pre_reorder_per_sample_hash[:16]=%s " - "post_reorder_per_sample_hash[:16]=%s " - "per_mb_per_sample_hash=%s per_mb_per_sample_nnz=%s " - "per_mb_per_sample_real_len=%s | %s", - _r3_pp_tp_info(), - tuple(routed_experts.shape), - hex(_r3_hash64(routed_experts)), - n_mbs, - "None" if forward_indices is None - else f"len={len(forward_indices)} first32={forward_indices[:32].tolist() if hasattr(forward_indices,'tolist') else list(forward_indices)[:32]}", - [tuple(r.shape) for r in result], - [hex(_r3_hash64(r)) for r in result], - [hex(h) for h in pre_hash[:16]], - [hex(h) for h in post_hash[:16]], - per_mb_hashes, - per_mb_nnz, - per_mb_real, - _r3_tensor_sig("routed_experts", routed_experts, max_sample=4), - ) - except Exception: - logger.exception( - "[R3-STAGE3/_split_routed_experts_for_mbs] trace log failed" - ) return result @@ -412,7 +268,11 @@ def _get_cu_seqlens_for_mb(mb_item) -> tuple[torch.Tensor, int] | None: source = ("padded_mb", cu, int(max_sl)) # Try orig_mb as fallback (pre-padding cu_seqlens) - if source is None and hasattr(mb_item, "orig_mb") and isinstance(mb_item.orig_mb, dict): + if ( + source is None + and hasattr(mb_item, "orig_mb") + and isinstance(mb_item.orig_mb, dict) + ): cu = mb_item.orig_mb.get("cu_seqlens") max_sl = mb_item.orig_mb.get("max_seqlen") if cu is not None and max_sl is not None: @@ -422,25 +282,6 @@ def _get_cu_seqlens_for_mb(mb_item) -> tuple[torch.Tensor, int] | None: return None src_name, cu_out, max_sl_out = source - try: - from areal.engine.router_replay_utils import ( - _r3_pp_tp_info, - _r3_should_log, - _r3_tensor_sig, - _r3_verbose, - ) - - if _r3_verbose() and _r3_should_log("_get_cu_seqlens_for_mb"): - logger.info( - "[R3-STAGE3/_get_cu_seqlens_for_mb] %s source=%s " - "max_seqlen=%d | %s", - _r3_pp_tp_info(), - src_name, - max_sl_out, - _r3_tensor_sig("cu_seqlens", cu_out), - ) - except Exception: - pass return cu_out, max_sl_out @@ -452,9 +293,7 @@ def _get_cu_seqlens_for_mb(mb_item) -> tuple[torch.Tensor, int] | None: def _r3_forward_backward_batch( self, mb_list, - process_output_fn: Callable[ - [torch.Tensor, dict[str, Any]], torch.Tensor | None - ], + process_output_fn: Callable[[torch.Tensor, dict[str, Any]], torch.Tensor | None], forward_only: bool = False, ) -> None: """Drop-in replacement for ``MegatronEngine.forward_backward_batch`` @@ -466,10 +305,6 @@ def _r3_forward_backward_batch( from areal.engine.router_replay_patch import RouterReplay, RouterReplayAction from areal.engine.router_replay_utils import ( RouterReplayHelper, - _r3_pp_tp_info, - _r3_should_log, - _r3_tensor_sig, - _r3_verbose, clear_router_replay, setup_per_microbatch_replay_forward, ) @@ -479,56 +314,15 @@ def _r3_forward_backward_batch( # ------------------------------------------------------------------ routed_experts_batch = None _from_side_channel = False - _consumed_trace_id = getattr(self, "_r3_active_trace_id", None) # Strategy A: Side-channel (preferred path) - if hasattr(self, '_r3_pending_routed_experts') and self._r3_pending_routed_experts is not None: + if ( + hasattr(self, "_r3_pending_routed_experts") + and self._r3_pending_routed_experts is not None + ): routed_experts_batch = self._r3_pending_routed_experts self._r3_pending_routed_experts = None # Consume it _from_side_channel = True - try: - from areal.engine.router_replay_utils import ( - _r3_hash64, - _r3_per_sample_hashes, - _r3_per_sample_nnz, - _r3_per_sample_seq_real_len, - _r3_pp_tp_info, - _r3_verbose, - ) - if _r3_verbose(): - # ---------- R3 diagnostics: D6 -- cross-PP batch-hash - # consistency. Print FULL global hash + per-sample hashes - # so PP rank 0 and PP rank 1 at the same trace_id can be - # diff'd offline. If hashes differ, the side-channel - # broadcast/scatter is wrong (root cause); if identical, - # PP-input parity is proved and we can rule out D6. - _full_hash = hex(_r3_hash64(routed_experts_batch)) - _all_per_sample = _r3_per_sample_hashes( - routed_experts_batch, max_rows=4096, - ) - _all_per_sample_hex = [hex(h) for h in _all_per_sample] - logger.info( - "[R3-STAGE3/_r3_forward_backward_batch] " - "SIDE_CHANNEL_CONSUME trace_id=%s %s forward_only=%s " - "shape=%s hash=%s per_sample_hash[:16]=%s " - "per_sample_nnz[:16]=%s per_sample_real_len[:16]=%s " - "n_samples_total=%d full_per_sample_hash=%s", - _consumed_trace_id, - _r3_pp_tp_info(), - forward_only, - routed_experts_batch.shape, - _full_hash, - _all_per_sample_hex[:16], - _r3_per_sample_nnz(routed_experts_batch, max_rows=16), - _r3_per_sample_seq_real_len(routed_experts_batch, max_rows=16), - len(_all_per_sample), - _all_per_sample_hex, - ) - except Exception: - logger.exception( - "[R3-STAGE3/_r3_forward_backward_batch] " - "SIDE_CHANNEL_CONSUME trace log failed" - ) logger.debug( "[R3] Retrieved routed_experts from engine side-channel: shape=%s.", routed_experts_batch.shape, @@ -578,23 +372,8 @@ def _r3_forward_backward_batch( routed_experts_batch.shape, forward_only, ) - if _r3_verbose() and _r3_should_log("_r3_forward_backward_batch/ENTER"): - logger.info( - "[R3-STAGE3/_r3_forward_backward_batch] ENTER %s " - "n_mbs=%d forward_only=%s from_side_channel=%s " - "has_padded_mbs=%s | %s", - _r3_pp_tp_info(), - len(mb_list), - forward_only, - _from_side_channel, - mb_list.padded_mbs is not None, - _r3_tensor_sig("routed_experts_batch", routed_experts_batch), - ) - # Split routed_experts per micro-batch - per_mb_routed_experts = _split_routed_experts_for_mbs( - routed_experts_batch, mb_list - ) + per_mb_routed_experts = _split_routed_experts_for_mbs(routed_experts_batch, mb_list) # ------------------------------------------------------------------ # 2. Store R3 data on the engine for the wrapped iterator. @@ -605,6 +384,7 @@ def _r3_forward_backward_batch( # Compute seq_align_to (same as what _prepare_mb_list uses) from megatron.core import parallel_state as mpu + tp_size = mpu.get_tensor_model_parallel_world_size() cp_size = getattr(mpu, "get_context_parallel_world_size", lambda: 1)() seq_align_to = tp_size * cp_size * 2 if cp_size > 1 else tp_size @@ -640,22 +420,6 @@ def __next__(self): else None ) - if _r3_verbose() and _r3_should_log("_R3MicroBatchIterator.__next__"): - logger.info( - "[R3-STAGE3/_R3MicroBatchIterator] ENTER mb_idx=%d %s " - "re_shape=%s has_orig_mb=%s has_padded_mb=%s " - "has_old_cu_seqlens=%s", - idx, - _r3_pp_tp_info(), - None if re is None else tuple(re.shape), - hasattr(mb_item, "orig_mb") - and isinstance(mb_item.orig_mb, dict), - hasattr(mb_item, "padded_mb") - and isinstance(mb_item.padded_mb, dict), - hasattr(mb_item, "old_cu_seqlens") - and mb_item.old_cu_seqlens is not None, - ) - # When backward recompute finishes and next forward starts, # switch back to REPLAY_FORWARD. if RouterReplayHelper.is_replay_backward_action(model_config): @@ -663,19 +427,7 @@ def __next__(self): model_config ) for router in router_list: - router.set_router_replay_action( - RouterReplayAction.REPLAY_FORWARD - ) - if _r3_verbose() and _r3_should_log( - "_R3MicroBatchIterator.toggle_to_forward" - ): - logger.info( - "[R3-STAGE3/_R3MicroBatchIterator] TOGGLE backward->forward " - "mb_idx=%d %s n_routers=%d", - idx, - _r3_pp_tp_info(), - len(router_list), - ) + router.set_router_replay_action(RouterReplayAction.REPLAY_FORWARD) if re is not None: # Extract cu_seqlens from padded_mb (TP-aligned, what the model sees) @@ -692,60 +444,26 @@ def __next__(self): # to know each sample's actual token count for # extracting from routed_experts. orig_cu = None - orig_cu_src = None - if hasattr(mb_item, "old_cu_seqlens") and mb_item.old_cu_seqlens is not None: + if ( + hasattr(mb_item, "old_cu_seqlens") + and mb_item.old_cu_seqlens is not None + ): orig_cu = mb_item.old_cu_seqlens - orig_cu_src = "old_cu_seqlens" - elif hasattr(mb_item, "orig_mb") and isinstance(mb_item.orig_mb, dict): + elif hasattr(mb_item, "orig_mb") and isinstance( + mb_item.orig_mb, dict + ): orig_cu = mb_item.orig_mb.get("cu_seqlens") - orig_cu_src = "orig_mb.cu_seqlens" if orig_cu is None: - # Fallback: use padded cu_seqlens directly orig_cu = cu_seqlens - orig_cu_src = "padded_cu_seqlens (fallback)" - - if _r3_verbose() and _r3_should_log( - "_R3MicroBatchIterator.pre_align" - ): - from areal.engine.router_replay_utils import ( - _r3_hash64, - _r3_per_sample_hashes, - _r3_per_sample_nnz, - _r3_per_sample_seq_real_len, - _r3_current_trace_id, - ) - logger.info( - "[R3-STAGE3/_R3MicroBatchIterator] PRE-ALIGN " - "mb_idx=%d trace_id=%d %s orig_cu_src=%s " - "max_seqlen=%d re_shape=%s re_hash=%s " - "per_sample_hash[:16]=%s per_sample_nnz[:16]=%s " - "per_sample_real_len[:16]=%s " - "orig_cu_diff[:16]=%s padded_cu_diff[:16]=%s " - "| %s | %s | %s", - idx, - _r3_current_trace_id(), - _r3_pp_tp_info(), - orig_cu_src, - max_seqlen, - tuple(re.shape), - hex(_r3_hash64(re)), - [hex(h) for h in _r3_per_sample_hashes(re, max_rows=16)], - _r3_per_sample_nnz(re, max_rows=16), - _r3_per_sample_seq_real_len(re, max_rows=16), - (orig_cu[1:] - orig_cu[:-1]).long().cpu().tolist()[:16] - if hasattr(orig_cu, "cpu") else "N/A", - (cu_seqlens[1:] - cu_seqlens[:-1]).long().cpu().tolist()[:16] - if hasattr(cu_seqlens, "cpu") else "N/A", - _r3_tensor_sig("re", re, max_sample=4), - _r3_tensor_sig("orig_cu", orig_cu), - _r3_tensor_sig("padded_cu", cu_seqlens), - ) # Align routed_experts from left-padded to left-aligned # using the ORIGINAL cu_seqlens (actual token counts). aligned_re = _align_routed_experts_to_mask( - re, orig_cu, max_seqlen, _r3_mb_idx=idx, + re, + orig_cu, + max_seqlen, + _r3_mb_idx=idx, ) # Pass the PADDED cu_seqlens (with TP alignment) @@ -756,56 +474,6 @@ def __next__(self): model_config, seq_align_to=_seq_align_to, ) - # ---------- R3 diagnostics: per-mb queue-depth - # snapshot RIGHT AFTER the dispatch finishes. - # Under PP=1 every router has fwd_q==1 here; under - # PP=2 1F1B the depth oscillates 1..PP_size. - try: - from areal.engine.router_replay_patch import ( - RouterReplay as _RR, - ) - from areal.engine.router_replay_utils import ( - _r3_should_log as _sl2, - _r3_verbose as _v2, - ) - if _v2() and _sl2( - "_R3MicroBatchIterator/post_dispatch_queue_audit" - ): - router_list = ( - RouterReplayHelper.get_micro_batch_router_list( - model_config - ) - ) - fwd_qs = [ - len(getattr(r, "replay_backward_list", []) or []) - for r in router_list - ] - push_qs = [ - len( - getattr(r, "replay_push_meta_list", []) or [] - ) - for r in router_list - ] - logger.info( - "[R3-STAGE3/_R3MicroBatchIterator] " - "POST_DISPATCH_QUEUE_AUDIT mb_idx=%d %s " - "n_routers=%d fwd_q_lens=%s push_meta_q_lens=%s " - "max_fwd_q=%d min_fwd_q=%d " - "lens_locked=%s", - idx, - _r3_pp_tp_info(), - len(router_list), - fwd_qs, - push_qs, - max(fwd_qs) if fwd_qs else -1, - min(fwd_qs) if fwd_qs else -1, - fwd_qs == push_qs, - ) - except Exception: - logger.exception( - "[R3-STAGE3/_R3MicroBatchIterator] " - "POST_DISPATCH_QUEUE_AUDIT diag log failed" - ) except Exception: logger.warning( "[R3] Failed to setup replay for micro-batch %d.", @@ -818,8 +486,14 @@ def __next__(self): "micro-batch %d; skipping replay setup. " "Keys in orig_mb: %s, keys in padded_mb: %s.", idx, - list(mb_item.orig_mb.keys()) if hasattr(mb_item, "orig_mb") and isinstance(mb_item.orig_mb, dict) else "N/A", - list(mb_item.padded_mb.keys()) if hasattr(mb_item, "padded_mb") and isinstance(mb_item.padded_mb, dict) else "N/A", + list(mb_item.orig_mb.keys()) + if hasattr(mb_item, "orig_mb") + and isinstance(mb_item.orig_mb, dict) + else "N/A", + list(mb_item.padded_mb.keys()) + if hasattr(mb_item, "padded_mb") + and isinstance(mb_item.padded_mb, dict) + else "N/A", ) return mb_item @@ -847,74 +521,18 @@ def __iter__(self_inner): # 4. Register a forward hook for REPLAY_FORWARD -> REPLAY_BACKWARD toggle. # ------------------------------------------------------------------ hook_handles = [] - # ---------- R3 diagnostics: D8 -- track which model chunks fire - # the post-forward hook. Under PP=2 + VP, multiple model chunks - # share the fbfunc; if any chunk's hook misses, the action toggle - # is skipped and backward pops see REPLAY_FORWARD, silently - # returning live routing. A mismatch between len(self.model) and - # hook_fire_counts[chunk_id] is a smoking gun. - _r3_hook_fire_counts: dict[int, int] = {} - _r3_toggle_count = {"n": 0} def _r3_post_forward_hook(module, input, output): """Switch from REPLAY_FORWARD to REPLAY_BACKWARD after model forward.""" - _chunk_id = id(module) - _r3_hook_fire_counts[_chunk_id] = _r3_hook_fire_counts.get(_chunk_id, 0) + 1 if RouterReplayHelper.is_replay_forward_action(model_config): - router_list = RouterReplayHelper.get_micro_batch_router_list( - model_config - ) + router_list = RouterReplayHelper.get_micro_batch_router_list(model_config) for router in router_list: - router.set_router_replay_action( - RouterReplayAction.REPLAY_BACKWARD - ) - _r3_toggle_count["n"] += 1 - if _r3_verbose() and _r3_should_log("_r3_post_forward_hook"): - logger.info( - "[R3-STAGE3/_r3_post_forward_hook] TOGGLE forward->backward " - "%s n_routers=%d mb_counter=%d chunk_id=%d " - "fire_count_this_chunk=%d total_toggles=%d " - "n_chunks_seen=%d", - _r3_pp_tp_info(), - len(router_list), - getattr(self, "_r3_mb_counter", -1), - _chunk_id, - _r3_hook_fire_counts[_chunk_id], - _r3_toggle_count["n"], - len(_r3_hook_fire_counts), - ) - else: - # Hook fired but action was not REPLAY_FORWARD -- this is - # expected after the first toggle under 1F1B (subsequent mbs - # see REPLAY_BACKWARD until the iterator flips them back). - # We still log rarely to confirm behavior. - if _r3_verbose() and _r3_should_log( - "_r3_post_forward_hook/no_toggle" - ): - logger.info( - "[R3-STAGE3/_r3_post_forward_hook] NO_TOGGLE " - "(already backward or cleared) %s mb_counter=%d " - "chunk_id=%d fire_count_this_chunk=%d " - "n_chunks_seen=%d", - _r3_pp_tp_info(), - getattr(self, "_r3_mb_counter", -1), - _chunk_id, - _r3_hook_fire_counts[_chunk_id], - len(_r3_hook_fire_counts), - ) + router.set_router_replay_action(RouterReplayAction.REPLAY_BACKWARD) for model_chunk in self.model: handle = model_chunk.register_forward_hook(_r3_post_forward_hook) hook_handles.append(handle) - # ---------- R3 diagnostics: reset FB-level aggregate counters so - # the end-of-FB summary reflects only this call. Safe to reset - # unconditionally: consumers read the dict inside the FB span. - try: - RouterReplay._r3_fb_stats = {} - except Exception: - pass - try: self._r3_original_forward_backward_batch( mb_list, process_output_fn, forward_only=forward_only @@ -927,50 +545,6 @@ def _r3_post_forward_hook(module, input, output): # class swap done above). The original class was never modified. mb_list.__class__ = _r3_original_mb_list_class - # ---------- R3 diagnostics: END_OF_FB summary. One line per - # forward_backward_batch call aggregating: - # * divergence_v1 (MATCH vs popped-vs-latest-target; under - # PP=2 1F1B, DIVERGE is EXPECTED — logging artifact). - # * divergence_v2 (MATCH vs popped-vs-own-push; under any - # PP layout MATCH is REQUIRED — REAL_MISMATCH is a bug). - # * hook fire counts per model-chunk: if unequal, one chunk - # missed its toggle (D8 smoking gun). - # * final queue residue across all routers (D9 smoking gun). - try: - if _r3_verbose() and _r3_should_log("_r3_forward_backward_batch/EXIT_SUMMARY"): - from areal.engine.router_replay_patch import RouterReplay as _RR - _fwd_q = [ - len(getattr(r, "replay_backward_list", []) or []) - for r in _RR.router_instances - ] - _push_q = [ - len(getattr(r, "replay_push_meta_list", []) or []) - for r in _RR.router_instances - ] - logger.info( - "[R3-STAGE3/_r3_forward_backward_batch] EXIT_SUMMARY %s " - "n_mbs=%d forward_only=%s trace_id=%s " - "fb_stats=%s n_model_chunks=%d hook_fire_counts=%s " - "total_toggles=%d residual_fwd_q=%s residual_push_q=%s " - "residual_max=%d", - _r3_pp_tp_info(), - len(mb_list), - forward_only, - _consumed_trace_id, - dict(_RR._r3_fb_stats), - len(self.model), - dict(_r3_hook_fire_counts), - _r3_toggle_count["n"], - _fwd_q, - _push_q, - max(_fwd_q) if _fwd_q else -1, - ) - except Exception: - logger.exception( - "[R3-STAGE3/_r3_forward_backward_batch] " - "EXIT_SUMMARY diag log failed" - ) - clear_router_replay() self._r3_per_mb_experts = None self._r3_mb_counter = 0 diff --git a/areal/engine/router_replay_patch.py b/areal/engine/router_replay_patch.py index 001fbfd629..07e5d0d100 100644 --- a/areal/engine/router_replay_patch.py +++ b/areal/engine/router_replay_patch.py @@ -87,15 +87,6 @@ class RouterReplay: # Set by the engine patch before forward_backward_func. pp_size: int = 1 - # ---------- R3 diagnostics (PP=2 root-cause hunt) ---------- - # Per forward_backward_batch aggregate counters. Keys set by the - # REPLAY_BACKWARD/consume path and the forward-hook; values are reset - # at FB entry by _r3_forward_backward_batch and dumped at FB exit. - # This gives one END-OF-FB summary line to diff PP=1 vs PP=2 runs. - # Always class-level dict (no cross-rank comm — each rank maintains - # its own, which is the correctness unit). - _r3_fb_stats: dict[str, int] = {} - # ------------------------------------------------------------------ # Class-level (static) helpers # ------------------------------------------------------------------ @@ -134,22 +125,12 @@ def set_global_router_replay_action(action: RouterReplayAction) -> None: """Set the replay action for all router instances.""" for r in RouterReplay.router_instances: r.set_router_replay_action(action) - try: - from areal.engine.router_replay_utils import ( - _r3_pp_tp_info, - _r3_should_log, - ) - - if _r3_should_log(f"set_global_router_replay_action/{action.value}"): - logger.info( - "[R3-STAGE4/set_global_router_replay_action] %s action=%s " - "applied_to=%d router_instances", - _r3_pp_tp_info(), - action.value, - len(RouterReplay.router_instances), - ) - except Exception: - pass + logger.debug( + "[R3] set_global_router_replay_action: action=%s " + "applied_to=%d router_instances", + action.value, + len(RouterReplay.router_instances), + ) @staticmethod def clear_global_router_replay_action() -> None: @@ -207,84 +188,7 @@ def set_target_indices( self.target_valid_mask = valid_mask self.replay_backward_list.append(topk_indices) self.replay_backward_mask_list.append(valid_mask) - # ---------- R3 diagnostics: capture push metadata at the SAME - # call site so REPLAY_BACKWARD/consume can later prove that the - # popped slab equals the slab that was originally pushed (the - # only correctness criterion). Hashing here is gated by - # _r3_should_log so steady-state cost is one int + one None. - # --------------------------------------------------------------- - try: - from areal.engine.router_replay_utils import ( - _r3_current_trace_id as _tid, - _r3_hash64 as _h64, - _r3_should_log as _sl, - _r3_verbose as _v, - ) - if _v() and _sl("RouterReplay.set_target_indices/push_meta"): - _slab_h = hex(_h64(topk_indices)) - _mask_h = ( - hex(_h64(valid_mask.to(torch.int32))) - if valid_mask is not None else "None" - ) - self.replay_push_meta_list.append({ - "push_id": getattr(self, "_r3_push_counter", 0), - "trace_id": _tid(), - "slab_hash": _slab_h, - "mask_hash": _mask_h, - "slab_shape": tuple(topk_indices.shape), - }) - self._r3_push_counter = getattr(self, "_r3_push_counter", 0) + 1 - else: - # Always append a placeholder so list lengths stay locked - # to ``replay_backward_list``; pop side will skip None. - self.replay_push_meta_list.append(None) - except Exception: - try: - self.replay_push_meta_list.append(None) - except Exception: - pass - # Cheap diagnostic: record every set in first few layers/mb. Gated - # via _r3_should_log so steady-state overhead is ~one integer - # increment. - try: - from areal.engine.router_replay_utils import ( - _r3_current_trace_id, - _r3_hash64, - _r3_pp_tp_info, - _r3_should_log, - _r3_tensor_sig, - _r3_verbose, - _r3_zero_row_stats, - ) - - if _r3_verbose() and _r3_should_log("RouterReplay.set_target_indices"): - # instance index in the class-level list tells us which - # MoE layer this replay slot refers to - try: - inst_idx = RouterReplay.router_instances.index(self) - except ValueError: - inst_idx = -1 - _slab_hash = hex(_r3_hash64(topk_indices)) - _mask_hash = ( - hex(_r3_hash64(valid_mask.to(torch.int32))) - if valid_mask is not None else "None" - ) - logger.info( - "[R3-STAGE3b/set_target_indices] trace_id=%d inst#%d %s %s " - "slab_shape=%s slab_hash=%s mask_hash=%s " - "| %s | backward_queue_len=%d (post-push)", - _r3_current_trace_id(), - inst_idx, - _r3_pp_tp_info(), - _r3_zero_row_stats(topk_indices), - tuple(topk_indices.shape), - _slab_hash, - _mask_hash, - _r3_tensor_sig("topk_indices", topk_indices), - len(self.replay_backward_list), - ) - except Exception: - pass + self.replay_push_meta_list.append(None) def get_recorded_indices(self) -> torch.Tensor | None: return self.recorded_topk_idx @@ -293,27 +197,6 @@ def record_indices(self, topk_indices: torch.Tensor) -> None: self.recorded_topk_idx = topk_indices def clear_indices(self) -> None: - # ---------- R3 diagnostics: dump tail-state queue sizes BEFORE - # clearing so residual queues (a smoking gun for lost backward - # pops under PP=2 1F1B) are always visible in logs. - try: - from areal.engine.router_replay_utils import ( - _r3_pp_tp_info, - _r3_should_log, - _r3_verbose, - ) - if _r3_verbose() and _r3_should_log("RouterReplay.clear_indices/tail_state"): - logger.info( - "[R3-STAGE3c/clear_indices] %s inst#%d fwd_q=%d " - "mask_q=%d push_meta_q=%d", - _r3_pp_tp_info(), - self.creation_order, - len(self.replay_backward_list), - len(self.replay_backward_mask_list), - len(getattr(self, "replay_push_meta_list", []) or []), - ) - except Exception: - pass self.recorded_topk_idx = None self.target_topk_idx = None self.target_valid_mask = None @@ -333,80 +216,6 @@ def clear_router_replay_action(self) -> None: # =================================================================== -def _R3_routing_log( - action_name: str, - *, - scores: torch.Tensor, - top_indices: torch.Tensor, - topk: int, - compute_topk_fn, - num_groups=None, - group_topk=None, -) -> None: - """Rate-limited diagnostic for the replay branches. - - Key quantities: - * ``shape_match`` -- does target_topk_idx align with this layer's - token count? If NOT, replay is being fed the wrong slab. - * ``zero_rows`` -- fraction of all-zero rows in the replay - indices; zero rows collapse routing to expert 0. - * ``live_vs_replay`` -- overlap between replay top-k and the live - top-k the router would have picked right now. 100% = no staleness - (rollout weights == train weights). 0% = total mismatch. - """ - from areal.engine.router_replay_utils import ( - _r3_call_count, - _r3_pp_tp_info, - _r3_should_log, - _r3_tensor_sig, - _r3_verbose, - _r3_zero_row_stats, - _R3_ROUTER_LAYER_LIMIT, - ) - - if not _r3_verbose(): - return - key = f"patched_routing/{action_name}" - call_n = _r3_call_count(key) - # We always want an early, concentrated burst of per-layer details at - # startup (helps catch first-step config problems) and then a sparse - # steady-state sample. - if not _r3_should_log(key): - return - with torch.no_grad(): - shape_match = top_indices.shape[0] == scores.shape[0] - if shape_match: - try: - _, live_top = compute_topk_fn( - scores, topk, num_groups=num_groups, group_topk=group_topk - ) - # per-token overlap ratio - set_live = live_top.sort(dim=-1).values - set_rep = top_indices.sort(dim=-1).values - # equality per (token, slot) - overlap = (set_live == set_rep).float().mean().item() - except Exception as e: - overlap = f"err:{e}" - else: - overlap = None - logger.info( - "[R3-STAGE4/patched_routing] %s call#%d %s " - "scores_shape=%s topk=%d target_shape=%s shape_match=%s " - "live_vs_replay_topk_overlap=%s %s | %s | %s", - action_name, - call_n, - _r3_pp_tp_info(), - tuple(scores.shape), - topk, - tuple(top_indices.shape), - shape_match, - overlap, - _r3_zero_row_stats(top_indices), - _r3_tensor_sig("scores", scores, max_sample=4), - _r3_tensor_sig("top_indices", top_indices, max_sample=8), - ) - - def _patched_topk_routing_with_score_function( logits: torch.Tensor, topk: int, @@ -463,38 +272,7 @@ def compute_topk(scores, topk, num_groups=None, group_topk=None): # Use the provided indices for replay top_indices = router_replay.target_topk_idx top_indices = top_indices.to(scores.device) - # splice padded rows with the LIVE router top-k so - # that TP-alignment / batch padding slack (which was recorded as - # all-zeros) does not force those rows to expert 0. valid_mask = getattr(router_replay, "target_valid_mask", None) - try: - from areal.engine.router_replay_utils import ( - _r3_current_trace_id as _tid, - _r3_hash64 as _h64, - _r3_should_log as _sl, - _r3_verbose as _v, - ) - if _v() and _sl("REPLAY_FORWARD/consume"): - try: - _inst_idx = RouterReplay.router_instances.index(router_replay) - except ValueError: - _inst_idx = -1 - logger.info( - "[R3-STAGE4/REPLAY_FORWARD/consume] trace_id=%d inst#%d " - "scores_shape=%s target_shape=%s shape_match=%s " - "target_hash=%s mask_hash=%s backward_queue_len=%d", - _tid(), - _inst_idx, - tuple(scores.shape), - tuple(top_indices.shape), - top_indices.shape[0] == scores.shape[0], - hex(_h64(top_indices)), - "None" if valid_mask is None - else hex(_h64(valid_mask.to(torch.int32))), - len(router_replay.replay_backward_list), - ) - except Exception: - pass if valid_mask is not None and valid_mask.shape[0] == top_indices.shape[0]: _, live_top = _compute_topk( scores, topk, num_groups=num_groups, group_topk=group_topk @@ -504,15 +282,6 @@ def compute_topk(scores, topk, num_groups=None, group_topk=None): top_indices, live_top, ) - _R3_routing_log( - "REPLAY_FORWARD", - scores=scores, - top_indices=top_indices, - topk=topk, - compute_topk_fn=_compute_topk, - num_groups=num_groups, - group_topk=group_topk, - ) probs = scores.gather(1, top_indices) return probs, top_indices @@ -526,10 +295,6 @@ def compute_topk(scores, topk, num_groups=None, group_topk=None): ) return _compute_topk(scores, topk, num_groups=num_groups, group_topk=group_topk) # Use the last recorded indices for backward replay - _bw_queue_len_before = len(router_replay.replay_backward_list) - _bw_mask_queue_len_before = len( - getattr(router_replay, "replay_backward_mask_list", []) or [] - ) top_indices = router_replay.replay_backward_list.pop(0) top_indices = top_indices.to(scores.device) # pop the matching per-row validity mask (if any) @@ -547,149 +312,11 @@ def compute_topk(scores, topk, num_groups=None, group_topk=None): # the slab that was REGISTERED AT PUSH TIME (the real # correctness criterion under 1F1B PP scheduling). _push_meta_list = getattr(router_replay, "replay_push_meta_list", None) - _bw_push_meta = None if _push_meta_list: try: - _bw_push_meta = _push_meta_list.pop(0) + _push_meta_list.pop(0) except IndexError: - _bw_push_meta = None - # ---- R3 deep-trace: log backward pop order + hashes ---- - try: - from areal.engine.router_replay_utils import ( - _r3_current_trace_id as _tid, - _r3_hash64 as _h64, - _r3_should_log as _sl, - _r3_verbose as _v, - ) - - if _v() and _sl("REPLAY_BACKWARD/consume"): - try: - _inst_idx = RouterReplay.router_instances.index(router_replay) - except ValueError: - _inst_idx = -1 - _popped_slab_hash = hex(_h64(top_indices)) - _popped_mask_hash = ( - "None" - if bw_valid_mask is None - else hex(_h64(bw_valid_mask.to(torch.int32))) - ) - _target_hash = ( - "None" - if router_replay.target_topk_idx is None - else hex(_h64(router_replay.target_topk_idx)) - ) - _divergence = ( - "None" - if router_replay.target_topk_idx is None - else ( - "MATCH" - if ( - router_replay.target_topk_idx.shape - == top_indices.shape - and hex( - _h64( - router_replay.target_topk_idx.to( - top_indices.device - ) - ) - ) - == _popped_slab_hash - ) - else "DIVERGE_vs_FWD_TARGET" - ) - ) - logger.info( - "[R3-STAGE4/REPLAY_BACKWARD/consume] trace_id=%d inst#%d " - "scores_shape=%s popped_shape=%s shape_match_scores=%s " - "popped_slab_hash=%s popped_mask_hash=%s " - "current_target_hash=%s divergence=%s " - "queue_len_before=%d queue_len_after=%d " - "mask_queue_len_before=%d mask_queue_len_after=%d " - "push_meta=%s divergence_v2=%s", - _tid(), - _inst_idx, - tuple(scores.shape), - tuple(top_indices.shape), - top_indices.shape[0] == scores.shape[0], - _popped_slab_hash, - _popped_mask_hash, - _target_hash, - _divergence, - _bw_queue_len_before, - len(router_replay.replay_backward_list), - _bw_mask_queue_len_before, - len( - getattr(router_replay, "replay_backward_mask_list", []) - or [] - ), - _bw_push_meta, - # divergence_v2 is the DEFINITIVE verdict: it - # compares popped slab against the slab recorded - # at the matching push site, not against the - # most recent (potentially overwritten) target. - # MATCH here = backward queue is correct under PP. - ( - "NO_PUSH_META" - if _bw_push_meta is None - else ( - "MATCH" - if _bw_push_meta.get("slab_hash") == _popped_slab_hash - else "REAL_MISMATCH" - ) - ), - ) - except Exception: - logger.exception( - "[R3-STAGE4/REPLAY_BACKWARD/consume] trace log failed" - ) - # ---------- R3 diagnostics: FB-level aggregate counters - # (gated by _r3_verbose so prod path is untouched). Counters - # are reset at _r3_forward_backward_batch entry and dumped at - # exit, giving one summary line per FB call. - try: - from areal.engine.router_replay_utils import ( - _r3_hash64 as _h64x, - _r3_verbose as _vx, - ) - if _vx(): - _stats = RouterReplay._r3_fb_stats - if router_replay.target_topk_idx is None: - _stats["divergence_v1_none"] = ( - _stats.get("divergence_v1_none", 0) + 1 - ) - else: - _v1_match = ( - router_replay.target_topk_idx.shape == top_indices.shape - and _h64x( - router_replay.target_topk_idx.to(top_indices.device) - ) == _h64x(top_indices) - ) - _stats["divergence_v1_match" if _v1_match - else "divergence_v1_diverge"] = ( - _stats.get( - "divergence_v1_match" if _v1_match - else "divergence_v1_diverge", 0, - ) + 1 - ) - if _bw_push_meta is None: - _stats["divergence_v2_no_meta"] = ( - _stats.get("divergence_v2_no_meta", 0) + 1 - ) - else: - _v2_match = ( - _bw_push_meta.get("slab_hash") - == hex(_h64x(top_indices)) - ) - _stats["divergence_v2_match" if _v2_match - else "divergence_v2_real_mismatch"] = ( - _stats.get( - "divergence_v2_match" if _v2_match - else "divergence_v2_real_mismatch", 0, - ) + 1 - ) - _stats["bw_pop_total"] = _stats.get("bw_pop_total", 0) + 1 - except Exception: - pass + pass if ( bw_valid_mask is not None and bw_valid_mask.shape[0] == top_indices.shape[0] @@ -702,15 +329,6 @@ def compute_topk(scores, topk, num_groups=None, group_topk=None): top_indices, live_top, ) - _R3_routing_log( - "REPLAY_BACKWARD", - scores=scores, - top_indices=top_indices, - topk=topk, - compute_topk_fn=_compute_topk, - num_groups=num_groups, - group_topk=group_topk, - ) probs = scores.gather(1, top_indices) return probs, top_indices diff --git a/areal/engine/router_replay_utils.py b/areal/engine/router_replay_utils.py index 22e6c17131..e26dede301 100644 --- a/areal/engine/router_replay_utils.py +++ b/areal/engine/router_replay_utils.py @@ -597,56 +597,6 @@ def set_router_replay_data( total_aligned = sum(aligned_lens) - if _r3_verbose() and _r3_should_log("set_router_replay_data/ENTER"): - logger.info( - "[R3-STAGE3/set_router_replay_data] ENTER call#%d trace_id=%d %s " - "layers_topk_idx=(bs=%d, max_seq=%d, L=%d, K=%d) dtype=%s " - "n_cu_entries=%d n_seqs_in_cu=%d seq_align_to=%d " - "seq_lens[:16]=%s aligned_lens[:16]=%s total_aligned=%d " - "vp_rank=%s | %s", - _r3_call_count("set_router_replay_data/ENTER"), - _r3_current_trace_id(), - _r3_pp_tp_info(tf_config, vp_rank), - bs_re, - layers_topk_idx.shape[1], - num_layers, - topk, - layers_topk_idx.dtype, - n_cu_entries, - n_seqs_in_cu, - seq_align_to, - seq_lens[:16], - aligned_lens[:16], - total_aligned, - vp_rank, - _r3_tensor_sig("cu_seqlens", cu_seqlens), - ) - # Per-sample fingerprint (hash, nnz, real_len) so we can verify - # the SAME bytes reach here as the actor pushed into the - # side-channel. Any mismatch between hashes implies a - # split/reorder bug somewhere upstream. - if _r3_verbose() and _r3_should_log("set_router_replay_data/PER_SAMPLE"): - try: - _h = _r3_per_sample_hashes(layers_topk_idx, max_rows=32) - _nnz = _r3_per_sample_nnz(layers_topk_idx, max_rows=32) - _rl = _r3_per_sample_seq_real_len(layers_topk_idx, max_rows=32) - logger.info( - "[R3-STAGE3/set_router_replay_data] PER_SAMPLE trace_id=%d %s " - "bs_re=%d n_seqs_in_cu=%d " - "per_sample_hash[:16]=%s per_sample_nnz_rows[:16]=%s " - "per_sample_real_len[:16]=%s cu_seqlens_diff[:16]=%s", - _r3_current_trace_id(), - _r3_pp_tp_info(tf_config, vp_rank), - bs_re, - n_seqs_in_cu, - [hex(h) for h in _h[:16]], - _nnz[:16], - _rl[:16], - seq_lens[:16], - ) - except Exception as e: - logger.warning("[R3-STAGE3/set_router_replay_data] PER_SAMPLE err=%s", e) - # Pack routed_experts using cu_seqlens-aligned layout. # layers_topk_idx is left-ALIGNED: real tokens at positions [0, seq_len). # For each sequence i, we take the first seq_lens[i] tokens and place @@ -727,56 +677,6 @@ def set_router_replay_data( n_strike = int((valid_mask & row_all_zero).sum().item()) valid_mask = valid_mask & (~row_all_zero) - if _r3_verbose() and _r3_should_log("set_router_replay_data/PACKED"): - with torch.no_grad(): - # Count global all-zero rows across ALL layers AND topk slots. - zrows = int(row_all_zero.sum().item()) - total_rows = int(row_all_zero.numel()) - n_valid = int(valid_mask.sum().item()) - # Per-sample valid-row count (after strike), lined up with - # aligned_lens so any off-by-one immediately surfaces. - per_sample_valid_after = [] - per_sample_valid_before = [] - _off = 0 - for _i in range(n_seqs_in_cu): - _al = aligned_lens[_i] if _i < len(aligned_lens) else 0 - _seg = valid_mask[_off : _off + _al] - _segz = row_all_zero[_off : _off + _al] - per_sample_valid_after.append(int(_seg.sum().item())) - per_sample_valid_before.append( - int((~_segz[: seq_lens[_i] if _i < len(seq_lens) else 0]).sum().item()) - if _i < len(seq_lens) - else 0 - ) - _off += _al - logger.info( - "[R3-STAGE3/set_router_replay_data] PACKED trace_id=%d %s " - "packed=(total_aligned=%d, L=%d, K=%d) global_zero_rows=%d/%d " - "(%.2f%%) valid_rows=%d/%d (%.2f%%) struck_tail_rows=%d " - "per_sample_valid_before_strike[:16]=%s " - "per_sample_valid_after_strike[:16]=%s " - "per_sample_real_len[:16]=%s aligned_lens[:16]=%s " - "packed_hash=%s | %s", - _r3_current_trace_id(), - _r3_pp_tp_info(tf_config, vp_rank), - packed.shape[0], - packed.shape[1], - packed.shape[2], - zrows, - total_rows, - 100.0 * zrows / max(total_rows, 1), - n_valid, - total_rows, - 100.0 * n_valid / max(total_rows, 1), - n_strike, - per_sample_valid_before[:16], - per_sample_valid_after[:16], - seq_lens[:16], - aligned_lens[:16], - hex(_r3_hash64(packed)), - _r3_tensor_sig("packed", packed), - ) - # Step 2: CP split (before TP scatter). # # When ``cp_size > 1``, megatron-core's @@ -804,21 +704,6 @@ def set_router_replay_data( valid_mask = split_packed_seqs_for_context_parallel( valid_mask.to(torch.int32), cu_seqlens_dev ).bool() - if _r3_verbose() and _r3_should_log("set_router_replay_data/CP_SPLIT"): - with torch.no_grad(): - n_after_cp_valid = int(valid_mask.sum().item()) - logger.info( - "[R3-STAGE3/set_router_replay_data] CP_SPLIT trace_id=%d %s " - "cp_size=%d post_cp_packed=%s post_cp_valid=%d/%d " - "post_cp_packed_hash=%s", - _r3_current_trace_id(), - _r3_pp_tp_info(tf_config, vp_rank), - cp_size, - _r3_tensor_sig("packed_after_cp", packed), - n_after_cp_valid, - valid_mask.numel(), - hex(_r3_hash64(packed)), - ) # Step 3: Scatter to SP ranks (TP) tp_size = mpu.get_tensor_model_parallel_world_size() @@ -840,96 +725,11 @@ def set_router_replay_data( # Step 4: Permute to (num_layers, local_tokens_count, topk) layers_topk = local_tokens.permute(1, 0, 2) - if _r3_verbose() and _r3_should_log("set_router_replay_data/SCATTER"): - with torch.no_grad(): - n_local_valid = int(local_mask.sum().item()) - logger.info( - "[R3-STAGE3/set_router_replay_data] POST-SCATTER trace_id=%d %s " - "tp_size=%d local_valid=%d/%d local_tokens=%s layers_topk=%s " - "local_tokens_hash=%s local_mask_hash=%s", - _r3_current_trace_id(), - _r3_pp_tp_info(tf_config, vp_rank), - tp_size, - n_local_valid, - local_mask.numel(), - _r3_tensor_sig("local_tokens", local_tokens), - _r3_tensor_sig("layers_topk", layers_topk), - hex(_r3_hash64(local_tokens)), - hex(_r3_hash64(local_mask.to(torch.int32))), - ) - # Step 5: Distribute to RouterReplay instances for local PP layers local_info = get_current_rank_layer_info(tf_config, vp_rank) offset, end = local_info["start"], local_info["end"] router_list = RouterReplayHelper.get_micro_batch_router_list(tf_config, vp_rank) - # ---------- R3 diagnostics: PP=2 root-cause hunt ---------- - # Print the full PP-rank slicing decision so the next log can - # trivially confirm: - # * tf_config.num_layers stays GLOBAL (27) under Megatron-Core PP - # -- if it ever becomes local (14/13), index_by_layer would - # still report True but ``idx=layer_idx`` would over-shoot. - # * offset/end honors get_transformer_layer_offset. - # * The set of MoE layers in [offset, end) matches the rollout - # layer-axis convention (absolute layer index, with layer 0 - # dense and recorded as zeros). - # * RouterReplay.router_instances is local-per-process: the - # selected slice (creation_order list) tells us which routers - # this PP rank actually owns. - try: - if _r3_verbose() and _r3_should_log("set_router_replay_data/PP_LAYOUT"): - moe_layers_in_range = [ - i for i in range(offset, end) if is_moe_layer(tf_config, i) - ] - non_moe_layers_in_range = [ - i for i in range(offset, end) if not is_moe_layer(tf_config, i) - ] - vp_size = getattr( - tf_config, "virtual_pipeline_model_parallel_size", None - ) - # Cheap audit: nnz of dim-0 slice for ALL layer indices - # (helps prove rollout's L-axis-0 is the dense layer and - # really is all-zero, vs. silently shifted). - with torch.no_grad(): - per_layer_nnz = [ - int((layers_topk[L] != 0).any(dim=-1).sum().item()) - if L < layers_topk.shape[0] else -1 - for L in range(min(layers_topk.shape[0], 32)) - ] - logger.info( - "[R3-STAGE3/set_router_replay_data] PP_LAYOUT trace_id=%d %s " - "tf_config.num_layers=%d vp_size=%s moe_layer_freq=%s " - "first_k_dense_replace=%s " - "local_info={start:%d, end:%d, count:%d} " - "moe_layers_in_range=%s non_moe_layers_in_range=%s " - "len(router_list)=%d total_router_instances=%d " - "selected_router_creation_orders=%s " - "selected_router_creator_ranks=%s " - "layers_topk_dim0=%d index_by_layer=%s " - "per_layer_any_nnz_first32=%s", - _r3_current_trace_id(), - _r3_pp_tp_info(tf_config, vp_rank), - tf_config.num_layers, - vp_size, - getattr(tf_config, "moe_layer_freq", None), - getattr(tf_config, "first_k_dense_replace", None), - offset, - end, - local_info["count"], - moe_layers_in_range, - non_moe_layers_in_range, - len(router_list), - len(RouterReplay.router_instances), - [getattr(r, "creation_order", -1) for r in router_list], - [getattr(r, "creator_rank", -1) for r in router_list], - layers_topk.shape[0], - len(layers_topk) == tf_config.num_layers, - per_layer_nnz, - ) - except Exception: - logger.exception("[R3-STAGE3/PP_LAYOUT] diag log failed") - # ---------------------------------------------------------- - if len(router_list) == 0: logger.warning( "[R3] set_router_replay_data: no RouterReplay instances found " @@ -947,7 +747,6 @@ def set_router_replay_data( moe_idx = sum(1 for i in range(offset) if is_moe_layer(tf_config, i)) router_offset = 0 - dispatched = [] # list of (layer_idx, idx_into_layers_topk, zero_row_stats) for layer_idx in range(offset, end): if not is_moe_layer(tf_config, layer_idx): continue @@ -972,18 +771,6 @@ def set_router_replay_data( continue slab = layers_topk[idx].to(torch.int64) router.set_target_indices(slab, valid_mask=local_mask) - if _r3_verbose() and _r3_should_log("set_router_replay_data/DISPATCH"): - dispatched.append( - ( - layer_idx, - idx, - _r3_zero_row_stats(slab), - _r3_tensor_sig(f"target[L={layer_idx}]", slab), - hex(_r3_hash64(slab)), - getattr(router, "creation_order", -1), - moe_idx, - ) - ) router_offset += 1 moe_idx += 1 @@ -997,32 +784,6 @@ def set_router_replay_data( end, tp_size, ) - if _r3_verbose() and dispatched: - # Only log first couple of dispatched layers in detail; keep - # the rest summarised. - head = dispatched[:_R3_ROUTER_LAYER_LIMIT] - logger.info( - "[R3-STAGE3/set_router_replay_data] DISPATCH trace_id=%d %s " - "router_offset=%d len(router_list)=%d index_by_layer=%s " - "first_layers=%s all_layers_to_router_map=%s " - "... (total dispatched=%d)", - _r3_current_trace_id(), - _r3_pp_tp_info(tf_config, vp_rank), - router_offset, - len(router_list), - index_by_layer, - [ - (lidx, j, zr, sig, h, co, mi) - for lidx, j, zr, sig, h, co, mi in head - ], - # Full (layer_idx, slab_idx_used, router_creation_order, - # moe_ordinal) tuple for every dispatched layer. This is - # the definitive cross-check: PP0 must be [(1,1,0,0), - # (2,2,1,1), ..., (13,13,12,12)] and PP1 must be - # [(14,14,0,13), ..., (26,26,12,25)] on Moonlight. - [(lidx, j, co, mi) for lidx, j, _, _, _, co, mi in dispatched], - len(dispatched), - ) # =================================================================== @@ -1049,19 +810,6 @@ def setup_per_microbatch_replay_forward( seq_align_to: Per-sequence TP alignment factor. """ routed_experts = routed_experts.to(torch.int32) - if _r3_verbose() and _r3_should_log("setup_per_microbatch_replay_forward"): - with torch.no_grad(): - per_row_zero = (routed_experts == 0).all(dim=-1).all(dim=-1) - logger.info( - "[R3-STAGE3/setup_per_microbatch_replay_forward] ENTER %s " - "routed_experts=%s cu_seqlens=%s seq_align_to=%s " - "per_sample_zero_rows=%s", - _r3_pp_tp_info(tf_config, vp_rank), - _r3_tensor_sig("routed_experts", routed_experts), - _r3_tensor_sig("cu_seqlens", cu_seqlens), - seq_align_to, - [int(x.sum().item()) for x in per_row_zero][:8], - ) set_router_replay_data( routed_experts, cu_seqlens, tf_config, vp_rank, seq_align_to=seq_align_to, @@ -1077,38 +825,6 @@ def setup_per_microbatch_replay_backward() -> None: def clear_router_replay() -> None: """Clear all RouterReplay state after a full forward-backward pass.""" n_instances = len(RouterReplay.router_instances) - # ---------- R3 diagnostics: dump pre-clear queue lengths so leftover - # backward pops (a smoking gun for missing recompute under PP=2 1F1B) - # are always visible. - try: - if _r3_verbose() and _r3_should_log("clear_router_replay/snapshot"): - fwd_qs = [ - len(getattr(r, "replay_backward_list", []) or []) - for r in RouterReplay.router_instances - ] - mask_qs = [ - len(getattr(r, "replay_backward_mask_list", []) or []) - for r in RouterReplay.router_instances - ] - push_qs = [ - len(getattr(r, "replay_push_meta_list", []) or []) - for r in RouterReplay.router_instances - ] - n_nonempty = sum(1 for q in fwd_qs if q > 0) - logger.info( - "[R3-STAGE3c/clear_router_replay] PRE_CLEAR_SNAPSHOT %s " - "n_instances=%d n_with_residual_fwd_q=%d " - "fwd_q_lens=%s mask_q_lens=%s push_meta_q_lens=%s", - _r3_pp_tp_info(), - n_instances, - n_nonempty, - fwd_qs, - mask_qs, - push_qs, - ) - except Exception: - logger.exception("[R3-STAGE3c/clear_router_replay] diag log failed") - # ----------------------------------------------------------------- RouterReplay.clear_global_indices() RouterReplay.clear_global_router_replay_action() logger.debug("[R3] Router replay state cleared (%d instances).", n_instances) @@ -1148,30 +864,12 @@ def preprocess_routed_experts_batch( import numpy as np if routed_experts_np is None: - if _r3_verbose(): - logger.info("[R3-STAGE1/preprocess] routed_experts_np=None, returning None") return None seq_len = input_ids.shape[1] num_sgl_tokens = routed_experts_np.shape[0] flat_dim = routed_experts_np.shape[1] - if _r3_verbose(): - logger.info( - "[R3-STAGE1/preprocess] ENTER " - "seq_len=%d num_sgl_tokens=%d flat_dim=%d num_moe_layers=%s topk=%s " - "expected_flat=%s | %s | %s | %s", - seq_len, - num_sgl_tokens, - flat_dim, - num_moe_layers, - topk, - (num_moe_layers or 0) * (topk or 0), - _r3_tensor_sig("input_ids", input_ids), - _r3_tensor_sig("attention_mask", attention_mask), - _r3_tensor_sig("routed_experts_np", routed_experts_np), - ) - expected_flat = num_moe_layers * topk if flat_dim != expected_flat: logger.warning( @@ -1254,34 +952,4 @@ def preprocess_routed_experts_batch( right_pad, ) - if _r3_verbose(): - # NOTE: for R3 correctness check. We expect num_sgl_tokens = - # real_tokens - 1 (SGLang drops the logprob of the very last - # generated token). Anything else means the routing -> token - # alignment is not what we think it is. - tail_row_is_zero = None - try: - tail_slice = padded[0, real_tokens - 1] if real_tokens > 0 else None - if tail_slice is not None: - tail_row_is_zero = bool((tail_slice == 0).all().item()) - except Exception: - tail_row_is_zero = "err" - # All-zero row stats across the seq_len axis for this sample. - with torch.no_grad(): - per_row_zero = (padded[0] == 0).all(dim=-1).all(dim=-1) - zero_rows_count = int(per_row_zero.sum().item()) - logger.info( - "[R3-STAGE1/preprocess] EXIT " - "num_sgl_tokens=%d real_tokens=%d delta=%d right_pad=%d " - "tail_real_row_all_zero=%s zero_rows_total=%d/%d | %s", - num_sgl_tokens, - real_tokens, - real_tokens - num_sgl_tokens, - right_pad, - tail_row_is_zero, - zero_rows_count, - seq_len, - _r3_tensor_sig("padded", padded), - ) - return padded diff --git a/areal/trainer/ppo/actor.py b/areal/trainer/ppo/actor.py index b83b54e3a4..9a10e296bd 100644 --- a/areal/trainer/ppo/actor.py +++ b/areal/trainer/ppo/actor.py @@ -141,83 +141,14 @@ def _compute_logp(self, data: dict[str, Any]) -> torch.Tensor | None: _r3_routed_experts, torch.Tensor ): from areal.trainer.ppo.actor_r3_patch import _resolve_to_tensor + _r3_routed_experts = _resolve_to_tensor(_r3_routed_experts) _r3_enabled = bool(getattr(self.engine, "_r3_enabled", False)) - try: - from areal.engine.router_replay_utils import ( - _r3_should_log, - _r3_tensor_sig, - _r3_verbose, - ) - - if _r3_verbose() and _r3_should_log("actor._compute_logp/ENTER"): - _re_info = ( - _r3_tensor_sig("routed_experts", _r3_routed_experts) - if _r3_routed_experts is not None - else "routed_experts=None" - ) - logger.info( - "[R3-STAGE2/actor._compute_logp] ENTER r3_enabled=%s " - "input_keys=%s batch_shape=%s | %s | %s | %s", - _r3_enabled, - sorted(list(data.keys())), - tuple(data["input_ids"].shape) - if "input_ids" in data - else "N/A", - _re_info, - _r3_tensor_sig("logprobs", data.get("logprobs")), - _r3_tensor_sig("loss_mask", data.get("loss_mask")), - ) - except Exception: - pass if _r3_routed_experts is not None and _r3_enabled: # forward_batch performs ONE forward_backward_batch(forward_only=True) # call internally; the R3 engine patch will split routed_experts per # micro-batch and consume the side-channel (setting it back to None). self.engine._r3_pending_routed_experts = _r3_routed_experts - try: - from areal.engine.router_replay_utils import ( - _r3_hash64, - _r3_next_trace_id, - _r3_per_sample_hashes, - _r3_per_sample_nnz, - _r3_per_sample_seq_real_len, - _r3_pp_tp_info, - _r3_tensor_sig, - _r3_verbose, - ) - - _trace_id = _r3_next_trace_id() - self.engine._r3_active_trace_id = _trace_id - if _r3_verbose(): - logger.info( - "[R3-STAGE2/actor._compute_logp] SIDE_CHANNEL_SET " - "trace_id=%d %s bs=%d seqlen=%s L=%s K=%s " - "hash=%s per_sample_hash[:16]=%s " - "per_sample_nnz[:16]=%s per_sample_real_len[:16]=%s " - "attn_sum[:16]=%s | %s", - _trace_id, - _r3_pp_tp_info(), - _r3_routed_experts.shape[0], - _r3_routed_experts.shape[1], - _r3_routed_experts.shape[2] - if _r3_routed_experts.ndim >= 3 else None, - _r3_routed_experts.shape[3] - if _r3_routed_experts.ndim >= 4 else None, - hex(_r3_hash64(_r3_routed_experts)), - [hex(h) for h in _r3_per_sample_hashes( - _r3_routed_experts, max_rows=16)], - _r3_per_sample_nnz(_r3_routed_experts, max_rows=16), - _r3_per_sample_seq_real_len(_r3_routed_experts, max_rows=16), - ( - data["attention_mask"].sum(dim=-1).long().cpu().tolist()[:16] - if "attention_mask" in data - else "N/A" - ), - _r3_tensor_sig("routed_experts", _r3_routed_experts), - ) - except Exception: - logger.exception("[R3-STAGE2/actor._compute_logp] side-channel trace log failed") train_logp = self.engine.forward( input_=data, aggregate_fn=lambda xs: torch.cat(xs, dim=-1), @@ -226,7 +157,9 @@ def _compute_logp(self, data: dict[str, Any]) -> torch.Tensor | None: # correlate with per-MB output shapes logged by forward_batch. try: import torch.distributed as _dist + from areal.utils.logging import getLogger as _getLogger + _diag = _getLogger("R3FwdDiag") _diag.info( "[SPLIT_MISMATCH_DIAG][actor._compute_logp] POST_FORWARD " @@ -325,125 +258,6 @@ def _log_r3_effectiveness_stats( extreme_tau2 = (abs_log_ratio > torch.log(torch.tensor(2.0))).float() extreme_tau5 = (abs_log_ratio > torch.log(torch.tensor(5.0))).float() - try: - from areal.engine.router_replay_utils import ( - _r3_current_trace_id, - _r3_hash64, - _r3_per_sample_hashes, - _r3_should_log, - _r3_tensor_sig, - _r3_verbose, - ) - - if _r3_verbose() and _r3_should_log( - "actor._log_r3_effectiveness_stats" - ): - with torch.no_grad(): - n_valid = int(shifted_mask.sum().item()) - if n_valid > 0: - _masked = abs_diff[shifted_mask] - _mean_abs = float(_masked.mean().item()) - _max_abs = float(_masked.max().item()) - _p99 = float( - torch.quantile( - _masked.float(), 0.99 - ).item() - ) if _masked.numel() > 0 else 0.0 - _mean_k3 = float(k3_kl[shifted_mask].mean().item()) - _frac_tau2 = float( - extreme_tau2[shifted_mask].mean().item() - ) - _frac_tau5 = float( - extreme_tau5[shifted_mask].mean().item() - ) - else: - _mean_abs = _max_abs = _p99 = _mean_k3 = _frac_tau2 = _frac_tau5 = 0.0 - logger.info( - "[R3-STAGE2/r3_effectiveness] trace_id=%d r3_enabled=%s " - "n_valid_tokens=%d mean_abs_diff=%.6f max_abs_diff=%.6f " - "p99_abs_diff=%.6f mean_k3_kl=%.6f frac_tau2=%.4f " - "frac_tau5=%.4f | %s | %s", - _r3_current_trace_id(), - r3_enabled, - n_valid, - _mean_abs, - _max_abs, - _p99, - _mean_k3, - _frac_tau2, - _frac_tau5, - _r3_tensor_sig("train_logp", train_logp), - _r3_tensor_sig("rollout_logp_rolled", rollout_logp_f), - ) - # ---- R3 per-sample breakdown: identify catastrophic samples ---- - with torch.no_grad(): - bs = shifted_mask.shape[0] - max_rows = min(bs, 64) - per_sample = [] - for i in range(max_rows): - m_i = shifted_mask[i] - n_i = int(m_i.sum().item()) - if n_i == 0: - per_sample.append( - { - "i": i, - "n": 0, - "mean_abs": 0.0, - "max_abs": 0.0, - "tau2_cnt": 0, - "tau5_cnt": 0, - "k3_mean": 0.0, - } - ) - continue - row_abs = abs_diff[i][m_i] - row_k3 = k3_kl[i][m_i] - per_sample.append( - { - "i": i, - "n": n_i, - "mean_abs": float(row_abs.mean().item()), - "max_abs": float(row_abs.max().item()), - "tau2_cnt": int( - (row_abs > torch.log(torch.tensor(2.0))).sum().item() - ), - "tau5_cnt": int( - (row_abs > torch.log(torch.tensor(5.0))).sum().item() - ), - "k3_mean": float(row_k3.mean().item()), - } - ) - # Routed-experts per-sample hashes (side-channel payload) - pending = getattr(self.engine, "_r3_pending_routed_experts", None) - re_hashes = ( - [hex(h) for h in _r3_per_sample_hashes(pending, max_rows=max_rows)] - if pending is not None - else [] - ) - re_full_hash = ( - hex(_r3_hash64(pending)) if pending is not None else "None" - ) - # Bad sample ranking (top-K by mean_abs) - sorted_bad = sorted( - per_sample, key=lambda x: x["mean_abs"], reverse=True - )[:8] - logger.info( - "[R3-STAGE2/r3_effectiveness/per_sample] trace_id=%d " - "r3_enabled=%s batch_size=%d routed_experts_full_hash=%s " - "per_sample=%s top_bad_samples=%s per_sample_routed_hash=%s", - _r3_current_trace_id(), - r3_enabled, - bs, - re_full_hash, - per_sample, - sorted_bad, - re_hashes, - ) - except Exception: - logger.exception( - "[R3-STAGE2/r3_effectiveness] per-sample trace log failed" - ) - with stats_tracker.scope("compute_logp"): with stats_tracker.scope("r3"): stats_tracker.denominator( @@ -659,11 +473,12 @@ def _ppo_update(self, data: dict[str, Any]) -> None: _r3_routed_experts, torch.Tensor ): from areal.trainer.ppo.actor_r3_patch import _resolve_to_tensor + _r3_routed_experts = _resolve_to_tensor(_r3_routed_experts) if _r3_routed_experts is not None: _re_np = _r3_routed_experts.cpu().numpy() _nonzero = _re_np[_re_np > 0] - logger.info( + logger.debug( "[R3-VERIFY] Actor received routed_experts: " "shape=%s, dtype=%s, nonzero_count=%d/%d, " "nonzero_first3=%s, max=%d, hash=%d", @@ -684,42 +499,19 @@ def _ppo_update(self, data: dict[str, Any]) -> None: # R3: Split routed_experts per mini-batch for side-channel delivery. _r3_split = None if _r3_routed_experts is not None: - from areal.trainer.ppo.actor_r3_patch import split_routed_experts_for_minibatches + from areal.trainer.ppo.actor_r3_patch import ( + split_routed_experts_for_minibatches, + ) + _r3_split = split_routed_experts_for_minibatches( _r3_routed_experts, mb_inputs ) - logger.info( + logger.debug( "[R3] Split routed_experts for %d mini-batches via side-channel " "(shapes: %s).", len(mb_inputs.mbs), [s.shape if s is not None else None for s in _r3_split], ) - try: - from areal.engine.router_replay_utils import ( - _r3_should_log, - _r3_tensor_sig, - _r3_verbose, - ) - - if _r3_verbose() and _r3_should_log("actor._ppo_update/split"): - logger.info( - "[R3-STAGE2/actor._ppo_update] SPLIT " - "n_ppo_minibatches=%d per_mb_shapes=%s " - "forward_indices=%s | %s", - len(mb_inputs.mbs), - [ - None if s is None else tuple(s.shape) - for s in _r3_split - ], - "None" - if mb_inputs.forward_indices is None - else f"len={len(mb_inputs.forward_indices)}", - _r3_tensor_sig( - "_r3_routed_experts", _r3_routed_experts, max_sample=4 - ), - ) - except Exception: - pass with stats_tracker.scope("update"): # Get current version for proximal approximation metrics @@ -735,62 +527,6 @@ def _ppo_update(self, data: dict[str, Any]) -> None: if i < len(_r3_split) and _r3_split[i] is not None else None ) - try: - from areal.engine.router_replay_utils import ( - _r3_hash64, - _r3_next_trace_id, - _r3_per_sample_hashes, - _r3_per_sample_nnz, - _r3_per_sample_seq_real_len, - _r3_pp_tp_info, - _r3_should_log, - _r3_tensor_sig, - _r3_verbose, - ) - - _trace_id = _r3_next_trace_id() - self.engine._r3_active_trace_id = _trace_id - _slice = ( - _r3_split[i] - if i < len(_r3_split) - else None - ) - if _r3_verbose(): - logger.info( - "[R3-STAGE2/actor._ppo_update] " - "SIDE_CHANNEL_SET trace_id=%d mb=%d " - "current_version=%s %s " - "slice_shape=%s hash=%s " - "per_sample_hash[:16]=%s " - "per_sample_nnz[:16]=%s " - "per_sample_real_len[:16]=%s " - "mb_attn_sum[:16]=%s | %s", - _trace_id, - i, - current_version, - _r3_pp_tp_info(), - None if _slice is None else tuple(_slice.shape), - hex(_r3_hash64(_slice)), - [hex(h) for h in _r3_per_sample_hashes( - _slice, max_rows=16)], - _r3_per_sample_nnz(_slice, max_rows=16), - _r3_per_sample_seq_real_len(_slice, max_rows=16), - ( - mb["attention_mask"].sum(dim=-1).long().cpu().tolist()[:16] - if isinstance(mb, dict) and "attention_mask" in mb - else "N/A" - ), - _r3_tensor_sig( - "pending_routed_experts", - _slice, - max_sample=4, - ), - ) - except Exception: - logger.exception( - "[R3-STAGE2/actor._ppo_update] " - "SIDE_CHANNEL_SET trace log failed", - ) else: logger.warning( "[R3] routed_experts available but engine._r3_enabled " diff --git a/areal/trainer/ppo/actor_r3_patch.py b/areal/trainer/ppo/actor_r3_patch.py index 9cc562b156..e9ad996769 100644 --- a/areal/trainer/ppo/actor_r3_patch.py +++ b/areal/trainer/ppo/actor_r3_patch.py @@ -1,3 +1,5 @@ +# SPDX-License-Identifier: Apache-2.0 + """ R3 data-splitting helpers for the PPO actor. @@ -91,58 +93,6 @@ def split_routed_experts_for_minibatches( [r.shape[0] for r in result], "None" if forward_indices is None else f"len={len(forward_indices)}", ) - try: - from areal.engine.router_replay_utils import ( - _r3_hash64, - _r3_per_sample_hashes, - _r3_per_sample_nnz, - _r3_pp_tp_info, - _r3_should_log, - _r3_tensor_sig, - _r3_verbose, - ) - - if _r3_verbose() and _r3_should_log( - "split_routed_experts_for_minibatches" - ): - # Pre-reorder per-sample hashes (what we *started* with) and - # post-reorder per-sample hashes (what each mini-batch gets). - pre_hash = _r3_per_sample_hashes(routed_experts, max_rows=32) - post_hash = _r3_per_sample_hashes(reordered, max_rows=32) - mb_hashes = [ - [hex(h) for h in _r3_per_sample_hashes(r, max_rows=16)] - for r in result - ] - mb_nnz = [_r3_per_sample_nnz(r, max_rows=16) for r in result] - logger.info( - "[R3-STAGE2/split_routed_experts_for_minibatches] %s " - "input_shape=%s input_hash=%s n_mbs=%d forward_indices=%s " - "per_mb_shapes=%s per_mb_hashes=%s " - "pre_reorder_per_sample_hash[:16]=%s " - "post_reorder_per_sample_hash[:16]=%s " - "per_mb_per_sample_hash=%s per_mb_per_sample_nnz=%s | %s", - _r3_pp_tp_info(), - tuple(routed_experts.shape), - hex(_r3_hash64(routed_experts)), - n_mbs, - "None" - if forward_indices is None - else ( - f"len={len(forward_indices)} " - f"first32={forward_indices[:32].tolist() if hasattr(forward_indices,'tolist') else list(forward_indices)[:32]}" - ), - [tuple(r.shape) for r in result], - [hex(_r3_hash64(r)) for r in result], - [hex(h) for h in pre_hash[:16]], - [hex(h) for h in post_hash[:16]], - mb_hashes, - mb_nnz, - _r3_tensor_sig("routed_experts", routed_experts, max_sample=4), - ) - except Exception: - logger.exception( - "[R3-STAGE2/split_routed_experts_for_minibatches] trace log failed" - ) return result diff --git a/areal/workflow/rlvr.py b/areal/workflow/rlvr.py index f5e38aa7b4..ae356b8b0a 100644 --- a/areal/workflow/rlvr.py +++ b/areal/workflow/rlvr.py @@ -196,29 +196,5 @@ async def arun_episode( topk=self.r3_topk, ) res = inject_routed_experts_into_result(res, routed_experts_tensor) - try: - from areal.engine.router_replay_utils import ( - _r3_should_log, - _r3_tensor_sig, - _r3_verbose, - ) - - if _r3_verbose() and _r3_should_log("rlvr.arun_episode"): - logger.info( - "[R3-STAGE1/rlvr.arun_episode] INJECT " - "r3_num_moe_layers=%s r3_topk=%s " - "resp.routed_experts.shape=%s input_len=%d " - "output_len=%d | %s", - self.r3_num_moe_layers, - self.r3_topk, - getattr(resp.routed_experts, "shape", None), - resp.input_len, - resp.output_len, - _r3_tensor_sig( - "routed_experts_tensor", routed_experts_tensor - ), - ) - except Exception: - pass return res diff --git a/areal/workflow/rlvr_r3_patch.py b/areal/workflow/rlvr_r3_patch.py index 68ca8ed7ec..346be74db4 100644 --- a/areal/workflow/rlvr_r3_patch.py +++ b/areal/workflow/rlvr_r3_patch.py @@ -1,3 +1,5 @@ +# SPDX-License-Identifier: Apache-2.0 + """ R3 helpers for the RLVR workflow. @@ -17,7 +19,6 @@ from __future__ import annotations - import numpy as np import torch @@ -168,28 +169,6 @@ def extract_routed_experts( topk=topk, compress_dtype=compress_dtype, ) - try: - from areal.engine.router_replay_utils import ( - _r3_should_log, - _r3_tensor_sig, - _r3_verbose, - ) - - if _r3_verbose() and _r3_should_log("extract_routed_experts"): - logger.info( - "[R3-STAGE1/extract_routed_experts] " - "num_moe_layers=%d topk=%d " - "input_ids_shape=%s attn_sum=%d " - "np_shape=%s | %s", - num_moe_layers, - topk, - tuple(input_ids.shape), - int(attention_mask.sum().item()), - getattr(routed_experts_np, "shape", None), - _r3_tensor_sig("result", result), - ) - except Exception: - pass return result except Exception: logger.warning( From 850041b44d3e997faf26cd86726185d65a6df96d Mon Sep 17 00:00:00 2001 From: bingyechen Date: Thu, 7 May 2026 13:11:44 +0800 Subject: [PATCH 106/112] refactor: remove --- areal/engine/megatron_engine_r3_patch.py | 82 +---- areal/engine/router_replay_patch.py | 40 +-- areal/engine/router_replay_utils.py | 374 +---------------------- areal/infra/launcher/sglang_r3_patch.py | 71 +---- areal/trainer/ppo/actor_r3_patch.py | 15 +- areal/workflow/rlvr_r3_patch.py | 19 +- 6 files changed, 46 insertions(+), 555 deletions(-) diff --git a/areal/engine/megatron_engine_r3_patch.py b/areal/engine/megatron_engine_r3_patch.py index 23c1b0e18f..9df1664b47 100644 --- a/areal/engine/megatron_engine_r3_patch.py +++ b/areal/engine/megatron_engine_r3_patch.py @@ -1,28 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 -""" -R3 Integration Patch for MegatronEngine. - -This module wraps ``MegatronEngine.forward_backward_batch`` so that, when -the micro-batch data contains ``routed_experts`` tensors, each micro-batch's -forward step is preceded by a call to ``setup_per_microbatch_replay_forward`` -and followed (after the full forward pass) by a switch to backward-replay -mode. - -The patch handles the critical issue that ``routed_experts`` is a 4D tensor -``(bs, seq_len, num_moe_layers, topk)`` which will NOT be correctly split by -``split_padded_tensor_dict_into_mb_list`` (which only splits tensors with -``numel() == bs * max_seqlen``). Instead, we extract ``routed_experts`` -from ``mb_list.data`` before micro-batch splitting, and manually distribute -it to each micro-batch using the ``forward_indices`` and ``group_lens`` -from ``MicroBatchList``. - -Usage:: - - from areal.engine.megatron_engine_r3_patch import patch_megatron_engine_for_r3 - patch_megatron_engine_for_r3(engine, enable_router_replay=True) +"""R3 Integration Patch for MegatronEngine. -Ref some code from megatron or verl, adapted for AReaL. +Wraps ``MegatronEngine.forward_backward_batch`` to inject per-microbatch +replay setup/teardown around the Megatron pipeline schedule when the +micro-batch data contains ``routed_experts`` tensors. """ from __future__ import annotations @@ -139,15 +121,6 @@ def _align_routed_experts_to_mask( n = min(actual_len, re_seqlen, max_seqlen) aligned[i, :n] = routed_experts[i, :n] - logger.debug( - "[R3] _align_routed_experts_to_mask: re_shape=%s -> aligned_shape=%s, " - "n_seqs=%d (re_bs=%d), seq_lens=%s.", - routed_experts.shape, - aligned.shape, - n_seqs, - re_bs, - seq_lens[:8], - ) return aligned @@ -234,14 +207,6 @@ def _split_routed_experts_for_mbs( result.append(reordered[offset : offset + n_samples]) offset += n_samples - logger.debug( - "[R3] _split_routed_experts_for_mbs: split %d samples into %d mbs " - "with sizes %s (forward_indices=%s).", - routed_experts.shape[0], - n_mbs, - [r.shape[0] for r in result], - "None" if forward_indices is None else f"len={len(forward_indices)}", - ) return result @@ -315,7 +280,7 @@ def _r3_forward_backward_batch( routed_experts_batch = None _from_side_channel = False - # Strategy A: Side-channel (preferred path) + # Side-channel path (preferred). if ( hasattr(self, "_r3_pending_routed_experts") and self._r3_pending_routed_experts is not None @@ -323,21 +288,11 @@ def _r3_forward_backward_batch( routed_experts_batch = self._r3_pending_routed_experts self._r3_pending_routed_experts = None # Consume it _from_side_channel = True - logger.debug( - "[R3] Retrieved routed_experts from engine side-channel: shape=%s.", - routed_experts_batch.shape, - ) # Strategy B: Legacy path from mb_list.data (backward compatibility) if routed_experts_batch is None and not forward_only: if hasattr(mb_list, "data") and isinstance(mb_list.data, dict): routed_experts_batch = mb_list.data.pop("routed_experts", None) - if routed_experts_batch is not None: - logger.debug( - "[R3] Retrieved routed_experts from mb_list.data (legacy path): " - "shape=%s.", - routed_experts_batch.shape, - ) # Clean from mbs and padded_mbs to avoid confusing downstream code. for mb_dict in mb_list.mbs: @@ -351,16 +306,6 @@ def _r3_forward_backward_batch( mb_list.data.pop("routed_experts", None) if routed_experts_batch is None: - if forward_only: - logger.debug( - "[R3] forward_only=True and no side-channel routed_experts; " - "skipping R3 replay (compute_logp/eval path)." - ) - else: - logger.debug( - "[R3] No routed_experts found (neither side-channel nor mb_list.data); " - "using original forward_backward_batch." - ) return self._r3_original_forward_backward_batch( mb_list, process_output_fn, forward_only=forward_only ) @@ -375,9 +320,7 @@ def _r3_forward_backward_batch( # Split routed_experts per micro-batch per_mb_routed_experts = _split_routed_experts_for_mbs(routed_experts_batch, mb_list) - # ------------------------------------------------------------------ - # 2. Store R3 data on the engine for the wrapped iterator. - # ------------------------------------------------------------------ + # Store R3 data for the wrapped iterator. self._r3_per_mb_experts = per_mb_routed_experts self._r3_mb_counter = 0 model_config = self.tf_config @@ -420,8 +363,7 @@ def __next__(self): else None ) - # When backward recompute finishes and next forward starts, - # switch back to REPLAY_FORWARD. + # Switch back to REPLAY_FORWARD after backward recompute. if RouterReplayHelper.is_replay_backward_action(model_config): router_list = RouterReplayHelper.get_micro_batch_router_list( model_config @@ -457,8 +399,7 @@ def __next__(self): if orig_cu is None: orig_cu = cu_seqlens - # Align routed_experts from left-padded to left-aligned - # using the ORIGINAL cu_seqlens (actual token counts). + # Align from left-padded to left-aligned using original cu_seqlens. aligned_re = _align_routed_experts_to_mask( re, orig_cu, @@ -466,8 +407,7 @@ def __next__(self): _r3_mb_idx=idx, ) - # Pass the PADDED cu_seqlens (with TP alignment) - # to set_router_replay_data so packing matches Megatron. + # Pass padded cu_seqlens (TP-aligned) to the model. setup_per_microbatch_replay_forward( aligned_re.to(cu_seqlens.device), cu_seqlens, @@ -517,9 +457,7 @@ def __iter__(self_inner): mb_list.__class__ = _R3MicroBatchListProxy - # ------------------------------------------------------------------ - # 4. Register a forward hook for REPLAY_FORWARD -> REPLAY_BACKWARD toggle. - # ------------------------------------------------------------------ + # Register forward hook for REPLAY_FORWARD → REPLAY_BACKWARD toggle. hook_handles = [] def _r3_post_forward_hook(module, input, output): diff --git a/areal/engine/router_replay_patch.py b/areal/engine/router_replay_patch.py index 07e5d0d100..a96e0bf575 100644 --- a/areal/engine/router_replay_patch.py +++ b/areal/engine/router_replay_patch.py @@ -1,12 +1,8 @@ -""" -Monkey-patches for Megatron-Core MoE components to support Router Replay (R3). - -Router Replay forces the TopKRouter to use pre-recorded expert assignments -(from rollout inference) instead of computing new ones during training. -This eliminates the train/inference routing mismatch caused by weight -staleness in asynchronous RL training. +"""Monkey-patches for Megatron-Core MoE components to support Router Replay (R3). -Ref some code from megatron or verl, adapted for AReaL. +Forces TopKRouter to use pre-recorded expert assignments from rollout inference +instead of computing new ones during training, eliminating train/inference routing +mismatch caused by weight staleness in async RL training. """ from __future__ import annotations @@ -125,12 +121,6 @@ def set_global_router_replay_action(action: RouterReplayAction) -> None: """Set the replay action for all router instances.""" for r in RouterReplay.router_instances: r.set_router_replay_action(action) - logger.debug( - "[R3] set_global_router_replay_action: action=%s " - "applied_to=%d router_instances", - action.value, - len(RouterReplay.router_instances), - ) @staticmethod def clear_global_router_replay_action() -> None: @@ -150,14 +140,7 @@ def __init__(self) -> None: # back to the live router output so they produce no replay signal. self.target_valid_mask: torch.Tensor | None = None self.replay_backward_mask_list: list[torch.Tensor | None] = [] - # ---------- R3 diagnostics (PP=2 root-cause hunt) ---------- - # Per-push metadata ring -- parallel to ``replay_backward_list`` so - # the BACKWARD consumer can compare the popped slab against the - # slab that was REGISTERED AT THAT PUSH (not the most recent - # target, which under 1F1B scheduling has been overwritten by a - # later mb). This cleanly separates "logging artifact" from "real - # backward queue corruption". See code-rules/distributed.md hang - # section -- the metadata ring is local state, no cross-rank comm. + # Keep push_meta queue in sync with backward_list. self.replay_push_meta_list: list[dict] = [] self.creation_order: int = len(RouterReplay.router_instances) try: @@ -272,6 +255,7 @@ def compute_topk(scores, topk, num_groups=None, group_topk=None): # Use the provided indices for replay top_indices = router_replay.target_topk_idx top_indices = top_indices.to(scores.device) + # Splice padding rows with live routing to avoid routing to expert 0. valid_mask = getattr(router_replay, "target_valid_mask", None) if valid_mask is not None and valid_mask.shape[0] == top_indices.shape[0]: _, live_top = _compute_topk( @@ -294,23 +278,15 @@ def compute_topk(scores, topk, num_groups=None, group_topk=None): "falling back to normal routing." ) return _compute_topk(scores, topk, num_groups=num_groups, group_topk=group_topk) - # Use the last recorded indices for backward replay + # Backward recompute: use the recorded indices from the forward pass. top_indices = router_replay.replay_backward_list.pop(0) top_indices = top_indices.to(scores.device) - # pop the matching per-row validity mask (if any) - # so the backward recompute sees the same spliced indices as the - # original forward pass. Without this, activation-checkpoint - # recomputation re-introduces the all-zero padding rows and the - # gradient path contradicts the forward pass. bw_mask_list = getattr(router_replay, "replay_backward_mask_list", None) if bw_mask_list: bw_valid_mask = bw_mask_list.pop(0) else: bw_valid_mask = None - # ---------- R3 diagnostics: pop the matching push-meta entry - # so the divergence verdict below compares popped-slab against - # the slab that was REGISTERED AT PUSH TIME (the real - # correctness criterion under 1F1B PP scheduling). + # Keep push_meta queue in sync with backward_list. _push_meta_list = getattr(router_replay, "replay_push_meta_list", None) if _push_meta_list: try: diff --git a/areal/engine/router_replay_utils.py b/areal/engine/router_replay_utils.py index e26dede301..d6ffa19fbe 100644 --- a/areal/engine/router_replay_utils.py +++ b/areal/engine/router_replay_utils.py @@ -1,14 +1,7 @@ -""" -Router Replay Utilities for AReaL. - -Handles the complete shape-transformation pipeline that converts rollout -routing indices into the layout expected by Megatron-Core's RouterReplay: +"""Router Replay (R3) utilities for AReaL. -1. **Right-padding to left-alignment** -- rollout batch is right-padded; - training uses left-aligned packed format. -2. **TP/SP splitting** -- sequence parallelism across tensor-model-parallel ranks. -3. **PP layer slicing** -- pipeline parallelism assigns different layers to ranks. -4. **Dense/MoE layer mapping** -- architectures with dense FFN layers before MoE. +Converts rollout routing indices into Megatron-Core's RouterReplay layout: +right-pad → left-align, TP/SP scatter, PP layer slicing, dense/MoE mapping. """ from __future__ import annotations @@ -28,322 +21,14 @@ logger = logging.getLogger("R3/utils") -# =================================================================== -# R3 detailed-logging helpers -# =================================================================== -# These helpers are used by EVERY R3 file (this module, router_replay_patch, -# megatron_engine_r3_patch, actor_r3_patch, actor.py, rlvr_r3_patch) so that -# all stages of the pipeline produce fingerprints in a consistent format. -# -# Controls (all opt-in via env vars so prod perf is not affected unless you -# deliberately enable): -# -# AREAL_R3_VERBOSE=1 -- master switch; enables everything below. -# Default: 1 (ON) so that if someone cares -# to run with R3 and grep logs, they do -# not need to set anything extra. -# AREAL_R3_LOG_FIRST_N=30 -- for rate-limited hot paths, always log -# the first N calls per key. -# AREAL_R3_LOG_EVERY=100 -- after the first N, log every Nth call. -# AREAL_R3_TENSOR_SAMPLE=8 -- how many leading elements to include in -# a tensor signature. -# AREAL_R3_ROUTER_LAYER_LIMIT=3 -- in patched_topk_routing, only print -# per-layer details for up to the first -# K routing calls of each type per step -# (layer idx is approximated via a -# per-action counter). -# =================================================================== +# R3 verbose-logging helpers (controlled via AREAL_R3_VERBOSE env var). def _r3_verbose() -> bool: + """Return whether R3 verbose logging is enabled.""" return os.environ.get("AREAL_R3_VERBOSE", "1") != "0" -_R3_LOG_CALL_COUNTS: dict[str, int] = {} -_R3_LOG_FIRST_N = int(os.environ.get("AREAL_R3_LOG_FIRST_N", "30")) -_R3_LOG_EVERY = int(os.environ.get("AREAL_R3_LOG_EVERY", "100")) -_R3_TENSOR_SAMPLE = int(os.environ.get("AREAL_R3_TENSOR_SAMPLE", "8")) -_R3_ROUTER_LAYER_LIMIT = int(os.environ.get("AREAL_R3_ROUTER_LAYER_LIMIT", "3")) - - -def _r3_should_log(key: str) -> bool: - """Rate-limited logging gate. Returns True for the first - ``AREAL_R3_LOG_FIRST_N`` calls against ``key``, then True once every - ``AREAL_R3_LOG_EVERY`` calls thereafter. Monotonic per-process counter. - """ - if not _r3_verbose(): - return False - n = _R3_LOG_CALL_COUNTS.get(key, 0) + 1 - _R3_LOG_CALL_COUNTS[key] = n - if n <= _R3_LOG_FIRST_N: - return True - return (n % max(_R3_LOG_EVERY, 1)) == 0 - - -def _r3_call_count(key: str) -> int: - return _R3_LOG_CALL_COUNTS.get(key, 0) - - -def _r3_tensor_sig(name: str, t, *, max_sample: int | None = None) -> str: - """Compact human-readable fingerprint for a tensor or numpy array. - - Intentionally cheap: performs ONE ``.detach().cpu()`` copy and one - reduction so it is safe to call from hot paths (still, prefer to gate - via ``_r3_should_log``). - """ - if t is None: - return f"{name}=None" - sample_n = _R3_TENSOR_SAMPLE if max_sample is None else max_sample - try: - if isinstance(t, torch.Tensor): - tc = t.detach() - if tc.device.type != "cpu": - tc = tc.to("cpu", non_blocking=False) - flat = tc.reshape(-1) - total = int(flat.numel()) - if total == 0: - return f"{name}(shape={tuple(t.shape)}, dtype={t.dtype}, empty)" - nnz = int((flat != 0).sum().item()) - if tc.dtype in ( - torch.float16, - torch.float32, - torch.float64, - torch.bfloat16, - ): - checksum = float(flat.float().double().sum().item()) - maxv = float(flat.float().abs().max().item()) - sample = [round(v, 6) for v in flat[:sample_n].float().tolist()] - else: - checksum = int(flat.long().sum().item()) - maxv = int(flat.long().abs().max().item()) - sample = flat[:sample_n].tolist() - return ( - f"{name}(shape={tuple(t.shape)}, dtype={t.dtype}, " - f"device={t.device}, nnz={nnz}/{total}, " - f"sum={checksum}, |max|={maxv}, first{len(sample)}={sample})" - ) - # numpy or generic array-like - if hasattr(t, "shape") and hasattr(t, "dtype"): - import numpy as np - - arr = t if isinstance(t, np.ndarray) else np.asarray(t) - flat = arr.reshape(-1) - total = int(flat.size) - if total == 0: - return f"{name}(shape={arr.shape}, dtype={arr.dtype}, empty, numpy)" - nnz = int((flat != 0).sum()) - checksum = int(flat.astype("int64").sum()) if np.issubdtype( - arr.dtype, np.integer - ) else float(flat.astype("float64").sum()) - maxv = ( - int(np.abs(flat).max()) - if np.issubdtype(arr.dtype, np.integer) - else float(np.abs(flat).max()) - ) - sample = flat[:sample_n].tolist() - return ( - f"{name}(shape={tuple(arr.shape)}, dtype={arr.dtype}, numpy, " - f"nnz={nnz}/{total}, sum={checksum}, |max|={maxv}, " - f"first{len(sample)}={sample})" - ) - except Exception as e: # pragma: no cover - diagnostic helper must not raise - return f"{name}=" - return f"{name}={type(t).__name__}" - - -def _r3_zero_row_stats(top_indices: torch.Tensor) -> str: - """Returns a string describing the count of all-zero rows in a - ``(num_tokens, topk)`` target_topk_idx tensor. All-zero rows are the - smoking gun for zero-fill hazard. - """ - if top_indices is None or top_indices.ndim < 2: - return "zero_row_stats=N/A" - try: - with torch.no_grad(): - zero_rows = (top_indices == 0).all(dim=-1) - total = int(zero_rows.numel()) - z = int(zero_rows.sum().item()) - return f"zero_rows={z}/{total} ({100.0*z/max(total,1):.2f}%)" - except Exception as e: - return f"zero_row_stats=" - - -def _r3_pp_tp_info(tf_config=None, vp_rank=None) -> str: - """Short PP/TP/DP/SP/EP context string for log lines.""" - try: - from megatron.core import parallel_state as mpu - - pp = mpu.get_pipeline_model_parallel_world_size() - pp_rank = mpu.get_pipeline_model_parallel_rank() - tp = mpu.get_tensor_model_parallel_world_size() - tp_rank = mpu.get_tensor_model_parallel_rank() - cp = getattr(mpu, "get_context_parallel_world_size", lambda: 1)() - dp = mpu.get_data_parallel_world_size() - dp_rank = mpu.get_data_parallel_rank() - return ( - f"pp={pp_rank}/{pp} tp={tp_rank}/{tp} cp={cp} dp={dp_rank}/{dp}" - + (f" vp={vp_rank}" if vp_rank is not None else "") - ) - except Exception: - return "pp=?/? tp=?/?" - - -# =================================================================== -# Root-cause hunting helpers (v2 — per-sample & per-mb fingerprints) -# =================================================================== -# We need to pinpoint whether the R3 replay indices that reach the router -# are the SAME bytes as those the rollout engine produced for the SAME -# sample. The cheapest reliable way to do this is a per-sample 64-bit -# fold-hash of the int32 tensor bytes. The hash is stable across device -# (we move to CPU once), preserves per-sample order, and survives -# reordering (each sample is hashed independently — we can then check -# any permutation via the multiset of hashes). -# -# We also expose a monotonically increasing trace-id that the actor -# increments every time it sets ``engine._r3_pending_routed_experts`` -# so each end-to-end replay can be correlated across STAGE2 → STAGE3 → -# STAGE4 log lines. -# =================================================================== - - -# Global monotonically increasing trace-id. Incremented at the side-channel -# SET site (actor._compute_logp / actor._ppo_update). Read back at the -# CONSUMPTION site in ``_r3_forward_backward_batch``. Exported via an -# env-independent module-level function so *all* stages print the same id. -_R3_TRACE_ID: int = 0 - - -def _r3_next_trace_id() -> int: - """Reserve & return a new trace-id. - - A trace-id identifies one SIDE_CHANNEL-SET -> CONSUME -> REPLAY cycle. - """ - global _R3_TRACE_ID - _R3_TRACE_ID += 1 - return _R3_TRACE_ID - - -def _r3_current_trace_id() -> int: - return _R3_TRACE_ID - - -def _r3_hash64(t) -> int: - """Return a stable 64-bit hash of a tensor/ndarray's int32 bytes. - - For a ``(bs, seqlen, L, K)`` routed_experts tensor this is cheap - (one CPU copy) and deterministic regardless of dtype conversion. - Returns 0 for ``None``. - """ - if t is None: - return 0 - try: - if isinstance(t, torch.Tensor): - tc = t.detach() - if tc.device.type != "cpu": - tc = tc.to("cpu", non_blocking=False) - arr = tc.to(torch.int32).contiguous().numpy() - else: - import numpy as np - - arr = (t if isinstance(t, np.ndarray) else np.asarray(t)).astype("int32", copy=False) - import hashlib - - return int.from_bytes( - hashlib.blake2b(arr.tobytes(), digest_size=8).digest(), - "big", - signed=False, - ) - except Exception: - return -1 - - -def _r3_per_sample_hashes(t, max_rows: int = 64) -> list[int]: - """Return per-sample 64-bit hashes. - - For a 4D ``(bs, seqlen, L, K)`` tensor, returns one hash per sample - (dim-0). For 3D packed ``(total_aligned, L, K)`` returns one hash - per row (capped at ``max_rows`` to keep log size sane). - """ - if t is None: - return [] - try: - if isinstance(t, torch.Tensor): - tc = t.detach() - if tc.device.type != "cpu": - tc = tc.to("cpu", non_blocking=False) - arr = tc.to(torch.int32).contiguous().numpy() - else: - import numpy as np - - arr = (t if isinstance(t, np.ndarray) else np.asarray(t)).astype("int32", copy=False) - import hashlib - - out = [] - for i in range(min(arr.shape[0], max_rows)): - b = arr[i].tobytes() - out.append( - int.from_bytes( - hashlib.blake2b(b, digest_size=8).digest()[:4], - "big", - signed=False, - ) - ) - return out - except Exception: - return [-1] - - -def _r3_per_sample_nnz(t, max_rows: int = 64) -> list[int]: - """Return per-sample non-zero counts (rows where any expert id != 0).""" - if t is None: - return [] - try: - if isinstance(t, torch.Tensor): - tc = t.detach() - if tc.device.type != "cpu": - tc = tc.to("cpu", non_blocking=False) - arr = tc.to(torch.int32).contiguous().numpy() - else: - import numpy as np - - arr = (t if isinstance(t, np.ndarray) else np.asarray(t)).astype("int32", copy=False) - out = [] - for i in range(min(arr.shape[0], max_rows)): - out.append(int((arr[i] != 0).any(axis=-1).sum())) - return out - except Exception: - return [-1] - - -def _r3_per_sample_seq_real_len(t, max_rows: int = 64) -> list[int]: - """Return per-sample "real-looking" length = index of last non-all-zero row + 1. - - Useful for verifying that the routed_experts tensor is right-padded - as expected: the real length should equal the sample's attention - mask sum (= cu_seqlens diff). If it doesn't, alignment is off. - """ - if t is None: - return [] - try: - if isinstance(t, torch.Tensor): - tc = t.detach() - if tc.device.type != "cpu": - tc = tc.to("cpu", non_blocking=False) - arr = tc.to(torch.int32).contiguous().numpy() - else: - import numpy as np - - arr = (t if isinstance(t, np.ndarray) else np.asarray(t)).astype("int32", copy=False) - out = [] - for i in range(min(arr.shape[0], max_rows)): - row_any = (arr[i] != 0).any(axis=(-1, -2)) if arr[i].ndim >= 2 else (arr[i] != 0) - nz = row_any.nonzero()[0] - out.append(int(nz[-1]) + 1 if len(nz) else 0) - return out - except Exception: - return [-1] - - # =================================================================== # Layer computation helpers # =================================================================== @@ -678,19 +363,6 @@ def set_router_replay_data( valid_mask = valid_mask & (~row_all_zero) # Step 2: CP split (before TP scatter). - # - # When ``cp_size > 1``, megatron-core's - # ``preprocess_packed_seqs_context_parallel`` has already - # interleaved-split the model's token axis so each CP rank only sees - # ``total_aligned / cp_size`` tokens. Router replay indices MUST - # match that layout before the TP scatter; otherwise each TP rank - # would end up with ``cp_size``x too many rows and overwrite the - # wrong positions. We reuse ``split_packed_seqs_for_context_parallel`` - # with the PADDED ``cu_seqlens`` the caller provided (see caller - # comment "Pass the PADDED cu_seqlens (with TP alignment) ..."). - # - # Contract: ``cu_seqlens`` here describes the SAME packed layout as - # ``packed`` (dim-0 aligned), which is what the engine passes in. packed = packed.to(device) valid_mask = valid_mask.to(device) cp_size = getattr(mpu, "get_context_parallel_world_size", lambda: 1)() @@ -705,7 +377,7 @@ def set_router_replay_data( valid_mask.to(torch.int32), cu_seqlens_dev ).bool() - # Step 3: Scatter to SP ranks (TP) + # Step 3: Scatter to SP ranks (TP). tp_size = mpu.get_tensor_model_parallel_world_size() if tp_size > 1: from megatron.core.tensor_parallel import scatter_to_sequence_parallel_region @@ -722,10 +394,10 @@ def set_router_replay_data( # local_tokens: (local_tokens_count, num_layers, topk) # local_mask: (local_tokens_count,) - # Step 4: Permute to (num_layers, local_tokens_count, topk) + # Step 4: Permute to (num_layers, local_tokens_count, topk). layers_topk = local_tokens.permute(1, 0, 2) - # Step 5: Distribute to RouterReplay instances for local PP layers + # Step 5: Distribute to RouterReplay instances for local PP layers. local_info = get_current_rank_layer_info(tf_config, vp_rank) offset, end = local_info["start"], local_info["end"] router_list = RouterReplayHelper.get_micro_batch_router_list(tf_config, vp_rank) @@ -774,17 +446,6 @@ def set_router_replay_data( router_offset += 1 moe_idx += 1 - logger.debug( - "[R3] set_router_replay_data: distributed %d layers of replay data " - "to %d/%d router instances (PP layers %d..%d, tp_size=%d).", - router_offset, - len(router_list), - len(RouterReplay.router_instances), - offset, - end, - tp_size, - ) - # =================================================================== # Per-microbatch replay control @@ -819,15 +480,12 @@ def setup_per_microbatch_replay_forward( def setup_per_microbatch_replay_backward() -> None: """Switch to backward replay mode for activation-checkpoint recomputation.""" RouterReplay.set_global_router_replay_action(RouterReplayAction.REPLAY_BACKWARD) - logger.debug("[R3] Switched to backward replay mode.") def clear_router_replay() -> None: """Clear all RouterReplay state after a full forward-backward pass.""" - n_instances = len(RouterReplay.router_instances) RouterReplay.clear_global_indices() RouterReplay.clear_global_router_replay_action() - logger.debug("[R3] Router replay state cleared (%d instances).", n_instances) # =================================================================== @@ -895,7 +553,7 @@ def preprocess_routed_experts_batch( reshaped = routed_experts_np.reshape(num_sgl_tokens, num_moe_layers, topk) tensor = torch.from_numpy(reshaped.astype(np.int32)) - # Build (1, seq_len, num_moe_layers, topk) with RIGHT padding. + # SGLang returns one fewer token than the prompt+gen length. real_tokens = int(attention_mask.sum().item()) padded = torch.zeros(1, seq_len, num_moe_layers, topk, dtype=torch.int32) if num_sgl_tokens > real_tokens: @@ -938,18 +596,4 @@ def preprocess_routed_experts_batch( elif max_val < 32768: padded = padded.to(torch.int16) - right_pad = seq_len - real_tokens - logger.debug( - "[R3] preprocess_routed_experts_batch: shape=%s dtype=%s " - "(num_moe_layers=%d, topk=%d, sgl_tokens=%d, real_tokens=%d, " - "right_pad=%d).", - padded.shape, - padded.dtype, - num_moe_layers, - topk, - num_sgl_tokens, - real_tokens, - right_pad, - ) - return padded diff --git a/areal/infra/launcher/sglang_r3_patch.py b/areal/infra/launcher/sglang_r3_patch.py index e656d779f3..f1cd59be97 100644 --- a/areal/infra/launcher/sglang_r3_patch.py +++ b/areal/infra/launcher/sglang_r3_patch.py @@ -1,42 +1,11 @@ -"""SGLang server-side monkey patches required for AReaL's R3 (Router Replay). - -Background ----------- -When ``skip_tokenizer_init=True`` (which AReaL forces whenever R3 is enabled, -see ``rl_trainer.py``), the SGLang *scheduler* bypasses the -``DetokenizerManager`` and sends ``BatchTokenIDOutput`` directly to -``TokenizerManager``. The side effect is that the ``routed_experts`` tensor -is **not** base64-encoded by ``DetokenizerManager._extract_routed_experts`` -anymore -- it reaches ``TokenizerManager._handle_batch_output`` still as a -raw ``torch.Tensor``. - -``TokenizerManager`` then attaches the tensor verbatim to -``meta_info["routed_experts"]`` and lets FastAPI/``ORJSONResponse`` -serialize the whole response. FastAPI's ``jsonable_encoder`` does not -know how to encode ``torch.Tensor``; it silently returns an **empty -dict** (``{}``) instead of raising. The client then receives - - meta_info["routed_experts"] == {} - -and hits ``TypeError: int() argument must be a string, a bytes-like -object or a real number, not 'dict'`` when it tries -``np.asarray(routed_experts, dtype=np.int32)``. Because the error is -swallowed in ``parse_generation_response``, ``routed_experts`` becomes -``None`` and the downstream ``RemoteInfEngine`` raises:: - - RuntimeError: Requested return_routed_experts=True but received None - from SGLang - -This module installs a monkey patch on -``sglang.srt.managers.tokenizer_manager.TokenizerManager._handle_batch_output`` -that base64-encodes the tensor in-place *before* it is serialised (exactly -the same encoding that ``DetokenizerManager._extract_routed_experts`` -applies in the non-``skip_tokenizer_init`` path), so the wire format stays -consistent with both SGLang's documented behaviour and AReaL's client-side -decoder in ``areal/engine/sglang_remote.py``. - -The patch is idempotent. When R3 is disabled the patch is a no-op at -runtime because the routed-experts attribute stays ``None``. +"""SGLang server-side monkey patches for AReaL's R3 (Router Replay). + +When ``skip_tokenizer_init=True`` (forced by R3), SGLang's ``TokenizerManager`` +receives raw ``torch.Tensor`` routed_experts instead of base64-encoded strings. +FastAPI's ``jsonable_encoder`` silently converts tensors to ``{}``, breaking +the client-side decoder. This patch base64-encodes the tensor in-place before +serialization, matching the format that ``DetokenizerManager`` produces in the +non-skip path. """ from __future__ import annotations @@ -52,17 +21,12 @@ def _encode_routed_experts_for_wire(value): """Convert ``routed_experts`` to a base64 string for wire transport. - Mirrors ``sglang.srt.managers.detokenizer_manager - ._extract_routed_experts``: each request's tensor is encoded as - ``pybase64.b64encode(tensor.numpy().tobytes()).decode("utf-8")``. - - Accepts tensors, numpy arrays, or already-encoded strings; other - types are returned unchanged so the client branch can surface them. + Mirrors ``DetokenizerManager._extract_routed_experts``. Accepts tensors, + numpy arrays, or already-encoded strings; other types returned unchanged. """ if value is None: return None if isinstance(value, str): - # Already encoded by DetokenizerManager or a prior invocation. return value try: import numpy as np @@ -72,18 +36,12 @@ def _encode_routed_experts_for_wire(value): return value if isinstance(value, torch.Tensor): - # ``to("cpu")`` is a no-op when already on CPU but protects us - # against exotic device placements (e.g. CUDA tensors leaked by - # a capture buffer). ``contiguous()`` guarantees ``tobytes`` - # produces a dense layout matching ``shape`` on the decode side. arr = value.detach().to("cpu").contiguous().numpy() elif isinstance(value, np.ndarray): arr = np.ascontiguousarray(value) else: return value - # Normalise dtype to int32 so the client's ``np.frombuffer(..., int32)`` - # matches regardless of whether the capture buffer was int64. if arr.dtype != np.int32: arr = arr.astype(np.int32, copy=False) @@ -114,10 +72,7 @@ def apply_sglang_r3_patch() -> bool: original = _tm.TokenizerManager._handle_batch_output def _handle_batch_output_r3(self, recv_obj): # type: ignore[no-redef] - # Pre-encode the routed-experts tensor so the downstream FastAPI - # serialisation sees a plain string (which ``jsonable_encoder`` - # passes through untouched) instead of a ``torch.Tensor`` (which - # ``jsonable_encoder`` silently flattens to ``{}``). + # Pre-encode routed_experts tensors to base64 before FastAPI serialization. try: for attr_name in _ROUTED_EXPERTS_ATTRS: re_list = getattr(recv_obj, attr_name, None) @@ -127,9 +82,7 @@ def _handle_batch_output_r3(self, recv_obj): # type: ignore[no-redef] try: setattr(recv_obj, attr_name, encoded) except Exception: - # Some SGLang versions freeze the dataclass; fall back - # to object.__setattr__ which bypasses __slots__ / - # frozen protection. + # Frozen dataclass fallback. object.__setattr__(recv_obj, attr_name, encoded) except Exception: # pragma: no cover - defensive logger.exception( diff --git a/areal/trainer/ppo/actor_r3_patch.py b/areal/trainer/ppo/actor_r3_patch.py index e9ad996769..256f4ba6a0 100644 --- a/areal/trainer/ppo/actor_r3_patch.py +++ b/areal/trainer/ppo/actor_r3_patch.py @@ -1,10 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 -""" -R3 data-splitting helpers for the PPO actor. +"""R3 data-splitting helpers for the PPO actor. -Provides utilities for resolving ``routed_experts`` tensors and splitting -them across mini-batches for side-channel delivery to the training engine. +Resolves and splits ``routed_experts`` tensors across mini-batches for +side-channel delivery to the training engine. """ from __future__ import annotations @@ -85,14 +84,6 @@ def split_routed_experts_for_minibatches( result.append(reordered[offset : offset + n_samples]) offset += n_samples - logger.debug( - "[R3] split_routed_experts_for_minibatches: split %d samples into " - "%d mini-batches with sizes %s (forward_indices=%s).", - routed_experts.shape[0], - n_mbs, - [r.shape[0] for r in result], - "None" if forward_indices is None else f"len={len(forward_indices)}", - ) return result diff --git a/areal/workflow/rlvr_r3_patch.py b/areal/workflow/rlvr_r3_patch.py index 346be74db4..fbf1c09090 100644 --- a/areal/workflow/rlvr_r3_patch.py +++ b/areal/workflow/rlvr_r3_patch.py @@ -1,20 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 -""" -R3 helpers for the RLVR workflow. - -These functions bridge the inference-time ``ModelResponse.routed_experts`` -(a numpy array of shape ``(num_sgl_tokens, num_moe_layers * topk)``) into the -training-side tensor dict so that the Megatron engine can replay routing -decisions. - -The conversion pipeline: - 1. ``extract_routed_experts`` -- called in ``arun_episode`` right after - ``_collect_samples``. Converts the numpy array to a left-padded - torch tensor of shape ``(1, seq_len, num_moe_layers, topk)``. - 2. The tensor is added to the result dict under key ``"routed_experts"``. - 3. During training, the ``MegatronEngine`` R3 patch picks it up from - the batch data and feeds it to ``setup_per_microbatch_replay_forward``. +"""R3 helpers for the RLVR workflow. + +Bridges inference-time ``ModelResponse.routed_experts`` into the training-side +tensor dict so that the Megatron engine can replay routing decisions. """ from __future__ import annotations From 49680c50525e4413b67523ca0fbf40f7e11d377f Mon Sep 17 00:00:00 2001 From: bingyechen Date: Thu, 7 May 2026 13:32:50 +0800 Subject: [PATCH 107/112] refactor: remove diag --- areal/engine/core/train_engine.py | 69 -------- areal/engine/megatron_engine.py | 237 ---------------------------- areal/engine/router_replay_utils.py | 10 -- areal/trainer/ppo/actor.py | 32 ---- areal/utils/data.py | 73 --------- areal/utils/logging.py | 3 - 6 files changed, 424 deletions(-) diff --git a/areal/engine/core/train_engine.py b/areal/engine/core/train_engine.py index 195509e13b..9bc83c2b77 100644 --- a/areal/engine/core/train_engine.py +++ b/areal/engine/core/train_engine.py @@ -19,9 +19,6 @@ reorder_list, unpack_sequence, ) -from areal.utils.logging import getLogger as _getLogger - -_SPLIT_DIAG_LOGGER = _getLogger("R3SplitDiag") __all__ = [ "compute_total_loss_weight", @@ -142,72 +139,6 @@ def reorder_and_pad_outputs( """ res = aggregate_fn(outputs) seqlens = [output_seqlens[i] for i in mb_list.forward_indices] - # [SPLIT_MISMATCH_DIAG] Log EVERYTHING needed to root-cause the - # `split_with_sizes` mismatch reported during compute_logp. - try: - _rank = None - try: - import torch.distributed as _dist - if _dist.is_available() and _dist.is_initialized(): - _rank = _dist.get_rank() - except Exception: - _rank = None - _out_shapes = [tuple(o.shape) for o in outputs] - _out_sum0 = [int(o.shape[0]) for o in outputs if o.ndim >= 1] - _sum_seqlens = int(sum(seqlens)) - _res_shape = tuple(res.shape) - _res_dim0 = int(res.shape[0]) if res.ndim >= 1 else -1 - _fwd_idx = list(mb_list.forward_indices) - _bwd_idx = list(mb_list.backward_indices) - _mbs_lens = None - try: - _mbs_lens = [ - int(mb.get("cu_seqlens", torch.empty(0))[-1].item()) - if isinstance(mb.get("cu_seqlens", None), torch.Tensor) - and mb["cu_seqlens"].numel() > 0 - else None - for mb in getattr(mb_list, "mbs", []) - ] - except Exception: - _mbs_lens = "ERR" - _padded_lens = getattr(mb_list, "padded_to_lengths", None) - _group_lens = getattr(mb_list, "group_lens", None) - _padding_lens = getattr(mb_list, "padding_lengths", None) - _SPLIT_DIAG_LOGGER.info( - "[SPLIT_MISMATCH_DIAG][reorder_and_pad_outputs] rank=%s " - "n_outputs=%d out_shapes=%s out_sum_dim0=%s sum_out_dim0=%d " - "res.shape=%s res.dim0=%d " - "len(output_seqlens)=%d sum(output_seqlens)=%d " - "len(seqlens_reordered)=%d sum(seqlens_reordered)=%d " - "forward_indices=%s backward_indices=%s " - "mb_real_total_lens=%s padded_to_lengths=%s " - "group_lens=%s padding_lengths=%s " - "match=%s output_seqlens_head=%s output_seqlens_tail=%s", - _rank, - len(outputs), - _out_shapes, - _out_sum0, - int(sum(_out_sum0)), - _res_shape, - _res_dim0, - len(output_seqlens), - int(sum(output_seqlens)), - len(seqlens), - _sum_seqlens, - _fwd_idx, - _bwd_idx, - _mbs_lens, - _padded_lens, - _group_lens, - _padding_lens, - (_res_dim0 == _sum_seqlens), - list(output_seqlens[:16]), - list(output_seqlens[-16:]), - ) - except Exception: - _SPLIT_DIAG_LOGGER.exception( - "[SPLIT_MISMATCH_DIAG][reorder_and_pad_outputs] log-emit failed" - ) unpacked = unpack_sequence(res, lens=seqlens, dim=0) reordered = reorder_list(unpacked, mb_list.backward_indices) return pad_and_stack_tensors_along_first_dim(reordered) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index ca031a27ee..5ef2252587 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -94,9 +94,6 @@ from areal.models.tree_attn.tree import build_packed_tree_batch from areal.utils import logging, name_resolve, names, perf_tracer, stats_tracker -# [SPLIT_MISMATCH_DIAG] Module-level diagnostic logger for the -# `split_with_sizes` mismatch hunt in compute_logp. Yellow color, INFO level. -_R3_FWD_DIAG_LOGGER = logging.getLogger("R3FwdDiag") from areal.utils.constants import ( DEFAULT_VECTORIZED_ALIGNMENT_BYTES, DIST_GROUP_DEFAULT_TIMEOUT, @@ -950,44 +947,6 @@ def forward_step(batch_iter, model): _ = cp_size # kept for the dead CP-local branch below cp_local = False - # [SPLIT_MISMATCH_DIAG] Log per-MB padded input geometry BEFORE forward. - try: - _padded_mb = mb_input.padded_mb - _orig_mb = mb_input.orig_mb - _R3_FWD_DIAG_LOGGER.info( - "[SPLIT_MISMATCH_DIAG][forward_step] PRE_FWD rank=%s " - "padded_to=%s padding_length=%s " - "padded.input_ids.shape=%s " - "padded.cu_seqlens=%s " - "old_cu_seqlens=%s " - "orig.input_ids.shape=%s " - "cp_size=%d cp_local=%s", - dist.get_rank() if dist.is_initialized() else None, - getattr(mb_input, "padded_to_length", None), - getattr(mb_input, "padding_length", None), - tuple(_padded_mb["input_ids"].shape) - if "input_ids" in _padded_mb - else None, - _padded_mb["cu_seqlens"].cpu().tolist() - if "cu_seqlens" in _padded_mb - and isinstance(_padded_mb["cu_seqlens"], torch.Tensor) - else None, - mb_input.old_cu_seqlens.cpu().tolist() - if isinstance( - getattr(mb_input, "old_cu_seqlens", None), torch.Tensor - ) - else getattr(mb_input, "old_cu_seqlens", None), - tuple(_orig_mb["input_ids"].shape) - if "input_ids" in _orig_mb - else None, - cp_size, - cp_local, - ) - except Exception: - _R3_FWD_DIAG_LOGGER.exception( - "[SPLIT_MISMATCH_DIAG][forward_step PRE_FWD] log-emit failed" - ) - output = packed_context_parallel_forward( model, mb_input.padded_mb, @@ -995,26 +954,6 @@ def forward_step(batch_iter, model): is_vision_model=self.is_vision_model, ) - # [SPLIT_MISMATCH_DIAG] Log raw forward output shape (pre-unpad). - try: - _R3_FWD_DIAG_LOGGER.info( - "[SPLIT_MISMATCH_DIAG][forward_step] POST_FWD rank=%s " - "raw_output.shape=%s raw_output.dim0=%s is_pp_last=%s", - dist.get_rank() if dist.is_initialized() else None, - tuple(output.shape) if hasattr(output, "shape") else None, - int(output.shape[0]) - if hasattr(output, "shape") and output.ndim >= 1 - else -1, - mpu.is_pipeline_last_stage( - ignore_virtual=False, - vp_stage=getattr(model, "vp_stage", 0), - ), - ) - except Exception: - _R3_FWD_DIAG_LOGGER.exception( - "[SPLIT_MISMATCH_DIAG][forward_step POST_FWD] log-emit failed" - ) - # Release tree attention metadata after forward pass for key in tree_attn_keys: del mb_input.padded_mb[key] @@ -1047,40 +986,12 @@ def _process_output(input_, output_): cp_inputs["cu_seqlens"] = cp_cu_seqlens return output, functools.partial(_process_output, cp_inputs) else: - _pre_shape = ( - tuple(output.shape) if hasattr(output, "shape") else None - ) output = unpad_logits( output, padding_length=mb_input.padding_length, cu_seqlens=cu_seqlens, old_cu_seqlens=mb_input.old_cu_seqlens, ) - # [SPLIT_MISMATCH_DIAG] Log unpad result. - try: - _R3_FWD_DIAG_LOGGER.info( - "[SPLIT_MISMATCH_DIAG][forward_step] UNPAD_LOGITS " - "rank=%s pre_shape=%s post_shape=%s " - "padding_length=%s " - "cu_seqlens_last=%s old_cu_seqlens_last=%s", - dist.get_rank() if dist.is_initialized() else None, - _pre_shape, - tuple(output.shape) if hasattr(output, "shape") else None, - getattr(mb_input, "padding_length", None), - int(cu_seqlens[-1].item()) - if isinstance(cu_seqlens, torch.Tensor) - else cu_seqlens, - int(mb_input.old_cu_seqlens[-1].item()) - if isinstance( - getattr(mb_input, "old_cu_seqlens", None), - torch.Tensor, - ) - else getattr(mb_input, "old_cu_seqlens", None), - ) - except Exception: - _R3_FWD_DIAG_LOGGER.exception( - "[SPLIT_MISMATCH_DIAG][forward_step UNPAD] log-emit failed" - ) return output, functools.partial(_process_output, mb_input.orig_mb) forward_backward_func = get_forward_backward_func() @@ -1205,57 +1116,6 @@ def forward_batch( input_batched, meta = self._normalize_batch_input(input_) - # [SPLIT_MISMATCH_DIAG] Log inbound forward_batch arguments. - try: - _rank = dist.get_rank() if dist.is_initialized() else None - _is_list = isinstance(input_, list) - if _is_list: - _attn_shapes = [ - tuple(d["attention_mask"].shape) if "attention_mask" in d else None - for d in input_ - ] - _attn_widths = [ - int(d["attention_mask"].shape[-1]) - if "attention_mask" in d - else None - for d in input_ - ] - _input_summary = ( - f"list len={len(input_)} attn_shapes_head={_attn_shapes[:8]} " - f"attn_widths_head={_attn_widths[:8]} " - f"sum_attn_widths={sum(w for w in _attn_widths if w)}" - ) - else: - _ks = sorted(list(input_.keys())) if isinstance(input_, dict) else "N/A" - _am = ( - tuple(input_["attention_mask"].shape) - if isinstance(input_, dict) and "attention_mask" in input_ - else None - ) - _ii = ( - tuple(input_["input_ids"].shape) - if isinstance(input_, dict) and "input_ids" in input_ - else None - ) - _input_summary = ( - f"dict keys={_ks} attention_mask.shape={_am} input_ids.shape={_ii}" - ) - _R3_FWD_DIAG_LOGGER.info( - "[SPLIT_MISMATCH_DIAG][forward_batch] ENTER rank=%s " - "is_list=%s meta_is_None=%s output_seqlens_arg=%s | %s", - _rank, - _is_list, - meta is None, - None - if output_seqlens is None - else f"len={len(output_seqlens)} sum={sum(output_seqlens)}", - _input_summary, - ) - except Exception: - _R3_FWD_DIAG_LOGGER.exception( - "[SPLIT_MISMATCH_DIAG][forward_batch ENTER] log-emit failed" - ) - # Step 1: Prepare sequence lengths if meta is not None: assert isinstance(input_, list) @@ -1273,32 +1133,6 @@ def forward_batch( assert output_seqlens is not None batch_size = len(output_seqlens) - # [SPLIT_MISMATCH_DIAG] Log derived sequence-length structure. - try: - _diff = (cu_seqlens[1:] - cu_seqlens[:-1]).cpu().tolist() - _R3_FWD_DIAG_LOGGER.info( - "[SPLIT_MISMATCH_DIAG][forward_batch] SEQLENS rank=%s " - "batch_size=%d cu_seqlens.shape=%s cu_seqlens_total=%d " - "real_lens_head=%s real_lens_tail=%s sum_real_lens=%d " - "output_seqlens_head=%s output_seqlens_tail=%s " - "sum_output_seqlens=%d output_seqlens_path=%s", - dist.get_rank() if dist.is_initialized() else None, - batch_size, - tuple(cu_seqlens.shape), - int(cu_seqlens[-1].item()), - _diff[:16], - _diff[-16:], - int(sum(_diff)), - list(output_seqlens[:16]), - list(output_seqlens[-16:]), - int(sum(output_seqlens)), - "from_attention_mask" if meta is not None else "from_cu_seqlens_diff", - ) - except Exception: - _R3_FWD_DIAG_LOGGER.exception( - "[SPLIT_MISMATCH_DIAG][forward_batch SEQLENS] log-emit failed" - ) - # Step 2: Prepare micro-batches mb_list = self._prepare_mb_list(input_batched).to(self.device) @@ -1308,32 +1142,6 @@ def forward_batch( def process_output(output: torch.Tensor, inputs: dict[str, Any]) -> None: result = self._compute_forward_result(output, inputs) outputs.append(result) - # [SPLIT_MISMATCH_DIAG] Log per-MB collected output shape. - try: - _R3_FWD_DIAG_LOGGER.info( - "[SPLIT_MISMATCH_DIAG][forward_batch.process_output] " - "rank=%s mb_idx=%d output.shape=%s output.dim0=%d " - "input_keys=%s input_ids.shape=%s cu_seqlens=%s " - "input_ids_total=%s", - dist.get_rank() if dist.is_initialized() else None, - len(outputs) - 1, - tuple(result.shape) if hasattr(result, "shape") else None, - int(result.shape[0]) - if hasattr(result, "shape") and result.ndim >= 1 - else -1, - sorted(list(inputs.keys())) if isinstance(inputs, dict) else "N/A", - tuple(inputs["input_ids"].shape) if "input_ids" in inputs else None, - inputs["cu_seqlens"].cpu().tolist() - if "cu_seqlens" in inputs - and isinstance(inputs["cu_seqlens"], torch.Tensor) - else None, - int(inputs["input_ids"].numel()) if "input_ids" in inputs else None, - ) - except Exception: - _R3_FWD_DIAG_LOGGER.exception( - "[SPLIT_MISMATCH_DIAG][forward_batch.process_output] " - "log-emit failed" - ) return None self.forward_backward_batch(mb_list, process_output, forward_only=True) @@ -2236,33 +2044,6 @@ def _prepare_mb_list(self, input_: dict[str, Any]) -> MicroBatchList: pad_to_maximum=self.config.pad_to_maximum, seq_align_to=align_to_multiple_of, ) - # [SPLIT_MISMATCH_DIAG] Log mb_list geometry. - try: - _R3_FWD_DIAG_LOGGER.info( - "[SPLIT_MISMATCH_DIAG][_prepare_mb_list] rank=%s " - "n_mbs=%d align_to_multiple_of=%d " - "group_lens=%s padded_to_lengths=%s padding_lengths=%s " - "align_to_lengths=%s " - "forward_indices=%s backward_indices=%s " - "max_seqlen=%s pp=%d cp=%d tp=%d", - dist.get_rank() if dist.is_initialized() else None, - len(mb_list.mbs), - align_to_multiple_of, - getattr(mb_list, "group_lens", None), - getattr(mb_list, "padded_to_lengths", None), - getattr(mb_list, "padding_lengths", None), - getattr(mb_list, "align_to_lengths", None), - list(getattr(mb_list, "forward_indices", []) or []), - list(getattr(mb_list, "backward_indices", []) or []), - getattr(mb_list, "max_seqlen", None), - pp_size, - cp_size, - tp_size, - ) - except Exception: - _R3_FWD_DIAG_LOGGER.exception( - "[SPLIT_MISMATCH_DIAG][_prepare_mb_list] log-emit failed" - ) self.logger.info( f"#microbatch: {len(mb_list.group_lens)}, microbatch #tokens: {mb_list.group_lens}, " f"aligned to: {mb_list.align_to_lengths}, padded to: {mb_list.padded_to_lengths}, " @@ -2389,24 +2170,6 @@ def _compute_forward_result( if mpu.get_tensor_model_parallel_world_size() > 1 else None, ) - # [SPLIT_MISMATCH_DIAG] Log gather_logprobs i/o shapes. - try: - _R3_FWD_DIAG_LOGGER.info( - "[SPLIT_MISMATCH_DIAG][_compute_forward_result] " - "rank=%s output.shape=%s labels.shape=%s " - "logprobs.shape=%s logprobs.dim0=%s", - dist.get_rank() if dist.is_initialized() else None, - tuple(output.shape) if hasattr(output, "shape") else None, - tuple(labels.shape) if hasattr(labels, "shape") else None, - tuple(logprobs.shape) if hasattr(logprobs, "shape") else None, - int(logprobs.shape[0]) - if hasattr(logprobs, "shape") and logprobs.ndim >= 1 - else -1, - ) - except Exception: - _R3_FWD_DIAG_LOGGER.exception( - "[SPLIT_MISMATCH_DIAG][_compute_forward_result] log-emit failed" - ) return logprobs else: values = output.squeeze(-1) diff --git a/areal/engine/router_replay_utils.py b/areal/engine/router_replay_utils.py index d6ffa19fbe..d458b93b4a 100644 --- a/areal/engine/router_replay_utils.py +++ b/areal/engine/router_replay_utils.py @@ -7,7 +7,6 @@ from __future__ import annotations import inspect -import os from typing import Optional import torch @@ -20,15 +19,6 @@ # name so the logger survives the dictConfig(disable_existing_loggers=True) re-init path. logger = logging.getLogger("R3/utils") - -# R3 verbose-logging helpers (controlled via AREAL_R3_VERBOSE env var). - - -def _r3_verbose() -> bool: - """Return whether R3 verbose logging is enabled.""" - return os.environ.get("AREAL_R3_VERBOSE", "1") != "0" - - # =================================================================== # Layer computation helpers # =================================================================== diff --git a/areal/trainer/ppo/actor.py b/areal/trainer/ppo/actor.py index 9a10e296bd..a34e3ac590 100644 --- a/areal/trainer/ppo/actor.py +++ b/areal/trainer/ppo/actor.py @@ -153,38 +153,6 @@ def _compute_logp(self, data: dict[str, Any]) -> torch.Tensor | None: input_=data, aggregate_fn=lambda xs: torch.cat(xs, dim=-1), ) - # [SPLIT_MISMATCH_DIAG] Log shape of returned train_logp so we can - # correlate with per-MB output shapes logged by forward_batch. - try: - import torch.distributed as _dist - - from areal.utils.logging import getLogger as _getLogger - - _diag = _getLogger("R3FwdDiag") - _diag.info( - "[SPLIT_MISMATCH_DIAG][actor._compute_logp] POST_FORWARD " - "rank=%s train_logp.shape=%s train_logp.dim0=%s " - "data.input_ids.shape=%s data.attention_mask.shape=%s " - "data.attention_mask.sum=%s", - _dist.get_rank() if _dist.is_initialized() else None, - tuple(train_logp.shape) - if isinstance(train_logp, torch.Tensor) - else None, - int(train_logp.shape[0]) - if isinstance(train_logp, torch.Tensor) and train_logp.ndim >= 1 - else -1, - tuple(data["input_ids"].shape) if "input_ids" in data else None, - tuple(data["attention_mask"].shape) - if "attention_mask" in data - else None, - int(data["attention_mask"].sum().item()) - if "attention_mask" in data - else None, - ) - except Exception: - logger.exception( - "[SPLIT_MISMATCH_DIAG][actor._compute_logp POST_FORWARD] log-emit failed" - ) # R3 effectiveness metrics. At compute_logp time the training weights # equal the rollout weights (no optimizer step has touched θ in this # rollout epoch), so comparing SGLang's cached logprobs against the diff --git a/areal/utils/data.py b/areal/utils/data.py index 17568927ad..09368e4ef2 100644 --- a/areal/utils/data.py +++ b/areal/utils/data.py @@ -437,35 +437,6 @@ def unpack_sequence( ): """Unpack a sequence tensor into a list of tensors based on cumulative sequence lengths.""" if lens is not None: - # [SPLIT_MISMATCH_DIAG] Pre-flight check that emits a detailed log - # BEFORE torch.split raises so we capture the exact mismatch. - try: - _x_dim = int(x.shape[dim]) if x.ndim > dim else -1 - _sum_lens = int(sum(lens)) - if _x_dim != _sum_lens: - import torch.distributed as _dist - from areal.utils.logging import getLogger as _getLogger - _diag = _getLogger("R3SplitDiag") - _diag.error( - "[SPLIT_MISMATCH_DIAG][unpack_sequence] PRE_SPLIT_MISMATCH " - "rank=%s x.shape=%s dim=%d x.size(dim)=%d " - "len(lens)=%d sum(lens)=%d delta=%d " - "lens_head=%s lens_tail=%s " - "min_len=%s max_len=%s", - _dist.get_rank() if _dist.is_initialized() else None, - tuple(x.shape), - dim, - _x_dim, - len(lens), - _sum_lens, - _sum_lens - _x_dim, - list(lens[:16]), - list(lens[-16:]), - min(lens) if lens else None, - max(lens) if lens else None, - ) - except Exception: - pass return torch.split(x, lens, dim=dim) if cu_seqlens is not None: return torch.split( @@ -1078,16 +1049,6 @@ def unpad_logits( cu_seqlens: torch.Tensor | None = None, old_cu_seqlens: torch.Tensor | None = None, ): - # [SPLIT_MISMATCH_DIAG] Log unpad_logits inputs/outputs centrally so we - # observe ALL call sites (engine forward_step, etc.). - try: - import torch.distributed as _dist - from areal.utils.logging import getLogger as _getLogger - _diag = _getLogger("R3SplitDiag") - _in_shape = tuple(logits.shape) - except Exception: - _diag = None - _in_shape = None # TODO: when using megatron, logits are in fp32, # create new logits in bucket to reduce peak memory usage # First unpad batch @@ -1108,42 +1069,8 @@ def unpad_logits( start = cu_seqlens[i].item() length = old_end - old_start new_logits[old_start:old_end] = logits[start : start + length] - if _diag is not None: - try: - _diag.info( - "[SPLIT_MISMATCH_DIAG][unpad_logits] rank=%s " - "in_shape=%s padding_length=%d " - "after_pad_strip_shape=%s out_shape=%s " - "cu_seqlens_last=%s old_cu_seqlens_last=%s " - "batch_size=%d", - _dist.get_rank() if _dist.is_initialized() else None, - _in_shape, - int(padding_length), - tuple(logits.shape), - tuple(new_logits.shape), - int(cu_seqlens[-1].item()) - if cu_seqlens is not None - else None, - int(old_cu_seqlens[-1].item()), - batch_size, - ) - except Exception: - pass return new_logits - if _diag is not None: - try: - _diag.info( - "[SPLIT_MISMATCH_DIAG][unpad_logits] rank=%s " - "in_shape=%s padding_length=%d out_shape=%s " - "old_cu_seqlens=None_branch", - _dist.get_rank() if _dist.is_initialized() else None, - _in_shape, - int(padding_length), - tuple(logits.shape), - ) - except Exception: - pass return logits diff --git a/areal/utils/logging.py b/areal/utils/logging.py index 0c1ab583c3..6752134a76 100644 --- a/areal/utils/logging.py +++ b/areal/utils/logging.py @@ -122,9 +122,6 @@ "InferenceRouter": "white", "InferenceGateway": "white", "RPCGuard": "white", - # R3 diagnostic loggers - yellow (debug) - "R3SplitDiag": "yellow", - "R3FwdDiag": "yellow", } # Prefix patterns checked in order (first match wins) From 1af8cc335f5bf3490b65a0c94431a72a15c3bd2d Mon Sep 17 00:00:00 2001 From: bingyechen Date: Thu, 7 May 2026 13:56:29 +0800 Subject: [PATCH 108/112] refactor(megatron_engine): fix --- areal/engine/megatron_engine.py | 61 ++++++++------------------------- 1 file changed, 14 insertions(+), 47 deletions(-) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index 5ef2252587..2862e8117d 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -918,32 +918,13 @@ def forward_step(batch_iter, model): tree_attn_keys = list(tree_kwargs.keys()) cp_size = mpu.get_context_parallel_world_size() - # Disable the CP-local output path unconditionally. - # - # The CP-local branch below only rewrites ``loss_mask`` / - # ``cu_seqlens`` / ``_cp_local_labels`` inside ``cp_inputs``, but - # leaves the other per-token tensors carried by ``mb_input.orig_mb`` - # (``logprobs``, ``prox_logp``, ``advantages``, ``rewards``, - # ``versions``, ...) at their original full real-sequence length. - # Downstream consumers then see inconsistent shapes, for example: - # * ``forward_batch`` (compute_logp / advantages) aggregates the - # per-MB logprobs via ``reorder_and_pad_outputs`` which splits - # by ``output_seqlens`` (full real lens). CP-local outputs are - # ``padded_to_length // cp_size`` rows and break - # ``torch.split_with_sizes``. - # * ``train_batch`` (ppo_update) routes the CP-local ``loss_mask`` - # into ``grpo_loss_fn`` / ``ppo_actor_loss_fn`` / - # ``apply_rejection_sampling`` together with the full-length - # ``proximal_logprobs`` / ``old_logprobs``, raising - # ``proximal_logprobs shape [N] != loss_mask shape [N/cp]``. - # - # Going through ``packed_context_parallel_forward(gather_cp_output= - # True)`` + ``unpad_logits`` restores each MB's output to the full - # real-sequence length, so it aligns with every tensor in - # ``orig_mb`` regardless of whether we are in forward-only or - # training mode. The attention / MoE forward still runs CP-sharded - # inside ``packed_context_parallel_forward``; only the final - # logits are all-gathered along the CP dimension. + # CP-local path only rewrites loss_mask/cu_seqlens, but orig_mb tensors + # (logprobs, advantages, etc.) retain full sequence length, causing shape + # mismatches downstream (e.g. reorder_and_pad_outputs split failure, or + # loss_mask vs proximal_logprobs dimension mismatch in ppo_update). + # Force gather_cp_output=True + unpad_logits to restore full-length output. + # CP sharding still runs inside packed_context_parallel_forward; only the + # final logits are all-gathered along the CP dimension. _ = cp_size # kept for the dead CP-local branch below cp_local = False @@ -1412,27 +1393,13 @@ def _create_optimizer(self, ft_spec: FinetuneSpec) -> None: mcore_opt_config.exp_avg_sq_dtype = getattr( torch, self.mcore_config.exp_avg_sq_dtype ) - # Precision-aware optimizer for MoE models - # ----------------------------------------- - # When ``use_precision_aware_optimizer=True``, Megatron-Core stores - # optimizer states (params, grads, exp_avg, exp_avg_sq) in - # user-specified dtypes and compensates for low-precision rounding - # via fp32 remainders. Two additional flags are set here: - # - # ``use_precision_aware_optimizer_no_fp8_or_ds_fp8``: - # A derived flag that is True when the precision-aware optimizer - # is enabled AND at least one of the following holds: - # 1. ``main_params_dtype != float32`` -- low-precision master - # params (e.g. bf16) need rounding-error compensation. - # 2. ``fp8_recipe is None or "delayed"`` -- no real-time FP8 - # scaling, so the optimizer must handle quantisation error - # itself (delayed-scaling FP8 updates the scale factor with - # a lag, conflicting with the optimizer's immediate residual - # correction). - # 3. ``optimizer_cpu_offload`` -- CPU offload introduces extra - # precision conversions that require special handling. - # When True, the optimizer applies its full residual-compensation - # logic during the parameter update step. + # Precision-aware optimizer: stores optimizer states in user-specified + # dtypes and compensates low-precision rounding via fp32 remainders. + # ``use_precision_aware_optimizer_no_fp8_or_ds_fp8`` is True when + # precision-aware optimizer is enabled AND any of: (1) main_params_dtype + # != fp32, (2) no real-time FP8 scaling (fp8_recipe is None/delayed), + # (3) optimizer_cpu_offload. When True, full residual-compensation is + # applied during the parameter update step. mcore_opt_config.use_precision_aware_optimizer_no_fp8_or_ds_fp8 = ( mcore_opt_config.use_precision_aware_optimizer and ( From 8fc8028e73d66f972caa5e7be88861f4bb646661 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Thu, 7 May 2026 14:10:42 +0800 Subject: [PATCH 109/112] refactor(ppo): remove useless --- areal/trainer/ppo/actor.py | 46 -------------------------------------- 1 file changed, 46 deletions(-) diff --git a/areal/trainer/ppo/actor.py b/areal/trainer/ppo/actor.py index a34e3ac590..ce03d4bd0e 100644 --- a/areal/trainer/ppo/actor.py +++ b/areal/trainer/ppo/actor.py @@ -697,52 +697,6 @@ def grpo_loss_fn( denominator="n_valid_tokens", ) - # ---- R3 Logprob Diff: rollout (inference) vs training logprobs ---- - # Three metrics quantify train/infer divergence at different granularities: - # - # 1. ``rollout_train_logprobs_abs_diff`` (old_logp vs current-train logprobs) - # Conflates (a) train/infer gap AND (b) intra-ppo_update weight drift - # from earlier mini-batches. At ``max_head_offpolicyness>0`` this metric - # grows with mini-batch index because ``logprobs`` is forwarded on - # W_current (already stepped i times within the epoch), while - # ``old_logp`` stays fixed at W_rollout. R3 cannot reduce this drift. - # - # 2. ``rollout_train_logprobs_abs_diff_prox`` (old_logp vs prox_logp_gt) - # The *pure* train/infer mismatch: both evaluated on W_rollout, differing - # only in SGLang-vs-Megatron forward (fused kernel + router top-k). - # With R3 enabled and weights fully synced (k=0), this should collapse - # to BF16 kernel noise. With R3 off, ~2% per-expert router flips add - # ~0.1-0.2 abs-diff per token. This is the metric to validate R3 - # effectiveness across staleness settings. - # - # Both are logged via ``stats_tracker.stat`` with ``n_valid_tokens`` - # denominator so the aggregation is a proper token-weighted global - # mean/min/max (cross-mini-batch, cross-rank), not a Python-float mean - # of per-minibatch local scalars (which would be small-batch-biased and - # would report local stds as if they were global ones). - if loss_mask.any(): - with torch.no_grad(): - _diff_abs = (old_logp - logprobs.detach()).abs().float() - _diff_abs = torch.where(loss_mask, _diff_abs, torch.zeros_like(_diff_abs)) - _diff_sq = (_diff_abs * _diff_abs).float() - stats_tracker.stat( - rollout_train_logprobs_abs_diff=_diff_abs, - rollout_train_logprobs_sq_diff=_diff_sq, - denominator="n_valid_tokens", - ) - if prox_logp_gt is not None: - with torch.no_grad(): - _prox_abs = (old_logp - prox_logp_gt.detach()).abs().float() - _prox_abs = torch.where( - loss_mask, _prox_abs, torch.zeros_like(_prox_abs) - ) - _prox_sq = (_prox_abs * _prox_abs).float() - stats_tracker.stat( - rollout_train_logprobs_abs_diff_prox=_prox_abs, - rollout_train_logprobs_sq_diff_prox=_prox_sq, - denominator="n_valid_tokens", - ) - if "behave_imp_weight" in stat: stats_tracker.denominator(unclipped_behave_tokens=stat["behave_mask"]) stats_tracker.stat( From 13384d54b727344f227647d2e18591fff85ea6bc Mon Sep 17 00:00:00 2001 From: bingyechen Date: Thu, 7 May 2026 15:05:02 +0800 Subject: [PATCH 110/112] feat(math): fix config --- ..._base.yaml => gsm8k_grpo_megatron_r3.yaml} | 37 ++-- ...moonlight_16b_a3b_gsm8k_grpo_megatron.yaml | 204 ------------------ ...ight_16b_a3b_gsm8k_grpo_megatron_base.yaml | 194 ----------------- ...light_16b_a3b_gsm8k_grpo_megatron_h20.yaml | 204 ------------------ 4 files changed, 23 insertions(+), 616 deletions(-) rename examples/math/{moonlight_16b_a3b_gsm8k_grpo_megatron_h20_base.yaml => gsm8k_grpo_megatron_r3.yaml} (77%) delete mode 100644 examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron.yaml delete mode 100644 examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron_base.yaml delete mode 100644 examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron_h20.yaml diff --git a/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron_h20_base.yaml b/examples/math/gsm8k_grpo_megatron_r3.yaml similarity index 77% rename from examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron_h20_base.yaml rename to examples/math/gsm8k_grpo_megatron_r3.yaml index a51916e143..4afca12666 100644 --- a/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron_h20_base.yaml +++ b/examples/math/gsm8k_grpo_megatron_r3.yaml @@ -1,4 +1,4 @@ -experiment_name: moonlight-16b-a3b-gsm8k-grpo-h20-base +experiment_name: gsm8k-grpo-megatron-r3 trial_name: trial0 seed: 1 @@ -18,18 +18,24 @@ scheduler: type: null rollout: - backend: "sglang:d1p1t2" + backend: "sglang:d2p1t2" experiment_name: ${experiment_name} trial_name: ${trial_name} - max_concurrent_rollouts: 16 + max_concurrent_rollouts: null queue_size: null consumer_batch_size: ${train_dataset.batch_size} - max_head_offpolicyness: 0 + max_head_offpolicyness: 1 enable_rollout_tracing: false scheduling_spec: ${actor.scheduling_spec} fileroot: ${cluster.fileroot} tokenizer_path: ${tokenizer_path} - dump_to_file: true + dump_to_file: false + # R3: Enable returning routed expert assignments from rollout inference. + # This triggers the entire Router Replay pipeline: SGLang returns per-token + # expert indices, which are then replayed during training to eliminate + # train/inference routing mismatch in MoE models. + # num_moe_layers and topk are automatically resolved from the model config. + return_routed_experts: true gconfig: n_samples: 8 @@ -39,20 +45,19 @@ gconfig: temperature: 1.0 actor: - backend: "megatron:(attn:d1p1t2c2|ffn:d1p1t1e4)" # ← PP=1, attn TP=2 CP=2, ffn EP=2 - # backend: "megatron:(attn:d1p1t4|ffn:d1p1t1e4)" + backend: "megatron:(attn:d1p1t4|ffn:d1p1t1e4)" experiment_name: ${experiment_name} trial_name: ${trial_name} - path: /workspace/models/Moonlight-16B-A3B-Instruct + path: moonshotai/Moonlight-16B-A3B-Instruct init_from_scratch: false disable_dropout: true gradient_checkpointing: true dtype: bfloat16 mb_spec: - max_tokens_per_mb: 1280 # ← 从 2048 降至 512 + max_tokens_per_mb: 1280 optimizer: type: adam - lr: 2e-6 + lr: 3e-6 weight_decay: 0.003 beta1: 0.9 beta2: 0.999 @@ -65,7 +70,7 @@ actor: reward_scaling: 10.0 reward_bias: -0.5 kl_ctl: 0.0 - ppo_n_minibatches: 1 # ← 从 1 提高至 4(分批梯度累积) + ppo_n_minibatches: 1 recompute_logprob: true use_decoupled_loss: true rejection_sampling: @@ -89,7 +94,7 @@ actor: recompute_method: uniform recompute_num_layers: 1 ddp: - grad_reduce_in_fp32: false # ← 保持逐层重计算 + grad_reduce_in_fp32: false scheduling_spec: - task_type: worker port_count: 2 @@ -120,11 +125,15 @@ sglang: random_seed: ${seed} skip_tokenizer_init: true dtype: bfloat16 - max_running_requests: 64 + max_running_requests: null context_length: 2048 mem_fraction_static: 0.8 attention_backend: triton - disable_cuda_graph: true + # R3: Enable SGLang to capture and return per-token routed expert indices + # during inference. This is auto-set by rl_trainer when + # rollout.return_routed_experts=true, but explicitly declared here for clarity. + enable_return_routed_experts: true + chunked_prefill_size: 4096 vllm: model: ${actor.path} diff --git a/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron.yaml b/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron.yaml deleted file mode 100644 index 65129ca71b..0000000000 --- a/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron.yaml +++ /dev/null @@ -1,204 +0,0 @@ -experiment_name: moonlight-16b-a3b-gsm8k-grpo -trial_name: trial0 - -seed: 1 -enable_offload: false -total_train_epochs: 10 -tokenizer_path: ${actor.path} - -cluster: - n_nodes: 1 - n_gpus_per_node: 8 - fileroot: /tmp/areal/moon_experiments - name_resolve: - type: nfs - nfs_record_root: /tmp/areal/moon_name_resolve - -scheduler: - type: null - -rollout: - backend: "sglang:d1p1t8" - experiment_name: ${experiment_name} - trial_name: ${trial_name} - max_concurrent_rollouts: 64 - queue_size: null - consumer_batch_size: ${train_dataset.batch_size} - max_head_offpolicyness: 2 - enable_rollout_tracing: false - scheduling_spec: ${actor.scheduling_spec} - fileroot: ${cluster.fileroot} - tokenizer_path: ${tokenizer_path} - dump_to_file: true - # R3: Enable returning routed expert assignments from rollout inference. - # This triggers the entire Router Replay pipeline: SGLang returns per-token - # expert indices, which are then replayed during training to eliminate - # train/inference routing mismatch in MoE models. - # num_moe_layers and topk are automatically resolved from the model config. - return_routed_experts: true - -gconfig: - n_samples: 8 - min_new_tokens: 0 - max_new_tokens: 1024 - greedy: false - temperature: 1.0 - -actor: - backend: "megatron:(attn:d1p2t4|ffn:d1p2t1e4)" # ← PP=2 回退,TP=4/EP=4 - experiment_name: ${experiment_name} - trial_name: ${trial_name} - path: /workspace/models/Moonlight-16B-A3B-Instruct - init_from_scratch: false - disable_dropout: true - gradient_checkpointing: true - dtype: bfloat16 - mb_spec: - max_tokens_per_mb: 1280 # ← 从 2048 降至 512 - optimizer: - type: adam_bf16 - lr: 2e-6 - weight_decay: 0.003 - beta1: 0.9 - beta2: 0.999 - eps: 1e-8 - lr_scheduler_type: constant - gradient_clipping: 1.0 - warmup_steps_proportion: 0.001 - eps_clip: 0.2 - temperature: ${gconfig.temperature} - reward_scaling: 10.0 - reward_bias: -0.5 - kl_ctl: 0.0 - ppo_n_minibatches: 1 # ← 从 1 提高至 4(分批梯度累积) - recompute_logprob: true - use_decoupled_loss: true - behave_imp_weight_cap: 5.0 - reward_norm: - mean_level: group - std_level: group - group_size: ${gconfig.n_samples} - adv_norm: - mean_level: batch - std_level: batch - weight_update_mode: disk - max_new_tokens: ${gconfig.max_new_tokens} - megatron: - use_deterministic_algorithms: false - recompute_granularity: full - recompute_method: uniform - recompute_num_layers: 14 - main_grads_dtype: bfloat16 # 梯度从 FP32 降为 BF16(节省 ~4 GiB) - # store_param_remainders: true - # optimizer_cpu_offload: true - # optimizer_offload_fraction: 0.5 - # main_params_dtype: bfloat16 - # main_grads_dtype: bfloat16 - # # adam_bf16 已自动设置以下两项,但显式声明更安全 - # exp_avg_dtype: bfloat16 - # exp_avg_sq_dtype: bfloat16 - ddp: - grad_reduce_in_fp32: false # ← 保持逐层重计算 - scheduling_spec: - - task_type: worker - port_count: 2 - gpu: 1 - mem: 48 - cmd: python3 -m areal.infra.rpc.rpc_server - env_vars: - PYTORCH_CUDA_ALLOC_CONF: "expandable_segments:True" - -ref: - backend: ${actor.backend} - experiment_name: ${experiment_name} - trial_name: ${trial_name} - path: ${actor.path} - init_from_scratch: false - disable_dropout: true - dtype: ${actor.dtype} - mb_spec: - max_tokens_per_mb: 1280 - optimizer: null - scheduling_strategy: - type: colocation - target: actor - scheduling_spec: ${actor.scheduling_spec} - -sglang: - model_path: ${actor.path} - random_seed: ${seed} - skip_tokenizer_init: false - dtype: bfloat16 - max_running_requests: 8 - context_length: 2048 - mem_fraction_static: 0.2 - attention_backend: triton - # R3: Enable SGLang to capture and return per-token routed expert indices - # during inference. This is auto-set by rl_trainer when - # rollout.return_routed_experts=true, but explicitly declared here for clarity. - enable_return_routed_experts: true - chunked_prefill_size: 2048 - -vllm: - model: ${actor.path} - seed: ${seed} - skip_tokenizer_init: false - dtype: bfloat16 - max_model_len: 4096 - gpu_memory_utilization: 0.75 - -train_dataset: - batch_size: 64 - shuffle: true - pin_memory: true - num_workers: 4 - path: openai/gsm8k - type: rl - max_length: 512 - -valid_dataset: - batch_size: 128 - pin_memory: true - num_workers: 4 - path: openai/gsm8k - type: rl - -saver: - experiment_name: ${experiment_name} - trial_name: ${trial_name} - fileroot: ${cluster.fileroot} - freq_epochs: 1 - freq_steps: null - freq_secs: null - -recover: - mode: disabled - experiment_name: ${experiment_name} - trial_name: ${trial_name} - fileroot: ${cluster.fileroot} - freq_epochs: 1 - freq_steps: null - freq_secs: 3600 - -evaluator: - experiment_name: ${experiment_name} - trial_name: ${trial_name} - fileroot: ${cluster.fileroot} - freq_epochs: 1 - freq_steps: null - freq_secs: null - -stats_logger: - experiment_name: ${experiment_name} - trial_name: ${trial_name} - fileroot: ${cluster.fileroot} - wandb: - mode: disabled - -perf_tracer: - experiment_name: ${experiment_name} - trial_name: ${trial_name} - fileroot: ${cluster.fileroot} - enabled: false - session_tracer: - enabled: false diff --git a/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron_base.yaml b/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron_base.yaml deleted file mode 100644 index a9bbd25530..0000000000 --- a/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron_base.yaml +++ /dev/null @@ -1,194 +0,0 @@ -experiment_name: moonlight-16b-a3b-gsm8k-grpo -trial_name: trial0 - -seed: 1 -enable_offload: false -total_train_epochs: 10 -tokenizer_path: ${actor.path} - -cluster: - n_nodes: 1 - n_gpus_per_node: 8 - fileroot: /tmp/areal/moon_experiments - name_resolve: - type: nfs - nfs_record_root: /tmp/areal/moon_name_resolve - -scheduler: - type: null - -rollout: - backend: "sglang:d1p1t8" - experiment_name: ${experiment_name} - trial_name: ${trial_name} - max_concurrent_rollouts: 64 - queue_size: null - consumer_batch_size: ${train_dataset.batch_size} - max_head_offpolicyness: 2 - enable_rollout_tracing: false - scheduling_spec: ${actor.scheduling_spec} - fileroot: ${cluster.fileroot} - tokenizer_path: ${tokenizer_path} - dump_to_file: true - -gconfig: - n_samples: 8 - min_new_tokens: 0 - max_new_tokens: 1024 - greedy: false - temperature: 1.0 - -actor: - backend: "megatron:(attn:d1p2t4|ffn:d1p2t1e4)" # ← PP=2 回退,TP=4/EP=4 - experiment_name: ${experiment_name} - trial_name: ${trial_name} - path: /workspace/models/Moonlight-16B-A3B-Instruct - init_from_scratch: false - disable_dropout: true - gradient_checkpointing: true - dtype: bfloat16 - mb_spec: - max_tokens_per_mb: 1280 # ← 从 2048 降至 512 - optimizer: - type: adam_bf16 - lr: 2e-6 - weight_decay: 0.003 - beta1: 0.9 - beta2: 0.999 - eps: 1e-8 - lr_scheduler_type: constant - gradient_clipping: 1.0 - warmup_steps_proportion: 0.001 - eps_clip: 0.2 - temperature: ${gconfig.temperature} - reward_scaling: 10.0 - reward_bias: -0.5 - kl_ctl: 0.0 - ppo_n_minibatches: 1 # ← 从 1 提高至 4(分批梯度累积) - recompute_logprob: true - use_decoupled_loss: true - behave_imp_weight_cap: 5.0 - reward_norm: - mean_level: group - std_level: group - group_size: ${gconfig.n_samples} - adv_norm: - mean_level: batch - std_level: batch - weight_update_mode: disk - max_new_tokens: ${gconfig.max_new_tokens} - megatron: - use_deterministic_algorithms: false - recompute_granularity: full - recompute_method: uniform - recompute_num_layers: 14 - main_grads_dtype: bfloat16 - # main_params_dtype: bfloat16 # 梯度从 FP32 降为 BF16(节省 ~4 GiB) - # store_param_remainders: true - # optimizer_cpu_offload: true - # optimizer_offload_fraction: 0.5 - # main_params_dtype: bfloat16 - # main_grads_dtype: bfloat16 - # # adam_bf16 已自动设置以下两项,但显式声明更安全 - # exp_avg_dtype: bfloat16 - # exp_avg_sq_dtype: bfloat16 - ddp: - grad_reduce_in_fp32: false # ← 保持逐层重计算 - scheduling_spec: - - task_type: worker - port_count: 2 - gpu: 1 - mem: 48 - cmd: python3 -m areal.infra.rpc.rpc_server - env_vars: - PYTORCH_CUDA_ALLOC_CONF: "expandable_segments:True" - -ref: - backend: ${actor.backend} - experiment_name: ${experiment_name} - trial_name: ${trial_name} - path: ${actor.path} - init_from_scratch: false - disable_dropout: true - dtype: ${actor.dtype} - mb_spec: - max_tokens_per_mb: 1280 - optimizer: null - scheduling_strategy: - type: colocation - target: actor - scheduling_spec: ${actor.scheduling_spec} - -sglang: - model_path: ${actor.path} - random_seed: ${seed} - skip_tokenizer_init: true - dtype: bfloat16 - max_running_requests: 8 - context_length: 2048 - mem_fraction_static: 0.2 - attention_backend: triton - -vllm: - model: ${actor.path} - seed: ${seed} - skip_tokenizer_init: false - dtype: bfloat16 - max_model_len: 4096 - gpu_memory_utilization: 0.75 - -train_dataset: - batch_size: 64 - shuffle: true - pin_memory: true - num_workers: 4 - path: openai/gsm8k - type: rl - max_length: 512 - -valid_dataset: - batch_size: 128 - pin_memory: true - num_workers: 4 - path: openai/gsm8k - type: rl - -saver: - experiment_name: ${experiment_name} - trial_name: ${trial_name} - fileroot: ${cluster.fileroot} - freq_epochs: 1 - freq_steps: null - freq_secs: null - -recover: - mode: disabled - experiment_name: ${experiment_name} - trial_name: ${trial_name} - fileroot: ${cluster.fileroot} - freq_epochs: 1 - freq_steps: null - freq_secs: 3600 - -evaluator: - experiment_name: ${experiment_name} - trial_name: ${trial_name} - fileroot: ${cluster.fileroot} - freq_epochs: 1 - freq_steps: null - freq_secs: null - -stats_logger: - experiment_name: ${experiment_name} - trial_name: ${trial_name} - fileroot: ${cluster.fileroot} - wandb: - mode: disabled - -perf_tracer: - experiment_name: ${experiment_name} - trial_name: ${trial_name} - fileroot: ${cluster.fileroot} - enabled: false - session_tracer: - enabled: false \ No newline at end of file diff --git a/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron_h20.yaml b/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron_h20.yaml deleted file mode 100644 index f71331d082..0000000000 --- a/examples/math/moonlight_16b_a3b_gsm8k_grpo_megatron_h20.yaml +++ /dev/null @@ -1,204 +0,0 @@ -experiment_name: moonlight-16b-a3b-gsm8k-grpo-h20 -trial_name: trial0 - -seed: 1 -enable_offload: false -total_train_epochs: 10 -tokenizer_path: ${actor.path} - -cluster: - n_nodes: 1 - n_gpus_per_node: 6 - fileroot: /tmp/areal/moon_experiments - name_resolve: - type: nfs - nfs_record_root: /tmp/areal/moon_name_resolve - -scheduler: - type: null - -rollout: - backend: "sglang:d1p1t2" - experiment_name: ${experiment_name} - trial_name: ${trial_name} - max_concurrent_rollouts: 128 - queue_size: null - consumer_batch_size: ${train_dataset.batch_size} - max_head_offpolicyness: 2 - enable_rollout_tracing: false - scheduling_spec: ${actor.scheduling_spec} - fileroot: ${cluster.fileroot} - tokenizer_path: ${tokenizer_path} - dump_to_file: true - # R3: Enable returning routed expert assignments from rollout inference. - # This triggers the entire Router Replay pipeline: SGLang returns per-token - # expert indices, which are then replayed during training to eliminate - # train/inference routing mismatch in MoE models. - # num_moe_layers and topk are automatically resolved from the model config. - return_routed_experts: true - -gconfig: - n_samples: 4 - min_new_tokens: 0 - max_new_tokens: 1024 - greedy: false - temperature: 1.0 - -actor: - backend: "megatron:(attn:d1p1t4|ffn:d1p1t1e4)" # ← PP=2 回退,TP=4/EP=4 - experiment_name: ${experiment_name} - trial_name: ${trial_name} - path: /workspace/models/Moonlight-16B-A3B - init_from_scratch: false - disable_dropout: true - gradient_checkpointing: false - dtype: bfloat16 - mb_spec: - max_tokens_per_mb: 10240 # ← 从 2048 降至 512 - optimizer: - type: adam - lr: 2e-6 - weight_decay: 0.003 - beta1: 0.9 - beta2: 0.999 - eps: 1e-8 - lr_scheduler_type: constant - gradient_clipping: 1.0 - warmup_steps_proportion: 0.001 - eps_clip: 0.2 - temperature: ${gconfig.temperature} - reward_scaling: 10.0 - reward_bias: -0.5 - kl_ctl: 0.0 - ppo_n_minibatches: 1 # ← 从 1 提高至 4(分批梯度累积) - recompute_logprob: true - use_decoupled_loss: true - behave_imp_weight_cap: 5.0 - reward_norm: - mean_level: group - std_level: group - group_size: ${gconfig.n_samples} - adv_norm: - mean_level: batch - std_level: batch - # weight_update_mode: disk - max_new_tokens: ${gconfig.max_new_tokens} - megatron: - use_deterministic_algorithms: false - recompute_granularity: full - recompute_method: uniform - recompute_num_layers: 14 - main_grads_dtype: bfloat16 # 梯度从 FP32 降为 BF16(节省 ~4 GiB) - # store_param_remainders: true - # optimizer_cpu_offload: true - # optimizer_offload_fraction: 0.5 - # main_params_dtype: bfloat16 - # main_grads_dtype: bfloat16 - # # adam_bf16 已自动设置以下两项,但显式声明更安全 - # exp_avg_dtype: bfloat16 - # exp_avg_sq_dtype: bfloat16 - ddp: - grad_reduce_in_fp32: false # ← 保持逐层重计算 - scheduling_spec: - - task_type: worker - port_count: 2 - gpu: 1 - mem: 48 - cmd: python3 -m areal.infra.rpc.rpc_server - env_vars: - PYTORCH_CUDA_ALLOC_CONF: "expandable_segments:True" - -ref: - backend: ${actor.backend} - experiment_name: ${experiment_name} - trial_name: ${trial_name} - path: ${actor.path} - init_from_scratch: false - disable_dropout: true - dtype: ${actor.dtype} - mb_spec: - max_tokens_per_mb: 10240 - optimizer: null - scheduling_strategy: - type: colocation - target: actor - scheduling_spec: ${actor.scheduling_spec} - -sglang: - model_path: ${actor.path} - random_seed: ${seed} - skip_tokenizer_init: false - dtype: bfloat16 - max_running_requests: null - context_length: 32768 - mem_fraction_static: 0.8 - attention_backend: triton - # R3: Enable SGLang to capture and return per-token routed expert indices - # during inference. This is auto-set by rl_trainer when - # rollout.return_routed_experts=true, but explicitly declared here for clarity. - enable_return_routed_experts: true - chunked_prefill_size: 2048 - -vllm: - model: ${actor.path} - seed: ${seed} - skip_tokenizer_init: false - dtype: bfloat16 - max_model_len: 4096 - gpu_memory_utilization: 0.75 - -train_dataset: - batch_size: 256 - shuffle: true - pin_memory: true - num_workers: 4 - path: openai/gsm8k - type: rl - max_length: 1024 - -valid_dataset: - batch_size: 256 - pin_memory: true - num_workers: 4 - path: openai/gsm8k - type: rl - -saver: - experiment_name: ${experiment_name} - trial_name: ${trial_name} - fileroot: ${cluster.fileroot} - freq_epochs: 1 - freq_steps: null - freq_secs: null - -recover: - mode: disabled - experiment_name: ${experiment_name} - trial_name: ${trial_name} - fileroot: ${cluster.fileroot} - freq_epochs: 1 - freq_steps: null - freq_secs: 3600 - -evaluator: - experiment_name: ${experiment_name} - trial_name: ${trial_name} - fileroot: ${cluster.fileroot} - freq_epochs: 1 - freq_steps: null - freq_secs: null - -stats_logger: - experiment_name: ${experiment_name} - trial_name: ${trial_name} - fileroot: ${cluster.fileroot} - wandb: - mode: disabled - -perf_tracer: - experiment_name: ${experiment_name} - trial_name: ${trial_name} - fileroot: ${cluster.fileroot} - enabled: false - session_tracer: - enabled: false From 35014fe4986a1714c057dd06df656d504f04c8f2 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Thu, 7 May 2026 15:44:23 +0800 Subject: [PATCH 111/112] refactor(router): fix precommit --- areal/engine/router_replay_patch.py | 21 ++++++++----------- areal/engine/router_replay_utils.py | 1 - .../torchrun/run_router_replay_distributed.py | 3 --- 3 files changed, 9 insertions(+), 16 deletions(-) diff --git a/areal/engine/router_replay_patch.py b/areal/engine/router_replay_patch.py index a96e0bf575..8f1df77305 100644 --- a/areal/engine/router_replay_patch.py +++ b/areal/engine/router_replay_patch.py @@ -29,12 +29,19 @@ apply_router_token_dropping, compute_routing_scores_for_aux_loss, ) + from megatron.core.transformer.moe.router import TopKRouter + from megatron.core.transformer.moe.token_dispatcher import ( + MoEAlltoAllTokenDispatcher, + ) + from megatron.core.transformer.transformer_config import TransformerConfig except ImportError: apply_router_token_dropping = None compute_routing_scores_for_aux_loss = None + TopKRouter = None + MoEAlltoAllTokenDispatcher = None + TransformerConfig = None warnings.warn( - "[R3] Could not import apply_router_token_dropping / " - "compute_routing_scores_for_aux_loss from megatron.core; " + "[R3] Could not import megatron.core MoE components; " "some MoE features may be unavailable.", stacklevel=2, ) @@ -44,16 +51,6 @@ except ImportError: group_limited_topk = None -try: - from megatron.core.transformer.moe.token_dispatcher import ( - MoEAlltoAllTokenDispatcher, - ) -except ImportError: - MoEAlltoAllTokenDispatcher = None - -from megatron.core.transformer.moe.router import TopKRouter -from megatron.core.transformer.transformer_config import TransformerConfig - # =================================================================== # RouterReplayAction enum and RouterReplay class diff --git a/areal/engine/router_replay_utils.py b/areal/engine/router_replay_utils.py index d458b93b4a..ccc8306b3d 100644 --- a/areal/engine/router_replay_utils.py +++ b/areal/engine/router_replay_utils.py @@ -349,7 +349,6 @@ def set_router_replay_data( row_all_zero = ( (packed == 0).reshape(packed.shape[0], -1).all(dim=-1) ) - n_strike = int((valid_mask & row_all_zero).sum().item()) valid_mask = valid_mask & (~row_all_zero) # Step 2: CP split (before TP scatter). diff --git a/tests/torchrun/run_router_replay_distributed.py b/tests/torchrun/run_router_replay_distributed.py index 8fda425d0f..9340d0c3cc 100644 --- a/tests/torchrun/run_router_replay_distributed.py +++ b/tests/torchrun/run_router_replay_distributed.py @@ -189,9 +189,6 @@ def _build_training_input_with_rollout_experts( rollout_logprobs = ( -torch.rand((bs, slen), generator=gen) * 2.0 ).to(engine.device) - action_log_probs = ( - -torch.rand((bs, slen), generator=gen) * 2.0 - ).to(engine.device) advantages = torch.randn((bs, slen), generator=gen).to(engine.device) loss_mask = attention_mask.to(dtype=torch.int64) From d2b6d4c652d01775d24ee61b83fcf12f0618828c Mon Sep 17 00:00:00 2001 From: root Date: Thu, 7 May 2026 07:46:44 +0000 Subject: [PATCH 112/112] feat: fix pre commit --- areal/engine/megatron_engine.py | 1 - areal/engine/megatron_utils/megatron_lora.py | 1 + areal/engine/router_replay_patch.py | 40 ++++++++---- areal/engine/router_replay_utils.py | 58 ++++++++++------- areal/infra/launcher/sglang_r3_patch.py | 2 + areal/infra/rpc/guard/engine_blueprint.py | 4 +- areal/infra/rpc/serialization.py | 8 +-- areal/infra/scheduler/local.py | 2 +- docs/en/cli_reference.md | 62 +++++++++---------- docs/zh/cli_reference.md | 62 +++++++++---------- examples/math/gsm8k_grpo_megatron_r3.yaml | 2 +- tests/test_r3_mask_alignment.py | 1 - tests/test_router_replay.py | 38 ++++++------ tests/test_router_replay_e2e.py | 4 +- .../torchrun/run_router_replay_distributed.py | 29 ++++----- 15 files changed, 165 insertions(+), 149 deletions(-) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index 8cc1f6a477..0d33a6268a 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -102,7 +102,6 @@ ) from areal.models.tree_attn.tree import build_packed_tree_batch from areal.utils import logging, name_resolve, names, perf_tracer, stats_tracker - from areal.utils.constants import ( DEFAULT_VECTORIZED_ALIGNMENT_BYTES, DIST_GROUP_DEFAULT_TIMEOUT, diff --git a/areal/engine/megatron_utils/megatron_lora.py b/areal/engine/megatron_utils/megatron_lora.py index 6700f0bc3e..562caf783e 100644 --- a/areal/engine/megatron_utils/megatron_lora.py +++ b/areal/engine/megatron_utils/megatron_lora.py @@ -393,6 +393,7 @@ def save_hf_adapter( _monkey_patch_applied = False + # Current: This monkey patch is needed as the current megatron-bridge 0.3.0 does not have a built-in method # to save LoRA adapters in HuggingFace PEFT format, which is required for our use case. # Future: This code is however present in main branch of megatron-bridge so this patch is temporary diff --git a/areal/engine/router_replay_patch.py b/areal/engine/router_replay_patch.py index 8f1df77305..94dfc0e346 100644 --- a/areal/engine/router_replay_patch.py +++ b/areal/engine/router_replay_patch.py @@ -1,3 +1,5 @@ +# SPDX-License-Identifier: Apache-2.0 + """Monkey-patches for Megatron-Core MoE components to support Router Replay (R3). Forces TopKRouter to use pre-recorded expert assignments from rollout inference @@ -74,7 +76,7 @@ class RouterReplay: """ # Class-level list of all router instances (one per MoE layer). - router_instances: list["RouterReplay"] = [] + router_instances: list[RouterReplay] = [] # Class-level pipeline parallelism size for backward remapping. # Set by the engine patch before forward_backward_func. @@ -142,6 +144,7 @@ def __init__(self) -> None: self.creation_order: int = len(RouterReplay.router_instances) try: import torch.distributed as _dist + self.creator_rank: int = _dist.get_rank() if _dist.is_initialized() else -1 except Exception: self.creator_rank = -1 @@ -230,7 +233,9 @@ def compute_topk(scores, topk, num_groups=None, group_topk=None): ) if routing_action is None: - return _compute_topk(scores, topk, num_groups=num_groups, group_topk=group_topk) + return _compute_topk( + scores, topk, num_groups=num_groups, group_topk=group_topk + ) if routing_action == RouterReplayAction.RECORD: probs, top_indices = _compute_topk( @@ -247,7 +252,9 @@ def compute_topk(scores, topk, num_groups=None, group_topk=None): "[R3] REPLAY_FORWARD: no replay indices available, " "falling back to normal routing." ) - return _compute_topk(scores, topk, num_groups=num_groups, group_topk=group_topk) + return _compute_topk( + scores, topk, num_groups=num_groups, group_topk=group_topk + ) # Use the provided indices for replay top_indices = router_replay.target_topk_idx @@ -274,7 +281,9 @@ def compute_topk(scores, topk, num_groups=None, group_topk=None): "[R3] REPLAY_BACKWARD: no backward indices available, " "falling back to normal routing." ) - return _compute_topk(scores, topk, num_groups=num_groups, group_topk=group_topk) + return _compute_topk( + scores, topk, num_groups=num_groups, group_topk=group_topk + ) # Backward recompute: use the recorded indices from the forward pass. top_indices = router_replay.replay_backward_list.pop(0) top_indices = top_indices.to(scores.device) @@ -306,7 +315,9 @@ def compute_topk(scores, topk, num_groups=None, group_topk=None): return probs, top_indices else: - return _compute_topk(scores, topk, num_groups=num_groups, group_topk=group_topk) + return _compute_topk( + scores, topk, num_groups=num_groups, group_topk=group_topk + ) # --- Score function dispatch --- if score_function == "softmax": @@ -320,11 +331,15 @@ def compute_topk(scores, topk, num_groups=None, group_topk=None): scores = torch.sigmoid(logits.float()).type_as(logits) if expert_bias is not None: scores_for_routing = scores + expert_bias - _, top_indices = compute_topk(scores_for_routing, topk, num_groups, group_topk) + _, top_indices = compute_topk( + scores_for_routing, topk, num_groups, group_topk + ) scores = torch.gather(scores, dim=1, index=top_indices).type_as(logits) else: scores, top_indices = compute_topk(scores, topk, num_groups, group_topk) - probs = scores / (scores.sum(dim=-1, keepdim=True) + 1e-20) if topk > 1 else scores + probs = ( + scores / (scores.sum(dim=-1, keepdim=True) + 1e-20) if topk > 1 else scores + ) else: raise ValueError(f"[R3] Invalid score_function: {score_function}") @@ -445,7 +460,9 @@ def patched_routing(self, logits: torch.Tensor, *args, **kwargs): fused=self.config.moe_router_fusion, ) ) - probs = self._apply_aux_loss(probs, scores_for_aux_loss, routing_map_for_aux_loss) + probs = self._apply_aux_loss( + probs, scores_for_aux_loss, routing_map_for_aux_loss + ) probs = self._apply_seq_aux_loss( probs, scores_for_aux_loss, routing_map_for_aux_loss, seq_length, bsz ) @@ -563,7 +580,9 @@ def patched_tf_config_init(self, *args, **kwargs): TransformerConfig.__init__ = patched_tf_config_init TransformerConfig._r3_config_patched = True - logger.debug("[R3] TransformerConfig.__init__ patched to accept enable_routing_replay.") + logger.debug( + "[R3] TransformerConfig.__init__ patched to accept enable_routing_replay." + ) def _undo_transformer_config_patch() -> None: @@ -595,8 +614,7 @@ def patched_init(self, *args, **kwargs): if getattr(self.config, "enable_routing_replay", False): self.router_replay = RouterReplay() logger.debug( - "[R3] TopKRouter: created RouterReplay instance " - "(total instances: %d).", + "[R3] TopKRouter: created RouterReplay instance (total instances: %d).", len(RouterReplay.router_instances), ) diff --git a/areal/engine/router_replay_utils.py b/areal/engine/router_replay_utils.py index ccc8306b3d..0c1a6c97f0 100644 --- a/areal/engine/router_replay_utils.py +++ b/areal/engine/router_replay_utils.py @@ -1,3 +1,5 @@ +# SPDX-License-Identifier: Apache-2.0 + """Router Replay (R3) utilities for AReaL. Converts rollout routing indices into Megatron-Core's RouterReplay layout: @@ -7,12 +9,10 @@ from __future__ import annotations import inspect -from typing import Optional import torch from areal.engine.router_replay_patch import RouterReplay, RouterReplayAction - from areal.utils import logging # NOTE: use areal.utils.logging.getLogger with a stable registered @@ -95,9 +95,7 @@ def get_num_layers_to_build(config, vp_stage=None, pp_rank=None) -> int: # Account for embedding/loss layers if getattr(config, "account_for_embedding_in_pipeline_split", False): - if is_first_pp_stage and ( - vp_stage is None or vp_stage == 0 - ): + if is_first_pp_stage and (vp_stage is None or vp_stage == 0): num_layers_to_build -= 1 if getattr(config, "account_for_loss_in_pipeline_split", False): @@ -119,7 +117,9 @@ def is_moe_layer(tf_config, layer_idx: int) -> bool: elif isinstance(moe_layer_freq, list): return moe_layer_freq[layer_idx] == 1 else: - raise ValueError(f"[R3] Unsupported moe_layer_freq type: {type(moe_layer_freq)}") + raise ValueError( + f"[R3] Unsupported moe_layer_freq type: {type(moe_layer_freq)}" + ) def get_moe_num_layers_to_build(config, vp_stage=None, pp_rank=None) -> int: @@ -203,14 +203,12 @@ def is_replay_backward_action(tf_config, vp_rank=None) -> bool: ) - - def set_router_replay_data( layers_topk_idx: torch.Tensor, cu_seqlens: torch.Tensor, tf_config, - vp_rank: Optional[int] = None, - seq_align_to: Optional[int] = None, + vp_rank: int | None = None, + seq_align_to: int | None = None, ) -> None: """Scatter packed router top-k indices to SP ranks and update RouterReplay instances. @@ -277,7 +275,9 @@ def set_router_replay_data( # For each sequence i, we take the first seq_lens[i] tokens and place # them at aligned positions, with zero-padding for TP alignment gaps. packed = torch.zeros( - total_aligned, num_layers, topk, + total_aligned, + num_layers, + topk, dtype=layers_topk_idx.dtype, device=layers_topk_idx.device, ) @@ -306,9 +306,9 @@ def set_router_replay_data( continue # Take first slen tokens from this sample's routed_experts actual_len = min(slen, layers_topk_idx.shape[1]) - packed[aligned_offset : aligned_offset + actual_len] = ( - layers_topk_idx[i, :actual_len] - ) + packed[aligned_offset : aligned_offset + actual_len] = layers_topk_idx[ + i, :actual_len + ] # Only the real-token span is marked valid; the per-seq # TP-alignment slack (aligned_lens[i] - actual_len) stays False. valid_mask[aligned_offset : aligned_offset + actual_len] = True @@ -346,9 +346,7 @@ def set_router_replay_data( # ``mean_abs_diff`` profile on every micro-batch. # ---------------------------------------------------------------- with torch.no_grad(): - row_all_zero = ( - (packed == 0).reshape(packed.shape[0], -1).all(dim=-1) - ) + row_all_zero = (packed == 0).reshape(packed.shape[0], -1).all(dim=-1) valid_mask = valid_mask & (~row_all_zero) # Step 2: CP split (before TP scatter). @@ -359,6 +357,7 @@ def set_router_replay_data( from areal.engine.megatron_utils.packed_context_parallel import ( split_packed_seqs_for_context_parallel, ) + cu_seqlens_dev = cu_seqlens.to(device) packed = split_packed_seqs_for_context_parallel(packed, cu_seqlens_dev) # Preserve bool semantics: split as int32 then recast. @@ -369,7 +368,10 @@ def set_router_replay_data( # Step 3: Scatter to SP ranks (TP). tp_size = mpu.get_tensor_model_parallel_world_size() if tp_size > 1: - from megatron.core.tensor_parallel import scatter_to_sequence_parallel_region + from megatron.core.tensor_parallel import ( + scatter_to_sequence_parallel_region, + ) + local_tokens = scatter_to_sequence_parallel_region(packed) # Scatter the mask on dim-0 as well. ``scatter_to_sequence_parallel_region`` # expects a tensor with a sequence dimension on dim 0; promote the @@ -396,7 +398,9 @@ def set_router_replay_data( "[R3] set_router_replay_data: no RouterReplay instances found " "for PP offset=%d..%d, vp_rank=%s. " "Total router_instances=%d.", - offset, end, vp_rank, + offset, + end, + vp_rank, len(RouterReplay.router_instances), ) return @@ -416,7 +420,9 @@ def set_router_replay_data( "[R3] set_router_replay_data: router_offset=%d >= " "len(router_list)=%d. Layer assignment mismatch at " "layer_idx=%d.", - router_offset, len(router_list), layer_idx, + router_offset, + len(router_list), + layer_idx, ) break router = router_list[router_offset] @@ -425,7 +431,8 @@ def set_router_replay_data( logger.warning( "[R3] set_router_replay_data: layer index %d >= " "layers_topk dim-0 (%d). Skipping.", - idx, len(layers_topk), + idx, + len(layers_topk), ) moe_idx += 1 router_offset += 1 @@ -445,8 +452,8 @@ def setup_per_microbatch_replay_forward( routed_experts: torch.Tensor, cu_seqlens: torch.Tensor, tf_config, - vp_rank: Optional[int] = None, - seq_align_to: Optional[int] = None, + vp_rank: int | None = None, + seq_align_to: int | None = None, ) -> None: """Set up RouterReplay for a single micro-batch's forward pass. @@ -461,7 +468,10 @@ def setup_per_microbatch_replay_forward( """ routed_experts = routed_experts.to(torch.int32) set_router_replay_data( - routed_experts, cu_seqlens, tf_config, vp_rank, + routed_experts, + cu_seqlens, + tf_config, + vp_rank, seq_align_to=seq_align_to, ) diff --git a/areal/infra/launcher/sglang_r3_patch.py b/areal/infra/launcher/sglang_r3_patch.py index f1cd59be97..e6f372d89b 100644 --- a/areal/infra/launcher/sglang_r3_patch.py +++ b/areal/infra/launcher/sglang_r3_patch.py @@ -1,3 +1,5 @@ +# SPDX-License-Identifier: Apache-2.0 + """SGLang server-side monkey patches for AReaL's R3 (Router Replay). When ``skip_tokenizer_init=True`` (forced by R3), SGLang's ``TokenizerManager`` diff --git a/areal/infra/rpc/guard/engine_blueprint.py b/areal/infra/rpc/guard/engine_blueprint.py index a7ff1079e9..5d90ac6035 100644 --- a/areal/infra/rpc/guard/engine_blueprint.py +++ b/areal/infra/rpc/guard/engine_blueprint.py @@ -535,9 +535,7 @@ def execute_in_engine_thread(): args={"method": method_name, "engine": engine_name}, ): if not hasattr(engine, method_name): - raise ValueError( - f"Engine does not have method '{method_name}'" - ) + raise ValueError(f"Engine does not have method '{method_name}'") method = getattr(engine, method_name) result = method(*args_bcast, **kwargs_bcast) diff --git a/areal/infra/rpc/serialization.py b/areal/infra/rpc/serialization.py index 2c08502b28..7ddc52f3b6 100644 --- a/areal/infra/rpc/serialization.py +++ b/areal/infra/rpc/serialization.py @@ -383,9 +383,7 @@ def to_tokenizer(self) -> Any: # ``AutoTokenizer.from_pretrained`` raises and the caller falls back # to the raw dict, which then explodes inside ``_save_model_to_hf`` # as ``'dict' object has no attribute 'save_pretrained'``. - tokenizer = AutoTokenizer.from_pretrained( - tmpdir, trust_remote_code=True - ) + tokenizer = AutoTokenizer.from_pretrained(tmpdir, trust_remote_code=True) if hasattr(tokenizer, "name_or_path"): tokenizer.name_or_path = self.name_or_path @@ -495,9 +493,7 @@ def to_processor(self) -> Any: with tempfile.TemporaryDirectory() as tmpdir: with zipfile.ZipFile(zip_buffer) as zf: zf.extractall(tmpdir) - processor = AutoProcessor.from_pretrained( - tmpdir, trust_remote_code=True - ) + processor = AutoProcessor.from_pretrained(tmpdir, trust_remote_code=True) if hasattr(processor, "name_or_path"): processor.name_or_path = self.name_or_path diff --git a/areal/infra/scheduler/local.py b/areal/infra/scheduler/local.py index eb5775db38..4522decbcb 100644 --- a/areal/infra/scheduler/local.py +++ b/areal/infra/scheduler/local.py @@ -1788,4 +1788,4 @@ def __del__(self): try: self.delete_workers() except Exception: - pass \ No newline at end of file + pass diff --git a/docs/en/cli_reference.md b/docs/en/cli_reference.md index 33f5f571bc..083573778f 100644 --- a/docs/en/cli_reference.md +++ b/docs/en/cli_reference.md @@ -549,7 +549,7 @@ Configuration for inference servers, including offpolicyness control. | `scheduling_strategy` | [`SchedulingStrategy`](section-scheduling-strategy) | **Required** | The scheduling strategy of this InferenceEngine, either separation or colocation. Currently only used by the RolloutController. | | `use_lora` | boolean | `False` | Whether to use LoRA. Should be same as actors LORA option. | | `agent` | [`AgentConfig`](section-agent) | **Required** | Agent workflow configuration used by inference-service rollouts. | -| `return_routed_experts` | boolean | `False` | Return routed expert indices for MoE models. Effective only when using SGLang engine with MoE models. | +| `return_routed_experts` | boolean | `False` | Return routed expert indices for MoE models. Effective only when using SGLang engine with MoE models. num_moe_layers and topk are automatically resolved from the model config. | | `_version` | string | `"v1"` | Rollout controller implementation version. Use 'v1' for legacy RolloutController, 'v2' for RolloutControllerV2. **Choices:** `v1`, `v2` | | `model` | string | `"default"` | Model name exposed through the inference-service gateway. | | `routing_strategy` | string | `"round_robin"` | Routing strategy for the inference-service router. | @@ -1059,36 +1059,36 @@ Configuration for Megatron-LM training framework. Refer to Megatron-LM documentation for implementation details. -| Parameter | Type | Default | Description | -| ------------------------------------------ | -------------------------------------------------------------------- | ------------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| `wrap_with_ddp` | boolean | `True` | - | -| `use_torch_fsdp2` | boolean | `False` | - | -| `use_custom_fsdp` | boolean | `False` | - | -| `ddp` | [`DistributedDataParallelConfig`](section-distributed-data-parallel) | **Required** | - | -| `virtual_pipeline_parallel_size` | integer | `1` | Virtual pipeline parallel size for Megatron interleaved schedule. Set to >1 to enable VPP. Default is 1 (disabled). | -| `overlap_param_gather_with_optimizer_step` | boolean | `False` | - | -| `use_precision_aware_optimizer` | boolean | `False` | Enable precision-aware optimizer for Megatron. When using adam_bf16 optimizer type with Megatron Engine, this is automatically enabled with exp_avg_dtype=bfloat16 and exp_avg_sq_dtype=bfloat16. | -| `main_grads_dtype` | string | `"float32"` | - | -| `main_params_dtype` | string | `"float32"` | - | -| `exp_avg_dtype` | string | `"float32"` | - | -| `exp_avg_sq_dtype` | string | `"float32"` | - | -| `async_save` | boolean | `False` | - | -| `use_checkpoint_opt_param_scheduler` | boolean | `True` | - | -| `use_deterministic_algorithms` | boolean | `False` | - | -| `recompute_granularity` | string \| None | `"full"` | - | -| `recompute_method` | string \| None | `"uniform"` | - | -| `recompute_num_layers` | integer \| None | `1` | - | -| `distribute_saved_activations` | boolean \| None | `None` | - | -| `recompute_modules` | list of string \| None | `None` | - | -| `moe_router_dtype` | string \| None | `"fp32"` | - | -| `moe_shared_expert_overlap` | boolean | `False` | Enable overlapping between shared expert computations and dispatcher communications. Without this, the shared experts execute after the routed experts. | -| `moe_enable_deepep` | boolean | `False` | - | -| `moe_token_dispatcher_type` | string | `"alltoall"` | Type of token dispatcher. Options: 'allgather','alltoall' and 'flex'. | -| `moe_permute_fusion` | boolean | `False` | Fuse token rearrangement ops during token dispatching. | -| `fp8_config` | [`FP8EngineConfig`](section-fp8-engine) \| None | `None` | - | -| `bridge_type` | string | `"mbridge"` | Bridge backend for MegatronEngine. Choices: 'mbridge' or 'megatron-bridge'. **Choices:** `mbridge`, `megatron-bridge` | -| `use_mbridge_save` | bool | `False` | Use mbridge's save method to save gpu memory when saving weights. | -| `enable_router_replay` | boolean | `False` | Enable Router Replay (R3) for MoE models. When True, the training forward pass replays the expert routing decisions from the inference engine to reduce train-inference routing discrepancy. | +| Parameter | Type | Default | Description | +| ------------------------------------------ | -------------------------------------------------------------------- | ------------ | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `wrap_with_ddp` | boolean | `True` | - | +| `use_torch_fsdp2` | boolean | `False` | - | +| `use_custom_fsdp` | boolean | `False` | - | +| `ddp` | [`DistributedDataParallelConfig`](section-distributed-data-parallel) | **Required** | - | +| `virtual_pipeline_parallel_size` | integer | `1` | Virtual pipeline parallel size for Megatron interleaved schedule. Set to >1 to enable VPP. Default is 1 (disabled). | +| `overlap_param_gather_with_optimizer_step` | boolean | `False` | - | +| `use_precision_aware_optimizer` | boolean | `False` | Enable precision-aware optimizer for Megatron. When using adam_bf16 optimizer type with Megatron Engine, this is automatically enabled with exp_avg_dtype=bfloat16 and exp_avg_sq_dtype=bfloat16. | +| `main_grads_dtype` | string | `"float32"` | - | +| `main_params_dtype` | string | `"float32"` | - | +| `exp_avg_dtype` | string | `"float32"` | - | +| `exp_avg_sq_dtype` | string | `"float32"` | - | +| `async_save` | boolean | `False` | - | +| `use_checkpoint_opt_param_scheduler` | boolean | `True` | - | +| `use_deterministic_algorithms` | boolean | `False` | - | +| `recompute_granularity` | string \| None | `"full"` | - | +| `recompute_method` | string \| None | `"uniform"` | - | +| `recompute_num_layers` | integer \| None | `1` | - | +| `distribute_saved_activations` | boolean \| None | `None` | - | +| `recompute_modules` | list of string \| None | `None` | - | +| `moe_router_dtype` | string \| None | `"fp32"` | - | +| `moe_shared_expert_overlap` | boolean | `False` | Enable overlapping between shared expert computations and dispatcher communications. Without this, the shared experts execute after the routed experts. | +| `moe_enable_deepep` | boolean | `False` | - | +| `moe_token_dispatcher_type` | string | `"alltoall"` | Type of token dispatcher. Options: 'allgather','alltoall' and 'flex'. | +| `moe_permute_fusion` | boolean | `False` | Fuse token rearrangement ops during token dispatching. | +| `fp8_config` | [`FP8EngineConfig`](section-fp8-engine) \| None | `None` | - | +| `bridge_type` | string | `"mbridge"` | Bridge backend for MegatronEngine. Choices: 'mbridge' or 'megatron-bridge'. **Choices:** `mbridge`, `megatron-bridge` | +| `enable_router_replay` | boolean | `False` | Enable Router Replay (R3) for MoE models. When True, the training forward pass replays the expert routing decisions from the inference engine to reduce train-inference routing discrepancy. Automatically set by the trainer when rollout.return_routed_experts=True. | +| `use_mbridge_save` | boolean | `False` | Use mbridge's save method to save gpu memory when saving weights. | (section-memory-profiler)= diff --git a/docs/zh/cli_reference.md b/docs/zh/cli_reference.md index ec5e410de6..32272fffa7 100644 --- a/docs/zh/cli_reference.md +++ b/docs/zh/cli_reference.md @@ -547,7 +547,7 @@ Configuration for inference servers, including offpolicyness control. | `scheduling_strategy` | [`SchedulingStrategy`](section-scheduling-strategy) | **Required** | The scheduling strategy of this InferenceEngine, either separation or colocation. Currently only used by the RolloutController. | | `use_lora` | boolean | `False` | Whether to use LoRA. Should be same as actors LORA option. | | `agent` | [`AgentConfig`](section-agent) | **Required** | Agent workflow configuration used by inference-service rollouts. | -| `return_routed_experts` | boolean | `False` | Return routed expert indices for MoE models. Effective only when using SGLang engine with MoE models. | +| `return_routed_experts` | boolean | `False` | Return routed expert indices for MoE models. Effective only when using SGLang engine with MoE models. num_moe_layers and topk are automatically resolved from the model config. | | `_version` | string | `"v1"` | Rollout controller implementation version. Use 'v1' for legacy RolloutController, 'v2' for RolloutControllerV2. **Choices:** `v1`, `v2` | | `model` | string | `"default"` | Model name exposed through the inference-service gateway. | | `routing_strategy` | string | `"round_robin"` | Routing strategy for the inference-service router. | @@ -1057,36 +1057,36 @@ Configuration for Megatron-LM training framework. Refer to Megatron-LM documentation for implementation details. -| Parameter | Type | Default | Description | -| ------------------------------------------ | -------------------------------------------------------------------- | ------------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| `wrap_with_ddp` | boolean | `True` | - | -| `use_torch_fsdp2` | boolean | `False` | - | -| `use_custom_fsdp` | boolean | `False` | - | -| `ddp` | [`DistributedDataParallelConfig`](section-distributed-data-parallel) | **Required** | - | -| `virtual_pipeline_parallel_size` | integer | `1` | Virtual pipeline parallel size for Megatron interleaved schedule. Set to >1 to enable VPP. Default is 1 (disabled). | -| `overlap_param_gather_with_optimizer_step` | boolean | `False` | - | -| `use_precision_aware_optimizer` | boolean | `False` | Enable precision-aware optimizer for Megatron. When using adam_bf16 optimizer type with Megatron Engine, this is automatically enabled with exp_avg_dtype=bfloat16 and exp_avg_sq_dtype=bfloat16. | -| `main_grads_dtype` | string | `"float32"` | - | -| `main_params_dtype` | string | `"float32"` | - | -| `exp_avg_dtype` | string | `"float32"` | - | -| `exp_avg_sq_dtype` | string | `"float32"` | - | -| `async_save` | boolean | `False` | - | -| `use_checkpoint_opt_param_scheduler` | boolean | `True` | - | -| `use_deterministic_algorithms` | boolean | `False` | - | -| `recompute_granularity` | string \| None | `"full"` | - | -| `recompute_method` | string \| None | `"uniform"` | - | -| `recompute_num_layers` | integer \| None | `1` | - | -| `distribute_saved_activations` | boolean \| None | `None` | - | -| `recompute_modules` | list of string \| None | `None` | - | -| `moe_router_dtype` | string \| None | `"fp32"` | - | -| `moe_shared_expert_overlap` | boolean | `False` | Enable overlapping between shared expert computations and dispatcher communications. Without this, the shared experts execute after the routed experts. | -| `moe_enable_deepep` | boolean | `False` | - | -| `moe_token_dispatcher_type` | string | `"alltoall"` | Type of token dispatcher. Options: 'allgather','alltoall' and 'flex'. | -| `moe_permute_fusion` | boolean | `False` | Fuse token rearrangement ops during token dispatching. | -| `fp8_config` | [`FP8EngineConfig`](section-fp8-engine) \| None | `None` | - | -| `bridge_type` | string | `"mbridge"` | Bridge backend for MegatronEngine. Choices: 'mbridge' or 'megatron-bridge'. **Choices:** `mbridge`, `megatron-bridge` | -| `use_mbridge_save` | bool | `False` | Use mbridge's save method to save gpu memory when saving weights. | -| `enable_router_replay` | boolean | `False` | Enable Router Replay (R3) for MoE models. When True, the training forward pass replays the expert routing decisions from the inference engine to reduce train-inference routing discrepancy. | +| Parameter | Type | Default | Description | +| ------------------------------------------ | -------------------------------------------------------------------- | ------------ | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `wrap_with_ddp` | boolean | `True` | - | +| `use_torch_fsdp2` | boolean | `False` | - | +| `use_custom_fsdp` | boolean | `False` | - | +| `ddp` | [`DistributedDataParallelConfig`](section-distributed-data-parallel) | **Required** | - | +| `virtual_pipeline_parallel_size` | integer | `1` | Virtual pipeline parallel size for Megatron interleaved schedule. Set to >1 to enable VPP. Default is 1 (disabled). | +| `overlap_param_gather_with_optimizer_step` | boolean | `False` | - | +| `use_precision_aware_optimizer` | boolean | `False` | Enable precision-aware optimizer for Megatron. When using adam_bf16 optimizer type with Megatron Engine, this is automatically enabled with exp_avg_dtype=bfloat16 and exp_avg_sq_dtype=bfloat16. | +| `main_grads_dtype` | string | `"float32"` | - | +| `main_params_dtype` | string | `"float32"` | - | +| `exp_avg_dtype` | string | `"float32"` | - | +| `exp_avg_sq_dtype` | string | `"float32"` | - | +| `async_save` | boolean | `False` | - | +| `use_checkpoint_opt_param_scheduler` | boolean | `True` | - | +| `use_deterministic_algorithms` | boolean | `False` | - | +| `recompute_granularity` | string \| None | `"full"` | - | +| `recompute_method` | string \| None | `"uniform"` | - | +| `recompute_num_layers` | integer \| None | `1` | - | +| `distribute_saved_activations` | boolean \| None | `None` | - | +| `recompute_modules` | list of string \| None | `None` | - | +| `moe_router_dtype` | string \| None | `"fp32"` | - | +| `moe_shared_expert_overlap` | boolean | `False` | Enable overlapping between shared expert computations and dispatcher communications. Without this, the shared experts execute after the routed experts. | +| `moe_enable_deepep` | boolean | `False` | - | +| `moe_token_dispatcher_type` | string | `"alltoall"` | Type of token dispatcher. Options: 'allgather','alltoall' and 'flex'. | +| `moe_permute_fusion` | boolean | `False` | Fuse token rearrangement ops during token dispatching. | +| `fp8_config` | [`FP8EngineConfig`](section-fp8-engine) \| None | `None` | - | +| `bridge_type` | string | `"mbridge"` | Bridge backend for MegatronEngine. Choices: 'mbridge' or 'megatron-bridge'. **Choices:** `mbridge`, `megatron-bridge` | +| `enable_router_replay` | boolean | `False` | Enable Router Replay (R3) for MoE models. When True, the training forward pass replays the expert routing decisions from the inference engine to reduce train-inference routing discrepancy. Automatically set by the trainer when rollout.return_routed_experts=True. | +| `use_mbridge_save` | boolean | `False` | Use mbridge's save method to save gpu memory when saving weights. | (section-memory-profiler)= diff --git a/examples/math/gsm8k_grpo_megatron_r3.yaml b/examples/math/gsm8k_grpo_megatron_r3.yaml index 4afca12666..1b802d9194 100644 --- a/examples/math/gsm8k_grpo_megatron_r3.yaml +++ b/examples/math/gsm8k_grpo_megatron_r3.yaml @@ -197,4 +197,4 @@ perf_tracer: fileroot: ${cluster.fileroot} enabled: false session_tracer: - enabled: false \ No newline at end of file + enabled: false diff --git a/tests/test_r3_mask_alignment.py b/tests/test_r3_mask_alignment.py index 56e574cc0d..a886d84f7f 100644 --- a/tests/test_r3_mask_alignment.py +++ b/tests/test_r3_mask_alignment.py @@ -1,4 +1,3 @@ -import pytest import torch from areal.engine.megatron_engine_r3_patch import _align_routed_experts_to_mask diff --git a/tests/test_router_replay.py b/tests/test_router_replay.py index c5558925af..68c2fc8697 100644 --- a/tests/test_router_replay.py +++ b/tests/test_router_replay.py @@ -34,7 +34,6 @@ remove_router_replay_patch, ) - # --------------------------------------------------------------------------- # RouterReplay instance lifecycle # --------------------------------------------------------------------------- @@ -80,9 +79,7 @@ def test_action_toggles(self): def test_set_global_action_broadcasts_to_all(self): a, b, c = RouterReplay(), RouterReplay(), RouterReplay() - RouterReplay.set_global_router_replay_action( - RouterReplayAction.REPLAY_FORWARD - ) + RouterReplay.set_global_router_replay_action(RouterReplayAction.REPLAY_FORWARD) for inst in (a, b, c): assert inst.router_replay_action is RouterReplayAction.REPLAY_FORWARD RouterReplay.clear_global_router_replay_action() @@ -91,9 +88,7 @@ def test_set_global_action_broadcasts_to_all(self): def test_set_replay_data_distributes_in_order(self): instances = [RouterReplay() for _ in range(3)] - per_layer = [ - torch.full((4, 6), i, dtype=torch.int32) for i in range(3) - ] + per_layer = [torch.full((4, 6), i, dtype=torch.int32) for i in range(3)] RouterReplay.set_replay_data(per_layer) for i, inst in enumerate(instances): assert torch.equal(inst.target_topk_idx, per_layer[i]) @@ -132,7 +127,6 @@ def test_apply_is_idempotent_and_sentinel_flips(self): def test_topk_router_init_patched_flag(self): pytest.importorskip("megatron.core") from megatron.core.transformer.moe.router import TopKRouter - from areal.engine import router_replay_patch as rrp remove_router_replay_patch() assert not getattr(TopKRouter, "_r3_init_patched", False) @@ -200,7 +194,9 @@ class TestDispatcherNumOutTokensOverride: must override it with ``routing_map.sum().item()``. """ - def _make_routing_map(self, num_real: int, num_padding: int, num_experts: int, topk: int): + def _make_routing_map( + self, num_real: int, num_padding: int, num_experts: int, topk: int + ): rm = torch.zeros(num_real + num_padding, num_experts, dtype=torch.bool) for i in range(num_real): experts = torch.randperm(num_experts)[:topk] @@ -247,9 +243,7 @@ def test_fp8_padding_keeps_upstream_value(self): rm = self._make_routing_map(5, 5, 64, 6) disp = types.SimpleNamespace( drop_and_pad=False, - config=_FakeMoEConfig( - enable_routing_replay=True, fp8_padding=True, topk=6 - ), + config=_FakeMoEConfig(enable_routing_replay=True, fp8_padding=True, topk=6), num_out_tokens=None, ) _invoke_patched_preprocess(disp, rm, 60) @@ -289,9 +283,7 @@ class _FakeAutoConfig: def from_pretrained(path, trust_remote_code=True): # noqa: ARG004 return fake_config - monkeypatch.setattr( - "transformers.AutoConfig", _FakeAutoConfig, raising=True - ) + monkeypatch.setattr("transformers.AutoConfig", _FakeAutoConfig, raising=True) def test_moonlight_like_config(self, monkeypatch): """Moonlight-16B-A3B: 27 layers (1 dense + 26 MoE), topk=6.""" @@ -383,8 +375,12 @@ def test_moonlight_shape(self): attention_mask = torch.ones(1, seq_len, dtype=torch.long) out = preprocess_routed_experts_batch( - np_arr, input_ids, attention_mask, - num_moe_layers=num_moe, topk=topk, compress_dtype=False, + np_arr, + input_ids, + attention_mask, + num_moe_layers=num_moe, + topk=topk, + compress_dtype=False, ) assert out.shape == (1, seq_len, num_moe, topk) # First num_sgl_tokens rows come from the numpy array @@ -405,7 +401,11 @@ def test_dtype_compression(self): input_ids = torch.zeros(1, 6, dtype=torch.long) attention_mask = torch.ones(1, 6, dtype=torch.long) out = preprocess_routed_experts_batch( - np_arr, input_ids, attention_mask, - num_moe_layers=6, topk=6, compress_dtype=True, + np_arr, + input_ids, + attention_mask, + num_moe_layers=6, + topk=6, + compress_dtype=True, ) assert out.dtype == torch.uint8 # max expert idx < 256 diff --git a/tests/test_router_replay_e2e.py b/tests/test_router_replay_e2e.py index 37f81d8fd6..58162fc72d 100644 --- a/tests/test_router_replay_e2e.py +++ b/tests/test_router_replay_e2e.py @@ -96,8 +96,7 @@ def _run_e2e( proc.wait() stdout = "".join(stdout_lines) if stdout_lines else "" pytest.fail( - f"R3 E2E subprocess timed out after {timeout_sec}s.\n" - f"OUTPUT:\n{stdout}" + f"R3 E2E subprocess timed out after {timeout_sec}s.\nOUTPUT:\n{stdout}" ) with open(output) as f: result = f.read().strip() @@ -223,4 +222,3 @@ def test_r3_e2e_moonlight_pp2_tp4_ep4_forward_backward(tmp_path_factory): test_type="forward_backward", output=str(out), ) - diff --git a/tests/torchrun/run_router_replay_distributed.py b/tests/torchrun/run_router_replay_distributed.py index 9340d0c3cc..8855e35439 100644 --- a/tests/torchrun/run_router_replay_distributed.py +++ b/tests/torchrun/run_router_replay_distributed.py @@ -31,7 +31,6 @@ import torch import torch.distributed as dist -from megatron.core import parallel_state as mpu from areal.api import FinetuneSpec from areal.api.alloc_mode import ModelAllocation @@ -186,9 +185,7 @@ def _build_training_input_with_rollout_experts( bs, slen = input_ids.shape gen = torch.Generator(device="cpu").manual_seed(seed) - rollout_logprobs = ( - -torch.rand((bs, slen), generator=gen) * 2.0 - ).to(engine.device) + rollout_logprobs = (-torch.rand((bs, slen), generator=gen) * 2.0).to(engine.device) advantages = torch.randn((bs, slen), generator=gen).to(engine.device) loss_mask = attention_mask.to(dtype=torch.int64) @@ -245,7 +242,9 @@ def test_patch_plumbing(model_type: str, backend: str, output: str | None): got = len(RouterReplay.router_instances) logger.info( "[R3-E2E] rank=%d expected_moe_layers=%d got_router_instances=%d", - rank, expected, got, + rank, + expected, + got, ) assert got == expected, ( f"RouterReplay.router_instances count ({got}) must match the " @@ -318,9 +317,7 @@ def test_forward_replay(model_type: str, backend: str, output: str | None): engine._r3_pending_routed_experts = routed_experts engine.eval() - _ = engine.forward( - input_=inp, aggregate_fn=lambda xs: torch.cat(xs, dim=0) - ) + _ = engine.forward(input_=inp, aggregate_fn=lambda xs: torch.cat(xs, dim=0)) assert engine._r3_pending_routed_experts is None, ( "_r3_pending_routed_experts should be consumed by the R3 wrapper." @@ -393,15 +390,13 @@ def test_forward_backward(model_type: str, backend: str, output: str | None): ) # Build training input + rollout_expert_indices. - input_dict, rollout_expert_indices = ( - _build_training_input_with_rollout_experts( - engine, - num_moe_layers_total=num_moe, - topk=topk, - batch_size=4, - min_seqlen=16, - max_seqlen=32, - ) + input_dict, rollout_expert_indices = _build_training_input_with_rollout_experts( + engine, + num_moe_layers_total=num_moe, + topk=topk, + batch_size=4, + min_seqlen=16, + max_seqlen=32, ) # Broadcast input across the context+model-parallel group so every