Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 78 additions & 1 deletion megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,10 @@ def parse_args(extra_args_provider=None, defaults={},
raise ModuleNotFoundError("Please install bitsandbytes from https://github.com/facebookresearch/bitsandbytes.")

_print_args(args)

# Extra sanity-checks and configuration summary (adapted from GPT-NeoX)
_validate_and_summarize_args(args)

return args


Expand Down Expand Up @@ -432,7 +436,7 @@ def _add_logging_args(parser):
help='Report to tensorboard interval.')
group.add_argument('--tensorboard-queue-size', type=int, default=1000,
help='Size of the tensorboard queue for pending events '
'and summaries before one of the add calls forces a '
"and summaries before one of the 'add' calls forces a "
'flush to disk.')
group.add_argument('--log-timers-to-tensorboard', action='store_true',
help='If set, write timers to tensorboard.')
Expand Down Expand Up @@ -1078,3 +1082,76 @@ def _add_activation_checkpoint_args(parser):
group.add_argument('--profile-backward', action='store_true',
help='Enables backward pass profiling for checkpointed layers.')
return parser

# -----------------------------------------------------------------------------
# Consistency checks & startup summary
# -----------------------------------------------------------------------------


def _rank0_print(msg=""):
"""Utility: only print from global rank 0."""
if int(os.getenv("RANK", "0")) == 0:
print(msg, flush=True)


def _validate_and_summarize_args(args):
"""Run inexpensive consistency checks and print a brief config table.

Raises
------
ValueError
If any sanity check fails so the job aborts early.
"""

# ---------------- Consistency checks ----------------
checks = []

# hidden_size vs heads
if args.hidden_size % args.num_attention_heads != 0:
raise ValueError(
f"hidden_size ({args.hidden_size}) must be divisible by num_attention_heads ({args.num_attention_heads})."
)
checks.append("hidden_size/head ratio OK")

# global batch divisibility
if args.global_batch_size % args.data_parallel_size != 0:
raise ValueError(
f"global_batch_size ({args.global_batch_size}) must be divisible by data_parallel_size ({args.data_parallel_size})."
)
checks.append("batch sizes divisible")

# vocab vs TP size
if args.pad_vocab_size_to is not None and (
args.pad_vocab_size_to % args.tensor_model_parallel_size != 0
):
raise ValueError(
f"pad_vocab_size_to ({args.pad_vocab_size_to}) must be divisible by tensor_model_parallel_size ({args.tensor_model_parallel_size})."
)
checks.append("vocab divisible by TP size")

# fp16 vs bf16 mutual exclusivity already handled earlier but we echo
checks.append("dtype flags consistent")

# ---------------- Summary table ----------------
header = (
"\n================ Megatron-DeepSpeed configuration summary ================"
)
lines = [header]
lines.append(
f"GPUs total : {args.world_size} (DP={args.data_parallel_size}, TP={args.tensor_model_parallel_size}, PP={args.pipeline_model_parallel_size})"
)
lines.append(f"Model layers : {args.num_layers}")
lines.append(
f"Hidden size : {args.hidden_size} (Heads={args.num_attention_heads}, Dim/head={args.hidden_size // args.num_attention_heads})"
)
lines.append(f"Seq length : {args.seq_length}")
lines.append(
f"Global batch : {args.global_batch_size} (Micro={args.micro_batch_size})"
)
lines.append(
f"Precision : {'fp16' if args.fp16 else ('bf16' if args.bf16 else 'fp32')}"
)
lines.append("Checks passed : " + ", ".join(checks))
lines.append("==========================================================================\n")

_rank0_print("\n".join(lines))
4 changes: 4 additions & 0 deletions megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,10 @@ def pretrain(train_valid_test_dataset_provider,
args.parameters_in_billions_no_embedding = get_parameters_in_billions(model, exclude_embeddings=True)
print_rank_0(f'estimated model parameters: {get_parameters_in_billions(model)}')
print_rank_0(f'estimated model parameters without embeddings: {get_parameters_in_billions(model, exclude_embeddings=True)}')
if args.rank == 0:
total_params_b = get_parameters_in_billions(model)
total_params = int(total_params_b * 1e9)
print(f"Model size: {round(total_params_b)}B ({total_params} params)", flush=True)
timers('model-and-optimizer-setup').stop()
print_datetime('after model, optimizer, and learning rate '
'scheduler are built')
Expand Down