@@ -168,7 +168,7 @@ def test_compute_effective_field_unordered(self):
168168 # = 0.13 * [-1] - 0.17 * [1] + 0.4 = .1
169169
170170 grbm ._linear .data = torch .tensor ([- 0.1 , - 0.2 , 0.4 , 0.2 ])
171- grbm ._quadratic .data = torch .tensor ([- .15 , - 0.7 , 0.15 , 0.13 , - 0.17 ])
171+ grbm ._quadratic .data = torch .tensor ([- .15 , - 0.7 , 0.15 , 0.13 , - 0.17 ])
172172 padded = torch .tensor ([[- 1.0 , float ("nan" ), float ("nan" ), 1.0 ]])
173173 h_eff = grbm ._compute_effective_field (padded )
174174 self .assertTrue (torch .allclose (h_eff .data , torch .tensor ([- 0.5000 , 0.1000 ]), atol = 1e-6 ))
@@ -322,7 +322,7 @@ def test_sample_return_sampleset(self):
322322 self .assertEqual (4 , len (sampleset .variables ))
323323 self .assertEqual (set (grbm .nodes ), set (sampleset .variables ))
324324
325- def test_objective (self ):
325+ def test_quasi_objective (self ):
326326 # Create a triangle graph with an additional dangling vertex
327327 self .nodes = list ("abcd" )
328328 self .edges = [["a" , "b" ], ["a" , "c" ], ["a" , "d" ], ["b" , "c" ]]
@@ -356,6 +356,46 @@ def test_objective(self):
356356 self .assertEqual (- 1 , grbm .quasi_objective (s1 , s2 ).item ())
357357 self .assertEqual (- 1 , grbm .quasi_objective (s1 , s3 ))
358358
359+ def test_quasi_objective_gradient_hidden_units (self ):
360+ grbm = GRBM ([1 , 2 , 3 ],
361+ [(1 , 2 ), (1 , 3 ), (2 , 3 )],
362+ [1 ],
363+ {1 : 0.2 , 2 : 0.2 , 3 : 0.3 },
364+ {(1 , 2 ): 0.2 , (1 , 3 ): 0.3 , (2 , 3 ): 0.6 })
365+ # Note : In the digram bellow linear biases are shown using <>
366+ # quadratic biases using ()
367+ # (0.2)
368+ # Model: v1 <0.2> ----- v2 <0.2>
369+ # \ /
370+ # (0.3) \ / (0.6)
371+ # \ /
372+ # v3 <0.3>
373+ s_observed = torch .tensor ([[1.0 , - 1.0 ]])
374+ s_model = torch .tensor ([[1.0 , - 1.0 , 1.0 ]])
375+ quasi = grbm .quasi_objective (s_observed , s_model , "exact-disc" )
376+ quasi .backward ()
377+ # Compute gradients manually
378+ # Compute unnormalized density
379+ # h1v1 + h2v2 + h3v3 + J23v2v3 + J12v1v2 + J13v1v3
380+ q_plus = torch .exp (- torch .tensor (0.2 + 0.2 - 0.3 - 0.6 + 0.2 - 0.3 ))
381+ q_minus = torch .exp (- torch .tensor (- 0.2 - 0.2 + 0.3 - 0.6 + 0.2 - 0.3 ))
382+ # Normalize it
383+ z_cond = q_plus + q_minus
384+ p_plus = q_plus / z_cond
385+ p_minus = q_minus / z_cond
386+ # t = sufficient statistics = (v1 v2 v3 v1v2 v1v3 v2v3)
387+ t_plus = torch .tensor ([1 , 1 , - 1 , 1 , - 1 , - 1 ]).float ()
388+ t_minus = torch .tensor ([- 1 , 1 , - 1 , - 1 , 1 , - 1 ]).float ()
389+ t_model = torch .tensor ([1 , - 1 , 1 , - 1 , 1 , - 1 ]).float ()
390+ # Compute expected stat
391+ t_cond = t_plus * p_plus + t_minus * p_minus
392+ grad = t_cond - t_model
393+ grad_auto = torch .cat ([grbm .linear .grad , grbm .quadratic .grad ])
394+ # NOTE: this test relied on the hidden units being disconnected. This assumption gives rise
395+ # to linearity in expectation of sufficient statistics, i.e., average spin, then calculating
396+ # the sufficient statistics of the average spins.
397+ torch .testing .assert_close (grad , grad_auto )
398+
359399
360400if __name__ == "__main__" :
361401 unittest .main ()
0 commit comments