Skip to content

Commit 00be346

Browse files
committed
any combinatino of number of experts and world size should not break
1 parent 52b5c8a commit 00be346

File tree

4 files changed

+34
-17
lines changed

4 files changed

+34
-17
lines changed

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ out, balance_loss, router_z_loss = moe_block(inputs) # (4, 1024, 512), (1,), (1,
7070

7171
- [ ] distributed
7272
- [x] handle any world size less than number of experts
73-
- [ ] handle any world size greater than number of experts
73+
- [x] handle any world size greater than number of experts - for now, just have remainder machines do nothing
7474
- [ ] optimize
7575
- [ ] figure out what is faster, all gather, or broadcast with async followed by barrier
7676
- [ ] make all distributed code pluggable, for different strategies

assert.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,8 @@ def start(
7676
cleanup()
7777

7878
if __name__ == '__main__':
79-
world_size = 8
80-
num_experts = 8
79+
world_size = 9
80+
num_experts = 4
8181
batch_size = 2
8282
batch_size_var_len = False
8383

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'st-moe-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.0.28',
6+
version = '0.0.29',
77
license='MIT',
88
description = 'ST - Mixture of Experts - Pytorch',
99
author = 'Phil Wang',

st_moe_pytorch/st_moe_pytorch.py

+30-13
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,7 @@ def forward(
207207
assert has_only_one_value(seq_sizes), 'number of tokens per expert must be the same'
208208

209209
x, batch_sizes = self.all_gather(x)
210+
total_batch_size = batch_sizes.sum().item()
210211

211212
world_size = dist.get_world_size()
212213
rank = dist.get_rank()
@@ -220,14 +221,28 @@ def forward(
220221
if world_size <= num_experts:
221222
num_experts_across_ranks = chunk_num(num_experts, world_size)
222223
start_indices = cumsum_exclusive(torch.tensor(num_experts_across_ranks), dim = -1)
224+
223225
num_experts_per_rank = num_experts_across_ranks[rank]
226+
num_experts_batches_across_ranks = tuple(i * total_batch_size for i in num_experts_across_ranks)
227+
224228
expert_start_index = start_indices[rank].item()
225229
else:
226-
# for now, make sure number of machines is right multiple
230+
num_batch_chunks = world_size // num_experts
231+
total_ranks_in_use = num_batch_chunks * num_experts
232+
233+
expert_start_index = rank // num_batch_chunks
234+
235+
batch_splits = chunk_num(total_batch_size, num_batch_chunks)
236+
num_experts_batches_across_ranks = batch_splits * num_experts
237+
238+
# for now, remaining machines just process nothing
239+
240+
remain_ranks = world_size % num_experts
241+
num_experts_batches_across_ranks += (0,) * remain_ranks
242+
243+
num_experts_per_rank = int(rank < total_ranks_in_use)
227244

228-
assert divisible_by(world_size, num_experts), 'if number of machines is greater than number of experts, machines must be divisible by number of experts, so experts are evenly distributed'
229-
num_experts_per_rank = 1
230-
expert_start_index = rank // (world_size // num_experts)
245+
assert len(num_experts_batches_across_ranks) == world_size
231246

232247
expert_slice = slice(expert_start_index, expert_start_index + num_experts_per_rank)
233248
else:
@@ -241,15 +256,13 @@ def forward(
241256
if is_distributed:
242257
x, expert_batch_packed_shape = pack_one(x, '* n d')
243258

244-
if world_size <= num_experts:
245-
x = rearrange(x, '(e b) n d -> e b n d', e = num_experts)
246-
x = x.split(num_experts_across_ranks, dim = 0)
247-
x = tuple(rearrange(t, 'e b n d -> (e b) n d') for t in x)
248-
else:
249-
x = rearrange(x, '(r eb) n d -> r eb n d', r = world_size)
250-
259+
x = x.split(num_experts_batches_across_ranks, dim = 0)
251260
x, experts_per_rank_sizes = split_by_rank(x)
252-
x = rearrange(x, '(e b) n d -> e b n d', e = num_experts_per_rank)
261+
262+
if num_experts_per_rank > 0:
263+
x = rearrange(x, '(e b) n d -> e b n d', e = num_experts_per_rank)
264+
else:
265+
x = x.reshape(num_experts, *x.shape)
253266

254267
# get the experts in use
255268

@@ -260,11 +273,15 @@ def forward(
260273
# route tokens to appropriate experts
261274

262275
outs = []
276+
263277
for expert, expert_input in zip(experts, x):
264278
out = expert(expert_input)
265279
outs.append(out)
266280

267-
outs = torch.stack(outs)
281+
if len(outs) > 0:
282+
outs = torch.stack(outs)
283+
else:
284+
outs = torch.empty_like(x).requires_grad_()
268285

269286
# all gather across merged expert batches dimensions
270287
# then split the batch dimension back

0 commit comments

Comments
 (0)