Skip to content

Commit e1c479e

Browse files
stas00thomasw21
andauthored
sync layer norms (#272)
* sync layer norms * all_reduce is an in_place operation * Make dataloader use another random generator (#276) * do all_reduce op.AVG directly * add eval dataloader deadlock workaround * revert generator sync * make auto-sync configurable; basic test; cleanup * test with updated AMI image * fix unrelated test Co-authored-by: thomasw21 <[email protected]>
1 parent 0cb043c commit e1c479e

File tree

4 files changed

+19
-6
lines changed

4 files changed

+19
-6
lines changed

megatron/arguments.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,10 @@ def _add_network_size_args(parser):
375375
', needs to be divisible by TP size and `make-vocab-size-divisible-by`.')
376376
group.add_argument('--layernorm-epsilon', type=float, default=1e-5,
377377
help='Layer norm epsilon.')
378+
group.add_argument('--layernorm-tp-auto-sync', action='store_true',
379+
help='Force syncing layernorm params across TP ranks in forward. '
380+
'This is a workaround for an unresolved bug leading to TP ranks '
381+
'getting out of sync with each other.')
378382
group.add_argument('--apply-residual-connection-post-layernorm',
379383
action='store_true',
380384
help='If set, use original BERT residula connection '

megatron/data/data_samplers.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def pack_samples(items, max_seq_len: int, micro_batch_size: int, pad_token: int)
4040
'target_tokens': array([5])
4141
}
4242
]
43-
43+
4444
Output:
4545
decoder_target_tokens = [[6, 7, 8, 3, 4, 5, <pad>]]: Concatenation of tokens followed with padding tokens.
4646
decoder_segment_ids = [[1, 1, 1, 2, 2, 2, 0]]: Segment ids determine original documents.
@@ -139,6 +139,7 @@ def build_pretraining_data_loader(dataset, consumed_samples, num_workers=None):
139139
dataset,
140140
batch_sampler=batch_sampler,
141141
num_workers=num_workers,
142+
generator=torch.Generator().manual_seed(args.seed),
142143
collate_fn=collate_fn,
143144
pin_memory=True
144145
)

megatron/model/fused_layer_norm.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,16 @@
1919

2020
import numbers
2121

22+
23+
from megatron import get_args
24+
from megatron import mpu
2225
from packaging import version
23-
import torch
2426
from torch import nn
25-
from torch.nn.parameter import Parameter
26-
import torch.nn.functional as F
2727
from torch.nn import init
28+
from torch.nn.parameter import Parameter
2829
import importlib
29-
30-
from megatron import get_args
30+
import torch
31+
import torch.nn.functional as F
3132

3233
global fused_mix_prec_layer_norm_cuda
3334
fused_mix_prec_layer_norm_cuda = None
@@ -83,6 +84,7 @@ def __init__(self, normalized_shape, eps=1e-5):
8384
self.reset_parameters()
8485

8586
args = get_args()
87+
self.layernorm_tp_auto_sync = args.layernorm_tp_auto_sync
8688

8789
self.use_meg_ds_fused_layer_norm = (
8890
args.bf16 # Current Meg-DS cuda kernel has better throughput than torch.nn.LayerNorm
@@ -97,6 +99,11 @@ def reset_parameters(self):
9799

98100

99101
def forward(self, input):
102+
103+
if self.layernorm_tp_auto_sync:
104+
torch.distributed.all_reduce(self.weight, op=torch.distributed.ReduceOp.AVG, group=mpu.get_tensor_model_parallel_group())
105+
torch.distributed.all_reduce(self.bias, op=torch.distributed.ReduceOp.AVG, group=mpu.get_tensor_model_parallel_group())
106+
100107
if self.use_meg_ds_fused_layer_norm:
101108
return FusedLayerNormAffineFunction.apply(
102109
input, self.weight, self.bias, self.normalized_shape, self.eps)

tests/test_training.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ def get_variation_config(self, variation, output_dir, n_samples=None):
147147
--clip-grad 1.0
148148
--weight-decay 1e-1
149149
--embed-layernorm
150+
--layernorm-tp-auto-sync
150151
--fp16
151152
152153
--log-level debug

0 commit comments

Comments
 (0)