Skip to content

Commit

Permalink
Uploaded correct WN18RR version; added reverse rule script.
Browse files Browse the repository at this point in the history
  • Loading branch information
TimDettmers committed Aug 22, 2017
1 parent 4832649 commit bc288dd
Show file tree
Hide file tree
Showing 6 changed files with 201 additions and 2 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
*.swp
Binary file modified WN18RR.tar.gz
Binary file not shown.
41 changes: 41 additions & 0 deletions create_WN18RR.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
#!/usr/bin/env python3

# the paths do not exist (we do not want to encourage the usage of WN18) and the filenames are off, but if you put in the WN18 files
# and adjust the path you will generate WN18RR.

predicates_to_remove = [
'_member_of_domain_topic',
'_synset_domain_usage_of',
'_instance_hyponym',
'_hyponym',
'_member_holonym',
'_synset_domain_region_of',
'_part_of'
]


def read_triples(path):
triples = []
with open(path, 'rt') as f:
for line in f.readlines():
s, p, o = line.split('\t')
triples += [(s.strip(), p.strip(), o.strip())]
return triples


def write_triples(triples, path):
with open(path, 'wt') as f:
for (s, p, o) in triples:
f.write('{}\t{}\t{}\n'.format(s, p, o))

train_triples = read_triples('original/wordnet-mlj12-train.txt')
valid_triples = read_triples('original/wordnet-mlj12-valid.txt')
test_triples = read_triples('original/wordnet-mlj12-test.txt')

filtered_train_triples = [(s, p, o) for (s, p, o) in train_triples if p not in predicates_to_remove]
filtered_valid_triples = [(s, p, o) for (s, p, o) in valid_triples if p not in predicates_to_remove]
filtered_test_triples = [(s, p, o) for (s, p, o) in test_triples if p not in predicates_to_remove]

write_triples(filtered_train_triples, 'wn18-train.tsv')
write_triples(filtered_valid_triples, 'wn18-valid.tsv')
write_triples(filtered_test_triples, 'wn18-test.tsv')
2 changes: 1 addition & 1 deletion training.py → evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from sklearn import metrics

#timer = CUDATimer()
log = Logger('training_{0}.py.txt'.format(datetime.datetime.now()))
log = Logger('evaluation{0}.py.txt'.format(datetime.datetime.now()))

def ranking_and_hits(model, dev_rank_batcher, vocab, name):
log.info('')
Expand Down
2 changes: 1 addition & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from os.path import join
import torch.backends.cudnn as cudnn

from training import ranking_and_hits
from evaluation import ranking_and_hits
from model import ConvE, DistMult, Complex

from spodernet.preprocessing.pipeline import Pipeline, DatasetStreamer
Expand Down
157 changes: 157 additions & 0 deletions reverse_rule.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
from __future__ import print_function
from os.path import join

import os
import numpy as np
import itertools
import sys


if len(sys.argv) > 1:
dataset_name = sys.argv[1]
else:
#dataset_name = 'FB15k-237'
#dataset_name = 'YAGO3-10'
#dataset_name = 'WN18'
#dataset_name = 'FB15k'
dataset_name = 'WN18RR'

threshold = 0.99
print(threshold)

base_path = 'data/{0}/'.format(dataset_name)
files = ['train.txt', 'valid.txt', 'test.txt']

data = []
for p in files:
with open(join(base_path, p)) as f:
data = f.readlines() + data

e_set = set()
rel_set = set()
test_cases = {}
rel_to_tuple = {}
e1rel2e2 = {}
existing_triples = set()
for p in files:
test_cases[p] = []


for p in files:
with open(join(base_path, p)) as f:
for i, line in enumerate(f):
e1, rel, e2 = line.split('\t')
e1 = e1.strip()
e2 = e2.strip()
rel = rel.strip()
e_set.add(e1)
e_set.add(e2)
rel_set.add(rel)
existing_triples.add((e1, rel, e2))

if (e1, rel) not in e1rel2e2: e1rel2e2[(e1, rel)] = set()
e1rel2e2[(e1, rel)].add(e2)

if rel not in rel_to_tuple:
rel_to_tuple[rel] = set()

