diff --git a/lang/models.py b/lang/models.py index 3bc7cac..3f4ad6c 100644 --- a/lang/models.py +++ b/lang/models.py @@ -273,7 +273,7 @@ def generate(autoencoder, gan_gen, z, vocab, sample, maxlen): """ if type(z) == Variable: noise = z - elif type(z) == torch.FloatTensor or type(z) == torch.cuda.FloatTensor: + elif type(z) == torch.Tensor or type(z) == torch.cuda.Tensor: noise = Variable(z, volatile=True) elif type(z) == np.ndarray: noise = Variable(torch.from_numpy(z).float(), volatile=True)