-
Notifications
You must be signed in to change notification settings - Fork 60
/
Copy pathatt.py
55 lines (48 loc) · 2.2 KB
/
att.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
# -*- coding: utf-8 -*-
from keras import backend as K
from keras.engine.topology import Layer
# 利用Keras构造注意力机制层
class Attention(Layer):
def __init__(self, attention_size, **kwargs):
self.attention_size = attention_size
super(Attention, self).__init__(**kwargs)
def build(self, input_shape):
# W: (EMBED_SIZE, ATTENTION_SIZE)
# b: (ATTENTION_SIZE, 1)
# u: (ATTENTION_SIZE, 1)
self.W = self.add_weight(name="W_{:s}".format(self.name),
shape=(input_shape[-1], self.attention_size),
initializer="glorot_normal",
trainable=True)
self.b = self.add_weight(name="b_{:s}".format(self.name),
shape=(input_shape[1], 1),
initializer="zeros",
trainable=True)
self.u = self.add_weight(name="u_{:s}".format(self.name),
shape=(self.attention_size, 1),
initializer="glorot_normal",
trainable=True)
super(Attention, self).build(input_shape)
def call(self, x, mask=None):
# input: (BATCH_SIZE, MAX_TIMESTEPS, EMBED_SIZE)
# et: (BATCH_SIZE, MAX_TIMESTEPS, ATTENTION_SIZE)
et = K.tanh(K.dot(x, self.W) + self.b)
# at: (BATCH_SIZE, MAX_TIMESTEPS)
at = K.softmax(K.squeeze(K.dot(et, self.u), axis=-1))
if mask is not None:
at *= K.cast(mask, K.floatx())
# ot: (BATCH_SIZE, MAX_TIMESTEPS, EMBED_SIZE)
atx = K.expand_dims(at, axis=-1)
ot = atx * x
# output: (BATCH_SIZE, EMBED_SIZE)
output = K.sum(ot, axis=1)
return output
def compute_mask(self, input, input_mask=None):
return None
def compute_output_shape(self, input_shape):
return (input_shape[0], input_shape[-1])
# 该函数用于保存和加载模型时使用
def get_config(self):
config = {"attention_size": self.attention_size}
base_config = super(Attention, self).get_config()
return dict(list(base_config.items()) + list(config.items()))