-
Notifications
You must be signed in to change notification settings - Fork 28
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
start journeying into distributed mixture of experts implementation
- Loading branch information
1 parent
1f63c54
commit 085d511
Showing
5 changed files
with
347 additions
and
18 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
import os | ||
from copy import deepcopy | ||
|
||
import torch | ||
import torch.multiprocessing as mp | ||
import torch.distributed as dist | ||
from torch.nn.parallel import DistributedDataParallel as DDP | ||
|
||
from st_moe_pytorch.st_moe_pytorch import Experts, Expert | ||
from st_moe_pytorch.distributed import all_gather_variable_dim | ||
|
||
def setup(rank, world_size): | ||
os.environ['MASTER_ADDR'] = 'localhost' | ||
os.environ['MASTER_PORT'] = '12355' | ||
dist.init_process_group("gloo", rank = rank, world_size = world_size) | ||
|
||
def cleanup(): | ||
dist.destroy_process_group() | ||
|
||
def start( | ||
rank, | ||
world_size, | ||
batch_size, | ||
batch_size_var_len, | ||
num_experts, | ||
tokens_per_expert, | ||
dim, | ||
): | ||
setup(rank, world_size) | ||
|
||
net = Experts([Expert(dim) for _ in range(num_experts)]) | ||
|
||
if batch_size_var_len: | ||
batch_size = batch_size + rank | ||
|
||
seq = torch.randn(batch_size, num_experts, tokens_per_expert, dim) | ||
|
||
# distributed | ||
|
||
model = DDP(net) | ||
out = model(seq) | ||
out.mean().backward() | ||
|
||
ddp_all_out, _ = all_gather_variable_dim(out) | ||
|
||
# on single device | ||
|
||
all_inputs, _ = all_gather_variable_dim(seq) | ||
copied_net = deepcopy(net) | ||
|
||
single_out = copied_net( | ||
all_inputs, | ||
is_distributed = False | ||
) | ||
|
||
single_out.mean().backward() | ||
|
||
if rank == 0: | ||
# validate output is the same | ||
# if done on 1 vs multiple machines | ||
|
||
assert torch.allclose(single_out, ddp_all_out), 'output is not the same' | ||
|
||
# validate backwards and grad | ||
|
||
get_first_expert_grad = lambda t: t.experts[0].net[0].weight.grad | ||
|
||
assert torch.allclose( | ||
get_first_expert_grad(net), | ||
get_first_expert_grad(copied_net), | ||
atol = 1e-2 | ||
), 'grad is not the same' | ||
|
||
print('✅') | ||
|
||
cleanup() | ||
|
||
if __name__ == '__main__': | ||
world_size = 4 | ||
num_experts = 8 | ||
batch_size = 2 | ||
batch_size_var_len = False | ||
|
||
seq_len = 32 | ||
dim = 8 | ||
|
||
mp.spawn( | ||
start, | ||
args = ( | ||
world_size, | ||
batch_size, | ||
batch_size_var_len, | ||
num_experts, | ||
seq_len, | ||
dim | ||
), | ||
nprocs = world_size, | ||
join = True | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
import torch | ||
from torch import nn | ||
import torch.nn.functional as F | ||
from torch.autograd import Function | ||
|
||
import torch.distributed as dist | ||
|
||
from einops import rearrange, pack, unpack | ||
|
||
def exists(val): | ||
return val is not None | ||
|
||
def default(val, d): | ||
return val if exists(val) else d | ||
|
||
def divisible_by(num, den): | ||
return (num % den) == 0 | ||
|
||
def pad_dim_to(t, length, dim = 0): | ||
pad_length = length - t.shape[dim] | ||
zero_pairs = (-dim - 1) if dim < 0 else (t.ndim - dim - 1) | ||
return F.pad(t, (*((0, 0) * zero_pairs), 0, pad_length)) | ||
|
||
def all_gather_same_dim(t): | ||
world_size = dist.get_world_size() | ||
gathered_tensors = [torch.empty_like(t, device = t.device, dtype = t.dtype) for i in range(world_size)] | ||
dist.all_gather(gathered_tensors, t) | ||
return gathered_tensors | ||
|
||
def gather_sizes(t, *, dim): | ||
size = torch.tensor(t.shape[dim], device = t.device, dtype = torch.long) | ||
sizes = all_gather_same_dim(size) | ||
return torch.stack(sizes) | ||
|
||
def has_only_one_value(t): | ||
return (t == t[0]).all() | ||
|
||
def all_gather_variable_dim(t, dim = 0, sizes = None): | ||
device, rank, world_size = t.device, dist.get_rank(), dist.get_world_size() | ||
|
||
if not exists(sizes): | ||
sizes = gather_sizes(t, dim = dim) | ||
|
||
if has_only_one_value(sizes): | ||
gathered_tensors = all_gather_same_dim(t) | ||
gathered_tensors = torch.cat(gathered_tensors, dim = dim) | ||
return gathered_tensors, sizes | ||
|
||
max_size = sizes.amax().item() | ||
|
||
padded_t = pad_dim_to(t, max_size, dim = dim) | ||
gathered_tensors = all_gather_same_dim(padded_t) | ||
|
||
gathered_tensors = torch.cat(gathered_tensors, dim = dim) | ||
seq = torch.arange(max_size, device = device) | ||
|
||
mask = rearrange(seq, 'j -> 1 j') < rearrange(sizes, 'i -> i 1') | ||
mask = rearrange(mask, 'i j -> (i j)') | ||
seq = torch.arange(mask.shape[-1], device = device) | ||
indices = seq[mask] | ||
|
||
gathered_tensors = gathered_tensors.index_select(dim, indices) | ||
|
||
return gathered_tensors, sizes | ||
|
||
class AllGatherFunction(Function): | ||
@staticmethod | ||
def forward(ctx, x, dim, sizes): | ||
x, batch_sizes = all_gather_variable_dim(x, dim = dim, sizes = sizes) | ||
ctx.batch_sizes = batch_sizes.tolist() | ||
ctx.dim = dim | ||
return x, batch_sizes | ||
|
||
@staticmethod | ||
def backward(ctx, grads, _): | ||
batch_sizes, rank = ctx.batch_sizes, dist.get_rank() | ||
grads_by_rank = grads.split(batch_sizes, dim = ctx.dim) | ||
return grads_by_rank[rank], None, None | ||
|
||
class AllGather(nn.Module): | ||
def __init__(self, *, dim = 0): | ||
super().__init__() | ||
self.dim = dim | ||
|
||
def forward(self, x, sizes = None): | ||
return AllGatherFunction.apply(x, self.dim, sizes) | ||
|
||
class SplitByRank(Function): | ||
@staticmethod | ||
def forward(ctx, x): | ||
rank = dist.get_rank() | ||
return x[rank] | ||
|
||
@staticmethod | ||
def backward(ctx, grads): | ||
grads = rearrange(grads, '... -> 1 ...') | ||
grads = all_gather_variable_dim(grads) | ||
return grads | ||
|
||
split_by_rank = SplitByRank.apply |
Oops, something went wrong.