@@ -207,6 +207,7 @@ def forward(
207
207
assert has_only_one_value (seq_sizes ), 'number of tokens per expert must be the same'
208
208
209
209
x , batch_sizes = self .all_gather (x )
210
+ total_batch_size = batch_sizes .sum ().item ()
210
211
211
212
world_size = dist .get_world_size ()
212
213
rank = dist .get_rank ()
@@ -220,14 +221,28 @@ def forward(
220
221
if world_size <= num_experts :
221
222
num_experts_across_ranks = chunk_num (num_experts , world_size )
222
223
start_indices = cumsum_exclusive (torch .tensor (num_experts_across_ranks ), dim = - 1 )
224
+
223
225
num_experts_per_rank = num_experts_across_ranks [rank ]
226
+ num_experts_batches_across_ranks = tuple (i * total_batch_size for i in num_experts_across_ranks )
227
+
224
228
expert_start_index = start_indices [rank ].item ()
225
229
else :
226
- # for now, make sure number of machines is right multiple
230
+ num_batch_chunks = world_size // num_experts
231
+ total_ranks_in_use = num_batch_chunks * num_experts
232
+
233
+ expert_start_index = rank // num_batch_chunks
234
+
235
+ batch_splits = chunk_num (total_batch_size , num_batch_chunks )
236
+ num_experts_batches_across_ranks = batch_splits * num_experts
237
+
238
+ # for now, remaining machines just process nothing
239
+
240
+ remain_ranks = world_size % num_experts
241
+ num_experts_batches_across_ranks += (0 ,) * remain_ranks
242
+
243
+ num_experts_per_rank = int (rank < total_ranks_in_use )
227
244
228
- assert divisible_by (world_size , num_experts ), 'if number of machines is greater than number of experts, machines must be divisible by number of experts, so experts are evenly distributed'
229
- num_experts_per_rank = 1
230
- expert_start_index = rank // (world_size // num_experts )
245
+ assert len (num_experts_batches_across_ranks ) == world_size
231
246
232
247
expert_slice = slice (expert_start_index , expert_start_index + num_experts_per_rank )
233
248
else :
@@ -241,15 +256,13 @@ def forward(
241
256
if is_distributed :
242
257
x , expert_batch_packed_shape = pack_one (x , '* n d' )
243
258
244
- if world_size <= num_experts :
245
- x = rearrange (x , '(e b) n d -> e b n d' , e = num_experts )
246
- x = x .split (num_experts_across_ranks , dim = 0 )
247
- x = tuple (rearrange (t , 'e b n d -> (e b) n d' ) for t in x )
248
- else :
249
- x = rearrange (x , '(r eb) n d -> r eb n d' , r = world_size )
250
-
259
+ x = x .split (num_experts_batches_across_ranks , dim = 0 )
251
260
x , experts_per_rank_sizes = split_by_rank (x )
252
- x = rearrange (x , '(e b) n d -> e b n d' , e = num_experts_per_rank )
261
+
262
+ if num_experts_per_rank > 0 :
263
+ x = rearrange (x , '(e b) n d -> e b n d' , e = num_experts_per_rank )
264
+ else :
265
+ x = x .reshape (num_experts , * x .shape )
253
266
254
267
# get the experts in use
255
268
@@ -260,11 +273,15 @@ def forward(
260
273
# route tokens to appropriate experts
261
274
262
275
outs = []
276
+
263
277
for expert , expert_input in zip (experts , x ):
264
278
out = expert (expert_input )
265
279
outs .append (out )
266
280
267
- outs = torch .stack (outs )
281
+ if len (outs ) > 0 :
282
+ outs = torch .stack (outs )
283
+ else :
284
+ outs = torch .empty_like (x ).requires_grad_ ()
268
285
269
286
# all gather across merged expert batches dimensions
270
287
# then split the batch dimension back
0 commit comments