-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathdataset_generation.py
70 lines (52 loc) · 2.71 KB
/
dataset_generation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from cProfile import label
from types import NoneType
import extraction as ex
import spacy
import torch
def convert_posts_to_docs(posts):
#Read all posts from the training file
nlp = spacy.blank("en")
#Create the dataset:
dataset = []#list of docs
#Convert each RedditPost object into spacy Doc object - easier to extract and manage spans
for post in posts:
spans = [] #where the spans of each post will be collected
doc = None
doc = nlp(post.text) #Initialize a new Doc object with the text of the reddit post
for a in post.arguments: #extract spans using spacy's Doc functions, using the offsets
#NOTE: The alignment_mode = contract specifies that only the tokens fully contained in the range defined by the
#characters will be highlighted. Other options are 'strict' -> the offsets must be found on a token boundary,
# and 'expand'-> highlights token that are also only partially covered by the range.
span = doc.char_span(a['start_offset'],a['end_offset'], label=a['label'], alignment_mode = 'contract')
if span is not None:
spans.append(span) #collect all of the posts' spans here
doc.set_ents(spans) #attach the spans to the Doc
dataset.append(doc) #collect the Docs here
return dataset
#print(dataset[1].ents[0].label_) #get label of the first span of the first post
def generate_dataset(docs):
#dataset
X = []
#generate dataset:
keys = ["question","per_exp","claim","claim_per_exp"]
dataset = docs
for doc_post in dataset:
post_text = str(doc_post) #the text of the post : String
spans = doc_post.ents #teh lists of spans: List[String]
for span in spans:
#initialize dict, mark all labels with '0':
labels = dict.fromkeys(keys, 0)
labels[span.label_] = 1 #change the value of the corresponding label to 1
data_point = [post_text,str(span),list(labels.values())] #how the datapoint looks
X.append(data_point) #the grain here is the span and it has the text and one hot encoding of labels
return X # X[0] = ['some redit post..', 'a span..', [1,0,0,0]] X= [[],[],[]...] - datapoints like X[0]
#print(len(X)) #12168
# def generate_constituency_based_dataset(): #this will be fed to the model so the test set needs to be manipulated, so far its training
# filtered_spans = torch.load('new_constituency_spans.pt')
# posts = ex.read_posts('st1_public_data/st1_train_inc_text.csv') # list[RedditPost]
# #dataset
# X = []
# #generate dataset:
# length = len(posts) -1
# for i in range(length):
# data_point = [posts[i].text, ]