Skip to content

Commit

Permalink
handle variable sequence lengths if allow_var_seq_len = True on `Ex…
Browse files Browse the repository at this point in the history
…perts`
  • Loading branch information
lucidrains committed Sep 11, 2023
1 parent 6982343 commit 2bb762d
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 23 deletions.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,9 @@ out, balance_loss, router_z_loss = moe_block(inputs) # (4, 1024, 512), (1,), (1,
- [x] handle any world size less than number of experts
- [x] handle any world size greater than number of experts - for now, just have remainder machines do nothing
- [x] support variable batch sizes
- [ ] support variable seq lengths
- [x] support variable seq lengths
- [ ] figure out how to move assert.py to pytests
- [ ] simplify the variable sequence length test code from another folder and move in so other researchers gain confidence
- [ ] optimize
- [ ] figure out what is faster, all gather, or broadcast with async followed by barrier
- [ ] make all distributed code pluggable, for different strategies
Expand Down
20 changes: 10 additions & 10 deletions assert.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,41 +45,41 @@ def start(

# on single device

all_inputs, _ = all_gather_variable_dim(seq)
copied_net = deepcopy(net)
local_inputs, _ = all_gather_variable_dim(seq)
local_net = deepcopy(net)

single_out = copied_net(
all_inputs,
local_out = local_net(
local_inputs,
is_distributed = False
)

single_out.mean().backward()
local_out.mean().backward()

if rank == 0:
# validate output is the same
# if done on 1 vs multiple machines

assert torch.allclose(single_out, ddp_all_out), 'output is not the same'
assert torch.allclose(local_out, ddp_all_out), 'output is not the same'

# validate backwards and grad

get_first_expert_grad = lambda t: t.experts[0].net[0].weight.grad

assert torch.allclose(
get_first_expert_grad(net),
get_first_expert_grad(copied_net),
get_first_expert_grad(local_net),
atol = 1e-2
), 'grad is not the same'

print('✅')
print('✅ outputs and gradients are same between local and ddp')

cleanup()

if __name__ == '__main__':
world_size = 9
world_size = 13
num_experts = 4
batch_size = 2
batch_size_var_len = False
batch_size_var_len = True

seq_len = 32
dim = 8
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'st-moe-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.29',
version = '0.0.30',
license='MIT',
description = 'ST - Mixture of Experts - Pytorch',
author = 'Phil Wang',
Expand Down
51 changes: 40 additions & 11 deletions st_moe_pytorch/st_moe_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,17 +152,25 @@ class Experts(nn.Module):
def __init__(
self,
experts,
is_distributed = None
is_distributed = None,
allow_var_seq_len = False # whether to handle variable sequence length
):
super().__init__()
self.num_experts = len(experts)
self.experts = nn.ModuleList(experts)

# distributed related settings

self.is_distributed = is_distributed
if not exists(self.is_distributed):
self.is_distributed = dist.is_initialized() and dist.get_world_size() > 1

self.all_gather = AllGather()

self.allow_var_seq_len = allow_var_seq_len

# device tracker, since need to manually move experts not in use to CPU in distributed

self.register_buffer('dummy', torch.ones(1), persistent = False)

@property
Expand Down Expand Up @@ -197,26 +205,42 @@ def forward(
d - feature dimension
"""

# declare some variables

is_distributed = default(is_distributed, self.is_distributed)
shape, num_experts = x.shape, self.num_experts
seq_len = shape[-2]

# for now naively all gather across batch dimension if distributed, optimize later

world_size = 1
rank = 0

if is_distributed:
seq_sizes = gather_sizes(x, dim = -2)
assert has_only_one_value(seq_sizes), 'number of tokens per expert must be the same'
var_seq_len = not has_only_one_value(seq_sizes)

assert self.allow_var_seq_len or not var_seq_len, 'number of tokens per expert must be the same - if you want the framework to handle it, set `allow_var_seq_len = True` on `Experts`'

# if variable sequence length, pad

if var_seq_len:
max_seq_size = seq_sizes.amax().item()
x = pad_dim_to(x, max_seq_size, dim = -2)

# gather and concat across batches, accounting for variable batch sizes

x, batch_sizes = self.all_gather(x)
total_batch_size = batch_sizes.sum().item()

world_size = dist.get_world_size()
rank = dist.get_rank()
else:
world_size = 1
rank = 0

# the experts in use on the rank

num_experts_per_rank = num_experts
expert_slice = slice(0, num_experts)

if is_distributed:
if world_size <= num_experts:
num_experts_across_ranks = chunk_num(num_experts, world_size)
Expand Down Expand Up @@ -245,9 +269,6 @@ def forward(
assert len(num_experts_batches_across_ranks) == world_size

expert_slice = slice(expert_start_index, expert_start_index + num_experts_per_rank)
else:
num_experts_per_rank = num_experts
expert_slice = slice(0, num_experts)

# if distributed, each machine only handles subset of experts and batch

Expand Down Expand Up @@ -281,7 +302,7 @@ def forward(
if len(outs) > 0:
outs = torch.stack(outs)
else:
outs = torch.empty_like(x).requires_grad_()
outs = torch.empty_like(x, requires_grad = self.training)

# all gather across merged expert batches dimensions
# then split the batch dimension back
Expand All @@ -297,6 +318,9 @@ def forward(
outs = outs.split(batch_sizes.tolist())
outs, _ = split_by_rank(outs)

# account for padded sequence length
outs = outs[..., :seq_len, :]

assert outs.shape == shape
return outs

Expand Down Expand Up @@ -527,7 +551,8 @@ def __init__(self,
straight_through_dispatch_tensor = True,
differentiable_topk = False,
differentiable_topk_fused = True,
is_distributed = None
is_distributed = None,
allow_var_seq_len = False
):
super().__init__()
self.dim = dim
Expand All @@ -547,7 +572,11 @@ def __init__(self,

experts = default(experts, lambda: [Expert(dim = dim, hidden_mult = expert_hidden_mult) for _ in range(num_experts)])

self.experts = Experts(experts, is_distributed = is_distributed)
self.experts = Experts(
experts,
is_distributed = is_distributed,
allow_var_seq_len = allow_var_seq_len
)

self.loss_coef = loss_coef
self.router_z_loss_coef = router_z_loss_coef
Expand Down

0 comments on commit 2bb762d

Please sign in to comment.