Skip to content

Commit

Permalink
Added FB15k-237, pytorch code, preprocessing script.
Browse files Browse the repository at this point in the history
  • Loading branch information
TimDettmers committed Aug 9, 2017
1 parent 512c4ae commit 9e3ad48
Show file tree
Hide file tree
Showing 10 changed files with 557 additions and 0 deletions.
Binary file added FB15k-237.tar.gz
Binary file not shown.
Binary file not shown.
Binary file removed YAGO3-10/yago3_mte10-test.tsv.gz
Binary file not shown.
Binary file removed YAGO3-10/yago3_mte10-valid.tsv.gz
Binary file not shown.
166 changes: 166 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
import json
import torch
import pickle
import numpy as np
import argparse
import sys
import os
import math

from os.path import join
import torch.backends.cudnn as cudnn

from training import ranking_and_hits
from model import ConvE, DistMult, Complex

from spodernet.preprocessing.pipeline import Pipeline, DatasetStreamer
from spodernet.preprocessing.processors import JsonLoaderProcessors, Tokenizer, AddToVocab, SaveLengthsToState, StreamToHDF5, SaveMaxLengthsToState, CustomTokenizer
from spodernet.preprocessing.processors import ConvertTokenToIdx, ApplyFunction, ToLower, DictKey2ListMapper, ApplyFunction, StreamToBatch
from spodernet.utils.global_config import Config, Backends
from spodernet.utils.logger import Logger, LogLevel
from spodernet.preprocessing.batching import StreamBatcher
from spodernet.preprocessing.pipeline import Pipeline
from spodernet.preprocessing.processors import TargetIdx2MultiTarget
from spodernet.hooks import LossHook, ETAHook
from spodernet.utils.util import Timer
from spodernet.utils.cuda_utils import CUDATimer
from spodernet.utils.cuda_utils import CUDATimer
from spodernet.preprocessing.processors import TargetIdx2MultiTarget
np.set_printoptions(precision=3)

timer = CUDATimer()
cudnn.benchmark = True

# parse console parameters and set global variables
Config.backend = Backends.TORCH
Config.parse_argv(sys.argv)

Config.cuda = True

Config.hidden_size = 1
Config.embedding_dim = 200
#Logger.GLOBAL_LOG_LEVEL = LogLevel.DEBUG


model_name = 'ConvE_{0}_{1}'.format(Config.input_dropout, Config.dropout)
do_process = True
epochs = 1000
Config.batch_size = 128
load = False
#dataset_name = 'YAGO3-10'
#dataset_name = 'WN18RR'
dataset_name = 'FB15k-237'
model_path = 'saved_models/{0}_{1}.model'.format(dataset_name, model_name)


''' Preprocess knowledge graph using spodernet. '''
def preprocess(dataset_name, delete_data=False):
full_path = 'data/{0}/e1rel_to_e2_full.json'.format(dataset_name)
train_path = 'data/{0}/e1rel_to_e2_train.json'.format(dataset_name)
dev_ranking_path = 'data/{0}/e1rel_to_e2_ranking_dev.json'.format(dataset_name)
test_ranking_path = 'data/{0}/e1rel_to_e2_ranking_test.json'.format(dataset_name)

keys2keys = {}
keys2keys['e1'] = 'e1' # entities
keys2keys['rel'] = 'rel' # relations
keys2keys['e2'] = 'e1' # entities
keys2keys['e2_multi1'] = 'e1' # entity
keys2keys['e2_multi2'] = 'e1' # entity
input_keys = ['e1', 'rel', 'e2', 'e2_multi1', 'e2_multi2']
d = DatasetStreamer(input_keys)
d.add_stream_processor(JsonLoaderProcessors())
d.add_stream_processor(DictKey2ListMapper(input_keys))

# process full vocabulary and save it to disk
d.set_path(full_path)
p = Pipeline(dataset_name, delete_data, keys=input_keys, skip_transformation=True)
p.add_sent_processor(ToLower())
p.add_sent_processor(CustomTokenizer(lambda x: x.split(' ')),keys=['e2_multi1', 'e2_multi2'])
p.add_token_processor(AddToVocab())
p.execute(d)
p.save_vocabs()


# process train, dev and test sets and save them to hdf5
p.skip_transformation = False
for path, name in zip([train_path, dev_ranking_path, test_ranking_path], ['train', 'dev_ranking', 'test_ranking']):
d.set_path(path)
p.clear_processors()
p.add_sent_processor(ToLower())
p.add_sent_processor(CustomTokenizer(lambda x: x.split(' ')),keys=['e2_multi1', 'e2_multi2'])
p.add_post_processor(ConvertTokenToIdx(keys2keys=keys2keys), keys=['e1', 'rel', 'e2', 'e2_multi1', 'e2_multi2'])
p.add_post_processor(StreamToHDF5(name, samples_per_file=1000, keys=input_keys))
p.execute(d)


