diff --git a/xlm/model/transformer.py b/xlm/model/transformer.py index 53660e16..6a61ab0e 100755 --- a/xlm/model/transformer.py +++ b/xlm/model/transformer.py @@ -206,8 +206,8 @@ def unshape(x): k, v = cache[self.layer_id] cache[self.layer_id] = (k, v) - q = q / math.sqrt(dim_per_head) # (bs, n_heads, qlen, dim_per_head) scores = torch.matmul(q, k.transpose(2, 3)) # (bs, n_heads, qlen, klen) + scores = scores / math.sqrt(dim_per_head) # (bs, n_heads, qlen, klen) mask = (mask == 0).view(mask_reshape).expand_as(scores) # (bs, n_heads, qlen, klen) scores.masked_fill_(mask, -float('inf')) # (bs, n_heads, qlen, klen)