Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,6 @@ venv.bak/
.mypy_cache/
.dmypy.json
dmypy.json

# aim
*.aim*
18 changes: 8 additions & 10 deletions dwave/plugins/torch/models/discrete_variational_autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,28 +80,26 @@ def __init__(
self._decoder = decoder
if latent_to_discrete is None:

def latent_to_discrete(
logits: torch.Tensor, n_samples: int
) -> torch.Tensor:
# Logits is of shape (batch_size, n_discrete), we assume these logits
def latent_to_discrete(logits: torch.Tensor, n_samples: int) -> torch.Tensor:
# Logits is of shape (batch_size, l1, l2, ...), we assume these logits
# refer to the probability of each discrete variable being 1. To use the
# gumbel softmax function we need to reshape the logits to (batch_size,
# n_discrete, 1), and then stack the logits to a zeros tensor of the
# l1, l2, ..., 1), and then stack the logits to a zeros tensor of the
# same shape. This is done to ensure that the gumbel softmax function
# works correctly.

n_feature_dims = logits.dim() - 1
logits = logits.unsqueeze(-1)
logits = torch.cat((logits, torch.zeros_like(logits)), dim=-1)
# We now create a new leading dimension and repeat the logits n_samples
# times:
logits = logits.unsqueeze(1).repeat(1, n_samples, 1, 1)
one_hots = torch.nn.functional.gumbel_softmax(
logits, tau=1 / 7, hard=True
logits = logits.unsqueeze(1).repeat(
*((1, n_samples) + (1,) * n_feature_dims + (1,))
)
one_hots = torch.nn.functional.gumbel_softmax(logits, tau=1 / 7, hard=True)
# The constant 1/7 is used because it was used in
# https://iopscience.iop.org/article/10.1088/2632-2153/aba220

# one_hots is of shape (batch_size, n_samples, n_discrete, 2), we need
# one_hots is of shape (batch_size, n_samples, f_1, f_2, ..., 2), we need
# to take the first element of the last dimension and convert it to spin
# variables to make the latent space compatible with QPU models.
return one_hots[..., 0] * 2 - 1
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
---
fixes:
- |
The default ``latent_to_discrete`` transformation in
``dwave.plugins.torch.models.discrete_variational_autoencoder.DiscreteVariationalAutoencoder``
has been fixed to accommodate arbitrary encoders. Before, the default
transformation only allowed encoders whose output shape was (B, l). Now,
encoders can have an arbitrary number of feature dimensions, i.e., the
shape can be (B, l1, l2, ...).
123 changes: 79 additions & 44 deletions tests/test_dvae_winci2020.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
from parameterized import parameterized

from dwave.plugins.torch.models.boltzmann_machine import GraphRestrictedBoltzmannMachine
from dwave.plugins.torch.models.discrete_variational_autoencoder import DiscreteVariationalAutoencoder as DVAE
from dwave.plugins.torch.models.discrete_variational_autoencoder import (
DiscreteVariationalAutoencoder as DVAE,
)
from dwave.plugins.torch.models.losses.kl_divergence import pseudo_kl_divergence_loss
from dwave.samplers import SimulatedAnnealingSampler

Expand All @@ -32,10 +34,7 @@ def setUp(self):
latent_features = 2

# Data in corners of unit square:
self.data = torch.tensor([[1.0, 1.0],
[1.0, 0.0],
[0.0, 0.0],
[0.0, 1.0]])
self.data = torch.tensor([[1.0, 1.0], [1.0, 0.0], [0.0, 0.0], [0.0, 1.0]])

# The encoder maps input data to logits. We make this encoder without parameters
# for simplicity. The encoder will map 1s to 10s and 0s to -10s, so that the
Expand All @@ -48,16 +47,45 @@ def setUp(self):
# [1, -1].

class Encoder(torch.nn.Module):
def __init__(self, n_latent_dims: int):
super().__init__()
self.n_latent_dims = n_latent_dims

def forward(self, x: torch.Tensor) -> torch.Tensor:
# x is always two-dimensional of shape (batch_size, features_size)
dims_to_add = self.n_latent_dims - 1
output = x * 20 - 10
for _ in range(dims_to_add):
output = output.unsqueeze(-2)
return output

class Decoder(torch.nn.Module):
def __init__(self, latent_features: int, input_features: int):
super().__init__()
self.linear = torch.nn.Linear(latent_features, input_features)

def forward(self, x: torch.Tensor) -> torch.Tensor:
return x * 20 - 10
# x is of shape (batch_size, n_samples, l1, l2, ...)
n_latent_dims_to_remove = x.ndim - 3
for _ in range(n_latent_dims_to_remove):
x = x.squeeze(1)
return self.linear(x)

self.encoder = Encoder()
self.decoder = torch.nn.Linear(latent_features, input_features)
# self.encoders is a dict whose keys are the number of latent dims and the values
# are the models themselves
self.encoders = {i: Encoder(i) for i in range(1, 3)}
# self.decoders is independent of number of latent dims, but we also create a dict to separate
# them
self.decoders = {i: Decoder(latent_features, input_features) for i in range(1, 3)}

self.dvae = DVAE(self.encoder, self.decoder)
# self.dvaes is a dict whose keys are the numbers of latent dims and the values are the models
# themselves

self.dvaes = {i: DVAE(self.encoders[i], self.decoders[i]) for i in range(1, 3)}

self.boltzmann_machine = GraphRestrictedBoltzmannMachine(
nodes=(0, 1), edges=[(0, 1)],
nodes=(0, 1),
edges=[(0, 1)],
linear={0: 0.1, 1: -0.2},
quadratic={(0, 1): -1.2},
)
Expand All @@ -66,8 +94,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:

def test_mappings(self):
"""Test the mapping between data and logits."""
# Let's make sure that indeed the maps are correct:
_, discretes, _ = self.dvae(self.data, n_samples=1)
# Let's make sure that indeed the maps are correct. For this, we use only the first
# autoencoder, which is the one whose encoder maps data to a single feature dimension. The
# second autoencoder maps data to two feature dimensions (the last one is a dummy dimension)
_, discretes, _ = self.dvaes[1](self.data, n_samples=1)
# squeeze the replica dimension:
discretes = discretes.squeeze(1)
# map [1, 1] to [1, 1]:
torch.testing.assert_close(torch.tensor([1, 1]).float(), discretes[0])
Expand All @@ -78,19 +109,18 @@ def test_mappings(self):
# map [0, 1] to [-1, 1]:
torch.testing.assert_close(torch.tensor([-1, 1]).float(), discretes[3])

def test_train(self):
@parameterized.expand([1, 2])
def test_train(self, n_latent_dims):
"""Test training simple dataset."""
dvae = self.dvaes[n_latent_dims]
optimiser = torch.optim.SGD(
list(self.dvae.parameters())
+ list(self.boltzmann_machine.parameters()),
list(dvae.parameters()) + list(self.boltzmann_machine.parameters()),
lr=0.01,
momentum=0.9,
)
N_SAMPLES = 1
for _ in range(1000):
latents, discretes, reconstructed_data = self.dvae(
self.data, n_samples=N_SAMPLES
)
latents, discretes, reconstructed_data = dvae(self.data, n_samples=N_SAMPLES)
true_data = self.data.unsqueeze(1).repeat(1, N_SAMPLES, 1)

# Measure the reconstruction loss
Expand All @@ -114,38 +144,43 @@ def test_train(self):
torch.testing.assert_close(true_data, reconstructed_data)
# Furthermore, the GRBM should learn that all spin strings of length 2 are
# equally likely, so the h and J parameters should be close to 0:
torch.testing.assert_close(self.boltzmann_machine.linear, torch.zeros(2),
rtol=1e-2, atol=1e-2)
torch.testing.assert_close(self.boltzmann_machine.quadratic, torch.tensor([0.0]).float(),
rtol=1e-2, atol=1e-2)

@parameterized.expand([
(
1,
torch.tensor([[[1., 1.]], [[1., -1.]], [[-1., -1.]], [[-1., 1.]]])
),
(
5,
torch.tensor([[[1., 1.]] * 5, [[1., -1.]] * 5, [[-1., -1.]] * 5, [[-1., 1.]] * 5])
),
])
torch.testing.assert_close(
self.boltzmann_machine.linear, torch.zeros(2), rtol=1e-2, atol=1e-2
)
torch.testing.assert_close(
self.boltzmann_machine.quadratic, torch.tensor([0.0]).float(), rtol=1e-2, atol=1e-2
)

@parameterized.expand(
[
(1, torch.tensor([[[1.0, 1.0]], [[1.0, -1.0]], [[-1.0, -1.0]], [[-1.0, 1.0]]])),
(
5,
torch.tensor(
[[[1.0, 1.0]] * 5, [[1.0, -1.0]] * 5, [[-1.0, -1.0]] * 5, [[-1.0, 1.0]] * 5]
),
),
]
)
def test_latent_to_discrete(self, n_samples, expected):
"""Test the latent_to_discrete default method."""
latents = self.encoder(self.data)
discretes = self.dvae.latent_to_discrete(latents, n_samples)
# All encoders and dvaes only differ in the number of dummy feature dimensions in the latent
# space. For this reason, this test can only be done with the case of one feature dimension,
# which corresponds to the first encoder and dvae.
latents = self.encoders[1](self.data)
discretes = self.dvaes[1].latent_to_discrete(latents, n_samples)
assert torch.equal(discretes, expected)

@parameterized.expand([0, 1, 5, 1000])
def test_forward(self, n_samples):
@parameterized.expand([(i, j) for i in range(1, 3) for j in [0, 1, 5, 1000]])
def test_forward(self, n_latent_dims, n_samples):
"""Test the forward method."""
latents = self.encoder(self.data)
discretes = self.dvae.latent_to_discrete(latents, n_samples)
reconstructed_x = self.decoder(discretes)
latents = self.encoders[n_latent_dims](self.data)
discretes = self.dvaes[n_latent_dims].latent_to_discrete(latents, n_samples)
reconstructed_x = self.decoders[n_latent_dims](discretes)

expected_latents, expected_discretes, expected_reconstructed_x = self.dvae.forward(
x=self.data,
n_samples=n_samples
)
expected_latents, expected_discretes, expected_reconstructed_x = self.dvaes[
n_latent_dims
].forward(x=self.data, n_samples=n_samples)

assert torch.equal(reconstructed_x, expected_reconstructed_x)
assert torch.equal(discretes, expected_discretes)
Expand Down