diff --git a/pixelcnn/model.py b/pixelcnn/model.py index b8ef4d0..7a0ee61 100644 --- a/pixelcnn/model.py +++ b/pixelcnn/model.py @@ -142,7 +142,7 @@ def sample(self, shape, count, device='cuda'): for c in range(channels): unnormalized_probs = self.forward(samples) pixel_probs = torch.softmax(unnormalized_probs[:, :, c, i, j], dim=1) - sampled_levels = torch.multinomial(pixel_probs, 1).squeeze() + sampled_levels = torch.multinomial(pixel_probs, 1).squeeze() / (self.color_levels - 1) samples[:, c, i, j] = sampled_levels return samples