diff --git a/lang/generate.py b/lang/generate.py index 65e3c87..d744290 100644 --- a/lang/generate.py +++ b/lang/generate.py @@ -81,11 +81,11 @@ def load_models(load_path): gan_gen = gan_gen.cuda() gan_disc = gan_disc.cuda() - word2idx = json.load(open(os.path.join(args.save, 'vocab.json'), 'r')) + word2idx = json.load(open(os.path.join(load_path, 'vocab.json'), 'r')) idx2word = {v: k for k, v in word2idx.items()} print('Loading models from {}'.format(args.save)) - loaded = torch.load(os.path.join(args.save, "model.pt")) + loaded = torch.load(os.path.join(load_path, "model.pt")) autoencoder.load_state_dict(loaded.get('ae')) gan_gen.load_state_dict(loaded.get('gan_g')) gan_disc.load_state_dict(loaded.get('gan_d'))