18
18
from parameterized import parameterized
19
19
20
20
from dwave .plugins .torch .models .boltzmann_machine import GraphRestrictedBoltzmannMachine
21
- from dwave .plugins .torch .models .discrete_variational_autoencoder import DiscreteVariationalAutoencoder as DVAE
21
+ from dwave .plugins .torch .models .discrete_variational_autoencoder import (
22
+ DiscreteVariationalAutoencoder as DVAE ,
23
+ )
22
24
from dwave .plugins .torch .models .losses .kl_divergence import pseudo_kl_divergence_loss
23
25
from dwave .samplers import SimulatedAnnealingSampler
24
26
@@ -32,10 +34,7 @@ def setUp(self):
32
34
latent_features = 2
33
35
34
36
# Data in corners of unit square:
35
- self .data = torch .tensor ([[1.0 , 1.0 ],
36
- [1.0 , 0.0 ],
37
- [0.0 , 0.0 ],
38
- [0.0 , 1.0 ]])
37
+ self .data = torch .tensor ([[1.0 , 1.0 ], [1.0 , 0.0 ], [0.0 , 0.0 ], [0.0 , 1.0 ]])
39
38
40
39
# The encoder maps input data to logits. We make this encoder without parameters
41
40
# for simplicity. The encoder will map 1s to 10s and 0s to -10s, so that the
@@ -48,16 +47,45 @@ def setUp(self):
48
47
# [1, -1].
49
48
50
49
class Encoder (torch .nn .Module ):
50
+ def __init__ (self , n_latent_dims : int ):
51
+ super ().__init__ ()
52
+ self .n_latent_dims = n_latent_dims
53
+
54
+ def forward (self , x : torch .Tensor ) -> torch .Tensor :
55
+ # x is always two-dimensional of shape (batch_size, features_size)
56
+ dims_to_add = self .n_latent_dims - 1
57
+ output = x * 20 - 10
58
+ for _ in range (dims_to_add ):
59
+ output = output .unsqueeze (- 2 )
60
+ return output
61
+
62
+ class Decoder (torch .nn .Module ):
63
+ def __init__ (self , latent_features : int , input_features : int ):
64
+ super ().__init__ ()
65
+ self .linear = torch .nn .Linear (latent_features , input_features )
66
+
51
67
def forward (self , x : torch .Tensor ) -> torch .Tensor :
52
- return x * 20 - 10
68
+ # x is of shape (batch_size, replica_size, l1, l2, ...)
69
+ n_latent_dims_to_remove = x .ndim - 3
70
+ for _ in range (n_latent_dims_to_remove ):
71
+ x = x .squeeze (1 )
72
+ return self .linear (x )
53
73
54
- self .encoder = Encoder ()
55
- self .decoder = torch .nn .Linear (latent_features , input_features )
74
+ # self.encoders is a dict whose keys are the number of latent dims and the values
75
+ # are the models themselves
76
+ self .encoders = {i : Encoder (i ) for i in range (1 , 3 )}
77
+ # self.decoders is independent of number of latent dims, but we also create a dict to separate
78
+ # them
79
+ self .decoders = {i : Decoder (latent_features , input_features ) for i in range (1 , 3 )}
56
80
57
- self .dvae = DVAE (self .encoder , self .decoder )
81
+ # self.dvaes is a dict whose keys are the numbers of latent dims and the values are the models
82
+ # themselves
83
+
84
+ self .dvaes = {i : DVAE (self .encoders [i ], self .decoders [i ]) for i in range (1 , 3 )}
58
85
59
86
self .boltzmann_machine = GraphRestrictedBoltzmannMachine (
60
- nodes = (0 , 1 ), edges = [(0 , 1 )],
87
+ nodes = (0 , 1 ),
88
+ edges = [(0 , 1 )],
61
89
linear = {0 : 0.1 , 1 : - 0.2 },
62
90
quadratic = {(0 , 1 ): - 1.2 },
63
91
)
@@ -66,8 +94,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
66
94
67
95
def test_mappings (self ):
68
96
"""Test the mapping between data and logits."""
69
- # Let's make sure that indeed the maps are correct:
70
- _ , discretes , _ = self .dvae (self .data , n_samples = 1 )
97
+ # Let's make sure that indeed the maps are correct. For this, we use only the first
98
+ # autoencoder, which is the one whose encoder maps data to a single feature dimension. The
99
+ # second autoencoder maps data to two feature dimensions (the last one is a dummy dimension)
100
+ _ , discretes , _ = self .dvaes [1 ](self .data , n_samples = 1 )
101
+ # squeeze the replica dimension:
71
102
discretes = discretes .squeeze (1 )
72
103
# map [1, 1] to [1, 1]:
73
104
torch .testing .assert_close (torch .tensor ([1 , 1 ]).float (), discretes [0 ])
@@ -78,19 +109,18 @@ def test_mappings(self):
78
109
# map [0, 1] to [-1, 1]:
79
110
torch .testing .assert_close (torch .tensor ([- 1 , 1 ]).float (), discretes [3 ])
80
111
81
- def test_train (self ):
112
+ @parameterized .expand ([1 , 2 ])
113
+ def test_train (self , n_latent_dims ):
82
114
"""Test training simple dataset."""
115
+ dvae = self .dvaes [n_latent_dims ]
83
116
optimiser = torch .optim .SGD (
84
- list (self .dvae .parameters ())
85
- + list (self .boltzmann_machine .parameters ()),
117
+ list (dvae .parameters ()) + list (self .boltzmann_machine .parameters ()),
86
118
lr = 0.01 ,
87
119
momentum = 0.9 ,
88
120
)
89
121
N_SAMPLES = 1
90
122
for _ in range (1000 ):
91
- latents , discretes , reconstructed_data = self .dvae (
92
- self .data , n_samples = N_SAMPLES
93
- )
123
+ latents , discretes , reconstructed_data = dvae (self .data , n_samples = N_SAMPLES )
94
124
true_data = self .data .unsqueeze (1 ).repeat (1 , N_SAMPLES , 1 )
95
125
96
126
# Measure the reconstruction loss
@@ -114,38 +144,43 @@ def test_train(self):
114
144
torch .testing .assert_close (true_data , reconstructed_data )
115
145
# Furthermore, the GRBM should learn that all spin strings of length 2 are
116
146
# equally likely, so the h and J parameters should be close to 0:
117
- torch .testing .assert_close (self .boltzmann_machine .linear , torch .zeros (2 ),
118
- rtol = 1e-2 , atol = 1e-2 )
119
- torch .testing .assert_close (self .boltzmann_machine .quadratic , torch .tensor ([0.0 ]).float (),
120
- rtol = 1e-2 , atol = 1e-2 )
121
-
122
- @parameterized .expand ([
123
- (
124
- 1 ,
125
- torch .tensor ([[[1. , 1. ]], [[1. , - 1. ]], [[- 1. , - 1. ]], [[- 1. , 1. ]]])
126
- ),
127
- (
128
- 5 ,
129
- torch .tensor ([[[1. , 1. ]] * 5 , [[1. , - 1. ]] * 5 , [[- 1. , - 1. ]] * 5 , [[- 1. , 1. ]] * 5 ])
130
- ),
131
- ])
147
+ torch .testing .assert_close (
148
+ self .boltzmann_machine .linear , torch .zeros (2 ), rtol = 1e-2 , atol = 1e-2
149
+ )
150
+ torch .testing .assert_close (
151
+ self .boltzmann_machine .quadratic , torch .tensor ([0.0 ]).float (), rtol = 1e-2 , atol = 1e-2
152
+ )
153
+
154
+ @parameterized .expand (
155
+ [
156
+ (1 , torch .tensor ([[[1.0 , 1.0 ]], [[1.0 , - 1.0 ]], [[- 1.0 , - 1.0 ]], [[- 1.0 , 1.0 ]]])),
157
+ (
158
+ 5 ,
159
+ torch .tensor (
160
+ [[[1.0 , 1.0 ]] * 5 , [[1.0 , - 1.0 ]] * 5 , [[- 1.0 , - 1.0 ]] * 5 , [[- 1.0 , 1.0 ]] * 5 ]
161
+ ),
162
+ ),
163
+ ]
164
+ )
132
165
def test_latent_to_discrete (self , n_samples , expected ):
133
166
"""Test the latent_to_discrete default method."""
134
- latents = self .encoder (self .data )
135
- discretes = self .dvae .latent_to_discrete (latents , n_samples )
167
+ # All encoders and dvaes only differ in the number of dummy feature dimensions in the latent
168
+ # space. For this reason, this test can only be done with the case of one feature dimension,
169
+ # which corresponds to the first encoder and dvae.
170
+ latents = self .encoders [1 ](self .data )
171
+ discretes = self .dvaes [1 ].latent_to_discrete (latents , n_samples )
136
172
assert torch .equal (discretes , expected )
137
173
138
- @parameterized .expand ([0 , 1 , 5 , 1000 ])
139
- def test_forward (self , n_samples ):
174
+ @parameterized .expand ([( i , j ) for i in range ( 1 , 3 ) for j in [ 0 , 1 , 5 , 1000 ] ])
175
+ def test_forward (self , n_latent_dims , n_samples ):
140
176
"""Test the forward method."""
141
- latents = self .encoder (self .data )
142
- discretes = self .dvae .latent_to_discrete (latents , n_samples )
143
- reconstructed_x = self .decoder (discretes )
177
+ latents = self .encoders [ n_latent_dims ] (self .data )
178
+ discretes = self .dvaes [ n_latent_dims ] .latent_to_discrete (latents , n_samples )
179
+ reconstructed_x = self .decoders [ n_latent_dims ] (discretes )
144
180
145
- expected_latents , expected_discretes , expected_reconstructed_x = self .dvae .forward (
146
- x = self .data ,
147
- n_samples = n_samples
148
- )
181
+ expected_latents , expected_discretes , expected_reconstructed_x = self .dvaes [
182
+ n_latent_dims
183
+ ].forward (x = self .data , n_samples = n_samples )
149
184
150
185
assert torch .equal (reconstructed_x , expected_reconstructed_x )
151
186
assert torch .equal (discretes , expected_discretes )
0 commit comments