diff --git a/README.md b/README.md index 4eec00a..fe0338b 100644 --- a/README.md +++ b/README.md @@ -69,7 +69,10 @@ out, balance_loss, router_z_loss = moe_block(inputs) # (4, 1024, 512), (1,), (1, - [x] make first naive non-optimized attempt at distributed code for mixture of experts - [ ] distributed + - [x] handle any world size less than number of experts + - [ ] handle any world size greater than number of experts - [ ] optimize + - [ ] figure out what is faster, all gather, or broadcast with async followed by barrier - [ ] make all distributed code pluggable, for different strategies - [ ] figure out why there is tiny error in gradients - [ ] support variable batch and sequence lengths diff --git a/assert.py b/assert.py index 6491b6d..18c9b71 100644 --- a/assert.py +++ b/assert.py @@ -76,7 +76,7 @@ def start( cleanup() if __name__ == '__main__': - world_size = 4 + world_size = 5 num_experts = 8 batch_size = 2 batch_size_var_len = False diff --git a/setup.py b/setup.py index 85918c2..e9e3692 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'st-moe-pytorch', packages = find_packages(exclude=[]), - version = '0.0.25', + version = '0.0.27', license='MIT', description = 'ST - Mixture of Experts - Pytorch', author = 'Phil Wang', diff --git a/st_moe_pytorch/st_moe_pytorch.py b/st_moe_pytorch/st_moe_pytorch.py index 7cbbbe3..df4ecb1 100644 --- a/st_moe_pytorch/st_moe_pytorch.py +++ b/st_moe_pytorch/st_moe_pytorch.py @@ -46,6 +46,16 @@ def default(val, default): def divisible_by(num, den): return (num % den) == 0 +def chunk_num(num, chunks): + num_per_chunk, remainder = divmod(num, chunks) + + out = [] + for i in range(chunks): + n = num_per_chunk + out.append(n + int(i < remainder)) + + return out + def pack_one(t, pattern): return pack([t], pattern) @@ -205,18 +215,24 @@ def forward( rank = 0 # the experts in use on the rank - # for now, make sure number of machines is right multiple - if world_size <= num_experts: - assert divisible_by(num_experts, world_size), 'if number of machines is less than the number of experts, the number of experts must be divisible by number of machines' - num_experts_per_rank = num_experts // world_size - expert_start_index = rank * num_experts_per_rank + if is_distributed: + if world_size <= num_experts: + num_experts_across_ranks = chunk_num(num_experts, world_size) + start_indices = cumsum_exclusive(torch.tensor(num_experts_across_ranks), dim = -1) + num_experts_per_rank = num_experts_across_ranks[rank] + expert_start_index = start_indices[rank].item() + else: + # for now, make sure number of machines is right multiple + + assert divisible_by(world_size, num_experts), 'if number of machines is greater than number of experts, machines must be divisible by number of experts, so experts are evenly distributed' + num_experts_per_rank = 1 + expert_start_index = rank // num_experts + + expert_slice = slice(expert_start_index, expert_start_index + num_experts_per_rank) else: - assert divisible_by(world_size, num_experts), 'if number of machines is greater than number of experts, machines must be divisible by number of experts, so experts are evenly distributed' - num_experts_per_rank = 1 - expert_start_index = rank // num_experts - - expert_slice = slice(expert_start_index, expert_start_index + num_experts_per_rank) + num_experts_per_rank = num_experts + expert_slice = slice(0, num_experts) # if distributed, each machine only handles subset of experts and batch @@ -224,7 +240,14 @@ def forward( if is_distributed: x, expert_batch_packed_shape = pack_one(x, '* n d') - x = rearrange(x, '(r eb) n d -> r eb n d', r = world_size) + + if world_size <= num_experts: + x = rearrange(x, '(e b) n d -> e b n d', e = num_experts) + x = x.split(num_experts_across_ranks, dim = 0) + x = tuple(rearrange(t, 'e b n d -> (e b) n d') for t in x) + else: + x = rearrange(x, '(r eb) n d -> r eb n d', r = world_size) + x, experts_per_rank_sizes = split_by_rank(x) x = rearrange(x, '(e b) n d -> e b n d', e = num_experts_per_rank)