-
Notifications
You must be signed in to change notification settings - Fork 162
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Uploaded correct WN18RR version; added reverse rule script.
- Loading branch information
1 parent
4832649
commit bc288dd
Showing
6 changed files
with
201 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
*.swp |
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# the paths do not exist (we do not want to encourage the usage of WN18) and the filenames are off, but if you put in the WN18 files | ||
# and adjust the path you will generate WN18RR. | ||
|
||
predicates_to_remove = [ | ||
'_member_of_domain_topic', | ||
'_synset_domain_usage_of', | ||
'_instance_hyponym', | ||
'_hyponym', | ||
'_member_holonym', | ||
'_synset_domain_region_of', | ||
'_part_of' | ||
] | ||
|
||
|
||
def read_triples(path): | ||
triples = [] | ||
with open(path, 'rt') as f: | ||
for line in f.readlines(): | ||
s, p, o = line.split('\t') | ||
triples += [(s.strip(), p.strip(), o.strip())] | ||
return triples | ||
|
||
|
||
def write_triples(triples, path): | ||
with open(path, 'wt') as f: | ||
for (s, p, o) in triples: | ||
f.write('{}\t{}\t{}\n'.format(s, p, o)) | ||
|
||
train_triples = read_triples('original/wordnet-mlj12-train.txt') | ||
valid_triples = read_triples('original/wordnet-mlj12-valid.txt') | ||
test_triples = read_triples('original/wordnet-mlj12-test.txt') | ||
|
||
filtered_train_triples = [(s, p, o) for (s, p, o) in train_triples if p not in predicates_to_remove] | ||
filtered_valid_triples = [(s, p, o) for (s, p, o) in valid_triples if p not in predicates_to_remove] | ||
filtered_test_triples = [(s, p, o) for (s, p, o) in test_triples if p not in predicates_to_remove] | ||
|
||
write_triples(filtered_train_triples, 'wn18-train.tsv') | ||
write_triples(filtered_valid_triples, 'wn18-valid.tsv') | ||
write_triples(filtered_test_triples, 'wn18-test.tsv') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,157 @@ | ||
from __future__ import print_function | ||
from os.path import join | ||
|
||
import os | ||
import numpy as np | ||
import itertools | ||
import sys | ||
|
||
|
||
if len(sys.argv) > 1: | ||
dataset_name = sys.argv[1] | ||
else: | ||
#dataset_name = 'FB15k-237' | ||
#dataset_name = 'YAGO3-10' | ||
#dataset_name = 'WN18' | ||
#dataset_name = 'FB15k' | ||
dataset_name = 'WN18RR' | ||
|
||
threshold = 0.99 | ||
print(threshold) | ||
|
||
base_path = 'data/{0}/'.format(dataset_name) | ||
files = ['train.txt', 'valid.txt', 'test.txt'] | ||
|
||
data = [] | ||
for p in files: | ||
with open(join(base_path, p)) as f: | ||
data = f.readlines() + data | ||
|
||
e_set = set() | ||
rel_set = set() | ||
test_cases = {} | ||
rel_to_tuple = {} | ||
e1rel2e2 = {} | ||
existing_triples = set() | ||
for p in files: | ||
test_cases[p] = [] | ||
|
||
|
||
for p in files: | ||
with open(join(base_path, p)) as f: | ||
for i, line in enumerate(f): | ||
e1, rel, e2 = line.split('\t') | ||
e1 = e1.strip() | ||
e2 = e2.strip() | ||
rel = rel.strip() | ||
e_set.add(e1) | ||
e_set.add(e2) | ||
rel_set.add(rel) | ||
existing_triples.add((e1, rel, e2)) | ||
|
||
if (e1, rel) not in e1rel2e2: e1rel2e2[(e1, rel)] = set() | ||
e1rel2e2[(e1, rel)].add(e2) | ||
|
||
if rel not in rel_to_tuple: | ||
rel_to_tuple[rel] = set() | ||
|
||
rel_to_tuple[rel].add((e1,e2)) | ||
test_cases[p].append([e1, rel, e2]) | ||
|
||
|
||
def check_for_reversible_relations(rel_to_tuple, threshold=0.80): | ||
rel2reversal_rel = {} | ||
pairs = set() | ||
for i, rel1 in enumerate(rel_to_tuple): | ||
if i % 100 == 0: | ||
print('Processed {0} relations...'.format(i)) | ||
for rel2 in rel_to_tuple: | ||
tuples2 = rel_to_tuple[rel2] | ||
tuples1 = rel_to_tuple[rel1] | ||
# check if the entire set of (e1, e2) is contained in the set of the | ||
# other relation, but in a reversed manner | ||
# that is ALL (e1, e2) -> (e2, e1) for rel 1 are contained in set entity tuple set of rel2 (and vice versa) | ||
# if this is true for ALL entities, that is the sets completely overlap, then add a rule that | ||
# (e1, rel1, e2) == (e2, rel2, e1) | ||
n1 = float(len(tuples1)) | ||
n2 = float(len(tuples2)) | ||
left = np.sum([(e2,e1) in tuples2 for (e1,e2) in tuples1])/n1 | ||
right = np.sum([(e1,e2) in tuples1 for (e2,e1) in tuples2])/n2 | ||
if left >= threshold or right >= threshold: | ||
print(left, right, rel1, rel2, n1, n2) | ||
rel2reversal_rel[rel1] = rel2 | ||
rel2reversal_rel[rel2] = rel1 | ||
if (rel2, rel1) not in pairs: | ||
pairs.add((rel1, rel2)) | ||
#print(rel1, rel2, left, right) | ||
return rel2reversal_rel, pairs | ||
|
||
rel2reversal_rel, banned_pairs = check_for_reversible_relations(rel_to_tuple, threshold) | ||
|
||
print(rel2reversal_rel) | ||
print(len(rel2reversal_rel)) | ||
evaluate = True | ||
if evaluate: | ||
all_cases = [] | ||
rel2tuples = {} | ||
train_dev = test_cases['train.txt'] + test_cases['valid.txt'] | ||
for e1, rel, e2 in train_dev: | ||
if rel not in rel2tuples: rel2tuples[rel] = set() | ||
rel2tuples[rel].add((e1, e2)) | ||
if rel in rel2reversal_rel: | ||
rel2 = rel2reversal_rel[rel] | ||
if rel2 not in rel2tuples: rel2tuples[rel2] = set() | ||
rel2tuples[rel2].add((e2, e1)) | ||
|
||
num_entities = len(e_set) | ||
ranks = [] | ||
for i, (e1, rel, e2) in enumerate(test_cases['test.txt']): | ||
if i % 1000 == 0: print(i) | ||
if rel in rel2reversal_rel: | ||
rel2 = rel2reversal_rel[rel] | ||
if (e2, e1) in rel2tuples[rel]: ranks.append(1) | ||
elif (e2, e1) in rel2tuples[rel2]: ranks.append(1) | ||
elif (e1, e2) in rel2tuples[rel2]: ranks.append(1) | ||
else: | ||
ranks.append(np.random.randint(1, num_entities+1)) | ||
|
||
for e2_neg in e_set: | ||
if (e1, rel, e2_neg) in existing_triples: continue | ||
if (e1, e2_neg) in rel2tuples[rel]: | ||
ranks[-1] += 1 | ||
if (e2_neg, e1) in rel2tuples[rel2]: | ||
ranks[-1] += 1 | ||
|
||
for e1_neg in e_set: | ||
if (e1_neg, rel, e2) in existing_triples: continue | ||
if (e2, e1_neg) in rel2tuples[rel]: | ||
ranks[-1] += 1 | ||
if (e2, e1_neg) in rel2tuples[rel2]: | ||
ranks[-1] += 1 | ||
else: | ||
if rel not in rel2tuples: | ||
ranks.append(np.random.randint(1, num_entities+1)) | ||
elif (e2, e1) in rel2tuples[rel]: | ||
ranks.append(1) | ||
|
||
for e2_neg in e_set: | ||
if (e1, rel, e2_neg) in existing_triples: continue | ||
if (e1, e2_neg) in rel2tuples[rel]: | ||
ranks[-1] += 1 | ||
|
||
for e1_neg in e_set: | ||
if (e1_neg, rel, e2) in existing_triples: continue | ||
if (e2, e1_neg) in rel2tuples[rel]: | ||
ranks[-1] += 1 | ||
else: | ||
ranks.append(np.random.randint(1, num_entities+1)) | ||
|
||
n = float(len(ranks)) | ||
print(n) | ||
ranks = np.array(ranks) | ||
for i in range(10): | ||
print('Hits@{0}: {1:.7}'.format(i+1, np.sum(ranks <= i+1)/n)) | ||
print("MR: {0}".format(np.mean(ranks))) | ||
print("MRR: {0}".format(np.mean(1.0/ranks))) | ||
|
||
print(threshold) |