@@ -300,8 +300,13 @@ def __init__(
300300 self .quant = quant
301301 self .rngs = rngs
302302
303- self .wi_kernel_axes = ("exp" , "embed_no_exp" , "mlp" )
304- self .wo_kernel_axes = ("exp" , "mlp" , "embed_no_exp" )
303+ if self .config .fsdp_shard_on_exp :
304+ # special sharding for dsv3
305+ self .wi_kernel_axes = ("embed_no_exp" , None , "mlp" )
306+ self .wo_kernel_axes = ("embed_no_exp" , "mlp" , None )
307+ else :
308+ self .wi_kernel_axes = ("exp" , "embed_no_exp" , "mlp" )
309+ self .wo_kernel_axes = ("exp" , "mlp" , "embed_no_exp" )
305310
306311 self .gate = GateLogit (
307312 in_features_shape = self .config .emb_dim ,
@@ -427,6 +432,7 @@ def get_topk(self, gate_logits, pre_bias_logits, rngs=None):
427432
428433 return top_k_weights , top_k_indices
429434
435+
430436 def deepseek_scale_weights (self , weights ):
431437 """Scales weights according to DeepSeek's v3 reference implementation."""
432438 # https://github.com/deepseek-ai/DeepSeek-V3/blob/2f7b80eecebf3d1c84da5a0d465f6639ea175012/inference/model.py#L592-L594.
@@ -900,9 +906,15 @@ def gmm(inputs, kernel, group_sizes, expert_assignments):
900906
901907 # w0, w1, wo needs to be un sharded on fsdp / fsdp_transpose axis, so use
902908 # mlp_no_fsdp axis
903- w0_pspec = nn .logical_to_mesh_axes (("exp" , "embed_tensor_transpose" , "mlp_no_fsdp" ))
904- w1_pspec = nn .logical_to_mesh_axes (("exp" , "embed_tensor_transpose" , "mlp_no_fsdp" ))
905- wo_pspec = nn .logical_to_mesh_axes (("exp" , "mlp_no_fsdp" , "embed_tensor_transpose" ))
909+ if self .config .fsdp_shard_on_exp :
910+ # special sharding for dsv3 to remove overhead between gmm/AG
911+ w0_pspec = nn .logical_to_mesh_axes (("embed_tensor_transpose" , None , "mlp_no_fsdp" ))
912+ w1_pspec = nn .logical_to_mesh_axes (("embed_tensor_transpose" , None , "mlp_no_fsdp" ))
913+ wo_pspec = nn .logical_to_mesh_axes (("embed_tensor_transpose" , "mlp_no_fsdp" , None ))
914+ else :
915+ w0_pspec = nn .logical_to_mesh_axes (("exp" , "embed_tensor_transpose" , "mlp_no_fsdp" ))
916+ w1_pspec = nn .logical_to_mesh_axes (("exp" , "embed_tensor_transpose" , "mlp_no_fsdp" ))
917+ wo_pspec = nn .logical_to_mesh_axes (("exp" , "mlp_no_fsdp" , "embed_tensor_transpose" ))
906918 if isinstance (w0_kernel , aqt .QTensor ):
907919 w0_pspec = aqt .partition_spec (w0_pspec , (1 ,), w0_kernel .dtype , use_bias = False )
908920 if isinstance (w1_kernel , aqt .QTensor ):
0 commit comments