From 54188734ff338ba5e30b2e4fd51c9dd623782678 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Sun, 10 Sep 2023 08:04:54 -0700 Subject: [PATCH] another micro optimization for communication --- setup.py | 2 +- st_moe_pytorch/distributed.py | 7 ++++--- st_moe_pytorch/st_moe_pytorch.py | 6 +++--- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/setup.py b/setup.py index 44cdf40..85918c2 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'st-moe-pytorch', packages = find_packages(exclude=[]), - version = '0.0.24', + version = '0.0.25', license='MIT', description = 'ST - Mixture of Experts - Pytorch', author = 'Phil Wang', diff --git a/st_moe_pytorch/distributed.py b/st_moe_pytorch/distributed.py index 3b55d51..7aa6b11 100644 --- a/st_moe_pytorch/distributed.py +++ b/st_moe_pytorch/distributed.py @@ -96,11 +96,12 @@ def forward(ctx, x): else: sizes = (x.shape[1],) * x.shape[0] - ctx.sizes = torch.tensor(sizes, device = out.device, dtype = torch.long) - return out + sizes = torch.tensor(sizes, device = out.device, dtype = torch.long) + ctx.sizes = sizes + return out, sizes @staticmethod - def backward(ctx, grads): + def backward(ctx, grads, _): grads = rearrange(grads, '... -> 1 ...') grads = all_gather_variable_dim(grads, sizes = ctx.sizes) return grads diff --git a/st_moe_pytorch/st_moe_pytorch.py b/st_moe_pytorch/st_moe_pytorch.py index d0f4cfa..7cbbbe3 100644 --- a/st_moe_pytorch/st_moe_pytorch.py +++ b/st_moe_pytorch/st_moe_pytorch.py @@ -225,7 +225,7 @@ def forward( if is_distributed: x, expert_batch_packed_shape = pack_one(x, '* n d') x = rearrange(x, '(r eb) n d -> r eb n d', r = world_size) - x = split_by_rank(x) + x, experts_per_rank_sizes = split_by_rank(x) x = rearrange(x, '(e b) n d -> e b n d', e = num_experts_per_rank) # get the experts in use @@ -248,14 +248,14 @@ def forward( if is_distributed: outs = rearrange(outs, 'e b n d -> (e b) n d') - outs, _ = self.all_gather(outs) + outs, _ = self.all_gather(outs, sizes = experts_per_rank_sizes) outs = unpack_one(outs, expert_batch_packed_shape, '* n d') outs = rearrange(outs, 'e b n d -> b e n d') if is_distributed: outs = outs.split(batch_sizes.tolist()) - outs = split_by_rank(outs) + outs, _ = split_by_rank(outs) assert outs.shape == shape return outs