diff --git a/README.md b/README.md index fe0338b..60465eb 100644 --- a/README.md +++ b/README.md @@ -70,7 +70,7 @@ out, balance_loss, router_z_loss = moe_block(inputs) # (4, 1024, 512), (1,), (1, - [ ] distributed - [x] handle any world size less than number of experts - - [ ] handle any world size greater than number of experts + - [x] handle any world size greater than number of experts - for now, just have remainder machines do nothing - [ ] optimize - [ ] figure out what is faster, all gather, or broadcast with async followed by barrier - [ ] make all distributed code pluggable, for different strategies diff --git a/assert.py b/assert.py index a9eafa1..1699937 100644 --- a/assert.py +++ b/assert.py @@ -76,8 +76,8 @@ def start( cleanup() if __name__ == '__main__': - world_size = 8 - num_experts = 8 + world_size = 9 + num_experts = 4 batch_size = 2 batch_size_var_len = False diff --git a/setup.py b/setup.py index ac593d7..51f8adc 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'st-moe-pytorch', packages = find_packages(exclude=[]), - version = '0.0.28', + version = '0.0.29', license='MIT', description = 'ST - Mixture of Experts - Pytorch', author = 'Phil Wang', diff --git a/st_moe_pytorch/st_moe_pytorch.py b/st_moe_pytorch/st_moe_pytorch.py index be21efc..06fd82e 100644 --- a/st_moe_pytorch/st_moe_pytorch.py +++ b/st_moe_pytorch/st_moe_pytorch.py @@ -207,6 +207,7 @@ def forward( assert has_only_one_value(seq_sizes), 'number of tokens per expert must be the same' x, batch_sizes = self.all_gather(x) + total_batch_size = batch_sizes.sum().item() world_size = dist.get_world_size() rank = dist.get_rank() @@ -220,14 +221,28 @@ def forward( if world_size <= num_experts: num_experts_across_ranks = chunk_num(num_experts, world_size) start_indices = cumsum_exclusive(torch.tensor(num_experts_across_ranks), dim = -1) + num_experts_per_rank = num_experts_across_ranks[rank] + num_experts_batches_across_ranks = tuple(i * total_batch_size for i in num_experts_across_ranks) + expert_start_index = start_indices[rank].item() else: - # for now, make sure number of machines is right multiple + num_batch_chunks = world_size // num_experts + total_ranks_in_use = num_batch_chunks * num_experts + + expert_start_index = rank // num_batch_chunks + + batch_splits = chunk_num(total_batch_size, num_batch_chunks) + num_experts_batches_across_ranks = batch_splits * num_experts + + # for now, remaining machines just process nothing + + remain_ranks = world_size % num_experts + num_experts_batches_across_ranks += (0,) * remain_ranks + + num_experts_per_rank = int(rank < total_ranks_in_use) - assert divisible_by(world_size, num_experts), 'if number of machines is greater than number of experts, machines must be divisible by number of experts, so experts are evenly distributed' - num_experts_per_rank = 1 - expert_start_index = rank // (world_size // num_experts) + assert len(num_experts_batches_across_ranks) == world_size expert_slice = slice(expert_start_index, expert_start_index + num_experts_per_rank) else: @@ -241,15 +256,13 @@ def forward( if is_distributed: x, expert_batch_packed_shape = pack_one(x, '* n d') - if world_size <= num_experts: - x = rearrange(x, '(e b) n d -> e b n d', e = num_experts) - x = x.split(num_experts_across_ranks, dim = 0) - x = tuple(rearrange(t, 'e b n d -> (e b) n d') for t in x) - else: - x = rearrange(x, '(r eb) n d -> r eb n d', r = world_size) - + x = x.split(num_experts_batches_across_ranks, dim = 0) x, experts_per_rank_sizes = split_by_rank(x) - x = rearrange(x, '(e b) n d -> e b n d', e = num_experts_per_rank) + + if num_experts_per_rank > 0: + x = rearrange(x, '(e b) n d -> e b n d', e = num_experts_per_rank) + else: + x = x.reshape(num_experts, *x.shape) # get the experts in use @@ -260,11 +273,15 @@ def forward( # route tokens to appropriate experts outs = [] + for expert, expert_input in zip(experts, x): out = expert(expert_input) outs.append(out) - outs = torch.stack(outs) + if len(outs) > 0: + outs = torch.stack(outs) + else: + outs = torch.empty_like(x).requires_grad_() # all gather across merged expert batches dimensions # then split the batch dimension back