diff --git a/model/hidden.py b/model/hidden.py index 718cf6e..cc6aaba 100644 --- a/model/hidden.py +++ b/model/hidden.py @@ -66,12 +66,11 @@ def train_on_batch(self, batch: list): # ---------------- Train the discriminator ----------------------------- self.optimizer_discrim.zero_grad() # train on cover - d_target_label_cover = torch.full((batch_size, 1), self.cover_label, device=self.device) - d_target_label_encoded = torch.full((batch_size, 1), self.encoded_label, device=self.device) - g_target_label_encoded = torch.full((batch_size, 1), self.cover_label, device=self.device) - + d_target_label_cover = torch.full((batch_size, 1), self.cover_label, device=self.device, dtype=torch.float32) + d_target_label_encoded = torch.full((batch_size, 1), self.encoded_label, device=self.device, dtype=torch.float32) + g_target_label_encoded = torch.full((batch_size, 1), self.cover_label, device=self.device, dtype=torch.float32) d_on_cover = self.discriminator(images) - d_loss_on_cover = self.bce_with_logits_loss(d_on_cover, d_target_label_cover) + d_loss_on_cover = self.bce_with_logits_loss(d_on_cover, d_target_label_cover.float()) d_loss_on_cover.backward() # train on fake @@ -144,15 +143,15 @@ def validate_on_batch(self, batch: list): g_target_label_encoded = torch.full((batch_size, 1), self.cover_label, device=self.device) d_on_cover = self.discriminator(images) - d_loss_on_cover = self.bce_with_logits_loss(d_on_cover, d_target_label_cover) + d_loss_on_cover = self.bce_with_logits_loss(d_on_cover, d_target_label_cover.float()) encoded_images, noised_images, decoded_messages = self.encoder_decoder(images, messages) d_on_encoded = self.discriminator(encoded_images) - d_loss_on_encoded = self.bce_with_logits_loss(d_on_encoded, d_target_label_encoded) + d_loss_on_encoded = self.bce_with_logits_loss(d_on_encoded, d_target_label_encoded.float()) d_on_encoded_for_enc = self.discriminator(encoded_images) - g_loss_adv = self.bce_with_logits_loss(d_on_encoded_for_enc, g_target_label_encoded) + g_loss_adv = self.bce_with_logits_loss(d_on_encoded_for_enc, g_target_label_encoded.float()) if self.vgg_loss is None: g_loss_enc = self.mse_loss(encoded_images, images) diff --git a/utils.py b/utils.py index 5bf578d..105c836 100644 --- a/utils.py +++ b/utils.py @@ -54,7 +54,15 @@ def save_images(original_images, watermarked_images, epoch, folder, resize_to=No stacked_images = torch.cat([images, watermarked_images], dim=0) filename = os.path.join(folder, 'epoch-{}.png'.format(epoch)) - torchvision.utils.save_image(stacked_images, filename, original_images.shape[0], normalize=False) + + # 修复后的save_image调用 + torchvision.utils.save_image( + stacked_images, + filename, + nrow=int(original_images.shape[0]), # 每行显示的图像数量 + normalize=False, + format='png' # 明确指定格式 + ) def sorted_nicely(l): @@ -181,4 +189,4 @@ def write_losses(file_name, losses_accu, epoch, duration): writer.writerow(row_to_write) row_to_write = [epoch] + ['{:.4f}'.format(loss_avg.avg) for loss_avg in losses_accu.values()] + [ '{:.0f}'.format(duration)] - writer.writerow(row_to_write) \ No newline at end of file + writer.writerow(row_to_write)