From 66132ab4d5e53bbc1db78f26dfd84b9766471d8a Mon Sep 17 00:00:00 2001 From: Surbhi Jain Date: Fri, 19 Sep 2025 18:11:23 +0000 Subject: [PATCH] Remove sharding rules for q_lora and kv_lora from base.yml --- src/MaxText/configs/base.yml | 4 ++++ src/MaxText/layers/attention_mla.py | 4 ++-- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/MaxText/configs/base.yml b/src/MaxText/configs/base.yml index 472673af9f..705eb91719 100644 --- a/src/MaxText/configs/base.yml +++ b/src/MaxText/configs/base.yml @@ -385,10 +385,12 @@ logical_axis_rules: [ ['q_lora', ['fsdp', 'sequence', 'context', 'tensor_transpose', 'expert']], ['q_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']], ['q_lora', ['fsdp', 'sequence', 'context', 'expert']], + ["q_lora_up_proj",[]], ['kv_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'tensor_transpose', 'expert']], ['kv_lora', ['fsdp', 'sequence', 'context', 'tensor_transpose', 'expert']], ['kv_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']], ['kv_lora', ['fsdp', 'sequence', 'context', 'expert']], + ["kv_lora_up_proj",[]], ['norm', ['tensor', 'tensor_transpose']], ['layers', 'stage'], ['kv', []], @@ -405,6 +407,8 @@ logical_axis_rules: [ ['num_pages', []], ['tokens_per_page', []], ['paged_kv_head_dim_size', []], + ['dense_layers', []], + ['moe_layers', []], ] # Axes used for DCN must be earlier in this list than ICI, see (b/339009148) for details data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']] diff --git a/src/MaxText/layers/attention_mla.py b/src/MaxText/layers/attention_mla.py index 70ce3cb0c0..812df5ba73 100644 --- a/src/MaxText/layers/attention_mla.py +++ b/src/MaxText/layers/attention_mla.py @@ -398,7 +398,7 @@ def _init_projections(self, inputs_q_shape: Tuple, inputs_kv_shape: Tuple) -> No out_features_shape=self.q_lora_rank, axis=-1, kernel_init=self.kernel_init, - kernel_axes=("embed", "q_lora"), + kernel_axes=("embed", "q_lora_up_proj"), dtype=self.dtype, weight_dtype=self.weight_dtype, quant=self.quant, @@ -432,7 +432,7 @@ def _init_projections(self, inputs_q_shape: Tuple, inputs_kv_shape: Tuple) -> No out_features_shape=self.kv_lora_rank + self.qk_rope_head_dim, axis=-1, kernel_init=self.kernel_init, - kernel_axes=("embed", "kv_lora"), + kernel_axes=("embed", "kv_lora_up_proj"), dtype=self.dtype, weight_dtype=self.weight_dtype, quant=self.quant,