1818from parameterized import parameterized
1919
2020from dwave .plugins .torch .models .boltzmann_machine import GraphRestrictedBoltzmannMachine
21- from dwave .plugins .torch .models .discrete_variational_autoencoder import DiscreteVariationalAutoencoder as DVAE
21+ from dwave .plugins .torch .models .discrete_variational_autoencoder import (
22+ DiscreteVariationalAutoencoder as DVAE ,
23+ )
2224from dwave .plugins .torch .models .losses .kl_divergence import pseudo_kl_divergence_loss
2325from dwave .samplers import SimulatedAnnealingSampler
2426
@@ -32,10 +34,7 @@ def setUp(self):
3234 latent_features = 2
3335
3436 # Data in corners of unit square:
35- self .data = torch .tensor ([[1.0 , 1.0 ],
36- [1.0 , 0.0 ],
37- [0.0 , 0.0 ],
38- [0.0 , 1.0 ]])
37+ self .data = torch .tensor ([[1.0 , 1.0 ], [1.0 , 0.0 ], [0.0 , 0.0 ], [0.0 , 1.0 ]])
3938
4039 # The encoder maps input data to logits. We make this encoder without parameters
4140 # for simplicity. The encoder will map 1s to 10s and 0s to -10s, so that the
@@ -48,16 +47,46 @@ def setUp(self):
4847 # [1, -1].
4948
5049 class Encoder (torch .nn .Module ):
50+ def __init__ (self , n_latent_dims : int ):
51+ super ().__init__ ()
52+ self .n_latent_dims = n_latent_dims
53+
54+ def forward (self , x : torch .Tensor ) -> torch .Tensor :
55+ # x is always two-dimensional of shape (batch_size, features_size)
56+ dims_to_add = self .n_latent_dims - 1
57+ output = x * 20 - 10
58+ for _ in range (dims_to_add ):
59+ output = output .unsqueeze (- 2 )
60+ return output
61+
62+ class Decoder (torch .nn .Module ):
63+ def __init__ (self , latent_features : int , input_features : int ):
64+ super ().__init__ ()
65+ self .linear = torch .nn .Linear (latent_features , input_features )
66+
5167 def forward (self , x : torch .Tensor ) -> torch .Tensor :
52- return x * 20 - 10
68+ # x is of shape (batch_size, n_samples, l1, l2, ...)
69+ n_latent_dims_to_remove = x .ndim - 3
70+ for _ in range (n_latent_dims_to_remove ):
71+ x = x .squeeze (1 )
72+ return self .linear (x )
5373
54- self .encoder = Encoder ()
55- self .decoder = torch .nn .Linear (latent_features , input_features )
74+ # self.encoders is a dict whose keys are the number of latent dims and the values
75+ # are the models themselves
76+ latent_dims_list = [1 , 2 ]
77+ self .encoders = {i : Encoder (i ) for i in latent_dims_list }
78+ # self.decoders is independent of number of latent dims, but we also create a dict to separate
79+ # them
80+ self .decoders = {i : Decoder (latent_features , input_features ) for i in latent_dims_list }
5681
57- self .dvae = DVAE (self .encoder , self .decoder )
82+ # self.dvaes is a dict whose keys are the numbers of latent dims and the values are the models
83+ # themselves
84+
85+ self .dvaes = {i : DVAE (self .encoders [i ], self .decoders [i ]) for i in latent_dims_list }
5886
5987 self .boltzmann_machine = GraphRestrictedBoltzmannMachine (
60- nodes = (0 , 1 ), edges = [(0 , 1 )],
88+ nodes = (0 , 1 ),
89+ edges = [(0 , 1 )],
6190 linear = {0 : 0.1 , 1 : - 0.2 },
6291 quadratic = {(0 , 1 ): - 1.2 },
6392 )
@@ -66,8 +95,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
6695
6796 def test_mappings (self ):
6897 """Test the mapping between data and logits."""
69- # Let's make sure that indeed the maps are correct:
70- _ , discretes , _ = self .dvae (self .data , n_samples = 1 )
98+ # Let's make sure that indeed the maps are correct. For this, we use only the first
99+ # autoencoder, which is the one whose encoder maps data to a single feature dimension. The
100+ # second autoencoder maps data to two feature dimensions (the last one is a dummy dimension)
101+ _ , discretes , _ = self .dvaes [1 ](self .data , n_samples = 1 )
102+ # squeeze the replica dimension:
71103 discretes = discretes .squeeze (1 )
72104 # map [1, 1] to [1, 1]:
73105 torch .testing .assert_close (torch .tensor ([1 , 1 ]).float (), discretes [0 ])
@@ -78,19 +110,18 @@ def test_mappings(self):
78110 # map [0, 1] to [-1, 1]:
79111 torch .testing .assert_close (torch .tensor ([- 1 , 1 ]).float (), discretes [3 ])
80112
81- def test_train (self ):
113+ @parameterized .expand ([1 , 2 ])
114+ def test_train (self , n_latent_dims ):
82115 """Test training simple dataset."""
116+ dvae = self .dvaes [n_latent_dims ]
83117 optimiser = torch .optim .SGD (
84- list (self .dvae .parameters ())
85- + list (self .boltzmann_machine .parameters ()),
118+ list (dvae .parameters ()) + list (self .boltzmann_machine .parameters ()),
86119 lr = 0.01 ,
87120 momentum = 0.9 ,
88121 )
89122 N_SAMPLES = 1
90123 for _ in range (1000 ):
91- latents , discretes , reconstructed_data = self .dvae (
92- self .data , n_samples = N_SAMPLES
93- )
124+ latents , discretes , reconstructed_data = dvae (self .data , n_samples = N_SAMPLES )
94125 true_data = self .data .unsqueeze (1 ).repeat (1 , N_SAMPLES , 1 )
95126
96127 # Measure the reconstruction loss
@@ -114,37 +145,44 @@ def test_train(self):
114145 torch .testing .assert_close (true_data , reconstructed_data )
115146 # Furthermore, the GRBM should learn that all spin strings of length 2 are
116147 # equally likely, so the h and J parameters should be close to 0:
117- torch .testing .assert_close (self .boltzmann_machine .linear , torch .zeros (2 ),
118- rtol = 1e-2 , atol = 1e-2 )
119- torch .testing .assert_close (self .boltzmann_machine .quadratic , torch .tensor ([0.0 ]).float (),
120- rtol = 1e-2 , atol = 1e-2 )
121-
122- @parameterized .expand ([
123- (
124- 1 ,
125- torch .tensor ([[[1. , 1. ]], [[1. , - 1. ]], [[- 1. , - 1. ]], [[- 1. , 1. ]]])
126- ),
127- (
128- 5 ,
129- torch .tensor ([[[1. , 1. ]] * 5 , [[1. , - 1. ]] * 5 , [[- 1. , - 1. ]] * 5 , [[- 1. , 1. ]] * 5 ])
130- ),
131- ])
148+ torch .testing .assert_close (
149+ self .boltzmann_machine .linear , torch .zeros (2 ), rtol = 1e-2 , atol = 1e-2
150+ )
151+ torch .testing .assert_close (
152+ self .boltzmann_machine .quadratic , torch .tensor ([0.0 ]).float (), rtol = 1e-2 , atol = 1e-2
153+ )
154+
155+ @parameterized .expand (
156+ [
157+ (1 , torch .tensor ([[[1.0 , 1.0 ]], [[1.0 , - 1.0 ]], [[- 1.0 , - 1.0 ]], [[- 1.0 , 1.0 ]]])),
158+ (
159+ 5 ,
160+ torch .tensor (
161+ [[[1.0 , 1.0 ]] * 5 , [[1.0 , - 1.0 ]] * 5 , [[- 1.0 , - 1.0 ]] * 5 , [[- 1.0 , 1.0 ]] * 5 ]
162+ ),
163+ ),
164+ ]
165+ )
132166 def test_latent_to_discrete (self , n_samples , expected ):
133167 """Test the latent_to_discrete default method."""
134- latents = self .encoder (self .data )
135- discretes = self .dvae .latent_to_discrete (latents , n_samples )
168+ # All encoders and dvaes only differ in the number of dummy feature dimensions in the latent
169+ # space. For this reason, this test can only be done with the case of one feature dimension,
170+ # which corresponds to the first encoder and dvae.
171+ latents = self .encoders [1 ](self .data )
172+ discretes = self .dvaes [1 ].latent_to_discrete (latents , n_samples )
136173 assert torch .equal (discretes , expected )
137174
138- @parameterized .expand ([0 , 1 , 5 , 1000 ])
139- def test_forward (self , n_samples ):
175+ @parameterized .expand ([( i , j ) for i in range ( 1 , 3 ) for j in [ 0 , 1 , 5 , 1000 ] ])
176+ def test_forward (self , n_latent_dims , n_samples ):
140177 """Test the forward method."""
141- latents = self .encoder (self .data )
142- discretes = self .dvae .latent_to_discrete (latents , n_samples )
143- reconstructed_x = self .decoder (discretes )
178+ expected_latents = self .encoders [n_latent_dims ](self .data )
179+ expected_discretes = self .dvaes [n_latent_dims ].latent_to_discrete (
180+ expected_latents , n_samples
181+ )
182+ expected_reconstructed_x = self .decoders [n_latent_dims ](expected_discretes )
144183
145- expected_latents , expected_discretes , expected_reconstructed_x = self .dvae .forward (
146- x = self .data ,
147- n_samples = n_samples
184+ latents , discretes , reconstructed_x = self .dvaes [n_latent_dims ].forward (
185+ x = self .data , n_samples = n_samples
148186 )
149187
150188 assert torch .equal (reconstructed_x , expected_reconstructed_x )
0 commit comments