Skip to content

Commit ef58511

Browse files
KimmiShizigzagcai
authored andcommitted
fix(moe): fix moe act late release (#387)
Co-authored-by: shidongxing <shidongxing@>
1 parent 1339fc8 commit ef58511

File tree

6 files changed

+27
-16
lines changed

6 files changed

+27
-16
lines changed

internlm/model/model_ops/modules/linear.py

+13-4
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,7 @@ def forward(
338338
raise NotImplementedError(f"Invalid backend: {backend}")
339339

340340
input_numel = x.numel()
341+
ctx.input_numel = input_numel
341342
if input_numel == 0:
342343
backend = "bmm"
343344

@@ -357,9 +358,11 @@ def forward(
357358
if input_numel == 0:
358359
# if inp is empty, reshape to make grad flow.
359360
# inp shape: (0, hdim)
360-
weight = weight.view(x.shape[-1], -1)
361+
output = torch.matmul(x, weight.view(x.shape[-1], -1))
362+
else:
363+
output = torch.matmul(x, weight)
361364

362-
output = torch.matmul(x, weight)
365+
assert len(output.shape) == len(x.shape)
363366

364367
assert len(output.shape) == len(x.shape)
365368

@@ -387,14 +390,20 @@ def backward(ctx, grad_output):
387390
if backend == "gmm":
388391
grad_input, grad_weight = gmm_backward_op(x, grad_output, batch_sizes, input_weight=weight)
389392
else:
390-
grad_weight = torch.matmul(x.transpose(-1, -2), grad_output)
393+
if ctx.input_numel == 0:
394+
grad_weight = torch.zeros_like(weight)
395+
else:
396+
grad_weight = torch.matmul(x.transpose(-1, -2), grad_output)
391397

392398
if ctx.needs_input_grad[0]:
393399
if backend == "gmm":
394400
if grad_input is None:
395401
grad_input, _ = gmm_backward_op(grad_output, weight, batch_sizes, is_grad_input=True)
396402
else:
397-
grad_input = torch.matmul(grad_output, weight.transpose(-1, -2))
403+
if ctx.input_numel == 0:
404+
grad_input = torch.zeros_like(x)
405+
else:
406+
grad_input = torch.matmul(grad_output, weight.transpose(-1, -2))
398407

399408
return grad_input, grad_weight, None, None, None, None, None
400409

internlm/model/model_ops/moe/base_layer.py

-3
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
11
from typing import TYPE_CHECKING, Union
22

3-
import torch
43
from torch import Tensor
54
from torch.nn import Module, ModuleList
65

76
from internlm.core.context import global_context as gpc
87
from internlm.model.model_ops.moe.experts import Experts
9-
from internlm.utils.common import get_current_device
108

119
if TYPE_CHECKING:
1210
Base = Module[Tensor]
@@ -32,7 +30,6 @@ def __init__(
3230
self.ep_group = ep_group
3331
self.ep_size = ep_size
3432
self.num_local_experts = num_local_experts
35-
self.l_aux = torch.tensor(0.0, device=get_current_device(), dtype=gpc.config.model.get("dtype"))
3633
self.exp_counts = None
3734

3835
for _, param in self.gate.named_parameters():

internlm/model/model_ops/moe/dropless_layer.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,12 @@ def forward(self, *inputs: Tensor) -> Tensor:
288288

289289
# Reshape the output tensor
290290
output = output.view(self.hidden_shape)
291-
return output
291+
292+
# Note: 1. we need to relase self.l_aux and its compute graph; 2. we need self.l_aux to simplify code
293+
# so we first use self.l_aux and then reset it.
294+
l_aux = self.l_aux
295+
self.l_aux = None
296+
return output, l_aux
292297

293298
def topk_softmax_with_capacity(self, gates):
294299
expert_weights, indices = torch.topk(gates, self.topk, dim=1)

internlm/model/model_ops/moe/gshard_layer.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -555,7 +555,7 @@ def forward(self, *inputs: Tensor) -> Tensor:
555555
# group_size = kwargs['group_size'] if 'group_size' in kwargs.keys() else 1
556556
reshaped_inputs = inputs[0].reshape(-1, d_model)
557557

558-
self.l_aux, combine_weights, dispatch_mask, self.exp_counts = self.gate(reshaped_inputs, inputs[1])
558+
l_aux, combine_weights, dispatch_mask, self.exp_counts = self.gate(reshaped_inputs, inputs[1])
559559
dispatched_inputs = einsum(
560560
"sec,sm->ecm", dispatch_mask.type_as(inputs[0]), reshaped_inputs
561561
) # TODO: heavy memory usage due to long sequence length
@@ -608,4 +608,4 @@ def forward(self, *inputs: Tensor) -> Tensor:
608608
timer("moe").stop()
609609
self.time_moe = timer("moe").elapsed(reset=False)
610610

611-
return out
611+
return out, l_aux

internlm/model/model_ops/moe/megablocks/megablock_moe.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,6 @@ def forward(self, *inputs) -> torch.Tensor:
303303

304304
x, tokens_per_expert = self.forward_fn(x, expert_weights, top_experts)
305305

306-
self.l_aux = self.load_balancing_loss(tokens_per_expert, all_probs)
306+
l_aux = self.load_balancing_loss(tokens_per_expert, all_probs)
307307

308-
return x.view(*input_shape)
308+
return x.view(*input_shape), l_aux

internlm/model/model_ops/moe/moe.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ def forward(self, hidden_states, used_token=None):
181181
182182
* exp_counts (int): expert count
183183
"""
184-
output = self.moe_layer(hidden_states, used_token)
184+
output, l_aux = self.moe_layer(hidden_states, used_token)
185185
if self.num_shared_experts > 0:
186186
# Residual MoE
187187
output_mlp = self.residual_mlp(hidden_states)
@@ -190,7 +190,7 @@ def forward(self, hidden_states, used_token=None):
190190
coef = self.coefficient(hidden_states)
191191
coef = torch.nn.functional.softmax(coef, dim=-1)
192192
output = output * coef[..., 0:1] + output_mlp * coef[..., 1:]
193-
return output, self.moe_layer.l_aux, self.moe_layer.exp_counts
193+
return output, l_aux, self.moe_layer.exp_counts
194194

195195

196196
class Qwen2MoE(MoEBase):
@@ -264,7 +264,7 @@ def forward(self, hidden_states, used_token=None):
264264
265265
* exp_counts (int): expert count
266266
"""
267-
output = self.moe_layer(hidden_states, used_token)
267+
output, l_aux = self.moe_layer(hidden_states, used_token)
268268
if self.num_shared_experts > 0:
269269
# Residual MoE
270270
output_mlp = self.residual_mlp(hidden_states)
@@ -273,4 +273,4 @@ def forward(self, hidden_states, used_token=None):
273273
coef = self.coefficient(hidden_states)
274274
output_mlp = F.sigmoid(coef) * output_mlp
275275
output = output + output_mlp
276-
return output, self.moe_layer.l_aux, self.moe_layer.exp_counts
276+
return output, l_aux, self.moe_layer.exp_counts

0 commit comments

Comments
 (0)