rel_to_tuple[rel].add((e1,e2))
test_cases[p].append([e1, rel, e2])


def check_for_reversible_relations(rel_to_tuple, threshold=0.80):
rel2reversal_rel = {}
pairs = set()
for i, rel1 in enumerate(rel_to_tuple):
if i % 100 == 0:
print('Processed {0} relations...'.format(i))
for rel2 in rel_to_tuple:
tuples2 = rel_to_tuple[rel2]
tuples1 = rel_to_tuple[rel1]
# check if the entire set of (e1, e2) is contained in the set of the
# other relation, but in a reversed manner
# that is ALL (e1, e2) -> (e2, e1) for rel 1 are contained in set entity tuple set of rel2 (and vice versa)
# if this is true for ALL entities, that is the sets completely overlap, then add a rule that
# (e1, rel1, e2) == (e2, rel2, e1)
n1 = float(len(tuples1))
n2 = float(len(tuples2))
left = np.sum([(e2,e1) in tuples2 for (e1,e2) in tuples1])/n1
right = np.sum([(e1,e2) in tuples1 for (e2,e1) in tuples2])/n2
if left >= threshold or right >= threshold:
print(left, right, rel1, rel2, n1, n2)
rel2reversal_rel[rel1] = rel2
rel2reversal_rel[rel2] = rel1
if (rel2, rel1) not in pairs:
pairs.add((rel1, rel2))
#print(rel1, rel2, left, right)
return rel2reversal_rel, pairs

rel2reversal_rel, banned_pairs = check_for_reversible_relations(rel_to_tuple, threshold)

print(rel2reversal_rel)
print(len(rel2reversal_rel))
evaluate = True
if evaluate:
all_cases = []
rel2tuples = {}
train_dev = test_cases['train.txt'] + test_cases['valid.txt']
for e1, rel, e2 in train_dev:
if rel not in rel2tuples: rel2tuples[rel] = set()
rel2tuples[rel].add((e1, e2))
if rel in rel2reversal_rel:
rel2 = rel2reversal_rel[rel]
if rel2 not in rel2tuples: rel2tuples[rel2] = set()
rel2tuples[rel2].add((e2, e1))

num_entities = len(e_set)
ranks = []
for i, (e1, rel, e2) in enumerate(test_cases['test.txt']):
if i % 1000 == 0: print(i)
if rel in rel2reversal_rel:
rel2 = rel2reversal_rel[rel]
if (e2, e1) in rel2tuples[rel]: ranks.append(1)
elif (e2, e1) in rel2tuples[rel2]: ranks.append(1)
elif (e1, e2) in rel2tuples[rel2]: ranks.append(1)
else:
ranks.append(np.random.randint(1, num_entities+1))

for e2_neg in e_set:
if (e1, rel, e2_neg) in existing_triples: continue
if (e1, e2_neg) in rel2tuples[rel]:
ranks[-1] += 1
if (e2_neg, e1) in rel2tuples[rel2]:
ranks[-1] += 1

for e1_neg in e_set:
if (e1_neg, rel, e2) in existing_triples: continue
if (e2, e1_neg) in rel2tuples[rel]:
ranks[-1] += 1
if (e2, e1_neg) in rel2tuples[rel2]:
ranks[-1] += 1
else:
if rel not in rel2tuples:
ranks.append(np.random.randint(1, num_entities+1))
elif (e2, e1) in rel2tuples[rel]:
ranks.append(1)

for e2_neg in e_set:
if (e1, rel, e2_neg) in existing_triples: continue
if (e1, e2_neg) in rel2tuples[rel]:
ranks[-1] += 1

for e1_neg in e_set:
if (e1_neg, rel, e2) in existing_triples: continue
if (e2, e1_neg) in rel2tuples[rel]:
ranks[-1] += 1
else:
ranks.append(np.random.randint(1, num_entities+1))

n = float(len(ranks))
print(n)
ranks = np.array(ranks)
for i in range(10):
print('Hits@{0}: {1:.7}'.format(i+1, np.sum(ranks <= i+1)/n))
print("MR: {0}".format(np.mean(ranks)))
print("MRR: {0}".format(np.mean(1.0/ranks)))

print(threshold)

0 comments on commit bc288dd

Please sign in to comment.