-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel.py
26 lines (23 loc) · 981 Bytes
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import torch
import torch.nn as nn
import transformers
class Transformer(nn.Module):
def __init__(self, transformer_name, output_dim, freeze):
super().__init__()
self.transformer = transformers.AutoModel.from_pretrained(transformer_name)
hidden_dim = self.transformer.config.hidden_size
self.fc = nn.Linear(hidden_dim, output_dim)
if freeze:
for param in self.transformer.parameters():
param.requires_grad = False
def forward(self, ids):
# ids = [batch size, seq len]
output = self.transformer(ids, output_attentions=True)
hidden = output.last_hidden_state
# hidden = [batch size, seq len, hidden dim]
attention = output.attentions[-1]
# attention = [batch size, n heads, seq len, seq len]
cls_hidden = hidden[:, 0, :]
prediction = self.fc(torch.tanh(cls_hidden))
# prediction = [batch size, output dim]
return prediction