Skip to content
51 changes: 29 additions & 22 deletions megatron/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,35 +12,42 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging

import torch
import os

logger = logging.getLogger(__name__)

from .package_info import (
__description__,
__contact_names__,
__url__,
__download_url__,
__keywords__,
__license__,
__package_name__,
__version__,
)

from .global_vars import get_args
from .global_vars import get_current_global_batch_size
from .global_vars import get_num_microbatches
from .global_vars import get_signal_handler
from .global_vars import update_num_microbatches
from .global_vars import get_tokenizer
from .global_vars import get_tensorboard_writer
from .global_vars import get_adlr_autoresume
from .global_vars import get_timers
from .initialize import initialize_megatron
if "MEGATRON_SETUP" not in os.environ:
from .global_vars import get_args
from .global_vars import get_current_global_batch_size
from .global_vars import get_num_microbatches
from .global_vars import get_signal_handler
from .global_vars import update_num_microbatches
from .global_vars import get_tokenizer
from .global_vars import get_tensorboard_writer
from .global_vars import get_adlr_autoresume
from .global_vars import get_timers
from .initialize import initialize_megatron

def print_rank_0(message):
"""If distributed is initialized, print only on rank 0."""
if torch.distributed.is_initialized():
if torch.distributed.get_rank() == 0:
print(message, flush=True)
else:
print(message, flush=True)
logger.info(str(message))

def is_last_rank():
return torch.distributed.get_rank() == (
torch.distributed.get_world_size() - 1)

def print_rank_last(message):
"""If distributed is initialized, print only on last rank."""
if torch.distributed.is_initialized():
if is_last_rank():
print(message, flush=True)
else:
print(message, flush=True)
logger.info(str(message))
63 changes: 31 additions & 32 deletions megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,13 @@
"""Megatron arguments."""

import argparse
import logging
import os

import torch

logger = logging.getLogger(__name__)

def parse_args(extra_args_provider=None, defaults={},
ignore_unknown_args=False):
"""Parse all arguments."""
Expand Down Expand Up @@ -74,13 +77,12 @@ def parse_args(extra_args_provider=None, defaults={},
'size ({})'.format(args.world_size, args.tensor_model_parallel_size,
args.pipeline_model_parallel_size)
args.data_parallel_size = args.world_size // model_parallel_size
if args.rank == 0:
print('using world size: {}, data-parallel-size: {}, '
'tensor-model-parallel size: {}, '
'pipeline-model-parallel size: {} '.format(
args.world_size, args.data_parallel_size,
args.tensor_model_parallel_size,
args.pipeline_model_parallel_size), flush=True)
logger.info('using world size: {}, data-parallel-size: {}, '
'tensor-model-parallel size: {}, '
'pipeline-model-parallel size: {} '.format(
args.world_size, args.data_parallel_size,
args.tensor_model_parallel_size,
args.pipeline_model_parallel_size))
if args.pipeline_model_parallel_size > 1:
if args.pipeline_model_parallel_split_rank is not None:
assert args.pipeline_model_parallel_split_rank < \
Expand Down Expand Up @@ -112,11 +114,9 @@ def parse_args(extra_args_provider=None, defaults={},
# arguments that are passed to the program. We check this by
# ensuring the arg is set to None.
if getattr(args, key) is not None:
if args.rank == 0:
print('WARNING: overriding default arguments for {key}:{v} \
with {key}:{v2}'.format(key=key, v=defaults[key],
v2=getattr(args, key)),
flush=True)
logger.warning('Overriding default arguments for {key}:{v} \
with {key}:{v2}'.format(key=key, v=defaults[key],
v2=getattr(args, key)))
else:
setattr(args, key, defaults[key])

Expand All @@ -125,9 +125,8 @@ def parse_args(extra_args_provider=None, defaults={},
assert args.micro_batch_size > 0
if args.global_batch_size is None:
args.global_batch_size = args.micro_batch_size * args.data_parallel_size
if args.rank == 0:
print('setting global batch size to {}'.format(
args.global_batch_size), flush=True)
logger.info('setting global batch size to {}'.format(
args.global_batch_size))
assert args.global_batch_size > 0
if args.num_layers_per_virtual_pipeline_stage is not None:
assert args.pipeline_model_parallel_size > 2, \
Expand All @@ -154,13 +153,10 @@ def parse_args(extra_args_provider=None, defaults={},
# be done in fp32.
if not args.accumulate_allreduce_grads_in_fp32:
args.accumulate_allreduce_grads_in_fp32 = True
if args.rank == 0:
print('accumulate and all-reduce gradients in fp32 for '
'bfloat16 data type.', flush=True)
logger.info('accumulate and all-reduce gradients in fp32 for '
'bfloat16 data type.')

if args.rank == 0:
print('using {} for parameters ...'.format(args.params_dtype),
flush=True)
logger.info('using {} for parameters ...'.format(args.params_dtype))

# If we do accumulation and all-reduces in fp32, we need to have local DDP
# and we should make sure use-contiguous-buffers-in-local-ddp is not off.
Expand Down Expand Up @@ -275,17 +271,14 @@ def parse_args(extra_args_provider=None, defaults={},

def _print_args(args):
"""Print arguments."""
if args.rank == 0:
print('------------------------ arguments ------------------------',
flush=True)
str_list = []
for arg in vars(args):
dots = '.' * (48 - len(arg))
str_list.append(' {} {} {}'.format(arg, dots, getattr(args, arg)))
for arg in sorted(str_list, key=lambda x: x.lower()):
print(arg, flush=True)
print('-------------------- end of arguments ---------------------',
flush=True)
logger.info('------------------------ arguments ------------------------')
str_list = []
for arg in vars(args):
dots = '.' * (48 - len(arg))
str_list.append(' {} {} {}'.format(arg, dots, getattr(args, arg)))
for arg in sorted(str_list, key=lambda x: x.lower()):
logger.info(arg)
logger.info('-------------------- end of arguments ---------------------')


def _check_arg_is_not_none(args, arg):
Expand Down Expand Up @@ -350,8 +343,12 @@ def _add_network_size_args(parser):
def _add_logging_args(parser):
group = parser.add_argument_group(title='logging')

group.add_argument('--name', type=str, default=None,
help='A name for the experiment.')
group.add_argument('--log-params-norm', action='store_true',
help='If set, calculate and log parameters norm.')
group.add_argument('--log-scales', action='store_true',
help='Log the scales of parameters, gradients and activations.')
group.add_argument('--log-num-zeros-in-grad', action='store_true',
help='If set, calculate and log the number of zeros in gradient.')
group.add_argument('--tensorboard-log-interval', type=int, default=1,
Expand Down Expand Up @@ -708,6 +705,8 @@ def _add_data_args(parser):
'1) a single data path, 2) multiple datasets in the'
'form: dataset1-weight dataset1-path dataset2-weight '
'dataset2-path ...')
group.add_argument('--indexmap-path', type=str, default=None,
help='Path for intermediate data files')
group.add_argument('--split', type=str, default='969, 30, 1',
help='Comma-separated list of proportions for training,'
' validation, and test split. For example the split '
Expand Down
2 changes: 1 addition & 1 deletion megatron/data/bert_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def build_training_sample(sample,
if binary_head:
# We assume that we have at least two sentences in the sample
assert len(sample) > 1
assert target_seq_length <= max_seq_length
assert target_seq_length <= max_seq_length-2

# Divide sample into two segments (A and B).
if binary_head:
Expand Down
24 changes: 17 additions & 7 deletions megatron/data/biencoder_dataset_utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import os
import time
from pathlib import Path

import numpy as np
import torch
import torch.distributed

from megatron import get_args, get_tokenizer, mpu, print_rank_0
from megatron.data.dataset_utils import create_masked_lm_predictions, \
Expand Down Expand Up @@ -146,6 +148,12 @@ def get_block_samples_mapping(block_dataset, title_dataset, data_prefix, num_epo
indexmap_filename += '_1sentok'
indexmap_filename += '.npy'

args=get_args()
if args.indexmap_path is not None:
indexmap_path=Path(get_args().indexmap_path).resolve()
indexmap_path.mkdir(parents=True, exist_ok=True)
indexmap_filename = indexmap_path/Path(indexmap_filename).name

# Build the indexed mapping if not exist.
if mpu.get_data_parallel_rank() == 0 and \
not os.path.isfile(indexmap_filename):
Expand Down Expand Up @@ -184,13 +192,15 @@ def get_block_samples_mapping(block_dataset, title_dataset, data_prefix, num_epo
'(seconds): {:4f}'.format(
time.time() - start_time))

# This should be a barrier but nccl barrier assumes
# device_index=rank which is not the case for model
# parallel case
counts = torch.cuda.LongTensor([1])
torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group())
assert counts[0].item() == torch.distributed.get_world_size(
group=mpu.get_data_parallel_group())
# Wait until rank 0 generate the index file.
torch.distributed.barrier(device_ids=[int(os.environ['LOCAL_RANK'])], group=mpu.get_data_parallel_group())
# It can take some time for the file to be visible on other nodes.
for i in range(120):
if indexmap_filename.is_file():
break
if i%10==0:
print_rank_0(" Waiting for index file...")
time.sleep(1.0)

# Load indexed dataset.
print_rank_0(' > loading indexed mapping from {}'.format(
Expand Down
30 changes: 20 additions & 10 deletions megatron/data/dataset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,11 @@
import os
import time
import collections
from pathlib import Path

import numpy as np
import torch
import torch.distributed

from megatron import (
get_args,
Expand Down Expand Up @@ -446,7 +448,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
prefixes[i], data_impl, splits_string,
datasets_train_valid_test_num_samples[i],
max_seq_length, masked_lm_prob, short_seq_prob,
seed, skip_warmup, binary_head, dataset_type=dataset_type)
seed, skip_warmup, binary_head,max_seq_length_dec, dataset_type=dataset_type)
if train_ds:
train_datasets.append(train_ds)
if valid_ds:
Expand Down Expand Up @@ -661,6 +663,12 @@ def get_samples_mapping(indexed_dataset,
indexmap_filename += '_{}s'.format(seed)
indexmap_filename += '.npy'

args=get_args()
if args.indexmap_path is not None:
indexmap_path=Path(args.indexmap_path).resolve()
indexmap_path.mkdir(parents=True, exist_ok=True)
indexmap_filename = indexmap_path/Path(indexmap_filename).name

# Build the indexed mapping if not exist.
if torch.distributed.get_rank() == 0 and \
not os.path.isfile(indexmap_filename):
Expand Down Expand Up @@ -696,15 +704,17 @@ def get_samples_mapping(indexed_dataset,
print_rank_0(' > elasped time to build and save samples mapping '
'(seconds): {:4f}'.format(
time.time() - start_time))
# This should be a barrier but nccl barrier assumes
# device_index=rank which is not the case for model
# parallel case
counts = torch.cuda.LongTensor([1])
torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group())
torch.distributed.all_reduce(counts, group=mpu.get_pipeline_model_parallel_group())
assert counts[0].item() == (
torch.distributed.get_world_size() //
torch.distributed.get_world_size(group=mpu.get_tensor_model_parallel_group()))

# Wait until rank 0 generate the index file.
print_rank_0(f"Barrier device {int(os.environ['LOCAL_RANK'])}")
torch.distributed.barrier(device_ids=[int(os.environ['LOCAL_RANK'])], group=mpu.get_data_parallel_group())
# It can take some time for the file to be visible on other nodes.
for i in range(120):
if indexmap_filename.is_file():
break
if i%10==0:
print_rank_0(" Waiting for index file...")
time.sleep(1.0)

# Load indexed dataset.
print_rank_0(' > loading indexed mapping from {}'.format(
Expand Down
31 changes: 21 additions & 10 deletions megatron/data/gpt_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@

import os
import time
from pathlib import Path

import numpy as np
import torch
import torch.distributed

from megatron import mpu, print_rank_0
from megatron import mpu, print_rank_0, get_args
from megatron.data.blendable_dataset import BlendableDataset
from megatron.data.dataset_utils import get_datasets_weights_and_num_samples
from megatron.data.dataset_utils import get_train_valid_test_split_
Expand Down Expand Up @@ -211,6 +213,14 @@ def _build_index_mappings(name, data_prefix, documents, sizes,
sample_idx_filename = _filename + '_sample_idx.npy'
shuffle_idx_filename = _filename + '_shuffle_idx.npy'

args=get_args()
if args.indexmap_path is not None:
indexmap_path=Path(args.indexmap_path).resolve()
indexmap_path.mkdir(parents=True, exist_ok=True)
doc_idx_filename = indexmap_path/Path(doc_idx_filename).name
sample_idx_filename = indexmap_path/Path(sample_idx_filename).name
shuffle_idx_filename = indexmap_path/Path(shuffle_idx_filename).name

# Build the indexed mapping if not exist.
if torch.distributed.get_rank() == 0:
if (not os.path.isfile(doc_idx_filename)) or \
Expand Down Expand Up @@ -293,15 +303,16 @@ def _build_index_mappings(name, data_prefix, documents, sizes,
print_rank_0(' > elasped time to build and save shuffle-idx mapping'
' (seconds): {:4f}'.format(time.time() - start_time))

# This should be a barrier but nccl barrier assumes
# device_index=rank which is not the case for model
# parallel case
counts = torch.cuda.LongTensor([1])
torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group())
torch.distributed.all_reduce(counts, group=mpu.get_pipeline_model_parallel_group())
assert counts[0].item() == (
torch.distributed.get_world_size() //
torch.distributed.get_world_size(group=mpu.get_tensor_model_parallel_group()))
# Wait until rank 0 generate the index file.
torch.distributed.barrier(device_ids=[int(os.environ['LOCAL_RANK'])], group=mpu.get_data_parallel_group())

# It can take some time for the file to be visible on other nodes.
for i in range(120):
if doc_idx_filename.is_file() and sample_idx_filename.is_file() and shuffle_idx_filename.is_file():
break
if i%10==0:
print_rank_0(" Waiting for index files...")
time.sleep(1.0)

# Load mappings.
start_time = time.time()
Expand Down
Loading