diff --git a/launcher_scripts/conf/training/mixtral/mixtral_8x7b.yaml b/launcher_scripts/conf/training/mixtral/mixtral_8x7b.yaml index 05daef455b..c7bfc9242a 100644 --- a/launcher_scripts/conf/training/mixtral/mixtral_8x7b.yaml +++ b/launcher_scripts/conf/training/mixtral/mixtral_8x7b.yaml @@ -52,9 +52,9 @@ model: micro_batch_size: 1 global_batch_size: 256 rampup_batch_size: null - tensor_model_parallel_size: 8 - pipeline_model_parallel_size: 4 - expert_model_parallel_size: 1 + tensor_model_parallel_size: 2 + pipeline_model_parallel_size: 1 + expert_model_parallel_size: 8 virtual_pipeline_model_parallel_size: null encoder_seq_length: 4096 max_position_embeddings: 32768 @@ -145,7 +145,9 @@ model: - 0 gen_shape: false optim: - name: distributed_fused_adam + name: mcore_distributed_optim + overlap_grad_sync: true + overlap_param_sync: true lr: 0.0001 weight_decay: 0.1 betas: