@@ -385,10 +385,12 @@ logical_axis_rules: [
385385 ['q_lora', ['fsdp', 'sequence', 'context', 'tensor_transpose', 'expert']],
386386 ['q_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
387387 ['q_lora', ['fsdp', 'sequence', 'context', 'expert']],
388+ ["q_lora_up_proj",[]],
388389 ['kv_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'tensor_transpose', 'expert']],
389390 ['kv_lora', ['fsdp', 'sequence', 'context', 'tensor_transpose', 'expert']],
390391 ['kv_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
391392 ['kv_lora', ['fsdp', 'sequence', 'context', 'expert']],
393+ ["kv_lora_up_proj",[]],
392394 ['norm', ['tensor', 'tensor_transpose']],
393395 ['layers', 'stage'],
394396 ['kv', []],
@@ -405,6 +407,8 @@ logical_axis_rules: [
405407 ['num_pages', []],
406408 ['tokens_per_page', []],
407409 ['paged_kv_head_dim_size', []],
410+ ['dense_layers', []],
411+ ['moe_layers', []],
408412 ]
409413# Axes used for DCN must be earlier in this list than ICI, see (b/339009148) for details
410414data_sharding : [['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']]
0 commit comments