diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..1377554 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +*.swp diff --git a/WN18RR.tar.gz b/WN18RR.tar.gz index 2fc4684..37f00f2 100644 Binary files a/WN18RR.tar.gz and b/WN18RR.tar.gz differ diff --git a/create_WN18RR.py b/create_WN18RR.py new file mode 100755 index 0000000..ed36822 --- /dev/null +++ b/create_WN18RR.py @@ -0,0 +1,41 @@ +#!/usr/bin/env python3 + +# the paths do not exist (we do not want to encourage the usage of WN18) and the filenames are off, but if you put in the WN18 files +# and adjust the path you will generate WN18RR. + +predicates_to_remove = [ + '_member_of_domain_topic', + '_synset_domain_usage_of', + '_instance_hyponym', + '_hyponym', + '_member_holonym', + '_synset_domain_region_of', + '_part_of' +] + + +def read_triples(path): + triples = [] + with open(path, 'rt') as f: + for line in f.readlines(): + s, p, o = line.split('\t') + triples += [(s.strip(), p.strip(), o.strip())] + return triples + + +def write_triples(triples, path): + with open(path, 'wt') as f: + for (s, p, o) in triples: + f.write('{}\t{}\t{}\n'.format(s, p, o)) + +train_triples = read_triples('original/wordnet-mlj12-train.txt') +valid_triples = read_triples('original/wordnet-mlj12-valid.txt') +test_triples = read_triples('original/wordnet-mlj12-test.txt') + +filtered_train_triples = [(s, p, o) for (s, p, o) in train_triples if p not in predicates_to_remove] +filtered_valid_triples = [(s, p, o) for (s, p, o) in valid_triples if p not in predicates_to_remove] +filtered_test_triples = [(s, p, o) for (s, p, o) in test_triples if p not in predicates_to_remove] + +write_triples(filtered_train_triples, 'wn18-train.tsv') +write_triples(filtered_valid_triples, 'wn18-valid.tsv') +write_triples(filtered_test_triples, 'wn18-test.tsv') diff --git a/training.py b/evaluation.py similarity index 98% rename from training.py rename to evaluation.py index 5d6a3e1..e91a8f1 100644 --- a/training.py +++ b/evaluation.py @@ -9,7 +9,7 @@ from sklearn import metrics #timer = CUDATimer() -log = Logger('training_{0}.py.txt'.format(datetime.datetime.now())) +log = Logger('evaluation{0}.py.txt'.format(datetime.datetime.now())) def ranking_and_hits(model, dev_rank_batcher, vocab, name): log.info('') diff --git a/main.py b/main.py index 58e5feb..ea1bf54 100644 --- a/main.py +++ b/main.py @@ -10,7 +10,7 @@ from os.path import join import torch.backends.cudnn as cudnn -from training import ranking_and_hits +from evaluation import ranking_and_hits from model import ConvE, DistMult, Complex from spodernet.preprocessing.pipeline import Pipeline, DatasetStreamer diff --git a/reverse_rule.py b/reverse_rule.py new file mode 100644 index 0000000..c1e65df --- /dev/null +++ b/reverse_rule.py @@ -0,0 +1,157 @@ +from __future__ import print_function +from os.path import join + +import os +import numpy as np +import itertools +import sys + + +if len(sys.argv) > 1: + dataset_name = sys.argv[1] +else: + #dataset_name = 'FB15k-237' + #dataset_name = 'YAGO3-10' + #dataset_name = 'WN18' + #dataset_name = 'FB15k' + dataset_name = 'WN18RR' + +threshold = 0.99 +print(threshold) + +base_path = 'data/{0}/'.format(dataset_name) +files = ['train.txt', 'valid.txt', 'test.txt'] + +data = [] +for p in files: + with open(join(base_path, p)) as f: + data = f.readlines() + data + +e_set = set() +rel_set = set() +test_cases = {} +rel_to_tuple = {} +e1rel2e2 = {} +existing_triples = set() +for p in files: + test_cases[p] = [] + + +for p in files: + with open(join(base_path, p)) as f: + for i, line in enumerate(f): + e1, rel, e2 = line.split('\t') + e1 = e1.strip() + e2 = e2.strip() + rel = rel.strip() + e_set.add(e1) + e_set.add(e2) + rel_set.add(rel) + existing_triples.add((e1, rel, e2)) + + if (e1, rel) not in e1rel2e2: e1rel2e2[(e1, rel)] = set() + e1rel2e2[(e1, rel)].add(e2) + + if rel not in rel_to_tuple: + rel_to_tuple[rel] = set() + + rel_to_tuple[rel].add((e1,e2)) + test_cases[p].append([e1, rel, e2]) + + +def check_for_reversible_relations(rel_to_tuple, threshold=0.80): + rel2reversal_rel = {} + pairs = set() + for i, rel1 in enumerate(rel_to_tuple): + if i % 100 == 0: + print('Processed {0} relations...'.format(i)) + for rel2 in rel_to_tuple: + tuples2 = rel_to_tuple[rel2] + tuples1 = rel_to_tuple[rel1] + # check if the entire set of (e1, e2) is contained in the set of the + # other relation, but in a reversed manner + # that is ALL (e1, e2) -> (e2, e1) for rel 1 are contained in set entity tuple set of rel2 (and vice versa) + # if this is true for ALL entities, that is the sets completely overlap, then add a rule that + # (e1, rel1, e2) == (e2, rel2, e1) + n1 = float(len(tuples1)) + n2 = float(len(tuples2)) + left = np.sum([(e2,e1) in tuples2 for (e1,e2) in tuples1])/n1 + right = np.sum([(e1,e2) in tuples1 for (e2,e1) in tuples2])/n2 + if left >= threshold or right >= threshold: + print(left, right, rel1, rel2, n1, n2) + rel2reversal_rel[rel1] = rel2 + rel2reversal_rel[rel2] = rel1 + if (rel2, rel1) not in pairs: + pairs.add((rel1, rel2)) + #print(rel1, rel2, left, right) + return rel2reversal_rel, pairs + +rel2reversal_rel, banned_pairs = check_for_reversible_relations(rel_to_tuple, threshold) + +print(rel2reversal_rel) +print(len(rel2reversal_rel)) +evaluate = True +if evaluate: + all_cases = [] + rel2tuples = {} + train_dev = test_cases['train.txt'] + test_cases['valid.txt'] + for e1, rel, e2 in train_dev: + if rel not in rel2tuples: rel2tuples[rel] = set() + rel2tuples[rel].add((e1, e2)) + if rel in rel2reversal_rel: + rel2 = rel2reversal_rel[rel] + if rel2 not in rel2tuples: rel2tuples[rel2] = set() + rel2tuples[rel2].add((e2, e1)) + + num_entities = len(e_set) + ranks = [] + for i, (e1, rel, e2) in enumerate(test_cases['test.txt']): + if i % 1000 == 0: print(i) + if rel in rel2reversal_rel: + rel2 = rel2reversal_rel[rel] + if (e2, e1) in rel2tuples[rel]: ranks.append(1) + elif (e2, e1) in rel2tuples[rel2]: ranks.append(1) + elif (e1, e2) in rel2tuples[rel2]: ranks.append(1) + else: + ranks.append(np.random.randint(1, num_entities+1)) + + for e2_neg in e_set: + if (e1, rel, e2_neg) in existing_triples: continue + if (e1, e2_neg) in rel2tuples[rel]: + ranks[-1] += 1 + if (e2_neg, e1) in rel2tuples[rel2]: + ranks[-1] += 1 + + for e1_neg in e_set: + if (e1_neg, rel, e2) in existing_triples: continue + if (e2, e1_neg) in rel2tuples[rel]: + ranks[-1] += 1 + if (e2, e1_neg) in rel2tuples[rel2]: + ranks[-1] += 1 + else: + if rel not in rel2tuples: + ranks.append(np.random.randint(1, num_entities+1)) + elif (e2, e1) in rel2tuples[rel]: + ranks.append(1) + + for e2_neg in e_set: + if (e1, rel, e2_neg) in existing_triples: continue + if (e1, e2_neg) in rel2tuples[rel]: + ranks[-1] += 1 + + for e1_neg in e_set: + if (e1_neg, rel, e2) in existing_triples: continue + if (e2, e1_neg) in rel2tuples[rel]: + ranks[-1] += 1 + else: + ranks.append(np.random.randint(1, num_entities+1)) + + n = float(len(ranks)) + print(n) + ranks = np.array(ranks) + for i in range(10): + print('Hits@{0}: {1:.7}'.format(i+1, np.sum(ranks <= i+1)/n)) + print("MR: {0}".format(np.mean(ranks))) + print("MRR: {0}".format(np.mean(1.0/ranks))) + +print(threshold)