diff --git a/.gitignore b/.gitignore index b2a15ab..39c649a 100644 --- a/.gitignore +++ b/.gitignore @@ -64,3 +64,6 @@ venv.bak/ .mypy_cache/ .dmypy.json dmypy.json + +# aim +*.aim* \ No newline at end of file diff --git a/dwave/plugins/torch/models/discrete_variational_autoencoder.py b/dwave/plugins/torch/models/discrete_variational_autoencoder.py index 0195c1b..ba145f2 100644 --- a/dwave/plugins/torch/models/discrete_variational_autoencoder.py +++ b/dwave/plugins/torch/models/discrete_variational_autoencoder.py @@ -80,28 +80,26 @@ def __init__( self._decoder = decoder if latent_to_discrete is None: - def latent_to_discrete( - logits: torch.Tensor, n_samples: int - ) -> torch.Tensor: - # Logits is of shape (batch_size, n_discrete), we assume these logits + def latent_to_discrete(logits: torch.Tensor, n_samples: int) -> torch.Tensor: + # Logits is of shape (batch_size, l1, l2, ...), we assume these logits # refer to the probability of each discrete variable being 1. To use the # gumbel softmax function we need to reshape the logits to (batch_size, - # n_discrete, 1), and then stack the logits to a zeros tensor of the + # l1, l2, ..., 1), and then stack the logits to a zeros tensor of the # same shape. This is done to ensure that the gumbel softmax function # works correctly. - + n_feature_dims = logits.dim() - 1 logits = logits.unsqueeze(-1) logits = torch.cat((logits, torch.zeros_like(logits)), dim=-1) # We now create a new leading dimension and repeat the logits n_samples # times: - logits = logits.unsqueeze(1).repeat(1, n_samples, 1, 1) - one_hots = torch.nn.functional.gumbel_softmax( - logits, tau=1 / 7, hard=True + logits = logits.unsqueeze(1).repeat( + *((1, n_samples) + (1,) * n_feature_dims + (1,)) ) + one_hots = torch.nn.functional.gumbel_softmax(logits, tau=1 / 7, hard=True) # The constant 1/7 is used because it was used in # https://iopscience.iop.org/article/10.1088/2632-2153/aba220 - # one_hots is of shape (batch_size, n_samples, n_discrete, 2), we need + # one_hots is of shape (batch_size, n_samples, f_1, f_2, ..., 2), we need # to take the first element of the last dimension and convert it to spin # variables to make the latent space compatible with QPU models. return one_hots[..., 0] * 2 - 1 diff --git a/releasenotes/notes/arbitrary-feature-dimension-in-latent-to-discrete-57917ab12f34bdd8.yaml b/releasenotes/notes/arbitrary-feature-dimension-in-latent-to-discrete-57917ab12f34bdd8.yaml new file mode 100644 index 0000000..f893a09 --- /dev/null +++ b/releasenotes/notes/arbitrary-feature-dimension-in-latent-to-discrete-57917ab12f34bdd8.yaml @@ -0,0 +1,9 @@ +--- +fixes: + - | + The default ``latent_to_discrete`` transformation in + ``dwave.plugins.torch.models.discrete_variational_autoencoder.DiscreteVariationalAutoencoder`` + has been fixed to accommodate arbitrary encoders. Before, the default + transformation only allowed encoders whose output shape was (B, l). Now, + encoders can have an arbitrary number of feature dimensions, i.e., the + shape can be (B, l1, l2, ...). \ No newline at end of file diff --git a/tests/test_dvae_winci2020.py b/tests/test_dvae_winci2020.py index 06b03d3..6387978 100644 --- a/tests/test_dvae_winci2020.py +++ b/tests/test_dvae_winci2020.py @@ -18,7 +18,9 @@ from parameterized import parameterized from dwave.plugins.torch.models.boltzmann_machine import GraphRestrictedBoltzmannMachine -from dwave.plugins.torch.models.discrete_variational_autoencoder import DiscreteVariationalAutoencoder as DVAE +from dwave.plugins.torch.models.discrete_variational_autoencoder import ( + DiscreteVariationalAutoencoder as DVAE, +) from dwave.plugins.torch.models.losses.kl_divergence import pseudo_kl_divergence_loss from dwave.samplers import SimulatedAnnealingSampler @@ -32,10 +34,7 @@ def setUp(self): latent_features = 2 # Data in corners of unit square: - self.data = torch.tensor([[1.0, 1.0], - [1.0, 0.0], - [0.0, 0.0], - [0.0, 1.0]]) + self.data = torch.tensor([[1.0, 1.0], [1.0, 0.0], [0.0, 0.0], [0.0, 1.0]]) # The encoder maps input data to logits. We make this encoder without parameters # for simplicity. The encoder will map 1s to 10s and 0s to -10s, so that the @@ -48,16 +47,45 @@ def setUp(self): # [1, -1]. class Encoder(torch.nn.Module): + def __init__(self, n_latent_dims: int): + super().__init__() + self.n_latent_dims = n_latent_dims + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # x is always two-dimensional of shape (batch_size, features_size) + dims_to_add = self.n_latent_dims - 1 + output = x * 20 - 10 + for _ in range(dims_to_add): + output = output.unsqueeze(-2) + return output + + class Decoder(torch.nn.Module): + def __init__(self, latent_features: int, input_features: int): + super().__init__() + self.linear = torch.nn.Linear(latent_features, input_features) + def forward(self, x: torch.Tensor) -> torch.Tensor: - return x * 20 - 10 + # x is of shape (batch_size, n_samples, l1, l2, ...) + n_latent_dims_to_remove = x.ndim - 3 + for _ in range(n_latent_dims_to_remove): + x = x.squeeze(1) + return self.linear(x) - self.encoder = Encoder() - self.decoder = torch.nn.Linear(latent_features, input_features) + # self.encoders is a dict whose keys are the number of latent dims and the values + # are the models themselves + self.encoders = {i: Encoder(i) for i in range(1, 3)} + # self.decoders is independent of number of latent dims, but we also create a dict to separate + # them + self.decoders = {i: Decoder(latent_features, input_features) for i in range(1, 3)} - self.dvae = DVAE(self.encoder, self.decoder) + # self.dvaes is a dict whose keys are the numbers of latent dims and the values are the models + # themselves + + self.dvaes = {i: DVAE(self.encoders[i], self.decoders[i]) for i in range(1, 3)} self.boltzmann_machine = GraphRestrictedBoltzmannMachine( - nodes=(0, 1), edges=[(0, 1)], + nodes=(0, 1), + edges=[(0, 1)], linear={0: 0.1, 1: -0.2}, quadratic={(0, 1): -1.2}, ) @@ -66,8 +94,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: def test_mappings(self): """Test the mapping between data and logits.""" - # Let's make sure that indeed the maps are correct: - _, discretes, _ = self.dvae(self.data, n_samples=1) + # Let's make sure that indeed the maps are correct. For this, we use only the first + # autoencoder, which is the one whose encoder maps data to a single feature dimension. The + # second autoencoder maps data to two feature dimensions (the last one is a dummy dimension) + _, discretes, _ = self.dvaes[1](self.data, n_samples=1) + # squeeze the replica dimension: discretes = discretes.squeeze(1) # map [1, 1] to [1, 1]: torch.testing.assert_close(torch.tensor([1, 1]).float(), discretes[0]) @@ -78,19 +109,18 @@ def test_mappings(self): # map [0, 1] to [-1, 1]: torch.testing.assert_close(torch.tensor([-1, 1]).float(), discretes[3]) - def test_train(self): + @parameterized.expand([1, 2]) + def test_train(self, n_latent_dims): """Test training simple dataset.""" + dvae = self.dvaes[n_latent_dims] optimiser = torch.optim.SGD( - list(self.dvae.parameters()) - + list(self.boltzmann_machine.parameters()), + list(dvae.parameters()) + list(self.boltzmann_machine.parameters()), lr=0.01, momentum=0.9, ) N_SAMPLES = 1 for _ in range(1000): - latents, discretes, reconstructed_data = self.dvae( - self.data, n_samples=N_SAMPLES - ) + latents, discretes, reconstructed_data = dvae(self.data, n_samples=N_SAMPLES) true_data = self.data.unsqueeze(1).repeat(1, N_SAMPLES, 1) # Measure the reconstruction loss @@ -114,38 +144,43 @@ def test_train(self): torch.testing.assert_close(true_data, reconstructed_data) # Furthermore, the GRBM should learn that all spin strings of length 2 are # equally likely, so the h and J parameters should be close to 0: - torch.testing.assert_close(self.boltzmann_machine.linear, torch.zeros(2), - rtol=1e-2, atol=1e-2) - torch.testing.assert_close(self.boltzmann_machine.quadratic, torch.tensor([0.0]).float(), - rtol=1e-2, atol=1e-2) - - @parameterized.expand([ - ( - 1, - torch.tensor([[[1., 1.]], [[1., -1.]], [[-1., -1.]], [[-1., 1.]]]) - ), - ( - 5, - torch.tensor([[[1., 1.]] * 5, [[1., -1.]] * 5, [[-1., -1.]] * 5, [[-1., 1.]] * 5]) - ), - ]) + torch.testing.assert_close( + self.boltzmann_machine.linear, torch.zeros(2), rtol=1e-2, atol=1e-2 + ) + torch.testing.assert_close( + self.boltzmann_machine.quadratic, torch.tensor([0.0]).float(), rtol=1e-2, atol=1e-2 + ) + + @parameterized.expand( + [ + (1, torch.tensor([[[1.0, 1.0]], [[1.0, -1.0]], [[-1.0, -1.0]], [[-1.0, 1.0]]])), + ( + 5, + torch.tensor( + [[[1.0, 1.0]] * 5, [[1.0, -1.0]] * 5, [[-1.0, -1.0]] * 5, [[-1.0, 1.0]] * 5] + ), + ), + ] + ) def test_latent_to_discrete(self, n_samples, expected): """Test the latent_to_discrete default method.""" - latents = self.encoder(self.data) - discretes = self.dvae.latent_to_discrete(latents, n_samples) + # All encoders and dvaes only differ in the number of dummy feature dimensions in the latent + # space. For this reason, this test can only be done with the case of one feature dimension, + # which corresponds to the first encoder and dvae. + latents = self.encoders[1](self.data) + discretes = self.dvaes[1].latent_to_discrete(latents, n_samples) assert torch.equal(discretes, expected) - @parameterized.expand([0, 1, 5, 1000]) - def test_forward(self, n_samples): + @parameterized.expand([(i, j) for i in range(1, 3) for j in [0, 1, 5, 1000]]) + def test_forward(self, n_latent_dims, n_samples): """Test the forward method.""" - latents = self.encoder(self.data) - discretes = self.dvae.latent_to_discrete(latents, n_samples) - reconstructed_x = self.decoder(discretes) + latents = self.encoders[n_latent_dims](self.data) + discretes = self.dvaes[n_latent_dims].latent_to_discrete(latents, n_samples) + reconstructed_x = self.decoders[n_latent_dims](discretes) - expected_latents, expected_discretes, expected_reconstructed_x = self.dvae.forward( - x=self.data, - n_samples=n_samples - ) + expected_latents, expected_discretes, expected_reconstructed_x = self.dvaes[ + n_latent_dims + ].forward(x=self.data, n_samples=n_samples) assert torch.equal(reconstructed_x, expected_reconstructed_x) assert torch.equal(discretes, expected_discretes)