Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Finetuning argparse interface #52

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 65 additions & 29 deletions elmoformanylangs/biLM.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import torch.optim as optim
from torch.autograd import Variable
from .modules.elmo import ElmobiLm
from .elmo import Embedder
from .modules.lstm import LstmbiLm
from .modules.token_embedder import ConvTokenEmbedder, LstmTokenEmbedder
from .modules.embedding_layer import EmbeddingLayer
Expand Down Expand Up @@ -280,8 +281,14 @@ def save_model(self, path, save_classify_layer):
torch.save(self.classify_layer.state_dict(), os.path.join(path, 'classifier.pkl'))

def load_model(self, path):
self.token_embedder.load_state_dict(torch.load(os.path.join(path, 'token_embedder.pkl')))
self.encoder.load_state_dict(torch.load(os.path.join(path, 'encoder.pkl')))

self.token_embedder.load_state_dict(torch.load(os.path.join(path, 'token_embedder.pkl'),
map_location=lambda storage, loc: storage))
self.encoder.load_state_dict(torch.load(os.path.join(path, 'encoder.pkl'),
map_location=lambda storage, loc: storage))

#self.token_embedder.load_state_dict(torch.load(os.path.join(path, 'token_embedder.pkl')))
#self.encoder.load_state_dict(torch.load(os.path.join(path, 'encoder.pkl')))
self.classify_layer.load_state_dict(torch.load(os.path.join(path, 'classifier.pkl')))


Expand Down Expand Up @@ -340,7 +347,7 @@ def train_model(epoch, opt, model, optimizer,
loss_forward, loss_backward = model.forward(w, c, masks)

loss = (loss_forward + loss_backward) / 2.0
total_loss += loss_forward.data[0]
total_loss += loss_forward.item()
n_tags = sum(lens)
total_tag += n_tags
loss.backward()
Expand Down Expand Up @@ -440,6 +447,12 @@ def train():

cmd.add_argument('--valid_size', type=int, default=0, help="size of validation dataset when there's no valid.")
cmd.add_argument('--eval_steps', required=False, type=int, help='report every xx batches.')


cmd.add_argument('--fine_tune', required=False, action="store_true", help='finetune base model')
cmd.add_argument('--old_model_folder', required=False, type=str, help='path to base model for finetuning')



opt = cmd.parse_args(sys.argv[2:])

Expand Down Expand Up @@ -501,46 +514,60 @@ def train():
len(test_data), sum([len(s) - 1 for s in test_data])))
else:
test_data = None

if opt.word_embedding is not None:
embs = load_embedding(opt.word_embedding)
word_lexicon = {word: i for i, word in enumerate(embs[0])}
else:
embs = None
word_lexicon = {}


if opt.fine_tune:
embedder = Embedder(opt.old_model_folder)
word_lexicon = embedder.word_lexicon
char_lexicon = embedder.char_lexicon
label_to_ix = word_lexicon
embs = None


# Maintain the vocabulary. vocabulary is used in either WordEmbeddingInput or softmax classification
vocab = get_truncated_vocab(train_data, opt.min_count)
if opt.fine_tune:
if opt.word_embedding is not None:
embs = load_embedding(opt.word_embedding)
word_lexicon = {word: i for i, word in enumerate(embs[0])}
else:
embs = None
word_lexicon = {}

# Ensure index of '<oov>' is 0
for special_word in ['<oov>', '<bos>', '<eos>', '<pad>']:
if special_word not in word_lexicon:
word_lexicon[special_word] = len(word_lexicon)

for word, _ in vocab:
if word not in word_lexicon:
word_lexicon[word] = len(word_lexicon)
# Ensure index of '<oov>' is 0
for special_word in ['<oov>', '<bos>', '<eos>', '<pad>']:
if special_word not in word_lexicon:
word_lexicon[special_word] = len(word_lexicon)

for word, _ in vocab:
if word not in word_lexicon:
word_lexicon[word] = len(word_lexicon)


# Word Embedding
if config['token_embedder']['word_dim'] > 0:
word_emb_layer = EmbeddingLayer(config['token_embedder']['word_dim'], word_lexicon, fix_emb=False, embs=embs)
#print(word_emb_layer)
logging.info('Word embedding size: {0}'.format(len(word_emb_layer.word2id)))
else:
word_emb_layer = None
logging.info('Vocabulary size: {0}'.format(len(word_lexicon)))

# Character Lexicon
if config['token_embedder']['char_dim'] > 0:
char_lexicon = {}
for sentence in train_data:
for word in sentence:
for ch in word:
if ch not in char_lexicon:
char_lexicon[ch] = len(char_lexicon)

for special_char in ['<bos>', '<eos>', '<oov>', '<pad>', '<bow>', '<eow>']:
if special_char not in char_lexicon:
char_lexicon[special_char] = len(char_lexicon)

if opt.fine_tune:

char_lexicon = {}
for sentence in train_data:
for word in sentence:
for ch in word:
if ch not in char_lexicon:
char_lexicon[ch] = len(char_lexicon)

for special_char in ['<bos>', '<eos>', '<oov>', '<pad>', '<bow>', '<eow>']:
if special_char not in char_lexicon:
char_lexicon[special_char] = len(char_lexicon)

char_emb_layer = EmbeddingLayer(config['token_embedder']['char_dim'], char_lexicon, fix_emb=False)
logging.info('Char embedding size: {0}'.format(len(char_emb_layer.word2id)))
Expand Down Expand Up @@ -572,7 +599,15 @@ def train():

nclasses = len(label_to_ix)

model = Model(config, word_emb_layer, char_emb_layer, nclasses, use_cuda)
if opt.fine_tunes:
model = Model(config, word_emb_layer, char_emb_layer, nclasses, use_cuda)
else:
model = Model(embedder.config, word_emb_layer, char_emb_layer, nclasses, use_cuda)
model.token_embedder = embedder.model.token_embedder
model.encoder = embedder.model.encoder



logging.info(str(model))
if use_cuda:
model = model.cuda()
Expand Down Expand Up @@ -671,6 +706,7 @@ def test():

model = Model(config, word_emb_layer, char_emb_layer, len(word_lexicon), use_cuda)


if use_cuda:
model.cuda()

Expand Down