def main():
if do_process: preprocess(dataset_name, delete_data=True)
input_keys = ['e1', 'rel', 'e2', 'e2_multi1', 'e2_multi2']
p = Pipeline(dataset_name, keys=input_keys)
p.load_vocabs()
vocab = p.state['vocab']

num_entities = vocab['e1'].num_token

train_batcher = StreamBatcher(dataset_name, 'train', Config.batch_size, randomize=True, keys=input_keys)
dev_rank_batcher = StreamBatcher(dataset_name, 'dev_ranking', Config.batch_size, randomize=False, loader_threads=4, keys=input_keys, is_volatile=True)
test_rank_batcher = StreamBatcher(dataset_name, 'test_ranking', Config.batch_size, randomize=False, loader_threads=4, keys=input_keys, is_volatile=True)


#model = Complex(vocab['e1'].num_token, vocab['rel'].num_token)
#model = DistMult(vocab['e1'].num_token, vocab['rel'].num_token)
model = ConvE(vocab['e1'].num_token, vocab['rel'].num_token)

train_batcher.at_batch_prepared_observers.insert(1,TargetIdx2MultiTarget(num_entities, 'e2_multi1', 'e2_multi1_binary'))


eta = ETAHook('train', print_every_x_batches=100)
train_batcher.subscribe_to_events(eta)
train_batcher.subscribe_to_start_of_epoch_event(eta)
train_batcher.subscribe_to_events(LossHook('train', print_every_x_batches=100))

if Config.cuda:
model.cuda()
if load:
model_params = torch.load(model_path)
print(model)
print([(key, value.size()) for key, value in model_params.items()])
model.load_state_dict(model_params)
model.eval()
ranking_and_hits(model, test_rank_batcher, vocab, 'test_evaluation')
ranking_and_hits(model, dev_rank_batcher, vocab, 'dev_evaluation')
else:
model.init()


opt = torch.optim.Adam(model.parameters(), lr=Config.learning_rate, weight_decay=Config.L2)
for epoch in range(epochs):
model.train()
for i, str2var in enumerate(train_batcher):
opt.zero_grad()
e1 = str2var['e1']
rel = str2var['rel']
e2_multi = str2var['e2_multi1_binary'].float()
# label smoothing
e2_multi = ((1.0-Config.label_smoothing_epsilon)*e2_multi) + (1.0/e2_multi.size(1))

pred = model.forward(e1, rel)
loss = model.loss(pred, e2_multi)
loss.backward()
opt.step()

train_batcher.state.loss = loss


print('saving to {0}'.format(model_path))
torch.save(model.state_dict(), model_path)

model.eval()
ranking_and_hits(model, dev_rank_batcher, vocab, 'dev_evaluation')
if epoch % 3 == 0:
if epoch > 0:
ranking_and_hits(model, test_rank_batcher, vocab, 'test_evaluation')


if __name__ == '__main__':
main()
127 changes: 127 additions & 0 deletions model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
import torch
from torch.nn import functional as F, Parameter
from torch.autograd import Variable


from spodernet.utils.global_config import Config
from spodernet.utils.cuda_utils import CUDATimer
from torch.nn.init import xavier_normal, xavier_uniform
from spodernet.utils.cuda_utils import CUDATimer
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

timer = CUDATimer()


class Complex(torch.nn.Module):
def __init__(self, num_entities, num_relations):
super(Complex, self).__init__()
self.num_entities = num_entities
self.emb_e_real = torch.nn.Embedding(num_entities, Config.embedding_dim, padding_idx=0)
self.emb_e_img = torch.nn.Embedding(num_entities, Config.embedding_dim, padding_idx=0)
self.emb_rel_real = torch.nn.Embedding(num_relations, Config.embedding_dim, padding_idx=0)
self.emb_rel_img = torch.nn.Embedding(num_relations, Config.embedding_dim, padding_idx=0)
self.inp_drop = torch.nn.Dropout(Config.input_dropout)
self.loss = torch.nn.BCELoss()

def init(self):
xavier_normal(self.emb_e_real.weight.data)
xavier_normal(self.emb_e_img.weight.data)
xavier_normal(self.emb_rel_real.weight.data)
xavier_normal(self.emb_rel_img.weight.data)

def forward(self, e1, rel):

e1_embedded_real = self.inp_drop(self.emb_e_real(e1)).view(Config.batch_size, -1)
rel_embedded_real = self.inp_drop(self.emb_rel_real(rel)).view(Config.batch_size, -1)
e1_embedded_img = self.inp_drop(self.emb_e_img(e1)).view(Config.batch_size, -1)
rel_embedded_img = self.inp_drop(self.emb_rel_img(rel)).view(Config.batch_size, -1)

