Skip to content

Commit

Permalink
stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
SoufianeNoubir committed Dec 24, 2024
1 parent fb51355 commit 0b673b2
Show file tree
Hide file tree
Showing 2 changed files with 704 additions and 3 deletions.
8 changes: 5 additions & 3 deletions gbmi/exp_indhead/finetunebound.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ def least_attention(a, i_1, i_2, j, dic):
i_1 - 1: a,
}
for i in range(8):
dic.setdefault(i, torch.arange(8)[torch.arange(8) != a])
dic.setdefault(i, torch.arange(26)[torch.arange(26) != a])
bound[a, i_2, i_1, j] = least_attention(a, i_1, i_2, j, dic)

bound_soft = bound.softmax(dim=-1)
Expand All @@ -425,7 +425,9 @@ def loss_diff_1(b, i_1, i_2, dic):

n = torch.arange(d_voc)[torch.arange(d_voc) != b]

return (term_5[i, :, n] - term_5[i, :, b].unsqueeze(dim=-1)).max()
return (
term_5[i_2, dic[i_2]][..., n] - term_5[i_2, :, b].unsqueeze(dim=-1)
).max()

def loss_diff_2(b, i_1, i_2, dic):

Expand Down Expand Up @@ -563,7 +565,7 @@ def total_bound(b, i_1, i_2, dic):
if (i_1 < i_2) & (i_1 > 0):
dic = {i_1: b}
for i in range(8):
dic.setdefault(i, torch.arange(8))
dic.setdefault(i, torch.arange(26))

out[b, i_2, i_1] = total_bound(b, i_1, i_2, dic)

Expand Down
Loading

0 comments on commit 0b673b2

Please sign in to comment.