@@ -42,7 +42,7 @@ def call(self, inputs, mask=None, **kwargs):
4242 s = q @ tf .transpose (k , [0 , 1 , 3 , 2 ]) / (tf .math .sqrt (self .k_f ) + 1e-8 )
4343 if mask is not None :
4444 s += mask * - 1e9
45- a = tf .nn .softmax (s ) # [b,h,attention,s]
45+ a = tf .nn .softmax (s )
4646 self .attention = a
4747 b = a @ v
4848 o = tf .concat (tf .unstack (b , axis = 1 ), 2 ) @ self .wo
@@ -188,7 +188,7 @@ def call(self, inputs, training=None, **kwargs):
188188 pad_mask = self ._pad_mask (x )
189189 encoded_z = self .encoder (x_embed , mask = pad_mask )
190190 decoded_z = self .decoder (
191- (encoded_z , y_embed ), look_ahead_mask = self ._look_ahead_mask (x ), pad_mask = pad_mask )
191+ (encoded_z , y_embed ), look_ahead_mask = self ._look_ahead_mask (y ), pad_mask = pad_mask )
192192 o = self .o (decoded_z )
193193 return o
194194
@@ -213,7 +213,7 @@ def translate(self, src, i2v, v2i):
213213 y = tgt [:, :- 1 ]
214214 y_embed = self .embed (y )
215215 decoded_z = self .decoder (
216- (encoded_z , y_embed ), look_ahead_mask = self ._look_ahead_mask (src_pad ), pad_mask = self ._pad_mask (src_pad ))
216+ (encoded_z , y_embed ), look_ahead_mask = self ._look_ahead_mask (y ), pad_mask = self ._pad_mask (src_pad ))
217217 logit = self .o (decoded_z )[:, tgti , :].numpy ()
218218 idx = np .argmax (logit , 1 )
219219 tgti += 1
0 commit comments