Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/MaxText/common_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@

BATCH = "activation_batch"
BATCH_NO_EXP = "activation_batch_no_exp"
LENGTH_WITH_EXP = "activation_length_with_exp"
LENGTH = "activation_length"
LENGTH_NO_EXP = "activation_length_no_exp"
PREFILL_LENGTH = "prefill_activation_length"
Q_LENGTH = "activation_q_length"
Q_LENGTH_NO_EXP = "activation_q_length_no_exp"
Expand Down
8 changes: 4 additions & 4 deletions src/MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -337,10 +337,10 @@ logical_axis_rules: [
['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert']],
['activation_heads', ['tensor', 'tensor_transpose', 'sequence','tensor_sequence','autoregressive']],
['activation_kv_heads', ['tensor', 'tensor_transpose', 'sequence','tensor_sequence']],
['activation_length', ['sequence', 'context', 'expert']],
['activation_length', ['context', 'expert']],
['activation_length_no_exp', ['sequence', 'context']],
['activation_length_no_exp', ['context']],
['activation_length_with_exp', ['sequence', 'context', 'expert']],
['activation_length_with_exp', ['context', 'expert']],
['activation_length', ['sequence', 'context']],
['activation_length', ['context']],
['activation_norm_length', ['tensor_sequence', 'context', 'sequence']],
['activation_q_length', ['context', 'expert']],
['activation_q_length_no_exp', ['context']],
Expand Down
46 changes: 22 additions & 24 deletions src/MaxText/layers/attention_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@
KV_BATCH_NO_EXP,
KV_HEAD,
KV_HEAD_DIM,
LENGTH_WITH_EXP,
LENGTH,
LENGTH_NO_EXP,
MODEL_MODE_PREFILL,
MODEL_MODE_TRAIN,
PREFILL_KV_BATCH,
Expand Down Expand Up @@ -100,16 +100,16 @@ def mla_as_linen(
prefill_query_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM),
prefill_key_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM),
prefill_value_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM),
query_axis_names: AxisNames = (KV_BATCH, LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM),
key_axis_names: AxisNames = (KV_BATCH, LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM),
value_axis_names: AxisNames = (KV_BATCH, LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM),
ep_query_axis_names: AxisNames = (KV_BATCH_NO_EXP, LENGTH, KV_HEAD, KV_HEAD_DIM),
ep_key_axis_names: AxisNames = (KV_BATCH_NO_EXP, LENGTH, KV_HEAD, KV_HEAD_DIM),
ep_value_axis_names: AxisNames = (KV_BATCH_NO_EXP, LENGTH, KV_HEAD, KV_HEAD_DIM),
input_axis_names: AxisNames = (BATCH, LENGTH_NO_EXP, EMBED),
ep_input_axis_names: AxisNames = (BATCH_NO_EXP, LENGTH, EMBED),
out_axis_names: AxisNames = (BATCH, LENGTH_NO_EXP, HEAD, D_KV),
ep_out_axis_names: AxisNames = (BATCH_NO_EXP, LENGTH, HEAD, D_KV),
query_axis_names: AxisNames = (KV_BATCH, LENGTH, KV_HEAD, KV_HEAD_DIM),
key_axis_names: AxisNames = (KV_BATCH, LENGTH, KV_HEAD, KV_HEAD_DIM),
value_axis_names: AxisNames = (KV_BATCH, LENGTH, KV_HEAD, KV_HEAD_DIM),
ep_query_axis_names: AxisNames = (KV_BATCH_NO_EXP, LENGTH_WITH_EXP, KV_HEAD, KV_HEAD_DIM),
ep_key_axis_names: AxisNames = (KV_BATCH_NO_EXP, LENGTH_WITH_EXP, KV_HEAD, KV_HEAD_DIM),
ep_value_axis_names: AxisNames = (KV_BATCH_NO_EXP, LENGTH_WITH_EXP, KV_HEAD, KV_HEAD_DIM),
input_axis_names: AxisNames = (BATCH, LENGTH, EMBED),
ep_input_axis_names: AxisNames = (BATCH_NO_EXP, LENGTH_WITH_EXP, EMBED),
out_axis_names: AxisNames = (BATCH, LENGTH, HEAD, D_KV),
ep_out_axis_names: AxisNames = (BATCH_NO_EXP, LENGTH_WITH_EXP, HEAD, D_KV),
prefill_input_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, EMBED),
decode_input_axis_names: AxisNames = (DECODE_BATCH, DECODE_LENGTH, EMBED),
prefill_out_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, HEAD, D_KV),
Expand Down Expand Up @@ -248,16 +248,16 @@ def __init__(
prefill_query_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM),
prefill_key_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM),
prefill_value_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM),
query_axis_names: AxisNames = (KV_BATCH, LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM),
key_axis_names: AxisNames = (KV_BATCH, LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM),
value_axis_names: AxisNames = (KV_BATCH, LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM),
ep_query_axis_names: AxisNames = (KV_BATCH_NO_EXP, LENGTH, KV_HEAD, KV_HEAD_DIM),
ep_key_axis_names: AxisNames = (KV_BATCH_NO_EXP, LENGTH, KV_HEAD, KV_HEAD_DIM),
ep_value_axis_names: AxisNames = (KV_BATCH_NO_EXP, LENGTH, KV_HEAD, KV_HEAD_DIM),
input_axis_names: AxisNames = (BATCH, LENGTH_NO_EXP, EMBED),
ep_input_axis_names: AxisNames = (BATCH_NO_EXP, LENGTH, EMBED),
out_axis_names: AxisNames = (BATCH, LENGTH_NO_EXP, HEAD, D_KV),
ep_out_axis_names: AxisNames = (BATCH_NO_EXP, LENGTH, HEAD, D_KV),
query_axis_names: AxisNames = (KV_BATCH, LENGTH, KV_HEAD, KV_HEAD_DIM),
key_axis_names: AxisNames = (KV_BATCH, LENGTH, KV_HEAD, KV_HEAD_DIM),
value_axis_names: AxisNames = (KV_BATCH, LENGTH, KV_HEAD, KV_HEAD_DIM),
ep_query_axis_names: AxisNames = (KV_BATCH_NO_EXP, LENGTH_WITH_EXP, KV_HEAD, KV_HEAD_DIM),
ep_key_axis_names: AxisNames = (KV_BATCH_NO_EXP, LENGTH_WITH_EXP, KV_HEAD, KV_HEAD_DIM),
ep_value_axis_names: AxisNames = (KV_BATCH_NO_EXP, LENGTH_WITH_EXP, KV_HEAD, KV_HEAD_DIM),
input_axis_names: AxisNames = (BATCH, LENGTH, EMBED),
ep_input_axis_names: AxisNames = (BATCH_NO_EXP, LENGTH_WITH_EXP, EMBED),
out_axis_names: AxisNames = (BATCH, LENGTH, HEAD, D_KV),
ep_out_axis_names: AxisNames = (BATCH_NO_EXP, LENGTH_WITH_EXP, HEAD, D_KV),
prefill_input_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, EMBED),
decode_input_axis_names: AxisNames = (DECODE_BATCH, DECODE_LENGTH, EMBED),
prefill_out_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, HEAD, D_KV),
Expand Down Expand Up @@ -647,9 +647,7 @@ def mla_kv_projection(self, inputs: Array, inputs_positions: Array, decoder_segm
if self.config.mla_naive_kvcache:
cached_values = self.update_kv_caches(key, value, decoder_segment_ids, model_mode, previous_chunk)
else:
cached_values = self.update_mla_kv_caches(
low_rank_main, key_rope, decoder_segment_ids, model_mode, previous_chunk
)
cached_values = self.update_mla_kv_caches(low_rank_main, key_rope, decoder_segment_ids, model_mode, previous_chunk)

return key, value, cached_values

Expand Down
18 changes: 9 additions & 9 deletions src/MaxText/layers/attention_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@
CACHE_SCALE_HEADS,
CACHE_SCALE_KV,
AxisIdxes,
LENGTH_WITH_EXP,
LENGTH,
LENGTH_NO_EXP,
DType,
Config,
Array,
Expand Down Expand Up @@ -290,12 +290,12 @@ def attention_op_as_linen(
float32_qk_product: bool = False,
max_prefill_predict_length: int = -1,
float32_logits: bool = False,
flash_axis_names_q: AxisNames = (BATCH, HEAD, LENGTH_NO_EXP, D_KV),
flash_axis_names_q_ep: AxisNames = (BATCH_NO_EXP, HEAD, LENGTH, D_KV),
flash_axis_names_q: AxisNames = (BATCH, HEAD, LENGTH, D_KV),
flash_axis_names_q_ep: AxisNames = (BATCH_NO_EXP, HEAD, LENGTH_WITH_EXP, D_KV),
flash_axis_names_kv: AxisNames = (BATCH, HEAD, KV_LENGTH, D_KV),
flash_axis_names_kv_ep: AxisNames = (BATCH_NO_EXP, HEAD, KV_LENGTH, D_KV),
flash_axis_names_splash_kernel: AxisNames = (HEAD, LENGTH_NO_EXP),
flash_axis_names_splash_kernel_ep: AxisNames = (HEAD, LENGTH),
flash_axis_names_splash_kernel: AxisNames = (HEAD, LENGTH),
flash_axis_names_splash_kernel_ep: AxisNames = (HEAD, LENGTH_WITH_EXP),
prefill_cache_logical_axis_names: AxisNames = (CACHE_BATCH_PREFILL, CACHE_SEQUENCE, CACHE_HEADS, CACHE_KV),
cache_logical_axis_names: AxisNames = (CACHE_BATCH, CACHE_SEQUENCE, CACHE_HEADS, CACHE_KV),
cache_scale_logical_axis_names: AxisNames = (
Expand Down Expand Up @@ -378,12 +378,12 @@ def __init__(
float32_qk_product: bool = False,
max_prefill_predict_length: int = -1,
float32_logits: bool = False,
flash_axis_names_q: AxisNames = (BATCH, HEAD, LENGTH_NO_EXP, D_KV),
flash_axis_names_q_ep: AxisNames = (BATCH_NO_EXP, HEAD, LENGTH, D_KV),
flash_axis_names_q: AxisNames = (BATCH, HEAD, LENGTH, D_KV),
flash_axis_names_q_ep: AxisNames = (BATCH_NO_EXP, HEAD, LENGTH_WITH_EXP, D_KV),
flash_axis_names_kv: AxisNames = (BATCH, HEAD, KV_LENGTH, D_KV),
flash_axis_names_kv_ep: AxisNames = (BATCH_NO_EXP, HEAD, KV_LENGTH, D_KV),
flash_axis_names_splash_kernel: AxisNames = (HEAD, LENGTH_NO_EXP),
flash_axis_names_splash_kernel_ep: AxisNames = (HEAD, LENGTH),
flash_axis_names_splash_kernel: AxisNames = (HEAD, LENGTH),
flash_axis_names_splash_kernel_ep: AxisNames = (HEAD, LENGTH_WITH_EXP),
prefill_cache_logical_axis_names: AxisNames = (CACHE_BATCH_PREFILL, CACHE_SEQUENCE, CACHE_HEADS, CACHE_KV),
cache_logical_axis_names: AxisNames = (CACHE_BATCH, CACHE_SEQUENCE, CACHE_HEADS, CACHE_KV),
cache_scale_logical_axis_names: AxisNames = (
Expand Down
42 changes: 21 additions & 21 deletions src/MaxText/layers/attentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@
D_KV,
AxisNames,
AxisIdxes,
LENGTH_WITH_EXP,
LENGTH,
LENGTH_NO_EXP,
DType,
Config,
Array,
Expand Down Expand Up @@ -139,16 +139,16 @@ def attention_as_linen(
prefill_query_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM),
prefill_key_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM),
prefill_value_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM),
query_axis_names: AxisNames = (KV_BATCH, LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM),
key_axis_names: AxisNames = (KV_BATCH, LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM),
value_axis_names: AxisNames = (KV_BATCH, LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM),
ep_query_axis_names: AxisNames = (KV_BATCH_NO_EXP, LENGTH, KV_HEAD, KV_HEAD_DIM),
ep_key_axis_names: AxisNames = (KV_BATCH_NO_EXP, LENGTH, KV_HEAD, KV_HEAD_DIM),
ep_value_axis_names: AxisNames = (KV_BATCH_NO_EXP, LENGTH, KV_HEAD, KV_HEAD_DIM),
input_axis_names: AxisNames = (BATCH, LENGTH_NO_EXP, EMBED),
ep_input_axis_names: AxisNames = (BATCH_NO_EXP, LENGTH, EMBED),
out_axis_names: AxisNames = (BATCH, LENGTH_NO_EXP, HEAD, D_KV),
ep_out_axis_names: AxisNames = (BATCH_NO_EXP, LENGTH, HEAD, D_KV),
query_axis_names: AxisNames = (KV_BATCH, LENGTH, KV_HEAD, KV_HEAD_DIM),
key_axis_names: AxisNames = (KV_BATCH, LENGTH, KV_HEAD, KV_HEAD_DIM),
value_axis_names: AxisNames = (KV_BATCH, LENGTH, KV_HEAD, KV_HEAD_DIM),
ep_query_axis_names: AxisNames = (KV_BATCH_NO_EXP, LENGTH_WITH_EXP, KV_HEAD, KV_HEAD_DIM),
ep_key_axis_names: AxisNames = (KV_BATCH_NO_EXP, LENGTH_WITH_EXP, KV_HEAD, KV_HEAD_DIM),
ep_value_axis_names: AxisNames = (KV_BATCH_NO_EXP, LENGTH_WITH_EXP, KV_HEAD, KV_HEAD_DIM),
input_axis_names: AxisNames = (BATCH, LENGTH, EMBED),
ep_input_axis_names: AxisNames = (BATCH_NO_EXP, LENGTH_WITH_EXP, EMBED),
out_axis_names: AxisNames = (BATCH, LENGTH, HEAD, D_KV),
ep_out_axis_names: AxisNames = (BATCH_NO_EXP, LENGTH_WITH_EXP, HEAD, D_KV),
prefill_input_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, EMBED),
decode_input_axis_names: AxisNames = (DECODE_BATCH, DECODE_LENGTH, EMBED),
prefill_out_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, HEAD, D_KV),
Expand Down Expand Up @@ -298,16 +298,16 @@ def __init__(
prefill_query_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM),
prefill_key_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM),
prefill_value_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM),
query_axis_names: AxisNames = (KV_BATCH, LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM),
key_axis_names: AxisNames = (KV_BATCH, LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM),
value_axis_names: AxisNames = (KV_BATCH, LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM),
ep_query_axis_names: AxisNames = (KV_BATCH_NO_EXP, LENGTH, KV_HEAD, KV_HEAD_DIM),
ep_key_axis_names: AxisNames = (KV_BATCH_NO_EXP, LENGTH, KV_HEAD, KV_HEAD_DIM),
ep_value_axis_names: AxisNames = (KV_BATCH_NO_EXP, LENGTH, KV_HEAD, KV_HEAD_DIM),
input_axis_names: AxisNames = (BATCH, LENGTH_NO_EXP, EMBED),
ep_input_axis_names: AxisNames = (BATCH_NO_EXP, LENGTH, EMBED),
out_axis_names: AxisNames = (BATCH, LENGTH_NO_EXP, HEAD, D_KV),
ep_out_axis_names: AxisNames = (BATCH_NO_EXP, LENGTH, HEAD, D_KV),
query_axis_names: AxisNames = (KV_BATCH, LENGTH, KV_HEAD, KV_HEAD_DIM),
key_axis_names: AxisNames = (KV_BATCH, LENGTH, KV_HEAD, KV_HEAD_DIM),
value_axis_names: AxisNames = (KV_BATCH, LENGTH, KV_HEAD, KV_HEAD_DIM),
ep_query_axis_names: AxisNames = (KV_BATCH_NO_EXP, LENGTH_WITH_EXP, KV_HEAD, KV_HEAD_DIM),
ep_key_axis_names: AxisNames = (KV_BATCH_NO_EXP, LENGTH_WITH_EXP, KV_HEAD, KV_HEAD_DIM),
ep_value_axis_names: AxisNames = (KV_BATCH_NO_EXP, LENGTH_WITH_EXP, KV_HEAD, KV_HEAD_DIM),
input_axis_names: AxisNames = (BATCH, LENGTH, EMBED),
ep_input_axis_names: AxisNames = (BATCH_NO_EXP, LENGTH_WITH_EXP, EMBED),
out_axis_names: AxisNames = (BATCH, LENGTH, HEAD, D_KV),
ep_out_axis_names: AxisNames = (BATCH_NO_EXP, LENGTH_WITH_EXP, HEAD, D_KV),
prefill_input_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, EMBED),
decode_input_axis_names: AxisNames = (DECODE_BATCH, DECODE_LENGTH, EMBED),
prefill_out_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, HEAD, D_KV),
Expand Down
32 changes: 16 additions & 16 deletions src/MaxText/layers/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -856,19 +856,19 @@ def gmm(inputs, kernel, group_sizes, expert_assignments):
batch_logical_axis = "activation_batch_no_exp"

