Skip to content

Commit

Permalink
chip away at edge cases
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Sep 10, 2023
1 parent 5418873 commit 83d75b8
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 13 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,10 @@ out, balance_loss, router_z_loss = moe_block(inputs) # (4, 1024, 512), (1,), (1,
- [x] make first naive non-optimized attempt at distributed code for mixture of experts

- [ ] distributed
- [x] handle any world size less than number of experts
- [ ] handle any world size greater than number of experts
- [ ] optimize
- [ ] figure out what is faster, all gather, or broadcast with async followed by barrier
- [ ] make all distributed code pluggable, for different strategies
- [ ] figure out why there is tiny error in gradients
- [ ] support variable batch and sequence lengths
Expand Down
2 changes: 1 addition & 1 deletion assert.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def start(
cleanup()

if __name__ == '__main__':
world_size = 4
world_size = 5
num_experts = 8
batch_size = 2
batch_size_var_len = False
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'st-moe-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.25',
version = '0.0.27',
license='MIT',
description = 'ST - Mixture of Experts - Pytorch',
author = 'Phil Wang',
Expand Down
45 changes: 34 additions & 11 deletions st_moe_pytorch/st_moe_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,16 @@ def default(val, default):
def divisible_by(num, den):
return (num % den) == 0

def chunk_num(num, chunks):
num_per_chunk, remainder = divmod(num, chunks)

out = []
for i in range(chunks):
n = num_per_chunk
out.append(n + int(i < remainder))

return out

def pack_one(t, pattern):
return pack([t], pattern)

Expand Down Expand Up @@ -205,26 +215,39 @@ def forward(
rank = 0

# the experts in use on the rank
# for now, make sure number of machines is right multiple

if world_size <= num_experts:
assert divisible_by(num_experts, world_size), 'if number of machines is less than the number of experts, the number of experts must be divisible by number of machines'
num_experts_per_rank = num_experts // world_size
expert_start_index = rank * num_experts_per_rank
if is_distributed:
if world_size <= num_experts:
num_experts_across_ranks = chunk_num(num_experts, world_size)
start_indices = cumsum_exclusive(torch.tensor(num_experts_across_ranks), dim = -1)
num_experts_per_rank = num_experts_across_ranks[rank]
expert_start_index = start_indices[rank].item()
else:
# for now, make sure number of machines is right multiple

assert divisible_by(world_size, num_experts), 'if number of machines is greater than number of experts, machines must be divisible by number of experts, so experts are evenly distributed'
num_experts_per_rank = 1
expert_start_index = rank // num_experts

expert_slice = slice(expert_start_index, expert_start_index + num_experts_per_rank)
else:
assert divisible_by(world_size, num_experts), 'if number of machines is greater than number of experts, machines must be divisible by number of experts, so experts are evenly distributed'
num_experts_per_rank = 1
expert_start_index = rank // num_experts

expert_slice = slice(expert_start_index, expert_start_index + num_experts_per_rank)
num_experts_per_rank = num_experts
expert_slice = slice(0, num_experts)

# if distributed, each machine only handles subset of experts and batch

x = rearrange(x, 'b e n d -> e b n d')

if is_distributed:
x, expert_batch_packed_shape = pack_one(x, '* n d')
x = rearrange(x, '(r eb) n d -> r eb n d', r = world_size)

if world_size <= num_experts:
x = rearrange(x, '(e b) n d -> e b n d', e = num_experts)
x = x.split(num_experts_across_ranks, dim = 0)
x = tuple(rearrange(t, 'e b n d -> (e b) n d') for t in x)
else:
x = rearrange(x, '(r eb) n d -> r eb n d', r = world_size)

x, experts_per_rank_sizes = split_by_rank(x)
x = rearrange(x, '(e b) n d -> e b n d', e = num_experts_per_rank)

Expand Down

0 comments on commit 83d75b8

Please sign in to comment.