diff --git a/lang/generate.py b/lang/generate.py index 65e3c87..ffc5fd2 100644 --- a/lang/generate.py +++ b/lang/generate.py @@ -40,7 +40,7 @@ def interpolate(ae, gg, z1, z2, vocab, if type(z1) == Variable: noise1 = z1 noise2 = z2 - elif type(z1) == torch.FloatTensor or type(z1) == torch.cuda.FloatTensor: + elif type(z1) == torch.Tensor or type(z1) == torch.cuda.Tensor: noise1 = Variable(z1, volatile=True) noise2 = Variable(z2, volatile=True) elif type(z1) == np.ndarray: