From 9ba8210fb4f21a0a6f446e7fc87aef405df2bb58 Mon Sep 17 00:00:00 2001 From: thomasw21 <24695242+thomasw21@users.noreply.github.com> Date: Thu, 24 Mar 2022 23:36:27 +0100 Subject: [PATCH 1/9] sync layer norms --- megatron/model/fused_layer_norm.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/megatron/model/fused_layer_norm.py b/megatron/model/fused_layer_norm.py index 78645c236..0d8383c76 100644 --- a/megatron/model/fused_layer_norm.py +++ b/megatron/model/fused_layer_norm.py @@ -22,6 +22,7 @@ from torch.nn.parameter import Parameter from torch.nn import init import importlib +from megatron import mpu global fused_mix_prec_layer_norm_cuda fused_mix_prec_layer_norm_cuda = None @@ -84,7 +85,12 @@ def reset_parameters(self): def forward(self, input): + # TODO: temporary hack in order to synchronize all layer norms params despite them being + # unsynced at the moment due to a bug in deepspeed's bf16 optimizer + if 1: + tp_world_size = mpu.get_tensor_model_parallel_world_size() + weight = mpu.reduce_from_tensor_model_parallel_region(self.weight) / tp_world_size + bias = mpu.reduce_from_tensor_model_parallel_region(self.bias) / tp_world_size return FusedLayerNormAffineFunction.apply( - input, self.weight, self.bias, self.normalized_shape,self.eps) - + input, weight, bias, self.normalized_shape,self.eps) From fbd47eed09ae653e5e0f31b4db39e6428fb4c915 Mon Sep 17 00:00:00 2001 From: thomasw21 <24695242+thomasw21@users.noreply.github.com> Date: Fri, 25 Mar 2022 11:05:52 +0100 Subject: [PATCH 2/9] all_reduce is an in_place operation --- megatron/model/fused_layer_norm.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/megatron/model/fused_layer_norm.py b/megatron/model/fused_layer_norm.py index 0d8383c76..0ae519a27 100644 --- a/megatron/model/fused_layer_norm.py +++ b/megatron/model/fused_layer_norm.py @@ -89,8 +89,10 @@ def forward(self, input): # unsynced at the moment due to a bug in deepspeed's bf16 optimizer if 1: tp_world_size = mpu.get_tensor_model_parallel_world_size() - weight = mpu.reduce_from_tensor_model_parallel_region(self.weight) / tp_world_size - bias = mpu.reduce_from_tensor_model_parallel_region(self.bias) / tp_world_size + weight = torch.clone(self.weight) + bias = torch.clone(self.bias) + weight = mpu.reduce_from_tensor_model_parallel_region(weight) / tp_world_size + bias = mpu.reduce_from_tensor_model_parallel_region(bias) / tp_world_size return FusedLayerNormAffineFunction.apply( input, weight, bias, self.normalized_shape,self.eps) From a9fb317e38e8351f92388a4188e1f3c95e184c16 Mon Sep 17 00:00:00 2001 From: Thomas Wang <24695242+thomasw21@users.noreply.github.com> Date: Wed, 6 Apr 2022 19:10:50 +0200 Subject: [PATCH 3/9] Make dataloader use another random generator (#276) --- megatron/data/data_samplers.py | 1 + 1 file changed, 1 insertion(+) diff --git a/megatron/data/data_samplers.py b/megatron/data/data_samplers.py index 1cbeac312..b933ff34e 100644 --- a/megatron/data/data_samplers.py +++ b/megatron/data/data_samplers.py @@ -52,6 +52,7 @@ def build_pretraining_data_loader(dataset, consumed_samples): return torch.utils.data.DataLoader(dataset, batch_sampler=batch_sampler, num_workers=args.num_workers, + generator=torch.Generator().manual_seed(args.seed), pin_memory=True) class MegatronPretrainingSampler: From 8c1ed22542177fb1818a0135458339d5840d2602 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Wed, 6 Apr 2022 19:17:52 +0200 Subject: [PATCH 4/9] do all_reduce op.AVG directly --- megatron/model/fused_layer_norm.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/megatron/model/fused_layer_norm.py b/megatron/model/fused_layer_norm.py index 0ae519a27..6305bcd83 100644 --- a/megatron/model/fused_layer_norm.py +++ b/megatron/model/fused_layer_norm.py @@ -85,6 +85,16 @@ def reset_parameters(self): def forward(self, input): + + torch.distributed.all_reduce(self.weight, op=torch.distributed.ReduceOp.AVG, group=mpu.get_tensor_model_parallel_group()) + torch.distributed.all_reduce(self.bias, op=torch.distributed.ReduceOp.AVG, group=mpu.get_tensor_model_parallel_group()) + + return FusedLayerNormAffineFunction.apply( + input, self.weight, self.bias, self.normalized_shape, self.eps) + + + + def forward1(self, input): # TODO: temporary hack in order to synchronize all layer norms params despite them being # unsynced at the moment due to a bug in deepspeed's bf16 optimizer if 1: From b015ec15577794bfe132051fdf284fc060e69eff Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Wed, 18 May 2022 20:15:40 +0200 Subject: [PATCH 5/9] add eval dataloader deadlock workaround --- megatron/data/data_samplers.py | 8 +++++--- megatron/training.py | 4 +++- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/megatron/data/data_samplers.py b/megatron/data/data_samplers.py index b933ff34e..876ef9053 100644 --- a/megatron/data/data_samplers.py +++ b/megatron/data/data_samplers.py @@ -22,7 +22,7 @@ from megatron import mpu -def build_pretraining_data_loader(dataset, consumed_samples): +def build_pretraining_data_loader(dataset, consumed_samples, num_workers=None): """Buld dataloader given an input dataset.""" if dataset is None: @@ -48,11 +48,13 @@ def build_pretraining_data_loader(dataset, consumed_samples): raise Exception('{} dataloader type is not supported.'.format( args.dataloader_type)) + if num_workers is None: + num_workers = args.num_workers + # Torch dataloader. return torch.utils.data.DataLoader(dataset, batch_sampler=batch_sampler, - num_workers=args.num_workers, - generator=torch.Generator().manual_seed(args.seed), + num_workers=num_workers, pin_memory=True) class MegatronPretrainingSampler: diff --git a/megatron/training.py b/megatron/training.py index 84fd4eb9d..5d98883fd 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -1132,7 +1132,9 @@ def build_train_valid_test_data_iterators( # We collapse None and empty list as both should mean we don't run validation # args.consumed_valid_samples accumulates the sum of valid steps for every dataset, which are all equal - valid_dataloaders = [build_pretraining_data_loader(d, args.consumed_valid_samples // len(valid_ds)) + # XXX: we get a deadlock in the dataloader on eval, possibly this bug in pytorch https://github.com/pytorch/pytorch/pull/25158 + # using num_workers=0 to work around it - the training can't use that since it impacts throughput by a few percent + valid_dataloaders = [build_pretraining_data_loader(d, args.consumed_valid_samples // len(valid_ds), num_workers=0) for d in valid_ds] \ if valid_ds is not None else [] # We collapse None and empty list as both should mean we don't run test From 10f50184c69fb9d1dc9c72077064846179c71f97 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Tue, 24 May 2022 17:15:22 +0200 Subject: [PATCH 6/9] revert generator sync --- megatron/data/data_samplers.py | 1 + 1 file changed, 1 insertion(+) diff --git a/megatron/data/data_samplers.py b/megatron/data/data_samplers.py index 876ef9053..8e97a87a1 100644 --- a/megatron/data/data_samplers.py +++ b/megatron/data/data_samplers.py @@ -55,6 +55,7 @@ def build_pretraining_data_loader(dataset, consumed_samples, num_workers=None): return torch.utils.data.DataLoader(dataset, batch_sampler=batch_sampler, num_workers=num_workers, + generator=torch.Generator().manual_seed(args.seed), pin_memory=True) class MegatronPretrainingSampler: From 76f2fd7ca8f1f00cec19ead30bce8acba58dfc7d Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Mon, 4 Jul 2022 14:15:23 -0700 Subject: [PATCH 7/9] make auto-sync configurable; basic test; cleanup --- megatron/arguments.py | 4 ++++ megatron/model/fused_layer_norm.py | 32 ++++++++++-------------------- tests/test_training.py | 1 + 3 files changed, 15 insertions(+), 22 deletions(-) diff --git a/megatron/arguments.py b/megatron/arguments.py index 230bd4d65..ccd38bdad 100644 --- a/megatron/arguments.py +++ b/megatron/arguments.py @@ -375,6 +375,10 @@ def _add_network_size_args(parser): ', needs to be divisible by TP size and `make-vocab-size-divisible-by`.') group.add_argument('--layernorm-epsilon', type=float, default=1e-5, help='Layer norm epsilon.') + group.add_argument('--layernorm-tp-auto-sync', action='store_true', + help='Force syncing layernorm params across TP ranks in forward. ' + 'This is a workaround for an unresolved bug leading to TP ranks ' + 'getting out of sync with each other.') group.add_argument('--apply-residual-connection-post-layernorm', action='store_true', help='If set, use original BERT residula connection ' diff --git a/megatron/model/fused_layer_norm.py b/megatron/model/fused_layer_norm.py index 95b2278ce..ed914d9a0 100644 --- a/megatron/model/fused_layer_norm.py +++ b/megatron/model/fused_layer_norm.py @@ -19,16 +19,16 @@ import numbers + +from megatron import get_args +from megatron import mpu from packaging import version -import torch from torch import nn -from torch.nn.parameter import Parameter -import torch.nn.functional as F from torch.nn import init +from torch.nn.parameter import Parameter import importlib -from megatron import mpu - -from megatron import get_args +import torch +import torch.nn.functional as F global fused_mix_prec_layer_norm_cuda fused_mix_prec_layer_norm_cuda = None @@ -84,6 +84,7 @@ def __init__(self, normalized_shape, eps=1e-5): self.reset_parameters() args = get_args() + self.layernorm_tp_auto_sync = args.layernorm_tp_auto_sync self.use_meg_ds_fused_layer_norm = ( args.bf16 # Current Meg-DS cuda kernel has better throughput than torch.nn.LayerNorm @@ -99,25 +100,12 @@ def reset_parameters(self): def forward(self, input): - torch.distributed.all_reduce(self.weight, op=torch.distributed.ReduceOp.AVG, group=mpu.get_tensor_model_parallel_group()) - torch.distributed.all_reduce(self.bias, op=torch.distributed.ReduceOp.AVG, group=mpu.get_tensor_model_parallel_group()) + if self.layernorm_tp_auto_sync: + torch.distributed.all_reduce(self.weight, op=torch.distributed.ReduceOp.AVG, group=mpu.get_tensor_model_parallel_group()) + torch.distributed.all_reduce(self.bias, op=torch.distributed.ReduceOp.AVG, group=mpu.get_tensor_model_parallel_group()) if self.use_meg_ds_fused_layer_norm: return FusedLayerNormAffineFunction.apply( input, self.weight, self.bias, self.normalized_shape, self.eps) else: return F.layer_norm(input, self.normalized_shape, self.weight, self.bias) - - - def forward1(self, input): - # TODO: temporary hack in order to synchronize all layer norms params despite them being - # unsynced at the moment due to a bug in deepspeed's bf16 optimizer - if 1: - tp_world_size = mpu.get_tensor_model_parallel_world_size() - weight = torch.clone(self.weight) - bias = torch.clone(self.bias) - weight = mpu.reduce_from_tensor_model_parallel_region(weight) / tp_world_size - bias = mpu.reduce_from_tensor_model_parallel_region(bias) / tp_world_size - - return FusedLayerNormAffineFunction.apply( - input, weight, bias, self.normalized_shape,self.eps) diff --git a/tests/test_training.py b/tests/test_training.py index c77cb9af2..79a43c6a2 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -147,6 +147,7 @@ def get_variation_config(self, variation, output_dir, n_samples=None): --clip-grad 1.0 --weight-decay 1e-1 --embed-layernorm + --layernorm-tp-auto-sync --fp16 --log-level debug From e2ffb5e90213f680974f316b62fbadacbd0faeda Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Mon, 4 Jul 2022 15:24:29 -0700 Subject: [PATCH 8/9] test with updated AMI image --- .github/workflows/ci.md | 6 ++++-- .github/workflows/main.yml | 12 ++++++------ 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/.github/workflows/ci.md b/.github/workflows/ci.md index 2db56c20a..6294f983f 100644 --- a/.github/workflows/ci.md +++ b/.github/workflows/ci.md @@ -83,6 +83,8 @@ pip install -r requirements-ms.txt - apex - needs a hack to deal with mismatching minor cuda versions (and it takes forever to build), so using this patch: +XXX: this no longer works - had to manually patch pytorch to avoid mismatch failure + ``` --- a/setup.py +++ b/setup.py @@ -110,8 +112,8 @@ cd code/apex Once the needed things got installed (and every time anything new is installed) a new AMI must be created (this is like an .iso image snapshot) -1. go to https://us-east-2.console.aws.amazon.com/ec2/v2/home?region=us-east-1#Instances: -2. choose the image to create a new image from +1. go to https://us-east-1.console.aws.amazon.com/ec2/v2/home?region=us-east-1#Instances: +2. choose the instance to create a new image from 3. Actions -> Image and Templates -> Create Image Must ensure it's created in the correct region (same as in script) - or can copy it to the right region. diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index e343df39e..574ab3212 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -40,7 +40,7 @@ jobs: with: mode: start github-token: ${{ secrets.GH_PERSONAL_ACCESS_TOKEN }} - ec2-image-id: ami-04933c2edcc56a03a + ec2-image-id: ami-0ad997818d90480f2 ec2-instance-type: g4dn.12xlarge security-group-id: sg-f2a4e2fc subnet-id: subnet-b7533b96 # us-east-1c @@ -57,7 +57,7 @@ jobs: with: mode: start github-token: ${{ secrets.GH_PERSONAL_ACCESS_TOKEN }} - ec2-image-id: ami-04933c2edcc56a03a + ec2-image-id: ami-0ad997818d90480f2 ec2-instance-type: g4dn.12xlarge security-group-id: sg-f2a4e2fc subnet-id: subnet-a396b2ad # us-east-1f @@ -74,7 +74,7 @@ jobs: with: mode: start github-token: ${{ secrets.GH_PERSONAL_ACCESS_TOKEN }} - ec2-image-id: ami-04933c2edcc56a03a + ec2-image-id: ami-0ad997818d90480f2 ec2-instance-type: g4dn.12xlarge security-group-id: sg-f2a4e2fc subnet-id: subnet-df0f6180 # us-east-1a @@ -92,7 +92,7 @@ jobs: with: mode: start github-token: ${{ secrets.GH_PERSONAL_ACCESS_TOKEN }} - ec2-image-id: ami-04933c2edcc56a03a + ec2-image-id: ami-0ad997818d90480f2 ec2-instance-type: p3.8xlarge security-group-id: sg-f2a4e2fc subnet-id: subnet-b7533b96 # us-east-1c @@ -109,7 +109,7 @@ jobs: with: mode: start github-token: ${{ secrets.GH_PERSONAL_ACCESS_TOKEN }} - ec2-image-id: ami-04933c2edcc56a03a + ec2-image-id: ami-0ad997818d90480f2 ec2-instance-type: p3.8xlarge security-group-id: sg-f2a4e2fc subnet-id: subnet-a396b2ad # us-east-1f @@ -125,7 +125,7 @@ jobs: with: mode: start github-token: ${{ secrets.GH_PERSONAL_ACCESS_TOKEN }} - ec2-image-id: ami-04933c2edcc56a03a + ec2-image-id: ami-0ad997818d90480f2 ec2-instance-type: p3.8xlarge security-group-id: sg-f2a4e2fc subnet-id: subnet-df0f6180 # us-east-1a From 44bc82efe176c5adb77ad0e51610b478d4d22707 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Mon, 4 Jul 2022 15:47:36 -0700 Subject: [PATCH 9/9] fix unrelated test --- tests/test_model.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_model.py b/tests/test_model.py index 6defd784d..fa625d764 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -11,7 +11,7 @@ from packaging import version from megatron import initialize_megatron, get_args, get_tokenizer, global_vars -from megatron.testing_utils import TestCasePlus, mockenv_context, flatten_arguments, torch_assert_equal +from megatron.testing_utils import TestCasePlus, mockenv_context, flatten_arguments, torch_assert_equal, require_torch_bf16 from megatron.training import setup_model_and_optimizer from pretrain_gpt import model_provider as gpt_model_provider, get_batch_pipe as get_gpt_batch_pipe from pretrain_prefix_lm import model_provider as prefix_lm_model_provider, get_batch_pipe as get_prefix_lm_batch_pipe @@ -270,6 +270,7 @@ def test_gpt_rotary_embeddings(self): #TODO: Check all invariants + @require_torch_bf16 def test_fused_layer_norm(self): command_args = get_default_args()