Skip to content

Commit c84c424

Browse files
Merge pull request #2373 from AI-Hypercomputer:qinwen/update_sharding_moe
PiperOrigin-RevId: 810951610
2 parents 1089b16 + 859abdf commit c84c424

File tree

3 files changed

+25
-5
lines changed

3 files changed

+25
-5
lines changed

src/MaxText/configs/base.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,8 @@ expert_shard_attention_option: "fsdp"
182182

183183
# When MoE weight matrices are sharded on both FSDP and FSDP-transpose axes, use two separate All-Gather calls
184184
moe_fsdp_use_two_stage_all_gather: False
185+
# shard the moe weights on num_expert_dim. this can be performanct when num_expert % fdsp_parallisum
186+
fsdp_shard_on_exp: False
185187

186188
# deepseek moe
187189
base_moe_mlp_dim: 7168 # intermediate dimension at MoE layer. For a fully MoE model, base_mlp_dim must be equal to base_moe_mlp_dim.

src/MaxText/layers/moe.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -300,8 +300,13 @@ def __init__(
300300
self.quant = quant
301301
self.rngs = rngs
302302

303-
self.wi_kernel_axes = ("exp", "embed_no_exp", "mlp")
304-
self.wo_kernel_axes = ("exp", "mlp", "embed_no_exp")
303+
if self.config.fsdp_shard_on_exp:
304+
# special sharding for dsv3
305+
self.wi_kernel_axes = ("embed_no_exp", None, "mlp")
306+
self.wo_kernel_axes = ("embed_no_exp", "mlp", None)
307+
else:
308+
self.wi_kernel_axes = ("exp", "embed_no_exp", "mlp")
309+
self.wo_kernel_axes = ("exp", "mlp", "embed_no_exp")
305310

306311
self.gate = GateLogit(
307312
in_features_shape=self.config.emb_dim,
@@ -427,6 +432,7 @@ def get_topk(self, gate_logits, pre_bias_logits, rngs=None):
427432

428433
return top_k_weights, top_k_indices
429434

435+
430436
def deepseek_scale_weights(self, weights):
431437
"""Scales weights according to DeepSeek's v3 reference implementation."""
432438
# https://github.com/deepseek-ai/DeepSeek-V3/blob/2f7b80eecebf3d1c84da5a0d465f6639ea175012/inference/model.py#L592-L594.
@@ -900,9 +906,15 @@ def gmm(inputs, kernel, group_sizes, expert_assignments):
900906

901907
# w0, w1, wo needs to be un sharded on fsdp / fsdp_transpose axis, so use
902908
# mlp_no_fsdp axis
903-
w0_pspec = nn.logical_to_mesh_axes(("exp", "embed_tensor_transpose", "mlp_no_fsdp"))
904-
w1_pspec = nn.logical_to_mesh_axes(("exp", "embed_tensor_transpose", "mlp_no_fsdp"))
905-
wo_pspec = nn.logical_to_mesh_axes(("exp", "mlp_no_fsdp", "embed_tensor_transpose"))
909+
if self.config.fsdp_shard_on_exp:
910+
# special sharding for dsv3 to remove overhead between gmm/AG
911+
w0_pspec = nn.logical_to_mesh_axes(("embed_tensor_transpose", None, "mlp_no_fsdp"))
912+
w1_pspec = nn.logical_to_mesh_axes(("embed_tensor_transpose", None, "mlp_no_fsdp"))
913+
wo_pspec = nn.logical_to_mesh_axes(("embed_tensor_transpose", "mlp_no_fsdp", None))
914+
else:
915+
w0_pspec = nn.logical_to_mesh_axes(("exp", "embed_tensor_transpose", "mlp_no_fsdp"))
916+
w1_pspec = nn.logical_to_mesh_axes(("exp", "embed_tensor_transpose", "mlp_no_fsdp"))
917+
wo_pspec = nn.logical_to_mesh_axes(("exp", "mlp_no_fsdp", "embed_tensor_transpose"))
906918
if isinstance(w0_kernel, aqt.QTensor):
907919
w0_pspec = aqt.partition_spec(w0_pspec, (1,), w0_kernel.dtype, use_bias=False)
908920
if isinstance(w1_kernel, aqt.QTensor):

src/MaxText/pyconfig.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,7 @@ def validate_keys(keys):
219219
validate_mlp_dim(keys)
220220
validate_sparse_matmul_parallelism(keys)
221221
validate_ring_of_experts_parallelism(keys)
222+
validate_shard_fsdp_on_expert_parallelism(keys)
222223
validate_ragged_dot(keys)
223224
validate_deepseek_moe(keys)
224225
validate_expert_shard_attention_option(keys["expert_shard_attention_option"])
@@ -1050,6 +1051,11 @@ def validate_ring_of_experts_parallelism(raw_keys):
10501051
if raw_keys["use_ring_of_experts"] and not using_expert_parallelism(raw_keys):
10511052
raise ValueError("Ring-of-experts requires expert-parallelism to be enabled.")
10521053

1054+
def validate_shard_fsdp_on_expert_parallelism(raw_keys):
1055+
if raw_keys["fsdp_shard_on_exp"] and raw_keys["num_experts"] % raw_keys["ici_fsdp_parallelism"]!=0:
1056+
raise ValueError("fsdp_shard_on_exp requires num_experts is divisiable by ici_fsdp_parallelism.")
1057+
if raw_keys["fsdp_shard_on_exp"] and (using_tensor_parallelism(raw_keys) or useing_expert_parallelism(raw_keys)):
1058+
raise ValueError("fsdp_shard_on_exp requires ici_expert_parallelism = 1 and ici_tensor_parallelism/ici_tensor_transpose_parallelism = 1.")
10531059

10541060
def validate_ragged_dot(raw_keys):
10551061
if raw_keys["sparse_matmul"] and not raw_keys["megablox"]:

0 commit comments

Comments
 (0)