diff --git a/models/pix2pix_model.py b/models/pix2pix_model.py index 784cb15a..5da01aa1 100644 --- a/models/pix2pix_model.py +++ b/models/pix2pix_model.py @@ -236,10 +236,10 @@ def divide_pred(self, pred): def get_edges(self, t): edge = self.ByteTensor(t.size()).zero_() - edge[:, :, :, 1:] = edge[:, :, :, 1:] | (t[:, :, :, 1:] != t[:, :, :, :-1]) - edge[:, :, :, :-1] = edge[:, :, :, :-1] | (t[:, :, :, 1:] != t[:, :, :, :-1]) - edge[:, :, 1:, :] = edge[:, :, 1:, :] | (t[:, :, 1:, :] != t[:, :, :-1, :]) - edge[:, :, :-1, :] = edge[:, :, :-1, :] | (t[:, :, 1:, :] != t[:, :, :-1, :]) + edge[:, :, :, 1:] = edge[:, :, :, 1:] | ((t[:, :, :, 1:] != t[:, :, :, :-1]).byte()) + edge[:, :, :, :-1] = edge[:, :, :, :-1] | ((t[:, :, :, 1:] != t[:, :, :, :-1]).byte()) + edge[:, :, 1:, :] = edge[:, :, 1:, :] | ((t[:, :, 1:, :] != t[:, :, :-1, :]).byte()) + edge[:, :, :-1, :] = edge[:, :, :-1, :] | ((t[:, :, 1:, :] != t[:, :, :-1, :]).byte()) return edge.float() def reparameterize(self, mu, logvar):