diff --git a/tests/test_vae.py b/tests/test_vae.py index beb30b44..31d21a26 100644 --- a/tests/test_vae.py +++ b/tests/test_vae.py @@ -3,9 +3,7 @@ from models import VanillaVAE from torchsummary import summary - class TestVAE(unittest.TestCase): - def setUp(self) -> None: # self.model2 = VAE(3, 10) self.model = VanillaVAE(3, 10) @@ -22,11 +20,15 @@ def test_forward(self): def test_loss(self): x = torch.randn(16, 3, 64, 64) - result = self.model(x) loss = self.model.loss_function(*result, M_N = 0.005) print(loss) - + + def test_reconstruction(self): + x = torch.randn(1, 3, 64, 64) + reconstructed, _, _ = self.model(x) + mse = torch.nn.functional.mse_loss(reconstructed, x) + print(f"Reconstruction MSE: {mse.item()}") if __name__ == '__main__': - unittest.main() \ No newline at end of file + unittest.main()