Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 7 additions & 8 deletions model/hidden.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,11 @@ def train_on_batch(self, batch: list):
# ---------------- Train the discriminator -----------------------------
self.optimizer_discrim.zero_grad()
# train on cover
d_target_label_cover = torch.full((batch_size, 1), self.cover_label, device=self.device)
d_target_label_encoded = torch.full((batch_size, 1), self.encoded_label, device=self.device)
g_target_label_encoded = torch.full((batch_size, 1), self.cover_label, device=self.device)

d_target_label_cover = torch.full((batch_size, 1), self.cover_label, device=self.device, dtype=torch.float32)
d_target_label_encoded = torch.full((batch_size, 1), self.encoded_label, device=self.device, dtype=torch.float32)
g_target_label_encoded = torch.full((batch_size, 1), self.cover_label, device=self.device, dtype=torch.float32)
d_on_cover = self.discriminator(images)
d_loss_on_cover = self.bce_with_logits_loss(d_on_cover, d_target_label_cover)
d_loss_on_cover = self.bce_with_logits_loss(d_on_cover, d_target_label_cover.float())
d_loss_on_cover.backward()

# train on fake
Expand Down Expand Up @@ -144,15 +143,15 @@ def validate_on_batch(self, batch: list):
g_target_label_encoded = torch.full((batch_size, 1), self.cover_label, device=self.device)

d_on_cover = self.discriminator(images)
d_loss_on_cover = self.bce_with_logits_loss(d_on_cover, d_target_label_cover)
d_loss_on_cover = self.bce_with_logits_loss(d_on_cover, d_target_label_cover.float())

encoded_images, noised_images, decoded_messages = self.encoder_decoder(images, messages)

d_on_encoded = self.discriminator(encoded_images)
d_loss_on_encoded = self.bce_with_logits_loss(d_on_encoded, d_target_label_encoded)
d_loss_on_encoded = self.bce_with_logits_loss(d_on_encoded, d_target_label_encoded.float())

d_on_encoded_for_enc = self.discriminator(encoded_images)
g_loss_adv = self.bce_with_logits_loss(d_on_encoded_for_enc, g_target_label_encoded)
g_loss_adv = self.bce_with_logits_loss(d_on_encoded_for_enc, g_target_label_encoded.float())

if self.vgg_loss is None:
g_loss_enc = self.mse_loss(encoded_images, images)
Expand Down
12 changes: 10 additions & 2 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,15 @@ def save_images(original_images, watermarked_images, epoch, folder, resize_to=No

stacked_images = torch.cat([images, watermarked_images], dim=0)
filename = os.path.join(folder, 'epoch-{}.png'.format(epoch))
torchvision.utils.save_image(stacked_images, filename, original_images.shape[0], normalize=False)

# 修复后的save_image调用
torchvision.utils.save_image(
stacked_images,
filename,
nrow=int(original_images.shape[0]), # 每行显示的图像数量
normalize=False,
format='png' # 明确指定格式
)


def sorted_nicely(l):
Expand Down Expand Up @@ -181,4 +189,4 @@ def write_losses(file_name, losses_accu, epoch, duration):
writer.writerow(row_to_write)
row_to_write = [epoch] + ['{:.4f}'.format(loss_avg.avg) for loss_avg in losses_accu.values()] + [
'{:.0f}'.format(duration)]
writer.writerow(row_to_write)
writer.writerow(row_to_write)