-
Notifications
You must be signed in to change notification settings - Fork 162
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added FB15k-237, pytorch code, preprocessing script.
- Loading branch information
1 parent
512c4ae
commit 9e3ad48
Showing
10 changed files
with
557 additions
and
0 deletions.
There are no files selected for viewing
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,166 @@ | ||
import json | ||
import torch | ||
import pickle | ||
import numpy as np | ||
import argparse | ||
import sys | ||
import os | ||
import math | ||
|
||
from os.path import join | ||
import torch.backends.cudnn as cudnn | ||
|
||
from training import ranking_and_hits | ||
from model import ConvE, DistMult, Complex | ||
|
||
from spodernet.preprocessing.pipeline import Pipeline, DatasetStreamer | ||
from spodernet.preprocessing.processors import JsonLoaderProcessors, Tokenizer, AddToVocab, SaveLengthsToState, StreamToHDF5, SaveMaxLengthsToState, CustomTokenizer | ||
from spodernet.preprocessing.processors import ConvertTokenToIdx, ApplyFunction, ToLower, DictKey2ListMapper, ApplyFunction, StreamToBatch | ||
from spodernet.utils.global_config import Config, Backends | ||
from spodernet.utils.logger import Logger, LogLevel | ||
from spodernet.preprocessing.batching import StreamBatcher | ||
from spodernet.preprocessing.pipeline import Pipeline | ||
from spodernet.preprocessing.processors import TargetIdx2MultiTarget | ||
from spodernet.hooks import LossHook, ETAHook | ||
from spodernet.utils.util import Timer | ||
from spodernet.utils.cuda_utils import CUDATimer | ||
from spodernet.utils.cuda_utils import CUDATimer | ||
from spodernet.preprocessing.processors import TargetIdx2MultiTarget | ||
np.set_printoptions(precision=3) | ||
|
||
timer = CUDATimer() | ||
cudnn.benchmark = True | ||
|
||
# parse console parameters and set global variables | ||
Config.backend = Backends.TORCH | ||
Config.parse_argv(sys.argv) | ||
|
||
Config.cuda = True | ||
|
||
Config.hidden_size = 1 | ||
Config.embedding_dim = 200 | ||
#Logger.GLOBAL_LOG_LEVEL = LogLevel.DEBUG | ||
|
||
|
||
model_name = 'ConvE_{0}_{1}'.format(Config.input_dropout, Config.dropout) | ||
do_process = True | ||
epochs = 1000 | ||
Config.batch_size = 128 | ||
load = False | ||
#dataset_name = 'YAGO3-10' | ||
#dataset_name = 'WN18RR' | ||
dataset_name = 'FB15k-237' | ||
model_path = 'saved_models/{0}_{1}.model'.format(dataset_name, model_name) | ||
|
||
|
||
''' Preprocess knowledge graph using spodernet. ''' | ||
def preprocess(dataset_name, delete_data=False): | ||
full_path = 'data/{0}/e1rel_to_e2_full.json'.format(dataset_name) | ||
train_path = 'data/{0}/e1rel_to_e2_train.json'.format(dataset_name) | ||
dev_ranking_path = 'data/{0}/e1rel_to_e2_ranking_dev.json'.format(dataset_name) | ||
test_ranking_path = 'data/{0}/e1rel_to_e2_ranking_test.json'.format(dataset_name) | ||
|
||
keys2keys = {} | ||
keys2keys['e1'] = 'e1' # entities | ||
keys2keys['rel'] = 'rel' # relations | ||
keys2keys['e2'] = 'e1' # entities | ||
keys2keys['e2_multi1'] = 'e1' # entity | ||
keys2keys['e2_multi2'] = 'e1' # entity | ||
input_keys = ['e1', 'rel', 'e2', 'e2_multi1', 'e2_multi2'] | ||
d = DatasetStreamer(input_keys) | ||
d.add_stream_processor(JsonLoaderProcessors()) | ||
d.add_stream_processor(DictKey2ListMapper(input_keys)) | ||
|
||
# process full vocabulary and save it to disk | ||
d.set_path(full_path) | ||
p = Pipeline(dataset_name, delete_data, keys=input_keys, skip_transformation=True) | ||
p.add_sent_processor(ToLower()) | ||
p.add_sent_processor(CustomTokenizer(lambda x: x.split(' ')),keys=['e2_multi1', 'e2_multi2']) | ||
p.add_token_processor(AddToVocab()) | ||
p.execute(d) | ||
p.save_vocabs() | ||
|
||
|
||
# process train, dev and test sets and save them to hdf5 | ||
p.skip_transformation = False | ||
for path, name in zip([train_path, dev_ranking_path, test_ranking_path], ['train', 'dev_ranking', 'test_ranking']): | ||
d.set_path(path) | ||
p.clear_processors() | ||
p.add_sent_processor(ToLower()) | ||
p.add_sent_processor(CustomTokenizer(lambda x: x.split(' ')),keys=['e2_multi1', 'e2_multi2']) | ||
p.add_post_processor(ConvertTokenToIdx(keys2keys=keys2keys), keys=['e1', 'rel', 'e2', 'e2_multi1', 'e2_multi2']) | ||
p.add_post_processor(StreamToHDF5(name, samples_per_file=1000, keys=input_keys)) | ||
p.execute(d) | ||
|
||
|
||
def main(): | ||
if do_process: preprocess(dataset_name, delete_data=True) | ||
input_keys = ['e1', 'rel', 'e2', 'e2_multi1', 'e2_multi2'] | ||
p = Pipeline(dataset_name, keys=input_keys) | ||
p.load_vocabs() | ||
vocab = p.state['vocab'] | ||
|
||
num_entities = vocab['e1'].num_token | ||
|
||
train_batcher = StreamBatcher(dataset_name, 'train', Config.batch_size, randomize=True, keys=input_keys) | ||
dev_rank_batcher = StreamBatcher(dataset_name, 'dev_ranking', Config.batch_size, randomize=False, loader_threads=4, keys=input_keys, is_volatile=True) | ||
test_rank_batcher = StreamBatcher(dataset_name, 'test_ranking', Config.batch_size, randomize=False, loader_threads=4, keys=input_keys, is_volatile=True) | ||
|
||
|
||
#model = Complex(vocab['e1'].num_token, vocab['rel'].num_token) | ||
#model = DistMult(vocab['e1'].num_token, vocab['rel'].num_token) | ||
model = ConvE(vocab['e1'].num_token, vocab['rel'].num_token) | ||
|
||
train_batcher.at_batch_prepared_observers.insert(1,TargetIdx2MultiTarget(num_entities, 'e2_multi1', 'e2_multi1_binary')) | ||
|
||
|
||
eta = ETAHook('train', print_every_x_batches=100) | ||
train_batcher.subscribe_to_events(eta) | ||
train_batcher.subscribe_to_start_of_epoch_event(eta) | ||
train_batcher.subscribe_to_events(LossHook('train', print_every_x_batches=100)) | ||
|
||
if Config.cuda: | ||
model.cuda() | ||
if load: | ||
model_params = torch.load(model_path) | ||
print(model) | ||
print([(key, value.size()) for key, value in model_params.items()]) | ||
model.load_state_dict(model_params) | ||
model.eval() | ||
ranking_and_hits(model, test_rank_batcher, vocab, 'test_evaluation') | ||
ranking_and_hits(model, dev_rank_batcher, vocab, 'dev_evaluation') | ||
else: | ||
model.init() | ||
|
||
|
||
opt = torch.optim.Adam(model.parameters(), lr=Config.learning_rate, weight_decay=Config.L2) | ||
for epoch in range(epochs): | ||
model.train() | ||
for i, str2var in enumerate(train_batcher): | ||
opt.zero_grad() | ||
e1 = str2var['e1'] | ||
rel = str2var['rel'] | ||
e2_multi = str2var['e2_multi1_binary'].float() | ||
# label smoothing | ||
e2_multi = ((1.0-Config.label_smoothing_epsilon)*e2_multi) + (1.0/e2_multi.size(1)) | ||
|
||
pred = model.forward(e1, rel) | ||
loss = model.loss(pred, e2_multi) | ||
loss.backward() | ||
opt.step() | ||
|
||
train_batcher.state.loss = loss | ||
|
||
|
||
print('saving to {0}'.format(model_path)) | ||
torch.save(model.state_dict(), model_path) | ||
|
||
model.eval() | ||
ranking_and_hits(model, dev_rank_batcher, vocab, 'dev_evaluation') | ||
if epoch % 3 == 0: | ||
if epoch > 0: | ||
ranking_and_hits(model, test_rank_batcher, vocab, 'test_evaluation') | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
import torch | ||
from torch.nn import functional as F, Parameter | ||
from torch.autograd import Variable | ||
|
||
|
||
from spodernet.utils.global_config import Config | ||
from spodernet.utils.cuda_utils import CUDATimer | ||
from torch.nn.init import xavier_normal, xavier_uniform | ||
from spodernet.utils.cuda_utils import CUDATimer | ||
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence | ||
|
||
timer = CUDATimer() | ||
|
||
|
||
class Complex(torch.nn.Module): | ||
def __init__(self, num_entities, num_relations): | ||
super(Complex, self).__init__() | ||
self.num_entities = num_entities | ||
self.emb_e_real = torch.nn.Embedding(num_entities, Config.embedding_dim, padding_idx=0) | ||
self.emb_e_img = torch.nn.Embedding(num_entities, Config.embedding_dim, padding_idx=0) | ||
self.emb_rel_real = torch.nn.Embedding(num_relations, Config.embedding_dim, padding_idx=0) | ||
self.emb_rel_img = torch.nn.Embedding(num_relations, Config.embedding_dim, padding_idx=0) | ||
self.inp_drop = torch.nn.Dropout(Config.input_dropout) | ||
self.loss = torch.nn.BCELoss() | ||
|
||
def init(self): | ||
xavier_normal(self.emb_e_real.weight.data) | ||
xavier_normal(self.emb_e_img.weight.data) | ||
xavier_normal(self.emb_rel_real.weight.data) | ||
xavier_normal(self.emb_rel_img.weight.data) | ||
|
||
def forward(self, e1, rel): | ||
|
||
e1_embedded_real = self.inp_drop(self.emb_e_real(e1)).view(Config.batch_size, -1) | ||
rel_embedded_real = self.inp_drop(self.emb_rel_real(rel)).view(Config.batch_size, -1) | ||
e1_embedded_img = self.inp_drop(self.emb_e_img(e1)).view(Config.batch_size, -1) | ||
rel_embedded_img = self.inp_drop(self.emb_rel_img(rel)).view(Config.batch_size, -1) | ||
|
||
e1_embedded_real = self.inp_drop(e1_embedded_real) | ||
rel_embedded_real = self.inp_drop(rel_embedded_real) | ||
e1_embedded_img = self.inp_drop(e1_embedded_img) | ||
rel_embedded_img = self.inp_drop(rel_embedded_img) | ||
|
||
# complex space bilinear product (equivalent to HolE) | ||
realrealreal = torch.mm(e1_embedded_real*rel_embedded_real, self.emb_e_real.weight.transpose(1,0)) | ||
realimgimg = torch.mm(e1_embedded_real*rel_embedded_img, self.emb_e_img.weight.transpose(1,0)) | ||
imgrealimg = torch.mm(e1_embedded_img*rel_embedded_real, self.emb_e_img.weight.transpose(1,0)) | ||
imgimgreal = torch.mm(e1_embedded_img*rel_embedded_img, self.emb_e_real.weight.transpose(1,0)) | ||
pred = realrealreal + realimgimg + imgrealimg - imgimgreal | ||
pred = F.sigmoid(pred) | ||
|
||
return pred | ||
|
||
|
||
class DistMult(torch.nn.Module): | ||
def __init__(self, num_entities, num_relations): | ||
super(DistMult, self).__init__() | ||
self.emb_e = torch.nn.Embedding(num_entities, Config.embedding_dim, padding_idx=0) | ||
self.emb_rel = torch.nn.Embedding(num_relations, Config.embedding_dim, padding_idx=0) | ||
self.inp_drop = torch.nn.Dropout(Config.input_dropout) | ||
self.loss = torch.nn.BCELoss() | ||
|
||
def init(self): | ||
xavier_normal(self.emb_e.weight.data) | ||
xavier_normal(self.emb_rel.weight.data) | ||
|
||
def forward(self, e1, rel): | ||
e1_embedded= self.emb_e(e1) | ||
rel_embedded= self.emb_rel(rel) | ||
e1_embedded = e1_embedded.view(-1, Config.embedding_dim) | ||
rel_embedded = rel_embedded.view(-1, Config.embedding_dim) | ||
|
||
e1_embedded = self.inp_drop(e1_embedded) | ||
rel_embedded = self.inp_drop(rel_embedded) | ||
|
||
pred = torch.mm(e1_embedded*rel_embedded, self.emb_e.weight.transpose(1,0)) | ||
pred = F.sigmoid(pred) | ||
|
||
return pred | ||
|
||
|
||
|
||
class ConvE(torch.nn.Module): | ||
def __init__(self, num_entities, num_relations): | ||
super(ConvE, self).__init__() | ||
self.emb_e = torch.nn.Embedding(num_entities, Config.embedding_dim, padding_idx=0) | ||
self.emb_rel = torch.nn.Embedding(num_relations, Config.embedding_dim, padding_idx=0) | ||
self.inp_drop = torch.nn.Dropout(Config.input_dropout) | ||
self.hidden_drop = torch.nn.Dropout(Config.dropout) | ||
self.feature_map_drop = torch.nn.Dropout2d(Config.feature_map_dropout) | ||
self.loss = torch.nn.BCELoss() | ||
|
||
self.conv1 = torch.nn.Conv2d(1, 8, (3, 3), 1, 0, bias=Config.use_bias) | ||
self.bn0 = torch.nn.BatchNorm2d(1) | ||
self.bn1 = torch.nn.BatchNorm2d(8) | ||
self.bn2 = torch.nn.BatchNorm1d(Config.embedding_dim) | ||
self.register_parameter('b', Parameter(torch.zeros(num_entities))) | ||
self.fc = torch.nn.Linear(2592,Config.embedding_dim) | ||
print(num_entities, num_relations) | ||
|
||
def init(self): | ||
xavier_normal(self.emb_e.weight.data) | ||
xavier_normal(self.emb_rel.weight.data) | ||
|
||
def forward(self, e1, rel): | ||
e1_embedded= self.emb_e(e1).view(Config.batch_size, 1, 10, 20) | ||
rel_embedded = self.emb_rel(rel).view(Config.batch_size, 1, 10, 20) | ||
|
||
stacked_inputs = torch.cat([e1_embedded, rel_embedded], 2) | ||
|
||
stacked_inputs = self.bn0(stacked_inputs) | ||
x= self.inp_drop(stacked_inputs) | ||
x= self.conv1(x) | ||
x= self.bn1(x) | ||
x= F.relu(x) | ||
x = self.feature_map_drop(x) | ||
x = x.view(Config.batch_size, -1) | ||
x = self.fc(x) | ||
x = self.hidden_drop(x) | ||
x = self.bn2(x) | ||
x = F.relu(x) | ||
x = torch.mm(x, self.emb_e.weight.transpose(1,0)) | ||
x += self.b.expand_as(x) | ||
pred = F.sigmoid(x) | ||
|
||
return pred | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
#!/bin/bash | ||
mkdir data | ||
mkdir data/WN18RR | ||
mkdir data/YAGO3-10 | ||
mkdir data/FB15k-237 | ||
mkdir saved_models | ||
tar -xvf WN18RR.tar.gz -C data/WN18RR | ||
tar -xvf YAGO3-10.tar.gz -C data/YAGO3-10 | ||
tar -xvf FB15k-237.tar.gz -C data/FB15k-237 | ||
python wrangle_KG.py WN18RR | ||
python wrangle_KG.py YAGO3-10 | ||
python wrangle_KG.py FB15k-237 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
-e [email protected]:TimDettmers/spodernet.git#egg=spodernet |
Oops, something went wrong.