@@ -385,10 +385,12 @@ logical_axis_rules: [
385
385
['q_lora', ['fsdp', 'sequence', 'context', 'tensor_transpose', 'expert']],
386
386
['q_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
387
387
['q_lora', ['fsdp', 'sequence', 'context', 'expert']],
388
+ ["q_lora_up_proj",[]],
388
389
['kv_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'tensor_transpose', 'expert']],
389
390
['kv_lora', ['fsdp', 'sequence', 'context', 'tensor_transpose', 'expert']],
390
391
['kv_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
391
392
['kv_lora', ['fsdp', 'sequence', 'context', 'expert']],
393
+ ["kv_lora_up_proj",[]],
392
394
['norm', ['tensor', 'tensor_transpose']],
393
395
['layers', 'stage'],
394
396
['kv', []],
@@ -405,6 +407,8 @@ logical_axis_rules: [
405
407
['num_pages', []],
406
408
['tokens_per_page', []],
407
409
['paged_kv_head_dim_size', []],
410
+ ['dense_layers', []],
411
+ ['moe_layers', []],
408
412
]
409
413
# Axes used for DCN must be earlier in this list than ICI, see (b/339009148) for details
410
414
data_sharding : [['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']]
0 commit comments