if self.get_tensor_transpose_parallelism_size() > 1:
input_partition_pspec = nn.logical_to_mesh_axes((batch_logical_axis, "activation_norm_length", "activation_embed"))
input_partition_pspec = nn.logical_to_mesh_axes((batch_logical_axis, "activation_length", "activation_embed"))
w0_bias_pspec = nn.logical_to_mesh_axes(("exp", None))
w1_bias_pspec = nn.logical_to_mesh_axes(("exp", None))
wo_bias_pspec = nn.logical_to_mesh_axes(("exp", "activation_embed"))
else:
input_partition_pspec = nn.logical_to_mesh_axes((batch_logical_axis, "activation_norm_length", None))
input_partition_pspec = nn.logical_to_mesh_axes((batch_logical_axis, "activation_length", None))
w0_bias_pspec = nn.logical_to_mesh_axes(("exp", "activation_mlp"))
w1_bias_pspec = nn.logical_to_mesh_axes(("exp", "activation_mlp"))
wo_bias_pspec = nn.logical_to_mesh_axes(("exp", "activation_embed"))

gate_logits_pspec = nn.logical_to_mesh_axes((batch_logical_axis, "activation_norm_length", None))
gate_logits_pspec = nn.logical_to_mesh_axes((batch_logical_axis, "activation_length", None))
if self.config.model_name.startswith("deepseek3"):
pre_bias_logits_pspec = nn.logical_to_mesh_axes((batch_logical_axis, "activation_norm_length", None))
pre_bias_logits_pspec = nn.logical_to_mesh_axes((batch_logical_axis, "activation_length", None))
else:
# pre_bias_logits is None for non-DeepSeek v3 models
pre_bias_logits_pspec = None
Expand Down Expand Up @@ -896,7 +896,7 @@ def gmm(inputs, kernel, group_sizes, expert_assignments):
w1_bias_pspec,
wo_bias_pspec,
),
out_specs=(nn.logical_to_mesh_axes((batch_logical_axis, "activation_norm_length", "activation_embed"))),
out_specs=(nn.logical_to_mesh_axes((batch_logical_axis, "activation_length", "activation_embed"))),
check_rep=False,
)
def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias):
Expand Down Expand Up @@ -1135,7 +1135,7 @@ def generate_masks_subgroup(self, top_k_indices, softmax_probs):
)
expert_token_count = nn.with_logical_constraint(
expert_token_count,
("activation_batch", "activation_norm_length", None, None, None),
("activation_batch", "activation_length", None, None, None),
)
trunc_expert_mask = expert_mask * jnp.less_equal(expert_token_count, expert_capacity_per_batch)
combined_expert_mask = jnp.sum(trunc_expert_mask, axis=3)
Expand Down Expand Up @@ -1223,7 +1223,7 @@ def generate_masks(self, top_k_indices, softmax_probs):
)
expert_token_count = nn.with_logical_constraint(
expert_token_count,
("activation_batch", "activation_norm_length", None, None),
("activation_batch", "activation_length", None, None),
)
trunc_expert_mask = expert_mask * jnp.less_equal(expert_token_count, expert_capacity_per_batch)
combined_expert_mask = jnp.sum(trunc_expert_mask, axis=2)
Expand Down Expand Up @@ -1321,10 +1321,10 @@ def dense_matmul(
) -> tuple[jax.Array, Optional[jax.Array]]:
"""Dense matrix multiplication."""
# gate_logits: batch, length, expert
gate_logits = nn.with_logical_constraint(gate_logits, ("activation_batch", "activation_norm_length", None))
gate_logits = nn.with_logical_constraint(gate_logits, ("activation_batch", "activation_length", None))
if self.config.model_name.startswith("deepseek3"):
# pre_bias_logits is None for non-DeepSeek v3 models
pre_bias_logits = nn.with_logical_constraint(pre_bias_logits, ("activation_batch", "activation_norm_length", None))
pre_bias_logits = nn.with_logical_constraint(pre_bias_logits, ("activation_batch", "activation_length", None))
top_k_weights, top_k_indices = self.get_topk(gate_logits, pre_bias_logits)
is_llama4_decoder_layer = self.config.decoder_block == ctypes.DecoderBlockType.LLAMA4
if is_llama4_decoder_layer:
Expand All @@ -1351,7 +1351,7 @@ def dense_matmul(
dispatch_mask, combine_mask = self.generate_masks(
top_k_indices, weights # pylint: disable=undefined-variable,possibly-used-before-assignment
)
mask_axes = ("activation_batch", "activation_norm_length", None, None)
mask_axes = ("activation_batch", "activation_length", None, None)
dispatch_axis = (
"activation_exp",
"activation_batch_no_exp",
Expand All @@ -1375,14 +1375,14 @@ def dense_matmul(
dispatch_mask, combine_mask = self.generate_masks_subgroup(top_k_indices, softmax_probs)
if self.get_context_autoregressive_parallelism_size() > 0 and cp == 1:
mask_axes = (
"activation_norm_length",
"activation_length",
"activation_batch",
None,
None,
None,
)
input_axis = (
"activation_norm_length",
"activation_length",
"activation_batch",
None,
"activation_embed",
Expand All @@ -1404,14 +1404,14 @@ def dense_matmul(
else:
mask_axes = (
"activation_batch",
"activation_norm_length",
"activation_length",
None,
None,
None,
)
input_axis = (
"activation_batch",
"activation_norm_length",
"activation_length",
None,
"activation_embed",
)
Expand Down Expand Up @@ -1451,7 +1451,7 @@ def dense_matmul(
(
None,
"activation_batch_no_exp",
"activation_norm_length",
"activation_length",
None,
"activation_embed",
),
Expand Down Expand Up @@ -1540,7 +1540,7 @@ def dense_matmul(
)
return output, loss
else:
inputs = nn.with_logical_constraint(inputs, ("activation_batch", "activation_norm_length", "activation_embed"))
inputs = nn.with_logical_constraint(inputs, ("activation_batch", "activation_length", "activation_embed"))
with jax.named_scope("wi_0"):
layer_w0 = self.get_einsum(rhs_mesh_axes=self.wi_kernel_axes)(
"BSM,EMH -> BSEH", inputs, w0_kernel, precision=matmul_precision
Expand Down
Loading
Loading