diff --git a/ArtExtract_Soyoung/utils/metrics.py b/ArtExtract_Soyoung/utils/metrics.py index 17b91515..e40961c4 100644 --- a/ArtExtract_Soyoung/utils/metrics.py +++ b/ArtExtract_Soyoung/utils/metrics.py @@ -24,6 +24,10 @@ def __init__(self, feature_extractor, size_average=True): self.ssim_metric = ssim(data_range=1.0) def psnr(self, output, target): + if output.shape != target.shape: + raise ValueError( + f"Shape mismatch: output has shape {output.shape}, target has shape {target.shape}" + ) # Compute MSE per channel mse_per_channel = torch.mean((target - output) ** 2, dim=[0, 2, 3]) # MSE per channel @@ -72,4 +76,4 @@ def forward(self, output, target): psnr_value = self.psnr(output, target) lpips_value = self.lpips(output, target) ssim_value = self.ssim(output, target) - return psnr_value, lpips_value, ssim_value \ No newline at end of file + return psnr_value, lpips_value, ssim_value