diff --git a/src/MaxText/common_types.py b/src/MaxText/common_types.py index fd7185a83f..4a3b5abc9f 100644 --- a/src/MaxText/common_types.py +++ b/src/MaxText/common_types.py @@ -32,8 +32,8 @@ BATCH = "activation_batch" BATCH_NO_EXP = "activation_batch_no_exp" +LENGTH_WITH_EXP = "activation_length_with_exp" LENGTH = "activation_length" -LENGTH_NO_EXP = "activation_length_no_exp" PREFILL_LENGTH = "prefill_activation_length" Q_LENGTH = "activation_q_length" Q_LENGTH_NO_EXP = "activation_q_length_no_exp" diff --git a/src/MaxText/configs/base.yml b/src/MaxText/configs/base.yml index 4d5510d115..c5cae68d3a 100644 --- a/src/MaxText/configs/base.yml +++ b/src/MaxText/configs/base.yml @@ -337,10 +337,10 @@ logical_axis_rules: [ ['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert']], ['activation_heads', ['tensor', 'tensor_transpose', 'sequence','tensor_sequence','autoregressive']], ['activation_kv_heads', ['tensor', 'tensor_transpose', 'sequence','tensor_sequence']], - ['activation_length', ['sequence', 'context', 'expert']], - ['activation_length', ['context', 'expert']], - ['activation_length_no_exp', ['sequence', 'context']], - ['activation_length_no_exp', ['context']], + ['activation_length_with_exp', ['sequence', 'context', 'expert']], + ['activation_length_with_exp', ['context', 'expert']], + ['activation_length', ['sequence', 'context']], + ['activation_length', ['context']], ['activation_norm_length', ['tensor_sequence', 'context', 'sequence']], ['activation_q_length', ['context', 'expert']], ['activation_q_length_no_exp', ['context']], diff --git a/src/MaxText/layers/attention_mla.py b/src/MaxText/layers/attention_mla.py index e7cb3652c6..fcbf2e26ab 100644 --- a/src/MaxText/layers/attention_mla.py +++ b/src/MaxText/layers/attention_mla.py @@ -42,8 +42,8 @@ KV_BATCH_NO_EXP, KV_HEAD, KV_HEAD_DIM, + LENGTH_WITH_EXP, LENGTH, - LENGTH_NO_EXP, MODEL_MODE_PREFILL, MODEL_MODE_TRAIN, PREFILL_KV_BATCH, @@ -100,16 +100,16 @@ def mla_as_linen( prefill_query_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM), prefill_key_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM), prefill_value_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM), - query_axis_names: AxisNames = (KV_BATCH, LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM), - key_axis_names: AxisNames = (KV_BATCH, LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM), - value_axis_names: AxisNames = (KV_BATCH, LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM), - ep_query_axis_names: AxisNames = (KV_BATCH_NO_EXP, LENGTH, KV_HEAD, KV_HEAD_DIM), - ep_key_axis_names: AxisNames = (KV_BATCH_NO_EXP, LENGTH, KV_HEAD, KV_HEAD_DIM), - ep_value_axis_names: AxisNames = (KV_BATCH_NO_EXP, LENGTH, KV_HEAD, KV_HEAD_DIM), - input_axis_names: AxisNames = (BATCH, LENGTH_NO_EXP, EMBED), - ep_input_axis_names: AxisNames = (BATCH_NO_EXP, LENGTH, EMBED), - out_axis_names: AxisNames = (BATCH, LENGTH_NO_EXP, HEAD, D_KV), - ep_out_axis_names: AxisNames = (BATCH_NO_EXP, LENGTH, HEAD, D_KV), + query_axis_names: AxisNames = (KV_BATCH, LENGTH, KV_HEAD, KV_HEAD_DIM), + key_axis_names: AxisNames = (KV_BATCH, LENGTH, KV_HEAD, KV_HEAD_DIM), + value_axis_names: AxisNames = (KV_BATCH, LENGTH, KV_HEAD, KV_HEAD_DIM), + ep_query_axis_names: AxisNames = (KV_BATCH_NO_EXP, LENGTH_WITH_EXP, KV_HEAD, KV_HEAD_DIM), + ep_key_axis_names: AxisNames = (KV_BATCH_NO_EXP, LENGTH_WITH_EXP, KV_HEAD, KV_HEAD_DIM), + ep_value_axis_names: AxisNames = (KV_BATCH_NO_EXP, LENGTH_WITH_EXP, KV_HEAD, KV_HEAD_DIM), + input_axis_names: AxisNames = (BATCH, LENGTH, EMBED), + ep_input_axis_names: AxisNames = (BATCH_NO_EXP, LENGTH_WITH_EXP, EMBED), + out_axis_names: AxisNames = (BATCH, LENGTH, HEAD, D_KV), + ep_out_axis_names: AxisNames = (BATCH_NO_EXP, LENGTH_WITH_EXP, HEAD, D_KV), prefill_input_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, EMBED), decode_input_axis_names: AxisNames = (DECODE_BATCH, DECODE_LENGTH, EMBED), prefill_out_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, HEAD, D_KV), @@ -248,16 +248,16 @@ def __init__( prefill_query_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM), prefill_key_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM), prefill_value_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM), - query_axis_names: AxisNames = (KV_BATCH, LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM), - key_axis_names: AxisNames = (KV_BATCH, LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM), - value_axis_names: AxisNames = (KV_BATCH, LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM), - ep_query_axis_names: AxisNames = (KV_BATCH_NO_EXP, LENGTH, KV_HEAD, KV_HEAD_DIM), - ep_key_axis_names: AxisNames = (KV_BATCH_NO_EXP, LENGTH, KV_HEAD, KV_HEAD_DIM), - ep_value_axis_names: AxisNames = (KV_BATCH_NO_EXP, LENGTH, KV_HEAD, KV_HEAD_DIM), - input_axis_names: AxisNames = (BATCH, LENGTH_NO_EXP, EMBED), - ep_input_axis_names: AxisNames = (BATCH_NO_EXP, LENGTH, EMBED), - out_axis_names: AxisNames = (BATCH, LENGTH_NO_EXP, HEAD, D_KV), - ep_out_axis_names: AxisNames = (BATCH_NO_EXP, LENGTH, HEAD, D_KV), + query_axis_names: AxisNames = (KV_BATCH, LENGTH, KV_HEAD, KV_HEAD_DIM), + key_axis_names: AxisNames = (KV_BATCH, LENGTH, KV_HEAD, KV_HEAD_DIM), + value_axis_names: AxisNames = (KV_BATCH, LENGTH, KV_HEAD, KV_HEAD_DIM), + ep_query_axis_names: AxisNames = (KV_BATCH_NO_EXP, LENGTH_WITH_EXP, KV_HEAD, KV_HEAD_DIM), + ep_key_axis_names: AxisNames = (KV_BATCH_NO_EXP, LENGTH_WITH_EXP, KV_HEAD, KV_HEAD_DIM), + ep_value_axis_names: AxisNames = (KV_BATCH_NO_EXP, LENGTH_WITH_EXP, KV_HEAD, KV_HEAD_DIM), + input_axis_names: AxisNames = (BATCH, LENGTH, EMBED), + ep_input_axis_names: AxisNames = (BATCH_NO_EXP, LENGTH_WITH_EXP, EMBED), + out_axis_names: AxisNames = (BATCH, LENGTH, HEAD, D_KV), + ep_out_axis_names: AxisNames = (BATCH_NO_EXP, LENGTH_WITH_EXP, HEAD, D_KV), prefill_input_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, EMBED), decode_input_axis_names: AxisNames = (DECODE_BATCH, DECODE_LENGTH, EMBED), prefill_out_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, HEAD, D_KV), @@ -647,9 +647,7 @@ def mla_kv_projection(self, inputs: Array, inputs_positions: Array, decoder_segm if self.config.mla_naive_kvcache: cached_values = self.update_kv_caches(key, value, decoder_segment_ids, model_mode, previous_chunk) else: - cached_values = self.update_mla_kv_caches( - low_rank_main, key_rope, decoder_segment_ids, model_mode, previous_chunk - ) + cached_values = self.update_mla_kv_caches(low_rank_main, key_rope, decoder_segment_ids, model_mode, previous_chunk) return key, value, cached_values diff --git a/src/MaxText/layers/attention_op.py b/src/MaxText/layers/attention_op.py index 75cc5cf3cb..351145bcc5 100644 --- a/src/MaxText/layers/attention_op.py +++ b/src/MaxText/layers/attention_op.py @@ -57,8 +57,8 @@ CACHE_SCALE_HEADS, CACHE_SCALE_KV, AxisIdxes, + LENGTH_WITH_EXP, LENGTH, - LENGTH_NO_EXP, DType, Config, Array, @@ -290,12 +290,12 @@ def attention_op_as_linen( float32_qk_product: bool = False, max_prefill_predict_length: int = -1, float32_logits: bool = False, - flash_axis_names_q: AxisNames = (BATCH, HEAD, LENGTH_NO_EXP, D_KV), - flash_axis_names_q_ep: AxisNames = (BATCH_NO_EXP, HEAD, LENGTH, D_KV), + flash_axis_names_q: AxisNames = (BATCH, HEAD, LENGTH, D_KV), + flash_axis_names_q_ep: AxisNames = (BATCH_NO_EXP, HEAD, LENGTH_WITH_EXP, D_KV), flash_axis_names_kv: AxisNames = (BATCH, HEAD, KV_LENGTH, D_KV), flash_axis_names_kv_ep: AxisNames = (BATCH_NO_EXP, HEAD, KV_LENGTH, D_KV), - flash_axis_names_splash_kernel: AxisNames = (HEAD, LENGTH_NO_EXP), - flash_axis_names_splash_kernel_ep: AxisNames = (HEAD, LENGTH), + flash_axis_names_splash_kernel: AxisNames = (HEAD, LENGTH), + flash_axis_names_splash_kernel_ep: AxisNames = (HEAD, LENGTH_WITH_EXP), prefill_cache_logical_axis_names: AxisNames = (CACHE_BATCH_PREFILL, CACHE_SEQUENCE, CACHE_HEADS, CACHE_KV), cache_logical_axis_names: AxisNames = (CACHE_BATCH, CACHE_SEQUENCE, CACHE_HEADS, CACHE_KV), cache_scale_logical_axis_names: AxisNames = ( @@ -378,12 +378,12 @@ def __init__( float32_qk_product: bool = False, max_prefill_predict_length: int = -1, float32_logits: bool = False, - flash_axis_names_q: AxisNames = (BATCH, HEAD, LENGTH_NO_EXP, D_KV), - flash_axis_names_q_ep: AxisNames = (BATCH_NO_EXP, HEAD, LENGTH, D_KV), + flash_axis_names_q: AxisNames = (BATCH, HEAD, LENGTH, D_KV), + flash_axis_names_q_ep: AxisNames = (BATCH_NO_EXP, HEAD, LENGTH_WITH_EXP, D_KV), flash_axis_names_kv: AxisNames = (BATCH, HEAD, KV_LENGTH, D_KV), flash_axis_names_kv_ep: AxisNames = (BATCH_NO_EXP, HEAD, KV_LENGTH, D_KV), - flash_axis_names_splash_kernel: AxisNames = (HEAD, LENGTH_NO_EXP), - flash_axis_names_splash_kernel_ep: AxisNames = (HEAD, LENGTH), + flash_axis_names_splash_kernel: AxisNames = (HEAD, LENGTH), + flash_axis_names_splash_kernel_ep: AxisNames = (HEAD, LENGTH_WITH_EXP), prefill_cache_logical_axis_names: AxisNames = (CACHE_BATCH_PREFILL, CACHE_SEQUENCE, CACHE_HEADS, CACHE_KV), cache_logical_axis_names: AxisNames = (CACHE_BATCH, CACHE_SEQUENCE, CACHE_HEADS, CACHE_KV), cache_scale_logical_axis_names: AxisNames = ( diff --git a/src/MaxText/layers/attentions.py b/src/MaxText/layers/attentions.py index 40698320d8..017101afa8 100644 --- a/src/MaxText/layers/attentions.py +++ b/src/MaxText/layers/attentions.py @@ -35,8 +35,8 @@ D_KV, AxisNames, AxisIdxes, + LENGTH_WITH_EXP, LENGTH, - LENGTH_NO_EXP, DType, Config, Array, @@ -139,16 +139,16 @@ def attention_as_linen( prefill_query_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM), prefill_key_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM), prefill_value_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM), - query_axis_names: AxisNames = (KV_BATCH, LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM), - key_axis_names: AxisNames = (KV_BATCH, LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM), - value_axis_names: AxisNames = (KV_BATCH, LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM), - ep_query_axis_names: AxisNames = (KV_BATCH_NO_EXP, LENGTH, KV_HEAD, KV_HEAD_DIM), - ep_key_axis_names: AxisNames = (KV_BATCH_NO_EXP, LENGTH, KV_HEAD, KV_HEAD_DIM), - ep_value_axis_names: AxisNames = (KV_BATCH_NO_EXP, LENGTH, KV_HEAD, KV_HEAD_DIM), - input_axis_names: AxisNames = (BATCH, LENGTH_NO_EXP, EMBED), - ep_input_axis_names: AxisNames = (BATCH_NO_EXP, LENGTH, EMBED), - out_axis_names: AxisNames = (BATCH, LENGTH_NO_EXP, HEAD, D_KV), - ep_out_axis_names: AxisNames = (BATCH_NO_EXP, LENGTH, HEAD, D_KV), + query_axis_names: AxisNames = (KV_BATCH, LENGTH, KV_HEAD, KV_HEAD_DIM), + key_axis_names: AxisNames = (KV_BATCH, LENGTH, KV_HEAD, KV_HEAD_DIM), + value_axis_names: AxisNames = (KV_BATCH, LENGTH, KV_HEAD, KV_HEAD_DIM), + ep_query_axis_names: AxisNames = (KV_BATCH_NO_EXP, LENGTH_WITH_EXP, KV_HEAD, KV_HEAD_DIM), + ep_key_axis_names: AxisNames = (KV_BATCH_NO_EXP, LENGTH_WITH_EXP, KV_HEAD, KV_HEAD_DIM), + ep_value_axis_names: AxisNames = (KV_BATCH_NO_EXP, LENGTH_WITH_EXP, KV_HEAD, KV_HEAD_DIM), + input_axis_names: AxisNames = (BATCH, LENGTH, EMBED), + ep_input_axis_names: AxisNames = (BATCH_NO_EXP, LENGTH_WITH_EXP, EMBED), + out_axis_names: AxisNames = (BATCH, LENGTH, HEAD, D_KV), + ep_out_axis_names: AxisNames = (BATCH_NO_EXP, LENGTH_WITH_EXP, HEAD, D_KV), prefill_input_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, EMBED), decode_input_axis_names: AxisNames = (DECODE_BATCH, DECODE_LENGTH, EMBED), prefill_out_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, HEAD, D_KV), @@ -298,16 +298,16 @@ def __init__( prefill_query_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM), prefill_key_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM), prefill_value_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM), - query_axis_names: AxisNames = (KV_BATCH, LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM), - key_axis_names: AxisNames = (KV_BATCH, LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM), - value_axis_names: AxisNames = (KV_BATCH, LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM), - ep_query_axis_names: AxisNames = (KV_BATCH_NO_EXP, LENGTH, KV_HEAD, KV_HEAD_DIM), - ep_key_axis_names: AxisNames = (KV_BATCH_NO_EXP, LENGTH, KV_HEAD, KV_HEAD_DIM), - ep_value_axis_names: AxisNames = (KV_BATCH_NO_EXP, LENGTH, KV_HEAD, KV_HEAD_DIM), - input_axis_names: AxisNames = (BATCH, LENGTH_NO_EXP, EMBED), - ep_input_axis_names: AxisNames = (BATCH_NO_EXP, LENGTH, EMBED), - out_axis_names: AxisNames = (BATCH, LENGTH_NO_EXP, HEAD, D_KV), - ep_out_axis_names: AxisNames = (BATCH_NO_EXP, LENGTH, HEAD, D_KV), + query_axis_names: AxisNames = (KV_BATCH, LENGTH, KV_HEAD, KV_HEAD_DIM), + key_axis_names: AxisNames = (KV_BATCH, LENGTH, KV_HEAD, KV_HEAD_DIM), + value_axis_names: AxisNames = (KV_BATCH, LENGTH, KV_HEAD, KV_HEAD_DIM), + ep_query_axis_names: AxisNames = (KV_BATCH_NO_EXP, LENGTH_WITH_EXP, KV_HEAD, KV_HEAD_DIM), + ep_key_axis_names: AxisNames = (KV_BATCH_NO_EXP, LENGTH_WITH_EXP, KV_HEAD, KV_HEAD_DIM), + ep_value_axis_names: AxisNames = (KV_BATCH_NO_EXP, LENGTH_WITH_EXP, KV_HEAD, KV_HEAD_DIM), + input_axis_names: AxisNames = (BATCH, LENGTH, EMBED), + ep_input_axis_names: AxisNames = (BATCH_NO_EXP, LENGTH_WITH_EXP, EMBED), + out_axis_names: AxisNames = (BATCH, LENGTH, HEAD, D_KV), + ep_out_axis_names: AxisNames = (BATCH_NO_EXP, LENGTH_WITH_EXP, HEAD, D_KV), prefill_input_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, EMBED), decode_input_axis_names: AxisNames = (DECODE_BATCH, DECODE_LENGTH, EMBED), prefill_out_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, HEAD, D_KV), diff --git a/src/MaxText/layers/moe.py b/src/MaxText/layers/moe.py index dfcaf0992a..35b67fea92 100644 --- a/src/MaxText/layers/moe.py +++ b/src/MaxText/layers/moe.py @@ -856,19 +856,19 @@ def gmm(inputs, kernel, group_sizes, expert_assignments): batch_logical_axis = "activation_batch_no_exp" if self.get_tensor_transpose_parallelism_size() > 1: - input_partition_pspec = nn.logical_to_mesh_axes((batch_logical_axis, "activation_norm_length", "activation_embed")) + input_partition_pspec = nn.logical_to_mesh_axes((batch_logical_axis, "activation_length", "activation_embed")) w0_bias_pspec = nn.logical_to_mesh_axes(("exp", None)) w1_bias_pspec = nn.logical_to_mesh_axes(("exp", None)) wo_bias_pspec = nn.logical_to_mesh_axes(("exp", "activation_embed")) else: - input_partition_pspec = nn.logical_to_mesh_axes((batch_logical_axis, "activation_norm_length", None)) + input_partition_pspec = nn.logical_to_mesh_axes((batch_logical_axis, "activation_length", None)) w0_bias_pspec = nn.logical_to_mesh_axes(("exp", "activation_mlp")) w1_bias_pspec = nn.logical_to_mesh_axes(("exp", "activation_mlp")) wo_bias_pspec = nn.logical_to_mesh_axes(("exp", "activation_embed")) - gate_logits_pspec = nn.logical_to_mesh_axes((batch_logical_axis, "activation_norm_length", None)) + gate_logits_pspec = nn.logical_to_mesh_axes((batch_logical_axis, "activation_length", None)) if self.config.model_name.startswith("deepseek3"): - pre_bias_logits_pspec = nn.logical_to_mesh_axes((batch_logical_axis, "activation_norm_length", None)) + pre_bias_logits_pspec = nn.logical_to_mesh_axes((batch_logical_axis, "activation_length", None)) else: # pre_bias_logits is None for non-DeepSeek v3 models pre_bias_logits_pspec = None @@ -896,7 +896,7 @@ def gmm(inputs, kernel, group_sizes, expert_assignments): w1_bias_pspec, wo_bias_pspec, ), - out_specs=(nn.logical_to_mesh_axes((batch_logical_axis, "activation_norm_length", "activation_embed"))), + out_specs=(nn.logical_to_mesh_axes((batch_logical_axis, "activation_length", "activation_embed"))), check_rep=False, ) def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias): @@ -1135,7 +1135,7 @@ def generate_masks_subgroup(self, top_k_indices, softmax_probs): ) expert_token_count = nn.with_logical_constraint( expert_token_count, - ("activation_batch", "activation_norm_length", None, None, None), + ("activation_batch", "activation_length", None, None, None), ) trunc_expert_mask = expert_mask * jnp.less_equal(expert_token_count, expert_capacity_per_batch) combined_expert_mask = jnp.sum(trunc_expert_mask, axis=3) @@ -1223,7 +1223,7 @@ def generate_masks(self, top_k_indices, softmax_probs): ) expert_token_count = nn.with_logical_constraint( expert_token_count, - ("activation_batch", "activation_norm_length", None, None), + ("activation_batch", "activation_length", None, None), ) trunc_expert_mask = expert_mask * jnp.less_equal(expert_token_count, expert_capacity_per_batch) combined_expert_mask = jnp.sum(trunc_expert_mask, axis=2) @@ -1321,10 +1321,10 @@ def dense_matmul( ) -> tuple[jax.Array, Optional[jax.Array]]: """Dense matrix multiplication.""" # gate_logits: batch, length, expert - gate_logits = nn.with_logical_constraint(gate_logits, ("activation_batch", "activation_norm_length", None)) + gate_logits = nn.with_logical_constraint(gate_logits, ("activation_batch", "activation_length", None)) if self.config.model_name.startswith("deepseek3"): # pre_bias_logits is None for non-DeepSeek v3 models - pre_bias_logits = nn.with_logical_constraint(pre_bias_logits, ("activation_batch", "activation_norm_length", None)) + pre_bias_logits = nn.with_logical_constraint(pre_bias_logits, ("activation_batch", "activation_length", None)) top_k_weights, top_k_indices = self.get_topk(gate_logits, pre_bias_logits) is_llama4_decoder_layer = self.config.decoder_block == ctypes.DecoderBlockType.LLAMA4 if is_llama4_decoder_layer: @@ -1351,7 +1351,7 @@ def dense_matmul( dispatch_mask, combine_mask = self.generate_masks( top_k_indices, weights # pylint: disable=undefined-variable,possibly-used-before-assignment ) - mask_axes = ("activation_batch", "activation_norm_length", None, None) + mask_axes = ("activation_batch", "activation_length", None, None) dispatch_axis = ( "activation_exp", "activation_batch_no_exp", @@ -1375,14 +1375,14 @@ def dense_matmul( dispatch_mask, combine_mask = self.generate_masks_subgroup(top_k_indices, softmax_probs) if self.get_context_autoregressive_parallelism_size() > 0 and cp == 1: mask_axes = ( - "activation_norm_length", + "activation_length", "activation_batch", None, None, None, ) input_axis = ( - "activation_norm_length", + "activation_length", "activation_batch", None, "activation_embed", @@ -1404,14 +1404,14 @@ def dense_matmul( else: mask_axes = ( "activation_batch", - "activation_norm_length", + "activation_length", None, None, None, ) input_axis = ( "activation_batch", - "activation_norm_length", + "activation_length", None, "activation_embed", ) @@ -1451,7 +1451,7 @@ def dense_matmul( ( None, "activation_batch_no_exp", - "activation_norm_length", + "activation_length", None, "activation_embed", ), @@ -1540,7 +1540,7 @@ def dense_matmul( ) return output, loss else: - inputs = nn.with_logical_constraint(inputs, ("activation_batch", "activation_norm_length", "activation_embed")) + inputs = nn.with_logical_constraint(inputs, ("activation_batch", "activation_length", "activation_embed")) with jax.named_scope("wi_0"): layer_w0 = self.get_einsum(rhs_mesh_axes=self.wi_kernel_axes)( "BSM,EMH -> BSEH", inputs, w0_kernel, precision=matmul_precision diff --git a/src/MaxText/layers/qwen3.py b/src/MaxText/layers/qwen3.py index 74bdfdce41..e10066403a 100644 --- a/src/MaxText/layers/qwen3.py +++ b/src/MaxText/layers/qwen3.py @@ -60,7 +60,7 @@ def self_attention_with_norm( epsilon=cfg.normalization_layer_epsilon, kernel_axes=("norm",), )(inputs_checkpoint) - lnx = nn.with_logical_constraint(lnx, ("activation_batch", "activation_length", "activation_embed")) + lnx = nn.with_logical_constraint(lnx, ("activation_batch", "activation_norm_length", "activation_embed")) # Self-attention block attention_layer = attentions.attention_as_linen( @@ -94,7 +94,7 @@ def self_attention_with_norm( model_mode=model_mode, ) attention_output = nn.with_logical_constraint( - attention_output, ("activation_batch", "activation_length", "activation_embed") + attention_output, ("activation_batch", "activation_norm_length", "activation_embed") ) # Residual connection after attention @@ -109,7 +109,9 @@ def self_attention_with_norm( epsilon=cfg.normalization_layer_epsilon, kernel_axes=("norm",), )(residual_after_attention) - hidden_states = nn.with_logical_constraint(hidden_states, ("activation_batch", "activation_length", "activation_embed")) + hidden_states = nn.with_logical_constraint( + hidden_states, ("activation_batch", "activation_norm_length", "activation_embed") + ) return hidden_states, residual_after_attention @@ -167,7 +169,7 @@ def __call__( layer_output = residual_after_attention + mlp_output layer_output = nn.with_logical_constraint( layer_output, - ("activation_batch", "activation_length", "activation_embed"), + ("activation_batch", "activation_norm_length", "activation_embed"), ) if cfg.scan_layers: @@ -230,13 +232,13 @@ def __call__( if load_balance_loss is not None: self.sow("intermediates", "moe_lb_loss", load_balance_loss) - mlp_output = nn.with_logical_constraint(mlp_output, ("activation_batch", "activation_length", "activation_embed")) + mlp_output = nn.with_logical_constraint(mlp_output, ("activation_batch", "activation_norm_length", "activation_embed")) # Final residual connection layer_output = residual_after_attention + mlp_output layer_output = nn.with_logical_constraint( layer_output, - ("activation_batch", "activation_length", "activation_embed"), + ("activation_batch", "activation_norm_length", "activation_embed"), ) if cfg.scan_layers: diff --git a/src/MaxText/pyconfig.py b/src/MaxText/pyconfig.py index 09a32af1dd..e4a06d14dd 100644 --- a/src/MaxText/pyconfig.py +++ b/src/MaxText/pyconfig.py @@ -1177,12 +1177,12 @@ def using_tensor_parallelism(raw_keys) -> bool: def using_sequence_parallelism(raw_keys) -> bool: - if int(raw_keys["ici_expert_parallelism"]) > 1 and int(raw_keys["dcn_expert_parallelism"]) > 1: - raise ValueError("Expert parallelism can only be enabled on ICI or DCN, not both.") return int(raw_keys["ici_sequence_parallelism"]) > 1 or int(raw_keys["dcn_sequence_parallelism"]) > 1 def using_expert_parallelism(raw_keys) -> bool: + if int(raw_keys["ici_expert_parallelism"]) > 1 and int(raw_keys["dcn_expert_parallelism"]) > 1: + raise ValueError("Expert parallelism can only be enabled on ICI or DCN, not both.") return int(raw_keys["ici_expert_parallelism"]) > 1 or int(raw_keys["dcn_expert_parallelism"]) > 1