Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/run_tests_internal.yml
Original file line number Diff line number Diff line change
Expand Up @@ -67,4 +67,4 @@ jobs:
FINAL_PYTEST_MARKER="${{ inputs.pytest_marker }} and not scheduled_only"
fi
python3 -m pip install -e . --no-dependencies &&
python3 -m pytest -v -m "${FINAL_PYTEST_MARKER}" --durations=0
LIBTPU_INIT_ARGS='--xla_tpu_scoped_vmem_limit_kib=65536' python3 -m pytest -v -m "${FINAL_PYTEST_MARKER}" --durations=0
6 changes: 6 additions & 0 deletions src/MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,8 @@ expert_shard_attention_option: "fsdp"

# When MoE weight matrices are sharded on both FSDP and FSDP-transpose axes, use two separate All-Gather calls
moe_fsdp_use_two_stage_all_gather: False
# shard the moe weights on num_expert_dim. this can be performanct when num_expert % fdsp_parallisum
fsdp_shard_on_exp: False

# deepseek moe
base_moe_mlp_dim: 7168 # intermediate dimension at MoE layer. For a fully MoE model, base_mlp_dim must be equal to base_moe_mlp_dim.
Expand Down Expand Up @@ -385,10 +387,12 @@ logical_axis_rules: [
['q_lora', ['fsdp', 'sequence', 'context', 'tensor_transpose', 'expert']],
['q_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
['q_lora', ['fsdp', 'sequence', 'context', 'expert']],
["q_lora_up_proj",[]],
['kv_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'tensor_transpose', 'expert']],
['kv_lora', ['fsdp', 'sequence', 'context', 'tensor_transpose', 'expert']],
['kv_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
['kv_lora', ['fsdp', 'sequence', 'context', 'expert']],
["kv_lora_up_proj",[]],
['norm', ['tensor', 'tensor_transpose']],
['layers', 'stage'],
['kv', []],
Expand All @@ -405,6 +409,8 @@ logical_axis_rules: [
['num_pages', []],
['tokens_per_page', []],
['paged_kv_head_dim_size', []],
['dense_layers', []],
['moe_layers', []],
]
# Axes used for DCN must be earlier in this list than ICI, see (b/339009148) for details
data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']]
Expand Down
4 changes: 2 additions & 2 deletions src/MaxText/layers/attention_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,7 @@ def _init_projections(self, inputs_q_shape: Tuple, inputs_kv_shape: Tuple) -> No
out_features_shape=self.q_lora_rank,
axis=-1,
kernel_init=self.kernel_init,
kernel_axes=("embed", "q_lora"),
kernel_axes=("embed", "q_lora_up_proj"),
Copy link
Collaborator

@RissyRan RissyRan Sep 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part is helpful for all MLA models or only DS V3?

cc @richjames0 another case.

dtype=self.dtype,
weight_dtype=self.weight_dtype,
quant=self.quant,
Expand Down Expand Up @@ -432,7 +432,7 @@ def _init_projections(self, inputs_q_shape: Tuple, inputs_kv_shape: Tuple) -> No
out_features_shape=self.kv_lora_rank + self.qk_rope_head_dim,
axis=-1,
kernel_init=self.kernel_init,
kernel_axes=("embed", "kv_lora"),
kernel_axes=("embed", "kv_lora_up_proj"),
dtype=self.dtype,
weight_dtype=self.weight_dtype,
quant=self.quant,
Expand Down
22 changes: 17 additions & 5 deletions src/MaxText/layers/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,8 +300,13 @@ def __init__(
self.quant = quant
self.rngs = rngs

self.wi_kernel_axes = ("exp", "embed_no_exp", "mlp")
self.wo_kernel_axes = ("exp", "mlp", "embed_no_exp")
if self.config.fsdp_shard_on_exp:
# special sharding for dsv3
self.wi_kernel_axes = ("embed_no_exp", None, "mlp")
self.wo_kernel_axes = ("embed_no_exp", "mlp", None)
else:
self.wi_kernel_axes = ("exp", "embed_no_exp", "mlp")
self.wo_kernel_axes = ("exp", "mlp", "embed_no_exp")

self.gate = GateLogit(
in_features_shape=self.config.emb_dim,
Expand Down Expand Up @@ -427,6 +432,7 @@ def get_topk(self, gate_logits, pre_bias_logits, rngs=None):

return top_k_weights, top_k_indices


def deepseek_scale_weights(self, weights):
"""Scales weights according to DeepSeek's v3 reference implementation."""
# https://github.com/deepseek-ai/DeepSeek-V3/blob/2f7b80eecebf3d1c84da5a0d465f6639ea175012/inference/model.py#L592-L594.
Expand Down Expand Up @@ -900,9 +906,15 @@ def gmm(inputs, kernel, group_sizes, expert_assignments):

# w0, w1, wo needs to be un sharded on fsdp / fsdp_transpose axis, so use
# mlp_no_fsdp axis
w0_pspec = nn.logical_to_mesh_axes(("exp", "embed_tensor_transpose", "mlp_no_fsdp"))
w1_pspec = nn.logical_to_mesh_axes(("exp", "embed_tensor_transpose", "mlp_no_fsdp"))
wo_pspec = nn.logical_to_mesh_axes(("exp", "mlp_no_fsdp", "embed_tensor_transpose"))
if self.config.fsdp_shard_on_exp:
# special sharding for dsv3 to remove overhead between gmm/AG
w0_pspec = nn.logical_to_mesh_axes(("embed_tensor_transpose", None, "mlp_no_fsdp"))
w1_pspec = nn.logical_to_mesh_axes(("embed_tensor_transpose", None, "mlp_no_fsdp"))
wo_pspec = nn.logical_to_mesh_axes(("embed_tensor_transpose", "mlp_no_fsdp", None))
else:
w0_pspec = nn.logical_to_mesh_axes(("exp", "embed_tensor_transpose", "mlp_no_fsdp"))
w1_pspec = nn.logical_to_mesh_axes(("exp", "embed_tensor_transpose", "mlp_no_fsdp"))
wo_pspec = nn.logical_to_mesh_axes(("exp", "mlp_no_fsdp", "embed_tensor_transpose"))
if isinstance(w0_kernel, aqt.QTensor):
w0_pspec = aqt.partition_spec(w0_pspec, (1,), w0_kernel.dtype, use_bias=False)
if isinstance(w1_kernel, aqt.QTensor):
Expand Down
6 changes: 6 additions & 0 deletions src/MaxText/pyconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ def validate_keys(keys):
validate_mlp_dim(keys)
validate_sparse_matmul_parallelism(keys)
validate_ring_of_experts_parallelism(keys)
validate_shard_fsdp_on_expert_parallelism(keys)
validate_ragged_dot(keys)
validate_deepseek_moe(keys)
validate_expert_shard_attention_option(keys["expert_shard_attention_option"])
Expand Down Expand Up @@ -1049,6 +1050,11 @@ def validate_ring_of_experts_parallelism(raw_keys):
if raw_keys["use_ring_of_experts"] and not using_expert_parallelism(raw_keys):
raise ValueError("Ring-of-experts requires expert-parallelism to be enabled.")

def validate_shard_fsdp_on_expert_parallelism(raw_keys):
if raw_keys["fsdp_shard_on_exp"] and raw_keys["num_experts"] % raw_keys["ici_fsdp_parallelism"]!=0:
raise ValueError("fsdp_shard_on_exp requires num_experts is divisiable by ici_fsdp_parallelism.")
if raw_keys["fsdp_shard_on_exp"] and (using_tensor_parallelism(raw_keys) or useing_expert_parallelism(raw_keys)):
raise ValueError("fsdp_shard_on_exp requires ici_expert_parallelism = 1 and ici_tensor_parallelism/ici_tensor_transpose_parallelism = 1.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: probably just say fsdp_shard_on_exp does not support EP and TP shardings?


def validate_ragged_dot(raw_keys):
if raw_keys["sparse_matmul"] and not raw_keys["megablox"]:
Expand Down
17 changes: 16 additions & 1 deletion tests/attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

import numpy as np

from jax.sharding import Mesh
from jax.sharding import Mesh, NamedSharding, PartitionSpec
import jax
import jax.numpy as jnp

Expand Down Expand Up @@ -1383,6 +1383,21 @@ def _forward_with_context_expert_parallelism(cfg_cp, mesh_cp, attention_cp, lnx,
decoder_positions = reordered_batch["inputs_position"]
# apply attention with sharding
with mesh_cp, nn_partitioning.axis_rules(cfg_cp.logical_axis_rules):
lnx_spec = nn_partitioning.logical_to_mesh_axes(
('activation_batch_no_exp', 'activation_length_no_exp', 'activation_embed'),
nn_partitioning.get_axis_rules()
)
pos_spec = nn_partitioning.logical_to_mesh_axes(
('activation_batch_no_exp', 'activation_length_no_exp'),
nn_partitioning.get_axis_rules()
)
lnx_sharding = NamedSharding(mesh_cp, lnx_spec)
pos_sharding = NamedSharding(mesh_cp, pos_spec)

lnx = jax.device_put(lnx, lnx_sharding)
decoder_segment_ids = jax.device_put(decoder_segment_ids, pos_sharding)
decoder_positions = jax.device_put(decoder_positions, pos_sharding)

attention_cp_output = attention_cp(
lnx,
lnx,
Expand Down
10 changes: 5 additions & 5 deletions tests/train_compile_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,7 +493,7 @@ def test_moe_deepseek_scanned_bf16(self):
"megablox=False",
"per_device_batch_size=2",
"max_target_length=1024",
"attention=dot_product", # Change to flash attention once it works for MLA
"attention=flash",
"dtype=bfloat16",
"weight_dtype=bfloat16",
"scan_layers=True",
Expand All @@ -518,7 +518,7 @@ def test_moe_deepseek_unscanned_bf16(self):
"megablox=False",
"per_device_batch_size=1",
"max_target_length=1024",
"attention=dot_product", # Change to flash attention once it works for MLA
"attention=flash",
"dtype=bfloat16",
"weight_dtype=bfloat16",
"scan_layers=False",
Expand All @@ -541,7 +541,7 @@ def test_moe_deepseek_with_device_limit(self):
"megablox=False",
"per_device_batch_size=1",
"max_target_length=1024",
"attention=dot_product", # Change to flash attention once it works for MLA
"attention=flash",
"dtype=bfloat16",
"weight_dtype=bfloat16",
"n_routing_groups=8",
Expand All @@ -565,7 +565,7 @@ def test_moe_deepseek_without_device_limit(self):
"megablox=False",
"per_device_batch_size=1",
"max_target_length=1024",
"attention=dot_product", # Change to flash attention once it works for MLA
"attention=flash",
"dtype=bfloat16",
"weight_dtype=bfloat16",
"n_routing_groups=-1",
Expand All @@ -585,7 +585,7 @@ def test_moe_deepseek_pipeline_subset(self):
"compile_topology_num_slices=8",
"use_iota_embed=true",
"model_name=deepseek3-671b",
"megablox=False", # dropless not yet supported (b/418313093)
"megablox=True",
"sparse_matmul=False",
"capacity_factor=1",
"per_device_batch_size=1",
Expand Down
Loading