-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathlinear.py
92 lines (75 loc) · 2.6 KB
/
linear.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
r"""
FMoE's parallel linear layer
"""
import torch
import torch.nn as nn
from torch.autograd import Function
import math
import fmoe_cuda
class MOELinear(Function):
r"""
Computes linear operators within one GPU on different experts simutaneously.
"""
@staticmethod
def forward(ctx, global_input_buf, fwd_expert_count, weight, bias=None):
global_output_buf = fmoe_cuda.linear_forward(
global_input_buf, fwd_expert_count, weight, bias
)
variables = (global_input_buf, fwd_expert_count, weight, bias)
ctx.save_for_backward(*variables)
return global_output_buf
@staticmethod
def backward(ctx, grad_out):
(input_buf, fwd_expert_count, weight, bias) = ctx.saved_tensors
grad_inp_buf, grad_weight, grad_bias = fmoe_cuda.linear_backward(
grad_out, input_buf, fwd_expert_count, weight, bias
)
if not torch.is_tensor(bias):
grad_bias = None
return grad_inp_buf, None, grad_weight, grad_bias
class FMoELinear(nn.Module):
r"""
A linear layer that contains multiple experts.
As multiple experts can be placed on the same worker, the computation can be
performed in parallel to increase the performance.
The FMoELinear module provides such function.
"""
def __init__(
self,
num_expert: int,
in_feat: int,
out_feat: int,
bias: bool = True,
rank: int = 0,
):
super().__init__()
self.num_expert = num_expert
self.in_feat = in_feat
self.out_feat = out_feat
self.rank = rank
self.weight = nn.Parameter(torch.Tensor(num_expert, out_feat, in_feat))
if bias:
self.bias = nn.Parameter(torch.zeros(num_expert, out_feat))
else:
self.register_parameter("bias", None)
self.reset_parameters()
def forward(self, inp, fwd_expert_count):
r"""
Call MOE function
"""
x = MOELinear.apply(inp, fwd_expert_count, self.weight, self.bias)
return x
def extra_repr(self) -> str:
return "num_expert={}, in_features={}, \
out_features={}, bias={}, rank={}".format(
self.num_expert,
self.in_feat,
self.out_feat,
self.bias is not None,
self.rank,
)
def reset_parameters(self):
# Approach is the same as in torch.nn.Linear
# https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/linear.py#L88
# bias is left to zero, similar as megatron
torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))