Skip to content

Commit 673a581

Browse files
Vladimir Vargas Calderónanahitamansouri
andcommitted
support arbitrary number of feature dimensions in latent to discrete
fix number of feature dimensions [docs] consistent notation release ignore files in aim testing different number of latent dimensions extend tests to more feature dimensions Update tests/test_dvae_winci2020.py Co-authored-by: Anahita Mansouri Bigvand <[email protected]> Update tests/test_dvae_winci2020.py Co-authored-by: Anahita Mansouri Bigvand <[email protected]> Update tests/test_dvae_winci2020.py Co-authored-by: Anahita Mansouri Bigvand <[email protected]>
1 parent d3f2989 commit 673a581

File tree

4 files changed

+99
-54
lines changed

4 files changed

+99
-54
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,3 +64,6 @@ venv.bak/
6464
.mypy_cache/
6565
.dmypy.json
6666
dmypy.json
67+
68+
# aim
69+
*.aim*

dwave/plugins/torch/models/discrete_variational_autoencoder.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -80,28 +80,26 @@ def __init__(
8080
self._decoder = decoder
8181
if latent_to_discrete is None:
8282

83-
def latent_to_discrete(
84-
logits: torch.Tensor, n_samples: int
85-
) -> torch.Tensor:
86-
# Logits is of shape (batch_size, n_discrete), we assume these logits
83+
def latent_to_discrete(logits: torch.Tensor, n_samples: int) -> torch.Tensor:
84+
# Logits is of shape (batch_size, l1, l2, ...), we assume these logits
8785
# refer to the probability of each discrete variable being 1. To use the
8886
# gumbel softmax function we need to reshape the logits to (batch_size,
89-
# n_discrete, 1), and then stack the logits to a zeros tensor of the
87+
# l1, l2, ..., 1), and then stack the logits to a zeros tensor of the
9088
# same shape. This is done to ensure that the gumbel softmax function
9189
# works correctly.
92-
90+
n_feature_dims = logits.dim() - 1
9391
logits = logits.unsqueeze(-1)
9492
logits = torch.cat((logits, torch.zeros_like(logits)), dim=-1)
9593
# We now create a new leading dimension and repeat the logits n_samples
9694
# times:
97-
logits = logits.unsqueeze(1).repeat(1, n_samples, 1, 1)
98-
one_hots = torch.nn.functional.gumbel_softmax(
99-
logits, tau=1 / 7, hard=True
95+
logits = logits.unsqueeze(1).repeat(
96+
*((1, n_samples) + (1,) * n_feature_dims + (1,))
10097
)
98+
one_hots = torch.nn.functional.gumbel_softmax(logits, tau=1 / 7, hard=True)
10199
# The constant 1/7 is used because it was used in
102100
# https://iopscience.iop.org/article/10.1088/2632-2153/aba220
103101

104-
# one_hots is of shape (batch_size, n_samples, n_discrete, 2), we need
102+
# one_hots is of shape (batch_size, n_samples, f_1, f_2, ..., 2), we need
105103
# to take the first element of the last dimension and convert it to spin
106104
# variables to make the latent space compatible with QPU models.
107105
return one_hots[..., 0] * 2 - 1
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
---
2+
fixes:
3+
- |
4+
The default ``latent_to_discrete`` transformation in
5+
``dwave.plugins.torch.models.discrete_variational_autoencoder.DiscreteVariationalAutoencoder``
6+
has been fixed to accommodate arbitrary encoders. Before, the default
7+
transformation only allowed encoders whose output shape was (B, l). Now,
8+
encoders can have an arbitrary number of feature dimensions, i.e., the
9+
shape can be (B, l1, l2, ...).

tests/test_dvae_winci2020.py

Lines changed: 79 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@
1818
from parameterized import parameterized
1919

2020
from dwave.plugins.torch.models.boltzmann_machine import GraphRestrictedBoltzmannMachine
21-
from dwave.plugins.torch.models.discrete_variational_autoencoder import DiscreteVariationalAutoencoder as DVAE
21+
from dwave.plugins.torch.models.discrete_variational_autoencoder import (
22+
DiscreteVariationalAutoencoder as DVAE,
23+
)
2224
from dwave.plugins.torch.models.losses.kl_divergence import pseudo_kl_divergence_loss
2325
from dwave.samplers import SimulatedAnnealingSampler
2426

@@ -32,10 +34,7 @@ def setUp(self):
3234
latent_features = 2
3335

3436
# Data in corners of unit square:
35-
self.data = torch.tensor([[1.0, 1.0],
36-
[1.0, 0.0],
37-
[0.0, 0.0],
38-
[0.0, 1.0]])
37+
self.data = torch.tensor([[1.0, 1.0], [1.0, 0.0], [0.0, 0.0], [0.0, 1.0]])
3938

4039
# The encoder maps input data to logits. We make this encoder without parameters
4140
# for simplicity. The encoder will map 1s to 10s and 0s to -10s, so that the
@@ -48,16 +47,45 @@ def setUp(self):
4847
# [1, -1].
4948

5049
class Encoder(torch.nn.Module):
50+
def __init__(self, n_latent_dims: int):
51+
super().__init__()
52+
self.n_latent_dims = n_latent_dims
53+
54+
def forward(self, x: torch.Tensor) -> torch.Tensor:
55+
# x is always two-dimensional of shape (batch_size, features_size)
56+
dims_to_add = self.n_latent_dims - 1
57+
output = x * 20 - 10
58+
for _ in range(dims_to_add):
59+
output = output.unsqueeze(-2)
60+
return output
61+
62+
class Decoder(torch.nn.Module):
63+
def __init__(self, latent_features: int, input_features: int):
64+
super().__init__()
65+
self.linear = torch.nn.Linear(latent_features, input_features)
66+
5167
def forward(self, x: torch.Tensor) -> torch.Tensor:
52-
return x * 20 - 10
68+
# x is of shape (batch_size, replica_size, l1, l2, ...)
69+
n_latent_dims_to_remove = x.ndim - 3
70+
for _ in range(n_latent_dims_to_remove):
71+
x = x.squeeze(1)
72+
return self.linear(x)
5373

54-
self.encoder = Encoder()
55-
self.decoder = torch.nn.Linear(latent_features, input_features)
74+
# self.encoders is a dict whose keys are the number of latent dims and the values
75+
# are the models themselves
76+
self.encoders = {i: Encoder(i) for i in range(1, 3)}
77+
# self.decoders is independent of number of latent dims, but we also create a dict to separate
78+
# them
79+
self.decoders = {i: Decoder(latent_features, input_features) for i in range(1, 3)}
5680

57-
self.dvae = DVAE(self.encoder, self.decoder)
81+
# self.dvaes is a dict whose keys are the numbers of latent dims and the values are the models
82+
# themselves
83+
84+
self.dvaes = {i: DVAE(self.encoders[i], self.decoders[i]) for i in range(1, 3)}
5885

5986
self.boltzmann_machine = GraphRestrictedBoltzmannMachine(
60-
nodes=(0, 1), edges=[(0, 1)],
87+
nodes=(0, 1),
88+
edges=[(0, 1)],
6189
linear={0: 0.1, 1: -0.2},
6290
quadratic={(0, 1): -1.2},
6391
)
@@ -66,8 +94,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
6694

6795
def test_mappings(self):
6896
"""Test the mapping between data and logits."""
69-
# Let's make sure that indeed the maps are correct:
70-
_, discretes, _ = self.dvae(self.data, n_samples=1)
97+
# Let's make sure that indeed the maps are correct. For this, we use only the first
98+
# autoencoder, which is the one whose encoder maps data to a single feature dimension. The
99+
# second autoencoder maps data to two feature dimensions (the last one is a dummy dimension)
100+
_, discretes, _ = self.dvaes[1](self.data, n_samples=1)
101+
# squeeze the replica dimension:
71102
discretes = discretes.squeeze(1)
72103
# map [1, 1] to [1, 1]:
73104
torch.testing.assert_close(torch.tensor([1, 1]).float(), discretes[0])
@@ -78,19 +109,18 @@ def test_mappings(self):
78109
# map [0, 1] to [-1, 1]:
79110
torch.testing.assert_close(torch.tensor([-1, 1]).float(), discretes[3])
80111

81-
def test_train(self):
112+
@parameterized.expand([1, 2])
113+
def test_train(self, n_latent_dims):
82114
"""Test training simple dataset."""
115+
dvae = self.dvaes[n_latent_dims]
83116
optimiser = torch.optim.SGD(
84-
list(self.dvae.parameters())
85-
+ list(self.boltzmann_machine.parameters()),
117+
list(dvae.parameters()) + list(self.boltzmann_machine.parameters()),
86118
lr=0.01,
87119
momentum=0.9,
88120
)
89121
N_SAMPLES = 1
90122
for _ in range(1000):
91-
latents, discretes, reconstructed_data = self.dvae(
92-
self.data, n_samples=N_SAMPLES
93-
)
123+
latents, discretes, reconstructed_data = dvae(self.data, n_samples=N_SAMPLES)
94124
true_data = self.data.unsqueeze(1).repeat(1, N_SAMPLES, 1)
95125

96126
# Measure the reconstruction loss
@@ -114,38 +144,43 @@ def test_train(self):
114144
torch.testing.assert_close(true_data, reconstructed_data)
115145
# Furthermore, the GRBM should learn that all spin strings of length 2 are
116146
# equally likely, so the h and J parameters should be close to 0:
117-
torch.testing.assert_close(self.boltzmann_machine.linear, torch.zeros(2),
118-
rtol=1e-2, atol=1e-2)
119-
torch.testing.assert_close(self.boltzmann_machine.quadratic, torch.tensor([0.0]).float(),
120-
rtol=1e-2, atol=1e-2)
121-
122-
@parameterized.expand([
123-
(
124-
1,
125-
torch.tensor([[[1., 1.]], [[1., -1.]], [[-1., -1.]], [[-1., 1.]]])
126-
),
127-
(
128-
5,
129-
torch.tensor([[[1., 1.]] * 5, [[1., -1.]] * 5, [[-1., -1.]] * 5, [[-1., 1.]] * 5])
130-
),
131-
])
147+
torch.testing.assert_close(
148+
self.boltzmann_machine.linear, torch.zeros(2), rtol=1e-2, atol=1e-2
149+
)
150+
torch.testing.assert_close(
151+
self.boltzmann_machine.quadratic, torch.tensor([0.0]).float(), rtol=1e-2, atol=1e-2
152+
)
153+
154+
@parameterized.expand(
155+
[
156+
(1, torch.tensor([[[1.0, 1.0]], [[1.0, -1.0]], [[-1.0, -1.0]], [[-1.0, 1.0]]])),
157+
(
158+
5,
159+
torch.tensor(
160+
[[[1.0, 1.0]] * 5, [[1.0, -1.0]] * 5, [[-1.0, -1.0]] * 5, [[-1.0, 1.0]] * 5]
161+
),
162+
),
163+
]
164+
)
132165
def test_latent_to_discrete(self, n_samples, expected):
133166
"""Test the latent_to_discrete default method."""
134-
latents = self.encoder(self.data)
135-
discretes = self.dvae.latent_to_discrete(latents, n_samples)
167+
# All encoders and dvaes only differ in the number of dummy feature dimensions in the latent
168+
# space. For this reason, this test can only be done with the case of one feature dimension,
169+
# which corresponds to the first encoder and dvae.
170+
latents = self.encoders[1](self.data)
171+
discretes = self.dvaes[1].latent_to_discrete(latents, n_samples)
136172
assert torch.equal(discretes, expected)
137173

138-
@parameterized.expand([0, 1, 5, 1000])
139-
def test_forward(self, n_samples):
174+
@parameterized.expand([(i, j) for i in range(1, 3) for j in [0, 1, 5, 1000]])
175+
def test_forward(self, n_latent_dims, n_samples):
140176
"""Test the forward method."""
141-
latents = self.encoder(self.data)
142-
discretes = self.dvae.latent_to_discrete(latents, n_samples)
143-
reconstructed_x = self.decoder(discretes)
177+
latents = self.encoders[n_latent_dims](self.data)
178+
discretes = self.dvaes[n_latent_dims].latent_to_discrete(latents, n_samples)
179+
reconstructed_x = self.decoders[n_latent_dims](discretes)
144180

145-
expected_latents, expected_discretes, expected_reconstructed_x = self.dvae.forward(
146-
x=self.data,
147-
n_samples=n_samples
148-
)
181+
expected_latents, expected_discretes, expected_reconstructed_x = self.dvaes[
182+
n_latent_dims
183+
].forward(x=self.data, n_samples=n_samples)
149184

150185
assert torch.equal(reconstructed_x, expected_reconstructed_x)
151186
assert torch.equal(discretes, expected_discretes)

0 commit comments

Comments
 (0)