Skip to content

Commit

Permalink
another micro optimization for communication
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Sep 10, 2023
1 parent 666d2fd commit 5418873
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 7 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'st-moe-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.24',
version = '0.0.25',
license='MIT',
description = 'ST - Mixture of Experts - Pytorch',
author = 'Phil Wang',
Expand Down
7 changes: 4 additions & 3 deletions st_moe_pytorch/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,11 +96,12 @@ def forward(ctx, x):
else:
sizes = (x.shape[1],) * x.shape[0]

ctx.sizes = torch.tensor(sizes, device = out.device, dtype = torch.long)
return out
sizes = torch.tensor(sizes, device = out.device, dtype = torch.long)
ctx.sizes = sizes
return out, sizes

@staticmethod
def backward(ctx, grads):
def backward(ctx, grads, _):
grads = rearrange(grads, '... -> 1 ...')
grads = all_gather_variable_dim(grads, sizes = ctx.sizes)
return grads
Expand Down
6 changes: 3 additions & 3 deletions st_moe_pytorch/st_moe_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def forward(
if is_distributed:
x, expert_batch_packed_shape = pack_one(x, '* n d')
x = rearrange(x, '(r eb) n d -> r eb n d', r = world_size)
x = split_by_rank(x)
x, experts_per_rank_sizes = split_by_rank(x)
x = rearrange(x, '(e b) n d -> e b n d', e = num_experts_per_rank)

# get the experts in use
Expand All @@ -248,14 +248,14 @@ def forward(

if is_distributed:
outs = rearrange(outs, 'e b n d -> (e b) n d')
outs, _ = self.all_gather(outs)
outs, _ = self.all_gather(outs, sizes = experts_per_rank_sizes)
outs = unpack_one(outs, expert_batch_packed_shape, '* n d')

outs = rearrange(outs, 'e b n d -> b e n d')

if is_distributed:
outs = outs.split(batch_sizes.tolist())
outs = split_by_rank(outs)
outs, _ = split_by_rank(outs)

assert outs.shape == shape
return outs
Expand Down

0 comments on commit 5418873

Please sign in to comment.