diff --git a/test/ops/self_attention.py b/test/ops/self_attention.py index 0c61d96..a042b51 100644 --- a/test/ops/self_attention.py +++ b/test/ops/self_attention.py @@ -15,7 +15,7 @@ def torch_self_attention(attn_val, query, key, value, scale): L, S = query.size(-2), key.size(-2) attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device) - temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) + temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=S-L) attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) attn_bias.to(query.dtype)