Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion test/ops/self_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def torch_self_attention(attn_val, query, key, value, scale):
L, S = query.size(-2), key.size(-2)
attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device)

temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
temp_mask = torch.ones(S, S, dtype=torch.bool).tril(diagonal=0)[-L:, ]
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
attn_bias.to(query.dtype)

Expand Down
Loading