diff --git a/torchtitan/config/job_config.py b/torchtitan/config/job_config.py index d3d8148d63..7702ebcd44 100644 --- a/torchtitan/config/job_config.py +++ b/torchtitan/config/job_config.py @@ -1410,7 +1410,8 @@ class Debug: """Only warns about ops without deterministic implementations rather than erroring out """ moe_force_load_balance: bool = False - """If True, we force each experts to get the same amount of tokens via round-robin. This option is for debugging usage only.""" + """If True, we force each experts to get the same amount of tokens via round-robin. + This option is for debugging usage only.""" enable_nan_tracker: bool = False """If True, enable lightweight NaN/Inf tracking to find where NaN first appears in the model.""" @@ -1418,6 +1419,10 @@ class Debug: nan_tracker_verbose: bool = False """If True, print stats for every layer (very verbose output).""" + moe_routing_noise_std: float = 0.0 + """Standard deviation of Gaussian noise added to MoE routing scores before top-k selection during training. + 0.0 disables noise.""" + @dataclass class JobConfig: diff --git a/torchtitan/models/deepseek_v3/model/args.py b/torchtitan/models/deepseek_v3/model/args.py index d7ed015300..f11c8ee70b 100644 --- a/torchtitan/models/deepseek_v3/model/args.py +++ b/torchtitan/models/deepseek_v3/model/args.py @@ -83,7 +83,9 @@ class DeepSeekV3ModelArgs(BaseModelArgs): beta_fast: int = 32 beta_slow: int = 1 mscale: float = 1.0 - mscale_all_dim: float = 1.0 # When mscale == mscale_all_dim, effective mscale is 1.0 + mscale_all_dim: float = ( + 1.0 # When mscale == mscale_all_dim, effective mscale is 1.0 + ) def update_from_config(self, job_config: JobConfig, **kwargs) -> None: seq_len = job_config.training.seq_len @@ -102,6 +104,7 @@ def update_from_config(self, job_config: JobConfig, **kwargs) -> None: self.moe_args._debug_force_load_balance = ( job_config.debug.moe_force_load_balance ) + self.moe_args.routing_noise_std = job_config.debug.moe_routing_noise_std # Configure expert parallel communication backend from config (defaults to "standard") self.moe_impl = job_config.parallelism.expert_parallel_comm_backend diff --git a/torchtitan/models/moe/moe.py b/torchtitan/models/moe/moe.py index 9caa4a1d9f..f4f89941ad 100644 --- a/torchtitan/models/moe/moe.py +++ b/torchtitan/models/moe/moe.py @@ -85,6 +85,7 @@ class MoEArgs: num_limited_groups: int | None = None use_grouped_mm: bool = True # grouped mm or for-loop for the experts computation load_balance_coeff: float | None = 1e-3 + routing_noise_std: float = 0.0 _debug_force_load_balance: bool = False @@ -610,6 +611,7 @@ def __init__( route_norm: bool, route_scale: float, _debug_force_load_balance: bool = False, + routing_noise_std: float = 0.0, ): super().__init__() self.gate = nn.Linear(dim, num_experts, bias=False) @@ -621,6 +623,7 @@ def __init__( self.route_norm = route_norm self.route_scale = route_scale self._debug_force_load_balance = _debug_force_load_balance + self.routing_noise_std = routing_noise_std def _debug_force_load_balance_routing( self, scores: torch.Tensor @@ -714,6 +717,11 @@ def forward( # Apply node-limited routing if configured if self.num_expert_groups is not None: scores_for_choice = self._get_node_limited_routing_scores(scores_for_choice) + if self.training and self.routing_noise_std > 0.0: + scores_for_choice = ( + scores_for_choice + + torch.randn_like(scores_for_choice) * self.routing_noise_std + ) _, selected_experts_indices = torch.topk( scores_for_choice, k=self.top_k, dim=-1, sorted=False ) @@ -878,6 +886,7 @@ def __init__( route_norm=moe_args.route_norm, route_scale=moe_args.route_scale, _debug_force_load_balance=moe_args._debug_force_load_balance, + routing_noise_std=moe_args.routing_noise_std, ) if peft_config is not None and peft_config.enable_peft: self.router.gate.weight.requires_grad = False @@ -1055,7 +1064,11 @@ def init_weights(self, init_std: float, buffer_device: torch.device, n_layers: i def build_moe( - args: MoEArgs, dim: int, hidden_dim: int, peft_config: PEFT, moe_impl: str = "standard", + args: MoEArgs, + dim: int, + hidden_dim: int, + peft_config: PEFT, + moe_impl: str = "standard", ) -> nn.Module: """Factory for MoE with different backends: 'standard' (all-to-all) or 'deepep' (DeepEP).""" if moe_impl == "deepep": diff --git a/torchtitan/models/qwen3/model/args.py b/torchtitan/models/qwen3/model/args.py index d0a0556bf1..8766d43706 100644 --- a/torchtitan/models/qwen3/model/args.py +++ b/torchtitan/models/qwen3/model/args.py @@ -58,6 +58,7 @@ def update_from_config(self, job_config: JobConfig, **kwargs) -> None: self.moe_args._debug_force_load_balance = ( job_config.debug.moe_force_load_balance ) + self.moe_args.routing_noise_std = job_config.debug.moe_routing_noise_std def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, int]: return get_moe_model_nparams_and_flops(self, model, 2 * self.head_dim, seq_len) diff --git a/torchtitan/train.py b/torchtitan/train.py index 2ff6b3a49c..7bf29420bf 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -637,11 +637,7 @@ def _collect_moe_expert_metrics(self) -> dict[str, Any]: tokens_per_expert_by_layer = torch.stack(moe_counts) - dp_group = None - if self.parallel_dims.dp_cp_enabled: - dp_group = self.parallel_dims.world_mesh["dp_cp"].get_group() - elif self.parallel_dims.dp_enabled: - dp_group = self.parallel_dims.world_mesh["dp"].get_group() + dp_group = self.parallel_dims.get_optional_mesh("loss") if dp_group is not None: torch.distributed.all_reduce(