From 00d651ad06ca0a16e23be05b0d36ddda46fada3f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=88=98=E6=99=AF=E5=AE=87?= <2537738252@qq.com> Date: Mon, 18 Aug 2025 19:18:33 +0800 Subject: [PATCH 1/2] fix test/ops/self_attention.py --- test/ops/self_attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/ops/self_attention.py b/test/ops/self_attention.py index 0c61d96..aa6f831 100644 --- a/test/ops/self_attention.py +++ b/test/ops/self_attention.py @@ -15,7 +15,7 @@ def torch_self_attention(attn_val, query, key, value, scale): L, S = query.size(-2), key.size(-2) attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device) - temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) + temp_mask = torch.ones(S, S, dtype=torch.bool).tril(diagonal=0)[-L:, ] attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) attn_bias.to(query.dtype) From 896616b44f04ab009bb8df44a16d6e66c08454c0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=88=98=E6=99=AF=E5=AE=87?= <2537738252@qq.com> Date: Tue, 19 Aug 2025 10:50:25 +0800 Subject: [PATCH 2/2] fix test/ops/self_attention.py --- test/ops/self_attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/ops/self_attention.py b/test/ops/self_attention.py index aa6f831..a042b51 100644 --- a/test/ops/self_attention.py +++ b/test/ops/self_attention.py @@ -15,7 +15,7 @@ def torch_self_attention(attn_val, query, key, value, scale): L, S = query.size(-2), key.size(-2) attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device) - temp_mask = torch.ones(S, S, dtype=torch.bool).tril(diagonal=0)[-L:, ] + temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=S-L) attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) attn_bias.to(query.dtype)