-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathcustom_layers.py
302 lines (254 loc) · 9.92 KB
/
custom_layers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
r"""
FMoE core layer
"""
import tree
import os
import torch
import torch.nn as nn
from functions import prepare_forward, ensure_comm
from functions import MOEScatter, MOEGather
from functions import AllGather, Slice
from custom_gate import HyperRouterGate
from fastermoe.config import switch_from_env
def mark_module_parallel_comm(module, comm):
r"""
Mark all parameters in `module` as doing data parallel in `comm`, where
`comm` may be one of `'world', 'dp', 'none'`.
"""
for p in module.parameters():
setattr(p, "dp_comm", comm)
def _fmoe_general_global_forward(inp, gate, expert_fn, num_expert, world_size, **kwargs):
r"""
A private function that performs the following steps to complete the MoE
computation.
* Count the number of tokens from each worker to each expert.
* Send the features to their target position so that input features to each
expert are contiguous in memory.
* Perform the forward computation of the experts using `expert_fn`
* Gather the output features of experts back, and reorder them as sentences.
Intermediate results like expert counts are hidden from users by this
function.
"""
(
pos,
local_expert_count,
global_expert_count,
fwd_expert_count,
fwd_batch_size,
) = prepare_forward(gate, num_expert, world_size)
topk = 1
if len(gate.shape) == 2:
topk = gate.shape[1]
def scatter_func(tensor):
return MOEScatter.apply(
tensor,
torch.div(pos, topk, rounding_mode='floor'),
local_expert_count,
global_expert_count,
fwd_batch_size,
world_size,
)
x = tree.map_structure(scatter_func, inp)
x = expert_fn(x, fwd_expert_count)
out_batch_size = tree.flatten(inp)[0].shape[0]
if len(gate.shape) == 2:
out_batch_size *= gate.shape[1]
def gather_func(tensor):
return MOEGather.apply(
tensor,
pos,
local_expert_count,
global_expert_count,
out_batch_size,
world_size,
)
outp = tree.map_structure(gather_func, x)
return outp
fmoe_faster_schedule = False
if switch_from_env('FMOE_FASTER_SCHEDULE_ENABLE', False):
fmoe_faster_schedule = True
from .fastermoe.schedule import _fmoe_general_global_forward
class FMoE(nn.Module):
r"""
A general moe implementation that supports an arbitrary module as the
expert.
* `num_expert` stands for the number of experts on **each** worker.
* `world_size` stands for the total number of workers that contains
different experts.
* `slice_group` can be a torch's communication group, indicating that
specific model parallel is applied across the group, and workers in the
group hold the same copy of input feature, and requires the same copy of
the output. For each worker, FMoE only computes the output of a certain
slice of the input batch, and will all-gather the outputs after
computation.
* `top_k` stands for the number of experts each token is going to.
* `gate` is a gate class which can found in `fmoe.gates`.
* `expert` can be specified as a module class, it is used to generate
`num_expert` expert modules.
"""
def __init__(
self,
num_expert=32,
d_model=1024,
world_size=1,
mp_group=None, # being deprecated
slice_group=None,
moe_group=None,
top_k=2,
gate=HyperRouterGate,
expert=None,
gate_hook=None,
mask=None,
mask_dict=None,
hyper_size=None,
):
super().__init__()
self.num_expert = num_expert
self.d_model = d_model
self.world_size = world_size
self.slice_group = slice_group
if mp_group is not None:
print("[Warning] mp_group is being deprecated")
self.slice_group = mp_group
if self.slice_group is None:
self.slice_size = 1
self.slice_rank = 0
else:
self.slice_size = self.slice_group.size()
self.slice_rank = self.slice_group.rank()
self.top_k = top_k
if type(expert) is list:
self.experts = nn.ModuleList([e(d_model) for e in expert])
self.experts_fused = False
self.num_expert = num_expert = len(expert)
elif expert is not None:
self.experts = nn.ModuleList([expert(d_model) for _ in range(num_expert)])
self.experts_fused = False
else:
self.experts_fused = True
if hyper_size is not None and isinstance(m.gate, HyperRouterGate):
self.gate = gate(d_model, num_expert, world_size, top_k, hyper_size)
else:
self.gate = gate(d_model, num_expert, world_size, top_k)
self.gate_hook = gate_hook
self.mask = mask
self.mask_dict = mask_dict
self.moe_group = moe_group
def expert_fn(self, inp, fwd_expert_count):
r"""
The default expert function which either calls the experts as a whole
or as separate experts.
"""
if self.experts_fused:
return self.experts(inp, fwd_expert_count)
if isinstance(fwd_expert_count, torch.Tensor):
fwd_expert_count = fwd_expert_count.cpu().numpy()
outputs = []
base_idx = 0
for i in range(self.num_expert):
batch_size = fwd_expert_count[i]
inp_slice = inp[base_idx : base_idx + batch_size]
outputs.append(self.experts[i](inp_slice))
base_idx += batch_size
return torch.cat(outputs, dim=0)
def mark_parallel_comm(self, expert_dp_comm="none"):
r"""
Automatically mark the data parallel comms of the parameters within the
module. This can be typically called at the end of the __init__ function
in child classes.
"""
if self.experts is not None:
comm = expert_dp_comm
if isinstance(self.experts, list):
for e in self.experts:
mark_module_parallel_comm(e, comm)
else:
mark_module_parallel_comm(self.experts, comm)
mark_module_parallel_comm(self.gate, "gate")
def forward(self, moe_inp):
r"""
The FMoE module first computes gate output, and then conduct MoE forward
according to the gate. The score of the selected gate given by the
expert is multiplied to the experts' output tensors as a weight.
"""
moe_inp_batch_size = tree.flatten(
tree.map_structure(lambda tensor: tensor.shape[0], moe_inp)
)
assert all(
[batch_size == moe_inp_batch_size[0] for batch_size in moe_inp_batch_size]
), "MoE inputs must have the same batch size"
if self.world_size > 1:
def ensure_comm_func(tensor):
ensure_comm(tensor, self.moe_group)
tree.map_structure(ensure_comm_func, moe_inp)
if self.slice_size > 1:
def slice_func(tensor):
return Slice.apply(
tensor, self.slice_rank, self.slice_size, self.slice_group
)
moe_inp = tree.map_structure(slice_func, moe_inp)
gate_top_k_idx, gate_score = self.gate(moe_inp)
if hasattr(self.gate, 'dynamic_top_k'):
self.top_k = self.gate.dynamic_top_k
if self.gate_hook is not None:
self.gate_hook(gate_top_k_idx, gate_score, None)
# delete masked tensors
if self.mask is not None and self.mask_dict is not None:
# TODO: to fix
def delete_mask_func(tensor):
# to: (BxL') x d_model
tensor = tensor[mask == 0, :]
return tensor
mask = self.mask.view(-1)
moe_inp = tree.map_structure(delete_mask_func, moe_inp)
gate_top_k_idx = gate_top_k_idx[mask == 0, :]
fwd = _fmoe_general_global_forward(
moe_inp, gate_top_k_idx, self.expert_fn,
self.num_expert, self.world_size,
experts=self.experts
)
# recover deleted tensors
if self.mask is not None and self.mask_dict is not None:
def recover_func(tensor):
# to: (BxL') x top_k x dim
dim = tensor.shape[-1]
tensor = tensor.view(-1, self.top_k, dim)
# to: (BxL) x top_k x d_model
x = torch.zeros(
mask.shape[0],
self.top_k,
dim,
device=tensor.device,
dtype=tensor.dtype,
)
# recover
x[mask == 0] = tensor
for k, v in self.mask_dict.items():
x[mask == k] = v
return x
moe_outp = tree.map_structure(recover_func, fwd)
else:
def view_func(tensor):
dim = tensor.shape[-1]
tensor = tensor.view(-1, self.top_k, dim)
return tensor
moe_outp = tree.map_structure(view_func, fwd)
gate_score = gate_score.view(-1, 1, self.top_k)
def bmm_func(tensor):
dim = tensor.shape[-1]
tensor = torch.bmm(gate_score, tensor).reshape(-1, dim)
return tensor
moe_outp = tree.map_structure(bmm_func, moe_outp)
if self.slice_size > 1:
def all_gather_func(tensor):
return AllGather.apply(
tensor, self.slice_rank, self.slice_size, self.slice_group
)
moe_outp = tree.map_structure(all_gather_func, moe_outp)
moe_outp_batch_size = tree.flatten(
tree.map_structure(lambda tensor: tensor.shape[0], moe_outp)
)
assert all(
[batch_size == moe_outp_batch_size[0] for batch_size in moe_outp_batch_size]
), "MoE outputs must have the same batch size"
return moe_outp