Skip to content

Commit 7dd9d16

Browse files
author
Vladimir Vargas Calderón
committed
support arbitrary number of feature dimensions in latent to discrete
fix number of feature dimensions [docs] consistent notation release ignore files in aim
1 parent d3f2989 commit 7dd9d16

File tree

3 files changed

+20
-10
lines changed

3 files changed

+20
-10
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,3 +64,6 @@ venv.bak/
6464
.mypy_cache/
6565
.dmypy.json
6666
dmypy.json
67+
68+
# aim
69+
*.aim*

dwave/plugins/torch/models/discrete_variational_autoencoder.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -80,28 +80,26 @@ def __init__(
8080
self._decoder = decoder
8181
if latent_to_discrete is None:
8282

83-
def latent_to_discrete(
84-
logits: torch.Tensor, n_samples: int
85-
) -> torch.Tensor:
86-
# Logits is of shape (batch_size, n_discrete), we assume these logits
83+
def latent_to_discrete(logits: torch.Tensor, n_samples: int) -> torch.Tensor:
84+
# Logits is of shape (batch_size, l1, l2, ...), we assume these logits
8785
# refer to the probability of each discrete variable being 1. To use the
8886
# gumbel softmax function we need to reshape the logits to (batch_size,
89-
# n_discrete, 1), and then stack the logits to a zeros tensor of the
87+
# l1, l2, ..., 1), and then stack the logits to a zeros tensor of the
9088
# same shape. This is done to ensure that the gumbel softmax function
9189
# works correctly.
92-
90+
n_feature_dims = logits.dim() - 1
9391
logits = logits.unsqueeze(-1)
9492
logits = torch.cat((logits, torch.zeros_like(logits)), dim=-1)
9593
# We now create a new leading dimension and repeat the logits n_samples
9694
# times:
97-
logits = logits.unsqueeze(1).repeat(1, n_samples, 1, 1)
98-
one_hots = torch.nn.functional.gumbel_softmax(
99-
logits, tau=1 / 7, hard=True
95+
logits = logits.unsqueeze(1).repeat(
96+
*((1, n_samples) + (1,) * n_feature_dims + (1,))
10097
)
98+
one_hots = torch.nn.functional.gumbel_softmax(logits, tau=1 / 7, hard=True)
10199
# The constant 1/7 is used because it was used in
102100
# https://iopscience.iop.org/article/10.1088/2632-2153/aba220
103101

104-
# one_hots is of shape (batch_size, n_samples, n_discrete, 2), we need
102+
# one_hots is of shape (batch_size, n_samples, f_1, f_2, ..., 2), we need
105103
# to take the first element of the last dimension and convert it to spin
106104
# variables to make the latent space compatible with QPU models.
107105
return one_hots[..., 0] * 2 - 1
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
---
2+
fixes:
3+
- |
4+
The default ``latent_to_discrete`` transformation in
5+
``dwave.plugins.torch.models.discrete_variational_autoencoder.DiscreteVariationalAutoencoder``
6+
has been fixed to accommodate arbitrary encoders. Before, the default
7+
transformation only allowed encoders whose output shape was (B, l). Now,
8+
encoders can have an arbitrary number of feature dimensions, i.e., the
9+
shape can be (B, l1, l2, ...).

0 commit comments

Comments
 (0)