e1_embedded_real = self.inp_drop(e1_embedded_real)
rel_embedded_real = self.inp_drop(rel_embedded_real)
e1_embedded_img = self.inp_drop(e1_embedded_img)
rel_embedded_img = self.inp_drop(rel_embedded_img)

# complex space bilinear product (equivalent to HolE)
realrealreal = torch.mm(e1_embedded_real*rel_embedded_real, self.emb_e_real.weight.transpose(1,0))
realimgimg = torch.mm(e1_embedded_real*rel_embedded_img, self.emb_e_img.weight.transpose(1,0))
imgrealimg = torch.mm(e1_embedded_img*rel_embedded_real, self.emb_e_img.weight.transpose(1,0))
imgimgreal = torch.mm(e1_embedded_img*rel_embedded_img, self.emb_e_real.weight.transpose(1,0))
pred = realrealreal + realimgimg + imgrealimg - imgimgreal
pred = F.sigmoid(pred)

return pred


class DistMult(torch.nn.Module):
def __init__(self, num_entities, num_relations):
super(DistMult, self).__init__()
self.emb_e = torch.nn.Embedding(num_entities, Config.embedding_dim, padding_idx=0)
self.emb_rel = torch.nn.Embedding(num_relations, Config.embedding_dim, padding_idx=0)
self.inp_drop = torch.nn.Dropout(Config.input_dropout)
self.loss = torch.nn.BCELoss()

def init(self):
xavier_normal(self.emb_e.weight.data)
xavier_normal(self.emb_rel.weight.data)

def forward(self, e1, rel):
e1_embedded= self.emb_e(e1)
rel_embedded= self.emb_rel(rel)
e1_embedded = e1_embedded.view(-1, Config.embedding_dim)
rel_embedded = rel_embedded.view(-1, Config.embedding_dim)

e1_embedded = self.inp_drop(e1_embedded)
rel_embedded = self.inp_drop(rel_embedded)

pred = torch.mm(e1_embedded*rel_embedded, self.emb_e.weight.transpose(1,0))
pred = F.sigmoid(pred)

return pred



class ConvE(torch.nn.Module):
def __init__(self, num_entities, num_relations):
super(ConvE, self).__init__()
self.emb_e = torch.nn.Embedding(num_entities, Config.embedding_dim, padding_idx=0)
self.emb_rel = torch.nn.Embedding(num_relations, Config.embedding_dim, padding_idx=0)
self.inp_drop = torch.nn.Dropout(Config.input_dropout)
self.hidden_drop = torch.nn.Dropout(Config.dropout)
self.feature_map_drop = torch.nn.Dropout2d(Config.feature_map_dropout)
self.loss = torch.nn.BCELoss()

self.conv1 = torch.nn.Conv2d(1, 8, (3, 3), 1, 0, bias=Config.use_bias)
self.bn0 = torch.nn.BatchNorm2d(1)
self.bn1 = torch.nn.BatchNorm2d(8)
self.bn2 = torch.nn.BatchNorm1d(Config.embedding_dim)
self.register_parameter('b', Parameter(torch.zeros(num_entities)))
self.fc = torch.nn.Linear(2592,Config.embedding_dim)
print(num_entities, num_relations)

def init(self):
xavier_normal(self.emb_e.weight.data)
xavier_normal(self.emb_rel.weight.data)

def forward(self, e1, rel):
e1_embedded= self.emb_e(e1).view(Config.batch_size, 1, 10, 20)
rel_embedded = self.emb_rel(rel).view(Config.batch_size, 1, 10, 20)

stacked_inputs = torch.cat([e1_embedded, rel_embedded], 2)

stacked_inputs = self.bn0(stacked_inputs)
x= self.inp_drop(stacked_inputs)
x= self.conv1(x)
x= self.bn1(x)
x= F.relu(x)
x = self.feature_map_drop(x)
x = x.view(Config.batch_size, -1)
x = self.fc(x)
x = self.hidden_drop(x)
x = self.bn2(x)
x = F.relu(x)
x = torch.mm(x, self.emb_e.weight.transpose(1,0))
x += self.b.expand_as(x)
pred = F.sigmoid(x)

return pred

12 changes: 12 additions & 0 deletions preprocess.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#!/bin/bash
mkdir data
mkdir data/WN18RR
mkdir data/YAGO3-10
mkdir data/FB15k-237
mkdir saved_models
tar -xvf WN18RR.tar.gz -C data/WN18RR
tar -xvf YAGO3-10.tar.gz -C data/YAGO3-10
tar -xvf FB15k-237.tar.gz -C data/FB15k-237
python wrangle_KG.py WN18RR
python wrangle_KG.py YAGO3-10
python wrangle_KG.py FB15k-237
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
-e [email protected]:TimDettmers/spodernet.git#egg=spodernet
Loading

0 comments on commit 9e3ad48

Please sign in to comment.