Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion torchtitan/config/job_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1410,14 +1410,19 @@ class Debug:
"""Only warns about ops without deterministic implementations rather than erroring out """

moe_force_load_balance: bool = False
"""If True, we force each experts to get the same amount of tokens via round-robin. This option is for debugging usage only."""
"""If True, we force each experts to get the same amount of tokens via round-robin.
This option is for debugging usage only."""

enable_nan_tracker: bool = False
"""If True, enable lightweight NaN/Inf tracking to find where NaN first appears in the model."""

nan_tracker_verbose: bool = False
"""If True, print stats for every layer (very verbose output)."""

moe_routing_noise_std: float = 0.0
"""Standard deviation of Gaussian noise added to MoE routing scores before top-k selection during training.
0.0 disables noise."""


@dataclass
class JobConfig:
Expand Down
5 changes: 4 additions & 1 deletion torchtitan/models/deepseek_v3/model/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,9 @@ class DeepSeekV3ModelArgs(BaseModelArgs):
beta_fast: int = 32
beta_slow: int = 1
mscale: float = 1.0
mscale_all_dim: float = 1.0 # When mscale == mscale_all_dim, effective mscale is 1.0
mscale_all_dim: float = (
1.0 # When mscale == mscale_all_dim, effective mscale is 1.0
)

def update_from_config(self, job_config: JobConfig, **kwargs) -> None:
seq_len = job_config.training.seq_len
Expand All @@ -102,6 +104,7 @@ def update_from_config(self, job_config: JobConfig, **kwargs) -> None:
self.moe_args._debug_force_load_balance = (
job_config.debug.moe_force_load_balance
)
self.moe_args.routing_noise_std = job_config.debug.moe_routing_noise_std

# Configure expert parallel communication backend from config (defaults to "standard")
self.moe_impl = job_config.parallelism.expert_parallel_comm_backend
Expand Down
15 changes: 14 additions & 1 deletion torchtitan/models/moe/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ class MoEArgs:
num_limited_groups: int | None = None
use_grouped_mm: bool = True # grouped mm or for-loop for the experts computation
load_balance_coeff: float | None = 1e-3
routing_noise_std: float = 0.0

_debug_force_load_balance: bool = False

Expand Down Expand Up @@ -610,6 +611,7 @@ def __init__(
route_norm: bool,
route_scale: float,
_debug_force_load_balance: bool = False,
routing_noise_std: float = 0.0,
):
super().__init__()
self.gate = nn.Linear(dim, num_experts, bias=False)
Expand All @@ -621,6 +623,7 @@ def __init__(
self.route_norm = route_norm
self.route_scale = route_scale
self._debug_force_load_balance = _debug_force_load_balance
self.routing_noise_std = routing_noise_std

def _debug_force_load_balance_routing(
self, scores: torch.Tensor
Expand Down Expand Up @@ -714,6 +717,11 @@ def forward(
# Apply node-limited routing if configured
if self.num_expert_groups is not None:
scores_for_choice = self._get_node_limited_routing_scores(scores_for_choice)
if self.training and self.routing_noise_std > 0.0:
scores_for_choice = (
scores_for_choice
+ torch.randn_like(scores_for_choice) * self.routing_noise_std
)
_, selected_experts_indices = torch.topk(
scores_for_choice, k=self.top_k, dim=-1, sorted=False
)
Expand Down Expand Up @@ -878,6 +886,7 @@ def __init__(
route_norm=moe_args.route_norm,
route_scale=moe_args.route_scale,
_debug_force_load_balance=moe_args._debug_force_load_balance,
routing_noise_std=moe_args.routing_noise_std,
)
if peft_config is not None and peft_config.enable_peft:
self.router.gate.weight.requires_grad = False
Expand Down Expand Up @@ -1055,7 +1064,11 @@ def init_weights(self, init_std: float, buffer_device: torch.device, n_layers: i


def build_moe(
args: MoEArgs, dim: int, hidden_dim: int, peft_config: PEFT, moe_impl: str = "standard",
args: MoEArgs,
dim: int,
hidden_dim: int,
peft_config: PEFT,
moe_impl: str = "standard",
) -> nn.Module:
"""Factory for MoE with different backends: 'standard' (all-to-all) or 'deepep' (DeepEP)."""
if moe_impl == "deepep":
Expand Down
1 change: 1 addition & 0 deletions torchtitan/models/qwen3/model/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def update_from_config(self, job_config: JobConfig, **kwargs) -> None:
self.moe_args._debug_force_load_balance = (
job_config.debug.moe_force_load_balance
)
self.moe_args.routing_noise_std = job_config.debug.moe_routing_noise_std

def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, int]:
return get_moe_model_nparams_and_flops(self, model, 2 * self.head_dim, seq_len)
6 changes: 1 addition & 5 deletions torchtitan/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,11 +637,7 @@ def _collect_moe_expert_metrics(self) -> dict[str, Any]:

tokens_per_expert_by_layer = torch.stack(moe_counts)

dp_group = None
if self.parallel_dims.dp_cp_enabled:
dp_group = self.parallel_dims.world_mesh["dp_cp"].get_group()
elif self.parallel_dims.dp_enabled:
dp_group = self.parallel_dims.world_mesh["dp"].get_group()
dp_group = self.parallel_dims.get_optional_mesh("loss")

if dp_group is not None:
torch.distributed.all_reduce(
Expand Down