Skip to content

Commit 0baff00

Browse files
Merge pull request #2369 from AI-Hypercomputer:deepseek_sharding
PiperOrigin-RevId: 810060338
2 parents 08d9f20 + 66132ab commit 0baff00

File tree

2 files changed

+6
-2
lines changed

2 files changed

+6
-2
lines changed

src/MaxText/configs/base.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -385,10 +385,12 @@ logical_axis_rules: [
385385
['q_lora', ['fsdp', 'sequence', 'context', 'tensor_transpose', 'expert']],
386386
['q_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
387387
['q_lora', ['fsdp', 'sequence', 'context', 'expert']],
388+
["q_lora_up_proj",[]],
388389
['kv_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'tensor_transpose', 'expert']],
389390
['kv_lora', ['fsdp', 'sequence', 'context', 'tensor_transpose', 'expert']],
390391
['kv_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
391392
['kv_lora', ['fsdp', 'sequence', 'context', 'expert']],
393+
["kv_lora_up_proj",[]],
392394
['norm', ['tensor', 'tensor_transpose']],
393395
['layers', 'stage'],
394396
['kv', []],
@@ -405,6 +407,8 @@ logical_axis_rules: [
405407
['num_pages', []],
406408
['tokens_per_page', []],
407409
['paged_kv_head_dim_size', []],
410+
['dense_layers', []],
411+
['moe_layers', []],
408412
]
409413
# Axes used for DCN must be earlier in this list than ICI, see (b/339009148) for details
410414
data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']]

src/MaxText/layers/attention_mla.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -398,7 +398,7 @@ def _init_projections(self, inputs_q_shape: Tuple, inputs_kv_shape: Tuple) -> No
398398
out_features_shape=self.q_lora_rank,
399399
axis=-1,
400400
kernel_init=self.kernel_init,
401-
kernel_axes=("embed", "q_lora"),
401+
kernel_axes=("embed", "q_lora_up_proj"),
402402
dtype=self.dtype,
403403
weight_dtype=self.weight_dtype,
404404
quant=self.quant,
@@ -432,7 +432,7 @@ def _init_projections(self, inputs_q_shape: Tuple, inputs_kv_shape: Tuple) -> No
432432
out_features_shape=self.kv_lora_rank + self.qk_rope_head_dim,
433433
axis=-1,
434434
kernel_init=self.kernel_init,
435-
kernel_axes=("embed", "kv_lora"),
435+
kernel_axes=("embed", "kv_lora_up_proj"),
436436
dtype=self.dtype,
437437
weight_dtype=self.weight_dtype,
438438
quant=self.quant,

0 commit comments

Comments
 (0)