Skip to content

Commit 05cde4d

Browse files
authored
MoE for mixtral 8x7b (#2535)
* MoE for mixtral 8x7b * removing bnb_sparse for now
1 parent 2509d93 commit 05cde4d

File tree

4 files changed

+128
-13
lines changed

4 files changed

+128
-13
lines changed

onmt/decoders/transformer.py

+45-12
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from onmt.modules import MultiHeadedAttention, AverageAttention
1010
from onmt.modules.position_ffn import PositionwiseFeedForward
1111
from onmt.modules.position_ffn import ActivationFunction
12+
from onmt.modules.moe import MoE
1213
from onmt.utils.misc import sequence_mask
1314

1415
try:
@@ -43,6 +44,8 @@ def __init__(
4344
parallel_gpu=1,
4445
sliding_window=0,
4546
rotary_interleave=True,
47+
num_experts=0,
48+
num_experts_per_tok=2,
4649
):
4750
"""
4851
Args:
@@ -109,18 +112,34 @@ def __init__(
109112
d_model, dropout=attention_dropout, aan_useffn=aan_useffn
110113
)
111114

112-
self.feed_forward = PositionwiseFeedForward(
113-
d_model,
114-
d_ff,
115-
dropout,
116-
pos_ffn_activation_fn,
117-
add_ffnbias,
118-
parallel_residual,
119-
layer_norm,
120-
norm_eps,
121-
use_ckpting=use_ckpting,
122-
parallel_gpu=parallel_gpu,
123-
)
115+
if num_experts > 0:
116+
self.feed_forward = MoE(
117+
num_experts,
118+
num_experts_per_tok,
119+
d_model,
120+
d_ff,
121+
dropout,
122+
pos_ffn_activation_fn,
123+
add_ffnbias,
124+
parallel_residual,
125+
layer_norm,
126+
norm_eps,
127+
use_ckpting=use_ckpting,
128+
parallel_gpu=parallel_gpu,
129+
)
130+
else:
131+
self.feed_forward = PositionwiseFeedForward(
132+
d_model,
133+
d_ff,
134+
dropout,
135+
pos_ffn_activation_fn,
136+
add_ffnbias,
137+
parallel_residual,
138+
layer_norm,
139+
norm_eps,
140+
use_ckpting=use_ckpting,
141+
parallel_gpu=parallel_gpu,
142+
)
124143
self.parallel_residual = parallel_residual
125144
self.shared_layer_norm = shared_layer_norm
126145
if layer_norm == "standard":
@@ -261,6 +280,8 @@ def __init__(
261280
parallel_gpu=1,
262281
sliding_window=0,
263282
rotary_interleave=True,
283+
num_experts=0,
284+
num_experts_per_tok=2,
264285
):
265286
"""
266287
Args:
@@ -290,6 +311,8 @@ def __init__(
290311
parallel_gpu=parallel_gpu,
291312
sliding_window=sliding_window,
292313
rotary_interleave=rotary_interleave,
314+
num_experts=num_experts,
315+
num_experts_per_tok=num_experts_per_tok,
293316
)
294317
self.context_attn = MultiHeadedAttention(
295318
heads,
@@ -450,6 +473,8 @@ def from_opt(cls, opt, embeddings):
450473
else 1,
451474
sliding_window=opt.sliding_window,
452475
rotary_interleave=opt.rotary_interleave,
476+
num_experts=opt.num_experts,
477+
num_experts_per_tok=opt.num_experts_per_tok,
453478
)
454479

455480
def init_state(self, src, enc_out, enc_final_hs):
@@ -569,6 +594,8 @@ def __init__(
569594
parallel_gpu=1,
570595
sliding_window=0,
571596
rotary_interleave=True,
597+
num_experts=0,
598+
num_experts_per_tok=2,
572599
):
573600
super(TransformerDecoder, self).__init__(
574601
d_model, copy_attn, embeddings, alignment_layer, layer_norm, norm_eps
@@ -600,6 +627,8 @@ def __init__(
600627
parallel_gpu=parallel_gpu,
601628
sliding_window=sliding_window,
602629
rotary_interleave=rotary_interleave,
630+
num_experts=num_experts,
631+
num_experts_per_tok=num_experts_per_tok,
603632
)
604633
for i in range(num_layers)
605634
]
@@ -836,6 +865,8 @@ def __init__(
836865
parallel_gpu=1,
837866
sliding_window=0,
838867
rotary_interleave=True,
868+
num_experts=0,
869+
num_experts_per_tok=2,
839870
):
840871
super(TransformerLMDecoder, self).__init__(
841872
d_model, copy_attn, embeddings, alignment_layer, layer_norm, norm_eps
@@ -866,6 +897,8 @@ def __init__(
866897
parallel_gpu=parallel_gpu,
867898
sliding_window=sliding_window,
868899
rotary_interleave=rotary_interleave,
900+
num_experts=num_experts,
901+
num_experts_per_tok=num_experts_per_tok,
869902
)
870903
for i in range(num_layers)
871904
]

onmt/modules/bnb_linear.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,12 @@
77
try:
88
os.environ["BITSANDBYTES_NOWELCOME"] = "1"
99
from bitsandbytes import MatmulLtState
10-
from bitsandbytes.nn import Linear4bit, Linear8bitLt, Params4bit, Int8Params
10+
from bitsandbytes.nn import (
11+
Linear4bit,
12+
Linear8bitLt,
13+
Params4bit,
14+
Int8Params,
15+
)
1116
except ImportError:
1217
raise ImportError("Install bitsandbytes to use 4/8bit compression")
1318

onmt/modules/moe.py

+63
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
"""MoE mixture of experts"."""
2+
import torch
3+
import torch.nn as nn
4+
from onmt.modules.position_ffn import PositionwiseFeedForward
5+
6+
7+
class MoE(nn.Module):
8+
def __init__(
9+
self,
10+
num_experts,
11+
num_experts_per_tok,
12+
d_model,
13+
d_ff,
14+
dropout,
15+
pos_ffn_activation_fn,
16+
add_ffnbias,
17+
parallel_residual,
18+
layer_norm,
19+
norm_eps,
20+
use_ckpting=[],
21+
parallel_gpu=1,
22+
):
23+
super().__init__()
24+
self.experts = nn.ModuleList(
25+
[
26+
PositionwiseFeedForward(
27+
d_model,
28+
d_ff,
29+
dropout,
30+
pos_ffn_activation_fn,
31+
add_ffnbias,
32+
parallel_residual,
33+
layer_norm,
34+
norm_eps,
35+
use_ckpting=use_ckpting,
36+
parallel_gpu=parallel_gpu,
37+
)
38+
for i in range(num_experts)
39+
]
40+
)
41+
self.gate = nn.Linear(d_model, num_experts, bias=False)
42+
self.num_experts_per_tok = num_experts_per_tok
43+
44+
def forward(self, x):
45+
orig_shape = x.shape
46+
x = x.view(-1, x.shape[-1])
47+
48+
scores = self.gate(x)
49+
expert_weights, expert_indices = torch.topk(
50+
scores, self.num_experts_per_tok, dim=-1
51+
)
52+
expert_weights = expert_weights.softmax(dim=-1)
53+
flat_expert_indices = expert_indices.view(-1)
54+
55+
x = x.repeat_interleave(self.num_experts_per_tok, dim=0)
56+
y = torch.empty_like(x)
57+
for i, expert in enumerate(self.experts):
58+
if torch.any(flat_expert_indices == i):
59+
y[flat_expert_indices == i] = expert(x[flat_expert_indices == i])
60+
y = (y.view(*expert_weights.shape, -1) * expert_weights.unsqueeze(-1)).sum(
61+
dim=1
62+
)
63+
return y.view(*orig_shape)

onmt/opts.py

+14
Original file line numberDiff line numberDiff line change
@@ -901,6 +901,20 @@ def model_opts(parser):
901901
default=2048,
902902
help="Size of hidden transformer feed-forward",
903903
)
904+
group.add(
905+
"--num_experts",
906+
"-num_experts",
907+
type=int,
908+
default=0,
909+
help="Number of experts",
910+
)
911+
group.add(
912+
"--num_experts_per_tok",
913+
"-num_experts_per_tok",
914+
type=int,
915+
default=2,
916+
help="Number of experts per token",
917+
)
904918
group.add(
905919
"--aan_useffn",
906920
"-aan_useffn",

0 commit comments

Comments
 (0)