|
9 | 9 | from onmt.modules import MultiHeadedAttention, AverageAttention
|
10 | 10 | from onmt.modules.position_ffn import PositionwiseFeedForward
|
11 | 11 | from onmt.modules.position_ffn import ActivationFunction
|
| 12 | +from onmt.modules.moe import MoE |
12 | 13 | from onmt.utils.misc import sequence_mask
|
13 | 14 |
|
14 | 15 | try:
|
@@ -43,6 +44,8 @@ def __init__(
|
43 | 44 | parallel_gpu=1,
|
44 | 45 | sliding_window=0,
|
45 | 46 | rotary_interleave=True,
|
| 47 | + num_experts=0, |
| 48 | + num_experts_per_tok=2, |
46 | 49 | ):
|
47 | 50 | """
|
48 | 51 | Args:
|
@@ -109,18 +112,34 @@ def __init__(
|
109 | 112 | d_model, dropout=attention_dropout, aan_useffn=aan_useffn
|
110 | 113 | )
|
111 | 114 |
|
112 |
| - self.feed_forward = PositionwiseFeedForward( |
113 |
| - d_model, |
114 |
| - d_ff, |
115 |
| - dropout, |
116 |
| - pos_ffn_activation_fn, |
117 |
| - add_ffnbias, |
118 |
| - parallel_residual, |
119 |
| - layer_norm, |
120 |
| - norm_eps, |
121 |
| - use_ckpting=use_ckpting, |
122 |
| - parallel_gpu=parallel_gpu, |
123 |
| - ) |
| 115 | + if num_experts > 0: |
| 116 | + self.feed_forward = MoE( |
| 117 | + num_experts, |
| 118 | + num_experts_per_tok, |
| 119 | + d_model, |
| 120 | + d_ff, |
| 121 | + dropout, |
| 122 | + pos_ffn_activation_fn, |
| 123 | + add_ffnbias, |
| 124 | + parallel_residual, |
| 125 | + layer_norm, |
| 126 | + norm_eps, |
| 127 | + use_ckpting=use_ckpting, |
| 128 | + parallel_gpu=parallel_gpu, |
| 129 | + ) |
| 130 | + else: |
| 131 | + self.feed_forward = PositionwiseFeedForward( |
| 132 | + d_model, |
| 133 | + d_ff, |
| 134 | + dropout, |
| 135 | + pos_ffn_activation_fn, |
| 136 | + add_ffnbias, |
| 137 | + parallel_residual, |
| 138 | + layer_norm, |
| 139 | + norm_eps, |
| 140 | + use_ckpting=use_ckpting, |
| 141 | + parallel_gpu=parallel_gpu, |
| 142 | + ) |
124 | 143 | self.parallel_residual = parallel_residual
|
125 | 144 | self.shared_layer_norm = shared_layer_norm
|
126 | 145 | if layer_norm == "standard":
|
@@ -261,6 +280,8 @@ def __init__(
|
261 | 280 | parallel_gpu=1,
|
262 | 281 | sliding_window=0,
|
263 | 282 | rotary_interleave=True,
|
| 283 | + num_experts=0, |
| 284 | + num_experts_per_tok=2, |
264 | 285 | ):
|
265 | 286 | """
|
266 | 287 | Args:
|
@@ -290,6 +311,8 @@ def __init__(
|
290 | 311 | parallel_gpu=parallel_gpu,
|
291 | 312 | sliding_window=sliding_window,
|
292 | 313 | rotary_interleave=rotary_interleave,
|
| 314 | + num_experts=num_experts, |
| 315 | + num_experts_per_tok=num_experts_per_tok, |
293 | 316 | )
|
294 | 317 | self.context_attn = MultiHeadedAttention(
|
295 | 318 | heads,
|
@@ -450,6 +473,8 @@ def from_opt(cls, opt, embeddings):
|
450 | 473 | else 1,
|
451 | 474 | sliding_window=opt.sliding_window,
|
452 | 475 | rotary_interleave=opt.rotary_interleave,
|
| 476 | + num_experts=opt.num_experts, |
| 477 | + num_experts_per_tok=opt.num_experts_per_tok, |
453 | 478 | )
|
454 | 479 |
|
455 | 480 | def init_state(self, src, enc_out, enc_final_hs):
|
@@ -569,6 +594,8 @@ def __init__(
|
569 | 594 | parallel_gpu=1,
|
570 | 595 | sliding_window=0,
|
571 | 596 | rotary_interleave=True,
|
| 597 | + num_experts=0, |
| 598 | + num_experts_per_tok=2, |
572 | 599 | ):
|
573 | 600 | super(TransformerDecoder, self).__init__(
|
574 | 601 | d_model, copy_attn, embeddings, alignment_layer, layer_norm, norm_eps
|
@@ -600,6 +627,8 @@ def __init__(
|
600 | 627 | parallel_gpu=parallel_gpu,
|
601 | 628 | sliding_window=sliding_window,
|
602 | 629 | rotary_interleave=rotary_interleave,
|
| 630 | + num_experts=num_experts, |
| 631 | + num_experts_per_tok=num_experts_per_tok, |
603 | 632 | )
|
604 | 633 | for i in range(num_layers)
|
605 | 634 | ]
|
@@ -836,6 +865,8 @@ def __init__(
|
836 | 865 | parallel_gpu=1,
|
837 | 866 | sliding_window=0,
|
838 | 867 | rotary_interleave=True,
|
| 868 | + num_experts=0, |
| 869 | + num_experts_per_tok=2, |
839 | 870 | ):
|
840 | 871 | super(TransformerLMDecoder, self).__init__(
|
841 | 872 | d_model, copy_attn, embeddings, alignment_layer, layer_norm, norm_eps
|
@@ -866,6 +897,8 @@ def __init__(
|
866 | 897 | parallel_gpu=parallel_gpu,
|
867 | 898 | sliding_window=sliding_window,
|
868 | 899 | rotary_interleave=rotary_interleave,
|
| 900 | + num_experts=num_experts, |
| 901 | + num_experts_per_tok=num_experts_per_tok, |
869 | 902 | )
|
870 | 903 | for i in range(num_layers)
|
871 | 904 | ]
|
|
0 commit comments