diff --git a/.github/workflows/run_tests_internal.yml b/.github/workflows/run_tests_internal.yml index 7fa9a72865..8132a3a13d 100644 --- a/.github/workflows/run_tests_internal.yml +++ b/.github/workflows/run_tests_internal.yml @@ -67,4 +67,4 @@ jobs: FINAL_PYTEST_MARKER="${{ inputs.pytest_marker }} and not scheduled_only" fi python3 -m pip install -e . --no-dependencies && - python3 -m pytest -v -m "${FINAL_PYTEST_MARKER}" --durations=0 + LIBTPU_INIT_ARGS='--xla_tpu_scoped_vmem_limit_kib=65536' python3 -m pytest -v -m "${FINAL_PYTEST_MARKER}" --durations=0 diff --git a/src/MaxText/configs/base.yml b/src/MaxText/configs/base.yml index 472673af9f..3a1c7bf7d1 100644 --- a/src/MaxText/configs/base.yml +++ b/src/MaxText/configs/base.yml @@ -182,6 +182,8 @@ expert_shard_attention_option: "fsdp" # When MoE weight matrices are sharded on both FSDP and FSDP-transpose axes, use two separate All-Gather calls moe_fsdp_use_two_stage_all_gather: False +# shard the moe weights on num_expert_dim. this can be performanct when num_expert % fdsp_parallisum +fsdp_shard_on_exp: False # deepseek moe base_moe_mlp_dim: 7168 # intermediate dimension at MoE layer. For a fully MoE model, base_mlp_dim must be equal to base_moe_mlp_dim. @@ -385,10 +387,12 @@ logical_axis_rules: [ ['q_lora', ['fsdp', 'sequence', 'context', 'tensor_transpose', 'expert']], ['q_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']], ['q_lora', ['fsdp', 'sequence', 'context', 'expert']], + ["q_lora_up_proj",[]], ['kv_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'tensor_transpose', 'expert']], ['kv_lora', ['fsdp', 'sequence', 'context', 'tensor_transpose', 'expert']], ['kv_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']], ['kv_lora', ['fsdp', 'sequence', 'context', 'expert']], + ["kv_lora_up_proj",[]], ['norm', ['tensor', 'tensor_transpose']], ['layers', 'stage'], ['kv', []], @@ -405,6 +409,8 @@ logical_axis_rules: [ ['num_pages', []], ['tokens_per_page', []], ['paged_kv_head_dim_size', []], + ['dense_layers', []], + ['moe_layers', []], ] # Axes used for DCN must be earlier in this list than ICI, see (b/339009148) for details data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']] diff --git a/src/MaxText/layers/attention_mla.py b/src/MaxText/layers/attention_mla.py index 70ce3cb0c0..812df5ba73 100644 --- a/src/MaxText/layers/attention_mla.py +++ b/src/MaxText/layers/attention_mla.py @@ -398,7 +398,7 @@ def _init_projections(self, inputs_q_shape: Tuple, inputs_kv_shape: Tuple) -> No out_features_shape=self.q_lora_rank, axis=-1, kernel_init=self.kernel_init, - kernel_axes=("embed", "q_lora"), + kernel_axes=("embed", "q_lora_up_proj"), dtype=self.dtype, weight_dtype=self.weight_dtype, quant=self.quant, @@ -432,7 +432,7 @@ def _init_projections(self, inputs_q_shape: Tuple, inputs_kv_shape: Tuple) -> No out_features_shape=self.kv_lora_rank + self.qk_rope_head_dim, axis=-1, kernel_init=self.kernel_init, - kernel_axes=("embed", "kv_lora"), + kernel_axes=("embed", "kv_lora_up_proj"), dtype=self.dtype, weight_dtype=self.weight_dtype, quant=self.quant, diff --git a/src/MaxText/layers/moe.py b/src/MaxText/layers/moe.py index 23eb545d7d..fbc1d0a830 100644 --- a/src/MaxText/layers/moe.py +++ b/src/MaxText/layers/moe.py @@ -300,8 +300,13 @@ def __init__( self.quant = quant self.rngs = rngs - self.wi_kernel_axes = ("exp", "embed_no_exp", "mlp") - self.wo_kernel_axes = ("exp", "mlp", "embed_no_exp") + if self.config.fsdp_shard_on_exp: + # special sharding for dsv3 + self.wi_kernel_axes = ("embed_no_exp", None, "mlp") + self.wo_kernel_axes = ("embed_no_exp", "mlp", None) + else: + self.wi_kernel_axes = ("exp", "embed_no_exp", "mlp") + self.wo_kernel_axes = ("exp", "mlp", "embed_no_exp") self.gate = GateLogit( in_features_shape=self.config.emb_dim, @@ -427,6 +432,7 @@ def get_topk(self, gate_logits, pre_bias_logits, rngs=None): return top_k_weights, top_k_indices + def deepseek_scale_weights(self, weights): """Scales weights according to DeepSeek's v3 reference implementation.""" # https://github.com/deepseek-ai/DeepSeek-V3/blob/2f7b80eecebf3d1c84da5a0d465f6639ea175012/inference/model.py#L592-L594. @@ -900,9 +906,15 @@ def gmm(inputs, kernel, group_sizes, expert_assignments): # w0, w1, wo needs to be un sharded on fsdp / fsdp_transpose axis, so use # mlp_no_fsdp axis - w0_pspec = nn.logical_to_mesh_axes(("exp", "embed_tensor_transpose", "mlp_no_fsdp")) - w1_pspec = nn.logical_to_mesh_axes(("exp", "embed_tensor_transpose", "mlp_no_fsdp")) - wo_pspec = nn.logical_to_mesh_axes(("exp", "mlp_no_fsdp", "embed_tensor_transpose")) + if self.config.fsdp_shard_on_exp: + # special sharding for dsv3 to remove overhead between gmm/AG + w0_pspec = nn.logical_to_mesh_axes(("embed_tensor_transpose", None, "mlp_no_fsdp")) + w1_pspec = nn.logical_to_mesh_axes(("embed_tensor_transpose", None, "mlp_no_fsdp")) + wo_pspec = nn.logical_to_mesh_axes(("embed_tensor_transpose", "mlp_no_fsdp", None)) + else: + w0_pspec = nn.logical_to_mesh_axes(("exp", "embed_tensor_transpose", "mlp_no_fsdp")) + w1_pspec = nn.logical_to_mesh_axes(("exp", "embed_tensor_transpose", "mlp_no_fsdp")) + wo_pspec = nn.logical_to_mesh_axes(("exp", "mlp_no_fsdp", "embed_tensor_transpose")) if isinstance(w0_kernel, aqt.QTensor): w0_pspec = aqt.partition_spec(w0_pspec, (1,), w0_kernel.dtype, use_bias=False) if isinstance(w1_kernel, aqt.QTensor): diff --git a/src/MaxText/pyconfig.py b/src/MaxText/pyconfig.py index 05356e1996..6ed1aac1d1 100644 --- a/src/MaxText/pyconfig.py +++ b/src/MaxText/pyconfig.py @@ -219,6 +219,7 @@ def validate_keys(keys): validate_mlp_dim(keys) validate_sparse_matmul_parallelism(keys) validate_ring_of_experts_parallelism(keys) + validate_shard_fsdp_on_expert_parallelism(keys) validate_ragged_dot(keys) validate_deepseek_moe(keys) validate_expert_shard_attention_option(keys["expert_shard_attention_option"]) @@ -1049,6 +1050,11 @@ def validate_ring_of_experts_parallelism(raw_keys): if raw_keys["use_ring_of_experts"] and not using_expert_parallelism(raw_keys): raise ValueError("Ring-of-experts requires expert-parallelism to be enabled.") +def validate_shard_fsdp_on_expert_parallelism(raw_keys): + if raw_keys["fsdp_shard_on_exp"] and raw_keys["num_experts"] % raw_keys["ici_fsdp_parallelism"]!=0: + raise ValueError("fsdp_shard_on_exp requires num_experts is divisiable by ici_fsdp_parallelism.") + if raw_keys["fsdp_shard_on_exp"] and (using_tensor_parallelism(raw_keys) or useing_expert_parallelism(raw_keys)): + raise ValueError("fsdp_shard_on_exp requires ici_expert_parallelism = 1 and ici_tensor_parallelism/ici_tensor_transpose_parallelism = 1.") def validate_ragged_dot(raw_keys): if raw_keys["sparse_matmul"] and not raw_keys["megablox"]: diff --git a/tests/attention_test.py b/tests/attention_test.py index 183cdf96fe..3f9895d171 100644 --- a/tests/attention_test.py +++ b/tests/attention_test.py @@ -26,7 +26,7 @@ import numpy as np -from jax.sharding import Mesh +from jax.sharding import Mesh, NamedSharding, PartitionSpec import jax import jax.numpy as jnp @@ -1383,6 +1383,21 @@ def _forward_with_context_expert_parallelism(cfg_cp, mesh_cp, attention_cp, lnx, decoder_positions = reordered_batch["inputs_position"] # apply attention with sharding with mesh_cp, nn_partitioning.axis_rules(cfg_cp.logical_axis_rules): + lnx_spec = nn_partitioning.logical_to_mesh_axes( + ('activation_batch_no_exp', 'activation_length_no_exp', 'activation_embed'), + nn_partitioning.get_axis_rules() + ) + pos_spec = nn_partitioning.logical_to_mesh_axes( + ('activation_batch_no_exp', 'activation_length_no_exp'), + nn_partitioning.get_axis_rules() + ) + lnx_sharding = NamedSharding(mesh_cp, lnx_spec) + pos_sharding = NamedSharding(mesh_cp, pos_spec) + + lnx = jax.device_put(lnx, lnx_sharding) + decoder_segment_ids = jax.device_put(decoder_segment_ids, pos_sharding) + decoder_positions = jax.device_put(decoder_positions, pos_sharding) + attention_cp_output = attention_cp( lnx, lnx, diff --git a/tests/train_compile_test.py b/tests/train_compile_test.py index 06498e8685..90dc92a832 100644 --- a/tests/train_compile_test.py +++ b/tests/train_compile_test.py @@ -493,7 +493,7 @@ def test_moe_deepseek_scanned_bf16(self): "megablox=False", "per_device_batch_size=2", "max_target_length=1024", - "attention=dot_product", # Change to flash attention once it works for MLA + "attention=flash", "dtype=bfloat16", "weight_dtype=bfloat16", "scan_layers=True", @@ -518,7 +518,7 @@ def test_moe_deepseek_unscanned_bf16(self): "megablox=False", "per_device_batch_size=1", "max_target_length=1024", - "attention=dot_product", # Change to flash attention once it works for MLA + "attention=flash", "dtype=bfloat16", "weight_dtype=bfloat16", "scan_layers=False", @@ -541,7 +541,7 @@ def test_moe_deepseek_with_device_limit(self): "megablox=False", "per_device_batch_size=1", "max_target_length=1024", - "attention=dot_product", # Change to flash attention once it works for MLA + "attention=flash", "dtype=bfloat16", "weight_dtype=bfloat16", "n_routing_groups=8", @@ -565,7 +565,7 @@ def test_moe_deepseek_without_device_limit(self): "megablox=False", "per_device_batch_size=1", "max_target_length=1024", - "attention=dot_product", # Change to flash attention once it works for MLA + "attention=flash", "dtype=bfloat16", "weight_dtype=bfloat16", "n_routing_groups=-1", @@ -585,7 +585,7 @@ def test_moe_deepseek_pipeline_subset(self): "compile_topology_num_slices=8", "use_iota_embed=true", "model_name=deepseek3-671b", - "megablox=False", # dropless not yet supported (b/418313093) + "megablox=True", "sparse_matmul=False", "capacity_factor=1", "per_device_batch_size=1",