@@ -80,28 +80,26 @@ def __init__(
80
80
self ._decoder = decoder
81
81
if latent_to_discrete is None :
82
82
83
- def latent_to_discrete (
84
- logits : torch .Tensor , n_samples : int
85
- ) -> torch .Tensor :
86
- # Logits is of shape (batch_size, n_discrete), we assume these logits
83
+ def latent_to_discrete (logits : torch .Tensor , n_samples : int ) -> torch .Tensor :
84
+ # Logits is of shape (batch_size, l1, l2, ...), we assume these logits
87
85
# refer to the probability of each discrete variable being 1. To use the
88
86
# gumbel softmax function we need to reshape the logits to (batch_size,
89
- # n_discrete , 1), and then stack the logits to a zeros tensor of the
87
+ # l1, l2, ... , 1), and then stack the logits to a zeros tensor of the
90
88
# same shape. This is done to ensure that the gumbel softmax function
91
89
# works correctly.
92
-
90
+ n_feature_dims = logits . dim () - 1
93
91
logits = logits .unsqueeze (- 1 )
94
92
logits = torch .cat ((logits , torch .zeros_like (logits )), dim = - 1 )
95
93
# We now create a new leading dimension and repeat the logits n_samples
96
94
# times:
97
- logits = logits .unsqueeze (1 ).repeat (1 , n_samples , 1 , 1 )
98
- one_hots = torch .nn .functional .gumbel_softmax (
99
- logits , tau = 1 / 7 , hard = True
95
+ logits = logits .unsqueeze (1 ).repeat (
96
+ * ((1 , n_samples ) + (1 ,) * n_feature_dims + (1 ,))
100
97
)
98
+ one_hots = torch .nn .functional .gumbel_softmax (logits , tau = 1 / 7 , hard = True )
101
99
# The constant 1/7 is used because it was used in
102
100
# https://iopscience.iop.org/article/10.1088/2632-2153/aba220
103
101
104
- # one_hots is of shape (batch_size, n_samples, n_discrete , 2), we need
102
+ # one_hots is of shape (batch_size, n_samples, f_1, f_2, ... , 2), we need
105
103
# to take the first element of the last dimension and convert it to spin
106
104
# variables to make the latent space compatible with QPU models.
107
105
return one_hots [..., 0 ] * 2 - 1
0 commit comments