-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtransformer.py
60 lines (48 loc) · 2.18 KB
/
transformer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from tensorflow import keras
from models import EncoderModel, DecoderModel
class Transformer(keras.Model):
def __init__(self, num_layers, input_vocab_size,
target_vocab_size, max_length, d_model, num_heads, dff, rate=0.1):
super(Transformer, self).__init__()
self.encoder = EncoderModel(
num_layers, input_vocab_size, max_length,
d_model, num_heads, dff, rate)
self.decoder = DecoderModel(
num_layers, target_vocab_size, max_length,
d_model, num_heads, dff, rate)
self.final_layer = keras.layers.Dense(target_vocab_size)
def call(self, inp, tar, training, encoder_padding_mask,
decoder_mask, encoder_decoder_padding_mask):
# inp.shape: (batch_size, inp_seq_len)
# encoding_output.shape: (batch_size, inp_seq_len, d_model)
encoding_outputs = self.encoder(
inp, training, encoder_padding_mask)
# decoding_output.shape: (batch_size, tar_seq_len, d_model)
decoding_outputs, attention_weights = self.decoder(
tar, encoding_outputs, training,
decoder_mask, encoder_decoder_padding_mask)
# predictions.shape: (batch_size, tar_seq_len, target_vocab_size)
predictions = self.final_layer(decoding_outputs)
return predictions, attention_weights
if __name__ == '__main__':
import tensorflow as tf
# 测试
sample_transformer = Transformer(4, 128, 8, 512,
8216, 8089, 40,
rate=0.1)
temp_input = tf.random.uniform((64, 26))
temp_target = tf.random.uniform((64, 31))
# 得到输出
predictions, attention_weights = sample_transformer(
temp_input, temp_target, training=False,
encoder_padding_mask=None,
decoder_mask=None,
encoder_decoder_padding_mask=None)
# 输出shape
print(predictions.shape)
print('-' * 50)
# attention_weights 的shape打印,为了后面画图做铺垫
for key in attention_weights:
print(key, attention_weights[key].shape)
print('-' * 50)
sample_transformer.summary()