Skip to content

Commit 0180929

Browse files
Vladimir Vargas Calderónanahitamansouri
andcommitted
support arbitrary number of feature dimensions in latent to discrete
fix number of feature dimensions [docs] consistent notation release ignore files in aim testing different number of latent dimensions extend tests to more feature dimensions Update tests/test_dvae_winci2020.py Co-authored-by: Anahita Mansouri Bigvand <[email protected]> Update tests/test_dvae_winci2020.py Co-authored-by: Anahita Mansouri Bigvand <[email protected]> Update tests/test_dvae_winci2020.py Co-authored-by: Anahita Mansouri Bigvand <[email protected]>
1 parent d3f2989 commit 0180929

File tree

4 files changed

+101
-53
lines changed

4 files changed

+101
-53
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,3 +64,6 @@ venv.bak/
6464
.mypy_cache/
6565
.dmypy.json
6666
dmypy.json
67+
68+
# aim
69+
*.aim*

dwave/plugins/torch/models/discrete_variational_autoencoder.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -80,28 +80,26 @@ def __init__(
8080
self._decoder = decoder
8181
if latent_to_discrete is None:
8282

83-
def latent_to_discrete(
84-
logits: torch.Tensor, n_samples: int
85-
) -> torch.Tensor:
86-
# Logits is of shape (batch_size, n_discrete), we assume these logits
83+
def latent_to_discrete(logits: torch.Tensor, n_samples: int) -> torch.Tensor:
84+
# Logits is of shape (batch_size, l1, l2, ...), we assume these logits
8785
# refer to the probability of each discrete variable being 1. To use the
8886
# gumbel softmax function we need to reshape the logits to (batch_size,
89-
# n_discrete, 1), and then stack the logits to a zeros tensor of the
87+
# l1, l2, ..., 1), and then stack the logits to a zeros tensor of the
9088
# same shape. This is done to ensure that the gumbel softmax function
9189
# works correctly.
92-
90+
n_feature_dims = logits.dim() - 1
9391
logits = logits.unsqueeze(-1)
9492
logits = torch.cat((logits, torch.zeros_like(logits)), dim=-1)
9593
# We now create a new leading dimension and repeat the logits n_samples
9694
# times:
97-
logits = logits.unsqueeze(1).repeat(1, n_samples, 1, 1)
98-
one_hots = torch.nn.functional.gumbel_softmax(
99-
logits, tau=1 / 7, hard=True
95+
logits = logits.unsqueeze(1).repeat(
96+
*((1, n_samples) + (1,) * n_feature_dims + (1,))
10097
)
98+
one_hots = torch.nn.functional.gumbel_softmax(logits, tau=1 / 7, hard=True)
10199
# The constant 1/7 is used because it was used in
102100
# https://iopscience.iop.org/article/10.1088/2632-2153/aba220
103101

104-
# one_hots is of shape (batch_size, n_samples, n_discrete, 2), we need
102+
# one_hots is of shape (batch_size, n_samples, f_1, f_2, ..., 2), we need
105103
# to take the first element of the last dimension and convert it to spin
106104
# variables to make the latent space compatible with QPU models.
107105
return one_hots[..., 0] * 2 - 1
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
---
2+
fixes:
3+
- |
4+
The default ``latent_to_discrete`` transformation in
5+
``dwave.plugins.torch.models.discrete_variational_autoencoder.DiscreteVariationalAutoencoder``
6+
has been fixed to accommodate arbitrary encoders. Before, the default
7+
transformation only allowed encoders whose output shape was (B, l). Now,
8+
encoders can have an arbitrary number of feature dimensions, i.e., the
9+
shape can be (B, l1, l2, ...).

tests/test_dvae_winci2020.py

Lines changed: 81 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@
1818
from parameterized import parameterized
1919

2020
from dwave.plugins.torch.models.boltzmann_machine import GraphRestrictedBoltzmannMachine
21-
from dwave.plugins.torch.models.discrete_variational_autoencoder import DiscreteVariationalAutoencoder as DVAE
21+
from dwave.plugins.torch.models.discrete_variational_autoencoder import (
22+
DiscreteVariationalAutoencoder as DVAE,
23+
)
2224
from dwave.plugins.torch.models.losses.kl_divergence import pseudo_kl_divergence_loss
2325
from dwave.samplers import SimulatedAnnealingSampler
2426

@@ -32,10 +34,7 @@ def setUp(self):
3234
latent_features = 2
3335

3436
# Data in corners of unit square:
35-
self.data = torch.tensor([[1.0, 1.0],
36-
[1.0, 0.0],
37-
[0.0, 0.0],
38-
[0.0, 1.0]])
37+
self.data = torch.tensor([[1.0, 1.0], [1.0, 0.0], [0.0, 0.0], [0.0, 1.0]])
3938

4039
# The encoder maps input data to logits. We make this encoder without parameters
4140
# for simplicity. The encoder will map 1s to 10s and 0s to -10s, so that the
@@ -48,16 +47,46 @@ def setUp(self):
4847
# [1, -1].
4948

5049
class Encoder(torch.nn.Module):
50+
def __init__(self, n_latent_dims: int):
51+
super().__init__()
52+
self.n_latent_dims = n_latent_dims
53+
54+
def forward(self, x: torch.Tensor) -> torch.Tensor:
55+
# x is always two-dimensional of shape (batch_size, features_size)
56+
dims_to_add = self.n_latent_dims - 1
57+
output = x * 20 - 10
58+
for _ in range(dims_to_add):
59+
output = output.unsqueeze(-2)
60+
return output
61+
62+
class Decoder(torch.nn.Module):
63+
def __init__(self, latent_features: int, input_features: int):
64+
super().__init__()
65+
self.linear = torch.nn.Linear(latent_features, input_features)
66+
5167
def forward(self, x: torch.Tensor) -> torch.Tensor:
52-
return x * 20 - 10
68+
# x is of shape (batch_size, n_samples, l1, l2, ...)
69+
n_latent_dims_to_remove = x.ndim - 3
70+
for _ in range(n_latent_dims_to_remove):
71+
x = x.squeeze(1)
72+
return self.linear(x)
5373

54-
self.encoder = Encoder()
55-
self.decoder = torch.nn.Linear(latent_features, input_features)
74+
# self.encoders is a dict whose keys are the number of latent dims and the values
75+
# are the models themselves
76+
latent_dims_list = [1, 2]
77+
self.encoders = {i: Encoder(i) for i in latent_dims_list}
78+
# self.decoders is independent of number of latent dims, but we also create a dict to separate
79+
# them
80+
self.decoders = {i: Decoder(latent_features, input_features) for i in latent_dims_list}
5681

57-
self.dvae = DVAE(self.encoder, self.decoder)
82+
# self.dvaes is a dict whose keys are the numbers of latent dims and the values are the models
83+
# themselves
84+
85+
self.dvaes = {i: DVAE(self.encoders[i], self.decoders[i]) for i in latent_dims_list}
5886

5987
self.boltzmann_machine = GraphRestrictedBoltzmannMachine(
60-
nodes=(0, 1), edges=[(0, 1)],
88+
nodes=(0, 1),
89+
edges=[(0, 1)],
6190
linear={0: 0.1, 1: -0.2},
6291
quadratic={(0, 1): -1.2},
6392
)
@@ -66,8 +95,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
6695

6796
def test_mappings(self):
6897
"""Test the mapping between data and logits."""
69-
# Let's make sure that indeed the maps are correct:
70-
_, discretes, _ = self.dvae(self.data, n_samples=1)
98+
# Let's make sure that indeed the maps are correct. For this, we use only the first
99+
# autoencoder, which is the one whose encoder maps data to a single feature dimension. The
100+
# second autoencoder maps data to two feature dimensions (the last one is a dummy dimension)
101+
_, discretes, _ = self.dvaes[1](self.data, n_samples=1)
102+
# squeeze the replica dimension:
71103
discretes = discretes.squeeze(1)
72104
# map [1, 1] to [1, 1]:
73105
torch.testing.assert_close(torch.tensor([1, 1]).float(), discretes[0])
@@ -78,19 +110,18 @@ def test_mappings(self):
78110
# map [0, 1] to [-1, 1]:
79111
torch.testing.assert_close(torch.tensor([-1, 1]).float(), discretes[3])
80112

81-
def test_train(self):
113+
@parameterized.expand([1, 2])
114+
def test_train(self, n_latent_dims):
82115
"""Test training simple dataset."""
116+
dvae = self.dvaes[n_latent_dims]
83117
optimiser = torch.optim.SGD(
84-
list(self.dvae.parameters())
85-
+ list(self.boltzmann_machine.parameters()),
118+
list(dvae.parameters()) + list(self.boltzmann_machine.parameters()),
86119
lr=0.01,
87120
momentum=0.9,
88121
)
89122
N_SAMPLES = 1
90123
for _ in range(1000):
91-
latents, discretes, reconstructed_data = self.dvae(
92-
self.data, n_samples=N_SAMPLES
93-
)
124+
latents, discretes, reconstructed_data = dvae(self.data, n_samples=N_SAMPLES)
94125
true_data = self.data.unsqueeze(1).repeat(1, N_SAMPLES, 1)
95126

96127
# Measure the reconstruction loss
@@ -114,37 +145,44 @@ def test_train(self):
114145
torch.testing.assert_close(true_data, reconstructed_data)
115146
# Furthermore, the GRBM should learn that all spin strings of length 2 are
116147
# equally likely, so the h and J parameters should be close to 0:
117-
torch.testing.assert_close(self.boltzmann_machine.linear, torch.zeros(2),
118-
rtol=1e-2, atol=1e-2)
119-
torch.testing.assert_close(self.boltzmann_machine.quadratic, torch.tensor([0.0]).float(),
120-
rtol=1e-2, atol=1e-2)
121-
122-
@parameterized.expand([
123-
(
124-
1,
125-
torch.tensor([[[1., 1.]], [[1., -1.]], [[-1., -1.]], [[-1., 1.]]])
126-
),
127-
(
128-
5,
129-
torch.tensor([[[1., 1.]] * 5, [[1., -1.]] * 5, [[-1., -1.]] * 5, [[-1., 1.]] * 5])
130-
),
131-
])
148+
torch.testing.assert_close(
149+
self.boltzmann_machine.linear, torch.zeros(2), rtol=1e-2, atol=1e-2
150+
)
151+
torch.testing.assert_close(
152+
self.boltzmann_machine.quadratic, torch.tensor([0.0]).float(), rtol=1e-2, atol=1e-2
153+
)
154+
155+
@parameterized.expand(
156+
[
157+
(1, torch.tensor([[[1.0, 1.0]], [[1.0, -1.0]], [[-1.0, -1.0]], [[-1.0, 1.0]]])),
158+
(
159+
5,
160+
torch.tensor(
161+
[[[1.0, 1.0]] * 5, [[1.0, -1.0]] * 5, [[-1.0, -1.0]] * 5, [[-1.0, 1.0]] * 5]
162+
),
163+
),
164+
]
165+
)
132166
def test_latent_to_discrete(self, n_samples, expected):
133167
"""Test the latent_to_discrete default method."""
134-
latents = self.encoder(self.data)
135-
discretes = self.dvae.latent_to_discrete(latents, n_samples)
168+
# All encoders and dvaes only differ in the number of dummy feature dimensions in the latent
169+
# space. For this reason, this test can only be done with the case of one feature dimension,
170+
# which corresponds to the first encoder and dvae.
171+
latents = self.encoders[1](self.data)
172+
discretes = self.dvaes[1].latent_to_discrete(latents, n_samples)
136173
assert torch.equal(discretes, expected)
137174

138-
@parameterized.expand([0, 1, 5, 1000])
139-
def test_forward(self, n_samples):
175+
@parameterized.expand([(i, j) for i in range(1, 3) for j in [0, 1, 5, 1000]])
176+
def test_forward(self, n_latent_dims, n_samples):
140177
"""Test the forward method."""
141-
latents = self.encoder(self.data)
142-
discretes = self.dvae.latent_to_discrete(latents, n_samples)
143-
reconstructed_x = self.decoder(discretes)
178+
expected_latents = self.encoders[n_latent_dims](self.data)
179+
expected_discretes = self.dvaes[n_latent_dims].latent_to_discrete(
180+
expected_latents, n_samples
181+
)
182+
expected_reconstructed_x = self.decoders[n_latent_dims](expected_discretes)
144183

145-
expected_latents, expected_discretes, expected_reconstructed_x = self.dvae.forward(
146-
x=self.data,
147-
n_samples=n_samples
184+
latents, discretes, reconstructed_x = self.dvaes[n_latent_dims].forward(
185+
x=self.data, n_samples=n_samples
148186
)
149187

150188
assert torch.equal(reconstructed_x, expected_reconstructed_x)

0 commit comments

Comments
 (0)