Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions flagscale/train/megatron/training/arguments_fs.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,44 @@ def _parse_recompute_refined_config(recom_config, recom_config_name):
assert args.recompute_method is None and args.recompute_granularity is None and args.recompute_num_layers is None, "PEFT will raise comfilcts with recompute currently"
assert args.ckpt_format == 'torch', "PEFT is only tested with torch format checkpoint"

# DualPipe (bidirectional) related
if getattr(args, 'use_dualpipe', False):
assert args.pipeline_model_parallel_size > 1, (
"DualPipe requires pipeline parallelism (--pipeline-model-parallel-size > 1)."
)
assert args.pipeline_model_parallel_size % 2 == 0, (
"DualPipe requires an even pipeline-model-parallel-size; "
f"got {args.pipeline_model_parallel_size}."
)
assert getattr(args, 'virtual_pipeline_model_parallel_size', None) is None, (
"DualPipe is incompatible with virtual pipeline parallelism "
"(--num-layers-per-virtual-pipeline-stage)."
)
assert not getattr(args, 'use_dualpipev', False), (
"DualPipe (--use-dualpipe) and DualPipeV (--use-dualpipev) "
"cannot be enabled simultaneously."
)
assert getattr(args, 'untie_embeddings_and_output_weights', True) is True, (
"DualPipe is not supported with shared embeddings and output weights "
"(set --untie-embeddings-and-output-weights)."
)
# Derive global_batch_size -> num_microbatches for validation.
# We check the minimum constraint; the exact value is available at runtime.
if args.micro_batch_size is not None and args.data_parallel_size is not None:
num_microbatches = args.global_batch_size // (
args.micro_batch_size * args.data_parallel_size
)
assert num_microbatches % 2 == 0, (
f"DualPipe requires an even number of micro-batches, "
f"got {num_microbatches}. Adjust --global-batch-size, "
f"--micro-batch-size, or --data-parallel-size."
)
assert num_microbatches >= args.pipeline_model_parallel_size * 2, (
f"DualPipe requires num_microbatches ({num_microbatches}) >= "
f"pipeline_model_parallel_size * 2 "
f"({args.pipeline_model_parallel_size * 2})."
)


def _add_hetero_args(parser):
"""Add heterogeneous training related arguments (FlagScale specific)."""
Expand Down Expand Up @@ -656,6 +694,10 @@ def _add_training_args(parser):
help='Use DualPipeV pipeline schedule method')
group.add_argument('--moe-fb-overlap', action='store_true',
help='DualPipeV overlapping of moe a2a communication and forward/backward computation')
group.add_argument('--use-dualpipe', action='store_true',
help='Use the bidirectional DualPipe pipeline schedule (DeepSeek-V3). '
'Requires an even pipeline-model-parallel-size > 1 and '
'num_microbatches >= pipeline-model-parallel-size * 2.')
return parser


Expand Down
Loading