Skip to content

Commit 5d5f380

Browse files
authored
Fix and test autograd with hidden units (#29)
1 parent d3f2989 commit 5d5f380

File tree

3 files changed

+56
-4
lines changed

3 files changed

+56
-4
lines changed

dwave/plugins/torch/models/boltzmann_machine.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -396,6 +396,11 @@ def quasi_objective(
396396
'each other.'
397397
)
398398
raise ValueError(err_msg)
399+
# NOTE: this method relies on hidden units being disconnected. The calculations
400+
# depend on this assumption in **two** ways. The obvious one is marginalization. The
401+
# less obvious dependence is the linearity of expectation and sufficient statistics.
402+
# Because hidden units are disconnected, we can average their spins before computing
403+
# the sufficient statistics, which is then passed into the quasi objective function.
399404
obs = self._compute_expectation_disconnected(s_observed)
400405
elif kind == "sampling":
401406
obs = self._approximate_expectation_sampling(
@@ -436,10 +441,10 @@ def _compute_effective_field(self, padded: torch.Tensor) -> torch.Tensor:
436441
# by the corresponding edges. Transforming this contribution vector by a
437442
# cumulative sum yields cumulative contributions to effective fields.
438443
# Differencing removes the extra gobbledygook.
439-
contribution = padded[:, self._flat_adj] * self._quadratic[self._flat_j_idx]
444+
contribution = padded[:, self._flat_adj] * self._quadratic[self._flat_j_idx].detach()
440445
cumulative_contribution = contribution.cumsum(1)
441446
# Don't forget to add the linear fields!
442-
h_eff = self._linear[self.hidden_idx] + cumulative_contribution[
447+
h_eff = self._linear[self.hidden_idx].detach() + cumulative_contribution[
443448
:, self._bin_idx
444449
].diff(dim=1, prepend=torch.zeros(bs, device=padded.device).unsqueeze(1))
445450

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
---
2+
fixes:
3+
- |
4+
Fix the (automatic) gradient computation when hidden units are present. The issue was that the
5+
parameters, linear and quadratic weights, were used in the marginalization without being
6+
detached from the computation graph. The fix was to detach the parameters when computing
7+
effective fields.

tests/test_boltzmann_machine.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ def test_compute_effective_field_unordered(self):
168168
# = 0.13 * [-1] - 0.17 * [1] + 0.4 = .1
169169

170170
grbm._linear.data = torch.tensor([-0.1, -0.2, 0.4, 0.2])
171-
grbm._quadratic.data = torch.tensor([-.15, -0.7, 0.15, 0.13, -0.17 ])
171+
grbm._quadratic.data = torch.tensor([-.15, -0.7, 0.15, 0.13, -0.17])
172172
padded = torch.tensor([[-1.0, float("nan"), float("nan"), 1.0]])
173173
h_eff = grbm._compute_effective_field(padded)
174174
self.assertTrue(torch.allclose(h_eff.data, torch.tensor([-0.5000, 0.1000]), atol=1e-6))
@@ -322,7 +322,7 @@ def test_sample_return_sampleset(self):
322322
self.assertEqual(4, len(sampleset.variables))
323323
self.assertEqual(set(grbm.nodes), set(sampleset.variables))
324324

325-
def test_objective(self):
325+
def test_quasi_objective(self):
326326
# Create a triangle graph with an additional dangling vertex
327327
self.nodes = list("abcd")
328328
self.edges = [["a", "b"], ["a", "c"], ["a", "d"], ["b", "c"]]
@@ -356,6 +356,46 @@ def test_objective(self):
356356
self.assertEqual(-1, grbm.quasi_objective(s1, s2).item())
357357
self.assertEqual(-1, grbm.quasi_objective(s1, s3))
358358

359+
def test_quasi_objective_gradient_hidden_units(self):
360+
grbm = GRBM([1, 2, 3],
361+
[(1, 2), (1, 3), (2, 3)],
362+
[1],
363+
{1: 0.2, 2: 0.2, 3: 0.3},
364+
{(1, 2): 0.2, (1, 3): 0.3, (2, 3): 0.6})
365+
# Note : In the digram bellow linear biases are shown using <>
366+
# quadratic biases using ()
367+
# (0.2)
368+
# Model: v1 <0.2> ----- v2 <0.2>
369+
# \ /
370+
# (0.3) \ / (0.6)
371+
# \ /
372+
# v3 <0.3>
373+
s_observed = torch.tensor([[1.0, -1.0]])
374+
s_model = torch.tensor([[1.0, -1.0, 1.0]])
375+
quasi = grbm.quasi_objective(s_observed, s_model, "exact-disc")
376+
quasi.backward()
377+
# Compute gradients manually
378+
# Compute unnormalized density
379+
# h1v1 + h2v2 + h3v3 + J23v2v3 + J12v1v2 + J13v1v3
380+
q_plus = torch.exp(-torch.tensor(0.2 + 0.2 - 0.3 - 0.6 + 0.2 - 0.3))
381+
q_minus = torch.exp(-torch.tensor(-0.2 - 0.2 + 0.3 - 0.6 + 0.2 - 0.3))
382+
# Normalize it
383+
z_cond = q_plus + q_minus
384+
p_plus = q_plus / z_cond
385+
p_minus = q_minus / z_cond
386+
# t = sufficient statistics = (v1 v2 v3 v1v2 v1v3 v2v3)
387+
t_plus = torch.tensor([1, 1, -1, 1, -1, -1]).float()
388+
t_minus = torch.tensor([-1, 1, -1, -1, 1, -1]).float()
389+
t_model = torch.tensor([1, -1, 1, -1, 1, -1]).float()
390+
# Compute expected stat
391+
t_cond = t_plus*p_plus + t_minus*p_minus
392+
grad = t_cond - t_model
393+
grad_auto = torch.cat([grbm.linear.grad, grbm.quadratic.grad])
394+
# NOTE: this test relied on the hidden units being disconnected. This assumption gives rise
395+
# to linearity in expectation of sufficient statistics, i.e., average spin, then calculating
396+
# the sufficient statistics of the average spins.
397+
torch.testing.assert_close(grad, grad_auto)
398+
359399

360400
if __name__ == "__main__":
361401
unittest.main()

0 commit comments

Comments
 (0)