Skip to content

Commit

Permalink
start journeying into distributed mixture of experts implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Sep 9, 2023
1 parent 1f63c54 commit 085d511
Show file tree
Hide file tree
Showing 5 changed files with 347 additions and 18 deletions.
9 changes: 8 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,15 @@ out, balance_loss, router_z_loss = moe_block(inputs) # (4, 1024, 512), (1,), (1,
- [x] figure out if there was an error in <a href="https://github.com/lucidrains/mixture-of-experts/blob/master/mixture_of_experts/mixture_of_experts.py#L210">a previous transcription</a> - no there was not an error
- [x] allow for different thresholds for second vs third routed expert
- [x] add coordinate descent based routing
- [x] make first naive non-optimized attempt at distributed code for mixture of experts

- [ ] take care of scatter gather, and once done, port over to <a href="https://github.com/lucidrains/soft-moe-pytorch">soft moe</a>
- [ ] distributed
- [ ] optimize
- [ ] make all distributed code pluggable, for different strategies
- [ ] figure out why there is tiny error in gradients
- [ ] support variable batch and sequence lengths

- [ ] improvise a `Top2GatingWithCoordinateDescent` for `MoE` without `importance`

## Citations

Expand Down
99 changes: 99 additions & 0 deletions assert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import os
from copy import deepcopy

import torch
import torch.multiprocessing as mp
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

from st_moe_pytorch.st_moe_pytorch import Experts, Expert
from st_moe_pytorch.distributed import all_gather_variable_dim

def setup(rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
dist.init_process_group("gloo", rank = rank, world_size = world_size)

def cleanup():
dist.destroy_process_group()

def start(
rank,
world_size,
batch_size,
batch_size_var_len,
num_experts,
tokens_per_expert,
dim,
):
setup(rank, world_size)

net = Experts([Expert(dim) for _ in range(num_experts)])

if batch_size_var_len:
batch_size = batch_size + rank

seq = torch.randn(batch_size, num_experts, tokens_per_expert, dim)

# distributed

model = DDP(net)
out = model(seq)
out.mean().backward()

ddp_all_out, _ = all_gather_variable_dim(out)

# on single device

all_inputs, _ = all_gather_variable_dim(seq)
copied_net = deepcopy(net)

single_out = copied_net(
all_inputs,
is_distributed = False
)

single_out.mean().backward()

if rank == 0:
# validate output is the same
# if done on 1 vs multiple machines

assert torch.allclose(single_out, ddp_all_out), 'output is not the same'

# validate backwards and grad

get_first_expert_grad = lambda t: t.experts[0].net[0].weight.grad

assert torch.allclose(
get_first_expert_grad(net),
get_first_expert_grad(copied_net),
atol = 1e-2
), 'grad is not the same'

print('✅')

cleanup()

if __name__ == '__main__':
world_size = 4
num_experts = 8
batch_size = 2
batch_size_var_len = False

seq_len = 32
dim = 8

mp.spawn(
start,
args = (
world_size,
batch_size,
batch_size_var_len,
num_experts,
seq_len,
dim
),
nprocs = world_size,
join = True
)
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'st-moe-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.22',
version = '0.0.23',
license='MIT',
description = 'ST - Mixture of Experts - Pytorch',
author = 'Phil Wang',
Expand Down
100 changes: 100 additions & 0 deletions st_moe_pytorch/distributed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import torch
from torch import nn
import torch.nn.functional as F
from torch.autograd import Function

import torch.distributed as dist

from einops import rearrange, pack, unpack

def exists(val):
return val is not None

def default(val, d):
return val if exists(val) else d

def divisible_by(num, den):
return (num % den) == 0

def pad_dim_to(t, length, dim = 0):
pad_length = length - t.shape[dim]
zero_pairs = (-dim - 1) if dim < 0 else (t.ndim - dim - 1)
return F.pad(t, (*((0, 0) * zero_pairs), 0, pad_length))

def all_gather_same_dim(t):
world_size = dist.get_world_size()
gathered_tensors = [torch.empty_like(t, device = t.device, dtype = t.dtype) for i in range(world_size)]
dist.all_gather(gathered_tensors, t)
return gathered_tensors

def gather_sizes(t, *, dim):
size = torch.tensor(t.shape[dim], device = t.device, dtype = torch.long)
sizes = all_gather_same_dim(size)
return torch.stack(sizes)

def has_only_one_value(t):
return (t == t[0]).all()

def all_gather_variable_dim(t, dim = 0, sizes = None):
device, rank, world_size = t.device, dist.get_rank(), dist.get_world_size()

if not exists(sizes):
sizes = gather_sizes(t, dim = dim)

if has_only_one_value(sizes):
gathered_tensors = all_gather_same_dim(t)
gathered_tensors = torch.cat(gathered_tensors, dim = dim)
return gathered_tensors, sizes

max_size = sizes.amax().item()

padded_t = pad_dim_to(t, max_size, dim = dim)
gathered_tensors = all_gather_same_dim(padded_t)

gathered_tensors = torch.cat(gathered_tensors, dim = dim)
seq = torch.arange(max_size, device = device)

mask = rearrange(seq, 'j -> 1 j') < rearrange(sizes, 'i -> i 1')
mask = rearrange(mask, 'i j -> (i j)')
seq = torch.arange(mask.shape[-1], device = device)
indices = seq[mask]

gathered_tensors = gathered_tensors.index_select(dim, indices)

return gathered_tensors, sizes

class AllGatherFunction(Function):
@staticmethod
def forward(ctx, x, dim, sizes):
x, batch_sizes = all_gather_variable_dim(x, dim = dim, sizes = sizes)
ctx.batch_sizes = batch_sizes.tolist()
ctx.dim = dim
return x, batch_sizes

@staticmethod
def backward(ctx, grads, _):
batch_sizes, rank = ctx.batch_sizes, dist.get_rank()
grads_by_rank = grads.split(batch_sizes, dim = ctx.dim)
return grads_by_rank[rank], None, None

class AllGather(nn.Module):
def __init__(self, *, dim = 0):
super().__init__()
self.dim = dim

def forward(self, x, sizes = None):
return AllGatherFunction.apply(x, self.dim, sizes)

class SplitByRank(Function):
@staticmethod
def forward(ctx, x):
rank = dist.get_rank()
return x[rank]

@staticmethod
def backward(ctx, grads):
grads = rearrange(grads, '... -> 1 ...')
grads = all_gather_variable_dim(grads)
return grads

split_by_rank = SplitByRank.apply
Loading

0 comments on commit 085d511

Please sign in to comment.