-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathembeddings.py
120 lines (100 loc) · 4.49 KB
/
embeddings.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
#GET SENTENCE EMBEDDINGS (well used for either )
import torch
from transformers import BertTokenizer, BertModel
from datetime import datetime
import extraction as ex
def get_sentence_embedding(input_string, tokenizer, emb_model):
e = tokenizer.encode(input_string,truncation = True, return_tensors = 'pt')
emb_model.eval()
with torch.no_grad():
outputs = emb_model(e)
hidden_states = outputs[2]
token_embeddings = torch.stack(hidden_states, dim=0)
token_embeddings = torch.squeeze(token_embeddings, dim=1) #if 0 remains unchanged
token_embeddings = token_embeddings.permute(1,0,2)
#SENTENCE EMBEDDINGS:
token_vecs = hidden_states[-2][0]#what average, this gets only the second last hidden layer
# Calculate the average of all 22 token vectors.
sentence_embeddings = torch.mean(token_vecs, dim=0)
return sentence_embeddings #[768]
def construct_input(spans, tokenizer, emb_model): # posts = X post, span, label
X = []
i = 1
for span in spans:
#the text of the post where the span is coming from
#post_text_emb = get_sentence_embedding(span[0], tokenizer, emb_model)
#the span
span_text_emb = get_sentence_embedding(span[1], tokenizer, emb_model)
#X.append(torch.cat((post_text_emb,span_text_emb)))
X.append(span_text_emb)
i += 1
if i % 100 == 0:
print("Data point "+ str(i)+ " done")
return X
#the more efficeint method, same as result as function above just more efficient time wise
def construct_input_eff(docs, tokenizer, emb_model): # posts = X post, span, label
X = []
i = 1
now = datetime.now()
current_time = now.strftime("%H:%M:%S")
print(current_time)
for doc in docs:
post_text = str(doc) #the text of the post : String
spans = doc.ents #teh lists of spans: List[String]
#the text of the post where the span is coming from
post_text_emb = get_sentence_embedding(post_text, tokenizer, emb_model)
#the span
for span in spans:
span_text_emb = get_sentence_embedding(str(span), tokenizer, emb_model)
X.append(post_text_emb + span_text_emb)
i += 1
if i % 100 == 0:
now = datetime.now()
current_time = now.strftime("%H:%M:%S")
print("Data point "+ str(i)+ " done at "+ current_time,)
return X
def embed_posts():
embeds = []
count = 0
posts = ex.read_posts('st1_public_data/st1_train_inc_text.csv')
l = len(posts)
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased',
output_hidden_states = True, # Whether the model returns all hidden-states.
)
for p in posts:
count += 1
print(str(count)+" /"+str(l))
embeds.append(get_sentence_embedding(p.text,tokenizer,model))
torch.save(embeds, 'post_embeddings.pt')
# #the text of the post where the span is coming from
# tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
# model = BertModel.from_pretrained('bert-base-uncased',
# output_hidden_states = True, # Whether the model returns all hidden-states.
# )
def generate_embeddings_binary_classifier(data_points):
X =[]
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased',
output_hidden_states = True, # Whether the model returns all hidden-states.
)
posts = ex.read_posts('st1_public_data/st1_train_inc_text.csv')
#Post embeddings dictionary:
embeds_posts = dict()
#Running the embeddings might take a bit of time so the post text embeddings were run and saved in "post_embeddings.pt"
#(origin: embeddings.py, "embed_posts()")
em = torch.load("post_embeddings.pt")
for i in range(len(posts)-1):
embeds_posts[posts[i].text] = em[i]
l = len(data_points)
count = 0
for x in data_points:
count += 1
if(count%100 ==0):
print(str(count)+" /"+str(l))
text_key = x[0]
if text_key in embeds_posts:
post_text_emb = embeds_posts[text_key] #post embedding
span_emb = get_sentence_embedding(x[1], tokenizer, model)
X.append(post_text_emb + span_emb)
return X